From 7a9a0825053c0386a4078185fe8a384128f3504c Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Mon, 21 Jul 2025 01:04:46 -0700 Subject: [PATCH 001/153] Changed VERSION to 2.7.0.dev0 (#1973) Signed-off-by: Kshitij Janardan Lakhani --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index 2a45a8a5c..ba610dcf0 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.6.0.dev0 +2.7.0.dev0 From 5ba7953f32a8a40809520927421bfa8b490dd9cf Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 21 Jul 2025 08:41:32 -0700 Subject: [PATCH 002/153] [PyTorch] Remove GH pinned deps (#1961) * Remove GH pinned deps Signed-off-by: Kirthi Shankar Sivamani * Pin onnxscript Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- build_tools/pytorch.py | 14 +------------- docs/debug/1_getting_started.rst | 4 ++-- 2 files changed, 3 insertions(+), 15 deletions(-) diff --git a/build_tools/pytorch.py b/build_tools/pytorch.py index 3f299dca2..33a3abfb7 100644 --- a/build_tools/pytorch.py +++ b/build_tools/pytorch.py @@ -14,19 +14,7 @@ def install_requirements() -> List[str]: """Install dependencies for TE/PyTorch extensions.""" - reqs = ["torch>=2.1", "einops", "onnxscript"] - reqs.append( - "nvdlfw-inspect @" - " git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect" - ) - reqs.extend( - [ - "torch>=2.1", - "onnx", - "onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871", - ] - ) - return reqs + return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"] def test_requirements() -> List[str]: diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index bc2b95057..555b9b4b8 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -21,7 +21,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea There are 4 things one needs to do to use Transformer Engine debug features: 1. Create a configuration YAML file to configure the desired features. -2. Import, and initialize the `Nvidia-DL-Framework-Inspect `_ tool, which is installed as the dependency of the Transformer Engine. +2. Import, initialize, and install the `Nvidia-DL-Framework-Inspect `_ tool. 3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically. 4. Invoke ``debug_api.step()`` at the end of one forward-backward pass. @@ -238,4 +238,4 @@ Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_ .. figure:: ./img/tensorboard.png :align: center - Fig 2: TensorBoard with plotted stats. \ No newline at end of file + Fig 2: TensorBoard with plotted stats. From 78a382124b945179a5dcfa63ecd6d6f8c9f35f7e Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 21 Jul 2025 10:36:25 -0700 Subject: [PATCH 003/153] [PyTorch] Reset FP8 weight workspace if usages are invalid (#1972) Reset FP8 weight workspace if usages are invalid Signed-off-by: Tim Moon --- transformer_engine/pytorch/module/base.py | 30 ++++++++++++++--------- 1 file changed, 19 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 72a6c28ca..e05e83df9 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -42,7 +42,7 @@ from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer from ..tensor._internal.float8_tensor_base import Float8TensorBase from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..utils import torch_get_autocast_gpu_dtype +from ..utils import is_non_tn_fp8_gemm_supported, torch_get_autocast_gpu_dtype from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState @@ -1293,21 +1293,29 @@ def get_weight_workspace( # Try getting workspace from cache out = None - if cache_name is not None: out = self._fp8_workspaces.get(cache_name, None) - if quantizer is not None and isinstance(out, MXFP8TensorBase): + + # Reset cache if workspace is invalid + if out is not None and quantizer is not None: + reset_cache = False + if isinstance(out, Float8TensorBase): + if ( + not is_non_tn_fp8_gemm_supported() + and quantizer.columnwise_usage + and out._transpose is None + ): + reset_cache = True + elif isinstance(out, MXFP8TensorBase): if quantizer.rowwise_usage and out._rowwise_data is None: - out = None - del self._fp8_workspaces[cache_name] + reset_cache = True elif quantizer.columnwise_usage and out._columnwise_data is None: - out = None - del self._fp8_workspaces[cache_name] - - is_debug = isinstance(quantizer, DebugQuantizer) - is_out_debug_tensor = out is not None and isinstance(out, DebugQuantizedTensor) - if is_debug != is_out_debug_tensor: + reset_cache = True + if isinstance(out, DebugQuantizedTensor) != isinstance(quantizer, DebugQuantizer): + reset_cache = True + if reset_cache: out = None + del self._fp8_workspaces[cache_name] # Gather cached Fp8 workspace if it's distributed # NOTE: FSDP sharding is supported only for Fp8 buffers and will not work From ab5cc407218f8237cf537bad73ee2bb3f8efa28f Mon Sep 17 00:00:00 2001 From: yuzhongw-nvidia Date: Tue, 22 Jul 2025 01:51:15 +0800 Subject: [PATCH 004/153] Fix the condition error when checking fp8 attn in `get_attention_backend` (#1965) Update utils.py Fix the condition error of the FP8 attention in `get_attention_backend` Signed-off-by: yuzhongw-nvidia Co-authored-by: Xiaowei Ren <103958965+xrennvidia@users.noreply.github.com> --- .../pytorch/attention/dot_product_attention/utils.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 318353bf0..7c4bf928c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -609,7 +609,7 @@ def get_attention_backend( " bias for THD format" ) use_fused_attention = False - elif fp8 and head_dim_qk != head_dim_v: + elif fp8 and fp8_meta["recipe"].fp8_dpa and head_dim_qk != head_dim_v: logger.debug( "Disabling FusedAttention as it does not support context parallelism with FP8" " MLA attention" From 0d8022834679d050f99b056fbb7a8ea1055ea936 Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Mon, 21 Jul 2025 14:23:51 -0700 Subject: [PATCH 005/153] [Common] Skip cuDNN 9.10.0/9.10.1 due to bugs (#1937) * exclude 9.10.0/.1 for certain configs Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix kv_channels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add get_backend to tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * add init files Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix numerics and cuda graph tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix jax tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove prints Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor changes after renaming Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix import structure and rename get_attention_backends Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix docs and benchmarks Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix get backend calls Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "fix get backend calls" This reverts commit 653cbb51c697bc2f975416bb3aac1d85f76c36dc. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Revert "fix docs and benchmarks" This reverts commit 98cd52e04ff7c53e26b412195f5744e39f7ed0e9. Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix docs, benchmarks and pre-commit ci Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix dpa/mha flash attn selection Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix rng states Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ModelConfig Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix backend selection on Ampere Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix issues from last merge Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update tests/pytorch/utils.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove initialization of rng_states to None Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * redefine ModelConfig Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix typo Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix ModelConfig Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix seed for CP tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * Update tests/pytorch/test_sanity.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * move fixture from utils to individual tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix CI Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- benchmarks/attention/benchmark_attention.py | 8 +- .../arbitrary_mask_to_post_scale_bias.py | 2 +- docs/examples/attention/attention.ipynb | 18 +- docs/examples/attention/example_attention.py | 8 +- qa/L0_pytorch_unittest/test.sh | 4 +- qa/L1_pytorch_distributed_unittest/test.sh | 2 +- qa/L3_pytorch_FA_versions_test/test.sh | 2 +- tests/jax/test_fused_attn.py | 2 +- .../run_attention_with_cp.py} | 2 +- .../test_attention.py} | 917 ++++++++++-------- .../test_attention_with_cp.py} | 93 +- .../test_kv_cache.py | 35 +- tests/pytorch/test_cpu_offloading.py | 25 +- tests/pytorch/test_cuda_graphs.py | 45 +- tests/pytorch/test_numerics.py | 320 ++---- tests/pytorch/test_sanity.py | 248 +---- tests/pytorch/utils.py | 187 ++++ .../common/fused_attn/fused_attn.cpp | 16 +- 18 files changed, 979 insertions(+), 955 deletions(-) rename tests/pytorch/{fused_attn/run_fused_attn_with_cp.py => attention/run_attention_with_cp.py} (99%) rename tests/pytorch/{fused_attn/test_fused_attn.py => attention/test_attention.py} (77%) rename tests/pytorch/{fused_attn/test_fused_attn_with_cp.py => attention/test_attention_with_cp.py} (71%) rename tests/pytorch/{fused_attn => attention}/test_kv_cache.py (97%) diff --git a/benchmarks/attention/benchmark_attention.py b/benchmarks/attention/benchmark_attention.py index dafafdff4..1df16cc01 100644 --- a/benchmarks/attention/benchmark_attention.py +++ b/benchmarks/attention/benchmark_attention.py @@ -9,11 +9,11 @@ import torch import nvtx import transformer_engine -from tests.pytorch.fused_attn.test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention pd.set_option("display.precision", 4) @@ -197,7 +197,7 @@ def main(): ) for model in model_configs.keys(): config = model_configs[model] - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py index e9eec14d9..97f1bcd7e 100644 --- a/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py +++ b/docs/examples/attention/arbitrary_mask_to_post_scale_bias.py @@ -5,7 +5,7 @@ import os import torch from typing import Tuple -from tests.pytorch.fused_attn.test_fused_attn import ModelConfig +from tests.pytorch.utils import ModelConfig from transformer_engine.pytorch.attention import DotProductAttention # Initialize RNG state diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 53a5eede7..6cd56d23d 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -375,7 +375,7 @@ "\n", "Our [unit tests](https://github.com/NVIDIA/TransformerEngine/tree/main/tests) demonstrate the use of Transformer Engine dot product attention APIs. Users are encouraged to use them as a template when integrating Transformer Engine to their ML workflows.\n", "\n", - "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." + "For example, in PyTorch, [test_dot_product_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) offers a variety of use cases of `pytorch.DotProductAttention`, from data types, model configs, checkpointing, to QKV layouts." ] }, { @@ -394,10 +394,10 @@ "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", - "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)\n", - "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn_with_cp.py)" + "- sliding window attention: [test_dpa_swa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- MQA/GQA: [test_te_layer_mqa_gqa](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- Multi-Latent Attention: [test_dpa_mla](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)\n", + "- context parallelism: [test_cp_with_fused_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py), [test_cp_with_flash_attention](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention_with_cp.py)" ] }, { @@ -458,7 +458,7 @@ " \n", "\n", "\n", - "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", + "Some example usage of the different layouts can be found at [test_dpa_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_dpa_qkv_layout_thd](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). Transformer Engine also provides a utility function [transformer_engine.pytorch.attention.dot_product_attention.utils.get_qkv_layout](https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/pytorch/attention.py) to help determine which layout a set of `q`, `k`, `v` tensors have (PyTorch only).\n", "\n", "
\n", "Note\n", @@ -548,7 +548,7 @@ "id": "dda4a589", "metadata": {}, "source": [ - "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py).\n", + "Some more examples of running Transformer Engine with different attention masks can be found at [test_dpa_mask](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py).\n", "\n", "### 3.3 Attention Bias\n", "\n", @@ -594,7 +594,7 @@ "\n", "The framework-native backends do not explicitly support `ALiBi`, but users can convert `ALiBi` to a regular `post_scale_bias` bias to achieve the same effect. In PyTorch, this utility function, `transformer_engine.pytorch.attention.get_alibi`, can be used to help with the conversion.\n", "\n", - "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py)." + "More examples of how to use the various attention biases are at [test_dpa_bias](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py)." ] }, { @@ -612,7 +612,7 @@ "\n", "- `DelayedScaling.fp8_mha=True (default=False)`: This option, on top of `fp8_dpa=True`, removes the casting operations at the beginning and end of the `FusedAttention` module. This feature is experimental. \n", "\n", - "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/fused_attn/test_fused_attn.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." + "Examples of using the two features are available at [test_dpa_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py) and [test_mha_fp8_vs_f16](https://github.com/NVIDIA/TransformerEngine/blob/main/tests/pytorch/attention/test_attention.py). To disable FP8 attention for backward and only use it for forward, users can also set `NVTE_FP8_DPA_BWD=0 (default=1)`." ] } ], diff --git a/docs/examples/attention/example_attention.py b/docs/examples/attention/example_attention.py index 2c32e8b5f..cf650265b 100644 --- a/docs/examples/attention/example_attention.py +++ b/docs/examples/attention/example_attention.py @@ -9,11 +9,11 @@ import torch import nvtx import transformer_engine -from tests.pytorch.fused_attn.test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention # data type dtype = torch.bfloat16 @@ -90,7 +90,7 @@ def main(): models = ["test_0"] for model in models: config = model_configs[model] - available_backends, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 7fe439b37..9a924282b 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -45,8 +45,8 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py || test_fail "test_fused_attn.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/fused_attn/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index 09ef661c4..f0436d4ff 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -28,7 +28,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_attn_with_cp.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py || test_fail "test_fused_attn_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" diff --git a/qa/L3_pytorch_FA_versions_test/test.sh b/qa/L3_pytorch_FA_versions_test/test.sh index 547849e95..7e9616cd0 100644 --- a/qa/L3_pytorch_FA_versions_test/test.sh +++ b/qa/L3_pytorch_FA_versions_test/test.sh @@ -41,6 +41,6 @@ do fi # Run tests - NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py + NVTE_TORCH_COMPILE=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/pytorch/attention/test_attention.py done diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index f9e5c8ad2..29a9bc2b9 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -372,7 +372,7 @@ def _check_configs(self): self.head_dim_v, (-1, -1) if self.window_size is None else self.window_size, ).get_fused_attn_backend() - if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: + if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: pytest.skip("Unsupported inputs combination or device compute capability.") if ( diff --git a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py b/tests/pytorch/attention/run_attention_with_cp.py similarity index 99% rename from tests/pytorch/fused_attn/run_fused_attn_with_cp.py rename to tests/pytorch/attention/run_attention_with_cp.py index f1db30d99..0ad64204f 100644 --- a/tests/pytorch/fused_attn/run_fused_attn_with_cp.py +++ b/tests/pytorch/attention/run_attention_with_cp.py @@ -13,7 +13,7 @@ get_cu_seqlens_on_cp_rank, ) import transformer_engine_torch as tex -from test_fused_attn_with_cp import model_configs_flash_attn, model_configs_fused_attn +from test_attention_with_cp import model_configs_flash_attn, model_configs_fused_attn from transformer_engine.pytorch.fp8 import fp8_autocast from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor, Float8Quantizer from transformer_engine.common.recipe import DelayedScaling diff --git a/tests/pytorch/fused_attn/test_fused_attn.py b/tests/pytorch/attention/test_attention.py similarity index 77% rename from tests/pytorch/fused_attn/test_fused_attn.py rename to tests/pytorch/attention/test_attention.py index a05e64fca..4dfd54cdb 100644 --- a/tests/pytorch/fused_attn/test_fused_attn.py +++ b/tests/pytorch/attention/test_attention.py @@ -4,8 +4,9 @@ import logging import math import os +import sys +import pathlib from typing import Any, Dict, List, Tuple, Union, Optional -from contextlib import contextmanager import pytest import torch @@ -21,7 +22,6 @@ FlashAttentionUtils, get_attention_backend, check_set_window_size, - AttentionParams, ) from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import RotaryPositionEmbedding @@ -48,21 +48,22 @@ restore_from_saved, ) +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ( + reset_rng_states, + ModelConfig, + dtype_tols, + logging_context, + get_available_attention_backends, +) + # Only run FP8 tests on H100 fp8_available, reason_for_no_fp8 = fp8.FP8GlobalStateManager.is_fp8_available() -# Initialize RNG state seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() - - -def reset_rng_states() -> None: - """Revert back to initial RNG state""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) +# Reset RNG states +reset_rng_states() @pytest.fixture(autouse=True) @@ -71,170 +72,20 @@ def reset_global_fp8_state(): fp8.FP8GlobalStateManager.reset() -class ModelConfig: - def __init__( - self, - batch_size: int, - num_heads: int, - num_gqa_groups: int, - head_dim_qk: int, - max_seqlen_q: int, - max_seqlen_kv: int, - dropout_p: float, - attn_mask_type: str, - attn_bias_type: str, - head_dim_v: int = None, - alibi_type: str = "none", - num_layers: int = 1, - bias_shape: str = "1hss", - window_size: Tuple[int, int] = (-1, -1), - total_requests: int = None, - max_ctx_len: int = None, - ): - self.batch_size = batch_size - self.num_heads = num_heads - self.num_gqa_groups = num_gqa_groups - self.head_dim_qk = head_dim_qk - self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v - self.hidden_size = num_heads * head_dim_qk - self.hidden_size_kv = num_gqa_groups * self.head_dim_v - self.max_seqlen_q = max_seqlen_q - self.max_seqlen_kv = max_seqlen_kv - self.dropout_p = dropout_p - self.attn_mask_type = attn_mask_type - self.attn_bias_type = attn_bias_type - self.alibi_type = alibi_type - self.attn_type = "self" if (max_seqlen_q == max_seqlen_kv) else "cross" - self.num_layers = num_layers - self.bias_shape = bias_shape - self.window_size = window_size - self.total_requests = total_requests - self.max_ctx_len = max_ctx_len - - -@contextmanager -def logging_context(highest_level=logging.WARNING): - previous_level = logging.root.manager.disable - logging.disable(highest_level) - try: - yield - finally: - logging.disable(previous_level) - - -def _get_attention_backends( - config: ModelConfig, - qkv_dtype: torch.dtype, - qkv_layout: str, - window_size: Tuple[int, int] = (-1, -1), - pad_between_seqs: bool = False, - context_parallel: bool = False, - deterministic: bool = False, - fp8: bool = False, - fp8_meta: Optional[Dict[str, Any]] = None, - is_training: bool = True, - inference_params: Optional[InferenceParams] = None, -) -> Tuple[List, List]: - """Check if what attention backends support a model configuration""" - - os.environ["NVTE_FLASH_ATTN"] = "1" - os.environ["NVTE_FUSED_ATTN"] = "1" - os.environ["NVTE_UNFUSED_ATTN"] = "1" - _attention_backends["backend_selection_requires_update"] = True - - alibi_slopes_shape = None - if config.attn_bias_type == "alibi" and config.alibi_type == "custom": - if config.bias_shape == "1hss": - alibi_slopes_shape = [config.num_heads] - if config.bias_shape == "bhss": - alibi_slopes_shape = [config.batch_size, config.num_heads] - - core_attention_bias_shape = ( - config.bias_shape if config.attn_bias_type == "post_scale_bias" else None - ) - core_attention_bias_requires_grad = False - # d=256 is supported by cuDNN 9.0+ for inference but not training - if ( - config.attn_bias_type == "post_scale_bias" - and config.head_dim_qk <= 128 - and config.head_dim_v <= 128 - ): - core_attention_bias_requires_grad = True - - fused_attn_backends = [] - available_backends = None - flash_attention_backend = None - fused_attention_backend = None - - def test(): - attention_params = AttentionParams( - qkv_dtype=qkv_dtype, - qkv_layout=qkv_layout, - batch_size=config.batch_size, - num_heads=config.num_heads, - num_gqa_groups=config.num_gqa_groups, - max_seqlen_q=config.max_seqlen_q, - max_seqlen_kv=config.max_seqlen_kv, - head_dim_qk=config.head_dim_qk, - head_dim_v=config.head_dim_v, - attn_mask_type=config.attn_mask_type, - window_size=window_size, - alibi_slopes_shape=alibi_slopes_shape, - core_attention_bias_type=config.attn_bias_type, - core_attention_bias_shape=core_attention_bias_shape, - core_attention_bias_requires_grad=core_attention_bias_requires_grad, - pad_between_seqs=pad_between_seqs, - attention_dropout=config.dropout_p, - context_parallel=context_parallel, - deterministic=deterministic, - fp8=fp8, - fp8_meta=fp8_meta, - is_training=is_training, - inference_params=inference_params, - ) - ( - use_flash_attention, - use_fused_attention, - flash_attention_backend, - fused_attention_backend, - use_unfused_attention, - available_backends, - ) = get_attention_backend(attention_params) - # Set attention.py _attention_backends var using return value - # from get_attention_backend() - _attention_backends["use_flash_attention"] = use_flash_attention - _attention_backends["use_fused_attention"] = use_fused_attention - _attention_backends["flash_attention_backend"] = flash_attention_backend - _attention_backends["fused_attention_backend"] = fused_attention_backend - _attention_backends["use_unfused_attention"] = use_unfused_attention - _attention_backends["backend_selection_requires_update"] = False - return available_backends, flash_attention_backend, fused_attention_backend - - backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} - with logging_context(): - for i in range(3): - os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) - _attention_backends["backend_selection_requires_update"] = True - available_backends, flash_attention_backend, fused_attention_backend = test() - if fused_attention_backend == FusedAttnBackend[backends[i]]: - fused_attn_backends.append(fused_attention_backend) - return available_backends, flash_attention_backend, fused_attn_backends - - model_configs_base = { # test: b, h, hg, d, sq, skv, p, mask, bias - "base_1_0": ModelConfig(8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "base_1_1": ModelConfig(4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias"), - "base_2_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "base_2_1": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "base_3_0": ModelConfig(8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_3_1": ModelConfig(8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_4_0": ModelConfig(8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_4_1": ModelConfig(8, 16, 16, 192, 128, 2048, 0.0, "no_mask", "no_bias"), - "base_5_0": ModelConfig(8, 16, 16, 512, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_5_1": ModelConfig(8, 16, 16, 512, 128, 2048, 0.0, "no_mask", "no_bias"), - "base_6_0": ModelConfig(8, 16, 16, 1024, 1, 2048, 0.0, "no_mask", "no_bias"), - "base_6_1": ModelConfig(8, 16, 16, 1024, 128, 2048, 0.0, "no_mask", "no_bias"), + "base_1_0": ModelConfig(8, 128, 16, 64), + "base_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256), + "base_2_0": ModelConfig(2, 2048, 24, 128), + "base_2_1": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096), + "base_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048), + "base_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048), + "base_4_0": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048), + "base_4_1": ModelConfig(8, 128, 16, 192, max_seqlen_kv=2048), + "base_5_0": ModelConfig(8, 1, 16, 512, max_seqlen_kv=2048), + "base_5_1": ModelConfig(8, 128, 16, 512, max_seqlen_kv=2048), + "base_6_0": ModelConfig(8, 1, 16, 1024, max_seqlen_kv=2048), + "base_6_1": ModelConfig(8, 128, 16, 1024, max_seqlen_kv=2048), } @@ -278,7 +129,7 @@ def test_dot_product_attention( config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) is_training = True - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -289,7 +140,7 @@ def test_dot_product_attention( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: is_training = False - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -413,33 +264,19 @@ def test_dpa_checkpoint(dtype, model_configs, model): model_configs_mla = { # test: b, h, hg, dqk, sq, skv, p, mask, bias # attn , backend - "mla_1_0": ModelConfig( - 8, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # self , 0 - "mla_1_1": ModelConfig( - 4, 16, 16, 64, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # cross, 0 - "mla_1_2": ModelConfig( - 4, 16, 16, 192, 128, 256, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # cross, 0 - "mla_2_0": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias", head_dim_v=64 - ), # self , 1 + "mla_1_0": ModelConfig(8, 128, 16, 64, head_dim_v=128), # self , 0 + "mla_1_1": ModelConfig(4, 128, 16, 64, max_seqlen_kv=256, head_dim_v=128), # cross, 0 + "mla_1_2": ModelConfig(4, 128, 16, 192, max_seqlen_kv=256, head_dim_v=128), # cross, 0 + "mla_2_0": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal", head_dim_v=64), # self , 1 "mla_2_1": ModelConfig( - 1, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=64 + 1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=64 ), # cross, 1 "mla_2_2": ModelConfig( - 1, 24, 24, 192, 2048, 4096, 0.0, "causal", "no_bias", head_dim_v=128 + 1, 2048, 24, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 ), # cross, 1 - "mla_3_0": ModelConfig( - 8, 16, 16, 128, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=64 - ), # inference - "mla_3_1": ModelConfig( - 8, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # inference - "mla_3_2": ModelConfig( - 8, 16, 16, 192, 1, 2048, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), # inference + "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference + "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference + "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference } @@ -454,40 +291,46 @@ def test_dpa_mla(dtype, model_configs, model): model_configs_mask = { # test: b, h, hg, d, sq, skv, p, mask, bias - "mask_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "mask_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "mask_2_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_2_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_2_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "mask_3_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_3_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "mask_3_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "mask_4_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_4_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "mask_4_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "mask_5_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "mask_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), + "mask_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal"), + "mask_1_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), + "mask_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"), + "mask_2_1": ModelConfig( + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="causal_bottom_right" + ), + "mask_2_2": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right" ), + "mask_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), + "mask_3_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), + "mask_3_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "mask_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"), + "mask_4_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"), + "mask_4_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"), + "mask_5_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"), "mask_5_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right" ), "mask_5_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" + ), + "mask_6_0": ModelConfig(2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal"), + "mask_6_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal"), + "mask_7_0": ModelConfig( + 2, 1, 16, 128, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right" + ), + "mask_7_1": ModelConfig( + 2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="causal_bottom_right" ), - "mask_6_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal", "no_bias"), - "mask_6_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal", "no_bias"), - "mask_7_0": ModelConfig(2, 16, 16, 128, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_7_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "causal_bottom_right", "no_bias"), - "mask_8_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding", "no_bias"), - "mask_8_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding", "no_bias"), - "mask_9_0": ModelConfig(2, 24, 24, 128, 1, 2048, 0.0, "padding_causal", "no_bias"), - "mask_9_1": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "padding_causal", "no_bias"), + "mask_8_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding"), + "mask_8_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding"), + "mask_9_0": ModelConfig(2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal"), + "mask_9_1": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal"), "mask_10_0": ModelConfig( - 2, 24, 24, 128, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 1, 24, 128, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right" ), "mask_10_1": ModelConfig( - 2, 16, 16, 256, 1, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 1, 16, 256, max_seqlen_kv=2048, attn_mask_type="padding_causal_bottom_right" ), } @@ -503,44 +346,102 @@ def test_dpa_mask(dtype, model_configs, model): model_configs_bias = { # test: b, h, hg, d, sq, skv, p, mask, bias - "bias_1_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), - "bias_1_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "no_mask", "post_scale_bias"), - "bias_1_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias"), - "bias_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "post_scale_bias"), - "bias_1_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "alibi"), # skipped - "bias_1_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "alibi"), # skipped - "bias_2_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), # skipped - "bias_2_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "post_scale_bias"), # skipped + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias"), + "bias_1_1": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_bias_type="post_scale_bias"), + "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias"), + "bias_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="post_scale_bias"), + "bias_1_4": ModelConfig(4, 2048, 24, 128, attn_bias_type="alibi"), # skipped + "bias_1_5": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_bias_type="alibi" + ), # skipped + "bias_2_0": ModelConfig( + 4, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" + ), # skipped + "bias_2_1": ModelConfig( + 2, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="padding", + attn_bias_type="post_scale_bias", + ), # skipped "bias_2_2": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "padding", "post_scale_bias" + 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="post_scale_bias" ), # skipped "bias_2_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "post_scale_bias" + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding", + attn_bias_type="post_scale_bias", + ), # skipped + "bias_2_4": ModelConfig( + 4, 2048, 24, 128, attn_mask_type="padding", attn_bias_type="alibi" + ), # skipped + "bias_2_5": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", attn_bias_type="alibi" ), # skipped - "bias_2_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding", "alibi"), # skipped - "bias_2_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "alibi"), # skipped - "bias_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "bias_3_1": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "causal", "post_scale_bias"), - "bias_3_2": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), + "bias_3_0": ModelConfig( + 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "bias_3_1": ModelConfig( + 2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "bias_3_2": ModelConfig( + 4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), "bias_3_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "causal", "post_scale_bias" + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + ), # skipped + "bias_3_4": ModelConfig(4, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="alibi"), + "bias_3_5": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal", attn_bias_type="alibi" ), # skipped - "bias_3_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi"), - "bias_3_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "alibi"), # skipped "bias_4_0": ModelConfig( - 4, 16, 16, 64, 128, 128, 0.0, "padding_causal", "post_scale_bias" + 4, 128, 16, 64, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" ), # skipped "bias_4_1": ModelConfig( - 2, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias" + 2, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", ), # skipped "bias_4_2": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "post_scale_bias" + 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="post_scale_bias" ), # skipped "bias_4_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias" + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", + ), # skipped + "bias_4_4": ModelConfig( + 4, 2048, 24, 128, attn_mask_type="padding_causal", attn_bias_type="alibi" + ), # skipped + "bias_4_5": ModelConfig( + 2, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding_causal", + attn_bias_type="alibi", ), # skipped - "bias_4_4": ModelConfig(4, 24, 24, 128, 2048, 2048, 0.0, "padding_causal", "alibi"), # skipped - "bias_4_5": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "alibi"), # skipped } @@ -555,33 +456,29 @@ def test_dpa_bias(dtype, model_configs, model): model_configs_bias_shapes = { # test: b, h, hg, d, sq, skv, p, - "bias_1_0": ModelConfig( + "bias_1_0": ModelConfig(4, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="11ss"), + "bias_1_1": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias", bias_shape="1hss"), + "bias_1_2": ModelConfig(4, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="b1ss"), + "bias_1_3": ModelConfig(2, 2048, 24, 128, attn_bias_type="post_scale_bias", bias_shape="bhss"), + "bias_1_4": ModelConfig( 4, - 16, - 16, - 64, - 128, + 2048, + 24, 128, - 0.0, - # mask, bias, bias_shape, - "no_mask", - "post_scale_bias", - bias_shape="11ss", - ), - "bias_1_1": ModelConfig( - 2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias", bias_shape="1hss" - ), - "bias_1_2": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="b1ss" - ), - "bias_1_3": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "post_scale_bias", bias_shape="bhss" - ), - "bias_1_4": ModelConfig( - 4, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="1hss", alibi_type="custom" + attn_mask_type="causal", + attn_bias_type="alibi", + bias_shape="1hss", + alibi_type="custom", ), "bias_1_5": ModelConfig( - 2, 24, 24, 128, 2048, 2048, 0.0, "causal", "alibi", bias_shape="bhss", alibi_type="custom" + 2, + 2048, + 24, + 128, + attn_mask_type="causal", + attn_bias_type="alibi", + bias_shape="bhss", + alibi_type="custom", ), } @@ -597,29 +494,31 @@ def test_dpa_bias_shapes(dtype, model_configs, model): model_configs_swa = { # test: b, h, hg, d, sq, skv, p, mask, bias - "swa_1_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "swa_1_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "no_mask", "no_bias"), - "swa_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "swa_2_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal", "no_bias"), - "swa_3_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "causal_bottom_right", "no_bias"), - "swa_3_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "causal_bottom_right", "no_bias"), - "swa_4_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "swa_4_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "swa_4_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "swa_5_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "swa_5_2": ModelConfig(2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "swa_5_3": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "swa_6_1": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "swa_1_1": ModelConfig(2, 2048, 16, 64), + "swa_1_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4), + "swa_1_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096), + "swa_2_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal"), + "swa_2_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal"), + "swa_2_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal"), + "swa_3_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="causal_bottom_right"), + "swa_3_2": ModelConfig( + 2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="causal_bottom_right" ), + "swa_3_3": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="causal_bottom_right" + ), + "swa_4_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), + "swa_4_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding"), + "swa_4_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "swa_5_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"), + "swa_5_2": ModelConfig(2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), + "swa_5_3": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal"), + "swa_6_1": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"), "swa_6_2": ModelConfig( - 2, 24, 4, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, num_gqa_groups=4, attn_mask_type="padding_causal_bottom_right" ), "swa_6_3": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" ), } @@ -635,13 +534,31 @@ def test_dpa_sliding_window(dtype, model_configs, model): model_configs_alibi_slopes = { # test: b, h, hg, d, sq, skv, p, mask, bias, alibi_type - "alibi_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "alibi", alibi_type="vanilla"), - "alibi_1_1": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "causal", "alibi", alibi_type="vanilla"), + "alibi_1_0": ModelConfig( + 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="vanilla" + ), + "alibi_1_1": ModelConfig( + 1, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="causal", + attn_bias_type="alibi", + alibi_type="vanilla", + ), "alibi_2_0": ModelConfig( - 2, 24, 24, 128, 1024, 1024, 0.0, "causal", "alibi", alibi_type="custom" + 2, 1024, 24, 128, attn_mask_type="causal", attn_bias_type="alibi", alibi_type="custom" ), "alibi_2_1": ModelConfig( - 1, 24, 24, 128, 1024, 2048, 0.0, "causal", "alibi", alibi_type="custom" + 1, + 1024, + 24, + 128, + max_seqlen_kv=2048, + attn_mask_type="causal", + attn_bias_type="alibi", + alibi_type="custom", ), } @@ -671,16 +588,38 @@ def test_dpa_alibi_slopes(dtype, model_configs, model): model_configs_layout = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "no_bias"), - "layout_0_1": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "layout_0_2": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "layout_0_3": ModelConfig(1, 16, 16, 64, 128, 256, 0.0, "padding_causal", "post_scale_bias"), - "layout_1_0": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "layout_1_1": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), - "layout_1_2": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "layout_1_3": ModelConfig(1, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "post_scale_bias"), - "layout_2_0": ModelConfig(2, 16, 16, 256, 1, 2048, 0.0, "no_mask", "no_bias"), - "layout_2_1": ModelConfig(2, 24, 24, 256, 2048, 2048, 0.0, "causal", "post_scale_bias"), + "layout_0_0": ModelConfig(2, 128, 16, 64), + "layout_0_1": ModelConfig( + 2, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "layout_0_2": ModelConfig(1, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"), + "layout_0_3": ModelConfig( + 1, + 128, + 16, + 64, + max_seqlen_kv=256, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", + ), + "layout_1_0": ModelConfig(2, 2048, 24, 128), + "layout_1_1": ModelConfig( + 2, 2048, 24, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "layout_1_2": ModelConfig(1, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "layout_1_3": ModelConfig( + 1, + 2048, + 24, + 128, + max_seqlen_kv=4096, + attn_mask_type="padding_causal", + attn_bias_type="post_scale_bias", + ), + "layout_2_0": ModelConfig(2, 1, 16, 256, max_seqlen_kv=2048), + "layout_2_1": ModelConfig( + 2, 2048, 24, 256, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), } @@ -697,55 +636,54 @@ def test_dpa_qkv_layout(dtype, model_configs, model, qkv_layout): qkv_layouts_thd = ["t3hd", "th3d", "thd_t2hd", "thd_th2d", "thd_thd_thd"] model_configs_layout_thd = { # test: b, h, hg, d, sq, skv, p, mask, bias - "layout_0_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_0_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "layout_0_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias"), - "layout_1_0": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_1": ModelConfig(2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "layout_1_2": ModelConfig(2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias"), - "layout_2_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + "layout_0_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding"), + "layout_0_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding"), + "layout_0_2": ModelConfig(2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding"), + "layout_1_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal"), + "layout_1_1": ModelConfig(2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal"), + "layout_1_2": ModelConfig( + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal" ), + "layout_2_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right"), "layout_2_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias" + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal_bottom_right" ), "layout_2_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias" - ), - "layout_3_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" ), + "layout_3_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding", window_size=(4, 4)), "layout_3_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding", "no_bias", window_size=(4, 4) + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding", window_size=(4, 4) ), "layout_3_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding", "no_bias", window_size=(4, 4) - ), - "layout_4_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding", window_size=(4, 4) ), + "layout_4_0": ModelConfig(2, 2048, 16, 64, attn_mask_type="padding_causal", window_size=(4, 0)), "layout_4_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + 2, 2048, 24, 128, num_gqa_groups=1, attn_mask_type="padding_causal", window_size=(4, 0) ), "layout_4_2": ModelConfig( - 2, 24, 24, 128, 2048, 4096, 0.0, "padding_causal", "no_bias", window_size=(4, 0) + 2, 2048, 24, 128, max_seqlen_kv=4096, attn_mask_type="padding_causal", window_size=(4, 0) ), "layout_5_0": ModelConfig( - 2, 16, 16, 64, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + 2, 2048, 16, 64, attn_mask_type="padding_causal_bottom_right", window_size=(4, 0) ), "layout_5_1": ModelConfig( - 2, 24, 1, 128, 2048, 2048, 0.0, "padding_causal_bottom_right", "no_bias", window_size=(4, 0) + 2, + 2048, + 24, + 128, + num_gqa_groups=1, + attn_mask_type="padding_causal_bottom_right", + window_size=(4, 0), ), "layout_5_2": ModelConfig( 2, - 24, + 2048, 24, 128, - 2048, - 4096, - 0.0, - "padding_causal_bottom_right", - "no_bias", + max_seqlen_kv=4096, + attn_mask_type="padding_causal_bottom_right", window_size=(4, 0), ), } @@ -1135,16 +1073,22 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_te_layer = { # test: b, h, hg, d, sq, skv, p, mask, bias - "te_1_0": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "no_mask", "post_scale_bias"), - "te_1_1": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "post_scale_bias"), - "te_1_2": ModelConfig(2, 16, 16, 64, 128, 128, 0.0, "padding", "post_scale_bias"), - "te_1_3": ModelConfig(2, 16, 16, 64, 128, 256, 0.0, "padding", "no_bias"), - "te_2_0": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "causal", "no_bias"), - "te_2_1": ModelConfig(2, 16, 16, 64, 2048, 2048, 0.0, "no_mask", "no_bias"), - "te_2_2": ModelConfig(1, 16, 16, 64, 2048, 2048, 0.0, "padding", "no_bias"), - "te_2_3": ModelConfig(1, 16, 16, 64, 2048, 4096, 0.0, "padding_causal_bottom_right", "no_bias"), - "te_3_0": ModelConfig(4, 16, 16, 64, 128, 128, 0.0, "causal", "alibi"), - "te_3_1": ModelConfig(4, 16, 16, 64, 2048, 2048, 0.0, "causal", "alibi"), + "te_1_0": ModelConfig(2, 128, 16, 64, attn_bias_type="post_scale_bias"), + "te_1_1": ModelConfig( + 4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="post_scale_bias" + ), + "te_1_2": ModelConfig( + 2, 128, 16, 64, attn_mask_type="padding", attn_bias_type="post_scale_bias" + ), + "te_1_3": ModelConfig(2, 128, 16, 64, max_seqlen_kv=256, attn_mask_type="padding"), + "te_2_0": ModelConfig(1, 2048, 16, 64, attn_mask_type="causal"), + "te_2_1": ModelConfig(2, 2048, 16, 64), + "te_2_2": ModelConfig(1, 2048, 16, 64, attn_mask_type="padding"), + "te_2_3": ModelConfig( + 1, 2048, 16, 64, max_seqlen_kv=4096, attn_mask_type="padding_causal_bottom_right" + ), + "te_3_0": ModelConfig(4, 128, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"), + "te_3_1": ModelConfig(4, 2048, 16, 64, attn_mask_type="causal", attn_bias_type="alibi"), } @@ -1168,7 +1112,7 @@ def test_transformer_layer( # Test backend availability is_training = True - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=( @@ -1179,7 +1123,7 @@ def test_transformer_layer( flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends if not fused_attn_supported: is_training = False - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=( @@ -1492,20 +1436,164 @@ def _run_transformer_layer( return out, inp.grad +model_configs_fp8_extra_state = { + "large": ModelConfig(2, 128, 4, 128, num_layers=1), +} + + +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") +@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") +@pytest.mark.parametrize("model", ["large"]) +@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) +def test_sanity_attention_extra_state(model, dtype): + config = model_configs_fp8_extra_state[model] + # Test backend availability + is_training = True + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout="sb3hd", + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not fused_attn_supported and not flash_attn_supported: + pytest.skip("No attention backend available.") + + outputs = _run_attention_extra_state(dtype, config, checkpoint=False) + outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) + outputs_checkpoint_v1_6 = _run_attention_extra_state( + dtype, config, mimic_v1_6=True, checkpoint=True + ) + + # Check that results match + tols = dtype_tols(dtype) + if dtype in (torch.float16, torch.bfloat16): + tols.update(dict(rtol=2e-2, atol=2e-3)) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): + torch.testing.assert_close( + test, + ref, + **tols, + ) + + +def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): + steps = 10 + path = "checkpoint.pt" + fp8_enabled = True + fp8_recipe = recipe.DelayedScaling( + margin=0, + fp8_format=recipe.Format.HYBRID, + amax_history_len=1, + amax_compute_algo="most_recent", + fp8_dpa=fp8_enabled, + fp8_mha=False, + ) + + reset_rng_states() + hidden_states = torch.randn( + (config.max_seqlen_q, config.batch_size, config.hidden_size), + dtype=dtype, + device="cuda", + requires_grad=True, + ) + + def get_model(dtype, config): + sigma = 0.023 + init_method = init_method_normal(sigma) + output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) + + with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): + block = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_heads, + init_method=init_method, + output_layer_init_method=output_layer_init_method, + hidden_dropout=0.0, + attention_dropout=0.0, + fuse_qkv_params=True, + params_dtype=dtype, + device="cuda", + ) + return block + + block = get_model(dtype, config) + for i in range(steps // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + if checkpoint: + sd = block.state_dict() + if mimic_v1_6: + sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ + "self_attention.core_attention._extra_state" + ] + del sd["self_attention.core_attention._extra_state"] + torch.save(sd, path) + + param_grads = [] + for p in block.parameters(): + if p.requires_grad: + param_grads.append(p.grad.clone()) + + _cpu_rng_state_new = torch.get_rng_state() + _cuda_rng_state_new = torch.cuda.get_rng_state() + + del block + block = get_model(dtype, config) + block.load_state_dict(torch.load(path, weights_only=False)) + torch.set_rng_state(_cpu_rng_state_new) + torch.cuda.set_rng_state(_cuda_rng_state_new) + + for p in block.parameters(): + if p.requires_grad: + p.grad = param_grads.pop(0) + + assert not param_grads, "Oops!" + + for i in range((steps + 1) // 2): + with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): + output = block(hidden_states, None) + loss = output.sum() + loss.backward() + + torch.cuda.synchronize() + + if os.path.exists(path): + os.remove(path) + + outputs = [output, hidden_states.grad] + for p in block.parameters(): + if p.requires_grad: + outputs.append(p.grad) + + return outputs + + model_configs_fp8_vs_f16 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_9": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_10": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_11": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "no_mask", "no_bias"), - "fp8_12": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_13": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_14": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), - "fp8_15": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "fp8_16": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding", "no_bias"), - "fp8_17": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding", "no_bias"), - "fp8_18": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "fp8_19": ModelConfig(2, 24, 12, 128, 2048, 2048, 0.0, "padding_causal", "no_bias"), - "fp8_20": ModelConfig(1, 32, 4, 128, 8192, 8192, 0.0, "padding_causal", "no_bias"), + "fp8_9": ModelConfig(2, 2048, 16, 128), + "fp8_10": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12), + "fp8_11": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4), + "fp8_12": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), + "fp8_13": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="causal"), + "fp8_14": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), + "fp8_15": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding"), + "fp8_16": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding"), + "fp8_17": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding"), + "fp8_18": ModelConfig(2, 2048, 16, 128, attn_mask_type="padding_causal"), + "fp8_19": ModelConfig(2, 2048, 24, 128, num_gqa_groups=12, attn_mask_type="padding_causal"), + "fp8_20": ModelConfig(1, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="padding_causal"), } param_types_fp8_vs_f16 = [torch.float16, torch.bfloat16] @@ -1554,18 +1642,30 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" config = model_configs_fp8_vs_f16[model] - if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( - 9, - 7, - 0, - ): - pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + # Test backend availability + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # Skip if only unfused backend is supported + if (len(fused_attn_backends) + flash_attn_supported + unfused_attn_supported) < 2: + pytest.skip("Less than two backends to compare.") + if not fp8_dpa_bwd: + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_format.replace("hd", "h3d"), + is_training=is_training, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") + + if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1591,11 +1691,7 @@ def test_mha_fp8_vs_f16(dtype, model, qkv_format, input_layernorm, fp8_dpa_bwd, rtol = 5e-1 rmse_tol = 0.15 logging.debug("========== {:^25s} ==========".format("forward output")) - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + if flash_attn_supported: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1768,23 +1864,34 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): # if get_device_compute_capability() >= (10, 0): # config.dropout_p = 0.1 - if ("padding" in config.attn_mask_type or config.head_dim_qk != 128) and get_cudnn_version() < ( - 9, - 7, - 0, - ): - pytest.skip("FP8 with padding or head_dim != 128 is not supported for cuDNN < 9.7") - if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: - pytest.skip("qkv_layout not applicable for MQA/GQA") - os.environ["NVTE_FP8_DPA_BWD"] = "1" if fp8_dpa_bwd else "0" os.environ["NVTE_ALLOW_NONDETERMINISTIC_ALGO"] = "1" - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + # Test backend availability + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout=qkv_layout, + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + # Skip if only unfused backend is supported + if flash_attn_supported + fused_attn_supported < 1: + pytest.skip("No FP8 attention backend available.") + if not fp8_dpa_bwd: + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") + if config.num_heads != config.num_gqa_groups and "3" in qkv_layout: + pytest.skip("qkv_layout not applicable for MQA/GQA") + + if flash_attn_supported: os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True @@ -1813,11 +1920,7 @@ def test_dpa_fp8_vs_f16(dtype, model, qkv_layout, fp8_dpa_bwd, is_training): rmse_tol = 0.11 bwd_names = ["dq", "dk", "dv"] logging.debug("========== {:^25s} ==========".format("forward output")) - if ( - FlashAttentionUtils.v3_is_installed - and not is_training - and "padding" not in config.attn_mask_type - ): + if flash_attn_supported: _error( flash_attn_fwd_fp8, fused_attn_fwd_f16, @@ -1991,14 +2094,14 @@ def get_dummy_cuda_rng_tracker() -> CudaRNGStatesTracker: model_configs_fp8 = { # test: b, h, hg, d, sq, skv, p, mask, bias - "fp8_1": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "no_mask", "no_bias"), - "fp8_2": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), - "fp8_3": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_4": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "no_mask", "no_bias"), - "fp8_5": ModelConfig(1, 1, 1, 64, 512, 512, 0.0, "causal", "no_bias"), - "fp8_6": ModelConfig(4, 16, 16, 64, 512, 512, 0.0, "causal", "no_bias"), - "fp8_7": ModelConfig(1, 1, 1, 128, 2048, 2048, 0.0, "causal", "no_bias"), - "fp8_8": ModelConfig(2, 24, 24, 128, 2048, 2048, 0.0, "causal", "no_bias"), + "fp8_1": ModelConfig(1, 512, 1, 64), + "fp8_2": ModelConfig(4, 512, 16, 64), + "fp8_3": ModelConfig(1, 2048, 1, 128), + "fp8_4": ModelConfig(2, 2048, 24, 128), + "fp8_5": ModelConfig(1, 512, 1, 64, attn_mask_type="causal"), + "fp8_6": ModelConfig(4, 512, 16, 64, attn_mask_type="causal"), + "fp8_7": ModelConfig(1, 2048, 1, 128, attn_mask_type="causal"), + "fp8_8": ModelConfig(2, 2048, 24, 128, attn_mask_type="causal"), } param_types_fp8 = [torch.float16, torch.bfloat16] cudnn_frontend_version = int(os.getenv("NVTE_FUSED_ATTN_FE_VER", "1")) @@ -2027,6 +2130,18 @@ def test_custom_mha_fp8_vs_f16(dtype, model): config = model_configs_fp8[model] + # Test backend availability + is_training = True + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=torch.float8_e4m3fn, + qkv_layout="t3hd" if cudnn_frontend_version == 0 else "bs3hd", + is_training=is_training, + ) + flash_attn_supported, fused_attn_supported, unfused_attn_supported = available_backends + if not (fused_attn_backends and unfused_attn_supported): + pytest.skip("Not enough backends to run this test with.") + fused_attn_fwd_fp8, fused_attn_bwd_fp8 = _run_custom_mha_fp8(dtype, config, "FusedAttention") unfused_attn_fwd_f16, unfused_attn_bwd_f16 = _run_ref_mha_f16(dtype, config, "UnfusedAttention") diff --git a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py similarity index 71% rename from tests/pytorch/fused_attn/test_fused_attn_with_cp.py rename to tests/pytorch/attention/test_attention_with_cp.py index 458070c9b..0e8501abf 100644 --- a/tests/pytorch/fused_attn/test_fused_attn_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -4,6 +4,8 @@ import os import subprocess +import sys +import pathlib import pytest import torch @@ -12,26 +14,28 @@ get_cudnn_version, ) from transformer_engine.pytorch.attention.dot_product_attention.utils import FlashAttentionUtils -from test_fused_attn import ModelConfig + +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ModelConfig, get_available_attention_backends + +# Initialize RNG state +seed = 1234 +torch.manual_seed(seed) +torch.cuda.manual_seed(seed) model_configs_flash_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA - "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA - "cp_1_2": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) - ), # MHA - "cp_1_3": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) - ), # MHA - "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + "cp_1_2": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_1_3": ModelConfig(2, 4096, 12, 128, window_size=(512, 512)), # MHA + "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA "cp_2_2": ModelConfig( - 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) - ), # GQA - "cp_2_3": ModelConfig( - 2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias", window_size=(512, 512) + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA + "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA } @@ -43,7 +47,7 @@ def get_bash_arguments(num_gpus_per_node, **kwargs): "--nproc-per-node=" + str(num_gpus_per_node), ] te_path = os.getenv("TE_PATH", "/opt/transformerengine") - script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py") + script_path = os.path.join(te_path, "tests/pytorch/attention/run_attention_with_cp.py") args.append(script_path) for k, v in kwargs.items(): args.append(f"{k}={v}") @@ -93,32 +97,36 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): model_configs_fused_attn = { # test: b, h, hg, d, sq, skv, p, mask, bias - "cp_1_0": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias"), # MHA - "cp_1_1": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # MHA - "cp_1_2": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # MHA - "cp_1_3": ModelConfig(2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # MHA - "cp_1_4": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + "cp_1_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal"), # MHA + "cp_1_1": ModelConfig(2, 4096, 12, 128), # MHA + "cp_1_2": ModelConfig( + 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias" ), # MHA - "cp_2_0": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias"), # GQA - "cp_2_1": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "no_bias"), # GQA - "cp_2_2": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "causal", "post_scale_bias"), # GQA - "cp_2_3": ModelConfig(2, 12, 2, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias"), # GQA + "cp_1_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias"), # MHA + "cp_1_4": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", window_size=(512, 0)), # MHA + "cp_2_0": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal"), # GQA + "cp_2_1": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2), # GQA + "cp_2_2": ModelConfig( + 2, + 4096, + 12, + 128, + num_gqa_groups=2, + attn_mask_type="causal", + attn_bias_type="post_scale_bias", + ), # GQA + "cp_2_3": ModelConfig( + 2, 4096, 12, 128, num_gqa_groups=2, attn_bias_type="post_scale_bias" + ), # GQA "cp_2_4": ModelConfig( - 2, 12, 2, 128, 4096, 4096, 0.0, "causal", "no_bias", window_size=(512, 0) + 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA - "cp_3_0": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=64 - ), # MLA - "cp_3_1": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=64 - ), # MLA + "cp_3_0": ModelConfig(2, 4096, 12, 128, attn_mask_type="causal", head_dim_v=64), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 128, head_dim_v=64), # MLA "cp_3_2": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "causal", "post_scale_bias", head_dim_v=64 - ), # MLA - "cp_3_3": ModelConfig( - 2, 12, 12, 128, 4096, 4096, 0.0, "no_mask", "post_scale_bias", head_dim_v=64 + 2, 4096, 12, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias", head_dim_v=64 ), # MLA + "cp_3_3": ModelConfig(2, 4096, 12, 128, attn_bias_type="post_scale_bias", head_dim_v=64), # MLA } @@ -175,6 +183,17 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha pytest.skip("MLA CP currently only support KV P2P!") if dtype == "fp8" and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently does not support FP8 attention!") + dtypes = {"fp16": torch.float16, "bf16": torch.bfloat16, "fp8": torch.bfloat16} + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtypes[dtype], + qkv_layout="_".join([qkv_format] * 3), + window_size=config.window_size, + context_parallel=True, + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("No attention backend available.") subprocess.run( get_bash_arguments( diff --git a/tests/pytorch/fused_attn/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py similarity index 97% rename from tests/pytorch/fused_attn/test_kv_cache.py rename to tests/pytorch/attention/test_kv_cache.py index 967309459..288c5382e 100644 --- a/tests/pytorch/fused_attn/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -5,18 +5,14 @@ from collections import OrderedDict from typing import List import os +import sys +import pathlib import logging import math import pytest import torch -from test_fused_attn import ( - ModelConfig, - reset_rng_states, - _get_attention_backends, -) - from torch.distributions import Exponential from transformer_engine.pytorch import make_graphed_callables from transformer_engine.common import recipe @@ -34,26 +30,25 @@ is_bf16_compatible, ) -# Initialize RNG state -seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ( + ModelConfig, + reset_rng_states, + get_available_attention_backends, +) +# Reset RNG states +reset_rng_states() param_types = [torch.float16] if is_bf16_compatible(): param_types.append(torch.bfloat16) model_configs_infer = { - # test: b, h, hg, d, sq, skv, p, mask, bias - "infer_0": ModelConfig( - 4, 16, 16, 128, 64, 64, 0.0, "no_mask", "no_bias", total_requests=8, max_ctx_len=16 - ), - "infer_1": ModelConfig( - 2, 16, 4, 256, 66, 66, 0.0, "no_mask", "no_bias", total_requests=6, max_ctx_len=16 - ), + # test: b, sq, hq, dqk, + "infer_0": ModelConfig(4, 64, 16, 128, total_requests=8, max_ctx_len=16), + "infer_1": ModelConfig(2, 66, 16, 256, num_gqa_groups=4, total_requests=6, max_ctx_len=16), } qkv_formats = ["bshd", "sbhd", "thd"] @@ -470,7 +465,7 @@ def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_g qkv_layout = qkv_format + "_" + "_".join([inference_params_qkv_format] * 2) if is_paged: qkv_layout = "paged_kv_" + qkv_layout - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 87494f3c2..cd71d5b93 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -10,6 +10,8 @@ import transformer_engine.pytorch as te from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from utils import ModelConfig, get_available_attention_backends # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -22,10 +24,13 @@ recipe.DelayedScaling(), ] -SIZE = 512 -NUM_HEADS = 8 -NUM_LAYERS = 5 -EPSILON = 0.1 +model_config = { + "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), +} +SIZE = model_config["small"].hidden_size +NUM_HEADS = model_config["small"].num_heads +NUM_LAYERS = model_config["small"].num_layers +EPSILON = model_config["small"].eps # Flash attention saves some internal tensor for the backward pass # that cannot be offloaded to CPU. @@ -130,6 +135,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if model_key in ["multihead_attention", "transformer_layer"]: + available_backends, *_ = get_available_attention_backends( + model_config["small"], + qkv_dtype=torch.bfloat16, + qkv_layout="sbhd_sbhd_sbhd", + ) + _, fused_attn_supported, _ = available_backends + if not fused_attn_supported: + pytest.skip("Fused attention backend not available.") + os.environ["NVTE_FLASH_ATTN"] = "0" + _attention_backends["backend_selection_requires_update"] = True + without_offloading = _measure_memory_between_forward_and_backward( models_list, fp8_recipe, False ) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 7bfe506f2..83837eafd 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -23,7 +23,7 @@ from transformer_engine.pytorch.utils import is_bf16_compatible import transformer_engine.pytorch.ops as te_ops from transformer_engine.common import recipe - +from utils import ModelConfig, reset_rng_states # Check if FP8 is supported. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -32,27 +32,12 @@ ) mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +# Reset RNG states. +reset_rng_states() -# Record initial RNG state. -seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() - - -@dataclass -class ModelConfig: - """Data tensor dimensions within Transformer model""" - - sequence_length: int - batch_size: int - hidden_size: int - num_heads: int - kv_channels: int - - -model_configs = {"small": ModelConfig(2, 32, 64, 2, 32)} +model_configs = { + "small": ModelConfig(32, 2, 2, 32), +} fp8_recipes = [ recipe.DelayedScaling(), @@ -67,12 +52,6 @@ class ModelConfig: dtypes.append(torch.bfloat16) -def reset_rng_states() -> None: - """Revert to initial RNG state.""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield @@ -107,7 +86,7 @@ def generate_data( """Generate synthetic data.""" gen_func = torch.ones if warmup else torch.randn return gen_func( - model_config.sequence_length, + model_config.max_seqlen_q, model_config.batch_size, model_config.hidden_size, device="cuda", @@ -389,7 +368,7 @@ def generate_data_for_dot_product_attention( gen_func = torch.ones if warmup else torch.randn return [ gen_func( - model_config.sequence_length, + model_config.max_seqlen_q, model_config.batch_size, model_config.num_heads, model_config.kv_channels, @@ -483,8 +462,8 @@ def _test_cuda_graphs_with_kwargs( ( model_config.batch_size, 1, - model_config.sequence_length, - model_config.sequence_length, + model_config.max_seqlen_q, + model_config.max_seqlen_kv, ), dtype=torch.bool, device="cuda", @@ -510,8 +489,8 @@ def _test_cuda_graphs_with_kwargs( ( model_config.batch_size, 1, - model_config.sequence_length, - model_config.sequence_length, + model_config.max_seqlen_q, + model_config.max_seqlen_kv, ), dtype=torch.bool, device="cuda", diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 440be43a0..790bc7a11 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -40,11 +40,13 @@ from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version from transformer_engine.common import recipe import transformer_engine_torch as tex +from utils import ModelConfig, reset_rng_states, get_available_attention_backends # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -56,33 +58,18 @@ sm_80plus = get_device_compute_capability() >= (8, 0) seed = 1234 -torch.manual_seed(seed) -torch.cuda.manual_seed(seed) -# Record initial RNG state from script run. -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() +# Reset RNG states. +reset_rng_states() torch._dynamo.config.recompile_limit = 16 -class ModelConfig: - def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq_len): - self.hidden_size = hidden_size - self.eps = eps - self.num_attention_heads = num_attention_heads - self.embed = embed - self.num_layers = num_layers - self.seq_len = seq_len - - model_configs = { - "small": ModelConfig(128, 1e-5, 8, 36, 4, 128), - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 2048), + "small": ModelConfig(1, 128, 8, 16, num_layers=4), + "126m": ModelConfig(1, 2048, 12, 64, num_layers=12), } - model_configs_inference = { - # hidden_size, eps, num_attention_heads, embed, num_layers, seq_len - "126m": ModelConfig(768, 1e-5, 12, 64, 12, 256), + "126m": ModelConfig(1, 256, 12, 64, num_layers=12), } backends_inference = ["FlashAttention", "UnfusedAttention", "FusedAttention"] module_inference = ["TransformerLayer", "MultiheadAttention"] @@ -124,6 +111,18 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq ] +def is_fused_attn_available( + config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True +): + available_backends, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + is_training=is_training, + ) + return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends + + def get_causal_attn_mask(sq: int) -> torch.Tensor: return torch.triu(torch.ones(sq, sq, device="cuda"), diagonal=1).bool() @@ -173,12 +172,6 @@ def assert_allclose( raise AssertionError(msg) -def reset_rng_states() -> None: - """revert back to initial RNG state.""" - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - @pytest.fixture(autouse=True) def reset_global_fp8_state(): yield @@ -531,13 +524,13 @@ def _test_e2e_selective_recompute( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -546,13 +539,13 @@ def _test_e2e_selective_recompute( ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): te_out = block( @@ -626,13 +619,13 @@ def _test_e2e_full_recompute( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -641,14 +634,14 @@ def _test_e2e_full_recompute( ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=use_reentrant, ) if use_reentrant: te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if recompute: @@ -757,13 +750,13 @@ def _test_e2e_checkpointing_get_model(config, dtype): return TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -775,7 +768,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= reset_rng_states() te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -805,14 +798,14 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= if p.requires_grad: param_grads.append(p.grad.clone()) - global _cpu_rng_state, _cuda_rng_state _cpu_rng_state = torch.get_rng_state() _cuda_rng_state = torch.cuda.get_rng_state() del block block = _test_e2e_checkpointing_get_model(config, dtype) block.load_state_dict(torch.load(path, weights_only=False)) - reset_rng_states() + torch.set_rng_state(_cpu_rng_state) + torch.cuda.set_rng_state(_cuda_rng_state) for p in block.parameters(): if p.requires_grad: @@ -845,6 +838,8 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] + if not is_fused_attn_available(config, dtype): + pytest.skip("No attention backend available.") outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -865,13 +860,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - inp_attn_mask = get_causal_attn_mask(config.seq_len) + inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) out = block(inp_hidden_states, attention_mask=inp_attn_mask) loss = out.sum() @@ -891,11 +886,13 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] + if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + pytest.skip("No attention backend available.") te_gpt = TransformerLayer( hidden_size=config.hidden_size, ffn_hidden_size=4 * config.hidden_size, - num_attention_heads=config.num_attention_heads, + num_attention_heads=config.num_heads, layernorm_epsilon=config.eps, attention_dropout=0.1, hidden_dropout=0.1, @@ -910,7 +907,7 @@ def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): TorchGPT( config.hidden_size, config.eps, - config.num_attention_heads, + config.num_heads, parallel_attention_mlp=parallel_attention_mlp, ) .to(dtype=dtype) @@ -971,13 +968,13 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - inp_attn_mask = get_causal_attn_mask(config.seq_len) if mask_type == "causal" else None + inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) if mask_type == "causal" else None forward_kwargs = {} if te: @@ -1002,10 +999,12 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] + if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + pytest.skip("No attention backend available.") te_mha = MultiheadAttention( config.hidden_size, - config.num_attention_heads, + config.num_heads, fuse_qkv_params=True, params_dtype=dtype, qkv_weight_interleaved=False, @@ -1016,7 +1015,7 @@ def test_mha_accuracy(dtype, bs, model, mask_type): torch_mha = ( TorchMHA( config.hidden_size, - config.num_attention_heads, + config.num_heads, ) .to(dtype=dtype) .cuda() @@ -1062,7 +1061,7 @@ def _test_granular_accuracy(block, bs, dtype, config, delay_wgrad_compute=False, FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1094,11 +1093,12 @@ def _test_dpa_accuracy(block, bs, dtype, config): reset_rng_states() mask = torch.triu( - torch.ones(config.seq_len, config.seq_len, dtype=torch.bool, device="cuda"), diagonal=1 + torch.ones(config.max_seqlen_q, config.max_seqlen_kv, dtype=torch.bool, device="cuda"), + diagonal=1, ) query, key, value = [ torch.randn( - (config.seq_len, bs, config.num_attention_heads, config.embed), + (config.max_seqlen_q, bs, config.num_heads, config.kv_channels), dtype=dtype, device="cuda", requires_grad=True, @@ -1127,8 +1127,8 @@ def test_dpa_accuracy(dtype, bs, model): te_dpa = ( DotProductAttention( - config.num_attention_heads, - config.embed, + config.num_heads, + config.kv_channels, attention_dropout=0.0, # disable dropout, FU uses rng differently ) .to(dtype=dtype) @@ -1137,7 +1137,7 @@ def test_dpa_accuracy(dtype, bs, model): torch_dpa = ( TorchDotProductAttention( - config.embed, + config.kv_channels, 0.0, # dropout ) .to(dtype=dtype) @@ -1286,7 +1286,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -1726,7 +1726,7 @@ def _test_grouped_linear_accuracy( FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -1739,14 +1739,14 @@ def _test_grouped_linear_accuracy( split_size = 16 if recipe.mxfp8(): split_size = 128 - m = config.seq_len // split_size + m = config.max_seqlen_q // split_size dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist() dist.append(dist[-1]) # Manually add a zero m_splits = torch.tensor(dist + [m]) - torch.tensor([0] + dist) m_splits = m_splits * split_size - assert m_splits.sum() == config.seq_len and len(m_splits) == num_gemms + assert m_splits.sum() == config.max_seqlen_q and len(m_splits) == num_gemms else: - m_splits = torch.tensor([config.seq_len]) + m_splits = torch.tensor([config.max_seqlen_q]) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, GroupedLinear): @@ -1812,7 +1812,7 @@ def test_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -1916,7 +1916,7 @@ def test_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2064,14 +2064,14 @@ def _generate_random_numbers(n, total_sum): FP8GlobalStateManager.reset() inp_hidden_states = torch.randn( - (config.seq_len * bs, config.hidden_size), + (config.max_seqlen_q * bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) inp_hidden_states.retain_grad() - m_splits = _generate_random_numbers(num_gemms, config.seq_len * bs) + m_splits = _generate_random_numbers(num_gemms, config.max_seqlen_q * bs) with fp8_autocast(enabled=fp8, fp8_recipe=recipe): if isinstance(block, TorchGroupedLinearWithPadding): @@ -2124,7 +2124,7 @@ def test_padding_grouped_linear_accuracy( pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2201,7 +2201,7 @@ def test_padding_grouped_linear_accuracy_save_original_input( pytest.skip("DelayedScaling recipe is not supported with save_original_input") config = model_configs[model] - if config.seq_len % 16 != 0 and fp8: + if config.max_seqlen_q % 16 != 0 and fp8: pytest.skip("FP8 requires sequence length to be divisible by 16.") with fp8_model_init(enabled=fp8 and fp8_model_params, recipe=recipe): @@ -2258,9 +2258,11 @@ def _test_gpt_e2e_cuda_graph(block, bs, dtype, config, graph): # Placeholders used for graph capture. static_input = torch.randn( - config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True + config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype, requires_grad=True + ) + static_target = torch.randn( + config.max_seqlen_q, bs, config.hidden_size, device="cuda", dtype=dtype ) - static_target = torch.randn(config.seq_len, bs, config.hidden_size, device="cuda", dtype=dtype) real_input = torch.rand_like(static_input) real_target = torch.rand_like(static_target) @@ -2324,7 +2326,7 @@ def test_gpt_cuda_graph(dtype, bs, model): block_args = ( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, ) block_kwargs = dict( layernorm_epsilon=config.eps, @@ -2332,7 +2334,7 @@ def test_gpt_cuda_graph(dtype, bs, model): output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2367,13 +2369,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, attention_dropout=0.1, - kv_channels=config.embed, + kv_channels=config.kv_channels, apply_residual_connection_post_layernorm=False, output_layernorm=False, params_dtype=dtype, @@ -2382,13 +2384,13 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): ) te_inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_hidden_states.retain_grad() - te_inp_attn_mask = get_causal_attn_mask(config.seq_len) + te_inp_attn_mask = get_causal_attn_mask(config.max_seqlen_q) with fp8_autocast(enabled=True, fp8_recipe=recipe): te_out = block(te_inp_hidden_states, attention_mask=te_inp_attn_mask) @@ -2451,13 +2453,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_sbhd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2472,13 +2474,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_bshd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2490,13 +2492,13 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): block_thd = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, layernorm_epsilon=config.eps, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0, attention_dropout=0, - kv_channels=config.embed, + kv_channels=config.kv_channels, params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, @@ -2511,15 +2513,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): assert torch.all(torch.eq(p1, p2) & torch.eq(p1, p3)), f"{n1}, {n2} and {n3} not identical" x_sbhd = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) x_bshd = x_sbhd.transpose(0, 1).contiguous() - x_thd = x_bshd.reshape(bs * config.seq_len, config.hidden_size).contiguous() - x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.seq_len + x_thd = x_bshd.reshape(bs * config.max_seqlen_q, config.hidden_size).contiguous() + x_thd_cumsum = torch.arange(bs + 1, device="cuda", dtype=torch.int32) * config.max_seqlen_q # To make sure forward is also identical (just in case some module decides # to act fancy) @@ -2546,165 +2548,15 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): x_thd, cu_seqlens_q=x_thd_cumsum, cu_seqlens_kv=x_thd_cumsum, - max_seqlen_q=config.seq_len, - max_seqlen_kv=config.seq_len, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, ) torch.testing.assert_close( y_bshd, - y_thd.reshape(bs, config.seq_len, config.hidden_size).contiguous(), - ) - - -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) -@pytest.mark.parametrize("model_key", model_configs_inference.keys()) -@pytest.mark.parametrize("use_RoPE", all_boolean) -@pytest.mark.parametrize("input_format", input_formats_inference) -@pytest.mark.parametrize("module", module_inference) -@pytest.mark.parametrize("backend", backends_inference) -@pytest.mark.parametrize("is_paged", [False, True]) -def test_kv_cache_accuracy(dtype, bs, model_key, use_RoPE, input_format, module, backend, is_paged): - reset_rng_states() - - if backend in ["FusedAttention", "FlashAttention"] and dtype == torch.float32: - pytest.skip("FusedAttention and FlashAttention do not support FP32") - if use_RoPE: - pytest.skip("KV cache does not support starting positions for RoPE") - if ( - backend == "FusedAttention" - and get_device_compute_capability() == (8, 9) - and get_cudnn_version() < (9, 12, 0) - ): - pytest.skip("Skip KV cache for sm89 and cuDNN < 9.12") - - os.environ["NVTE_FLASH_ATTN"] = "0" - os.environ["NVTE_FUSED_ATTN"] = "0" - os.environ["NVTE_UNFUSED_ATTN"] = "0" - - if backend == "FlashAttention": - os.environ["NVTE_FLASH_ATTN"] = "1" - elif backend == "FusedAttention": - os.environ["NVTE_FUSED_ATTN"] = "1" - elif backend == "UnfusedAttention": - os.environ["NVTE_UNFUSED_ATTN"] = "1" - - config = model_configs_inference[model_key] - - S = config.seq_len - B = bs - H = config.num_attention_heads - D = config.hidden_size - head_size = config.embed - layer_number = 1 - - # Limits the max size of KV-cache - B_max = B - S_max = S - - if module == "TransformerLayer": - model = TransformerLayer( - hidden_size=D, - ffn_hidden_size=4 * D, - num_attention_heads=H, - attn_input_format=input_format, - self_attn_mask_type="causal", - enc_dec_attn_mask_type="causal", - layer_number=layer_number, - attention_dropout=0.0, - params_dtype=dtype, - device="cuda", - ).eval() - else: - model = ( - MultiheadAttention( - hidden_size=D, - num_attention_heads=H, - qkv_format=input_format, - layer_number=layer_number, - attention_dropout=0.0, - attn_mask_type="causal", - params_dtype=dtype, - ) - .cuda() - .eval() + y_thd.reshape(bs, config.max_seqlen_q, config.hidden_size).contiguous(), ) - inference_params = InferenceParams( - max_batch_size=B_max, - max_sequence_length=S_max, - num_heads_kv=H, - head_dim_k=head_size, - dtype=dtype, - is_paged=is_paged, - total_num_pages=int(B_max * S_max / 256), - page_size=256, - ) - - rotary_freqs = torch.randn((S_max, 1, 1, head_size), dtype=torch.float, device="cuda") - - input = torch.randn((S, B, D), dtype=dtype, device="cuda") - if input_format == "bshd": - input = input.transpose(0, 1).contiguous() - - incremental_output = torch.zeros_like(input) - - # Generate output for the entire sequence - full_output = model(hidden_states=input, rotary_pos_emb=rotary_freqs if use_RoPE else None) - - # Incrementaly generate outputs using KV-cache - step_dict = OrderedDict(zip(list(range(B)), [1] * B)) - for i in range(S): - inference_params.pre_step(step_dict) - - if input_format == "sbhd": - incremental_input = input[i].view(1, B, D) - else: - incremental_input = input[:, i, :].view(B, 1, D) - - seqlens_q = torch.ones(B, dtype=torch.int32, device="cuda") - cu_seqlens_q = torch.zeros(B + 1, dtype=torch.int32, device="cuda") - cu_seqlens_q[1:] = torch.cumsum(seqlens_q, dim=0) - cu_seqlens_kv = cu_seqlens_q.clone() - - mask_type = "padding" - kwargs = {} - if module == "TransformerLayer": - kwargs["self_attn_mask_type"] = mask_type - else: - kwargs["attn_mask_type"] = mask_type - line_output = model( - hidden_states=incremental_input, - inference_params=inference_params, - rotary_pos_emb=rotary_freqs if use_RoPE else None, - **kwargs, - max_seqlen_q=1, - max_seqlen_kv=S, - cu_seqlens_q=cu_seqlens_q, - cu_seqlens_kv=cu_seqlens_kv, - ) - - if input_format == "sbhd": - incremental_output[i, :, :] = line_output.view(B, D) - else: - incremental_output[:, i, :] = line_output.view(B, D) - - if module == "TransformerLayer": - atol = { - torch.float32: 5e-3, - torch.half: 5e-3, - torch.bfloat16: 5e-2, - } - else: - atol = { - torch.float32: 1e-3, - torch.half: 1e-3, - torch.bfloat16: 1e-2, - } - - # Check if the fully generated output matches the one generated incrementally - assert_allclose(full_output, incremental_output, atol[dtype]) - @pytest.mark.parametrize( "shape", diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 00dff53da..4df6d987a 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -46,7 +46,7 @@ from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint -from utils import dtype_tols +from utils import ModelConfig, dtype_tols # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -59,8 +59,6 @@ seed = 1234 torch.manual_seed(seed) torch.cuda.manual_seed(seed) -_cpu_rng_state = torch.get_rng_state() -_cuda_rng_state = torch.cuda.get_rng_state() NVTE_TEST_NVINSPECT_ENABLED = int(os.environ.get("NVTE_TEST_NVINSPECT_ENABLED", "0")) @@ -105,37 +103,22 @@ def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: return torch.min(amax_history, dim=0).values -def reset_rng_states() -> None: - """revert back to initial RNG state.""" - global _cpu_rng_state, _cuda_rng_state - torch.set_rng_state(_cpu_rng_state) - torch.cuda.set_rng_state(_cuda_rng_state) - - -@dataclass -class ModelConfig: - """Transformer model configuration""" - - num_layers: int - seq_len: int - batch_size: int - hidden_size: int - num_attention_heads: int - kv_channels: Optional[int] = None - - def is_fp8_supported(self): - if self.seq_len * self.batch_size % 16: - return False - if self.hidden_size % 16: - return False - return True +def is_fp8_supported(config: ModelConfig): + if ( + config.max_seqlen_q * config.batch_size % 16 + or config.max_seqlen_kv * config.batch_size % 16 + ): + return False + if config.hidden_size % 16 or config.hidden_size_kv % 16: + return False + return True model_configs = { - "126m": ModelConfig(12, 2048, 2, 768, 12), - "small": ModelConfig(2, 32, 2, 64, 2), - "weird": ModelConfig(2, 37, 3, 69, 3), - "large": ModelConfig(1, 128, 2, 512, 4, 128), + "126m": ModelConfig(2, 2048, 12, 64, num_layers=12), + "small": ModelConfig(2, 32, 2, 32, num_layers=2), + "weird": ModelConfig(3, 37, 3, 23, num_layers=2), + "large": ModelConfig(2, 128, 4, 128, num_layers=1), } fp8_recipes = [ @@ -184,7 +167,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): # Placeholders used for capture. static_input = torch.randn( - config.seq_len, + config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", @@ -192,7 +175,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): requires_grad=True, ) static_target = torch.randn( - config.seq_len, config.batch_size, config.hidden_size, device="cuda", dtype=dtype + config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", dtype=dtype ) real_input = torch.rand_like(static_input) @@ -236,7 +219,7 @@ def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=torch.float32, device="cuda", requires_grad=True, @@ -244,7 +227,7 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states.retain_grad() te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -271,14 +254,14 @@ def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -311,7 +294,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -337,7 +320,7 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -345,7 +328,7 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_attn_mask = torch.randint( 2, - (config.batch_size, 1, 1, config.seq_len), + (config.batch_size, 1, 1, config.max_seqlen_q), dtype=torch.bool, device="cuda", ) @@ -363,21 +346,21 @@ def _test_sanity_e2e_bert(block, dtype, config, fp8_recipe, skip_wgrad): def _test_sanity_e2e_T5(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, ) te_inp_attn_mask = torch.randint( 2, - (1, 1, config.seq_len, config.seq_len), + (1, 1, config.max_seqlen_q, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) enc_dec_attn_mask = torch.randint( 2, - (config.batch_size, 1, 1, config.seq_len), + (config.batch_size, 1, 1, config.max_seqlen_kv), dtype=torch.bool, device="cuda", ) @@ -405,7 +388,7 @@ def _test_sanity_common( pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, device="cuda", requires_grad=not skip_dgrad, @@ -433,7 +416,7 @@ def _test_sanity_normalization_amp(block, dtype, config, skip_wgrad, skip_dgrad) pytest.skip("No gradient computation; Skipping to avoid PyTorch RuntimeError.") te_inp = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), + (config.max_seqlen_q, config.batch_size, config.hidden_size), device="cuda", requires_grad=True, ) @@ -494,7 +477,7 @@ def test_sanity_layernorm_linear( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -528,7 +511,7 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -555,7 +538,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ pytest.skip("Quantized model parameters are not supported in debug mode.") config = model_configs[model] ffn_hidden_size = 4 * config.hidden_size - num_tokens = bs * config.seq_len + num_tokens = bs * config.max_seqlen_q if fp8_recipe is not None: if not fp8_available: @@ -564,7 +547,7 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None @@ -600,7 +583,7 @@ def test_sanity_grouped_linear( ffn_hidden_size = 4 * config.hidden_size # Small batch size used to catch bug from https://github.com/NVIDIA/TransformerEngine/pull/1527. bs = bs * 16 - num_tokens = bs * config.seq_len * (num_gemms - 1) + num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) if fp8_recipe is not None: if not fp8_available: @@ -609,7 +592,7 @@ def test_sanity_grouped_linear( pytest.skip(reason_for_no_mxfp8) if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: pytest.skip(reason_for_no_fp8_block_scaling) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") use_fp8 = fp8_recipe is not None @@ -621,7 +604,7 @@ def test_sanity_grouped_linear( inp_hidden_states = torch.randn( num_tokens, config.hidden_size, dtype=dtype, requires_grad=True ).cuda() - m_splits = [bs * config.seq_len] * num_gemms + m_splits = [bs * config.max_seqlen_q] * num_gemms if empty_split == "first": m_splits[0] = 0 elif empty_split == "last": @@ -665,7 +648,7 @@ def test_sanity_layernorm_mlp( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -719,7 +702,7 @@ def test_sanity_gpt( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -729,7 +712,7 @@ def test_sanity_gpt( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -788,7 +771,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -798,7 +781,7 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -849,7 +832,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -859,7 +842,7 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -908,7 +891,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -918,7 +901,7 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -945,7 +928,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -955,7 +938,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -985,7 +968,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -995,7 +978,7 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1028,7 +1011,7 @@ def test_sanity_gradient_accumulation_fusion( pytest.skip(reason_for_no_fp8_block_scaling) if fp8_recipe.mxfp8() and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -1038,7 +1021,7 @@ def test_sanity_gradient_accumulation_fusion( block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1074,7 +1057,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm pytest.skip(reason_for_no_mxfp8) if fp8_recipe.float8_block_scaling(): pytest.skip("cuda graph not supported for float8_block_scaling recipe") - if not config.is_fp8_supported(): + if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") sigma = 0.023 @@ -1084,7 +1067,7 @@ def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamm block = TransformerLayer( config.hidden_size, 4 * config.hidden_size, - config.num_attention_heads, + config.num_heads, init_method=init_method, output_layer_init_method=output_layer_init_method, hidden_dropout=0.1, @@ -1156,133 +1139,6 @@ def test_sanity_fp8_gemm_with_unalignment(N, datatype): torch.cuda.synchronize() -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -@pytest.mark.skipif(get_device_compute_capability() < (9, 0), reason="FP8 tests require Hopper.") -@pytest.mark.skipif(get_cudnn_version() < (9, 3, 0), reason="cuDNN 9.3.0+ is required.") -@pytest.mark.parametrize("model", ["large"]) -@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16]) -def test_sanity_attention_extra_state(model, dtype): - config = model_configs[model] - outputs = _run_attention_extra_state(dtype, config, checkpoint=False) - outputs_checkpoint = _run_attention_extra_state(dtype, config, checkpoint=True) - outputs_checkpoint_v1_6 = _run_attention_extra_state( - dtype, config, mimic_v1_6=True, checkpoint=True - ) - - # Check that results match - tols = dtype_tols(dtype) - if dtype in (torch.float16, torch.bfloat16): - tols.update(dict(rtol=2e-2, atol=2e-3)) - for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint)): - torch.testing.assert_close( - test, - ref, - **tols, - ) - for i, (ref, test) in enumerate(zip(outputs, outputs_checkpoint_v1_6)): - torch.testing.assert_close( - test, - ref, - **tols, - ) - - -def _run_attention_extra_state(dtype, config, checkpoint=False, mimic_v1_6=False): - steps = 10 - path = "checkpoint.pt" - fp8_enabled = True - fp8_recipe = recipe.DelayedScaling( - margin=0, - fp8_format=recipe.Format.HYBRID, - amax_history_len=1, - amax_compute_algo="most_recent", - fp8_dpa=fp8_enabled, - fp8_mha=False, - ) - - reset_rng_states() - hidden_states = torch.randn( - (config.seq_len, config.batch_size, config.hidden_size), - dtype=dtype, - device="cuda", - requires_grad=True, - ) - - def get_model(dtype, config): - sigma = 0.023 - init_method = init_method_normal(sigma) - output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - - with fp8_model_init(enabled=fp8_enabled, recipe=fp8_recipe): - block = TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_attention_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.0, - attention_dropout=0.0, - fuse_qkv_params=True, - params_dtype=dtype, - device="cuda", - ) - return block - - block = get_model(dtype, config) - for i in range(steps // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) - loss = output.sum() - loss.backward() - - if checkpoint: - sd = block.state_dict() - if mimic_v1_6: - sd["self_attention.core_attention.fused_attention._extra_state"] = sd[ - "self_attention.core_attention._extra_state" - ] - del sd["self_attention.core_attention._extra_state"] - torch.save(sd, path) - - param_grads = [] - for p in block.parameters(): - if p.requires_grad: - param_grads.append(p.grad.clone()) - - _cpu_rng_state_new = torch.get_rng_state() - _cuda_rng_state_new = torch.cuda.get_rng_state() - - del block - block = get_model(dtype, config) - block.load_state_dict(torch.load(path, weights_only=False)) - torch.set_rng_state(_cpu_rng_state_new) - torch.cuda.set_rng_state(_cuda_rng_state_new) - - for p in block.parameters(): - if p.requires_grad: - p.grad = param_grads.pop(0) - - assert not param_grads, "Oops!" - - for i in range((steps + 1) // 2): - with fp8_autocast(enabled=fp8_enabled, fp8_recipe=fp8_recipe): - output = block(hidden_states, None) - loss = output.sum() - loss.backward() - - torch.cuda.synchronize() - - if os.path.exists(path): - os.remove(path) - - outputs = [output, hidden_states.grad] - for p in block.parameters(): - if p.requires_grad: - outputs.append(p.grad) - - return outputs - - @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_replace_raw_data_for_float8tensor(): """Test the functionality of replace_raw_data""" diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 61ccfc6f2..524bd3289 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -4,12 +4,24 @@ from __future__ import annotations +import logging +import os +from contextlib import contextmanager + +import pytest import torch import transformer_engine import transformer_engine.common.recipe import transformer_engine.pytorch as te import transformer_engine_torch as tex +from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.attention.dot_product_attention.utils import ( + get_attention_backend, + AttentionParams, + AttentionLogging, +) +from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend def str_to_dtype(dtype: str | torch.dtype) -> torch.dtype: @@ -106,3 +118,178 @@ def make_recipe(name: Optional[str]) -> Optional[Recipe]: if name == "fp8_block_scaling": return transformer_engine.common.recipe.Float8BlockScaling() raise ValueError(f"Unsupported quantization scheme ({name})") + + +# Cached RNG state +_rng_states: Optional[Tuple[torch.Tensor, torch.Tensor]] = None + + +def reset_rng_states() -> None: + """Revert to deterministic RNG state""" + global _rng_states + if _rng_states is None: + torch.manual_seed(1234) + torch.cuda.manual_seed(1234) + _rng_states = (torch.get_rng_state(), torch.cuda.get_rng_state()) + else: + cpu_rng_state, cuda_rng_state = _rng_states + torch.set_rng_state(cpu_rng_state) + torch.cuda.set_rng_state(cuda_rng_state) + + +class ModelConfig: + def __init__( + self, + batch_size: int, + max_seqlen_q: int, + num_heads: int, + head_dim_qk: int, + max_seqlen_kv: int = None, + num_gqa_groups: int = None, + head_dim_v: int = None, + dropout_p: float = 0.0, + attn_mask_type: str = "no_mask", + attn_bias_type: str = "no_bias", + alibi_type: str = "none", + bias_shape: str = "1hss", + window_size: Tuple[int, int] = (-1, -1), + total_requests: int = None, + max_ctx_len: int = None, + num_layers: int = 1, + eps: float = 1e-5, + ): + self.batch_size = batch_size + self.max_seqlen_q = max_seqlen_q + self.max_seqlen_kv = max_seqlen_q if max_seqlen_kv is None else max_seqlen_kv + self.num_heads = num_heads + self.num_gqa_groups = num_heads if num_gqa_groups is None else num_gqa_groups + self.head_dim_qk = head_dim_qk + self.head_dim_v = head_dim_qk if head_dim_v is None else head_dim_v + if self.head_dim_qk == self.head_dim_v: + self.kv_channels = self.head_dim_qk + else: + self.kv_channels = (self.head_dim_qk, self.head_dim_v) + self.hidden_size = self.num_heads * self.head_dim_qk + self.hidden_size_kv = self.num_gqa_groups * self.head_dim_v + self.dropout_p = dropout_p + self.attn_mask_type = attn_mask_type + self.attn_bias_type = attn_bias_type + self.alibi_type = alibi_type + self.attn_type = "self" if (self.max_seqlen_q == self.max_seqlen_kv) else "cross" + self.bias_shape = bias_shape + self.window_size = window_size + self.total_requests = total_requests + self.max_ctx_len = max_ctx_len + self.num_layers = num_layers + self.eps = eps + + +@contextmanager +def logging_context(highest_level=logging.WARNING): + previous_level = logging.root.manager.disable + logging.disable(highest_level) + try: + yield + finally: + logging.disable(previous_level) + + +def get_available_attention_backends( + config: ModelConfig, + qkv_dtype: torch.dtype, + qkv_layout: str, + window_size: Tuple[int, int] = (-1, -1), + pad_between_seqs: bool = False, + context_parallel: bool = False, + deterministic: bool = False, + fp8: bool = False, + fp8_meta: Optional[Dict[str, Any]] = None, + is_training: bool = True, + inference_params: Optional[InferenceParams] = None, +) -> Tuple[List, List]: + """Check for all available attention backends that support a model configuration""" + + os.environ["NVTE_FLASH_ATTN"] = "1" + os.environ["NVTE_FUSED_ATTN"] = "1" + os.environ["NVTE_UNFUSED_ATTN"] = "1" + _attention_backends["backend_selection_requires_update"] = True + + alibi_slopes_shape = None + if config.attn_bias_type == "alibi" and config.alibi_type == "custom": + if config.bias_shape == "1hss": + alibi_slopes_shape = [config.num_heads] + if config.bias_shape == "bhss": + alibi_slopes_shape = [config.batch_size, config.num_heads] + + core_attention_bias_shape = ( + config.bias_shape if config.attn_bias_type == "post_scale_bias" else None + ) + core_attention_bias_requires_grad = False + # d=256 is supported by cuDNN 9.0+ for inference but not training + if ( + config.attn_bias_type == "post_scale_bias" + and config.head_dim_qk <= 128 + and config.head_dim_v <= 128 + ): + core_attention_bias_requires_grad = True + + fused_attn_backends = [] + available_backends = None + flash_attention_backend = None + fused_attention_backend = None + + def test(): + attention_params = AttentionParams( + qkv_dtype=qkv_dtype, + qkv_layout=qkv_layout, + batch_size=config.batch_size, + num_heads=config.num_heads, + num_gqa_groups=config.num_gqa_groups, + max_seqlen_q=config.max_seqlen_q, + max_seqlen_kv=config.max_seqlen_kv, + head_dim_qk=config.head_dim_qk, + head_dim_v=config.head_dim_v, + attn_mask_type=config.attn_mask_type, + window_size=window_size, + alibi_slopes_shape=alibi_slopes_shape, + core_attention_bias_type=config.attn_bias_type, + core_attention_bias_shape=core_attention_bias_shape, + core_attention_bias_requires_grad=core_attention_bias_requires_grad, + pad_between_seqs=pad_between_seqs, + attention_dropout=config.dropout_p, + context_parallel=context_parallel, + deterministic=deterministic, + fp8=fp8, + fp8_meta=fp8_meta, + is_training=is_training, + inference_params=inference_params, + ) + ( + use_flash_attention, + use_fused_attention, + flash_attention_backend, + fused_attention_backend, + use_unfused_attention, + available_backends, + ) = get_attention_backend(attention_params) + # Set attention.py _attention_backends var using return value + # from get_attention_backend() + _attention_backends["use_flash_attention"] = use_flash_attention + _attention_backends["use_fused_attention"] = use_fused_attention + _attention_backends["flash_attention_backend"] = flash_attention_backend + _attention_backends["fused_attention_backend"] = fused_attention_backend + _attention_backends["use_unfused_attention"] = use_unfused_attention + _attention_backends["backend_selection_requires_update"] = False + return available_backends, flash_attention_backend, fused_attention_backend + + backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} + if AttentionLogging._is_logging_setup is False: + AttentionLogging.setup_logging() + with logging_context(highest_level=AttentionLogging._log_level): + for i in range(3): + os.environ["NVTE_FUSED_ATTN_BACKEND"] = str(i) + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[backends[i]]: + fused_attn_backends.append(fused_attention_backend) + return available_backends, flash_attention_backend, fused_attn_backends diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 9d4701730..940c1d305 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -183,7 +183,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_MASK || attn_mask_type == NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) && (qkv_format == NVTE_QKV_Format::NVTE_BSHD || qkv_format == NVTE_QKV_Format::NVTE_SBHD) && - !requires_64bit_ragged_offset) { + !requires_64bit_ragged_offset && + // 9.10.0: known bugs with SDPA FP8 + (cudnn_runtime_version != 91000)) { if (cudnn_runtime_version >= 8900) { backend = NVTE_Fused_Attn_Backend::NVTE_FP8; } else { @@ -239,10 +241,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.9: any head_dim + Blackwell + fprop + non_paged + sq > 1 (!is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 90900 && max_seqlen_q > 1 && layout_group != NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD) || - // 9.10: any head_dim + any arch + fprop + paged - // 9.10: any head_dim + any arch + fprop + non_paged + sq > 1 - // 9.10: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} - (!is_training && cudnn_runtime_version >= 91000 && + // 9.10.2: any head_dim + any arch + fprop + paged + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq > 1 + // 9.10.2: any head_dim + any arch + fprop + non_paged + sq = 1 + {no_mask, padding, BRCM, padding_BRCM} + (!is_training && cudnn_runtime_version >= 91002 && (layout_group == NVTE_QKV_Layout_Group::NVTE_Paged_KV_HD_HD_HD || max_seqlen_q > 1 || (max_seqlen_q == 1 && attn_mask_type != NVTE_Mask_Type::NVTE_CAUSAL_MASK && attn_mask_type != NVTE_Mask_Type::NVTE_PADDING_CAUSAL_MASK))) || @@ -358,7 +360,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( max_seqlen_q <= max_seqlen_kv && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS && dropout == 0.0)))) && // check 64-bit ragged offset support - (supported_ragged_offset_size)) { + (supported_ragged_offset_size) && + // 9.10.0/9.10.1: known bugs with SDPA F16 + (cudnn_runtime_version != 91000) && (cudnn_runtime_version != 91001)) { flag_arb = true; } if (((max_seqlen_q > 512) || (max_seqlen_kv > 512)) && (flag_arb == true)) { From 315b47db7ff54384e36b483336b0bba34df78401 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 21 Jul 2025 18:01:40 -0700 Subject: [PATCH 006/153] [PyTorch] Debug linear layer when saving original input and using debug quantizer (#1963) * Debug linear layer when saving original input and using debug quantizer Signed-off-by: Tim Moon * Workaround bugs with quantizing with only column-wise usage Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unused imports Signed-off-by: Tim Moon * Avoid unnecessary row-wise data Signed-off-by: Tim Moon * Workaround bugs with quantizing with only column-wise usage FP8 does not support transpose-only cast. Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/module/linear.py | 65 ++++++++++++--------- 1 file changed, 39 insertions(+), 26 deletions(-) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index de55155b9..b1d4196df 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -65,8 +65,6 @@ ) from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer -from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState @@ -170,16 +168,19 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") if not isinstance(inputmat, QuantizedTensorBase): - input_quantizer.set_usage( - rowwise=True, columnwise=backward_needs_input and not save_original_input - ) + own_quantized_input = True + input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) if isinstance( input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) ): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) + if save_original_input: + # No need for column-wise data since this + # tensor will not be cached for backward pass + input_quantizer.set_usage(columnwise=False) + own_quantized_input = False inputmat = input_quantizer(inputmat) - own_quantized_input = True else: inputmat = cast_if_needed(inp, activation_dtype) # Cast for AMP @@ -344,23 +345,29 @@ def forward( inputmat = inp ctx.weight_quantizer = weight_quantizer - saved_inputmat = None ctx.backward_input_needs_gather = ( weight.requires_grad and parallel_mode == "column" and sequence_parallel ) + # Discard unneeded data in input tensor + if ( + backward_needs_input + and own_quantized_input + and isinstance(inputmat, QuantizedTensorBase) + ): + if ctx.backward_input_needs_gather and isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): + # All-gather is not supported with FP8 column-wise data + inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) + else: + # Discard row-wise data since it is not needed in backward pass + inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) + + # Cached input tensor + saved_inputmat = None if backward_needs_input: - if not save_original_input: - if own_quantized_input and isinstance(inputmat, QuantizedTensorBase): - # For sequence parallel in vanilla FP8, rowwise data is - # to gather the input. For MXFP8, columnwise only data - # can be allgathered. - if ( - isinstance(inputmat, (MXFP8TensorBase, Float8BlockwiseQTensorBase)) - or not ctx.backward_input_needs_gather - ): - inputmat.update_usage(rowwise_usage=False, columnwise_usage=True) saved_inputmat = inputmat # Weight with column-wise usage is needed for dgrad GEMM. @@ -572,20 +579,26 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], inputmat_total = None inputmat_total_work = None if ctx.requires_wgrad: - input_is_quantized = isinstance(inputmat, QuantizedTensorBase) if ctx.fp8 or ctx.debug: - if not input_is_quantized: + if isinstance(inputmat, QuantizedTensorBase): + # Input tensor is already quantized + pass + elif ctx.debug: + # Debug quantizer will be applied immediately before wgrad GEMM + pass + else: + # Quantize input tensor quantizer = ctx.input_quantizer - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): - quantizer.set_usage( - rowwise=True, - columnwise=not ctx.backward_input_needs_gather, - ) + if ctx.backward_input_needs_gather and isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + ): + # All-gather is not supported with FP8 column-wise data + quantizer.set_usage(rowwise=True, columnwise=False) else: - quantizer.set_usage(rowwise=False, columnwise=True) + quantizer.set_usage(rowwise=True, columnwise=True) inputmat = quantizer(inputmat) else: - if input_is_quantized: + if isinstance(inputmat, QuantizedTensorBase): inputmat = inputmat.dequantize(dtype=ctx.activation_dtype) else: inputmat = cast_if_needed(inputmat, ctx.activation_dtype) From cb504cda15cc8ea77d6603badee4575692cef29b Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Tue, 22 Jul 2025 09:44:15 +0200 Subject: [PATCH 007/153] [Common] Improved performance of mxfp8 cast kernels (#1628) * Fixed conflicts Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor code refactoring to avoid unnecessary checks Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Oleg Goncharov * Fixed dBias accumulation error due to initialization. Minor code refactoring Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Test case to reproduce the init error Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed rowwise dbias error Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Changed ptx API Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Added a struct for two packed FP8 values Signed-off-by: Oleg Goncharov * Rolled back to scalar code for columnwise scaling due to its better performance Signed-off-by: Oleg Goncharov * Minor corrections Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Rebased on main Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixes per code review Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Removed constexpr in C++ test suite to build faster Signed-off-by: Oleg Goncharov * Computed activations are now numerically truncated to InputType before scaling. Improved test suite. Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor refactoring Signed-off-by: Oleg Goncharov * Minor refactoring Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Modified mismatches checks of MXFP8 to address FP8 numerics Signed-off-by: Oleg Goncharov * Implemented Jeremy's fixes to JAX test suite with an intermediate downcast Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Reduced the dims of the test tensors to improve CI runtime Signed-off-by: Oleg Goncharov * Fixed memory alignment issue. Compute dbias without downcast. Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed misaligned memory issue also in gated kernels. Reduced size of MXFP8 gated tests Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Oleg Goncharov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/cpp/operator/test_cast_mxfp8.cu | 521 ++++---- .../operator/test_cast_mxfp8_gated_swiglu.cu | 436 +++---- tests/cpp/test_common.cu | 151 ++- tests/cpp/test_common.h | 18 +- tests/jax/test_custom_call_compute.py | 51 +- transformer_engine/common/CMakeLists.txt | 1 + transformer_engine/common/common.cu | 4 +- transformer_engine/common/common.h | 3 +- .../common/util/cast_gated_kernels.cuh | 1047 ++++++++++------- .../common/util/cast_kernels.cuh | 936 ++++++++------- .../common/util/dequantize_kernels.cuh | 6 +- transformer_engine/common/util/ptx.cuh | 200 ++++ transformer_engine/common/utils.cuh | 43 +- 13 files changed, 2026 insertions(+), 1391 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index bea988736..5a9423745 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -36,95 +36,34 @@ enum ActivationType { SReLU }; -template -void scale_block(const ProcessingMethod processing_method, +template +void compute_ref(const ProcessingMethod processing_method, + float (*OP)(const float), + const bool rowwise, + const bool colwise, const InputType* input, const InputType* grad, - OutputType* output_c, - float* dbias, - fp8e8m0* output_scales, - const size_t scale_idx, - const size_t i_min, - const size_t i_max, - const size_t j_min, - const size_t j_max, - const size_t cols) { - float amax = 0.0f; - - // Find the absolute maximum value in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - const size_t idx = i * cols + j; - float elt = static_cast(input[idx]); - if (processing_method == ProcessingMethod::CAST_DBIAS) { - // grad is the input - elt = static_cast(grad[idx]); - } - if (processing_method != ProcessingMethod::CAST_ONLY - && processing_method != ProcessingMethod::CAST_DBIAS) { - elt = OP(elt); - } - if (processing_method == ProcessingMethod::CAST_DACT || - processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(grad[idx]); - } - dbias[j] += elt; - if (isinf(elt) || isnan(elt)) { - continue; - } - amax = std::max(amax, std::abs(elt)); - } - } - - const fp8e8m0 biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_reciprocal()); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - output_scales[scale_idx] = biased_exponent; - - // Quantize elements in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - const size_t idx = i * cols + j; - float elt = static_cast(input[idx]); - if (processing_method == ProcessingMethod::CAST_DBIAS) { - // grad is the input - elt = static_cast(grad[idx]); - } - if (processing_method != ProcessingMethod::CAST_ONLY - && processing_method != ProcessingMethod::CAST_DBIAS) { - elt = OP(elt); - } - if (processing_method == ProcessingMethod::CAST_DACT || - processing_method == ProcessingMethod::CAST_DBIAS_DACT) { - elt *= static_cast(grad[idx]); - } - output_c[idx] = static_cast(elt * scale_reciprocal); - } - } -} - -template -void compute_ref_x1(const ProcessingMethod processing_method, - const InputType* input, - const InputType* grad, - OutputType* output_c, - fp8e8m0* output_scales, - InputType* output_dbias, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride) + OutputType* output_rowwise, + OutputType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + InputType* output_dbias, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise) { - const size_t tile_size_Y = std::max(32lu, block_size_Y); - const size_t tile_size_X = std::max(64lu, block_size_X); + const size_t tile_size_Y = 32; + const size_t tile_size_X = 32; const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; - const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; - const size_t blocks_per_tile_X = tile_size_X / block_size_X; std::vector output_dbias_fp32(cols, 0); #pragma omp parallel proc_bind(spread) { + // Buffers to cache intermediate computations + std::vector cache_buffer(tile_size_Y * tile_size_X); + std::vector thread_dbias(cols, 0); #pragma omp for schedule(static) for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { @@ -133,24 +72,83 @@ void compute_ref_x1(const ProcessingMethod processing_method, const size_t tile_offset_Y = tile_Y * tile_size_Y; const size_t tile_offset_X = tile_X * tile_size_X; - for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { - const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; - const size_t block_offset_Y = ii * block_size_Y; - const size_t i_min = tile_offset_Y + block_offset_Y; - if (i_min >= rows) continue; - const size_t i_max = std::min(i_min + block_size_Y, rows); - - for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { - const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; - const size_t block_offset_X = jj * block_size_X; - const size_t j_min = tile_offset_X + block_offset_X; - if (j_min >= cols) continue; - const size_t j_max = std::min(j_min + block_size_X, cols); - - const size_t scale_idx = block_idx_Y * scales_stride + block_idx_X; - scale_block( - processing_method, input, grad, output_c, thread_dbias.data(), - output_scales, scale_idx, i_min, i_max, j_min, j_max, cols); + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(i_min + tile_size_Y, rows); + + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(j_min + tile_size_X, cols); + + // Cache computations + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + + float elt = static_cast(input[idx]); + if (processing_method == ProcessingMethod::CAST_DBIAS) { + // grad is the input + elt = static_cast(grad[idx]); + } + if (processing_method != ProcessingMethod::CAST_ONLY + && processing_method != ProcessingMethod::CAST_DBIAS) { + elt = OP(elt); + } + if (processing_method == ProcessingMethod::CAST_DACT || + processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + elt *= static_cast(grad[idx]); + } + thread_dbias[j] += elt; + + // Numerical truncation: after downcast to InputType (BF16/FP16), upcast it back to FP32 + elt = static_cast(static_cast(elt)); + + cache_buffer[cache_idx] = elt; + if (isinf(elt) || isnan(elt)) { + continue; + } + } + } + + if (rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax = 0.0f; + + for (size_t j = j_min; j < j_max; ++j) { + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const int scale_idx = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t j = j_min; j < j_max; ++j) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_rowwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } + } + } + if (colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax = 0.0f; + + for (size_t i = i_min; i < i_max; ++i) { + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); + } + + const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); + const int scale_idx = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx] = biased_exponent; + const float scale_reciprocal = exp2f_rcp(biased_exponent); + + for (size_t i = i_min; i < i_max; ++i) { + const int idx = i * cols + j; + const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + output_colwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); + } } } } @@ -166,29 +164,6 @@ void compute_ref_x1(const ProcessingMethod processing_method, } } -template -void compute_ref_x2(const ProcessingMethod processing_method, - const InputType* input, - const InputType* grad, - OutputType* output_rowwise, - OutputType* output_colwise, - fp8e8m0* scales_rowwise, - fp8e8m0* scales_colwise, - InputType* output_dbias, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) { - compute_ref_x1( - processing_method, input, grad, output_rowwise, scales_rowwise, output_dbias, - rows, cols, 1, block_size_X, scales_stride_rowwise); - compute_ref_x1( - processing_method, input, grad, output_colwise, scales_colwise, output_dbias, - rows, cols, block_size_Y, 1, scales_stride_colwise); -} - /** * Scaling along single dimension (either rows or columns) * Produces one set of output data and the corresponding data of the fused operation (dbias): @@ -197,8 +172,9 @@ void compute_ref_x2(const ProcessingMethod processing_method, * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x1(const ProcessingMethod processing_method, + float (*OP)(const float), const std::vector& shape, const bool rowwise, const bool colwise, @@ -261,28 +237,46 @@ void performTest_x1(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output_c.data(), - output_dbias.data(), - workspace.data(), - 0); + auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output_c.data(), - output_dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output_c.data(), + output_dbias.data(), + workspace.data(), + 0); break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(grad.data(), input.data(), output_c.data(), 0); + auto nvte_dact = &nvte_dgelu; + if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + nvte_dact(grad.data(), input.data(), output_c.data(), 0); break; } case ProcessingMethod::CAST_ACT: { - nvte_gelu(input.data(), output_c.data(), 0); + auto nvte_act = &nvte_gelu; + if (OP == &silu) { nvte_act = &nvte_silu; } + else if (OP == &relu) { nvte_act = &nvte_relu; } + else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + nvte_act(input.data(), output_c.data(), 0); break; } } @@ -291,29 +285,45 @@ void performTest_x1(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - compute_ref_x1(processing_method, - input.rowwise_cpu_dptr(), - grad.rowwise_cpu_dptr(), - ref_output_c.get(), - ref_output_scales.get(), - ref_output_dbias.get(), - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride); - - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol); + compute_ref(processing_method, + OP, + rowwise, + colwise, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c.get(), + ref_output_c.get(), + ref_output_scales.get(), + ref_output_scales.get(), + ref_output_dbias.get(), + rows, + cols, + scales_stride, + scales_stride); const uint8_t * const gpu_scales_ptr = rowwise ? output_c.rowwise_cpu_scale_inv_ptr() : output_c.columnwise_cpu_scale_inv_ptr(); + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + + size_t mismatches_scales = 0; compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts = 32 * mismatches_scales; + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts); - if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + if (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + { auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -332,8 +342,9 @@ void performTest_x1(const ProcessingMethod processing_method, * AND * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x2(const ProcessingMethod processing_method, + float (*OP)(const float), const std::vector& shape, const size_t block_size_rows, const size_t block_size_cols, @@ -401,28 +412,46 @@ void performTest_x2(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS_DACT: { - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); + auto nvte_quantize_dbias_dact = &nvte_quantize_dbias_dgelu; + if (OP == &dsilu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsilu; } + else if (OP == &drelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_drelu; } + else if (OP == &dqgelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dqgelu; } + else if (OP == &dsrelu) { nvte_quantize_dbias_dact = &nvte_quantize_dbias_dsrelu; } + + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_quantize_dbias_dgelu(grad.data(), - input.data(), - output.data(), - output_dbias.data(), - workspace.data(), - 0); + nvte_quantize_dbias_dact(grad.data(), + input.data(), + output.data(), + output_dbias.data(), + workspace.data(), + 0); break; } case ProcessingMethod::CAST_DACT: { - nvte_dgelu(grad.data(), input.data(), output.data(), 0); + auto nvte_dact = &nvte_dgelu; + if (OP == &dsilu) { nvte_dact = &nvte_dsilu; } + else if (OP == &drelu) { nvte_dact = &nvte_drelu; } + else if (OP == &dqgelu) { nvte_dact = &nvte_dqgelu; } + else if (OP == &dsrelu) { nvte_dact = &nvte_dsrelu; } + + nvte_dact(grad.data(), input.data(), output.data(), 0); break; } case ProcessingMethod::CAST_ACT: { - nvte_gelu(input.data(), output.data(), 0); + auto nvte_act = &nvte_gelu; + if (OP == &silu) { nvte_act = &nvte_silu; } + else if (OP == &relu) { nvte_act = &nvte_relu; } + else if (OP == &qgelu) { nvte_act = &nvte_qgelu; } + else if (OP == &srelu) { nvte_act = &nvte_srelu; } + + nvte_act(input.data(), output.data(), 0); break; } } @@ -431,32 +460,54 @@ void performTest_x2(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - compute_ref_x2(processing_method, - input.rowwise_cpu_dptr(), - grad.rowwise_cpu_dptr(), - ref_output_c_rowwise.get(), - ref_output_c_colwise.get(), - ref_scales_rowwise.get(), - ref_scales_colwise.get(), - ref_output_dbias.get(), - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride_rowwise, - scales_stride_colwise); - - auto [atol, rtol] = getTolerances(otype); - compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol); + compute_ref(processing_method, + OP, + true, + true, + input.rowwise_cpu_dptr(), + grad.rowwise_cpu_dptr(), + ref_output_c_rowwise.get(), + ref_output_c_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_output_dbias.get(), + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise); + + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; + + size_t mismatches_scales_rowwise = 0; compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise); + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + size_t mismatches_scales_colwise = 0; compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise); + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; + const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; + + auto [atol, rtol] = getTolerances(otype); + compareResults("output_c_rowwise", output, ref_output_c_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise); + compareResults("output_c_colwise", output, ref_output_c_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise); - if (processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) { + if (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT) + { auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -475,11 +526,10 @@ std::vector> matrix_sizes = { {128, 128}, {256, 256}, {993, 512}, - {256, 65536}, - {2048, 6144}, - {16384, 128}, - {32768, 160}, - {4096, 1632}, + {511, 6144}, + {8192, 128}, + {2048, 160}, + {577, 1632}, {1024}, {8, 32, 1024}, {16, 8, 4, 512}, @@ -528,26 +578,6 @@ class FusedCastMXFP8TestSuite : public ::testing::TestWithParam transformer_engine::DType, InputsFillCase>> {}; -#define DACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ -switch (OP_FUNC_TYPE) { \ - case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ - case ActivationType::GeLU: { constexpr auto OP = &dgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SiLU: { constexpr auto OP = &dsilu; { __VA_ARGS__ } } break; \ - case ActivationType::ReLU: { constexpr auto OP = &drelu; { __VA_ARGS__ } } break; \ - case ActivationType::QGeLU: { constexpr auto OP = &dqgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SReLU: { constexpr auto OP = &dsrelu; { __VA_ARGS__ } } break; \ -} - -#define ACT_FUNC_SWITCH(OP_FUNC_TYPE, OP, ...) \ -switch (OP_FUNC_TYPE) { \ - case ActivationType::Identity: { constexpr auto OP = &identity; { __VA_ARGS__ } } break; \ - case ActivationType::GeLU: { constexpr auto OP = &gelu; { __VA_ARGS__ } } break; \ - case ActivationType::SiLU: { constexpr auto OP = &silu; { __VA_ARGS__ } } break; \ - case ActivationType::ReLU: { constexpr auto OP = &relu; { __VA_ARGS__ } } break; \ - case ActivationType::QGeLU: { constexpr auto OP = &qgelu; { __VA_ARGS__ } } break; \ - case ActivationType::SReLU: { constexpr auto OP = &srelu; { __VA_ARGS__ } } break; \ -} - TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { // Skip tests for pre-Blackwell architectures if (getDeviceComputeCapability() < blackwellComputeCapability) { @@ -581,35 +611,48 @@ TEST_P(FusedCastMXFP8TestSuite, TestFusedCastMXFP8) { const bool colwise = block_size.first != 1; if (processing_method == ProcessingMethod::CAST_ACT) { // Forward activations - ACT_FUNC_SWITCH(Act_type, OP, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - if (block_size.first == 1 || block_size.second == 1) { - performTest_x1( - processing_method, matrix_size, - rowwise, colwise, fill_case); - } else { - performTest_x2( - processing_method, matrix_size, - block_size.first, block_size.second, fill_case); - } - ); + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &gelu; break; + case ActivationType::SiLU: OP = &silu; break; + case ActivationType::ReLU: OP = &relu; break; + case ActivationType::QGeLU: OP = &qgelu; break; + case ActivationType::SReLU: OP = &srelu; break; + } + + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, OP, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, OP, matrix_size, + block_size.first, block_size.second, fill_case); + } ); ); } else { - DACT_FUNC_SWITCH(Act_type, OP, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, - if (block_size.first == 1 || block_size.second == 1) { - performTest_x1( - processing_method, matrix_size, - rowwise, colwise, fill_case); - } else { - performTest_x2( - processing_method, matrix_size, - block_size.first, block_size.second, fill_case); - } - ); + auto OP = &identity; + switch (Act_type) { + case ActivationType::GeLU: OP = &dgelu; break; + case ActivationType::SiLU: OP = &dsilu; break; + case ActivationType::ReLU: OP = &drelu; break; + case ActivationType::QGeLU: OP = &dqgelu; break; + case ActivationType::SReLU: OP = &dsrelu; break; + } + TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OutputType, + if (block_size.first == 1 || block_size.second == 1) { + performTest_x1( + processing_method, OP, matrix_size, + rowwise, colwise, fill_case); + } else { + performTest_x2( + processing_method, OP, matrix_size, + block_size.first, block_size.second, fill_case); + } ); ); } diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 2b22942f8..3c7b8c8b7 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -18,107 +18,32 @@ using namespace test; namespace { -template -void scale_block(const IType* grad, +template +void compute_ref(const IType* grad, const IType* input, - OType* output, - fp8e8m0* output_scales, - const size_t scale_idx, - const size_t scale_idx_gate, - float& thread_amax, - const size_t i_min, - const size_t i_max, - const size_t j_min, - const size_t j_max, - const size_t cols) { - - float block_amax = 0.0f; - float block_amax_gate = 0.0f; - const size_t stride = cols * 2; - - // Find the absolute maximum value in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - float silu_elt = static_cast(input[i * stride + j]); - float gate_elt = static_cast(input[i * stride + cols + j]); - float gated_amax_act = 0; - float gated_amax_gate = 0; - - if constexpr (IS_DGATED) { - const float grad_elt = static_cast(grad[i * cols + j]); - const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; - const float after_dgate = silu(silu_elt) * grad_elt; - gated_amax_act = abs(after_dsilu); - gated_amax_gate = abs(after_dgate); - } else { - const float after_silu = silu(silu_elt) * gate_elt; - gated_amax_act = abs(after_silu); - } - - if (gated_amax_act > block_amax) { block_amax = gated_amax_act; } - if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; } - } - } - - const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * - Quantized_Limits::max_reciprocal()); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - output_scales[scale_idx] = biased_exponent; - float scale_reciprocal_gate = 1; - if constexpr (IS_DGATED) { - const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate * - Quantized_Limits::max_reciprocal()); - scale_reciprocal_gate = exp2f_rcp(biased_exponent); - output_scales[scale_idx_gate] = biased_exponent; - } - - - // Quantize elements in the block - for (size_t i = i_min; i < i_max; ++i) { - for (size_t j = j_min; j < j_max; ++j) { - float silu_elt = static_cast(input[i * stride + j]); - float gate_elt = static_cast(input[i * stride + cols + j]); - - if constexpr (IS_DGATED) { - const float grad_elt = static_cast(grad[i * cols + j]); - const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt; - const float after_dgate = silu(silu_elt) * grad_elt; - output[i * stride + j] = static_cast(after_dsilu * scale_reciprocal); - output[i * stride + cols + j] = static_cast(after_dgate * - scale_reciprocal_gate); - } else { - const float after_silu = silu(silu_elt) * gate_elt; - output[i * cols + j] = static_cast(after_silu * scale_reciprocal); - } - - } - } - thread_amax = std::max(thread_amax, block_amax); - thread_amax = std::max(thread_amax, block_amax_gate); -} - -template -void compute_ref_x1(const IType* grad, - const IType* input, - OType* output, - fp8e8m0* output_scales, - float& ref_amax, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride) { - const size_t tile_size_Y = std::max(32lu, block_size_Y); - const size_t tile_size_X = std::max(64lu, block_size_X); + OType* output_rowwise, + OType* output_colwise, + fp8e8m0* output_scales_rowwise, + fp8e8m0* output_scales_colwise, + float& ref_amax, + const bool IS_DGATED, + const size_t rows, + const size_t cols, + const size_t scales_stride_rowwise, + const size_t scales_stride_colwise, + const bool is_rowwise, + const bool is_colwise) { + constexpr size_t tile_size_Y = 32; + constexpr size_t tile_size_X = 32; const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y; const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X; - const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y; - const size_t blocks_per_tile_X = tile_size_X / block_size_X; - float amax = 0; #pragma omp parallel reduction(max: amax) proc_bind(spread) { - float thread_amax = 0; + // Buffers to cache intermediate computations + std::vector cache_buffer_act(tile_size_Y * tile_size_X); + std::vector cache_buffer_gate(tile_size_Y * tile_size_X); + float thread_amax = 0.0f; #pragma omp for schedule(static) for (size_t t = 0; t < tiles_num_Y * tiles_num_X; ++t) { const size_t tile_Y = t / tiles_num_X; @@ -126,26 +51,124 @@ void compute_ref_x1(const IType* grad, const size_t tile_offset_Y = tile_Y * tile_size_Y; const size_t tile_offset_X = tile_X * tile_size_X; - for (size_t ii = 0; ii < blocks_per_tile_Y; ++ii) { - const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii; - const size_t block_offset_Y = ii * block_size_Y; - const size_t i_min = tile_offset_Y + block_offset_Y; - if (i_min >= rows) continue; - const size_t i_max = std::min(i_min + block_size_Y, rows); - - for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) { - const size_t block_idx_X = tile_X * blocks_per_tile_X + jj; - const size_t block_offset_X = jj * block_size_X; - const size_t j_min = tile_offset_X + block_offset_X; - if (j_min >= cols) continue; - const size_t j_max = std::min(j_min + block_size_X, cols); - - const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X; - const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X + - cols / block_size_X; - scale_block( - grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate, - thread_amax, i_min, i_max, j_min, j_max, cols); + const size_t stride = cols * 2; + + const size_t i_min = tile_offset_Y; + const size_t i_max = std::min(rows, tile_offset_Y + tile_size_Y); + const size_t j_min = tile_offset_X; + const size_t j_max = std::min(cols, tile_offset_X + tile_size_X); + + // Compute and cache activations for the entire tile + for (size_t i = i_min; i < i_max; ++i) { + for (size_t j = j_min; j < j_max; ++j) { + float silu_elt = static_cast(input[i * stride + j]); + float gate_elt = static_cast(input[i * stride + cols + j]); + + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + + if (IS_DGATED) { + const float x = silu_elt; + const float s = sigmoid(x); + const float act_x = x * s; + const float dact_x = x * s * (1 - s) + s; + + const float grad_elt = static_cast(grad[i * cols + j]); + float after_dsilu = dact_x * grad_elt * gate_elt; + float after_dgate = act_x * grad_elt; + + // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32 + after_dsilu = static_cast(static_cast(after_dsilu)); + after_dgate = static_cast(static_cast(after_dgate)); + + cache_buffer_act[cached_idx] = after_dsilu; + cache_buffer_gate[cached_idx] = after_dgate; + thread_amax = std::max(thread_amax, std::abs(after_dsilu)); + thread_amax = std::max(thread_amax, std::abs(after_dgate)); + } else { + float after_silu = silu(silu_elt) * gate_elt; + + // Numerical truncation: after downcast to IType (BF16/FP16), upcast it back to FP32 + after_silu = static_cast(static_cast(after_silu)); + + cache_buffer_act[cached_idx] = after_silu; + thread_amax = std::max(thread_amax, std::abs(after_silu)); + } + } + } + + if (is_rowwise) { + for (size_t i = i_min; i < i_max; ++i) { + float block_amax_act = 0.0f; + float block_amax_gate = 0.0f; + for (size_t j = j_min; j < j_max; ++j) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); + if (IS_DGATED) { + block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); + } + } + const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); + const int scale_idx_act = i * scales_stride_rowwise + tile_X; + output_scales_rowwise[scale_idx_act] = biased_exponent_act; + + float scale_reciprocal_gate; + if (IS_DGATED) { + const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); + scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); + const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32; + output_scales_rowwise[scale_idx_gate] = biased_exponent_gate; + } + for (size_t j = j_min; j < j_max; ++j) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; + + if (IS_DGATED) { + const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate; + output_rowwise[i * stride + j] = static_cast(after_act); + output_rowwise[i * stride + cols + j] = static_cast(after_gate); + } else { + output_rowwise[i * cols + j] = static_cast(after_act); + } + } + } + } + + if (is_colwise) { + for (size_t j = j_min; j < j_max; ++j) { + float block_amax_act = 0.0f; + float block_amax_gate = 0.0f; + for (size_t i = i_min; i < i_max; ++i) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); + if (IS_DGATED) { + block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); + } + } + const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); + const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); + const int scale_idx_act = tile_Y * scales_stride_colwise + j; + output_scales_colwise[scale_idx_act] = biased_exponent_act; + + float scale_reciprocal_gate; + if (IS_DGATED) { + const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); + const int scale_idx_gate = scale_idx_act + cols; + scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); + output_scales_colwise[scale_idx_gate] = biased_exponent_gate; + } + for (size_t i = i_min; i < i_max; ++i) { + const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; + + if (IS_DGATED) { + const float after_gate = cache_buffer_gate[cached_idx] * scale_reciprocal_gate; + output_colwise[i * stride + j] = static_cast(after_act); + output_colwise[i * stride + cols + j] = static_cast(after_gate); + } else { + output_colwise[i * cols + j] = static_cast(after_act); + } + } } } } @@ -156,26 +179,6 @@ void compute_ref_x1(const IType* grad, ref_amax = amax; } -template -void compute_ref_x2(const IType* grad, - const IType* input, - OType* output_rowwise, - OType* output_colwise, - fp8e8m0* scales_rowwise, - fp8e8m0* scales_colwise, - float& ref_amax, - const size_t rows, - const size_t cols, - const size_t block_size_Y, - const size_t block_size_X, - const size_t scales_stride_rowwise, - const size_t scales_stride_colwise) { - compute_ref_x1( - grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise); - compute_ref_x1( - grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise); -} - /** * Scaling along single dimension (either rows or columns) * Produces one set of output data and the corresponding data of the fused operation (dbias): @@ -183,12 +186,13 @@ void compute_ref_x2(const IType* grad, * OR * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x1(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols, - InputsFillCase fill_case) { + InputsFillCase fill_case, + const bool IS_DGATED) { using namespace test; using EncodingType = fp32; DType itype = TypeInfo::dtype; @@ -198,12 +202,6 @@ void performTest_x1(const size_t rows, const bool colwise = (block_size_rows == 32) && (block_size_cols == 1); NVTE_CHECK(rowwise || colwise); - // std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl; - // std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl; - // std::cout << "blocks_Y: " << blocks_Y << std::endl; - // std::cout << "blocks_X: " << blocks_X << std::endl; - // std::cout << "scales_stride: " << scales_stride << std::endl; - Tensor grad("grad", std::vector{ rows, cols }, itype); Tensor input("input", std::vector{ rows, cols * 2 }, itype); @@ -229,12 +227,12 @@ void performTest_x1(const size_t rows, } // fillCase(&grad, fill_case); - if constexpr (IS_DGATED) { + if (IS_DGATED) { fillUniform(&grad); } fillUniform(&input); - if constexpr (IS_DGATED) { + if (IS_DGATED) { nvte_dswiglu(grad.data(), input.data(), output.data(), 0); } else { nvte_swiglu(input.data(), output.data(), 0); @@ -245,30 +243,48 @@ void performTest_x1(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref_x1(grad.rowwise_cpu_dptr(), - input.rowwise_cpu_dptr(), - ref_output.get(), - ref_output_scales.get(), - ref_amax, - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride); - - auto [atol, rtol] = getTolerances(otype); - compareResults("output", output, ref_output.get(), rowwise, atol, rtol); + compute_ref(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output.get(), + ref_output.get(), + ref_output_scales.get(), + ref_output_scales.get(), + ref_amax, + IS_DGATED, + rows, + cols, + scales_stride, + scales_stride, + rowwise, + colwise); + + size_t mismatches_scales = 0; + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; const uint8_t * const gpu_scales_ptr = rowwise ? output.rowwise_cpu_scale_inv_ptr() : output.columnwise_cpu_scale_inv_ptr(); if (rowwise) { compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } else { compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), - unpadded_blocks_Y, unpadded_blocks_X, scales_stride); + unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); } + + const size_t mismatches_elts = 32 * mismatches_scales; + auto [atol, rtol] = getTolerances(otype); + compareResults("output", output, ref_output.get(), rowwise, atol, rtol, true, mismatches_elts); } /** @@ -278,12 +294,13 @@ void performTest_x1(const size_t rows, * AND * 2) Scaled columns + column-wise scaling factors */ -template +template void performTest_x2(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols, - InputsFillCase fill_case) { + InputsFillCase fill_case, + const bool IS_DGATED) { using namespace test; using EncodingType = fp32; DType itype = TypeInfo::dtype; @@ -325,12 +342,12 @@ void performTest_x2(const size_t rows, } // fillCase(&grad, fill_case); - if constexpr (IS_DGATED) { + if (IS_DGATED) { fillUniform(&grad); } fillUniform(&input); - if constexpr (IS_DGATED) { + if (IS_DGATED) { nvte_dswiglu(grad.data(), input.data(), output.data(), 0); } else { nvte_swiglu(input.data(), output.data(), 0); @@ -341,30 +358,49 @@ void performTest_x2(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref_x2(grad.rowwise_cpu_dptr(), - input.rowwise_cpu_dptr(), - ref_output_rowwise.get(), - ref_output_colwise.get(), - ref_scales_rowwise.get(), - ref_scales_colwise.get(), - ref_amax, - rows, - cols, - block_size_rows, - block_size_cols, - scales_stride_rowwise, - scales_stride_colwise); - - auto [atol, rtol] = getTolerances(otype); - auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); - compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol); - compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol); + compute_ref(grad.rowwise_cpu_dptr(), + input.rowwise_cpu_dptr(), + ref_output_rowwise.get(), + ref_output_colwise.get(), + ref_scales_rowwise.get(), + ref_scales_colwise.get(), + ref_amax, + IS_DGATED, + rows, + cols, + scales_stride_rowwise, + scales_stride_colwise, + true, + true); + + const size_t scale_diff_abs_tolerance = 0; + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; + + size_t mismatches_scales_rowwise = 0; compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, - unpadded_blocks_X_rowwise, scales_stride_rowwise); + unpadded_blocks_X_rowwise, scales_stride_rowwise, + mismatches_scales_rowwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + size_t mismatches_scales_colwise = 0; compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, - unpadded_blocks_X_colwise, scales_stride_colwise); + unpadded_blocks_X_colwise, scales_stride_colwise, + mismatches_scales_colwise, + scale_diff_abs_tolerance, + abs_tolerable_mismatches_limit, + rel_tolerable_mismatches_limit); + + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; + const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; + + auto [atol, rtol] = getTolerances(otype); + auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32); + compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol, true, mismatches_elts_rowwise); + compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol, true, mismatches_elts_colwise); } std::vector> matrix_sizes = { @@ -375,8 +411,8 @@ std::vector> matrix_sizes = { {256, 256}, {993, 512}, {768, 1024}, - {65504, 128}, - {16384, 1632}, + {8192, 128}, + {577, 1632}, }; std::vector> block_sizes = { @@ -393,9 +429,9 @@ std::vector input_scenarios = { // InputsFillCase::maxNorm_to_inf }; -std::vector is_dgated_op = { - true, - false +std::vector is_bwd_op = { + false, + true }; } // namespace @@ -427,21 +463,11 @@ TEST_P(CastMXFP8_GatedActTestSuite, TestCastMXFP8Swiglu) { TRANSFORMER_ENGINE_TYPE_SWITCH_FP16_FP32_ONLY(input_type, IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(output_type, OType, if (block_size.first == 1 || block_size.second == 1) { - if (IS_DGATED) { - performTest_x1(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } else { - performTest_x1(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } + performTest_x1(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case, IS_DGATED); } else { - if (IS_DGATED) { - performTest_x2(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } else { - performTest_x2(matrix_size.first, matrix_size.second, - block_size.first, block_size.second, fill_case); - } + performTest_x2(matrix_size.first, matrix_size.second, + block_size.first, block_size.second, fill_case, IS_DGATED); } ); ); @@ -456,7 +482,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16), ::testing::Values(DType::kFloat8E4M3, DType::kFloat8E5M2), ::testing::ValuesIn(input_scenarios), - ::testing::ValuesIn(is_dgated_op)), + ::testing::ValuesIn(is_bwd_op)), [](const testing::TestParamInfo& info) { std::string name = std::to_string(std::get<0>(info.param).first) + "X" + std::to_string(std::get<0>(info.param).second) + "X" + @@ -465,6 +491,6 @@ INSTANTIATE_TEST_SUITE_P( test::typeName(std::get<2>(info.param)) + "X" + test::typeName(std::get<3>(info.param)) + "X" + test::caseName(std::get<4>(info.param)) + "X" + - (std::get<5>(info.param) ? "DGATED" : "GATED"); + (std::get<5>(info.param) ? "BWD" : "FWD"); return name; }); diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 0f64d7c01..187742c39 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -523,10 +523,13 @@ std::vector unravel(const size_t i, const NVTEShape &shape) { void compareResults_sequential(const std::string &name, const Tensor &test, const void *ref, const bool rowwise, - double atol, double rtol, bool if_on_gpus) { + double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { if (if_on_gpus) test.to_cpu(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const size_t N = product(shape); + size_t mismatches_num = 0; + int first_mismatch_idx = -1; TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); @@ -547,80 +550,102 @@ void compareResults_sequential(const std::string &name, const Tensor &test, assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); } std::string direction = rowwise ? "rowwise" : "columnwise"; - ASSERT_FALSE(assertion) << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "Mismatch at place " << to_string(unravel(i, shape)) - << " (" << std::to_string(i) << "): " << t << " vs " << r; + if (assertion) { + mismatches_num++; + if (first_mismatch_idx == -1) { + first_mismatch_idx = i; + } + } + if (mismatches_num > tolerable_mismatches_limit) { + const double first_mismatch_t = static_cast(test_data[first_mismatch_idx]); + const double first_mismatch_r = static_cast(ref_data[first_mismatch_idx]); + + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "First mismatch at place " << to_string(unravel(first_mismatch_idx, shape)) + << " (" << std::to_string(first_mismatch_idx) << "): " + << first_mismatch_t << " vs " << first_mismatch_r; + } } ); } template static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, const T* ref_data, - const size_t N, const double atol, const double rtol) { + const size_t N, const double atol, const double rtol, + size_t& mismatches) { int first_mismatch_idx = N; - bool is_mismatch_found = false; - #pragma omp parallel for schedule(static) firstprivate(is_mismatch_found) \ - reduction(min: first_mismatch_idx) proc_bind(spread) - for (size_t i = 0; i < N; ++i) { - if (is_mismatch_found) { // early escape of the omp thread - continue; - } - - double t = static_cast(test_data[i]); - double r = static_cast(ref_data[i]); + #pragma omp parallel reduction(min: first_mismatch_idx) reduction(+: mismatches) proc_bind(spread) + { + size_t thread_mismatches = 0; + #pragma omp for schedule(static) + for (size_t i = 0; i < N; ++i) { + double t = static_cast(test_data[i]); + double r = static_cast(ref_data[i]); - bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); - /* For Float32 the floating point comparison is enough to error out */ - bool assertion = mismatch && (data_type == DType::kFloat32); - if (mismatch && !assertion) { - /* Check if it is just a failure of round to nearest choosing different - side of the real value */ - const double mean = (t + r) / 2; - const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); - const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); - const double cast_mean_p = static_cast(static_cast(mean_p)); - const double cast_mean_m = static_cast(static_cast(mean_m)); - assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); - } - if (assertion && i < first_mismatch_idx) { - first_mismatch_idx = i; - is_mismatch_found = true; + bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); + /* For Float32 the floating point comparison is enough to error out */ + bool assertion = mismatch && (data_type == DType::kFloat32); + if (mismatch && !assertion) { + /* Check if it is just a failure of round to nearest choosing different + side of the real value */ + const double mean = (t + r) / 2; + const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6); + const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6); + const double cast_mean_p = static_cast(static_cast(mean_p)); + const double cast_mean_m = static_cast(static_cast(mean_m)); + assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r)); + } + if (assertion) { + if (i < first_mismatch_idx) { + first_mismatch_idx = i; + } + thread_mismatches++; + } } + mismatches += thread_mismatches; } return first_mismatch_idx; } void compareResults_parallel(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus) { + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { if (if_on_gpus) test.to_cpu(); const auto& shape = rowwise ? test.rowwise_shape() : test.columnwise_shape(); const size_t N = product(shape); + size_t mismatches = 0; TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T, const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); - const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol); - if (i != N) { + const size_t i = getFirstMismatchIdx(test.dtype(), test_data, ref_data, N, atol, rtol, mismatches); + if ((i != N) && (mismatches > tolerable_mismatches_limit)) { const double t = static_cast(test_data[i]); const double r = static_cast(ref_data[i]); std::string direction = rowwise ? "rowwise" : "columnwise"; - ASSERT_FALSE(true) << "Error in tensor " << name << " in " - << direction << " direction." << std::endl - << "Mismatch at place " << to_string(unravel(i, shape)) - << " (" << std::to_string(i) << "): " << t << " vs " << r; + + GTEST_FAIL() << mismatches << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "." << std::endl + << "Error in tensor " << name << " in " + << direction << " direction." << std::endl + << "Mismatch at place " << to_string(unravel(i, shape)) + << " (" << std::to_string(i) << "): " << t << " vs " << r; } ); } void compareResults(const std::string &name, const Tensor &test, const void *ref, - const bool rowwise, double atol, double rtol, bool if_on_gpus) { + const bool rowwise, double atol, double rtol, bool if_on_gpus, + const size_t tolerable_mismatches_limit) { constexpr bool sequential = false; if constexpr (sequential) { - compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus); + compareResults_sequential(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); } else { - compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus); + compareResults_parallel(name, test, ref, rowwise, atol, rtol, if_on_gpus, tolerable_mismatches_limit); } } @@ -657,25 +682,39 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t } void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride) + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, const size_t atol, + const double abs_tolerable_mismatches_limit, + const double rel_tolerable_mismatches_limit) { + const size_t N = row_blocks * col_blocks; + const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, + std::floor(N * rel_tolerable_mismatches_limit)); + mismatches_num = 0; + std::vector mismatch_indices; + for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { const int idx = i * stride + j; - ASSERT_FALSE(test[idx] != ref[idx]) << "Error in " << name << std::endl - << "Mismatch: " << static_cast(test[idx]) << " vs " - << static_cast(ref[idx]) << " at index " << idx; - } - } -} + const int test_val = static_cast(test[idx]); + const int ref_val = static_cast(ref[idx]); + const int abs_delta = std::abs(test_val - ref_val); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t N) -{ - for (int i = 0; i < N; i++) { - ASSERT_FALSE(test[i] != ref[i]) << "Error in " << name << std::endl - << "Mismatch: " << static_cast(test[i]) << " vs " - << static_cast(ref[i]) << " at index " << i; + if (abs_delta > atol) { + mismatches_num++; + mismatch_indices.push_back(idx); + } + if (mismatches_num > tolerable_mismatches_limit) { + std::cout << "Error in " << name << std::endl; + for (const int index : mismatch_indices) { + std::cout << "Mismatch at (" << index << "):" + << static_cast(test[index]) << " vs " + << static_cast(ref[index]) << std::endl; + } + GTEST_FAIL() << mismatches_num << " mismatche(s) which is more than tolerable mismatch limit of " + << tolerable_mismatches_limit << "."; + } + } } } diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 3597c94d8..d1e273c6d 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -413,7 +413,12 @@ inline fp8e8m0 float_to_e8m0(float val) { } inline float exp2f_rcp(fp8e8m0 biased_exp) { - return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); + if (biased_exp == 0) { + return 1.0f; + } + int32_t int_val = (254 - biased_exp) << FP32_MANTISSA_BITS; // 127 - (biased_exp - 127) + float fp32_val = *reinterpret_cast(&int_val); + return fp32_val; } inline float identity(const float x) { return x; } @@ -445,15 +450,18 @@ size_t last_dimension(const std::vector &shape); bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2); void compareResults(const std::string &name, const Tensor &test, const void *ref, - bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true); + bool rowwise, double atol = 1e-5, double rtol = 1e-8, bool if_on_gpus = true, + const size_t tolerable_mismatches_limit = 0); void compareResults(const std::string &name, const float test, const float ref, double atol = 1e-5, double rtol = 1e-8); void compareResults(const std::string &name, const uint8_t *test, const uint8_t *ref, size_t N, float mismatch_rate_tol = 0.); void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride); -void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, - const size_t N); + const size_t row_blocks, const size_t col_blocks, const size_t stride, + size_t& mismatches_num, + const size_t scale_diff_abs_tolerance = 0, + const double abs_tolerable_mismatches_limit = 0, + const double rel_tolerable_mismatches_limit = 0); std::array get_scale_tensor_dims(const size_t rows, const size_t cols, const size_t block_size_rows, const size_t block_size_cols); diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index afe7edbe2..1e1467521 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -78,8 +78,14 @@ def is_shape_supported_by_mxfp8(input_shape): return False -def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): +def assert_bitwise_scaled_tensors( + a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True +): if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x): + if not precise_comparison: + assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype) + return + assert a.scaling_mode == b.scaling_mode assert a.scale_inv.dtype == b.scale_inv.dtype if a.scaling_mode.is_tensor_scaling(): @@ -94,8 +100,12 @@ def assert_bitwise_scaled_tensors(a: ScaledTensor, b: ScaledTensor): assert_allclose(a.data, b.data) elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x): - assert_bitwise_scaled_tensors(a.rowwise_tensor, b.rowwise_tensor) - assert_bitwise_scaled_tensors(a.colwise_tensor, b.colwise_tensor) + assert_bitwise_scaled_tensors( + a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison + ) + assert_bitwise_scaled_tensors( + a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison + ) else: pytest.fail("Unsupported input types") @@ -481,24 +491,7 @@ def _test_norm_forward( # if the input dtype is not float32 precise_comparison = False - if precise_comparison: - assert_bitwise_scaled_tensors(output, ref_out) - else: - if isinstance(ref_out, ScaledTensor1x): - assert_allclose(output.dequantize(), ref_out.dequantize(), dtype=out_dtype) - elif isinstance(ref_out, ScaledTensor2x): - assert_allclose( - output.rowwise_tensor.dequantize(), - ref_out.rowwise_tensor.dequantize(), - dtype=out_dtype, - ) - assert_allclose( - output.colwise_tensor.dequantize(), - ref_out.colwise_tensor.dequantize(), - dtype=out_dtype, - ) - else: - pytest.fail("Unsupported output type") + assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison) assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype) if norm_type == "layernorm": @@ -768,12 +761,24 @@ def _test_quantize_dact_dbias( )(dz, x) if is_casted_output: - assert_bitwise_scaled_tensors(te_output, jax_output) + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation + precise_comparison = not ( + in_dtype != jnp.float32 and scaling_mode.is_1d_block_scaling() + ) + assert_bitwise_scaled_tensors( + te_output, jax_output, precise_comparison=precise_comparison + ) else: assert_allclose(te_output, jax_output) if is_dbias: - assert_allclose(te_dbias, jax_dbias) + # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. + precise_comparison = not ( + in_dtype == jnp.bfloat16 and scaling_mode.is_1d_block_scaling() + ) + assert_allclose( + te_dbias, jax_dbias, dtype=in_dtype if precise_comparison else out_dtype + ) @pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES) @pytest_parametrize_wrapper("input_shape", ALL_ACTIVATION_SHAPES) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b276240fc..aff282214 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -192,6 +192,7 @@ if (NVTE_BUILD_ACTIVATION_WITH_FAST_MATH) set_source_files_properties(activation/gelu.cu activation/relu.cu activation/swiglu.cu + util/cast.cu PROPERTIES COMPILE_OPTIONS "--use_fast_math") endif() diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 192c915a8..619bf6ca0 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -162,10 +162,10 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, void *dataPtr = reinterpret_cast(reinterpret_cast(tensor.dptr) + (offset_elems * type_num_bits) / 8); - NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_gmem_alignment), + NVTE_CHECK(is_aligned_ptr(dataPtr, TMA_GMEM_ALIGNMENT), "Tensor data pointer must be 16B aligned"); - const int TMA_needed_size = (TMA_gmem_alignment * 8) / type_num_bits; + const int TMA_needed_size = (TMA_GMEM_ALIGNMENT * 8) / type_num_bits; NVTE_CHECK(globalX % TMA_needed_size == 0, "Shape not supported. For ", type_num_bits, "-bit data type, expected multiple of ", TMA_needed_size, ", got ", globalX); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 22b448a00..08001671d 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -668,7 +668,8 @@ constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; // Alignment requirements for the Tensor Memory Accelerator (TMA) -constexpr int TMA_gmem_alignment = 16; // global memory address alignment +constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment +constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment inline bool is_aligned_ptr(const void *ptr, size_t alignment) { return reinterpret_cast(ptr) % alignment == 0; diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index c24337dcd..82041d9f9 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -27,14 +27,8 @@ namespace transformer_engine { -template -__device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { - return DIVUP(static_cast(N), static_cast(M)) * M; -} - namespace gated_kernels { -constexpr size_t ALIGNMENT_SIZE = 128; constexpr size_t CHUNK_DIM_Y = 128; constexpr size_t CHUNK_DIM_X = 128; constexpr size_t THREADS_PER_CHUNK = 512; @@ -76,18 +70,19 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float amax = 0; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); constexpr size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; constexpr size_t buff_elems_total = BUFFERS_NUM * buff_elems; constexpr size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); constexpr size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); constexpr size_t grad_mem = IS_DGATED ? buff_size_aligned_in : 0; @@ -96,8 +91,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr size_t in_mem = in_act_mem + in_gate_mem; constexpr size_t out_act_mem = buff_size_aligned_out; - - // const size_t in_transaction_size = grad_mem + in_mem; constexpr size_t in_transaction_size = buff_elems * sizeof(IType); // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned @@ -269,9 +262,34 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +namespace mxfp8_kernel { + +constexpr size_t CHUNK_DIM_Y = 64; +constexpr size_t CHUNK_DIM_X = 64; +constexpr size_t THREADS_PER_CHUNK_COLWISE = 128; +constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = CHUNK_DIM_X; + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t BUFF_DIM_Y = 32; +constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; +static_assert(BUFF_DIM_Y == 32); + +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 + template + bool ROWWISE_SCALING, bool COLWISE_SCALING, size_t THREADS_PER_CHUNK> __global__ void __launch_bounds__(THREADS_PER_CHUNK) cast_mxfp8_gated_kernel(const __grid_constant__ CUtensorMap tensor_map_grad, const __grid_constant__ CUtensorMap tensor_map_input_act, @@ -284,43 +302,73 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = CHUNK_DIM_Y; // 128 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = CHUNK_DIM_X / SCALE_DIM_X; // 4 = 128 / 32 + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = CHUNK_DIM_Y / SCALE_DIM_Y; // 4 = 128 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = CHUNK_DIM_X; // 128 + constexpr bool IS_CACHED_ACT_OP = ROWWISE_SCALING && COLWISE_SCALING; + constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING); - const int scales_rowwise_chunk_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = blockIdx.x * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = blockIdx.y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = blockIdx.x * SCALES_COLWISE_PER_CHUNK_X; + // # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension. + constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X); - const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X; + const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols); + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X; + const int gate_scale_idx_offset_colwise = cols; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; - const bool col_out_of_bounds = (chunk_offset_X + thread_offset_X >= cols); + constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; + __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X]; - extern __shared__ char dshmem_unaligned[]; - const uint64_t dshmem_unaligned_as_uint = reinterpret_cast(dshmem_unaligned); - const uint64_t dshmem_aligned_as_uint = - DIVUP(dshmem_unaligned_as_uint, static_cast(ALIGNMENT_SIZE)) * ALIGNMENT_SIZE; - char *dshmem = reinterpret_cast(dshmem_aligned_as_uint); + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); - const size_t buff_elems = SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_elems_total = BUFFERS_NUM * buff_elems; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); @@ -329,12 +377,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const size_t in_mem = in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); const size_t out_mem = out_act_mem + out_gate_mem; - // const size_t in_transaction_size = grad_mem + in_mem; - const size_t in_transaction_size = (IS_DGATED ? 3 : 2) * buff_elems * sizeof(IType); - // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned IType *in_grad_sh = reinterpret_cast(dshmem); IType *in_act_sh = reinterpret_cast(dshmem + grad_mem); @@ -346,374 +391,493 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) OType *out_act_colwise_sh = out_act_rowwise_sh; OType *out_gate_colwise_sh = out_gate_rowwise_sh; - if constexpr (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + if constexpr (ROWWISE_SCALING && COLWISE_SCALING) { out_act_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem); out_gate_colwise_sh = reinterpret_cast(dshmem + grad_mem + in_mem + out_mem + out_act_mem); } - const uint64_t *TMAP_grad_in = reinterpret_cast(&tensor_map_grad); - const uint64_t *TMAP_in_act = reinterpret_cast(&tensor_map_input_act); - const uint64_t *TMAP_in_gate = reinterpret_cast(&tensor_map_input_gate); - const uint64_t *TMAP_output_act_rowwise = - reinterpret_cast(&tensor_map_output_act_rowwise); - const uint64_t *TMAP_output_gate_rowwise = - reinterpret_cast(&tensor_map_output_gate_rowwise); - const uint64_t *TMAP_output_act_colwise = - reinterpret_cast(&tensor_map_output_act_colwise); - const uint64_t *TMAP_output_gate_colwise = - reinterpret_cast(&tensor_map_output_gate_colwise); + IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations + IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values - __shared__ float stage_amax_sh[THREADS_PER_CHUNK_Y][CHUNK_DIM_X]; + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + + const bool is_master_thread = (threadIdx.x == 0); // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[ITERATIONS]; - - const bool is_master_thread = (threadIdx.x == 0); + __shared__ alignas(8) uint64_t mbar[STAGES]; - if (is_master_thread) { -// Initialize barrier. All `blockDim.x * blockDim.y` threads in block participate. -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_init(&mbar[it], THREADS_PER_CHUNK); - } - ptx::fence_proxy_async_shared_cta(); - } - // Syncthreads so initialized barrier is visible to all threads. - __syncthreads(); + initialize_barriers(mbar, is_master_thread); int parity = 0; - // Prefetch data of the first stage - if (is_master_thread) { - // Initiate bulk tensor copy - // Grad - if constexpr (IS_DGATED) { - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_grad_sh[0]), - TMAP_grad_in, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - } - - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_act_sh[0]), - TMAP_in_act, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared(reinterpret_cast(&in_gate_sh[0]), - TMAP_in_gate, chunk_offset_X, chunk_offset_Y, - &mbar[0]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[0], in_transaction_size); + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(&in_grad_sh[0], &tensor_map_grad, block_offset_X, block_offset_Y, + &in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, + shmem_buff_size, &mbar[0], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[0]); + copy_2d_to_sharedx2(&in_act_sh[0], &tensor_map_input_act, block_offset_X, block_offset_Y, + &in_gate_sh[0], &tensor_map_input_gate, block_offset_X, block_offset_Y, + shmem_buff_size, &mbar[0], is_master_thread); } #pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - const int buff = it % BUFFERS_NUM; - const int next_it = it + 1; - const size_t row_base = chunk_offset_Y + it * BUFFER_DIM_Y; - if (next_it < ITERATIONS) { - if (is_master_thread) { - const int next_buff = next_it % BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - // Initiate bulk tensor copy - if constexpr (IS_DGATED) { - // Grad - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_grad_sh[next_buff * buff_elems]), TMAP_grad_in, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - } - // Act - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_act_sh[next_buff * buff_elems]), TMAP_in_act, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - // Gate - ptx::cp_async_bulk_tensor_2d_global_to_shared( - reinterpret_cast(&in_gate_sh[next_buff * buff_elems]), TMAP_in_gate, - chunk_it_offset_x, chunk_it_offset_y, &mbar[next_it]); - - // Arrive on the barrier and tell how many bytes are expected to come in. - ptx::mbarrier_arrive_expect_tx(&mbar[next_it], in_transaction_size); + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; + + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); + + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_DIM; + if constexpr (IS_DGATED) { + copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, + global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, + global_offset_X, global_offset_Y, &in_gate_sh[next_buff_offset], + &tensor_map_input_gate, global_offset_X, global_offset_Y, + shmem_buff_size, &mbar[next_stage], is_master_thread); } else { - // Other threads just arrive - ptx::mbarrier_arrive(&mbar[next_it]); + copy_2d_to_sharedx2(&in_act_sh[next_buff_offset], &tensor_map_input_act, global_offset_X, + global_offset_Y, &in_gate_sh[next_buff_offset], &tensor_map_input_gate, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); } } ptx::fence_proxy_async_shared_cta(); // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[it], parity); + ptx::mbarrier_wait_parity(&mbar[stage], parity); - IType *in_grad_sh_curr = in_grad_sh + buff * buff_elems; - IType *in_act_sh_curr = in_act_sh + buff * buff_elems; - IType *in_gate_sh_curr = in_gate_sh + buff * buff_elems; - OType *out_act_rowwise_sh_curr = out_act_rowwise_sh + buff * buff_elems; - OType *out_gate_rowwise_sh_curr = out_gate_rowwise_sh + buff * buff_elems; - OType *out_act_colwise_sh_curr = out_act_colwise_sh + buff * buff_elems; - OType *out_gate_colwise_sh_curr = out_gate_colwise_sh + buff * buff_elems; - - // Assuming one iteration covers exactly 32 rows - const int iteration_scale_colwise_offset_Y = scales_colwise_chunk_offset_Y + it; - const int iteration_scale_rowwise_offset_Y = scales_rowwise_chunk_offset_Y + it * BUFFER_DIM_Y; - - float after_dact_reg[BUFFER_STAGES_NUM]; - float after_dgate_reg[BUFFER_STAGES_NUM]; - float thread_Y_mx_block_amax = 0.0f; - float thread_Y_mx_block_amax_gate = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = + buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise; + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + float after_act_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; + float after_gate_colwise[BUFF_DIM_Y / COLWISE_WAVEFRONT_SIZE]; +// 1. Read/Compute elements. Find MXFP8-block AMAX #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); - - float act_elt = static_cast(in_act_sh_curr[shmem_idx]); - float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); + for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const int shmem_offset_colwise = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; - if constexpr (IS_DGATED) { - float grad_elt = static_cast(in_grad_sh_curr[shmem_idx]); - const float x = act_elt; - float act_x; - float dact_x; + float act_elt = static_cast(in_act_sh[shmem_offset_colwise]); + float gate_elt = static_cast(in_gate_sh[shmem_offset_colwise]); + float after_act_elt; + float after_gate_elt; - if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { - const float s = sigmoidf(x); - act_x = x * s; - dact_x = x * s * (1 - s) + s; + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad_sh[shmem_offset_colwise]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; } else { - act_x = ActOP(x, {}); - dact_x = DActOP(x, {}); + after_act_elt = ActOP(act_elt, {}) * gate_elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } } - after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = act_x * grad_elt; - } else { - after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; - } - if constexpr (USE_ROWWISE_SCALING) { + after_act_colwise[i] = after_act_elt; if constexpr (IS_DGATED) { - // dgate - float amax = fabsf(after_dgate_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); - - out_gate_rowwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal_X * after_dgate_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + (tid_X + cols) / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; - } + after_gate_colwise[i] = after_gate_elt; } - float amax = fabsf(after_dact_reg[stage]); - const float mx_block_X_amax = warp_reduce_max_broadcast(amax); - const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); - - out_act_rowwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal_X * after_dact_reg[stage]); - - // Only single thread writes the computed scaling factor - if ((tid_X % SCALE_DIM_X == 0) && !out_of_bounds) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + thread_offset_Y; - const int global_scales_offset_X = scales_rowwise_chunk_offset_X + tid_X / SCALE_DIM_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent_X; + + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(after_act_elt); + if constexpr (IS_DGATED) { + cached_gate_sh[shmem_offset_colwise] = static_cast(after_gate_elt); + } } - } - if constexpr (USE_COLWISE_SCALING) { - __builtin_assume(thread_Y_mx_block_amax >= 0); - __builtin_assume(thread_Y_mx_block_amax_gate >= 0); - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); - if constexpr (IS_DGATED) { - thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } } } - } - - if constexpr (USE_COLWISE_SCALING) { - const bool row_out_of_bounds = (row_base >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); - if constexpr (IS_DGATED) { - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax_gate; + if constexpr (ONLY_COLWISE_SCALING) { + // Threads, whose id along Y-dim is 0, don't need to store to shared memory, + // as they manage the columwise reduction of the amax + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_act; } __syncthreads(); - if (tid_Y == 0) { + if (tid_Y_colwise == 0) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, stage_amax_sh[y][tid_X]); + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_act >= 0); + __builtin_assume(other_thread_amax >= 0); + + thread_amax_act = fmaxf(thread_amax_act, other_thread_amax); } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax_gate; // write mx column-block amax + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_act; } __syncthreads(); - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax + // All threads read the reduced amax (ACT) + thread_amax_act = subamax_colwise_buff[0][tid_X_colwise]; + + if constexpr (IS_DGATED) { + // Make sure the previous read of the ACT values has been completed, + // so the data are not rewritten + __syncthreads(); + if (tid_Y_colwise > 0) { + subamax_colwise_buff[tid_Y_colwise - 1][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); + if (tid_Y_colwise == 0) { +#pragma unroll + for (int t = 0; t < SUBAMAX_BUFF_DIM_Y; ++t) { + const float other_thread_amax = subamax_colwise_buff[t][tid_X_colwise]; + __builtin_assume(thread_amax_gate >= 0); + __builtin_assume(other_thread_amax >= 0); + + thread_amax_gate = fmaxf(thread_amax_gate, other_thread_amax); + } + subamax_colwise_buff[0][tid_X_colwise] = thread_amax_gate; + } + __syncthreads(); - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); + // All threads read the reduced amax (GATE) + thread_amax_gate = subamax_colwise_buff[0][tid_X_colwise]; } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; + const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; + + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { + scales_colwise[scale_idx] = biased_exponent_act; + } - const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X + cols; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + float block_scale_inverse_gate; + + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + // const int scale_idx_gate = scale_idx + scale_stride_colwise / 2; + const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; + if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { + scales_colwise[scale_idx_gate] = biased_exponent_gate; } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + } +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_gate_colwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal * after_dgate_reg[stage]); + for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { + const int shmem_offset_elt = + shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; + if constexpr (IS_DGATED) { + OType2 out_pair; + ptx::floatx2 in_pair = {after_act_colwise[i], after_gate_colwise[i]}; + const ptx::floatx2 block_scale_inverse_2x_pair = {block_scale_inverse_act, + block_scale_inverse_gate}; + ptx::mul_cvt_2x(out_pair, in_pair, block_scale_inverse_2x_pair); + out_act_colwise_sh[shmem_offset_elt] = out_pair.x; + out_gate_colwise_sh[shmem_offset_elt] = out_pair.y; + } else { + const float scaled_out_act = block_scale_inverse_act * after_act_colwise[i]; + out_act_colwise_sh[shmem_offset_elt] = static_cast(scaled_out_act); } } - // Colwise max reduction of the amax element - if (tid_Y > 0) { - stage_amax_sh[tid_Y][tid_X] = thread_Y_mx_block_amax; - } - __syncthreads(); - if (tid_Y == 0) { + } + + if constexpr (ROWWISE_SCALING) { + const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + + float thread_amax_act = 0.0f; + float thread_amax_gate = 0.0f; + + Vec in_cached_act[WAVES]; + Vec in_cached_gate[WAVES]; + + float after_act_rowwise[SCALE_DIM_X]; + float after_gate_rowwise[SCALE_DIM_X]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x_act = {static_cast(0.0f), static_cast(0.0f)}; + IType2 thread_amax_2x_gate = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached_act[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_cached_gate[w].load_from(&cached_gate_sh[shmem_offset_rowwise]); + } + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { #pragma unroll - for (int y = 1; y < THREADS_PER_CHUNK_Y; ++y) { - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, stage_amax_sh[y][tid_X]); + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(in_cached_act[w].data.elt[e])); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(in_cached_gate[w].data.elt[e])); + } + } + } else { +#pragma unroll + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x_act = {in_cached_act[w].data.elt[e], + in_cached_act[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x_act, thread_amax_2x_act, in_cached_2x_act); + if constexpr (IS_DGATED) { + const IType2 in_cached_2x_gate = {in_cached_gate[w].data.elt[e], + in_cached_gate[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x_gate, thread_amax_2x_gate, in_cached_2x_gate); + } + } + } + } } - stage_amax_sh[0][tid_X] = thread_Y_mx_block_amax; // write mx column-block amax - } - __syncthreads(); + if constexpr (!std::is_same_v) { + thread_amax_act = static_cast( + __hmax(__habs(thread_amax_2x_act.x), __habs(thread_amax_2x_act.y))); + if constexpr (IS_DGATED) { + thread_amax_gate = static_cast( + __hmax(__habs(thread_amax_2x_gate.x), __habs(thread_amax_2x_gate.y))); + } + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + Vec in_grad; + Vec in_act; + Vec in_gate; + + in_act.load_from(&in_act_sh[shmem_offset_rowwise]); + in_gate.load_from(&in_gate_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + in_grad.load_from(&in_grad_sh[shmem_offset_rowwise]); + } - const float mx_block_Y_amax = stage_amax_sh[0][tid_X]; // read the mx column-block amax +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + + float act_elt = static_cast(in_act.data.elt[e]); + float gate_elt = static_cast(in_gate.data.elt[e]); + float after_act_elt; + float after_gate_elt; + + if constexpr (IS_DGATED) { + float grad_elt = static_cast(in_grad.data.elt[e]); + const float x = act_elt; + float act_x; + float dact_x; + + if constexpr ((ActOP == &silu) && (DActOP == &dsilu)) { + const float s = sigmoidf(x); + act_x = x * s; + dact_x = x * s * (1 - s) + s; + } else { + act_x = ActOP(x, {}); + dact_x = DActOP(x, {}); + } + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; + after_act_rowwise[j] = after_act_elt; + after_gate_rowwise[j] = after_gate_elt; + } else { + after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_act_rowwise[j] = after_act_elt; + } + + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } + } + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + if (!out_of_bounds) { + thread_amax_act = fmaxf(thread_amax_act, fabsf(after_act_elt)); + if constexpr (IS_DGATED) { + thread_amax_gate = fmaxf(thread_amax_gate, fabsf(after_gate_elt)); + } + } + } + } + } - // For the scaling along both dimensions, the thread amax is already computed in ROWWISE section - if constexpr (!USE_ROWWISE_SCALING) { - __builtin_assume(mx_block_Y_amax >= 0); + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent_act = + ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; + const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx] = biased_exponent_act; } - const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); - - // Only single thread writes the computed scaling factor - // Also assuming one iteration covers exactly 32 rows - if ((tid_Y == 0) && !out_of_bounds) { - const int global_scales_offset_Y = iteration_scale_colwise_offset_Y; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + const float block_scale_inverse_act = ptx::exp2f_rcp(biased_exponent_act); + const ptx::floatx2 block_scale_inverse_2x_act = {block_scale_inverse_act, + block_scale_inverse_act}; + + float block_scale_inverse_gate; + ptx::floatx2 block_scale_inverse_2x_gate; + if constexpr (IS_DGATED) { + const e8m0_t biased_exponent_gate = + ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); + const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; + if (!out_of_bounds_rowwise) { + scales_rowwise[scale_idx_gate] = biased_exponent_gate; + } + block_scale_inverse_gate = ptx::exp2f_rcp(biased_exponent_gate); + block_scale_inverse_2x_gate = {block_scale_inverse_gate, block_scale_inverse_gate}; } +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; - - out_act_colwise_sh_curr[shmem_idx] = - static_cast(scale_reciprocal * after_dact_reg[stage]); + for (int w = 0; w < WAVES; ++w) { + Vec out_act; + Vec out_gate; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in_act; + OType2 &out_act_pair = reinterpret_cast(out_act.data.elt[e]); + + if constexpr (IS_CACHED_ACT_OP) { + in_act.x = in_cached_act[w].data.elt[2 * e]; + in_act.y = in_cached_act[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in_act.x = after_act_rowwise[j]; + in_act.y = after_act_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_act_pair, in_act, block_scale_inverse_2x_act); + + if constexpr (IS_DGATED) { + IType2 in_gate; + OType2 &out_gate_pair = reinterpret_cast(out_gate.data.elt[e]); + + if constexpr (IS_CACHED_ACT_OP) { + in_gate.x = in_cached_gate[w].data.elt[2 * e]; + in_gate.y = in_cached_gate[w].data.elt[2 * e + 1]; + } else { + const int j = w * PACK_SIZE + 2 * e; + in_gate.x = after_gate_rowwise[j]; + in_gate.y = after_gate_rowwise[j + 1]; + } + ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); + } + } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); + if constexpr (IS_DGATED) { + out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); + } } - } // endif USE_COLWISE_SCALING + } - // Wait for shared memory writes to be visible to TMA engine (cross-proxy fence) + // Wait for shared memory writes to be visible to TMA engine. ptx::fence_proxy_async_shared_cta(); __syncthreads(); // After syncthreads, writes by all threads are visible to TMA engine. // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; - // dGeLU - if constexpr (USE_ROWWISE_SCALING) { + if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_act_rowwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_act_rowwise_sh_curr)); - + reinterpret_cast(&tensor_map_output_act_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_act_rowwise_sh[buff_offset])); if constexpr (IS_DGATED) { - // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate_rowwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_rowwise_sh_curr)); + reinterpret_cast(&tensor_map_output_gate_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_gate_rowwise_sh[buff_offset])); } } - - // dGeLU - if constexpr (USE_COLWISE_SCALING) { + if constexpr (COLWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_act_colwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_act_colwise_sh_curr)); - + reinterpret_cast(&tensor_map_output_act_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_act_colwise_sh[buff_offset])); if constexpr (IS_DGATED) { - // dGate ptx::cp_async_bulk_tensor_2d_shared_to_global( - TMAP_output_gate_colwise, chunk_it_offset_x, chunk_it_offset_y, - reinterpret_cast(out_gate_colwise_sh_curr)); + reinterpret_cast(&tensor_map_output_gate_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_gate_colwise_sh[buff_offset])); } } // Create a "bulk async-group" out of the previous bulk copy operation. ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); } } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - // Destroy the barriers. This invalidates the memory region of the barrier. - // If further computations were to take place in the kernel, this allows the - // memory location of the shared memory barrier to be reused. - if (is_master_thread) { -#pragma unroll - for (int it = 0; it < ITERATIONS; ++it) { - ptx::mbarrier_invalid(&mbar[it]); - } - } + parity ^= 1; + destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace mxfp8_kernel template @@ -771,17 +935,16 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); const size_t in_act_mem = buff_size_aligned_in; const size_t in_gate_mem = buff_size_aligned_in; const size_t out_act_mem = buff_size_aligned_out; const size_t out_gate_mem = buff_size_aligned_out; - // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); - const size_t shmem_size = ALIGNMENT_SIZE + grad_mem + (in_act_mem + in_gate_mem) + - (out_act_mem + out_gate_mem); // + mbar_mem; + const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; cudaFuncSetAttribute( cast_fp8_gated_kernel, @@ -809,16 +972,34 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); } - // TODO: Make more general - const size_t scale_dim_X_rowwise = USE_ROWWISE_SCALING ? 32 : 1; - const size_t scale_dim_Y_colwise = USE_COLWISE_SCALING ? 32 : 1; + ScalingType scaling_type; + if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!USE_ROWWISE_SCALING) && USE_COLWISE_SCALING) { + scaling_type = ScalingType::COLWISE; + } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { + scaling_type = ScalingType::BIDIMENSIONAL; + } const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; - const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; + constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; + constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; + + const size_t blocks_Y = DIVUP(rows, mxfp8_kernel::CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, mxfp8_kernel::CHUNK_DIM_X); + + constexpr size_t THREADS_PER_CHUNK_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_COLWISE; + constexpr size_t THREADS_PER_CHUNK_NON_COLWISE = mxfp8_kernel::THREADS_PER_CHUNK_NON_COLWISE; + const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) + ? THREADS_PER_CHUNK_COLWISE + : THREADS_PER_CHUNK_NON_COLWISE; + + const dim3 grid(blocks_X, blocks_Y); + const dim3 block_size(THREADS_PER_CHUNK); size_t scale_stride_rowwise = USE_ROWWISE_SCALING ? output->scale_inv.shape[1] : 1; size_t scale_stride_colwise = USE_COLWISE_SCALING ? output->columnwise_scale_inv.shape[1] : 1; @@ -828,94 +1009,122 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out e8m0_t *const scales_colwise_ptr = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_scale_inv.dptr) : nullptr; - const dim3 block_dim(THREADS_PER_CHUNK); - const dim3 grid_dim(blocks_X, blocks_Y); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + gated_input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_grad{}; + alignas(64) CUtensorMap tensor_map_input_act{}; + alignas(64) CUtensorMap tensor_map_input_gate{}; + alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_act_colwise{}; + alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + if constexpr (IS_DGATED) { + create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + } + + const uint32_t tensor_stride_elems = output_cols; + create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, 0, input_type_bit_size); + create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols * 2, cols, input_type_bit_size); + + if (USE_ROWWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } + + if (USE_COLWISE_SCALING) { + create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, 0, + output_type_bit_size); + create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, rows, + cols, BUFF_DIM_Y, BUFF_DIM_X, tensor_stride_elems, cols, + output_type_bit_size); + } + + const size_t buff_elems_total = BUFFS_NUM * BUFF_DIM_Y * BUFF_DIM_X; + const size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + const size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + const size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + const size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); + const size_t in_act_mem = buff_size_aligned_in; + const size_t in_gate_mem = buff_size_aligned_in; + const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - gated_input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_grad{}; - alignas(64) CUtensorMap tensor_map_input_act{}; - alignas(64) CUtensorMap tensor_map_input_gate{}; - alignas(64) CUtensorMap tensor_map_output_act_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_act_colwise{}; - alignas(64) CUtensorMap tensor_map_output_gate_colwise{}; - - if constexpr (IS_DGATED) { - create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, - SHMEM_DIM_X, cols, 0, typeToNumBits(gated_input.dtype())); - } - - const uint32_t tensor_stride_elems = output_cols; - create_2D_tensor_map(tensor_map_input_act, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, 0, - typeToNumBits(gated_input.dtype())); - create_2D_tensor_map(tensor_map_input_gate, gated_input.data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, cols * 2, cols, - typeToNumBits(gated_input.dtype())); - - if (USE_ROWWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_rowwise, output->data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, 0, - typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate_rowwise, output->data, rows, cols, - SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, cols, - typeToNumBits(output->dtype())); - } - - if (USE_COLWISE_SCALING) { - create_2D_tensor_map(tensor_map_output_act_colwise, output->columnwise_data, - rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - 0, typeToNumBits(output->dtype())); - create_2D_tensor_map(tensor_map_output_gate_colwise, output->columnwise_data, - rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X, tensor_stride_elems, - cols, typeToNumBits(output->dtype())); - } - - const size_t buff_elems_total = BUFFERS_NUM * SHMEM_DIM_Y * SHMEM_DIM_X; - const size_t buff_size_aligned_in = - DIVUP(buff_elems_total * sizeof(IType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - const size_t buff_size_aligned_out = - DIVUP(buff_elems_total * sizeof(OType), ALIGNMENT_SIZE) * ALIGNMENT_SIZE; - - const size_t grad_mem = (IS_DGATED ? buff_size_aligned_in : 0); - const size_t in_act_mem = buff_size_aligned_in; - const size_t in_gate_mem = buff_size_aligned_in; - const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; - - const size_t out_act_mem = buff_size_aligned_out; - const size_t out_gate_mem = buff_size_aligned_out; - size_t out_mem = out_act_mem + out_gate_mem; - if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } - - // const size_t mbar_mem = ITERATIONS * sizeof(uint64_t); - // const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem + mbar_mem; - - const size_t shmem_size = ALIGNMENT_SIZE + in_mem + out_mem; - - cudaFuncSetAttribute( - cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); - - cast_mxfp8_gated_kernel - <<>>( + const size_t out_act_mem = buff_size_aligned_out; + const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); + size_t out_mem = out_act_mem + out_gate_mem; + if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } + + const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::COLWISE: + cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise);); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) + scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + mxfp8_kernel::cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + + mxfp8_kernel::cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + }); // NOLINT(*) + ); // NOLINT(*) } template diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 610cbf41f..79209adf5 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -28,36 +28,25 @@ namespace transformer_engine { -constexpr size_t MXFP8_CHUNK_DIM_Y = 64; -constexpr size_t MXFP8_CHUNK_DIM_X = 64; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK_X = 1; -constexpr size_t MXFP8_CHUNKS_PER_BLOCK = MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNKS_PER_BLOCK_X; -constexpr size_t MXFP8_THREADS_PER_CHUNK = 64; -constexpr size_t MXFP8_BUFFERS_NUM = 2; -constexpr size_t MXFP8_PREFETCH_BUFFERS_NUM = 1; -static_assert(MXFP8_PREFETCH_BUFFERS_NUM < MXFP8_BUFFERS_NUM); - -constexpr size_t ELEMS_PER_THREAD = 16; -constexpr size_t MXFP8_BUFFER_DIM_Y = 32; // only 32 is supported -constexpr size_t MXFP8_BUFFER_DIM_X = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_SHMEM_DIM_Y = MXFP8_BUFFER_DIM_Y; // 32 -constexpr size_t MXFP8_SHMEM_DIM_X = MXFP8_BUFFER_DIM_X; // 64 - -constexpr size_t THREADS_PER_CHUNK_X_ROWWISE = - MXFP8_CHUNK_DIM_X / ELEMS_PER_THREAD; // 4 = 64 / 16 -constexpr size_t THREADS_PER_CHUNK_Y_ROWWISE = - MXFP8_THREADS_PER_CHUNK / THREADS_PER_CHUNK_X_ROWWISE; // 16 = 64 / 4 -constexpr size_t THREADS_PER_CHUNK_X_COLWISE = MXFP8_CHUNK_DIM_X; // 64 -constexpr size_t MXFP8_BUFF_STAGES_NUM = - MXFP8_BUFFER_DIM_Y / THREADS_PER_CHUNK_Y_ROWWISE; // 2 = 32 / 16 -constexpr size_t MXFP8_ITERATIONS = MXFP8_CHUNK_DIM_Y / MXFP8_BUFFER_DIM_Y; // 2 = 64 / 32 -static_assert(MXFP8_ITERATIONS >= MXFP8_PREFETCH_BUFFERS_NUM); +namespace mxfp8_kernel { + +constexpr size_t SCALE_DIM_Y = 32; +constexpr size_t SCALE_DIM_X = 32; + +constexpr size_t BUFFS_NUM = 2; +constexpr size_t PACK_SIZE = 4; +constexpr size_t WAVES = SCALE_DIM_X / PACK_SIZE; + +// Number of 1-byte elements that span 32 banks (4-byte each) of shared memory +constexpr size_t TOTAL_BANKS_WIDTH = (32 * 4) / 1; // 128 + +// Number of threads (rowwise scaling) that span 32 banks (4-byte banks) of shared memory +constexpr size_t THREADS_PER_BANK = TOTAL_BANKS_WIDTH / SCALE_DIM_X; // 4 = 128 / 32 template -__global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) + float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING, + bool COLWISE_SCALING, size_t CHUNK_DIM_Y, size_t CHUNK_DIM_X, size_t THREADS_PER_CHUNK> +__global__ void __launch_bounds__(THREADS_PER_CHUNK) cast_mxfp8_2D_kernel(const __grid_constant__ CUtensorMap tensor_map_input, const __grid_constant__ CUtensorMap tensor_map_act_input, const __grid_constant__ CUtensorMap tensor_map_output_rowwise, @@ -67,201 +56,341 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const size_t rows, const size_t cols, const size_t scale_stride_rowwise, const size_t scale_stride_colwise) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - if constexpr (!IS_DBIAS && !IS_DACT && !IS_ACT) { - if (noop != nullptr && noop[0] == 1.0f) return; + constexpr bool COMPUTE_ACTIVATIONS = IS_DACT || IS_ACT; + constexpr bool NO_ACTIVATIONS = !COMPUTE_ACTIVATIONS; + + using IType2 = typename ptx::FPx2; + using OType2 = typename ptx::FPx2; + + if constexpr (NO_ACTIVATIONS) { + if (noop != nullptr && noop[0] == 1.0f) { + return; + } } + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + constexpr size_t BUFF_DIM = BUFF_DIM_Y * BUFF_DIM_X; + static_assert(BUFF_DIM_Y == 32); + + constexpr size_t STAGES = CHUNK_DIM_Y / BUFF_DIM_Y; + static_assert(STAGES >= 1); + + constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; + + const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const int block_offset_X = blockIdx.x * CHUNK_DIM_X; + const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const int tid_Y_rowwise = threadIdx.x / THREADS_X; + const int tid_X_rowwise = threadIdx.x % THREADS_X; + const int tid_Y_colwise = 0; + const int tid_X_colwise = threadIdx.x; + + const int thread_offset_Y_rowwise = tid_Y_rowwise; + const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const int thread_offset_Y_colwise = tid_Y_colwise; + const int thread_offset_X_colwise = tid_X_colwise; + + const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + + const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); + + const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + + // helps resolving bank conflicts in shmem + const int thread_lane = threadIdx.x % THREADS_PER_WARP; + const int bank_group = thread_lane / THREADS_PER_BANK; + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(IType), TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(buff_elems_total * sizeof(OType), TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + constexpr size_t out_mem_rowwise = (ROWWISE_SCALING ? buff_size_aligned_out : 0); + + extern __shared__ char dynamic_shmem[]; + uintptr_t base_shmem_ptr = reinterpret_cast(dynamic_shmem); + // Manually align dynamic SHMEM per TMA requirements using padding + // __align__(128) Does not guarantee the pointer to be aligned! + uintptr_t dshmem = (base_shmem_ptr + TMA_SHMEM_ALIGNMENT - 1) & + ~(static_cast(TMA_SHMEM_ALIGNMENT - 1)); + + // The destination shared memory buffer of a bulk tensor operation should be 16-byte aligned + IType *in_sh = reinterpret_cast(dshmem); + IType *act_in_sh = reinterpret_cast(dshmem + elt_input_mem); + OType *out_rowwise_sh = reinterpret_cast(dshmem + in_mem); + OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); + IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer + + constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; - constexpr bool USE_ROWWISE_SCALING = SCALE_DIM_X > 1; - constexpr bool USE_COLWISE_SCALING = SCALE_DIM_Y > 1; - constexpr bool COMPUTE_DBIAS_IN_ROWWISE_SECTION = !USE_COLWISE_SCALING; - - constexpr size_t SCALES_ROWWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_ROWWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X / SCALE_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_Y = - SCALES_ROWWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_ROWWISE_PER_BLOCK_X = - SCALES_ROWWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 - - constexpr size_t SCALES_COLWISE_PER_CHUNK_Y = MXFP8_CHUNK_DIM_Y / SCALE_DIM_Y; // 2 = 64 / 32 - constexpr size_t SCALES_COLWISE_PER_CHUNK_X = MXFP8_CHUNK_DIM_X; // 64 = 64 / 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_Y = - SCALES_COLWISE_PER_CHUNK_Y * MXFP8_CHUNKS_PER_BLOCK_Y; // 2 = 2 * 1 - constexpr size_t SCALES_COLWISE_PER_BLOCK_X = - SCALES_COLWISE_PER_CHUNK_X * MXFP8_CHUNKS_PER_BLOCK_X; // 64 = 64 * 1 - - constexpr size_t THREADS_PER_SCALE_X_ROWWISE = - DIVUP(SCALE_DIM_X, ELEMS_PER_THREAD); // 2 = 32 / 16 - constexpr size_t SUBWARP_WIDTH = THREADS_PER_SCALE_X_ROWWISE; // 2 - - const int block_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y * MXFP8_CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X; - const int scales_rowwise_block_offset_Y = blockIdx.y * SCALES_ROWWISE_PER_BLOCK_Y; - const int scales_rowwise_block_offset_X = blockIdx.x * SCALES_ROWWISE_PER_BLOCK_X; - const int scales_colwise_block_offset_Y = blockIdx.y * SCALES_COLWISE_PER_BLOCK_Y; - const int scales_colwise_block_offset_X = blockIdx.x * SCALES_COLWISE_PER_BLOCK_X; - - const int tid_rowwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_ROWWISE; - const int tid_rowwise_X = threadIdx.x % THREADS_PER_CHUNK_X_ROWWISE; - // const int tid_colwise_Y = threadIdx.x / THREADS_PER_CHUNK_X_COLWISE; - const int tid_colwise_X = threadIdx.x % THREADS_PER_CHUNK_X_COLWISE; - - const int thread_offset_Y = tid_rowwise_Y; - const int thread_offset_X_rowwise = tid_rowwise_X * ELEMS_PER_THREAD; - // const int thread_offset_X_colwise = tid_colwise_X; - - const int dbias_rowwise_offset_Y = blockIdx.y * MXFP8_CHUNKS_PER_BLOCK_Y + tid_rowwise_Y; - const int dbias_rowwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + thread_offset_X_rowwise; - const int dbias_colwise_offset_Y = blockIdx.y; - const int dbias_colwise_block_offset_X = - blockIdx.x * MXFP8_CHUNKS_PER_BLOCK_X * MXFP8_CHUNK_DIM_X + tid_colwise_X; - const int dbias_stride = cols; + const bool is_master_thread = (threadIdx.x == 0); - Vec partial_dbias_rowwise[MXFP8_CHUNKS_PER_BLOCK_X]; - float partial_dbias_colwise[MXFP8_CHUNKS_PER_BLOCK_X]; + float partial_dbias_colwise = 0.0f; + float thread_dbias_rowwise[SCALE_DIM_X]; if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { #pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - partial_dbias_rowwise[i].clear(); - } - } else { -#pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - partial_dbias_colwise[i] = 0; - } + for (int j = 0; j < SCALE_DIM_X; ++j) { + thread_dbias_rowwise[j] = 0.0f; } } - // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - __shared__ alignas(128) IType in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) IType act_in_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) - OType out_rowwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - __shared__ alignas(128) - OType out_colwise_sh[MXFP8_BUFFERS_NUM][MXFP8_SHMEM_DIM_Y][MXFP8_SHMEM_DIM_X]; - - constexpr int shmem_buff_size = sizeof(in_sh) / MXFP8_BUFFERS_NUM; - - const bool is_master_thread = (threadIdx.x == 0); - - float block_amax = 0; + float block_amax = 0.0f; // Initialize shared memory barrier with the number of threads participating in the barrier. #pragma nv_diag_suppress static_var_with_dynamic_init - __shared__ alignas(8) uint64_t mbar[MXFP8_ITERATIONS]; + __shared__ alignas(8) uint64_t mbar[STAGES]; - initialize_barriers(mbar, is_master_thread); + initialize_barriers(mbar, is_master_thread); int parity = 0; -#pragma unroll - for (int chunk = 0; chunk < MXFP8_CHUNKS_PER_BLOCK; ++chunk) { - const int chunk_Y = chunk / MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_X = chunk % MXFP8_CHUNKS_PER_BLOCK_X; - const int chunk_offset_Y = block_offset_Y + chunk_Y * MXFP8_CHUNK_DIM_Y; - const int chunk_offset_X = block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; + if constexpr (IS_DACT) { + copy_2d_to_sharedx2(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, &act_in_sh[0], + &tensor_map_act_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } else { + copy_2d_to_shared(&in_sh[0], &tensor_map_input, block_offset_X, block_offset_Y, shmem_buff_size, + &mbar[0], is_master_thread); + } - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + chunk_X * MXFP8_CHUNK_DIM_X; +#pragma unroll + for (int stage = 0; stage < STAGES; ++stage) { + const int buff = stage % BUFFS_NUM; + const int next_stage = stage + 1; + const int stage_offset_Y = stage * BUFF_DIM_Y; - const int scales_rowwise_chunk_offset_Y = - scales_rowwise_block_offset_Y + chunk_Y * SCALES_ROWWISE_PER_CHUNK_Y; - const int scales_rowwise_chunk_offset_X = - scales_rowwise_block_offset_X + chunk_X * SCALES_ROWWISE_PER_CHUNK_X; - const int scales_colwise_chunk_offset_Y = - scales_colwise_block_offset_Y + chunk_Y * SCALES_COLWISE_PER_CHUNK_Y; - const int scales_colwise_chunk_offset_X = - scales_colwise_block_offset_X + chunk_X * SCALES_COLWISE_PER_CHUNK_X; + if (next_stage < STAGES) { + // Wait for TMA transfer to have finished reading shared memory. + // I.e. the buffer is ready to be written to + ptx::cp_async_bulk_wait_group_read<1>(); -#pragma unroll - for (int prefetch_buff = 0; prefetch_buff < MXFP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * MXFP8_BUFFER_DIM_Y; - const int chunk_stage_offset_X = chunk_offset_X; + const int next_buff = next_stage % BUFFS_NUM; + const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const int global_offset_Y = block_offset_Y + next_stage_offset_Y; + const int global_offset_X = block_offset_X; + const int next_buff_offset = next_buff * BUFF_DIM; if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, - chunk_stage_offset_X, chunk_stage_offset_Y, shmem_buff_size, - &mbar[prefetch_buff], is_master_thread); + copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, + global_offset_X, global_offset_Y, shmem_buff_size, &mbar[next_stage], + is_master_thread); } else { - copy_2d_to_shared(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, - chunk_stage_offset_Y, shmem_buff_size, &mbar[prefetch_buff], - is_master_thread); + copy_2d_to_shared(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, + global_offset_Y, shmem_buff_size, &mbar[next_stage], is_master_thread); } } + ptx::fence_proxy_async_shared_cta(); + + // Wait for the data to have arrived + ptx::mbarrier_wait_parity(&mbar[stage], parity); + + float thread_amax = 0.0f; + if constexpr (COLWISE_SCALING) { + const int shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + thread_amax = 0.0f; + float in_compute_colwise[BUFF_DIM_Y]; + IType in_colwise_IType[BUFF_DIM_Y]; + + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType thread_amax_f16 = static_cast(0.0f); #pragma unroll - for (int iter = 0; iter < MXFP8_ITERATIONS; ++iter) { - const int buff = iter % MXFP8_BUFFERS_NUM; - const int next_iter = iter + MXFP8_PREFETCH_BUFFERS_NUM; - const size_t row_base = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - - if (next_iter < MXFP8_ITERATIONS) { - const int next_buff = next_iter % MXFP8_BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - if constexpr (IS_DACT) { - copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, - chunk_it_offset_x, chunk_it_offset_y, shmem_buff_size, - &mbar[next_iter], is_master_thread); - } else { - copy_2d_to_shared(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, - chunk_it_offset_y, shmem_buff_size, &mbar[next_iter], is_master_thread); + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + in_colwise_IType[i] = in_sh[shmem_offset_colwise]; + thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); } - } + thread_amax = static_cast(thread_amax_f16); + } else { +#pragma unroll + for (int i = 0; i < BUFF_DIM_Y; ++i) { + const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; - ptx::fence_proxy_async_shared_cta(); + float elt = static_cast(in_sh[shmem_offset_colwise]); + if constexpr (IS_ACT) { + elt = OP(elt, {}); + } + if constexpr (IS_DACT) { + float act_in_elt = static_cast(act_in_sh[shmem_offset_colwise]); + elt *= OP(act_in_elt, {}); + } + if constexpr (IS_DBIAS) { + partial_dbias_colwise += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + // Cache computed activations to avoid computing them again in the 2nd pass along another dimension + if constexpr (IS_CACHED_ACT_OP) { + cached_act_sh[shmem_offset_colwise] = static_cast(elt); + } - // Wait for the data to have arrived - ptx::mbarrier_wait_parity(&mbar[iter], parity); + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y + i >= rows); + const bool out_of_bounds = (col_out_of_bounds_colwise || row_out_of_bounds_colwise); + if (!out_of_bounds) { + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + } else { + // If no activation, elt is 0 so we can safely do this + thread_amax = fmaxf(thread_amax, fabsf(elt)); + } + in_compute_colwise[i] = elt; + } + } + + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - if constexpr (USE_ROWWISE_SCALING) { - Vec in; - Vec act_in; - Vec out_c; + const int global_scales_offset_Y = scales_offset_Y_colwise + stage; + const int global_scales_offset_X = scales_offset_X_colwise; + const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + scales_colwise[scale_idx] = biased_exponent; - const int iteration_scale_rowwise_offset_Y = - scales_rowwise_chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; +// 3. Scale elements #pragma unroll - for (int stage = 0; stage < MXFP8_BUFF_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y_ROWWISE; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X_rowwise; + for (int i = 0; i < SCALE_DIM_Y; ++i) { + float in; + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = static_cast(in_colwise_IType[i]); + } else { + in = in_compute_colwise[i]; + } + const float scaled_out = in * block_scale_inverse; - const size_t row = row_base + shmem_offset_y; - const bool row_out_of_bounds = (row >= rows); + const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); + } + } - in.load_from(&in_sh[buff][shmem_offset_y][shmem_offset_x]); - if constexpr (IS_DACT) { - act_in.load_from(&act_in_sh[buff][shmem_offset_y][shmem_offset_x]); - } + if constexpr (ROWWISE_SCALING) { + const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + thread_amax = 0.0f; + float in_compute_rowwise[SCALE_DIM_X]; + Vec in_cached[WAVES]; - float thread_amax = 0; - float in_compute[ELEMS_PER_THREAD]; + // used as an IType container for BF16/FP16 --> MXFP8 CAST ONLY + Vec in_IType[WAVES]; + // 1. Read/Compute elements. Find MXFP8-block AMAX + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + // Load elements + in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_IType[w].data.elt[e]); + } + } + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } else if constexpr (IS_CACHED_ACT_OP) { + // ensures that all writes to cache made in the section above are visible to all threads + __syncthreads(); + IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); + + // Load cached elements + in_cached[w].load_from(&cached_act_sh[shmem_offset_rowwise]); + // Since TMA requirement for the data alignment is 16B (i.e. cols % 8 == 0, in case of BF16 elements) + // only single check (w.r.t. column direction) is sufficient to be sure the entire wave is inside the boundaries + if (!out_of_bounds) { + if constexpr (std::is_same_v) { +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + thread_amax = fmaxf(thread_amax, fabsf(in_cached[w].data.elt[e])); + } + } else { #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - const bool col_out_of_bounds = (dbias_rowwise_offset_X + j >= cols); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + for (int e = 0; e < PACK_SIZE; e += 2) { + const IType2 in_cached_2x = {in_cached[w].data.elt[e], + in_cached[w].data.elt[e + 1]}; + ptx::abs_max_2x(thread_amax_2x, thread_amax_2x, in_cached_2x); + } + } + } + } + if constexpr (!std::is_same_v) { + thread_amax = + static_cast(__hmax(__habs(thread_amax_2x.x), __habs(thread_amax_2x.y))); + } + } else { +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; - float elt = static_cast(in.data.elt[j]); + Vec in; + Vec act_in; + + in.load_from(&in_sh[shmem_offset_rowwise]); + if constexpr (IS_DACT) { + act_in.load_from(&act_in_sh[shmem_offset_rowwise]); + } +#pragma unroll + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + // Compute element + float elt = static_cast(in.data.elt[e]); if constexpr (IS_ACT) { elt = OP(elt, {}); } if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in.data.elt[j]); + float act_in_elt = static_cast(act_in.data.elt[e]); elt *= OP(act_in_elt, {}); } - if constexpr (IS_DBIAS && COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - if (!out_of_bounds) { - partial_dbias_rowwise[chunk_X].data.elt[j] += elt; - } - } - in_compute[j] = elt; - if constexpr (IS_ACT || IS_DACT) { + // If DBIAS was computed in the 1st pass (COLWISE) then no need to compute it again + if constexpr (IS_DBIAS && (!COLWISE_SCALING)) { + thread_dbias_rowwise[j] += elt; + } + // Numerical truncation: Downcast to IType (BF16/FP16), then upcast it back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } + if constexpr (COMPUTE_ACTIVATIONS) { + const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); + const bool swizzled_col_out_of_bounds = + (block_offset_X + swizzled_thread_idx >= cols); + const bool out_of_bounds = (row_out_of_bounds_rowwise || swizzled_col_out_of_bounds); if (!out_of_bounds) { thread_amax = fmaxf(thread_amax, fabsf(elt)); } @@ -269,196 +398,141 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) // If no activation, elt is 0 so we can safely do this thread_amax = fmaxf(thread_amax, fabsf(elt)); } + in_compute_rowwise[j] = elt; } - - __builtin_assume(block_amax >= 0); - __builtin_assume(thread_amax >= 0); - block_amax = fmaxf(block_amax, thread_amax); - - const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); - const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp); - - // Only single thread writes the computed scaling factor - if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { - const int global_scales_offset_Y = - iteration_scale_rowwise_offset_Y + stage_offset_Y + tid_rowwise_Y; - const int global_scales_offset_X = - scales_rowwise_chunk_offset_X + tid_rowwise_X / THREADS_PER_SCALE_X_ROWWISE; - const int scale_idx = - global_scales_offset_Y * scale_stride_rowwise + global_scales_offset_X; - scales_rowwise[scale_idx] = biased_exponent; - } - - const float block_scale_inverse = exp2f_rcp(biased_exponent); - -#pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - out_c.data.elt[j] = static_cast(in_compute[j] * block_scale_inverse); - } - out_c.store_to(&out_rowwise_sh[buff][shmem_offset_y][shmem_offset_x]); } } - if constexpr (USE_COLWISE_SCALING) { - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); - float in_compute[SCALE_DIM_Y]; + // 2. Compute E8M0 scaling factor + const e8m0_t biased_exponent = + ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); + const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const int stage_scales_offset_X = scales_offset_X_rowwise; + const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + scales_rowwise[scale_idx] = biased_exponent; - float amax = 0; -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - const size_t row = row_base + i; - const bool row_out_of_bounds = (row >= rows); - const bool out_of_bounds = (col_out_of_bounds || row_out_of_bounds); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); + const ptx::floatx2 block_scale_inverse_2x = {block_scale_inverse, block_scale_inverse}; - float elt = static_cast(in_sh[buff][i][tid_colwise_X]); - if constexpr (IS_ACT) { - elt = OP(elt, {}); - } - if constexpr (IS_DACT) { - float act_in_elt = static_cast(act_in_sh[buff][i][tid_colwise_X]); - elt *= OP(act_in_elt, {}); - } - if constexpr (IS_DBIAS) { - if (!out_of_bounds) { - partial_dbias_colwise[chunk_X] += elt; - } - } - in_compute[i] = elt; - if constexpr (IS_ACT || IS_DACT) { - if (!out_of_bounds) { - amax = fmaxf(amax, fabsf(elt)); - } + // 3. Scale elements +#pragma unroll + for (int w = 0; w < WAVES; ++w) { + Vec out; +#pragma unroll + for (int e = 0; e < PACK_SIZE / 2; ++e) { + IType2 in; + OType2 &out_pair = reinterpret_cast(out.data.elt[e]); + if constexpr (NO_ACTIVATIONS && (!IS_DBIAS) && (!std::is_same_v)) { + in = in_IType[w].data.elt[e]; + } else if constexpr (IS_CACHED_ACT_OP) { + in.x = in_cached[w].data.elt[2 * e]; + in.y = in_cached[w].data.elt[2 * e + 1]; } else { - // If no activation, elt is 0 so we can safely do this - amax = fmaxf(amax, fabsf(elt)); + const int j = w * PACK_SIZE + 2 * e; + in.x = in_compute_rowwise[j]; + in.y = in_compute_rowwise[j + 1]; } + ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); + } + } - __builtin_assume(block_amax >= 0); - __builtin_assume(amax >= 0); - block_amax = fmaxf(block_amax, amax); - - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp); + __builtin_assume(block_amax >= 0); + __builtin_assume(thread_amax >= 0); + block_amax = fmaxf(block_amax, thread_amax); - const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; - const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; - const int scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; - scales_colwise[scale_idx] = biased_exponent; + // Wait for shared memory writes to be visible to TMA engine. + ptx::fence_proxy_async_shared_cta(); + __syncthreads(); + // After syncthreads, writes by all threads are visible to TMA engine. - const float block_scale_inverse = exp2f_rcp(biased_exponent); -#pragma unroll - for (int i = 0; i < SCALE_DIM_Y; ++i) { - out_colwise_sh[buff][i][tid_colwise_X] = - static_cast(in_compute[i] * block_scale_inverse); - } + // Initiate TMA transfer to copy shared memory to global memory + if (is_master_thread) { + const int global_offset_Y = block_offset_Y + stage_offset_Y; + const int global_offset_X = block_offset_X; + const int buff_offset = buff * BUFF_DIM; + + if constexpr (ROWWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_rowwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_rowwise_sh[buff_offset])); } - - // Wait for shared memory writes to be visible to TMA engine. - ptx::fence_proxy_async_shared_cta(); - __syncthreads(); - // After syncthreads, writes by all threads are visible to TMA engine. - - // Initiate TMA transfer to copy shared memory to global memory - if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + iter * MXFP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; - if constexpr (USE_ROWWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_rowwise), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_rowwise_sh[buff])); - } - if constexpr (USE_COLWISE_SCALING) { - ptx::cp_async_bulk_tensor_2d_shared_to_global( - reinterpret_cast(&tensor_map_output_colwise), chunk_it_offset_x, - chunk_it_offset_y, reinterpret_cast(&out_colwise_sh[buff])); - } - // Create a "bulk async-group" out of the previous bulk copy operation. - ptx::cp_async_bulk_commit_group(); - - // Wait for TMA transfer to have finished reading shared memory. - ptx::cp_async_bulk_wait_group_read(); + if constexpr (COLWISE_SCALING) { + ptx::cp_async_bulk_tensor_2d_shared_to_global( + reinterpret_cast(&tensor_map_output_colwise), global_offset_X, + global_offset_Y, reinterpret_cast(&out_colwise_sh[buff_offset])); } - } - ptx::cp_async_bulk_wait_group_read<0>(); - __syncthreads(); - parity ^= 1; + // Create a "bulk async-group" out of the previous bulk copy operation. + ptx::cp_async_bulk_commit_group(); + } } - if constexpr (IS_DBIAS) { - if constexpr (COMPUTE_DBIAS_IN_ROWWISE_SECTION) { - constexpr size_t CZ = MXFP8_CHUNKS_PER_BLOCK_X; - constexpr size_t Y = THREADS_PER_CHUNK_Y_ROWWISE - 1; - constexpr size_t X = THREADS_PER_CHUNK_X_ROWWISE; - __shared__ float shmem_partial_dbias_rowwise[CZ][Y][X][ELEMS_PER_THREAD]; - - if (tid_rowwise_Y > 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { - partial_dbias_rowwise[c].store_to( - &shmem_partial_dbias_rowwise[c][tid_rowwise_Y - 1][tid_rowwise_X]); - } - } - __syncthreads(); + parity ^= 1; - if (tid_rowwise_Y == 0) { -#pragma unroll - for (int c = 0; c < MXFP8_CHUNKS_PER_BLOCK_X; ++c) { - Vec other_row_dbias; - const int dbias_rowwise_offset_X = dbias_rowwise_block_offset_X + c * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_rowwise_offset_Y * dbias_stride + dbias_rowwise_offset_X; + if constexpr (IS_DBIAS) { + float thread_partial_dbias = 0.0f; + if constexpr (COLWISE_SCALING) { + thread_partial_dbias = partial_dbias_colwise; + } else { + // Reusing dshmem (in_sh) as dbias buffer [HEIGHT x WIDTH] + // HEIGHT = THREADS_Y + // WIDTH = THREADS_X * (SCALE_DIM_X + 1) + // Added extra 1-element padding per thread_X to reduce bank conflicts + float *partial_dbias_rowwise = reinterpret_cast(dshmem); - const int left_bound = dbias_rowwise_offset_X; - const int right_bound = dbias_rowwise_offset_X + ELEMS_PER_THREAD - 1; + constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + const int shmem_thread_offset = + tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll - for (int i = 0; i < Y; ++i) { - other_row_dbias.load_from(&shmem_partial_dbias_rowwise[c][i][tid_rowwise_X]); + for (int w = 0; w < WAVES; ++w) { + const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll - for (int j = 0; j < ELEMS_PER_THREAD; ++j) { - partial_dbias_rowwise[c].data.elt[j] += other_row_dbias.data.elt[j]; - } - } - - // Vectorized store when all elements are inside the boundaries - if (right_bound < cols) { - partial_dbias_rowwise[c].store_to(&dbias_workspace[dbias_offset]); - } else if (left_bound < cols && right_bound >= cols) { - // Element-by-element store when some elements cross the boundaries - const int in_bound_elts_count = cols - left_bound; - partial_dbias_rowwise[c].store_to_elts(&dbias_workspace[dbias_offset], 0, - in_bound_elts_count); - } + for (int e = 0; e < PACK_SIZE; ++e) { + const int j = w * PACK_SIZE + e; + const int shmem_elt_idx = swizzled_group_offset + e; + partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; } } - } else { + __syncthreads(); #pragma unroll - for (int i = 0; i < MXFP8_CHUNKS_PER_BLOCK_X; ++i) { - const int dbias_colwise_offset_X = dbias_colwise_block_offset_X + i * MXFP8_CHUNK_DIM_X; - const int dbias_offset = dbias_colwise_offset_Y * dbias_stride + dbias_colwise_offset_X; - const bool col_out_of_bounds = (dbias_colwise_offset_X >= cols); - if (!col_out_of_bounds) { - dbias_workspace[dbias_offset] = partial_dbias_colwise[i]; - } + for (int i = 0; i < THREADS_Y; ++i) { + // Add extra element offset per MXFP8 scaling block [1x32] + const int scaling_block = threadIdx.x / SCALE_DIM_X; + thread_partial_dbias += + partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } + const int dbias_stride = cols; + const int dbias_offset_Y = blockIdx.y; + const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); + if (!col_out_of_bounds_dbias) { + dbias_workspace[dbias_idx] = thread_partial_dbias; + } } if (amax_ptr != nullptr) { const int warp_id = threadIdx.x / THREADS_PER_WARP; // Reduce the amax over the block - block_amax = reduce_max(block_amax, warp_id); + block_amax = reduce_max(block_amax, warp_id); } if (is_master_thread && amax_ptr != nullptr) { atomicMaxFloat(amax_ptr, block_amax); } - destroy_barriers(mbar, is_master_thread); + destroy_barriers(mbar, is_master_thread); #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) } +} // namespace mxfp8_kernel constexpr size_t FP8_CHUNK_DIM_Y = 128; constexpr size_t FP8_CHUNK_DIM_X = 128; @@ -507,9 +581,12 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(128) IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(128) IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - __shared__ alignas(128) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + IType act_in_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) + OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; @@ -678,8 +755,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; // The destination shared memory buffer of a bulk tensor operation should be 128-byte aligned - __shared__ alignas(128) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; - __shared__ alignas(128) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; @@ -921,6 +998,7 @@ template has_data(); bool use_colwise_scaling = output->has_columnwise_data(); checkCuDriverContext(stream); @@ -936,16 +1014,24 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, } CheckNoopTensor(*noop, "cast_noop"); - // TODO: Make more general - const size_t scale_dim_X_rowwise = use_rowwise_scaling ? 32 : 1; - const size_t scale_dim_Y_colwise = use_colwise_scaling ? 32 : 1; - const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); - const size_t chunks_Y = DIVUP(rows, MXFP8_CHUNK_DIM_Y); - const size_t chunks_X = DIVUP(cols, MXFP8_CHUNK_DIM_X); - const size_t blocks_Y = DIVUP(chunks_Y, MXFP8_CHUNKS_PER_BLOCK_Y); - const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X); + + constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); + + constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t CHUNK_DIM_X = CAST_DBIAS_ONLY ? 128 : 64; + constexpr size_t THREADS_PER_CHUNK = CAST_DBIAS_ONLY ? 128 : 64; + + constexpr size_t THREADS_X = CHUNK_DIM_X / SCALE_DIM_X; + constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; + constexpr size_t BUFF_DIM_Y = THREADS_Y; + constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); + const dim3 grid(blocks_X, blocks_Y); + const size_t block_size = THREADS_PER_CHUNK; const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1; const size_t scale_stride_colwise = @@ -958,6 +1044,15 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; + ScalingType scaling_type; + if (use_rowwise_scaling && (!use_colwise_scaling)) { + scaling_type = ScalingType::ROWWISE; + } else if ((!use_rowwise_scaling) && use_colwise_scaling) { + scaling_type = ScalingType::COLWISE; + } else if (use_rowwise_scaling && use_colwise_scaling) { + scaling_type = ScalingType::BIDIMENSIONAL; + } + if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); NVTE_CHECK(dbias->data.shape == std::vector{cols}, "Wrong shape of DBias."); @@ -972,58 +1067,107 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; float *const amax_ptr = reinterpret_cast(output->amax.dptr); + const float *noop_ptr = reinterpret_cast(noop->data.dptr); - const dim3 block(MXFP8_THREADS_PER_CHUNK); - const dim3 grid(blocks_X, blocks_Y); + TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( + input.dtype(), IType, + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + output->dtype(), OType, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input.data, rows, cols, BUFF_DIM_Y, BUFF_DIM_X, + cols, 0, input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, BUFF_DIM_Y, + BUFF_DIM_X, cols, 0, output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, cols, + BUFF_DIM_Y, BUFF_DIM_X, cols, 0, output_type_bit_size); + } - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_Y_colwise, SCALE_DIM_Y, - TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( - scale_dim_X_rowwise, SCALE_DIM_X, - TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( - input.dtype(), IType, - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - create_2D_tensor_map(tensor_map_input, input.data, rows, cols, MXFP8_SHMEM_DIM_Y, - MXFP8_SHMEM_DIM_X, cols, 0, typeToNumBits(input.dtype())); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, act_input->data, rows, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(input.dtype())); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, rows, cols, - MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(output->dtype())); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, rows, - cols, MXFP8_SHMEM_DIM_Y, MXFP8_SHMEM_DIM_X, cols, 0, - typeToNumBits(output->dtype())); - } - - cast_mxfp8_2D_kernel<<>>( + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = mxfp8_kernel::BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + switch (scaling_type) { + case ScalingType::ROWWISE: + cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_mxfp8_2D_kernel + <<>>( tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, - reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, - rows, cols, scale_stride_rowwise, scale_stride_colwise); - - if constexpr (IS_DBIAS) { - reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - }); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) - ); // NOLINT(*) + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::COLWISE: + cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + case ScalingType::BIDIMENSIONAL: + cudaFuncSetAttribute( + cast_mxfp8_2D_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + break; + } + + if constexpr (IS_DBIAS) { + reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + }); // NOLINT(*) + ); // NOLINT(*) } namespace detail { @@ -1117,8 +1261,8 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons case NVTE_DELAYED_TENSOR_SCALING: { if (!IS_DBIAS && !IS_DACT) { if (is_full_tile_1D_tensor(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_gmem_alignment) && - is_aligned_tensor_data(*output, TMA_gmem_alignment)) { + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT)) { // Aligned AND FP8 cast_fp8_1D(input, output, stream); } else { @@ -1127,9 +1271,9 @@ void fp8_quantize_arch_ge_100(const Tensor &input, const Tensor *act_input, cons } } else if (!IS_DBIAS && IS_DACT) { if (dimensions_supported_by_TMA(output) && is_fp8_dtype(output->dtype()) && - is_aligned_tensor_data(input, TMA_gmem_alignment) && - is_aligned_tensor_data(*output, TMA_gmem_alignment) && - is_aligned_tensor_data(*act_input, TMA_gmem_alignment)) { + is_aligned_tensor_data(input, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*output, TMA_GMEM_ALIGNMENT) && + is_aligned_tensor_data(*act_input, TMA_GMEM_ALIGNMENT)) { // Aligned AND FP8 (+dAct) cast_fp8_2D(input, act_input, output, dbias, workspace, stream); diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index e716065ab..a82f11307 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -84,8 +84,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // const int thread_offset_X_colwise = tid_colwise_X; // The destination shared memory buffer of a bulk tensor operation should be 128 e8m0_t aligned - __shared__ alignas(128) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; - __shared__ alignas(128) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; + __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[BUFFERS_NUM][SHMEM_DIM_Y][SHMEM_DIM_X]; constexpr int shmem_buff_size = sizeof(in_sh) / BUFFERS_NUM; constexpr int transaction_size = shmem_buff_size; @@ -166,7 +166,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const e8m0_t biased_exponent = scales_ptr[scale_idx]; - const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float block_scale = ptx::exp2f(biased_exponent); if constexpr (USE_ROWWISE_SCALING) { Vec in; diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 55bc247f7..581de9f9f 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -104,6 +104,53 @@ __device__ __forceinline__ void mbarrier_wait_parity(uint64_t *mbar, const uint3 #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) +constexpr uint32_t FP32_MANTISSA_BITS = 23; +constexpr uint32_t FP32_EXPONENT_BIAS = 127; + +__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { + return (biased_exp == 0) ? 1 + : __int_as_float((254 - biased_exp) + << FP32_MANTISSA_BITS); // 127 - (biased_exp - 127) +} + +__device__ __forceinline__ float exp2f(e8m0_t biased_exp) { + return __int_as_float(biased_exp << FP32_MANTISSA_BITS); +} + +__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { +#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ + (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) + uint16_t out; + asm volatile( + "{\n" + "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" + "}" + : "=h"(out) + : "f"(val)); + return *reinterpret_cast(&out); +#else + // TODO: nan/inf needs to be set for any value + // of nan/inf in input not just amax. + if (isnan(val)) { + return 0xFF; + } + if (isinf(val)) { + return 0xFE; + } + if (val == 0.0f) { + return 0x00; + } + uint32_t val_u32 = *reinterpret_cast(&val); + e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); + uint32_t mantissa = val_u32 & 0x7FFFFF; + // Round up exponent and deal with satfinite. + if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { + ++exponent; + } + return exponent; +#endif +} + #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor @@ -169,6 +216,159 @@ __device__ __forceinline__ void fence_proxy_async_shared_cta() { asm volatile("fence.proxy.async.shared::cta;"); } +template +struct alignas(2 * sizeof(T)) FPx2 { + T x; + T y; +}; + +using floatx2 = FPx2; +using bf16x2 = FPx2; +using fp16x2 = FPx2; +using fp8e4m3x2 = FPx2; +using fp8e5m2x2 = FPx2; + +static_assert(sizeof(floatx2) == 8); +static_assert(sizeof(bf16x2) == 4); +static_assert(sizeof(fp16x2) == 4); +static_assert(sizeof(fp8e4m3x2) == 2); +static_assert(sizeof(fp8e5m2x2) == 2); + +// SIMD like "Fused" cast + multiplication (x2) +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const floatx2 &in, + const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + "mul.f32x2 val_pair, %1, %2; \n\t" + "mov.b64 {val2,val1}, val_pair; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const floatx2 &in, + const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + "mul.f32x2 val_pair, %1, %2; \n\t" + "mov.b64 {val2,val1}, val_pair; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "l"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const bf16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_bf16; \n\t" + ".reg.b16 val2_bf16; \n\t" + "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t" + "cvt.f32.bf16 val1, val1_bf16; \n\t" + "cvt.f32.bf16 val2, val2_bf16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const bf16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_bf16; \n\t" + ".reg.b16 val2_bf16; \n\t" + "mov.b32 {val1_bf16, val2_bf16} , %1; \n\t" + "cvt.f32.bf16 val1, val1_bf16; \n\t" + "cvt.f32.bf16 val2, val2_bf16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e4m3x2 &out, const fp16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_fp16; \n\t" + ".reg.b16 val2_fp16; \n\t" + "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t" + "cvt.f32.f16 val1, val1_fp16; \n\t" + "cvt.f32.f16 val2, val2_fp16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e4m3x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void mul_cvt_2x(fp8e5m2x2 &out, const fp16x2 &in, const floatx2 &scale) { + asm volatile( + "{\n" + ".reg.b64 val_pair_before; \n\t" + ".reg.b64 val_pair_after; \n\t" + ".reg.b32 val1; \n\t" + ".reg.b32 val2; \n\t" + ".reg.b16 val1_fp16; \n\t" + ".reg.b16 val2_fp16; \n\t" + "mov.b32 {val1_fp16, val2_fp16} , %1; \n\t" + "cvt.f32.f16 val1, val1_fp16; \n\t" + "cvt.f32.f16 val2, val2_fp16; \n\t" + "mov.b64 val_pair_before, {val1,val2}; \n\t" + "mul.f32x2 val_pair_after, val_pair_before, %2; \n\t" + "mov.b64 {val2,val1}, val_pair_after; \n\t" + "cvt.rn.satfinite.e5m2x2.f32 %0, val1, val2; \n\t" + "}" + : "=h"(reinterpret_cast(out)) + : "r"(reinterpret_cast(in)), + "l"(reinterpret_cast(scale))); +} + +__device__ __forceinline__ void abs_max_2x(bf16x2 &dst, const bf16x2 &p1, const bf16x2 &p2) { + asm volatile("max.xorsign.abs.bf16x2 %0, %1, %2;" + : "=r"(reinterpret_cast(dst)) + : "r"(reinterpret_cast(p1)), + "r"(reinterpret_cast(p2))); +} + +__device__ __forceinline__ void abs_max_2x(fp16x2 &dst, const fp16x2 &p1, const fp16x2 &p2) { + asm volatile("max.xorsign.abs.f16x2 %0, %1, %2;" + : "=r"(reinterpret_cast(dst)) + : "r"(reinterpret_cast(p1)), + "r"(reinterpret_cast(p2))); +} + #endif // #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) } // namespace ptx diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index e6a54108e..3f5bcc975 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -905,10 +905,7 @@ using fp8e4m3 = __nv_fp8_e4m3; using fp8e5m2 = __nv_fp8_e5m2; using e8m0_t = uint8_t; -constexpr uint32_t FP32_MANTISSA_BITS = 23; -constexpr uint32_t FP32_EXPONENT_BIAS = 127; - -enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENTIONAL = 2 }; +enum ScalingType { ROWWISE = 0, COLWISE = 1, BIDIMENSIONAL = 2 }; template struct Numeric_Traits; @@ -934,44 +931,6 @@ struct Quantized_Limits { static constexpr float emax_rcp = 1.0 / emax; }; -__device__ __forceinline__ e8m0_t float_to_e8m0(float val) { - // TODO: nan/inf needs to be set for any value - // of nan/inf in input not just amax. - if (isnan(val)) { - return 0xFF; - } - if (isinf(val)) { - return 0xFE; - } -#if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ - (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) - uint16_t out; - asm volatile( - "{\n" - "cvt.rp.satfinite.ue8m0x2.f32 %0, 0.0, %1;\n" - "}" - : "=h"(out) - : "f"(val)); - return *reinterpret_cast(&out); -#else - if (val == 0.0f) { - return 0x00; - } - uint32_t val_u32 = *reinterpret_cast(&val); - e8m0_t exponent = (val_u32 >> FP32_MANTISSA_BITS); - uint32_t mantissa = val_u32 & 0x7FFFFF; - // Round up exponent and deal with satfinite. - if ((mantissa > 0 && exponent != 0xFE) && !(exponent == 0 && mantissa <= 0x400000)) { - ++exponent; - } - return exponent; -#endif -} - -__device__ __forceinline__ float exp2f_rcp(e8m0_t biased_exp) { - return (biased_exp == 0) ? 1 : exp2f(FP32_EXPONENT_BIAS - static_cast(biased_exp)); -} - } // namespace transformer_engine #endif // TRANSFORMER_ENGINE_COMMON_UTILS_CUH_ From e0204fbbe0ee048d8372e1f7fe17adcf6da3fdc9 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 22 Jul 2025 10:19:08 -0700 Subject: [PATCH 008/153] Refactor `te.ops` (#1951) * Refactor _OperationFuserAutogradFunction.forward to use less parameters Signed-off-by: Jan Bielak (cherry picked from commit f8f59b1bb184e89468058521df4cfff029ad909c) * Rename `BackwardBiasActivation` to `BackwardActivationBias` Signed-off-by: Jan Bielak (cherry picked from commit 397c58fc296f801fe4ad600aadc2daff3b78be45) * Use forward operation order in backward fused operations Signed-off-by: Jan Bielak (cherry picked from commit 2d37a9385069b066e6cdeff3eb9173c2079cb791) * Rename `prev_op_grad_input_quantizer` to `prev_op_grad_output_quantizer` Signed-off-by: Jan Bielak (cherry picked from commit d7ab5dfb23e216866f7f4fc4d7a99f625d329f1e) * Make OperationFuser persistent Signed-off-by: Jan Bielak (cherry picked from commit 77984d9715d31e87519dc6ea1e02c483a81355a7) * Distribute extra inputs to and collect extra outputs from multiple module groups in Sequential Signed-off-by: Jan Bielak (cherry picked from commit 0716aaad542e59f2c1ac4620167965a0334bbf71) * Take requires_grad into account when fusing operations Signed-off-by: Jan Bielak * Change get_quantizer to return None if no quantization recipe is used Signed-off-by: Jan Bielak * Refactor pre_first_forward Signed-off-by: Jan Bielak * Fix for failing `test_make_graphed_callables[fp8_recipe0-*-True-*-linear_op]` Signed-off-by: Jan Bielak * Fix linting errors Signed-off-by: Jan Bielak * Apply suggestions from code review Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Jan Bielak * Fix fp8 meta tensors in CUDA Graph capture Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix failing distributed userbuffers tests Signed-off-by: Jan Bielak --------- Signed-off-by: Jan Bielak Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 70 +++++- tests/pytorch/test_recipe.py | 72 +++--- transformer_engine/pytorch/graph.py | 75 +++--- .../pytorch/ops/basic/activation.py | 26 +- .../pytorch/ops/basic/add_in_place.py | 2 +- .../pytorch/ops/basic/all_gather.py | 2 +- .../pytorch/ops/basic/all_reduce.py | 2 +- .../pytorch/ops/basic/basic_linear.py | 126 ++++------ transformer_engine/pytorch/ops/basic/bias.py | 23 +- .../pytorch/ops/basic/identity.py | 2 +- .../pytorch/ops/basic/l2normalization.py | 2 +- .../pytorch/ops/basic/layer_norm.py | 22 +- .../pytorch/ops/basic/make_extra_output.py | 2 +- .../pytorch/ops/basic/quantize.py | 5 +- .../pytorch/ops/basic/reduce_scatter.py | 2 +- .../pytorch/ops/basic/reshape.py | 5 +- .../pytorch/ops/basic/rmsnorm.py | 22 +- .../pytorch/ops/fused/__init__.py | 6 +- ...ivation.py => backward_activation_bias.py} | 19 +- .../pytorch/ops/fused/backward_linear_add.py | 6 +- .../fused/forward_linear_bias_activation.py | 42 ++-- .../ops/fused/forward_linear_bias_add.py | 41 ++-- .../ops/fused/userbuffers_backward_linear.py | 13 +- .../ops/fused/userbuffers_forward_linear.py | 40 ++- transformer_engine/pytorch/ops/fuser.py | 165 ++++++++----- transformer_engine/pytorch/ops/op.py | 232 +++++++----------- transformer_engine/pytorch/ops/sequential.py | 51 ++-- 27 files changed, 516 insertions(+), 559 deletions(-) rename transformer_engine/pytorch/ops/fused/{backward_bias_activation.py => backward_activation_bias.py} (87%) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 778ae687f..10e9dc5e7 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -20,7 +20,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops.fused import ( - BackwardBiasActivation, + BackwardActivationBias, BackwardLinearAdd, ForwardLinearBiasActivation, ForwardLinearBiasAdd, @@ -262,6 +262,65 @@ def test_module_groups(self) -> None: model(torch.zeros(1)) assert len(model._module_groups) == 6 + def test_extra_tensors(self, size: int = 16) -> None: + """Check that extra inputs are distributed properly between module groups + and that extra outputs are properly collected""" + + # Construct sequential container + bias = te_ops.Bias(size=size, device="cpu") + with torch.no_grad(): + bias.bias.copy_(torch.rand((size,))) + model = te_ops.Sequential( # | Inputs | Outputs + torch.nn.Identity(), # | x1 | x1 + te_ops.MakeExtraOutput(), # | x1 | x1 [x1] + bias, # | x1 | h1 (= x1 + b) + te_ops.MakeExtraOutput(), # | h1 | h1 [h1] + te_ops.AddInPlace(), # | h1 [x2] | x2 (= x2 + h1) + te_ops.MakeExtraOutput(), # | x2 | x2 [x2] + torch.nn.Identity(), # | x2 | x2 + bias, # | x2 | h2 (= x2 + b) + te_ops.AddInPlace(), # | h2 [x3] | x3 (= x3 + h2) + te_ops.MakeExtraOutput(), # | x3 | x3 [x3] + te_ops.AddInPlace(), # | x3 [x4] | x4 (= x4 + x3) + torch.nn.Identity(), # | x4 | x4 + te_ops.Identity(), # | x4 | x4 + te_ops.MakeExtraOutput(), # | x4 | x4 [x4] + te_ops.Identity(), # | x4 | x4 + ) + + # Create input tensors + x1 = torch.rand((size,)) + x2 = torch.rand((size,)) + x3 = torch.rand((size,)) + x4 = torch.rand((size,)) + + # Save original input tensor values + x1_orig = x1.clone() + x2_orig = x2.clone() + x3_orig = x3.clone() + x4_orig = x4.clone() + + # Run forward + ys = model(x1, x2, x3, x4) + + # Check whether outputs match (x4, x1, h1, x2, x3, x4) + assert len(ys) == 6 + assert ys[0].data_ptr() == x4.data_ptr() + assert ys[1].data_ptr() == x1.data_ptr() + assert ys[2].data_ptr() not in [x.data_ptr() for x in (x1, x2, x3, x4)] + assert ys[3].data_ptr() == x2.data_ptr() + assert ys[4].data_ptr() == x3.data_ptr() + assert ys[5].data_ptr() == x4.data_ptr() + + # Check whether tensors have correct values + b = bias.bias + h1 = ys[2] + torch.testing.assert_close(x1, x1_orig) + torch.testing.assert_close(h1, x1_orig + b) + torch.testing.assert_close(x2, x2_orig + h1) + torch.testing.assert_close(x3, x3_orig + x2 + b) + torch.testing.assert_close(x4, x4_orig + x3) + class TestFuser: """Tests for operation fusion infrastructure""" @@ -1870,7 +1929,7 @@ def test_forward_linear_bias_add( @pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) - def test_backward_bias_activation( + def test_backward_activation_bias( self, *, activation: str, @@ -1879,7 +1938,7 @@ def test_backward_bias_activation( device: torch.device = "cuda", quantization: Optional[str], ) -> None: - """Backward dbias + dact + quantize""" + """Backward dact + dbias + quantize""" # Tensor dimensions in_shape = list(out_shape) @@ -1938,7 +1997,7 @@ def test_backward_bias_activation( backward_ops = model._module_groups[0]._backward_ops if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]: assert len(backward_ops) == 2 - assert isinstance(backward_ops[0][0], BackwardBiasActivation) + assert isinstance(backward_ops[0][0], BackwardActivationBias) assert isinstance(backward_ops[1][0], te_ops.Quantize) else: assert len(backward_ops) == 3 @@ -2185,6 +2244,7 @@ def setup_class(cls) -> None: torch.manual_seed(seed) torch.cuda.manual_seed(seed) + @pytest.mark.parametrize("requires_grad", (False, True)) @pytest.mark.parametrize("bias", (False, True)) @pytest.mark.parametrize("normalization", ("LayerNorm", "RMSNorm")) @pytest.mark.parametrize("quantized_compute", (False, True)) @@ -2194,6 +2254,7 @@ def setup_class(cls) -> None: def test_layernorm_mlp( self, *, + requires_grad: bool, bias: bool, normalization: str, quantized_compute: bool, @@ -2234,6 +2295,7 @@ def test_layernorm_mlp( quantization=quantization, test_dtype=dtype, test_device=device, + requires_grad=requires_grad, ) _, dy_test = make_reference_and_test_tensors( in_shape, diff --git a/tests/pytorch/test_recipe.py b/tests/pytorch/test_recipe.py index 9a7228733..9a51c53e3 100644 --- a/tests/pytorch/test_recipe.py +++ b/tests/pytorch/test_recipe.py @@ -191,12 +191,6 @@ def test_fp8_scale_update_with_linear_fuser_op( amax_compute_algo=amax_compute_algo, ) - # Get FP8 meta tensors - with te.fp8_autocast(fp8_recipe=recipe): - x_fp8_meta = op.get_quantizer("forward", 0) - w_fp8_meta = op.get_quantizer("forward", 1) - dy_fp8_meta = op.get_quantizer("backward", 0) - # Perform training steps x_history = [] w_history = [] @@ -228,19 +222,30 @@ def test_fp8_scale_update_with_linear_fuser_op( y = op(x) y.backward(dy) - def check_amax_history( - fp8_meta: dict, - ref_amax_history: Iterable[float], - ) -> None: - """Check that amax history matches expected values""" - if len(ref_amax_history) > amax_history_len: - ref_amax_history = ref_amax_history[-amax_history_len:] + def check_metas( + test_scale: float, + test_amax_history: torch.Tensor, + ref_amax_history_list: list[float], + stage: str, + ): + """Check that meta tensors match expected values""" + + # Compute amax + if len(ref_amax_history_list) > amax_history_len: + ref_amax_history_list = ref_amax_history_list[-(amax_history_len + 1) :] ref_amax_history = torch.tensor( - ref_amax_history, + ref_amax_history_list, dtype=torch.float32, device=device, ) - test_amax_history = fp8_meta.amax_history[:, 0] + if amax_compute_algo == "max": + ref_amax = max(ref_amax_history_list) + elif amax_compute_algo == "most_recent": + ref_amax = ref_amax_history_list[-1] + else: + raise RuntimeError(f"{amax_compute_algo=} is not supported") + + # Compare amax history tols = dict(rtol=0, atol=0) torch.testing.assert_close( test_amax_history[-(step + 1) :], @@ -248,23 +253,6 @@ def check_amax_history( **tols, ) - def check_scale( - quantizer: Float8Quantizer, - ref_amax_history: Iterable[float], - stage: str, - ): - """Check that scale and scale reciprocal match expected values""" - - # Compute amax - if len(ref_amax_history) > amax_history_len: - ref_amax_history = ref_amax_history[-(amax_history_len + 1) :] - if amax_compute_algo == "max": - ref_amax = max(ref_amax_history) - elif amax_compute_algo == "most_recent": - ref_amax = ref_amax_history[-1] - else: - raise RuntimeError(f"{amax_compute_algo=} is not supported") - # Compute scale max_val = { "forward": 448.0, @@ -272,16 +260,26 @@ def check_scale( }[stage] ref_scale = (max_val / ref_amax) / (2**margin) - # Check values in FP8 meta tensors + # Compare scale torch.testing.assert_close( - quantizer.scale.item(), + test_scale, ref_scale, ) + # Get scaling factors + x_test_scale = op.get_quantizer("forward", 0).scale.item() + w_test_scale = op.get_quantizer("forward", 1).scale.item() + dy_test_scale = op.get_quantizer("backward", 0).scale.item() + + # Get amax histories + x_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 0] + w_test_history = op._fp8_metas["forward"][forward_key].amax_history[:, 1] + dy_test_history = op._fp8_metas["backward"][backward_key].amax_history[:, 0] + # Check that results match expected values - check_scale(x_fp8_meta, x_history, "forward") - check_scale(w_fp8_meta, w_history, "forward") - check_scale(dy_fp8_meta, dy_history, "backward") + check_metas(x_test_scale, x_test_history, x_history, "forward") + check_metas(w_test_scale, w_test_history, w_history, "forward") + check_metas(dy_test_scale, dy_test_history, dy_history, "backward") @pytest.mark.parametrize("amax_case", ["zero", "tiny", "normal", "inf", "nan"]) @pytest.mark.parametrize("fused_update", [True, False], ids=["fused", "non-fused"]) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 95f39fc92..6152d3aa7 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -21,6 +21,8 @@ from .distributed import get_all_rng_states, graph_safe_rng_available from .module.base import TransformerEngineBaseModule from .ops.op import BasicOperation +from .ops import Sequential +from .ops.fuser import OperationFuser from .utils import make_weak_ref __all__ = ["make_graphed_callables"] @@ -44,7 +46,7 @@ def set_capture_end() -> None: _IS_GRAPH_CAPTURING = False -def is_graph_capturing() -> None: +def is_graph_capturing() -> bool: """Return whether within `make_graphed_callables`.""" return _IS_GRAPH_CAPTURING @@ -338,6 +340,16 @@ def _make_graphed_callables( def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument if isinstance(module, TransformerEngineBaseModule): visited_te_modules.add(module) + # If forward is called on a BasicOperation directly the hook will run + elif isinstance(module, BasicOperation): + visited_te_modules.add(module) + # If forward is called on a te.ops.Sequential it is not called on its constituent ops + elif isinstance(module, Sequential): + assert module._module_groups is not None, "Should have been initialized by warmup" + for module_group in module._module_groups: + if isinstance(module_group, OperationFuser): + for basic_op in module_group._basic_ops: + visited_te_modules.add(basic_op) # Run warmup and do the above filtering. with torch.cuda.stream(torch.cuda.Stream()): @@ -674,31 +686,35 @@ def new_fwd(*user_args, **user_kwargs): # run the graph, otherwise run the original forward method if func.training == graph_training_state: # Set the FP8 group from global amax reduction. - for m in func.modules(): - if ( - isinstance(m, TransformerEngineBaseModule) - and FP8GlobalStateManager.is_fp8_enabled() - ): + if FP8GlobalStateManager.is_fp8_enabled(): + fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() + for m in func.modules(): if m not in visited_te_modules: # Only Set the FP8 meta for the modules included by forward continue - fp8_recipe = FP8GlobalStateManager.get_fp8_recipe() - from transformer_engine.pytorch.attention.dot_product_attention import ( - DotProductAttention, - ) - - if ( - isinstance(m, DotProductAttention) - and not fp8_recipe.fp8_mha - and not fp8_recipe.fp8_dpa - ): - # Don't need to update FP8 meta for non-FP8 DPA - continue - m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - m.fp8_meta, - ) + if isinstance(m, TransformerEngineBaseModule): + from transformer_engine.pytorch.attention.dot_product_attention import ( + DotProductAttention, + ) + + if ( + isinstance(m, DotProductAttention) + and not fp8_recipe.fp8_mha + and not fp8_recipe.fp8_dpa + ): + # Don't need to update FP8 meta for non-FP8 DPA + continue + m.fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() + m.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + m.fp8_meta, + ) + elif isinstance(m, BasicOperation): + for mode in ("forward", "backward"): + if m.num_quantizers(mode): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + m._fp8_metas[mode], + ) return graphed(*user_args, **user_kwargs) return orig_fwd(*user_args, **user_kwargs) @@ -721,7 +737,7 @@ def new_fwd(*user_args, **user_kwargs): def save_fp8_tensors( modules: Iterable[torch.nn.Module], - fp8_recipe: Recipe, + fp8_recipe: Optional[Recipe], ) -> Optional[List[Any]]: """ Returns the FP8 tensors for all modules @@ -740,7 +756,7 @@ def save_fp8_tensors( m.adjust_amax_history_length(fp8_recipe.amax_history_len) module_tensors = m.get_fp8_meta_tensors() elif isinstance(m, BasicOperation): - m.pre_first_forward(recipe=fp8_recipe) + m.reset_recipe_type(recipe=fp8_recipe) module_tensors = m._save_fp8_metas() fp8_tensors.append(module_tensors) return fp8_tensors @@ -777,7 +793,7 @@ def make_graphed_callables( sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, fp8_enabled: bool = False, fp8_calibrating: bool = False, - fp8_recipe: Optional[DelayedScaling] = None, + fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, fp8_weight_caching: bool = False, _order: Optional[List[int]] = None, @@ -828,7 +844,7 @@ def make_graphed_callables( data of fp8 tensors even when executing without fp8 enabled. This is useful for saving an inference ready fp8 checkpoint while training using a higher precision. - fp8_recipe: recipe.DelayedScaling, default = `None` + fp8_recipe: Recipe, default = `None` recipe used for FP8 training. fp8_group: torch._C._distributed_c10d.ProcessGroup, default = `None` distributed group over which amaxes for the fp8 tensors @@ -844,7 +860,10 @@ def make_graphed_callables( """ set_capture_start() - fp8_recipe = get_default_fp8_recipe() if fp8_recipe is None else fp8_recipe + if fp8_enabled and fp8_recipe is None: + fp8_recipe = get_default_fp8_recipe() + elif not fp8_enabled: + fp8_recipe = None # Handle single module. just_one_callable = False diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index c077829a3..f1b59170e 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -11,7 +11,6 @@ import torch import transformer_engine_torch as tex -from ...fp8 import FP8GlobalStateManager from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -71,7 +70,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: @@ -87,14 +86,8 @@ def op_forward( # Check input tensor x = maybe_dequantize(input_.contiguous(), dtype) - # Check if quantized compute is enabled - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - quantizer = None - if with_quantized_compute: - quantizer = next_op_input_quantizer - # Launch kernel - y = self._activation_forward_impl(x, quantizer) + y = self._activation_forward_impl(x, next_op_input_quantizer) # Quantize input to FP8 before caching if needed if self.cache_quantized_input: @@ -103,10 +96,10 @@ def op_forward( x = input_quantizer(x) # Save state for backward pass - ctx.save_for_backward(x) - ctx.with_quantized_compute = with_quantized_compute - ctx.dtype = dtype - ctx.prev_op_grad_input_quantizer = prev_op_grad_input_quantizer + if ctx.requires_grad: + ctx.save_for_backward(x) + ctx.dtype = dtype + ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer return y @@ -125,13 +118,8 @@ def op_backward( # Check grad output tensor dy = maybe_dequantize(grad_output.contiguous(), x.dtype) - # Check if quantized compute is enabled - quantizer = None - if ctx.with_quantized_compute: - quantizer = ctx.prev_op_grad_input_quantizer - # Launch kernel - dx = self._activation_backward_impl(dy, x, quantizer) + dx = self._activation_backward_impl(dy, x, ctx.prev_op_grad_output_quantizer) # Clear input tensor if possible clear_tensor_data(x) diff --git a/transformer_engine/pytorch/ops/basic/add_in_place.py b/transformer_engine/pytorch/ops/basic/add_in_place.py index e1493d3c7..3a7f1843b 100644 --- a/transformer_engine/pytorch/ops/basic/add_in_place.py +++ b/transformer_engine/pytorch/ops/basic/add_in_place.py @@ -59,7 +59,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: diff --git a/transformer_engine/pytorch/ops/basic/all_gather.py b/transformer_engine/pytorch/ops/basic/all_gather.py index 0df165a06..bcd3c1417 100644 --- a/transformer_engine/pytorch/ops/basic/all_gather.py +++ b/transformer_engine/pytorch/ops/basic/all_gather.py @@ -40,7 +40,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: out: torch.Tensor diff --git a/transformer_engine/pytorch/ops/basic/all_reduce.py b/transformer_engine/pytorch/ops/basic/all_reduce.py index af928dd24..d8c1eb006 100644 --- a/transformer_engine/pytorch/ops/basic/all_reduce.py +++ b/transformer_engine/pytorch/ops/basic/all_reduce.py @@ -42,7 +42,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 59fc09607..5f5b38184 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -22,7 +22,6 @@ from ...fp8 import FP8GlobalStateManager, Recipe from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer -from ...tensor.float8_tensor import Float8Quantizer, Float8CurrentScalingQuantizer from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase @@ -291,6 +290,14 @@ def reset_parameters(self) -> None: # Quantize if needed if self._with_quantized_weight: quantizer = self.get_quantizer("forward", 1) + if quantizer is None: + raise RuntimeError( + "Tried to quantize weight with deferred initialization " + "due to meta device, but no quantizer was available. " + "This is most likely because fp8_model_init was called " + "with enabled=True and recipe=None, instead of providing " + "a recipe to use for quantization." + ) quantizer.set_usage( rowwise=True, columnwise=torch.is_grad_enabled(), @@ -303,62 +310,19 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_first_forward( - self, - *, - recipe: Optional[Recipe], - ) -> None: - super().pre_first_forward(recipe=recipe) - - # Initialize weights if needed - weight = self.weight - if weight.device.type == "meta": + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() + if self.weight.device.type == "meta": self.reset_parameters() - weight = self.weight - # Configure quantizers - if recipe is not None: - input_quantizer = self.get_quantizer("forward", 0) - weight_quantizer = self.get_quantizer("forward", 1) - grad_output_quantizer = self.get_quantizer("backward", 0) + def reset_recipe_type(self, *, recipe: Optional[Recipe]) -> None: + super().reset_recipe_type(recipe=recipe) - # Specify required tensor formats - input_quantizer.internal = True - weight_quantizer.internal = True - grad_output_quantizer.internal = True - - # Recipe-specific configuration - if recipe.float8_current_scaling(): - if any( - not isinstance(q, Float8CurrentScalingQuantizer) - for q in (input_quantizer, weight_quantizer, grad_output_quantizer) - ): - raise RuntimeError( - "FP8 current-scaling recipe is enabled, " - f"but input quantizer is {input_quantizer.__class__.__name__}, " - f"weight quantizer is {weight_quantizer.__class__.__name__}, " - f"grad output quantizer is {grad_output_quantizer.__class__.__name__}" - ) - input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale - grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon - if self.sequence_parallel and self.tensor_parallel_mode == "column": - input_quantizer.with_amax_reduction = True - input_quantizer.amax_reduction_group = self.tensor_parallel_group - if self.sequence_parallel and self.tensor_parallel_mode == "row": - grad_output_quantizer.with_amax_reduction = True - grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group - - # Make sure weight tensor has correct quantizer - # Note: Quantizer might have changed if quantization - # recipe changed - if isinstance( - weight_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) - ) and isinstance(weight, Float8TensorBase): - weight._quantizer = weight_quantizer + if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters(): + # Make quantizers use internal tensors + self.get_input_quantizer().internal = True + self.get_grad_output_quantizer().internal = True + self.get_quantizer("forward", 1).internal = True @staticmethod def _functional_forward( @@ -894,7 +858,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: @@ -903,27 +867,34 @@ def op_forward( weight_requires_grad = ctx.requires_grad and self.weight.requires_grad # FP8 metadata + input_quantizer = self.get_quantizer("forward", 0) + weight_quantizer = self.get_quantizer("forward", 1) + output_quantizer = next_op_input_quantizer + grad_output_quantizer = self.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - input_quantizer = None - weight_quantizer = None - output_quantizer = None - grad_output_quantizer = None - grad_input_quantizer = None if with_quantized_compute: - - # Get quantizers - input_quantizer = self.get_quantizer("forward", 0) - weight_quantizer = self.get_quantizer("forward", 1) - output_quantizer = next_op_input_quantizer - grad_output_quantizer = self.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_input_quantizer - # Configure quantizers # Note: We cache the quantized input for backward pass, # but discard the quantized weights. input_quantizer.set_usage(rowwise=True, columnwise=weight_requires_grad) weight_quantizer.set_usage(rowwise=True, columnwise=False) + recipe = FP8GlobalStateManager.get_fp8_recipe() + if recipe.float8_current_scaling(): + input_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + input_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + weight_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + weight_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + grad_output_quantizer.force_pow_2_scales = recipe.fp8_quant_fwd_inp.power_2_scale + grad_output_quantizer.amax_epsilon_scales = recipe.fp8_quant_fwd_inp.amax_epsilon + if self.sequence_parallel and self.tensor_parallel_mode == "column": + input_quantizer.with_amax_reduction = True + input_quantizer.amax_reduction_group = self.tensor_parallel_group + if self.sequence_parallel and self.tensor_parallel_mode == "row": + grad_output_quantizer.with_amax_reduction = True + grad_output_quantizer.amax_reduction_group = self.tensor_parallel_group + # Get autocast dtype if needed if torch.is_autocast_enabled(): dtype = torch.get_autocast_dtype("cuda") @@ -947,15 +918,16 @@ def op_forward( ) # Save state for backward pass - ctx.save_for_backward(x_local, w) - ctx.with_quantized_compute = with_quantized_compute - ctx.input_quantizer = input_quantizer - ctx.weight_quantizer = weight_quantizer - ctx.grad_output_quantizer = grad_output_quantizer - ctx.grad_input_quantizer = grad_input_quantizer - ctx.dtype = dtype - ctx.input_requires_grad = input_requires_grad - ctx.weight_requires_grad = weight_requires_grad + if ctx.requires_grad: + ctx.save_for_backward(x_local, w) + ctx.with_quantized_compute = with_quantized_compute + ctx.input_quantizer = input_quantizer + ctx.weight_quantizer = weight_quantizer + ctx.grad_output_quantizer = grad_output_quantizer + ctx.grad_input_quantizer = grad_input_quantizer + ctx.dtype = dtype + ctx.input_requires_grad = input_requires_grad + ctx.weight_requires_grad = weight_requires_grad return output diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index a985601e2..4c107b888 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -18,7 +18,6 @@ canonicalize_device, canonicalize_dtype, ) -from ...fp8 import FP8GlobalStateManager from ...tensor import Quantizer @@ -114,8 +113,8 @@ def reset_parameters(self) -> None: bias = torch.nn.Parameter(bias) self.bias = bias - def pre_first_forward(self, *args, **kwargs) -> None: - super().pre_first_forward(*args, **kwargs) + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() if self.bias.device.type == "meta": self.reset_parameters() @@ -123,24 +122,14 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: x = input_ b = self.bias.view([1] * (x.dim() - 1) + [self.local_size]) - # Check if backward pass is needed - requires_grad = ctx.requires_grad - - # Check if previous op quantizes its output's gradient - grad_input_quantizer = None - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - grad_input_quantizer = prev_op_grad_input_quantizer - - if requires_grad: - ctx.with_quantized_compute = with_quantized_compute - ctx.grad_input_quantizer = grad_input_quantizer + if ctx.requires_grad: + ctx.grad_input_quantizer = prev_op_grad_output_quantizer return x + b @@ -152,7 +141,7 @@ def op_backward( dy = grad_output if dy.dim() > 1: quantizer = ctx.grad_input_quantizer - if ctx.with_quantized_compute and quantizer is not None: + if quantizer is not None: db, dy = tex.bgrad_quantize(dy, quantizer) else: db = dy.sum(tuple(range(dy.dim() - 1))) diff --git a/transformer_engine/pytorch/ops/basic/identity.py b/transformer_engine/pytorch/ops/basic/identity.py index 3161e77c7..788b3aac8 100644 --- a/transformer_engine/pytorch/ops/basic/identity.py +++ b/transformer_engine/pytorch/ops/basic/identity.py @@ -23,7 +23,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: return input_ diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index d8196c1bd..1e72475ad 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -74,7 +74,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: # Use input directly - torch.compile can handle multi-dimensional tensors diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 8286932f0..3d8862e99 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -13,7 +13,6 @@ import torch from transformer_engine_torch import layernorm_bwd, layernorm_fwd -from ...fp8 import FP8GlobalStateManager from ...constants import TE_DType from ...utils import ( canonicalize_device, @@ -168,8 +167,8 @@ def reset_parameters(self) -> None: self.weight = weight self.bias = bias - def pre_first_forward(self, *args, **kwargs) -> None: - super().pre_first_forward(*args, **kwargs) + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() if self.weight.device.type == "meta" or self.bias.device.type == "meta": self.reset_parameters() @@ -177,7 +176,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: if is_in_onnx_export_mode(): @@ -200,31 +199,22 @@ def op_forward( w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) b = maybe_dequantize(self.bias, dtype).view((inner_dim,)) - # Check if backward pass is needed - requires_grad = ctx.requires_grad - - # Check if output is quantized - output_quantizer = None - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - output_quantizer = next_op_input_quantizer - # Compute layer norm - sm_margin = self._sm_margins["forward" if requires_grad else "inference"] + sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"] y, means, rstdevs = layernorm_fwd( x, w, b, self.eps, None, - output_quantizer, + next_op_input_quantizer, TE_DType[dtype], sm_margin, self.zero_centered_gamma, ) # Save state for backward pass - if requires_grad: + if ctx.requires_grad: ctx.save_for_backward(x, means, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index 81b581ae2..f64b609de 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -59,7 +59,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: diff --git a/transformer_engine/pytorch/ops/basic/quantize.py b/transformer_engine/pytorch/ops/basic/quantize.py index 005e9fd8d..dcfc3c4f7 100644 --- a/transformer_engine/pytorch/ops/basic/quantize.py +++ b/transformer_engine/pytorch/ops/basic/quantize.py @@ -50,7 +50,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: @@ -64,7 +64,8 @@ def op_forward( if quantize_forward and not is_quantized_tensor(out): out = self.get_quantizer("forward", 0)(out) - ctx.quantize_backward = quantize_backward + if ctx.requires_grad: + ctx.quantize_backward = quantize_backward return out def op_backward( diff --git a/transformer_engine/pytorch/ops/basic/reduce_scatter.py b/transformer_engine/pytorch/ops/basic/reduce_scatter.py index 1238b0879..e0017853f 100644 --- a/transformer_engine/pytorch/ops/basic/reduce_scatter.py +++ b/transformer_engine/pytorch/ops/basic/reduce_scatter.py @@ -40,7 +40,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: diff --git a/transformer_engine/pytorch/ops/basic/reshape.py b/transformer_engine/pytorch/ops/basic/reshape.py index 8d8b75ff0..50af9fcff 100644 --- a/transformer_engine/pytorch/ops/basic/reshape.py +++ b/transformer_engine/pytorch/ops/basic/reshape.py @@ -38,10 +38,11 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: - ctx.input_shape = input_.size() + if ctx.requires_grad: + ctx.input_shape = input_.size() return input_.reshape(*self._shape) def op_backward( diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 33a83daf5..42d3fc101 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -13,7 +13,6 @@ import torch from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd -from ...fp8 import FP8GlobalStateManager from ...constants import TE_DType from ...utils import ( canonicalize_device, @@ -151,8 +150,8 @@ def reset_parameters(self) -> None: weight = torch.nn.Parameter(weight) self.weight = weight - def pre_first_forward(self, *args, **kwargs) -> None: - super().pre_first_forward(*args, **kwargs) + def pre_first_fuser_forward(self) -> None: + super().pre_first_fuser_forward() if self.weight.device.type == "meta": self.reset_parameters() @@ -160,7 +159,7 @@ def op_forward( self, ctx: OperationContext, input_: torch.Tensor, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: if is_in_onnx_export_mode(): @@ -182,30 +181,21 @@ def op_forward( x = maybe_dequantize(input_.contiguous(), dtype).view((-1, inner_dim)) w = maybe_dequantize(self.weight, dtype).view((inner_dim,)) - # Check if backward pass is needed - requires_grad = ctx.requires_grad - - # Check if output is quantized - output_quantizer = None - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - if with_quantized_compute: - output_quantizer = next_op_input_quantizer - # Compute RMSNorm - sm_margin = self._sm_margins["forward" if requires_grad else "inference"] + sm_margin = self._sm_margins["forward" if ctx.requires_grad else "inference"] y, _, rstdevs = rmsnorm_fwd( x, w, self.eps, None, - output_quantizer, + next_op_input_quantizer, TE_DType[dtype], sm_margin, self.zero_centered_gamma, ) # Save state for backward pass - if requires_grad: + if ctx.requires_grad: ctx.save_for_backward(x, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 7e15f38cf..3ee23dc7f 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -4,9 +4,9 @@ """Compound tensor operation supported by the operation fuser.""" -from .backward_bias_activation import ( - BackwardBiasActivation, - fuse_backward_bias_activation, +from .backward_activation_bias import ( + BackwardActivationBias, + fuse_backward_activation_bias, ) from .backward_linear_add import ( BackwardLinearAdd, diff --git a/transformer_engine/pytorch/ops/fused/backward_bias_activation.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py similarity index 87% rename from transformer_engine/pytorch/ops/fused/backward_bias_activation.py rename to transformer_engine/pytorch/ops/fused/backward_activation_bias.py index f4b7b9ec3..bf3ff8ca6 100644 --- a/transformer_engine/pytorch/ops/fused/backward_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Fused backward dbias + dact + quantize.""" +"""Fused backward dact + dbias + quantize.""" from __future__ import annotations from typing import Optional @@ -29,8 +29,8 @@ _fusible_activations = tuple(_fused_activations.keys()) -class BackwardBiasActivation(FusedOperation): - """Fused backward dbias + dact + quantize +class BackwardActivationBias(FusedOperation): + """Fused backward dact + dbias + quantize Uses the next operation's input quantizer. @@ -66,15 +66,10 @@ def fuser_backward( dy = maybe_dequantize(grad_output.contiguous(), act_input.dtype) # Get previous op quantizer - if not bias_op_ctx.with_quantized_compute: - raise RuntimeError( - "BackwardBiasActivation requires quantized compute, " - "but Bias context has it disabled" - ) quantizer = bias_op_ctx.grad_input_quantizer if quantizer is None: raise RuntimeError( - "BackwardBiasActivation requires previous op's grad output quantizer, " + "BackwardActivationBias requires previous op's grad output quantizer, " "but Bias context has no quantizer" ) @@ -87,11 +82,11 @@ def fuser_backward( return dx, [(), (db,)], [(), ()] -def fuse_backward_bias_activation( +def fuse_backward_activation_bias( ops: list[tuple[FusibleOperation, list[int]]], recipe: Optional[Recipe], ) -> list[tuple[FusibleOperation, list[int]]]: - """Fused backward dbias + dact + quantize + """Fused backward dact + dbias + quantize Parameters ---------- @@ -138,7 +133,7 @@ def fuse_backward_bias_activation( ops = ops[1:] # Replace window with fused op - op = BackwardBiasActivation( + op = BackwardActivationBias( activation=window[0][0], bias=window[1][0], ) diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 54ddfaa5c..286503419 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -29,10 +29,10 @@ class BackwardLinearAdd(FusedOperation): def __init__( self, *, - linear: BasicLinear, backward_add: MakeExtraOutput, + linear: BasicLinear, ) -> None: - super().__init__((linear, backward_add)) + super().__init__((backward_add, linear)) def fuser_backward( self, @@ -47,7 +47,7 @@ def fuser_backward( ]: # Get basic operations - linear_op = self.basic_ops[0] + linear_op = self.basic_ops[1] linear_op_ctx = basic_op_ctxs[0] # Saved tensors from forward pass diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index 5d1223bd8..b87b12f84 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -59,7 +59,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -89,18 +89,12 @@ def fuser_forward( weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad # FP8 metadata + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + output_quantizer = next_op_input_quantizer + grad_output_quantizer = linear_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - input_quantizer = None - weight_quantizer = None - output_quantizer = None - grad_output_quantizer = None - grad_input_quantizer = None - if with_quantized_compute: - input_quantizer = linear_op.get_quantizer("forward", 0) - weight_quantizer = linear_op.get_quantizer("forward", 1) - output_quantizer = next_op_input_quantizer - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_input_quantizer # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -126,18 +120,18 @@ def fuser_forward( ) # Save state for backward pass - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - if bias_op is not None: - bias_op_ctx.with_quantized_compute = with_quantized_compute - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() + if linear_op_ctx.requires_grad: + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer + linear_op_ctx.dtype = dtype + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + if bias_op is not None and bias_op_ctx.requires_grad: + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 5055bc60a..608fff01f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -57,7 +57,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -83,17 +83,12 @@ def fuser_forward( weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad # FP8 metadata - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - input_quantizer = None - weight_quantizer = None + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) output_quantizer = None - grad_output_quantizer = None - grad_input_quantizer = None - if with_quantized_compute: - input_quantizer = linear_op.get_quantizer("forward", 0) - weight_quantizer = linear_op.get_quantizer("forward", 1) - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_input_quantizer + grad_output_quantizer = linear_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -122,18 +117,18 @@ def fuser_forward( ) # Save state for backward pass - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - if bias_op is not None: - bias_op_ctx.with_quantized_compute = with_quantized_compute - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() + if linear_op_ctx.requires_grad: + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer + linear_op_ctx.dtype = dtype + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + if bias_op is not None and bias_op_ctx.requires_grad: + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 4fbc28482..b8acb02e3 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -48,14 +48,14 @@ def __init__( # Basic operations that comprise this fused operation op_idxs = {"linear": None, "bias": None, "reduce_scatter": None} ops = [] - if reduce_scatter is not None: - op_idxs["reduce_scatter"] = len(ops) - ops.append(reduce_scatter) + op_idxs["linear"] = len(ops) + ops.append(linear) if bias is not None: op_idxs["bias"] = len(ops) ops.append(bias) - op_idxs["linear"] = len(ops) - ops.append(linear) + if reduce_scatter is not None: + op_idxs["reduce_scatter"] = len(ops) + ops.append(reduce_scatter) # Initialize base class super().__init__(ops) @@ -495,7 +495,7 @@ def fuser_backward( # Get basic operations idx = self._op_idxs["linear"] linear_op = self.basic_ops[idx] - linear_op_ctx = basic_op_ctxs[idx] + linear_op_ctx = basic_op_ctxs[-1] bias_op = None if self._op_idxs["bias"] is not None: idx = self._op_idxs["bias"] @@ -556,6 +556,7 @@ def fuser_backward( grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: grad_params[self._op_idxs["bias"]] = (grad_bias,) + grad_params.reverse() grad_extra_inputs = [() for _ in range(len(self.basic_ops))] return grad_input, grad_params, grad_extra_inputs diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 30d9cdaae..9316f3d79 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -282,7 +282,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -307,21 +307,17 @@ def fuser_forward( weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad # Quantization metadata + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + grad_output_quantizer = linear_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - input_quantizer = None - weight_quantizer = None - grad_output_quantizer = None - grad_input_quantizer = None if with_quantized_compute: recipe = FP8GlobalStateManager.get_fp8_recipe() if not any((recipe.delayed(), recipe.float8_current_scaling(), recipe.mxfp8())): raise RuntimeError( f"Unsupported recipe for Userbuffers ({recipe.__class__.__name__})" ) - input_quantizer = linear_op.get_quantizer("forward", 0) - weight_quantizer = linear_op.get_quantizer("forward", 1) - grad_output_quantizer = linear_op.get_quantizer("backward", 0) - grad_input_quantizer = prev_op_grad_input_quantizer # Get autocast dtype if needed if torch.is_autocast_enabled(): @@ -356,19 +352,19 @@ def fuser_forward( w = extra_outputs["weight"] # Save state for backward pass - linear_op_ctx.save_for_backward(x_local, w) - linear_op_ctx.with_quantized_compute = with_quantized_compute - linear_op_ctx.input_quantizer = input_quantizer - linear_op_ctx.weight_quantizer = weight_quantizer - linear_op_ctx.grad_output_quantizer = grad_output_quantizer - linear_op_ctx.grad_input_quantizer = grad_input_quantizer - linear_op_ctx.dtype = dtype - linear_op_ctx.input_dims = input_.size() - linear_op_ctx.input_requires_grad = input_requires_grad - linear_op_ctx.weight_requires_grad = weight_requires_grad - if bias_op is not None: - bias_op_ctx.with_quantized_compute = with_quantized_compute - bias_op_ctx.grad_input_quantizer = linear_op.get_grad_input_quantizer() + if linear_op_ctx.requires_grad: + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer + linear_op_ctx.dtype = dtype + linear_op_ctx.input_dims = input_.size() + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + if bias_op is not None and bias_op_ctx.requires_grad: + bias_op_ctx.grad_input_quantizer = linear_op.get_grad_output_quantizer() return output, [() for _ in range(len(self.basic_ops))] diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 44ae2bb19..e618da349 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -5,19 +5,20 @@ """Manager class for a pipeline of fusible operations.""" from __future__ import annotations -from collections.abc import Callable +from collections.abc import Callable, Iterable from typing import Any, Optional +import itertools import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe, DelayedScaling from transformer_engine.pytorch.ops.op import ( BasicOperation, FusibleOperation, OperationContext, ) from transformer_engine.pytorch.ops.fused import ( - fuse_backward_bias_activation, + fuse_backward_activation_bias, fuse_backward_linear_add, fuse_forward_linear_bias_activation, fuse_forward_linear_bias_add, @@ -68,8 +69,7 @@ def forward( input_: torch.Tensor, fuser: OperationFuser, basic_op_kwargs: list[dict[str, Any]], - is_grad_enabled: bool, - *params_and_extra_inputs: torch.nn.Parameter, + *params_and_extra_inputs: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass @@ -83,8 +83,6 @@ def forward( Container for the pipeline of operations to run basic_op_kwargs: list of dict Keyword arguments to BasicOperation - is_grad_enabled: bool - Should context be saved for backward *params_and_extra_inputs: torch.Tensor Other tensor inputs to include in autograd graph. Consists of parameter tensors, followed by extra operation inputs. @@ -106,52 +104,53 @@ def forward( tensor.do_not_clear = True # Unflatten list of parameters and extra tensor inputs - extra_inputs = params_and_extra_inputs[-fuser._num_extra_inputs :] + extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :] basic_op_extra_inputs = [] for op in fuser._basic_ops: xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) basic_op_extra_inputs.append(xs) + # Get environment state + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None + is_grad_enabled = func_ctx is not None + + # Attempt to fuse operations if neccesary + fuser.maybe_fuse_ops(is_grad_enabled, recipe, input_, basic_op_extra_inputs) + # Apply forward ops x = input_ - requires_grad = is_grad_enabled and x.requires_grad - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() extra_outputs = [None] * fuser._num_basic_ops for op, basic_op_idxs in fuser._forward_ops: - # Check if backward op is required - if is_grad_enabled: - if not requires_grad: - requires_grad = any(param.requires_grad for param in op.parameters()) - if not requires_grad: - requires_grad = any(any(x.requires_grad for x in xs) for xs in extra_inputs) + # Set if backward op is required for idx in basic_op_idxs: - basic_op_ctxs[idx].requires_grad = requires_grad + basic_op_ctxs[idx].requires_grad = idx >= fuser.first_op_requiring_backward # Forward op extra_inputs = [basic_op_extra_inputs[idx] for idx in basic_op_idxs] prev_op_idx = basic_op_idxs[0] - 1 prev_op = fuser._basic_ops[prev_op_idx] if prev_op_idx >= 0 else None - prev_op_grad_input_quantizer = None - if prev_op is not None and with_quantized_compute: - prev_op_grad_input_quantizer = prev_op.get_grad_input_quantizer() + prev_op_grad_output_quantizer = None + if prev_op is not None: + prev_op_grad_output_quantizer = prev_op.get_grad_output_quantizer() next_op_idx = basic_op_idxs[-1] + 1 next_op = fuser._basic_ops[next_op_idx] if next_op_idx < fuser._num_basic_ops else None next_op_input_quantizer = None - if next_op is not None and with_quantized_compute: + if next_op is not None: next_op_input_quantizer = next_op.get_input_quantizer() x, fused_op_extra_outputs = op.fuser_forward( [basic_op_ctxs[idx] for idx in basic_op_idxs], x, basic_op_extra_inputs=extra_inputs, - prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, next_op_input_quantizer=next_op_input_quantizer, basic_op_kwargs=[basic_op_kwargs[idx] for idx in basic_op_idxs], ) for idx, ys in zip(basic_op_idxs, fused_op_extra_outputs): for y in ys: - y.requires_grad_(requires_grad) + y.requires_grad_(idx >= fuser.first_op_requiring_backward) extra_outputs[idx] = ys # Flatten list of extra outputs @@ -192,13 +191,13 @@ def forward( func_ctx.backward_ops = fuser._backward_ops func_ctx.basic_ops = fuser._basic_ops func_ctx.basic_op_ctxs = basic_op_ctxs - func_ctx.basic_op_num_params = fuser._num_list_basic_op_params - func_ctx.num_extra_inputs = fuser._num_extra_inputs + func_ctx.basic_op_num_params = fuser._basic_op_num_params + func_ctx.num_extra_inputs = fuser.num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.with_quantized_compute = with_quantized_compute - x.requires_grad_(requires_grad) + x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) if extra_outputs_flat: return x, *extra_outputs_flat @@ -304,7 +303,6 @@ def backward( dx, # input_ None, # fuser None, # basic_op_kwargs - None, # is_grad_enabled *grad_params_flat, *grad_extra_inputs_flat, ) @@ -317,19 +315,12 @@ class OperationFuser: ---------- ops: list of FusibleOperation Pipeline of operations - fuse_ops: bool - Whether to attempt fusing operations - recipe: Recipe, optional - Quantization recipe to use when fusing and executing operations. - Note: certain fusions may depend on what kind of recipe is being used. """ def __init__( self, ops: list[FusibleOperation], - fuse_ops: bool, - recipe: Optional[Recipe], ) -> None: # Get list of basic operations @@ -343,25 +334,22 @@ def __init__( self._basic_ops: list[BasicOperation] = basic_ops # Number of extra tensor inputs - self._num_extra_inputs: int = sum(op.num_extra_inputs for op in basic_ops) + self._basic_op_num_extra_inputs: list[int] = list(op.num_extra_inputs for op in basic_ops) + self.num_extra_inputs: int = sum(self._basic_op_num_extra_inputs) - # Ops for forward and backward pass + # Ops for forward and backward pass, will be populated in fuse_ops self._forward_ops: list[tuple[FusibleOperation, list[int]]] self._backward_ops: list[tuple[FusibleOperation, list[int]]] - self._forward_ops = [(op, (idx,)) for idx, op in enumerate(self._basic_ops)] - self._backward_ops = list(reversed(self._forward_ops)) - # Flag for checking if this is the first iteration - self._is_first_forward = True - - # Fuse ops if needed - self.recipe = recipe - if fuse_ops: - self.fuse_ops() + # Cache and detect change of state relevant for fusing operations + self.recipe_type = None + self.first_op_requiring_backward = 0 + self._last_amax_history_len = 0 # Flatten list of parameters - self._basic_op_params = [param for op in self._basic_ops for param in op.parameters()] - self._num_list_basic_op_params = [sum(1 for _ in op.parameters()) for op in self._basic_ops] + self._basic_op_params = [list(op.parameters()) for op in self._basic_ops] + self._basic_op_num_params = list(map(len, self._basic_op_params)) + self._flat_basic_op_params = sum(self._basic_op_params, []) @classmethod def _fuse_forward_ops( @@ -384,13 +372,70 @@ def _fuse_backward_ops( """Attempt to fuse operations in backward pass""" ops = fuse_userbuffers_backward_linear(ops) ops = fuse_backward_linear_add(ops) - ops = fuse_backward_bias_activation(ops, recipe) + ops = fuse_backward_activation_bias(ops, recipe) return ops - def fuse_ops(self) -> None: - """Attempt to fuse operations""" - self._forward_ops = self._fuse_forward_ops(self._forward_ops, self.recipe) - self._backward_ops = self._fuse_backward_ops(self._backward_ops, self.recipe) + def maybe_fuse_ops( + self, + is_grad_enabled: bool, + recipe: Optional[Recipe], + input_: torch.Tensor, + extra_inputs: list[Iterable[torch.Tensor]], + ): + """Attempt to fuse operations if neccesary""" + + # Determine which basic ops require backward + if not is_grad_enabled: + first_op_requiring_backward = self._num_basic_ops + elif input_.requires_grad: + first_op_requiring_backward = 0 + else: + first_op_requiring_backward = self._num_basic_ops + for op_idx in range(self._num_basic_ops): + op_inputs = itertools.chain(self._basic_op_params[op_idx], extra_inputs[op_idx]) + if any(tensor.requires_grad for tensor in op_inputs): + first_op_requiring_backward = op_idx + break + + # Early exit if fusion parameters haven't changed + recipe_type = type(recipe) + fusion_params = (recipe_type, first_op_requiring_backward) + if fusion_params == (self.recipe_type, self.first_op_requiring_backward): + return + + # Initialize ops if recipe type has changed + if self.recipe_type != recipe_type: + # Check if this is the first iteration + if self.recipe_type is None: + for op in self._basic_ops: + op.pre_first_fuser_forward() + # Inform ops that the recipe type has changed + for op in self._basic_ops: + op.reset_recipe_type(recipe=recipe) + # Check if amax history was invalidated + elif isinstance(recipe, DelayedScaling): + if recipe.amax_history_len != self._last_amax_history_len: + raise RuntimeError( + "Detected change of amax history length. " + "Changing the length of amax history is currently not supported." + ) + + # Prepare basic op lists for fusions + forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)] + backward_ops = list(reversed(forward_ops[first_op_requiring_backward:])) + + # Fuse ops + self._forward_ops = self._fuse_forward_ops(forward_ops, recipe) + self._backward_ops = self._fuse_backward_ops(backward_ops, recipe) + + # Save current fusion params + self.recipe_type, self.first_op_requiring_backward = fusion_params + + # Save amax history length + if isinstance(recipe, DelayedScaling): + self._last_amax_history_len = recipe.amax_history_len + else: + self._last_amax_history_len = 0 def __call__( self, @@ -399,24 +444,17 @@ def __call__( basic_op_kwargs: Optional[list[dict[str, Any]]] = None, ) -> torch.Tensor | tuple[torch.Tensor, ...]: # Verify extra input count - if len(extra_inputs) != self._num_extra_inputs: + if len(extra_inputs) != self.num_extra_inputs: raise ValueError( - f"Expected {self._num_extra_inputs} extra inputs but got {len(extra_inputs)}" + f"Expected {self.num_extra_inputs} extra inputs but got {len(extra_inputs)}" ) - # Initialization before forward pass - if self._is_first_forward: - for op in self._basic_ops: - op.pre_first_forward(recipe=self.recipe) - self._is_first_forward = False - # Canonicalize op kwargs if basic_op_kwargs is None: basic_op_kwargs = [{}] * self._num_basic_ops # Fuser forward pass - is_grad_enabled = torch.is_grad_enabled() - if is_grad_enabled: + if torch.is_grad_enabled(): forward_func = _OperationFuserAutogradFunction.apply args = [] else: @@ -426,8 +464,7 @@ def __call__( input, self, basic_op_kwargs, - is_grad_enabled, - *self._basic_op_params, + *self._flat_basic_op_params, *extra_inputs, ) return forward_func(*args) diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 8490019e5..740fbd50d 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -15,9 +15,6 @@ from transformer_engine.common.recipe import Recipe from ..fp8 import ( - MXFP8BlockScalingRecipeState, - DelayedScalingRecipeState, - Float8BlockScalingRecipeState, FP8GlobalStateManager, RecipeState, fp8_autocast, @@ -65,18 +62,14 @@ class FusibleOperation(torch.nn.Module, metaclass=abc.ABCMeta): def is_fused_op(self) -> bool: """Whether this op is the fusion of one or more basic ops""" - def pre_first_forward( - self, - *, - recipe: Optional[Recipe], - ) -> None: - """Preprocessing before forward pass""" + def pre_first_fuser_forward(self) -> None: + """Preprocessing before first fuser forward pass""" def get_input_quantizer(self) -> Optional[Quantizer]: """Get builder class for quantized input tensor""" - def get_grad_input_quantizer(self) -> Optional[Quantizer]: - """Get builder class for quantized input's grad tensor""" + def get_grad_output_quantizer(self) -> Optional[Quantizer]: + """Get builder class for quantized output's grad tensor""" def fuser_forward( self, @@ -84,7 +77,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: @@ -104,8 +97,8 @@ def fuser_forward( Input tensor basic_op_extra_inputs: list of torch.Tensor Extra tensor inputs to basic operations - prev_op_grad_input_quantizer: Quantizer, optional - The grad_input_quantizer of the preceeding operation + prev_op_grad_output_quantizer: Quantizer, optional + The grad_output_quantizer of the preceeding operation next_op_input_quantizer: Quantizer, optional The input_quantizer of the following operation basic_op_kwargs: list of dict @@ -186,8 +179,11 @@ def __init__(self) -> None: super().__init__() # Objects for quantization - self._quantizers: Optional[dict[str, list[Quantizer]]] = None self._fp8_metas: Optional[dict[str, dict[str, Any]]] = None + self._quantizers: Optional[dict[str, list[Quantizer]]] = None + with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() + recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None + self.reset_recipe_type(recipe=recipe) @property def is_fused_op(self) -> bool: @@ -214,120 +210,90 @@ def get_input_quantizer(self) -> Optional[Quantizer]: return self.get_quantizer("forward", 0) return None - def get_grad_input_quantizer(self) -> Optional[Quantizer]: + def get_grad_output_quantizer(self) -> Optional[Quantizer]: if self.num_quantizers("backward") > 0: return self.get_quantizer("backward", 0) return None - def _reset_quantization_recipe_state( + def reset_recipe_type( self, *, - recipe: Recipe, + recipe: Optional[Recipe], ) -> None: """Construct state for quantization recipe""" - # Quantization recipe state for forward and backward pass - self._fp8_metas = {"forward": None, "backward": None} - self._quantizers = {"forward": [], "backward": []} - for mode in ("forward", "backward"): - num_quantizers = self.num_quantizers(mode) - if num_quantizers == 0: - continue + # Clear quantization state if necessary + if recipe is None: + self._fp8_metas = None + self._quantizers = None + return - if recipe.float8_block_scaling(): - raise NotImplementedError( - "Fusible operations do not support FP8 block scaling recipe" + # Skip resetting recipe type if it did not actually change. + # This could happen for example if calling BasicOperation.forward directly, as in that + # case, the OperationFuser is not persistent, or when loading from a checkpoint + need_to_reset_recipe_state = False + if self._fp8_metas is None or self._quantizers is None: + need_to_reset_recipe_state = True + else: + for mode in ("forward", "backward"): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), ) + if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: + continue + recipe_state = self._fp8_metas[mode][fp8_meta_key] + if not isinstance(recipe, type(recipe_state.recipe)): + need_to_reset_recipe_state = True + break + + if need_to_reset_recipe_state: + # Quantization recipe state for forward and backward pass + self._fp8_metas = {"forward": None, "backward": None} + self._quantizers = {"forward": [], "backward": []} + for mode in ("forward", "backward"): + num_quantizers = self.num_quantizers(mode) + if num_quantizers == 0: + continue - # Construct quantization recipe state - recipe_state = RecipeState.create( - recipe, - mode=mode, - num_quantizers=num_quantizers, - ) - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=(mode == "forward"), - ) - self._fp8_metas[mode] = { - fp8_meta_key: recipe_state, - "recipe": recipe, - "fp8_group": FP8GlobalStateManager.get_fp8_group(), - } + if recipe.float8_block_scaling(): + raise NotImplementedError( + "Fusible operations do not support FP8 block scaling recipe" + ) - # Construct builder class for quantized tensors - self._quantizers[mode] = recipe_state.make_quantizers() + # Construct quantization recipe state + recipe_state = RecipeState.create( + recipe, + mode=mode, + num_quantizers=num_quantizers, + ) + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + self._fp8_metas[mode] = { + fp8_meta_key: recipe_state, + "recipe": recipe, + "fp8_group": FP8GlobalStateManager.get_fp8_group(), + } - def _update_quantization_recipe_state( - self, - *, - recipe: Recipe, - ) -> None: - """Make sure quantizer state matches quantization recipe""" + # Construct builder class for quantized tensors + self._quantizers[mode] = recipe_state.make_quantizers() - # Reset quantization state if needed - if self._fp8_metas is None or self._quantizers is None: - self._reset_quantization_recipe_state(recipe=recipe) - return + # Add meta tensors to global buffer to participate in reduction for mode in ("forward", "backward"): - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=(mode == "forward"), - ) - if self._fp8_metas[mode] is None or fp8_meta_key not in self._fp8_metas[mode]: - continue - recipe_state = self._fp8_metas[mode][fp8_meta_key] - need_to_reset_recipe_state = ( - (recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)) - or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState)) - or ( - recipe.float8_block_scaling() - and not isinstance(recipe_state, Float8BlockScalingRecipeState) + if ( + FP8GlobalStateManager.is_fp8_enabled() + and self.num_quantizers(mode) + and not FP8GlobalStateManager.fp8_graph_capturing() + ): + FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( + self._fp8_metas[mode], ) - ) - if need_to_reset_recipe_state: - self._reset_quantization_recipe_state(recipe=recipe) - return - - # Quantization recipe state for forward and backward pass - for mode in ("forward", "backward"): - num_quantizers = self.num_quantizers(mode) - if num_quantizers == 0: - continue - - # Update FP8 metadata - fp8_meta = self._fp8_metas[mode] - fp8_meta["recipe"] = recipe - fp8_meta["fp8_group"] = FP8GlobalStateManager.get_fp8_group() - - # Get recipe state - fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( - forward=(mode == "forward"), - ) - recipe_state = fp8_meta[fp8_meta_key] - - # Reallocate amax history if needed - if not recipe.delayed(): - continue - - current_length = recipe_state.amax_history.size(0) - target_length = recipe.amax_history_len - if current_length != target_length: - with torch.no_grad(): - if target_length < current_length: - recipe_state.amax_history = recipe_state.amax_history[ - :target_length - ].clone() - else: - recipe_state.amax_history = torch.nn.functional.pad( - recipe_state.amax_history, - pad=(0, 0, 0, target_length - current_length), - ) - self._quantizers[mode] = recipe_state.make_quantizers() def get_quantizer( self, mode: str, index: int, - ) -> Quantizer: + ) -> Optional[Quantizer]: """Get builder class for quantized tensor Parameters @@ -337,7 +303,7 @@ def get_quantizer( """ if self._quantizers is None: - self._reset_quantization_recipe_state(recipe=FP8GlobalStateManager.get_fp8_recipe()) + return None return self._quantizers[mode][index] @torch.no_grad() @@ -388,33 +354,13 @@ def _load_fp8_metas(self, fp8_metas: Optional[dict[str, Any]]) -> None: self._fp8_metas[mode][fp8_meta_key].scale.copy_(scale) self._fp8_metas[mode][fp8_meta_key].amax_history.copy_(amax_history) - def pre_first_forward( - self, - *, - recipe: Optional[Recipe], - ) -> None: - """Preprocessing before forward pass""" - - # Initialize FP8 metadata if needed - if recipe is not None: - self._update_quantization_recipe_state(recipe=recipe) - if not FP8GlobalStateManager.fp8_graph_capturing(): - if self.num_quantizers("forward"): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self._fp8_metas["forward"], - ) - if self.num_quantizers("backward"): - FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( - self._fp8_metas["backward"], - ) - @abc.abstractmethod def op_forward( self, ctx: OperationContext, input_: torch.Tensor, *, - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], **kwargs: Any, ) -> torch.Tensor: @@ -426,8 +372,8 @@ def op_forward( Context to coordinate between forward and backward passes input_: torch.Tensor Input tensor - prev_op_grad_input_quantizer: Quantizer, optional - The grad_input_quantizer of the preceeding operation + prev_op_grad_output_quantizer: Quantizer, optional + The grad_output_quantizer of the preceeding operation next_op_input_quantizer: Quantizer, optional The input_quantizer of the following operation @@ -468,7 +414,7 @@ def fuser_forward( input_: torch.Tensor, *, basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], - prev_op_grad_input_quantizer: Optional[Quantizer], + prev_op_grad_output_quantizer: Optional[Quantizer], next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, list[tuple[()]]]: @@ -482,7 +428,7 @@ def fuser_forward( output = self.op_forward( basic_op_ctxs[0], input_, - prev_op_grad_input_quantizer=prev_op_grad_input_quantizer, + prev_op_grad_output_quantizer=prev_op_grad_output_quantizer, next_op_input_quantizer=next_op_input_quantizer, **basic_op_kwargs[0], ) @@ -518,9 +464,7 @@ def forward( """Apply operation""" from .fuser import OperationFuser - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None - return OperationFuser([self], fuse_ops=False, recipe=recipe)( + return OperationFuser([self])( input, *extra_inputs, basic_op_kwargs=[kwargs], @@ -630,7 +574,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: # Get op's quantizer state, initializing if needed if self._fp8_metas is None or self._fp8_metas[mode] is None: with fp8_autocast(fp8_recipe=state[mode]["recipe"]): - self._reset_quantization_recipe_state(recipe=state[mode]["recipe"]) + self.reset_recipe_type(recipe=state[mode]["recipe"]) fp8_meta = self._fp8_metas[mode] # Load extra items @@ -708,13 +652,13 @@ def is_fused_op(self) -> bool: def get_input_quantizer(self) -> Optional[Quantizer]: return self.basic_ops[0].get_input_quantizer() - def get_grad_input_quantizer(self) -> Optional[Quantizer]: - return self.basic_ops[-1].get_grad_input_quantizer() + def get_grad_output_quantizer(self) -> Optional[Quantizer]: + return self.basic_ops[-1].get_grad_output_quantizer() - def pre_first_forward(self, *args, **kwargs) -> None: - """Preprocessing before forward pass""" + def pre_first_fuser_forward(self) -> None: + """Preprocessing before first fuser forward pass""" for op in self.basic_ops: - op.pre_first_forward(*args, **kwargs) + op.pre_first_fuser_forward() def forward( self, @@ -727,9 +671,7 @@ def forward( basic_op_kwargs = [{} for _ in range(len(self.basic_ops))] from .fuser import OperationFuser - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None - return OperationFuser([self], fuse_ops=False, recipe=recipe)( + return OperationFuser([self])( input, *extra_inputs, basic_op_kwargs=basic_op_kwargs, diff --git a/transformer_engine/pytorch/ops/sequential.py b/transformer_engine/pytorch/ops/sequential.py index f18678309..2afda58e4 100644 --- a/transformer_engine/pytorch/ops/sequential.py +++ b/transformer_engine/pytorch/ops/sequential.py @@ -10,7 +10,6 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager, Recipe from transformer_engine.pytorch.ops.op import FusibleOperation from transformer_engine.pytorch.ops.fuser import OperationFuser @@ -147,7 +146,6 @@ def __add__(self, modules: Iterable[torch.nn.Modules]) -> Sequential: def _make_module_groups( cls, modules: Iterable[torch.nn.Module], - recipe: Optional[Recipe], ) -> list[OperationFuser | torch.nn.Module]: """Make list of modules, with fusible operations grouped together""" @@ -162,24 +160,7 @@ def _make_module_groups( groups.append(module) for idx, group in enumerate(groups): if isinstance(group, list): - groups[idx] = OperationFuser(group, fuse_ops=True, recipe=recipe) - - # Check if operations expect extra input or output tensors - # Note: If any op has extra inputs or outputs, then the entire - # Sequential must be made up of TE ops. - if len(groups) > 1: - ops = [] - for group in groups: - if isinstance(group, OperationFuser): - ops.extend(group._basic_ops) - num_extra_inputs = sum(op.num_extra_inputs for op in ops) - num_extra_outputs = sum(op.num_extra_outputs for op in ops) - if num_extra_inputs > 0 or num_extra_outputs > 0: - raise RuntimeError( - f"`Sequential` expects {num_extra_inputs} extra inputs " - f"and {num_extra_outputs} extra outputs, " - "but it contains non-fusible operations" - ) + groups[idx] = OperationFuser(group) return groups @@ -190,22 +171,28 @@ def forward( ) -> torch.Tensor | tuple[torch.Tensor, ...]: """Forward pass""" - # Get current global state - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None - global_state = (with_quantized_compute, type(recipe)) - - # Reset module groups is global state changed - if self._last_global_state != global_state: - self._module_groups = None - self._last_global_state = global_state - # Create module groups if needed if self._module_groups is None: - self._module_groups = self._make_module_groups(self._modules.values(), recipe) + self._module_groups = self._make_module_groups(self._modules.values()) # Forward pass for each module group x = input + extra_outputs: list[torch.Tensor] = [] for module_group in self._module_groups: - x = module_group(x, *extra_inputs) + if isinstance(module_group, OperationFuser): + xs, extra_inputs = ( + (x,) + extra_inputs[: module_group.num_extra_inputs], + extra_inputs[module_group.num_extra_inputs :], + ) + xs = module_group(*xs) + if isinstance(xs, tuple): + x, ys = xs[0], xs[1:] + extra_outputs.extend(ys) + else: + x = xs + else: + x = module_group(x) + + if extra_outputs: + return (x,) + tuple(extra_outputs) return x From d1967d5504eb96b3be04448ba6d3f0e31f7935a6 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Wed, 23 Jul 2025 08:54:42 +1200 Subject: [PATCH 009/153] fix: Add stream synchronization before destroying MPI communicator (#1979) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 40595ea98..38a6e3e61 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -133,14 +133,21 @@ CommOverlapCore::~CommOverlapCore() { if (_atomic_gemm) cudaFree(_counter.dptr()); - for (size_t i = 0; i < _stream_compute.size(); i++) cudaStreamDestroy(_stream_compute[i]); + for (size_t i = 0; i < _stream_compute.size(); i++) { + cudaStreamSynchronize(_stream_compute[i]); + cudaStreamDestroy(_stream_compute[i]); + } if (_comm_created) { + try { #ifdef NVTE_UB_WITH_MPI - destroy_communicator_mpi(_ub_comm); + destroy_communicator_mpi(_ub_comm); #else - destroy_communicator(_ub_comm); + destroy_communicator(_ub_comm); #endif + } catch (const std::exception &e) { + NVTE_WARN("Error destroying communicator, cleanup may be incomplete:\n", e.what()); + } _comm_created = false; } } From fdb87afc686ac385cd01f9fd260911ab79f64803 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 23 Jul 2025 15:08:27 -0700 Subject: [PATCH 010/153] [PyTorch] Reset recipe state in fusible operations when FP8 amax history length changes (#1985) * Fix bug where TE ops were not updating fp8_meta dicts Signed-off-by: Tim Moon * Rename reset_recipe_state function Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update error message when initializing meta device quantized weight without recipe Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/graph.py | 8 +++- .../pytorch/ops/basic/basic_linear.py | 10 ++--- transformer_engine/pytorch/ops/fuser.py | 35 ++++++++-------- transformer_engine/pytorch/ops/op.py | 40 ++++++++++++++++--- 4 files changed, 66 insertions(+), 27 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 6152d3aa7..4a2b2c61c 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -712,6 +712,12 @@ def new_fwd(*user_args, **user_kwargs): elif isinstance(m, BasicOperation): for mode in ("forward", "backward"): if m.num_quantizers(mode): + m._fp8_metas[mode][ + "fp8_group" + ] = FP8GlobalStateManager.get_fp8_group() + m._fp8_metas[mode][ + "recipe" + ] = FP8GlobalStateManager.get_fp8_recipe() FP8GlobalStateManager.add_fp8_tensors_to_global_buffer( m._fp8_metas[mode], ) @@ -756,7 +762,7 @@ def save_fp8_tensors( m.adjust_amax_history_length(fp8_recipe.amax_history_len) module_tensors = m.get_fp8_meta_tensors() elif isinstance(m, BasicOperation): - m.reset_recipe_type(recipe=fp8_recipe) + m.reset_recipe_state(recipe=fp8_recipe) module_tensors = m._save_fp8_metas() fp8_tensors.append(module_tensors) return fp8_tensors diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 5f5b38184..383efc823 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -294,9 +294,9 @@ def reset_parameters(self) -> None: raise RuntimeError( "Tried to quantize weight with deferred initialization " "due to meta device, but no quantizer was available. " - "This is most likely because fp8_model_init was called " - "with enabled=True and recipe=None, instead of providing " - "a recipe to use for quantization." + "This is most likely because the weight was initialized " + "within fp8_model_init, but the forward pass was not " + "performed within fp8_autocast." ) quantizer.set_usage( rowwise=True, @@ -315,8 +315,8 @@ def pre_first_fuser_forward(self) -> None: if self.weight.device.type == "meta": self.reset_parameters() - def reset_recipe_type(self, *, recipe: Optional[Recipe]) -> None: - super().reset_recipe_type(recipe=recipe) + def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: + super().reset_recipe_state(recipe=recipe) if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters(): # Make quantizers use internal tensors diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index e618da349..19e7bb31a 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -398,27 +398,30 @@ def maybe_fuse_ops( break # Early exit if fusion parameters haven't changed + need_reset = False recipe_type = type(recipe) fusion_params = (recipe_type, first_op_requiring_backward) - if fusion_params == (self.recipe_type, self.first_op_requiring_backward): + if fusion_params != (self.recipe_type, self.first_op_requiring_backward): + # Recipe type or grad requirmenets have changed + need_reset = True + elif ( + recipe is not None + and recipe.delayed() + and self._last_amax_history_len != recipe.amax_history_len + ): + # FP8 delayed scaling has changed amax history length + need_reset = True + if not need_reset: return - # Initialize ops if recipe type has changed - if self.recipe_type != recipe_type: - # Check if this is the first iteration - if self.recipe_type is None: - for op in self._basic_ops: - op.pre_first_fuser_forward() - # Inform ops that the recipe type has changed + # Reset recipe state + for op in self._basic_ops: + op.reset_recipe_state(recipe=recipe) + + # Check if this is the first iteration + if self.recipe_type is None: for op in self._basic_ops: - op.reset_recipe_type(recipe=recipe) - # Check if amax history was invalidated - elif isinstance(recipe, DelayedScaling): - if recipe.amax_history_len != self._last_amax_history_len: - raise RuntimeError( - "Detected change of amax history length. " - "Changing the length of amax history is currently not supported." - ) + op.pre_first_fuser_forward() # Prepare basic op lists for fusions forward_ops = [(op, [idx]) for idx, op in enumerate(self._basic_ops)] diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index 740fbd50d..c2efc5169 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -183,7 +183,7 @@ def __init__(self) -> None: self._quantizers: Optional[dict[str, list[Quantizer]]] = None with_fp8_parameters = FP8GlobalStateManager.with_fp8_parameters() recipe = FP8GlobalStateManager.get_fp8_recipe() if with_fp8_parameters else None - self.reset_recipe_type(recipe=recipe) + self.reset_recipe_state(recipe=recipe) @property def is_fused_op(self) -> bool: @@ -215,7 +215,7 @@ def get_grad_output_quantizer(self) -> Optional[Quantizer]: return self.get_quantizer("backward", 0) return None - def reset_recipe_type( + def reset_recipe_state( self, *, recipe: Optional[Recipe], @@ -228,6 +228,9 @@ def reset_recipe_type( self._quantizers = None return + # Communication group for FP8 amax reductions + fp8_group = FP8GlobalStateManager.get_fp8_group() + # Skip resetting recipe type if it did not actually change. # This could happen for example if calling BasicOperation.forward directly, as in that # case, the OperationFuser is not persistent, or when loading from a checkpoint @@ -247,7 +250,7 @@ def reset_recipe_type( break if need_to_reset_recipe_state: - # Quantization recipe state for forward and backward pass + # Construct quantization recipe states self._fp8_metas = {"forward": None, "backward": None} self._quantizers = {"forward": [], "backward": []} for mode in ("forward", "backward"): @@ -272,11 +275,38 @@ def reset_recipe_type( self._fp8_metas[mode] = { fp8_meta_key: recipe_state, "recipe": recipe, - "fp8_group": FP8GlobalStateManager.get_fp8_group(), + "fp8_group": fp8_group, } # Construct builder class for quantized tensors self._quantizers[mode] = recipe_state.make_quantizers() + else: + # Update quantization recipe states + for mode in ("forward", "backward"): + if self._fp8_metas[mode] is None: + continue + self._fp8_metas[mode]["recipe"] = recipe + self._fp8_metas[mode]["fp8_group"] = fp8_group + + # Update amax history for FP8 delayed scaling + if recipe.delayed(): + fp8_meta_key = FP8GlobalStateManager.get_meta_tensor_key( + forward=(mode == "forward"), + ) + recipe_state = self._fp8_metas[mode][fp8_meta_key] + current_length = recipe_state.amax_history.size(0) + target_length = recipe.amax_history_len + if target_length < current_length: + with torch.no_grad(): + recipe_state.amax_history = recipe_state.amax_history[ + :target_length + ].clone() + elif target_length > current_length: + with torch.no_grad(): + recipe_state.amax_history = torch.nn.functional.pad( + recipe_state.amax_history, + pad=(0, 0, 0, target_length - current_length), + ) # Add meta tensors to global buffer to participate in reduction for mode in ("forward", "backward"): @@ -574,7 +604,7 @@ def copy_tensor(src: torch.Tensor, dst: torch.Tensor) -> None: # Get op's quantizer state, initializing if needed if self._fp8_metas is None or self._fp8_metas[mode] is None: with fp8_autocast(fp8_recipe=state[mode]["recipe"]): - self.reset_recipe_type(recipe=state[mode]["recipe"]) + self.reset_recipe_state(recipe=state[mode]["recipe"]) fp8_meta = self._fp8_metas[mode] # Load extra items From 4296b7d0bcd60303c63007601abba313686f112d Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Wed, 23 Jul 2025 16:37:54 -0700 Subject: [PATCH 011/153] Fix the device for cuDNN/cuBLAS handles (#1974) * fix current device for cuDNN/cuBLAS handles Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add unit test Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use weight device and improve tests Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + tests/pytorch/distributed/test_sanity.py | 121 ++++++++++++++++++ .../dot_product_attention.py | 2 +- .../pytorch/module/grouped_linear.py | 4 +- .../pytorch/module/layernorm_linear.py | 4 +- .../pytorch/module/layernorm_mlp.py | 4 +- transformer_engine/pytorch/module/linear.py | 4 +- 7 files changed, 135 insertions(+), 5 deletions(-) create mode 100644 tests/pytorch/distributed/test_sanity.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index f0436d4ff..d7a4f054f 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -23,6 +23,7 @@ mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/distributed/test_numerics.py || test_fail "test_numerics.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_PATH/tests/pytorch/distributed/test_torch_fsdp2.py || test_fail "test_torch_fsdp2.py" diff --git a/tests/pytorch/distributed/test_sanity.py b/tests/pytorch/distributed/test_sanity.py new file mode 100644 index 000000000..39494a92b --- /dev/null +++ b/tests/pytorch/distributed/test_sanity.py @@ -0,0 +1,121 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import pathlib +import sys +import pytest +import torch +import transformer_engine +from transformer_engine.pytorch.attention.dot_product_attention import DotProductAttention +from transformer_engine.pytorch import TransformerLayer, Linear + +_current_file = pathlib.Path(__file__).resolve() +sys.path.append(str(_current_file.parent.parent)) +from utils import ModelConfig + +model_configs = { + "small": ModelConfig(2, 10, 2, 16), +} + + +@pytest.mark.parametrize("model", ["small"]) +@pytest.mark.parametrize("module", ["TransformerLayer", "DotProductAttention", "Linear"]) +def test_current_device(model, module): + """Test cases where current device is different from tensor device""" + + num_devices = torch.cuda.device_count() + assert num_devices > 1, "This test requires more than one GPU!" + tensor_device = num_devices - 1 + dtype = torch.bfloat16 + config = model_configs[model] + + args = [] + kwargs = {} + bwd_args = [] + if module == "TransformerLayer": + model = TransformerLayer( + config.hidden_size, + 4 * config.hidden_size, + config.num_heads, + params_dtype=dtype, + attn_input_format="thd", + self_attn_mask_type="padding", + device=f"cuda:{tensor_device}", + ) + num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() + args = [ + torch.randn( + (num_tokens, config.hidden_size), + dtype=dtype, + device=f"cuda:{tensor_device}", + requires_grad=True, + ) + ] + cu_seqlens_q, cu_seqlens_kv = [ + torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2) + ] + kwargs["cu_seqlens_q"] = cu_seqlens_q + kwargs["cu_seqlens_kv"] = cu_seqlens_kv + kwargs["max_seqlen_q"] = config.max_seqlen_q + kwargs["max_seqlen_kv"] = config.max_seqlen_kv + if module == "DotProductAttention": + model = DotProductAttention( + config.num_heads, config.head_dim_qk, qkv_format="thd", attn_mask_type="padding" + ) + num_tokens = torch.randint(0, config.max_seqlen_q, (1,)).item() + args = [ + torch.randn( + num_tokens, + config.num_heads, + config.head_dim_qk, + dtype=dtype, + device=tensor_device, + requires_grad=True, + ) + for _ in range(3) + ] + cu_seqlens_q, cu_seqlens_kv = [ + torch.Tensor([0, 2, 3]).to(dtype=torch.int32, device=tensor_device) for _ in range(2) + ] + kwargs["cu_seqlens_q"] = cu_seqlens_q + kwargs["cu_seqlens_kv"] = cu_seqlens_kv + kwargs["max_seqlen_q"] = config.max_seqlen_q + kwargs["max_seqlen_kv"] = config.max_seqlen_kv + bwd_args = [torch.randn(num_tokens, config.hidden_size, dtype=dtype, device=tensor_device)] + elif module == "Linear": + model = Linear( + config.hidden_size, + 4 * config.hidden_size, + params_dtype=dtype, + device=f"cuda:{tensor_device}", + ) + args = [ + torch.randn( + (config.max_seqlen_q, config.batch_size, config.hidden_size), + dtype=dtype, + device=f"cuda:{tensor_device}", + requires_grad=True, + ) + ] + + current_device_before = torch.cuda.current_device() + out = model(*args, **kwargs) + if module == "DotProductAttention": + out.backward(*bwd_args) + else: + loss = out.sum() + loss.backward() + current_device_after = torch.cuda.current_device() + tensor_device_out = out.get_device() + tensor_device_grad = args[0].grad.get_device() + + assert ( + current_device_after == current_device_before + ), "The current device should not have changed!" + assert ( + tensor_device_out == tensor_device + ), "The output tensor should be the same as the input tensors!" + assert ( + tensor_device_grad == tensor_device + ), "The gradient tensor should be the same as the input tensors!" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py index 893e2d228..b35b87a83 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/dot_product_attention.py @@ -630,7 +630,7 @@ def forward( If true, there are padding tokens between individual sequences in a packed batch. """ - with self.prepare_forward( + with torch.cuda.device(query_layer.device), self.prepare_forward( query_layer, num_gemms=3, allow_non_contiguous=True, diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index da66e68b4..cc472390f 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -742,7 +742,9 @@ def forward( if skip_fp8_weight_update is not None: is_first_microbatch = False - with self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward(inp, num_gemms=self.num_gemms) as inp: weight_tensors = self._get_weight_tensors() bias_tensors = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index a044894d7..659fcd0e1 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1484,7 +1484,9 @@ def forward( if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward( inp, allow_non_contiguous=False # removed .contiguous from inside the layer ) as inp: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index ec3f4be25..cec74aa81 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1740,7 +1740,9 @@ def forward( if get_ub("fc2_fprop").is_fp8_ubuf(): fp8_output = True - with self.prepare_forward(inp, num_gemms=2) as inp: + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward(inp, num_gemms=2) as inp: quantizers = ( self._get_quantizers(fp8_output) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index b1d4196df..5b657e848 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1353,7 +1353,9 @@ def forward( if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): fp8_grad = True - with self.prepare_forward( + with torch.cuda.device( + getattr(self, list(self.named_parameters())[0][0]).device + ), self.prepare_forward( inp, allow_non_contiguous=isinstance(inp, QuantizedTensor), ) as inp: From 992ba01d4aacdd2a59e8f22d7c04d23a6c020752 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 23 Jul 2025 16:47:28 -0700 Subject: [PATCH 012/153] [JAX] Fix current scaling test_helper.py and enable test_helper.py in L0 (#1990) Fix current scaling test_helper.py and enable test_helper.py in L0 Signed-off-by: Jeremy Berchtold --- qa/L0_jax_unittest/test.sh | 2 +- tests/jax/test_helper.py | 7 +++---- 2 files changed, 4 insertions(+), 5 deletions(-) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index 3d00e0346..ab1148505 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -25,7 +25,7 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' --ignore=$TE_PATH/tests/jax/test_helper.py || test_fail "tests/jax/*not_distributed_*" +python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_jax_not_distributed.xml $TE_PATH/tests/jax -k 'not distributed' || test_fail "tests/jax/*not_distributed_*" pip3 install -r $TE_PATH/examples/jax/mnist/requirements.txt || error_exit "Failed to install mnist requirements" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_mnist.xml $TE_PATH/examples/jax/mnist || test_fail "mnist" diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index e237318a4..d0a3efd27 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -58,7 +58,6 @@ def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) def _compare_current_scaling(self, test): - self.assertEqual(QuantizeConfig.MARGIN, test.margin) self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING) @@ -91,7 +90,7 @@ def test_fp8_autocast_delayed_scaling(self): self._check_default_state() - @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) + @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_current_scaling(self): QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() @@ -101,14 +100,14 @@ def test_fp8_autocast_current_scaling(self): self._check_default_state() - cs = Float8CurrentScaling(margin=5.0, fp8_format=FP8Format.E4M3) + cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) with fp8_autocast(enabled=True, fp8_recipe=cs): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() - cs = Float8CurrentScaling(margin=3.0, fp8_format=FP8Format.HYBRID) + cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) with fp8_autocast(enabled=True, fp8_recipe=cs): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_current_scaling(cs) From 2a2934567d296fcf924da644fd2dfdf44758ef05 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 23 Jul 2025 20:00:46 -0400 Subject: [PATCH 013/153] [JAX] Helper to disable TE custom calls + disable GemmPrimitive for non-MXFP8 recipes. (#1962) * add manage_primitives() helper * disable GEMM primitives for non-MXFP8 recipes * implement the NVTE_JAX_CUSTOM_CALLS + deprecate NVTE_JAX_CUSTOM_CALLS_RE * replace NVTE_JAX_CUSTOM_CALLS_RE with NVTE_JAX_CUSTOM_CALLS in TE tests and examples * fix use_jax_gemm contextmanager Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- qa/L0_jax_unittest/test.sh | 2 +- qa/L2_jax_unittest/test.sh | 2 +- tests/jax/test_custom_call_compute.py | 9 -- tests/jax/utils.py | 10 +- .../jax/cpp_extensions/activation.py | 4 +- transformer_engine/jax/cpp_extensions/base.py | 133 ++++++++++++++++-- .../jax/cpp_extensions/quantization.py | 4 +- transformer_engine/jax/quantize/helper.py | 3 + 8 files changed, 138 insertions(+), 29 deletions(-) diff --git a/qa/L0_jax_unittest/test.sh b/qa/L0_jax_unittest/test.sh index ab1148505..e4a3f4630 100644 --- a/qa/L0_jax_unittest/test.sh +++ b/qa/L0_jax_unittest/test.sh @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS_RE="" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" +NVTE_JAX_CUSTOM_CALLS="false" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py without custom calls" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/qa/L2_jax_unittest/test.sh b/qa/L2_jax_unittest/test.sh index c5c193351..f933a0732 100644 --- a/qa/L2_jax_unittest/test.sh +++ b/qa/L2_jax_unittest/test.sh @@ -36,7 +36,7 @@ export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" # Test without custom calls export XLA_FLAGS="${XLA_FLAGS} --xla_gpu_deterministic_ops" -NVTE_JAX_CUSTOM_CALLS_RE="" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" +NVTE_JAX_CUSTOM_CALLS="false" NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest_test_single_gpu_encoder.xml $TE_PATH/examples/jax/encoder/test_single_gpu_encoder.py || test_fail "test_single_gpu_encoder.py" if [ $RET -ne 0 ]; then echo "Error: some sub-tests failed: $FAILED_CASES" diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1e1467521..aa243be62 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -863,15 +863,6 @@ def test_quantize_dact_dbias_mxfp8_scaling( ] -def _use_jax_fp8_gemm(enabled=False): - import os - - if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" - elif "NVTE_JAX_CUSTOM_CALLS_RE" in os.environ: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") - - class TestDense: def _ref_gemm_with_jnp_dot(self, a, b, data_layout): if data_layout[0] == "T": diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 13b2b9148..8ad6dccfe 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -1604,16 +1604,18 @@ def print_debug_tensor_stats(prefix, tensor, hist=False): @contextmanager def use_jax_gemm(enabled=False): - orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS_RE", None) + orig_custom_calls_filter = os.environ.get("NVTE_JAX_CUSTOM_CALLS", None) try: if enabled: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = "^(?!GemmPrimitive$).+$" + os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=false" + else: + os.environ["NVTE_JAX_CUSTOM_CALLS"] = "GemmPrimitive=true" yield finally: if enabled: if orig_custom_calls_filter is None: - os.environ.pop("NVTE_JAX_CUSTOM_CALLS_RE") + os.environ.pop("NVTE_JAX_CUSTOM_CALLS") else: - os.environ["NVTE_JAX_CUSTOM_CALLS_RE"] = orig_custom_calls_filter + os.environ["NVTE_JAX_CUSTOM_CALLS"] = orig_custom_calls_filter diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index 57133f48a..b8dcca66c 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -915,11 +915,11 @@ def shardy_sharding_rule( class DActLuDBiasQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): - """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDActLuDBiasQuantizePrimitive for DBias and fused activation quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): - """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 13120f45a..fcc2108cc 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -4,6 +4,7 @@ """JAX/TE base custom ops""" import os import re +import warnings from abc import ABCMeta, abstractmethod from functools import partial from packaging import version @@ -30,19 +31,77 @@ class BasePrimitive(metaclass=ABCMeta): name = None + _is_enabled = True + + # Default list of primitives to disable for all recipes + _default_disable_names = ["GemmPrimitive"] + @classmethod def enabled(cls): """ - A custom call is marked as disabled if the `cls.__name__` does not fully match the - `NVTE_JAX_CUSTOM_CALLS_RE` pattern. - This uses the Python class name of the primitive definitions that inherit from BasePrimitive. - By default, `NVTE_JAX_CUSTOM_CALLS_RE` is set to `.*`, which matches and enables all names. - For example, set `NVTE_JAX_CUSTOM_CALLS_RE='^(?!DBiasQuantizePrimitive$).+$'` to disable `DBiasQuantizePrimitive`. + Determines if a custom call is enabled based on a state variable and environment variables. + Checks `NVTE_JAX_CUSTOM_CALLS` (key/value format) first, then falls back to the deprecated `NVTE_JAX_CUSTOM_CALLS_RE` (regex pattern), + and finally to the internal state `_is_enabled` if neither is set. + + Environment Variables: + 1. `NVTE_JAX_CUSTOM_CALLS`: Preferred key/value format to enable/disable specific primitives or a single value 'true' or 'false' to enable/disable all primitives. + - Example 1 (global enable): 'true' enables all primitives. + - Example 2 (global disable): 'false' disables all primitives. + - Example 3 (specific settings): 'DBiasQuantizePrimitive=false,GemmPrimitive=true' disables DBiasQuantizePrimitive and enables GemmPrimitive, leaving others at their default state. + Note that the default state is set at class level based on _default_disable_names. + 2. `NVTE_JAX_CUSTOM_CALLS_RE`: Deprecated regex pattern to match primitive names. + - Example: 'DBiasQuantizePrimitive' or '^(?!DBiasQuantizePrimitive$).+$' to enable/disable DBiasQuantizePrimitive. + - A deprecation warning is raised if used; it will be removed in future releases. + + Behavior: + 1. Checks if `NVTE_JAX_CUSTOM_CALLS` is set and parses key/value pairs or single true/false value. + 2. If not set, checks `NVTE_JAX_CUSTOM_CALLS_RE` (with deprecation warning) for regex matching. + 3. If neither is set, falls back to the internal state `_is_enabled`. """ - pattern = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE", r".*") - pattern = re.compile(pattern) - is_enabled = pattern.fullmatch(cls.__name__) is not None - return is_enabled + + # Check new key/value environment variable first + custom_calls_str = os.getenv("NVTE_JAX_CUSTOM_CALLS") + if custom_calls_str is not None: + custom_calls_str = custom_calls_str.strip() + if custom_calls_str.lower() == "true": + return True + if custom_calls_str.lower() == "false": + return False + + # Parse key=value pairs + settings = {} + for pair in custom_calls_str.split(","): + pair = pair.strip() + if "=" in pair: + key, value = pair.split("=", 1) + key = key.strip() + value = value.strip().lower() + settings[key] = value == "true" + if cls.__name__ in settings: + return settings[cls.__name__] + + # Check old regex environment variable (deprecated) + pattern_str = os.getenv("NVTE_JAX_CUSTOM_CALLS_RE") + if pattern_str is not None: + warnings.warn( + "NVTE_JAX_CUSTOM_CALLS_RE is deprecated and will be removed in future releases. Use" + " NVTE_JAX_CUSTOM_CALLS with key=value format instead (e.g.," + " 'DBiasQuantizePrimitive=false').", + DeprecationWarning, + ) + pattern = re.compile(pattern_str) + env_enabled = pattern.fullmatch(cls.__name__) is not None + return env_enabled + + # If no environment variable is set, fall back to the internal state + return cls._is_enabled + + @classmethod + def set_enabled(cls, enabled: bool): + """ + Sets the enabled state for this primitive. + """ + cls._is_enabled = enabled @staticmethod @abstractmethod @@ -109,10 +168,19 @@ def shardy_sharding_rule(*args): return "... -> ..." +# Registry to store all registered primitive classes +_primitive_registry = {} + + def register_primitive(cls): """ - register jax primitive + Register a JAX primitive and add it to the internal registry. """ + _primitive_registry[cls.__name__] = cls + + # Set default disabled state at class level based on _default_disable_names + if cls.__name__ in BasePrimitive._default_disable_names: + cls.set_enabled(False) def name_of_wrapper_p(): return cls.name + "_wrapper" @@ -145,3 +213,48 @@ def name_of_wrapper_p(): for _name, _value in transformer_engine_jax.registrations().items(): ffi.register_ffi_target(_name, _value, platform="CUDA") + + +def manage_primitives(enable_names=None, disable_names=None, disable_all_first=False): + """ + Helper function to manage primitive states by name without modifying environment variables. + Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. + This helper is used in the QuantizeConfig.initialize() methods. + + Args: + enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. + disable_names: List of strings, each representing the name of a primitive class to disable. Defaults to None. + disable_all_first: Boolean, if True, disables all primitives before applying enable/disable lists. Defaults to False. + + Note: + 1. If `disable_all_first` is True, all primitives are disabled first, then `enable_names` is applied. + 2. Conflicts (a primitive in both enable and disable lists) are resolved by applying disable last. + """ + + enable_set = set(enable_names or []) + disable_set = set(disable_names or []) + + if disable_all_first: + for name, cls in _primitive_registry.items(): + if ( + isinstance(cls, type) + and issubclass(cls, BasePrimitive) + and cls is not BasePrimitive + ): + cls.set_enabled(False) + + # Apply enables + for name in enable_set: + cls = _primitive_registry.get(name) + if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive): + cls.set_enabled(True) + else: + raise ValueError(f"Primitive not found in registry: {name}") + + # Apply disables (overrides enables if there's a conflict) + for name in disable_set: + cls = _primitive_registry.get(name) + if cls and isinstance(cls, type) and issubclass(cls, BasePrimitive): + cls.set_enabled(False) + else: + raise ValueError(f"Primitive not found in registry: {name}") diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 23e821b1a..a7697ce25 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -519,11 +519,11 @@ def shardy_sharding_rule( class DBiasQuantizePrimitive(BaseDBiasQuantizePrimitive): - """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDBiasQuantizePrimitive for DBias quantization. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" class QuantizePrimitive(BaseDBiasQuantizePrimitive): - """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS_RE.""" + """Subclass of BaseDBiasQuantizePrimitive for quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" def _jax_quantize( diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 122265ea2..e31f1852b 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -352,6 +352,9 @@ def initialize(fp8_recipe: recipe.Recipe) -> None: cls.initialize(fp8_recipe) cls.AMAX_HISTORY_LEN = 0 + # Use TE GEMM instead of JAX GEMM for better performance + tex.base.manage_primitives(enable_names=["GemmPrimitive"]) + @staticmethod def finalize() -> None: """Reset the block scaling configuration.""" From dab931a7aea1cc72fb480e20a083778aa4e44a4b Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Thu, 24 Jul 2025 02:14:41 +0200 Subject: [PATCH 014/153] [PyTorch] Improve L2Normalization basic op (#1964) * Increase intermediate precision and reuse tensors from fwd Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * JIT warmup only when required Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Recompute only rsqrt_norm Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Evgeny Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 3 -- transformer_engine/pytorch/jit.py | 36 +++++++++++++------ .../pytorch/ops/basic/l2normalization.py | 12 +++++-- 3 files changed, 35 insertions(+), 16 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 10e9dc5e7..2a0426e34 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1400,9 +1400,6 @@ def test_l2normalization( dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(y_test, y_ref, **tols) - # L2Norm backward pass requires slightly looser atol for bfloat16 - if dtype == torch.bfloat16: - tols["atol"] = 2e-3 torch.testing.assert_close(dx_test, x_ref.grad, **tols) @pytest.mark.parametrize("dtype", _dtypes) diff --git a/transformer_engine/pytorch/jit.py b/transformer_engine/pytorch/jit.py index 4ac2e15b8..f0f77621e 100644 --- a/transformer_engine/pytorch/jit.py +++ b/transformer_engine/pytorch/jit.py @@ -134,30 +134,43 @@ def dgelu_fused_(grad_output: torch.Tensor, inp: torch.Tensor) -> torch.Tensor: @jit_fuser def l2normalization_fused_(x: torch.Tensor, eps: float) -> torch.Tensor: """L2 normalization fused - inference version""" - x_squared = x.pow(2) + x_fp32 = x.float() + x_squared = x_fp32.pow(2) l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) - return x * rsqrt_norm + y_fp32 = x_fp32 * rsqrt_norm + return y_fp32.to(x.dtype) @jit_fuser def l2normalization_fwd_fused_(x: torch.Tensor, eps: float) -> tuple[torch.Tensor, torch.Tensor]: """L2 normalization fused - training version that returns intermediate values""" - x_squared = x.pow(2) + x_fp32 = x.float() + x_squared = x_fp32.pow(2) l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) - rsqrt_norm = torch.rsqrt(l2_norm_squared + eps) - y = x * rsqrt_norm + l2_norm_squared_eps = l2_norm_squared + eps + rsqrt_norm = torch.rsqrt(l2_norm_squared_eps) + y_fp32 = x_fp32 * rsqrt_norm + y = y_fp32.to(x.dtype) return y, rsqrt_norm @jit_fuser def l2normalization_backward_fused_( - grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float + grad_output: torch.Tensor, + x: torch.Tensor, + rsqrt_norm: torch.Tensor, + eps: float, ) -> torch.Tensor: """L2 normalization backward fused""" - x_dy_sum = (x * grad_output).sum(dim=-1, keepdim=True) - x_norm_squared = x.pow(2).sum(dim=-1, keepdim=True) + eps - return rsqrt_norm * (grad_output - x * x_dy_sum / x_norm_squared) + x_fp32 = x.float() + grad_output_fp32 = grad_output.float() + x_dy_sum = (x_fp32 * grad_output_fp32).sum(dim=-1, keepdim=True) + x_squared = x_fp32.pow(2) + l2_norm_squared = x_squared.sum(dim=-1, keepdim=True) + x_norm_squared = l2_norm_squared + eps + dx_fp32 = rsqrt_norm * (grad_output_fp32 - x_fp32 * x_dy_sum / x_norm_squared) + return dx_fp32.to(x.dtype) def bias_gelu_fused(inp: torch.Tensor, bias: torch.Tensor) -> torch.Tensor: @@ -191,7 +204,10 @@ def l2normalization_fwd_fused(x: torch.Tensor, eps: float) -> tuple[torch.Tensor def l2normalization_backward_fused( - grad_output: torch.Tensor, x: torch.Tensor, rsqrt_norm: torch.Tensor, eps: float + grad_output: torch.Tensor, + x: torch.Tensor, + rsqrt_norm: torch.Tensor, + eps: float, ) -> torch.Tensor: """Disable native AMP for l2normalization_backward_fused_""" with gpu_autocast_ctx(enabled=False): diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index 1e72475ad..a340e7d42 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -6,10 +6,12 @@ from __future__ import annotations from typing import Optional +import os import torch from ...utils import clear_tensor_data +from ... import torch_version from .._common import maybe_dequantize from ..op import BasicOperation, OperationContext from ...jit import ( @@ -60,7 +62,11 @@ def __init__( # JIT warmup for L2Normalization fused operations if seq_length and micro_batch_size: - if torch.cuda.is_available(): + if ( + torch.cuda.is_available() + and torch_version() >= (2, 0, 0) + and bool(int(os.getenv("NVTE_TORCH_COMPILE", "1"))) + ): set_jit_fusion_options() # For L2Normalization, we don't know the hidden size until forward pass, # but we can warm up with common sizes. For QK normalization, this will be @@ -86,7 +92,7 @@ def op_forward( # Compute L2 normalization using fused implementation # L2 norm: x / sqrt(sum(x^2) + eps) = x * rsqrt(sum(x^2) + eps) if requires_grad: - # Training: use version that returns both output and intermediate values + # Training: use version that returns output and intermediate values for backward pass y, rsqrt_norm = l2normalization_fwd_fused(x, self.eps) else: # Inference: use lightweight version that only returns output @@ -110,7 +116,7 @@ def op_backward( dy = maybe_dequantize(grad_output) - # Compute L2 norm backward pass using fused implementation + # Compute L2 norm backward pass using fused implementation - recalculates l2_norm_squared_eps dx = l2normalization_backward_fused(dy, x, rsqrt_norm, self.eps) # Clear saved tensors if possible From fe27bf1cb3699bd4c196126ac6594f39be2c0eb5 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 23 Jul 2025 20:53:00 -0400 Subject: [PATCH 015/153] Fix runtime lib loading for cuDNN (#1989) Fix cuDNN lib runtime loading and simplify Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/__init__.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 09a71a80d..834c4fe25 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -246,6 +246,18 @@ def _load_cudnn(): if found: return handle + # Attempt to locate libcudnn via ldconfig + libs = subprocess.check_output( + f"ldconfig -p | grep 'libcudnn{_get_sys_extension()}'", shell=True + ) + libs = libs.decode("utf-8").split("\n") + sos = [] + for lib in libs: + if "libcudnn" in lib and "=>" in lib: + sos.append(lib.split(">")[1].strip()) + if sos: + return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise return ctypes.CDLL(f"libcudnn{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) @@ -267,12 +279,12 @@ def _load_nvrtc(): return handle # Attempt to locate NVRTC via ldconfig - libs = subprocess.check_output("ldconfig -p | grep 'libnvrtc'", shell=True) + libs = subprocess.check_output( + f"ldconfig -p | grep 'libnvrtc{_get_sys_extension()}'", shell=True + ) libs = libs.decode("utf-8").split("\n") sos = [] for lib in libs: - if "stub" in lib or "libnvrtc-builtins" in lib: - continue if "libnvrtc" in lib and "=>" in lib: sos.append(lib.split(">")[1].strip()) if sos: From ee84108417de5b8f1a531b666d6d43966c4ee43c Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Wed, 23 Jul 2025 21:32:21 -0700 Subject: [PATCH 016/153] Add `in_place` kwarg to extra tensor ops (#1983) * Mark output tensors as not deletable in backward Signed-off-by: Jan Bielak * Add `in_place` kwarg to `MakeExtraOutput` Signed-off-by: Jan Bielak * Rename `AddInPlace` to `AddExtraInput` and add an `in_place` kwarg Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 46 ++++++++++--------- .../pytorch/ops/basic/__init__.py | 2 +- .../{add_in_place.py => add_extra_input.py} | 30 ++++++++---- .../pytorch/ops/basic/make_extra_output.py | 30 ++++++++---- .../pytorch/ops/fused/backward_linear_add.py | 2 + .../ops/fused/forward_linear_bias_add.py | 10 ++-- transformer_engine/pytorch/ops/fuser.py | 4 ++ 7 files changed, 80 insertions(+), 44 deletions(-) rename transformer_engine/pytorch/ops/basic/{add_in_place.py => add_extra_input.py} (70%) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 2a0426e34..40be0a75a 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -270,22 +270,22 @@ def test_extra_tensors(self, size: int = 16) -> None: bias = te_ops.Bias(size=size, device="cpu") with torch.no_grad(): bias.bias.copy_(torch.rand((size,))) - model = te_ops.Sequential( # | Inputs | Outputs - torch.nn.Identity(), # | x1 | x1 - te_ops.MakeExtraOutput(), # | x1 | x1 [x1] - bias, # | x1 | h1 (= x1 + b) - te_ops.MakeExtraOutput(), # | h1 | h1 [h1] - te_ops.AddInPlace(), # | h1 [x2] | x2 (= x2 + h1) - te_ops.MakeExtraOutput(), # | x2 | x2 [x2] - torch.nn.Identity(), # | x2 | x2 - bias, # | x2 | h2 (= x2 + b) - te_ops.AddInPlace(), # | h2 [x3] | x3 (= x3 + h2) - te_ops.MakeExtraOutput(), # | x3 | x3 [x3] - te_ops.AddInPlace(), # | x3 [x4] | x4 (= x4 + x3) - torch.nn.Identity(), # | x4 | x4 - te_ops.Identity(), # | x4 | x4 - te_ops.MakeExtraOutput(), # | x4 | x4 [x4] - te_ops.Identity(), # | x4 | x4 + model = te_ops.Sequential( # | Inputs | Outputs + torch.nn.Identity(), # | x1 | x1 + te_ops.MakeExtraOutput(in_place=True), # | x1 | x1 [x1] + bias, # | x1 | h1 (= x1 + b) + te_ops.MakeExtraOutput(in_place=True), # | h1 | h1 [h1] + te_ops.AddExtraInput(in_place=True), # | h1 [x2] | x2 (= x2 + h1) + te_ops.MakeExtraOutput(in_place=True), # | x2 | x2 [x2] + torch.nn.Identity(), # | x2 | x2 + bias, # | x2 | h2 (= x2 + b) + te_ops.AddExtraInput(in_place=True), # | h2 [x3] | x3 (= x3 + h2) + te_ops.MakeExtraOutput(in_place=True), # | x3 | x3 [x3] + te_ops.AddExtraInput(in_place=True), # | x3 [x4] | x4 (= x4 + x3) + torch.nn.Identity(), # | x4 | x4 + te_ops.Identity(), # | x4 | x4 + te_ops.MakeExtraOutput(in_place=True), # | x4 | x4 [x4] + te_ops.Identity(), # | x4 | x4 ) # Create input tensors @@ -1402,13 +1402,15 @@ def test_l2normalization( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("in_place", (True, False)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("quantization", _quantization_list) - def test_add_in_place( + def test_add_extra_input( self, *, in_shape: Iterable[int] = (32, 32), + in_place: bool, dtype: torch.dtype, device: torch.device, quantization: Optional[str], @@ -1454,7 +1456,7 @@ def test_add_in_place( dx2_ref = dy_ref # Implementation with fusible operation - op = te_ops.AddInPlace() + op = te_ops.AddExtraInput(in_place=in_place) y_test = op(x1_test, x2_test) y_test.backward(dy_test) @@ -1469,6 +1471,7 @@ def test_add_in_place( torch.testing.assert_close(dx1_test, dx1_ref, rtol=0, atol=0) torch.testing.assert_close(dx2_test, dx2_ref, rtol=0, atol=0) + @pytest.mark.parametrize("in_place", (True, False)) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @pytest.mark.parametrize("quantization", _quantization_list) @@ -1476,6 +1479,7 @@ def test_make_extra_output( self, *, in_shape: Iterable[int] = (32, 32), + in_place: bool, dtype: torch.dtype, device: torch.device, quantization: Optional[str], @@ -1521,7 +1525,7 @@ def test_make_extra_output( (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() # Implementation with fusible operation - op = te_ops.MakeExtraOutput() + op = te_ops.MakeExtraOutput(in_place=in_place) y1_test, y2_test = op(x_test) (y1_test * dy1_test + y2_test * dy2_test).sum().backward() @@ -1885,7 +1889,7 @@ def test_forward_linear_bias_add( device=device, dtype=dtype, ), - te_ops.AddInPlace(), + te_ops.AddExtraInput(in_place=True), ) with torch.no_grad(): model[0].weight.copy_(w_test) @@ -2077,7 +2081,7 @@ def test_backward_linear_add( recipe = make_recipe(quantization) with te.fp8_model_init(enabled=quantized_weight): model = te_ops.Sequential( - te_ops.MakeExtraOutput(), + te_ops.MakeExtraOutput(in_place=True), te_ops.Linear( in_features, out_features, diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index c69e3df02..e0e15b703 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -5,7 +5,7 @@ """Single tensor operations supported by the operation fuser.""" from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU -from .add_in_place import AddInPlace +from .add_extra_input import AddExtraInput from .all_gather import AllGather from .all_reduce import AllReduce from .basic_linear import BasicLinear diff --git a/transformer_engine/pytorch/ops/basic/add_in_place.py b/transformer_engine/pytorch/ops/basic/add_extra_input.py similarity index 70% rename from transformer_engine/pytorch/ops/basic/add_in_place.py rename to transformer_engine/pytorch/ops/basic/add_extra_input.py index 3a7f1843b..1fcfa0466 100644 --- a/transformer_engine/pytorch/ops/basic/add_in_place.py +++ b/transformer_engine/pytorch/ops/basic/add_extra_input.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. -"""Fusible operation for in-place add.""" +"""Fusible operation for adding extra input tensor.""" from __future__ import annotations from collections.abc import Iterable @@ -18,16 +18,17 @@ from transformer_engine.pytorch.tensor import Quantizer -class AddInPlace(BasicOperation): - """Add in-place +class AddExtraInput(BasicOperation): + """Add extra input tensor This operation requires an extra tensor input to the operation - fuser. The main input is added in-place to the extra input, and a - view of the extra input is output. + user. It returns the sum of the main input and the extra input. + If in_place=True, the main input is added in-place to the extra + input, and a view of the extra input is output. - This operation is considered an advanced feature and most users - are discouraged from using it. In-place operations break some - autograd assumptions and they can result in subtle, esoteric bugs. + Using this operation with in_place=True is considered an advanced + feature and most users are discouraged from it. In-place operations + break some autograd assumptions and they can result in subtle, esoteric bugs. Compare to `MakeExtraOutput`, which does a similar operation in the backward pass. @@ -37,6 +38,10 @@ class AddInPlace(BasicOperation): # Operation expects buffer for output tensor num_extra_inputs: int = 1 + def __init__(self, *, in_place: bool = False): + super().__init__() + self._in_place = in_place + def op_forward(self, *args, **kwargs) -> None: raise RuntimeError( "{self.__class__.__name__} operation has " @@ -63,8 +68,13 @@ def fuser_forward( next_op_input_quantizer: Optional[Quantizer], basic_op_kwargs: list[dict[str, Any]], ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: - output = basic_op_extra_inputs[0][0].detach() - output += input_ + extra_input = basic_op_extra_inputs[0][0] + if self._in_place: + extra_input = extra_input.detach() + extra_input += input_ + output = extra_input + else: + output = extra_input + input_ return output, [()] def fuser_backward( diff --git a/transformer_engine/pytorch/ops/basic/make_extra_output.py b/transformer_engine/pytorch/ops/basic/make_extra_output.py index f64b609de..34228affc 100644 --- a/transformer_engine/pytorch/ops/basic/make_extra_output.py +++ b/transformer_engine/pytorch/ops/basic/make_extra_output.py @@ -22,14 +22,20 @@ class MakeExtraOutput(BasicOperation): If this operation is included in the operation fuser, then the operation fuser will return the intermediate tensor as an extra - tensor output. In the backward pass, the gradient is directly - accumulated into the gradient w.r.t. the extra output. + tensor output. - This operation is considered an advanced feature and most users - are discouraged from using it. In-place operations break some - autograd assumptions and they can result in subtle, esoteric bugs. + In the backward pass, the gradient may be directly + accumulated into the gradient w.r.t. the extra output. This is + controlled by the in_place kwarg. Currently, the BackwardLinearAdd + fusion is able to happen only with in_place=True. - Compare to `AddInPlace`, which does a similar operation in the + Using this operation with in_place=True is + considered an advanced feature. Most users are discouraged + from enabling it in-place gradient accumulation, as in-place + operations break some autograd assumptions and they can result + in subtle, esoteric bugs. + + Compare to `AddExtraInput`, which does a similar operation in the backward pass. """ @@ -37,6 +43,10 @@ class MakeExtraOutput(BasicOperation): # Operation expects buffer for output tensor num_extra_outputs: int = 1 + def __init__(self, *, in_place: bool = False): + super().__init__() + self._in_place: bool = in_place + def op_forward(self, *args, **kwargs) -> None: raise RuntimeError( "{self.__class__.__name__} operation has " @@ -76,6 +86,10 @@ def fuser_backward( Iterable[Iterable[Optional[torch.Tensor]]], Iterable[Iterable[Optional[torch.Tensor]]], ]: - grad_input = basic_op_grad_extra_outputs[0][0] - grad_input += grad_output + grad_extra_output = basic_op_grad_extra_outputs[0][0] + if self._in_place: + grad_extra_output += grad_output + grad_input = grad_extra_output + else: + grad_input = grad_extra_output + grad_output return grad_input, [()], [()] diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 286503419..8af46a27c 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -139,6 +139,8 @@ def fuse_backward_linear_add( op, _ = ops[0] if not isinstance(op, MakeExtraOutput): continue + if not op._in_place: + continue window.extend(ops[:1]) ops = ops[1:] diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index 608fff01f..dd59e602f 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -11,7 +11,7 @@ import torch from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.ops.basic import AddInPlace, BasicLinear, Bias +from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias from transformer_engine.pytorch.ops.op import ( FusedOperation, FusibleOperation, @@ -33,7 +33,7 @@ def __init__( *, linear: BasicLinear, bias: Optional[Bias], - add: AddInPlace, + add: AddExtraInput, ) -> None: # Basic operations that comprise this fused operation @@ -179,8 +179,10 @@ def fuse_forward_linear_bias_add( continue op, _ = ops[0] - # Check if next op is add in-place - if not isinstance(op, AddInPlace): + # Check if next op is in-place add extra input + if not isinstance(op, AddExtraInput): + continue + if not op._in_place: continue add = op window.extend(ops[:1]) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 19e7bb31a..9923a5fbe 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -197,6 +197,10 @@ def forward( func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() func_ctx.with_quantized_compute = with_quantized_compute + # Mark output tensors as not deletable in backward + for tensor in [x] + extra_outputs_flat: + tensor.do_not_clear = True + x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) if extra_outputs_flat: From 71b2dd48f2f15ec51faf2d911e3857209f80427f Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Thu, 24 Jul 2025 09:52:12 -0700 Subject: [PATCH 017/153] Fix cudnn versioning support in PyTorch DPA and Fused attn (#1991) Fix cudnn versioning in support in PyTorch DPA and Fused attn Signed-off-by: Kshitij Janardan Lakhani --- transformer_engine/common/fused_attn/fused_attn.cpp | 8 ++++---- .../pytorch/attention/dot_product_attention/utils.py | 4 ++-- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 940c1d305..bb30261b9 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -251,10 +251,10 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 91100)) && - // 9.11 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - (!(cudnn_runtime_version == 91100 && is_training && sm_arch_ == 90 && head_dim_qk >= 128 && - head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && - head_dim_qk != head_dim_v))) && + // 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA + (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training && + sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && + !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (cudnn_runtime_version >= 8906 && diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7c4bf928c..9d6677b62 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -434,8 +434,8 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version < (9, 12, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN < 9.12") + if device_compute_capability == (8, 9) and cudnn_version <= (9, 12, 0): + logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") From a99c056be2afe186944674bdffa2d563cd53a962 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov <64355998+Oleg-Goncharov@users.noreply.github.com> Date: Thu, 24 Jul 2025 21:39:38 +0200 Subject: [PATCH 018/153] [Common] Fixed integer overflow issue in cast kernels (#1988) * Fixed integer overflow when computing offsets Signed-off-by: Oleg Goncharov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Oleg Goncharov Co-authored-by: Kirthi Shankar Sivamani --- tests/cpp/operator/test_cast_mxfp8.cu | 20 +- .../operator/test_cast_mxfp8_gated_swiglu.cu | 18 +- .../common/util/cast_gated_kernels.cuh | 156 ++++++------ .../common/util/cast_kernels.cuh | 224 +++++++++--------- 4 files changed, 211 insertions(+), 207 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 5a9423745..49bbf1655 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -81,8 +81,8 @@ void compute_ref(const ProcessingMethod processing_method, // Cache computations for (size_t i = i_min; i < i_max; ++i) { for (size_t j = j_min; j < j_max; ++j) { - const int idx = i * cols + j; - const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); float elt = static_cast(input[idx]); if (processing_method == ProcessingMethod::CAST_DBIAS) { @@ -114,18 +114,18 @@ void compute_ref(const ProcessingMethod processing_method, float block_amax = 0.0f; for (size_t j = j_min; j < j_max; ++j) { - const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); } const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); - const int scale_idx = i * scales_stride_rowwise + tile_X; + const size_t scale_idx = i * scales_stride_rowwise + tile_X; output_scales_rowwise[scale_idx] = biased_exponent; const float scale_reciprocal = exp2f_rcp(biased_exponent); for (size_t j = j_min; j < j_max; ++j) { - const int idx = i * cols + j; - const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); output_rowwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); } } @@ -135,18 +135,18 @@ void compute_ref(const ProcessingMethod processing_method, float block_amax = 0.0f; for (size_t i = i_min; i < i_max; ++i) { - const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); block_amax = std::max(block_amax, std::abs(cache_buffer[cache_idx])); } const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits::max_reciprocal()); - const int scale_idx = tile_Y * scales_stride_colwise + j; + const size_t scale_idx = tile_Y * scales_stride_colwise + j; output_scales_colwise[scale_idx] = biased_exponent; const float scale_reciprocal = exp2f_rcp(biased_exponent); for (size_t i = i_min; i < i_max; ++i) { - const int idx = i * cols + j; - const int cache_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t idx = i * cols + j; + const size_t cache_idx = (i - i_min) * tile_size_X + (j - j_min); output_colwise[idx] = static_cast(cache_buffer[cache_idx] * scale_reciprocal); } } diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index 3c7b8c8b7..464b77128 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -64,7 +64,7 @@ void compute_ref(const IType* grad, float silu_elt = static_cast(input[i * stride + j]); float gate_elt = static_cast(input[i * stride + cols + j]); - const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); if (IS_DGATED) { const float x = silu_elt; @@ -101,7 +101,7 @@ void compute_ref(const IType* grad, float block_amax_act = 0.0f; float block_amax_gate = 0.0f; for (size_t j = j_min; j < j_max; ++j) { - const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); if (IS_DGATED) { block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); @@ -109,18 +109,18 @@ void compute_ref(const IType* grad, } const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); - const int scale_idx_act = i * scales_stride_rowwise + tile_X; + const size_t scale_idx_act = i * scales_stride_rowwise + tile_X; output_scales_rowwise[scale_idx_act] = biased_exponent_act; float scale_reciprocal_gate; if (IS_DGATED) { const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); - const int scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32; + const size_t scale_idx_gate = scale_idx_act + (cols + 32 - 1) / 32; output_scales_rowwise[scale_idx_gate] = biased_exponent_gate; } for (size_t j = j_min; j < j_max; ++j) { - const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; if (IS_DGATED) { @@ -139,7 +139,7 @@ void compute_ref(const IType* grad, float block_amax_act = 0.0f; float block_amax_gate = 0.0f; for (size_t i = i_min; i < i_max; ++i) { - const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); block_amax_act = std::max(block_amax_act, std::abs(cache_buffer_act[cached_idx])); if (IS_DGATED) { block_amax_gate = std::max(block_amax_gate, std::abs(cache_buffer_gate[cached_idx])); @@ -147,18 +147,18 @@ void compute_ref(const IType* grad, } const fp8e8m0 biased_exponent_act = float_to_e8m0(block_amax_act * Quantized_Limits::max_reciprocal()); const float scale_reciprocal_act = exp2f_rcp(biased_exponent_act); - const int scale_idx_act = tile_Y * scales_stride_colwise + j; + const size_t scale_idx_act = tile_Y * scales_stride_colwise + j; output_scales_colwise[scale_idx_act] = biased_exponent_act; float scale_reciprocal_gate; if (IS_DGATED) { const fp8e8m0 biased_exponent_gate = float_to_e8m0(block_amax_gate * Quantized_Limits::max_reciprocal()); - const int scale_idx_gate = scale_idx_act + cols; + const size_t scale_idx_gate = scale_idx_act + cols; scale_reciprocal_gate = exp2f_rcp(biased_exponent_gate); output_scales_colwise[scale_idx_gate] = biased_exponent_gate; } for (size_t i = i_min; i < i_max; ++i) { - const int cached_idx = (i - i_min) * tile_size_X + (j - j_min); + const size_t cached_idx = (i - i_min) * tile_size_X + (j - j_min); const float after_act = cache_buffer_act[cached_idx] * scale_reciprocal_act; if (IS_DGATED) { diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 82041d9f9..d7552835e 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -58,14 +58,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const float *const scale_ptr, const size_t rows, const size_t cols) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - const int chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int chunk_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t chunk_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t chunk_offset_X = blockIdx.x * CHUNK_DIM_X; - const int tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; - const int tid_X = threadIdx.x % THREADS_PER_CHUNK_X; + const size_t tid_Y = threadIdx.x / THREADS_PER_CHUNK_X; + const size_t tid_X = threadIdx.x % THREADS_PER_CHUNK_X; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; float amax = 0; const float scale = (scale_ptr != nullptr) ? *scale_ptr : 1; @@ -131,12 +131,12 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int it = 0; it < ITERATIONS; ++it) { - const int buff = it % BUFFERS_NUM; - const int next_it = it + 1; + const size_t buff = it % BUFFERS_NUM; + const size_t next_it = it + 1; if (next_it < ITERATIONS) { - const int next_buff = next_it % BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const size_t next_buff = next_it % BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; if constexpr (IS_DGATED) { copy_2d_to_sharedx3( &in_grad_sh[next_buff * buff_elems], TMAP_grad_in, chunk_it_offset_x, chunk_it_offset_y, @@ -164,10 +164,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int stage = 0; stage < BUFFER_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage * THREADS_PER_CHUNK_Y; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; - const int shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; + const size_t stage_offset_Y = stage * THREADS_PER_CHUNK_Y; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; + const size_t shmem_idx = shmem_offset_y * SHMEM_DIM_X + shmem_offset_x; float act_elt = static_cast(in_act_sh_curr[shmem_idx]); float gate_elt = static_cast(in_gate_sh_curr[shmem_idx]); @@ -210,8 +210,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const size_t chunk_it_offset_y = chunk_offset_Y + it * BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; // dGeLU ptx::cp_async_bulk_tensor_2d_shared_to_global(TMAP_output_act, chunk_it_offset_x, @@ -312,48 +312,48 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr bool ONLY_COLWISE_SCALING = COLWISE_SCALING && (!ROWWISE_SCALING); // # of rows covered by one wave. Equal to the # of columnwise threads in Y dimension. - constexpr int COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X); + constexpr size_t COLWISE_WAVEFRONT_SIZE = DIVUP(THREADS_PER_CHUNK, CHUNK_DIM_X); - const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * CHUNK_DIM_X; - const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; constexpr size_t THREADS_X_ROWWISE = CHUNK_DIM_X / SCALE_DIM_X; - const int tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; - const int tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; - const int tid_Y_colwise = threadIdx.x / CHUNK_DIM_X; - const int tid_X_colwise = threadIdx.x % CHUNK_DIM_X; + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X_ROWWISE; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X_ROWWISE; + const size_t tid_Y_colwise = threadIdx.x / CHUNK_DIM_X; + const size_t tid_X_colwise = threadIdx.x % CHUNK_DIM_X; - const int thread_offset_Y_rowwise = tid_Y_rowwise; - const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const int thread_offset_Y_colwise = tid_Y_colwise; - const int thread_offset_X_colwise = tid_X_colwise; + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; - const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const int col_base_rowwise = block_offset_X + thread_offset_X_rowwise; - const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t col_base_rowwise = block_offset_X + thread_offset_X_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; const bool col_out_of_bounds_rowwise = (col_base_rowwise >= cols); const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; - const int gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X; - const int gate_scale_idx_offset_colwise = cols; + const size_t gate_scale_idx_offset_rowwise = (cols + SCALE_DIM_X - 1) / SCALE_DIM_X; + const size_t gate_scale_idx_offset_colwise = cols; // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; const int bank_group = thread_lane / THREADS_PER_BANK; - constexpr int SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; + constexpr size_t SUBAMAX_BUFF_DIM_Y = ONLY_COLWISE_SCALING ? COLWISE_WAVEFRONT_SIZE - 1 : 1; __shared__ float subamax_colwise_buff[SUBAMAX_BUFF_DIM_Y][CHUNK_DIM_X]; extern __shared__ char dynamic_shmem[]; @@ -400,7 +400,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType *cached_act_sh = in_act_sh; // in_act_sh is used as a cache buffer for activations IType *cached_gate_sh = in_gate_sh; // in_gate_sh is used as a cache buffer for gated values - constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; const bool is_master_thread = (threadIdx.x == 0); @@ -425,20 +425,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { - const int buff = stage % BUFFS_NUM; - const int next_stage = stage + 1; - const int stage_offset_Y = stage * BUFF_DIM_Y; + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; if (next_stage < STAGES) { // Wait for TMA transfer to have finished reading shared memory. // I.e. the buffer is ready to be written to ptx::cp_async_bulk_wait_group_read<1>(); - const int next_buff = next_stage % BUFFS_NUM; - const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const int global_offset_Y = block_offset_Y + next_stage_offset_Y; - const int global_offset_X = block_offset_X; - const int next_buff_offset = next_buff * BUFF_DIM; + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; if constexpr (IS_DGATED) { copy_2d_to_sharedx3(&in_grad_sh[next_buff_offset], &tensor_map_grad, global_offset_X, global_offset_Y, &in_act_sh[next_buff_offset], &tensor_map_input_act, @@ -459,7 +459,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mbarrier_wait_parity(&mbar[stage], parity); if constexpr (COLWISE_SCALING) { - const int shmem_offset_base_colwise = + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_Y_colwise * BUFF_DIM_X + tid_X_colwise; float thread_amax_act = 0.0f; float thread_amax_gate = 0.0f; @@ -469,7 +469,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 1. Read/Compute elements. Find MXFP8-block AMAX #pragma unroll for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { - const int shmem_offset_colwise = + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; float act_elt = static_cast(in_act_sh[shmem_offset_colwise]); @@ -581,9 +581,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const e8m0_t biased_exponent_act = ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); - const int global_scales_offset_Y = scales_offset_Y_colwise + stage; - const int global_scales_offset_X = scales_offset_X_colwise; - const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; const bool row_out_of_bounds_colwise = (row_base_colwise + stage_offset_Y) >= rows; const bool out_of_bounds_colwise = row_out_of_bounds_colwise || col_out_of_bounds_colwise; @@ -597,8 +598,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); - // const int scale_idx_gate = scale_idx + scale_stride_colwise / 2; - const int scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; + // const size_t scale_idx_gate = scale_idx + scale_stride_colwise / 2; + const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_colwise; if (tid_Y_colwise == 0 && (!out_of_bounds_colwise)) { scales_colwise[scale_idx_gate] = biased_exponent_gate; } @@ -608,7 +609,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 3. Scale elements #pragma unroll for (int i = 0; i < SCALE_DIM_Y / COLWISE_WAVEFRONT_SIZE; ++i) { - const int shmem_offset_elt = + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * COLWISE_WAVEFRONT_SIZE * BUFF_DIM_X; if constexpr (IS_DGATED) { OType2 out_pair; @@ -626,7 +627,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } if constexpr (ROWWISE_SCALING) { - const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; float thread_amax_act = 0.0f; float thread_amax_gate = 0.0f; @@ -645,9 +647,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType2 thread_amax_2x_gate = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); @@ -695,9 +697,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } else { #pragma unroll for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; Vec in_grad; Vec in_act; @@ -765,9 +767,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 2. Compute E8M0 scaling factor const e8m0_t biased_exponent_act = ptx::float_to_e8m0(thread_amax_act * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; - const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const size_t stage_scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y) >= rows; const bool out_of_bounds_rowwise = row_out_of_bounds_rowwise || col_out_of_bounds_rowwise; if (!out_of_bounds_rowwise) { @@ -783,7 +785,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (IS_DGATED) { const e8m0_t biased_exponent_gate = ptx::float_to_e8m0(thread_amax_gate * Quantized_Limits::max_norm_rcp); - const int scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; + const size_t scale_idx_gate = scale_idx + gate_scale_idx_offset_rowwise; if (!out_of_bounds_rowwise) { scales_rowwise[scale_idx_gate] = biased_exponent_gate; } @@ -826,9 +828,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) ptx::mul_cvt_2x(out_gate_pair, in_gate, block_scale_inverse_2x_gate); } } - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; out_act.store_to(&out_act_rowwise_sh[shmem_offset_rowwise]); if constexpr (IS_DGATED) { out_gate.store_to(&out_gate_rowwise_sh[shmem_offset_rowwise]); @@ -843,9 +845,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 79209adf5..fcf0a4084 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -80,33 +80,33 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) constexpr bool IS_CACHED_ACT_OP = COMPUTE_ACTIVATIONS && ROWWISE_SCALING && COLWISE_SCALING; - const int block_offset_Y = blockIdx.y * CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * CHUNK_DIM_X; - const int scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; - const int scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; - const int scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; - const int scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; - - const int tid_Y_rowwise = threadIdx.x / THREADS_X; - const int tid_X_rowwise = threadIdx.x % THREADS_X; - const int tid_Y_colwise = 0; - const int tid_X_colwise = threadIdx.x; - - const int thread_offset_Y_rowwise = tid_Y_rowwise; - const int thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; - const int thread_offset_Y_colwise = tid_Y_colwise; - const int thread_offset_X_colwise = tid_X_colwise; - - const int row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; - const int row_base_colwise = block_offset_Y + thread_offset_Y_colwise; - const int col_base_colwise = block_offset_X + thread_offset_X_colwise; + const size_t block_offset_Y = blockIdx.y * CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * CHUNK_DIM_X; + const size_t scales_block_offset_Y_rowwise = blockIdx.y * CHUNK_DIM_Y; + const size_t scales_block_offset_X_rowwise = blockIdx.x * CHUNK_DIM_X / SCALE_DIM_X; + const size_t scales_block_offset_Y_colwise = blockIdx.y * CHUNK_DIM_Y / SCALE_DIM_Y; + const size_t scales_block_offset_X_colwise = blockIdx.x * CHUNK_DIM_X; + + const size_t tid_Y_rowwise = threadIdx.x / THREADS_X; + const size_t tid_X_rowwise = threadIdx.x % THREADS_X; + const size_t tid_Y_colwise = 0; + const size_t tid_X_colwise = threadIdx.x; + + const size_t thread_offset_Y_rowwise = tid_Y_rowwise; + const size_t thread_offset_X_rowwise = tid_X_rowwise * SCALE_DIM_X; + const size_t thread_offset_Y_colwise = tid_Y_colwise; + const size_t thread_offset_X_colwise = tid_X_colwise; + + const size_t row_base_rowwise = block_offset_Y + thread_offset_Y_rowwise; + const size_t row_base_colwise = block_offset_Y + thread_offset_Y_colwise; + const size_t col_base_colwise = block_offset_X + thread_offset_X_colwise; const bool col_out_of_bounds_colwise = (col_base_colwise >= cols); - const int scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; - const int scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; - const int scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; - const int scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; + const size_t scales_offset_Y_rowwise = scales_block_offset_Y_rowwise + tid_Y_rowwise; + const size_t scales_offset_X_rowwise = scales_block_offset_X_rowwise + tid_X_rowwise; + const size_t scales_offset_Y_colwise = scales_block_offset_Y_colwise + tid_Y_colwise; + const size_t scales_offset_X_colwise = scales_block_offset_X_colwise + tid_X_colwise; // helps resolving bank conflicts in shmem const int thread_lane = threadIdx.x % THREADS_PER_WARP; @@ -139,7 +139,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) OType *out_colwise_sh = reinterpret_cast(dshmem + in_mem + out_mem_rowwise); IType *cached_act_sh = in_sh; // in_sh is used as a cache buffer - constexpr int shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; + constexpr size_t shmem_buff_size = buff_size_aligned_in / BUFFS_NUM; const bool is_master_thread = (threadIdx.x == 0); @@ -173,20 +173,20 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int stage = 0; stage < STAGES; ++stage) { - const int buff = stage % BUFFS_NUM; - const int next_stage = stage + 1; - const int stage_offset_Y = stage * BUFF_DIM_Y; + const size_t buff = stage % BUFFS_NUM; + const size_t next_stage = stage + 1; + const size_t stage_offset_Y = stage * BUFF_DIM_Y; if (next_stage < STAGES) { // Wait for TMA transfer to have finished reading shared memory. // I.e. the buffer is ready to be written to ptx::cp_async_bulk_wait_group_read<1>(); - const int next_buff = next_stage % BUFFS_NUM; - const int next_stage_offset_Y = next_stage * BUFF_DIM_Y; - const int global_offset_Y = block_offset_Y + next_stage_offset_Y; - const int global_offset_X = block_offset_X; - const int next_buff_offset = next_buff * BUFF_DIM; + const size_t next_buff = next_stage % BUFFS_NUM; + const size_t next_stage_offset_Y = next_stage * BUFF_DIM_Y; + const size_t global_offset_Y = block_offset_Y + next_stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t next_buff_offset = next_buff * BUFF_DIM; if constexpr (IS_DACT) { copy_2d_to_sharedx2(&in_sh[next_buff_offset], &tensor_map_input, global_offset_X, global_offset_Y, &act_in_sh[next_buff_offset], &tensor_map_act_input, @@ -205,7 +205,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { - const int shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; + const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; thread_amax = 0.0f; float in_compute_colwise[BUFF_DIM_Y]; IType in_colwise_IType[BUFF_DIM_Y]; @@ -215,7 +215,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType thread_amax_f16 = static_cast(0.0f); #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; in_colwise_IType[i] = in_sh[shmem_offset_colwise]; thread_amax_f16 = __hmax(thread_amax_f16, __habs(in_colwise_IType[i])); } @@ -223,7 +223,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } else { #pragma unroll for (int i = 0; i < BUFF_DIM_Y; ++i) { - const int shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; + const size_t shmem_offset_colwise = shmem_offset_base_colwise + i * BUFF_DIM_X; float elt = static_cast(in_sh[shmem_offset_colwise]); if constexpr (IS_ACT) { @@ -263,9 +263,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int global_scales_offset_Y = scales_offset_Y_colwise + stage; - const int global_scales_offset_X = scales_offset_X_colwise; - const int scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; + const size_t global_scales_offset_X = scales_offset_X_colwise; + const size_t scale_idx = + global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; scales_colwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -282,13 +283,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } const float scaled_out = in * block_scale_inverse; - const int shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; + const size_t shmem_offset_elt = shmem_offset_base_colwise + i * BUFF_DIM_X; out_colwise_sh[shmem_offset_elt] = static_cast(scaled_out); } } if constexpr (ROWWISE_SCALING) { - const int shmem_offset_base_rowwise = buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; + const size_t shmem_offset_base_rowwise = + buff * BUFF_DIM + thread_offset_Y_rowwise * BUFF_DIM_X; thread_amax = 0.0f; float in_compute_rowwise[SCALE_DIM_X]; Vec in_cached[WAVES]; @@ -301,9 +303,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; // Load elements in_IType[w].load_from(&in_sh[shmem_offset_rowwise]); #pragma unroll @@ -319,9 +321,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) IType2 thread_amax_2x = {static_cast(0.0f), static_cast(0.0f)}; #pragma unroll for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; const bool row_out_of_bounds_rowwise = (row_base_rowwise + stage_offset_Y >= rows); const bool swizzled_col_out_of_bounds = (block_offset_X + swizzled_thread_idx >= cols); @@ -354,9 +356,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } else { #pragma unroll for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; - const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_thread_idx = thread_offset_X_rowwise + swizzled_group_idx; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_thread_idx; Vec in; Vec act_in; @@ -406,9 +408,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // 2. Compute E8M0 scaling factor const e8m0_t biased_exponent = ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); - const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; - const int stage_scales_offset_X = scales_offset_X_rowwise; - const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + const size_t stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; + const size_t stage_scales_offset_X = scales_offset_X_rowwise; + const size_t scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; scales_rowwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -434,9 +436,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } ptx::mul_cvt_2x(out_pair, in, block_scale_inverse_2x); } - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; - const int shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_idx = swizzled_group_idx + thread_offset_X_rowwise; + const size_t shmem_offset_rowwise = shmem_offset_base_rowwise + swizzled_idx; out.store_to(&out_rowwise_sh[shmem_offset_rowwise]); } } @@ -452,9 +454,9 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int global_offset_Y = block_offset_Y + stage_offset_Y; - const int global_offset_X = block_offset_X; - const int buff_offset = buff * BUFF_DIM; + const size_t global_offset_Y = block_offset_Y + stage_offset_Y; + const size_t global_offset_X = block_offset_X; + const size_t buff_offset = buff * BUFF_DIM; if constexpr (ROWWISE_SCALING) { ptx::cp_async_bulk_tensor_2d_shared_to_global( @@ -485,18 +487,18 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Added extra 1-element padding per thread_X to reduce bank conflicts float *partial_dbias_rowwise = reinterpret_cast(dshmem); - constexpr int DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); + constexpr size_t DBIAS_BUFF_WIDTH = THREADS_X * (SCALE_DIM_X + 1); - const int shmem_thread_offset = + const size_t shmem_thread_offset = tid_Y_rowwise * DBIAS_BUFF_WIDTH + tid_X_rowwise * (SCALE_DIM_X + 1); #pragma unroll for (int w = 0; w < WAVES; ++w) { - const int swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; - const int swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; + const size_t swizzled_group_idx = ((w + bank_group) * PACK_SIZE) % SCALE_DIM_X; + const size_t swizzled_group_offset = shmem_thread_offset + swizzled_group_idx; #pragma unroll for (int e = 0; e < PACK_SIZE; ++e) { const int j = w * PACK_SIZE + e; - const int shmem_elt_idx = swizzled_group_offset + e; + const size_t shmem_elt_idx = swizzled_group_offset + e; partial_dbias_rowwise[shmem_elt_idx] = thread_dbias_rowwise[j]; } } @@ -504,15 +506,15 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) #pragma unroll for (int i = 0; i < THREADS_Y; ++i) { // Add extra element offset per MXFP8 scaling block [1x32] - const int scaling_block = threadIdx.x / SCALE_DIM_X; + const size_t scaling_block = threadIdx.x / SCALE_DIM_X; thread_partial_dbias += partial_dbias_rowwise[i * DBIAS_BUFF_WIDTH + threadIdx.x + scaling_block]; } } - const int dbias_stride = cols; - const int dbias_offset_Y = blockIdx.y; - const int dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; - const int dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; + const size_t dbias_stride = cols; + const size_t dbias_offset_Y = blockIdx.y; + const size_t dbias_offset_X = blockIdx.x * CHUNK_DIM_X + threadIdx.x; + const size_t dbias_idx = dbias_offset_Y * dbias_stride + dbias_offset_X; const bool col_out_of_bounds_dbias = (dbias_offset_X >= cols); if (!col_out_of_bounds_dbias) { dbias_workspace[dbias_idx] = thread_partial_dbias; @@ -561,19 +563,19 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) const size_t cols) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - const int block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; - const int block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; + const size_t block_offset_Y = blockIdx.y * FP8_CHUNK_DIM_Y; + const size_t block_offset_X = blockIdx.x * FP8_CHUNK_DIM_X; - const int tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; - const int tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; + const size_t tid_Y = threadIdx.x / FP8_THREADS_PER_CHUNK; + const size_t tid_X = threadIdx.x % FP8_THREADS_PER_CHUNK; - const int thread_offset_Y = tid_Y; - const int thread_offset_X = tid_X; + const size_t thread_offset_Y = tid_Y; + const size_t thread_offset_X = tid_X; - const int dbias_offset_Y = blockIdx.y + tid_Y; - const int my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; + const size_t dbias_offset_Y = blockIdx.y + tid_Y; + const size_t my_column = blockIdx.x * FP8_CHUNK_DIM_X + thread_offset_X; const bool col_out_of_bounds = my_column >= cols; - const int dbias_stride = cols; + const size_t dbias_stride = cols; float partial_dbias = 0.f; @@ -588,7 +590,7 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[FP8_BUFFERS_NUM][FP8_SHMEM_DIM_Y][FP8_SHMEM_DIM_X]; - constexpr int shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; + constexpr size_t shmem_buff_size = sizeof(in_sh) / FP8_BUFFERS_NUM; const bool is_master_thread = (threadIdx.x == 0); @@ -600,13 +602,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) int parity = 0; - const int chunk_offset_Y = block_offset_Y; - const int chunk_offset_X = block_offset_X; + const size_t chunk_offset_Y = block_offset_Y; + const size_t chunk_offset_X = block_offset_X; #pragma unroll for (int prefetch_buff = 0; prefetch_buff < FP8_PREFETCH_BUFFERS_NUM; ++prefetch_buff) { - const int chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; - const int chunk_stage_offset_X = chunk_offset_X; + const size_t chunk_stage_offset_Y = chunk_offset_Y + prefetch_buff * FP8_BUFFER_DIM_Y; + const size_t chunk_stage_offset_X = chunk_offset_X; if constexpr (IS_DACT) { copy_2d_to_sharedx2(&in_sh[prefetch_buff], &tensor_map_input, chunk_stage_offset_X, chunk_stage_offset_Y, &act_in_sh[prefetch_buff], &tensor_map_act_input, @@ -621,13 +623,13 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) #pragma unroll for (int iter = 0; iter < FP8_ITERATIONS; ++iter) { - const int buff = iter % FP8_BUFFERS_NUM; - const int next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; + const size_t buff = iter % FP8_BUFFERS_NUM; + const size_t next_iter = iter + FP8_PREFETCH_BUFFERS_NUM; const size_t row_base = block_offset_Y + iter * FP8_BUFFER_DIM_Y; if (next_iter < FP8_ITERATIONS) { - const int next_buff = next_iter % FP8_BUFFERS_NUM; - const int chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const size_t next_buff = next_iter % FP8_BUFFERS_NUM; + const size_t chunk_it_offset_y = chunk_offset_Y + next_iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; if constexpr (IS_DACT) { copy_2d_to_sharedx2(&in_sh[next_buff], &tensor_map_input, chunk_it_offset_x, chunk_it_offset_y, &act_in_sh[next_buff], &tensor_map_act_input, @@ -644,9 +646,9 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) #pragma unroll for (int stage = 0; stage < FP8_BUFF_STAGES_NUM; ++stage) { - const int stage_offset_Y = stage; - const int shmem_offset_y = thread_offset_Y + stage_offset_Y; - const int shmem_offset_x = thread_offset_X; + const size_t stage_offset_Y = stage; + const size_t shmem_offset_y = thread_offset_Y + stage_offset_Y; + const size_t shmem_offset_x = thread_offset_X; const size_t row = row_base + shmem_offset_y; const bool row_out_of_bounds = row >= rows; const bool out_of_bounds = col_out_of_bounds || row_out_of_bounds; @@ -685,8 +687,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) // Initiate TMA transfer to copy shared memory to global memory if (is_master_thread) { - const int chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; - const int chunk_it_offset_x = chunk_offset_X; + const size_t chunk_it_offset_y = chunk_offset_Y + iter * FP8_BUFFER_DIM_Y; + const size_t chunk_it_offset_x = chunk_offset_X; ptx::cp_async_bulk_tensor_2d_shared_to_global( reinterpret_cast(&tensor_map_output), chunk_it_offset_x, chunk_it_offset_y, reinterpret_cast(&out_sh[buff])); @@ -704,8 +706,8 @@ __global__ void __launch_bounds__(FP8_THREADS_PER_CHUNK) parity ^= 1; if constexpr (IS_DBIAS) { - const int dbias_offset_X = my_column; - const int dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; + const size_t dbias_offset_X = my_column; + const size_t dbias_offset = dbias_offset_Y * dbias_stride + dbias_offset_X; if (!col_out_of_bounds) { dbias_workspace[dbias_offset] = partial_dbias; } @@ -747,7 +749,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) float *const scale_inv_ptr, const float *const scale_ptr, const size_t N) { #if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 1000) - const int block_offset = blockIdx.x * ELEMS_PER_BLOCK; + const size_t block_offset = blockIdx.x * ELEMS_PER_BLOCK; const IType *input = input_ptr + block_offset; OType *output = output_ptr + block_offset; @@ -758,8 +760,8 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) __shared__ alignas(TMA_SHMEM_ALIGNMENT) IType in_sh[SHMEM_BUFFERS][SHMEM_DIM]; __shared__ alignas(TMA_SHMEM_ALIGNMENT) OType out_sh[SHMEM_BUFFERS][SHMEM_DIM]; - constexpr int transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; - constexpr int transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; + constexpr size_t transaction_size_IN = sizeof(in_sh) / SHMEM_BUFFERS; + constexpr size_t transaction_size_OUT = sizeof(out_sh) / SHMEM_BUFFERS; const bool is_master_thread = (threadIdx.x == 0); @@ -775,12 +777,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) #pragma unroll for (int iter = 0; iter < ITERATIONS; ++iter) { - const int buff = iter % SHMEM_BUFFERS; - const int it_offset = iter * SHMEM_DIM; + const size_t buff = iter % SHMEM_BUFFERS; + const size_t it_offset = iter * SHMEM_DIM; - const int next_iter = iter + 1; - const int next_buff = next_iter % SHMEM_BUFFERS; - const int next_iter_offset = next_iter * SHMEM_DIM; + const size_t next_iter = iter + 1; + const size_t next_buff = next_iter % SHMEM_BUFFERS; + const size_t next_iter_offset = next_iter * SHMEM_DIM; if (next_iter < ITERATIONS) { copy_1d_to_shared(&(in_sh[next_buff]), input + next_iter_offset, transaction_size_IN, @@ -794,7 +796,7 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) #pragma unroll for (int chunk = 0; chunk < CHUNKS_PER_ITERATION; ++chunk) { - const int shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; + const size_t shmem_offset = chunk * CHUNK_SIZE + threadIdx.x; float elt = static_cast(in_sh[buff][shmem_offset]); if constexpr (IS_ACT) { elt = OP(elt, {}); @@ -847,12 +849,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) constexpr size_t DBIAS_THREADS_PER_BLOCK = 256; template __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) - reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, const int rows, - const int cols) { + reduce_dbias_kernel(OType *const dbias_output, const float *const dbias_partial, + const size_t rows, const size_t cols) { using ComputeVec = Vec; using OutputVec = Vec; - const int thread_id = blockIdx.x * blockDim.x + threadIdx.x; + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; if (thread_id * nvec >= cols) { return; @@ -883,8 +885,8 @@ __global__ void __launch_bounds__(DBIAS_THREADS_PER_BLOCK) template void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, const size_t cols, cudaStream_t stream) { - constexpr int reduce_dbias_store_bytes = 8; // stg.64 - constexpr int reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); NVTE_CHECK(cols % reduce_dbias_nvec == 0, "Unsupported shape."); const size_t reduce_dbias_num_blocks = DIVUP(cols, DBIAS_THREADS_PER_BLOCK * reduce_dbias_nvec); @@ -1244,8 +1246,8 @@ static bool is_full_tile_1D_tensor(const Tensor *const t) { bool dimensions_supported_by_TMA(const Tensor *const t) { const size_t cols = t->flat_last_dim(); - constexpr int TMA_bytes = 16; - const int alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); + constexpr size_t TMA_bytes = 16; + const size_t alignment_requirement = (TMA_bytes * 8) / typeToNumBits(t->dtype()); return cols % alignment_requirement == 0; } From 25a82192b738281ed05a128d945f72be584667c5 Mon Sep 17 00:00:00 2001 From: Alp Dener Date: Thu, 24 Jul 2025 18:21:27 -0500 Subject: [PATCH 019/153] [JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parallelism correctly for sequence-parallel inputs (#1980) * updated GemmPrimitive partitioning rules to explicitly control all-reduce vs. reduce-scatter for sequence-parallelism Signed-off-by: Alp Dener * corrected handling of FSDP sharding for the RHS operand Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * use correct logical axes variable to identify sequence-parallel dim in LayerNormDenseGeneral Signed-off-by: Alp Dener * fixed linting issues Signed-off-by: Alp Dener * added assert on sequence-parallel options when GemmPrimitive is disabled Signed-off-by: Alp Dener * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Alp Dener Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/jax/cpp_extensions/gemm.py | 230 ++++++++++-------- transformer_engine/jax/dense.py | 66 ++++- transformer_engine/jax/flax/module.py | 4 + transformer_engine/jax/flax/transformer.py | 1 + transformer_engine/jax/layernorm_dense.py | 6 + transformer_engine/jax/layernorm_mlp.py | 12 +- transformer_engine/jax/sharding.py | 48 +++- 7 files changed, 257 insertions(+), 110 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c4c744643..d2e65d265 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15) + impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) inner_primitive = None outer_primitive = None @@ -177,8 +177,14 @@ def abstract( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, ): del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator + del ( + sequence_parallel_output, + sequence_dim, + ) def _dims_are_consecutive(dims): if len(dims) <= 1: @@ -343,8 +349,12 @@ def lowering( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, ): del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype + del sequence_parallel_output, sequence_dim + lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -393,6 +403,8 @@ def impl( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, ): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -430,6 +442,8 @@ def impl( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + sequence_parallel_output=sequence_parallel_output, + sequence_dim=sequence_dim, ) return outputs[:-3] # discard workspace arrays @@ -447,6 +461,8 @@ def batcher( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, ): assert GemmPrimitive.outer_primitive is not None lhs, _, rhs, *_ = batched_args @@ -489,6 +505,8 @@ def batcher( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + sequence_parallel_output=sequence_parallel_output, + sequence_dim=sequence_dim, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -510,7 +528,13 @@ def _decompose_operand_specs(specs, contracting_dims, batch_dims): return bspecs, lspecs, cspecs @staticmethod - def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims): + def _parse_operand_output_specs( + arg_infos, + contracting_dims, + batched_dims, + sequence_parallel_output, + sequence_dim, + ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( @@ -556,96 +580,66 @@ def _parse_operand_output_specs(arg_infos, contracting_dims, batched_dims): ) # Extract single leading and contracting dimension specs - (lhs_lspec, rhs_lspec, lhs_cspec, rhs_cspec) = map( + (lhs_cspec, rhs_cspec) = map( lambda specs: None if len(specs) == 0 else specs[0], - (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none), + (lhs_cspec_not_none, rhs_cspec_not_none), ) - # Reproducing jax.nn.scaled_matmul() custom partitioning for arbitrary GEMM layouts - # with row-wise LHS:(B, M, K1) and row-wise RHS:(B, N, K2) operands. - # 1. K1 == K2 != None and N == None - # LHS: (B, M, K) - # RHS: (B, None, K) - # OUT: (B, M, None) --(AR)-> (B, M, None) - # 2. K1 == K2 != None and M == N != None - # LHS: (B, M, K) - # RHS: (B, N, K)--(AG)->(B, None, K) - # OUT: (B, M, None) --(RS)--> (B, M, N) - # 3. M == N - # LHS: (B, M, K)--(AG)->(B, M, None) - # RHS: (B, M, K)--(AG)->(B, None, None) - # OUT: (B, M, None) - # 4. M != N - # LHS: (B, M, K)--(AG)->(B, M, None) - # RHS: (B, N, K)--(AG)->(B, N, None) - # OUT: (B, M, N) - reduce_flag = lhs_cspec is not None and lhs_cspec == rhs_cspec - all_reduce_output = reduce_flag and rhs_lspec is None - reduce_scatter_output = reduce_flag and lhs_lspec is not None and lhs_lspec == rhs_lspec - all_reduce_spec = reduce_scatter_spec = scatter_dim = None + # Partitioning rules: + # ([B], M, K1) x ([B], N, K2)^T = ([B], M, N) + # 1. K1 == K2 != None + # - Require non-batched non-contracting dims of both LHS and RHS to be unsharded. + # - If `sequence_parallel_output=True`, then reduce-scatter the output. + # - Otherwise, all-reduce the output. + # 2. Otherwise + # - Require contracting dimensions of both LHS and RHS to be unsharded. + # - Require non-batched non-contracting dims of LHS to be unsharded. + reduce_output = rhs_cspec is not None and lhs_cspec == rhs_cspec + reduce_spec = scatter_dim = None + if reduce_output: + reduce_spec = rhs_cspec + if sequence_parallel_output: + # If the sequence dimension is not specified, assume it to be the first + # non-batched non-contracting dimension of the LHS operand. + scatter_dim = sequence_dim if sequence_dim is not None else lhs_ldims[0] + + # Always require the non-batched non-contracting dims of LHS to be unsharded + # NOTE: This will all-gather sequence-parallel inputs and preserve tensor-parallel params. + lhs_specs = tuple( + lhs_specs[i] if i in set(lhs_bdims + lhs_cdims) else None for i in range(lhs_ndim) + ) + if reduce_output: + # When reducing GEMM output, require non-batched non-contracting dims of the RHS + # operand to be unsharded (i.e. FSDP) + rhs_specs = tuple( + None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] + for i in range(rhs_ndim) + ) + else: + # Otherwise, require contracting dims of both operands to be unsharded + lhs_specs = tuple(None if i in lhs_cdims else lhs_specs[i] for i in range(lhs_ndim)) + rhs_specs = tuple(None if i in rhs_cdims else rhs_specs[i] for i in range(rhs_ndim)) + # Combine modified LHS and RHS specs into the output lhs_non_contracting_specs, rhs_non_contracting_specs = map( lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), (lhs_specs, rhs_specs), (lhs_cdims, rhs_cdims), ) - out_specs = (*lhs_non_contracting_specs, *rhs_non_contracting_specs) - if reduce_scatter_output: - # All-gather (if necessary) the non-batch non-contracting dimension of RHS - # (B, N, K) --(AG)-> (B, None, K) - # (B, M, K) x (B, None, K)^T = (B, M, None) --(RS)-> (B, M, N) - rhs_spec = tuple( - rhs_spec[i] if i in set(rhs_bdims + rhs_cdims) else None for i in range(rhs_ndim) - ) - reduce_scatter_spec = lhs_cspec - scatter_dim = out_specs.index(rhs_lspec) - - elif all_reduce_output: - # Set all output trailing dimensions to zero - out_specs = ( - *lhs_non_contracting_specs, - *[None for _ in range(len(rhs_non_contracting_specs))], - ) - all_reduce_spec = lhs_cspec - else: - # All-gather (if necessary) the non-batch contracting dimensions - # (B, M, K) --(AG)-> (B, M, None) - # (B, N, K) --(AG)-> (B, N, None) - # (B, M, None) x (B, N, None)^T = (B, M, N) - lhs_specs = tuple( - None if i in lhs_cdims and i not in lhs_bdims else lhs_specs[i] - for i in range(lhs_ndim) - ) - rhs_specs = tuple( - None if i in rhs_cdims and i not in rhs_bdims else rhs_specs[i] - for i in range(rhs_ndim) - ) - # Check if RHS non-contracting spec also appears in the LHS non-contracting specs - if rhs_lspec is not None and rhs_lspec in tuple( - lhs_specs[i] for i in range(lhs_ndim) if i not in lhs_cdims - ): - # All-gather (if necessary) the non-batch non-contracting dimensions of RHS - # (B, N, None) --(AG)-> (B, None, None) - # (B, M, None) x (B, None, None)^T = (B, M, None) - rhs_specs = tuple( - None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] - for i in range(rhs_ndim) - ) - # Set all output trailing dimensions to zero - out_specs = ( - *lhs_non_contracting_specs, - *[None for _ in range(len(rhs_non_contracting_specs))], - ) + out_specs = [*lhs_non_contracting_specs, *rhs_non_contracting_specs] - # Bias and Pre-GeLU sharding is based on GEMM output - bias_specs = out_specs[len(lhs_non_contracting_specs) :] - gelu_specs = out_specs + # Bias and Pre-GeLU sharding is based on GEMM output before any scatter + bias_specs = tuple(list(out_specs[len(lhs_non_contracting_specs) :]).copy()) + gelu_specs = tuple(list(out_specs).copy()) + + # Set output scatter dim to the tensor-parallel spec + if sequence_parallel_output: + out_specs[scatter_dim] = reduce_spec return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs), - all_reduce_spec, - reduce_scatter_spec, + reduce_spec, scatter_dim, ) @@ -661,6 +655,8 @@ def infer_sharding_from_operands( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, mesh, arg_infos, result_infos, @@ -675,7 +671,13 @@ def infer_sharding_from_operands( del use_split_accumulator, result_infos (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( - GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + GemmPrimitive._parse_operand_output_specs( + arg_infos, + contracting_dims, + batched_dims, + sequence_parallel_output, + sequence_dim, + ) ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -703,6 +705,8 @@ def partition( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, mesh, arg_infos, result_infos, @@ -712,10 +716,15 @@ def partition( ( (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (out_specs, dbias_specs, pre_gelu_specs), - all_reduce_spec, - reduce_scatter_spec, + reduce_spec, scatter_dim, - ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims, batched_dims) + ) = GemmPrimitive._parse_operand_output_specs( + arg_infos, + contracting_dims, + batched_dims, + sequence_parallel_output, + sequence_dim, + ) # Assemble argument shardings # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. @@ -770,20 +779,17 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + sequence_parallel_output=sequence_parallel_output, + sequence_dim=sequence_dim, ) # All-Reduce/Reduce-Scatter GEMM output - if all_reduce_spec is not None: - outputs[0] = jax.lax.psum(outputs[0], all_reduce_spec) - if fuse_gelu and not grad: - outputs[2] = jax.lax.psum(outputs[2], all_reduce_spec) - elif reduce_scatter_spec is not None: - outputs[0] = jax.lax.psum_scatter( - outputs[0], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True - ) - if fuse_gelu and not grad: - outputs[2] = jax.lax.psum_scatter( - outputs[2], reduce_scatter_spec, scatter_dimension=scatter_dim, tiled=True + if reduce_spec is not None: + if scatter_dim is None: + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) + else: + outputs[0] = jax.lax.psum_scatter( + outputs[0], reduce_spec, scatter_dimension=scatter_dim, tiled=True ) return outputs @@ -802,12 +808,14 @@ def shardy_sharding_rule( fuse_gelu, grad, use_split_accumulator, + sequence_parallel_output, + sequence_dim, mesh, operand_types, result_types, ): del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator - del mesh, result_types + del sequence_parallel_output, sequence_dim, mesh, result_types prefix = "GemmPrimitive_" @@ -896,6 +904,8 @@ def _te_gemm( fuse_gelu: bool = False, grad: bool = False, use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, + sequence_parallel_output: bool = False, + sequence_dim: int = None, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands @@ -969,6 +979,8 @@ def _te_gemm( fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, + sequence_parallel_output=sequence_parallel_output, + sequence_dim=sequence_dim, ) @@ -1307,9 +1319,9 @@ def gemm( Tuple of sequences representing the contracting dimensions of the operands. batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()), Tuple of sequences representing the batched dimensions of the operands. This is *not* used - to perform a batched matrix multiplication, but it is required to avoid a potentially - undesirable reduction in any batched contracting dimensions when invoked with sharded - operands (e.g. when computing weight gradients in a Flax module). + to perform a batched matrix multiplication, but it is required for TE's custom cuBLAS GEMM + call to avoid a potentially undesirable reduction in any batched contracting dimensions + when invoked with sharded operands (e.g. when computing weight gradients in a Flax module). bias: jax.Array, default = None Optional additive bias term, required for forward GEMM with bias fusion. Only supported with TE's custom call to cuBLAS GEMM. @@ -1327,7 +1339,17 @@ def gemm( TE's custom call to cuBLAS GEMM. use_split_accumulator: bool, default = True Enable promoting some intermediate sums to higher precision when accumulating the result in - the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. + the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only + supported with TE's custom call to cuBLAS GEMM. + sequence_parallel_output: bool, default = False + Produces an output with the first non-batched non-contracting dimension sharded with the + same spec as operand contracting dimensions. This effectively converts the `jax.lax.psum` + for the GEMM output into a `jax.lax.psum_scatter`. Only supported with TE's custom call to + cuBLAS GEMM. + sequence_dim: int, default = None + Index of the sequence dimension for the LHS operand. This controls which dimension of the + GEMM output is scattered when `sequence_parallel_output=True`. When `None`, the first + non-batched non-contracting dimension is assumed to be the sequence dimension. Returns ------- @@ -1358,12 +1380,20 @@ def gemm( if not GemmPrimitive.enabled(): assert kwargs.get("bias", None) is None and not fuse_gelu, ( "TE GEMM was invoked with bias fusion options that are not supported by the " - "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) assert kwargs.get("gelu_input", None) is None and not fuse_bias, ( "TE GEMM was invoked with GeLU fusion options that are not supported by the " - "`jax.lax.dot_general` and `jnp.scaled_matmul` backends used when the custom cuBLAS " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " + "GEMM primitive is disabled." + ) + assert ( + not kwargs.get("sequence_parallel_output", False) + and kwargs.get("sequence_dim", None) is None + ), ( + "TE GEMM was invoked with sequence-parallelism options that are not supported by the " + "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backedns used when the custom cuBLAS " "GEMM primitive is disabled." ) return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index a0fc7b7af..5be551dbd 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -22,6 +22,7 @@ TensorUsage, ) +from .sharding import get_sequence_parallel_dim DENSE_BATCH_FIRST_WARNING_ISSUED = False @@ -41,6 +42,7 @@ def dense( input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, batch_first: bool = True, + sequence_parallel_output: bool = False, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -55,6 +57,8 @@ def dense( bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract batch_first: Assume that X is batched in the first dimension. + sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only + supported for TE custom GEMM with row-parallel kernel axes. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: @@ -69,13 +73,31 @@ def dense( output += jnp.reshape(bias, bias_new_shape) else: output = _dense( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + sequence_parallel_output, + quantizer_set, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6)) -def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set): +@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) +def _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + sequence_parallel_output, + quantizer_set, +): """Internal implementation of dense layer transformation with custom VJP. This function implements the core dense layer transformation logic with support @@ -88,20 +110,38 @@ def _dense(x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_fir contracting_dims: Contracting dimensions specification input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix - quantizer_set: QuantizerSet which contains quantizers for different tensor types batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. + sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only + supported for TE custom GEMM with row-parallel kernel axes. + quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: Transformed output tensor """ output, _ = _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + sequence_parallel_output, + quantizer_set, ) return output def _dense_fwd_rule( - x, kernel, bias, contracting_dims, input_axes, kernel_axes, batch_first, quantizer_set + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + batch_first, + sequence_parallel_output, + quantizer_set, ): """Forward pass rule for dense layer transformation. @@ -161,6 +201,7 @@ def _dense_fwd_rule( batched_dims=((x_bdim,), ()), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, + sequence_parallel_output=sequence_parallel_output and not tex.gemm_uses_jax_dot(), ) if use_bias and tex.gemm_uses_jax_dot(): @@ -181,7 +222,7 @@ def _dense_fwd_rule( def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, batch_first, ctx, grad + contracting_dims, input_axes, kernel_axes, batch_first, sequence_parallel_output, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. @@ -220,11 +261,22 @@ def _dense_bwd_rule( k_contracting_dim = tuple( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) + + # Get sequence-parallel dimension of the FWD input (if it exists) + sequence_dim = get_sequence_parallel_dim(input_axes, fwd_x_contracting_dims, (x_bdim,)) dgrad = tex.gemm( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), batched_dims=((x_bdim,), ()), + sequence_parallel_output=( + sequence_dim is not None + and not sequence_parallel_output + and not tex.gemm_uses_jax_dot() + ), + sequence_dim=( + None if sequence_parallel_output or tex.gemm_uses_jax_dot() else sequence_dim + ), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 5992d3607..6670377f7 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -415,6 +415,8 @@ class DenseGeneral(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + sequence_parallel_output: bool, default = False + Produce a sequence-parallel output with the first non-batch dimension sharded over Optimization parameters ----------------------- @@ -439,6 +441,7 @@ class DenseGeneral(TransformerEngineBase): dtype: DType = jnp.float32 transpose_batch_sequence: bool = False input_axes: Tuple[str, ...] = () + sequence_parallel_output: bool = False def __post_init__(self): if self.transpose_batch_sequence: @@ -511,6 +514,7 @@ def __call__(self, inputs: Array) -> Array: input_axes=self.input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, + sequence_parallel_output=self.sequence_parallel_output, ) if self.enable_low_rank_adaptation: diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index f2c0bc2a1..5f309820c 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1425,6 +1425,7 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, name="out", + sequence_parallel_output=self.enable_sequence_parallel, )(x) out = checkpoint_name(out, "out_proj") diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index 5ccfc71c2..c616aa699 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -24,6 +24,7 @@ with_sharding_constraint_by_logical_axes, TensorUsage, ) +from .sharding import get_sequence_parallel_dim LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False @@ -324,11 +325,16 @@ def _layernorm_dense_bwd_rule( ) # NT GEMM + sequence_dim = get_sequence_parallel_dim( + layernorm_input_axes, x_contracting_dims_in_fwd, (x_bdim,) + ) dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, contracting_dims=(g_constracting_dim, k_constracting_dim), batched_dims=((x_bdim,), ()), + sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(), + sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 507c49c7e..8dd045100 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -29,7 +29,10 @@ noop_quantizer_set, TensorUsage, ) -from .sharding import get_non_contracting_logical_axes +from .sharding import ( + get_non_contracting_logical_axes, + get_sequence_parallel_dim, +) LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False @@ -342,6 +345,7 @@ def _layernorm_mlp_fwd_rule( # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) + sequence_dim = get_sequence_parallel_dim(norm_input_axes, x_contracting_dims, (x_bdim,)) dot_2_output = tex.gemm( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), @@ -349,6 +353,8 @@ def _layernorm_mlp_fwd_rule( batched_dims=((x_bdim,), ()), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, + sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(), + sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -377,6 +383,7 @@ def _layernorm_mlp_fwd_rule( use_bias_2, quantizer_sets, x_bdim, + sequence_dim, ) return dot_2_output, ctx @@ -431,6 +438,7 @@ def _layernorm_mlp_bwd_rule( use_bias_2, quantizer_sets, x_bdim, + sequence_dim, ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -501,6 +509,8 @@ def _layernorm_mlp_bwd_rule( casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), batched_dims=((x_bdim,), ()), + sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(), + sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None, ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index e59c9de12..a7bbef997 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -86,17 +86,61 @@ def get_sharding_map_logic_axis_to_mesh_axis(): return te_logical_axis_to_mesh_axis -def generate_pspec(logical_axis_names): +def get_sequence_parallel_dim(logical_axes, contracting_dims, batch_dims): + """ + Get the index for the sequence-parallel dimension based on the given logical axes. + + The sequence-parallel dimension is assumed to be the only sharded non-batched non-contracting + dimension. + """ + if not logical_axes: + return None + + pspec = generate_pspec(logical_axes, with_flax_rules=True, padded=True) + ldims = [i for i in range(len(logical_axes)) if i not in set(contracting_dims + batch_dims)] + lspecs = [pspec[i] for i in ldims if pspec[i] is not None] + if len(lspecs) == 0: + return None + + assert len(lspecs) == 1, ( + "Expected only 1 non-batched non-contracting dimension to be sharded for " + f"sequence-parallelism, but found {len(lspecs)}: {pspec} @ idx {ldims}" + ) + + return pspec.index(lspecs[0]) + + +def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False): """ Convert logical axes to PartitionSpec """ - rules = get_sharding_map_logic_axis_to_mesh_axis() + rules = None + if with_flax_rules: + try: + import flax + + rules = dict(flax.linen.get_logical_axis_rules()) + except ImportError: + pass + + if rules is None: + warnings.warn( + "Transformer Engine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated" + " and removed in a future version. Please use Flax logical axes with the" + " `flax.linen.logical_axis_rules()` context and optionally use" + " `transformer_engine.jax.flax.extend_logical_axis_rules()` to extend Flax axis rules" + " with Transformer Engine logical axes.", + DeprecationWarning, + ) + rules = get_sharding_map_logic_axis_to_mesh_axis() # mesh_axis_names = [rules[name] for name in logical_axis_names] mesh_axis_names = [] for name in logical_axis_names: axis_name = rules[name] if name in rules else None mesh_axis_names.append(axis_name) pspec = jax.sharding.PartitionSpec(*mesh_axis_names) + if padded: + pspec = get_padded_spec(pspec, len(mesh_axis_names)) return pspec From e950ceb0ad5be6997a71f0e0c10c9e4a3786d692 Mon Sep 17 00:00:00 2001 From: buptzyb Date: Fri, 25 Jul 2025 13:20:29 +0800 Subject: [PATCH 020/153] [PyTorch] Optimize cudagraph static_grad_outputs reuse (#1992) * optimize static grad outputs Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/graph.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 4a2b2c61c..432a47985 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -422,7 +422,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_static_grad_inputs = [None] * len(flatten_sample_args) fwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks - static_grad_outputs = None + static_grad_outputs_dict = {} previous_per_callable_bwd_idx = None for c_id in _order: if c_id > 0: @@ -454,9 +454,21 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument static_outputs = per_callable_static_outputs[per_callable_bwd_idx] bwd_graph = bwd_graphs[per_callable_bwd_idx] # For now, assumes all static_outputs require grad - if not _reuse_graph_input_output_buffers or static_grad_outputs is None: + if _reuse_graph_input_output_buffers: # Note for _reuse_graph_input_output_buffers: grad output is only used # within backward, so we can reuse the same static buffers every time. + static_grad_outputs_keys = tuple( + (o.shape, o.dtype, o.layout) for o in static_outputs if o.requires_grad + ) + if static_grad_outputs_keys in static_grad_outputs_dict: + static_grad_outputs = static_grad_outputs_dict[static_grad_outputs_keys] + else: + static_grad_outputs = tuple( + torch.empty_like(o) if o.requires_grad else None + for o in static_outputs + ) + static_grad_outputs_dict[static_grad_outputs_keys] = static_grad_outputs + else: static_grad_outputs = tuple( torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) From 374849e35575cc6d55677546011e8065910aced9 Mon Sep 17 00:00:00 2001 From: Evgeny Tsykunov Date: Fri, 25 Jul 2025 09:10:41 +0200 Subject: [PATCH 021/153] [PyTorch] Enable generic QK norm support (+ RMSNorm/LayerNorm) (#1966) * Support RMSNorm for QK Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * rms -> RMSNorm, l2 -> L2Normalization (align with current pattern) Signed-off-by: Evgeny * Support LayerNorm + init refactor Signed-off-by: Evgeny * Before/after RoPE Signed-off-by: Evgeny * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix pylint Signed-off-by: Evgeny --------- Signed-off-by: Evgeny Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_qk_norm.py | 217 +++++++++++++++--- .../pytorch/attention/multi_head_attention.py | 128 +++++++++-- transformer_engine/pytorch/transformer.py | 32 ++- 3 files changed, 311 insertions(+), 66 deletions(-) diff --git a/tests/pytorch/test_qk_norm.py b/tests/pytorch/test_qk_norm.py index 6f4e62f81..d45ec283c 100644 --- a/tests/pytorch/test_qk_norm.py +++ b/tests/pytorch/test_qk_norm.py @@ -8,10 +8,10 @@ import torch -@pytest.mark.parametrize("use_qk_norm", [False, True]) +@pytest.mark.parametrize("qk_norm_type", [None, "L2Normalization", "RMSNorm", "LayerNorm"]) @pytest.mark.parametrize("attention_type", ["self", "cross"]) @pytest.mark.parametrize("qk_norm_eps", [1e-6, 1e-5]) -def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None: +def test_qk_norm_functionality(qk_norm_type, attention_type, qk_norm_eps) -> None: """Test QK normalization functionality, module structure, and numerical behavior.""" hidden_size = 256 num_attention_heads = 8 @@ -22,25 +22,59 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None hidden_size=hidden_size, num_attention_heads=num_attention_heads, attention_type=attention_type, - use_qk_norm=use_qk_norm, + qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, bias=False, device="cuda", ).cuda() - # Check module structure based on use_qk_norm parameter - if use_qk_norm: - assert hasattr(mha, "qk_norm"), "Should have qk_norm module when use_qk_norm=True" - assert not hasattr(mha, "q_l2norm"), "Should not have separate q_l2norm module" - assert not hasattr(mha, "k_l2norm"), "Should not have separate k_l2norm module" - # Check that the module is L2Norm type - from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization - - assert isinstance( - mha.qk_norm, L2Normalization - ), "qk_norm should be an L2Normalization module" + # Check module structure based on qk_norm_type parameter + if qk_norm_type is not None: + assert mha.q_norm is not None, "Should have q_norm module when qk_norm_type is not None" + assert mha.k_norm is not None, "Should have k_norm module when qk_norm_type is not None" + + # Check that the modules are of the correct type + if qk_norm_type == "L2Normalization": + from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization + + assert isinstance( + mha.q_norm, L2Normalization + ), "q_norm should be an L2Normalization module" + assert isinstance( + mha.k_norm, L2Normalization + ), "k_norm should be an L2Normalization module" + # For L2 normalization, q_norm and k_norm should be the same instance (parameter-free) + assert ( + mha.q_norm is mha.k_norm + ), "q_norm and k_norm should be the same instance for L2 normalization" + + elif qk_norm_type == "RMSNorm": + from transformer_engine.pytorch.module.rmsnorm import RMSNorm + + assert isinstance(mha.q_norm, RMSNorm), "q_norm should be an RMSNorm module" + assert isinstance(mha.k_norm, RMSNorm), "k_norm should be an RMSNorm module" + # For RMS normalization, q_norm and k_norm should be separate instances + assert ( + mha.q_norm is not mha.k_norm + ), "q_norm and k_norm should be separate instances for RMS normalization" + + elif qk_norm_type == "LayerNorm": + from transformer_engine.pytorch.module.layernorm import LayerNorm + + assert isinstance(mha.q_norm, LayerNorm), "q_norm should be a LayerNorm module" + assert isinstance(mha.k_norm, LayerNorm), "k_norm should be a LayerNorm module" + # For LayerNorm, q_norm and k_norm should be separate instances + assert ( + mha.q_norm is not mha.k_norm + ), "q_norm and k_norm should be separate instances for LayerNorm" + + else: + # For extensibility - just ensure they exist + assert mha.q_norm is not None, f"q_norm should exist for qk_norm_type={qk_norm_type}" + assert mha.k_norm is not None, f"k_norm should exist for qk_norm_type={qk_norm_type}" else: - assert not hasattr(mha, "qk_norm"), "Should not have qk_norm module when use_qk_norm=False" + assert mha.q_norm is None, "Should not have q_norm module when qk_norm_type is None" + assert mha.k_norm is None, "Should not have k_norm module when qk_norm_type is None" # Create input tensors batch_size = 2 # Use a fixed batch size for testing @@ -89,17 +123,14 @@ def test_qk_norm_functionality(use_qk_norm, attention_type, qk_norm_eps) -> None assert not torch.isinf(output_with_rope).any(), "RoPE output contains Inf" -def test_qk_norm_output_difference() -> None: +@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"]) +def test_qk_norm_output_difference(qk_norm_type) -> None: """Test that QK normalization actually changes the output compared to no normalization.""" hidden_size = 256 num_attention_heads = 8 seq_len = 128 batch_size = 2 - # Use same random seed to ensure identical weight initialization - current_rng_state = torch.get_rng_state() - current_cuda_rng_state = torch.cuda.get_rng_state() - # Reset to a known seed for reproducible initialization torch.manual_seed(42) torch.cuda.manual_seed(42) @@ -108,7 +139,7 @@ def test_qk_norm_output_difference() -> None: mha_with_norm = MultiheadAttention( hidden_size=hidden_size, num_attention_heads=num_attention_heads, - use_qk_norm=True, + qk_norm_type=qk_norm_type, bias=False, device="cuda", ).cuda() @@ -121,7 +152,7 @@ def test_qk_norm_output_difference() -> None: mha_no_norm = MultiheadAttention( hidden_size=hidden_size, num_attention_heads=num_attention_heads, - use_qk_norm=False, + qk_norm_type=None, bias=False, device="cuda", ).cuda() @@ -139,10 +170,11 @@ def test_qk_norm_output_difference() -> None: # Outputs should be different when QK normalization is enabled assert not torch.allclose( output_with_norm, output_no_norm, atol=1e-6 - ), "QK normalization should change the output, but outputs are identical" + ), f"QK normalization ({qk_norm_type}) should change the output, but outputs are identical" -def test_qk_norm_with_fused_qkv() -> None: +@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"]) +def test_qk_norm_with_fused_qkv(qk_norm_type) -> None: """Test QK normalization works with fused QKV parameters.""" hidden_size = 256 num_attention_heads = 8 @@ -152,7 +184,7 @@ def test_qk_norm_with_fused_qkv() -> None: hidden_size=hidden_size, num_attention_heads=num_attention_heads, fuse_qkv_params=True, - use_qk_norm=True, + qk_norm_type=qk_norm_type, bias=False, device="cuda", ).cuda() @@ -173,7 +205,8 @@ def test_qk_norm_with_fused_qkv() -> None: ), f"Output shape mismatch: {output.shape}" -def test_qk_norm_transformer_layer_output_difference() -> None: +@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"]) +def test_qk_norm_transformer_layer_output_difference(qk_norm_type) -> None: """Test that QK normalization actually changes TransformerLayer output compared to no normalization.""" from transformer_engine.pytorch import TransformerLayer @@ -183,10 +216,6 @@ def test_qk_norm_transformer_layer_output_difference() -> None: seq_len = 128 batch_size = 2 - # Use same random seed to ensure identical weight initialization - current_rng_state = torch.get_rng_state() - current_cuda_rng_state = torch.cuda.get_rng_state() - # Reset to a known seed for reproducible initialization torch.manual_seed(42) torch.cuda.manual_seed(42) @@ -196,7 +225,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None: hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_attention_heads=num_attention_heads, - use_qk_norm=True, + qk_norm_type=qk_norm_type, bias=False, device="cuda", ).cuda() @@ -210,7 +239,7 @@ def test_qk_norm_transformer_layer_output_difference() -> None: hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size, num_attention_heads=num_attention_heads, - use_qk_norm=False, + qk_norm_type=None, bias=False, device="cuda", ).cuda() @@ -226,9 +255,10 @@ def test_qk_norm_transformer_layer_output_difference() -> None: output_no_norm = transformer_no_norm(hidden_states) # Outputs should be different when QK normalization is enabled - assert not torch.allclose( - output_with_norm, output_no_norm, atol=1e-6 - ), "QK normalization should change the TransformerLayer output, but outputs are identical" + assert not torch.allclose(output_with_norm, output_no_norm, atol=1e-6), ( + f"QK normalization ({qk_norm_type}) should change the TransformerLayer output, but outputs" + " are identical" + ) # Check that outputs have expected shapes and properties assert output_with_norm.shape == ( @@ -240,3 +270,120 @@ def test_qk_norm_transformer_layer_output_difference() -> None: assert not torch.isinf(output_with_norm).any(), "Output with QK norm contains Inf" assert not torch.isnan(output_no_norm).any(), "Output without QK norm contains NaN" assert not torch.isinf(output_no_norm).any(), "Output without QK norm contains Inf" + + +@pytest.mark.parametrize("qk_norm_type", ["L2Normalization", "RMSNorm", "LayerNorm"]) +def test_qk_norm_before_after_rope(qk_norm_type) -> None: + """Test that QK normalization before and after RoPE works without errors.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 64 + batch_size = 2 + + # Create model with QK norm after RoPE (default) + mha_after = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qk_norm_type=qk_norm_type, + qk_norm_before_rope=False, + bias=False, + device="cuda", + ).cuda() + + # Create model with QK norm before RoPE + mha_before = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qk_norm_type=qk_norm_type, + qk_norm_before_rope=True, + bias=False, + device="cuda", + ).cuda() + + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Create RoPE embeddings + head_dim = hidden_size // num_attention_heads + rotary_dim = head_dim // 2 + rotary_pos_emb = torch.randn(seq_len, 1, 1, rotary_dim, device="cuda", dtype=torch.float32) + + with torch.no_grad(): + output_after_rope = mha_after(hidden_states, rotary_pos_emb=rotary_pos_emb) + output_before_rope = mha_before(hidden_states, rotary_pos_emb=rotary_pos_emb) + + output_after_no_rope = mha_after(hidden_states) + output_before_no_rope = mha_before(hidden_states) + + # Check output shapes and properties + expected_shape = (seq_len, batch_size, hidden_size) + for output in [ + output_after_rope, + output_before_rope, + output_after_no_rope, + output_before_no_rope, + ]: + assert output.shape == expected_shape, f"Output shape mismatch: {output.shape}" + assert not torch.isnan(output).any(), "Output contains NaN" + assert not torch.isinf(output).any(), "Output contains Inf" + + assert output_after_rope.shape == output_before_rope.shape, "Outputs should have same shape" + assert mha_after.qk_norm_before_rope == False, "mha_after should have qk_norm_before_rope=False" + assert mha_before.qk_norm_before_rope == True, "mha_before should have qk_norm_before_rope=True" + + +def test_different_qk_norm_types_produce_different_outputs() -> None: + """Test that different QK normalization types produce different outputs.""" + hidden_size = 256 + num_attention_heads = 8 + seq_len = 128 + batch_size = 2 + + # Use same random seed to ensure identical weight initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create model with L2 normalization + mha_l2 = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qk_norm_type="L2Normalization", + bias=False, + device="cuda", + ).cuda() + + # Reset to same seed for identical initialization + torch.manual_seed(42) + torch.cuda.manual_seed(42) + + # Create model with RMS normalization + mha_rms = MultiheadAttention( + hidden_size=hidden_size, + num_attention_heads=num_attention_heads, + qk_norm_type="RMSNorm", + bias=False, + device="cuda", + ).cuda() + + # Create input tensors + hidden_states = torch.randn( + seq_len, batch_size, hidden_size, device="cuda", dtype=torch.float32 + ) + + # Compare outputs with identical weights but different QK norm types + with torch.no_grad(): + output_l2 = mha_l2(hidden_states) + output_rms = mha_rms(hidden_states) + + # Outputs should be different when using different normalization types + assert not torch.allclose( + output_l2, output_rms, atol=1e-6 + ), "L2 and RMS normalization should produce different outputs, but outputs are identical" + + # Check that outputs have expected shapes and properties + assert output_l2.shape == output_rms.shape, "L2 and RMS outputs should have same shape" + assert not torch.isnan(output_l2).any(), "L2 output contains NaN" + assert not torch.isinf(output_l2).any(), "L2 output contains Inf" + assert not torch.isnan(output_rms).any(), "RMS output contains NaN" + assert not torch.isinf(output_rms).any(), "RMS output contains Inf" diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 142044240..f25a09fbe 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -11,7 +11,7 @@ from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.float8_tensor import Float8Tensor from transformer_engine.pytorch.module.base import TransformerEngineBaseModule -from transformer_engine.pytorch.module import LayerNormLinear, Linear +from transformer_engine.pytorch.module import LayerNormLinear, Linear, RMSNorm, LayerNorm from transformer_engine.pytorch.ops.basic.l2normalization import L2Normalization from transformer_engine.pytorch.utils import ( SplitAlongDim, @@ -175,14 +175,23 @@ class MultiheadAttention(torch.nn.Module): parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument `fuse_wgrad_accumulation`. - use_qk_norm: bool, default = 'False' - if set to `True`, L2 normalization is applied to query and key tensors - after RoPE (if applicable) but before attention computation. - This follows the Llama4 approach for QK normalization to improve - training stability and model performance. + qk_norm_type: Optional[str], default = None + type of normalization to apply to query and key tensors. + Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied. + When 'L2Normalization', L2 normalization is applied to query and key tensors. + When 'RMSNorm', RMS normalization is applied to query and key tensors. + When 'LayerNorm', layer normalization is applied to query and key tensors. + Normalization is applied after RoPE (if applicable) but before attention computation + when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach + for QK normalization to improve training stability and model performance. qk_norm_eps: float, default = 1e-6 - epsilon value for L2 normalization of query and key tensors. - Only used when `use_qk_norm` is True. + epsilon value for normalization of query and key tensors. + Only used when `qk_norm_type` is not None. + qk_norm_before_rope: bool, default = `False` + if set to `True`, query and key normalization is applied before rotary position + embedding. When `False` (default), normalization is applied after RoPE. + This parameter allows supporting different architectural variants that apply + QK normalization at different points. seq_length: Optional[int], default = `None` sequence length of input samples. Needed for JIT Warmup, a technique where jit fused functions are warmed up before training to ensure same kernels are used for @@ -231,8 +240,9 @@ def __init__( device: Union[torch.device, str] = "cuda", qkv_format: str = "sbhd", name: str = None, - use_qk_norm: bool = False, + qk_norm_type: Optional[str] = None, qk_norm_eps: float = 1e-6, + qk_norm_before_rope: bool = False, seq_length: Optional[int] = None, micro_batch_size: Optional[int] = None, ) -> None: @@ -264,6 +274,7 @@ def __init__( qkv_weight_interleaved = False self.qkv_weight_interleaved = qkv_weight_interleaved self.rotary_pos_interleaved = rotary_pos_interleaved + self.qk_norm_before_rope = qk_norm_before_rope assert attention_type in AttnTypes, f"attention_type {attention_type} not supported" if layer_number is not None: @@ -288,7 +299,6 @@ def __init__( self.hidden_size_kv = self.hidden_size_per_attention_head * self.num_gqa_groups self.name = name - self.use_qk_norm = use_qk_norm common_gemm_kwargs = { "fuse_wgrad_accumulation": fuse_wgrad_accumulation, @@ -300,13 +310,9 @@ def __init__( "device": device, } - # Initialize L2 normalization modules for query and key if enabled - if self.use_qk_norm: - self.qk_norm = L2Normalization( - eps=qk_norm_eps, - seq_length=seq_length, - micro_batch_size=micro_batch_size, - ) + self.q_norm, self.k_norm = self._create_qk_norm_modules( + qk_norm_type, qk_norm_eps, device, seq_length, micro_batch_size + ) qkv_parallel_mode = "column" if set_parallel_mode else None @@ -427,6 +433,78 @@ def __init__( **common_gemm_kwargs, ) + def _create_qk_norm_modules( + self, + qk_norm_type: Optional[str], + qk_norm_eps: float, + device: Union[torch.device, str], + seq_length: Optional[int] = None, + micro_batch_size: Optional[int] = None, + ) -> Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]]: + """ + Create query and key normalization modules based on the specified normalization type. + + Parameters + ---------- + qk_norm_type : Optional[str] + Type of normalization to apply. Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm' + qk_norm_eps : float + Epsilon value for numerical stability + device : Union[torch.device, str] + Device to place the normalization modules on + seq_length : Optional[int], default = None + Sequence length for L2Normalization optimization + micro_batch_size : Optional[int], default = None + Micro batch size for L2Normalization optimization + + Returns + ------- + Tuple[Optional[torch.nn.Module], Optional[torch.nn.Module]] + Query and key normalization modules (q_norm, k_norm) + """ + if qk_norm_type is None: + return None, None + + if qk_norm_type == "L2Normalization": + l2_norm = L2Normalization( + eps=qk_norm_eps, + seq_length=seq_length, + micro_batch_size=micro_batch_size, + ) + # L2Normalization is parameter-free, so we can share the same instance + return l2_norm, l2_norm + + if qk_norm_type == "RMSNorm": + q_norm = RMSNorm( + normalized_shape=self.hidden_size_per_attention_head, + eps=qk_norm_eps, + device=device, + ) + k_norm = RMSNorm( + normalized_shape=self.hidden_size_per_attention_head, + eps=qk_norm_eps, + device=device, + ) + return q_norm, k_norm + + if qk_norm_type == "LayerNorm": + q_norm = LayerNorm( + normalized_shape=self.hidden_size_per_attention_head, + eps=qk_norm_eps, + device=device, + ) + k_norm = LayerNorm( + normalized_shape=self.hidden_size_per_attention_head, + eps=qk_norm_eps, + device=device, + ) + return q_norm, k_norm + + raise ValueError( + f"Unsupported QK norm type: {qk_norm_type}. " + "Supported types: ['L2Normalization', 'RMSNorm', 'LayerNorm']" + ) + def set_tensor_parallel_group(self, tp_group: Union[dist_group_type, None]) -> None: """ Set the tensor parallel group for the given @@ -789,6 +867,14 @@ def forward( ) query_layer = query_layer.view(*new_tensor_shape) + # =========================== + # Apply normalization to query and key tensors (before RoPE if configured) + # =========================== + + if self.q_norm is not None and self.qk_norm_before_rope: + query_layer = self.q_norm(query_layer) + key_layer = self.k_norm(key_layer) + # ====================================================== # Apply relative positional encoding (rotary embedding) # ====================================================== @@ -843,12 +929,12 @@ def forward( ) # =========================== - # Apply L2 normalization to query and key tensors + # Apply normalization to query and key tensors (after RoPE if not applied before) # =========================== - if self.use_qk_norm: - query_layer = self.qk_norm(query_layer) - key_layer = self.qk_norm(key_layer) + if self.q_norm is not None and not self.qk_norm_before_rope: + query_layer = self.q_norm(query_layer) + key_layer = self.k_norm(key_layer) # =========================== # Core attention computation diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index b9d59f496..1a98f2f52 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -236,14 +236,23 @@ class TransformerLayer(torch.nn.Module): parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument `fuse_wgrad_accumulation`. - use_qk_norm: bool, default = 'False' - if set to `True`, L2 normalization is applied to query and key tensors - after RoPE (if applicable) but before attention computation. - This follows the Llama4 approach for QK normalization to improve - training stability and model performance. + qk_norm_type: Optional[str], default = None + type of normalization to apply to query and key tensors. + Options: None, 'L2Normalization', 'RMSNorm', 'LayerNorm'. When None, no normalization is applied. + When 'L2Normalization', L2 normalization is applied to query and key tensors. + When 'RMSNorm', RMS normalization is applied to query and key tensors. + When 'LayerNorm', layer normalization is applied to query and key tensors. + Normalization is applied after RoPE (if applicable) but before attention computation + when `qk_norm_before_rope` is False. This follows the e.g. Llama4 approach for + QK normalization to improve training stability and model performance. qk_norm_eps: float, default = 1e-6 - epsilon value for L2 normalization of query and key tensors. - Only used when `use_qk_norm` is True. + epsilon value for normalization of query and key tensors. + Only used when `qk_norm_type` is not None. + qk_norm_before_rope: bool, default = `False` + if set to `True`, query and key normalization is applied before rotary position + embedding. When `False` (default), normalization is applied after RoPE. + This parameter allows supporting different architectural variants that apply + QK normalization at different points. """ def __init__( @@ -293,8 +302,9 @@ def __init__( device: Union[torch.device, str] = "cuda", attn_input_format: str = "sbhd", name: str = None, - use_qk_norm: bool = False, + qk_norm_type: Optional[str] = None, qk_norm_eps: float = 1e-6, + qk_norm_before_rope: bool = False, ) -> None: super().__init__() @@ -397,8 +407,9 @@ def __init__( return_bias=not self.parallel_attention_mlp, normalization=normalization, device=device, - use_qk_norm=use_qk_norm, + qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, + qk_norm_before_rope=qk_norm_before_rope, name=name + ".self_attention" if name is not None else None, ) @@ -413,8 +424,9 @@ def __init__( return_bias=True, normalization=normalization, device=device, - use_qk_norm=use_qk_norm, + qk_norm_type=qk_norm_type, qk_norm_eps=qk_norm_eps, + qk_norm_before_rope=qk_norm_before_rope, name=name + ".inter_attention" if name is not None else None, ) From 1470116ebc147a890f2e9384c940d1c1efa36d9f Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Fri, 25 Jul 2025 08:52:58 -0400 Subject: [PATCH 022/153] [C][PyTorch] Remove deprecated `device_id` arg for multi tensor API (#1994) * Remove deprecated device arg Signed-off-by: Kirthi Shankar Sivamani * Remove test Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- tests/pytorch/test_fused_optimizer.py | 14 ---- .../include/transformer_engine/multi_tensor.h | 49 ++++------- .../common/multi_tensor/adam.cu | 83 +++++++++---------- .../common/multi_tensor/compute_scale.cu | 19 +++-- .../common/multi_tensor/l2norm.cu | 25 +++--- .../multi_tensor/multi_tensor_apply.cuh | 51 +----------- .../common/multi_tensor/scale.cu | 9 +- transformer_engine/common/multi_tensor/sgd.cu | 23 +++-- .../csrc/extensions/multi_tensor/adam.cpp | 16 ++-- .../extensions/multi_tensor/compute_scale.cpp | 3 +- .../csrc/extensions/multi_tensor/l2norm.cpp | 7 +- .../csrc/extensions/multi_tensor/scale.cpp | 3 +- .../csrc/extensions/multi_tensor/sgd.cpp | 3 +- 13 files changed, 100 insertions(+), 205 deletions(-) diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index e04f0477b..f1e80e698 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -112,13 +112,6 @@ def test_half(self): def test_bfloat16(self): self.gen_single_type_test(param_type=torch.bfloat16, skip_assert=True) - @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - def test_multi_params(self): sizes = [[4096, 1024], [4096], [4096, 2048], [32320, 1024], [1]] @@ -530,13 +523,6 @@ def test_float(self): def test_half(self): self.gen_single_type_test(param_type=torch.float16) - @pytest.mark.skipif(torch.cuda.device_count() < 2, reason="more than 1 GPU required") - def test_multi_device(self): - devices = ("cuda:0", "cuda:1") - for current_dev, tensor_dev in product(devices, devices): - with torch.cuda.device(current_dev): - self.gen_single_type_test(param_type=torch.float, device=tensor_dev) - class Model(torch.nn.Module): def __init__(self): diff --git a/transformer_engine/common/include/transformer_engine/multi_tensor.h b/transformer_engine/common/include/transformer_engine/multi_tensor.h index c21fd2627..a01b2e5da 100644 --- a/transformer_engine/common/include/transformer_engine/multi_tensor.h +++ b/transformer_engine/common/include/transformer_engine/multi_tensor.h @@ -20,7 +20,6 @@ extern "C" { /*! \brief Computes L2 norm for a list of tensors. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -33,22 +32,19 @@ extern "C" { * \param[out] ret_per_tensor L2 norm for each tensor. * \param[in] per_tensor Whether to calculate per tensor or cumulative norm. * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor ret_per_tensor, int per_tensor, - int max_chunks_per_tensor, const int device_id, - cudaStream_t stream); + int max_chunks_per_tensor, cudaStream_t stream); /*! \brief Computes L2 norm for a list of tensors after unscaling. * * Unscaling is only done for computing the L2 norm. The tensors themselves are not updated. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -62,7 +58,6 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen * \param[in] inv_scale Scalar for the unscaling operation. * \param[in] per_tensor Whether to calculate per tensor or cumulative norm. * \param[in] max_chunks_per_tensor Maximum number of chunks in any input tensor. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, @@ -71,12 +66,11 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor ret_per_tensor, NVTETensor inv_scale, int per_tensor, int max_chunks_per_tensor, - const int device_id, cudaStream_t stream); + cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -91,7 +85,6 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, * \param[in] mode Whether to use AdamW (L2 penalty applied to params). * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, @@ -99,13 +92,12 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, - const int device_id, cudaStream_t stream); + cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer * where the master parameters only store the remainder bits. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -120,20 +112,18 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso * \param[in] mode Whether to use AdamW (L2 penalty applied to params). * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_param_remainder_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, - const float weight_decay, const int device_id, cudaStream_t stream); + const float weight_decay, cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer * when model parameters are in Float8 precision. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -149,7 +139,6 @@ void nvte_multi_tensor_adam_param_remainder_cuda( * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. * \param[in] fp8_dtype FP8 data type for model parameters. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, @@ -158,13 +147,12 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const NVTEDType fp8_dtype, - const int device_id, cudaStream_t stream); + cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer * with CUDA graph support and LR scheduling. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -180,20 +168,18 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. * \param[in] inv_scale Scalar for the unscaling operation. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_capturable_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, - const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); + const float weight_decay, NVTETensor inv_scale, cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for Adam optimizer * with CUDA graph support, LR scheduling, and FP32 master weights. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -209,19 +195,17 @@ void nvte_multi_tensor_adam_capturable_cuda( * \param[in] bias_correction Whether to apply correction factor for moment estimates. * \param[in] weight_decay L2 penalty for weight decay. * \param[in] inv_scale Scalar for the unscaling operation. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_adam_capturable_master_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, - const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream); + const float weight_decay, NVTETensor inv_scale, cudaStream_t stream); /*! \brief Compute and apply gradient update to parameters for SGD optimizer. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -236,19 +220,17 @@ void nvte_multi_tensor_adam_capturable_master_cuda( * \param[in] first_run Whether momentum buffers have been initialized. * \param[in] wd_after_momentum Whether to applied weight decay after momentum update. * \param[in] scale Scalar for the scaling operation. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, float wd, float momentum, float dampening, float lr, int nesterov, int first_run, int wd_after_momentum, float scale, - const int device_id, cudaStream_t stream); + cudaStream_t stream); /*! \brief Check overflow and scale a list of tensors. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -256,17 +238,15 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor * \param[in] num_tensor_lists Size (dim0) of tensor_lists. * \param[in] num_tensors_per_list Size (dim1) of tensor_lists. * \param[in] scale Scalar for the scaling operation. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, - float scale, const int device_id, cudaStream_t stream); + float scale, cudaStream_t stream); /*! \brief Check overflow and scale a list of tensors. * * \warning This API is **experimental** and subject to change. - * \warning Argument device_id is deprecated and will be removed in a future release. * * \param[in] chunk_size Number of tensor elements processed by a CUDA block. * \param[in] noop_flag If this single element tensor has non-zero value, kernel will exit immediately. @@ -276,13 +256,14 @@ void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETens * \param[in] max_fp8 Maximum representible value in underlying FP8 format. * \param[in] force_pow_2_scales Ensure scaling factors are a power of 2. * \param[in] epsilon Term added to the denominator for numerical stability. - * \param[in] device_id [DEPRECATED] CUDA device ID for this operation. * \param[in] stream CUDA stream used for this operation. */ -void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( - int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, - const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, - const int device_id, cudaStream_t stream); +void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, + const size_t num_tensor_lists, + const size_t num_tensors_per_list, + float max_fp8, int force_pow_2_scales, + float epsilon, cudaStream_t stream); #ifdef __cplusplus } // extern "C" diff --git a/transformer_engine/common/multi_tensor/adam.cu b/transformer_engine/common/multi_tensor/adam.cu index 2e117eb6b..9dec2c178 100644 --- a/transformer_engine/common/multi_tensor/adam.cu +++ b/transformer_engine/common/multi_tensor/adam.cu @@ -576,7 +576,7 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, - const float weight_decay, const int device_id, cudaStream_t stream) { + const float weight_decay, cudaStream_t stream) { // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -643,20 +643,20 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<4>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctor(), device_id, - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + AdamFunctor(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); } else { // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( p_in_type_te, p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, - multi_tensor_apply<5>( - (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), device_id, stream, - beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);)); + multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, + tensor_lists, + AdamFunctorMaster(), + stream, beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)); } } else { if (num_tensor_lists == 4) { @@ -666,9 +666,9 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamFunctor(), device_id, - stream, beta1, beta2, bias_correction1, bias_correction2, - epsilon, lr, (adamMode_t)mode, weight_decay);)); + AdamFunctor(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + (adamMode_t)mode, weight_decay);)); } else { // g, p, m, v, p_master TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( @@ -677,9 +677,8 @@ void multi_tensor_adam_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctorMaster(), - device_id, stream, beta1, beta2, bias_correction1, - bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);)); + stream, beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)); } } NVTE_CHECK_CUDA(cudaGetLastError()); @@ -690,7 +689,7 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -732,8 +731,8 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, Tensor noop_flag, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( g_in_type_te, g_in_type, multi_tensor_apply<5>((int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMasterParamRemainder(), device_id, - stream, beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, + AdamFunctorMasterParamRemainder(), stream, + beta1, beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -743,7 +742,7 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const DType fp8_dtype, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { // Handle bias correction mode float bias_correction1 = 1.0f, bias_correction2 = 1.0f; if (bias_correction == 1) { @@ -813,9 +812,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<5, true>( (int64_t)BLOCK_SIZE, (int64_t)chunk_size, noop_flag, tensor_lists, - AdamFunctorMaster(), device_id, stream, beta1, - beta2, bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);)); + AdamFunctorMaster(), stream, beta1, beta2, + bias_correction1, bias_correction2, epsilon, lr, (adamMode_t)mode, weight_decay);)); } else { TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( fp8_dtype, FP8_T, @@ -823,9 +821,8 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, Tensor noop_flag, g_in_type_te, g_in_type, multi_tensor_apply<5, true>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, AdamFunctorMaster(), - device_id, stream, beta1, beta2, bias_correction1, - bias_correction2, epsilon, lr, (adamMode_t)mode, - weight_decay);)); + stream, beta1, beta2, bias_correction1, bias_correction2, + epsilon, lr, (adamMode_t)mode, weight_decay);)); } NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -835,7 +832,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, const float beta1, const float beta2, const float epsilon, Tensor step, const int mode, const int bias_correction, const float weight_decay, Tensor inv_scale, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { // Check tensor list sizes // 4 tensor lists: g, p, m, v const size_t num_tensor_lists = tensor_lists.size(); @@ -867,7 +864,7 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, Tensor noop_flag, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamCapturableFunctor(), device_id, stream, beta1, beta2, + AdamCapturableFunctor(), stream, beta1, beta2, reinterpret_cast(step.data.dptr), bias_correction, epsilon, reinterpret_cast(lr.data.dptr), (adamMode_t)mode, weight_decay, reinterpret_cast(inv_scale.data.dptr));) @@ -880,8 +877,7 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, Tensor lr, const float beta1, const float beta2, const float epsilon, Tensor step, const int mode, const int bias_correction, const float weight_decay, - Tensor inv_scale, const int device_id, - cudaStream_t stream) { + Tensor inv_scale, cudaStream_t stream) { // Check tensor list sizes // 4 tensor lists: g, p, m, v, p_master const size_t num_tensor_lists = tensor_lists.size(); @@ -916,10 +912,10 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, Tensor noop_flag, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<5>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - AdamCapturableMasterFunctor(), device_id, stream, beta1, - beta2, reinterpret_cast(step.data.dptr), bias_correction, - epsilon, reinterpret_cast(lr.data.dptr), (adamMode_t)mode, - weight_decay, reinterpret_cast(inv_scale.data.dptr));) + AdamCapturableMasterFunctor(), stream, beta1, beta2, + reinterpret_cast(step.data.dptr), bias_correction, epsilon, + reinterpret_cast(lr.data.dptr), (adamMode_t)mode, weight_decay, + reinterpret_cast(inv_scale.data.dptr));) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -932,28 +928,28 @@ void nvte_multi_tensor_adam_cuda(int chunk_size, NVTETensor noop_flag, NVTETenso const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_cuda); using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, - epsilon, step, mode, bias_correction, weight_decay, device_id, stream); + epsilon, step, mode, bias_correction, weight_decay, stream); } void nvte_multi_tensor_adam_param_remainder_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, const float lr, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, - const float weight_decay, const int device_id, cudaStream_t stream) { + const float weight_decay, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_param_remainder_cuda); using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_param_remainder_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, - epsilon, step, mode, bias_correction, weight_decay, device_id, stream); + epsilon, step, mode, bias_correction, weight_decay, stream); } void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, @@ -962,22 +958,21 @@ void nvte_multi_tensor_adam_fp8_cuda(int chunk_size, NVTETensor noop_flag, const float beta1, const float beta2, const float epsilon, const int step, const int mode, const int bias_correction, const float weight_decay, const NVTEDType fp8_dtype, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_fp8_cuda); using namespace transformer_engine; multi_tensor_adam::multi_tensor_adam_fp8_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), lr, beta1, beta2, - epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), device_id, - stream); + epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), stream); } void nvte_multi_tensor_adam_capturable_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, - const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) { + const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_capturable_cuda); using namespace transformer_engine; @@ -985,14 +980,14 @@ void nvte_multi_tensor_adam_capturable_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, - bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); + bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream); } void nvte_multi_tensor_adam_capturable_master_cuda( int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor lr, const float beta1, const float beta2, const float epsilon, NVTETensor step, const int mode, const int bias_correction, - const float weight_decay, NVTETensor inv_scale, const int device_id, cudaStream_t stream) { + const float weight_decay, NVTETensor inv_scale, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_adam_capturable_master_cuda); using namespace transformer_engine; @@ -1000,5 +995,5 @@ void nvte_multi_tensor_adam_capturable_master_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), *convertNVTETensorCheck(lr), beta1, beta2, epsilon, *convertNVTETensorCheck(step), mode, - bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), device_id, stream); + bias_correction, weight_decay, *convertNVTETensorCheck(inv_scale), stream); } diff --git a/transformer_engine/common/multi_tensor/compute_scale.cu b/transformer_engine/common/multi_tensor/compute_scale.cu index ebdcfbb56..dc4eb8714 100644 --- a/transformer_engine/common/multi_tensor/compute_scale.cu +++ b/transformer_engine/common/multi_tensor/compute_scale.cu @@ -58,26 +58,27 @@ struct ComputeScaleAndScaleInvFunctor { void multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, float max_fp8, bool force_pow_2_scales, - float epsilon, const int device_id, - cudaStream_t stream) { + float epsilon, cudaStream_t stream) { multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - ComputeScaleAndScaleInvFunctor(), device_id, stream, max_fp8, - force_pow_2_scales, epsilon); + ComputeScaleAndScaleInvFunctor(), stream, max_fp8, force_pow_2_scales, + epsilon); NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace multi_tensor_compute_scale } // namespace transformer_engine -void nvte_multi_tensor_compute_scale_and_scale_inv_cuda( - int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, - const size_t num_tensors_per_list, float max_fp8, int force_pow_2_scales, float epsilon, - const int device_id, cudaStream_t stream) { +void nvte_multi_tensor_compute_scale_and_scale_inv_cuda(int chunk_size, NVTETensor noop_flag, + NVTETensor **tensor_lists, + const size_t num_tensor_lists, + const size_t num_tensors_per_list, + float max_fp8, int force_pow_2_scales, + float epsilon, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_compute_scale_and_scale_inv_cuda); using namespace transformer_engine; multi_tensor_compute_scale::multi_tensor_compute_scale_and_scale_inv_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), max_fp8, - force_pow_2_scales, epsilon, device_id, stream); + force_pow_2_scales, epsilon, stream); } diff --git a/transformer_engine/common/multi_tensor/l2norm.cu b/transformer_engine/common/multi_tensor/l2norm.cu index f27beee3a..ca2fce27a 100644 --- a/transformer_engine/common/multi_tensor/l2norm.cu +++ b/transformer_engine/common/multi_tensor/l2norm.cu @@ -393,13 +393,12 @@ __global__ void cleanup_v2(float *output, float *output_per_tensor, float *ret, void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, Tensor output, Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor, - bool per_tensor, int max_chunks_per_tensor, const int device_id, - cudaStream_t stream) { + bool per_tensor, int max_chunks_per_tensor, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<1>( - BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), device_id, - stream, reinterpret_cast(output.data.dptr), + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, L2NormFunctor(), stream, + reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor);) @@ -408,7 +407,6 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, // This involves one more small kernel launches, but will be negligible end to end. // I could get rid of these by hacking the functor + multi tensor harness with persistence // logic, but keeping it simple for now - const OptionalCUDAGuard device_guard(device_id); cleanup<<>>( reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, @@ -421,13 +419,12 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, Tensor output, Tensor output_per_tensor, Tensor ret, Tensor ret_per_tensor, Tensor inv_scale, bool per_tensor, - int max_chunks_per_tensor, const int device_id, - cudaStream_t stream) { + int max_chunks_per_tensor, cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), dtype, multi_tensor_apply<1>( - BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor(), device_id, - stream, reinterpret_cast(inv_scale.data.dptr), + BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, UnscaleL2NormFunctor(), stream, + reinterpret_cast(inv_scale.data.dptr), reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor);) @@ -437,7 +434,6 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, // This involves one more small kernel launches, but will be negligible end to end. // I could get rid of these by hacking the functor + multi tensor harness with persistence // logic, but keeping it simple for now - const OptionalCUDAGuard device_guard(device_id); cleanup<<>>( reinterpret_cast(output.data.dptr), per_tensor ? reinterpret_cast(output_per_tensor.data.dptr) : nullptr, @@ -453,8 +449,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen const size_t num_tensor_lists, const size_t num_tensors_per_list, NVTETensor output, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor ret_per_tensor, int per_tensor, - int max_chunks_per_tensor, const int device_id, - cudaStream_t stream) { + int max_chunks_per_tensor, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_l2norm_cuda); using namespace transformer_engine; @@ -463,7 +458,7 @@ void nvte_multi_tensor_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETen convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), per_tensor, - max_chunks_per_tensor, device_id, stream); + max_chunks_per_tensor, stream); } void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, @@ -472,7 +467,7 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor output_per_tensor, NVTETensor ret, NVTETensor ret_per_tensor, NVTETensor inv_scale, int per_tensor, int max_chunks_per_tensor, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_unscale_l2norm_cuda); using namespace transformer_engine; @@ -481,5 +476,5 @@ void nvte_multi_tensor_unscale_l2norm_cuda(int chunk_size, NVTETensor noop_flag, convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), *convertNVTETensorCheck(output), *convertNVTETensorCheck(output_per_tensor), *convertNVTETensorCheck(ret), *convertNVTETensorCheck(ret_per_tensor), - *convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, device_id, stream); + *convertNVTETensorCheck(inv_scale), per_tensor, max_chunks_per_tensor, stream); } diff --git a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh index 4727f3964..b78612181 100644 --- a/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh +++ b/transformer_engine/common/multi_tensor/multi_tensor_apply.cuh @@ -14,53 +14,6 @@ // This header is the one-stop shop for all your multi-tensor apply needs. -// Change device if needed. -class OptionalCUDAGuard { - public: - explicit OptionalCUDAGuard(int new_device) { - if (new_device < 0) return; - - int current_device; - NVTE_CHECK_CUDA(cudaGetDevice(¤t_device)); - - if (new_device != current_device) { - NVTE_CHECK_CUDA(cudaSetDevice(new_device)); - device_changed_ = true; - prev_device_ = current_device; - } - } - - OptionalCUDAGuard(const OptionalCUDAGuard &) = delete; - OptionalCUDAGuard &operator=(const OptionalCUDAGuard &) = delete; - - OptionalCUDAGuard(OptionalCUDAGuard &&other) noexcept - : prev_device_(other.prev_device_), device_changed_(other.device_changed_) { - other.device_changed_ = false; - } - - OptionalCUDAGuard &operator=(OptionalCUDAGuard &&other) noexcept { - if (this != &other) { - if (device_changed_) { - cudaSetDevice(prev_device_); - } - prev_device_ = other.prev_device_; - device_changed_ = other.device_changed_; - other.device_changed_ = false; - } - return *this; - } - - ~OptionalCUDAGuard() { - if (device_changed_) { - cudaSetDevice(prev_device_); - } - } - - private: - int prev_device_; - bool device_changed_ = false; -}; - // TODO: Kernel arg size limit may be <4KB for some other cards (ie Jetson) constexpr int depth_to_max_tensors[6] = {110, 64, 48, 36, 30, 24}; constexpr int depth_to_max_blocks[6] = {320, 320, 320, 320, 320, 320}; @@ -94,7 +47,7 @@ template void multi_tensor_apply(int64_t block_size, int64_t chunk_size, const transformer_engine::Tensor &noop_flag, std::vector> tensor_lists, - T callable, const int device_id, cudaStream_t stream, ArgTypes... args) { + T callable, cudaStream_t stream, ArgTypes... args) { const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensors_per_list = tensor_lists[0].size(); @@ -108,8 +61,6 @@ void multi_tensor_apply(int64_t block_size, int64_t chunk_size, TensorListMetadata tl; - const OptionalCUDAGuard device_guard(device_id); - tl.start_tensor_this_launch = 0; int loc_block_info = 0; int loc_tensor_info = 0; diff --git a/transformer_engine/common/multi_tensor/scale.cu b/transformer_engine/common/multi_tensor/scale.cu index 66a173bdb..ac457adb0 100644 --- a/transformer_engine/common/multi_tensor/scale.cu +++ b/transformer_engine/common/multi_tensor/scale.cu @@ -104,13 +104,13 @@ struct ScaleFunctor { void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, float scale, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[0][0]->dtype(), p_in_type, TRANSFORMER_ENGINE_TYPE_SWITCH_NON_FP8ONLY( tensor_lists[1][0]->dtype(), g_in_type, multi_tensor_apply<2>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - ScaleFunctor(), device_id, stream, scale);)) + ScaleFunctor(), stream, scale);)) NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -119,12 +119,11 @@ void multi_tensor_scale_cuda(int chunk_size, Tensor noop_flag, void nvte_multi_tensor_scale_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor **tensor_lists, const size_t num_tensor_lists, const size_t num_tensors_per_list, - float scale, const int device_id, cudaStream_t stream) { + float scale, cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_scale_cuda); using namespace transformer_engine; multi_tensor_scale::multi_tensor_scale_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), - convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, device_id, - stream); + convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), scale, stream); } diff --git a/transformer_engine/common/multi_tensor/sgd.cu b/transformer_engine/common/multi_tensor/sgd.cu index 05106e46d..9235de330 100644 --- a/transformer_engine/common/multi_tensor/sgd.cu +++ b/transformer_engine/common/multi_tensor/sgd.cu @@ -127,8 +127,7 @@ struct SGDFunctor { void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag, std::vector> tensor_lists, float wd, float momentum, float dampening, float lr, bool nesterov, bool first_run, - bool wd_after_momentum, float scale, const int device_id, - cudaStream_t stream) { + bool wd_after_momentum, float scale, cudaStream_t stream) { const size_t num_tensor_lists = tensor_lists.size(); const size_t num_tensors_per_list = tensor_lists[0].size(); @@ -154,29 +153,29 @@ void multi_tensor_sgd_cuda(int chunk_size, Tensor noop_flag, // Case 1. fp16, fp16, fp16, No if (grad_type == DType::kFloat16 && weight_type == DType::kFloat16 && num_tensor_lists == 3) { multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<3, fp16, fp16>(), device_id, stream, wd, momentum, dampening, - lr, nesterov, first_run, wd_after_momentum, scale); + SGDFunctor<3, fp16, fp16>(), stream, wd, momentum, dampening, lr, + nesterov, first_run, wd_after_momentum, scale); } // Case 2. fp32, fp32, fp32, No else if (grad_type == DType::kFloat32 && // NOLINT(*) weight_type == DType::kFloat32 && num_tensor_lists == 3) { multi_tensor_apply<3>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<3, float, float>(), device_id, stream, wd, momentum, dampening, - lr, nesterov, first_run, wd_after_momentum, scale); + SGDFunctor<3, float, float>(), stream, wd, momentum, dampening, lr, + nesterov, first_run, wd_after_momentum, scale); } // Case 3. fp16, fp32, fp32, Yes else if (grad_type == DType::kFloat16 && // NOLINT(*) weight_type == DType::kFloat32 && num_tensor_lists == 4) { multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<4, fp16, float>(), device_id, stream, wd, momentum, dampening, - lr, nesterov, first_run, wd_after_momentum, scale); + SGDFunctor<4, fp16, float>(), stream, wd, momentum, dampening, lr, + nesterov, first_run, wd_after_momentum, scale); } // Case 4. fp32, fp32, fp32, Yes else if (grad_type == DType::kFloat32 && // NOLINT(*) weight_type == DType::kFloat32 && num_tensor_lists == 4) { multi_tensor_apply<4>(BLOCK_SIZE, chunk_size, noop_flag, tensor_lists, - SGDFunctor<4, float, float>(), device_id, stream, wd, momentum, dampening, - lr, nesterov, first_run, wd_after_momentum, scale); + SGDFunctor<4, float, float>(), stream, wd, momentum, dampening, lr, + nesterov, first_run, wd_after_momentum, scale); } else { NVTE_ERROR("Unsupported combination of weight and gradient types."); } @@ -191,12 +190,12 @@ void nvte_multi_tensor_sgd_cuda(int chunk_size, NVTETensor noop_flag, NVTETensor const size_t num_tensor_lists, const size_t num_tensors_per_list, float wd, float momentum, float dampening, float lr, int nesterov, int first_run, int wd_after_momentum, float scale, - const int device_id, cudaStream_t stream) { + cudaStream_t stream) { NVTE_API_CALL(nvte_multi_tensor_sgd_cuda); using namespace transformer_engine; multi_tensor_sgd::multi_tensor_sgd_cuda( chunk_size, *convertNVTETensorCheck(noop_flag), convert_tensor_array(tensor_lists, num_tensor_lists, num_tensors_per_list), wd, momentum, - dampening, lr, nesterov, first_run, wd_after_momentum, scale, device_id, stream); + dampening, lr, nesterov, first_run, wd_after_momentum, scale, stream); } diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp index 21d3e0574..acf04900e 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/adam.cpp @@ -16,11 +16,10 @@ void multi_tensor_adam_cuda(int chunk_size, at::Tensor noop_flag, auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction, - weight_decay, device_id, at::cuda::getCurrentCUDAStream()); + weight_decay, at::cuda::getCurrentCUDAStream()); } void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag, @@ -31,12 +30,10 @@ void multi_tensor_adam_param_remainder_cuda(int chunk_size, at::Tensor noop_flag auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_param_remainder_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, - beta2, epsilon, step, mode, bias_correction, weight_decay, device_id, - at::cuda::getCurrentCUDAStream()); + beta2, epsilon, step, mode, bias_correction, weight_decay, at::cuda::getCurrentCUDAStream()); } void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, @@ -47,12 +44,11 @@ void multi_tensor_adam_fp8_cuda(int chunk_size, at::Tensor noop_flag, auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_fp8_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr, beta1, beta2, epsilon, step, mode, bias_correction, weight_decay, static_cast(fp8_dtype), - device_id, at::cuda::getCurrentCUDAStream()); + at::cuda::getCurrentCUDAStream()); } void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, @@ -67,12 +63,11 @@ void multi_tensor_adam_capturable_cuda(int chunk_size, at::Tensor noop_flag, auto lr_cu = makeTransformerEngineTensor(lr); auto step_cu = makeTransformerEngineTensor(step); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_capturable_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, - inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream()); + inv_scale_cu.data(), at::cuda::getCurrentCUDAStream()); } void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_flag, @@ -87,12 +82,11 @@ void multi_tensor_adam_capturable_master_cuda(int chunk_size, at::Tensor noop_fl auto lr_cu = makeTransformerEngineTensor(lr); auto step_cu = makeTransformerEngineTensor(step); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_adam_capturable_master_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, lr_cu.data(), beta1, beta2, epsilon, step_cu.data(), mode, bias_correction, weight_decay, - inv_scale_cu.data(), device_id, at::cuda::getCurrentCUDAStream()); + inv_scale_cu.data(), at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp index 290f70b57..8a1a34698 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/compute_scale.cpp @@ -14,11 +14,10 @@ void multi_tensor_compute_scale_and_scale_inv_cuda( auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_compute_scale_and_scale_inv_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, max_fp8, - force_pow_2_scales, epsilon, device_id, at::cuda::getCurrentCUDAStream()); + force_pow_2_scales, epsilon, at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp index 1e8eb44d9..d33a2520e 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/l2norm.cpp @@ -43,12 +43,11 @@ std::tuple multi_tensor_l2norm_cuda( auto output_per_tensor_cu = makeTransformerEngineTensor(output_per_tensor); auto ret_cu = makeTransformerEngineTensor(ret); auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_l2norm_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(), per_tensor, - max_chunks_per_tensor, device_id, at::cuda::getCurrentCUDAStream()); + max_chunks_per_tensor, at::cuda::getCurrentCUDAStream()); return std::tuple(ret, ret_per_tensor); } @@ -91,13 +90,11 @@ std::tuple multi_tensor_unscale_l2norm_cuda( auto ret_cu = makeTransformerEngineTensor(ret); auto ret_per_tensor_cu = makeTransformerEngineTensor(ret_per_tensor); auto inv_scale_cu = makeTransformerEngineTensor(inv_scale); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_unscale_l2norm_cuda( chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, output_cu.data(), output_per_tensor_cu.data(), ret_cu.data(), ret_per_tensor_cu.data(), - inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, device_id, - at::cuda::getCurrentCUDAStream()); + inv_scale_cu.data(), per_tensor, max_chunks_per_tensor, at::cuda::getCurrentCUDAStream()); return std::tuple(ret, ret_per_tensor); } diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp index ba33f04bf..2db936f84 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/scale.cpp @@ -13,10 +13,9 @@ void multi_tensor_scale_cuda(int chunk_size, at::Tensor noop_flag, auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_scale_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, - num_tensors, scale, device_id, at::cuda::getCurrentCUDAStream()); + num_tensors, scale, at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp b/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp index de3209535..2c6a6b7c4 100644 --- a/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp +++ b/transformer_engine/pytorch/csrc/extensions/multi_tensor/sgd.cpp @@ -15,11 +15,10 @@ void multi_tensor_sgd_cuda(int chunk_size, at::Tensor noop_flag, auto noop_flag_cu = makeTransformerEngineTensor(noop_flag); auto [_, __, tensor_lists_ptr, num_lists, num_tensors] = makeTransformerEngineTensorList(tensor_lists); - int device_id = tensor_lists[0][0].device().index(); nvte_multi_tensor_sgd_cuda(chunk_size, noop_flag_cu.data(), tensor_lists_ptr.data(), num_lists, num_tensors, wd, momentum, dampening, lr, nesterov, first_run, - wd_after_momentum, scale, device_id, at::cuda::getCurrentCUDAStream()); + wd_after_momentum, scale, at::cuda::getCurrentCUDAStream()); } } // namespace transformer_engine::pytorch From 38c26dd8dfcf8386e712f7eb176fee603a29baf9 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Fri, 25 Jul 2025 08:42:57 -0700 Subject: [PATCH 023/153] Fixed double buffering issue for assymetric layers (#1984) * Fixed double buffering issue for assymetric layers Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/cpu_offload.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 1c03e3d37..75d3e1b2e 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -556,21 +556,33 @@ def bulk_reload_group(self, group_to_reload): for tensor_label, state in self.tensor_tag_to_state.items(): group_id, _ = tensor_label if group_id == group_to_reload: + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] + else: + reload_buffer = None + if isinstance(state, tuple): recovered_tensor = SynchronizedGroupOffloadHandler.reload( - state, True, self.reload_double_buffer[double_buffer_idx][buffer_idx] + state, True, reload_buffer ) buffer_idx = buffer_idx + 1 self.tensor_tag_to_state[tensor_label] = recovered_tensor elif isinstance(state, list): tensor_list = [] for state_tuple in state: + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][ + buffer_idx + ] + else: + reload_buffer = None + if isinstance(state_tuple, tuple): tensor_list.append( SynchronizedGroupOffloadHandler.reload( state_tuple, True, - self.reload_double_buffer[double_buffer_idx][buffer_idx], + reload_buffer, ) ) buffer_idx = buffer_idx + 1 From c6c1f50eba26bd8b02f2a069083909b3be8332d4 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 25 Jul 2025 09:18:22 -0700 Subject: [PATCH 024/153] [PyTorch] Add ops for dropout and constant scale (#1995) * Add ops for dropout and constant scale Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 123 ++++++++++++++---- .../pytorch/ops/basic/__init__.py | 2 + .../pytorch/ops/basic/constant_scale.py | 40 ++++++ .../pytorch/ops/basic/dropout.py | 67 ++++++++++ 4 files changed, 209 insertions(+), 23 deletions(-) create mode 100644 transformer_engine/pytorch/ops/basic/constant_scale.py create mode 100644 transformer_engine/pytorch/ops/basic/dropout.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 40be0a75a..13b27a0b1 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -36,9 +36,7 @@ import transformer_engine_torch as tex # Import utility functions -_current_file = pathlib.Path(__file__).resolve() -sys.path.append(str(_current_file.parent)) -from utils import dtype_tols, make_recipe +from utils import dtype_tols, make_recipe, reset_rng_states # Check if FP8 is supported fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() @@ -327,10 +325,7 @@ class TestFuser: @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() @pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) def test_fp8_scale_update( @@ -544,10 +539,7 @@ class TestBasicOps: @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("device", ("cuda", "cpu")) @@ -1693,16 +1685,107 @@ def test_swiglu( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) + @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) + @pytest.mark.parametrize("shape", ((), (1, 13), (4, 4, 2))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("device", _devices) + def test_constant_scale( + self, + *, + scale: float, + shape: Iterable[int], + dtype: torch.dtype, + device: torch.device, + ): + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = scale * x_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operation + op = te_ops.ConstantScale(scale) + y_test = op(x_test) + y_test.backward(dy_test) + + # Check results + tols = dtype_tols(dtype) + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + + @pytest.mark.parametrize("prob", (0.1, 0.5, 0.75)) + @pytest.mark.parametrize("is_training", (True, False)) + @pytest.mark.parametrize("shape", ((101,), (2, 4, 16))) + @pytest.mark.parametrize("dtype", _dtypes) + def test_dropout( + self, + *, + prob: float, + is_training: bool, + shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + ): + + # Random data + x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5 + x_test = x_ref.clone().requires_grad_() + dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5 + dy_test = dy_ref.clone() + + # Apply dropout + op = te_ops.Dropout(prob) + if is_training: + op.train() + else: + op.eval() + y = op(x_test) + y.backward(dy_test) + + # Check values + if is_training: + mask = ((y != 0) / (1 - prob)).to(dtype=dtype) + torch.testing.assert_close(y, x_ref * mask) + torch.testing.assert_close(x_test.grad, dy_ref * mask) + else: + torch.testing.assert_close(y, x_ref, rtol=0, atol=0) + torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0) + + # Hypothesis testing for number of zeros + # Note: A Bernoulli random variable with probability p has + # mean p and standard deviation sqrt(p*(1-p)). By the central + # limit theorem, the mean of n iid Bernoulli variables + # converges to a normal random variable with mean p and + # standard deviation sqrt(p*(1-p)/n). If the observed mean is + # below the 0.5th or above the 99.5th percentiles, then the + # p-value is less than 1% and we assume that the dropout + # distribution is incorrect. + if is_training: + prob_observed = 1 - torch.count_nonzero(y).item() / y.numel() + z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel()) + assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval" + class TestFusedOps: """Tests for fused operations""" @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() @pytest.mark.parametrize("weight_shape", ((32, 64), (3, 5))) @pytest.mark.parametrize("in_shape", ((-1,), (1, 7, -1), (8, 2, 10, -1))) @@ -2125,10 +2208,7 @@ class TestCheckpointing: @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() @pytest.mark.parametrize("quantization", _quantization_list) @pytest.mark.parametrize("quantized_weight", (False, True)) @@ -2240,10 +2320,7 @@ class TestSequentialModules: @staticmethod def setup_class(cls) -> None: - # Configure RNG - seed = 1234 - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) + reset_rng_states() @pytest.mark.parametrize("requires_grad", (False, True)) @pytest.mark.parametrize("bias", (False, True)) diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index e0e15b703..843bfc1bd 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -10,6 +10,8 @@ from .all_reduce import AllReduce from .basic_linear import BasicLinear from .bias import Bias +from .constant_scale import ConstantScale +from .dropout import Dropout from .identity import Identity from .l2normalization import L2Normalization from .layer_norm import LayerNorm diff --git a/transformer_engine/pytorch/ops/basic/constant_scale.py b/transformer_engine/pytorch/ops/basic/constant_scale.py new file mode 100644 index 000000000..4de70c0e9 --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/constant_scale.py @@ -0,0 +1,40 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for constant scaling.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from ...tensor import Quantizer + + +class ConstantScale(BasicOperation): + """Multiply by a constant""" + + def __init__(self, scale: float) -> None: + super().__init__() + self.scale = scale + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + return input_ * self.scale + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + return grad_output * self.scale, () diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py new file mode 100644 index 000000000..958e9b06c --- /dev/null +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -0,0 +1,67 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fusible operation for dropout.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from transformer_engine.pytorch.ops.op import ( + BasicOperation, + OperationContext, +) +from ...tensor import Quantizer + + +class Dropout(BasicOperation): + """Randomly zero out tensor entries during training + + During training, tensor entries are randomly set to zero with + probability :math:`p` and remaining entries are scaled by + :math:`1/(1-p)`. + + """ + + def __init__(self, p: float) -> None: + super().__init__() + self.dropout_probability = p + + def op_forward( + self, + ctx: OperationContext, + input_: torch.Tensor, + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + ) -> torch.Tensor: + + # Compute dropout if training + out = input_ + is_training = self.training + mask = None + if is_training: + keep_prob = 1 - self.dropout_probability + mask = torch.empty_like(input_) + mask.bernoulli_(keep_prob) + mask *= 1 / keep_prob + out = out * mask + + # Save context for backward + if ctx.requires_grad: + ctx.save_for_backward(mask) + ctx.is_training = is_training + + return out + + def op_backward( + self, + ctx: OperationContext, + grad_output: torch.Tensor, + ) -> tuple[torch.Tensor, tuple[()]]: + (mask,) = ctx.saved_tensors + grad_input = grad_output + if ctx.is_training: + grad_input = grad_input * mask + return grad_input, () From aac744276046e280d8e7ced1e83517b0739d4203 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 28 Jul 2025 20:47:57 -0400 Subject: [PATCH 025/153] [PyTorch] Prune L0 unit test (#1999) * Add verbosity only for failing tests Signed-off-by: Kirthi Shankar Sivamani * Prune some tests and preinit recipe Signed-off-by: Kirthi Shankar Sivamani * Prune further tests Signed-off-by: Kirthi Shankar Sivamani * fix multitensor Signed-off-by: Kirthi Shankar Sivamani * Minor fixes Signed-off-by: Kirthi Shankar Sivamani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix a100 Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- qa/L0_pytorch_unittest/test.sh | 48 +-- tests/pytorch/attention/test_attention.py | 5 +- tests/pytorch/test_cpu_offloading.py | 19 +- tests/pytorch/test_cuda_graphs.py | 44 +-- tests/pytorch/test_fused_optimizer.py | 1 - tests/pytorch/test_fused_router.py | 3 +- tests/pytorch/test_hf_integration.py | 1 - tests/pytorch/test_numerics.py | 98 ++--- tests/pytorch/test_onnx_export.py | 385 +++++++++---------- tests/pytorch/test_parallel_cross_entropy.py | 1 - tests/pytorch/test_sanity.py | 281 +------------- 11 files changed, 276 insertions(+), 610 deletions(-) diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 9a924282b..482ae6dca 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -26,30 +26,30 @@ pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime" pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" -PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" -NVTE_FLASH_ATTN=0 python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" -NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" -python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_deferred_init.xml $TE_PATH/tests/pytorch/test_deferred_init.py || test_fail "test_deferred_init.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_numerics.xml $TE_PATH/tests/pytorch/test_numerics.py || test_fail "test_numerics.py" +PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cuda_graphs.xml $TE_PATH/tests/pytorch/test_cuda_graphs.py || test_fail "test_cuda_graphs.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_jit.xml $TE_PATH/tests/pytorch/test_jit.py || test_fail "test_jit.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_rope.xml $TE_PATH/tests/pytorch/test_fused_rope.py || test_fail "test_fused_rope.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8tensor.xml $TE_PATH/tests/pytorch/test_float8tensor.py || test_fail "test_float8tensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8blockwisetensor.xml $TE_PATH/tests/pytorch/test_float8blockwisetensor.py || test_fail "test_float8blockwisetensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_scaling_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_scaling_exact.py || test_fail "test_float8_blockwise_scaling_exact.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise_gemm_exact.xml $TE_PATH/tests/pytorch/test_float8_blockwise_gemm_exact.py || test_fail "test_float8_blockwise_gemm_exact.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" +NVTE_FLASH_ATTN=0 python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_cpu_offloading.xml $TE_PATH/tests/pytorch/test_cpu_offloading.py || test_fail "test_cpu_offloading.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_attention.xml $TE_PATH/tests/pytorch/attention/test_attention.py || test_fail "test_attention.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_kv_cache.xml $TE_PATH/tests/pytorch/attention/test_kv_cache.py || test_fail "test_kv_cache.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_hf_integration.xml $TE_PATH/tests/pytorch/test_hf_integration.py || test_fail "test_hf_integration.py" +NVTE_TEST_CHECKPOINT_ARTIFACT_PATH=$TE_PATH/artifacts/tests/pytorch/test_checkpoint python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_checkpoint.xml $TE_PATH/tests/pytorch/test_checkpoint.py || test_fail "test_checkpoint.py" +python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_router.xml $TE_PATH/tests/pytorch/test_fused_router.py || test_fail "test_fused_router.py" if [ "$RET" -ne 0 ]; then echo "Error in the following test cases:$FAILED_CASES" diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 4dfd54cdb..3088853a2 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -6,7 +6,7 @@ import os import sys import pathlib -from typing import Any, Dict, List, Tuple, Union, Optional +from typing import Any, Dict, Tuple, Union import pytest import torch @@ -20,10 +20,8 @@ from transformer_engine.pytorch.attention.multi_head_attention import MultiheadAttention from transformer_engine.pytorch.attention.dot_product_attention.utils import ( FlashAttentionUtils, - get_attention_backend, check_set_window_size, ) -from transformer_engine.pytorch.attention import InferenceParams from transformer_engine.pytorch.attention import RotaryPositionEmbedding import transformer_engine.pytorch.cpp_extensions as ext from transformer_engine.pytorch.cpp_extensions.fused_attn import ( @@ -54,7 +52,6 @@ reset_rng_states, ModelConfig, dtype_tols, - logging_context, get_available_attention_backends, ) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index cd71d5b93..0b0732dfa 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -14,15 +14,12 @@ from utils import ModelConfig, get_available_attention_backends # Check if FP8 is supported -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, _ = FP8GlobalStateManager.is_fp8_available() -fp8_recipes = [ - None, # non-fp8 - # recipe.MXFP8BlockScaling(), - scale inverse tensors offloading doest not work yet - recipe.Float8CurrentScaling(), - recipe.DelayedScaling(), -] +fp8_recipes = [None] +if fp8_available: + fp8_recipes.append(recipe.Float8CurrentScaling()) + fp8_recipes.append(recipe.DelayedScaling()) model_config = { "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), @@ -129,12 +126,6 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: model_cls = model_types[model_key] models_list = [model_cls() for _ in range(NUM_LAYERS)] - if fp8_recipe and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None: - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if model_key in ["multihead_attention", "transformer_layer"]: available_backends, *_ = get_available_attention_backends( model_config["small"], diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 83837eafd..9b5118e6e 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -2,9 +2,7 @@ # # See LICENSE for license information. -from dataclasses import dataclass -import itertools -from typing import Iterable, List, Tuple, Union +from typing import Iterable, List, Union import pytest import torch @@ -26,11 +24,9 @@ from utils import ModelConfig, reset_rng_states # Check if FP8 is supported. -fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() -) -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() +mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() # Reset RNG states. reset_rng_states() @@ -39,12 +35,14 @@ "small": ModelConfig(32, 2, 2, 32), } -fp8_recipes = [ - recipe.DelayedScaling(), - recipe.MXFP8BlockScaling(), - recipe.Float8CurrentScaling(), - recipe.Float8BlockScaling(), -] +fp8_recipes = [] +if mxfp8_available: + fp8_recipes.append(recipe.MXFP8BlockScaling()) +if fp8_block_scaling_available: + fp8_recipes.append(recipe.Float8BlockScaling()) +if fp8_available: + fp8_recipes.append(recipe.Float8CurrentScaling()) + fp8_recipes.append(recipe.DelayedScaling()) # Supported data types dtypes: List[torch.dtype] = [torch.float32, torch.float16] @@ -277,35 +275,27 @@ def _test_cuda_graphs( @pytest.mark.parametrize("module", _test_cuda_graphs_modules) @pytest.mark.parametrize("dtype", dtypes) -@pytest.mark.parametrize("fp8", (False, True)) @pytest.mark.parametrize("fp8_params", (False, True)) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes + [None]) def test_make_graphed_callables( *, module: str, model_config: str = "small", num_layers: int = 3, dtype: torch.dtype, - fp8: bool, fp8_params: bool, fp8_recipe: recipe.Recipe, fp8_weight_caching: bool = False, ) -> None: - # Skip invalid configurations. - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) + fp8 = fp8_recipe is not None if fp8_params and not fp8: pytest.skip("FP8 needed for FP8 parameters.") if fp8_weight_caching and not fp8: pytest.skip("FP8 needed for FP8 parameters.") - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - - if fp8_recipe.float8_block_scaling() and module == "linear_op": + if fp8 and fp8_recipe.float8_block_scaling() and module == "linear_op": pytest.skip("Module not yet supported for float8_block_scaling with CUDA graphs") + # Run model with different CUDA graph settings. model_config = model_configs[model_config] kwargs = dict( @@ -336,7 +326,6 @@ def test_make_graphed_callables( ] -@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) @pytest.mark.parametrize( "module", _test_make_graphed_callables_with_fp8_weight_caching_modules, @@ -352,7 +341,6 @@ def test_make_graphed_callables_with_fp8_weight_caching( test_make_graphed_callables( module=module, dtype=torch.float32, - fp8=True, fp8_params=fp8_params, fp8_recipe=fp8_recipe, fp8_weight_caching=True, diff --git a/tests/pytorch/test_fused_optimizer.py b/tests/pytorch/test_fused_optimizer.py index f1e80e698..8adabd751 100644 --- a/tests/pytorch/test_fused_optimizer.py +++ b/tests/pytorch/test_fused_optimizer.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. -from itertools import product import copy from contextlib import nullcontext diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index d2cb85dd3..aacad9081 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -2,8 +2,7 @@ # # See LICENSE for license information. import torch -import math -from typing import Optional, Dict +from typing import Optional from transformer_engine.pytorch.router import ( fused_topk_with_score_function, fused_compute_score_for_moe_aux_loss, diff --git a/tests/pytorch/test_hf_integration.py b/tests/pytorch/test_hf_integration.py index 0b2468510..e74b16022 100644 --- a/tests/pytorch/test_hf_integration.py +++ b/tests/pytorch/test_hf_integration.py @@ -7,7 +7,6 @@ from transformers.modeling_utils import PreTrainedModel from transformer_engine.pytorch.transformer import TransformerLayer -from transformer_engine.pytorch.utils import is_bf16_compatible class SimpleTEModel(PreTrainedModel): diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 790bc7a11..543f5f08d 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -2,7 +2,6 @@ # # See LICENSE for license information. -from collections import OrderedDict import math import os from typing import Dict, List, Tuple, Optional @@ -37,23 +36,20 @@ Fp8Padding, Fp8Unpadding, ) -from transformer_engine.pytorch.attention.inference import InferenceParams from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace -from transformer_engine.pytorch.utils import get_device_compute_capability, get_cudnn_version +from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states, get_available_attention_backends # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() -) +mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -103,18 +99,21 @@ feature_dirs=os.environ["NVTE_TEST_NVINSPECT_FEATURE_DIRS"], ) -fp8_recipes = [ - recipe.MXFP8BlockScaling(), - recipe.DelayedScaling(), - recipe.Float8CurrentScaling(), - recipe.Float8BlockScaling(), -] + +fp8_recipes = [] +if mxfp8_available: + fp8_recipes.append(recipe.MXFP8BlockScaling()) +if fp8_block_scaling_available: + fp8_recipes.append(recipe.Float8BlockScaling()) +if fp8_available: + fp8_recipes.append(recipe.Float8CurrentScaling()) + fp8_recipes.append(recipe.DelayedScaling()) def is_fused_attn_available( config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True ): - available_backends, _, fused_attn_backends = get_available_attention_backends( + _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -571,14 +570,8 @@ def _test_e2e_selective_recompute( @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -687,14 +680,8 @@ def _test_e2e_full_recompute( def test_gpt_full_activation_recompute( dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant ): - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -1263,8 +1250,8 @@ def test_linear_accuracy_delay_wgrad_compute(dtype, bs, model, bias, fuse_wgrad_ te_linear_ref, bs, dtype, config, delay_wgrad_compute=False ) - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): + # Should be bit-wise match + for _, (o, o_ref) in enumerate(zip(te_outputs, te_outputs_ref)): torch.testing.assert_close(o, o_ref, rtol=0, atol=0) @@ -1276,12 +1263,7 @@ def test_linear_accuracy_save_original_input(dtype, model, recipe): fuse_wgrad_accumulation = True fp8_model_params = False fp8 = recipe is not None - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) + if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") @@ -1649,14 +1631,12 @@ def test_layernorm_mlp_accuracy(dtype, bs, model, activation, normalization, ret @pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("bs", [2]) @pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("activation", all_activations) -@pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("bias", all_boolean) @pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) def test_layernorm_mlp_accuracy_delay_wgrad_compute( - dtype, bs, model, activation, normalization, bias, fuse_wgrad_accumulation + dtype, bs, model, bias, fuse_wgrad_accumulation ): config = model_configs[model] @@ -1665,7 +1645,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ffn_hidden_size=4 * config.hidden_size, eps=config.eps, bias=bias, - normalization=normalization, params_dtype=dtype, device="cuda", delay_wgrad_compute=True, @@ -1677,7 +1656,6 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( ffn_hidden_size=4 * config.hidden_size, eps=config.eps, bias=bias, - normalization=normalization, params_dtype=dtype, device="cuda", delay_wgrad_compute=False, @@ -1687,8 +1665,7 @@ def test_layernorm_mlp_accuracy_delay_wgrad_compute( # Share params with torch.no_grad(): ln_mlp_ref.layer_norm_weight = Parameter(ln_mlp.layer_norm_weight.clone()) - if normalization != "RMSNorm": - ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) + ln_mlp_ref.layer_norm_bias = Parameter(ln_mlp.layer_norm_bias.clone()) ln_mlp_ref.fc1_weight = Parameter(ln_mlp.fc1_weight.clone()) ln_mlp_ref.fc2_weight = Parameter(ln_mlp.fc2_weight.clone()) if bias: @@ -1802,14 +1779,8 @@ def test_grouped_linear_accuracy( parallel_mode=None, ): fp8 = recipe is not None - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -1904,14 +1875,8 @@ def test_grouped_linear_accuracy_save_original_input( parallel_mode=None, ): fp8 = recipe is not None - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8 and recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: + if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if fp8 and recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") @@ -2114,14 +2079,8 @@ def test_padding_grouped_linear_accuracy( fp8_model_params, parallel_mode=None, ): - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] if config.max_seqlen_q % 16 != 0 and fp8: @@ -2189,14 +2148,8 @@ def test_padding_grouped_linear_accuracy_save_original_input( fp8_model_params, parallel_mode=None, ): - if fp8 and not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) if fp8 and recipe.delayed(): pytest.skip("DelayedScaling recipe is not supported with save_original_input") @@ -2410,14 +2363,8 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): @pytest.mark.parametrize("model", ["126m"]) @pytest.mark.parametrize("recipe", fp8_recipes) def test_gpt_fp8_parameters(dtype, bs, model, recipe): - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) config = model_configs[model] @@ -2645,9 +2592,8 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): (16, 4096, 128, 512), ], ) -@pytest.mark.parametrize("fp8_dtype", [tex.DType.kFloat8E4M3, tex.DType.kFloat8E5M2]) @pytest.mark.parametrize("accumulate", [False, True]) -def test_fp8_grouped_gemm(shape, fp8_dtype, accumulate): +def test_fp8_grouped_gemm(shape, accumulate): if not fp8_available: pytest.skip(reason_for_no_fp8) diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index ea9c85e37..839fb8dff 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -27,7 +27,6 @@ import numpy as np import onnxruntime as ort import torch -import random from torch import nn as nn from typing import Optional, Union, Tuple, List from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op @@ -59,14 +58,13 @@ fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() -skip_FP8 = pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) -skip_MXFP8 = pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) -fp8_recipes = [ - None, - recipe.DelayedScaling(), - recipe.MXFP8BlockScaling(), -] +fp8_recipes = [] +if mxfp8_available: + fp8_recipes.append(recipe.MXFP8BlockScaling()) +if fp8_available: + fp8_recipes.append(recipe.DelayedScaling()) +fp8_recipes.append(None) supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] @@ -369,14 +367,6 @@ def create_ort_input_dict(session, inputs): ) -def create_meta(scale_factor: float, size: int = 1): - meta = tex.FP8TensorMeta() - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor - return meta - - def dtype2str(dtype: torch.dtype, fake_bf16_io=False): if fake_bf16_io: assert dtype == torch.bfloat16 @@ -413,36 +403,12 @@ def get_attn_mask_str(use_mask, attn_mask_type): """ -@pytest.mark.parametrize("scale_factor", [112]) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -# Returning the bias is a TE fusion optimization we don't care about. -@pytest.mark.parametrize("return_bias", [True, False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, False), - (torch.float16, True), - # Todo: cannot configure BF16 when bias is disabled (ORT issue?) - (torch.bfloat16, False), - # Todo: cannot configure BF16 when bias is enabled (ORT issue?) - (torch.bfloat16, True), - ], -) -def test_export_linear( - seed_default_rng, - scale_factor: float, - fp8_recipe: recipe.Recipe, - use_bias: bool, - return_bias: bool, - precision: torch.dtype, +def _test_export_linear( + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_bias: bool = True, + return_bias: bool = False, + precision: torch.dtype = torch.float32, ): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if return_bias and not use_bias: pytest.skip("Cannot return bias when bias is disabled") @@ -498,32 +464,28 @@ def forward(self, inp): ) -@pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize( - "precision", - [ - torch.float32, - torch.float16, - torch.bfloat16, - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm( - seed_default_rng, - scale_factor: float, - fp8_recipe: recipe.Recipe, - precision: torch.dtype, - zero_centered_gamma: bool, - normalization: str, -): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision): + _test_export_linear(fp8_recipe=fp8_recipe, precision=precision) + + +@pytest.mark.parametrize("use_bias", [True, False]) +def test_export_linear_use_bias(seed_default_rng, use_bias): + _test_export_linear(use_bias=use_bias) + +@pytest.mark.parametrize("return_bias", [True, False]) +def test_export_linear_return_bias(seed_default_rng, return_bias): + _test_export_linear(return_bias=return_bias) + + +def _test_export_layernorm( + fp8_recipe: recipe.Recipe = fp8_recipes[0], + precision: torch.dtype = torch.float32, + zero_centered_gamma: bool = False, + normalization: str = all_normalizations[0], +): # Set dimensions (these are arbitrary). batch_size = 4 in_features = 64 @@ -564,39 +526,31 @@ def test_export_layernorm( ) -@pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("return_bias", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_layernorm_recipe(seed_default_rng, fp8_recipe, precision): + _test_export_layernorm(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_layernorm_zero_centered_gamma(seed_default_rng): + _test_export_layernorm(zero_centered_gamma=True) + + @pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_linear( - seed_default_rng, - scale_factor: float, - fp8_recipe: recipe.Recipe, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - normalization: str, +def test_export_layernorm_normalization(seed_default_rng, normalization): + _test_export_layernorm(normalization=normalization) + + +def _test_export_layernorm_linear( + scale_factor: float = 112, + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_bias: bool = True, + return_bias: bool = False, + return_layernorm_output: bool = False, + precision: torch.dtype = torch.float32, + zero_centered_gamma: bool = False, + normalization: str = all_normalizations[0], ): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if return_bias and not use_bias: pytest.skip("Cannot return bias when bias is disabled") @@ -644,41 +598,44 @@ def test_export_layernorm_linear( ) -@pytest.mark.parametrize("scale_factor", [112]) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("return_bias", [True, False]) -@pytest.mark.parametrize("return_layernorm_output", [True, False]) -@pytest.mark.parametrize( - "precision, use_bias", - [ - (torch.float32, False), - (torch.float32, True), - (torch.float16, True), - (torch.float16, False), - (torch.bfloat16, True), - (torch.bfloat16, False), - ], -) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_export_layernorm_mlp( - seed_default_rng, - scale_factor: float, - fp8_recipe: recipe.Recipe, - use_bias: bool, - return_bias: bool, - return_layernorm_output: bool, - precision: torch.dtype, - zero_centered_gamma: bool, - activation: str, - normalization: str, +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision): + _test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_layernorm_linear_return_ln_out(seed_default_rng): + _test_export_layernorm_linear(return_layernorm_output=True) + + +def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng): + _test_export_layernorm_linear(zero_centered_gamma=True) + + +@pytest.mark.parametrize("normalization", all_normalizations[1:]) +def test_export_layernorm_linear_normalization(seed_default_rng, normalization): + _test_export_layernorm_linear(normalization=normalization) + + +def test_export_layernorm_linear_no_bias(seed_default_rng): + _test_export_layernorm_linear(use_bias=False) + + +def test_export_layernorm_linear_return_bias(seed_default_rng): + _test_export_layernorm_linear(return_bias=True) + + +def _test_export_layernorm_mlp( + scale_factor: float = 112, + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_bias: bool = True, + return_bias: bool = False, + return_layernorm_output: bool = False, + precision: torch.dtype = torch.float32, + zero_centered_gamma: bool = False, + activation: str = supported_activations[0], + normalization: str = all_normalizations[0], ): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if return_bias and not use_bias: pytest.skip("Cannot return bias when bias is disabled") @@ -720,6 +677,38 @@ def test_export_layernorm_mlp( ) +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision): + _test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng): + _test_export_layernorm_mlp(return_layernorm_output=True) + + +def test_export_layernorm_mlp_return_bias(seed_default_rng): + _test_export_layernorm_mlp(return_bias=True) + + +def test_export_layernorm_mlp_no_bias(seed_default_rng): + _test_export_layernorm_mlp(use_bias=False) + + +def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng): + _test_export_layernorm_mlp(zero_centered_gamma=True) + + +@pytest.mark.parametrize("normalization", all_normalizations[1:]) +def test_export_layernorm_mlp_normalization(seed_default_rng, normalization): + _test_export_layernorm_mlp(normalization=normalization) + + +@pytest.mark.parametrize("activation", supported_activations[1:]) +def test_export_layernorm_mlp_activation(seed_default_rng, activation): + _test_export_layernorm_mlp(activation=activation) + + @pytest.mark.parametrize( "precision, use_mask, attn_mask_type", [ @@ -734,8 +723,6 @@ def test_export_layernorm_mlp( ], ) def test_export_core_attention( - seed_default_rng, - set_max_seq_len, precision: torch.dtype, use_mask: bool, attn_mask_type: str, @@ -777,11 +764,6 @@ def test_export_core_attention( ) -test_configs_multihead_attention = [ - # "use_mask, attn_mask_type" - (False, "no_mask"), # calls ScaledSoftmax - (True, "arbitrary"), # calls ScaledMaskedSoftmax -] test_configs_attention_type = [ # "input_layernorm, attention_type, fuse_qkv_params" (True, "self", True), @@ -795,31 +777,14 @@ def test_export_core_attention( ] -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("return_layernorm_output", [False]) -@pytest.mark.parametrize( - "input_layernorm, attention_type, fuse_qkv_params", test_configs_attention_type -) -def test_export_multihead_attention( - seed_default_rng, - set_max_seq_len, - fp8_recipe: recipe.Recipe, - use_mask: bool, - attn_mask_type: str, - precision: torch.dtype, - return_layernorm_output: bool, - input_layernorm: bool, - attention_type: str, - fuse_qkv_params: bool, +def _test_export_multihead_attention( + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_mask: bool = True, + precision: torch.dtype = torch.float32, + input_layernorm: bool = True, + attention_type: str = "self", + fuse_qkv_params: bool = True, ): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - hidden_size = 256 sequence_length = 128 batch_size = 4 @@ -837,6 +802,7 @@ def test_export_multihead_attention( init_method, output_layer_init_method, ) + attn_mask_type = "arbitrary" if use_mask else "no_mask" hidden_states_context = torch.randn( sequence_length, batch_size, hidden_size, dtype=precision, device="cuda" @@ -868,7 +834,7 @@ def test_export_multihead_attention( *attention_args, attn_mask_type=attn_mask_type, params_dtype=precision, - return_layernorm_output=return_layernorm_output, + return_layernorm_output=False, input_layernorm=input_layernorm, attention_type=attention_type, fuse_qkv_params=fuse_qkv_params, @@ -960,30 +926,37 @@ def test_export_multihead_attention( @pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("use_mask, attn_mask_type", test_configs_multihead_attention) -@pytest.mark.parametrize("output_layernorm", [True, False]) @pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("fuse_qkv_params", [False, True]) -@pytest.mark.parametrize("zero_centered_gamma", [False, True]) -@pytest.mark.parametrize("activation", supported_activations) -def test_export_transformer_layer( - seed_default_rng, - set_max_seq_len, - fp8_recipe: recipe.Recipe, - use_mask: bool, - attn_mask_type: str, - output_layernorm: bool, - precision: torch.dtype, - fuse_qkv_params: bool, - zero_centered_gamma: bool, - activation: str, -): - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) +def test_export_multihead_attention_recipe(fp8_recipe, precision): + _test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_multihead_attention_no_mask(): + _test_export_multihead_attention(use_mask=False) + + +def test_export_multihead_attention_no_input_layernorm(): + _test_export_multihead_attention(input_layernorm=False) + +def test_export_multihead_attention_cross_attn(): + _test_export_multihead_attention(attention_type="cross") + + +def test_export_multihead_attention_unfused_qkv_params(): + _test_export_multihead_attention(fuse_qkv_params=False) + + +def _test_export_transformer_layer( + fp8_recipe: recipe.Recipe = fp8_recipes[0], + use_mask: bool = True, + attn_mask_type: str = "arbitrary", + output_layernorm: bool = False, + precision: torch.dtype = torch.float32, + fuse_qkv_params: bool = True, + zero_centered_gamma: bool = False, + activation: str = supported_activations[0], +): # Layer configuration hidden_size = 64 sequence_length = 128 @@ -1043,28 +1016,43 @@ def test_export_transformer_layer( ) -@skip_FP8 -@skip_MXFP8 +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16]) +def test_export_transformer_layer_recipe(fp8_recipe, precision): + _test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision) + + +def test_export_transformer_layer_no_mask(): + _test_export_transformer_layer(use_mask=False) + + +def test_export_transformer_layer_output_layernorm(): + _test_export_transformer_layer(output_layernorm=True) + + +def test_export_transformer_layer_unfused_qkv_params(): + _test_export_transformer_layer(fuse_qkv_params=False) + + +def test_export_transformer_layer_zero_centered_gamma(): + _test_export_transformer_layer(zero_centered_gamma=True) + + +@pytest.mark.parametrize("activation", supported_activations[1:]) +def test_export_transformer_layer_activation(activation): + _test_export_transformer_layer(activation=activation) + + @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16]) -@pytest.mark.parametrize("zero_centered_gamma", [True]) def test_export_gpt_generation( - seed_default_rng, - set_max_seq_len, fp8_recipe: recipe.Recipe, precision: torch.dtype, - zero_centered_gamma: bool, ): """Test that the ONNX model can correctly handle inputs with different shapes and that the attention mask is adjusted on-the-fly to different sequence lengths. """ - # Skip FP8 tests on non-hopper devices - if fp8_recipe is not None and not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe is not None and fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - # Layer configuration hidden_size = 64 sequence_length = 128 @@ -1091,7 +1079,6 @@ def test_export_gpt_generation( output_layernorm=output_layernorm, params_dtype=precision, fuse_qkv_params=fuse_qkv_params, - zero_centered_gamma=zero_centered_gamma, ).to(device="cuda") # "Context phase": use full input sequence length diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index dd6c6a3b0..77bea2b36 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -3,7 +3,6 @@ # See LICENSE for license information. import random -import pytest import torch from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 4df6d987a..07c636ab1 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -2,9 +2,7 @@ # # See LICENSE for license information. -from dataclasses import dataclass from typing import Optional -from contextlib import nullcontext import torch import pytest @@ -17,11 +15,9 @@ fp8_model_init, ) from transformer_engine.pytorch.utils import ( - get_device_compute_capability, init_method_normal, scaled_init_method_normal, is_bf16_compatible, - get_cudnn_version, ) from transformer_engine.pytorch import ( LayerNormLinear, @@ -31,7 +27,6 @@ TransformerLayer, RMSNorm, LayerNorm, - get_cpu_offload_context, ) from transformer_engine.common import recipe import transformer_engine_torch as tex @@ -46,13 +41,11 @@ from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor.utils import replace_raw_data from transformer_engine.pytorch.distributed import checkpoint -from utils import ModelConfig, dtype_tols +from utils import ModelConfig # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - FP8GlobalStateManager.is_fp8_block_scaling_available() -) +fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() # Record initial RNG state from script run. @@ -76,33 +69,6 @@ ) -def create_meta(scale_factor: float, size: int = 1): - meta = tex.FP8TensorMeta() - meta.amax_history = torch.zeros(1, size, dtype=torch.float32, device="cuda") - meta.scale_inv = torch.ones(size, dtype=torch.float32, device="cuda") / scale_factor - meta.scale = torch.ones(size, dtype=torch.float32, device="cuda") * scale_factor - return meta - - -def custom_amax_to_scale( - amax: torch.Tensor, - scale: torch.Tensor, - fp8_max: torch.Tensor, - recipe: recipe.DelayedScaling, -) -> torch.Tensor: - """Custom func to test recipe.""" - sf = fp8_max / amax - sf = torch.where(amax > 0.0, sf, scale) - sf = torch.where(torch.isfinite(amax), sf, scale) - - return sf - - -def custom_amax_compute(amax_history: torch.Tensor) -> torch.Tensor: - """Custom func to test recipe.""" - return torch.min(amax_history, dim=0).values - - def is_fp8_supported(config: ModelConfig): if ( config.max_seqlen_q * config.batch_size % 16 @@ -121,22 +87,15 @@ def is_fp8_supported(config: ModelConfig): "large": ModelConfig(2, 128, 4, 128, num_layers=1), } -fp8_recipes = [ - None, # Test non-FP8 - recipe.MXFP8BlockScaling(), # Test default - recipe.Float8CurrentScaling(), # Test default - recipe.Float8BlockScaling(), # Test default - recipe.DelayedScaling(), # Test default - recipe.DelayedScaling( # Test most_recent algo - amax_history_len=16, - amax_compute_algo="most_recent", - ), - recipe.DelayedScaling( # Test custom amax and scale compute algo - fp8_format=recipe.Format.E4M3, - amax_compute_algo=custom_amax_compute, - scaling_factor_compute_algo=custom_amax_to_scale, - ), -] +fp8_recipes = [] +if mxfp8_available: + fp8_recipes.append(recipe.MXFP8BlockScaling()) +if fp8_block_scaling_available: + fp8_recipes.append(recipe.Float8BlockScaling()) +if fp8_available: + fp8_recipes.append(recipe.Float8CurrentScaling()) + fp8_recipes.append(recipe.DelayedScaling()) +fp8_recipes.append(None) param_types = [torch.float32, torch.float16] if is_bf16_compatible(): # bf16 requires sm_80 or higher @@ -160,63 +119,6 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() -def _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad): - # Initialize loss function and optimizer. - loss_fn = torch.nn.MSELoss() - optimizer = torch.optim.SGD(block.parameters(), lr=0.1) - - # Placeholders used for capture. - static_input = torch.randn( - config.max_seqlen_q, - config.batch_size, - config.hidden_size, - device="cuda", - dtype=dtype, - requires_grad=True, - ) - static_target = torch.randn( - config.max_seqlen_q, config.batch_size, config.hidden_size, device="cuda", dtype=dtype - ) - - real_input = torch.rand_like(static_input) - real_target = torch.rand_like(static_target) - - use_fp8 = fp8_recipe is not None - if skip_wgrad: - _disable_wgrads(block) - - # Pre graph capture warmup in a separate stream. - s = torch.cuda.Stream() - s.wait_stream(torch.cuda.current_stream()) - with torch.cuda.stream(s): - for _ in range(3): - optimizer.zero_grad(set_to_none=True) - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): - out = block(static_input) - loss = loss_fn(out, static_target) - loss.backward() - optimizer.step() - torch.cuda.current_stream().wait_stream(s) - - # Capture. - g = torch.cuda.CUDAGraph() - optimizer.zero_grad(set_to_none=True) - with torch.cuda.graph(g): - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, _graph=True): - static_output = block(static_input) - static_loss = loss_fn(static_output, static_target) - static_loss.backward() - optimizer.step() - - # Fills the graph's input memory with new data to compute on - with torch.no_grad(): - static_input.copy_(real_input) - static_target.copy_(real_target) - g.replay() - - torch.cuda.synchronize() - - def _test_sanity_e2e_amp(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( (config.max_seqlen_q, config.batch_size, config.hidden_size), @@ -292,7 +194,7 @@ def _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_reci assert len(failed_grads) == 0, f"Gradient not accumulated for {failed_grads}." -def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): +def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad): te_inp_hidden_states = torch.randn( (config.max_seqlen_q, config.batch_size, config.hidden_size), dtype=dtype, @@ -303,16 +205,9 @@ def _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload): if skip_wgrad: _disable_wgrads(block) - if cpu_offload: - offload_context, sync_function = get_cpu_offload_context(enabled=True) - else: - offload_context = nullcontext() - sync_function = lambda x: x - use_fp8 = fp8_recipe is not None - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe), offload_context: + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): te_out = block(te_inp_hidden_states) - te_out = sync_function(te_out) loss = te_out.sum() loss.backward() torch.cuda.synchronize() @@ -471,12 +366,6 @@ def test_sanity_layernorm_linear( config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -505,12 +394,6 @@ def test_sanity_linear(dtype, fp8_recipe, model, skip_wgrad, skip_dgrad, microba config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -541,12 +424,6 @@ def test_sanity_linear_with_zero_tokens(dtype, bs, model, fp8_recipe, fp8_model_ num_tokens = bs * config.max_seqlen_q if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -586,12 +463,6 @@ def test_sanity_grouped_linear( num_tokens = bs * config.max_seqlen_q * (num_gemms - 1) if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -642,12 +513,6 @@ def test_sanity_layernorm_mlp( config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -673,35 +538,23 @@ def test_sanity_layernorm_mlp( @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("bias", all_boolean) -@pytest.mark.parametrize("activation", all_activations) +@pytest.mark.parametrize("activation", ["gelu", "swiglu"]) @pytest.mark.parametrize("normalization", all_normalizations) @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) -@pytest.mark.parametrize("cpu_offload", all_boolean) def test_sanity_gpt( dtype, fp8_recipe, model, skip_wgrad, - zero_centered_gamma, bias, activation, normalization, parallel_attention_mlp, - cpu_offload, ): - if cpu_offload and NVTE_TEST_NVINSPECT_ENABLED: - pytest.skip("CPU offload is not supported in debug mode.") config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -721,7 +574,6 @@ def test_sanity_gpt( params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, bias=bias, activation=activation, normalization=normalization, @@ -729,7 +581,7 @@ def test_sanity_gpt( parallel_attention_mlp=parallel_attention_mlp, ) - _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, cpu_offload) + _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) def test_sanity_gpt_126m(): @@ -746,12 +598,10 @@ def test_sanity_gpt_126m(): fp8_recipe=fp8_recipe, model="126m", skip_wgrad=False, - zero_centered_gamma=True, bias=True, activation="gelu", normalization="LayerNorm", parallel_attention_mlp=False, - cpu_offload=False, ) @@ -759,18 +609,13 @@ def test_sanity_gpt_126m(): @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) -def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): +def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, normalization): config = model_configs[model] if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -790,7 +635,6 @@ def test_sanity_bert(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, params_dtype=dtype, apply_residual_connection_post_layernorm=True, output_layernorm=True, - zero_centered_gamma=zero_centered_gamma, self_attn_mask_type="causal", normalization=normalization, device="cuda", @@ -811,7 +655,6 @@ def test_sanity_bert_126m(): fp8_recipe=fp8_recipe, model="126m", skip_wgrad=False, - zero_centered_gamma=False, normalization="LayerNorm", ) @@ -820,18 +663,13 @@ def test_sanity_bert_126m(): @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) @pytest.mark.parametrize("normalization", all_normalizations) -def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): +def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, normalization): config = model_configs[model] if fp8_recipe is not None: if not fp8_available: pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -852,7 +690,6 @@ def test_sanity_T5(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, no apply_residual_connection_post_layernorm=False, output_layernorm=False, layer_type="decoder", - zero_centered_gamma=zero_centered_gamma, normalization=normalization, device="cuda", ) @@ -872,7 +709,6 @@ def test_sanity_T5_126m(): fp8_recipe=fp8_recipe, model="126m", skip_wgrad=False, - zero_centered_gamma=False, normalization="LayerNorm", ) @@ -885,12 +721,6 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -917,17 +747,10 @@ def test_sanity_amp_and_nvfuser(dtype, fp8_recipe, model, skip_wgrad): @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("skip_wgrad", all_boolean) -def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): +def test_sanity_drop_path(dtype, fp8_recipe, model): config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -951,7 +774,7 @@ def test_sanity_drop_path(dtype, fp8_recipe, model, skip_wgrad): device="cuda", ) - _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) + _test_sanity_e2e(block, dtype, config, fp8_recipe, False) @pytest.mark.parametrize("dtype", param_types) @@ -962,12 +785,6 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -991,26 +808,17 @@ def test_sanity_fused_qkv_params(dtype, fp8_recipe, model, skip_wgrad): device="cuda", ) - _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad, False) + _test_sanity_e2e(block, dtype, config, fp8_recipe, skip_wgrad) @pytest.mark.parametrize("dtype", param_types) @pytest.mark.parametrize("fp8_recipe", fp8_recipes) @pytest.mark.parametrize("model", ["small"]) @pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) -def test_sanity_gradient_accumulation_fusion( - dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma -): +def test_sanity_gradient_accumulation_fusion(dtype, fp8_recipe, model, skip_wgrad): config = model_configs[model] if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) if not is_fp8_supported(config): pytest.skip("Model config does not support FP8") @@ -1030,7 +838,6 @@ def test_sanity_gradient_accumulation_fusion( params_dtype=dtype, apply_residual_connection_post_layernorm=False, output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, fuse_qkv_params=True, fuse_wgrad_accumulation=True, device="cuda", @@ -1039,52 +846,6 @@ def test_sanity_gradient_accumulation_fusion( _test_sanity_e2e_gradient_accumulation_fusion(block, dtype, config, fp8_recipe, skip_wgrad) -@pytest.mark.parametrize("dtype", param_types) -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("model", ["small"]) -@pytest.mark.parametrize("skip_wgrad", all_boolean) -@pytest.mark.parametrize("zero_centered_gamma", all_boolean) -@pytest.mark.parametrize("normalization", all_normalizations) -def test_gpt_cuda_graph(dtype, fp8_recipe, model, skip_wgrad, zero_centered_gamma, normalization): - config = model_configs[model] - - if fp8_recipe is not None: - if not fp8_available: - pytest.skip(reason_for_no_fp8) - if fp8_recipe.float8_block_scaling() and not fp8_block_scaling_available: - pytest.skip(reason_for_no_fp8_block_scaling) - if fp8_recipe.mxfp8() and not mxfp8_available: - pytest.skip(reason_for_no_mxfp8) - if fp8_recipe.float8_block_scaling(): - pytest.skip("cuda graph not supported for float8_block_scaling recipe") - if not is_fp8_supported(config): - pytest.skip("Model config does not support FP8") - - sigma = 0.023 - init_method = init_method_normal(sigma) - output_layer_init_method = scaled_init_method_normal(sigma, config.num_layers) - - block = TransformerLayer( - config.hidden_size, - 4 * config.hidden_size, - config.num_heads, - init_method=init_method, - output_layer_init_method=output_layer_init_method, - hidden_dropout=0.1, - attention_dropout=0.1, - kv_channels=config.kv_channels, - params_dtype=dtype, - apply_residual_connection_post_layernorm=False, - output_layernorm=False, - zero_centered_gamma=zero_centered_gamma, - fuse_qkv_params=True, - normalization=normalization, - device="cuda", - ) - - _test_sanity_e2e_cuda_graph(block, dtype, config, fp8_recipe, skip_wgrad) - - def test_model_multiple_cast(): a = torch.zeros((16, 16), device="cuda") m = Linear(16, 32) From 5a495a396d2588e405a3c078db635c782b560ff9 Mon Sep 17 00:00:00 2001 From: Przemyslaw Tredak Date: Mon, 28 Jul 2025 21:14:05 -0700 Subject: [PATCH 026/153] Fix the use-after-free bug in unfused normalization (#2002) Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/csrc/extensions/normalization.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 88404a2e1..0d2011ba7 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -108,9 +108,9 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe } } TensorWrapper unquantized_out_cu; + py::object unquantized_out; if (force_unfused_kernel) { NoneQuantizer q{none}; - py::object unquantized_out; std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); } TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; @@ -269,9 +269,9 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w } } TensorWrapper unquantized_out_cu; + py::object unquantized_out; if (force_unfused_kernel) { NoneQuantizer q{none}; - py::object unquantized_out; std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); } TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; From cb5013bd90673b520dcc911b07a390b095c82a06 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 29 Jul 2025 08:11:16 -0700 Subject: [PATCH 027/153] [PyTorch] Refactor C++ quantizer infrastructure (#1952) * remove reciprocal op Signed-off-by: zhongboz * Refactor Quantizer::create_tensor function Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix bug when constructing FP8 tensor Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add quantize function to C++ quantizers Signed-off-by: Tim Moon * Prototype function to coerce Python quantized tensors to match quantizer Signed-off-by: Tim Moon * Use quantizer class in tex.quantize Signed-off-by: Tim Moon * Add FP8 current scaling support for activation backward Signed-off-by: Tim Moon * Disable quantized GEMM output with FP8 current scaling Signed-off-by: Tim Moon * Add coerce_tensor functions for MXFP8 and DSv3 Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Avoid quantizing empty tensors Signed-off-by: Tim Moon * Use consistent shapes for FP8 transposes Signed-off-by: Tim Moon * In attention impl, construct FP8 tensors with pre-initialized scale-invs Signed-off-by: Tim Moon * Initialize MXFP8 scales to zero Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Store copy of quantizer when creating quantized tensors Signed-off-by: Tim Moon * Fix linter warnings Signed-off-by: Tim Moon * Make sure quantized tensors have private quantizer Avoid problems with in-place ops after quantizer usages are changed externally. Signed-off-by: Tim Moon * Rename "coerce_tensor" to "convert_and_update_tensor" Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make sure CUDA context is available when launching NVRTC kernel Signed-off-by: Tim Moon * Expose CUDA context creation function externally Signed-off-by: Tim Moon --------- Signed-off-by: zhongboz Signed-off-by: Tim Moon Co-authored-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_fusible_ops.py | 7 +- .../common/libtransformer_engine.version | 1 + .../common/util/cuda_driver.cpp | 13 + transformer_engine/common/util/cuda_driver.h | 8 + transformer_engine/common/util/rtc.h | 1 + transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 80 +- .../pytorch/csrc/extensions/activation.cpp | 127 ++- .../pytorch/csrc/extensions/attention.cpp | 53 +- .../pytorch/csrc/extensions/cast.cpp | 65 +- .../pytorch/csrc/extensions/transpose.cpp | 30 +- transformer_engine/pytorch/csrc/quantizer.cpp | 803 ++++++++++++++---- .../pytorch/ops/basic/basic_linear.py | 29 +- .../ops/fused/userbuffers_forward_linear.py | 2 +- .../_internal/float8_blockwise_tensor_base.py | 2 +- .../tensor/_internal/float8_tensor_base.py | 2 +- .../tensor/_internal/mxfp8_tensor_base.py | 2 +- .../pytorch/tensor/float8_blockwise_tensor.py | 2 +- .../pytorch/tensor/float8_tensor.py | 9 +- .../pytorch/tensor/mxfp8_tensor.py | 2 +- 20 files changed, 864 insertions(+), 376 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 13b27a0b1..9b9bb58ac 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -837,10 +837,9 @@ def _test_basic_linear( pytest.skip("FP8 output is only supported with FP8 GEMMs") if quantized_grad_input and not quantized_compute: pytest.skip("FP8 grad input is only supported with FP8 GEMMs") - if quantization == "mxfp8" and quantized_output: - pytest.skip("MXFP8 output is not supported with MXFP8 GEMMs") - if quantization == "mxfp8" and quantized_grad_input: - pytest.skip("MXFP8 grad input is not supported with MXFP8 GEMMs") + if quantization not in (None, "fp8"): + if quantized_output or quantized_grad_input: + pytest.skip("Recipe does not support quantized GEMM output") # Random data x_ref, x_test = make_reference_and_test_tensors( diff --git a/transformer_engine/common/libtransformer_engine.version b/transformer_engine/common/libtransformer_engine.version index 4412d0c5f..706c237cc 100644 --- a/transformer_engine/common/libtransformer_engine.version +++ b/transformer_engine/common/libtransformer_engine.version @@ -8,6 +8,7 @@ transformer_engine::cuda::stream_priority_range*; transformer_engine::cuda::current_device*; transformer_engine::cuda_driver::get_symbol*; + transformer_engine::cuda_driver::ensure_context_exists*; transformer_engine::ubuf_built_with_mpi*; *transformer_engine::rtc*; transformer_engine::nvte_cudnn_handle_init*; diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index 59d490e58..4812435f7 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -44,6 +44,19 @@ void *get_symbol(const char *symbol, int cuda_version) { return entry_point; } +void ensure_context_exists() { + CUcontext context; + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context); + if (context == nullptr) { + // Add primary context to context stack + CUdevice device; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device()); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device); + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device); + } +} + } // namespace cuda_driver } // namespace transformer_engine diff --git a/transformer_engine/common/util/cuda_driver.h b/transformer_engine/common/util/cuda_driver.h index a0fcd65c8..3425e0af3 100644 --- a/transformer_engine/common/util/cuda_driver.h +++ b/transformer_engine/common/util/cuda_driver.h @@ -39,6 +39,14 @@ inline CUresult call(const char *symbol, ArgTs... args) { return (*func)(args...); } +/*! \brief Ensure that the calling thread has a CUDA context + * + * Each thread maintains a stack of CUDA contexts. If the calling + * thread has an empty stack, the primary context is added to the + * stack. + */ +void ensure_context_exists(); + } // namespace cuda_driver } // namespace transformer_engine diff --git a/transformer_engine/common/util/rtc.h b/transformer_engine/common/util/rtc.h index 820b16c20..7de1e4d55 100644 --- a/transformer_engine/common/util/rtc.h +++ b/transformer_engine/common/util/rtc.h @@ -59,6 +59,7 @@ class Kernel { template void launch(int device_id, const dim3 grid_dim, const dim3 block_dim, unsigned int shared_mem_bytes, cudaStream_t stream, ArgTs &&...args) { + cuda_driver::ensure_context_exists(); void *arg_ptrs[] = {const_cast(static_cast(&args))...}; NVTE_CALL_CHECK_CUDA_DRIVER(cuLaunchKernel, get_function(device_id), grid_dim.x, grid_dim.y, grid_dim.z, block_dim.x, block_dim.y, block_dim.z, shared_mem_bytes, diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index f86b60f61..ab3b7abec 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -12,7 +12,7 @@ namespace transformer_engine::pytorch { -std::vector getTensorShape(at::Tensor t) { +std::vector getTensorShape(const at::Tensor& t) { std::vector shape; for (auto s : t.sizes()) { shape.push_back(s); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index b5b63f757..be3b995a1 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -98,9 +98,21 @@ class Quantizer { virtual void set_quantization_params(TensorWrapper* tensor) const = 0; - virtual std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const = 0; + /*! @brief Construct a tensor with uninitialized data */ + virtual std::pair create_tensor(const std::vector& shape, + DType dtype) const = 0; + + /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor + * + * The PyTorch tensor's attributes are modified to match the + * quantizer's configuration. + */ + virtual std::pair convert_and_update_tensor( + py::object tensor) const = 0; + + /*! @brief Convert to a quantized data format */ + virtual void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) = 0; virtual ~Quantizer() = default; @@ -121,9 +133,17 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Tensor data) const; + + std::pair convert_and_update_tensor(py::object tensor) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8Quantizer : public Quantizer { @@ -139,9 +159,19 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + /*! @brief Construct a tensor with pre-initialized data */ + std::pair create_tensor(const std::vector& shape, DType dtype, + std::optional data, + std::optional transpose, + std::optional scale_inv) const; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8CurrentScalingQuantizer : public Quantizer { @@ -161,9 +191,13 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; }; class Float8BlockQuantizer : public Quantizer { @@ -195,9 +229,13 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; @@ -212,16 +250,20 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor( - const std::vector& shape, DType dtype, - std::optional rowwise_data = std::nullopt) const override; + std::pair create_tensor(const std::vector& shape, + DType dtype) const override; + + std::pair convert_and_update_tensor(py::object shape) const override; + + void quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt) override; std::vector get_scale_shape(const std::vector& shape, bool columnwise) const; }; std::unique_ptr convert_quantizer(py::handle quantizer); -std::vector getTensorShape(at::Tensor t); +std::vector getTensorShape(const at::Tensor& t); transformer_engine::DType getTransformerEngineFP8Type(bool e4m3_if_hybrid, const std::string& fp8_recipe); diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index dfc8a8291..c9eae092b 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -13,87 +13,74 @@ namespace transformer_engine::pytorch { template py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = input.contiguous(); - - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - input_shape[input_shape.size() - 1] /= shape_divisor; - auto fake_tensor_type = input.scalar_type(); - - auto [te_output, out] = - my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - - // for current scaling, we need to compute amax first and then quantize - // because cache cannot fit in the entire tensor to compute amax and quantize - // the quantizer should not need amax reduction, no process group needed here - if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // activation function might change the input data range, we need to first call the activation function - // and then find the amax and scale of that and then do the quantization - // get a NoneQuantizer to calculate amax of activation output - auto my_quantizer_none = std::make_unique(py::none()); - auto [te_output_act, out_act] = - my_quantizer_none->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - - NVTE_SCOPED_GIL_RELEASE({ - act_func(te_input.data(), te_output_act.data(), at::cuda::getCurrentCUDAStream()); - // use te_output_act as input to the compute amax and find the amax of activated tensor - nvte_compute_amax(te_output_act.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - }); - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); - if (my_quantizer_cs->with_amax_reduction) { - NVTE_ERROR( - "per-tensor current scaling amax reduction is not supported in activation functions."); - } - QuantizationConfigWrapper quant_config; - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - te_output.set_amax(nullptr, DType::kFloat32, te_output.defaultShape); - nvte_quantize_v2(te_output_act.data(), te_output.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - // sanity check, since activation fusion is not supported for blockwise quantization yet - // need to raise an error here instead of silently going into act_func with wrong numerics - NVTE_ERROR("Activation fusion is not supported for blockwise quantization yet."); + // Input tensor + auto input_tensor = input.contiguous(); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct output tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape = input_cpp.shape(); + std::vector output_shape(input_shape.data, input_shape.data + input_shape.ndim); + output_shape.back() /= shape_divisor; + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [out_cpp, out_py] = quantizer_cpp->create_tensor(output_shape, fake_dtype); + + // Compute activation + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation directly + NVTE_SCOPED_GIL_RELEASE( + { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); } else { + // Compute activation in high-precision, then quantize + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE( - { act_func(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); }); + { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + quantizer_cpp->quantize(temp_cpp, out_cpp); } - return out; + return out_py; } -template -py::object dactivation_helper(const at::Tensor& grad, const at::Tensor& input, +template +py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& input, py::handle quantizer) { init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - auto input_tensor = input.contiguous(); - auto grad_tensor = grad.contiguous(); - - const TensorWrapper& te_input = makeTransformerEngineTensor(input_tensor); - const TensorWrapper& te_grad = makeTransformerEngineTensor(grad_tensor); - const auto& te_input_shape = te_input.shape(); - std::vector input_shape(te_input_shape.data, te_input_shape.data + te_input_shape.ndim); - auto fake_tensor_type = input.scalar_type(); - - auto [te_output, out] = - my_quantizer->create_tensor(input_shape, GetTransformerEngineDType(fake_tensor_type)); - NVTE_SCOPED_GIL_RELEASE({ - act_func(te_grad.data(), te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); - }); + // Grad output and input tensors + auto grad_output_tensor = grad_output.contiguous(); + auto input_tensor = input.contiguous(); + const TensorWrapper& grad_output_cpp = makeTransformerEngineTensor(grad_output_tensor); + const TensorWrapper& input_cpp = makeTransformerEngineTensor(input_tensor); + + // Construct grad input tensor + auto quantizer_cpp = convert_quantizer(quantizer); + const auto input_shape_te = input_cpp.shape(); + const std::vector input_shape(input_shape_te.data, + input_shape_te.data + input_shape_te.ndim); + auto fake_dtype = GetTransformerEngineDType(input_tensor.scalar_type()); + auto [grad_input_cpp, grad_input_py] = quantizer_cpp->create_tensor(input_shape, fake_dtype); + + // Compute activation backward + if (quantizer.is_none() || detail::IsFloat8Quantizers(quantizer.ptr()) || + detail::IsMXFP8Quantizers(quantizer.ptr())) { + // Compute activation backward directly + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + } else { + // Compute activation backward in high-precision, then quantize + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp->quantize(temp_cpp, grad_input_cpp); + } - return out; + return grad_input_py; } py::object gelu(const at::Tensor& input, py::handle quantizer) { diff --git a/transformer_engine/pytorch/csrc/extensions/attention.cpp b/transformer_engine/pytorch/csrc/extensions/attention.cpp index 71a8062b1..6d835a5c9 100644 --- a/transformer_engine/pytorch/csrc/extensions/attention.cpp +++ b/transformer_engine/pytorch/csrc/extensions/attention.cpp @@ -18,7 +18,7 @@ void mha_fill(const transformer_engine::TensorWrapper &self, const at::Tensor &s auto max_tokens = shape[0]; auto fcd_size = 1; - for (int i = 1; i <= shape.size(); i++) { + for (size_t i = 1; i <= shape.size(); i++) { fcd_size *= shape[i]; } @@ -103,8 +103,20 @@ std::vector fused_attn_fwd( auto o_shape = std::vector{q_shape.begin(), q_shape.end()}; o_shape[o_shape.size() - 1] = v_shape[v_shape.size() - 1]; py::object o_python, s_python; - std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + // Initialize FP8 tensor with scale-inverse + auto *O_quantizer_fp8 = dynamic_cast(O_quantizer.get()); + auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); + NVTE_CHECK(O_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_O, o_python) = O_quantizer_fp8->create_tensor(o_shape, fake_dtype_te, std::nullopt, + std::nullopt, std::nullopt); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + } else { + std::tie(te_O, o_python) = O_quantizer->create_tensor(o_shape, fake_dtype_te); + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + } auto o_shape_int64 = std::vector{o_shape.begin(), o_shape.end()}; // construct NVTE tensors @@ -284,8 +296,20 @@ std::vector fused_attn_bwd( py::object s_python, dp_python; std::unique_ptr S_quantizer = convert_quantizer(s_quantizer); std::unique_ptr dP_quantizer = convert_quantizer(dp_quantizer); - std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); - std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + auto *S_quantizer_fp8 = dynamic_cast(S_quantizer.get()); + auto *dP_quantizer_fp8 = dynamic_cast(dP_quantizer.get()); + NVTE_CHECK(S_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + NVTE_CHECK(dP_quantizer_fp8 != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_S, s_python) = S_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + std::tie(te_dP, dp_python) = dP_quantizer_fp8->create_tensor({0}, DType::kFloat32, std::nullopt, + std::nullopt, std::nullopt); + } else { + std::tie(te_S, s_python) = S_quantizer->create_tensor({0}, DType::kFloat32); + std::tie(te_dP, dp_python) = dP_quantizer->create_tensor({0}, DType::kFloat32); + } std::vector q_shape = convertShape(te_Q.shape()); std::vector k_shape = convertShape(te_K.shape()); @@ -374,9 +398,22 @@ std::vector fused_attn_bwd( default: NVTE_ERROR("QKV layout not supported!"); } - std::tie(te_dQ, py_dQ) = dQKV_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); - std::tie(te_dK, py_dK) = dQKV_quantizer->create_tensor(k_shape, fake_dtype_te, dK); - std::tie(te_dV, py_dV) = dQKV_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { + auto *fp8_quantizer = dynamic_cast(dQKV_quantizer.get()); + NVTE_CHECK(fp8_quantizer != nullptr, "Expected Float8Quantizer when dtype is FP8"); + std::tie(te_dQ, py_dQ) = + fp8_quantizer->create_tensor(q_shape, fake_dtype_te, dQ, std::nullopt, std::nullopt); + std::tie(te_dK, py_dK) = + fp8_quantizer->create_tensor(k_shape, fake_dtype_te, dK, std::nullopt, std::nullopt); + std::tie(te_dV, py_dV) = + fp8_quantizer->create_tensor(v_shape, fake_dtype_te, dV, std::nullopt, std::nullopt); + } else { + auto *none_quantizer = dynamic_cast(dQKV_quantizer.get()); + NVTE_CHECK(none_quantizer != nullptr, "Expected NoneQuantizer when dtype is not FP8"); + std::tie(te_dQ, py_dQ) = none_quantizer->create_tensor(q_shape, fake_dtype_te, dQ); + std::tie(te_dK, py_dK) = none_quantizer->create_tensor(k_shape, fake_dtype_te, dK); + std::tie(te_dV, py_dV) = none_quantizer->create_tensor(v_shape, fake_dtype_te, dV); + } // construct NVTE tensors if (qkv_type == DType::kFloat8E4M3 || qkv_type == DType::kFloat8E5M2) { diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 07f2be9df..5408cf1a6 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -28,60 +28,6 @@ std::vector get_tensor_shape(const TensorWrapper &tensor) { return std::vector(shape.data, shape.data + shape.ndim); } -void quantize_impl(const TensorWrapper &input, py::handle &quantizer_py, - std::unique_ptr &quantizer_cpp, TensorWrapper &output, - TensorWrapper &noop_flag) { - // Check tensor dims - NVTE_CHECK(get_tensor_shape(input) == get_tensor_shape(output), - "Input tensor (shape=", get_tensor_shape(input), - ") and output tensor (shape=", get_tensor_shape(output), ") do not match"); - if (input.numel() == 0) { - return; - } - - // Recipe-specific configuration - QuantizationConfigWrapper quant_config; - quant_config.set_noop_tensor(noop_flag.data()); - if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { - auto my_quantizer_cs = static_cast(quantizer_cpp.get()); - NVTE_SCOPED_GIL_RELEASE( - { nvte_compute_amax(input.data(), output.data(), at::cuda::getCurrentCUDAStream()); }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - // this config is used for cs scaling factor computation - // because compute scale is cannot be fused with quantize kernel - // so in nvte_quantize_v2 with current scaling, the quant config is not used again - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(output.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in output TensorWrapper to avoid atomic amax updates in kernel - output.set_amax(nullptr, DType::kFloat32, output.defaultShape); - } else if (detail::IsFloat8BlockwiseQuantizers(quantizer_py.ptr())) { - auto my_quantizer_bw = static_cast(quantizer_cpp.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); - if (my_quantizer_bw->all_gather_usage) { - quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); - } - } - - // Perform quantization - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(input.data(), output.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); -} - } // namespace py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::object &output, @@ -101,18 +47,17 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob const auto fake_dtype = input_cpp.dtype(); std::tie(output_cpp, output_py) = quantizer_cpp->create_tensor(shape, fake_dtype); } else { - output_py = output; - output_cpp = makeTransformerEngineTensor(output_py, quantizer); + std::tie(output_cpp, output_py) = quantizer_cpp->convert_and_update_tensor(output); } // Initialize no-op flag - TensorWrapper noop_flag_cpp; + std::optional noop_flag_cpp; if (noop_flag.has_value()) { noop_flag_cpp = makeTransformerEngineTensor(*noop_flag); } // Perform quantization - quantize_impl(input_cpp, quantizer, quantizer_cpp, output_cpp, noop_flag_cpp); + quantizer_cpp->quantize(input_cpp, output_cpp, noop_flag_cpp); return output_py; } @@ -182,10 +127,8 @@ void multi_tensor_quantize_impl(const std::vector &input_list, }); } else { // Quantize kernels individually - TensorWrapper dummy_noop_flag; for (size_t i = 0; i < num_tensors; ++i) { - quantize_impl(input_list[i], quantizer_py_list[i], quantizer_cpp_list[i], output_list[i], - dummy_noop_flag); + quantizer_cpp_list[i]->quantize(input_list[i], output_list[i]); } } } diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index d2f7107fe..d6ae0c86a 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -18,27 +18,35 @@ namespace pytorch { at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output) { init_extension(); - const auto dim = input.dim(); - NVTE_CHECK(dim >= 2, "Need at least 2D tensor to transpose."); - - if (input.dim() > 2) { - input = input.view({-1, input.size(dim - 1)}); + // Tensor dimensions + const auto shape = getTensorShape(input); + std::vector transpose_shape_int64; + if (shape.size() > 0) { + transpose_shape_int64.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + transpose_shape_int64.push_back(shape[i]); + } } + const size_t M = shape.size() > 0 ? product(shape) / shape.back() : 1; + const size_t N = shape.size() > 0 ? shape.back() : 1; - size_t M = static_cast(input.size(0)); - size_t N = static_cast(input.size(1)); - + // Output tensor at::Tensor out; if (output.has_value()) { out = *output; } else { - out = allocateTorchTensor(input.size(1), input.size(0), DType::kByte); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + out = at::empty(transpose_shape_int64, opts); } - if (M == 0 || N == 0) return out; + // Return immediately if tensor is empty + if (M == 0 || N == 0) { + return out; + } + + // Compute transpose auto input_cu = makeTransformerEngineTensor(input.data_ptr(), std::vector{M, N}, otype); auto output_cu = makeTransformerEngineTensor(out.data_ptr(), std::vector{N, M}, otype); - nvte_transpose(input_cu.data(), output_cu.data(), at::cuda::getCurrentCUDAStream()); return out; diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0ce1fc90e..a7b7f5889 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -12,6 +12,27 @@ namespace transformer_engine::pytorch { +namespace { + +/*! @brief Transposed tensor shape + * + * The tensor is interpreted as a 2D matrix by flattening all but the + * last dimension, and then transposed. + */ +template +std::vector make_transpose_shape(const std::vector& shape) { + std::vector ret; + if (shape.size() > 0) { + ret.push_back(shape.back()); + for (size_t i = 0; i < shape.size() - 1; ++i) { + ret.push_back(shape[i]); + } + } + return ret; +} + +} // namespace + constexpr size_t MXFP8_BLOCK_SIZE = 32; Quantizer::Quantizer(const py::handle& quantizer) { @@ -37,24 +58,36 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti this->dtype = type; } -std::pair NoneQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { - at::TensorOptions opts; - opts = opts.dtype(GetATenDType(dtype)).device(torch::kCUDA); - std::vector torch_shape; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - } - at::Tensor ret; - if (rowwise_data.has_value()) { - ret = std::move(*rowwise_data); - } else { - ret = at::empty(torch_shape, opts); - } +std::pair NoneQuantizer::create_tensor(const std::vector& shape, + DType dtype) const { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + return create_tensor(shape, dtype, at::empty(shape_int64, opts)); +} - TensorWrapper tensor; - tensor.set_rowwise_data(ret.data_ptr(), dtype, shape); - return {std::move(tensor), py::cast(ret)}; +std::pair NoneQuantizer::create_tensor(const std::vector& shape, + DType dtype, + at::Tensor data) const { + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(data.data_ptr(), dtype, shape); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), py::cast(data)}; +} + +std::pair NoneQuantizer::convert_and_update_tensor( + py::object tensor) const { + auto tensor_pyt = tensor.cast(); + TensorWrapper out_cpp; + out_cpp.set_rowwise_data(tensor_pyt.data_ptr(), + GetTransformerEngineDType(tensor_pyt.scalar_type()), + getTensorShape(tensor_pyt)); + set_quantization_params(&out_cpp); + return {std::move(out_cpp), std::move(tensor)}; +} + +void NoneQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + NVTE_ERROR("NoneQuantizer does not support quantization"); } void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { @@ -76,68 +109,180 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + at::Tensor scale_inv = at::empty(std::vector{1}, opts); + return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); +} + +std::pair Float8Quantizer::create_tensor( + const std::vector& shape, DType dtype, std::optional data, + std::optional transpose, std::optional scale_inv) const { using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); + // Initialize data tensor + const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + if (with_data && !data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data = at::empty(shape_int64, opts); + } else if (!with_data && data) { + data.reset(); } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); - } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); + py::object data_py = with_data ? py::cast(*data) : py::none(); + + // Initialize transpose tensor + const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (with_transpose && !transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose = at::empty(transpose_shape, opts); + } else if (!with_transpose && transpose) { + transpose.reset(); } - at::TensorOptions opts; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; - if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(rowwise_torch_shape, opts); - } - } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + py::object transpose_py = with_transpose ? py::cast(*transpose) : py::none(); + + // Initialize scale-inverse tensor + if (!scale_inv) { + scale_inv = at::reciprocal(scale); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - opts = opts.dtype(torch::kFloat32); - // TODO: Replace with an empty tensor. - at::Tensor scale_inv = at::reciprocal(scale); - py::object ret; + + // Construct Python FP8 tensor + py::object out_py; if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + const std::vector shape_int64(shape.begin(), shape.end()); + out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "data"_a = data_py, "fp8_scale_inv"_a = *scale_inv, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } - TensorWrapper tensor(this->get_scaling_mode()); - if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + + // Construct C++ FP8 tensor + TensorWrapper out_cpp(this->get_scaling_mode()); + if (with_data) { + out_cpp.set_rowwise_data(data->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, std::vector{1}); + } + if (with_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv->data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair Float8Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); + + // Expected buffers + const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + NVTE_CHECK(need_data || need_transpose, "Invalid usages for Float8Quantizer."); + + // Extract buffers from Python tensor + auto data_py = tensor.attr("_data"); + auto transpose_py = tensor.attr("_transpose"); + const bool has_data = !data_py.is_none(); + const bool has_transpose = !transpose_py.is_none(); + NVTE_CHECK(has_data || has_transpose, "Float8Tensor has no data."); + std::optional data_tensor, transpose_tensor; + if (has_data) { + data_tensor = data_py.cast(); } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); + if (has_transpose) { + transpose_tensor = transpose_py.cast(); + } + at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); + + // Tensor dimensions + std::vector shape; + if (has_transpose) { + const auto transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.size() > 0) { + for (size_t i = 1; i < transpose_shape.size(); ++i) { + shape.push_back(transpose_shape[i]); + } + shape.push_back(transpose_shape.front()); + } + if (has_data) { + auto expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, + ") and transpose (shape=", transpose_shape, ") do not match"); } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + } else { // Already checked has_data == true + shape = getTensorShape(*data_tensor); } - this->set_quantization_params(&tensor); - return {std::move(tensor), std::move(ret)}; + + // Coerce data tensor + if (has_data && !need_data) { + data_tensor.reset(); + data_py = py::none(); + tensor.attr("_data") = data_py; + } else if (!has_data && need_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); + data_py = py::cast(data_tensor); + tensor.attr("_data") = data_py; + } + + // Coerce transpose tensor + if (has_transpose && !need_transpose) { + transpose_tensor.reset(); + transpose_py = py::none(); + tensor.attr("_transpose") = transpose_py; + } else if (!has_transpose && need_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape, opts); + transpose_py = py::cast(transpose_tensor); + tensor.attr("_transpose") = transpose_py; + } + tensor.attr("_transpose_invalid") = !need_transpose; + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + + // Construct C++ FP8 tensor + TensorWrapper out_cpp; + if (data_tensor) { + out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (transpose_tensor) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void Float8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); } Float8CurrentScalingQuantizer::Float8CurrentScalingQuantizer(const py::handle& quantizer) @@ -187,71 +332,198 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso } std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { using namespace pybind11::literals; - std::vector rowwise_torch_shape; - std::vector columnwise_torch_shape; - std::vector scale_inv_torch_shape = {1}; // Shape of 1 element for scale_inv - if (!shape.empty()) { - columnwise_torch_shape.emplace_back(static_cast(shape.back())); - } - for (size_t i = 0; i < shape.size(); ++i) { - if (i < shape.size() - 1) { - columnwise_torch_shape.emplace_back(static_cast(shape[i])); - } - rowwise_torch_shape.emplace_back(static_cast(shape[i])); + // Initialize data tensor + at::Tensor data_tensor; + const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + if (with_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); } - at::TensorOptions opts; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - at::Tensor data; - if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(rowwise_torch_shape, opts); - } - } - const py::object py_data = rowwise_usage ? py::cast(data) : py::none(); - at::Tensor columnwise_data; - bool create_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); - if (create_transpose) { - columnwise_data = at::empty(columnwise_torch_shape, opts); + + // Initialize transpose tensor + at::Tensor transpose_tensor; + const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + if (with_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape, opts); } - const py::object py_columnwise_data = create_transpose ? py::cast(columnwise_data) : py::none(); - // In current scaling, scale is not known but we initialize it with 1 to avoid division by zero. If scale is already calculated, it can be correctly set. - at::Tensor scale_inv = at::reciprocal(scale); + // Initialize scale-inverse tensor + at::Tensor scale_inv_tensor; + { + const std::vector scale_inv_shape = {1}; + const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + scale_inv_tensor = at::empty(scale_inv_shape, opts); + } - py::object ret; + // Construct Python FP8 tensor + py::object out_py; + py::object data_py = with_data ? py::cast(data_tensor) : py::none(); + py::object transpose_py = with_transpose ? py::cast(transpose_tensor) : py::none(); if (internal) { py::handle Float8TensorClass(reinterpret_cast(Float8TensorBasePythonClass)); - ret = Float8TensorClass("data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + out_py = Float8TensorClass("data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } else { py::handle Float8TensorClass(reinterpret_cast(Float8TensorPythonClass)); - ret = Float8TensorClass("shape"_a = rowwise_torch_shape, "dtype"_a = GetATenDType(dtype), - "data"_a = py_data, "fp8_scale_inv"_a = scale_inv, - "fp8_dtype"_a = this->dtype, "data_transpose"_a = py_columnwise_data, - "quantizer"_a = this->quantizer); + const std::vector shape_int64(shape.begin(), shape.end()); + out_py = Float8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "data"_a = data_py, "fp8_scale_inv"_a = scale_inv_tensor, + "fp8_dtype"_a = this->dtype, "data_transpose"_a = transpose_py, + "quantizer"_a = this->quantizer); } - TensorWrapper tensor(this->get_scaling_mode()); - if (rowwise_usage) { - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + + // Construct C++ FP8 tensor + TensorWrapper out_cpp(this->get_scaling_mode()); + if (with_data) { + out_cpp.set_rowwise_data(data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (with_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose_tensor.data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); } - if (create_transpose) { - std::vector transposed_shape; - for (auto s : columnwise_torch_shape) { - transposed_shape.emplace_back(static_cast(s)); + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair Float8CurrentScalingQuantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), + "Float8CurrentScalingQuantizer must output to Float8Tensor."); + + // Expected buffers + const bool need_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); + const bool need_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); + NVTE_CHECK(need_data || need_transpose, "Invalid quantizer usages."); + + // Extract buffers from Python tensor + auto data_py = tensor.attr("_data"); + auto transpose_py = tensor.attr("_transpose"); + const bool has_data = !data_py.is_none(); + const bool has_transpose = !transpose_py.is_none(); + NVTE_CHECK(has_data || has_transpose, "Tensor has no data."); + std::optional data_tensor, transpose_tensor; + if (has_data) { + data_tensor = data_py.cast(); + } + if (has_transpose) { + transpose_tensor = transpose_py.cast(); + } + at::Tensor scale_inv_tensor = tensor.attr("_scale_inv").cast(); + + // Tensor dimensions + std::vector shape; + if (has_transpose) { + const auto transpose_shape = getTensorShape(*transpose_tensor); + if (transpose_shape.size() > 0) { + for (size_t i = 1; i < transpose_shape.size(); ++i) { + shape.push_back(transpose_shape[i]); + } + shape.push_back(transpose_shape.front()); } - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, transposed_shape); - tensor.set_columnwise_scale_inv(scale_inv.data_ptr(), DType::kFloat32, std::vector{1}); + if (has_data) { + auto expected_shape = getTensorShape(*data_tensor); + NVTE_CHECK(shape == expected_shape, "FP8 data (shape=", expected_shape, + ") and transpose (shape=", transpose_shape, ") do not match"); + } + } else { // Already checked has_data == true + shape = getTensorShape(*data_tensor); } - this->set_quantization_params(&tensor); - return {std::move(tensor), std::move(ret)}; + // Coerce data tensor in Python tensor + if (has_data && !need_data) { + data_tensor.reset(); + data_py = py::none(); + tensor.attr("_data") = data_py; + } else if (!has_data && need_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + data_tensor = at::empty(shape_int64, opts); + data_py = py::cast(data_tensor); + tensor.attr("_data") = data_py; + } + + // Coerce transpose tensor + if (has_transpose && !need_transpose) { + transpose_tensor.reset(); + transpose_py = py::none(); + tensor.attr("_transpose") = transpose_py; + } else if (!has_transpose && need_transpose) { + const auto transpose_shape = make_transpose_shape(shape); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + transpose_tensor = at::empty(transpose_shape, opts); + transpose_py = py::cast(transpose_tensor); + tensor.attr("_transpose") = transpose_py; + } + tensor.attr("_transpose_invalid") = !need_transpose; + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + + // Construct C++ FP8 tensor + TensorWrapper out_cpp; + if (data_tensor) { + out_cpp.set_rowwise_data(data_tensor->data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + if (transpose_tensor) { + const auto transpose_shape = make_transpose_shape(shape); + out_cpp.set_columnwise_data(transpose_tensor->data_ptr(), this->dtype, transpose_shape); + out_cpp.set_columnwise_scale_inv(scale_inv_tensor.data_ptr(), DType::kFloat32, + std::vector{1}); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + auto stream = at::cuda::getCurrentCUDAStream(); + + // Nothing to be done if input is empty + if (input.numel() == 0) { + return; + } + + // Quantization configs + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(amax_epsilon); + + // Compute amax + NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); + + // Perform amax reduction if needed + if (with_amax_reduction) { + // allreduce amax tensor + c10d::AllreduceOptions opts; + opts.reduceOp = c10d::ReduceOp::MAX; + std::vector tensors = {amax}; + NVTE_SCOPED_GIL_RELEASE({ amax_reduction_group->allreduce(tensors, opts)->wait(); }); + } + + // Compute scaling factor + NVTE_SCOPED_GIL_RELEASE({ nvte_compute_scale_from_amax(out.data(), quant_config, stream); }); + + // Cast to FP8 + out.set_amax(nullptr, DType::kFloat32, out.defaultShape); // Avoid atomic amax updates + NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); } Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { @@ -280,7 +552,7 @@ void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const } std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { + const std::vector& shape, DType dtype) const { using namespace pybind11::literals; std::vector torch_shape; for (auto s : shape) { @@ -299,11 +571,7 @@ std::pair Float8BlockQuantizer::create_tensor( : Float8BlockScaleTensorFormat::GEMM_READY); if (rowwise_usage) { - if (rowwise_data.has_value()) { - data_rowwise = std::move(*rowwise_data); - } else { - data_rowwise = at::empty(torch_shape, opts); - } + data_rowwise = at::empty(torch_shape, opts); auto scale_shape = get_scale_shape(shape, false); size_t sinv0 = scale_shape[0]; size_t sinv1 = scale_shape[1]; @@ -373,6 +641,62 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +std::pair Float8BlockQuantizer::convert_and_update_tensor( + py::object tensor) const { + const DType dtype = tensor.attr("_fp8_dtype").cast(); + bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); + + // Check the data matches quantizer usages + NVTE_CHECK(!tensor.attr("_rowwise_data").is_none() == rowwise_usage, + "Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=", + !tensor.attr("_rowwise_data").is_none(), ", rowwise_usage=", rowwise_usage); + NVTE_CHECK(!tensor.attr("_columnwise_data").is_none() == columnwise_usage, + "Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=", + !tensor.attr("_columnwise_data").is_none(), ", columnwise_usage=", columnwise_usage); + + auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); + + if (rowwise_usage) { + const at::Tensor& data_rowwise = tensor.attr("_rowwise_data").cast(); + const at::Tensor& scale_inv_rowwise = tensor.attr("_rowwise_scale_inv").cast(); + void* scale_inv_rowwise_dptr = scale_inv_rowwise.data_ptr(); + const auto& rowwise_shape = getTensorShape(data_rowwise); + ret.set_rowwise_data(data_rowwise.data_ptr(), dtype, rowwise_shape); + const auto scale_inv_rowwise_shape = getTensorShape(scale_inv_rowwise); + ret.set_rowwise_scale_inv(scale_inv_rowwise_dptr, DType::kFloat32, scale_inv_rowwise_shape); + } + if (columnwise_usage) { + const at::Tensor& data_colwise = tensor.attr("_columnwise_data").cast(); + const at::Tensor& scale_inv_colwise = tensor.attr("_columnwise_scale_inv").cast(); + void* scale_inv_colwise_dptr = scale_inv_colwise.data_ptr(); + const auto& shape = getTensorShape(data_colwise); + ret.set_columnwise_data(data_colwise.data_ptr(), dtype, shape); + const auto scale_inv_colwise_shape = getTensorShape(scale_inv_colwise); + ret.set_columnwise_scale_inv(scale_inv_colwise_dptr, DType::kFloat32, scale_inv_colwise_shape); + } + set_quantization_params(&ret); + return {std::move(ret), std::move(tensor)}; +} + +void Float8BlockQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + quant_config.set_force_pow_2_scales(force_pow_2_scales); + quant_config.set_amax_epsilon(amax_epsilon); + if (all_gather_usage) { + quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); +} + std::vector Float8BlockQuantizer::get_scale_shape(const std::vector& shape, bool columnwise) const { size_t numel = 1; @@ -465,71 +789,204 @@ void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { columnwise_data.shape); } -std::pair MXFP8Quantizer::create_tensor( - const std::vector& shape, DType dtype, std::optional rowwise_data) const { +std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, + DType dtype) const { using namespace pybind11::literals; - std::vector torch_shape; - size_t numel = 1; - for (auto s : shape) { - torch_shape.emplace_back(static_cast(s)); - numel *= s; - } - TensorWrapper tensor(NVTE_MXFP8_1D_SCALING); - at::TensorOptions opts; - at::Tensor rowwise_data1, columnwise_data, rowwise_scale_inv, - columnwise_scale_inv; // TODO(pgadzinski) - change - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - - at::Tensor data; - if (rowwise_usage) { - if (rowwise_data.has_value()) { - data = std::move(*rowwise_data); - } else { - data = at::empty(torch_shape, opts); + // Tensor dimensions + const std::vector shape_int64(shape.begin(), shape.end()); + size_t flat_first_dim = 1; + if (shape.size() > 0) { + for (size_t i = 0; i < shape.size() - 1; ++i) { + flat_first_dim *= shape[i]; } - auto scale_shape = get_scale_shape(shape, false); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - rowwise_scale_inv = at::zeros({static_cast(sinv0), static_cast(sinv1)}, opts); - tensor.set_rowwise_data(data.data_ptr(), this->dtype, shape); - tensor.set_rowwise_scale_inv( - rowwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{static_cast(sinv0), static_cast(sinv1)}); } + const size_t flat_last_dim = shape.size() > 0 ? shape.back() : 1; + NVTE_CHECK(flat_first_dim % MXFP8_BLOCK_SIZE == 0 && flat_last_dim % MXFP8_BLOCK_SIZE == 0, + "MXFP8 requires tensor dims that are divisble by ", MXFP8_BLOCK_SIZE, + " (got shape=", shape, ")"); + const auto rowwise_scale_inv_shape = get_scale_shape(shape, false); + const auto columnwise_scale_inv_shape = get_scale_shape(shape, true); + // Allocate tensors + at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor; + at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor; + const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + if (rowwise_usage) { + const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), + rowwise_scale_inv_shape.end()); + rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); + rowwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); + } if (columnwise_usage) { - auto scale_shape = get_scale_shape(shape, true); - size_t sinv0 = scale_shape[0]; - size_t sinv1 = scale_shape[1]; - columnwise_data = at::empty(torch_shape, opts); - columnwise_scale_inv = - at::zeros({static_cast(sinv0), static_cast(sinv1)}, opts); - - tensor.set_columnwise_data(columnwise_data.data_ptr(), this->dtype, shape); - tensor.set_columnwise_scale_inv( - columnwise_scale_inv.data_ptr(), DType::kFloat8E8M0, - std::vector{static_cast(sinv0), static_cast(sinv1)}); + const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), + columnwise_scale_inv_shape.end()); + columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); + columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); } - this->set_quantization_params(&tensor); - py::object ret; + // Convert tensors to Python + auto py_cast = [](at::Tensor& tensor, bool need_cast) -> py::object { + return need_cast ? py::cast(tensor) : py::none(); + }; + auto rowwise_data_py = py_cast(rowwise_data_tensor, rowwise_usage); + auto rowwise_scale_inv_py = py_cast(rowwise_scale_inv_tensor, rowwise_usage); + auto columnwise_data_py = py_cast(columnwise_data_tensor, columnwise_usage); + auto columnwise_scale_inv_py = py_cast(columnwise_scale_inv_tensor, columnwise_usage); + + // Construct Python MXFP8 tensor + py::object out_py; if (internal) { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorBasePythonClass)); - ret = MXFP8TensorClass("rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass("rowwise_data"_a = rowwise_data_py, + "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } else { py::handle MXFP8TensorClass(reinterpret_cast(MXFP8TensorPythonClass)); - ret = MXFP8TensorClass("shape"_a = torch_shape, "dtype"_a = GetATenDType(dtype), - "rowwise_data"_a = data, "columnwise_data"_a = columnwise_data, - "rowwise_scale_inv"_a = rowwise_scale_inv, - "columnwise_scale_inv"_a = columnwise_scale_inv, - "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); + out_py = MXFP8TensorClass("shape"_a = shape_int64, "dtype"_a = GetATenDType(dtype), + "rowwise_data"_a = rowwise_data_py, + "columnwise_data"_a = columnwise_data_py, + "rowwise_scale_inv"_a = rowwise_scale_inv_py, + "columnwise_scale_inv"_a = columnwise_scale_inv_py, + "fp8_dtype"_a = this->dtype, "quantizer"_a = this->quantizer); } - return {std::move(tensor), std::move(ret)}; + // Construct C++ MXFP8 tensor + TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + rowwise_scale_inv_shape); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data_tensor.data_ptr(), this->dtype, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv_tensor.data_ptr(), DType::kFloat8E8M0, + columnwise_scale_inv_shape); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(out_py)}; +} + +std::pair MXFP8Quantizer::convert_and_update_tensor( + py::object tensor) const { + NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); + + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + NVTE_CHECK(rowwise_data || columnwise_data, "MXFP8Tensor has no data."); + + // Tensor dimensions + std::vector shape; + if (columnwise_data) { + shape = getTensorShape(*columnwise_data); + if (rowwise_data) { + auto expected_shape = getTensorShape(*rowwise_data); + NVTE_CHECK(shape == expected_shape, "MXFP8 row-wise data (shape=", expected_shape, + ") and column-wise data (shape=", shape, ") do not match"); + } + } else { // Already checked columnwise_data_tensor == true + shape = getTensorShape(*rowwise_data); + } + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_data = at::empty(shape_int64, opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, false); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + rowwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + if (!columnwise_data) { + const std::vector shape_int64(shape.begin(), shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_data = at::empty(shape_int64, opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + const auto scale_inv_shape = get_scale_shape(shape, true); + const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), + scale_inv_shape.end()); + const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + columnwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + } + + // Coerce other attrs + tensor.attr("_fp8_dtype") = dtype; + + // Construct C++ MXFP8 tensor + TensorWrapper out_cpp(NVTE_MXFP8_1D_SCALING); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, shape); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, shape); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*columnwise_scale_inv)); + } + this->set_quantization_params(&out_cpp); + + return {std::move(out_cpp), std::move(tensor)}; +} + +void MXFP8Quantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + if (input.numel() == 0) { + return; + } + QuantizationConfigWrapper quant_config; + if (noop_flag) { + quant_config.set_noop_tensor(noop_flag->data()); + } + NVTE_SCOPED_GIL_RELEASE({ + nvte_quantize_v2(input.data(), out.data(), quant_config, at::cuda::getCurrentCUDAStream()); + }); } std::vector MXFP8Quantizer::get_scale_shape(const std::vector& shape, diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 383efc823..7f10336ce 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -22,8 +22,7 @@ from ...fp8 import FP8GlobalStateManager, Recipe from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD from ...tensor import Quantizer -from ...tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ...tensor.mxfp8_tensor import MXFP8Quantizer +from ...tensor.float8_tensor import Float8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from ..op import BasicOperation, OperationContext from .._common import maybe_dequantize, is_quantized_tensor @@ -480,18 +479,11 @@ def _functional_forward( raise ValueError("Output tensor is quantized, but quantizer was not provided") else: output_quantizer = None - if isinstance(output_quantizer, MXFP8Quantizer): - raise RuntimeError( - "Attempting to generate MXFP8 output tensor, " - "but GEMM with MXFP8 output is not supported" - ) - if isinstance(output_quantizer, Float8BlockQuantizer): - raise RuntimeError( - "Attempting to generate Float8BlockQuantized output tensor, " - "but GEMM with Float8BlockQuantized output is not supported" - ) - if output_quantizer is not None: + if not isinstance(output_quantizer, Float8Quantizer): + raise RuntimeError( + "Attempting to generate quantized output tensor with unsupported quantizer" + ) output_quantizer.set_usage(rowwise=True, columnwise=False) # Check if accumulating into output tensor @@ -765,11 +757,12 @@ def _functional_backward( ) else: grad_input_quantizer = None - if isinstance(grad_input_quantizer, MXFP8Quantizer): - raise RuntimeError( - "Attempting to generate MXFP8 grad input tensor, " - "but GEMM with MXFP8 output is not supported" - ) + if grad_input_quantizer is not None: + if not isinstance(grad_input_quantizer, Float8Quantizer): + raise RuntimeError( + "Attempting to generate quantized grad input tensor " + "with unsupported quantizer" + ) # Check if accumulating into grad input tensor if accumulate_into_grad_input: diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 9316f3d79..61853f9f4 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -182,7 +182,7 @@ def _functional_forward( if weight_quantizer is None: raise ValueError("Missing quantizer for weight tensor") if output_quantizer is not None: - raise ValueError("FP8 output is not supported") + raise ValueError("Quantized output is not supported") else: input_quantizer = None weight_quantizer = None diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 882650ffb..787c322a0 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -59,7 +59,7 @@ def __new__( instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index c0dc6e651..a88ae33f0 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -86,7 +86,7 @@ def __new__( else: instance = super().__new__(cls, *args, **kwargs) instance._data = data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._scale_inv = fp8_scale_inv instance._transpose = data_transpose diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index 8f87e5c73..a093904bc 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -83,7 +83,7 @@ def __new__( instance = super().__new__(cls, *args, **kwargs) instance._rowwise_data = rowwise_data instance._columnwise_data = columnwise_data - instance._quantizer = quantizer + instance._quantizer = quantizer.copy() if quantizer is not None else None instance._fp8_dtype = fp8_dtype instance._rowwise_scale_inv = rowwise_scale_inv instance._columnwise_scale_inv = columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index bac715949..0e41fc9c5 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -521,7 +521,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: def _set_from_tensor(dst: Float8BlockwiseQTensor, src: Float8BlockwiseQTensor): dst._rowwise_data = src._rowwise_data dst._columnwise_data = src._columnwise_data - dst._quantizer = src._quantizer + dst._quantizer = src._quantizer.copy() dst._fp8_dtype = src._fp8_dtype dst._rowwise_scale_inv = src._rowwise_scale_inv dst._columnwise_scale_inv = src._columnwise_scale_inv diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index bccfc49db..895e68bf0 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -108,10 +108,9 @@ def make_empty( # Allocate FP8 data transpose if needed data_transpose = None if self.columnwise_usage: - inner_dim = data.size(-1) + transpose_shape = [data.size(-1)] + list(data.shape[:-1]) data_transpose = torch.empty( - inner_dim, - data.numel() // inner_dim, + transpose_shape, dtype=torch.uint8, device=device, ) @@ -230,7 +229,7 @@ def __init__( amax_epsilon: float = 0.0, ) -> None: super().__init__(rowwise=rowwise, columnwise=columnwise) - self.scale = torch.ones(1, dtype=torch.float32, device=device) + self.scale = torch.empty(1, dtype=torch.float32, device=device) self.amax = torch.empty(1, dtype=torch.float32, device=device) self.dtype = fp8_dtype self.with_amax_reduction = with_amax_reduction @@ -690,7 +689,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: # Float8Tensor attributes self._data = tensor._data - self._quantizer = tensor._quantizer + self._quantizer = tensor._quantizer.copy() self._fp8_dtype = tensor._fp8_dtype self._scale_inv = tensor._scale_inv self._transpose = tensor._transpose diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 10b587e17..b96575d37 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -433,7 +433,7 @@ def _set_data(self, tensor: torch.Tensor) -> None: super(MXFP8Tensor, type(self)).data.__set__(self, dummy_tensor) self._rowwise_data = tensor._rowwise_data self._columnwise_data = tensor._columnwise_data - self._quantizer = tensor._quantizer + self._quantizer = tensor._quantizer.copy() self._fp8_dtype = tensor._fp8_dtype self._rowwise_scale_inv = tensor._rowwise_scale_inv self._columnwise_scale_inv = tensor._columnwise_scale_inv From f858dc351cee49a8611e11c97637d35140b0327c Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 29 Jul 2025 12:30:00 -0700 Subject: [PATCH 028/153] Rename `do_not_clear` to `_do_not_clear` (#1977) Signed-off-by: Jan Bielak Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/cpu_offload.py | 2 +- transformer_engine/pytorch/ops/fuser.py | 2 +- transformer_engine/pytorch/utils.py | 4 ++-- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 75d3e1b2e..3fdf8b14f 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -431,7 +431,7 @@ def tensor_pop(self, tensor_tag, **kwargs): tensor = self.fp8_tensor_object_map.pop(tensor_tag) if self.double_buffering: - tensor.do_not_clear = True + tensor._do_not_clear = True self.tensor_tag_to_buf.pop(tensor_tag, None) # the tensor should have been copied back in on_group_commit_backward() diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 9923a5fbe..6a26de2aa 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -101,7 +101,7 @@ def forward( # Mark input tensors as not deletable in backward for tensor in (input_,) + params_and_extra_inputs: - tensor.do_not_clear = True + tensor._do_not_clear = True # Unflatten list of parameters and extra tensor inputs extra_inputs = params_and_extra_inputs[-fuser.num_extra_inputs :] diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 25b7a65dc..6420f3e12 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -41,10 +41,10 @@ def clear_tensor_data(*tensors: Tuple[Optional[torch.Tensor], ...]) -> None: for t in tensors: if t is not None: # Workaround for double buffering in cpu offload - if hasattr(t, "do_not_clear"): + if hasattr(t, "_do_not_clear"): continue if hasattr(t, "get_data_tensors"): - if any(hasattr(tensor, "do_not_clear") for tensor in t.get_data_tensors()): + if any(hasattr(tensor, "_do_not_clear") for tensor in t.get_data_tensors()): continue if hasattr(t, "clear"): From feda5b558ff9cfa9033edd976795f3276a1a3707 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 29 Jul 2025 16:48:51 -0700 Subject: [PATCH 029/153] Fuse amax computation into activation kernel (#2004) * Compute amax in activation kernels when the output pointer is provided, even for non-fp8 outputs Signed-off-by: Jan Bielak (cherry picked from commit 9f13fe2fefc58cae93bc467d87d01ecf792a0381) * Initialize metatensor values Signed-off-by: Jan Bielak * Fuse computation of amax into the activation kernel for fp8 current scaling Signed-off-by: Jan Bielak (cherry picked from commit 2b54327ac9c931a5340983a79e99de5caa0399dd) Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Zero out amax in `create_hp_tensor_with_amax` instead of relying on `Float8CurrentScalingQuantizer.__init__` to zero-initialize it Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/transformer_engine.cpp | 3 +- .../common/util/vectorized_pointwise.h | 93 +++++++++++-------- transformer_engine/pytorch/csrc/common.h | 16 ++++ .../pytorch/csrc/extensions/activation.cpp | 18 ++++ transformer_engine/pytorch/csrc/quantizer.cpp | 31 ++++++- 5 files changed, 117 insertions(+), 44 deletions(-) diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index 858945251..a33f3d959 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -197,7 +197,8 @@ void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empt } } else { NVTE_CHECK(t.scale.dptr == nullptr, "Scale is not supported for non-FP8 output ", name); - NVTE_CHECK(t.amax.dptr == nullptr, "Amax is not supported for non-FP8 output ", name); + // Note: amax is supported for non-FP8 output as it can be fused into the computation + // and later used for quantization with no need to compute it separately NVTE_CHECK(t.scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 output ", name); NVTE_CHECK(t.columnwise_scale_inv.dptr == nullptr, "Scale_inv is not supported for non-FP8 input ", name); diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 420b9ed3b..6e4507eef 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -183,6 +183,7 @@ __launch_bounds__(unary_kernel_threads) __global__ VectorizedStorer storer(output, N); ComputeType max = 0; ComputeType s = 1; + const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -196,27 +197,28 @@ __launch_bounds__(unary_kernel_threads) __global__ for (int i = 0; i < nvec; ++i) { const ComputeType val = static_cast(loader.separate()[i]); ComputeType temp = OP(val, p); - if constexpr (is_fp8::value) { + if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(temp), max); - + } + if constexpr (is_fp8::value) { temp = temp * s; } - storer.separate()[i] = static_cast(temp); } storer.store(tid, N); } - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + + // Reduce amax over block + if (requires_amax) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } + if constexpr (is_fp8::value) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { reciprocal(scale_inv, s); @@ -236,6 +238,7 @@ __launch_bounds__(unary_kernel_threads) __global__ VectorizedStorer storer(output, N); ComputeType max = 0; ComputeType s = 1; + const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -251,10 +254,11 @@ __launch_bounds__(unary_kernel_threads) __global__ const ComputeType val = static_cast(loader.separate()[i]); const ComputeType g = static_cast(grad_loader.separate()[i]); ComputeType temp = OP(val, p) * g; - if constexpr (is_fp8::value) { + if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(temp), max); - + } + if constexpr (is_fp8::value) { temp = temp * s; } @@ -262,16 +266,17 @@ __launch_bounds__(unary_kernel_threads) __global__ } storer.store(tid, N); } - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + + // Reduce amax over block + if (requires_amax) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } + if constexpr (is_fp8::value) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { reciprocal(scale_inv, s); @@ -406,6 +411,7 @@ __launch_bounds__(unary_kernel_threads) __global__ const size_t M = num_aligned_elements * m; ComputeType max = 0; ComputeType s = 1; + const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -425,25 +431,28 @@ __launch_bounds__(unary_kernel_threads) __global__ const ComputeType val = static_cast(loader0.separate()[i]); const ComputeType val2 = static_cast(loader1.separate()[i]); ComputeType temp = static_cast(Activation(val, p) * val2); - if constexpr (is_fp8::value) { + if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(temp), max); + } + if constexpr (is_fp8::value) { temp = temp * s; } storer.separate()[i] = static_cast(static_cast(temp)); } storer.store(id_x, n); } - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + + // Reduce amax over block + if (requires_amax) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } + if constexpr (is_fp8::value) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { reciprocal(scale_inv, s); @@ -497,6 +506,7 @@ __launch_bounds__(unary_kernel_threads) __global__ const size_t M = num_aligned_elements * m; ComputeType max = 0; ComputeType s = 1; + const bool requires_amax = (amax != nullptr); if constexpr (is_fp8::value) { if (scale != nullptr) s = *scale; } @@ -524,11 +534,13 @@ __launch_bounds__(unary_kernel_threads) __global__ ComputeType after_dgelu = Dactivation(gelu_in, p) * grad_val * gate_in; ComputeType after_dgate = grad_val * Activation(gelu_in, p); - if constexpr (is_fp8::value) { + if (requires_amax) { __builtin_assume(max >= 0); max = fmaxf(fabsf(after_dgelu), max); - after_dgelu = after_dgelu * s; max = fmaxf(fabsf(after_dgate), max); + } + if constexpr (is_fp8::value) { + after_dgelu = after_dgelu * s; after_dgate = after_dgate * s; } @@ -538,16 +550,17 @@ __launch_bounds__(unary_kernel_threads) __global__ storer0.store(id_x, n); storer1.store(id_x, n); } - if constexpr (is_fp8::value) { - // Reduce amax over block - if (amax != nullptr) { - max = reduce_max(max, warp_id); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(amax, max); - } + + // Reduce amax over block + if (requires_amax) { + max = reduce_max(max, warp_id); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(amax, max); } + } + if constexpr (is_fp8::value) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && scale_inv != nullptr) { reciprocal(scale_inv, s); diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index be3b995a1..45e3291ef 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -194,10 +194,26 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + /*! @brief Construct a high precision tensor giving it this quantizer's amax + + Note: this member function also zeros out the amax, as it is meant to be used in conjunction with + a kernel computing the amax, which might expect the amax to be initialized to zero + */ + std::pair create_hp_tensor_with_amax(const std::vector& shape, + DType dtype); + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag = std::nullopt) override; + + /*! @brief Convert to a quantized data format avoiding amax computation */ + void quantize_with_amax(TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag = std::nullopt); + + private: + void quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, bool compute_amax); }; class Float8BlockQuantizer : public Quantizer { diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index c9eae092b..2ef7a869a 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -32,8 +32,17 @@ py::object activation_helper(const at::Tensor& input, py::handle quantizer, int // Compute activation directly NVTE_SCOPED_GIL_RELEASE( { act_func(input_cpp.data(), out_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // Compute activation in high-precision fused together with amax, then quantize. + + auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); + auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(output_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE( + { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + quantizer_cpp_cs->quantize_with_amax(temp_cpp, out_cpp); } else { // Compute activation in high-precision, then quantize + auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(output_shape, fake_dtype); NVTE_SCOPED_GIL_RELEASE( { act_func(input_cpp.data(), temp_cpp.data(), at::cuda::getCurrentCUDAStream()); }); @@ -70,6 +79,15 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i dact_func(grad_output_cpp.data(), input_cpp.data(), grad_input_cpp.data(), at::cuda::getCurrentCUDAStream()); }); + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + // Compute activation backward in high-precision fused together with amax, then quantize. + auto quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); + auto [temp_cpp, _] = quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, fake_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_cpp.data(), input_cpp.data(), temp_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + quantizer_cpp_cs->quantize_with_amax(temp_cpp, grad_input_cpp); } else { // Compute activation backward in high-precision, then quantize auto [temp_cpp, _] = NoneQuantizer(py::none()).create_tensor(input_shape, fake_dtype); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index a7b7f5889..f0e0aba00 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -397,6 +397,15 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8CurrentScalingQuantizer::create_hp_tensor_with_amax( + const std::vector& shape, DType dtype) { + amax.zero_(); + auto [out_cpp, out_py] = NoneQuantizer(py::none()).create_tensor(shape, dtype); + out_cpp.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8CurrentScalingQuantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), @@ -489,8 +498,9 @@ std::pair Float8CurrentScalingQuantizer::convert_and_ return {std::move(out_cpp), std::move(tensor)}; } -void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, - const std::optional& noop_flag) { +void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag, + bool compute_amax) { auto stream = at::cuda::getCurrentCUDAStream(); // Nothing to be done if input is empty @@ -507,7 +517,9 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorW quant_config.set_amax_epsilon(amax_epsilon); // Compute amax - NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); + if (compute_amax) { + NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); + } // Perform amax reduction if needed if (with_amax_reduction) { @@ -526,6 +538,19 @@ void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorW NVTE_SCOPED_GIL_RELEASE({ nvte_quantize_v2(input.data(), out.data(), quant_config, stream); }); } +void Float8CurrentScalingQuantizer::quantize(const TensorWrapper& input, TensorWrapper& out, + const std::optional& noop_flag) { + this->quantize_impl(input, out, noop_flag, true); +} + +void Float8CurrentScalingQuantizer::quantize_with_amax( + TensorWrapper& input, TensorWrapper& out, const std::optional& noop_flag) { + NVTE_CHECK(input.get_amax().data_ptr == amax.data_ptr(), + "Input does not use the appropriate amax tensor"); + input.set_amax(nullptr, DType::kFloat32, input.defaultShape); + this->quantize_impl(input, out, noop_flag, false); +} + Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quantizer(quantizer) { this->dtype = quantizer.attr("dtype").cast(); this->block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); From 020428f07303ce012213455db304f3f817f3cc5f Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 29 Jul 2025 17:41:56 -0700 Subject: [PATCH 030/153] [PyTorch] Fix bug with clearing op outputs during backward (#2008) Fix merge conflict bug with clearing op outputs Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/fuser.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 6a26de2aa..98b3468a2 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -199,7 +199,7 @@ def forward( # Mark output tensors as not deletable in backward for tensor in [x] + extra_outputs_flat: - tensor.do_not_clear = True + tensor._do_not_clear = True x.requires_grad_(fuser.first_op_requiring_backward < fuser._num_basic_ops) From 11ac24cfa22cdda3b57d36a1f732ac582863b1a2 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 29 Jul 2025 17:42:48 -0700 Subject: [PATCH 031/153] Refactor normalization.cpp to use quantizer logic introduced in #1952 (#2006) Refactor normalization.cpp to use quantizer logic introduced in #1952 instead of manual quantization Signed-off-by: Jan Bielak --- .../pytorch/csrc/extensions/normalization.cpp | 80 +------------------ 1 file changed, 2 insertions(+), 78 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 0d2011ba7..45d4bf870 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -139,45 +139,7 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { - QuantizationConfigWrapper quant_config; - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), - at::cuda::getCurrentCUDAStream()); - }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = - my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); - } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - auto my_quantizer_bw = static_cast(my_quantizer.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); - if (my_quantizer_bw->all_gather_usage) { - quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); - } - } - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); + my_quantizer->quantize(unquantized_out_cu, out_cu); } return {out, py::cast(mu), py::cast(rsigma)}; @@ -300,45 +262,7 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { - QuantizationConfigWrapper quant_config; - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(my_quantizer.get()); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_amax(unquantized_out_cu.data(), out_cu.data(), - at::cuda::getCurrentCUDAStream()); - }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = - my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor &amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(out_cu.data(), quant_config, at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - out_cu.set_amax(nullptr, DType::kFloat32, out_cu.defaultShape); - } else if (IsFloat8BlockwiseQuantizers(quantizer.ptr())) { - auto my_quantizer_bw = static_cast(my_quantizer.get()); - quant_config.set_force_pow_2_scales(my_quantizer_bw->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_bw->amax_epsilon); - if (my_quantizer_bw->all_gather_usage) { - quant_config.set_float8_block_scale_tensor_format(Float8BlockScaleTensorFormat::COMPACT); - } - } - NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_v2(unquantized_out_cu.data(), out_cu.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); + my_quantizer->quantize(unquantized_out_cu, out_cu); } return {out, py::none(), py::cast(rsigma)}; From 858755c0c3898d5bc971fcec031b26ff26e7e4fb Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 30 Jul 2025 06:31:14 -0700 Subject: [PATCH 032/153] [JAX] TE GEMM checkpointing policies (#2003) * TE primitive checkpointing policies Signed-off-by: Jeremy Berchtold * Remove batched gemm policy Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- docs/api/jax.rst | 4 ++ transformer_engine/jax/attention.py | 16 ++++++-- transformer_engine/jax/checkpoint_policies.py | 38 +++++++++++++++++++ transformer_engine/jax/flax/module.py | 18 +++++---- transformer_engine/jax/flax/transformer.py | 7 ++++ 5 files changed, 72 insertions(+), 11 deletions(-) create mode 100644 transformer_engine/jax/checkpoint_policies.py diff --git a/docs/api/jax.rst b/docs/api/jax.rst index d72af37ec..1af5cd1d0 100644 --- a/docs/api/jax.rst +++ b/docs/api/jax.rst @@ -19,6 +19,10 @@ Variables are available in `transformer_engine.jax.sharding`. * JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded. +Checkpointing +------------------------------------ +When using checkpointing with Transformer Engine JAX, please be aware of the checkpointing policy being applied to your model. Any JAX checkpointing policy using `dot`, such as `jax.checkpoint_policies.dots_with_no_batch_dims`, may not work with GEMMs provided by Transformer Engine as they do not always use the `jax.lax.dot_general` primitive. Instead, you can use `transformer_engine.jax.checkpoint_policies.dots_and_te_gemms_with_no_batch_dims` or similar policies that are designed to work with Transformer Engine's GEMMs and `jax.lax.dot_general` GEMMs. You may also use any JAX policies that do not filter by primitive, such as `jax.checkpoint_policies.save_only_these_names` or `jax.checkpoint_policies.everything_saveable`. + Modules ------------------------------------ .. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index fe4109cee..093146162 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -855,7 +855,7 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)) +@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], @@ -872,6 +872,7 @@ def _fused_attn( context_parallel_strategy: CPStrategy, context_parallel_causal_load_balanced: bool, context_parallel_axis: str, + context_checkpoint_name: str = "context", ): output, _ = _fused_attn_fwd_rule( qkv, @@ -889,6 +890,7 @@ def _fused_attn( context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, + context_checkpoint_name=context_checkpoint_name, ) return output @@ -909,6 +911,7 @@ def _fused_attn_fwd_rule( context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, + context_checkpoint_name, ): output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, @@ -927,9 +930,9 @@ def _fused_attn_fwd_rule( context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, ) - output = checkpoint_name(output, "context") - softmax_aux = checkpoint_name(softmax_aux, "context") - rng_state = checkpoint_name(rng_state, "context") + output = checkpoint_name(output, context_checkpoint_name) + softmax_aux = checkpoint_name(softmax_aux, context_checkpoint_name) + rng_state = checkpoint_name(rng_state, context_checkpoint_name) return output, ( qkv, bias, @@ -952,9 +955,11 @@ def _fused_attn_bwd_rule( context_parallel_strategy, context_parallel_causal_load_balanced, context_parallel_axis, + context_checkpoint_name, ctx, dz, ): + del context_checkpoint_name ( qkv, bias, @@ -1012,6 +1017,7 @@ def fused_attn( context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, context_parallel_causal_load_balanced: bool = False, context_parallel_axis: str = "", + context_checkpoint_name: str = "context", ): """ Perform cuDNN fused attention. @@ -1044,6 +1050,7 @@ def fused_attn( context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + context_checkpoint_name (str): The name of the context checkpoint for the custom VJP forward pass. Returns: (jnp.ndarray): The output tensor from the fused attention. @@ -1116,6 +1123,7 @@ def fused_attn( context_parallel_strategy=context_parallel_strategy, context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, context_parallel_axis=context_parallel_axis, + context_checkpoint_name=context_checkpoint_name, ) return output diff --git a/transformer_engine/jax/checkpoint_policies.py b/transformer_engine/jax/checkpoint_policies.py new file mode 100644 index 000000000..a03db09b9 --- /dev/null +++ b/transformer_engine/jax/checkpoint_policies.py @@ -0,0 +1,38 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. +"""Checkpoint policies for Transformer Engine in JAX. + +This module provides JAX checkpoint policies that are compatible with Transformer Engine's custom primitives. +""" + +import jax +from .cpp_extensions.gemm import GemmPrimitive, GroupedGemmPrimitive + + +__all__ = [ + "te_gemms_saveable", + "dots_and_te_gemms_with_no_batch_dims", + "checkpoint_dots_and_te_gemms", +] + + +def te_gemms_saveable(prim, *_, **__) -> bool: + """Checkpoint policy for Transformer Engine GEMMs.""" + is_te_gemm = prim in {GemmPrimitive.outer_primitive, GroupedGemmPrimitive.outer_primitive} + # Workaround to include JAX's scaled_matmul until JAX checkpoint policies for dots are + # updated to include it. + is_jax_scaled_matmul = prim.name == "scaled_matmul_wrapper" + + return is_te_gemm or is_jax_scaled_matmul + + +dots_and_te_gemms_with_no_batch_dims = jax.checkpoint_policies.save_from_both_policies( + jax.checkpoint_policies.checkpoint_dots_with_no_batch_dims, + te_gemms_saveable, +) + +checkpoint_dots_and_te_gemms = jax.checkpoint_policies.save_from_both_policies( + jax.checkpoint_policies.checkpoint_dots, + te_gemms_saveable, +) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 6670377f7..ba5ee6d13 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -940,6 +940,11 @@ class LayerNormMLP(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input of 2nd dot, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. + ffn1_ckpt_name: str = "ffn1" + Checkpoint name for the output of the first fully-connected layer in the MLP block. + ffn2_ckpt_name: str = "ffn2" + Checkpoint name for the output of the second fully-connected layer in the MLP block. + Optimization parameters ----------------------- @@ -981,6 +986,8 @@ class LayerNormMLP(TransformerEngineBase): layernorm_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None + ffn1_ckpt_name: str = "ffn1" + ffn2_ckpt_name: str = "ffn2" def __post_init__(self): if self.transpose_batch_sequence: @@ -1150,9 +1157,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): bias_1 = None bias_2 = None - ffn1_ckpt_name = "ffn1" - ffn2_ckpt_name = "ffn2" - if use_fused_layernorm_mlp: out = layernorm_mlp( y, @@ -1168,8 +1172,8 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): dot_2_input_axes=self.dot_2_input_axes, kernel_1_axes=self.kernel_axes_1, kernel_2_axes=self.kernel_axes_2, - ffn1_ckpt_name=ffn1_ckpt_name, - ffn2_ckpt_name=ffn2_ckpt_name, + ffn1_ckpt_name=self.ffn1_ckpt_name, + ffn2_ckpt_name=self.ffn2_ckpt_name, activation_type=normalized_acts, quantizer_sets=(ffn1_quantizer_set, ffn2_quantizer_set), ) @@ -1251,7 +1255,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if self.use_bias: x += jnp.reshape(bias_1, bias_1_shape) - x = checkpoint_name(x, ffn1_ckpt_name) + x = checkpoint_name(x, self.ffn1_ckpt_name) if is_act_implemented: z = activation(x, normalized_acts) else: @@ -1314,7 +1318,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): if self.use_bias: out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) - out = checkpoint_name(out, ffn2_ckpt_name) + out = checkpoint_name(out, self.ffn2_ckpt_name) assert out.dtype == input_dtype return out, ln_output # Output, layner_norm_output diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 5f309820c..2d13f25ca 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -274,6 +274,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" + context_checkpoint_name: str = "context" @nn.compact def __call__( @@ -322,6 +323,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_checkpoint_name=self.context_checkpoint_name, ) elif self.qkv_layout.is_kvpacked(): """kvpacked format, treat @@ -348,6 +350,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_checkpoint_name=self.context_checkpoint_name, ) elif self.qkv_layout.is_separate(): if self.transpose_batch_sequence: @@ -369,6 +372,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_checkpoint_name=self.context_checkpoint_name, ) else: raise ValueError(f"Unsupported {self.qkv_layout=}.") @@ -501,6 +505,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. Optimization parameters ----------------------- @@ -524,6 +529,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" + context_checkpoint_name: str = "context" @nn.compact def __call__( @@ -690,6 +696,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_checkpoint_name=self.context_checkpoint_name, )( query, key, From 44a581c1fbb05225e9a3edff91224d198d23c0a5 Mon Sep 17 00:00:00 2001 From: Dupel <36371030+dupeljan@users.noreply.github.com> Date: Thu, 31 Jul 2025 17:31:04 +0200 Subject: [PATCH 033/153] [PyTorch Debug] Minor fix in docs. (#1947) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Update 1_getting_started.rst Signed-off-by: dupeljan * Update docs/debug/1_getting_started.rst Signed-off-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> * Update docs/debug/1_getting_started.rst Signed-off-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> * Update docs/debug/1_getting_started.rst Signed-off-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> --------- Signed-off-by: dupeljan Signed-off-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> Co-authored-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> --- docs/debug/1_getting_started.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/debug/1_getting_started.rst b/docs/debug/1_getting_started.rst index 555b9b4b8..a7b86dad3 100644 --- a/docs/debug/1_getting_started.rst +++ b/docs/debug/1_getting_started.rst @@ -141,7 +141,7 @@ Adjusting Python file In the modified code above, the following changes were made: 1. Added an import for ``nvdlfw_inspect.api``. -2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. +2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. The directory with Transformer Engine features is located `here `_. The full parameters description could be found :doc:`here <3_api_debug_setup>`. 3. Added ``debug_api.step()`` after each of the forward-backward pass. Inspecting the logs From 51eb6362be7db24da501d078302b4d9d8110272d Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Thu, 31 Jul 2025 14:05:17 -0700 Subject: [PATCH 034/153] Fuse amax computation into normalization kernel for current scaling (#2013) * Compute amax in normalization kernels as long as the pointer is provided, even if using non quantized output Signed-off-by: Jan Bielak * Fuse amax computation into normalization forward Signed-off-by: Jan Bielak * Use TE lahyernorm kernel instead of raising error about unsupported cuDNN feature Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/normalization/layernorm/ln_api.cpp | 4 +++ .../layernorm/ln_fwd_kernels.cuh | 22 +++++++----- .../normalization/rmsnorm/rmsnorm_api.cpp | 4 +++ .../rmsnorm/rmsnorm_fwd_kernels.cuh | 22 +++++++----- .../pytorch/csrc/extensions/normalization.cpp | 34 +++++++++++++++---- 5 files changed, 62 insertions(+), 24 deletions(-) diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 3f7a71014..cf5678e40 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -65,6 +65,10 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size bool is_aligned = true; bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { + cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output + } + bool gamma_in_weight_dtype = false; if (cudnn_backend) { // TODO: add check for GPU ARCH diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index eb2f62b4b..417e84a56 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -75,6 +75,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( scale = *reinterpret_cast(params.scale); } compute_t amax = 0; + const bool requires_amax = params.amax != nullptr; for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { Ivec x[LDGS]; @@ -120,9 +121,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( compute_t b_ij = beta[it].data.elt[jt]; compute_t temp_output = g_ij * y_ij + b_ij; - if (params.fp8_out) { + if (requires_amax) { __builtin_assume(amax >= 0); amax = fmaxf(amax, fabsf(temp_output)); + } + if (params.fp8_out) { temp_output = temp_output * scale; } @@ -132,16 +135,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_tuned_kernel( idx += VEC_COLS_PER_LDG; } } - if (params.fp8_out) { - // Reduce amax over block - if (params.amax != nullptr) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + + // Reduce amax over block + if (requires_amax) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } + if (params.fp8_out) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { reciprocal(reinterpret_cast(params.scale_inv), scale); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 4ee2b42c3..499c0ef69 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -51,6 +51,10 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens bool is_aligned = true; bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); + if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { + cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output + } + bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index c63184739..da3f8192c 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -71,6 +71,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke scale = *reinterpret_cast(params.scale); } compute_t amax = 0; + const bool requires_amax = params.amax != nullptr; for (int row = r; row < params.rows; row += params.ctas_per_col * ROWS_PER_CTA) { Ivec x[LDGS]; @@ -112,9 +113,11 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke } compute_t temp_output = g_ij * y_ij; - if (params.fp8_out) { + if (requires_amax) { __builtin_assume(amax >= 0); amax = fmaxf(amax, fabsf(temp_output)); + } + if (params.fp8_out) { temp_output = temp_output * scale; } @@ -124,16 +127,17 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_tuned_ke idx += VEC_COLS_PER_LDG; } } - if (params.fp8_out) { - // Reduce amax over block - if (params.amax != nullptr) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + + // Reduce amax over block + if (requires_amax) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } + if (params.fp8_out) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { reciprocal(reinterpret_cast(params.scale_inv), scale); diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 45d4bf870..e5a1a2a78 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -110,8 +110,14 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); + std::tie(unquantized_out_cu, unquantized_out) = + my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); + } else { + NoneQuantizer q{none}; + std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + } } TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; @@ -139,7 +145,12 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { - my_quantizer->quantize(unquantized_out_cu, out_cu); + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); + my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); + } else { + my_quantizer->quantize(unquantized_out_cu, out_cu); + } } return {out, py::cast(mu), py::cast(rsigma)}; @@ -233,8 +244,14 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - NoneQuantizer q{none}; - std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); + std::tie(unquantized_out_cu, unquantized_out) = + my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); + } else { + NoneQuantizer q{none}; + std::tie(unquantized_out_cu, unquantized_out) = q.create_tensor(size, out_dtype); + } } TensorWrapper &kernel_out_cu = force_unfused_kernel ? unquantized_out_cu : out_cu; @@ -262,7 +279,12 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { - my_quantizer->quantize(unquantized_out_cu, out_cu); + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); + my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); + } else { + my_quantizer->quantize(unquantized_out_cu, out_cu); + } } return {out, py::none(), py::cast(rsigma)}; From 8dfdb9115b0a8b5dc8a1a03be128f0a1248c6872 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 31 Jul 2025 23:46:49 +0200 Subject: [PATCH 035/153] [PyTorch] Tutorial for the ONNX export (#1586) * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: Kirthi Shankar Sivamani --- docs/examples/onnx/onnx_export.ipynb | 345 +++++++++++++++++++++++++++ docs/examples/onnx/utils.py | 25 ++ docs/index.rst | 1 + 3 files changed, 371 insertions(+) create mode 100644 docs/examples/onnx/onnx_export.ipynb create mode 100644 docs/examples/onnx/utils.py diff --git a/docs/examples/onnx/onnx_export.ipynb b/docs/examples/onnx/onnx_export.ipynb new file mode 100644 index 000000000..91fc38003 --- /dev/null +++ b/docs/examples/onnx/onnx_export.ipynb @@ -0,0 +1,345 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Export to ONNX and inference using TensorRT\n", + "\n", + "
\n", + "\n", + "Note:\n", + "\n", + "Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.\n", + "\n", + "
\n", + "\n", + "Transformer Engine (TE) is a library designed primarily for training DL models in low precision. It is not specifically optimized for inference tasks, so other dedicated solutions should be used. NVIDIA provides several [inference tools](https://www.nvidia.com/en-us/solutions/ai/inference/) that enhance the entire inference pipeline. Two prominent NVIDIA inference SDKs are [TensorRT](https://github.com/NVIDIA/TensorRT) and [TensorRT-LLM](https://github.com/NVIDIA/TensorRT-LLM).\n", + "\n", + "This tutorial illustrates how one can export a PyTorch model to ONNX format and subsequently perform inference with TensorRT. This approach is particularly beneficial if model integrates Transformer Engine layers within more complex architectures. It's important to highlight that for Transformer-based large language models (LLMs), TensorRT-LLM could provide a more optimized inference experience. However, the ONNX-to-TensorRT approach described here may be more suitable for other models, such as diffusion-based architectures or vision transformers.\n", + "\n", + "#### Creating models with TE\n", + "\n", + "Let's begin by defining a simple model composed of layers both from Transformer Engine and standard PyTorch:\n" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import torch \n", + "import torch.nn as nn\n", + "import transformer_engine as te\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "# batch size, sequence length, hidden dimension\n", + "B, S, H = 256, 512, 256\n", + "\n", + "class Model(torch.nn.Module):\n", + " def __init__(self, hidden_dim=H, num_non_te_layers=16, num_te_layers=4, num_te_heads=4):\n", + " super(Model, self).__init__()\n", + " self.non_te_part = nn.Sequential(\n", + " *[nn.Sequential(nn.Linear(hidden_dim, hidden_dim), nn.GELU()) for _ in range(num_non_te_layers)]\n", + " )\n", + " self.te_part = nn.Sequential(\n", + " *[te.pytorch.TransformerLayer(hidden_dim, hidden_dim, num_te_heads) for _ in range(num_te_layers)]\n", + " )\n", + "\n", + " def forward(self, x):\n", + " x = self.non_te_part(x)\n", + " return self.te_part(x)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's run some simple inference benchmarks:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average inference time FP32: 0.065 ms\n", + "Average inference time FP8: 0.062 ms\n" + ] + } + ], + "source": [ + "from utils import _measure_time\n", + "\n", + "model = Model().eval().cuda()\n", + "inps = (torch.randn([S, B, H], device=\"cuda\"),)\n", + "def _inference(fp8_enabled):\n", + " with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8_enabled):\n", + " model(*inps)\n", + "\n", + "te_fp32_time = _measure_time(lambda: _inference(fp8_enabled=False))\n", + "te_fp8_time = _measure_time(lambda: _inference(fp8_enabled=True))\n", + "\n", + "print(f\"Average inference time FP32: {te_fp32_time} ms\")\n", + "print(f\"Average inference time FP8: {te_fp8_time} ms\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Exporting the TE Model to ONNX Format\n", + "\n", + "PyTorch developed a new [ONNX exporter](https://pytorch.org/docs/stable/onnx.html) built on TorchDynamo and plans to phase out the existing TorchScript exporter. As this feature is currently in active development, we recommend running this process with the latest PyTorch version.\n", + "\n", + "\n", + "To export a Transformer Engine model into ONNX format, follow these steps:\n", + "\n", + "- Conduct warm-up run within autocast using the recipe intended for export.\n", + "- Encapsulate your export-related code within `te.onnx_export`, ensuring warm-up runs remain outside this wrapper.\n", + "- Use the PyTorch Dynamo ONNX exporter by invoking: `torch.onnx.export(..., dynamo=True)`." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Exporting model_fp8.onnx\n", + "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...\n", + "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅\n", + "[torch.onnx] Run decomposition...\n", + "[torch.onnx] Run decomposition... ✅\n", + "[torch.onnx] Translate the graph into ONNX...\n", + "[torch.onnx] Translate the graph into ONNX... ✅\n", + "Applied 12 of general pattern rewrite rules.\n", + "Exporting model_fp32.onnx\n", + "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`...\n", + "[torch.onnx] Obtain model graph for `Model([...]` with `torch.export.export(..., strict=False)`... ✅\n", + "[torch.onnx] Run decomposition...\n", + "[torch.onnx] Run decomposition... ✅\n", + "[torch.onnx] Translate the graph into ONNX...\n", + "[torch.onnx] Translate the graph into ONNX... ✅\n", + "Applied 12 of general pattern rewrite rules.\n" + ] + } + ], + "source": [ + "from transformer_engine.pytorch.export import te_translation_table\n", + "\n", + "def export(model, fname, inputs, fp8=True):\n", + " with torch.no_grad(), te.pytorch.fp8_autocast(enabled=fp8):\n", + " # ! IMPORTANT !\n", + " # Transformer Engine models must have warm-up run\n", + " # before export. FP8 recipe during warm-up should \n", + " # match the recipe used during export.\n", + " model(*inputs)\n", + " \n", + " # Only dynamo=True mode is supported;\n", + " # dynamo=False is deprecated and unsupported.\n", + " #\n", + " # te_translation_table contains necessary ONNX translations\n", + " # for FP8 quantize/dequantize operators.\n", + " print(f\"Exporting {fname}\")\n", + " with te.pytorch.onnx_export(enabled=True):\n", + " torch.onnx.export(\n", + " model,\n", + " inputs,\n", + " fname,\n", + " output_names=[\"output\"],\n", + " dynamo=True,\n", + " custom_translation_table=te_translation_table\n", + " )\n", + "\n", + "# Example usage:\n", + "export(model, \"model_fp8.onnx\", inps, fp8=True)\n", + "export(model, \"model_fp32.onnx\", inps, fp8=False)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "#### Inference with TensorRT\n", + "\n", + "TensorRT is a high-performance deep learning inference optimizer and runtime developed by NVIDIA. It enables optimized deployment of neural network models by maximizing inference throughput and reducing latency on NVIDIA GPUs. TensorRT performs various optimization techniques, including layer fusion, precision calibration, kernel tuning, and memory optimization. \n", + "For detailed information and documentation, refer to the official [TensorRT documentation](https://developer.nvidia.com/tensorrt).\n", + "\n", + "When using TensorRT, ONNX model must first be compiled into a TensorRT engine. This compilation step involves converting the ONNX model into an optimized representation tailored specifically to the target GPU platform. The compiled engine file can then be loaded into applications for rapid and efficient inference execution." + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "!trtexec --onnx=model_fp32.onnx --saveEngine=model_fp32.engine > output_fp32.log 2>&1\n", + "!trtexec --onnx=model_fp8.onnx --saveEngine=model_fp8.engine > output_fp8.log 2>&1" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's run the benchmarks for inference:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Average inference time without TRT (FP32 for all layers): 0.065 ms\n", + "Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): 0.062 ms, speedup = 1.05x\n", + "Average inference time with TRT (FP32 for all layers): 0.0500 ms, speedup = 1.30x\n", + "Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers): 0.0470 ms, speedup = 1.38x\n" + ] + } + ], + "source": [ + "import tensorrt as trt\n", + "\n", + "# Output tensor is allocated - TRT needs static memory address.\n", + "output_tensor = torch.empty_like(model(*inps))\n", + "\n", + "# Loads TRT engine from file.\n", + "def load_engine(engine_file_path):\n", + " logger = trt.Logger(trt.Logger.WARNING)\n", + " runtime = trt.Runtime(logger)\n", + " \n", + " with open(engine_file_path, \"rb\") as f:\n", + " engine_data = f.read()\n", + " \n", + " engine = runtime.deserialize_cuda_engine(engine_data)\n", + " return engine\n", + "\n", + "def benchmark_inference(model_name):\n", + " engine = load_engine(model_name)\n", + " context = engine.create_execution_context()\n", + " stream = torch.cuda.Stream()\n", + " \n", + " # TRT need static input and output addresses.\n", + " # Here they are set.\n", + " for i in range(len(inps)):\n", + " context.set_tensor_address(engine.get_tensor_name(i), inps[i].data_ptr()) \n", + " context.set_tensor_address(\"output\", output_tensor.data_ptr())\n", + " \n", + " def _inference():\n", + " # The data is loaded from static input addresses\n", + " # and output is written to static output address.\n", + " context.execute_async_v3(stream_handle=stream.cuda_stream)\n", + " stream.synchronize()\n", + " \n", + " return _measure_time(_inference)\n", + "\n", + "\n", + "trt_fp8_time = benchmark_inference(\"model_fp8.engine\")\n", + "trt_fp32_time = benchmark_inference(\"model_fp32.engine\")\n", + "\n", + "print(f\"Average inference time without TRT (FP32 for all layers): {te_fp32_time} ms\")\n", + "print(f\"Average inference time without TRT (FP8 for TE layers, FP32 for non-TE layers): {te_fp8_time} ms, speedup = {te_fp32_time/te_fp8_time:.2f}x\")\n", + "print(f\"Average inference time with TRT (FP32 for all layers): {trt_fp32_time:.4f} ms, speedup = {te_fp32_time/trt_fp32_time:.2f}x\")\n", + "print(f\"Average inference time with TRT (FP8 for TE layers, FP32 for non-TE layers): {trt_fp8_time:.4f} ms, speedup = {te_fp32_time/trt_fp8_time:.2f}x\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "

\n", + "\n", + "\n", + "| Run | Inference Time (ms) | Speedup |\n", + "| ----------------------------------| ------------------- | ------------------- |\n", + "| PyTorch + TE | 0.065 | 1.00x |\n", + "| PyTorch + TE (FP8 for TE layers) | 0.062 | 1.05x |\n", + "| TRT | 0.0500 | 1.30x |\n", + "| TRT (FP8 for TE layers) | 0.047 | 1.38x |\n", + "\n", + "Note that this example highlights how TensorRT can speed up models composed of both TE and non-TE layers.\n", + "If a larger part of the model's layers were implemented with TE, the benefits of using FP8 for inference could be greater.\n", + "\n", + "

\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "We clearly observe performance improvements when using FP8 and the TensorRT inference engine. These improvements may become even more significant with more complex models, as TensorRT could potentially identify additional optimization opportunities.\n", + "\n", + "#### Appendix: Low Precision Operators in ONNX and TensorRT\n", + "\n", + "The ONNX standard does not currently support all precision types provided by the Transformer Engine. All available ONNX operators are listed on [this website](https://onnx.ai/onnx/operators/). Consequently, TensorRT and the Transformer Engine utilize certain specialized low-precision operators, detailed below.\n", + "\n", + "**TRT_FP8_QUANTIZE**\n", + "\n", + "- **Name**: TRT_FP8_QUANTIZE\n", + "- **Domain**: trt\n", + "- **Inputs**:\n", + " - `x`: float32 tensor\n", + " - `scale`: float32 scalar\n", + "- **Outputs**:\n", + " - `y`: int8 tensor\n", + "\n", + "Produces an int8 tensor that represents the binary encoding of FP8 values.\n", + "\n", + "**TRT_FP8_DEQUANTIZE**\n", + "\n", + "- **Name**: TRT_FP8_DEQUANTIZE\n", + "- **Domain**: trt\n", + "- **Inputs**:\n", + " - `x`: int8 tensor\n", + " - `scale`: float32 scalar\n", + "- **Outputs**:\n", + " - `y`: float32 tensor\n", + "\n", + "Converts FP8-encoded int8 tensor data back into float32 precision.\n", + "\n", + "
\n", + "\n", + "Note:\n", + "\n", + "Since standard ONNX operators do not support certain input and output precision types, a workaround is employed: tensors are dequantized to higher precision (float32) before input into these operators or quantized to lower precision after processing. TensorRT recognizes such quantize-dequantize patterns and replaces them with optimized operations. More details are available in [this section](https://docs.nvidia.com/deeplearning/tensorrt/latest/inference-library/work-quantized-types.html#tensorrt-processing-of-q-dq-networks) of the TensorRT documentation.\n", + "\n", + "
" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/docs/examples/onnx/utils.py b/docs/examples/onnx/utils.py new file mode 100644 index 000000000..7acf2ffc6 --- /dev/null +++ b/docs/examples/onnx/utils.py @@ -0,0 +1,25 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Utility functions for ONNX export. +""" + +import time +import torch + + +def _measure_time(f): + + time_taken = [] + num_iterations = 10 + f() # warm-up + + for _ in range(num_iterations): + start_time = time.time() + f() + torch.cuda.synchronize() + end_time = time.time() + time_taken.append(end_time - start_time) + return round(sum(time_taken) / num_iterations, 3) diff --git a/docs/index.rst b/docs/index.rst index bbdb4fea6..e678b1d46 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,6 +46,7 @@ Transformer Engine documentation examples/fp8_primer.ipynb examples/advanced_optimizations.ipynb examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb + examples/onnx/onnx_export.ipynb .. toctree:: :hidden: From 8e2d37e958c67641a45d993cda3ba3ba27985fb6 Mon Sep 17 00:00:00 2001 From: Autumn1998 <1515848689@qq.com> Date: Fri, 1 Aug 2025 10:38:08 +0800 Subject: [PATCH 036/153] [PyTorch] Fix corner case in router fuson (#2009) * fix bug if all values<0 Signed-off-by: tongliu * minor fix Signed-off-by: tongliu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tongliu Co-authored-by: tongliu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao --- tests/pytorch/test_fused_router.py | 55 ++++++++++----- .../common/fused_router/fused_moe_aux_loss.cu | 4 +- .../fused_score_for_moe_aux_loss.cu | 9 ++- .../fused_topk_with_score_function.cu | 6 +- .../common/fused_router/utils.h | 69 +++++++++++++------ 5 files changed, 95 insertions(+), 48 deletions(-) diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index aacad9081..fa134ba4b 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -148,11 +148,21 @@ def run_comparison( # Set some parameters if score_function == "sigmoid": # Construct the special logits to avoid inf in the sigmoid function - offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 - logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + logits = ( + torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + ) logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) else: - logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 + logits = ( + torch.arange( + -num_tokens * num_experts // 2, + num_tokens * num_experts // 2, + device="cuda", + dtype=dtype, + ) + * 1e-4 + ) logits = logits.view(num_tokens, num_experts) logits.requires_grad = True if enable_bias and score_function == "sigmoid": @@ -281,11 +291,21 @@ def test_topk_softmax( def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): if score_function == "sigmoid": # Construct the special logits to avoid inf in the sigmoid function - offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 - logits = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + logits = ( + torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 + ) logits = logits.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) else: - logits = torch.arange(num_tokens * num_experts, device="cuda", dtype=dtype) * 1e-4 + logits = ( + torch.arange( + -num_tokens * num_experts // 2, + num_tokens * num_experts // 2, + device="cuda", + dtype=dtype, + ) + * 1e-4 + ) logits = logits.view(num_tokens, num_experts) logits.requires_grad = True @@ -321,8 +341,8 @@ def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_f @pytest.mark.parametrize("topk", [4]) def test_fused_moe_aux_loss(dtype, num_tokens, num_experts, topk): # Construct the special probs to avoid inf in the sigmoid function - offset = torch.arange(0, num_tokens, dtype=dtype, device="cuda") * 1e-4 - probs = torch.arange(num_experts, device="cuda", dtype=dtype) * 1e-2 + offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 + probs = torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 probs = probs.unsqueeze(0).repeat(num_tokens, 1) + offset.unsqueeze(1) probs = probs.view(num_tokens, num_experts) probs.requires_grad = True @@ -379,15 +399,12 @@ def profile_topk_softmax( if __name__ == "__main__": - test_fused_scores_for_aux_loss( - dtype=torch.float32, num_tokens=2, num_experts=32, topk=8, score_function="softmax" + test_topk_softmax( + dtype=torch.float32, + num_tokens=1024, + num_experts=128, + topk=4, + use_pre_softmax=False, + group_topk=None, + scaling_factor=None, ) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=32, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=128, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=2048, num_experts=256, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=32, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=128, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=7168, num_experts=256, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=32, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=128, topk=4) - test_fused_moe_aux_loss(dtype=torch.float32, num_tokens=14234, num_experts=256, topk=4) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index 221963b11..f64b75d97 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -90,7 +90,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce to get the sum of aggregated_probs_per_expert */ CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id); __syncwarp(); if (lane_id == 0) { @@ -146,7 +146,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Reduce to get the sum of aggregated_probs_per_expert */ CompType intermediate_result = - warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, sum, lane_id); + warp_reduce_on_shmem(aggregated_probs_per_expert, num_cols, ReduceFuncType::SUM, lane_id); __syncwarp(); if (lane_id == 0) { diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 91a4bbb53..47d215057 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -107,7 +107,8 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi if (score_function == 0) { if (topk > 1) { - auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, sum, lane_id); + auto sum_logits = + warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_logits[i] = static_cast(static_cast(local_logits[i]) / (static_cast(sum_logits) + epsilon)); @@ -231,13 +232,15 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int */ // Sigmoid Post-processing bwd when topk > 1 if (topk > 1 && score_function == 0) { - auto sum_fwd_input = warp_reduce_on_shmem(local_act_from_fwd, num_experts, sum, lane_id); + auto sum_fwd_input = + warp_reduce_on_shmem(local_act_from_fwd, num_experts, ReduceFuncType::SUM, lane_id); // Put the result of output * grad to the comp_buf for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i]; } __syncwarp(); - auto sum_Output_x_Grad = warp_reduce_on_shmem(local_comp_buf, num_experts, sum, lane_id); + auto sum_Output_x_Grad = + warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id); // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_grad[i] = diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 06f97afc1..a1785c663 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -220,7 +220,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( // score_function == 0 means sigmoid if (score_function == 0) { if (topk > 1) { - double sum_scores = warp_reduce_on_shmem(topk_scores, topk, sum, lane_id); + double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { topk_scores[i] = static_cast(topk_scores[i]) / (sum_scores + epsilon); } @@ -362,7 +362,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( /*data ptr = */ local_act_from_fwd, /*mask ptr = */ local_routing_map, /*data size = */ num_experts, - /*reduce func = */ sum, lane_id); + /*reduce func = */ ReduceFuncType::SUM, lane_id); // Put the result of output * grad to the comp_buf for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { local_comp_buf[i] = (local_routing_map[i] ? static_cast(local_grad[i]) * @@ -374,7 +374,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( /*data ptr = */ local_comp_buf, /*mask ptr = */ local_routing_map, /*data size = */ num_experts, - /*reduce func = */ sum, lane_id); + /*reduce func = */ ReduceFuncType::SUM, lane_id); // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { if (local_routing_map[i]) { diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index b5d8c231b..46e0ba632 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -26,14 +26,28 @@ __device__ inline T sum(T a, T b) { return a + b; } +enum ReduceFuncType { + SUM, + MAX, +}; + template -__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, T (*reduce_func)(T, T), +__device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type, int lane_id) { + T (*reduce_func)(T, T); + double default_val = 0; + if (type == ReduceFuncType::SUM) { + reduce_func = sum; + default_val = 0; + } else if (type == ReduceFuncType::MAX) { + reduce_func = max; + default_val = -std::numeric_limits::infinity(); + } + // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread - volatile double val = - lane_id < data_size ? static_cast(data_ptr[lane_id]) : static_cast(0); + volatile double val = lane_id < data_size ? static_cast(data_ptr[lane_id]) : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { val = reduce_func(val, data_ptr[i]); } @@ -57,13 +71,22 @@ __device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, i template __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size, - T (*reduce_func)(T, T), int lane_id) { + ReduceFuncType type, int lane_id) { + T (*reduce_func)(T, T); + double default_val = 0; + if (type == ReduceFuncType::SUM) { + reduce_func = sum; + default_val = 0; + } else if (type == ReduceFuncType::MAX) { + reduce_func = max; + default_val = -std::numeric_limits::infinity(); + } + // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread - volatile double val = lane_id < data_size && mask[lane_id] - ? static_cast(data_ptr[lane_id]) - : static_cast(0); + volatile double val = + lane_id < data_size && mask[lane_id] ? static_cast(data_ptr[lane_id]) : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { if (mask[i]) { val = reduce_func(val, data_ptr[i]); @@ -108,7 +131,7 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_ float sum_Output_x_Grad = warp_reduce_on_shmem( /*data ptr = */ comp_buf, /*data size = */ data_size, - /*reduce func = */ sum, lane_id); + /*reduce func = */ ReduceFuncType::SUM, lane_id); // In-place update for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { if (mask) { @@ -127,14 +150,16 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_ template __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) { // 1. compute the max of value - float max_val = static_cast(warp_reduce_on_shmem(scores, data_size, max, lane_id)); + float max_val = + static_cast(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id)); // 2. value -> exp_value for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { scores[i] = static_cast(exp(static_cast(scores[i]) - max_val)); } __syncwarp(); // 3. compute the sum of exp_value - float sum_val = static_cast(warp_reduce_on_shmem(scores, data_size, sum, lane_id)); + float sum_val = + static_cast(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id)); // 4. update the softmax value for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { scores[i] = static_cast(scores[i]) / sum_val; @@ -145,19 +170,29 @@ __device__ inline void apply_softmax_on_float(DataType *scores, int data_size, i template __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, T *topk_scores, int lane_id) { + // Check if the index is masked by the later iteration + auto is_masked = [&topk_indices](int k, int index) { + if (k == 0) return false; + for (int i = 0; i < k; i++) { + if (topk_indices[i] == index) return true; + } + return false; + }; // Topk Times: Find the max value and its index // Then mask it, and record the index in the topk_indices // After looping topk times, the topk_indices will be the topk indices for (int k = 0; k < topk; k++) { // Find the max value and its index - volatile double val = - (lane_id < data_size) ? static_cast(scores[lane_id]) : static_cast(0); + volatile double val = (lane_id < data_size && !is_masked(k, lane_id)) + ? static_cast(scores[lane_id]) + : -std::numeric_limits::infinity(); volatile int index = (lane_id < data_size) ? lane_id : 0; // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - volatile double cur_val = scores[i]; + volatile double cur_val = (is_masked(k, i)) ? -std::numeric_limits::infinity() + : static_cast(scores[i]); if (cur_val > val) { val = cur_val; index = i; @@ -175,17 +210,9 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i if (lane_id == 0) { topk_indices[k] = index; topk_scores[k] = val; - scores[index] = - static_cast(-1.0) - val; // make the selected experts using val = - 1 - val } __syncwarp(); } - - // Reset the scores to the original value - for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - scores[topk_indices[i]] = - static_cast(-1.0) - static_cast(scores[topk_indices[i]]); - } } // Current TE only support float32/bf16/fp16, float64 probs should be considered in the future From 1258bbe022a54f14256123ed51c5e5b648dc0f4c Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 1 Aug 2025 20:39:48 +0800 Subject: [PATCH 037/153] Manually launch wgrad accumulation and reduce in backward_dw() instead of backward() (#1976) * disable wgrad accumulation and reduce in backward() And manually launch it in backward_dw() Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu * refactor Signed-off-by: Hongbin Liu * refactor Signed-off-by: Hongbin Liu * set skip_backward_post_hook to True only if delay_wgrad_compute is True Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu --------- Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/module/base.py | 18 +++++++++++++++-- .../pytorch/module/grouped_linear.py | 20 +++++++++++++------ .../pytorch/module/layernorm_linear.py | 5 +++++ .../pytorch/module/layernorm_mlp.py | 6 ++++++ transformer_engine/pytorch/module/linear.py | 5 +++++ 5 files changed, 46 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index e05e83df9..b0da6e5fc 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -582,6 +582,7 @@ def __init__(self) -> None: self.fsdp_group = None self._fp8_workspaces: Dict[str, QuantizedTensor] = {} self.activation_dtype: Optional[torch.dtype] = None + self.wgrad_accumulation_and_reduce_hooks = [] if not TEDebugState.debug_enabled: TEDebugState.initialize() @@ -1383,6 +1384,16 @@ def _load_from_state_dict( state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs ) + def register_wgrad_accumulation_and_reduce_hooks(self, wgrad_accumulation_and_reduce_hook): + """ + This method is used to manually control the weight gradient accumulation and reduce. + This method should be called before the backward() method. + Set the skip_wgrad_accumulation_and_reduce to True to skip the weight gradient accumulation + and reduce in backward(); + And register the wgrad_accumulation_and_reduce_func to be called in backward_dw() method. + """ + self.wgrad_accumulation_and_reduce_hooks.append(wgrad_accumulation_and_reduce_hook) + def backward_dw(self): """ Execute the delayed weight gradient computation. @@ -1393,14 +1404,17 @@ def backward_dw(self): with torch.cuda.nvtx.range(f"_{self.__class__.__name__}_wgrad"): (wgrad, bgrad), _ = self.wgrad_store.pop() if not self.fuse_wgrad_accumulation: - unfused_weights = [getattr(self, name) for name in self.weight_names] - weight_tensor = noop_cat(unfused_weights) + weight_tensor = noop_cat(self._get_weight_tensors()) if weight_tensor.grad is None: weight_tensor.grad = wgrad.to(weight_tensor.dtype) if self.use_bias: bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) if bias_tensor.grad is None: bias_tensor.grad = bgrad.to(bias_tensor.dtype) + del wgrad + del bgrad + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() def _validate_name(self): """ diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index cc472390f..3d7a5efac 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -662,6 +662,12 @@ def __init__( self.reset_parameters(defer_init=device == "meta") + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + for i in range(self.num_gemms): + if name in (f"weight{i}", f"bias{i}"): + param.skip_backward_post_hook = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) @@ -819,19 +825,21 @@ def backward_dw(self): with torch.cuda.nvtx.range("_GroupedLinear_wgrad"): (_, grad_biases_, _), tensor_list = self.wgrad_store.pop() wgrad_list = tensor_list[2] + weight_params = [getattr(self, f"weight{i}") for i in range(self.num_gemms)] + bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): - weight_param = getattr(self, f"weight{i}") - if weight_param.grad is None: - weight_param.grad = wgrad_list[i].to(weight_param.dtype) + if weight_params[i].grad is None: + weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) if self.use_bias: for i in range(self.num_gemms): - bias_param = getattr(self, f"bias{i}") - if bias_param.grad is None: - bias_param.grad = grad_biases_[i].to(bias_param.dtype) + if bias_params[i].grad is None: + bias_params[i].grad = grad_biases_[i].to(bias_params[i].dtype) del grad_biases_ del wgrad_list del tensor_list + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe) -> None: """Customize quantizers based on current scaling recipe + linear.""" diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 659fcd0e1..5e45b5c25 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1382,6 +1382,11 @@ def __init__( self.bwd_ln_sm_margin = int(os.getenv("NVTE_BWD_LAYERNORM_SM_MARGIN", "0")) self.inf_ln_sm_margin = int(os.getenv("NVTE_INF_LAYERNORM_SM_MARGIN", "0")) + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + if name in self.weight_names or name in self.bias_names: + param.skip_backward_post_hook = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index cec74aa81..31ba65478 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1642,6 +1642,10 @@ def __init__( warmup_jit_bias_gelu_all_dtypes( self.size_per_partition, seq_length, micro_batch_size ) + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + if name in ["fc1_weight", "fc2_weight", "fc1_bias", "fc2_bias"]: + param.skip_backward_post_hook = True # These many SMs are subtracted from the total SM count when calling forward # and backward LayerNorm C APIs. These envvars can be used to prevent the LN @@ -2152,3 +2156,5 @@ def backward_dw(self): del fc2_wgrad del fc1_wgrad del fc1_bias_grad + for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: + wgrad_accumulation_and_reduce_hook() diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 5b657e848..f2a6871a8 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1270,6 +1270,11 @@ def __init__( else: self.gemm_bias_unfused_add = False + if self.wgrad_store.delay_wgrad_compute(): + for name, param in self.named_parameters(): + if name in self.weight_names or name in self.bias_names: + param.skip_backward_post_hook = True + def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None: """Init scales and amaxes for fwd | bwd.""" super().set_meta_tensor(fwd, recipe) From c444bf5340fa91aa95152e2a25a49c2eda4c37f6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Fri, 1 Aug 2025 16:07:11 +0200 Subject: [PATCH 038/153] [PyTorch Debug] Fix debug tests (#2021) fix Signed-off-by: Pawel Gadzinski --- qa/L0_pytorch_debug_unittest/test.sh | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index c94edba2b..6c0c79251 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -21,7 +21,7 @@ pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_ NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 # standard sanity and numerics tests with initialized debug -NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 -NVTE_TEST_NVINSPECT_ENABLED=True NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 +NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py || FAIL=1 exit $FAIL From 13cae89e02111d5af619847b132859db5e790917 Mon Sep 17 00:00:00 2001 From: Shang Zhang Date: Fri, 1 Aug 2025 13:58:12 -0700 Subject: [PATCH 039/153] Tensor numel() return dtype to be size_t (#2022) Fix tensor numel() return dtype The original dytpe int would be an issue if the tensor element numbers are larger than int32's range (which is not a super large number in real workloads) Signed-off-by: Shang Zhang --- transformer_engine/common/common.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 08001671d..aa47f2c3d 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -88,7 +88,7 @@ struct SimpleTensor { nvte_make_shape(this->shape.data(), this->shape.size())}; } - int numel() const { + size_t numel() const { size_t acc = 1; for (const auto &dim : shape) { acc *= dim; From 1f2df735acaf229719581c5e94cb437b33589620 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 1 Aug 2025 15:49:51 -0700 Subject: [PATCH 040/153] Fix JAX and PyTorch wheel builds for v2.6 (#2005) * Fix L0_jax_wheel Signed-off-by: Jeremy Berchtold * Update Signed-off-by: Jeremy Berchtold * remove commented line Signed-off-by: Jeremy Berchtold * Reduce usage of --no-deps Signed-off-by: Jeremy Berchtold * Also fix pytorch wheel build Signed-off-by: Jeremy Berchtold * Revert test_sanity_import.py changes as it is also used on CPU-only GitHub build jobs Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- qa/L0_jax_wheel/test.sh | 13 +++++++------ qa/L0_pytorch_wheel/test.sh | 12 ++++++------ 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/qa/L0_jax_wheel/test.sh b/qa/L0_jax_wheel/test.sh index e1400b10b..bf9e4a461 100644 --- a/qa/L0_jax_wheel/test.sh +++ b/qa/L0_jax_wheel/test.sh @@ -26,23 +26,24 @@ pip3 uninstall -y transformer-engine transformer-engine-cu12 transformer-engine- VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" + # Core wheel. -NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" -wheel unpack dist/* || error_exit "Failed to unpack dist/*" +NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel" +wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" rm dist/*.whl || error_exit "Failed to remove dist/*.whl" mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage" cd transformer_engine/jax -NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" +NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup sdist" -pip3 install dist/* || error_exit "Failed to install dist/*" +pip3 install --no-build-isolation --no-deps -vvv dist/* || error_exit "Failed to install dist/*" cd $TE_PATH -pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" +pip3 install --no-build-isolation --no-deps -vvv dist/*.whl || error_exit "Failed to install dist/*.whl --no-deps" python3 $TE_PATH/tests/jax/test_sanity_import.py || test_fail "test_sanity_import.py" diff --git a/qa/L0_pytorch_wheel/test.sh b/qa/L0_pytorch_wheel/test.sh index ffd5ca290..3056547ef 100644 --- a/qa/L0_pytorch_wheel/test.sh +++ b/qa/L0_pytorch_wheel/test.sh @@ -27,22 +27,22 @@ VERSION=`cat $TE_PATH/build_tools/VERSION.txt` WHL_BASE="transformer_engine-${VERSION}" # Core wheel. -NVTE_RELEASE_BUILD=1 python3 setup.py bdist_wheel || error_exit "Failed to setup bdist_wheel" -wheel unpack dist/* || error_exit "Failed to unpack dist/*" +NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation -vvv --wheel-dir ./dist . || error_exit "Failed to setup bdist_wheel" +wheel unpack dist/${WHL_BASE}-* || error_exit "Failed to unpack dist/${WHL_BASE}-*.whl" sed -i "s/Name: transformer-engine/Name: transformer-engine-cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" sed -i "s/Name: transformer_engine/Name: transformer_engine_cu12/g" "transformer_engine-${VERSION}/transformer_engine-${VERSION}.dist-info/METADATA" mv "${WHL_BASE}/${WHL_BASE}.dist-info" "${WHL_BASE}/transformer_engine_cu12-${VERSION}.dist-info" || error_exit "Failed to move ${WHL_BASE}.dist-info to transformer_engine_cu12-${VERSION}.dist-info" wheel pack ${WHL_BASE} || error_exit "Failed to pack ${WHL_BASE}" rm dist/*.whl || error_exit "Failed to remove dist/*.whl" mv *.whl dist/ || error_exit "Failed to move *.whl to dist/" -NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 python3 setup.py bdist_wheel || error_exit "Failed to setup metapackage" +NVTE_RELEASE_BUILD=1 NVTE_BUILD_METAPACKAGE=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup metapackage" cd transformer_engine/pytorch -NVTE_RELEASE_BUILD=1 python3 setup.py sdist || error_exit "Failed to setup sdist" +NVTE_RELEASE_BUILD=1 pip3 wheel --no-build-isolation --no-deps -vvv --wheel-dir ./dist . || error_exit "Failed to setup sdist" -pip3 install dist/* || error_exit "Failed to install dist/*" +pip3 install --no-build-isolation --no-deps -vvv dist/* || error_exit "Failed to install dist/*" cd $TE_PATH -pip3 install dist/*.whl --no-deps || error_exit "Failed to install dist/*.whl --no-deps" +pip3 install --no-build-isolation --no-deps -vvv dist/*.whl || error_exit "Failed to install dist/*.whl --no-deps" python3 $TE_PATH/tests/pytorch/test_sanity_import.py || test_fail "test_sanity_import.py" From c3f8a9f5cd593905d431cc87e4a222e61c872c13 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 4 Aug 2025 11:19:15 -0700 Subject: [PATCH 041/153] [Core] Kernel that swaps first two tensor dimensions (#1998) * Add basic kernel for swapping first two tensor dims Signed-off-by: Tim Moon * Add NVRTC kernel for swapping first dims Signed-off-by: Tim Moon * Add PyTorch extension for swap first dims kernel Signed-off-by: Tim Moon * Tweak variable names Signed-off-by: Tim Moon * Tune kernel Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Make sure writes are contiguous Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/cpp/operator/CMakeLists.txt | 1 + tests/cpp/operator/test_swap_first_dims.cu | 112 ++++++++++ transformer_engine/common/CMakeLists.txt | 3 + .../include/transformer_engine/transpose.h | 8 + .../common/transpose/rtc/swap_first_dims.cu | 37 ++++ .../common/transpose/swap_first_dims.cu | 209 ++++++++++++++++++ transformer_engine/pytorch/csrc/extensions.h | 2 + .../pytorch/csrc/extensions/pybind.cpp | 3 + .../pytorch/csrc/extensions/transpose.cpp | 25 +++ 9 files changed, 400 insertions(+) create mode 100644 tests/cpp/operator/test_swap_first_dims.cu create mode 100644 transformer_engine/common/transpose/rtc/swap_first_dims.cu create mode 100644 transformer_engine/common/transpose/swap_first_dims.cu diff --git a/tests/cpp/operator/CMakeLists.txt b/tests/cpp/operator/CMakeLists.txt index ff889c281..498c1d394 100644 --- a/tests/cpp/operator/CMakeLists.txt +++ b/tests/cpp/operator/CMakeLists.txt @@ -28,6 +28,7 @@ add_executable(test_operator test_multi_unpadding.cu test_causal_softmax.cu test_swizzle.cu + test_swap_first_dims.cu ../test_common.cu) find_package(OpenMP REQUIRED) diff --git a/tests/cpp/operator/test_swap_first_dims.cu b/tests/cpp/operator/test_swap_first_dims.cu new file mode 100644 index 000000000..4c2cf415f --- /dev/null +++ b/tests/cpp/operator/test_swap_first_dims.cu @@ -0,0 +1,112 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include + +#include +#include + +#include +#include "../test_common.h" + +using namespace transformer_engine; + +namespace { + +template +void compute_ref(const Type *input, Type *output, + const std::vector &shape) { + const size_t dim0 = shape[0]; + const size_t dim1 = shape[1]; + size_t dim2 = 1; + for (size_t i = 2; i < shape.size(); ++i) { + dim2 *= shape[i]; + } + for (size_t i = 0; i < dim0; ++i) { + for (size_t j = 0; j < dim1; ++j) { + for (size_t k = 0; k < dim2; ++k) { + const size_t in_offset = i * dim1 * dim2 + j * dim2 + k; + const size_t out_offset = j * dim0 * dim2 + i * dim2 + k; + output[out_offset] = input[in_offset]; + } + } + } +} + +template +void performTest(const std::vector &in_shape) { + using namespace test; + + DType dtype = TypeInfo::dtype; + + // Tensor dimensions + std::vector out_shape = in_shape; + out_shape[0] = in_shape[1]; + out_shape[1] = in_shape[0]; + size_t numel = 1; + for (const auto& dim : in_shape) { + numel *= dim; + } + + // Transformer engine implementation + Tensor input("input", in_shape, dtype); + Tensor output("output", out_shape, dtype); + fillUniform(&input); + nvte_swap_first_dims(input.data(), output.data(), 0); + + // Reference implementation + std::unique_ptr ref_output = std::make_unique(numel); + compute_ref(input.rowwise_cpu_dptr(), ref_output.get(), in_shape); + + // Check for CUDA failure + cudaDeviceSynchronize(); + auto err = cudaGetLastError(); + ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); + + // Check for exact numerics + compareResults("output", output, ref_output.get(), true, 0, 0); +} + +std::vector> test_cases = {{4, 64, 1280}, + {48, 8, 128, 16}, + {229, 173}, // Primes 50, 40 + {113, 71, 1, 1, 1, 29, 1, 1}}; // Primes 30, 20, 10 +} // namespace + +class SwapFirstDimsTestSuite : public ::testing::TestWithParam>> {}; + +TEST_P(SwapFirstDimsTestSuite, TestSwapFirstDims) { + using namespace transformer_engine; + using namespace test; + + const DType type = std::get<0>(GetParam()); + const auto shape = std::get<1>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T, + performTest(shape); + ); +} + + + +INSTANTIATE_TEST_SUITE_P( + OperatorTest, + SwapFirstDimsTestSuite, + ::testing::Combine( + ::testing::ValuesIn(test::all_fp_types), + ::testing::ValuesIn(test_cases)), + [](const testing::TestParamInfo& info) { + std::string name = test::typeName(std::get<0>(info.param)); + for (const auto& dim : std::get<1>(info.param)) { + name += "X"; + name += std::to_string(dim); + } + return name; + }); diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index aff282214..b51e61929 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -67,6 +67,7 @@ list(APPEND transformer_engine_SOURCES transpose/multi_cast_transpose.cu transpose/quantize_transpose_square_blockwise.cu transpose/quantize_transpose_vector_blockwise.cu + transpose/swap_first_dims.cu activation/gelu.cu fused_attn/flash_attn.cu fused_attn/context_parallel.cu @@ -166,6 +167,8 @@ make_string_header_from_file(transpose/rtc/cast_transpose.cu string_code_transpose_rtc_cast_transpose_cu) make_string_header_from_file(transpose/rtc/transpose.cu string_code_transpose_rtc_transpose_cu) +make_string_header_from_file(transpose/rtc/swap_first_dims.cu + string_code_transpose_rtc_swap_first_dims_cu) make_string_header_from_file(utils.cuh string_code_utils_cuh) make_string_header_from_file(util/math.h diff --git a/transformer_engine/common/include/transformer_engine/transpose.h b/transformer_engine/common/include/transformer_engine/transpose.h index a7db5cba4..cc069ee3e 100644 --- a/transformer_engine/common/include/transformer_engine/transpose.h +++ b/transformer_engine/common/include/transformer_engine/transpose.h @@ -318,6 +318,14 @@ void nvte_dqgeglu_cast_transpose(const NVTETensor input, const NVTETensor act_in void nvte_dsreglu_cast_transpose(const NVTETensor input, const NVTETensor act_input, NVTETensor output, cudaStream_t stream); +/*! \brief Swap the first two tensor dimensions. + * + * \param[in] input Input tensor of shape [M, N, ...]. + * \param[out] output Output tensor of shape [N, M, ...]. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/transpose/rtc/swap_first_dims.cu b/transformer_engine/common/transpose/rtc/swap_first_dims.cu new file mode 100644 index 000000000..89a07697a --- /dev/null +++ b/transformer_engine/common/transpose/rtc/swap_first_dims.cu @@ -0,0 +1,37 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "utils.cuh" + +using namespace transformer_engine; + +namespace { + +// Parameters +using VectorType = BytesToType<__VECTOR_SIZE__>::Type; +constexpr size_t block_size = __BLOCK_SIZE__; + +} // namespace + +__global__ void __launch_bounds__(block_size) + swap_first_dims_kernel(const VectorType* __restrict__ const input, + VectorType* __restrict__ const output, const size_t dim0, + const size_t dim1, const size_t dim2) { + const size_t gid = threadIdx.x + blockIdx.x * block_size; +#if __SINGLE_LOAD_STORE__ + const auto idx = gid; +#else + const size_t nthreads = gridDim.x * block_size; + for (size_t idx = gid; idx < dim0 * dim1 * dim2; idx += nthreads) +#endif // __SINGLE_LOAD_STORE__ + { + const auto idx2 = idx % dim2; + const auto idx1 = (idx / dim2) % dim1; + const auto idx0 = (idx / dim2) / dim1; + const auto in_offset = idx1 * dim0 * dim2 + idx0 * dim2 + idx2; + output[idx] = input[in_offset]; + } +} diff --git a/transformer_engine/common/transpose/swap_first_dims.cu b/transformer_engine/common/transpose/swap_first_dims.cu new file mode 100644 index 000000000..08249a823 --- /dev/null +++ b/transformer_engine/common/transpose/swap_first_dims.cu @@ -0,0 +1,209 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include + +#include +#include +#include + +#include "../common.h" +#include "../util/cuda_runtime.h" +#include "../util/logging.h" +#include "../util/rtc.h" +#include "../util/string.h" + +namespace transformer_engine { + +namespace { + +// String with RTC kernel implementation +#include "string_code_transpose_rtc_swap_first_dims_cu.h" + +// Hard-coded kernel parameters +constexpr size_t block_size = 128; + +/* Performance heuristics for optimized kernel parameters */ +struct KernelConfig { + /* Vector load/store size */ + size_t vector_size; + + /* Whether config is valid */ + bool valid = false; + /* Number of CUDA blocks */ + size_t num_blocks = 0; + /* Whether each thread needs to make exactly one load/store */ + bool single_load_store = true; + + /* Number of active SMs */ + size_t active_sm_count = 0; + /* Used bytes per L1 cache load */ + size_t bytes_per_load = 0; + /* Used bytes per L1 cache store */ + size_t bytes_per_store = 0; + + KernelConfig(size_t dim0, size_t dim1, size_t dim2, size_t sm_count, size_t vector_size_) + : vector_size{vector_size_} { + // Check that tiles are correctly aligned + if (dim2 % vector_size_ != 0) { + return; + } + valid = true; + + // Number of CUDA blocks + num_blocks = DIVUP(dim0 * dim1 * dim2 / vector_size, block_size); + if (num_blocks > 2147483647ull) { + // Maximum number of CUDA blocks + single_load_store = false; + num_blocks = 2147483647ull; + } else if (num_blocks * block_size != dim0 * dim1 * dim2 / vector_size) { + single_load_store = false; + } + + // SM occupancy + constexpr size_t warp_size = 32; + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * block_size / warp_size, warps_per_sm), sm_count); + + // L1 cache efficiency + constexpr size_t cache_line_size = 128; + bytes_per_store = std::min(cache_line_size, warp_size * vector_size); // Contiguous writes + bytes_per_load = bytes_per_store; + if (dim2 % (vector_size * warp_size) != 0) { + // Some warps are reading from two non-contiguous regions + bytes_per_load /= 2; + } + } + + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/bytes_per_load + 1/bytes_per_store) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->bytes_per_load; + const auto &s1 = this->bytes_per_store; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.bytes_per_load; + const auto &s2 = other.bytes_per_store; + const auto &p2 = other.active_sm_count; + const auto scale = l1 * s1 * p1 * l2 * s2 * p2; + const auto cost1 = (scale / l1 + scale / s1) / p1; + const auto cost2 = (scale / l2 + scale / s2) / p2; + return cost1 < cost2; + } else { + return this->valid && !other.valid; + } + } +}; + +template +__global__ void __launch_bounds__(block_size) + swap_first_dims_untuned_kernel(const Type *__restrict__ input, Type *__restrict__ output, + const size_t dim0, const size_t dim1, const size_t dim2) { + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + for (size_t idx = gid; idx < dim0 * dim1 * dim2; idx += nthreads) { + const auto idx2 = idx % dim2; + const auto idx1 = (idx / dim2) % dim1; + const auto idx0 = (idx / dim2) / dim1; + const auto in_offset = idx1 * dim0 * dim2 + idx0 * dim2 + idx2; + output[idx] = input[in_offset]; + } +} + +} // namespace + +void swap_first_dims(const Tensor &input, Tensor &output, cudaStream_t stream) { + // Check tensors + NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be simple tensor, but scaling mode is ", + to_string(input.scaling_mode), "."); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor must be simple tensor, but scaling mode is ", + to_string(output.scaling_mode), "."); + NVTE_CHECK(input.dtype() == output.dtype(), "Input tensor (dtype=", to_string(input.dtype()), + ") and output tensor (dtype=", to_string(output.dtype()), ") do not match."); + NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); + NVTE_CHECK(output.data.dptr != nullptr, "Output is not allocated."); + + // Check tensor dimensions + const auto input_shape = input.shape(); + const auto output_shape = output.shape(); + NVTE_CHECK(input_shape.size() >= 2, "Invalid input tensor dimensions (shape=", input_shape, ")."); + NVTE_CHECK(output_shape.size() == input_shape.size(), "Input tensor (shape=", input_shape, + ") and output tensor (shape=", output_shape, ") do not match."); + NVTE_CHECK(input_shape[0] == output_shape[1], "Input tensor (shape=", input_shape, + ") and output tensor (shape=", output_shape, ") do not match."); + NVTE_CHECK(input_shape[1] == output_shape[0], "Input tensor (shape=", input_shape, + ") and output tensor (shape=", output_shape, ") do not match."); + for (size_t i = 2; i < input_shape.size(); ++i) { + NVTE_CHECK(input_shape[i] == output_shape[i], "Input tensor (shape=", input_shape, + ") and output tensor (shape=", output_shape, ") do not match."); + } + + // Reinterpret tensors as 3D tensors of bytes + const size_t dim0 = output_shape[0]; + const size_t dim1 = output_shape[1]; + size_t dim2 = 1; + for (size_t i = 2; i < output_shape.size(); ++i) { + dim2 *= output_shape[i]; + } + dim2 = get_buffer_size_bytes(dim2, output.dtype()); + + // Choose kernel config with performance heuristics + const size_t sm_count = static_cast(cuda::sm_count()); + KernelConfig config(dim0, dim1, dim2, sm_count, 1); + if (rtc::is_enabled()) { + auto try_config = [&](size_t vector_size) { + KernelConfig new_config(dim0, dim1, dim2, sm_count, vector_size); + if (new_config < config) { + config = new_config; + } + }; + try_config(16); + try_config(8); + try_config(4); + try_config(2); + } + const size_t vector_size = config.vector_size; + + // Launch kernel + if (vector_size == 1) { + // General kernel + swap_first_dims_untuned_kernel<<>>( + static_cast(input.data.dptr), static_cast(output.data.dptr), + dim0, dim1, dim2); + NVTE_CHECK_CUDA(cudaGetLastError()); + } else { + // Compile NVRTC kernel if needed + auto &rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = + concat_strings("swap_first_dims,vector_size=", vector_size, + ",single_load_store=", config.single_load_store); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_swap_first_dims_cu; + code = regex_replace(code, "__VECTOR_SIZE__", vector_size); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + code = + regex_replace(code, "__SINGLE_LOAD_STORE__", static_cast(config.single_load_store)); + rtc_manager.compile(kernel_label, "swap_first_dims_kernel", code, + "transformer_engine/common/transpose/rtc/swap_first_dims.cu"); + } + + // Launch NVRTC kernel + rtc_manager.launch(kernel_label, config.num_blocks, block_size, 0, stream, input.data.dptr, + output.data.dptr, dim0, dim1, dim2 / vector_size); + } +} + +} // namespace transformer_engine + +void nvte_swap_first_dims(const NVTETensor input, NVTETensor output, cudaStream_t stream) { + NVTE_API_CALL(nvte_swap_first_dims); + using namespace transformer_engine; + swap_first_dims(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), stream); +} diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 25e858222..0b2ace76a 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -143,6 +143,8 @@ std::optional> te_general_grouped_gemm( at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional output = std::nullopt); +at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = std::nullopt); + /*************************************************************************************************** * Activations **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index c9b5a67a7..af06bb9fb 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -211,6 +211,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { m.def("fp8_transpose", &transformer_engine::pytorch::fp8_transpose, "Transpose with FP8 I/O", py::arg("input"), py::arg("dtype"), py::kw_only(), py::arg("out"), py::call_guard()); + m.def("swap_first_dims", &transformer_engine::pytorch::swap_first_dims, + "Swap first two tensor dimensions", py::arg("tensor"), py::kw_only(), py::arg("out"), + py::call_guard()); m.def("get_fused_attn_backend", &transformer_engine::pytorch::get_fused_attn_backend, "Get Fused Attention backend", py::call_guard()); m.def("compute_amax", &transformer_engine::pytorch::compute_amax, diff --git a/transformer_engine/pytorch/csrc/extensions/transpose.cpp b/transformer_engine/pytorch/csrc/extensions/transpose.cpp index d6ae0c86a..7dfdf9954 100644 --- a/transformer_engine/pytorch/csrc/extensions/transpose.cpp +++ b/transformer_engine/pytorch/csrc/extensions/transpose.cpp @@ -52,5 +52,30 @@ at::Tensor fp8_transpose(at::Tensor input, DType otype, std::optional out) { + init_extension(); + + // Make sure input is contiguous + const auto &input = tensor.contiguous(); + + // Allocate output tensor if needed + if (!out) { + auto in_shape = getTensorShape(input); + NVTE_CHECK(in_shape.size() >= 2, "Invalid input tensor dimensions (shape=", in_shape, ")"); + std::vector out_shape_int64(in_shape.begin(), in_shape.end()); + out_shape_int64[0] = static_cast(in_shape[1]); + out_shape_int64[1] = static_cast(in_shape[0]); + auto opts = at::TensorOptions().dtype(input.dtype()).device(input.device()); + out = at::empty(out_shape_int64, opts); + } + + // Launch kernel + const TensorWrapper te_input = makeTransformerEngineTensor(input); + TensorWrapper te_output = makeTransformerEngineTensor(*out); + nvte_swap_first_dims(te_input.data(), te_output.data(), at::cuda::getCurrentCUDAStream()); + + return std::move(*out); +} + } // namespace pytorch } // namespace transformer_engine From 06947e87b5511f8ad69ccd00286de9227f0fad24 Mon Sep 17 00:00:00 2001 From: buptzyb Date: Tue, 5 Aug 2025 04:24:11 +0800 Subject: [PATCH 042/153] [PyTorch] Fix cudagraph static_input and static_grad_input reuse (#2018) * fix graph static grad input reuse Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/graph.py | 78 ++++++++++++++++++----------- 1 file changed, 49 insertions(+), 29 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 432a47985..866f0b639 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -179,24 +179,17 @@ def _make_graphed_callables( assert isinstance( sample_args, list ), "sample_args must be a list for _reuse_graph_input_output_buffers." - len_args = len(sample_args[0]) - for i, arg in enumerate(sample_args): - assert len_args == len( - arg - ), "Arguments must have same length and shape for `_reuse_graph_input_output_buffers`." - len_kwargs = len(sample_kwargs[0]) - assert isinstance( - sample_kwargs, list - ), "sample_kwargs must be a list for _reuse_graph_input_output_buffers." - for i, kwarg in enumerate(sample_kwargs): - assert len_kwargs == len(kwarg), ( - "Keyword arguments must have same length and shape for" - " `_reuse_graph_input_output_buffers`." - ) # Reorganize args and kwargs for input tensor reuse. + # fwd_sample_qs is keyed by model chunk index. The value is a queue of tuples. + # Each tuple contains the sample key signature and its fwd_idx. When we finish a backward + # chunk, we pop the corresponding fwd_idx and push to the consumed_sample_q. + # consumed_sample_q is keyed by the sample key signature. The value is a queue of the + # fwd_idx whose backward has been called so that we can reuse the same static buffers. + # In this way, we can reuse the same static input buffers for the non-overlapping samples + # with the same input signature. fwd_sample_qs = {} - consumed_sample_q = [] + consumed_sample_q = {} fwd_idx = [0] * num_model_chunks for c_id in _order: m_chunk = abs(c_id) - 1 @@ -208,10 +201,21 @@ def _make_graphed_callables( fwd_sample_idx = [ sample_start_idx + i for i in range(_num_layers_per_chunk[m_chunk]) ] - fwd_sample_qs[m_chunk] = fwd_sample_qs.get(m_chunk, []) + fwd_sample_idx + if m_chunk not in fwd_sample_qs: + fwd_sample_qs[m_chunk] = [] for per_callable_fwd_idx in fwd_sample_idx: - if consumed_sample_q: - reuse_fwd_idx = consumed_sample_q.pop(0) + sample_args_keys = tuple( + (t.shape, t.dtype, t.layout) for t in sample_args[per_callable_fwd_idx] + ) + sample_kwargs_keys = tuple( + (k, v.shape, v.dtype, v.layout) + for k, v in sorted(sample_kwargs[per_callable_fwd_idx].items()) + ) + sample_keys = sample_args_keys + sample_kwargs_keys + + fwd_sample_qs[m_chunk].append((sample_keys, per_callable_fwd_idx)) + if consumed_sample_q.get(sample_keys, []): + reuse_fwd_idx = consumed_sample_q[sample_keys].pop(0) sample_args[per_callable_fwd_idx] = sample_args[reuse_fwd_idx] sample_kwargs[per_callable_fwd_idx] = sample_kwargs[reuse_fwd_idx] fwd_idx[m_chunk] += 1 @@ -219,7 +223,12 @@ def _make_graphed_callables( num_consumed_samples = min( len(fwd_sample_qs[m_chunk]), _num_layers_per_chunk[m_chunk] ) - consumed_sample_q += fwd_sample_qs[m_chunk][:num_consumed_samples] + for sample_keys, per_callable_fwd_idx in fwd_sample_qs[m_chunk][ + :num_consumed_samples + ]: + if sample_keys not in consumed_sample_q: + consumed_sample_q[sample_keys] = [] + consumed_sample_q[sample_keys].append(per_callable_fwd_idx) fwd_sample_qs[m_chunk] = fwd_sample_qs[m_chunk][num_consumed_samples:] if fp8_weight_caching: @@ -423,7 +432,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument fwd_idx = [0] * num_model_chunks bwd_idx = [0] * num_model_chunks static_grad_outputs_dict = {} - previous_per_callable_bwd_idx = None + previous_chunk_last_callable_bwd_idx = None for c_id in _order: if c_id > 0: # Capture forward graph for model chunk c_id, microbatch fwd_idx[c_id-1] @@ -446,6 +455,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument else: # Capture backward graph for model chunk c_id, microbatch bwd_idx[-c_id-1] m_chunk = -c_id - 1 + previous_per_callable_bwd_idx = None for l_no in list(reversed(range(_num_layers_per_chunk[m_chunk]))): per_callable_bwd_idx = (_prefix_num_layers[m_chunk] * num_microbatches) + ( bwd_idx[m_chunk] * _num_layers_per_chunk[m_chunk] + l_no @@ -508,19 +518,29 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_static_outputs[per_callable_bwd_idx] = make_weak_ref( static_outputs ) - # Weak ref the static grad inputs of the previous backward pass. - # Note: After a backward pass, we assume Mcore will send the - # grad input to another pipeline parallel rank and that the - # communication is finished before the end of the next backward - # pass. + + # Weak ref the static grad inputs of the previous backward pass within the + # same chunk. if previous_per_callable_bwd_idx is not None: - per_callable_static_grad_inputs[previous_per_callable_bwd_idx] = ( - make_weak_ref( - per_callable_static_grad_inputs[previous_per_callable_bwd_idx] - ) + idx = previous_per_callable_bwd_idx + per_callable_static_grad_inputs[idx] = make_weak_ref( + per_callable_static_grad_inputs[idx] ) previous_per_callable_bwd_idx = per_callable_bwd_idx + # Weak ref the static grad inputs of the previous chunk's last backward + # pass. + # Note: After a chunk's backward pass, we assume Mcore will send the grad + # input to another pipeline parallel rank and that the communication is + # finished before the end of the next chunk's backward pass. + if l_no == 0: + if previous_chunk_last_callable_bwd_idx is not None: + idx = previous_chunk_last_callable_bwd_idx + per_callable_static_grad_inputs[idx] = make_weak_ref( + per_callable_static_grad_inputs[idx] + ) + previous_chunk_last_callable_bwd_idx = per_callable_bwd_idx + bwd_idx[m_chunk] += 1 else: # Capture forward graphs From 3e6859e22f5a3b7969f6068f00e148bd825775ad Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 5 Aug 2025 10:37:13 -0400 Subject: [PATCH 043/153] [JAX] Sharding specs for TE GEMM custom call operands (#2023) * new gemm operand specs processing Signed-off-by: Phuong Nguyen * fix for lhs_non_specs Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 159 ++++++------------ 1 file changed, 50 insertions(+), 109 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d2e65d265..897baaa26 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -511,22 +511,6 @@ def batcher( (out_bdims, bias_bdims, pre_gelu_bdims), ) - @staticmethod - def _decompose_operand_specs(specs, contracting_dims, batch_dims): - ndims = len(specs) - cdims, bdims = map(sanitize_dims, (ndims, ndims), (contracting_dims, batch_dims)) - - # Batch specs - bspecs = tuple(specs[i] for i in bdims) - - # Non-batch leading dimension specs - lspecs = tuple(specs[i] for i in range(ndims) if i not in cdims + bdims) - - # Non-batch contracting dimension specs - cspecs = tuple(specs[i] for i in range(ndims) if i in cdims and i not in bdims) - - return bspecs, lspecs, cspecs - @staticmethod def _parse_operand_output_specs( arg_infos, @@ -535,112 +519,74 @@ def _parse_operand_output_specs( sequence_parallel_output, sequence_dim, ): + del sequence_dim, sequence_parallel_output, batched_dims lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) - lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) - lhs_cdims, rhs_cdims, lhs_bdims, rhs_bdims = map( - sanitize_dims, 2 * [lhs_ndim, rhs_ndim], contracting_dims + batched_dims - ) - ( - (lhs_bspecs, lhs_lspecs, lhs_cspecs), - (rhs_bspecs, rhs_lspecs, rhs_cspecs), - ) = map( - GemmPrimitive._decompose_operand_specs, - (lhs_specs, rhs_specs), - (lhs_cdims, rhs_cdims), - (lhs_bdims, rhs_bdims), - ) - - # Batched dimensions must have the same sharding - if len(lhs_bdims) > 0 and len(rhs_bdims) > 0: - assert all( - lhs_bspec == rhs_bspec for lhs_bspec, rhs_bspec in zip(lhs_bspecs, rhs_bspecs) - ), ( - "cuBLAS GEMM operand batch dimensions must have the same sharding: " - f"{lhs_specs} @ idx {lhs_bdims} x {rhs_specs} @ idx {rhs_bdims}." - ) - # Only one each of the non-batched leading dimensions and non-batched contracting - # dimensions can be sharded - lhs_ldims, rhs_ldims = map( - lambda ndim, exclude: tuple(dim for dim in range(ndim) if dim not in exclude), + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) + lhs_non_cdims, rhs_non_cdims = map( + lambda ndim, cdims: tuple(i for i in range(ndim) if i not in cdims), (lhs_ndim, rhs_ndim), - (lhs_bdims + lhs_cdims, rhs_bdims + rhs_cdims), - ) - (lhs_lspec_not_none, rhs_lspec_not_none, lhs_cspec_not_none, rhs_cspec_not_none) = map( - lambda specs: tuple(spec for spec in specs if spec is not None), - (lhs_lspecs, rhs_lspecs, lhs_cspecs, rhs_cspecs), - ) - assert len(lhs_lspec_not_none) <= 1 and len(rhs_lspec_not_none) <= 1, ( - "cuBLAS GEMM operands can have only one sharded non-batched leading dimension: " - f"{lhs_specs} @ idx {lhs_ldims} x {rhs_specs} @ idx {rhs_ldims}." - ) - assert len(lhs_cspec_not_none) <= 1 and len(rhs_cspec_not_none) <= 1, ( - "cuBLAS GEMM operands can have only one sharded non-batched contracting dimension: " - f"{lhs_specs} @ idx {lhs_cdims} x {rhs_specs} @ idx {rhs_cdims}." + (lhs_cdims, rhs_cdims), ) - - # Extract single leading and contracting dimension specs - (lhs_cspec, rhs_cspec) = map( - lambda specs: None if len(specs) == 0 else specs[0], - (lhs_cspec_not_none, rhs_cspec_not_none), + lhs_non_cspecs, lhs_cspecs, rhs_non_cspecs, rhs_cspecs = map( + lambda specs, dims: tuple(specs[i] for i in dims), + (lhs_specs, lhs_specs, rhs_specs, rhs_specs), + (lhs_non_cdims, lhs_cdims, rhs_non_cdims, rhs_cdims), ) - # Partitioning rules: - # ([B], M, K1) x ([B], N, K2)^T = ([B], M, N) - # 1. K1 == K2 != None - # - Require non-batched non-contracting dims of both LHS and RHS to be unsharded. - # - If `sequence_parallel_output=True`, then reduce-scatter the output. - # - Otherwise, all-reduce the output. - # 2. Otherwise - # - Require contracting dimensions of both LHS and RHS to be unsharded. - # - Require non-batched non-contracting dims of LHS to be unsharded. - reduce_output = rhs_cspec is not None and lhs_cspec == rhs_cspec - reduce_spec = scatter_dim = None - if reduce_output: - reduce_spec = rhs_cspec - if sequence_parallel_output: - # If the sequence dimension is not specified, assume it to be the first - # non-batched non-contracting dimension of the LHS operand. - scatter_dim = sequence_dim if sequence_dim is not None else lhs_ldims[0] - - # Always require the non-batched non-contracting dims of LHS to be unsharded - # NOTE: This will all-gather sequence-parallel inputs and preserve tensor-parallel params. - lhs_specs = tuple( - lhs_specs[i] if i in set(lhs_bdims + lhs_cdims) else None for i in range(lhs_ndim) - ) - if reduce_output: - # When reducing GEMM output, require non-batched non-contracting dims of the RHS - # operand to be unsharded (i.e. FSDP) - rhs_specs = tuple( - None if i not in set(rhs_bdims + rhs_cdims) else rhs_specs[i] - for i in range(rhs_ndim) + reduce_spec = None + for l in lhs_cspecs: + for r in rhs_cspecs: + if l is not None and l == r: + assert reduce_spec is None, "Multiple reduce dimension is detected!" + reduce_spec = l + + if reduce_spec is not None: + # Other non-reduce cdims (if exists) need to be unsharded + lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs) + rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) + + # Non-batched non-contracting dims of RHS needs to be unsharded (i.e. FSDP) + # Check if spec is not the batch-dim is not needed as rhs_non_cspecs never includes batch-dim + # rhs_specs only includes batch-dim in the Wgrad GEMM, but there batch-dim belongs to rhs_cspecs + rhs_non_cspecs = tuple( + None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs ) else: # Otherwise, require contracting dims of both operands to be unsharded - lhs_specs = tuple(None if i in lhs_cdims else lhs_specs[i] for i in range(lhs_ndim)) - rhs_specs = tuple(None if i in rhs_cdims else rhs_specs[i] for i in range(rhs_ndim)) + lhs_cspecs = (None,) * len(lhs_cspecs) + rhs_cspecs = (None,) * len(rhs_cspecs) - # Combine modified LHS and RHS specs into the output - lhs_non_contracting_specs, rhs_non_contracting_specs = map( - lambda specs, cdims: tuple(specs[i] for i in range(len(specs)) if i not in cdims), - (lhs_specs, rhs_specs), + # Non-batched non-contracting dims of LHS to be unsharded, i.e gather SP dim + # The spec for batch_dim in lhs_non_cspecs won't ever appear in the rhs_non_cspecs as + # rhs_non_cspecs never has batch-dim. Hence, spec for batch_dim of lhs_non_cspecs won't be + # overwrite + # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for + # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. + lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs) + + out_specs = lhs_non_cspecs + rhs_non_cspecs + + # specs = merge(cspecs, non_cspecs) + lhs_specs, rhs_specs = map( + lambda cdims, cspecs, non_cspecs: ( + cspecs + non_cspecs if cdims[0] == 0 else non_cspecs + cspecs + ), (lhs_cdims, rhs_cdims), + (lhs_cspecs, rhs_cspecs), + (lhs_non_cspecs, rhs_non_cspecs), ) - out_specs = [*lhs_non_contracting_specs, *rhs_non_contracting_specs] # Bias and Pre-GeLU sharding is based on GEMM output before any scatter - bias_specs = tuple(list(out_specs[len(lhs_non_contracting_specs) :]).copy()) + bias_specs = tuple(list(rhs_non_cspecs).copy()) gelu_specs = tuple(list(out_specs).copy()) - # Set output scatter dim to the tensor-parallel spec - if sequence_parallel_output: - out_specs[scatter_dim] = reduce_spec - return ( (lhs_specs, rhs_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs), reduce_spec, - scatter_dim, + 0, ) @staticmethod @@ -717,7 +663,7 @@ def partition( (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (out_specs, dbias_specs, pre_gelu_specs), reduce_spec, - scatter_dim, + _, ) = GemmPrimitive._parse_operand_output_specs( arg_infos, contracting_dims, @@ -785,12 +731,7 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): # All-Reduce/Reduce-Scatter GEMM output if reduce_spec is not None: - if scatter_dim is None: - outputs[0] = jax.lax.psum(outputs[0], reduce_spec) - else: - outputs[0] = jax.lax.psum_scatter( - outputs[0], reduce_spec, scatter_dimension=scatter_dim, tiled=True - ) + outputs[0] = jax.lax.psum(outputs[0], reduce_spec) return outputs From 6c970612715e2a493a2468256c05ce40a11e8556 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 5 Aug 2025 16:55:30 -0400 Subject: [PATCH 044/153] [JAX] Disable TE Norm Custom Calls (#1993) Disable Norm custom calls Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index fcc2108cc..0d19785a0 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -34,7 +34,7 @@ class BasePrimitive(metaclass=ABCMeta): _is_enabled = True # Default list of primitives to disable for all recipes - _default_disable_names = ["GemmPrimitive"] + _default_disable_names = ["GemmPrimitive", "NormFwdPrimitive", "NormBwdPrimitive"] @classmethod def enabled(cls): From 7101f4bec6d0c344b134f501ad6bdec4f326fd0b Mon Sep 17 00:00:00 2001 From: xiaoxi-wangfj <690912414@qq.com> Date: Wed, 6 Aug 2025 16:18:35 +0800 Subject: [PATCH 045/153] [PyTorch] Fix zero initialization in permute kernel for padded slots (#2026) Signed-off-by: xiaoxi-wangfj <690912414@qq.com> --- transformer_engine/pytorch/triton/permutation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/triton/permutation.py b/transformer_engine/pytorch/triton/permutation.py index 9ce01362f..ceb88108f 100644 --- a/transformer_engine/pytorch/triton/permutation.py +++ b/transformer_engine/pytorch/triton/permutation.py @@ -359,7 +359,7 @@ def _permute_kernel( if prob == 0.0: # for routing_map padding # dst_row != -1 and prob == 0.0 means that this slot is padded - tl.store(output_ptr + output_off, 0, mask=mask) + tl.store(output_ptr + output_off, 0.0, mask=mask) else: tl.store(output_ptr + output_off, inp, mask=mask) else: From ed42b5ac6fd0e6cf48ef037adcb875f268d94151 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 6 Aug 2025 12:19:15 -0400 Subject: [PATCH 046/153] [JAX] Remove `dot_1_output_axes` usage in LayerNormMLP (#2029) * remove dot1_output_axes Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/flax/module.py | 10 ---------- transformer_engine/jax/layernorm_mlp.py | 10 +--------- transformer_engine/jax/sharding.py | 21 --------------------- 3 files changed, 1 insertion(+), 40 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index ba5ee6d13..60c39a037 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -31,7 +31,6 @@ jax_scaled_upper_triang_masked_softmax, ) from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode -from ..sharding import get_non_contracting_logical_axes PRNGKey = Any Shape = Tuple[int, ...] @@ -1206,15 +1205,6 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): quantizer_set=ffn1_quantizer_set, ) - if self.dot_1_input_axes is not None and self.kernel_axes_1 is not None: - dot_1_output_axes = ( - *get_non_contracting_logical_axes(y.ndim, self.dot_1_input_axes, axis), - *get_non_contracting_logical_axes( - kernel_1.ndim, self.kernel_axes_1, contract_ind - ), - ) - x = with_sharding_constraint_by_logical_axes(x, dot_1_output_axes) - if self.enable_low_rank_adaptation: wi_lora_a_kernel_each_shape = ( kernel_1_each_shape[: len(axis)], diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 8dd045100..5b738e46b 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -30,7 +30,6 @@ TensorUsage, ) from .sharding import ( - get_non_contracting_logical_axes, get_sequence_parallel_dim, ) @@ -259,7 +258,7 @@ def _layernorm_mlp_fwd_rule( Returns: Tuple of (output, context) for automatic differentiation """ - del kernel_2_axes + del kernel_1_axes, kernel_2_axes ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -318,13 +317,6 @@ def _layernorm_mlp_fwd_rule( fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) - if dot_1_input_axes is not None and kernel_1_axes is not None: - dot_1_output_axes = ( - *get_non_contracting_logical_axes(x.ndim, dot_1_input_axes, x_contracting_dims), - *get_non_contracting_logical_axes(kernel_1.ndim, kernel_1_axes, k_contracting_dims), - ) - dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) - if use_bias_1 and tex.gemm_uses_jax_dot(): bias_1_shape = bias_1.shape bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index a7bbef997..6dd2e88a6 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -427,24 +427,3 @@ class ShardingType(Enum): TP_ROW = (MajorShardingType.TP, "tp_row") DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") - - -def get_non_contracting_logical_axes( - ndim, logical_axes: tuple[Optional[str]], contracting_dims -) -> tuple[Optional[str]]: - """Get logical axes for non-contracting dimensions. - - Args: - ndim: Number of dimensions in the tensor. - logical_axes: Tuple of logical axes for each dimension. - contracting_dims: Set of dimensions that are being contracted. - - Returns: - Tuple of logical axes for non-contracting dimensions. - """ - assert logical_axes is not None, "Logical axes must be a tuple and cannot be None." - assert len(logical_axes) == ndim, "Logical axes must match the number of dimensions." - - non_contracting_dims = [i for i in range(ndim) if i not in contracting_dims] - non_contracting_logical_axes = tuple(logical_axes[i] for i in non_contracting_dims) - return non_contracting_logical_axes From 6d178b4e0465cb6739d7213a6814325aa92d2e38 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 6 Aug 2025 09:19:27 -0700 Subject: [PATCH 047/153] [JAX] Reduce L1 tests/jax/test_distributed_softmax.py test runtime (#2031) * Pytest timings Signed-off-by: Jeremy Berchtold * Reduce softmax test shape sizes Signed-off-by: Jeremy Berchtold * Switch softmax tests to use shardy by default Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- tests/jax/conftest.py | 53 +++++++++++++++++++++++++++ tests/jax/test_distributed_softmax.py | 8 ++-- 2 files changed, 57 insertions(+), 4 deletions(-) diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index 663a95418..cb5676d51 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -5,6 +5,8 @@ import os import jax import pytest +from collections import defaultdict +import time import transformer_engine.jax @@ -32,3 +34,54 @@ def enable_fused_attn_after_hopper(): yield if "NVTE_FUSED_ATTN" in os.environ: del os.environ["NVTE_FUSED_ATTN"] + + +class TestTimingPlugin: + """ + Plugin to measure test execution time. Enable test timing by setting NVTE_JAX_TEST_TIMING=1 + in the environment. + """ + + def __init__(self): + self.test_timings = defaultdict(list) + + @pytest.hookimpl(tryfirst=True) + def pytest_runtest_setup(self, item): + item._timing_start = time.time() + + @pytest.hookimpl(trylast=True) + def pytest_runtest_teardown(self, item, nextitem): + if hasattr(item, "_timing_start"): + duration = time.time() - item._timing_start + + # Extract base function name without parameters + test_name = item.name + if "[" in test_name: + base_name = test_name.split("[")[0] + else: + base_name = test_name + + self.test_timings[base_name].append(duration) + + def pytest_sessionfinish(self, session, exitstatus): + print("\n" + "=" * 80) + print("TEST RUNTIME SUMMARY (grouped by function)") + print("=" * 80) + + total_overall = 0 + for test_name, durations in sorted(self.test_timings.items()): + total_time = sum(durations) + count = len(durations) + avg_time = total_time / count if count > 0 else 0 + total_overall += total_time + + print(f"{test_name:<60} | {count:3}x | {total_time:7.2f}s | avg: {avg_time:6.2f}s") + + print("=" * 80) + print(f"{'TOTAL RUNTIME':<60} | {'':>3} | {total_overall:7.2f}s |") + print("=" * 80) + + +def pytest_configure(config): + if os.getenv("NVTE_JAX_TEST_TIMING", "0") == "1": + config.pluginmanager.register(TestTimingPlugin(), "test_timing") diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index cb30c34ab..8d2ad6fad 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -135,7 +135,7 @@ def impl_test_softmax( ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [64, 16, 1024, 1024]]) + @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]]) @pytest.mark.parametrize( "softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED, SoftmaxType.SCALED_UPPER_TRIANG_MASKED], @@ -168,14 +168,14 @@ def test_softmax( dtype, bad_sharding, broadcast_batch_mask, - use_shardy=False, + use_shardy=True, ) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED]) @pytest.mark.parametrize("bad_sharding", [False, True]) @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) - def test_softmax_shardy( + def test_softmax_gspmd( self, device_count, mesh_shape, @@ -196,5 +196,5 @@ def test_softmax_shardy( dtype=DTYPES[0], bad_sharding=bad_sharding, broadcast_batch_mask=broadcast_batch_mask, - use_shardy=True, + use_shardy=False, ) From c0d2f1a54be61162bc336bcb60fcb7cfaf647018 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 7 Aug 2025 02:39:22 +0800 Subject: [PATCH 048/153] [PyTorch] Multi-tensor swizzle scaling factors for MXFP8 and fuse padding zeros (#2019) * for loop Signed-off-by: Xin Yao * bulk alloc Signed-off-by: Xin Yao * multi-tensor swizzle Signed-off-by: Xin Yao * pad zeros in swizzle kernels Signed-off-by: Xin Yao * unify single- and multi-tensor swizzle Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix empty tensor list Signed-off-by: Xin Yao * fix bug for col swizzle Signed-off-by: Xin Yao * check context & fix signifiers Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- benchmarks/linear/benchmark_grouped_linear.py | 2 +- transformer_engine/common/common.cu | 1 + .../include/transformer_engine/swizzle.h | 14 + transformer_engine/common/swizzle/swizzle.cu | 439 ++++++++++++++++-- transformer_engine/common/util/padding.cu | 1 + .../pytorch/csrc/extensions/cast.cpp | 10 +- .../pytorch/csrc/extensions/gemm.cpp | 19 +- transformer_engine/pytorch/csrc/quantizer.cpp | 8 +- transformer_engine/pytorch/csrc/util.cpp | 95 ++++ transformer_engine/pytorch/csrc/util.h | 11 +- 10 files changed, 533 insertions(+), 67 deletions(-) diff --git a/benchmarks/linear/benchmark_grouped_linear.py b/benchmarks/linear/benchmark_grouped_linear.py index 0dbee212d..44f1c8967 100644 --- a/benchmarks/linear/benchmark_grouped_linear.py +++ b/benchmarks/linear/benchmark_grouped_linear.py @@ -247,7 +247,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4): num_gemms_list = [8] if args.profile: - mkns = [(4096, 4096, 4096)] + mkns = [(4096 * 8, 4096, 4096)] # in profile mode, only run one recipe specified in args.recipe assert args.recipe != "all", ( "In profile mode, only one recipe can be specified, please specify the recipe as" diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 619bf6ca0..9831bbb24 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -138,6 +138,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits) { + cuda_driver::ensure_context_exists(); // Get a function pointer to the cuTensorMapEncodeTiled driver API // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { diff --git a/transformer_engine/common/include/transformer_engine/swizzle.h b/transformer_engine/common/include/transformer_engine/swizzle.h index de5a11eb7..079feb4a7 100644 --- a/transformer_engine/common/include/transformer_engine/swizzle.h +++ b/transformer_engine/common/include/transformer_engine/swizzle.h @@ -30,6 +30,20 @@ extern "C" { */ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Swizzling scaling factors into the required interleaved layout for GEMM + * + * \param[in] inputs Input tensors with non-swizzled scale_inv. + * \param[in,out] outputs Output tensors which hosts swizzled scale_inv. + * \param[in] stream CUDA stream used for the operation. + * + * Requirements: + * - scale_inv is stored in row-major. + * - scale_inv size is padded to 128x4 for row-scale and 4x128 for col-scale. + * - data is quantitized along K-dimension, i.e. 1D-scaling block lies along the K-dimension. + */ +void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index cea0e5080..37d7491d9 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -15,15 +15,17 @@ #include "../util/logging.h" #include "transformer_engine/transformer_engine.h" +namespace transformer_engine { namespace { -constexpr int TB_DIM = 32; -constexpr int NEW_SF_TILE_DIM_K = 16; -constexpr int N_SF_PER_TD_PER_TILE = 4; +constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; +constexpr __device__ __host__ int TB_DIM = 32; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; +constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; // output is in ~K-major interleaved blocks -constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; -constexpr int NEW_SF_TILE_DIM_M_I32 = 32; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; +constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; template __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { @@ -51,8 +53,11 @@ __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { } template -__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, - const int K) { +__device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; @@ -66,21 +71,24 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons int m_tiles_in_tb = N_TILE_PER_TD; int k_tiles_in_tb = TB_DIM; - if (blockIdx.x == gridDim.x - 1) { + if (bid_x == grid_dim_x - 1) { k_tiles_in_tb = (K_i32 / SF_TILE_DIM_K_I32 - 1) % k_tiles_in_tb + 1; } - if (blockIdx.y == gridDim.y - 1) { + if (bid_y == grid_dim_y - 1) { m_tiles_in_tb = (M_i32 / SF_TILE_DIM_M_I32 - 1) % m_tiles_in_tb + 1; } - const int32_t* input_i32 = reinterpret_cast(input) + - blockIdx.x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + - blockIdx.y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + + const int input_offset = + bid_x * TB_DIM * SF_TILE_DIM_K_I32 * M_i32 + bid_y * N_TILE_PER_TD * SF_TILE_DIM_M_I32; + const int32_t* input_i32 = reinterpret_cast(input) + input_offset; int32_t* output_i32[N_TILE_PER_TD]; #pragma unroll for (int i = 0; i < m_tiles_in_tb; i++) { - output_i32[i] = reinterpret_cast(output) + blockIdx.x * TB_DIM * SF_TILE_SIZE_I32 + - (blockIdx.y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; + output_i32[i] = reinterpret_cast(output) + bid_x * TB_DIM * SF_TILE_SIZE_I32 + + (bid_y * N_TILE_PER_TD + i) * SF_TILE_DIM_M_I32 * K_i32; } extern __shared__ int slm[]; @@ -90,8 +98,18 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons threadIdx.y < k_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { - regs_vec[i] = __ldg(reinterpret_cast( - input_i32 + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD)); + const int thread_offset = + (threadIdx.y * SF_TILE_DIM_K_I32 + i) * M_i32 + threadIdx.x * N_TILE_PER_TD; + regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); + // Pad zeros + if (padding_m || padding_k) { + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (input_offset + thread_offset) * sizeof(int) + j; + if (index / M >= original_K || index % M >= original_M) { + reinterpret_cast(regs_vec + i)[j] = 0; + } + } + } } // local shuffle @@ -126,6 +144,14 @@ __global__ void swizzle_col_scaling_kernel(const void* input, void* output, cons } } +template +__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K) { + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} + template __device__ inline void regs_shuffle(LType* regs_vec) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); @@ -143,8 +169,11 @@ __device__ inline void regs_shuffle(LType* regs_vec) { } template -__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, - const int K) { +__device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K, const int bid_x, + const int bid_y, const int grid_dim_x, + const int grid_dim_y) { constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; @@ -154,14 +183,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons int n_tiles_in_tb = N_TILES_IN_TB; const int K_i32 = K / 4; - if (blockIdx.x == gridDim.x - 1) { + if (bid_x == grid_dim_x - 1) { n_tiles_in_tb = (K_i32 - 1) % N_TILES_IN_TB + 1; } - const int* input_i32 = reinterpret_cast(input) + - blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + blockIdx.x * N_TILES_IN_TB; - int* output_i32 = reinterpret_cast(output) + blockIdx.y * SF_TILE_DIM_M_I32 * K_i32 + - blockIdx.x * N_TILES_IN_TB * SF_TILE_SIZE_I32; + bool padding_m = (bid_y == grid_dim_y - 1) && (original_M < M); + bool padding_k = (bid_x == grid_dim_x - 1) && (original_K < K); + + const int input_offset = bid_y * SF_TILE_DIM_M_I32 * K_i32 + bid_x * N_TILES_IN_TB; + const int* input_i32 = reinterpret_cast(input) + input_offset; + int* output_i32 = reinterpret_cast(output) + bid_y * SF_TILE_DIM_M_I32 * K_i32 + + bid_x * N_TILES_IN_TB * SF_TILE_SIZE_I32; extern __shared__ int4 slm_v4i[]; @@ -170,8 +202,17 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons if (threadIdx.x * N_TILE_PER_TD < n_tiles_in_tb) { #pragma unroll for (int i = 0; i < N_SF_PER_TD_PER_TILE; i++) { - regs_vec[i] = __ldg(reinterpret_cast( - input_i32 + (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD)); + const int thread_offset = (i * TB_DIM + threadIdx.y) * K_i32 + threadIdx.x * N_TILE_PER_TD; + regs_vec[i] = __ldg(reinterpret_cast(input_i32 + thread_offset)); + if (padding_m || padding_k) { + // Pad zeros + for (int j = 0; j < N_TILE_PER_TD * sizeof(int); j++) { + const int index = (input_offset + thread_offset) * sizeof(int) + j; + if (index / K >= original_M || index % K >= original_K) { + reinterpret_cast(regs_vec + i)[j] = 0; + } + } + } } // shuffle regs @@ -196,9 +237,99 @@ __global__ void swizzle_row_scaling_kernel(const void* input, void* output, cons } } -} // namespace +template +__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, + const int K, const int original_M, + const int original_K) { + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); +} -namespace transformer_engine { +constexpr int kMaxTensorsPerKernel = 64; // Args must be <4 KB +struct MultiSwizzleArgs { + // (input) Data buffers for input scaling factors + void* input_list[kMaxTensorsPerKernel]; + // (output) Data buffers for swizzled scaling factors + void* output_list[kMaxTensorsPerKernel]; + // Input scaling factor m + int m_list[kMaxTensorsPerKernel]; + // Input scaling factor k + int k_list[kMaxTensorsPerKernel]; + // Input scaling factor m before padding + int original_m_list[kMaxTensorsPerKernel]; + // Input scaling factor k before padding + int original_k_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of CUDA blocks needed for each + // tensor + int block_range[kMaxTensorsPerKernel + 1]; + // Number of tensors being processed by kernel + int num_tensors; +}; + +template +__global__ void multi_tensor_swizzle_row_scaling_kernel(MultiSwizzleArgs kernel_args) { + // Find tensor corresponding to block + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + // Get args corresponding to block + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD; + + // Get block index in grid. Emulate 2D grid. + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB); + const int grid_dim_y = num_tiles_m; + const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; + const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; + + swizzle_row_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +template +__global__ void multi_tensor_swizzle_col_scaling_kernel(MultiSwizzleArgs kernel_args) { + // Find tensor corresponding to block + const int bid = blockIdx.x; + int tensor_id = 0; + while (kernel_args.block_range[tensor_id + 1] <= bid) { + ++tensor_id; + } + // Get args corresponding to block + const void* input = kernel_args.input_list[tensor_id]; + void* output = kernel_args.output_list[tensor_id]; + const int M = kernel_args.m_list[tensor_id]; + const int K = kernel_args.k_list[tensor_id]; + const int original_M = kernel_args.original_m_list[tensor_id]; + const int original_K = kernel_args.original_k_list[tensor_id]; + + constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int); + constexpr int N_SF_PER_TD = N_TILE_PER_TD * N_SF_PER_TD_PER_TILE; + constexpr int SF_TILE_SIZE_I32 = SF_TILE_DIM_M * SF_TILE_DIM_K / 4; + + // Get block index in grid. Emulate 2D grid. + const int num_tiles_k = K / SF_TILE_DIM_K; + const int num_tiles_m = M / SF_TILE_DIM_M; + const int grid_dim_x = DIVUP(num_tiles_k, TB_DIM); + const int grid_dim_y = DIVUP(num_tiles_m, N_TILE_PER_TD); + const int bid_x = (bid - kernel_args.block_range[tensor_id]) / grid_dim_y; + const int bid_y = (bid - kernel_args.block_range[tensor_id]) % grid_dim_y; + + swizzle_col_scaling_kernel_impl( + input, output, M, K, original_M, original_K, bid_x, bid_y, grid_dim_x, grid_dim_y); +} + +} // namespace void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t stream) { if (!is_fp8_dtype(input->dtype()) || is_delayed_tensor_scaling(input->scaling_mode)) { @@ -252,27 +383,29 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int n_tiles_in_tb = TB_DIM * vec_load_size; dim3 num_blocks(DIVUP(num_tiles_k, n_tiles_in_tb), num_tiles_m); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_first_dim(); + const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: cudaFuncSetAttribute(swizzle_row_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 2: cudaFuncSetAttribute(swizzle_row_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 1: cudaFuncSetAttribute(swizzle_row_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_row_scaling_kernel - <<>>(input->scale_inv.dptr, - output->scale_inv.dptr, m, k); + <<>>( + input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; default: NVTE_ERROR("Not valid vec_load_size."); @@ -285,27 +418,32 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s int n_tiles_in_tb = TB_DIM * vec_load_size; dim3 num_blocks(DIVUP(num_tiles_k, TB_DIM), DIVUP(num_tiles_m, vec_load_size)); int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + const int original_M = input->flat_last_dim(); + const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: cudaFuncSetAttribute(swizzle_col_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; case 2: cudaFuncSetAttribute(swizzle_col_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; case 1: cudaFuncSetAttribute(swizzle_col_scaling_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); swizzle_col_scaling_kernel - <<>>( - input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k); + <<>>(input->columnwise_scale_inv.dptr, + output->columnwise_scale_inv.dptr, m, + k, original_M, original_K); break; default: NVTE_ERROR("Not valid vec_load_size."); @@ -317,10 +455,212 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s } else { NVTE_ERROR("Not implemented for scaling_mode " + to_string(input->scaling_mode) + ", trans."); } - cudaError_t err = cudaGetLastError(); - if (err != cudaSuccess) { - printf("CUDA Error: %s\n", cudaGetErrorString(err)); - exit(-1); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, + const int vec_load_size, const bool is_rowwise, + cudaStream_t stream) { + int n_tiles_in_tb = TB_DIM * vec_load_size; + int slm_size = n_tiles_in_tb * SF_TILE_DIM_M * SF_TILE_DIM_K * sizeof(int8_t); + /* Calculate number of CUDA blocks needed for each tensor. + * We have to do it here because we have to iterate over all tensors in this batch to + * get the minimum vec_load_size. + */ + for (size_t j = 0; j < kernel_args.num_tensors; j++) { + const int m = kernel_args.m_list[j]; + const int k = kernel_args.k_list[j]; + int num_tiles_m = m / SF_TILE_DIM_M; + int num_tiles_k = k / SF_TILE_DIM_K; + if (is_rowwise) { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + DIVUP(num_tiles_k, n_tiles_in_tb) * num_tiles_m; + } else { + kernel_args.block_range[j + 1] = + kernel_args.block_range[j] + + DIVUP(num_tiles_k, TB_DIM) * DIVUP(num_tiles_m, vec_load_size); + } + } + // Launch kernel + const int num_blocks = kernel_args.block_range[kernel_args.num_tensors]; + dim3 block_size(TB_DIM, TB_DIM); + if (is_rowwise) { + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 2: + cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + case 1: + cudaFuncSetAttribute( + multi_tensor_swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_row_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } else { + switch (vec_load_size) { + case 4: + cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 2: + cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + case 1: + cudaFuncSetAttribute( + multi_tensor_swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + multi_tensor_swizzle_col_scaling_kernel + <<>>(kernel_args); + break; + default: + NVTE_ERROR("Not valid vec_load_size."); + break; + } + } + NVTE_CHECK_CUDA(cudaGetLastError()); +} +void multi_tensor_swizzle_scaling_factors(const std::vector& input, + std::vector& output, cudaStream_t stream) { + auto num_tensors = input.size(); + bool all_has_data = true; + bool all_has_columnwise_data = true; + for (size_t i = 0; i < num_tensors; i++) { + if (!is_fp8_dtype(input[i]->dtype()) || !is_mxfp_scaling(input[i]->scaling_mode)) { + NVTE_ERROR("Not implemented caling mode " + to_string(input[i]->scaling_mode) + "."); + } + // We don't allow empty tensors. They should be filtered out before calling this function. + if (input[i]->data.numel() == 0) { + NVTE_ERROR("Tensor input[" + std::to_string(i) + "] is empty."); + } + CheckInputTensor(*input[i], "scaling_factor_input[" + std::to_string(i) + "]"); + CheckInputTensor(*output[i], "scaling_factor_output[" + std::to_string(i) + "]"); + all_has_data &= input[i]->has_data(); + all_has_columnwise_data &= input[i]->has_columnwise_data(); + } + NVTE_CHECK(all_has_data || all_has_columnwise_data, + "All tensors should have data or columnwise data."); + + constexpr int SF_TILE_DIM_M = 128; + constexpr int SF_TILE_DIM_K = 4; + if (all_has_data) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + //Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + // Reset the argument struct and vec_load_size + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->scale_inv.shape[0]; + const int k = input[i]->scale_inv.shape[1]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK( + m * k == std::accumulate(output[i]->scale_inv.shape.begin(), + output[i]->scale_inv.shape.end(), 1, std::multiplies()), + "Input.scale_inv size is not equal to Output.scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + // We use the minimum vec_load_size across all tensors. + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = input[i]->flat_first_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_last_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + // Launch the remaining tensors + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, true, stream); + } + + if (all_has_columnwise_data) { + MultiSwizzleArgs kernel_args; + kernel_args.num_tensors = 0; + kernel_args.block_range[0] = 0; + int vec_load_size = 4; + for (size_t i = 0; i < num_tensors; i++) { + //Launch kernel if argument struct is full + if (kernel_args.num_tensors == kMaxTensorsPerKernel) { + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); + // Reset the argument struct and vec_load_size + kernel_args.num_tensors = 0; + vec_load_size = 4; + } + const int m = input[i]->columnwise_scale_inv.shape[1]; + const int k = input[i]->columnwise_scale_inv.shape[0]; + + NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!"); + NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!"); + NVTE_CHECK(k > 0, "Input scale inverse should be 2D!"); + NVTE_CHECK(m * k == std::accumulate(output[i]->columnwise_scale_inv.shape.begin(), + output[i]->columnwise_scale_inv.shape.end(), 1, + std::multiplies()), + "Input.columnwise_scale_inv size is not equal to " + "Output.columnwise_scale_inv size!"); + + int num_tiles_k = k / SF_TILE_DIM_K; + int vec_load_size_i = (num_tiles_k - 1) % 4 + 1; + // We use the minimum vec_load_size across all tensors. + vec_load_size = std::min(vec_load_size, vec_load_size_i); + + const int pos = kernel_args.num_tensors; + kernel_args.input_list[pos] = const_cast(input[i]->columnwise_scale_inv.dptr); + kernel_args.output_list[pos] = output[i]->columnwise_scale_inv.dptr; + kernel_args.m_list[pos] = m; + kernel_args.k_list[pos] = k; + kernel_args.original_m_list[pos] = input[i]->flat_last_dim(); + kernel_args.original_k_list[pos] = input[i]->flat_first_dim() / MXFP8_BLOCK_SIZE; + kernel_args.num_tensors++; + } + // Launch the remaining tensors + // There is no int3 and misaligned if using int4/int2. + if (vec_load_size == 3) vec_load_size = 1; + launch_multi_tensor_swizzle_scaling_factors( + kernel_args, vec_load_size, false, stream); } } } // namespace transformer_engine @@ -335,3 +675,16 @@ void nvte_swizzle_scaling_factors(const NVTETensor input, NVTETensor output, cud using namespace transformer_engine; swizzle_scaling_factors(convertNVTETensorCheck(input), convertNVTETensorCheck(output), stream); } + +void nvte_multi_tensor_swizzle_scaling_factors(const NVTETensor* inputs, NVTETensor* outputs, + const size_t num_tensors, cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_tensor_swizzle_scaling_factors); + using namespace transformer_engine; + NVTE_CHECK(num_tensors > 0, "Number of tensors should be greater than 0."); + std::vector input_list, output_list; + for (size_t i = 0; i < num_tensors; i++) { + input_list.push_back(convertNVTETensorCheck(inputs[i])); + output_list.push_back(convertNVTETensorCheck(outputs[i])); + } + multi_tensor_swizzle_scaling_factors(input_list, output_list, stream); +} diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index a1899d5b1..ad6cf2a2e 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -35,6 +35,7 @@ struct MultiPaddingArgs { int padded_num_rows_list[kMaxTensorsPerKernel]; // Input matrix widths int row_length_list[kMaxTensorsPerKernel]; + // Prefix sum (with leading zero) of CUDA blocks needed for each // tensor int block_range[kMaxTensorsPerKernel + 1]; // Number of tensors being processed by kernel diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5408cf1a6..fe7aecbc2 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -398,11 +398,8 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel auto buffer = std::make_shared( - at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // auto buffer = std::make_shared( - // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { @@ -441,11 +438,8 @@ std::tuple, std::vector> bulk_allocate_mx } // Allocate full buffer - // TODO(zhongbo): use torch.empty if zero padding is added to the swizzle kernel auto buffer = std::make_shared( - at::zeros({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); - // auto buffer = std::make_shared( - // at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); + at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8))); // Construct tensor views for (size_t i = 0; i < num_tensors; ++i) { diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 99bb4e69f..4f1ab3e56 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -326,10 +326,8 @@ std::optional> te_general_grouped_gemm( size_t workspaceSize, bool accumulate, bool use_split_accumulator, int math_sm_count) { std::vector te_A_vector, te_B_vector, te_D_vector, te_bias_vector, te_pre_gelu_out_vector, te_workspace_vector; - std::vector wrappers; + std::vector te_A_wrappers, te_B_wrappers, wrappers; std::vector D_vectors; - // Keep the swizzled scaling factor tensors alive during the GEMMs. - std::vector> swizzled_scale_inverses_list; auto none = py::none(); @@ -396,10 +394,6 @@ std::optional> te_general_grouped_gemm( continue; } - // Optionally swizzle the scaling factors - swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_A, transa))); - swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(te_B, !transb))); - auto te_D = makeTransformerEngineTensor(out_tensor); auto te_bias = makeTransformerEngineTensor(bias[i]); auto te_pre_gelu_out = makeTransformerEngineTensor(pre_gelu_out[i]); @@ -419,18 +413,25 @@ std::optional> te_general_grouped_gemm( te_bias_vector.emplace_back(te_bias.data()); te_pre_gelu_out_vector.emplace_back(te_pre_gelu_out.data()); - wrappers.emplace_back(std::move(te_A)); - wrappers.emplace_back(std::move(te_B)); + te_A_wrappers.emplace_back(std::move(te_A)); + te_B_wrappers.emplace_back(std::move(te_B)); wrappers.emplace_back(std::move(te_D)); wrappers.emplace_back(std::move(te_bias)); wrappers.emplace_back(std::move(te_pre_gelu_out)); } + + // Optionally swizzle the scaling factors + // Keep the swizzled scaling factor tensors alive during the GEMMs. + auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); + auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); + for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), std::vector{workspaceSize}, DType::kByte); te_workspace_vector.emplace_back(wsp.data()); wrappers.emplace_back(std::move(wsp)); } + // For now, we only have multi-stream cublas backend. NVTE_SCOPED_GIL_RELEASE({ nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index f0e0aba00..fc5f99dcb 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -841,13 +841,13 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); rowwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - rowwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); + rowwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); } if (columnwise_usage) { const std::vector scale_inv_shape_int64(columnwise_scale_inv_shape.begin(), columnwise_scale_inv_shape.end()); columnwise_data_tensor = at::empty(shape_int64, uint8_tensor_opts); - columnwise_scale_inv_tensor = at::zeros(scale_inv_shape_int64, uint8_tensor_opts); + columnwise_scale_inv_tensor = at::empty(scale_inv_shape_int64, uint8_tensor_opts); } // Convert tensors to Python @@ -939,7 +939,7 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), scale_inv_shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - rowwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + rowwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; } } else { // rowwise_usage == false @@ -966,7 +966,7 @@ std::pair MXFP8Quantizer::convert_and_update_tensor( const std::vector scale_inv_shape_int64(scale_inv_shape.begin(), scale_inv_shape.end()); const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - columnwise_scale_inv = at::zeros(scale_inv_shape_int64, opts); + columnwise_scale_inv = at::empty(scale_inv_shape_int64, opts); tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; } } else { // columnwise_usage == false diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index a878345ff..92f2d3a50 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -75,3 +75,98 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap return swizzled_scale_inv; } + +std::optional multi_tensor_swizzle_scaling_factors( + std::vector& tensors, bool rowwise) { + using namespace transformer_engine::pytorch; + + if (tensors.empty()) { + return std::nullopt; + } + + bool all_same_scaling_mode = std::all_of( + tensors.cbegin(), tensors.cend(), [&tensors](const transformer_engine::TensorWrapper& val) { + return val.scaling_mode() == tensors.front().scaling_mode(); + }); + NVTE_CHECK(all_same_scaling_mode, "Scaling mode of the input tensors must be the same."); + + if (tensors.front().scaling_mode() == NVTE_INVALID_SCALING) { + NVTE_ERROR("Invalid scaling mode for swizzle."); + } else if (tensors.front().scaling_mode() != NVTE_MXFP8_1D_SCALING) { + return std::nullopt; + } + + std::vector wrappers; + std::vector input_tensors, output_tensors; + + // Collect scale_inv shapes and calculate buffer size and offsets for scale_invs + std::vector> scale_inv_shapes; + std::vector scale_inv_dptrs; + size_t buffer_size = 0; + std::vector scale_inv_offsets; + constexpr size_t scale_elem_size = 1; + for (auto& tensor : tensors) { + NVTEBasicTensor scale_inv; + if (rowwise) { + scale_inv = tensor.get_rowwise_scale_inv(); + } else { + scale_inv = tensor.get_columnwise_scale_inv(); + } + auto scale_inv_shape = nvte_shape_to_vector(scale_inv.shape); + buffer_size = roundup(buffer_size, 16); // align to 16B + scale_inv_offsets.push_back(buffer_size); + buffer_size += product(scale_inv_shape) * scale_elem_size; + scale_inv_shapes.emplace_back(scale_inv_shape); + scale_inv_dptrs.push_back(scale_inv.data_ptr); + } + + // Allocate full buffer + auto buffer = at::empty({(int64_t)buffer_size}, at::device(at::kCUDA).dtype(torch::kUInt8)); + + for (size_t i = 0; i < tensors.size(); ++i) { + auto& tensor = tensors[i]; + void* scale_inv_dptr = scale_inv_dptrs[i]; + void* swizzled_scale_inv_dptr = getDataPtr(buffer, scale_inv_offsets[i]); + auto input_shape = nvte_shape_to_vector(tensor.shape()); + + // Reconstruct input only to avoid swizzling both directions if not needed. + // Use any 8 bit type, it's irrelevant. + transformer_engine::TensorWrapper input_cu(NVTE_MXFP8_1D_SCALING); + transformer_engine::TensorWrapper output_cu(NVTE_MXFP8_1D_SCALING); + if (rowwise) { + input_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, input_shape); + input_cu.set_rowwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + output_cu.set_rowwise_data(tensor.dptr(), transformer_engine::DType::kFloat8E4M3, + input_shape); + output_cu.set_rowwise_scale_inv(swizzled_scale_inv_dptr, + transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + // Set the swizzled scaling factor to the original tensor. + tensor.set_rowwise_scale_inv(swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + } else { + input_cu.set_columnwise_data(tensor.columnwise_dptr(), transformer_engine::DType::kFloat8E4M3, + input_shape); + input_cu.set_columnwise_scale_inv(scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, + scale_inv_shapes[i]); + output_cu.set_columnwise_data(tensor.columnwise_dptr(), + transformer_engine::DType::kFloat8E4M3, input_shape); + output_cu.set_columnwise_scale_inv( + swizzled_scale_inv_dptr, transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + // Set the swizzled scaling factor to the original tensor. + tensor.set_columnwise_scale_inv(swizzled_scale_inv_dptr, + transformer_engine::DType::kFloat8E8M0, scale_inv_shapes[i]); + } + + input_tensors.emplace_back(input_cu.data()); + output_tensors.emplace_back(output_cu.data()); + wrappers.emplace_back(std::move(input_cu)); + wrappers.emplace_back(std::move(output_cu)); + } + + // Launch kernel + nvte_multi_tensor_swizzle_scaling_factors(input_tensors.data(), output_tensors.data(), + input_tensors.size(), at::cuda::getCurrentCUDAStream()); + + return buffer; +} diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 0cfeb81f5..4b2686096 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -13,11 +13,18 @@ #include "transformer_engine/transformer_engine.h" -/* Swizzle the scaling factor of the input tensor. +/*! \brief Swizzle the scaling factor of the input tensor. * * The returned swizzled scaling factor tensor should be kept alive during the GEMM. */ std::optional swizzle_scaling_factors(transformer_engine::TensorWrapper &input, - bool trans); + bool rowwise); + +/*! \brief Swizzle the scaling factor of the input tensors. + * + * The returned swizzled scaling factor tensors should be kept alive during the GEMMs. + */ +std::optional multi_tensor_swizzle_scaling_factors( + std::vector &inputs, bool rowwise); #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ From de69ca0e7e6a2c2f045f30b23fb47b8f11fca8d6 Mon Sep 17 00:00:00 2001 From: hx Date: Thu, 7 Aug 2025 05:18:03 +0800 Subject: [PATCH 049/153] [PyTorch] fix input_quantizer usage for save_original_input; fix blockwise FP8 convert_and_update_tensor (#1978) * fix input_quantizer in save_original_input bwd Signed-off-by: Hongxiao Bai * fix get shape of blockwise tensor with only compact colwise data Signed-off-by: Hongxiao Bai * fix blockwise FP8 convert_and_update_tensor Signed-off-by: Hongxiao Bai * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongxiao Bai Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_float8blockwisetensor.py | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 129 +++++++++++++++++- transformer_engine/pytorch/module/linear.py | 11 +- 3 files changed, 129 insertions(+), 13 deletions(-) diff --git a/tests/pytorch/test_float8blockwisetensor.py b/tests/pytorch/test_float8blockwisetensor.py index 1f23be362..39062b442 100644 --- a/tests/pytorch/test_float8blockwisetensor.py +++ b/tests/pytorch/test_float8blockwisetensor.py @@ -219,7 +219,7 @@ def test_quantize_dequantize_compact_format( rowwise=True, columnwise=dq_columnwise, block_scaling_dim=block_scaling_dim, - all_gather_usage=True, + all_gather_usage=(block_scaling_dim == 1), ) self._test_quantize_dequantize( quantizer=quantizer, diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index fc5f99dcb..0c75789ed 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -671,13 +671,128 @@ std::pair Float8BlockQuantizer::convert_and_update_te const DType dtype = tensor.attr("_fp8_dtype").cast(); bool is_2D_scaled = tensor.attr("_is_2D_scaled").cast(); - // Check the data matches quantizer usages - NVTE_CHECK(!tensor.attr("_rowwise_data").is_none() == rowwise_usage, - "Float8BlockwiseQTensor does not match quantizer usages (has_rowwise_data=", - !tensor.attr("_rowwise_data").is_none(), ", rowwise_usage=", rowwise_usage); - NVTE_CHECK(!tensor.attr("_columnwise_data").is_none() == columnwise_usage, - "Float8BlockwiseQTensor does not match quantizer usages (has_columnwise_data=", - !tensor.attr("_columnwise_data").is_none(), ", columnwise_usage=", columnwise_usage); + // Extract buffers from Python tensor + auto get_tensor = [&tensor](const char* name) -> std::optional { + auto attr_py = tensor.attr(name); + if (attr_py.is_none()) { + return std::nullopt; + } + return attr_py.cast(); + }; + auto rowwise_data = get_tensor("_rowwise_data"); + auto rowwise_scale_inv = get_tensor("_rowwise_scale_inv"); + auto columnwise_data = get_tensor("_columnwise_data"); + auto columnwise_scale_inv = get_tensor("_columnwise_scale_inv"); + NVTE_CHECK(rowwise_data || columnwise_data, "FP8BlockwiseTensor has no data."); + + // Tensor options and dimensions + at::TensorOptions opts; + at::TensorOptions scale_opts; + opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); + scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); + + auto get_columnwise_shape = [&columnwise_data](bool all_gather_usage) -> std::vector { + if (!columnwise_data) { + return std::vector(); + } + if (all_gather_usage) { + return getTensorShape(*columnwise_data); + } + std::vector shape = getTensorShape(*columnwise_data); + std::vector shape_transposed(shape.size()); + for (size_t i = 0; i + 1 < shape.size(); ++i) { + shape_transposed[i] = shape[i + 1]; + } + if (shape.size() > 0) { + shape_transposed[shape.size() - 1] = shape[0]; + } + return shape_transposed; + }; + std::vector shape; + if (rowwise_data) { + shape = getTensorShape(*rowwise_data); + if (columnwise_data) { + auto expected_shape = get_columnwise_shape(all_gather_usage); + NVTE_CHECK(shape == expected_shape, "BlockwiseFP8 row-wise data (shape=", shape, + ") and column-wise data (shape=", expected_shape, ") do not match"); + } + } else { + shape = get_columnwise_shape(all_gather_usage); + } + std::vector torch_shape; + for (auto s : shape) { + torch_shape.emplace_back(static_cast(s)); + } + + // Coerce row-wise data + if (rowwise_usage) { + if (!rowwise_data) { + rowwise_data = at::empty(torch_shape, opts); + tensor.attr("_rowwise_data") = *rowwise_data; + } + if (!rowwise_scale_inv) { + auto scale_shape = get_scale_shape(shape, false); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; + rowwise_scale_inv = + at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); + tensor.attr("_rowwise_scale_inv") = *rowwise_scale_inv; + } + } else { // rowwise_usage == false + if (rowwise_data) { + rowwise_data.reset(); + tensor.attr("_rowwise_data") = py::none(); + } + if (rowwise_scale_inv) { + rowwise_scale_inv.reset(); + tensor.attr("_rowwise_scale_inv") = py::none(); + } + } + + // Coerce column-wise data + if (columnwise_usage) { + std::vector columnwise_shape; + std::vector torch_columnwise_shape; + if (torch_shape.size() > 0) { + if (!all_gather_usage) { + torch_columnwise_shape.reserve(torch_shape.size()); + columnwise_shape.reserve(shape.size()); + torch_columnwise_shape.push_back(torch_shape[torch_shape.size() - 1]); + columnwise_shape.push_back(shape[shape.size() - 1]); + for (size_t i = 0; i < torch_shape.size() - 1; ++i) { + torch_columnwise_shape.push_back(torch_shape[i]); + columnwise_shape.push_back(shape[i]); + } + } else { + // assert we are doing 1D scaling + NVTE_CHECK(block_scaling_dim == 1, + "Compact columnwise format is not supported for 128x128 2D block scaling."); + torch_columnwise_shape = torch_shape; + columnwise_shape = shape; + } + } + if (!columnwise_data) { + columnwise_data = at::empty(torch_columnwise_shape, opts); + tensor.attr("_columnwise_data") = *columnwise_data; + } + if (!columnwise_scale_inv) { + auto scale_shape = get_scale_shape(shape, true); + size_t sinv0 = scale_shape[0]; + size_t sinv1 = scale_shape[1]; + columnwise_scale_inv = + at::empty({static_cast(sinv0), static_cast(sinv1)}, scale_opts); + tensor.attr("_columnwise_scale_inv") = *columnwise_scale_inv; + } + } else { // columnwise_usage == false + if (columnwise_data) { + columnwise_data.reset(); + tensor.attr("_columnwise_data") = py::none(); + } + if (columnwise_scale_inv) { + columnwise_scale_inv.reset(); + tensor.attr("_columnwise_scale_inv") = py::none(); + } + } auto ret = TensorWrapper(is_2D_scaled ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D); diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index f2a6871a8..a5dae9f30 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -589,13 +589,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: # Quantize input tensor quantizer = ctx.input_quantizer - if ctx.backward_input_needs_gather and isinstance( - quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) - ): + if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): # All-gather is not supported with FP8 column-wise data - quantizer.set_usage(rowwise=True, columnwise=False) + quantizer.set_usage( + rowwise=True, + columnwise=not ctx.backward_input_needs_gather, + ) else: - quantizer.set_usage(rowwise=True, columnwise=True) + quantizer.set_usage(rowwise=False, columnwise=True) inputmat = quantizer(inputmat) else: if isinstance(inputmat, QuantizedTensorBase): From c5ee5fd01ba15d84f4b8ed2d8161ee339c20714e Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 6 Aug 2025 17:23:55 -0400 Subject: [PATCH 050/153] Revert "[JAX] Disable TE Norm Custom Calls" (#2035) Revert "[JAX] Disable TE Norm Custom Calls (#1993)" This reverts commit 6c970612715e2a493a2468256c05ce40a11e8556. --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 0d19785a0..fcc2108cc 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -34,7 +34,7 @@ class BasePrimitive(metaclass=ABCMeta): _is_enabled = True # Default list of primitives to disable for all recipes - _default_disable_names = ["GemmPrimitive", "NormFwdPrimitive", "NormBwdPrimitive"] + _default_disable_names = ["GemmPrimitive"] @classmethod def enabled(cls): From bfab8c679f17bed5b63ae5c904c205f164beaae4 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 7 Aug 2025 08:08:28 +0800 Subject: [PATCH 051/153] [Common] PDL for Quantization Kernels (#2001) * PDL for MXFP8 Quantize Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/util/cast_kernels.cuh | 54 +++++++++++-------- 1 file changed, 33 insertions(+), 21 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index fcf0a4084..5590cee10 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -203,6 +203,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], parity); + // Trigger the next kernel, so its TMA load can be overlapped with the current kernel + if (stage == STAGES - 1) { + cudaTriggerProgrammaticLaunchCompletion(); + } + float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; @@ -1121,6 +1126,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + cudaLaunchConfig_t cfg = {grid, block_size, dshmem_size, stream, NULL, 0}; + // This kernel will only be called on sm100+, so no need to check sm_arch + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; cfg.attrs = attribute; + cfg.numAttrs = 1; + switch (scaling_type) { case ScalingType::ROWWISE: cudaFuncSetAttribute( @@ -1128,13 +1140,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + cudaLaunchKernelEx( + &cfg, + cast_mxfp8_2D_kernel, + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); break; case ScalingType::COLWISE: cudaFuncSetAttribute( @@ -1142,13 +1154,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + cudaLaunchKernelEx( + &cfg, + cast_mxfp8_2D_kernel, + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); break; case ScalingType::BIDIMENSIONAL: cudaFuncSetAttribute( @@ -1156,13 +1168,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - cast_mxfp8_2D_kernel - <<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, - scale_stride_colwise); + cudaLaunchKernelEx( + &cfg, + cast_mxfp8_2D_kernel, + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); break; } From dd083bdfca296294322f2b5f8143ce54ee15db22 Mon Sep 17 00:00:00 2001 From: ldl <140483453+lvdunlin@users.noreply.github.com> Date: Thu, 7 Aug 2025 08:14:58 +0800 Subject: [PATCH 052/153] [PyTorch] Fix numeric overflow caused by int-type parameters and return value in the roundup function (#2034) Signed-off-by: lvdunlin Co-authored-by: lvdunlin Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/csrc/common.cpp | 2 +- transformer_engine/pytorch/csrc/common.h | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/common.cpp b/transformer_engine/pytorch/csrc/common.cpp index ab3b7abec..dffb899f7 100644 --- a/transformer_engine/pytorch/csrc/common.cpp +++ b/transformer_engine/pytorch/csrc/common.cpp @@ -286,7 +286,7 @@ std::vector convertShape(const NVTEShape& shape) { return std::vector(shape.data, shape.data + shape.ndim); } -int roundup(const int value, const int multiple) { +size_t roundup(const size_t value, const size_t multiple) { assert(multiple > 0); return ((value + multiple - 1) / multiple) * multiple; } diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index 45e3291ef..2d35de852 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -417,7 +417,7 @@ void* getDataPtr(at::Tensor tensor, int offset = 0); std::vector convertShape(const NVTEShape& shape); -int roundup(const int value, const int multiple); +size_t roundup(const size_t value, const size_t multiple); NVTEShape convertTorchShape(const c10::IntArrayRef torch_shape); } // namespace transformer_engine::pytorch From cae1c436027cc028ee49c83463141ba67f0adca0 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Thu, 7 Aug 2025 15:27:23 -0400 Subject: [PATCH 053/153] [JAX] TE Gemm custom call clean up (#2030) * rm batch_dim, sequence_dim, sequence_parallel_output Signed-off-by: Phuong Nguyen * rm lhs_quantized_colwise and rhs_quantized_colwise Signed-off-by: Phuong Nguyen * rm unnecessary transpose_batch_sequence arg from some modules Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- tests/jax/test_distributed_layernorm_mlp.py | 2 - transformer_engine/jax/cpp_extensions/gemm.py | 199 +++--------------- transformer_engine/jax/dense.py | 73 +------ transformer_engine/jax/flax/module.py | 48 +---- transformer_engine/jax/flax/transformer.py | 9 - transformer_engine/jax/layernorm_dense.py | 45 +--- transformer_engine/jax/layernorm_mlp.py | 52 +---- transformer_engine/jax/sharding.py | 24 --- 8 files changed, 41 insertions(+), 411 deletions(-) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 015b37603..79186aa47 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -333,7 +333,6 @@ def _test_layernorm_mlp( with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): ln_mlp_single = LayerNormMLP( layernorm_type=layernorm_type, - transpose_batch_sequence=False, # input: [batch, seqlen, hidden] intermediate_dim=INTERMEDIATE, activations=activation_type, use_bias=use_bias, @@ -352,7 +351,6 @@ def _test_layernorm_mlp( ): ln_mlp_sharded = LayerNormMLP( layernorm_type=layernorm_type, - transpose_batch_sequence=False, intermediate_dim=INTERMEDIATE, activations=activation_type, scale_axes=LN_SCALE_AXES, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 897baaa26..5c2438906 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -155,7 +155,7 @@ class GemmPrimitive(BasePrimitive): name = "te_gemm_ffi" multiple_results = True - impl_static_args = (6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17) + impl_static_args = (6, 7, 8, 9, 10, 11, 12) inner_primitive = None outer_primitive = None @@ -169,22 +169,13 @@ def abstract( gelu_input, out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, - sequence_parallel_output, - sequence_dim, ): - del lhs_quantized_colwise, rhs_quantized_colwise, use_split_accumulator - del ( - sequence_parallel_output, - sequence_dim, - ) + del use_split_accumulator def _dims_are_consecutive(dims): if len(dims) <= 1: @@ -207,27 +198,6 @@ def _dims_are_consecutive(dims): f"{rhs_contracting_dims}." ) - ( - lhs_batch_dims, - rhs_batch_dims, - ) = map(sanitize_dims, operand_ndims, batched_dims) - assert _dims_are_consecutive(lhs_batch_dims), ( - "cuBLAS GEMM expected consecutive batch dimensions for LHS operand, but got " - f"{lhs_batch_dims}." - ) - assert _dims_are_consecutive(rhs_batch_dims), ( - "cuBLAS GEMM expected consecutive batch dimensions for RHS operand, but got " - f"{rhs_batch_dims}." - ) - if len(lhs_batch_dims) == 0: - assert ( - len(rhs_batch_dims) == 0 - ), "cuBLAS GEMM RHS operand cannot be batched if LHS operand is not batched." - elif len(rhs_batch_dims) != 0: - assert all(bdim in lhs_contracting_dims for bdim in lhs_batch_dims) and all( - bdim in rhs_contracting_dims for bdim in rhs_batch_dims - ), "cuBLAS GEMM batched dimensions must be contracting when both operands are batched." - lhs_contracting_size, rhs_contracting_size = map( lambda shape, dims: reduce(operator.mul, [shape[dim] for dim in dims]), (lhs.shape, rhs.shape), @@ -341,19 +311,13 @@ def lowering( gelu_input, out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, - sequence_parallel_output, - sequence_dim, ): - del batched_dims, lhs_quantized_colwise, rhs_quantized_colwise, out_dtype - del sequence_parallel_output, sequence_dim + del out_dtype lhs_aval, _, rhs_aval, *_ = ctx.avals_in lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_aval.ndim, rhs_aval.ndim), contracting_dims) @@ -395,16 +359,11 @@ def impl( gelu_input, out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, - sequence_parallel_output, - sequence_dim, ): lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) lhs_transposed, rhs_transposed = _get_gemm_layout( @@ -414,14 +373,14 @@ def impl( lhs_scale_inv, scaling_mode, lhs.shape, - is_colwise=lhs_quantized_colwise, + is_colwise=lhs_transposed, flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), ) rhs_scale_inv = apply_padding_to_scale_inv( rhs_scale_inv, scaling_mode, rhs.shape, - is_colwise=rhs_quantized_colwise, + is_colwise=not rhs_transposed, flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, ) @@ -434,55 +393,34 @@ def impl( gelu_input, out_dtype=out_dtype, contracting_dims=contracting_dims, - batched_dims=batched_dims, - lhs_quantized_colwise=lhs_quantized_colwise, - rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, - sequence_parallel_output=sequence_parallel_output, - sequence_dim=sequence_dim, ) return outputs[:-3] # discard workspace arrays @staticmethod def batcher( batched_args, - jax_batch_dims, + batch_dims, out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, - sequence_parallel_output, - sequence_dim, ): assert GemmPrimitive.outer_primitive is not None - lhs, _, rhs, *_ = batched_args - lhs_bdims, _, rhs_bdims, *_ = jax_batch_dims - arg_lhs_bdims, arg_rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) - arg_lhs_bdims = (None,) if len(arg_lhs_bdims) == 0 else arg_lhs_bdims - assert all(bdim == arg_bdim for bdim, arg_bdim in zip(lhs_bdims, arg_lhs_bdims)), ( - "User-specified batch dimension(s) for cuBLAS GEMM LHS operand does not match batch " - f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." - ) - arg_rhs_bdims = (None,) if len(arg_rhs_bdims) == 0 else arg_rhs_bdims - assert all(bdim == arg_bdim for bdim, arg_bdim in zip(rhs_bdims, arg_rhs_bdims)), ( - "User-specified batch dimension(s) for cuBLAS GEMM RHS operand does not match batch " - f"dimensions inferred by JAX/XLA, expected {lhs_bdims} but got {arg_lhs_bdims}." - ) + lhs_bdims, _, rhs_bdims, *_ = batch_dims - # Output is batched like the non-contracting batch dimensions of the LHS operand - lhs_cdims = sanitize_dims(lhs.ndim, contracting_dims) - lhs_non_contracting_bdims = tuple(dim for dim in lhs_bdims if dim not in lhs_cdims) - out_bdims = (None,) if len(lhs_non_contracting_bdims) == 0 else lhs_non_contracting_bdims + # Batched GEMM is not supported + assert ( + lhs_bdims is None and rhs_bdims is None + ), f"(Batching is not supported, got lhs_bdims={lhs_bdims}, rhs_bdims={rhs_bdims})" + out_bdims = (None,) # Bias gradient is never batched bias_bdims = (None,) @@ -497,16 +435,11 @@ def batcher( *batched_args, out_dtype=out_dtype, contracting_dims=contracting_dims, - batched_dims=batched_dims, - lhs_quantized_colwise=lhs_quantized_colwise, - rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, - sequence_parallel_output=sequence_parallel_output, - sequence_dim=sequence_dim, ), (out_bdims, bias_bdims, pre_gelu_bdims), ) @@ -515,11 +448,7 @@ def batcher( def _parse_operand_output_specs( arg_infos, contracting_dims, - batched_dims, - sequence_parallel_output, - sequence_dim, ): - del sequence_dim, sequence_parallel_output, batched_dims lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) @@ -586,44 +515,30 @@ def _parse_operand_output_specs( (lhs_specs, rhs_specs, bias_specs, gelu_specs), (out_specs, bias_specs, gelu_specs), reduce_spec, - 0, ) @staticmethod def infer_sharding_from_operands( out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, - sequence_parallel_output, - sequence_dim, mesh, arg_infos, result_infos, ): del ( out_dtype, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, grad, ) del use_split_accumulator, result_infos - (_, (out_specs, dbias_specs, pre_gelu_specs), *_) = ( - GemmPrimitive._parse_operand_output_specs( - arg_infos, - contracting_dims, - batched_dims, - sequence_parallel_output, - sequence_dim, - ) + (_, (out_specs, dbias_specs, pre_gelu_specs), _) = ( + GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) ) out_sharding = NamedSharding(mesh, PartitionSpec(*out_specs)) @@ -643,16 +558,11 @@ def infer_sharding_from_operands( def partition( out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, - sequence_parallel_output, - sequence_dim, mesh, arg_infos, result_infos, @@ -663,14 +573,7 @@ def partition( (lhs_specs, rhs_specs, bias_input_specs, gelu_input_specs), (out_specs, dbias_specs, pre_gelu_specs), reduce_spec, - _, - ) = GemmPrimitive._parse_operand_output_specs( - arg_infos, - contracting_dims, - batched_dims, - sequence_parallel_output, - sequence_dim, - ) + ) = GemmPrimitive._parse_operand_output_specs(arg_infos, contracting_dims) # Assemble argument shardings # NOTE: Block scale inverses match their operands, but tensor scale inverses are unsharded. @@ -717,19 +620,14 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): gelu_input, out_dtype=out_dtype, contracting_dims=contracting_dims, - batched_dims=batched_dims, - lhs_quantized_colwise=lhs_quantized_colwise, - rhs_quantized_colwise=rhs_quantized_colwise, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, - sequence_parallel_output=sequence_parallel_output, - sequence_dim=sequence_dim, ) - # All-Reduce/Reduce-Scatter GEMM output + # All-Reduce GEMM output if reduce_spec is not None: outputs[0] = jax.lax.psum(outputs[0], reduce_spec) @@ -741,54 +639,42 @@ def _sharded_impl(lhs, lhs_scale_inv, rhs, rhs_scale_inv, bias, gelu_input): def shardy_sharding_rule( out_dtype, contracting_dims, - batched_dims, - lhs_quantized_colwise, - rhs_quantized_colwise, scaling_mode, fuse_bias, fuse_gelu, grad, use_split_accumulator, - sequence_parallel_output, - sequence_dim, mesh, operand_types, result_types, ): - del lhs_quantized_colwise, rhs_quantized_colwise, out_dtype, grad, use_split_accumulator - del sequence_parallel_output, sequence_dim, mesh, result_types + del out_dtype, grad, use_split_accumulator + del mesh, result_types prefix = "GemmPrimitive_" - def _generate_operand_rules(name, ndim, cdims, bdims): + def _generate_operand_rules(name, ndim, cdims): specs = [] - ldims = tuple(i for i in range(ndim) if i not in bdims + cdims) + ldims = tuple(i for i in range(ndim) if i not in cdims) for i in range(ndim): dim_name = None - if i in bdims: - dim_idx = bdims.index(i) if len(bdims) > 1 else "" - dim_name = f"b{dim_idx}" - elif i in cdims: - dim_idx = cdims.index(i) if len(cdims) > 1 else "" + if i in cdims: + dim_idx = cdims.index(i) dim_name = f"k{dim_idx}" else: - dim_idx = ldims.index(i) if len(ldims) > 1 else "" + dim_idx = ldims.index(i) dim_name = f"{name}_l{dim_idx}" specs.append(prefix + dim_name) return specs lhs, _, rhs, *_ = operand_types operand_ndims = (len(lhs.shape), len(rhs.shape)) - (lhs_cdims, rhs_cdims), (lhs_bdims, rhs_bdims) = map( - lambda dims: map(sanitize_dims, operand_ndims, dims), - (contracting_dims, batched_dims), - ) + (lhs_cdims, rhs_cdims) = map(sanitize_dims, operand_ndims, contracting_dims) lhs_specs, rhs_specs = map( _generate_operand_rules, ("lhs", "rhs"), operand_ndims, (lhs_cdims, rhs_cdims), - (lhs_bdims, rhs_bdims), ) lhs_scale_specs = ("…1",) rhs_scale_specs = ("…2",) @@ -840,13 +726,10 @@ def _te_gemm( lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), - batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, - sequence_parallel_output: bool = False, - sequence_dim: int = None, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands @@ -857,7 +740,6 @@ def _te_gemm( scaling_mode = ScalingMode.NO_SCALING lhs_is_transposed, rhs_is_transposed = _get_gemm_layout((lhs.ndim, rhs.ndim), contracting_dims) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) - lhs_bdims, rhs_bdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), batched_dims) # Quantize operands (if necessary) lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) @@ -876,7 +758,6 @@ def _te_gemm( lhs_scale_inv = lhs_q.scale_inv if lhs_q.data_layout == "T": lhs_cdims = transpose_dims(lhs_q.ndim, lhs_cdims, flatten_axis=lhs_q.flatten_axis) - lhs_bdims = transpose_dims(lhs_q.ndim, lhs_bdims, flatten_axis=lhs_q.flatten_axis) if isinstance(rhs_q, ScaledTensor): assert isinstance(lhs_q, ScaledTensor) or lhs_quantizer is not None, ( @@ -894,7 +775,6 @@ def _te_gemm( rhs_scale_inv = rhs_q.scale_inv if rhs_q.data_layout == "T": rhs_cdims = transpose_dims(rhs_q.ndim, rhs_cdims, flatten_axis=rhs_q.flatten_axis) - rhs_bdims = transpose_dims(rhs_q.ndim, rhs_bdims, flatten_axis=rhs_q.flatten_axis) # Dummy empties for bias and gelu out_dtype = lhs_q.dq_dtype if isinstance(lhs_q, ScaledTensor) else lhs_data.dtype @@ -912,16 +792,11 @@ def _te_gemm( gelu_input, out_dtype=out_dtype, contracting_dims=(lhs_cdims, rhs_cdims), - batched_dims=(lhs_bdims, rhs_bdims), - lhs_quantized_colwise=lhs_q.is_colwise if isinstance(lhs_q, ScaledTensor) else False, - rhs_quantized_colwise=rhs_q.is_colwise if isinstance(rhs_q, ScaledTensor) else False, scaling_mode=scaling_mode, fuse_bias=fuse_bias, fuse_gelu=fuse_gelu, grad=grad, use_split_accumulator=use_split_accumulator, - sequence_parallel_output=sequence_parallel_output, - sequence_dim=sequence_dim, ) @@ -1124,10 +999,8 @@ def _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision): (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) = dim_nums if lhs.data_layout == "T": lhs_contract = transpose_dims(lhs.data.ndim, lhs_contract, flatten_axis=lhs.flatten_axis) - lhs_batch = transpose_dims(lhs.data.ndim, lhs_batch, flatten_axis=lhs.flatten_axis) if rhs.data_layout == "T": rhs_contract = transpose_dims(rhs.data.ndim, rhs_contract, flatten_axis=rhs.flatten_axis) - rhs_batch = transpose_dims(rhs.data.ndim, rhs_batch, flatten_axis=rhs.flatten_axis) dim_nums = (lhs_contract, rhs_contract), (lhs_batch, rhs_batch) @@ -1239,7 +1112,6 @@ def gemm( lhs: Union[jnp.ndarray, ScaledTensor], rhs: Union[jnp.ndarray, ScaledTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), - batched_dims: Tuple[Sequence[int], Sequence[int]] = ((), ()), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, **kwargs, @@ -1258,11 +1130,6 @@ def gemm( Object for down-casting the RHS operand for quantized GEMM. contracting_dims: Tuple[Sequence[int], Sequence[int]], default = ((-1, ), (0, )) Tuple of sequences representing the contracting dimensions of the operands. - batched_dims: Tuple[Sequence[int], Sequence[int]], default = ((), ()), - Tuple of sequences representing the batched dimensions of the operands. This is *not* used - to perform a batched matrix multiplication, but it is required for TE's custom cuBLAS GEMM - call to avoid a potentially undesirable reduction in any batched contracting dimensions - when invoked with sharded operands (e.g. when computing weight gradients in a Flax module). bias: jax.Array, default = None Optional additive bias term, required for forward GEMM with bias fusion. Only supported with TE's custom call to cuBLAS GEMM. @@ -1282,15 +1149,6 @@ def gemm( Enable promoting some intermediate sums to higher precision when accumulating the result in the cuBLAS GEMM kernel. Disabling this trades off numerical accuracy for speed. Only supported with TE's custom call to cuBLAS GEMM. - sequence_parallel_output: bool, default = False - Produces an output with the first non-batched non-contracting dimension sharded with the - same spec as operand contracting dimensions. This effectively converts the `jax.lax.psum` - for the GEMM output into a `jax.lax.psum_scatter`. Only supported with TE's custom call to - cuBLAS GEMM. - sequence_dim: int, default = None - Index of the sequence dimension for the LHS operand. This controls which dimension of the - GEMM output is scattered when `sequence_parallel_output=True`. When `None`, the first - non-batched non-contracting dimension is assumed to be the sequence dimension. Returns ------- @@ -1329,14 +1187,6 @@ def gemm( "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backends used when the custom cuBLAS " "GEMM primitive is disabled." ) - assert ( - not kwargs.get("sequence_parallel_output", False) - and kwargs.get("sequence_dim", None) is None - ), ( - "TE GEMM was invoked with sequence-parallelism options that are not supported by the " - "`jax.lax.dot_general` and `jax.nn.scaled_matmul` backedns used when the custom cuBLAS " - "GEMM primitive is disabled." - ) return _jax_gemm(lhs, rhs, contracting_dims, lhs_quantizer, rhs_quantizer) outputs = _te_gemm( @@ -1345,7 +1195,6 @@ def gemm( lhs_quantizer=lhs_quantizer, rhs_quantizer=rhs_quantizer, contracting_dims=contracting_dims, - batched_dims=batched_dims, **kwargs, ) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 5be551dbd..4a50fe0e5 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -8,7 +8,7 @@ It implements matrix multiplication with optional bias addition and supports customizable contracting dimensions for flexible tensor operations. """ -import warnings + from typing import Tuple, Sequence from functools import partial import jax @@ -22,17 +22,6 @@ TensorUsage, ) -from .sharding import get_sequence_parallel_dim - -DENSE_BATCH_FIRST_WARNING_ISSUED = False - - -def _issue_batch_first_warning(msg): - global DENSE_BATCH_FIRST_WARNING_ISSUED - if not DENSE_BATCH_FIRST_WARNING_ISSUED: - warnings.warn(msg, UserWarning) - DENSE_BATCH_FIRST_WARNING_ISSUED = True - def dense( x: jnp.ndarray, @@ -41,8 +30,6 @@ def dense( contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (0,)), input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, - batch_first: bool = True, - sequence_parallel_output: bool = False, quantizer_set: QuantizerSet = noop_quantizer_set, ): """Perform dense layer transformation with optional quantization. @@ -56,9 +43,6 @@ def dense( kernel: Weight matrix for the dense layer transformation bias: Optional bias tensor to add after the transformation contracting_dims: Tuple of sequences specifying which dimensions to contract - batch_first: Assume that X is batched in the first dimension. - sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only - supported for TE custom GEMM with row-parallel kernel axes. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: @@ -79,14 +63,19 @@ def dense( contracting_dims, input_axes, kernel_axes, - batch_first, - sequence_parallel_output, quantizer_set, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 4, 5, 6, 7)) +@partial( + jax.custom_vjp, + nondiff_argnums=( + 3, + 4, + 5, + ), +) def _dense( x, kernel, @@ -94,8 +83,6 @@ def _dense( contracting_dims, input_axes, kernel_axes, - batch_first, - sequence_parallel_output, quantizer_set, ): """Internal implementation of dense layer transformation with custom VJP. @@ -110,9 +97,6 @@ def _dense( contracting_dims: Contracting dimensions specification input_axes: Logical axes for sharding the activation input kernel_axes: Logical axes for sharding the weight matrix - batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. - sequence_parallel_output: Produce an output that sharded in the first non-batched dim. Only - supported for TE custom GEMM with row-parallel kernel axes. quantizer_set: QuantizerSet which contains quantizers for different tensor types Returns: @@ -125,8 +109,6 @@ def _dense( contracting_dims, input_axes, kernel_axes, - batch_first, - sequence_parallel_output, quantizer_set, ) return output @@ -139,8 +121,6 @@ def _dense_fwd_rule( contracting_dims, input_axes, kernel_axes, - batch_first, - sequence_parallel_output, quantizer_set, ): """Forward pass rule for dense layer transformation. @@ -159,23 +139,6 @@ def _dense_fwd_rule( not x_is_transposed and not k_is_transposed ), "Dense layer only supports `NN` layout inputs, i.e. non-transposed X and Kernel." - # Determine X batch dimension - # - If `batch_first=True` -> (batch, leading..., contracting...) - # - Otherwise -> (leading..., batch, contracting...) - # NOTE: Always assume a single batch dimension - x_bdim = None - num_cdims = len(x_contracting_dims) - if x.ndim >= num_cdims + 2: - # Assume X is batched if it has at least +2 dimensions more than the number of contracting - # dimensions. - if not batch_first: - _issue_batch_first_warning( - "TE/JAX `dense()` layer implementation does not officially support sequence-first " - "inputs and may produce incorrect results when `batch_first=False`. Use " - "sequence-first inputs at your own discretion.", - ) - x_bdim = 0 if batch_first else x.ndim - num_cdims - 1 - flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) @@ -198,10 +161,8 @@ def _dense_fwd_rule( casted_x.get_tensor(usage=TensorUsage.LHS), casted_kernel.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - batched_dims=((x_bdim,), ()), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, - sequence_parallel_output=sequence_parallel_output and not tex.gemm_uses_jax_dot(), ) if use_bias and tex.gemm_uses_jax_dot(): @@ -216,13 +177,12 @@ def _dense_fwd_rule( use_bias, quantizer_set, flatten_axis_k, - x_bdim, ) return output, ctx def _dense_bwd_rule( - contracting_dims, input_axes, kernel_axes, batch_first, sequence_parallel_output, ctx, grad + contracting_dims, input_axes, kernel_axes, ctx, grad ): # pylint: disable=unused-argument """Backward pass rule for dense layer transformation. @@ -237,7 +197,6 @@ def _dense_bwd_rule( use_bias, quantizer_set, flatten_axis_k, - x_bdim, ) = ctx fwd_x_contracting_dims, fwd_k_contracting_dims = map( @@ -262,21 +221,10 @@ def _dense_bwd_rule( dim for dim in range(len(kernel_shape)) if dim not in fwd_k_contracting_dims ) - # Get sequence-parallel dimension of the FWD input (if it exists) - sequence_dim = get_sequence_parallel_dim(input_axes, fwd_x_contracting_dims, (x_bdim,)) dgrad = tex.gemm( casted_grad.get_tensor(usage=TensorUsage.LHS), casted_kernel_rhs, contracting_dims=(g_contracting_dim, k_contracting_dim), - batched_dims=((x_bdim,), ()), - sequence_parallel_output=( - sequence_dim is not None - and not sequence_parallel_output - and not tex.gemm_uses_jax_dot() - ), - sequence_dim=( - None if sequence_parallel_output or tex.gemm_uses_jax_dot() else sequence_dim - ), ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, input_axes) @@ -290,7 +238,6 @@ def _dense_bwd_rule( casted_x_lhs, casted_grad.get_tensor(usage=TensorUsage.RHS), contracting_dims=(x_contracting_dim, g_contracting_dim), - batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 60c39a037..e923991e4 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -15,12 +15,12 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name -from ..dense import dense, _issue_batch_first_warning as _dense_warning +from ..dense import dense from ..layernorm import canonicalize_norm_type from ..layernorm import layernorm -from ..layernorm_dense import layernorm_dense, _issue_batch_first_warning as _ln_dense_warning -from ..layernorm_mlp import layernorm_mlp, _issue_batch_first_warning as _ln_mlp_warning +from ..layernorm_dense import layernorm_dense +from ..layernorm_mlp import layernorm_mlp from ..activation import activation from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes @@ -273,10 +273,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = False - Indicate whether the input tensors were switched axis of batch - and sequence length dimension. If set to True, the input tensors - should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). """ epsilon: float = 1e-6 @@ -287,7 +283,6 @@ class LayerNorm(nn.Module): # pylint: disable=too-few-public-methods bias_init: Initializer = nn.initializers.zeros bias_axes: Tuple[str, ...] = ("embed",) dtype: DType = jnp.float32 - transpose_batch_sequence: bool = False def __post_init__(self): self.scale_init = _obtain_default_layernorm_scale_init_if_need( @@ -414,17 +409,11 @@ class DenseGeneral(TransformerEngineBase): Indicate the logical axes of sharding constraint to the input, like (BATCH_AXES, SEQLEN_AXES, HIDDEN_AXES). Default is None, which means not to insert sharding constraint. - sequence_parallel_output: bool, default = False - Produce a sequence-parallel output with the first non-batch dimension sharded over Optimization parameters ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True - Indicate whether the input tensors were switched axis of batch - and sequence length dimension. If set to True, the input tensors - should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). """ features: Union[Iterable[int], int] @@ -438,17 +427,9 @@ class DenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - transpose_batch_sequence: bool = False input_axes: Tuple[str, ...] = () - sequence_parallel_output: bool = False def __post_init__(self): - if self.transpose_batch_sequence: - _dense_warning( - "TE/JAX DenseGeneral() module does not officially support sequence-first inputs " - "and may produce incorrect results when `transpose_batch_sequence=True`. Use " - "sequence-first inputs at your own discretion." - ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype @@ -513,7 +494,6 @@ def __call__(self, inputs: Array) -> Array: input_axes=self.input_axes, kernel_axes=self.kernel_axes, quantizer_set=quantizer_set, - sequence_parallel_output=self.sequence_parallel_output, ) if self.enable_low_rank_adaptation: @@ -631,10 +611,6 @@ class LayerNormDenseGeneral(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True - Indicate whether the input tensors were switched axis of batch - and sequence length dimension. If set to True, the input tensors - should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). depth_scaling: float, default = None The factor to scale the output from `DenseGeneral`. It should be a float value or None. When None is set, then no scaling is applied. @@ -660,18 +636,11 @@ class LayerNormDenseGeneral(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_input_axes: Tuple[str, ...] = None depth_scaling: float = None def __post_init__(self): - if self.transpose_batch_sequence: - _ln_dense_warning( - "TE/JAX LayerNormDenseGeneral() module does not officially support sequence-first " - "inputs and may produce incorrect results when `transpose_batch_sequence=True`. " - "Use sequence-first inputs at your own discretion." - ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, @@ -949,10 +918,6 @@ class LayerNormMLP(TransformerEngineBase): ----------------------- dtype: jax.numpy.dtype, default = jax.numpy.float32 The data type used to allocate the initial parameters. - transpose_batch_sequence : bool, default = True - Indicate whether the input tensors were switched axis of batch - and sequence length dimension. If set to True, the input tensors - should be in (seqlen, batch, hidden), otherwise (batch, seqlen, hidden). """ intermediate_dim: int = 2048 @@ -981,7 +946,6 @@ class LayerNormMLP(TransformerEngineBase): low_rank_adaptation_alpha: float = None axis: Union[Iterable[int], int] = -1 dtype: DType = jnp.float32 - transpose_batch_sequence: bool = True layernorm_input_axes: Tuple[str, ...] = None dot_1_input_axes: Tuple[str, ...] = None dot_2_input_axes: Tuple[str, ...] = None @@ -989,12 +953,6 @@ class LayerNormMLP(TransformerEngineBase): ffn2_ckpt_name: str = "ffn2" def __post_init__(self): - if self.transpose_batch_sequence: - _ln_mlp_warning( - "TE/JAX LayerNormMLP() module does not officially support sequence-first inputs " - "and may produce incorrect results when `transpose_batch_sequence=True`. Use " - "sequence-first inputs at your own discretion." - ) if self.kernel_init is None: self.kernel_init = nn.initializers.variance_scaling( 1.0, "fan_in", "truncated_normal", dtype=self.dtype diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 2d13f25ca..d85593c1e 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -1167,7 +1167,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): epsilon=self.layernorm_epsilon, axis=-1, features=(3, self.num_attention_heads * self.head_dim), - transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=self.return_layernorm_output, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), @@ -1194,7 +1193,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): epsilon=self.layernorm_epsilon, axis=-1, features=self.num_attention_heads * self.head_dim, - transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=(self.return_layernorm_output or is_self_attn), scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), @@ -1219,7 +1217,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): kv_proj = DenseGeneral( axis=-1, features=(2, self.num_gqa_groups * self.head_dim), - transpose_batch_sequence=self.transpose_batch_sequence, kernel_axes=(W_FSDP_AXES, W_JOINED_AXES, W_TP_AXES), kernel_init=kv_init, use_bias=self.use_bias, @@ -1238,7 +1235,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): DenseGeneral, axis=-1, features=self.num_gqa_groups * self.head_dim, - transpose_batch_sequence=self.transpose_batch_sequence, kernel_axes=(W_FSDP_AXES, W_TP_AXES), use_bias=self.use_bias, bias_init=self.bias_init, @@ -1255,7 +1251,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): epsilon=self.layernorm_epsilon, axis=-1, features=self.num_attention_heads * self.head_dim, - transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=True, scale_axes=(W_NO_SHARD_AXES,), ln_bias_axes=(W_NO_SHARD_AXES,), @@ -1420,7 +1415,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): out = DenseGeneral( features=inputs_q.shape[-1], - transpose_batch_sequence=self.transpose_batch_sequence, axis=-1, kernel_init=self.kernel_init, kernel_axes=(W_TP_AXES, W_FSDP_AXES), @@ -1432,7 +1426,6 @@ def generate_batch_seqlen_logical_axes(is_sharded_seq): low_rank_adaptation_alpha=self.low_rank_adaptation_alpha, dtype=self.dtype, name="out", - sequence_parallel_output=self.enable_sequence_parallel, )(x) out = checkpoint_name(out, "out_proj") @@ -2023,7 +2016,6 @@ def hidden_dropout(x, deterministic): layernorm_type=self.layernorm_type, zero_centered_gamma=self.zero_centered_gamma, epsilon=self.layernorm_epsilon, - transpose_batch_sequence=self.transpose_batch_sequence, return_layernorm_output=self.apply_residual_connection_post_layernorm, intermediate_dim=self.mlp_hidden_size, activations=self.mlp_activations, @@ -2078,7 +2070,6 @@ def hidden_dropout(x, deterministic): epsilon=self.layernorm_epsilon, scale_axes=(W_NO_SHARD_AXES,), bias_axes=(W_NO_SHARD_AXES,), - transpose_batch_sequence=self.transpose_batch_sequence, dtype=self.dtype, name="output_layernorm", )(z) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index c616aa699..b830cdb4f 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -9,7 +9,6 @@ distributed training through sharding constraints. """ -import warnings from functools import partial from typing import Tuple @@ -24,17 +23,6 @@ with_sharding_constraint_by_logical_axes, TensorUsage, ) -from .sharding import get_sequence_parallel_dim - - -LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = False - - -def _issue_batch_first_warning(msg): - global LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED - if not LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED: - warnings.warn(msg, UserWarning) - LAYERNORM_DENSE_BATCH_FIRST_WARNING_ISSUED = True def layernorm_dense( @@ -49,7 +37,6 @@ def layernorm_dense( layernorm_input_axes: Tuple[str, ...] = None, dot_input_axes: Tuple[str, ...] = None, kernel_axes: Tuple[str, ...] = None, - batch_first: bool = True, quantizer_set: QuantizerSet = noop_quantizer_set, ) -> jnp.ndarray: """Apply layer normalization followed by dense layer transformation. @@ -70,7 +57,6 @@ def layernorm_dense( layernorm_input_axes: Logical axes for sharding the layernorm input dot_input_axes: Logical axes for sharding the matrix multiplication input kernel_axes: Logical axes for sharding the weight matrix - batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_set: Set of quantizers for different tensor types Returns: @@ -94,7 +80,6 @@ def layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, - batch_first, quantizer_set, ) return output @@ -109,7 +94,6 @@ def layernorm_dense( 8, 9, 10, - 11, ), ) def _layernorm_dense( @@ -124,7 +108,6 @@ def _layernorm_dense( layernorm_input_axes: Tuple[str, ...], dot_input_axes: Tuple[str, ...], kernel_axes: Tuple[str, ...], - batch_first: bool, quantizer_set, ): """Internal implementation of layernorm_dense with custom VJP. @@ -144,7 +127,6 @@ def _layernorm_dense( epsilon: Small constant for numerical stability layernorm_input_axes: Logical axes for layernorm sharding dot_input_axes: Logical axes for matrix multiplication sharding - batch_first: Assume that X is batched in the first dimension. quantizer_set: Set of quantizers Returns: @@ -162,7 +144,6 @@ def _layernorm_dense( layernorm_input_axes, dot_input_axes, kernel_axes, - batch_first, quantizer_set, ) return output @@ -180,7 +161,6 @@ def _layernorm_dense_fwd_rule( layernorm_input_axes, dot_input_axes, kernel_axes, - batch_first, quantizer_set, ): """Forward pass rule for layernorm_dense. @@ -198,17 +178,6 @@ def _layernorm_dense_fwd_rule( k_contracting_dims = (0,) assert x.shape[-1] == kernel.shape[0] - x_bdim = None - if x.ndim > 2: - if not batch_first: - _issue_batch_first_warning( - "TE/JAX `layernorm_dense()` fused-layer implementation does not officially " - "support sequence-first inputs and may produce incorrect results when " - "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " - "inputs at your own discretion." - ) - x_bdim = 0 if batch_first else x.ndim - 2 - x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) casted_ln_out, mu, rsigma = tex.normalization_fwd( @@ -237,7 +206,6 @@ def _layernorm_dense_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - batched_dims=((x_bdim,), ()), bias=bias if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias if not tex.gemm_uses_jax_dot() else False, ) @@ -261,7 +229,6 @@ def _layernorm_dense_fwd_rule( use_bias, quantizer_set, flatten_axis, - x_bdim, ) return output, ctx @@ -272,9 +239,8 @@ def _layernorm_dense_bwd_rule( zero_centered_gamma, epsilon, layernorm_input_axes, - dot_input_axes, # pylint: disable=unused-argument + dot_input_axes, kernel_axes, - batch_first, # pylint: disable=unused-argument ctx, grad, ): @@ -289,6 +255,7 @@ def _layernorm_dense_bwd_rule( Returns: Tuple of gradients for all input parameters """ + del dot_input_axes ( casted_ln_out, casted_kernel, @@ -304,7 +271,6 @@ def _layernorm_dense_bwd_rule( use_bias, quantizer_set, flatten_axis, - x_bdim, ) = ctx casted_grad, dbias = tex.quantize_dbias( @@ -325,16 +291,10 @@ def _layernorm_dense_bwd_rule( ) # NT GEMM - sequence_dim = get_sequence_parallel_dim( - layernorm_input_axes, x_contracting_dims_in_fwd, (x_bdim,) - ) dgrad = tex.gemm( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel, contracting_dims=(g_constracting_dim, k_constracting_dim), - batched_dims=((x_bdim,), ()), - sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(), - sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None, ) dgrad = with_sharding_constraint_by_logical_axes(dgrad, layernorm_input_axes) @@ -348,7 +308,6 @@ def _layernorm_dense_bwd_rule( casted_ln_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_constracting_dim, g_constracting_dim), - batched_dims=((x_bdim,), (x_bdim,)), ) wgrad = with_sharding_constraint_by_logical_axes(wgrad, kernel_axes) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 5b738e46b..8727ea7e3 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -13,7 +13,6 @@ quantization, and distributed training through sharding constraints. """ -import warnings from typing import List, Tuple, Sequence, Union, Callable from functools import partial @@ -29,19 +28,6 @@ noop_quantizer_set, TensorUsage, ) -from .sharding import ( - get_sequence_parallel_dim, -) - - -LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = False - - -def _issue_batch_first_warning(msg): - global LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED - if not LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED: - warnings.warn(msg, UserWarning) - LAYERNORM_MLP_BATCH_FIRST_WARNING_ISSUED = True def layernorm_mlp( @@ -61,7 +47,6 @@ def layernorm_mlp( ffn1_ckpt_name: str = "ffn1", ffn2_ckpt_name: str = "ffn2", activation_type: Sequence[Union[str, Callable]] = ("gelu",), - batch_first: bool = True, quantizer_sets: Tuple[QuantizerSet] = (noop_quantizer_set, noop_quantizer_set), ) -> jnp.ndarray: """Apply layer normalization followed by MLP block. @@ -93,7 +78,6 @@ def layernorm_mlp( ffn1_ckpt_name: Name for checkpointing the first feed-forward network ffn2_ckpt_name: Name for checkpointing the second feed-forward network activation_type: Activation function(s) to apply after the first dense layer transformation - batch_first: Assume that X is batched in the first dimension if it has more than 2 dims. quantizer_sets: Tuple of two quantizer sets for the two dense layer transformations Returns: @@ -139,13 +123,12 @@ def layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - batch_first, quantizer_sets, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) +@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) def _layernorm_mlp( x: jnp.ndarray, gamma: jnp.ndarray, @@ -165,7 +148,6 @@ def _layernorm_mlp( ffn1_ckpt_name: str, ffn2_ckpt_name: str, activation_type: Sequence[Union[str, Callable]], - batch_first: bool, quantizer_sets, ): """Internal implementation of layernorm_mlp with custom VJP. @@ -191,7 +173,6 @@ def _layernorm_mlp( ffn1_ckpt_name: Name for first feed-forward network checkpointing ffn2_ckpt_name: Name for second feed-forward network checkpointing activation_type: Activation function(s) - batch_first: Assume that X is batched in the first dimension. quantizer_sets: Tuple of quantizer sets Returns: @@ -216,7 +197,6 @@ def _layernorm_mlp( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - batch_first, quantizer_sets, ) return output @@ -241,7 +221,6 @@ def _layernorm_mlp_fwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - batch_first, quantizer_sets, ): """Forward pass rule for layernorm_mlp. @@ -274,17 +253,6 @@ def _layernorm_mlp_fwd_rule( assert x.shape[x_contracting_dims[0]] == kernel_1.shape[k_contracting_dims[0]] - x_bdim = None - if x.ndim > 2: - if not batch_first: - _issue_batch_first_warning( - "TE/JAX `layernorm_mlp()` fused-layer implementation does not officially " - "support sequence-first inputs and may produce incorrect results when " - "`batch_first=False` or `transpose_batch_sequence=True`. Use sequence-first " - "inputs at your own discretion." - ) - x_bdim = 0 if batch_first else x.ndim - 2 - use_bias_1 = bias_1 is not None use_bias_2 = bias_1 is not None @@ -312,7 +280,6 @@ def _layernorm_mlp_fwd_rule( casted_ln_out.get_tensor(TensorUsage.LHS), casted_kernel_1.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - batched_dims=((x_bdim,), ()), bias=bias_1 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_1 if not tex.gemm_uses_jax_dot() else False, ) @@ -337,16 +304,12 @@ def _layernorm_mlp_fwd_rule( # NN GEMM # (batch..., hidden_in) x (hidden_out, hidden_in) - sequence_dim = get_sequence_parallel_dim(norm_input_axes, x_contracting_dims, (x_bdim,)) dot_2_output = tex.gemm( casted_act_out.get_tensor(TensorUsage.LHS), casted_kernel_2.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, k_contracting_dims), - batched_dims=((x_bdim,), ()), bias=bias_2 if not tex.gemm_uses_jax_dot() else None, fuse_bias=use_bias_2 if not tex.gemm_uses_jax_dot() else False, - sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(), - sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None, ) if use_bias_2 and tex.gemm_uses_jax_dot(): @@ -374,8 +337,6 @@ def _layernorm_mlp_fwd_rule( use_bias_1, use_bias_2, quantizer_sets, - x_bdim, - sequence_dim, ) return dot_2_output, ctx @@ -393,7 +354,6 @@ def _layernorm_mlp_bwd_rule( ffn1_ckpt_name, ffn2_ckpt_name, activation_type, - batch_first, ctx, grad, ): @@ -410,7 +370,7 @@ def _layernorm_mlp_bwd_rule( Returns: Tuple of gradients for all input parameters """ - del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, batch_first + del norm_input_axes, ffn1_ckpt_name, ffn2_ckpt_name ( x, mu, @@ -429,8 +389,6 @@ def _layernorm_mlp_bwd_rule( use_bias_1, use_bias_2, quantizer_sets, - x_bdim, - sequence_dim, ) = ctx ffn1_quantizer_set, ffn2_quantizer_set = quantizer_sets @@ -457,7 +415,6 @@ def _layernorm_mlp_bwd_rule( casted_grad.get_tensor(TensorUsage.LHS), casted_kernel_2, contracting_dims=(g_contracting_dims_2, k_contracting_dims_2), - batched_dims=((x_bdim,), ()), ) dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) @@ -472,7 +429,6 @@ def _layernorm_mlp_bwd_rule( casted_act_out, casted_grad.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), - batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_2 = with_sharding_constraint_by_logical_axes(wgrad_2, kernel_2_axes) @@ -500,9 +456,6 @@ def _layernorm_mlp_bwd_rule( casted_dact_out.get_tensor(TensorUsage.LHS), casted_kernel_1, contracting_dims=(g_contracting_dims_1, k_contracting_dims_1), - batched_dims=((x_bdim,), ()), - sequence_parallel_output=sequence_dim is not None and not tex.gemm_uses_jax_dot(), - sequence_dim=sequence_dim if not tex.gemm_uses_jax_dot() else None, ) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, dot_1_input_axes) @@ -513,7 +466,6 @@ def _layernorm_mlp_bwd_rule( casted_ln_out, casted_dact_out.get_tensor(TensorUsage.RHS), contracting_dims=(x_contracting_dims, g_contracting_dims), - batched_dims=((x_bdim,), (x_bdim,)), ) wgrad_1 = with_sharding_constraint_by_logical_axes(wgrad_1, kernel_1_axes) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 6dd2e88a6..606c233c9 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -86,30 +86,6 @@ def get_sharding_map_logic_axis_to_mesh_axis(): return te_logical_axis_to_mesh_axis -def get_sequence_parallel_dim(logical_axes, contracting_dims, batch_dims): - """ - Get the index for the sequence-parallel dimension based on the given logical axes. - - The sequence-parallel dimension is assumed to be the only sharded non-batched non-contracting - dimension. - """ - if not logical_axes: - return None - - pspec = generate_pspec(logical_axes, with_flax_rules=True, padded=True) - ldims = [i for i in range(len(logical_axes)) if i not in set(contracting_dims + batch_dims)] - lspecs = [pspec[i] for i in ldims if pspec[i] is not None] - if len(lspecs) == 0: - return None - - assert len(lspecs) == 1, ( - "Expected only 1 non-batched non-contracting dimension to be sharded for " - f"sequence-parallelism, but found {len(lspecs)}: {pspec} @ idx {ldims}" - ) - - return pspec.index(lspecs[0]) - - def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False): """ Convert logical axes to PartitionSpec From 9f9b48168f106a172b28aeb44cb12f2c2c232181 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 8 Aug 2025 14:26:49 -0400 Subject: [PATCH 054/153] [JAX] Remove cudaGraph compatible trait from GroupedGemmFFI and GroupedQuantizeFFI (#2048) * rm cudaGraph compatible trait from GroupedGEMM and groupedQuantize Signed-off-by: Phuong Nguyen * add grouped_gemm jitting in the unit test Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- tests/jax/test_custom_call_compute.py | 37 +++++++------------ .../jax/csrc/extensions/gemm.cpp | 3 +- .../jax/csrc/extensions/quantization.cpp | 3 +- 3 files changed, 15 insertions(+), 28 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index aa243be62..d5f21651d 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -673,10 +673,6 @@ def test_grouped_qdq( n_groups=n_groups, ) - # grouped_quantize does not work with cudaGraph yet, so the jitting will breaks - # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to - # disable cudaGraph, then use the following jitted function - scaled_tensor = tex.grouped_quantize( x, group_sizes=group_sizes, flatten_axis=flatten_axis, quantizer=grouped_quantizer ) @@ -1312,16 +1308,14 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - # grouped_gemm does not work with cudaGraph yet, so the jitting will breaks - # To test it locally, export XLA_FLAGS="--xla_gpu_enable_command_buffer= $XLA_FLAGS" to - # disable cudaGraph, then use the following jitted function - # jitting grouped_gemm - # prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - # lhs, rhs, group_sizes, contracting_dims, - # ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs, + rhs, + group_sizes, + contracting_dims, + ) - prim_out = tex.grouped_gemm(lhs, rhs, group_sizes, contracting_dims) self._assert_grouped_gemm_output(prim_out, group_sizes, ref_out, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) @@ -1350,12 +1344,7 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - # jitting grouped_gemm - # prim_out = jax.jit(tex.grouped_gemm, static_argnames=('contracting_dims',))( - # lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set - # ) - - prim_out = tex.grouped_gemm( + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set ) @@ -1391,9 +1380,9 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) # jitting the grouped_dense - # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), - # static_argnums=(4,)) - value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = jit( + value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,) + ) ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( x, kernel, bias, group_sizes, contracting_dims @@ -1432,9 +1421,9 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): value_n_grad_ref_func = value_and_grad(self._ref_sum_grouped_dense, (0, 1, 2)) # jitting the grouped_dense - # value_n_grad_prim_func = jit(value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), - # static_argnums=(4,)) - value_n_grad_prim_func = value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)) + value_n_grad_prim_func = jit( + value_and_grad(self._primitive_sum_grouped_dense, (0, 1, 2)), static_argnums=(4,) + ) ref_out_sum, (ref_dgrad, ref_wgrad, ref_dbias) = value_n_grad_ref_func( x, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index ba2d65e3e..29d0fbfa6 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -592,8 +592,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") - .Attr("is_grouped_dense_wgrad"), - FFI_CudaGraph_Traits); + .Attr("is_grouped_dense_wgrad")); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index a92934193..7bea11f91 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -410,8 +410,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Ret() // amax .Attr("scaling_mode") .Attr("q_layout") - .Attr("flatten_axis"), - FFI_CudaGraph_Traits); + .Attr("flatten_axis")); } // namespace jax } // namespace transformer_engine From b6b3abcee696451d1d049ec90a52d460d1ea7fdb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Fri, 8 Aug 2025 20:49:38 +0200 Subject: [PATCH 055/153] [PyTorch debug] Improve precision debug tools performance (#1909) * turn on userbuffers for layers without debug Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * working change Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tests and fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * update nvinspect version Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix ci Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- qa/L0_pytorch_debug_unittest/test.sh | 8 + qa/L1_pytorch_distributed_unittest/test.sh | 5 + tests/pytorch/debug/test_api_features.py | 77 +++--- .../debug/test_configs/log_config.yaml | 19 ++ .../debug/test_configs/perf_config.yaml | 13 + tests/pytorch/debug/test_log.py | 59 +++++ tests/pytorch/debug/test_perf.py | 76 ++++++ transformer_engine/debug/features/api.py | 98 ++++++-- .../debug/features/disable_fp8_gemm.py | 2 +- .../debug/features/disable_fp8_layer.py | 2 +- .../debug/features/fake_quant.py | 4 +- .../debug/features/log_fp8_tensor_stats.py | 12 +- .../debug/features/log_tensor_stats.py | 11 +- .../debug/features/per_tensor_scaling.py | 4 +- .../debug/features/utils/__init__.py | 35 +++ .../debug/features/utils/stats_buffer.py | 39 ++- .../debug/pytorch/debug_quantization.py | 235 +++++++++++++----- .../debug/pytorch/debug_state.py | 7 + transformer_engine/debug/pytorch/utils.py | 19 ++ transformer_engine/pytorch/distributed.py | 77 +++--- transformer_engine/pytorch/module/base.py | 71 ++++-- .../pytorch/module/layernorm_linear.py | 30 +-- .../pytorch/module/layernorm_mlp.py | 27 +- transformer_engine/pytorch/module/linear.py | 32 ++- .../pytorch/tensor/float8_tensor.py | 12 + .../pytorch/tensor/quantized_tensor.py | 4 + 26 files changed, 740 insertions(+), 238 deletions(-) create mode 100644 tests/pytorch/debug/test_configs/log_config.yaml create mode 100644 tests/pytorch/debug/test_configs/perf_config.yaml create mode 100644 tests/pytorch/debug/test_log.py create mode 100644 tests/pytorch/debug/test_perf.py diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index 6c0c79251..414899aa4 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -14,11 +14,19 @@ FAIL=0 +# It is not installed as a requirement, +# because it is not available on PyPI. +pip uninstall -y nvdlfw-inspect +pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git + pip install pytest==8.2.1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 + # standard sanity and numerics tests with initialized debug NVTE_TEST_NVINSPECT_ENABLED=1 NVTE_TEST_NVINSPECT_CONFIG_FILE=$NVTE_TEST_NVINSPECT_DUMMY_CONFIG_FILE NVTE_TEST_NVINSPECT_FEATURE_DIRS=$NVTE_TEST_NVINSPECT_FEATURE_DIRS PYTORCH_JIT=0 NVTE_TORCH_COMPILE=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py || FAIL=1 diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index d7a4f054f..e5b4b5861 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -21,6 +21,11 @@ FAILED_CASES="" mkdir -p "$XML_LOG_DIR" +# It is not installed as a requirement, +# because it is not available on PyPI. +pip uninstall -y nvdlfw-inspect +pip install git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git + pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/distributed/test_sanity.py || test_fail "test_sanity.py" diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py index f9cd234ba..2a2ef1fe8 100644 --- a/tests/pytorch/debug/test_api_features.py +++ b/tests/pytorch/debug/test_api_features.py @@ -24,22 +24,22 @@ def test_transformer_engine_no_config(feature_dirs): # FP8 enabled - true by the default assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="fprop", iteration=0 - ) + )[0] - # modify_tensor_enabled - False by default + # modify_tensor_enabled - (False, None) by default assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 - ) + )[0] - # inspect_tensor_enabled - False by default + # inspect_tensor_enabled - (False, None) by default assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.1.attn.qkv", tensor_name="activation", iteration=0 - ) + )[0] - # inspect_tensor_postquantize - False by default + # inspect_tensor_postquantize - (False, None) by default assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 - ) + )[0] finally: debug_api.end_debug() @@ -51,24 +51,24 @@ def test_disable_fp8_gemm(configs_dir, feature_dirs): assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="fprop", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="dgrad", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="wgrad", iteration=0 - ) + )[0] # caching assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="fprop", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="dgrad", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="wgrad", iteration=0 - ) + )[0] finally: debug_api.end_debug() @@ -80,22 +80,22 @@ def test_disable_fp8_layer(configs_dir, feature_dirs): assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.mlp.fc1", gemm="fprop", iteration=0 - ) + )[0] assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.mlp.fc1", gemm="wgrad", iteration=0 - ) + )[0] assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.mlp.fc1", gemm="dgrad", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="fprop", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="wgrad", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.attn.qkv", gemm="dgrad", iteration=0 - ) + )[0] finally: debug_api.end_debug() @@ -111,22 +111,22 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): # check modify_tensor_enabled assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 - ) + )[0] assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="fprop", tensor_name="weight", iteration=0 - ) + )[0] assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="weight", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="gradient", iteration=0 - ) + )[0] assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="wgrad", tensor_name="activation", iteration=0 - ) + )[0] # check modify_tensor @@ -168,14 +168,14 @@ def test_per_tensor_scaling(configs_dir, feature_dirs): gemm="wgrad", tensor_name="gradient", iteration=0, - ) + )[0] assert not debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc4", gemm="fprop", tensor_name="activation", iteration=0, - ) + )[0] finally: debug_api.end_debug() @@ -191,11 +191,11 @@ def test_fake_quant(configs_dir, feature_dirs): # modify_tensor_enabled assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="fprop", tensor_name="activation", iteration=0 - ) + )[0] assert debug_api.transformer_engine.modify_tensor_enabled( "decoder.1.mlp.fc1", gemm="dgrad", tensor_name="gradient", iteration=0 - ) + )[0] # modify_tensor debug_api.transformer_engine.modify_tensor( @@ -218,11 +218,11 @@ def test_fake_quant(configs_dir, feature_dirs): assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.fc2", gemm="wgrad", iteration=0 - ) + )[0] # caching assert debug_api.transformer_engine.fp8_gemm_enabled( "decoder.1.fc2", gemm="wgrad", iteration=0 - ) + )[0] finally: debug_api.end_debug() @@ -265,21 +265,20 @@ def assert_empty(): assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max() assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.1.mlp.fc1", tensor_name="activation", iteration=201 - ) + )[0] assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.2.mlp.fc1", tensor_name="activation", iteration=200 - ) + )[0] assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 - ) + )[0] expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5) - expected_overflows = (tensor_fp8._data == 126).sum() * 100 / (100 * 100 * 5) # TE FP8 tensor stats -- assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled( "decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 - ) + )[0] debug_api.transformer_engine.inspect_tensor_postquantize( "decoder.1.mlp.fc1", tensor=tensor_fp8, @@ -295,10 +294,10 @@ def assert_empty(): assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( "decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201 - ) + )[0] assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( "decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 - ) + )[0] # Second config in same yaml tensor = torch.rand((100, 100, 5)) @@ -328,7 +327,7 @@ def assert_empty(): assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.7.mlp.fc1", tensor_name="weight", iteration=201 - ) + )[0] assert_empty() finally: diff --git a/tests/pytorch/debug/test_configs/log_config.yaml b/tests/pytorch/debug/test_configs/log_config.yaml new file mode 100644 index 000000000..04f490b9d --- /dev/null +++ b/tests/pytorch/debug/test_configs/log_config.yaml @@ -0,0 +1,19 @@ +test: + enabled: True + layers: + layer_name_regex_pattern: .* + transformer_engine: + LogTensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [cur_amax, dynamic_range, mean, std, l1_norm] + start_step: 1 + freq: 3 + LogFp8TensorStats: + enabled: True + tensors: weight + stats: [underflows%] + start_step: 1 + freq: 5 + \ No newline at end of file diff --git a/tests/pytorch/debug/test_configs/perf_config.yaml b/tests/pytorch/debug/test_configs/perf_config.yaml new file mode 100644 index 000000000..4ef5e51cd --- /dev/null +++ b/tests/pytorch/debug/test_configs/perf_config.yaml @@ -0,0 +1,13 @@ +test: + enabled: True + layers: + layer_name_regex_pattern: .*1 + transformer_engine: + LogTensorStats: + enabled: True + tensors_struct: + - tensor: activation + stats: [cur_amax, dynamic_range, mean, std, l1_norm] + start_step: 0 + freq: 100000 + \ No newline at end of file diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py new file mode 100644 index 000000000..fb0988d76 --- /dev/null +++ b/tests/pytorch/debug/test_log.py @@ -0,0 +1,59 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +import pytest +import torch +import transformer_engine.pytorch as te +import tempfile +import os + +import nvdlfw_inspect.api as debug_api + +from transformer_engine.debug.pytorch.debug_state import TEDebugState + + +@pytest.mark.parametrize("layer", ["linear", "transformer"]) +def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): + # If layer does not invoke any feature in current iteration, + # then it changed into non-debug mode. + # This test checks whether this works correctly - + # non-quantized statistics should be logged every 3 iterations, + # and quantized statistics should be logged every 5 iterations. + with tempfile.TemporaryDirectory() as temp_dir: + debug_api.initialize( + config_file=configs_dir + "/log_config.yaml", + feature_dirs=feature_dirs, + log_dir=temp_dir, + ) + + if layer == "linear": + model = te.Linear(128, 128, name="linear1") + elif layer == "transformer": + model = te.TransformerLayer(128, 128, 4, name="transformer1") + else: + raise ValueError(f"Invalid layer: {layer}") + + for i in range(11): + x = torch.randn(4, 128, 128).cuda() + with te.fp8_autocast(enabled=True): + y = model(x) + y.sum().backward() + debug_api.step() + + with open( + os.path.join( + temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log" + ), + "r", + ) as f: + file_content = f.read() + for i in range(1, 11): + if i % 3 == 0 or i % 5 == 0: + assert f"iteration={i:06d}" in file_content + else: + assert f"iteration={i:06d}" not in file_content + + debug_api.end_debug() + TEDebugState._reset() diff --git a/tests/pytorch/debug/test_perf.py b/tests/pytorch/debug/test_perf.py new file mode 100644 index 000000000..2d4b62b23 --- /dev/null +++ b/tests/pytorch/debug/test_perf.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +import pytest +import torch +import transformer_engine.pytorch as te +import time + +import nvdlfw_inspect.api as debug_api + +from transformer_engine.debug.pytorch.debug_state import TEDebugState + + +def _run_cpu_overhead(debug_tools_initialized, layer, configs_dir, feature_dirs): + debug_api.end_debug() + TEDebugState._reset() + if debug_tools_initialized: + # This config log stats starting from 0, every N iterations for huge N >> NUM_ITERS. + # So after 1 warm-up iteration, this layers should work in non-debug mode. + debug_api.initialize( + config_file=configs_dir + "/perf_config.yaml", feature_dirs=feature_dirs + ) + + try: + if layer == "linear": + model = torch.nn.Sequential( + te.Linear(1, 1, name="linear1"), te.Linear(1, 1, name="linear2") + ).cuda() + NUM_ITERS = 18000 + elif layer == "transformer": + model = torch.nn.Sequential( + te.TransformerLayer(1, 1, 1, name="transformer1"), + te.TransformerLayer(1, 1, 1, name="transformer2"), + ).cuda() + NUM_ITERS = 2000 + + x = torch.randn(1, 1, 1).cuda() + + y = model(x) + y.sum().backward() + debug_api.step() + torch.cuda.synchronize() + + time_start = time.time() + for i in range(NUM_ITERS): + y = model(x) + y.sum().backward() + if debug_tools_initialized: + debug_api.step() + torch.cuda.synchronize() + time_end = time.time() + + finally: + if debug_tools_initialized: + debug_api.end_debug() + + return time_end - time_start + + +@pytest.mark.parametrize("layer", ["linear", "transformer"]) +def test_cpu_overhead(layer, configs_dir, feature_dirs): + # runs one layer many times on very small tensor + # - gpu time should be negligible, so time should be dominated by cpu time. + # if layers does not invoke any feature in current iteration, + # then it changed into non-debug mode and should not have any non-negligible cpu overhead + # compared to layer without debug tools initialized. + + with_debug_tools = _run_cpu_overhead(True, layer, configs_dir, feature_dirs) + without_debug_tools = _run_cpu_overhead(False, layer, configs_dir, feature_dirs) + + print(f"with_debug_tools: {with_debug_tools} s") + print(f"without_debug_tools: {without_debug_tools} s") + + assert with_debug_tools < without_debug_tools * 1.25 # 25% overhead margin diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index 13ab6040d..ff37f57bf 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -5,7 +5,7 @@ """API definition for nvidia-dlframework-inspect.""" import copy -from typing import Dict, Union +from typing import Dict, Union, Tuple, Optional from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper from nvdlfw_inspect.registry import Registry @@ -101,13 +101,23 @@ def _process_transformer_engine_config(self, config, **kwargs): class TEDefaultFeatures: """Transformer Engine API calls default behavior.""" - def fp8_gemm_enabled(self, config: Dict, layer_name: str, gemm: str, iteration: int) -> bool: + def fp8_gemm_enabled( + self, + config: Dict, + layer_name: str, + gemm: str, + iteration: int, + ) -> bool | Tuple[bool, Optional[int]]: """ If the tensor is not processed using *modify_tensor* and the fp8 recipe is enabled, then the decision whether to cast it to fp8 is based on the value returned by the call *fp8_gemm_enabled*. If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled, the result of this call does not matter. + This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. + It can return (bool, None) if the feature will never be enabled for that layer and gemm. + Returning the next enabled iteration can help optimize CPU usage. + Parameters ---------- @@ -122,9 +132,9 @@ def fp8_gemm_enabled(self, config: Dict, layer_name: str, gemm: str, iteration: Returns ------- - bool - default is True + Union[bool, Tuple[bool, Optional[int]]] - default is (True, None) """ - return True # if it is false, fp8_gemm will be turned off. Otherwise nothing happens. + return True, None # if it is false, fp8_gemm will be turned off. Otherwise nothing happens. def modify_tensor_enabled( self, @@ -133,9 +143,16 @@ def modify_tensor_enabled( gemm: str, tensor_name: str, iteration: int, - ) -> bool: + ) -> bool | Tuple[bool, Optional[int]]: """ - It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name. It has **higher priority** than fp8_gemm, if *modify_tensor_enabled* returns True, then modify_tensor call is invoked for the respective tensor no matter what. + It is used to determine whether *modify_tensor* will be run for a given GEMM and tensor name. + It has **higher priority** than fp8_gemm; if *modify_tensor_enabled* returns True or (True, next_enabled_iter), + then modify_tensor call is invoked for the respective tensor no matter what. + + This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. + It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor. + Returning the next enabled iteration can help optimize CPU usage, especially when the interval between modify_tensor is large. + Returning only a bool is deprecated. Parameters ---------- @@ -153,9 +170,9 @@ def modify_tensor_enabled( Returns ------- - bool - default is False + Union[bool, Tuple[bool, Optional[int]]] - default is (False, None) """ - return False + return False, None def modify_tensor( self, @@ -167,7 +184,7 @@ def modify_tensor( default_quantizer: Quantizer, iteration: int, out: Union[torch.Tensor, QuantizedTensor], - ) -> Union[torch.Tensor, QuantizedTensor, None]: + ) -> torch.Tensor | QuantizedTensor | None: """ It allows tensor modification. For example, feature `FakeQuant` uses it to emulate casting to FP8. @@ -298,9 +315,15 @@ def inspect_tensor_enabled( layer_name: str, tensor_name: str, iteration: int, - ) -> bool: + ) -> bool | Tuple[bool, Optional[int]]: """ - It is a routing call, which is run at the initialization of the layer. If it returns true, then *inspect_tensor* for a given GEMM and tensor will be invoked. + It is a routing call, which is run at the initialization of the layer. + Determines if *inspect_tensor* for a given GEMM and tensor will be invoked. + + This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. + It can return (bool, None) if the feature will never be enabled for that layer and tensor. + Returning the next enabled iteration can help optimize CPU usage, especially when the interval between inspect_tensor is large. + Returning only a bool is deprecated. Parameters ---------- @@ -316,9 +339,9 @@ def inspect_tensor_enabled( Returns ------- - bool - default is False + Union[bool, Tuple[bool, Optional[int]]] - default is (False, None) """ - return False + return False, None def inspect_tensor_postquantize_enabled( self, @@ -327,11 +350,16 @@ def inspect_tensor_postquantize_enabled( gemm: str, tensor_name: str, iteration: int, - ) -> bool: + ) -> bool | Tuple[bool, Optional[int]]: """ It is a routing call, which is run at the initialization of the layer. - If it returns true, then *inspect_tensor_postquantize* for - a given GEMM and tensor will be invoked. + Determines if *inspect_tensor_postquantize* for a given GEMM and tensor will be invoked. + + This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. + It can return (bool, None) if the feature will never be enabled for that layer, gemm and tensor name. + Returning the next enabled iteration can help optimize CPU usage, + especially when the interval between inspect_tensor_postquantize is large. + Returning only a bool is deprecated. Parameters ---------- @@ -349,9 +377,9 @@ def inspect_tensor_postquantize_enabled( Returns ------- - bool - default is False + Union[bool, Tuple[bool, Optional[int]]] - default is (False, None) """ - return False + return False, None @Registry.register_namespace_api(namespace="transformer_engine") @@ -420,7 +448,7 @@ def routing_condition(self, api_name, config, _, feature_obj, **kwargs): def output_assertions_hook(self, api_name, ret, **kwargs): """Output hooks used to check correctness of the outputs of the API calls.""" if "enabled" in api_name or api_name == "fp8_gemm": - assert isinstance(ret, bool) + assert isinstance(ret, (bool, tuple)) if api_name in ["inspect_tensor", "inspect_tensor_postquantize"]: assert ret is None if api_name == "modify_tensor": @@ -432,6 +460,38 @@ def output_assertions_hook(self, api_name, ret, **kwargs): if kwargs["dtype"] is not None: assert ret.dtype == kwargs["dtype"] + def handle_multi_feature_output( + self, api_name, multi_feature_outputs, features_to_invoke, **kwargs + ): + """ + Handle multi-tensor output of the API calls. + """ + if "enabled" in api_name: + # *_enabled feature calls can return bool, or tuple (bool, Optional[int]). + # If any of them returns bool, then we return bool - this means that we cannot state anything + # about enablement in the next steps. + # If all of them return a tuple (bool, Optional[int]), we return the minimum value, + # representing the number of steps after the feature will be enabled next time. + # If the second value is None, that means that the feature will never be enabled. + all_ret_tuple = all( + isinstance(feature_output, tuple) + for feature_output in multi_feature_outputs.values() + ) + if all_ret_tuple: + run_current = any( + feature_output[0] for feature_output in multi_feature_outputs.values() + ) + next_iter = None + for feature_output in multi_feature_outputs.values(): + if feature_output[1] is not None: + next_iter = min(next_iter, feature_output[1]) + return run_current, next_iter + run_current = any(feature_output for feature_output in multi_feature_outputs.values()) + return run_current, None + return super().handle_multi_feature_output( + api_name, multi_feature_outputs, features_to_invoke, **kwargs + ) + def step(self): """This function is called by the nvidia-dlframework-inspect after every debug_api.step()""" STATS_BUFFERS.log_stats() diff --git a/transformer_engine/debug/features/disable_fp8_gemm.py b/transformer_engine/debug/features/disable_fp8_gemm.py index b2400d1cd..11822fd08 100644 --- a/transformer_engine/debug/features/disable_fp8_gemm.py +++ b/transformer_engine/debug/features/disable_fp8_gemm.py @@ -50,4 +50,4 @@ def fp8_gemm_enabled( # If this feature is invoked, then FP8 GEMM is disabled. # If not, then default behaviour in TransformerEngineAPI # is that fp8_gemm() API call returns True. - return False + return False, None diff --git a/transformer_engine/debug/features/disable_fp8_layer.py b/transformer_engine/debug/features/disable_fp8_layer.py index 7e885fe5e..d4f9b1b12 100644 --- a/transformer_engine/debug/features/disable_fp8_layer.py +++ b/transformer_engine/debug/features/disable_fp8_layer.py @@ -41,7 +41,7 @@ def fp8_gemm_enabled( # If this feature is invoked, then FP8 GEMM is disabled. # If not, then default behavior in TransformerEngineAPI # is that fp8_gemm() API call returns True. - return False + return False, None def parse_config_and_api(self, config, **_kwargs): """Determines whether to run the API diff --git a/transformer_engine/debug/features/fake_quant.py b/transformer_engine/debug/features/fake_quant.py index 4a5b6c34a..4b01f9712 100644 --- a/transformer_engine/debug/features/fake_quant.py +++ b/transformer_engine/debug/features/fake_quant.py @@ -127,14 +127,14 @@ def fp8_gemm_enabled( self, config, layer_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call responsible for selecting between high-precision and FP8 GEMM execution.""" - return False + return False, None @api_method def modify_tensor_enabled( self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call used to determine whether to run process_tensor() in the forward.""" - return True + return True, iteration + 1 @api_method def modify_tensor( diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index e5c84a9bd..c1528bb05 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -13,6 +13,7 @@ from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS +from transformer_engine.debug.features.utils import next_enabled_iter from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor @@ -92,8 +93,15 @@ def inspect_tensor_postquantize_enabled( self, config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int ): # pylint: disable=unused-argument """API call used to determine whether to run inspect_tensor_postquantize() in the forward.""" - # check whether logging should happen in this iteration - return self._check_params(config, layer_name, iteration=iteration) + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter + return run_current, next_iter @api_method def inspect_tensor_postquantize( diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 75ff81d13..402750c28 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -19,6 +19,7 @@ from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS +from transformer_engine.debug.features.utils import next_enabled_iter @Registry.register_feature(namespace="transformer_engine") @@ -97,7 +98,15 @@ def inspect_tensor_enabled( self, config: Dict, layer_name: str, tensor_name: str, iteration: int ): # pylint: disable=unused-argument """API call used to determine whether to run look_at_tensor_before_process() in the forward.""" - return self._check_params(config, layer_name, iteration=iteration) + run_current, next_iter = next_enabled_iter( + config.get("start_step", None), + config.get("end_step", None), + config.get("start_end_list", None), + config.get("freq", 1), + iteration, + ) + STATS_BUFFERS.layers_to_next_iter[layer_name] = next_iter + return run_current, next_iter @api_method def inspect_tensor( diff --git a/transformer_engine/debug/features/per_tensor_scaling.py b/transformer_engine/debug/features/per_tensor_scaling.py index 7b4de0a18..dd1f42cf0 100644 --- a/transformer_engine/debug/features/per_tensor_scaling.py +++ b/transformer_engine/debug/features/per_tensor_scaling.py @@ -91,14 +91,14 @@ def fp8_gemm( self, config, layer_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call responsible for selecting between high-precision and FP8 GEMM execution.""" - return False + return False, None @api_method def modify_tensor_enabled( self, config, layer_name: str, tensor_name: str, gemm: str, iteration: int ): # pylint: disable=unused-argument """API call used to determine whether to run process_tensor() in the forward.""" - return True + return True, iteration + 1 @api_method def modify_tensor( diff --git a/transformer_engine/debug/features/utils/__init__.py b/transformer_engine/debug/features/utils/__init__.py index 951e25063..60f6b0a21 100644 --- a/transformer_engine/debug/features/utils/__init__.py +++ b/transformer_engine/debug/features/utils/__init__.py @@ -5,3 +5,38 @@ """ Utils for the debug features. """ + + +def next_enabled_iter(start_step, end_step, start_end_list, freq, iteration): + """ + Determines whether the feature should be enabled at the current iteration, + and computes the next iteration at which the feature will be enabled. + + Returns + ------- + run_current : bool + True if the feature should be enabled at the current iteration. + next_iter : int + The next iteration index at which the feature will be enabled. + """ + + run_current = False + + if start_end_list: + intervals = sorted(start_end_list) + else: + start_step = 0 if start_step is None else start_step + end = float("inf") if end_step is None else end_step + intervals = [(start_step, end)] + + for s, e in intervals: + if iteration % freq == 0 and s <= iteration <= e: + run_current = True + + first = max(iteration + 1, s) + offset = first % freq + candidate = first if offset == 0 else first + (freq - offset) + if candidate <= e: + return run_current, candidate + + return run_current, None # No next iteration found diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 4be465f8e..7ccef20bc 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -10,6 +10,7 @@ from collections import defaultdict +from typing import Dict import torch from nvdlfw_inspect.utils import gather_along_first_dim @@ -20,6 +21,7 @@ DEPENDENCIES, stats_to_num, ) +from transformer_engine.debug.pytorch.debug_state import TEDebugState class _Buffer: @@ -146,10 +148,41 @@ def __init__(self): self.buffers = {} # (layer_name, tensor_name) -> buffer self.reduction_group_to_buffer = defaultdict(list) + # Logging stats involves synchronization between nodes + # and non-trivial cpu overhead. + # It should be only done if absolutely necessary. + # This variables helps to determine if we can reduce. + self.at_least_one_layer_fed = False + self.layers_to_next_iter: Dict[str, int] = {} + + def _if_run_reduction(self) -> bool: + """ + Returns True if reduction should be run. + + This is the case if at least one layer logged stats. + If not, it may be the case that some layer was not run on this node. + If we know that such layers on all other nodes do not log this time, + we can not reduce. If this in not the case, we should reduce. + + To ensure corretness, we assume that every layer is invoked at first forward pass. + If this is not the case, hang might happen. + """ + if self.at_least_one_layer_fed: + return True + iteration = TEDebugState.get_iteration() + for _, next_iter in self.layers_to_next_iter.items(): + # Note that layer can be not run for many iterations, + # in this case we will synchronize until every step until we get any information from it. + if iteration >= next_iter: + return True + return False + def reset(self): """Resets all buffers.""" self.buffers = {} # (layer_name, tensor_name) -> buffer self.reduction_group_to_buffer = defaultdict(list) + self.at_least_one_layer_fed = False + self.layers_to_next_iter: Dict[str, int] = {} def try_add_buffer( self, layer_name, tensor_name, stats, options, reduction_group, reduce_within_microbatch @@ -163,12 +196,16 @@ def try_add_buffer( def feed(self, layer_name, tensor_name, options, tensor, iteration, skip_reduction): """Feeds the tensor into the respective buffer.""" + self.at_least_one_layer_fed = True buffer = self.buffers[(layer_name, tensor_name, options)] buffer.feed(tensor, iteration) buffer.skip_reduction = skip_reduction def log_stats(self): """Logs the stats from all the buffers.""" + if not self._if_run_reduction(): + return {} + output = {} for reduction_group, buffers in self.reduction_group_to_buffer.items(): changed_buffers = [ @@ -181,7 +218,7 @@ def log_stats(self): for _, buffer in changed_buffers: stats = buffer.log() output.update(stats) - + self.at_least_one_layer_fed = False return output diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 2b859800a..98feb3180 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -22,6 +22,7 @@ prepare_for_saving, restore_from_saved, ) +from transformer_engine.debug.pytorch.debug_state import TEDebugState aten = torch.ops.aten @@ -53,14 +54,13 @@ def __init__( parent_quantizer: Optional[Quantizer], tp_group: torch.distributed.ProcessGroup, ): - import nvdlfw_inspect.api as debug_api super().__init__(rowwise=True, columnwise=True) self.layer_name = layer_name self.tensor_name = tensor_name self.parent_quantizer = parent_quantizer self.tp_group = tp_group # used in inspect_tensor calls - self.iteration = debug_api.DEBUG_MANAGER._trainer_iteration_count + self.iteration = TEDebugState.get_iteration() # .internal = True is slightly faster, but results # in errors when caching the weights. @@ -70,6 +70,12 @@ def __init__( self.rowwise_gemm_name, self.columnwise_gemm_name = _tensor_to_gemm_names_map[tensor_name] + # next iteration when this quantizer will call any API + # it is None at the init and it is computed after_enabled api calls. + # None at the beginning means that if nothing will be done, + # this quantizer will never call any API. + self.next_debug_iter = None + # The values of the inspect_tensor_enabled, inspect_tensor_postquantize_enabled, # rowwise_tensor_plan, and columnwise_tensor_plan are computed. # These fields indicate the path where API calls will be inserted. @@ -102,15 +108,21 @@ def get_plans_for_output_tensors(self) -> Tuple[bool, str]: """ import nvdlfw_inspect.api as debug_api - inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( - layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + inspect_tensor_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) ) - modify_enabled = debug_api.transformer_engine.modify_tensor_enabled( - layer_name=self.layer_name, - gemm=self.rowwise_gemm_name, - tensor_name=self.tensor_name, - iteration=self.iteration, + + modify_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) ) + plan = API_CALL_MODIFY if modify_enabled else HIGH_PRECISION return inspect_tensor_enabled, plan @@ -121,10 +133,13 @@ def get_enabled_look_at_tensors(self): """ import nvdlfw_inspect.api as debug_api - inspect_tensor_enabled = debug_api.transformer_engine.inspect_tensor_enabled( - layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + inspect_tensor_enabled = self.process_enabled_api_call( + debug_api.transformer_engine.inspect_tensor_enabled( + layer_name=self.layer_name, tensor_name=self.tensor_name, iteration=self.iteration + ) ) - inspect_tensor_postquantize_enabled_rowwise = ( + + inspect_tensor_postquantize_enabled_rowwise = self.process_enabled_api_call( debug_api.transformer_engine.inspect_tensor_postquantize_enabled( layer_name=self.layer_name, tensor_name=self.tensor_name, @@ -132,7 +147,8 @@ def get_enabled_look_at_tensors(self): gemm=self.rowwise_gemm_name, ) ) - inspect_tensor_postquantize_enabled_columnwise = ( + + inspect_tensor_postquantize_enabled_columnwise = self.process_enabled_api_call( debug_api.transformer_engine.inspect_tensor_postquantize_enabled( layer_name=self.layer_name, tensor_name=self.tensor_name, @@ -158,42 +174,54 @@ def get_tensors_plan(self): rowwise_plan = None columnwise_plan = None - modify_rowwise = debug_api.transformer_engine.modify_tensor_enabled( - layer_name=self.layer_name, - gemm=self.rowwise_gemm_name, - tensor_name=self.tensor_name, - iteration=self.iteration, + modify_rowwise = self.process_enabled_api_call( + debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) ) + if modify_rowwise: rowwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( - layer_name=self.layer_name, - gemm=self.rowwise_gemm_name, - iteration=self.iteration, + fp8_quantize = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.rowwise_gemm_name, + iteration=self.iteration, + ) ) + if fp8_quantize: rowwise_plan = STANDARD_FP8_QUANTIZE if rowwise_plan is None: rowwise_plan = HIGH_PRECISION if self.columnwise_gemm_name is not None: - modify_columnwise = debug_api.transformer_engine.modify_tensor_enabled( - layer_name=self.layer_name, - gemm=self.columnwise_gemm_name, - tensor_name=self.tensor_name, - iteration=self.iteration, + modify_columnwise = self.process_enabled_api_call( + debug_api.transformer_engine.modify_tensor_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + tensor_name=self.tensor_name, + iteration=self.iteration, + ) ) + if modify_columnwise: columnwise_plan = API_CALL_MODIFY else: if self.parent_quantizer is not None: - fp8_quantize = debug_api.transformer_engine.fp8_gemm_enabled( - layer_name=self.layer_name, - gemm=self.columnwise_gemm_name, - iteration=self.iteration, + fp8_quantize = self.process_enabled_api_call( + debug_api.transformer_engine.fp8_gemm_enabled( + layer_name=self.layer_name, + gemm=self.columnwise_gemm_name, + iteration=self.iteration, + ) ) + if fp8_quantize: columnwise_plan = STANDARD_FP8_QUANTIZE if columnwise_plan is None: @@ -229,7 +257,7 @@ def _call_inspect_tensor_api( "layer_name": self.layer_name, "tensor": tensor, "tensor_name": self.tensor_name, - "iteration": debug_api.DEBUG_MANAGER._trainer_iteration_count, + "iteration": TEDebugState.get_iteration(), "tp_group": self.tp_group, } if tensor is not None and self.inspect_tensor_enabled: @@ -270,22 +298,14 @@ def quantize( # 1. If there is fp8 quantization in at least one of the gemms, # the quantization using the self.parent_quantizer is performed. - # rowwise gemm corresponds to the rowwise_usage in fp8, similarly with columnwise - rowwise_gemm_quantize = ( - self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE - ) - columnwise_gemm_quantize = ( - self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE - ) - if columnwise_gemm_quantize and not rowwise_gemm_quantize: - rowwise_gemm_quantize = True # only columnwise quantization not implemented + self._update_parent_quantizer_usage() + # Only columnwise quantization is not supported. + if self.parent_quantizer is not None: + if not self.parent_quantizer.rowwise_usage and self.parent_quantizer.columnwise_usage: + self.parent_quantizer.set_usage(rowwise=True) rowwise_gemm_tensor, columnwise_gemm_tensor = None, None if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: - self.parent_quantizer.set_usage( - rowwise=True, - columnwise=columnwise_gemm_quantize, # columnwise usage only is not supported - ) quantized_tensor = self.parent_quantizer(tensor) # if both rowwise_tensor_plan and columnwise_tensor_plan need to be in fp8, # one tensor with columnwise=True and rowwise=True is computed @@ -341,7 +361,6 @@ def quantize( quantizer=self, layer_name=self.layer_name, tensor_name=self.tensor_name, - original_tensor=tensor, ) def process_gemm_output(self, tensor: torch.Tensor): @@ -375,6 +394,25 @@ def make_empty( return self.parent_quantizer.make_empty(shape, dtype=dtype, device=device) return torch.empty(shape, dtype=dtype, device=device) + def any_feature_enabled(self) -> bool: + """Returns bool if there is at least one API call enabled.""" + if self.output_tensor: + return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY + if ( + self.inspect_tensor_enabled + or self.inspect_tensor_postquantize_enabled_rowwise + or self.inspect_tensor_postquantize_enabled_columnwise + or self.rowwise_tensor_plan == API_CALL_MODIFY + or self.columnwise_tensor_plan == API_CALL_MODIFY + ): + return True + if self.parent_quantizer is not None: + if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: + return True + return False + def calibrate(self, tensor: torch.Tensor): """Calibration override, should not be invoked.""" raise RuntimeError("[NVTORCH-INSPECT ERROR] Calibration with debug is not supported") @@ -446,29 +484,70 @@ def update_quantized( self._call_inspect_tensor_api(src, dst.rowwise_gemm_tensor, dst.columnwise_gemm_tensor) - def any_feature_enabled(self) -> bool: - """Returns bool if there is at least one API call enabled.""" - if self.output_tensor: - return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY - if ( - self.inspect_tensor_enabled - or self.inspect_tensor_postquantize_enabled_rowwise - or self.inspect_tensor_postquantize_enabled_columnwise - or self.rowwise_tensor_plan == API_CALL_MODIFY - or self.columnwise_tensor_plan == API_CALL_MODIFY - ): - return True - if self.parent_quantizer is not None: - if self.rowwise_tensor_plan != STANDARD_FP8_QUANTIZE: - return True - if self.columnwise_tensor_plan != STANDARD_FP8_QUANTIZE: - return True - return False + def get_next_debug_iter(self) -> Optional[int]: + """ + Returns the next iteration for which the debug is enabled for this tensor. + If the next iteration is None, then the debug is not enabled for this tensor. + """ + return self.next_debug_iter def _get_compatible_recipe(self) -> Union[type[Recipe], None]: """Probably not needed for debug quantizer""" return None + def process_enabled_api_call( + self, enabled_call_output: bool | Tuple[bool, Optional[int]] + ) -> bool: + """ + Process enabled API call output. + Updates self.next_debug_iter field accordingly. + Return the bool representing if the API call is enabled. + """ + if isinstance(enabled_call_output, tuple): + assert len(enabled_call_output) == 2, "Expected a tuple of length 2" + enabled_bool, next_iter = enabled_call_output + else: + enabled_bool = enabled_call_output + next_iter = self.iteration + 1 + + if self.next_debug_iter is None: + self.next_debug_iter = next_iter + elif next_iter is not None: + # If next iter is None, that means that call will never be enabled. + self.next_debug_iter = min(self.next_debug_iter, next_iter) + + return enabled_bool + + def supports_only_rowwise_all_gather(self) -> bool: + if self.parent_quantizer is not None: + return self.parent_quantizer.supports_only_rowwise_all_gather() + return False + + def _update_parent_quantizer_usage(self): + """ + Updates the usage of the parent quantizer. + """ + rowwise_gemm_quantize = ( + self.rowwise_usage and self.rowwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + columnwise_gemm_quantize = ( + self.columnwise_usage and self.columnwise_tensor_plan == STANDARD_FP8_QUANTIZE + ) + + if STANDARD_FP8_QUANTIZE in [self.rowwise_tensor_plan, self.columnwise_tensor_plan]: + self.parent_quantizer.set_usage( + rowwise=rowwise_gemm_quantize, + columnwise=columnwise_gemm_quantize, + ) + + def set_usage(self, rowwise: bool = None, columnwise: bool = None): + """ + Sets the usage of the quantizer. + """ + super().set_usage(rowwise=rowwise, columnwise=columnwise) + if not self.output_tensor: + self._update_parent_quantizer_usage() + class DebugQuantizedTensor(QuantizedTensorBase): """ @@ -484,7 +563,6 @@ def __init__( quantizer, layer_name=None, tensor_name=None, - original_tensor=None, ): self.rowwise_gemm_tensor = rowwise_gemm_tensor @@ -492,7 +570,6 @@ def __init__( self.quantizer = quantizer self._layer_name = layer_name self._tensor_name = tensor_name - self._original_tensor = original_tensor def prepare_for_saving(self): """ " Prepare for saving method override""" @@ -501,6 +578,7 @@ def prepare_for_saving(self): if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor else [self.rowwise_gemm_tensor] ) + tensor_list, tensor_objects_list = prepare_for_saving(*self.tensors_to_save) self.tensors_to_save = tensor_objects_list # pylint: disable=unbalanced-tuple-unpacking @@ -519,6 +597,7 @@ def restore_from_saved(self, tensors): else: self.rowwise_gemm_tensor = tensor_objects_list[0] self.columnwise_gemm_tensor = self.rowwise_gemm_tensor + return saved_tensors def quantize_(self, tensor, *, noop_flag=None): @@ -542,3 +621,27 @@ def size(self): def update_usage(self, rowwise_usage: bool = None, columnwise_usage: bool = None): """Update usage of the tensor.""" + if self.rowwise_gemm_tensor is not self.columnwise_gemm_tensor: + # If the same object is used both for rowwise and columnwise gemms, + # there is no benefit in erasing the usage of one of them. + # And there are scenarios when not deleting the usage of one of them is needed. + # For example when we want to recreate columnwise from rowwise. + if rowwise_usage is False: + self.rowwise_gemm_tensor = None + if columnwise_usage is False: + self.columnwise_gemm_tensor = None + + if isinstance(self.rowwise_gemm_tensor, QuantizedTensor): + self.rowwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage) + if isinstance(self.columnwise_gemm_tensor, QuantizedTensor): + self.columnwise_gemm_tensor.update_usage(rowwise_usage, columnwise_usage) + + if rowwise_usage and self.rowwise_gemm_tensor is None: + raise RuntimeError( + "Cannot recreate rowwise tensor from columnwise tensor in debug mode." + ) + + if columnwise_usage and self.columnwise_gemm_tensor is None: + raise RuntimeError( + "Cannot recreate columnwise tensor from rowwise tensor is debug mode." + ) diff --git a/transformer_engine/debug/pytorch/debug_state.py b/transformer_engine/debug/pytorch/debug_state.py index 11edb3641..c47e859bb 100644 --- a/transformer_engine/debug/pytorch/debug_state.py +++ b/transformer_engine/debug/pytorch/debug_state.py @@ -62,6 +62,13 @@ def set_weight_tensor_tp_group_reduce(cls, enabled): """Sets weight tensor reduction mode.""" cls.weight_tensor_tp_group_reduce = enabled + @classmethod + def get_iteration(cls): + """Returns the current iteration.""" + import nvdlfw_inspect.api as debug_api + + return debug_api.DEBUG_MANAGER._trainer_iteration_count + def set_weight_tensor_tp_group_reduce(enabled): """Sets weight tensor reduction mode.""" diff --git a/transformer_engine/debug/pytorch/utils.py b/transformer_engine/debug/pytorch/utils.py index 4aea05333..18ed3556f 100644 --- a/transformer_engine/debug/pytorch/utils.py +++ b/transformer_engine/debug/pytorch/utils.py @@ -4,6 +4,25 @@ """Utils functions for the debug module.""" +from typing import Optional + + +def next_iter_when_debug_should_be_run(quantizers) -> Optional[int]: + """ + Returns next iteration at which the debug should be run. + If debug will never be run for this layer, returns None. + """ + + out = None + for q in quantizers: + if q.get_next_debug_iter() is not None: + if out is None: + out = q.get_next_debug_iter() + else: + out = min(out, q.get_next_debug_iter()) + + return out + def any_feature_enabled(quantizers): """Returns True if at least one API call is made from DebugQuantizer.""" diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 868fc3a27..c3b42c5c4 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -981,6 +981,15 @@ def _all_gather_fp8( return out, handle +def _get_quantizer_format(quantizer: Quantizer) -> Optional[bool]: + """Get quantizer format.""" + if isinstance(quantizer, DebugQuantizer): + quantizer = quantizer.parent_quantizer + if isinstance(quantizer, Float8BlockQuantizer): + return quantizer.all_gather_usage + return None + + def _set_quantizer_format(quantizer: Quantizer, compact: bool = False) -> None: """Make quantizer compact""" _quantizer = quantizer @@ -1343,6 +1352,44 @@ def gather_along_first_dim( inp = quantizer(inp) return inp, None + # Debug case - call gather_along_first_dim on each tensor + if isinstance(inp, DebugQuantizedTensor): + out_obj = DebugQuantizedTensor( + rowwise_gemm_tensor=inp.rowwise_gemm_tensor, + columnwise_gemm_tensor=inp.columnwise_gemm_tensor, + quantizer=inp.quantizer, + layer_name=inp._layer_name, + tensor_name=inp._tensor_name, + ) + rowwise = inp.get_tensor(False) + columnwise = inp.get_tensor(True) + # shapes + final_quantizer = ( + None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer + ) + rowwise_total = None + if rowwise is not None: + rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[ + 0 + ] + out_obj.rowwise_gemm_tensor = rowwise_total + if rowwise is not columnwise: + final_quantizer_columnwise = ( + None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer + ) + columnwise_total = None + if columnwise is not None: + columnwise_total, _ = gather_along_first_dim( + columnwise, process_group, False, final_quantizer_columnwise + ) + out_obj.columnwise_gemm_tensor = columnwise_total + else: + # Sometimes the same object is used both for rowwise and columnwise gemms, + # and we want to avoid double all-gathers. + out_obj.columnwise_gemm_tensor = out_obj.rowwise_gemm_tensor + + return out_obj, None + # Output tensor dims out_shape = list(inp.size()) out_shape[0] *= world_size @@ -1380,34 +1427,6 @@ def gather_along_first_dim( out_shape=out_shape, ) - # Debug case - call gather_along_first_dim on each tensor - if isinstance(inp, DebugQuantizedTensor): - out_obj = inp - rowwise = inp.get_tensor(False) - columnwise = inp.get_tensor(True) - final_quantizer = ( - None if not needs_quantized_gemm(inp, rowwise=True) else quantizer.parent_quantizer - ) - # Temporary fix for TP communication of Float8BlockwiseQTensorBase - if isinstance(rowwise, Float8BlockwiseQTensorBase): - rowwise = inp._original_tensor - rowwise_total = gather_along_first_dim(rowwise, process_group, False, final_quantizer)[0] - out_obj.rowwise_gemm_tensor = rowwise_total - if rowwise is not columnwise: - final_quantizer_columnwise = ( - None if not needs_quantized_gemm(inp, rowwise=False) else quantizer.parent_quantizer - ) - # Temporary fix for TP communication of Float8BlockwiseQTensorBase - if isinstance(columnwise, Float8BlockwiseQTensorBase): - columnwise = inp._original_tensor - columnwise_total, _ = gather_along_first_dim( - columnwise, process_group, False, final_quantizer_columnwise - ) - out_obj.columnwise_gemm_tensor = columnwise_total - else: - out_obj.rowwise_gemm_tensor = out_obj.rowwise_gemm_tensor - return out_obj, None - # High-precision communication for quantized tensors if quantizer is not None: warnings.warn( @@ -1418,6 +1437,7 @@ def gather_along_first_dim( inp = inp.dequantize() # Falling back to high-precision all-gather for Float8BlockQuantizer # means that it should directly output GEMM_READY format + compact = _get_quantizer_format(quantizer) _set_quantizer_format(quantizer, compact=False) out = torch.empty( out_shape, @@ -1427,6 +1447,7 @@ def gather_along_first_dim( ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) out = quantizer(out) + _set_quantizer_format(quantizer, compact=compact) return out, None # Dequantize quantized tensor if not supported diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b0da6e5fc..b28b9db98 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -47,6 +47,7 @@ from ...common.recipe import DelayedScaling, Recipe from ...debug.pytorch.debug_state import TEDebugState from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor +from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled __all__ = ["initialize_ub", "destroy_ub"] @@ -564,6 +565,7 @@ def __init__(self) -> None: super().__init__() assert torch.cuda.is_available(), "TransformerEngine needs CUDA." self.name = None + self.next_iter_when_debug_should_be_run = 0 self.fp8_initialized = False self.fp8 = False self.fp8_calibration = False @@ -1416,12 +1418,55 @@ def backward_dw(self): for wgrad_accumulation_and_reduce_hook in self.wgrad_accumulation_and_reduce_hooks: wgrad_accumulation_and_reduce_hook() + def is_debug_iter(self) -> bool: + """ + This function checks if the debug should be enabled for this layer. + """ + debug = TEDebugState.debug_enabled + if not debug: + return False + self._validate_name() + + # If layer is run first time in new iteration, + # we need to check if the debug should be enabled for this layer - + # maybe in previous iterations debug features returned information + # that no feature will be active for this layer for multiple next iterations. + started_new_iteration = TEDebugState.get_iteration() != getattr( + self, "debug_last_iteration", None + ) + if started_new_iteration: + if self.next_iter_when_debug_should_be_run is None: + debug = False + else: + debug = TEDebugState.get_iteration() >= self.next_iter_when_debug_should_be_run + self.debug_last_iteration = TEDebugState.get_iteration() + return debug + + def no_debug_features_active(self, quantizers): + """ + Checks if any debug feature is active for this layer. + """ + run_current = any_feature_enabled(quantizers) + + # Sometimes features inform that they will not be enabled for particular layer + # for multiple next iterations. + self.next_iter_when_debug_should_be_run = next_iter_when_debug_should_be_run(quantizers) + + if not run_current: + return True + + if self.primary_weights_in_fp8: + raise RuntimeError("FP8 weights are not supported in debug mode.") + return False + def _validate_name(self): """ Validate name passed to the module. This is invoked in the forward() method as module names are assigned after Model is initialized in Megatron-LM. If no name is assigned, it creates a default name with layer count as the variable. """ + if self.name is not None: + return assert TEDebugState.debug_enabled import nvdlfw_inspect.api as debug_api @@ -1470,29 +1515,3 @@ def _check_weight_tensor_recipe_correspondence(self) -> None: " Please check the recipes assigned during fp8_model_init() and" " fp8_autocast() calls." ) - - def _turn_off_unsupported_features_in_debug(self): - if ( - getattr(self, "ub_bulk_wgrad", False) - or getattr(self, "ub_bulk_dgrad", False) - or getattr(self, "ub_overlap_ag", False) - or getattr(self, "ub_overlap_rs_dgrad", False) - or getattr(self, "ub_overlap_rs", False) - ): - import nvdlfw_inspect.api as debug_api - - debug_api.log_message( - "UserBuffers are not supported in debug module. " - "Using UB optimization will not affect the debug module. ", - level=logging.WARNING, - ) - if hasattr(self, "ub_bulk_wgrad"): - self.ub_bulk_wgrad = None - if hasattr(self, "ub_bulk_dgrad"): - self.ub_bulk_dgrad = None - if hasattr(self, "ub_overlap_ag"): - self.ub_overlap_ag = None - if hasattr(self, "ub_overlap_rs_dgrad"): - self.ub_overlap_rs_dgrad = None - if hasattr(self, "ub_overlap_rs"): - self.ub_overlap_rs = None diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 5e45b5c25..1413dd172 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -62,9 +62,7 @@ restore_from_saved, ) from ...debug.pytorch.debug_state import TEDebugState -from ...debug.pytorch.utils import any_feature_enabled from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer -from ..tensor.float8_tensor import Float8CurrentScalingQuantizer, Float8Quantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase @@ -162,6 +160,13 @@ def forward( with_input_all_gather = parallel_mode == "column" and sequence_parallel # Configure Userbuffers communication (comm+GEMM overlap) + if debug: # turn off userbuffers in debug mode + ub_overlap_ag_fprop = False + ub_overlap_rs_fprop = False + ub_overlap_ag_dgrad = False + ub_overlap_rs_dgrad = False + ub_bulk_wgrad = False + ub_bulk_dgrad = False ub_obj = None ub_type = None ub_overlap_ag_fprop = ( @@ -179,9 +184,7 @@ def forward( if input_quantizer is None: raise ValueError("Missing quantizer for input tensor") input_quantizer.set_usage(rowwise=True, columnwise=backward_needs_input) - if with_input_all_gather and isinstance( - input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) - ): + if with_input_all_gather and input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data input_quantizer.set_usage(columnwise=False) @@ -638,7 +641,7 @@ def backward( quantizer = None if ctx.input_quantizer is not None: quantizer = ctx.input_quantizer - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) else: @@ -1163,8 +1166,6 @@ def __init__( self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) self.name = name - if TEDebugState.debug_enabled: - self._turn_off_unsupported_features_in_debug() # turn off userbuffers if tp_group is None: self.tp_size = tp_size @@ -1471,9 +1472,8 @@ def forward( """ if is_in_onnx_export_mode(): return self.onnx_forward(inp, fp8_output) - debug = TEDebugState.debug_enabled - if debug: - self._validate_name() + + debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1504,13 +1504,9 @@ def forward( else self._get_debug_quantizers(fp8_output, fp8_grad) ) if debug: - if not any_feature_enabled(quantizers): - # If no feature is used, then run faster implementation with debug = False. - quantizers = self._get_quantizers(fp8_output, fp8_grad) + if self.no_debug_features_active(quantizers): debug = False - - if isinstance(weight_tensor, QuantizedTensor): - raise RuntimeError("FP8 weights are not supported in debug mode.") + quantizers = self._get_quantizers(fp8_output, fp8_grad) ( input_quantizer, diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 31ba65478..4149ab73c 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -68,7 +68,6 @@ from ._common import apply_normalization, WeightGradStore from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ..tensor.quantized_tensor import ( - QuantizedTensor, QuantizedTensorBase, Quantizer, prepare_for_saving, @@ -78,7 +77,6 @@ general_gemm, ) from ..export import is_in_onnx_export_mode, assert_warmed_up -from ...debug.pytorch.utils import any_feature_enabled from ...debug.pytorch.debug_state import TEDebugState __all__ = ["LayerNormMLP"] @@ -223,6 +221,12 @@ def forward( device = inp.device # Configure Userbuffers communication (comm+GEMM overlap) + if debug: # turn off userbuffers in debug mode + ub_overlap_ag = False + ub_overlap_rs = False + ub_overlap_rs_dgrad = False + ub_bulk_wgrad = False + ub_bulk_dgrad = False ub_overlap_ag = ub_overlap_ag and is_grad_enabled and not return_layernorm_output_gathered ub_overlap_rs = ub_overlap_rs and is_grad_enabled @@ -238,9 +242,7 @@ def forward( if fc1_input_quantizer is None: raise ValueError("Missing quantizer for FC1 input tensor") fc1_input_quantizer.set_usage(rowwise=True, columnwise=backwards_needs_fc1_input) - if sequence_parallel and isinstance( - fc1_input_quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) - ): + if sequence_parallel and fc1_input_quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data fc1_input_quantizer.set_usage(columnwise=False) @@ -1523,9 +1525,6 @@ def __init__( ) self.name = name - if TEDebugState.debug_enabled: - self._turn_off_unsupported_features_in_debug() # turn off userbuffers - self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if tp_group is None: @@ -1728,9 +1727,8 @@ def forward( """ if is_in_onnx_export_mode(): return self.onnx_forward(inp) - debug = TEDebugState.debug_enabled - if debug: - self._validate_name() + + debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1754,12 +1752,9 @@ def forward( else self._get_debug_quantizers(fp8_output) ) if debug: - if not any_feature_enabled(quantizers): - quantizers = self._get_quantizers(fp8_output) + if self.no_debug_features_active(quantizers): debug = False - - if isinstance(self.fc1_weight, QuantizedTensor): - raise RuntimeError("FP8 weights are not supported in debug mode.") + quantizers = self._get_quantizers(fp8_output) # Get quantizers ( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index a5dae9f30..c5dc43e6e 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -68,7 +68,6 @@ from ..export import is_in_onnx_export_mode, assert_warmed_up from ..cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...debug.pytorch.debug_state import TEDebugState -from ...debug.pytorch.utils import any_feature_enabled __all__ = ["Linear"] @@ -137,6 +136,12 @@ def forward( ) # Configure Userbuffers communication (comm+GEMM overlap) + if debug: # turn off userbuffers in debug mode + ub_overlap_rs_fprop = False + ub_overlap_ag_fprop = False + ub_overlap_rs_dgrad = False + ub_bulk_wgrad = False + ub_bulk_dgrad = False ub_obj = None ub_type = None if ub_overlap_rs_fprop: @@ -356,8 +361,9 @@ def forward( and own_quantized_input and isinstance(inputmat, QuantizedTensorBase) ): - if ctx.backward_input_needs_gather and isinstance( - quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer) + if ( + ctx.backward_input_needs_gather + and weight_quantizer.supports_only_rowwise_all_gather() ): # All-gather is not supported with FP8 column-wise data inputmat.update_usage(rowwise_usage=True, columnwise_usage=False) @@ -589,7 +595,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], else: # Quantize input tensor quantizer = ctx.input_quantizer - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + if quantizer.supports_only_rowwise_all_gather(): # All-gather is not supported with FP8 column-wise data quantizer.set_usage( rowwise=True, @@ -607,7 +613,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], quantizer = None if ctx.fp8 or ctx.debug: quantizer = ctx.input_quantizer - if isinstance(quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer)): + if quantizer.supports_only_rowwise_all_gather(): # If data is in FP8, we compute FP8 transposes manually quantizer.set_usage(rowwise=True, columnwise=False) else: @@ -1077,9 +1083,6 @@ def __init__( self.save_original_input = save_original_input self.name = name - if TEDebugState.debug_enabled: - self._turn_off_unsupported_features_in_debug() # turn off userbuffers - self.wgrad_store = WeightGradStore(delay_wgrad_compute, ub_bulk_wgrad) if device == "meta": @@ -1341,9 +1344,7 @@ def forward( if is_in_onnx_export_mode(): return self.onnx_forward(inp, fp8_output) - debug = TEDebugState.debug_enabled - if debug: - self._validate_name() + debug = self.is_debug_iter() if FP8GlobalStateManager.fp8_graph_capturing(): skip_fp8_weight_update = FP8GlobalStateManager.get_skip_fp8_weight_update_tensor() @@ -1373,14 +1374,11 @@ def forward( if not debug else self._get_debug_quantizers(fp8_output, fp8_grad) ) + if debug: - if not any_feature_enabled(quantizers): - # If no feature is used, then run faster implementation with debug = False. - quantizers = self._get_quantizers(fp8_output, fp8_grad) + if self.no_debug_features_active(quantizers): debug = False - - if isinstance(weight_tensor, QuantizedTensor): - raise RuntimeError("FP8 weights are not supported in debug mode.") + quantizers = self._get_quantizers(fp8_output, fp8_grad) ( input_quantizer, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 895e68bf0..acc03ba78 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -184,6 +184,12 @@ def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return DelayedScaling + def supports_only_rowwise_all_gather(self) -> bool: + """ + Float8Quantizer supports only rowwise all-gather + """ + return True + class Float8CurrentScalingQuantizer(Quantizer): """Builder class for FP8 tensors with per-tensor current scaling @@ -361,6 +367,12 @@ def _canonicalized_amax_reduction_group(self) -> dist_group_type: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: return Float8CurrentScaling + def supports_only_rowwise_all_gather(self) -> bool: + """ + Float8CurrentScalingQuantizer supports only rowwise all-gather + """ + return True + class Float8Tensor(Float8TensorBase, QuantizedTensor): """Experimental tensor class with FP8 data diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index d28b1583b..656eda46c 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -260,6 +260,10 @@ def onnx_dequantize(self, tensor) -> torch.Tensor: def _get_compatible_recipe(self) -> Union[type[Recipe], None]: """Returns recipe class that is compatible with this quantizer""" + def supports_only_rowwise_all_gather(self) -> bool: + """Returns True if the quantizer supports only rowwise all-gather""" + return False + class _QuantizeFunc(torch.autograd.Function): """Cast to FP8 from other dtype""" From 235c8d0008ce9688807605d7e26e9ce22fad5356 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 8 Aug 2025 18:50:39 -0400 Subject: [PATCH 056/153] [JAX] Enable TE GEMM custom call for all recipes (#2047) * enabled TE GEMM for all recipes Signed-off-by: Phuong Nguyen * add warnings Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/jax/cpp_extensions/base.py | 2 +- transformer_engine/jax/layernorm_mlp.py | 23 +++++++++++++++++++ transformer_engine/jax/quantize/helper.py | 3 --- 3 files changed, 24 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index fcc2108cc..22842e4f3 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -34,7 +34,7 @@ class BasePrimitive(metaclass=ABCMeta): _is_enabled = True # Default list of primitives to disable for all recipes - _default_disable_names = ["GemmPrimitive"] + _default_disable_names = [] @classmethod def enabled(cls): diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 8727ea7e3..ce3ebc78a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -15,6 +15,7 @@ from typing import List, Tuple, Sequence, Union, Callable from functools import partial +import warnings import jax import jax.numpy as jnp @@ -92,6 +93,28 @@ def layernorm_mlp( """ assert len(kernels) == 2 + # For MaxText TP (= Megatron TP + sharding in hidden dimension of remaining unsharded + # activations), JAX dot_general may perform better then TE GEMM custom call + # This inspection only works if either norm_input_axes or dot_1_input_axes is set + is_mxfp8 = ( + False + if quantizer_sets[0] == noop_quantizer_set + else quantizer_sets[0].x.scaling_mode.is_1d_block_scaling() + ) + inspect_axes = norm_input_axes or dot_1_input_axes + if ( + inspect_axes is not None + and len(inspect_axes) == x.ndim + and inspect_axes[-1] is not None + and not is_mxfp8 + ): + warnings.warn( + "Detected sharding in the hidden dimension of the MLP activation input. For improved" + " performance, consider using JAX’s built-in `dot_general` implementation. To try" + " this, set the environment variable: `NVTE_JAX_CUSTOM_CALLS='GemmPrimitive=false'`", + UserWarning, + ) + kernel_1 = kernels[0] kernel_2 = kernels[1] bias_1 = biases[0] diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index e31f1852b..122265ea2 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -352,9 +352,6 @@ def initialize(fp8_recipe: recipe.Recipe) -> None: cls.initialize(fp8_recipe) cls.AMAX_HISTORY_LEN = 0 - # Use TE GEMM instead of JAX GEMM for better performance - tex.base.manage_primitives(enable_names=["GemmPrimitive"]) - @staticmethod def finalize() -> None: """Reset the block scaling configuration.""" From 077e26c319d0fffbff75e56124b26b8b04aac0e5 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Sat, 9 Aug 2025 13:11:39 +1200 Subject: [PATCH 057/153] Use userbuffers for MXFP8 wgrad all-gather overlap (#1982) * fix: Add stream synchronization before destroying MPI communicator (#1979) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * feat: Implement column-wise userbuffer overlap for comm+GEMM operations Add support for overlapping column-wise allgather communication with GEMM operations to improve training performance: * **Core infrastructure changes:** - Update bulk_overlap_columnwise_ag() to accept explicit stream parameter - Modify userbuffers send/recv loops to use rank-ordered iteration - Add userbuffers_send_all/recv_all function declarations * **Python integration:** - Add bulk_overlap_ag_with_external_gemm() C++ extension function - Expose new overlap function via pybind11 bindings - Update overlap method configurations to include more ring_exchange ops * **LayerNorm MLP optimization:** - Enable column-wise quantization for FC2 gradient output - Implement overlap of allgather communication with FC2 DGRAD GEMM - Use fill_userbuffers_buffer_for_all_gather for efficient buffering This optimization allows overlapping communication and computation phases more effectively, reducing training wall-clock time by hiding allgather latency behind GEMM execution. Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * fix: Working userbuffer overlapping API Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * fix: Fix overwriting bulk overlap UB object for layernormLinear Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * fix: Update external overlap to use tp size instead of nvsize to determine number of copies Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * fix: Fix linter error Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * fix: Explanatory comments of overlap logic Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * fix: Fix the UB fused ops tests Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> * fix: Fix linter errors Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --------- Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../distributed/run_layer_with_overlap.py | 1 + .../comm_gemm_overlap/comm_gemm_overlap.cpp | 25 +++++++++++++ .../userbuffers/userbuffers.cu | 24 ++++++++++++ .../userbuffers/userbuffers.h | 8 ++++ .../transformer_engine/comm_gemm_overlap.h | 20 +++++++++- .../common/util/pybind_helper.h | 4 +- transformer_engine/pytorch/csrc/extensions.h | 15 +++++++- .../csrc/extensions/comm_gemm_overlap.cpp | 18 +++++++-- .../pytorch/csrc/extensions/pybind.cpp | 7 ++++ transformer_engine/pytorch/module/base.py | 29 +++++++++++++-- .../pytorch/module/layernorm_linear.py | 33 +++++++++++------ .../pytorch/module/layernorm_mlp.py | 35 ++++++++++++------ transformer_engine/pytorch/module/linear.py | 34 +++++++++++------ .../ops/fused/userbuffers_backward_linear.py | 37 ++++++++++++------- 14 files changed, 228 insertions(+), 62 deletions(-) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 8638c1bce..2fc4537f0 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -519,6 +519,7 @@ def run_fwd_bwd(model, x): if opts.use_cuda_graphs: del test_graph + torch.cuda.synchronize() te.module.base.destroy_ub() dist_print("Destroying Userbuffers objects...", debug=True) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 38a6e3e61..9ba6688ce 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -138,6 +138,11 @@ CommOverlapCore::~CommOverlapCore() { cudaStreamDestroy(_stream_compute[i]); } + auto error = cudaGetLastError(); + if (error != cudaSuccess) { + NVTE_WARN("Error detected while destroying communicator: ", cudaGetErrorString(error)); + } + if (_comm_created) { try { #ifdef NVTE_UB_WITH_MPI @@ -289,6 +294,7 @@ CommOverlapBase::CommOverlapBase(const std::vector &buffer_shape, DType CommOverlapBase::~CommOverlapBase() { cudaEventDestroy(_start_d2dcopy); + cudaStreamSynchronize(_stream_comm); cudaStreamDestroy(_stream_comm); } @@ -591,6 +597,25 @@ void CommOverlapBase::split_overlap_rs(const TensorWrapper &A, bool transa, cons NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); } // CommOverlapBase::split_overlap_rs +void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) { + int comm_bytes = _ubuf.bytes(); + int comm_bytes_per_rank = comm_bytes / _tp_size; + + // We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush + userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, + send_stream); + userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, + recv_stream); + + for (auto stream : {send_stream, recv_stream}) { + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); + // We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf + NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0)); + } +} + /*************************************************************************************************** * Comm+GEMM Overlap P2P Base (Ring-Exchange) **************************************************************************************************/ diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 1211392e4..893644ce6 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2535,6 +2535,30 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds } } +void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, communicator *comm, cudaStream_t stream) { + for (int j = 1; j < tp_size; j++) { + int i = (tp_rank + j) % tp_size; + int send_offset = srcoffset + bytes_per_slice * tp_rank; + int recv_offset = dstoffset + bytes_per_slice * tp_rank; + userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, + stream); + } +} + +void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, communicator *comm, cudaStream_t stream) { + for (int j = tp_size - 1; j > 0; j--) { + int i = (tp_rank + j) % tp_size; + int send_offset = srcoffset + bytes_per_slice * i; + int recv_offset = dstoffset + bytes_per_slice * i; + userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, + stream); + } +} + // producer static __global__ void producer_kernel(void *atomic_ptr, int chunk_i) { // Decrement atomic val to signal current output tile finish diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 03e45b978..34d6ff72f 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -304,4 +304,12 @@ void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inp void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cudaStream_t stream); +void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, communicator *comm, cudaStream_t stream); + +void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, + const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, + int tp_size, communicator *comm, cudaStream_t stream); + #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h index 293c57526..4d65e26ce 100644 --- a/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h +++ b/transformer_engine/common/include/transformer_engine/comm_gemm_overlap.h @@ -36,7 +36,8 @@ enum class CommOverlapAlgo { SPLIT_PIPELINED_RS_P2P = 4, ATOMIC_GEMM_RS = 5, ATOMIC_GEMM_AG_P2P = 6, - ATOMIC_GEMM_RS_P2P = 7 + ATOMIC_GEMM_RS_P2P = 7, + EXTERNAL_BULK_OVERLAP_AG = 8, }; class CommOverlapCore { @@ -133,6 +134,11 @@ class CommOverlapCore { cudaStream_t stream_main) { NVTE_ERROR("Operation is not implemented."); } + + virtual void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) { + NVTE_ERROR("Operation is not implemented."); + } }; // CommOverlapCore class CommOverlapBase : public CommOverlapCore { @@ -198,6 +204,9 @@ class CommOverlapBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override; + + void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) override; }; // CommOverlapBase class CommOverlapP2PBase : public CommOverlapCore { @@ -277,6 +286,15 @@ class CommOverlapP2PBase : public CommOverlapCore { TensorWrapper &workspace, bool grad, bool accumulate, bool use_split_accumulator, TensorWrapper &rs_output, cudaStream_t stream_main) override; + + /* + ** This function overlaps the AG for the current communicator object with the GEMM for the overlap_gemm object. + ** The gemm for overlap_gemm is assumed to have been previously started. + */ + void bulk_overlap_external_ag(cudaStream_t send_stream, cudaStream_t recv_stream, + cudaStream_t stream_main) override { + NVTE_ERROR("Operation not supported."); + } }; // CommOverlapP2PBase } // namespace transformer_engine diff --git a/transformer_engine/common/util/pybind_helper.h b/transformer_engine/common/util/pybind_helper.h index a1cd85ba2..67d21f618 100644 --- a/transformer_engine/common/util/pybind_helper.h +++ b/transformer_engine/common/util/pybind_helper.h @@ -94,7 +94,9 @@ transformer_engine::CommOverlapAlgo::SPLIT_PIPELINED_RS_P2P) \ .value("ATOMIC_GEMM_RS", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS) \ .value("ATOMIC_GEMM_AG_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_AG_P2P) \ - .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P); \ + .value("ATOMIC_GEMM_RS_P2P", transformer_engine::CommOverlapAlgo::ATOMIC_GEMM_RS_P2P) \ + .value("EXTERNAL_BULK_OVERLAP_AG", \ + transformer_engine::CommOverlapAlgo::EXTERNAL_BULK_OVERLAP_AG); \ py::class_>(m, "CommOverlapCore", \ pybind11::module_local()) \ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 0b2ace76a..1f2460cbf 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -11,6 +11,10 @@ #include "common.h" +class CommOverlapHelper; +class CommOverlap; +class CommOverlapP2P; + namespace transformer_engine::pytorch { /*************************************************************************************************** @@ -419,6 +423,13 @@ void nvshmem_wait_on_current_stream(at::Tensor signal, const std::string &wait_k void nvshmem_finalize(); +/*************************************************************************************************** + * Comm+GEMM Overlap Wrappers + **************************************************************************************************/ + +void bulk_overlap_ag_with_external_gemm(CommOverlap &allgather_communicator, at::Stream send_stream, + at::Stream recv_stream); + } // namespace transformer_engine::pytorch /*************************************************************************************************** @@ -468,7 +479,7 @@ class CommOverlap : torch::CustomClassHolder, public transformer_engine::CommOve at::Tensor get_buffer(bool local_chunk = false, std::optional> shape = std::nullopt); - at::Stream get_communication_stream(); + std::pair get_communication_stream(); }; // CommOverlap @@ -489,7 +500,7 @@ class CommOverlapP2P : torch::CustomClassHolder, public transformer_engine::Comm at::Tensor get_buffer(bool local_chunk = false, std::optional> shape = std::nullopt); - at::Stream get_communication_stream(); + std::pair get_communication_stream(); }; // CommOverlapP2P diff --git a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp index 0e7bca25b..38947c5a9 100644 --- a/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp +++ b/transformer_engine/pytorch/csrc/extensions/comm_gemm_overlap.cpp @@ -216,8 +216,10 @@ at::Tensor CommOverlap::get_buffer(bool local_chunk, std::optional CommOverlap::get_communication_stream() { + // Return the same stream for both send and recv + return {at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device()), + at::cuda::getStreamFromExternal(_stream_comm, at::cuda::current_device())}; } /*************************************************************************************************** @@ -305,6 +307,14 @@ at::Tensor CommOverlapP2P::get_buffer(bool local_chunk, std::optional CommOverlapP2P::get_communication_stream() { + return {at::cuda::getStreamFromExternal(_stream_send[0], at::cuda::current_device()), + at::cuda::getStreamFromExternal(_stream_recv, at::cuda::current_device())}; +} + +void transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm( + CommOverlap &allgather_communicator, at::Stream send_stream, at::Stream recv_stream) { + auto main_stream = at::cuda::getCurrentCUDAStream(); + allgather_communicator.bulk_overlap_external_ag(at::cuda::CUDAStream(send_stream), + at::cuda::CUDAStream(recv_stream), main_stream); } diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index af06bb9fb..dceaa5b15 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -374,6 +374,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda, "Fused compute scale and scale_inv from amax", py::call_guard()); + // Comm+GEMM Overlap + m.def("bulk_overlap_ag_with_external_gemm", + &transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm, + "Bulk overlap All-Gather with a GEMM operation launched by another communicator", + py::call_guard(), py::arg("allgather_communicator"), + py::arg("send_stream"), py::arg("recv_stream")); + // Data structures py::class_(m, "FP8TensorMeta") .def(py::init<>()) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index b28b9db98..5d04b29f7 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -151,7 +151,7 @@ def initialize_ub( ``` for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", - "fc2_fprop", "fc2_dgrad"]`. + "fc2_fprop", "fc2_wgrad"]`. bootstrap_backend : str = None `torch.distributed` communication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are @@ -250,22 +250,31 @@ def initialize_ub( "qkv_fprop", "qkv_dgrad", "proj_dgrad", + "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", + "fc2_wgrad", ] layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] dgrad_reduce_scatter_overlap = ["qkv_dgrad", "fc1_dgrad"] # Default overlap methods for layers methods = { - "ring_exchange": ["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], + "ring_exchange": [ + "qkv_fprop", + "fc1_fprop", + "proj_dgrad", + "fc2_dgrad", + ], "pipeline": ["proj_fprop", "fc2_fprop"], "bulk": ["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], + "external": ["proj_wgrad", "fc2_wgrad"], } # AG-RS overlap pairs of layers forming a tensor-parallel block ag_rs_pairs = {"qkv_fprop": "proj_fprop", "fc1_fprop": "fc2_fprop"} rs_ag_pairs = {v: k for k, v in ag_rs_pairs.items()} + external_gemm_to_overlap = {"proj_wgrad": "proj_dgrad", "fc2_wgrad": "fc2_dgrad"} global layers_atomic_ring_exchange layers_atomic_ring_exchange = [] @@ -319,7 +328,7 @@ def add_ub( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." - if method == "bulk": + if method in ("bulk", "external"): warnings.warn( f"At {name}, atoimic GEMM not is supported for a bulk overlap." "Defaulting to `atomic_gemm=False`." @@ -348,6 +357,16 @@ def add_ub( if atomic_gemm and method == "ring_exchange": assert rs_ag_pairs[name] in layers_atomic_ring_exchange, assert_message + if name in external_gemm_to_overlap: + assert method == "external", ( + f"At {name}, `external` overlap method is specified, but the selected method is" + f" {method}" + ) + assert external_gemm_to_overlap[name] in methods["ring_exchange"], ( + f"At {name}, `external` overlap method is specified, but the external gemm" + f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" + ) + buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype if method == "ring_exchange": ub_obj = tex.CommOverlapP2P( @@ -396,7 +415,9 @@ def add_ub( new_method = ub_cfgs[name]["method"] methods[new_method].append(name) - for name in methods["ring_exchange"] + methods["pipeline"] + methods["bulk"]: + for name in ( + methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] + ): ub_cfg = get_default_config(name) if ub_cfgs is not None and name in ub_cfgs: fp8_buf = (name in layers_all_gather_overlap) or ( diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 1413dd172..04e3eba7d 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -758,27 +758,36 @@ def backward( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output + # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered # for the dgrad GEMM. We work around by explicitly - # overlapping the NCCL operation with the dgrad GEMM. + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() + + # This object is separate from the ub_obj_wgrad object which is passed to the GEMM + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad") + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - # Get the communication stream from the dgrad GEMM and set it as the current torch stream - dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() - with torch.cuda.stream(dgrad_comm_stream): - # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather - # This ensures that we don't start until all communication for the dgrad GEMM is complete - grad_output, mxfp8_grad_output_work = gather_along_first_dim( + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_overlap_wgrad, grad_outputs[0], + ctx.grad_output_quantizer, ctx.tp_group, - async_op=True, - quantizer=ctx.grad_output_quantizer, ) - # Synchronize with the main stream - mxfp8_grad_output_work.wait() + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + tex.bulk_overlap_ag_with_external_gemm( + ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream + ) # Prepare input tensor # Note: Synchronize tensor-parallel communication and diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 4149ab73c..c384dc3a7 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -851,26 +851,37 @@ def backward( # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.ub_overlap_ag and isinstance(ctx.fc2_grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output + # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered # for the dgrad GEMM. We work around by explicitly - # overlapping the NCCL operation with the dgrad GEMM. + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ( + ub_obj_fc2_dgrad.get_communication_stream() + ) + + ub_obj_fc2_wgrad = get_ub("fc2_wgrad") + ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - # Get the communication stream from the dgrad GEMM and set it as the current torch stream - dgrad_comm_stream = ub_obj_fc2_dgrad.get_communication_stream() - with torch.cuda.stream(dgrad_comm_stream): - # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather - # This ensures that we don't start until all communication for the dgrad GEMM is complete - grad_output, mxfp8_fc2_grad_output_work = gather_along_first_dim( + + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_fc2_wgrad, grad_outputs[0], + ctx.fc2_grad_output_quantizer, ctx.tp_group, - async_op=True, - quantizer=ctx.fc2_grad_output_quantizer, ) - # Synchronize with the main stream - mxfp8_fc2_grad_output_work.wait() + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + tex.bulk_overlap_ag_with_external_gemm( + ub_obj_fc2_wgrad, dgrad_send_stream, dgrad_recv_stream + ) # Prepare input tensor # Note: Synchronize tensor-parallel communication and diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index c5dc43e6e..8b05e71d7 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -745,26 +745,36 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Note: Synchronize tensor-parallel communication and # make sure required data is available if ctx.ub_overlap_ag and isinstance(ctx.grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output + # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered # for the dgrad GEMM. We work around by explicitly - # overlapping the NCCL operation with the dgrad GEMM. + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() + + # This object is separate from the ub_obj_wgrad object which is passed to the GEMM + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad") + ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - # Get the communication stream from the dgrad GEMM and set it as the current torch stream - dgrad_comm_stream = ub_obj_dgrad.get_communication_stream() - with torch.cuda.stream(dgrad_comm_stream): - # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather - # This ensures that we don't start until all communication for the dgrad GEMM is complete - grad_output, grad_output_work = gather_along_first_dim( + + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + grad_output, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_overlap_wgrad, grad_output_arg, + ctx.grad_output_quantizer, ctx.tp_group, - async_op=True, - quantizer=ctx.grad_output_quantizer, ) - # Synchronize with the main stream - grad_output_work.wait() + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + tex.bulk_overlap_ag_with_external_gemm( + ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream + ) if ctx.fp8 or ctx.debug: if isinstance(grad_output, QuantizedTensorBase): diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index b8acb02e3..54a4d49db 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -10,9 +10,9 @@ import torch -from transformer_engine_torch import CommOverlapType +from transformer_engine_torch import CommOverlapType, bulk_overlap_ag_with_external_gemm from ...cpp_extensions import general_gemm -from ...distributed import gather_along_first_dim, get_distributed_world_size +from ...distributed import get_distributed_world_size from ...module.base import ( fill_userbuffers_buffer_for_all_gather, get_ub, @@ -398,26 +398,35 @@ def _functional_backward( # Initialize grad output if tensor_parallel_mode == "row" and isinstance(grad_output_quantizer, MXFP8Quantizer): - # UB does not support overlapping grad output + # UB does not support pipelined overlapping grad output # all-gather with wgrad GEMM. Also, we can't # convert row-scaled MXFP8 to column-scaled, so we # can't reuse the grad output that was gathered # for the dgrad GEMM. We work around by explicitly - # overlapping the NCCL operation with the dgrad GEMM. + # overlapping the AG operation with the dgrad GEMM. + + # Get the communication stream from the dgrad GEMM to use for the AG + dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream() + + ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad") + grad_output_quantizer.set_usage(rowwise=False, columnwise=True) - # Get the communication stream from the dgrad GEMM and set it as the current torch stream - dgrad_comm_stream = ub_comm_dgrad.get_communication_stream() - with torch.cuda.stream(dgrad_comm_stream): - # Syncs with the current stream (dgrad_comm_stream) before starting the all-gather - # This ensures that we don't start until all communication for the dgrad GEMM is complete - dy, dy_work = gather_along_first_dim( + + # We use the send stream to copy into the userbuffers. + # This is the same stream that we will use to access the data in the AG, + # so we dont need to add any syncs yet. + with torch.cuda.stream(dgrad_send_stream): + dy, _ = fill_userbuffers_buffer_for_all_gather( + ub_obj_overlap_wgrad, dy_local, + grad_output_quantizer, tensor_parallel_group, - async_op=True, - quantizer=grad_output_quantizer, ) - # Synchronize with the main stream - dy_work.wait() + + # Allgather grad_outputs[0] using the dgrad streams so we can overlap with the fc2_dgrad gemm + bulk_overlap_ag_with_external_gemm( + ub_obj_overlap_wgrad, dgrad_send_stream, dgrad_recv_stream + ) if tensor_parallel_mode == "column": dy = dy_local From de6afe24b2b33d13a86770d761c165804ab0400e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 11 Aug 2025 15:09:24 -0700 Subject: [PATCH 058/153] [PyTorch] Fix high-precision dtype for MXFP8 AG (#2058) * Fix high-precision dtype for MXFP8 AG Signed-off-by: Kirthi Shankar Sivamani * Comment Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/distributed.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index c3b42c5c4..709b4f3b8 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1225,14 +1225,12 @@ def _all_gather_mxfp8( if inp._rowwise_data is not None: in_shape = inp._rowwise_data.size() device = inp._rowwise_data.device - dtype = inp._rowwise_data.dtype elif inp._columnwise_data is not None: in_shape = inp._columnwise_data.size() device = inp._columnwise_data.device - dtype = inp._columnwise_data.dtype else: raise ValueError("Got MXFP8 input tensor without any data") - dtype = torch.bfloat16 + dtype = torch.bfloat16 # Guess high-precision dtype. else: raise ValueError( "Invalid type for input tensor (expected torch.Tensor or MXFP8TensorBase, " From bfca2e33a4c4f832b1189c037b5893683190d6de Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 11 Aug 2025 21:48:31 -0700 Subject: [PATCH 059/153] [PyTorch] Update amax pointers when reallocating amax history in fusible ops (#2044) * Update amax pointers when reallocating amax history in fusible ops Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update weight tensor quantizer when recipe state is reset Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../pytorch/ops/basic/basic_linear.py | 32 +++++++++++--- transformer_engine/pytorch/ops/fuser.py | 43 ++++++++++--------- transformer_engine/pytorch/ops/op.py | 22 +++++++++- 3 files changed, 70 insertions(+), 27 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 7f10336ce..5a151a362 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -301,6 +301,7 @@ def reset_parameters(self) -> None: rowwise=True, columnwise=torch.is_grad_enabled(), ) + quantizer.internal = False with torch.no_grad(): weight = quantizer(weight) @@ -317,11 +318,32 @@ def pre_first_fuser_forward(self) -> None: def reset_recipe_state(self, *, recipe: Optional[Recipe]) -> None: super().reset_recipe_state(recipe=recipe) - if recipe is not None and not FP8GlobalStateManager.with_fp8_parameters(): - # Make quantizers use internal tensors - self.get_input_quantizer().internal = True - self.get_grad_output_quantizer().internal = True - self.get_quantizer("forward", 1).internal = True + # Input/grad output quantizers use internal tensors + input_quantizer = self.get_quantizer("forward", 0) + grad_output_quantizer = self.get_quantizer("backward", 0) + if input_quantizer is not None: + input_quantizer.internal = True + if grad_output_quantizer is not None: + grad_output_quantizer.internal = True + + # Handle weight quantizer + # Note: This function may be called in base class constructor, + # before any basic linear attrs have been set. + weight_quantizer = self.get_quantizer("forward", 1) + if weight_quantizer is None: + pass + elif is_quantized_tensor(getattr(self, "weight", None)): + # Make sure weight param has correct quantizer + weight_quantizer.set_usage(rowwise=True, columnwise=torch.is_grad_enabled()) + weight_quantizer.internal = False + self.weight.update_quantizer(weight_quantizer.copy()) + else: + # Use internal tensors if quantized weights will not be + # exposed externally + weight_quantizer.internal = ( + not FP8GlobalStateManager.with_fp8_parameters() + and not getattr(self, "_with_quantized_weight", False) + ) @staticmethod def _functional_forward( diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 98b3468a2..2ee476779 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -110,14 +110,6 @@ def forward( xs, extra_inputs = _split_tuple(extra_inputs, op.num_extra_inputs) basic_op_extra_inputs.append(xs) - # Get environment state - with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() - recipe = FP8GlobalStateManager.get_fp8_recipe() if with_quantized_compute else None - is_grad_enabled = func_ctx is not None - - # Attempt to fuse operations if neccesary - fuser.maybe_fuse_ops(is_grad_enabled, recipe, input_, basic_op_extra_inputs) - # Apply forward ops x = input_ extra_outputs = [None] * fuser._num_basic_ops @@ -167,7 +159,7 @@ def forward( extra_outputs_flat.extend(ys) # Save context for backward pass - if is_grad_enabled: + if func_ctx is not None: # Flatten list of saved tensors to_save = [] @@ -180,12 +172,9 @@ def forward( ctx._saved_tensors_range = (range_start, range_end) # Save tensors for backward - if with_quantized_compute: - tensors_to_save, tensor_objects = prepare_for_saving(*to_save) - func_ctx.save_for_backward(*tensors_to_save) - func_ctx.tensor_objects = tensor_objects - else: - func_ctx.save_for_backward(*to_save) + tensors_to_save, tensor_objects = prepare_for_saving(*to_save) + func_ctx.save_for_backward(*tensors_to_save) + func_ctx.tensor_objects = tensor_objects # Other context func_ctx.backward_ops = fuser._backward_ops @@ -195,7 +184,6 @@ def forward( func_ctx.num_extra_inputs = fuser.num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() - func_ctx.with_quantized_compute = with_quantized_compute # Mark output tensors as not deletable in backward for tensor in [x] + extra_outputs_flat: @@ -223,10 +211,7 @@ def backward( basic_op_ctxs = func_ctx.basic_op_ctxs # Restore saved tensors - if func_ctx.with_quantized_compute: - saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors) - else: - saved_tensors = func_ctx.saved_tensors + saved_tensors = restore_from_saved(func_ctx.tensor_objects, func_ctx.saved_tensors) # Unflatten list of saved tensors for ctx in basic_op_ctxs: @@ -460,8 +445,24 @@ def __call__( if basic_op_kwargs is None: basic_op_kwargs = [{}] * self._num_basic_ops + # Unflatten list of extra tensor inputs + extra_inputs_copy = list(extra_inputs) + basic_op_extra_inputs = [] + for op in self._basic_ops: + xs, extra_inputs_copy = _split_tuple(extra_inputs_copy, op.num_extra_inputs) + basic_op_extra_inputs.append(xs) + + # Get environment state + recipe = None + if FP8GlobalStateManager.is_fp8_enabled(): + recipe = FP8GlobalStateManager.get_fp8_recipe() + is_grad_enabled = torch.is_grad_enabled() + + # Attempt to fuse operations if neccesary + self.maybe_fuse_ops(is_grad_enabled, recipe, input, basic_op_extra_inputs) + # Fuser forward pass - if torch.is_grad_enabled(): + if is_grad_enabled: forward_func = _OperationFuserAutogradFunction.apply args = [] else: diff --git a/transformer_engine/pytorch/ops/op.py b/transformer_engine/pytorch/ops/op.py index c2efc5169..903bc49d5 100644 --- a/transformer_engine/pytorch/ops/op.py +++ b/transformer_engine/pytorch/ops/op.py @@ -294,6 +294,8 @@ def reset_recipe_state( forward=(mode == "forward"), ) recipe_state = self._fp8_metas[mode][fp8_meta_key] + + # Reallocate amax history if needed current_length = recipe_state.amax_history.size(0) target_length = recipe.amax_history_len if target_length < current_length: @@ -308,6 +310,25 @@ def reset_recipe_state( pad=(0, 0, 0, target_length - current_length), ) + # Update quantizers with new amax pointers + self._quantizers[mode] = recipe_state.make_quantizers() + + # Update the global buffers with new amax pointers + if FP8GlobalStateManager.get_buffer_info() in self._fp8_metas[mode]: + pos, buffer_key = self._fp8_metas[mode][ + FP8GlobalStateManager.get_buffer_info() + ] + if buffer_key in FP8GlobalStateManager.global_amax_buffer: + assert ( + buffer_key in FP8GlobalStateManager.global_amax_history_buffer + ), "TE internal error during amax history change." + FP8GlobalStateManager.global_amax_buffer[buffer_key][pos] = ( + recipe_state.amax_history[0] + ) + FP8GlobalStateManager.global_amax_history_buffer[buffer_key][ + pos + ] = recipe_state.amax_history + # Add meta tensors to global buffer to participate in reduction for mode in ("forward", "backward"): if ( @@ -686,7 +707,6 @@ def get_grad_output_quantizer(self) -> Optional[Quantizer]: return self.basic_ops[-1].get_grad_output_quantizer() def pre_first_fuser_forward(self) -> None: - """Preprocessing before first fuser forward pass""" for op in self.basic_ops: op.pre_first_fuser_forward() From f947e703cb4f78ba94b3a9986eec32db92ec0e1c Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 11 Aug 2025 22:13:30 -0700 Subject: [PATCH 060/153] [PyTorch] Fix bug when deducing dtype in linear functional API (#2017) Fix bug when deducing dtype in linear functional API Signed-off-by: Tim Moon --- transformer_engine/pytorch/ops/basic/basic_linear.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 5a151a362..c0ec991ff 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -426,7 +426,7 @@ def _functional_forward( if dtype is None: if out is not None and isinstance(out, torch.Tensor): dtype = out.dtype - elif weight is not None and isinstance(out, torch.Tensor): + elif weight is not None and isinstance(weight, torch.Tensor): dtype = weight.dtype else: raise ValueError( From 6a4e871ef6bb01d8e0670d8bf4d1b6fa9bdf1a7f Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 12 Aug 2025 13:57:11 -0700 Subject: [PATCH 061/153] [JAX] Support custom recipe and custom collection name when creating quantizer sets (#2059) * Support setting collection name for quantizer set Flax variables in TransformerEngineBase flax module Signed-off-by: Jeremy Berchtold * Support creating quantizer set from a recipe directly Signed-off-by: Jeremy Berchtold * Fix debug error format string in gemm.py Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 2 +- transformer_engine/jax/flax/module.py | 15 ++++++++++---- transformer_engine/jax/quantize/quantizer.py | 20 +++++++++++++++++-- 3 files changed, 30 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 5c2438906..be2dfabb3 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1090,7 +1090,7 @@ def _jax_gemm_fp8_impl(lhs, rhs): if lhs.scaling_mode == ScalingMode.MXFP8_1D_SCALING: return _jax_gemm_mxfp8_1d(lhs, rhs, dim_nums) - raise NotImplementedError("Unsupported ScalingMode: {lhs.scaling_mode}") + raise NotImplementedError(f"Unsupported ScalingMode: {lhs.scaling_mode}") lhs_q, rhs_q = _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_dims) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index e923991e4..8c7135210 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -337,21 +337,28 @@ class TransformerEngineBase(nn.Module): # pylint: disable=too-few-public-method Base class of transformer engine """ - def generate_quantizer_set(self, postfix: str = ""): + def generate_quantizer_set( + self, postfix: str = "", variable_collection: str = None, fp8_recipe=None + ): """ Generate a set of FP8 meta for a GEMM. """ def generate_quantize_meta(quantizer_name: str): + collection_name = ( + variable_collection + if variable_collection is not None + else QuantizeConfig.COLLECTION_NAME + ) scale = self.variable( - QuantizeConfig.COLLECTION_NAME, + collection_name, f"{quantizer_name}{postfix}_scale", jnp.ones, (1,), jnp.float32, ).value amax_history = self.variable( - QuantizeConfig.COLLECTION_NAME, + collection_name, f"{quantizer_name}{postfix}_amax_history", jnp.zeros, (QuantizeConfig.AMAX_HISTORY_LEN,), @@ -368,7 +375,7 @@ def generate_quantize_meta(quantizer_name: str): else: kwargs = {} - quantizer_set = QuantizerFactory.create_set(**kwargs) + quantizer_set = QuantizerFactory.create_set(fp8_recipe=fp8_recipe, **kwargs) return quantizer_set diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 881f3a74b..09856065c 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -16,12 +16,14 @@ import jax.numpy as jnp from jax.tree_util import register_pytree_node_class from transformer_engine_jax import QuantizeLayout +from transformer_engine.common import recipe from .scaling_modes import ScalingMode from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .helper import ( QuantizeConfig, AmaxComputeAlgo, + _get_scaling_mode, ) from .device_utils import is_fp8_gemm_with_all_layouts_supported @@ -878,11 +880,12 @@ def _create_set( @staticmethod def create_set( n_quantizer_sets: int = 1, - scaling_mode: ScalingMode = None, + scaling_mode: Optional[ScalingMode] = None, fwd_dtype: jnp.dtype = None, bwd_dtype: jnp.dtype = None, is_2x2x: bool = None, n_groups: int = None, + fp8_recipe: Optional[recipe.Recipe] = None, **kwargs, ) -> tuple[Union[tuple[Quantizer], None]]: """Create one or more sets of quantizers. @@ -894,12 +897,25 @@ def create_set( bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X n_groups: + fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. **kwargs: Additional arguments for quantizer initialization Returns: A single quantizer set or tuple of quantizer sets """ - scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE + + assert scaling_mode is None or fp8_recipe is None, ( + "Cannot specify both scaling_mode and fp8_recipe when creating a quantizer set. Scaling" + " mode can be specified directly via the scaling_mode parameter or indirectly via" + " recipe. Recipe is preferred as it will support additional recipes in future where" + " scaling mode differs between x, kernel, and grad in the quantizer set." + ) + + if fp8_recipe is not None: + # TODO(jberchtold): once recipe and scaling mode are decoupled update this logic + scaling_mode = _get_scaling_mode(fp8_recipe) + else: + scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE if is_2x2x is None: From 05d3b7b511ff142a2a0e6d7b46d56b81d1323461 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 12 Aug 2025 15:10:27 -0700 Subject: [PATCH 062/153] [PyTorch] Fix normalization+amax forward CS fusion to work for untuned kernels (#2061) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../layernorm/ln_fwd_kernels.cuh | 24 ++++++++++--------- .../rmsnorm/rmsnorm_fwd_kernels.cuh | 24 ++++++++++--------- 2 files changed, 26 insertions(+), 22 deletions(-) diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh index 417e84a56..6050b164d 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_kernels.cuh @@ -215,6 +215,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne scale = *reinterpret_cast(params.scale); } compute_t amax = 0; + const bool requires_amax = params.amax != nullptr; for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { const int row = cta_row + warp_m; @@ -283,14 +284,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne } // Apply fp8 factors - if (params.fp8_out) { + if (params.fp8_out || requires_amax) { #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); amax = fmaxf(amax, fabsf(z_ij)); - z.data.elt[jt] = z_ij * scale; + if (params.fp8_out) { + z.data.elt[jt] = z_ij * scale; + } } } } @@ -302,17 +305,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void ln_fwd_general_kerne } } - // Finalize fp8 factors - if (params.fp8_out) { - // Reduce amax over block - if (params.amax != nullptr) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + // Reduce amax over block + if (requires_amax) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } + if (params.fp8_out) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { reciprocal(reinterpret_cast(params.scale_inv), scale); diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh index da3f8192c..fc093b73a 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_kernels.cuh @@ -205,6 +205,7 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ scale = *reinterpret_cast(params.scale); } compute_t amax = 0; + const bool requires_amax = params.amax != nullptr; for (int cta_row = bidm * bdimm; cta_row < params.rows; cta_row += gdimm) { const int row = cta_row + warp_m; @@ -258,14 +259,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ } // Apply fp8 factors - if (params.fp8_out) { + if (params.fp8_out || requires_amax) { #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { if (col + jt < params.cols) { compute_t z_ij = z.data.elt[jt]; __builtin_assume(amax >= 0); amax = fmaxf(amax, fabsf(z_ij)); - z.data.elt[jt] = z_ij * scale; + if (params.fp8_out) { + z.data.elt[jt] = z_ij * scale; + } } } } @@ -277,17 +280,16 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_fwd_general_ } } - // Finalize fp8 factors - if (params.fp8_out) { - // Reduce amax over block - if (params.amax != nullptr) { - amax = reduce_max(amax, warp); - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - atomicMaxFloat(reinterpret_cast(params.amax), amax); - } + // Reduce amax over block + if (requires_amax) { + amax = reduce_max(amax, warp); + if (threadIdx.x == 0) { + static_assert(std::is_same::value); + atomicMaxFloat(reinterpret_cast(params.amax), amax); } + } + if (params.fp8_out) { // Update scale-inverse if (blockIdx.x == 0 && threadIdx.x == 0 && params.scale_inv != nullptr) { reciprocal(reinterpret_cast(params.scale_inv), scale); From ec65ba3cd65bc4dd9dab57e9bc357f02996a11ca Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 12 Aug 2025 17:21:39 -0700 Subject: [PATCH 063/153] [JAX] Add L2_jax_distributed_unittest (#2060) * Add L2_jax_distributed_unittest Signed-off-by: Jeremy Berchtold * Add L1 entry for NORM_INPUT_SHAPES that was missing Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- qa/L1_jax_distributed_unittest/test.sh | 2 +- qa/L2_jax_distributed_unittest/test.sh | 11 +++++++++++ tests/jax/test_distributed_layernorm.py | 1 + 3 files changed, 13 insertions(+), 1 deletion(-) create mode 100644 qa/L2_jax_distributed_unittest/test.sh diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index 5deb77af9..f332e32e8 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -8,4 +8,4 @@ set -xe : ${XML_LOG_DIR:=/logs} mkdir -p "$XML_LOG_DIR" -python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* +NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* diff --git a/qa/L2_jax_distributed_unittest/test.sh b/qa/L2_jax_distributed_unittest/test.sh new file mode 100644 index 000000000..0b7372650 --- /dev/null +++ b/qa/L2_jax_distributed_unittest/test.sh @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -xe + +: ${TE_PATH:=/opt/transformerengine} +: ${XML_LOG_DIR:=/logs} +mkdir -p "$XML_LOG_DIR" + +NVTE_JAX_UNITTEST_LEVEL="L2" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index be5c8ef98..a777e2f43 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -25,6 +25,7 @@ NORM_INPUT_SHAPES = { "L0": [[64, 64]], + "L1": [[64, 64]], "L2": [[64, 64]], } From ebca61532000c72113cdb2987d50b9fba08d0d8c Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Wed, 13 Aug 2025 09:54:11 +0800 Subject: [PATCH 064/153] [Common] PDL for Blockwise Quantization (#2066) * enable PDL for blockwise qunatization kernels Signed-off-by: Xin Yao * add comment Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xin Yao --- .../quantize_transpose_square_blockwise.cu | 63 +++++++++++++------ .../quantize_transpose_vector_blockwise.cu | 54 ++++++++++++---- 2 files changed, 84 insertions(+), 33 deletions(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 79d8d215f..0b70f3f40 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -14,6 +14,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" +#include "common/util/cuda_runtime.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" @@ -167,6 +168,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } } +// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's +// store to global memory. +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + // Step 3: Store cast output, Step 4: do transpose within thread tile OVecCast tmp_output_c; @@ -390,6 +397,12 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose } } +// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's +// store to global memory. +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + cudaTriggerProgrammaticLaunchCompletion(); +#endif + // Step 3: Store cast output, Step 4: do transpose within thread tile // Edge case: in the non-full tile case, there are three subcases // for full thread tile, it's the same thing here @@ -511,6 +524,15 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); + dim3 grid(num_blocks_x, num_blocks_y, 1); + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; + cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0}; + if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) { + cfg.attrs = attribute; + cfg.numAttrs = 1; + } TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -521,7 +543,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor TRANSFORMER_ENGINE_SWITCH_CONDITION( return_transpose, kReturnTranspose, - dim3 grid(num_blocks_x, num_blocks_y, 1); const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; @@ -531,26 +552,28 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor tensor_map_output_trans = get_tensor_map(output_t, num_rows, row_length); } - block_scaled_cast_transpose_kernel - <<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - tensor_map_output_trans, pow_2_scale); + cudaLaunchKernelEx(&cfg, + block_scaled_cast_transpose_kernel, + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, + scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale); } else { - block_scaled_cast_transpose_kernel_notaligned - <<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - pow_2_scale); + cudaLaunchKernelEx( + &cfg, + block_scaled_cast_transpose_kernel_notaligned, + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + pow_2_scale); } // full-tile ) // return_transpose ) // OutputType diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 6f5c0f3a6..5bf2f5201 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -17,6 +17,7 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" +#include "common/util/cuda_runtime.h" #include "common/utils.cuh" namespace transformer_engine { @@ -234,6 +235,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo __syncthreads(); +// If not return columnwise, we trigger the next kernel here so that it's load from global memory +// can overlap with this kernel's return rowwise. +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + if (!return_columnwise_gemm_ready && !return_columnwise_compact) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + // Step 2: Cast and store to output_c if (return_rowwise) { constexpr int r_stride = @@ -325,6 +334,14 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } +// If return columnwise, we trigger the next kernel here so that it's load from global memory +// can overlap with this kernel's return columnwise. +#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) + if (return_columnwise_gemm_ready || return_columnwise_compact) { + cudaTriggerProgrammaticLaunchCompletion(); + } +#endif + // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t if (return_columnwise_gemm_ready) { constexpr int c_stride = @@ -584,6 +601,10 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); + dim3 grid(num_blocks_x, num_blocks_y, 1); + cudaLaunchAttribute attribute[1]; + attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; + attribute[0].val.programmaticStreamSerializationAllowed = 1; TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -591,31 +612,38 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output.dtype, OutputType, - dim3 grid(num_blocks_x, num_blocks_y, 1); - const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; TRANSFORMER_ENGINE_SWITCH_CONDITION( full_tile, kAligned, size_t smem_bytes = kSMemSize * sizeof(InputType); + + cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0}; + if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= + 90) { + cfg.attrs = attribute; + cfg.numAttrs = 1; + } // shared memory must be requested up if (smem_bytes >= 48 * 1024) { cudaError_t err = cudaFuncSetAttribute( &block_scaled_1d_cast_transpose_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); - } block_scaled_1d_cast_transpose_kernel - <<>>( - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, - scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, - columnwise_option, pow2_scale);) // kAligned - ) // OutputType - ) // InputType + } cudaLaunchKernelEx(&cfg, + block_scaled_1d_cast_transpose_kernel, + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, + scale_t_stride_y, epsilon, rowwise_option, columnwise_option, + pow2_scale);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } From 6afca29c092b85019e30e412dece146f37748fc5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 13 Aug 2025 12:23:15 +0200 Subject: [PATCH 065/153] [PyTorch Debug] More advanced stats for Quantized Tensors (#1897) * code drop Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * turn on userbuffers for layers without debug Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * working change Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tests and fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * update nvinspect version Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * docs change Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * test Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix default Signed-off-by: Pawel Gadzinski * fix default Signed-off-by: Pawel Gadzinski * fix default Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * tests fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fixes Signed-off-by: Pawel Gadzinski * fixes Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kirthi Shankar Sivamani --- docs/debug/3_api_te_calls.rst | 16 +- docs/debug/img/api_calls2.svg | 1 - qa/L0_pytorch_debug_unittest/test.sh | 1 + tests/pytorch/debug/run_distributed.py | 34 +++ tests/pytorch/debug/test_api_features.py | 89 +++--- .../debug/test_configs/log_config.yaml | 2 +- tests/pytorch/debug/test_log.py | 205 ++++++++++++- transformer_engine/debug/features/api.py | 60 +++- .../debug/features/disable_fp8_gemm.py | 2 +- .../debug/features/disable_fp8_layer.py | 2 +- .../debug/features/log_fp8_tensor_stats.py | 284 ++++++++++++++---- .../debug/features/log_tensor_stats.py | 25 +- .../debug/features/utils/__init__.py | 20 ++ .../debug/features/utils/stats_buffer.py | 25 +- .../debug/features/utils/stats_computation.py | 229 ++++++++++++-- .../debug/pytorch/debug_quantization.py | 10 +- transformer_engine/pytorch/distributed.py | 7 +- .../_internal/float8_blockwise_tensor_base.py | 10 +- .../tensor/_internal/float8_tensor_base.py | 10 +- .../tensor/_internal/mxfp8_tensor_base.py | 10 +- 20 files changed, 846 insertions(+), 196 deletions(-) delete mode 100644 docs/debug/img/api_calls2.svg diff --git a/docs/debug/3_api_te_calls.rst b/docs/debug/3_api_te_calls.rst index eb66c8ff2..1435d41d7 100644 --- a/docs/debug/3_api_te_calls.rst +++ b/docs/debug/3_api_te_calls.rst @@ -12,14 +12,7 @@ Let's look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine w Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in ``config.yaml``, behavior of ``modify_tensor_enabled()`` and ``modify_tensor()`` calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing. -In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. The order of these calls is illustrated in the image below. - -.. figure:: ./img/api_calls2.svg - :align: center - - Fig 2: The calls to Nvidia-DL-Framework-Inspect done for Transformer Engine. There are 2 types of calls: GEMM calls and routing calls. - - +In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. There are 2 categories of API calls, each is used for different purposes: - GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them, @@ -32,14 +25,15 @@ if fusions happen. An important remark is that if no feature is used for the lay .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor -.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor - -.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor + +.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize + .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled .. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled diff --git a/docs/debug/img/api_calls2.svg b/docs/debug/img/api_calls2.svg deleted file mode 100644 index 5df72fc2e..000000000 --- a/docs/debug/img/api_calls2.svg +++ /dev/null @@ -1 +0,0 @@ -Tensor Ainspect_tensorfp8 castmodify_tensorinspect_tensor_postquantizeGEMMinspect_tensormodify_tensorinspect_tensor_enabledinspect_tensor_postquantize_enabledfp8_gemm_enabledmodify_tensor_enabledTensor Binspect_tensorfp8 castmodify_tensorinspect_tensor_postquantizeRouting callsGEMM calls \ No newline at end of file diff --git a/qa/L0_pytorch_debug_unittest/test.sh b/qa/L0_pytorch_debug_unittest/test.sh index 414899aa4..b4bf0a024 100644 --- a/qa/L0_pytorch_debug_unittest/test.sh +++ b/qa/L0_pytorch_debug_unittest/test.sh @@ -23,6 +23,7 @@ pip install pytest==8.2.1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_sanity.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_config.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_numerics.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS || FAIL=1 +pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/debug/test_api_features.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_log.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 pytest -v -s $TE_PATH/tests/pytorch/debug/test_perf.py --feature_dirs=$NVTE_TEST_NVINSPECT_FEATURE_DIRS --configs_dir=$NVTE_TEST_NVINSPECT_CONFIGS_DIR || FAIL=1 diff --git a/tests/pytorch/debug/run_distributed.py b/tests/pytorch/debug/run_distributed.py index b12f8c3d3..716c16056 100644 --- a/tests/pytorch/debug/run_distributed.py +++ b/tests/pytorch/debug/run_distributed.py @@ -364,6 +364,40 @@ def get_stat(tensor, stat): set_weight_tensor_tp_group_reduce(True) # reset +@run_debug_test +def sanity_test_log_quantized_stats(parallel_mode, gather_weight, **kwargs): + from test_log import LOG_QUANTIZED_CONFIG + + kwargs["config_file"].write(LOG_QUANTIZED_CONFIG) + kwargs["config_file"].flush() + _init_debug(kwargs["config_file"].name, kwargs["log_dir"], FEATURE_DIRS) + set_weight_tensor_tp_group_reduce(gather_weight) + if WORLD_SIZE % 2 != 0: + return # skip + TP_SIZE = WORLD_SIZE // 2 + DP_SIZE = 2 + TP_RANK = WORLD_RANK % TP_SIZE + DP_RANK = (WORLD_RANK - TP_RANK) // TP_SIZE + + debug_api.set_tensor_reduction_group(NCCL_WORLD) + + x, weight = _get_tensors( + parallel_mode, + weight_seed=TP_RANK * 1234, + data_seed=DP_RANK * 1234, + tp_size=TP_SIZE, + tp_rank=TP_RANK, + ) + + tp_group_ranks = [i for i in range(DP_RANK * TP_SIZE, (DP_RANK + 1) * TP_SIZE)] + tp_group = dist.new_group(ranks=tp_group_ranks) + + model = _init_model(weight, parallel_mode=parallel_mode, tp_group=tp_group) + _run_forward_backward(x, model, parallel_mode=parallel_mode, group=tp_group) + + set_weight_tensor_tp_group_reduce(True) # reset + + @run_debug_test def test_log_expert_parallel(**kwargs): """ diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py index 2a2ef1fe8..974772599 100644 --- a/tests/pytorch/debug/test_api_features.py +++ b/tests/pytorch/debug/test_api_features.py @@ -36,11 +36,6 @@ def test_transformer_engine_no_config(feature_dirs): "decoder.1.attn.qkv", tensor_name="activation", iteration=0 )[0] - # inspect_tensor_postquantize - (False, None) by default - assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( - "decoder.1.attn.qkv", gemm="fprop", tensor_name="activation", iteration=0 - )[0] - finally: debug_api.end_debug() @@ -236,13 +231,12 @@ def test_statistics_collection(configs_dir, feature_dirs): ) tensor = torch.randn((100, 100, 5)).cuda() - tensor_fp8 = Float8Tensor( - data=tensor.to(torch.uint8).cuda(), - fp8_scale_inv=torch.full([1], 1.0).cuda(), + quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda(), + amax=torch.full([1], 1.0).cuda(), fp8_dtype=tex.DType.kFloat8E4M3, - shape=tensor.shape, - dtype=torch.float32, ) + tensor_fp8 = quantizer(tensor) def log(): from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS @@ -260,6 +254,9 @@ def assert_empty(): tensor_name="activation", iteration=200, tp_group=None, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) stats = log() assert stats[("decoder.1.mlp.fc1", "activation", "cur_amax", 200)] == tensor.abs().max() @@ -269,44 +266,52 @@ def assert_empty(): assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.2.mlp.fc1", tensor_name="activation", iteration=200 )[0] - assert not debug_api.transformer_engine.inspect_tensor_enabled( + + expected_underflows = ( + ((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5) + ) + + assert debug_api.transformer_engine.inspect_tensor_enabled( "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 )[0] - expected_underflows = (tensor_fp8._data == 0).sum() * 100 / (100 * 100 * 5) - # TE FP8 tensor stats -- - assert debug_api.transformer_engine.inspect_tensor_postquantize_enabled( - "decoder.1.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 + assert debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.mlp.fc1", tensor_name="gradient", iteration=200 )[0] - debug_api.transformer_engine.inspect_tensor_postquantize( + debug_api.transformer_engine.inspect_tensor( "decoder.1.mlp.fc1", - tensor=tensor_fp8, tensor_name="gradient", iteration=200, - rowwise=True, tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) stats = log() torch.testing.assert_close( stats[("decoder.1.mlp.fc1", "gradient", "underflows%", 200)], expected_underflows ) - assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( - "decoder.1.mlp.fc1", tensor_name="activation", gemm="fprop", iteration=201 + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.1.mlp.fc1", tensor_name="activation", iteration=201 )[0] - assert not debug_api.transformer_engine.inspect_tensor_postquantize_enabled( - "decoder.2.mlp.fc1", tensor_name="gradient", gemm="wgrad", iteration=200 + assert not debug_api.transformer_engine.inspect_tensor_enabled( + "decoder.2.mlp.fc1", tensor_name="gradient", iteration=200 )[0] # Second config in same yaml tensor = torch.rand((100, 100, 5)) debug_api.transformer_engine.inspect_tensor( "decoder.6.mlp.fc1", - tensor=tensor, tensor_name="activation", iteration=200, tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) stats = log() stats_names = [x[3] for x in stats.keys()] @@ -315,10 +320,13 @@ def assert_empty(): debug_api.transformer_engine.inspect_tensor( "decoder.7.mlp.fc1", - tensor=tensor, tensor_name="weight", iteration=200, tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) stats = log() stats_names = [x[3] for x in stats.keys()] @@ -342,21 +350,16 @@ def test_statistics_multi_run(configs_dir, feature_dirs): default_logging_enabled=False, ) - def feed(tensor, tensor_fp8): + def feed(tensor, tensor_fp8, quantizer): debug_api.transformer_engine.inspect_tensor( "decoder.5.mlp.fc1", tensor=tensor, tensor_name="activation", iteration=1, tp_group=None, - ) - debug_api.transformer_engine.inspect_tensor_postquantize( - "decoder.5.mlp.fc1", - tensor=tensor_fp8, - tensor_name="activation", - iteration=1, - rowwise=True, - tp_group=None, + quantizer=quantizer, + rowwise_quantized_tensor=tensor_fp8, + columnwise_quantized_tensor=tensor_fp8, ) def log_stats(): @@ -364,26 +367,26 @@ def log_stats(): return STATS_BUFFERS.log_stats() + quantizer = Float8Quantizer( + scale=torch.full([1], 1.0).cuda(), + amax=torch.full([1], 1.0).cuda(), + fp8_dtype=tex.DType.kFloat8E4M3, + ) + def fp8_tensor(t): - return Float8Tensor( - data=t.to(torch.uint8).cuda(), - fp8_scale_inv=torch.ones([1]).cuda(), - fp8_dtype=tex.DType.kFloat8E4M3, - shape=t.shape, - dtype=torch.float32, - ) + return quantizer(t.cuda()) shape = [1024, 1024] tensors = [torch.randn(shape) for _ in range(2)] tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)] - feed(tensors[0], tensors_fp8[0]) - feed(tensors[1], tensors_fp8[1]) + feed(tensors[0], tensors_fp8[0], quantizer) + feed(tensors[1], tensors_fp8[1], quantizer) stats1 = log_stats() tensor2 = torch.cat((tensors[0], tensors[1])).cuda() fp8tensor2 = fp8_tensor(tensor2) - feed(tensor2, fp8tensor2) + feed(tensor2, fp8tensor2, quantizer) stats2 = log_stats() assert len(stats1.keys()) > 0 diff --git a/tests/pytorch/debug/test_configs/log_config.yaml b/tests/pytorch/debug/test_configs/log_config.yaml index 04f490b9d..3e94006d9 100644 --- a/tests/pytorch/debug/test_configs/log_config.yaml +++ b/tests/pytorch/debug/test_configs/log_config.yaml @@ -12,7 +12,7 @@ test: freq: 3 LogFp8TensorStats: enabled: True - tensors: weight + tensors: activation stats: [underflows%] start_step: 1 freq: 5 diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index fb0988d76..0b0adb451 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -2,16 +2,207 @@ # # See LICENSE for license information. - -import pytest -import torch +import nvdlfw_inspect.api as debug_api +import transformer_engine.debug import transformer_engine.pytorch as te +import torch import tempfile +from transformer_engine.common import recipe +from transformer_engine.pytorch.fp8 import RecipeState +import pytest +import contextlib import os +from transformer_engine.pytorch.fp8 import FP8GlobalStateManager +from transformer_engine.debug.pytorch.debug_state import TEDebugState -import nvdlfw_inspect.api as debug_api -from transformer_engine.debug.pytorch.debug_state import TEDebugState +fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + FP8GlobalStateManager.is_fp8_block_scaling_available() +) + +LOG_QUANTIZED_CONFIG_BASE = """ +log: + layers: + layer_name_regex_pattern: .* + enabled: + True + transformer_engine: + LogFp8TensorStats: + enabled: True + stats: [ + {stats} + ] + tensors: [activation, gradient, weight] + freq: 2 + start_step: 0 + end_step: 10 +""" +recipes = [ + "fp8_delayed_scaling", + "fp8_current_scaling", + "fp8_block_scaling", + "mxfp8", +] + +bare_stats = [ + "underflows%", + "scale_inv_min", + "scale_inv_max", + "mse", +] + +all_stats = [] + +for r in recipes: + for stat in bare_stats: + for columnwise_postfix in ["", "_columnwise"]: + if ( + r in ["fp8_current_scaling", "fp8_block_scaling"] + and torch.cuda.get_device_capability()[0] < 9 + ): + # hopper is needed for current-scaling, block-scaling + continue + if r == "mxfp8" and torch.cuda.get_device_capability()[0] < 10: + # blackwell is needed for mxfp8 + continue + if ( + r in ["fp8_delayed_scaling", "fp8_current_scaling"] + and columnwise_postfix == "_columnwise" + ): + # columnwise stats are not supported for fp8_delayed_scaling and fp8_current_scaling + continue + + all_stats.append(f"{r}_{stat}{columnwise_postfix}") + +all_stats.append("fp8_delayed_scaling_overflows%") # only delayed-scaling supports overflows% + + +@contextlib.contextmanager +def debug_session(config_str: str, feature_dirs): + """ + Helper context manager that + 1. writes the YAML `config_str` to a temporary file, + 2. starts a debug session, and + 3. yields the directory that contains the statistics log. + + The session is closed automatically – even on exceptions – so every test + stays concise and leak-free. + """ + with tempfile.NamedTemporaryFile( + mode="w", delete=False + ) as cfg_file, tempfile.TemporaryDirectory() as log_dir: + cfg_file.write(config_str) + cfg_file.flush() + + debug_api.initialize( + config_file=cfg_file.name, + feature_dirs=feature_dirs, + log_dir=log_dir, + ) + try: + yield log_dir + finally: + debug_api.end_debug() + + +def read_log(log_dir: str) -> str: + """Return the content of the statistics log produced by `debug_session`.""" + stat_path = os.path.join( + log_dir, + "nvdlfw_inspect_statistics_logs", + "nvdlfw_inspect_globalrank-0.log", + ) + with open(stat_path, "r") as f: + return f.read() + + +def test_sanity(feature_dirs): + log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats)) + with debug_session(log_all_stats_config, feature_dirs) as log_dir: + model = te.Linear(128, 128, params_dtype=torch.bfloat16) + inp = torch.zeros(128, 128, dtype=torch.bfloat16).cuda() + + for _ in range(10): + with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()): + output = model(inp) + loss = output.sum() + loss.backward() + debug_api.step() + + output = read_log(log_dir) + + assert output, "Output is empty" + for stat in all_stats: + assert stat in output, f"Stat {stat} not found in output" + + +fp8_recipes = [ + recipe.MXFP8BlockScaling(), + recipe.DelayedScaling(), + recipe.Float8CurrentScaling(), + recipe.Float8BlockScaling(), +] + + +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +def test_numerics(fp8_recipe, feature_dirs): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if not mxfp8_available and fp8_recipe == recipe.MXFP8BlockScaling(): + pytest.skip(reason_for_no_mxfp8) + if not fp8_block_scaling_available and fp8_recipe == recipe.Float8BlockScaling(): + pytest.skip(reason_for_no_fp8_block_scaling) + + log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(bare_stats)) + + with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: + recipe_state = RecipeState.create( + fp8_recipe, + mode="forward", + num_quantizers=3, + ) + + tensor = torch.zeros(1024, 1024).cuda() + tensor[0, :] = 1000 + quantizer = recipe_state.make_quantizers()[0] + quantized_tensor = quantizer(tensor) + + debug_api.transformer_engine.inspect_tensor( + layer_name="layer_name", + tensor_name="activation", + iteration=0, + tp_group=None, + tensor=tensor, + quantizer=quantizer, + rowwise_quantized_tensor=quantized_tensor, + columnwise_quantized_tensor=quantized_tensor, + ) + debug_api.step() + + dequantized_tensor = quantized_tensor.dequantize() + output = read_log(log_dir) + + for line in output.splitlines(): + if "underflows%" in line: + underflows = float(line.split("value=")[1]) + expected = ( + ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) + / dequantized_tensor.numel() + * 100 + ) + assert underflows == pytest.approx(expected.cpu(), abs=1e-4) + if "mse" in line: + mse = float(line.split("value=")[1]) + expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean") + assert mse == pytest.approx(expected.cpu(), abs=1e-6) + if "overflows%" in line: + overflows = float(line.split("value=")[1]) + expected = ( + (abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100 + ) + assert overflows == pytest.approx(expected.cpu(), abs=1e-4) @pytest.mark.parametrize("layer", ["linear", "transformer"]) @@ -35,7 +226,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): else: raise ValueError(f"Invalid layer: {layer}") - for i in range(11): + for i in range(20): x = torch.randn(4, 128, 128).cuda() with te.fp8_autocast(enabled=True): y = model(x) @@ -49,7 +240,7 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): "r", ) as f: file_content = f.read() - for i in range(1, 11): + for i in range(1, 20): if i % 3 == 0 or i % 5 == 0: assert f"iteration={i:06d}" in file_content else: diff --git a/transformer_engine/debug/features/api.py b/transformer_engine/debug/features/api.py index ff37f57bf..94fc6d129 100644 --- a/transformer_engine/debug/features/api.py +++ b/transformer_engine/debug/features/api.py @@ -5,6 +5,7 @@ """API definition for nvidia-dlframework-inspect.""" import copy +import warnings from typing import Dict, Union, Tuple, Optional from nvdlfw_inspect.base import BaseNamespaceAPI, BaseConfigAPIMapper from nvdlfw_inspect.registry import Registry @@ -114,7 +115,7 @@ def fp8_gemm_enabled( If the tensor is processed using *modify_tensor* or fp8 autocast is not enabled, the result of this call does not matter. - This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be enabled. + This method may return a tuple (bool, Optional[int]), where the int indicates the next iteration when the feature will be disabled. It can return (bool, None) if the feature will never be enabled for that layer and gemm. Returning the next enabled iteration can help optimize CPU usage. @@ -244,6 +245,9 @@ def inspect_tensor( layer_name: str, tensor_name: str, tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[torch.Tensor], + columnwise_quantized_tensor: Optional[torch.Tensor], + quantizer: Optional[Quantizer], iteration: int, tp_group: torch.distributed.ProcessGroup, ) -> None: @@ -260,6 +264,12 @@ def inspect_tensor( one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`], tensor: torch.Tensor tensor in high precision, + rowwise_quantized_tensor: Optional[torch.Tensor] + rowwise quantized tensor, + columnwise_quantized_tensor: Optional[torch.Tensor] + columnwise quantized tensor, + quantizer: Optional[Quantizer] + quantizer, iteration: int iteration number - equal to the number of times `debug_api.step()` was called. tp_group: torch.distributed.ProcessGroup @@ -277,12 +287,15 @@ def inspect_tensor_postquantize( config: Dict, layer_name: str, tensor_name: str, - gemm: str, tensor: torch.Tensor, iteration: int, tp_group: torch.distributed.ProcessGroup, + rowwise: bool, ) -> None: """ + + This is deprecated call, we advise to use *inspect_tensor* instead. + Similar to *inspect_tensor*, but is run after one of the: fp8 cast, modify_tensor if they are run. If none of the fp8 cast or modify_tensor is invoked, then *inspect_tensor_postquantize* is also not invoked. The feature LogFp8Stats uses this call to collect FP8 statistics after the quantization. Parameters @@ -295,8 +308,6 @@ def inspect_tensor_postquantize( one of [`activation`, `weight`, `gradient`, `output`, `wgrad`, `dgrad`], tensor: torch.Tensor tensor in fp8 or processed tensor after the modify_tensor call, - gemm: str - one of [`fprop`, `dgrad`, `wgrad`], iteration: int iteration number - equal to the number of times `debug_api.step()` was called. tp_group: torch.distributed.ProcessGroup @@ -352,6 +363,8 @@ def inspect_tensor_postquantize_enabled( iteration: int, ) -> bool | Tuple[bool, Optional[int]]: """ + This is deprecated call, we advise to use *inspect_tensor* and *inspect_tensor_enabled* instead. + It is a routing call, which is run at the initialization of the layer. Determines if *inspect_tensor_postquantize* for a given GEMM and tensor will be invoked. @@ -399,8 +412,8 @@ def __init__(self): "modify_tensor": ["tensor_name", "gemm"], "inspect_tensor": ["tensor_name"], "inspect_tensor_postquantize": ["tensor_name"], - "inspect_tensor_enabled": ["tensor_name"], - "inspect_tensor_postquantize_enabled": ["tensor_name"], + "inspect_tensor_enabled": ["tensor_name", "iteration"], + "inspect_tensor_postquantize_enabled": ["tensor_name", "iteration"], "modify_tensor_enabled": ["tensor_name"], } @@ -460,6 +473,26 @@ def output_assertions_hook(self, api_name, ret, **kwargs): if kwargs["dtype"] is not None: assert ret.dtype == kwargs["dtype"] + def call_feature(self, call, feat_config, layer_name, **kwargs): + """ + For backward compatibility, remove kwargs that are not needed for the call + """ + if call.__name__ == "inspect_tensor": + kwargs_copy = kwargs.copy() + for k in ["quantizer", "columnwise_quantized_tensor", "rowwise_quantized_tensor"]: + if k not in call.__code__.co_varnames: + kwargs_copy.pop(k) + else: + kwargs_copy = kwargs + + if call.__name__ == "inspect_tensor_postquantize": + warnings.warn( + "inspect_tensor_postquantize is deprecated, use inspect_tensor instead.", + DeprecationWarning, + ) + + return call(feat_config, layer_name, **kwargs_copy) + def handle_multi_feature_output( self, api_name, multi_feature_outputs, features_to_invoke, **kwargs ): @@ -474,19 +507,18 @@ def handle_multi_feature_output( # representing the number of steps after the feature will be enabled next time. # If the second value is None, that means that the feature will never be enabled. all_ret_tuple = all( - isinstance(feature_output, tuple) - for feature_output in multi_feature_outputs.values() + isinstance(feature_output, tuple) for feature_output in multi_feature_outputs ) if all_ret_tuple: - run_current = any( - feature_output[0] for feature_output in multi_feature_outputs.values() - ) + run_current = any(feature_output[0] for feature_output in multi_feature_outputs) next_iter = None - for feature_output in multi_feature_outputs.values(): - if feature_output[1] is not None: + for feature_output in multi_feature_outputs: + if next_iter is None: + next_iter = feature_output[1] + elif feature_output[1] is not None: next_iter = min(next_iter, feature_output[1]) return run_current, next_iter - run_current = any(feature_output for feature_output in multi_feature_outputs.values()) + run_current = any(feature_output for feature_output in multi_feature_outputs) return run_current, None return super().handle_multi_feature_output( api_name, multi_feature_outputs, features_to_invoke, **kwargs diff --git a/transformer_engine/debug/features/disable_fp8_gemm.py b/transformer_engine/debug/features/disable_fp8_gemm.py index 11822fd08..ef2cccbe4 100644 --- a/transformer_engine/debug/features/disable_fp8_gemm.py +++ b/transformer_engine/debug/features/disable_fp8_gemm.py @@ -50,4 +50,4 @@ def fp8_gemm_enabled( # If this feature is invoked, then FP8 GEMM is disabled. # If not, then default behaviour in TransformerEngineAPI # is that fp8_gemm() API call returns True. - return False, None + return False, iteration + 1 diff --git a/transformer_engine/debug/features/disable_fp8_layer.py b/transformer_engine/debug/features/disable_fp8_layer.py index d4f9b1b12..c3b0e4cca 100644 --- a/transformer_engine/debug/features/disable_fp8_layer.py +++ b/transformer_engine/debug/features/disable_fp8_layer.py @@ -41,7 +41,7 @@ def fp8_gemm_enabled( # If this feature is invoked, then FP8 GEMM is disabled. # If not, then default behavior in TransformerEngineAPI # is that fp8_gemm() API call returns True. - return False, None + return False, iteration + 1 def parse_config_and_api(self, config, **_kwargs): """Determines whether to run the API diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index c1528bb05..31620211d 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -4,55 +4,119 @@ """LogFp8TensorStats Feature support for nvidia-dlframework-inspect""" -from typing import Dict, Union +from typing import Dict, Optional, List, Tuple +from contextlib import contextmanager import torch - import nvdlfw_inspect.api as debug_api + + from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS -from transformer_engine.debug.features.utils import next_enabled_iter -from transformer_engine.pytorch.tensor import QuantizedTensor -from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor -from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor -from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase -from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from transformer_engine.debug.pytorch.debug_state import TEDebugState +from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter + + +ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] + + +def _get_recipe_name(quantizer: Optional[Quantizer]): + if quantizer is None: + return "" + if isinstance(quantizer, Float8Quantizer): + return "fp8_delayed_scaling" + if isinstance(quantizer, Float8CurrentScalingQuantizer): + return "fp8_current_scaling" + if isinstance(quantizer, MXFP8Quantizer): + return "mxfp8" + if isinstance(quantizer, Float8BlockQuantizer): + return "fp8_block_scaling" + raise ValueError(f"Unsupported quantizer type: {type(quantizer)}") + + +def _get_new_quantizer(recipe_name, fp8_dtype): + if recipe_name == "fp8_block_scaling": + return Float8BlockQuantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + if recipe_name == "fp8_current_scaling": + return Float8CurrentScalingQuantizer( + fp8_dtype=fp8_dtype, device=torch.device("cuda"), rowwise=True, columnwise=True + ) + if recipe_name == "mxfp8": + return MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) + if recipe_name == "fp8_delayed_scaling": + raise ValueError("Cannot recreate quantizer for fp8_delayed_scaling") + raise ValueError(f"Unsupported recipe name: {recipe_name}") @Registry.register_feature(namespace="transformer_engine") class LogFp8TensorStats(BaseLogTensorStats): """ - This feature handles logging of FP8 tensor stats. + Logs statistics of quantized tensors. + Supports computing statistics for current recipe, but also + allows to see what would happend if different recipes were used for these tensors in current iteration. + For example, during delayed-scaling training you may wish to track + "current_scaling_underflows%" to measure the accuracy of the current scaling + factors; note that this requires an extra cast and therefore adds overhead. + Using a logging frequency (`freq`) greater than 1 is recommended in this case. + Computing the stats matching the training recipe does not require an extra cast. - In a distributed setting, the auxiliary stats are computed on each rank and gathered after - the `debug_api.step()` call. Do not forget to invoke `debug_api.step()` at every step to log - stats! + Statistics are identified by the pattern `_` with optional `_columnwise` suffix (e.g. + `delayed_scaling_underflows%` or `mxfp8_scale_inv_min_columnwise`). + One can provide `` only, then the current training recipe is used. - `LogFp8TensorStats` supports micro-batching. If multiple forward/backward passes are invoked - per `debug_api.step()`, then stats for all tensors except weights will be accumulated. + Stats for delayed-scaling cannot be collected if delayed-scaling is not the current training recipe. - `LogFp8TensorStats` can induce significant overhead. To mitigate this issue, logging stats - with `freq > 1` is recommended. If `LogFp8TensorStats` is not used in a given step, the - overhead is smaller. If no other feature is used for the layer, the TE layer will - run as fast as it would without `debug_api` initialized. + In distributed runs each rank first computes its local statistics; the values + are gathered the next time `debug_api.step()` is called. Remember to call + `debug_api.step()` every training step so the logs are flushed. + + The feature is micro-batch aware: if several forward/backward passes occur + between successive `debug_api.step()` calls, statistics are accumulated for all + tensors except weights. + + Collecting FP8 statistics is expensive. Choosing a larger `freq` reduces the + overhead, and if the feature is skipped for a step the additional cost is + minimal. When no other debug feature is active, the layer runs at normal + Transformer Engine speed. Parameters ---------- stats: List[str] - list of statistics to log + Each stat is a string of the form `_`, with an optional `_columnwise` suffix (i.e., `__columnwise`). + If only `` is omitted, the current training recipe is used. + For mxfp8 and fp8_block_scaling `_columnwise` suffix can be provided. Then stat is computed on columnwise(transpose) + version of the tensor, which can be numerically different from rowwise (non-transpose) tensors. + "_columnwise" suffix is not supported for fp8_delayed_scaling and fp8_current_scaling. + + recipes: + - fp8_delayed_scaling, + - fp8_current_scaling, + - mxfp8, + - fp8_block_scaling, + + stats: + - underflows% - percentage of non-zero elements of tensor clipped to 0 after quantization, + - overflows% - percentage of elements of tensor that were clipped to the max/min value of the FP8 range - supported only for fp8_delayed_scaling, + - scale_inv_min - minimum of the inverse of the scaling factors, + - scale_inv_max - maximum of the inverse of the scaling factors, + - mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements, - - underflows% - percentage of elements of the tensor equal to 0, tensors/tensors_struct: List[str] list of tensors to log + - activation, + - gradient, + - weight, - - activation - - gradient - - weight freq: Optional[int], default = 1 frequency of logging stats, stats will be logged every `freq` steps start_step: Optional[int], default = None @@ -75,7 +139,7 @@ class LogFp8TensorStats(BaseLogTensorStats): enabled: True tensors_struct: - tensor: activation - stats: [underflows%] + stats: [mxfp8_underflows%] freq: 1 - tensor: gradient stats: [underflows%] @@ -84,13 +148,106 @@ class LogFp8TensorStats(BaseLogTensorStats): end_step: 80 """ - def _get_supported_stats_list(self): - """Returns stats this feature can log.""" - return {"underflows%"} + def check_if_stat_is_supported(self, stat: str, current_recipe: str): + """Returns True if stat is supported, raises ValueError otherwise.""" + columnwise = stat.endswith("_columnwise") + if columnwise: + stat = stat[: -len("_columnwise")] + recipe_from_stat, _ = self.get_recipe_from_stat(stat, default_recipe=current_recipe) + stat_without_recipe = stat.replace(recipe_from_stat + "_", "") + + if current_recipe == "" and recipe_from_stat == "": + raise ValueError( + f"Stat {stat} does not contain a recipe name and the current recipe is not set." + ) + + if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: + raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") + + if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: + raise ValueError( + f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for" + " fp8_delayed_scaling and fp8_current_scaling." + ) + + if recipe_from_stat == "fp8_delayed_scaling" and stat_without_recipe == "overflows%": + return True + + if recipe_from_stat in ["fp8_block_scaling"] and torch.cuda.get_device_capability()[0] < 9: + raise ValueError(f"Stat {stat} needs Hopper or later GPU.") + + if recipe_from_stat == "mxfp8" and torch.cuda.get_device_capability()[0] < 10: + raise ValueError(f"Stat {stat} needs Blackwell or later GPU.") + + supported_stats = ["underflows%", "scale_inv_min", "scale_inv_max", "mse"] + if stat_without_recipe not in supported_stats: + raise ValueError( + f"Stat {stat} contains an unsupported stat name: {stat_without_recipe}" + ) + + return True + + def get_recipe_from_stat(self, stat: str, default_recipe: str = ""): + """Returns the recipe name from the stat string.""" + columnwise_stat = stat.endswith("_columnwise") + for recipe_name in ALL_RECIPE_NAMES: + if recipe_name in stat: + return recipe_name, columnwise_stat + return default_recipe, columnwise_stat + + @contextmanager + def update_aux_dict( + self, + aux_dict: Dict, + recipe_name: str, + quantized_tensor: QuantizedTensor, + quantizer: Quantizer, + original_tensor: torch.Tensor, + recipes_in_stats: List[Tuple[str, bool]], + ): + """ + Updates the aux_dict with the quantized tensor for each recipe provided in recipes_in_stats. + It allows to compute stats for different recipes in the same iteration, + without recomputing the quantized tensor for each recipe for each stat. + Also updates usage of the quantized tensor with rowwise and columnwise usage. + Yields the aux_dict. + Needs to clean after usage, because it possibly change the usage of the quantized tensor. + """ + fp8_dtype = None + if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: + assert isinstance( + quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) + ) + fp8_dtype = quantizer.dtype + + aux_dict = { + recipe_name: quantized_tensor, + } + + old_rowwise_usage = quantizer.rowwise_usage + old_columnwise_usage = quantizer.columnwise_usage + for cur_recipe_name, cur_columnwise_stat in recipes_in_stats: + if recipe_name is not cur_recipe_name: + quantizer = _get_new_quantizer(cur_recipe_name, fp8_dtype) + aux_dict[cur_recipe_name] = quantizer(original_tensor) + elif isinstance(quantized_tensor, QuantizedTensor): + if cur_columnwise_stat: + quantized_tensor.update_usage(columnwise_usage=True) + else: + quantized_tensor.update_usage(rowwise_usage=True) + aux_dict[""] = quantized_tensor + aux_dict[cur_recipe_name] = quantized_tensor + try: + yield aux_dict + finally: + if isinstance(quantized_tensor, QuantizedTensor): + quantized_tensor.update_usage( + rowwise_usage=old_rowwise_usage, columnwise_usage=old_columnwise_usage + ) @api_method - def inspect_tensor_postquantize_enabled( - self, config: Dict, layer_name: str, gemm: str, tensor_name: str, iteration: int + def inspect_tensor_enabled( + self, config: Dict, layer_name: str, tensor_name: str, iteration: int ): # pylint: disable=unused-argument """API call used to determine whether to run inspect_tensor_postquantize() in the forward.""" run_current, next_iter = next_enabled_iter( @@ -104,29 +261,34 @@ def inspect_tensor_postquantize_enabled( return run_current, next_iter @api_method - def inspect_tensor_postquantize( + def inspect_tensor( self, config: Dict, layer_name: str, tensor_name: str, - tensor: Union[torch.Tensor, QuantizedTensor], - rowwise: bool, iteration: int, tp_group: torch.distributed.ProcessGroup, + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, ): """ API call used to collect the data about the tensor after process_tensor()/quantization. """ + assert rowwise_quantized_tensor is columnwise_quantized_tensor + assert ( + quantizer is not None + ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats cannot be run without low-precision recipe." - assert type(tensor) in [Float8Tensor, Float8TensorBase, MXFP8Tensor, MXFP8TensorBase], ( - f"[NVTORCH INSPECT ERROR] Tensor {tensor_name} must be a quantized tensor when using" - " log_fp8_tensor_stats. Use log_tensor_stats for high precision tensors." - ) + quantized_tensor = rowwise_quantized_tensor + assert isinstance( + quantized_tensor, QuantizedTensor + ), "[NVTORCH INSPECT ERROR] LogFp8TensorStats quantized_tensor must be a QuantizedTensor." + recipe_name = _get_recipe_name(quantizer) - # This API can be invoked twice - with the tensor and with the transpose. - # We want to collect the stats once. - if not rowwise: - return # tensor was already seen rowwise in the other gemm + for stat in config["stats"]: + self.check_if_stat_is_supported(stat, recipe_name) options = ( config.get("start_step", None), @@ -135,19 +297,9 @@ def inspect_tensor_postquantize( "fp8", ) - skip_reduction = False - reduction_group = debug_api.get_tensor_reduction_group() - reduce_within_microbatch = tensor_name != "weight" - if tensor_name == "weight": - if TEDebugState.weight_tensor_tp_group_reduce: - reduction_group = tp_group - else: - skip_reduction = True - - for stat in config["stats"]: - assert ( - stat in self._get_supported_stats_list() - ), f"[NVTORCH INSPECT ERROR] Statistic {stat} is not supported." + skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( + tensor_name, tp_group + ) STATS_BUFFERS.try_add_buffer( layer_name=layer_name, @@ -158,10 +310,30 @@ def inspect_tensor_postquantize( reduce_within_microbatch=reduce_within_microbatch, ) - STATS_BUFFERS.feed(layer_name, tensor_name, options, tensor, iteration, skip_reduction) + recipes_in_stats = [ + self.get_recipe_from_stat(stat, default_recipe=recipe_name) for stat in config["stats"] + ] + + with self.update_aux_dict( + aux_dict={}, + recipe_name=recipe_name, + quantized_tensor=quantized_tensor, + quantizer=quantizer, + original_tensor=tensor, + recipes_in_stats=recipes_in_stats, + ) as aux_dict: + STATS_BUFFERS.feed( + layer_name, + tensor_name, + options, + tensor, + iteration, + skip_reduction, + aux_dict=aux_dict, + ) debug_api.log_message( - f"Feature={self.__class__.__name__}, API=inspect_tensor_postquantize: {tensor_name}", + f"Feature={self.__class__.__name__}, API=inspect_tensor: {tensor_name}", layer_name, extra_cachable_args=(tensor_name,), ) diff --git a/transformer_engine/debug/features/log_tensor_stats.py b/transformer_engine/debug/features/log_tensor_stats.py index 402750c28..7ba2f9f77 100644 --- a/transformer_engine/debug/features/log_tensor_stats.py +++ b/transformer_engine/debug/features/log_tensor_stats.py @@ -4,7 +4,7 @@ """LogTensorStats Feature support for nvidia-dlframework-inspect""" -from typing import Dict, Union +from typing import Dict, Optional import torch @@ -12,14 +12,13 @@ from nvdlfw_inspect.registry import Registry, api_method import nvdlfw_inspect.api as debug_api -from transformer_engine.pytorch.tensor import QuantizedTensor +from transformer_engine.pytorch.tensor import QuantizedTensor, Quantizer from transformer_engine.pytorch.tensor.float8_tensor import Float8Tensor from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Tensor from transformer_engine.pytorch.tensor._internal.float8_tensor_base import Float8TensorBase from transformer_engine.pytorch.tensor._internal.mxfp8_tensor_base import MXFP8TensorBase -from transformer_engine.debug.pytorch.debug_state import TEDebugState from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS -from transformer_engine.debug.features.utils import next_enabled_iter +from transformer_engine.debug.features.utils import next_enabled_iter, get_reduction_params @Registry.register_feature(namespace="transformer_engine") @@ -114,10 +113,13 @@ def inspect_tensor( config: Dict, layer_name: str, tensor_name: str, - tensor: Union[torch.Tensor, QuantizedTensor], iteration: int, tp_group: torch.distributed.ProcessGroup, - ): + tensor: torch.Tensor, + rowwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + columnwise_quantized_tensor: Optional[torch.Tensor | QuantizedTensor] = None, + quantizer: Optional[Quantizer] = None, + ): # pylint: disable=unused-argument """API call used to collect the data about the tensor before process_tensor()/quantization.""" assert ( @@ -134,14 +136,9 @@ def inspect_tensor( config.get("start_end_list", None), ) - skip_reduction = False - reduction_group = debug_api.get_tensor_reduction_group() - reduce_within_microbatch = tensor_name != "weight" - if tensor_name == "weight": - if TEDebugState.weight_tensor_tp_group_reduce: - reduction_group = tp_group - else: - skip_reduction = True + skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( + tensor_name, tp_group + ) for stat in config["stats"]: assert ( diff --git a/transformer_engine/debug/features/utils/__init__.py b/transformer_engine/debug/features/utils/__init__.py index 60f6b0a21..aae2ec4e9 100644 --- a/transformer_engine/debug/features/utils/__init__.py +++ b/transformer_engine/debug/features/utils/__init__.py @@ -6,6 +6,26 @@ Utils for the debug features. """ +import torch +import nvdlfw_inspect.api as debug_api + +from transformer_engine.debug.pytorch.debug_state import TEDebugState + + +def get_reduction_params(tensor_name: str, tp_group: torch.distributed.ProcessGroup): + """ + Returns the statistics reduction parameters for the tensor. + """ + skip_reduction = False + reduction_group = debug_api.get_tensor_reduction_group() + reduce_within_microbatch = tensor_name != "weight" + if tensor_name == "weight": + if TEDebugState.weight_tensor_tp_group_reduce: + reduction_group = tp_group + else: + skip_reduction = True + return skip_reduction, reduction_group, reduce_within_microbatch + def next_enabled_iter(start_step, end_step, start_end_list, freq, iteration): """ diff --git a/transformer_engine/debug/features/utils/stats_buffer.py b/transformer_engine/debug/features/utils/stats_buffer.py index 7ccef20bc..f07602d23 100644 --- a/transformer_engine/debug/features/utils/stats_buffer.py +++ b/transformer_engine/debug/features/utils/stats_buffer.py @@ -67,14 +67,17 @@ def _gather_helper_stats(self) -> torch.Tensor: gathered_buffer, _ = gather_along_first_dim( self._buffer.unsqueeze(0), process_group=self.reduction_group ) - return gathered_buffer[mask.to(bool)] + return gathered_buffer[mask.to(torch.bool)] - def feed(self, tensor, iteration): + def feed(self, tensor, iteration, aux_dict=None): """ feed() is used to add tensor for computing the statistics. Because of the microbatching, feed() can be used multiple times for one log(). + The aux_dict is used to share common computation between different stats. + For example for LogFp8TensorStats in can contain quantized tensors in different precisions. + The main reason of this design: need to combine results for already processed tensors with the result of the new tensor. """ @@ -97,7 +100,7 @@ def feed(self, tensor, iteration): # save stats for tensor to tmp buffer for stat_name in self.stats_to_compute: fn, _ = STATS[stat_name] - self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor) + self._tmp_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict) # [num_buffers, num_stats] buffers = torch.cat((self._buffer.unsqueeze(0), self._tmp_buffer.unsqueeze(0)), dim=0) @@ -108,7 +111,7 @@ def feed(self, tensor, iteration): self._new_buffer[stats_to_num[stat_name]] = combinator(buffers) else: fn = STATS[stat_name][0] - self._new_buffer[stats_to_num[stat_name]] = fn(tensor) + self._new_buffer[stats_to_num[stat_name]] = fn(tensor, aux_dict) self._buffer.copy_(self._new_buffer) @@ -127,7 +130,6 @@ def log(self): for stat_name in self.stats_to_log: combiner = STATS[stat_name][1] stat_value = combiner(gathered_helper_stats) - MetricLogger.log_scalar( f"{self.layer_name}_{self.tensor_name}_{stat_name}", stat_value, self.iteration ) @@ -194,11 +196,18 @@ def try_add_buffer( self.buffers[(layer_name, tensor_name, options)] = buffer self.reduction_group_to_buffer[reduction_group].append(buffer) - def feed(self, layer_name, tensor_name, options, tensor, iteration, skip_reduction): - """Feeds the tensor into the respective buffer.""" + def feed( + self, layer_name, tensor_name, options, tensor, iteration, skip_reduction, aux_dict=None + ): + """ + Feeds the tensor into the respective buffer. + + The aux_dict is used to share common computation between different stats. + For example for LogFp8TensorStats in can contain quantized tensors in different precisions. + """ self.at_least_one_layer_fed = True buffer = self.buffers[(layer_name, tensor_name, options)] - buffer.feed(tensor, iteration) + buffer.feed(tensor, iteration, aux_dict) buffer.skip_reduction = skip_reduction def log_stats(self): diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index ed32de1ae..3842ab1c5 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -8,8 +8,9 @@ import math import torch - -MAX_FP8_VALUE_INT8 = 126 +import torch.nn.functional as F +import transformer_engine_torch as tex +from transformer_engine.common.recipe import Format @torch.compile @@ -49,6 +50,29 @@ def compute_std(variances, numels, sums): return torch.sqrt(compute_variance(variances, numels, sums)) +def compute_fp8_delayed_scaling_overflows_num(tensor, quantized_tensor): + """Computes the overflows of the tensor.""" + scale_inv = quantized_tensor._scale_inv + dtype = quantized_tensor._fp8_dtype + + # Map each supported FP8 dtype to its corresponding max forward value. + dtype_to_max = { + tex.DType.kFloat8E4M3: Format.E4M3.value.max_fwd, + tex.DType.kFloat8E5M2: Format.E5M2.value.max_fwd, + } + + if dtype not in dtype_to_max: + raise ValueError( + f"Unsupported FP8 dtype {dtype} passed to compute_fp8_delayed_scaling_overflows_num()." + ) + + fp8_max = dtype_to_max[dtype] + fp8_min = -fp8_max + + overflows = (tensor > fp8_max * scale_inv) | (tensor < fp8_min * scale_inv) + return overflows.sum() + + # buffers is tensor of shape [nr_buffers, nr_stats] def _get(buffers, stat_name): stat_nr = stats_to_num[stat_name] @@ -68,10 +92,12 @@ def _get(buffers, stat_name): "cur_amax": 9, "dynamic_range_top": 10, "dynamic_range_bottom": 11, - "underflows_num": 12, - "std": 13, - "dynamic_range": 14, - "underflows%": 15, + "std": 12, + "dynamic_range": 13, + "fp8_delayed_scaling_overflows_num": 14, + "fp8_delayed_scaling_overflows%": 15, + "overflows_num": 16, + "overflows%": 17, } DEPENDENCIES = { @@ -87,62 +113,207 @@ def _get(buffers, stat_name): "cur_amax": {"cur_amax"}, "dynamic_range_top": {"dynamic_range_top"}, "dynamic_range_bottom": {"dynamic_range_bottom"}, - "underflows_num": {"underflows_num"}, "std": {"variance", "numel", "sum"}, "dynamic_range": {"dynamic_range_top", "dynamic_range_bottom"}, - "underflows%": {"underflows_num", "numel"}, + "fp8_delayed_scaling_overflows_num": {"fp8_delayed_scaling_overflows_num"}, + "fp8_delayed_scaling_overflows%": {"fp8_delayed_scaling_overflows_num", "numel"}, + "overflows_num": {"overflows_num"}, + "overflows%": {"overflows_num", "numel"}, } STATS = { - "min": (torch.min, lambda buffers: min(_get(buffers, "min"))), - "max": (torch.max, lambda buffers: max(_get(buffers, "max"))), - "sum": (torch.sum, lambda buffers: sum(_get(buffers, "sum"))), - "mean": (torch.mean, lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel"))), + "min": (lambda x, aux_dict: torch.min(x), lambda buffers: min(_get(buffers, "min"))), + "max": (lambda x, aux_dict: torch.max(x), lambda buffers: max(_get(buffers, "max"))), + "sum": (lambda x, aux_dict: torch.sum(x), lambda buffers: sum(_get(buffers, "sum"))), + "mean": ( + lambda x, aux_dict: torch.mean(x), + lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel")), + ), "numel": ( - lambda x: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(), + lambda x, aux_dict: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(), lambda buffers: sum(_get(buffers, "numel")), ), - "l1_norm": (lambda x: torch.norm(x, p=1), lambda buffers: sum(_get(buffers, "l1_norm"))), + "l1_norm": ( + lambda x, aux_dict: torch.norm(x, p=1), + lambda buffers: sum(_get(buffers, "l1_norm")), + ), "l2_norm_square": ( - lambda x: torch.sum(x**2), + lambda x, aux_dict: torch.sum(x**2), lambda buffers: sum(_get(buffers, "l2_norm_square")), ), "l2_norm": ( - lambda x: torch.norm(x, p=2), + lambda x, aux_dict: torch.norm(x, p=2), lambda buffers: math.sqrt(sum(_get(buffers, "l2_norm_square"))), ), "variance": ( - torch.var, + lambda x, aux_dict: torch.var(x), lambda buffers: compute_variance( _get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum") ), ), - "cur_amax": (lambda x: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))), + "cur_amax": (lambda x, aux_dict: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))), "dynamic_range_top": ( - _compute_dynamic_range_top, + lambda x, aux_dict: _compute_dynamic_range_top(x), lambda buffers: max(_get(buffers, "dynamic_range_top")), ), "dynamic_range_bottom": ( - _compute_dynamic_range_bottom, + lambda x, aux_dict: _compute_dynamic_range_bottom(x), lambda buffers: min(_get(buffers, "dynamic_range_bottom")), ), - "underflows_num": ( - lambda x: (x.get_data_tensors()[0] == 0).sum(), - lambda buffers: sum(_get(buffers, "underflows_num")), - ), "std": ( - torch.std, + lambda x, aux_dict: torch.std(x), lambda buffers: compute_std( _get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum") ), ), "dynamic_range": ( - lambda x: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x), + lambda x, aux_dict: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x), lambda buffers: max(_get(buffers, "dynamic_range_top")) - min(_get(buffers, "dynamic_range_bottom")), ), - "underflows%": ( - lambda x: (x.get_data_tensors()[0] == 0).sum() / x.get_data_tensors()[0].numel() * 100, - lambda buffers: 100 * sum(_get(buffers, "underflows_num")) / sum(_get(buffers, "numel")), + "fp8_delayed_scaling_overflows_num": ( + lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num( + x, aux_dict["fp8_delayed_scaling"] + ), + lambda buffers: sum(_get(buffers, "fp8_delayed_scaling_overflows_num")), + ), + "fp8_delayed_scaling_overflows%": ( + lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num( + x, aux_dict["fp8_delayed_scaling"] + ) + / x.numel() + * 100, + lambda buffers: 100 + * sum(_get(buffers, "fp8_delayed_scaling_overflows_num")) + / sum(_get(buffers, "numel")), + ), + "overflows_num": ( + lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""]), + lambda buffers: sum(_get(buffers, "overflows_num")), + ), + "overflows%": ( + lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""]) + / x.numel() + * 100, + lambda buffers: 100 * sum(_get(buffers, "overflows_num")) / sum(_get(buffers, "numel")), ), } + + +def add_underflows_stats(recipe_name: str, columnwise: bool = False): + """Register *both* underflow stats (num and %) for the given recipe.""" + columnwise_suffix = "_columnwise" if columnwise else "" + + # Stat names + stat_num = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows_num{columnwise_suffix}" + stat_pct = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows%{columnwise_suffix}" + + stats_to_num[stat_num] = len(stats_to_num) + stats_to_num[stat_pct] = len(stats_to_num) + + STATS[stat_num] = ( + lambda x, aux_dict: ( + aux_dict[recipe_name].get_data_tensors( + rowwise_data=not columnwise, columnwise_data=columnwise + ) + == 0 + ).sum() + - (x == 0).sum(), + lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)), + ) + STATS[stat_pct] = ( + lambda x, aux_dict: ( + aux_dict[recipe_name].get_data_tensors( + rowwise_data=not columnwise, columnwise_data=columnwise + ) + == 0 + ).sum() + / aux_dict[recipe_name].numel() + * 100, + lambda buffers, _sn_num=stat_num: 100 + * sum(_get(buffers, _sn_num)) + / sum(_get(buffers, "numel")), + ) + + DEPENDENCIES[stat_num] = {stat_num} + DEPENDENCIES[stat_pct] = {stat_num, "numel"} + + +def add_scale_inv_stats(recipe_name: str, columnwise: bool = False): + """Register *both* scale-inv min and max stats for a given recipe. + + This replaces the earlier separate helpers and avoids duplicated boilerplate. + """ + # Determine which attribute holds the scale-inverse tensor. + + def get_scale_inv(quantized_tensor, columnwise): + if hasattr(quantized_tensor, "_scale_inv"): + return getattr(quantized_tensor, "_scale_inv") + if columnwise: + return getattr(quantized_tensor, "_columnwise_scale_inv") + return getattr(quantized_tensor, "_rowwise_scale_inv") + + columnwise_suffix = "_columnwise" if columnwise else "" + # Prepare stat names. + stat_name_min = ( + f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_min{columnwise_suffix}" + ) + stat_name_max = ( + f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_max{columnwise_suffix}" + ) + + # Assign indices in `stats_to_num` (order matters — keep insertion order deterministic). + stats_to_num[stat_name_min] = len(stats_to_num) + stats_to_num[stat_name_max] = len(stats_to_num) + + # Capture the attribute name inside lambdas via default args to avoid late binding. + STATS[stat_name_min] = ( + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).min(), + lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)), + ) + STATS[stat_name_max] = ( + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).max(), + lambda buffers, _sn=stat_name_max: max(_get(buffers, _sn)), + ) + + DEPENDENCIES[stat_name_min] = {stat_name_min} + DEPENDENCIES[stat_name_max] = {stat_name_max} + + +def add_mse_stats(recipe_name: str, columnwise: bool = False): + """Register mse and total_square_error stats for the recipe.""" + columnwise_suffix = "_columnwise" if columnwise else "" + + stat_mse = f"{recipe_name}{'_' if recipe_name != '' else ''}mse{columnwise_suffix}" + stat_err = ( + f"{recipe_name}{'_' if recipe_name != '' else ''}total_square_error{columnwise_suffix}" + ) + + stats_to_num[stat_mse] = len(stats_to_num) + stats_to_num[stat_err] = len(stats_to_num) + + STATS[stat_mse] = ( + lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="mean"), + lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)) + / sum(_get(buffers, "numel")), + ) + STATS[stat_err] = ( + lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="sum"), + lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)), + ) + + DEPENDENCIES[stat_err] = {stat_err} + DEPENDENCIES[stat_mse] = {stat_mse, stat_err, "numel"} + + +for _columnwise in [True, False]: + for _recipe_name in [ + "", # default recipe + "fp8_delayed_scaling", + "mxfp8", + "fp8_current_scaling", + "fp8_block_scaling", + ]: + add_underflows_stats(_recipe_name, _columnwise) + add_scale_inv_stats(_recipe_name, _columnwise) + add_mse_stats(_recipe_name, _columnwise) diff --git a/transformer_engine/debug/pytorch/debug_quantization.py b/transformer_engine/debug/pytorch/debug_quantization.py index 98feb3180..d564ca8e9 100644 --- a/transformer_engine/debug/pytorch/debug_quantization.py +++ b/transformer_engine/debug/pytorch/debug_quantization.py @@ -156,7 +156,6 @@ def get_enabled_look_at_tensors(self): gemm=self.columnwise_gemm_name, ) ) - return ( inspect_tensor_enabled, inspect_tensor_postquantize_enabled_rowwise, @@ -259,6 +258,9 @@ def _call_inspect_tensor_api( "tensor_name": self.tensor_name, "iteration": TEDebugState.get_iteration(), "tp_group": self.tp_group, + "columnwise_quantized_tensor": columnwise_gemm_tensor, + "rowwise_quantized_tensor": rowwise_gemm_tensor, + "quantizer": self.parent_quantizer, } if tensor is not None and self.inspect_tensor_enabled: debug_api.transformer_engine.inspect_tensor(**args) @@ -266,6 +268,10 @@ def _call_inspect_tensor_api( if self.output_tensor: return + del args["columnwise_quantized_tensor"] + del args["rowwise_quantized_tensor"] + del args["quantizer"] + if ( self.rowwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] and self.inspect_tensor_postquantize_enabled_rowwise @@ -273,6 +279,7 @@ def _call_inspect_tensor_api( args["tensor"] = rowwise_gemm_tensor args["rowwise"] = True debug_api.transformer_engine.inspect_tensor_postquantize(**args) + if ( self.columnwise_tensor_plan in [API_CALL_MODIFY, STANDARD_FP8_QUANTIZE] and self.inspect_tensor_postquantize_enabled_columnwise @@ -398,6 +405,7 @@ def any_feature_enabled(self) -> bool: """Returns bool if there is at least one API call enabled.""" if self.output_tensor: return self.inspect_tensor_enabled or self.rowwise_tensor_plan == API_CALL_MODIFY + # pylint: disable=too-many-boolean-expressions if ( self.inspect_tensor_enabled or self.inspect_tensor_postquantize_enabled_rowwise diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 709b4f3b8..217cb98c7 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1138,6 +1138,10 @@ def _all_gather_fp8_blockwise( "Dequantizing and requantizing to Float8BlockwiseQTensor." ) inp = quantizer(inp.dequantize()) + + # Construct Float8BlockwiseQTensor output tensor + out = quantizer.make_empty(out_shape, dtype=dtype, device=device) + quantizer.all_gather_usage = orig_all_gather_usage # Begin to do network communication, need to make sure compact format @@ -1147,9 +1151,6 @@ def _all_gather_fp8_blockwise( f"but found data_format={inp._data_format}" ) - # Construct Float8BlockwiseQTensor output tensor - out = quantizer.make_empty(out_shape, dtype=dtype, device=device) - # Coalesce NCCL collectives with torch.distributed._coalescing_manager( group=process_group, diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index 787c322a0..adffe7c58 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -124,9 +124,15 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - def get_data_tensors(self): + def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" - return self._rowwise_data, self._columnwise_data + if rowwise_data and columnwise_data: + return self._rowwise_data, self._columnwise_data + if rowwise_data: + return self._rowwise_data + if columnwise_data: + return self._columnwise_data + raise ValueError("No data to get, both rowwise_data and columnwise_data are False") def _transpose_dq_columnwise_output(self, columnwise_dq: torch.Tensor) -> torch.Tensor: """Takes dequantized columnwise data and permutes to a rowwise shape""" diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index a88ae33f0..61edc999a 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -128,9 +128,15 @@ def restore_from_saved( self._scale_inv = tensors[2] return tensors[3:] - def get_data_tensors(self): + def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" - return self._data, self._transpose + if rowwise_data and columnwise_data: + return self._data, self._transpose + if rowwise_data: + return self._data + if columnwise_data: + return self._transpose + raise ValueError("No data to get, both rowwise_data and columnwise_data are False") def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Dequantize to a higher precision.""" diff --git a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py index a093904bc..5a7dd6b44 100644 --- a/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/mxfp8_tensor_base.py @@ -136,9 +136,15 @@ def restore_from_saved( self._columnwise_scale_inv = tensors[3] return tensors[4:] - def get_data_tensors(self): + def get_data_tensors(self, rowwise_data: bool = True, columnwise_data: bool = True): """Get this Tensor's data.""" - return self._rowwise_data, self._columnwise_data + if rowwise_data and columnwise_data: + return self._rowwise_data, self._columnwise_data + if rowwise_data: + return self._rowwise_data + if columnwise_data: + return self._columnwise_data + raise ValueError("No data to get, both rowwise_data and columnwise_data are False") def dequantize(self, *, dtype: torch.dtype = torch.float32) -> torch.Tensor: """Dequantize to a higher precision.""" From aa0659e5914933711bf1df92078431bc1330805a Mon Sep 17 00:00:00 2001 From: Kate Cheng Date: Wed, 13 Aug 2025 10:35:31 -0700 Subject: [PATCH 066/153] Remove if-else and torch.tensor to meet cudagraph requirement (#1997) * Remove if-else and torch.tensor to meet cudagraph requirement Signed-off-by: Kate Cheng * Add is_cg_capturable flag to guard the if-else statement Signed-off-by: Kate Cheng --------- Signed-off-by: Kate Cheng Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/cross_entropy.py | 15 +++++++++++++-- .../pytorch/triton/cross_entropy.py | 10 +++++++--- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/transformer_engine/pytorch/cross_entropy.py b/transformer_engine/pytorch/cross_entropy.py index 75b5de37b..076dbec0d 100644 --- a/transformer_engine/pytorch/cross_entropy.py +++ b/transformer_engine/pytorch/cross_entropy.py @@ -29,6 +29,7 @@ def forward( reduce_loss=False, dist_process_group=None, ignore_idx=-100, + is_cg_capturable=False, ): """ The forward pass of the Cross Entropy loss. If dist_process_group is passed for distributed loss calculation, the input to each @@ -47,10 +48,16 @@ def forward( tensor: The computed loss. """ loss, _input = triton_cross_entropy.cross_entropy_forward( - _input, target, label_smoothing, reduce_loss, dist_process_group, ignore_idx + _input, + target, + label_smoothing, + reduce_loss, + dist_process_group, + ignore_idx, ) ctx.save_for_backward(_input.detach()) + ctx.is_cg_capturable = is_cg_capturable return loss @staticmethod @@ -66,13 +73,17 @@ def backward(ctx, grad_output): tuple: A tuple with the gradients with respect to the inputs. The elements are tensors or None. """ (_input,) = ctx.saved_tensors - _input = triton_cross_entropy.cross_entropy_backward(_input, grad_output) + _input = triton_cross_entropy.cross_entropy_backward( + _input, grad_output, ctx.is_cg_capturable + ) return ( _input, None, None, None, None, + None, + None, ) diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 45ff9f9c5..323a93922 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -340,13 +340,17 @@ def cross_entropy_forward( return loss, _input -def cross_entropy_backward(_input: torch.Tensor, grad_output: torch.Tensor): +def cross_entropy_backward( + _input: torch.Tensor, grad_output: torch.Tensor, is_cg_capturable: bool = False +): """Backward implementation of cross entropy loss kernel""" # If cross entropy is the last layer, grad_output is 1.0. Skip the mul to save time - if torch.equal(grad_output, torch.tensor(1.0, device=grad_output.device)): + # Only check torch.equal when not in CUDA graph capturable mode + if not is_cg_capturable and torch.equal( + grad_output, torch.tensor(1.0, device=grad_output.device) + ): pass - else: B, SQ, V = _input.shape n_rows = B * SQ From 8dc2756ed818d51508e0c53e3422f61cf9661472 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 13 Aug 2025 17:41:48 -0400 Subject: [PATCH 067/153] [JAX] Manual axis filter in `with_sharding_constraint` (#2069) * add manual axis filer to sharding_constraint impl Signed-off-by: Phuong Nguyen * fix lint Signed-off-by: Phuong Nguyen * use abstract_mesh instead of physical_mesh Signed-off-by: Phuong Nguyen * add a comment Signed-off-by: Phuong Nguyen * cleanup Signed-off-by: Phuong Nguyen * clean unused var Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/sharding.py | 19 ++++++++++++++----- 1 file changed, 14 insertions(+), 5 deletions(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 606c233c9..6a6e25da1 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -15,10 +15,10 @@ from enum import Enum from typing import Callable, Optional import warnings -from jax.interpreters import pxla import jax import jax.numpy as jnp -from jax.sharding import PartitionSpec +from jax.interpreters import pxla +from jax.sharding import PartitionSpec, get_abstract_mesh import numpy as np _PXLA_THREAD_RESOURCES = pxla.thread_resources @@ -122,8 +122,10 @@ def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False): def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): """ - A wrapper function to jax.lax.with_sharding_constraint to - support the case that Mesh is empty. + A wrapper function to jax.lax.with_sharding_constraint + 1. Does nothing if mesh is empty. + 2. If all mesh axes are manual axes, replaces pspec with all Nones. + 3. Otherwise, strips only the manual axes. """ if pspec is None: return x @@ -131,7 +133,14 @@ def with_sharding_constraint(x: jnp.array, pspec: PartitionSpec): mesh = _PXLA_THREAD_RESOURCES.env.physical_mesh if mesh.empty: return x - return jax.lax.with_sharding_constraint(x, pspec) + + # We want to exclude the axes that already used by shard_map and shard_map + # only sets those in the abstract_mesh, not the physical one + manual_axis_names = get_abstract_mesh().manual_axes + cleaned_axis_names = tuple(name if name not in manual_axis_names else None for name in pspec) + + cleaned_pspec = PartitionSpec(*cleaned_axis_names) + return jax.lax.with_sharding_constraint(x, cleaned_pspec) def with_sharding_constraint_by_logical_axes( From bbddcb92896f604a478c9e94ab697c71d838638f Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 13 Aug 2025 17:42:34 -0400 Subject: [PATCH 068/153] [JAX] Cleanup the MLP warning for TE GEMM + TP (#2054) * fix pspec check Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * cleaning Signed-off-by: Phuong Nguyen * add docstring Signed-off-by: Phuong Nguyen * use dict.get() Signed-off-by: Phuong Nguyen * fix lint Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/jax/layernorm_mlp.py | 23 -------------- transformer_engine/jax/sharding.py | 41 ++++++++----------------- 2 files changed, 12 insertions(+), 52 deletions(-) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index ce3ebc78a..8727ea7e3 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -15,7 +15,6 @@ from typing import List, Tuple, Sequence, Union, Callable from functools import partial -import warnings import jax import jax.numpy as jnp @@ -93,28 +92,6 @@ def layernorm_mlp( """ assert len(kernels) == 2 - # For MaxText TP (= Megatron TP + sharding in hidden dimension of remaining unsharded - # activations), JAX dot_general may perform better then TE GEMM custom call - # This inspection only works if either norm_input_axes or dot_1_input_axes is set - is_mxfp8 = ( - False - if quantizer_sets[0] == noop_quantizer_set - else quantizer_sets[0].x.scaling_mode.is_1d_block_scaling() - ) - inspect_axes = norm_input_axes or dot_1_input_axes - if ( - inspect_axes is not None - and len(inspect_axes) == x.ndim - and inspect_axes[-1] is not None - and not is_mxfp8 - ): - warnings.warn( - "Detected sharding in the hidden dimension of the MLP activation input. For improved" - " performance, consider using JAX’s built-in `dot_general` implementation. To try" - " this, set the environment variable: `NVTE_JAX_CUSTOM_CALLS='GemmPrimitive=false'`", - UserWarning, - ) - kernel_1 = kernels[0] kernel_2 = kernels[1] bias_1 = biases[0] diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 6a6e25da1..6d4894fd8 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -86,37 +86,20 @@ def get_sharding_map_logic_axis_to_mesh_axis(): return te_logical_axis_to_mesh_axis -def generate_pspec(logical_axis_names, with_flax_rules=False, padded=False): +def _generate_pspec(logical_axis_names): """ - Convert logical axes to PartitionSpec + Convert TransformerEngine logical axes (e.g. BATCH_AXES) to a JAX PartitionSpec. + Note, this method does not support Flax logical axes. + + Args: + logical_axis_names: TransformerEngine logical axes to convert to a JAX PartitionSpec. + Returns: + A JAX PartitionSpec with the mesh axes corresponding to the given TransformerEngine logical axis names """ - rules = None - if with_flax_rules: - try: - import flax - - rules = dict(flax.linen.get_logical_axis_rules()) - except ImportError: - pass - - if rules is None: - warnings.warn( - "Transformer Engine logical axes, such as BATCH_AXES, SEQLEN_AXES, etc. are deprecated" - " and removed in a future version. Please use Flax logical axes with the" - " `flax.linen.logical_axis_rules()` context and optionally use" - " `transformer_engine.jax.flax.extend_logical_axis_rules()` to extend Flax axis rules" - " with Transformer Engine logical axes.", - DeprecationWarning, - ) - rules = get_sharding_map_logic_axis_to_mesh_axis() - # mesh_axis_names = [rules[name] for name in logical_axis_names] - mesh_axis_names = [] - for name in logical_axis_names: - axis_name = rules[name] if name in rules else None - mesh_axis_names.append(axis_name) + rules = get_sharding_map_logic_axis_to_mesh_axis() + + mesh_axis_names = [rules.get(name) for name in logical_axis_names] pspec = jax.sharding.PartitionSpec(*mesh_axis_names) - if padded: - pspec = get_padded_spec(pspec, len(mesh_axis_names)) return pspec @@ -188,7 +171,7 @@ def with_sharding_constraint_by_logical_axes( # If no logical axis rules are available from Flax, fallback to TE's hardcoded logical axis rule table assert len(x.shape) == len(logical_axis_names) - pspec = generate_pspec(logical_axis_names) + pspec = _generate_pspec(logical_axis_names) return with_sharding_constraint(x, pspec) From 44fbe9e68163d15f451a399e0222c42bb4f1f079 Mon Sep 17 00:00:00 2001 From: Kshiteej K Date: Thu, 14 Aug 2025 02:54:27 +0200 Subject: [PATCH 069/153] fix: update grad_output quant to avoid redundant work (#1736) * fix: update grad_output quant to avoid redundant work Signed-off-by: kshitij12345 * add test Signed-off-by: kshitij12345 * don't keep only columnwise quant if requires_dgrad=False Signed-off-by: kshitij12345 * fix stray merge Signed-off-by: kshitij12345 * fix for ctx.use_bias is True case Signed-off-by: kshitij12345 * Skip if FP8 not available Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: kshitij12345 Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- tests/pytorch/test_sanity.py | 26 +++++++++++++++++++ .../common/util/cast_kernels.cuh | 2 +- transformer_engine/pytorch/module/linear.py | 13 ++++++++++ 3 files changed, 40 insertions(+), 1 deletion(-) diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 07c636ab1..5f61772d9 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -995,6 +995,32 @@ def backward(ctx, grad_output): torch.testing.assert_close(grad_checkpoint, grad_standard) +@pytest.mark.skipif(not fp8_available, reason=reason_for_no_fp8) +def test_linear_frozen_weights_memory_default_recipe(): + """Test that memory usage is optimized when weights are frozen for MXFP8.""" + dim = 1024 + linear = Linear(dim, dim, bias=False) + x = torch.randn(dim, dim, requires_grad=True, device="cuda") + + # Freeze weights + linear.weight.requires_grad = False + + # Forward and backward pass with FP8 + with fp8_autocast(): + o = linear(x) + g_o = torch.randn_like(o) + + max_memory_before_backward = torch.cuda.max_memory_allocated() + o.backward(g_o) + max_memory_after_backward = torch.cuda.max_memory_allocated() + + memory_diff = (max_memory_after_backward - max_memory_before_backward) / 1e6 + assert memory_diff < 5.5, ( + f"Memory usage with frozen weights ({memory_diff}MB) should be less than 5.5MB as the" + " grad_output should be quantized only columnwise." + ) + + @pytest.mark.parametrize( "module_name", ("Linear", "LayerNormLinear", "LayerNormMLP", "GroupedLinear", "ops.Linear"), diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 5590cee10..7885bbaf3 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1320,7 +1320,7 @@ void fp8_quantize_arch_l_100(const Tensor &input, const Tensor *act_input, const if (!is_tensor_scaling(output->scaling_mode) || IS_DBIAS) { // zhongboz: should we just ignore IS_ACT here? NVTE_ERROR("Not implemented scaling mode or fusion: " + to_string(output->scaling_mode) + - " on GPU with compute capability < 10.0."); + " or IS_DBIAS=true" + " on GPU with compute capability < 10.0."); } switch (output->scaling_mode) { case NVTE_DELAYED_TENSOR_SCALING: { diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 8b05e71d7..695cbb4e6 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -559,6 +559,19 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # usage for only dgrad GEMM. quantizer.set_usage(columnwise=False) + # Adjust the quantization direction approach depending + # on whether wgrad calculations will be performed. + # NOTE: If requires_dgrad is False, disabling `rowwise` quantization and keeping `columnwise` quantization + # results in `Assertion failed: output_tensor->has_data(). Quantizing in only the columnwise direction not supported yet!` + # NOTE: For `ctx.bias is True`, selected quantize kernel errors with + # `cast_kernels.cuh:1322 in function fp8_quantize_arch_l_100: Not implemented scaling mode or fusion: NVTE_DELAYED_TENSOR_SCALING or IS_DBIAS=true on GPU with compute capability < 10.0.` + if ( + not ctx.use_bias + and not ctx.requires_wgrad + and ctx.grad_output_quantizer is not None + ): + ctx.grad_output_quantizer.set_usage(columnwise=False) + # Prepare grad output tensor nvtx_range_push(f"{nvtx_label}.grad_output_preprocess") ( From c582f6bef75f7b84b369d917b6d646cecd048e03 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Thu, 14 Aug 2025 10:35:19 +0800 Subject: [PATCH 070/153] [Common] Reduce CUDA driver calls (#2067) * reduce driver calls Signed-off-by: Xin Yao * reduce driver calls Signed-off-by: Xin Yao * adjust tests to capture this Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_cuda_graphs.py | 8 +++++-- transformer_engine/common/common.cu | 4 +++- .../quantize_transpose_square_blockwise.cu | 2 ++ .../common/util/cast_gated_kernels.cuh | 5 +++- .../common/util/cast_kernels.cuh | 3 ++- .../common/util/cuda_driver.cpp | 23 +++++++++++-------- 6 files changed, 30 insertions(+), 15 deletions(-) diff --git a/tests/pytorch/test_cuda_graphs.py b/tests/pytorch/test_cuda_graphs.py index 9b5118e6e..90e624c94 100644 --- a/tests/pytorch/test_cuda_graphs.py +++ b/tests/pytorch/test_cuda_graphs.py @@ -122,10 +122,12 @@ def forward(self, input_: torch.Tensor, **kwargs) -> torch.Tensor: # Supported modules _test_cuda_graphs_modules: List[str] = [ + # Put linear first to test the case where the cuda context might not be set in + # creating TMA descriptor for MXFP8 quantization. + "linear", "transformer", "layernorm_mlp", "layernorm_linear", - "linear", "mha", "linear_op", ] @@ -308,9 +310,11 @@ def test_make_graphed_callables( fp8_weight_caching=fp8_weight_caching, fp8_recipe=fp8_recipe, ) - outputs = _test_cuda_graphs(graph_mode="none", **kwargs) + # Put graphed callables first to test the case where the cuda context might not be set in + # creating TMA descriptor for MXFP8 quantization. graph_outputs_mode1 = _test_cuda_graphs(graph_mode="full", **kwargs) graph_outputs_mode2 = _test_cuda_graphs(graph_mode="individual", **kwargs) + outputs = _test_cuda_graphs(graph_mode="none", **kwargs) # Check that results match. assert_all_equal(outputs, graph_outputs_mode1) diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 9831bbb24..4e697979d 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -95,6 +95,9 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream } // extern "C" void checkCuDriverContext(CUstream stream) { + // Ensure the thread's "current" CUDA context is set. + cuda_driver::ensure_context_exists(); + CUcontext ctx; const CUresult driver_status = cuda_driver::call("cuStreamGetCtx", stream, &ctx); switch (driver_status) { @@ -138,7 +141,6 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits) { - cuda_driver::ensure_context_exists(); // Get a function pointer to the cuTensorMapEncodeTiled driver API // Note: PFN_cuTensorMapEncodeTiled is not defined in cuda13 static PFN_cuTensorMapEncodeTiled_v12000 cuDriverTensorMapEncodeTiled = []() { diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 0b70f3f40..3a2247f5c 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -488,6 +488,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor const bool return_transpose, const bool pow_2_scale, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_square_blockwise); + checkCuDriverContext(stream); + NVTE_CHECK(input.shape == output.shape, "Input and output must have the same shape."); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; size_t num_rows = 1; diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index d7552835e..83359eb05 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -885,6 +885,8 @@ template void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, cudaStream_t stream) { + checkCuDriverContext(stream); + if (output->has_data()) { NVTE_CHECK(output->scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); } @@ -964,6 +966,8 @@ template void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, cudaStream_t stream) { + checkCuDriverContext(stream); + const bool USE_ROWWISE_SCALING = output->has_data(); const bool USE_COLWISE_SCALING = output->has_columnwise_data(); @@ -1206,7 +1210,6 @@ template void quantize_gated(const Tensor &grad, const Tensor &gated_input, Tensor *output, cudaStream_t stream) { - checkCuDriverContext(stream); constexpr bool allow_empty = false; CheckInputTensor(gated_input, "gated_input"); CheckOutputTensor(*output, "output", allow_empty); diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 7885bbaf3..c084c3116 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1006,9 +1006,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, // TODO (ksivamani) Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { using namespace mxfp8_kernel; + checkCuDriverContext(stream); + bool use_rowwise_scaling = output->has_data(); bool use_colwise_scaling = output->has_columnwise_data(); - checkCuDriverContext(stream); NVTE_CHECK(input.has_data(), "Cannot quantize tensor without rowwise data."); NVTE_CHECK(is_fp8_dtype(output->dtype()), "Output must have FP8 type."); diff --git a/transformer_engine/common/util/cuda_driver.cpp b/transformer_engine/common/util/cuda_driver.cpp index 4812435f7..01e3edf57 100644 --- a/transformer_engine/common/util/cuda_driver.cpp +++ b/transformer_engine/common/util/cuda_driver.cpp @@ -45,16 +45,19 @@ void *get_symbol(const char *symbol, int cuda_version) { } void ensure_context_exists() { - CUcontext context; - NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context); - if (context == nullptr) { - // Add primary context to context stack - CUdevice device; - NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device()); - NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device); - NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); - NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device); - } + static thread_local bool need_check = []() { + CUcontext context; + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxGetCurrent, &context); + if (context == nullptr) { + // Add primary context to context stack + CUdevice device; + NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, cuda::current_device()); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device); + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); + NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRelease, device); + } + return false; + }(); } } // namespace cuda_driver From ccbc8cf40131456b8263e783302ea2d1b8bcfe3a Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 13 Aug 2025 19:51:59 -0700 Subject: [PATCH 071/153] [PyTorch] Register weight and bias params in linear op (#2027) * Register weight/bias params in linear op Signed-off-by: Tim Moon * Tweak docs Signed-off-by: Tim Moon * Make sure linear op checkpoint is backward-compatible Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warning Signed-off-by: Tim Moon * Check for invalid case before setting bias Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/ops/linear.py | 98 +++++++++++++++++------- 1 file changed, 69 insertions(+), 29 deletions(-) diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 8ed2702a7..8686c1853 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -6,7 +6,7 @@ from __future__ import annotations from collections.abc import Callable -from typing import Optional +from typing import Any, Optional import torch @@ -91,6 +91,8 @@ def __init__( # Construct basic ops ops = [] + linear_idx = None + bias_idx = None linear_kwargs = { "in_features": in_features, "out_features": out_features, @@ -111,14 +113,16 @@ def __init__( } if tensor_parallel_mode == "row": # Row TP: GEMM + bias + reduction + linear_idx = len(ops) linear_kwargs["in_features"] = local_in_features linear_kwargs["out_features"] = local_out_features linear_kwargs["tensor_parallel_mode"] = None linear_kwargs["tensor_parallel_group"] = None linear_kwargs["sequence_parallel"] = False - bias_kwargs["size"] *= tensor_parallel_size ops.append(BasicLinear(**linear_kwargs)) if bias: + bias_idx = len(ops) + bias_kwargs["size"] *= tensor_parallel_size ops.append(Bias(**bias_kwargs)) if sequence_parallel: ops.append(ReduceScatter(tensor_parallel_group)) @@ -126,45 +130,81 @@ def __init__( ops.append(AllReduce(tensor_parallel_group)) else: # Column TP or no TP: (gather + GEMM) + bias + linear_idx = len(ops) ops.append(BasicLinear(**linear_kwargs)) if bias: + bias_idx = len(ops) ops.append(Bias(**bias_kwargs)) # Initialize base class super().__init__(ops) - self._has_bias: bool = bias + # Register parameters + self._linear_idx: Optional[int] = linear_idx + self._bias_idx: Optional[int] = bias_idx + self.register_parameter("weight", self.basic_ops[self._linear_idx].weight) + bias = None + if self._bias_idx is not None: + bias = self.basic_ops[self._bias_idx].bias + self.register_parameter("bias", bias) - @property - def weight(self) -> torch.nn.Parameter: - """Weight tensor + def register_parameter(self, name: str, param: Optional[torch.nn.Parameter]) -> None: + """Add a parameter to the module - Parameter is owned by `BasicLinear` operation. + Also updates the basic operation that owns the parameter. """ - return self.basic_ops[0].weight - - @weight.setter - def weight(self, value: Optional[torch.nn.Parameter]) -> None: - self.basic_ops[0].weight = value - - @property - def bias(self) -> Optional[torch.nn.Parameter]: - """Bias tensor - - Parameter is owned by `Bias` operation. - - """ - if self._has_bias: - return self.basic_ops[1].bias - return None - - @bias.setter - def bias(self, value: Optional[torch.nn.Parameter]) -> None: - if self._has_bias: - self.basic_ops[1].bias = value - elif value is not None: + if name == "bias" and self._bias_idx is None and param is not None: raise ValueError( "Attempted to set bias parameter in Linear operation " "that does not have bias enabled" ) + super().register_parameter(name, param) + if name == "weight": + self.basic_ops[self._linear_idx].weight = param + elif name == "bias" and self._bias_idx is not None: + self.basic_ops[self._bias_idx].bias = param + + def state_dict(self, *, prefix: str = "", **kwargs) -> dict[str, Any]: + """Save state""" + state_dict = super().state_dict(prefix=prefix, **kwargs) + + # Remove basic op params from state dict + # Note: Logically, basic ops own params and fused ops are + # considered as stateless. However, we register weight and + # bias params in the linear op for convenience. We remove + # these redudant params from the checkpoint for backward + # compatibility. + if f"{prefix}weight" in state_dict: + del state_dict[f"{prefix}weight"] + if f"{prefix}bias" in state_dict: + del state_dict[f"{prefix}bias"] + + return state_dict + + def _load_from_state_dict( + self, + state_dict: dict[str, Any], + prefix: str, + *args, + **kwargs, + ) -> None: + + # Add basic op params to state dict + # Note: Logically, basic ops own params and fused ops are + # considered as stateless. However, we register weight and + # bias params in the linear op for convenience. We remove + # these redudant params from the checkpoint for backward + # compatibility. + if f"{prefix}weight" not in state_dict: + state_dict[f"{prefix}weight"] = state_dict[ + f"{prefix}basic_ops.{self._linear_idx}.weight" + ] + if f"{prefix}bias" not in state_dict: + if self._bias_idx is None: + state_dict[f"{prefix}bias"] = None + else: + state_dict[f"{prefix}bias"] = state_dict[f"{prefix}basic_ops.{self._bias_idx}.bias"] + + # Load state dict + super()._load_from_state_dict(state_dict, prefix, *args, **kwargs) From 26b4b71acf8a9d0a1796de526401c67af867b46d Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 13 Aug 2025 19:52:06 -0700 Subject: [PATCH 072/153] [PyTorch] Avoid registering FP8 scale update in ops without backward pass (#2063) Avoid registering FP8 recipe update in ops without backward pass Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/ops/fuser.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 2ee476779..b46bfb4b7 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -176,6 +176,11 @@ def forward( func_ctx.save_for_backward(*tensors_to_save) func_ctx.tensor_objects = tensor_objects + # Whether to perform recipe update in backward pass + is_first_module = False + if fuser.first_op_requiring_backward < fuser._num_basic_ops: + is_first_module = FP8GlobalStateManager.is_first_fp8_module() + # Other context func_ctx.backward_ops = fuser._backward_ops func_ctx.basic_ops = fuser._basic_ops @@ -183,7 +188,7 @@ def forward( func_ctx.basic_op_num_params = fuser._basic_op_num_params func_ctx.num_extra_inputs = fuser.num_extra_inputs func_ctx.num_extra_outputs = len(extra_outputs_flat) - func_ctx.is_first_module = FP8GlobalStateManager.is_first_fp8_module() + func_ctx.is_first_module = is_first_module # Mark output tensors as not deletable in backward for tensor in [x] + extra_outputs_flat: From a169e9e709d51b34806babd7fa1afaa7ccbfeeb7 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 13 Aug 2025 20:11:49 -0700 Subject: [PATCH 073/153] [PyTorch] Disable fused dbias-quantize kernel for unsupported recipes (#2007) * Unfused impl for dbias-quantize Signed-off-by: Tim Moon * Unfused impl for dact-dbias-quantize Signed-off-by: Tim Moon * Disable fused bgrad-quantize for unsupported recipes Signed-off-by: Tim Moon * Remove unfused dbias-quantize impls Not supported in the core lib. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Support unfused impls in tex functions Signed-off-by: Tim Moon * Tweaks Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unused imports Signed-off-by: Tim Moon --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 3 +- .../pytorch/csrc/extensions/bias.cpp | 253 ++++++++++++++---- .../pytorch/csrc/extensions/cast.cpp | 61 ----- transformer_engine/pytorch/ops/basic/bias.py | 16 +- .../ops/fused/backward_activation_bias.py | 2 +- 5 files changed, 206 insertions(+), 129 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9b9bb58ac..6638fdb57 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -2078,7 +2078,7 @@ def test_backward_activation_bias( # Check that backward operations have been fused backward_ops = model._module_groups[0]._backward_ops - if with_quantization and quantization in ["fp8_delayed_scaling", "mxfp8"]: + if with_quantization: assert len(backward_ops) == 2 assert isinstance(backward_ops[0][0], BackwardActivationBias) assert isinstance(backward_ops[1][0], te_ops.Quantize) @@ -2093,6 +2093,7 @@ def test_backward_activation_bias( if with_quantization: tols = dtype_tols(tex.DType.kFloat8E4M3) + # Check results y_test = y_test.to(dtype=torch.float64, device="cpu") dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") db_test = model[1].bias.grad.to(dtype=torch.float64, device="cpu") diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 63455e3c0..a80cb35f2 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -4,80 +4,223 @@ * See LICENSE for license information. ************************************************************************/ +#include +#include + +#include +#include + #include "common.h" +#include "extensions.h" #include "pybind.h" #include "transformer_engine/cast.h" #include "transformer_engine/transformer_engine.h" -namespace transformer_engine::pytorch { +namespace transformer_engine { +namespace pytorch { -std::vector bgrad_quantize(const at::Tensor& input, py::handle py_quantizer) { - auto quantizer = convert_quantizer(py_quantizer); +std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle quantizer) { + using namespace transformer_engine::pytorch::detail; + init_extension(); - auto input_tensor = makeTransformerEngineTensor(input); + // Grad output tensor + auto grad_output_torch = grad_output.contiguous(); + const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); + const auto shape = getTensorShape(grad_output_torch); + auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); - auto dbias = allocateTorchTensor(input.size(-1), input_tensor.dtype()); + // Construct grad bias tensor + const int64_t bias_size = static_cast(shape.back()); + auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype); + auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch); - std::vector output_shape; - for (auto s : input.sizes()) { - output_shape.emplace_back(static_cast(s)); + // Unquantized impl only requires computing grad bias + if (quantizer.is_none()) { + if (product(shape) == 0) { + grad_bias_torch.zero_(); + } else { + at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); + } + return {py::cast(std::move(grad_bias_torch)), py::cast(std::move(grad_output_torch))}; } - auto [out_tensor, out] = quantizer->create_tensor(output_shape, input_tensor.dtype()); - // Return immediately if tensors are empty - if (product(output_shape) == 0) { - return {py::cast(dbias.zero_()), out}; + // Construct grad input tensor + auto quantizer_cpp = convert_quantizer(quantizer); + auto [grad_input_nvte, grad_input_py] = quantizer_cpp->create_tensor(shape, grad_output_dtype); + + // Trivial impl if tensors are empty + if (product(shape) == 0) { + grad_bias_torch.zero_(); + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; + } + + // Unfused impl if quantizer is not supported + const bool with_fused_dbias_quantize_kernel = + detail::IsFloat8Quantizers(quantizer.ptr()) || detail::IsMXFP8Quantizers(quantizer.ptr()); + if (!with_fused_dbias_quantize_kernel) { + at::sum_out(grad_bias_torch, grad_output_torch.reshape({-1, bias_size}), {0}); + quantizer_cpp->quantize(grad_output_nvte, grad_input_nvte); + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; } - auto dbias_tensor = makeTransformerEngineTensor(dbias); - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; + // Query workspace size + TensorWrapper workspace_nvte; + at::Tensor workspace_torch; + auto stream = at::cuda::getCurrentCUDAStream(); NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(), + workspace_nvte.data(), stream); }); - void* workspace_data_ptr = nullptr; - if (workspace.shape().ndim > 0) { - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace_data_ptr = workspace_data.data_ptr(); - } - workspace = makeTransformerEngineTensor(workspace_data_ptr, workspace.shape(), workspace.dtype()); - - // Launch kernel - if (detail::IsFloat8CurrentScalingQuantizers(py_quantizer.ptr())) { - // my_quantizer here has to be a Float8CurrentScalingQuantizer - auto my_quantizer_cs = static_cast(quantizer.get()); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_amax(input_tensor.data(), out_tensor.data(), at::cuda::getCurrentCUDAStream()); - }); - // check if we need to do amax reudction (depending on model parallel configs) - if (my_quantizer_cs->with_amax_reduction) { - c10::intrusive_ptr process_group_ptr = my_quantizer_cs->amax_reduction_group; - // construct torch tesnor from NVTEBasicTensor without reallocating memory - at::Tensor& amax_tensor_torch = my_quantizer_cs->amax; - std::vector tensors = {amax_tensor_torch}; - // allreduce amax tensor - c10d::AllreduceOptions allreduce_opts; - allreduce_opts.reduceOp = c10d::ReduceOp::MAX; - process_group_ptr->allreduce(tensors, allreduce_opts)->wait(); - } - QuantizationConfigWrapper quant_config; - quant_config.set_force_pow_2_scales(my_quantizer_cs->force_pow_2_scales); - quant_config.set_amax_epsilon(my_quantizer_cs->amax_epsilon); - NVTE_SCOPED_GIL_RELEASE({ - nvte_compute_scale_from_amax(out_tensor.data(), quant_config, - at::cuda::getCurrentCUDAStream()); - }); - // set amax ptr to null in te_output TensorWrapper to avoid atomic amax updates in kernel - out_tensor.set_amax(nullptr, DType::kFloat32, out_tensor.defaultShape); + // Allocate workspace + if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) { + workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype()); + workspace_nvte = makeTransformerEngineTensor(workspace_torch.data_ptr(), workspace_nvte.shape(), + workspace_nvte.dtype()); } + + // Launch fused kernel NVTE_SCOPED_GIL_RELEASE({ - nvte_quantize_dbias(input_tensor.data(), out_tensor.data(), dbias_tensor.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); + nvte_quantize_dbias(grad_output_nvte.data(), grad_input_nvte.data(), grad_bias_nvte.data(), + workspace_nvte.data(), stream); }); - return {py::cast(dbias), out}; + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; +} + +namespace { + +std::vector dact_dbias( + void (*dact_dbias_func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor, + cudaStream_t), + void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t), + at::Tensor grad_output_torch, at::Tensor act_input_torch, py::handle quantizer_py) { + using namespace transformer_engine::pytorch::detail; + init_extension(); + + // Grad output and activation input tensors + grad_output_torch = grad_output_torch.contiguous(); + const TensorWrapper &grad_output_nvte = makeTransformerEngineTensor(grad_output_torch); + const auto output_shape = getTensorShape(grad_output_torch); + auto grad_output_dtype = GetTransformerEngineDType(grad_output_torch.scalar_type()); + act_input_torch = act_input_torch.contiguous(); + const TensorWrapper &act_input_nvte = makeTransformerEngineTensor(act_input_torch); + const auto input_shape = getTensorShape(act_input_torch); + + // Construct tensors + auto quantizer_cpp = convert_quantizer(quantizer_py); + auto [grad_input_nvte, grad_input_py] = + quantizer_cpp->create_tensor(input_shape, grad_output_dtype); + const int64_t bias_size = static_cast(input_shape.back()); + auto grad_bias_torch = allocateTorchTensor(bias_size, grad_output_dtype); + auto grad_bias_nvte = makeTransformerEngineTensor(grad_bias_torch); + + // Return immediately if tensors are empty + if (product(output_shape) == 0) { + grad_bias_torch.zero_(); + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; + } + + // Choose implementation + enum class Impl { UNFUSED, FUSED_DACT_DBIAS_QUANTIZE, FUSED_DACT_AMAX }; + Impl impl = Impl::UNFUSED; + if (detail::IsFloat8Quantizers(quantizer_py.ptr()) || + detail::IsMXFP8Quantizers(quantizer_py.ptr())) { + impl = Impl::FUSED_DACT_DBIAS_QUANTIZE; + } else if (detail::IsFloat8CurrentScalingQuantizers(quantizer_py.ptr())) { + impl = Impl::FUSED_DACT_AMAX; + } + + // Perform compute + auto stream = at::cuda::getCurrentCUDAStream(); + switch (impl) { + case Impl::UNFUSED: + // Unfused dact, dbias, quantize + { + auto [temp_nvte, temp_py] = + NoneQuantizer(py::none()).create_tensor(input_shape, grad_output_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); + }); + const auto temp_torch = temp_py.cast(); + at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); + quantizer_cpp->quantize(temp_nvte, grad_input_nvte); + break; + } + case Impl::FUSED_DACT_DBIAS_QUANTIZE: + // Fused dact-dbias-quantize kernel + { + // Query workspace size + TensorWrapper workspace_nvte; + NVTE_SCOPED_GIL_RELEASE({ + dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(), + grad_bias_nvte.data(), workspace_nvte.data(), stream); + }); + + // Allocate workspace + at::Tensor workspace_torch; + if (workspace_nvte.ndim() > 0 && workspace_nvte.numel() > 0) { + workspace_torch = allocateSpace(workspace_nvte.shape(), workspace_nvte.dtype()); + workspace_nvte = makeTransformerEngineTensor( + workspace_torch.data_ptr(), workspace_nvte.shape(), workspace_nvte.dtype()); + } + + // Launch kernel + NVTE_SCOPED_GIL_RELEASE({ + dact_dbias_func(grad_output_nvte.data(), act_input_nvte.data(), grad_input_nvte.data(), + grad_bias_nvte.data(), workspace_nvte.data(), stream); + }); + break; + } + case Impl::FUSED_DACT_AMAX: + // Fused dact-amax kernel, unfused dbias and quantize + { + auto *quantizer_cpp_cs = dynamic_cast(quantizer_cpp.get()); + NVTE_CHECK(quantizer_cpp_cs != nullptr, + "Invalid quantizer for fused dact-amax kernel impl"); + auto [temp_nvte, temp_py] = + quantizer_cpp_cs->create_hp_tensor_with_amax(input_shape, grad_output_dtype); + NVTE_SCOPED_GIL_RELEASE({ + dact_func(grad_output_nvte.data(), act_input_nvte.data(), temp_nvte.data(), stream); + }); + const auto temp_torch = temp_py.cast(); + at::sum_out(grad_bias_torch, temp_torch.reshape({-1, bias_size}), {0}); + quantizer_cpp_cs->quantize_with_amax(temp_nvte, grad_input_nvte); + break; + } + default: + NVTE_ERROR("Invalid implementation"); + } + + return {py::cast(std::move(grad_bias_torch)), std::move(grad_input_py)}; +} + +} // namespace + +std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_dgelu, nvte_dgelu, grad_output, act_input, quantizer); +} + +std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_dsilu, nvte_dsilu, grad_output, act_input, quantizer); +} + +std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_drelu, nvte_drelu, grad_output, act_input, quantizer); +} + +std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_dqgelu, nvte_dqgelu, grad_output, act_input, quantizer); +} + +std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, + py::handle quantizer) { + return dact_dbias(nvte_quantize_dbias_dsrelu, nvte_dsrelu, grad_output, act_input, quantizer); } -} // namespace transformer_engine::pytorch +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index fe7aecbc2..819d3e518 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -587,66 +587,5 @@ std::vector split_quantize(const at::Tensor &tensor, return output_py_list; } -template -std::vector dbias_dact(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - init_extension(); - auto my_quantizer = convert_quantizer(quantizer); - - auto grad_tensor = makeTransformerEngineTensor(grad_output); - - auto grad_bias = allocateTorchTensor(grad_output.size(-1), grad_tensor.dtype()); - auto act_input_tensor = makeTransformerEngineTensor(act_input); - - const auto &shape = convertShape(grad_tensor.shape()); - auto [dact_tensor, dact] = my_quantizer->create_tensor(shape, act_input_tensor.dtype()); - - auto dbias_tensor = makeTransformerEngineTensor(grad_bias); - - // Query workspace size and allocate workspace - transformer_engine::TensorWrapper workspace; - NVTE_SCOPED_GIL_RELEASE({ - func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - }); - auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); - workspace = - makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); - - // Launch kernel - NVTE_SCOPED_GIL_RELEASE({ - func(grad_tensor.data(), act_input_tensor.data(), dact_tensor.data(), dbias_tensor.data(), - workspace.data(), at::cuda::getCurrentCUDAStream()); - }); - - return {py::cast(grad_bias), dact}; -} - -std::vector dbias_dgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - -std::vector dbias_dsilu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - -std::vector dbias_drelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - -std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - -std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, - py::handle quantizer) { - return dbias_dact(grad_output, act_input, quantizer); -} - } // namespace pytorch } // namespace transformer_engine diff --git a/transformer_engine/pytorch/ops/basic/bias.py b/transformer_engine/pytorch/ops/basic/bias.py index 4c107b888..5ec0d2ce5 100644 --- a/transformer_engine/pytorch/ops/basic/bias.py +++ b/transformer_engine/pytorch/ops/basic/bias.py @@ -10,14 +10,8 @@ import torch import transformer_engine_torch as tex -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) -from ...utils import ( - canonicalize_device, - canonicalize_dtype, -) +from ..op import BasicOperation, OperationContext +from ...utils import canonicalize_device, canonicalize_dtype from ...tensor import Quantizer @@ -141,10 +135,10 @@ def op_backward( dy = grad_output if dy.dim() > 1: quantizer = ctx.grad_input_quantizer - if quantizer is not None: - db, dy = tex.bgrad_quantize(dy, quantizer) - else: + if quantizer is None: db = dy.sum(tuple(range(dy.dim() - 1))) + else: + db, dy = tex.bgrad_quantize(dy, quantizer) else: db = dy return dy, (db,) diff --git a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py index bf3ff8ca6..40510c856 100644 --- a/transformer_engine/pytorch/ops/fused/backward_activation_bias.py +++ b/transformer_engine/pytorch/ops/fused/backward_activation_bias.py @@ -104,7 +104,7 @@ def fuse_backward_activation_bias( """ # Check if recipe supports bias activation fusion - if recipe is None or not (recipe.delayed() or recipe.mxfp8()): + if recipe is None: return ops # Scan through ops, fusing if possible From 12065ac2b275e30b3a461d4420ed8462d45ab457 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 14 Aug 2025 16:10:30 -0700 Subject: [PATCH 074/153] [Core] Add launch bounds to swizzle kernels (#2076) Add launch bounds to swizzle kernel, use empty scale inv Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/common/swizzle/swizzle.cu | 12 ++++++------ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 4 ++-- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 37d7491d9..fcb379a82 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -145,9 +145,9 @@ __device__ void swizzle_col_scaling_kernel_impl(const void* input, void* output, } template -__global__ void swizzle_col_scaling_kernel(const void* input, void* output, const int M, - const int K, const int original_M, - const int original_K) { +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_col_scaling_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { swizzle_col_scaling_kernel_impl( input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); } @@ -238,9 +238,9 @@ __device__ void swizzle_row_scaling_kernel_impl(const void* input, void* output, } template -__global__ void swizzle_row_scaling_kernel(const void* input, void* output, const int M, - const int K, const int original_M, - const int original_K) { +__global__ void __launch_bounds__(TB_DIM* TB_DIM) + swizzle_row_scaling_kernel(const void* input, void* output, const int M, const int K, + const int original_M, const int original_K) { swizzle_row_scaling_kernel_impl( input, output, M, K, original_M, original_K, blockIdx.x, blockIdx.y, gridDim.x, gridDim.y); } diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index b96575d37..321c351dd 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -100,7 +100,7 @@ def make_empty( # Allocate FP8 data data = torch.empty(shape, dtype=torch.uint8, device=device) - scale_inv = torch.zeros( + scale_inv = torch.empty( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), dtype=torch.uint8, @@ -112,7 +112,7 @@ def make_empty( columnwise_scale_inv = None if self.columnwise_usage: columnwise_data = torch.empty_like(data) - columnwise_scale_inv = torch.zeros( + columnwise_scale_inv = torch.empty( round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, From 92f431bfee3bebbf49d1cc6f6bc37796bffd8bb7 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Thu, 14 Aug 2025 18:07:13 -0700 Subject: [PATCH 075/153] [JAX] Trim dist fused attn tests in L1 (#2050) * Move some dist fused attn tests to L2 1. TestReorderCausalLoadBalancing: Run two (non symmetric) BSHD/SBHD data shape combination 2. TestDistributedSelfAttn: Run only one (smaller) BSHD type data shape combination 3. TestDistributedCrossAttn: Run only one (smaller) BSHD type data shape combination 4. TestDistributedContextParallelSelfAttn: Run all cp1 combinations Signed-off-by: Kshitij Janardan Lakhani * Use pytest_parametrize_wrapper for splitting fused attn distributed JAX tests as L1 and L2 Signed-off-by: Kshitij Janardan Lakhani * Undo pytest -k split commands in qa scripts Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix usage of pytest_parametrize_wrapper in test_distributed_fused_attn Signed-off-by: Kshitij Janardan Lakhani * Remove test code for L2 dist residing in L2 test.sh Signed-off-by: Kshitij Janardan Lakhani * Add comments for code. Swap the test data shapes in REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add L0 to the data shape dictionaries in the distributed test Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Janardan Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Janardan Lakhani Signed-off-by: Kshitij Janardan Lakhani Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Kshitij Janardan Lakhani --- tests/jax/distributed_test_base.py | 20 +++++--- tests/jax/test_distributed_fused_attn.py | 63 ++++++++++++++---------- 2 files changed, 51 insertions(+), 32 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 3b86481bd..bda42f5f7 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -39,8 +39,10 @@ def generate_configs(): return configs -def generate_context_parallel_configs(): - configs = [] +def generate_context_parallel_configs_for_attn(): + """Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only""" + configsL1 = [] + configsL2 = [] mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp") axes = ("dp", "cp", "tp") DP_sizes = (1, 2) @@ -49,10 +51,16 @@ def generate_context_parallel_configs(): for dp, cp, tp in product(DP_sizes, CP_sizes, TP_sizes): ndev = cp * tp * dp if is_devices_enough(ndev): - configs.append( - pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") - ) - + # Do not run cp1 case in L1 as that is already covered in TestDistributedSelfAttn and TestDistributedCrossAttn (as these do not have any cp combinations) + if cp != 1: + configsL1.append( + pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") + ) + else: + configsL2.append( + pytest.param(ndev, (dp, cp, tp), axes, mr, id=f"n{ndev}_dp{dp}_cp{cp}_tp{tp}") + ) + configs = {"L0": [], "L1": configsL1, "L2": configsL2} return configs diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index e88108155..ea29736e7 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -9,10 +9,11 @@ from jax import random from distributed_test_base import ( generate_configs, - generate_context_parallel_configs, + generate_context_parallel_configs_for_attn, generate_collectives_count, ) from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat +from utils import pytest_parametrize_wrapper from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, AttnBiasType, @@ -28,6 +29,12 @@ DTYPES = [jnp.bfloat16] +DISTRIBUTED_SELF_ATTN_DATA_SHAPES = { + "L0": [()], + "L1": [(32, 1024, 16, 128)], + "L2": [(32, 512, 12, 64)], +} + class TestDistributedSelfAttn: @@ -64,7 +71,6 @@ def impl_test_self_attn( jax.config.update("jax_use_shardy_partitioner", use_shardy) dropout_prob = 0.0 is_training = True - batch, seqlen, num_head, hidden = data_shape if not is_fused_attn_kernel_available( @@ -119,13 +125,7 @@ def impl_test_self_attn( runner.test_backward() @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize( - "data_shape", - [ - pytest.param((32, 512, 12, 64), id="32-512-12-64"), - pytest.param((32, 1024, 16, 128), id="32-1024-16-128"), - ], - ) + @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize( "attn_bias_type, bias_shape", [ @@ -193,6 +193,13 @@ def test_self_attn_shardy( ) +DISTRIBUTED_CROSS_ATTN_DATA_SHAPES = { + "L0": [()], + "L1": [[32, 512, 16, 64]], + "L2": [[32, 128, 12, 64]], +} + + class TestDistributedCrossAttn: def generate_collectives_count_ref(self): @@ -201,7 +208,7 @@ def generate_collectives_count_ref(self): return generate_collectives_count(allreduce=all_reduce_loss_bytes, allgather=0, other=0) @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) - @pytest.mark.parametrize("data_shape", [[32, 128, 12, 64], [32, 512, 16, 64]]) + @pytest_parametrize_wrapper("data_shape", DISTRIBUTED_CROSS_ATTN_DATA_SHAPES) @pytest.mark.parametrize( "attn_mask_type", [AttnMaskType.PADDING_MASK, AttnMaskType.CAUSAL_MASK] ) @@ -390,8 +397,9 @@ def check_has_backend_for_mask(mask_type): runner.test_backward() del os.environ["NVTE_FUSED_RING_ATTENTION_USE_SCAN"] - @pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), ) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @@ -426,8 +434,9 @@ def test_context_parallel_allgather_attn_shardy( use_shardy=True, ) - @pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), ) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize("kv_groups", [1, 8]) @@ -468,8 +477,9 @@ def test_context_parallel_allgather_attn( use_shardy=False, ) - @pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), ) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES) @pytest.mark.parametrize("kv_groups", [1, 8]) @@ -532,8 +542,9 @@ def test_context_parallel_ring_attn( window_size=window_size, ) - @pytest.mark.parametrize( - "device_count,mesh_shape,mesh_axes,mesh_resource", generate_context_parallel_configs() + @pytest_parametrize_wrapper( + "device_count,mesh_shape,mesh_axes,mesh_resource", + generate_context_parallel_configs_for_attn(), ) @pytest.mark.parametrize("data_shape", DISTRIBUTED_CONTEXT_SELF_ATTN_DATA_SHAPES[:1]) @pytest.mark.parametrize("dtype", [pytest.param(jnp.bfloat16, id="BF16")]) @@ -570,16 +581,16 @@ def test_context_parallel_ring_attn_shardy( ) +REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES = { + "L0": [[]], + "L1": [[3, 32, 8, 64]], + "L2": [[4, 32, 12, 32], [1, 16, 1, 1]], +} + + class TestReorderCausalLoadBalancing: @pytest.mark.parametrize("cp_size", [2, 4, 8]) - @pytest.mark.parametrize( - "shape", - [ - pytest.param([1, 16, 1, 1], id="1-16-1-1"), - pytest.param([4, 32, 12, 32], id="4-32-12-32"), - pytest.param([3, 32, 8, 64], id="3-32-8-64"), - ], - ) + @pytest_parametrize_wrapper("shape", REORDER_CAUSAL_LOAD_BALANCING_DATA_SHAPES) @pytest.mark.parametrize("qkv_format", [QKVFormat.BSHD, QKVFormat.SBHD]) @pytest.mark.parametrize( "reorder_strategy", From c654e4fe08062acf6f71d1efd5bfaa421d51d114 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Fri, 15 Aug 2025 14:42:10 -0700 Subject: [PATCH 076/153] Fuse linear+scale+add (#2042) * Add `nvte_cublas_gemm_scaled` Signed-off-by: Jan Bielak * Support use of `alpha` and `beta` in `tex.generic_gemm` Signed-off-by: Jan Bielak * Support use of `alpha` and `beta` in `general_gemm` Signed-off-by: Jan Bielak * Support use of `alpha` and `beta` in `BasicLinear._functional_forward` and `BasicLinear._functional_backward` Signed-off-by: Jan Bielak * Add `ForwardLinearScaleAdd` fusion Signed-off-by: Jan Bielak * Add `BackwardLinearScale` fusion Signed-off-by: Jan Bielak * Apply suggestions from code review Signed-off-by: Jan Bielak * Remove calls to `validate_gemm_scale` from `BasicLinear` Signed-off-by: Jan Bielak --------- Signed-off-by: Jan Bielak Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 198 ++++++++++++++++++ .../common/gemm/cublaslt_gemm.cu | 38 +++- .../common/include/transformer_engine/gemm.h | 30 +++ .../pytorch/cpp_extensions/gemm.py | 16 ++ transformer_engine/pytorch/csrc/extensions.h | 3 +- .../pytorch/csrc/extensions/gemm.cpp | 22 +- .../pytorch/csrc/extensions/pybind.cpp | 3 +- .../pytorch/ops/basic/basic_linear.py | 24 +++ .../pytorch/ops/fused/__init__.py | 8 + .../ops/fused/backward_linear_scale.py | 155 ++++++++++++++ .../ops/fused/forward_linear_scale_add.py | 176 ++++++++++++++++ transformer_engine/pytorch/ops/fuser.py | 4 + 12 files changed, 660 insertions(+), 17 deletions(-) create mode 100644 transformer_engine/pytorch/ops/fused/backward_linear_scale.py create mode 100644 transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 6638fdb57..0164c0446 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -22,8 +22,10 @@ from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, BackwardLinearAdd, + BackwardLinearScale, ForwardLinearBiasActivation, ForwardLinearBiasAdd, + ForwardLinearScaleAdd, ) from transformer_engine.pytorch.tensor import QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( @@ -2008,6 +2010,109 @@ def test_forward_linear_bias_add( db_test = model[0].bias.grad.to(dtype=torch.float64, device="cpu") torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + def test_forward_linear_scale_add( + self, + *, + scale: float, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool = False, + ) -> None: + """Forward GEMM + scale + add""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): + pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") + + # Random data + x1_ref, x1_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + x2_ref, x2_test = make_reference_and_test_tensors( + out_shape, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x1_ref, w_ref) * scale + x2_ref + y_ref.backward(dy_ref) + + # Implementation with fusible operations + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight, recipe=recipe): + model = te_ops.Sequential( + te_ops.Linear( + in_features, + out_features, + bias=False, + device=device, + dtype=dtype, + ), + te_ops.ConstantScale(scale), + te_ops.AddExtraInput(in_place=True), + te_ops.Quantize(), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + del w_test + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x1_test, x2_test) + y_test.backward(dy_test) + + # Check that forward operations have been fused + forward_ops = model._module_groups[0]._forward_ops + assert len(forward_ops) == 2 + assert isinstance(forward_ops[0][0], ForwardLinearScaleAdd) + assert isinstance(forward_ops[1][0], te_ops.Quantize) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx1_test = x1_test.grad.to(dtype=torch.float64, device="cpu") + dx2_test = x2_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx1_test, x1_ref.grad, **tols) + torch.testing.assert_close(dx2_test, x2_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("activation", ("relu", "gelu")) @pytest.mark.parametrize("out_shape", ((32, 32), (32, 1, 32), (8, 2, 2, 32))) @pytest.mark.parametrize("dtype", _dtypes) @@ -2202,6 +2307,99 @@ def test_backward_linear_add( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("scale", (1, 0, -2.5, 3.5)) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("quantization", _quantization_list) + def test_backward_linear_scale( + self, + *, + scale: float, + weight_shape: tuple[int, int] = (32, 32), + in_shape: Iterable[int] = (32, -1), + dtype: torch.dtype, + device: torch.device = "cuda", + quantization: Optional[str], + quantized_weight: bool = False, + ) -> None: + """Backward dgrad GEMM + scale""" + + # Make input and weight shapes consistent + out_features, in_features = weight_shape + in_shape = list(in_shape)[:-1] + [in_features] + out_shape = in_shape[:-1] + [out_features] + + # Skip invalid configurations + quantized_compute = quantization is not None + maybe_skip_quantization(quantization, dims=in_shape, device=device) + maybe_skip_quantization(quantization, dims=out_shape) + if quantized_compute and dtype not in (torch.float16, torch.bfloat16): + pytest.skip("FP8 GEMM is only supported with FP8, FP16, or BF16 output") + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + (out_features, in_features), + quantization=quantization, + test_dtype=dtype, + test_device=device, + ) + dy_ref, dy_test = make_reference_and_test_tensors( + out_shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + y_ref = torch.nn.functional.linear(x_ref, w_ref) * scale + y_ref.backward(dy_ref) + + # Implementation with fusible operations + recipe = make_recipe(quantization) + with te.fp8_model_init(enabled=quantized_weight): + model = te_ops.Sequential( + te_ops.Linear( + in_features, + out_features, + bias=False, + device=device, + dtype=dtype, + ), + te_ops.ConstantScale(scale), + ) + with torch.no_grad(): + model[0].weight.copy_(w_test) + del w_test + with te.fp8_autocast(enabled=quantized_compute, fp8_recipe=recipe): + y_test = model(x_test) + (y_test * dy_test).sum().backward() + + # Check that backward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], BackwardLinearScale) + + # Expected numerical error + tols = dtype_tols(dtype) + if dtype == torch.float32: + tols = dtype_tols(torch.float16) # TF32 GEMM + if quantized_compute: + tols = dtype_tols(tex.DType.kFloat8E4M3) + + # Check results + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[0].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y_test, y_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + class TestCheckpointing: """Tests for checkpointing""" diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 29430e43f..1c4af23eb 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -238,8 +238,9 @@ using cublasHandleManager = detail::HandleManagerflat_first_dim(); const int A1 = inputA->flat_last_dim(); @@ -295,13 +296,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, "fp8 Aux output for gemm + gelu fusion not supported!"); } if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); + NVTE_CHECK(beta == 0.0f, "Accumulation mode not supported with FP8 GEMM output!"); } - float one = 1.0; - float zero = 0.0; - float beta = (accumulate) ? one : zero; - cublasLtHandle_t handle = cublasHandleManager::Instance().GetHandle(); cublasLtMatmulDesc_t operationDesc = nullptr; @@ -586,7 +583,7 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, // D = alpha * (A * B) + beta * C NVTE_CHECK_CUBLAS(cublasLtMatmul(handle, operationDesc, - static_cast(&one), /* alpha */ + static_cast(&alpha), /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ Bdesc, static_cast(&beta), /* beta */ @@ -629,7 +626,26 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); + 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false, + nullptr, stream); +} + +void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, + const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, + bool transb, bool grad, NVTETensor workspace, float alpha, float beta, + bool use_split_accumulator, int math_sm_count, cudaStream_t stream) { + NVTE_API_CALL(nvte_cublas_gemm_scaled); + using namespace transformer_engine; + const Tensor *inputA = convertNVTETensorCheck(A); + const Tensor *inputB = convertNVTETensorCheck(B); + Tensor *outputD = convertNVTETensor(D); + const Tensor *biasTensor = convertNVTETensor(bias); + Tensor *outputGelu = convertNVTETensor(pre_gelu_out); + Tensor *wspace = convertNVTETensor(workspace); + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], + alpha, beta, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, @@ -671,8 +687,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor "Atomic GEMM only supports delayed scaling."); cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], - accumulate, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, - inputCounter, stream); + 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split, + n_split, gemm_producer, inputCounter, stream); } void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index d7e257cc1..50b33909f 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -44,6 +44,36 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons NVTETensor workspace, bool accumulate, bool use_split_accumulator, int math_sm_count, cudaStream_t stream); +/*! \brief Compute matrix multiplication of 2 matrices, potentially fused with other operations, + * allowing for using a scaling factor for the GEMM result and the accumulation input + * + * Computes: + * - `D = alpha*AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = alpha*AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(alpha*AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * \param[in] A The A matrix. + * \param[in] B The B matrix. + * \param[in,out] D Output matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_gelu_out Output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of the + * gradient computation. + * \param[out] workspace Workspace tensor. + * \param[in] alpha Scaling factor applied to the result of the GEMM + * \param[in] beta Scaling factor applied to original value of D when + * accumulating into it. beta=0 means no accumulation. + * \param[in] use_split_accumulator Whether to use split accumulator in the FP8 GEMM. + * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_cublas_gemm_scaled(const NVTETensor A, const NVTETensor B, NVTETensor D, + const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, + bool transb, bool grad, NVTETensor workspace, float alpha, float beta, + bool use_split_accumulator, int math_sm_count, cudaStream_t stream); + /*! \brief Compute matrix multiplication of 2 matrices with chunking and atomic counters. * * \warning Cublas atomic gemm uses a beta API and is not tested for all use cases. diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 9f3921d36..e4f4e619f 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -21,6 +21,15 @@ ] +def validate_gemm_scale(scale: Optional[float], required: bool) -> float: + """Validate whether a GEMM scaling factor is consistent with its usage""" + if required: + return scale if scale is not None else 1.0 + if scale not in (0.0, None): + raise ValueError("scale must be zero") + return 0.0 + + def general_gemm( A: torch.Tensor, B: torch.Tensor, @@ -29,6 +38,8 @@ def general_gemm( quantization_params: Optional[Quantizer] = None, gelu: bool = False, gelu_in: torch.Tensor = None, + alpha: float = 1.0, + beta: Optional[float] = None, accumulate: bool = False, layout: str = "TN", out: Optional[torch.Tensor] = None, @@ -47,6 +58,9 @@ def general_gemm( transb = layout[1] == "T" # assert quantization_params is None, "FP8 output not supported yet" + alpha = validate_gemm_scale(alpha, True) + beta = validate_gemm_scale(beta, accumulate) + if ub_type is not None: assert ub is not None, ( f"{'AG+GEMM' if ub_type == tex.CommOverlapType.AG else 'GEMM+RS'} overlap requires" @@ -108,6 +122,8 @@ def general_gemm( "comm_type": ub_type, "extra_output": extra_output, "bulk_overlap": bulk_overlap, + "alpha": alpha, + "beta": beta, } out, bias_grad, gelu_input, extra_output = tex.generic_gemm(*args, **kwargs) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 1f2460cbf..2f4414328 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -122,7 +122,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, CommOverlapCore *comm_overlap = nullptr, std::optional comm_type = std::nullopt, - MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false); + MaybeTensor extra_output = std::nullopt, bool bulk_overlap = false, + float alpha = 1.0f, std::optional beta = std::nullopt); void te_atomic_gemm(at::Tensor A, at::Tensor A_scale_inverse, DType A_type, std::vector A_scaling_mode, bool transa, at::Tensor B, diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 4f1ab3e56..f4768bb9b 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -92,7 +92,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans at::Tensor workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, CommOverlapCore* comm_overlap, std::optional comm_type, MaybeTensor extra_output, - bool bulk_overlap) { + bool bulk_overlap, float alpha, std::optional beta) { // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); @@ -110,6 +110,19 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans NVTE_CHECK(A_shape.ndim >= 1, "Tensor A needs to have at least 1 dimension"); NVTE_CHECK(B_shape.ndim >= 1, "Tensor B needs to have at least 1 dimension"); + // Check scaling factors + if (accumulate) { + if (!beta) { + beta = 1.0f; + } + } else { + if (!beta) { + beta = 0.0f; + } + NVTE_CHECK(beta == 0.0, "Trying to use non-zero beta while not accumulating ", + "into D tensor. Beta has nothing to be applied to."); + } + // Output tensor TensorWrapper D_tensor; if (D.is_none()) { @@ -238,9 +251,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm(A_tensor.data(), B_tensor.data(), D_tensor.data(), bias_tensor.data(), - te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), - accumulate, use_split_accumulator, num_math_sms, main_stream); + nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(), + bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, + te_workspace.data(), alpha, *beta, use_split_accumulator, + num_math_sms, main_stream); }); } } else { diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index dceaa5b15..d38348ae9 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -111,7 +111,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("gelu"), py::arg("gelu_in"), py::arg("grad"), py::arg("workspace"), py::arg("workspace_size"), py::arg("accumulate"), py::arg("use_split_accumulator"), py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, - py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false); + py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false, + py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), py::arg("quantizer")); m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index c0ec991ff..877596824 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -350,10 +350,12 @@ def _functional_forward( input: torch.Tensor, # pylint: disable=redefined-builtin weight: torch.Tensor, *, + alpha: float = 1.0, bias: Optional[torch.Tensor] = None, device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, out: Optional[torch.Tensor] = None, + beta: Optional[float] = None, accumulate_into_out: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, @@ -373,6 +375,8 @@ def _functional_forward( Input tensor weight: torch.Tensor Weight tensor + alpha: float, default = 1.0 + Scaling factor applied to the result of the GEMM bias: torch.Tensor, optional Bias tensor device: torch.device, default = default CUDA device @@ -381,6 +385,8 @@ def _functional_forward( Tensor datatype out: torch.Tensor, optional Output tensor + beta: float, optional + Scaling factor applied to original value of out when accumulating into it accumulate_into_out: bool, default = `False` Add result to output tensor instead of overwriting tensor_parallel_mode: {`None`, "column", "row"}, default = `None` @@ -530,6 +536,8 @@ def _functional_forward( get_workspace(), out_dtype=dtype, quantization_params=output_quantizer, + alpha=alpha, + beta=beta, accumulate=accumulate_into_out, out=y, bias=bias, @@ -567,13 +575,17 @@ def _functional_backward( input: Optional[torch.Tensor], # pylint: disable=redefined-builtin weight: Optional[torch.Tensor], *, + grad_input_alpha: Optional[float] = None, input_requires_grad: bool = True, + grad_weight_alpha: Optional[float] = None, weight_requires_grad: bool = True, device: Optional[torch.device] = None, # pylint: disable=unused-argument dtype: Optional[torch.dtype] = None, grad_weight: Optional[torch.Tensor] = None, + grad_weight_beta: Optional[float] = None, accumulate_into_grad_weight: bool = False, grad_input: Optional[torch.Tensor] = None, + grad_input_beta: Optional[float] = None, accumulate_into_grad_input: bool = False, tensor_parallel_mode: Optional[str] = None, tensor_parallel_group: Optional[torch.distributed.ProcessGroup] = None, @@ -596,8 +608,12 @@ def _functional_backward( weight: torch.Tensor, optional Weight tensor. Required to compute loss gradient w.r.t. input. + grad_input_alpha: float, optional + Scaling factor applied to the result of the dgrad GEMM input_requires_grad: bool Whether to compute loss gradient w.r.t. input tensor + grad_weight_alpha: float, optional + Scaling factor applied to the result of the wgrad GEMM weight_requires_grad: bool Whether to compute loss gradient w.r.t. weight tensor device: torch.device, default = default CUDA device @@ -606,10 +622,14 @@ def _functional_backward( Tensor datatype grad_weight: torch.Tensor, optional Loss gradient w.r.t. weight tensor + grad_weight_beta: float, optional + Scaling factor applied to original value of grad_weight when accumulating into it accumulate_into_grad_weight: bool, default = `False` Add result to weight grad instead of overwriting grad_input: torch.Tensor, optional Loss gradient w.r.t. input tensor + grad_input_beta: float, optional + Scaling factor applied to original value of grad_input when accumulating into it accumulate_into_grad_input: bool, default = `False` Add result to input grad instead of overwriting tensor_parallel_mode: {`None`, "column", "row"}, default = `None` @@ -806,6 +826,8 @@ def _functional_backward( get_workspace(), out_dtype=dtype, quantization_params=grad_input_quantizer, + alpha=grad_input_alpha, + beta=grad_input_beta, accumulate=accumulate_into_grad_input, layout="NN", out=dx, @@ -856,6 +878,8 @@ def _functional_backward( dy, get_workspace(), out_dtype=dw_dtype, + alpha=grad_weight_alpha, + beta=grad_weight_beta, accumulate=accumulate_into_grad_weight, layout="NT", out=dw, diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index 3ee23dc7f..b21be1924 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -12,6 +12,10 @@ BackwardLinearAdd, fuse_backward_linear_add, ) +from .backward_linear_scale import ( + BackwardLinearScale, + fuse_backward_linear_scale, +) from .forward_linear_bias_activation import ( ForwardLinearBiasActivation, fuse_forward_linear_bias_activation, @@ -20,6 +24,10 @@ ForwardLinearBiasAdd, fuse_forward_linear_bias_add, ) +from .forward_linear_scale_add import ( + ForwardLinearScaleAdd, + fuse_forward_linear_scale_add, +) from .userbuffers_backward_linear import ( UserbuffersBackwardLinear, fuse_userbuffers_backward_linear, diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py new file mode 100644 index 000000000..630a63157 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -0,0 +1,155 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused backward dgrad GEMM + scale.""" + +from __future__ import annotations +from typing import Optional + +import torch + +from ..basic import BasicLinear, ConstantScale +from ..op import ( + FusedOperation, + FusibleOperation, + OperationContext, +) +from ...utils import clear_tensor_data + + +class BackwardLinearScale(FusedOperation): + """Fused backward dgrad GEMM + scale + + Column tensor parallelism is not supported since that requires + communication immediately after the dgrad GEMM. + + """ + + def __init__( + self, + *, + scale: ConstantScale, + linear: BasicLinear, + ) -> None: + super().__init__((linear, scale)) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + linear_op = self.basic_ops[0] + linear_op_ctx = basic_op_ctxs[1] + scale_op = self.basic_ops[1] + + # Saved tensors from forward pass + (x_local, w) = linear_op_ctx.saved_tensors + + # wgrad fusion + accumulate_into_main_grad = linear_op._accumulate_into_main_grad + grad_weight = None + if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: + if hasattr(linear_op.weight, "__fsdp_param__"): + linear_op.weight.main_grad = linear_op.weight.get_main_grad() + + if not hasattr(linear_op.weight, "main_grad"): + raise RuntimeError( + "BasicLinear op is configured with " + "accumulate_into_main_grad=True, " + "but weight parameter does not have main_grad attribute" + ) + grad_weight = linear_op.weight.main_grad.detach() + else: + accumulate_into_main_grad = False + + # Linear backward pass + grad_input, grad_weight = BasicLinear._functional_backward( + grad_output=grad_output, + input=x_local, + weight=w, + input_requires_grad=linear_op_ctx.input_requires_grad, + grad_input_alpha=scale_op.scale, + weight_requires_grad=linear_op_ctx.weight_requires_grad, + grad_weight_alpha=scale_op.scale, + dtype=linear_op_ctx.dtype, + grad_weight=grad_weight, + accumulate_into_grad_weight=accumulate_into_main_grad, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_quantized_compute=linear_op_ctx.with_quantized_compute, + input_quantizer=linear_op_ctx.input_quantizer, + weight_quantizer=linear_op_ctx.weight_quantizer, + grad_output_quantizer=linear_op_ctx.grad_output_quantizer, + grad_input_quantizer=linear_op_ctx.grad_input_quantizer, + ) + if accumulate_into_main_grad: + grad_weight = None + + # Clear input tensor if possible + clear_tensor_data(x_local) + + return grad_input, [(), (grad_weight,)], [(), ()] + + +def fuse_backward_linear_scale( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fused backward dgrad GEMM + constant scale + + Parameters + ---------- + ops: list of tuples + Backward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 2: + out.extend(window) + + # Check if first op is constant scale + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, ConstantScale): + continue + + # Check if second op is linear + op, _ = ops[0] + if not isinstance(op, BasicLinear): + continue + if op.tensor_parallel_mode == "column": + # Column tensor-parallelism requires communication after the dgrad GEMM + continue + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = BackwardLinearScale( + scale=window[0][0], + linear=window[1][0], + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py new file mode 100644 index 000000000..448f72763 --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -0,0 +1,176 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused operation for forward GEMM + scale + add.""" + +from __future__ import annotations +from collections.abc import Iterable +from typing import Any, Optional + +import torch + +from ...fp8 import FP8GlobalStateManager +from ..basic import AddExtraInput, BasicLinear, ConstantScale +from ..op import ( + FusedOperation, + FusibleOperation, + OperationContext, +) +from ...tensor import Quantizer + + +class ForwardLinearScaleAdd(FusedOperation): + """Fused forward GEMM + scale + add + + Row tensor parallelism is not supported since that requires + communication immediately after the GEMM. + + """ + + def __init__( + self, + *, + linear: BasicLinear, + scale: ConstantScale, + add: AddExtraInput, + ) -> None: + super().__init__((linear, scale, add)) + + def fuser_forward( + self, + basic_op_ctxs: list[OperationContext], + input_: torch.Tensor, + *, + basic_op_extra_inputs: list[tuple[torch.Tensor, ...]], + prev_op_grad_output_quantizer: Optional[Quantizer], + next_op_input_quantizer: Optional[Quantizer], + basic_op_kwargs: list[dict[str, Any]], + ) -> tuple[torch.Tensor, Iterable[Iterable[torch.Tensor]]]: + + # Get basic operations + linear_op = self.basic_ops[0] + linear_op_ctx = basic_op_ctxs[0] + scale_op = self.basic_ops[1] + + # Check which grads are required + input_requires_grad = linear_op_ctx.requires_grad + weight_requires_grad = linear_op_ctx.requires_grad and linear_op.weight.requires_grad + + # FP8 metadata + input_quantizer = linear_op.get_quantizer("forward", 0) + weight_quantizer = linear_op.get_quantizer("forward", 1) + output_quantizer = None + grad_output_quantizer = linear_op.get_quantizer("backward", 0) + grad_input_quantizer = prev_op_grad_output_quantizer + with_quantized_compute = FP8GlobalStateManager.is_fp8_enabled() + + # Get extra input tensor for add operation + extra_input = basic_op_extra_inputs[2][0] + + # Get autocast dtype if needed + if torch.is_autocast_enabled(): + dtype = torch.get_autocast_dtype("cuda") + else: + dtype = linear_op.weight.dtype + + # Linear forward + output, x_local, w = BasicLinear._functional_forward( + input=input_, + weight=linear_op.weight, + alpha=scale_op.scale, + dtype=dtype, + out=extra_input, + accumulate_into_out=True, + tensor_parallel_mode=linear_op.tensor_parallel_mode, + tensor_parallel_group=linear_op.tensor_parallel_group, + sequence_parallel=linear_op.sequence_parallel, + with_quantized_compute=with_quantized_compute, + input_quantizer=input_quantizer, + weight_quantizer=weight_quantizer, + output_quantizer=output_quantizer, + input_requires_grad=input_requires_grad, + weight_requires_grad=weight_requires_grad, + ) + + # Save state for backward pass + if linear_op_ctx.requires_grad: + linear_op_ctx.save_for_backward(x_local, w) + linear_op_ctx.with_quantized_compute = with_quantized_compute + linear_op_ctx.input_quantizer = input_quantizer + linear_op_ctx.weight_quantizer = weight_quantizer + linear_op_ctx.grad_output_quantizer = grad_output_quantizer + linear_op_ctx.grad_input_quantizer = grad_input_quantizer + linear_op_ctx.dtype = dtype + linear_op_ctx.input_requires_grad = input_requires_grad + linear_op_ctx.weight_requires_grad = weight_requires_grad + + return output, [() for _ in range(len(self.basic_ops))] + + +def fuse_forward_linear_scale_add( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fuse forward GEMM + scale + add + + Parameters + ---------- + ops: list of tuples + Forward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated forward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 3: + out.extend(window) + + # Check if first op is linear + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, BasicLinear): + continue + if op.tensor_parallel_mode == "row": + # Row tensor-parallelism requires communication after the + # GEMM + continue + linear = op + op, _ = ops[0] + + # Check if next op is constant scale + if not isinstance(op, ConstantScale): + continue + scale = op + window.extend(ops[:1]) + ops = ops[1:] + op, _ = ops[0] + + # Check if next op is in-place add extra input + if not isinstance(op, AddExtraInput): + continue + if not op._in_place: + continue + add = op + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = ForwardLinearScaleAdd( + linear=linear, + scale=scale, + add=add, + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index b46bfb4b7..448c7d6c9 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -20,8 +20,10 @@ from transformer_engine.pytorch.ops.fused import ( fuse_backward_activation_bias, fuse_backward_linear_add, + fuse_backward_linear_scale, fuse_forward_linear_bias_activation, fuse_forward_linear_bias_add, + fuse_forward_linear_scale_add, fuse_userbuffers_backward_linear, fuse_userbuffers_forward_linear, ) @@ -355,6 +357,7 @@ def _fuse_forward_ops( ops = fuse_userbuffers_forward_linear(ops) ops = fuse_forward_linear_bias_add(ops) ops = fuse_forward_linear_bias_activation(ops) + ops = fuse_forward_linear_scale_add(ops) return ops @classmethod @@ -366,6 +369,7 @@ def _fuse_backward_ops( """Attempt to fuse operations in backward pass""" ops = fuse_userbuffers_backward_linear(ops) ops = fuse_backward_linear_add(ops) + ops = fuse_backward_linear_scale(ops) ops = fuse_backward_activation_bias(ops, recipe) return ops From 6ba98d439190901fec85c2ae3c2cea235d9f2196 Mon Sep 17 00:00:00 2001 From: jomitchellnv <148147880+jomitchellnv@users.noreply.github.com> Date: Fri, 15 Aug 2025 22:03:49 -0700 Subject: [PATCH 077/153] fix: fixes multi head attention for context parallel: rotary embedding to use padded cu_seq_lens (#2077) fix: fixes mha to use padded cu_seq_lens during cp Signed-off-by: Jonathan Mitchell --- .../pytorch/attention/multi_head_attention.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index f25a09fbe..9c82442af 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -907,12 +907,19 @@ def forward( q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + if pad_between_seqs: + rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded + rotary_pos_cu_seq_lens_kv = cu_seqlens_kv_padded + else: + rotary_pos_cu_seq_lens_q = cu_seqlens_q + rotary_pos_cu_seq_lens_kv = cu_seqlens_kv + query_layer = apply_rotary_pos_emb( query_layer, q_pos_emb, self.qkv_format, fused=True, - cu_seqlens=cu_seqlens_q, + cu_seqlens=rotary_pos_cu_seq_lens_q, cp_size=self.cp_size, cp_rank=self.cp_rank, interleaved=self.rotary_pos_interleaved, @@ -922,7 +929,7 @@ def forward( k_pos_emb, self.qkv_format, fused=True, - cu_seqlens=cu_seqlens_kv, + cu_seqlens=rotary_pos_cu_seq_lens_kv, cp_size=self.cp_size, cp_rank=self.cp_rank, interleaved=self.rotary_pos_interleaved, From 757fd1cfd7ef09c664df5dc854e1051665e3317f Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Mon, 18 Aug 2025 09:37:46 -0700 Subject: [PATCH 078/153] [JAX] Fix Flax variable creation when quantizers are created directly from a recipe (#2079) Fix flax variables when creating quantizers directly from a recipe Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/flax/module.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8c7135210..dc9d0209b 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -15,6 +15,8 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name +from transformer_engine.common import recipe + from ..dense import dense from ..layernorm import canonicalize_norm_type @@ -366,7 +368,9 @@ def generate_quantize_meta(quantizer_name: str): ).value return QuantizeMeta(scale=scale, amax_history=amax_history) - if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: + if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance( + fp8_recipe, recipe.DelayedScaling + ): x_meta = generate_quantize_meta("x") kernel_meta = generate_quantize_meta("kernel") grad_meta = generate_quantize_meta("grad") From 988af0fdd5b35dcfe58f6be1f81f8eeef9d9bf21 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 18 Aug 2025 10:19:29 -0700 Subject: [PATCH 079/153] Update list of authorized CI users (#2078) * Update list of authorized CI users Signed-off-by: Tim Moon * Update .github/workflows/trigger-ci.yml Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Tim Moon Signed-off-by: Kirthi Shankar Sivamani Co-authored-by: Kirthi Shankar Sivamani --- .github/workflows/trigger-ci.yml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 66400ffd7..dc2b98e07 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -53,7 +53,8 @@ jobs: || github.actor == 'lhb8125' || github.actor == 'kunlunl' || github.actor == 'pstjohn' - || github.actor == 'mk-61' + || github.actor == 'vcherepanov-nv' + || github.actor == 'tdophung' ) steps: - name: Check if comment is issued by authorized person From 0e3e270fc4474b500d3d57f1a174e12364c11870 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 19 Aug 2025 01:46:20 +0800 Subject: [PATCH 080/153] [PyTorch] Check if the given recipe is supported in `fp8_autocast` (#2073) * check if the given recipe is supported in fp8_autocast Signed-off-by: Xin Yao * resolve comments Signed-off-by: Xin Yao * check only when enabled Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/fp8.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/fp8.py b/transformer_engine/pytorch/fp8.py index c74fc3759..8f9dbd88d 100644 --- a/transformer_engine/pytorch/fp8.py +++ b/transformer_engine/pytorch/fp8.py @@ -64,14 +64,26 @@ def check_fp8_block_scaling_support() -> Tuple[bool, str]: return False, "FP8 block scaled GEMM requires Hopper and CUDA >= 12.9." +def check_recipe_support(recipe: Recipe) -> None: + """Check if the given recipe is supported.""" + recipe_supported = True + unsupported_reason = "" + if isinstance(recipe, (DelayedScaling, Float8CurrentScaling)): + recipe_supported, unsupported_reason = check_fp8_support() + elif isinstance(recipe, Float8BlockScaling): + recipe_supported, unsupported_reason = check_fp8_block_scaling_support() + elif isinstance(recipe, MXFP8BlockScaling): + recipe_supported, unsupported_reason = check_mxfp8_support() + assert recipe_supported, unsupported_reason + + def get_default_fp8_recipe() -> Recipe: """FP8 recipe with default args.""" if check_mxfp8_support()[0]: - # This is a temporary restriction until MXFP8 is supported for all - # gemm layouts. - if get_device_compute_capability() >= (12, 0): - return Float8BlockScaling() return MXFP8BlockScaling() + if get_device_compute_capability() >= (12, 0): + # This is a temporary restriction until MXFP8 is supported for all gemm layouts. + return Float8CurrentScaling() return DelayedScaling() @@ -648,6 +660,8 @@ def fp8_autocast( distributed group over which amaxes for the fp8 tensors are reduced at the end of each training step. """ + if enabled: + check_recipe_support(fp8_recipe) fp8_state = FP8GlobalStateManager.get_fp8_autocast_state() FP8GlobalStateManager.fp8_autocast_enter( enabled=enabled, From 3fc1e4bf8d46850f215698b39b5625310194bab9 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 18 Aug 2025 13:56:09 -0400 Subject: [PATCH 081/153] [JAX] Fix for TE GEMM - Always AllGather RHS non-contracting dims with FSDP axis (#2075) * fix fsdp Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index be2dfabb3..9975f558b 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -476,21 +476,24 @@ def _parse_operand_output_specs( lhs_cspecs = tuple(s if s == reduce_spec else None for s in lhs_cspecs) rhs_cspecs = tuple(s if s == reduce_spec else None for s in rhs_cspecs) - # Non-batched non-contracting dims of RHS needs to be unsharded (i.e. FSDP) - # Check if spec is not the batch-dim is not needed as rhs_non_cspecs never includes batch-dim - # rhs_specs only includes batch-dim in the Wgrad GEMM, but there batch-dim belongs to rhs_cspecs + # Non-contracting dims of RHS always needs to be gathered, i.e. for TP + activation_hidden + # No batch-dim check needed as `rhs_non_cspecs` never contains batch-dim. + # In `rhs_specs`, the batch dim appears only in Wgrad GEMM under `rhs_cspecs`. rhs_non_cspecs = tuple( None if spec in lhs_non_cspecs else spec for spec in rhs_non_cspecs ) + else: # Otherwise, require contracting dims of both operands to be unsharded lhs_cspecs = (None,) * len(lhs_cspecs) rhs_cspecs = (None,) * len(rhs_cspecs) - # Non-batched non-contracting dims of LHS to be unsharded, i.e gather SP dim - # The spec for batch_dim in lhs_non_cspecs won't ever appear in the rhs_non_cspecs as - # rhs_non_cspecs never has batch-dim. Hence, spec for batch_dim of lhs_non_cspecs won't be - # overwrite + # Non-contracting dims of RHS always needs to be gathered along the FSDP axis + rhs_non_cspecs = tuple( + None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs + ) + + # Non-contracting dims of LHS to be gathered along the SP axis. # Minor note: This causes MaxText TP (= Megatron TP + activation_hidden sharding) gathering x for # dW1 = x^T * dY1 which is unexpected. This is a known issue and no solution has found yet. lhs_non_cspecs = tuple(None if spec in rhs_non_cspecs else spec for spec in lhs_non_cspecs) From 734bcedd9d86e4be30ce44f1ef67af5f69f3670d Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 18 Aug 2025 16:25:25 -0700 Subject: [PATCH 082/153] Changed VERSION to 2.8.0.dev0 Signed-off-by: Przemek Tredak --- build_tools/VERSION.txt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/build_tools/VERSION.txt b/build_tools/VERSION.txt index ba610dcf0..81006d78c 100644 --- a/build_tools/VERSION.txt +++ b/build_tools/VERSION.txt @@ -1 +1 @@ -2.7.0.dev0 +2.8.0.dev0 From 1d075c0682a3b10b2ee7a381adbd8f0ab50f7664 Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 19 Aug 2025 15:29:22 -0700 Subject: [PATCH 083/153] Add user to TE CI (#2089) Update trigger-ci.yml Signed-off-by: Kirthi Shankar Sivamani --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index dc2b98e07..6d7410619 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -55,6 +55,7 @@ jobs: || github.actor == 'pstjohn' || github.actor == 'vcherepanov-nv' || github.actor == 'tdophung' + || github.actor == 'vthumbe1503' ) steps: - name: Check if comment is issued by authorized person From 5b4d89c3227fa31743bdb186d25f646e4636c668 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 19 Aug 2025 18:48:51 -0700 Subject: [PATCH 084/153] Add backward RMSNorm+Add fusion (#2028) * Add rmsnorm_bwd_add Signed-off-by: Jan Bielak * Add BackwardAddRMSNorm fused operation Signed-off-by: Jan Bielak * Try to optimize register usage in kernels Signed-off-by: Jan Bielak * Add separate BackwardAdd stage for the fused backward add Signed-off-by: Jan Bielak --------- Signed-off-by: Jan Bielak --- tests/cpp/operator/test_normalization.cu | 96 ++++++++---- tests/cpp/operator/test_normalization.h | 6 +- tests/cpp/test_common.cu | 11 +- tests/pytorch/test_fusible_ops.py | 89 +++++++++++ .../transformer_engine/normalization.h | 39 +++-- .../common/normalization/common.cpp | 15 +- .../common/normalization/common.h | 19 ++- .../common/normalization/layernorm/ln_api.cpp | 3 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 85 ++++++++++- .../rmsnorm/rmsnorm_bwd_kernels.cuh | 55 ++++++- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 139 +++++++++++++++--- transformer_engine/pytorch/csrc/extensions.h | 5 + .../pytorch/csrc/extensions/normalization.cpp | 46 ++++++ .../pytorch/csrc/extensions/pybind.cpp | 2 + .../pytorch/ops/fused/__init__.py | 4 + .../pytorch/ops/fused/backward_add_rmsnorm.py | 133 +++++++++++++++++ transformer_engine/pytorch/ops/fuser.py | 2 + 17 files changed, 667 insertions(+), 82 deletions(-) create mode 100644 transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py diff --git a/tests/cpp/operator/test_normalization.cu b/tests/cpp/operator/test_normalization.cu index e9a125968..20ad38ca2 100644 --- a/tests/cpp/operator/test_normalization.cu +++ b/tests/cpp/operator/test_normalization.cu @@ -27,10 +27,19 @@ namespace { template void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, - NormType norm_type, bool use_cudnn, const bool zero_centered_gamma_in_weight_dtype) { + const NormType norm_type, const bool use_cudnn, + const bool zero_centered_gamma_in_weight_dtype, const bool fused_bwd_add) { if (sizeof(InputType) < sizeof(OutputType)) { GTEST_SKIP() << "LN kernel does not support OutputType > InputType"; - return; + } + + if (norm_type == LayerNorm && fused_bwd_add) { + GTEST_SKIP() << "Fused LN backward+add not currently supported"; + } + + if (fused_bwd_add && zero_centered_gamma_in_weight_dtype) { + GTEST_SKIP() << "zero_centered_gamma_in_weight_dtype not currently supported " + << "in fused norm backward+add"; } if (getDeviceComputeCapability() < hopperComputeCapability && use_cudnn) { @@ -45,7 +54,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, if ((itype == DType::kBFloat16 && otype == DType::kFloat16) || (itype == DType::kFloat16 && otype == DType::kBFloat16)) { GTEST_SKIP() << "LN kernel does not support mixing Float16 and BFloat16"; - return; } Tensor input("input", std::vector{ N, H }, itype); @@ -55,6 +63,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, Tensor mu("mu", std::vector{ N }, DType::kFloat32); Tensor rsigma("rsigma", std::vector{ N }, DType::kFloat32); Tensor dz("dz", std::vector{ N, H }, wtype); + Tensor bwd_add("bwd_add", std::vector{ N, H }, wtype); Tensor dx("dx", std::vector{ N, H }, itype); Tensor dgamma("dgamma", std::vector{ H }, wtype); Tensor dbeta("dbeta", std::vector{ H }, wtype); @@ -65,6 +74,11 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, fillUniform(&beta); setRandomScale(&z); fillUniform(&dz); + if (fused_bwd_add) { + fillUniform(&bwd_add); + } else { + fillCase(&bwd_add, zeros); + } std::unique_ptr ref_output = std::make_unique(N * H); std::unique_ptr ref_mu = std::make_unique(N); @@ -85,7 +99,6 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, nvte_enable_cudnn_norm_fwd(true); nvte_enable_cudnn_norm_bwd(true); - // Zero-centered gamma in weight dtype only supported by CuDNN backend currently if (zero_centered_gamma_in_weight_dtype) { nvte_enable_zero_centered_gamma_in_weight_dtype(true); @@ -125,15 +138,23 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, z.data(), rsigma.data(), workspace_fwd.data(), prop.multiProcessorCount, zero_centered_gamma, 0); - nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), - workspace_bwd.data(), - prop.multiProcessorCount, zero_centered_gamma, 0); - workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); - nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), - dx.data(), dgamma.data(), - workspace_bwd.data(), - prop.multiProcessorCount, zero_centered_gamma, 0); + if (fused_bwd_add) { + nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); + nvte_rmsnorm_bwd_add(dz.data(), input.data(), bwd_add.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), + prop.multiProcessorCount, zero_centered_gamma, 0); + } else { + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount, + zero_centered_gamma, 0); + workspace_bwd = Tensor("workspace", workspace_bwd.rowwise_shape(), workspace_bwd.dtype()); + nvte_rmsnorm_bwd(dz.data(), input.data(), rsigma.data(), gamma.data(), + dx.data(), dgamma.data(), workspace_bwd.data(), prop.multiProcessorCount, + zero_centered_gamma, 0); + } } if (use_cudnn){ @@ -167,6 +188,7 @@ void performTest(const size_t N, const size_t H, const bool zero_centered_gamma, use_cudnn, zero_centered_gamma_in_weight_dtype); compute_ref_backward(norm_type, dz.rowwise_cpu_dptr(), + bwd_add.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), mu.rowwise_cpu_dptr(), rsigma.rowwise_cpu_dptr(), gamma.rowwise_cpu_dptr(), @@ -214,30 +236,40 @@ std::vector> test_cases = { } // namespace class NormTestSuite : public ::testing::TestWithParam, - bool, - bool>> {}; + NormType, + transformer_engine::DType, + transformer_engine::DType, + std::pair, + bool, + bool, + bool>> {}; TEST_P(NormTestSuite, TestNorm) { - using namespace transformer_engine; - using namespace test; + using namespace transformer_engine; + using namespace test; const bool use_cudnn = std::get<0>(GetParam()); const NormType norm_type = std::get<1>(GetParam()); - const DType input_type = std::get<2>(GetParam()); - const DType output_type = std::get<3>(GetParam()); - const auto size = std::get<4>(GetParam()); - const bool zero_centered_gamma = std::get<5>(GetParam()); - const bool cudnn_zero_centered_gamm_in_weight_dtype = std::get<6>(GetParam()); - - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, - performTest(size.first, size.second, zero_centered_gamma, norm_type, use_cudnn, cudnn_zero_centered_gamm_in_weight_dtype); + const DType input_type = std::get<2>(GetParam()); + const DType output_type = std::get<3>(GetParam()); + const auto size = std::get<4>(GetParam()); + const bool zero_centered_gamma = std::get<5>(GetParam()); + const bool cudnn_zero_centered_gamma_in_weight_dtype = std::get<6>(GetParam()); + const bool fused_bwd_add = std::get<7>(GetParam()); + + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType, + performTest( + size.first, + size.second, + zero_centered_gamma, + norm_type, + use_cudnn, + cudnn_zero_centered_gamma_in_weight_dtype, + fused_bwd_add ); ); + ); } INSTANTIATE_TEST_SUITE_P( @@ -250,6 +282,7 @@ INSTANTIATE_TEST_SUITE_P( ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16, DType::kFloat8E4M3), ::testing::ValuesIn(test_cases), ::testing::Values(false, true), + ::testing::Values(false, true), ::testing::Values(false, true)), [](const testing::TestParamInfo& info) { auto backend = std::get<0>(info.param) == false ? "Te" : "Cudnn"; @@ -261,6 +294,7 @@ INSTANTIATE_TEST_SUITE_P( std::to_string(std::get<4>(info.param).first) + "X" + std::to_string(std::get<4>(info.param).second) + "X" + std::to_string(std::get<5>(info.param)) + "X" + - std::to_string(std::get<6>(info.param)); + std::to_string(std::get<6>(info.param)) + "X" + + std::to_string(std::get<7>(info.param)); return name; }); diff --git a/tests/cpp/operator/test_normalization.h b/tests/cpp/operator/test_normalization.h index f8dfb9f6e..fe69852d0 100644 --- a/tests/cpp/operator/test_normalization.h +++ b/tests/cpp/operator/test_normalization.h @@ -126,7 +126,8 @@ void compute_ref_output(NormType norm_type, template -void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, const InputType *data, +void compute_ref_backward(const NormType norm_type, const OutputType *output_grad, + const OutputType *add, const InputType *data, const float *mu, const float *rsigma, const InputType *gamma, InputType *data_grad, @@ -165,7 +166,8 @@ void compute_ref_backward(const NormType norm_type, const OutputType *output_gra compute_t g = compute_gamma(gamma[j], zero_centered_gamma, use_cudnn, cudnn_zero_centered_gamma_in_weight_dtype); const compute_t dz = static_cast(output_grad[i * H + j]); const compute_t dy = g * dz; - const compute_t dx = rsigma[i] * (dy - mdyy * y - mdy); + const compute_t a = static_cast(add[i * H + j]); + const compute_t dx = a + rsigma[i] * (dy - mdyy * y - mdy); data_grad[i * H + j] = static_cast(dx); } } diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 187742c39..f974d9083 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -844,9 +844,18 @@ void fillCase(Tensor *t, const InputsFillCase fill_case) { } } +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); +template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); template void fillCase(Tensor *t, const InputsFillCase fill_case); -template void fillCase(Tensor *t, const InputsFillCase fill_case); +#if FP4_TYPE_SUPPORTED +template void fillCase(Tensor *t, const InputsFillCase fill_case); +#endif void setRandomScale(Tensor *t) { std::uniform_real_distribution<> dis(-2.0, 1.0); diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 0164c0446..1b9d9acbf 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -21,6 +21,7 @@ import transformer_engine.pytorch.ops as te_ops from transformer_engine.pytorch.ops.fused import ( BackwardActivationBias, + BackwardAddRMSNorm, BackwardLinearAdd, BackwardLinearScale, ForwardLinearBiasActivation, @@ -2206,6 +2207,94 @@ def test_backward_activation_bias( torch.testing.assert_close(dx_test, x_ref.grad, **tols) torch.testing.assert_close(db_test, b_ref.grad, **tols) + @pytest.mark.parametrize("weight_shape", ((19,), (64,))) + @pytest.mark.parametrize("in_shape", ((-1,), (6, 16, -1))) + @pytest.mark.parametrize("dtype", _dtypes) + @pytest.mark.parametrize("zero_centered_gamma", (False, True)) + def test_backward_add_rmsnorm( + self, + *, + weight_shape: Iterable[int], + in_shape: Iterable[int], + dtype: torch.dtype, + device: torch.device = "cuda", + eps: float = 0.3, + zero_centered_gamma: bool, + ) -> None: + """Fused backward RMNorm + add""" + + # Make input and weight shapes consistent + in_shape = list(in_shape)[:-1] + list(weight_shape) + + # Random data + x_ref, x_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + ) + w_ref, w_test = make_reference_and_test_tensors( + weight_shape, + test_dtype=dtype, + test_device=device, + ) + dy1_ref, dy1_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + dy2_ref, dy2_test = make_reference_and_test_tensors( + in_shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) + + # Plain PyTorch implementation + inner_dims = tuple(range(len(in_shape) - len(weight_shape), len(in_shape))) + var_ref = x_ref.square().sum(dim=inner_dims, keepdim=True) / math.prod(weight_shape) + if zero_centered_gamma: + y1_ref = x_ref / torch.sqrt(eps + var_ref) * (1 + w_ref) + else: + y1_ref = x_ref / torch.sqrt(eps + var_ref) * w_ref + y2_ref = x_ref + (y1_ref * dy1_ref + y2_ref * dy2_ref).sum().backward() + + # Implementation with fusible operations + model = te_ops.Sequential( + te_ops.MakeExtraOutput(), + te_ops.RMSNorm( + weight_shape, + eps=eps, + device=device, + dtype=dtype, + zero_centered_gamma=zero_centered_gamma, + ), + ) + with torch.no_grad(): + model[1].weight.copy_(w_test) + del w_test + y1_test, y2_test = model(x_test) + (y1_test * dy1_test + y2_test * dy2_test).sum().backward() + + # Check that backward operations have been fused + backward_ops = model._module_groups[0]._backward_ops + assert len(backward_ops) == 1 + assert isinstance(backward_ops[0][0], BackwardAddRMSNorm) + + # Expected numerical error + tols = dtype_tols(dtype) + + # Check results + y1_test = y1_test.to(dtype=torch.float64, device="cpu") + y2_test = y2_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") + dw_test = model[1].weight.grad.to(dtype=torch.float64, device="cpu") + torch.testing.assert_close(y1_test, y1_ref, **tols) + torch.testing.assert_close(y2_test, y2_ref, **tols) + torch.testing.assert_close(dx_test, x_ref.grad, **tols) + torch.testing.assert_close(dw_test, w_ref.grad, **tols) + @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) def test_backward_linear_add( diff --git a/transformer_engine/common/include/transformer_engine/normalization.h b/transformer_engine/common/include/transformer_engine/normalization.h index 9c194e9da..651ae87b4 100644 --- a/transformer_engine/common/include/transformer_engine/normalization.h +++ b/transformer_engine/common/include/transformer_engine/normalization.h @@ -24,7 +24,7 @@ extern "C" { * y = \frac{x - E[x]}{\sqrt{Var[x] + \varepsilon}} \gamma + \beta * @f] * - * Calling this function with workspace set to empty tensor will not perform the operation, + * Calling this function with workspace set to an empty tensor will not perform the operation, * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] x Input tensor of shape [N, H]. @@ -55,8 +55,8 @@ void nvte_layernorm_fwd(const NVTETensor x, const NVTETensor gamma, const NVTETe * else * with respect to \f$x\f$, \f$\gamma\f$ and \f$\beta\f$. * - * Calling this function with workspace set to empty tensor will not perform the operation, - * but instead set the shape and type of these tensors to the required values. + * Calling this function with workspace set to an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] dz Incoming gradient tensor of shape [N, H]. * \param[in] x Forward input tensor of shape [N, H]. @@ -90,9 +90,8 @@ void nvte_layernorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETenso * RMS_\varepsilon(x) = \sqrt{\frac{1}{n}\sum_{i=0}^{n-1} x_i^2 + \varepsilon} * @f] * - * Calling this function with workspace and barrier set to empty tensor will not - * perform the operation, but instead set the shape and type of the workspace - * and barrier tensors to the required values. + * Calling this function with workspace set to an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] x Input tensor of shape [N, H]. * \param[in] gamma Gamma tensor of shape [H]. @@ -121,9 +120,8 @@ void nvte_rmsnorm_fwd(const NVTETensor x, const NVTETensor gamma, const float ep * @f] * with respect to \f$x\f$ and \f$gamma\f$. * - * Calling this function with workspace, barrier, dgamma_part set - * to empty tensor will not perform the operation, but instead set the shape and type - * of these tensors to the required values. + * Calling this function with workspace set to an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. * * \param[in] dz Incoming gradient tensor of shape [N, H]. * \param[in] x Forward input tensor of shape [N, H]. @@ -142,6 +140,29 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, const NVTETensor x, const NVTETensor NVTETensor workspace, const int multiprocessorCount, const bool zero_centered_gamma, cudaStream_t stream); +/*! \brief Compute backward of RMSNorm and add additional tensor to output gradient + * + * Calling this function with workspace set to an empty tensor will not perform the operation, + * but instead set the shape and type of the workspace tensor to the required values. + * + * \param[in] dz Incoming gradient tensor of shape [N, H]. + * \param[in] x Forward input tensor of shape [N, H]. + * \param[in] add Additional tensor to add to output gradient [N, H]. + * \param[in] rsigma Reciprocal of the root mean square of the input + * calculated over the last dimension. Shape: [N]. + * \param[in] gamma Gamma tensor of shape [H]. + * \param[out] dx Output gradient of shape [N, H]. + * \param[out] dgamma Gradient for gamma tensor of shape [H]. + * \param[out] workspace Workspace tensor. + * \param[in] multiprocessorCount Number of SMs in the device. + * \param[in] zero_centered_gamma Multiply normalized values by @f$ \gamma+1 @f$ instead of @f$ \gamma @f$ + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_rmsnorm_bwd_add(const NVTETensor dz, const NVTETensor x, const NVTETensor add, + const NVTETensor rsigma, const NVTETensor gamma, NVTETensor dx, + NVTETensor dgamma, NVTETensor workspace, const int multiprocessorCount, + const bool zero_centered_gamma, cudaStream_t stream); + /*! \brief Helper to enable cuDNN backend for normalization * * \param[in] bool Enable if True diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index 9df81a917..c280c1c35 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -156,7 +156,7 @@ void TeNormalizationPlan::_set_workspace() { template <> void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, - void* dx_dptr, void* dz_dptr, + void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) { NVTE_ERROR("Forward normalization should not call the backward execute function!"); @@ -166,8 +166,9 @@ template <> void TeNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, void* dz_dptr, - void* dbeta_dptr, void* dgamma_dptr, - void* workspace_dptr, cudaStream_t stream) { + void* add_dptr, void* dbeta_dptr, + void* dgamma_dptr, void* workspace_dptr, + cudaStream_t stream) { _launch_params.stream = stream; auto& kernel_params = _launch_params.params; @@ -177,6 +178,7 @@ void TeNormalizationPlan::execute(void* x_dptr, void* gamm kernel_params.rs = rsigma_dptr; kernel_params.dx = dx_dptr; kernel_params.dz = dz_dptr; + kernel_params.add = add_dptr; kernel_params.dgamma = dgamma_dptr; if (_is_layernorm) { @@ -447,8 +449,11 @@ void CudnnNormalizationPlan::execute(Tensor* z, void* x_dptr, void* gamma_dptr, void CudnnNormalizationPlan::execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, void* dz_dptr, - void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, - cudaStream_t stream) { + void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) { + // cuDNN does not currently support fused backward+add + NVTE_CHECK(add_dptr == nullptr); + // Binding data pointers to graph tensors _variant_pack = { {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}}; diff --git a/transformer_engine/common/normalization/common.h b/transformer_engine/common/normalization/common.h index 0ec16046e..37144052a 100644 --- a/transformer_engine/common/normalization/common.h +++ b/transformer_engine/common/normalization/common.h @@ -126,6 +126,9 @@ struct BackwardKernelParams : public KernelParamsBase { // Input: gradient wrt. LN FWD output. void* dz; + // Input: extra tensor to add for fused backward+add + void* add; + // Workspace for Wgrad pre-reduction. void* dbeta_part; void* dgamma_part; @@ -137,8 +140,10 @@ struct BackwardKernelParams : public KernelParamsBase { void* dgamma; }; +using BackwardAddKernelParams = BackwardKernelParams; + enum class NVTE_Norm_Backend { Te, Cudnn }; -enum class NVTE_Norm_Stage { Forward, Backward }; +enum class NVTE_Norm_Stage { Forward, Backward, BackwardAdd }; using TupleKeyType = std::tuple; struct TupleHash { @@ -221,8 +226,8 @@ class NormalizationPlanBase { cudaStream_t stream) = 0; virtual void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, - void* dx_dptr, void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, - void* workspace_dptr, cudaStream_t stream) = 0; + void* dx_dptr, void* dz_dptr, void* add_dptr, void* dbeta_dptr, + void* dgamma_dptr, void* workspace_dptr, cudaStream_t stream) = 0; private: virtual void _build() = 0; @@ -241,8 +246,8 @@ class TeNormalizationPlan : public NormalizationPlanBase { cudaStream_t stream) override; void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, - void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, - cudaStream_t stream) override; + void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) override; private: void _set_workspace(); @@ -270,8 +275,8 @@ class CudnnNormalizationPlan : public NormalizationPlanBase { cudaStream_t stream) override; void execute(void* x_dptr, void* gamma_dptr, void* mean_dptr, void* rsigma_dptr, void* dx_dptr, - void* dz_dptr, void* dbeta_dptr, void* dgamma_dptr, void* workspace_dptr, - cudaStream_t stream) override; + void* dz_dptr, void* add_dptr, void* dbeta_dptr, void* dgamma_dptr, + void* workspace_dptr, cudaStream_t stream) override; private: void _build() override; diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index cf5678e40..af19300a9 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -185,7 +185,8 @@ void layernorm_bwd(const Tensor& dz, const Tensor& x, const Tensor& mu, const Te } else { NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); plan->execute(x.data.dptr, gamma.data.dptr, mu.data.dptr, rsigma.data.dptr, dx->data.dptr, - dz.data.dptr, dbeta->data.dptr, dgamma->data.dptr, workspace->data.dptr, stream); + dz.data.dptr, nullptr /*add*/, dbeta->data.dptr, dgamma->data.dptr, + workspace->data.dptr, stream); } return; } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 499c0ef69..1aae72e15 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -162,7 +162,74 @@ void rmsnorm_bwd(const Tensor &dz, const Tensor &x, const Tensor &rsigma, const } else { NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr, - dz.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, workspace->data.dptr, stream); + dz.data.dptr, nullptr /*add*/, nullptr /*dbeta*/, dgamma->data.dptr, + workspace->data.dptr, stream); + } + return; +} + +void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const Tensor &rsigma, + const Tensor &gamma, Tensor *dx, Tensor *dgamma, Tensor *workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + using namespace transformer_engine; + + NVTE_CHECK(dz.data.dtype == gamma.data.dtype); + NVTE_CHECK(add.data.dtype == gamma.data.dtype); + NVTE_CHECK(rsigma.data.dtype == DType::kFloat32); + + NVTE_CHECK(x.data.shape.size() == 2); + NVTE_CHECK(dz.data.shape == x.data.shape); + NVTE_CHECK(add.data.shape == x.data.shape); + + NVTE_CHECK(gamma.data.shape[0] == x.data.shape[1]); + + NVTE_CHECK(dx->data.shape == x.data.shape); + NVTE_CHECK(dx->data.dtype == x.data.dtype); + + NVTE_CHECK(dgamma->data.shape == gamma.data.shape); + NVTE_CHECK(dgamma->data.dtype == gamma.data.dtype); + + if (!workspace->data.shape.empty()) { + CheckInputTensor(dz, "dz"); + CheckInputTensor(x, "x"); + CheckInputTensor(add, "add"); + CheckInputTensor(rsigma, "rsigma"); + CheckInputTensor(gamma, "gamma"); + CheckOutputTensor(*dx, "dx"); + CheckOutputTensor(*dgamma, "dgamma"); + } + + // cuDNN does not currently support fused backward+add + NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te; + + // TE backend does not currently support zero_centered_gamma_in_weight_dtype + NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(), + "zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add"); + + bool is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, dgamma->data.dptr, add.data.dptr); + bool gamma_in_weight_dtype = false; + + auto plan = NormalizationPlanRegistry::getInstance().getNormalizationPlan( + norm_backend, NVTE_Norm_Type::RMSNorm, NVTE_Norm_Stage::BackwardAdd, + gamma.data.dtype, // wtype + x.data.dtype, // itype + gamma.data.dtype, // otype + x.data.shape[0], // batch_size + x.data.shape[1], // hidden_size + multiprocessorCount, zero_centered_gamma, is_aligned, NVTE_DELAYED_TENSOR_SCALING, true, + gamma_in_weight_dtype); + + if (workspace->data.shape.empty()) { + workspace->data.shape = plan->getWorkspaceShape(); + workspace->data.dtype = DType::kByte; + return; + } else { + NVTE_CHECK(workspace->data.shape == plan->getWorkspaceShape()); + plan->execute(x.data.dptr, gamma.data.dptr, nullptr /*mu*/, rsigma.data.dptr, dx->data.dptr, + dz.data.dptr, add.data.dptr, nullptr /*dbeta*/, dgamma->data.dptr, + workspace->data.dptr, stream); } return; } @@ -195,3 +262,19 @@ void nvte_rmsnorm_bwd(const NVTETensor dz, // Nxhidden_size convertNVTETensor(dx), convertNVTETensor(dgamma), convertNVTETensor(workspace), multiprocessorCount, zero_centered_gamma, stream); } + +void nvte_rmsnorm_bwd_add(const NVTETensor dz, // Nxhidden_size + const NVTETensor x, // Nxhidden_size + const NVTETensor add, // Nxhidden_size + const NVTETensor rsigma, // N, FP32! + const NVTETensor gamma, // hidden_size + NVTETensor dx, NVTETensor dgamma, NVTETensor workspace, + const int multiprocessorCount, const bool zero_centered_gamma, + cudaStream_t stream) { + NVTE_API_CALL(nvte_rmsnorm_bwd_add); + using namespace transformer_engine; + rmsnorm_bwd_add(*convertNVTETensorCheck(dz), *convertNVTETensorCheck(x), + *convertNVTETensorCheck(add), *convertNVTETensorCheck(rsigma), + *convertNVTETensorCheck(gamma), convertNVTETensor(dx), convertNVTETensor(dgamma), + convertNVTETensor(workspace), multiprocessorCount, zero_centered_gamma, stream); +} diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh index 5d8a5b765..3f3cdd065 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_kernels.cuh @@ -7,13 +7,31 @@ #ifndef TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ #define TRANSFORMER_ENGINE_COMMON_RMSNORM_RMSNORM_BWD_KERNELS_CUH_ +#include + #include "../../utils.cuh" #include "../common.h" namespace transformer_engine { namespace normalization { -template +struct maybe_not_t {}; + +template +using maybe_t = std::conditional_t; + +template +union dx_add_t { + using add_t = maybe_t; + using dx_t = Ivec; + struct { + char _padding[sizeof(dx_t) > sizeof(add_t) ? sizeof(dx_t) - sizeof(add_t) : 0]; + add_t add; + }; + dx_t dx; +}; + +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_kernel( BackwardKernelParams params) { enum { ROWS_PER_CTA = Ktraits::ROWS_PER_CTA }; @@ -111,10 +129,19 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke } } + dx_add_t temp[LDGS]; + if constexpr (FusedAdd) { + idx = row * Ktraits::VEC_COLS + c; +#pragma unroll + for (int it = 0; it < LDGS; it++) { + temp[it].add.load_from(params.add, idx); + idx += Ktraits::VEC_COLS_PER_LDG; + } + } + reduce_t result = reducer.allreduce({0, mdyy_local}, sum); mdyy_local = Get<1>::of(result) * rn; - Ivec dx[LDGS]; idx = row * Ktraits::VEC_COLS + c; #pragma unroll for (int it = 0; it < LDGS; it++) { @@ -123,9 +150,13 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_tuned_ke compute_t dy_tmp = dy[it * NUM_ELTS + jt]; compute_t y_tmp = y[it * NUM_ELTS + jt]; compute_t dx_tmp = rs_r * (dy_tmp - (mdyy_local * y_tmp)); - dx[it].data.elt[jt] = dx_tmp; + if constexpr (FusedAdd) { + compute_t add_tmp = temp[it].add.data.elt[jt]; + dx_tmp += add_tmp; + } + temp[it].dx.data.elt[jt] = dx_tmp; } - dx[it].store_to(params.dx, idx); + temp[it].dx.store_to(params.dx, idx); idx += Ktraits::VEC_COLS_PER_LDG; } } // end: grid stride loop @@ -274,7 +305,7 @@ __global__ __launch_bounds__(Kernel_traits::THREADS_PER_CTA) void rmsnorm_bwd_fi } } -template +template __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_kernel( BackwardKernelParams params) { enum { LDGS = Ktraits::LDGS }; @@ -379,14 +410,22 @@ __global__ __launch_bounds__(Ktraits::THREADS_PER_CTA) void rmsnorm_bwd_general_ #pragma unroll for (int it = 0, col = gidn * NUM_ELTS; it < LDGS && row < params.rows && col < params.cols; it++, col += gdimn * NUM_ELTS) { - Ivec dx; + dx_add_t temp; + if constexpr (FusedAdd) { + temp.add.load_from_elts(params.add, row * params.cols + col, params.cols - col); + } #pragma unroll for (int jt = 0; jt < NUM_ELTS; jt++) { compute_t dy_ij = dy[it].data.elt[jt]; compute_t y_ij = y[it].data.elt[jt]; - dx.data.elt[jt] = rs * (dy_ij - (mdyy * y_ij)); + compute_t dx_ij = rs * (dy_ij - (mdyy * y_ij)); + if constexpr (FusedAdd) { + compute_t add_ij = temp.add.data.elt[jt]; + dx_ij += add_ij; + } + temp.dx.data.elt[jt] = dx_ij; } - dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); + temp.dx.store_to_elts(params.dx, row * params.cols + col, params.cols - col); } } diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index fb5741b35..0a7b38000 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -12,17 +12,17 @@ using namespace transformer_engine::normalization; template + int BYTES_PER_LDG_MAIN, int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false> void launch_tuned_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_bwd_tuned_kernel; + auto kernel = &rmsnorm_bwd_tuned_kernel; if (configure_params) { - int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + int ctas_per_sm = 0; + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -52,9 +52,9 @@ void launch_tuned_(LaunchParams &launch_params, dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES, stream)); } using Kernel_traits_f = @@ -69,7 +69,7 @@ void launch_tuned_(LaunchParams &launch_params, template + int BYTES_PER_LDG_FINAL, bool FUSED_ADD = false> void launch_general_(LaunchParams &launch_params, const bool configure_params) { // NOLINT(*) auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; @@ -77,7 +77,7 @@ void launch_general_(LaunchParams &launch_params, // Instantiate kernel using Kernel_traits = Kernel_traits; - auto kernel = &rmsnorm_bwd_general_kernel; + auto kernel = &rmsnorm_bwd_general_kernel; // Configure kernel params const int rows = launch_params.params.rows; @@ -85,9 +85,9 @@ void launch_general_(LaunchParams &launch_params, int ctas_per_col = launch_params.params.ctas_per_col; int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { - int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); + int ctas_per_sm = 0; + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -112,8 +112,8 @@ void launch_general_(LaunchParams &launch_params, kernel<<>>(launch_params.params); } else { void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } // Launch finalization kernel @@ -143,7 +143,7 @@ void launch_general_(LaunchParams &launch_params, norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE); \ } // namespace -// Create rmsnorm tuned launch function and register. Macro signature: +// Create rmsnorm bwd tuned launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL @@ -171,7 +171,7 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp32, fp32, fp32, fp32, 1 REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4); -// Create rmsnorm general launch function and register. Macro signature: +// Create rmsnorm bwd general launch function and register. Macro signature: // HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... // WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL @@ -204,3 +204,108 @@ REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp16, fp16, fp32, REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4); REGISTER_NORM_LAUNCHER(RMSNorm, Backward, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4); + +// Create fused rmsnorm bwd + add tuned launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, CTAS_PER_ROW, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 512, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 768, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp32, fp32, fp32, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, fp16, fp16, fp16, fp32, 1, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 1024, bf16, bf16, bf16, fp32, 1, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 2048, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 4096, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp32, fp32, fp32, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, fp16, fp16, fp16, fp32, 1, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, tuned, 8192, bf16, bf16, bf16, fp32, 1, 1, 4, 16, 4, + true); + +// Create fused rmsnorm bwd + add general launch function and register. Macro signature: +// HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, ... +// WARPS_M, WARPS_N, BYTES_PER_LDG, BYTES_PER_LDG_FINAL + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp32, fp32, fp32, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp16, fp16, fp16, fp32, 4, 1, 8, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, fp16, fp32, fp16, fp32, 4, 1, 8, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, bf16, bf16, bf16, fp32, 4, 1, 8, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 128, bf16, fp32, bf16, fp32, 4, 1, 8, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp32, fp32, fp32, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp16, fp16, fp16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, fp16, fp32, fp16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, bf16, bf16, bf16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 512, bf16, fp32, bf16, fp32, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp32, fp32, fp32, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp16, fp16, fp16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, fp16, fp32, fp16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, bf16, bf16, bf16, fp32, 4, 1, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 1024, bf16, fp32, bf16, fp32, 4, 1, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp32, fp32, fp32, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp16, fp16, fp16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, fp16, fp32, fp16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, bf16, bf16, bf16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 2048, bf16, fp32, bf16, fp32, 1, 4, 16, 4, + true); + +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp32, fp32, fp32, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp16, fp16, fp16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, fp16, fp32, fp16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, bf16, bf16, bf16, fp32, 1, 4, 16, 4, + true); +REGISTER_NORM_LAUNCHER(RMSNorm, BackwardAdd, general, 4096, bf16, fp32, bf16, fp32, 1, 4, 16, 4, + true); diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 2f4414328..9df220ba7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -208,6 +208,11 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, const at::Tensor &rsigma, const at::Tensor &gamma, const int sm_margin, const bool zero_centered_gamma); +std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &add, const at::Tensor &rsigma, + const at::Tensor &gamma, const int sm_margin, + const bool zero_centered_gamma); + std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, py::object ln_out, py::handle quantizer, DType otype, const int sm_margin, const bool zero_centered_gamma); diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index e5a1a2a78..59bac8fe5 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -199,6 +199,52 @@ std::vector rmsnorm_bwd(const at::Tensor &dz, const at::Tensor &x, return {py::cast(dx), py::cast(dgamma)}; } +std::vector rmsnorm_bwd_add(const at::Tensor &dz, const at::Tensor &x, + const at::Tensor &add, const at::Tensor &rsigma, + const at::Tensor &gamma, const int sm_margin, + const bool zero_centered_gamma) { + const auto &dz_ = dz.contiguous(); + const auto &x_ = x.contiguous(); + const auto &add_ = add.contiguous(); + const auto &rsigma_ = rsigma.contiguous(); + const auto &gamma_ = gamma.contiguous(); + + auto dx = at::empty_like(x_); + auto dgamma = at::empty_like(gamma_); + TensorWrapper workspace; + + auto dz_cu = makeTransformerEngineTensor(dz_); + auto x_cu = makeTransformerEngineTensor(x_); + auto add_cu = makeTransformerEngineTensor(add_); + auto rsigma_cu = makeTransformerEngineTensor(rsigma_); + auto gamma_cu = makeTransformerEngineTensor(gamma_); + auto dx_cu = makeTransformerEngineTensor(dx); + auto dgamma_cu = makeTransformerEngineTensor(dgamma); + + // This call populates tensors with the required config. + NVTE_SCOPED_GIL_RELEASE({ + nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(), + gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); + }); + + // Alloc space for Tensors. + auto workspace_data = allocateSpace(workspace.shape(), workspace.dtype()); + workspace = + makeTransformerEngineTensor(workspace_data.data_ptr(), workspace.shape(), workspace.dtype()); + + // Actual call to bwd kernel. + NVTE_SCOPED_GIL_RELEASE({ + nvte_rmsnorm_bwd_add(dz_cu.data(), x_cu.data(), add_cu.data(), rsigma_cu.data(), + gamma_cu.data(), dx_cu.data(), dgamma_cu.data(), workspace.data(), + at::cuda::getCurrentDeviceProperties()->multiProcessorCount - sm_margin, + zero_centered_gamma, at::cuda::getCurrentCUDAStream()); + }); + + return {py::cast(dx), py::cast(dgamma)}; +} + std::vector rmsnorm_fwd(const py::handle &input, const py::handle &weight, float eps, py::object out, py::handle quantizer, DType out_dtype, const int sm_margin, const bool zero_centered_gamma) { diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index d38348ae9..235516d0c 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -202,6 +202,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("weight"), py::arg("eps"), py::arg("ln_out"), py::arg("quantizer"), py::arg("otype"), py::arg("sm_margin"), py::arg("zero_centered_gamma")); m.def("rmsnorm_bwd", &transformer_engine::pytorch::rmsnorm_bwd, "Backward of RMSNorm"); + m.def("rmsnorm_bwd_add", &transformer_engine::pytorch::rmsnorm_bwd_add, + "Fused backward of RMSNorm + add"); m.def("multi_tensor_quantize", &transformer_engine::pytorch::multi_tensor_quantize, "Multi-tensor quantize", py::arg("tensor_list"), py::arg("quantizer_list")); m.def("split_quantize", &transformer_engine::pytorch::split_quantize, diff --git a/transformer_engine/pytorch/ops/fused/__init__.py b/transformer_engine/pytorch/ops/fused/__init__.py index b21be1924..21113c212 100644 --- a/transformer_engine/pytorch/ops/fused/__init__.py +++ b/transformer_engine/pytorch/ops/fused/__init__.py @@ -8,6 +8,10 @@ BackwardActivationBias, fuse_backward_activation_bias, ) +from .backward_add_rmsnorm import ( + BackwardAddRMSNorm, + fuse_backward_add_rmsnorm, +) from .backward_linear_add import ( BackwardLinearAdd, fuse_backward_linear_add, diff --git a/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py new file mode 100644 index 000000000..54a23395a --- /dev/null +++ b/transformer_engine/pytorch/ops/fused/backward_add_rmsnorm.py @@ -0,0 +1,133 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Fused backward RMNorm + add.""" + +from __future__ import annotations +from typing import Optional +import math + +import torch + +import transformer_engine_torch as tex +from transformer_engine.pytorch.ops.basic import MakeExtraOutput, RMSNorm + +from transformer_engine.pytorch.ops.op import ( + FusedOperation, + FusibleOperation, + OperationContext, +) +from ...utils import clear_tensor_data +from .._common import maybe_dequantize + + +class BackwardAddRMSNorm(FusedOperation): + """Fused backward RMNorm + add""" + + def __init__(self, *, add: MakeExtraOutput, rmsnorm: RMSNorm): + super().__init__((add, rmsnorm)) + + def fuser_backward( + self, + basic_op_ctxs: list[OperationContext], + grad_output: torch.Tensor, + *, + basic_op_grad_extra_outputs: list[tuple[torch.Tensor, ...]], + ) -> tuple[ + torch.Tensor, + list[tuple[Optional[torch.Tensor], ...]], + list[tuple[()]], + ]: + + # Get basic operations + rmsnorm_op = self.basic_ops[1] + rmsnorm_op_ctx = basic_op_ctxs[0] + + # Saved tensors from forward pass + x, rstdevs = rmsnorm_op_ctx.saved_tensors + + # Tensor dims + weight_dims = rmsnorm_op.weight.size() + inner_dim = math.prod(weight_dims) + + # Check input tensors + dtype = rmsnorm_op_ctx.dtype + extra_grad = basic_op_grad_extra_outputs[1][0] + dy = maybe_dequantize(grad_output.contiguous(), dtype).view(x.size()) + w = maybe_dequantize(rmsnorm_op.weight, dtype).view((inner_dim,)) + add = maybe_dequantize(extra_grad.contiguous(), dtype).view(x.size()) + + # Compute RMSNorm backward pass + dx, dw = tex.rmsnorm_bwd_add( + dy, + x, + add, + rstdevs, + w, + rmsnorm_op._sm_margins["backward"], + rmsnorm_op.zero_centered_gamma, + ) + + # Clear saved tensors if possible + clear_tensor_data(x) + clear_tensor_data(rstdevs) + + # Reshape results + grad_input = dx.view(grad_output.size()) + grad_weight = dw.view(weight_dims) + + return grad_input, [(grad_weight,), ()], [(), ()] + + +def fuse_backward_add_rmsnorm( + ops: list[tuple[FusibleOperation, list[int]]], +) -> list[tuple[FusibleOperation, list[int]]]: + """Fused backward RMNorm + add + + Parameters + ---------- + ops: list of tuples + Backward pass operations and the indices of the corresponding + basic operations. + + Returns + ------- + ops: list of tuples + Updated backward pass operations + + """ + + # Scan through ops, fusing if possible + out = [] + window = [] + while len(ops) >= 2: + out.extend(window) + + # Check if first op is linear + window, ops = ops[:1], ops[1:] + op, _ = window[0] + if not isinstance(op, RMSNorm): + continue + + # Check if second op is "make extra output" + op, _ = ops[0] + if not isinstance(op, MakeExtraOutput): + continue + if op._in_place: + continue + window.extend(ops[:1]) + ops = ops[1:] + + # Replace window with fused op + op = BackwardAddRMSNorm( + rmsnorm=window[0][0], + add=window[1][0], + ) + basic_op_idxs = [basic_op_idxs[0] for _, basic_op_idxs in window] + window = [(op, basic_op_idxs)] + + # Return list of ops + out.extend(window) + out.extend(ops) + return out diff --git a/transformer_engine/pytorch/ops/fuser.py b/transformer_engine/pytorch/ops/fuser.py index 448c7d6c9..ccd7ee52b 100644 --- a/transformer_engine/pytorch/ops/fuser.py +++ b/transformer_engine/pytorch/ops/fuser.py @@ -19,6 +19,7 @@ ) from transformer_engine.pytorch.ops.fused import ( fuse_backward_activation_bias, + fuse_backward_add_rmsnorm, fuse_backward_linear_add, fuse_backward_linear_scale, fuse_forward_linear_bias_activation, @@ -371,6 +372,7 @@ def _fuse_backward_ops( ops = fuse_backward_linear_add(ops) ops = fuse_backward_linear_scale(ops) ops = fuse_backward_activation_bias(ops, recipe) + ops = fuse_backward_add_rmsnorm(ops) return ops def maybe_fuse_ops( From 51f19fdc5bf58fd0ca6ccac2c597d2e43c0580a2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Wed, 20 Aug 2025 10:09:52 +0200 Subject: [PATCH 085/153] [PyTorch] Add test for TRT integration + fix for mxfp8 export (#2083) * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- qa/L0_pytorch_unittest/test.sh | 3 - qa/L1_pytorch_onnx_unittest/test.sh | 11 ++++ tests/pytorch/test_onnx_export.py | 59 ++++++++++++++++++- transformer_engine/pytorch/onnx_extensions.py | 8 +-- 4 files changed, 73 insertions(+), 8 deletions(-) create mode 100644 qa/L1_pytorch_onnx_unittest/test.sh diff --git a/qa/L0_pytorch_unittest/test.sh b/qa/L0_pytorch_unittest/test.sh index 482ae6dca..394273ca4 100644 --- a/qa/L0_pytorch_unittest/test.sh +++ b/qa/L0_pytorch_unittest/test.sh @@ -23,8 +23,6 @@ set -x mkdir -p "$XML_LOG_DIR" pip3 install pytest==8.2.1 || error_exit "Failed to install pytest" -pip3 install onnxruntime==1.20.1 || error_exit "Failed to install onnxruntime" -pip3 install onnxruntime_extensions==0.13.0 || error_exit "Failed to install onnxruntime_extensions" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_sanity.xml $TE_PATH/tests/pytorch/test_sanity.py || test_fail "test_sanity.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_recipe.xml $TE_PATH/tests/pytorch/test_recipe.py || test_fail "test_recipe.py" @@ -40,7 +38,6 @@ python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_float8_blockwise python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_gqa.xml $TE_PATH/tests/pytorch/test_gqa.py || test_fail "test_gqa.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fused_optimizer.xml $TE_PATH/tests/pytorch/test_fused_optimizer.py || test_fail "test_fused_optimizer.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_multi_tensor.xml $TE_PATH/tests/pytorch/test_multi_tensor.py || test_fail "test_multi_tensor.py" -python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_onnx_export.xml $TE_PATH/tests/pytorch/test_onnx_export.py || test_fail "test_onnx_export.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops.xml $TE_PATH/tests/pytorch/test_fusible_ops.py || test_fail "test_fusible_ops.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_permutation.xml $TE_PATH/tests/pytorch/test_permutation.py || test_fail "test_permutation.py" python3 -m pytest --tb=auto --junitxml=$XML_LOG_DIR/pytest_test_parallel_cross_entropy.xml $TE_PATH/tests/pytorch/test_parallel_cross_entropy.py || test_fail "test_parallel_cross_entropy.py" diff --git a/qa/L1_pytorch_onnx_unittest/test.sh b/qa/L1_pytorch_onnx_unittest/test.sh new file mode 100644 index 000000000..1486d5097 --- /dev/null +++ b/qa/L1_pytorch_onnx_unittest/test.sh @@ -0,0 +1,11 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + + +pip3 install onnxruntime==1.20.1 +pip3 install onnxruntime_extensions==0.13.0 + +: ${TE_PATH:=/opt/transformerengine} + +python3 -m pytest --tb=auto $TE_PATH/tests/pytorch/test_onnx_export.py diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index 839fb8dff..b353333a5 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -36,6 +36,7 @@ from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.utils import get_default_init_method +import tensorrt as trt # Global test configuration knobs. @@ -113,7 +114,7 @@ def trt_fp8_dequantize(t, scale): @onnx_op( - op_type="trt::TRT_MXFP8QuantizeLinear", + op_type="trt::TRT_MXFP8DynamicQuantize", domain="trt", inputs=[ PyCustomOpDef.dt_float, @@ -1139,3 +1140,59 @@ def test_export_ctx_manager(enabled): with te.onnx_export(enabled): assert is_in_onnx_export_mode() == enabled assert is_in_onnx_export_mode() == False + + +@pytest.mark.parametrize("fp8_recipe", fp8_recipes) +def test_trt_integration(fp8_recipe: recipe.Recipe): + + model = te.TransformerLayer( + hidden_size=128, + ffn_hidden_size=128, + num_attention_heads=4, + ).eval() + inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) + + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + out_ref = model(*inps) + + onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx") + os.close(onnx_fd) + try: + with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): + with te.onnx_export(enabled=True): + torch.onnx.export( + model, + inps, + onnx_path, + output_names=["output"], + dynamo=True, + custom_translation_table=te_translation_table, + ) + + os.system(f"trtexec --onnx={onnx_path} --saveEngine={onnx_path}.engine") + + # Run TRT engine + logger = trt.Logger(trt.Logger.WARNING) + runtime = trt.Runtime(logger) + with open(onnx_path + ".engine", "rb") as f: + engine_data = f.read() + engine = runtime.deserialize_cuda_engine(engine_data) + context = engine.create_execution_context() + context.set_tensor_address(engine.get_tensor_name(0), inps[0].data_ptr()) + stream = torch.cuda.Stream() + + out = torch.zeros_like(out_ref) + context.set_tensor_address("output", out.data_ptr()) + + context.execute_async_v3(stream_handle=stream.cuda_stream) + stream.synchronize() + + # Compare TRT and TE outputs + atol = 5e-2 if fp8_recipe is not None else 1e-4 + rtol = 5e-2 if fp8_recipe is not None else 1e-4 + torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol) + finally: + try: + os.remove(onnx_path) + except FileNotFoundError: + pass diff --git a/transformer_engine/pytorch/onnx_extensions.py b/transformer_engine/pytorch/onnx_extensions.py index e34fd7846..42f5a1d55 100644 --- a/transformer_engine/pytorch/onnx_extensions.py +++ b/transformer_engine/pytorch/onnx_extensions.py @@ -194,12 +194,12 @@ def onnx_quantize_mxfp8_symbolic( tensor: onnxscript.onnx_types.TensorType, ) -> Tuple[onnxscript.onnx_types.TensorType, onnxscript.onnx_types.TensorType]: """Symbolic quantize to MXFP8Tensor used for inference.""" - tensor_out, scale_inv_out = TRT_MXFP8QuantizeLinear(tensor) + tensor_out, scale_inv_out = TRT_MXFP8DynamicQuantize(tensor) return tensor_out, scale_inv_out schema = defs.OpSchema( - name="TRT_MXFP8QuantizeLinear", + name="TRT_MXFP8DynamicQuantize", domain="trt", since_version=1, doc="TRT MXFP8 Quantize Linear used for inference.", @@ -214,8 +214,8 @@ def onnx_quantize_mxfp8_symbolic( ], ) -TRT_MXFP8QuantizeLinear = onnxscript.values.Op( - opset=trt_opset, name="TRT_MXFP8QuantizeLinear", op_schema=schema +TRT_MXFP8DynamicQuantize = onnxscript.values.Op( + opset=trt_opset, name="TRT_MXFP8DynamicQuantize", op_schema=schema ) From bc99a88da65fa2e47a0eadff575b456bd4ec02e1 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 20 Aug 2025 08:36:10 -0700 Subject: [PATCH 086/153] [JAX] Error checking for mesh resource and update GemmPrimitive to use global_mesh_resource().fsdp_resource (#2088) * Enforce global MeshResource is set Signed-off-by: Jeremy Berchtold * Use global_mesh_resource().fsdp_resource in gemm primitive Signed-off-by: Jeremy Berchtold * Update tests Signed-off-by: Jeremy Berchtold * Update gemm.py Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- .../jax/encoder/test_single_gpu_encoder.py | 4 +++- examples/jax/mnist/test_single_gpu_mnist.py | 4 +++- tests/jax/test_distributed_layernorm_mlp.py | 4 ++-- tests/jax/test_layer.py | 21 +++++++++++++++---- transformer_engine/jax/cpp_extensions/gemm.py | 4 +++- transformer_engine/jax/quantize/helper.py | 3 --- transformer_engine/jax/sharding.py | 7 ++++++- 7 files changed, 34 insertions(+), 13 deletions(-) diff --git a/examples/jax/encoder/test_single_gpu_encoder.py b/examples/jax/encoder/test_single_gpu_encoder.py index b4c8767a5..826d0d2fc 100644 --- a/examples/jax/encoder/test_single_gpu_encoder.py +++ b/examples/jax/encoder/test_single_gpu_encoder.py @@ -219,7 +219,9 @@ def train_and_evaluate(args): else: fp8_recipe = None - with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): + with te.fp8_autocast( + enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + ): encoder = Net(num_embed) # We use nn.Embed, thus inputs need to be in int inputs = jnp.zeros(input_shape, dtype=jnp.int32) diff --git a/examples/jax/mnist/test_single_gpu_mnist.py b/examples/jax/mnist/test_single_gpu_mnist.py index 110705d01..92baf4b0c 100644 --- a/examples/jax/mnist/test_single_gpu_mnist.py +++ b/examples/jax/mnist/test_single_gpu_mnist.py @@ -193,7 +193,9 @@ def train_and_evaluate(args): else: fp8_recipe = None - with te.fp8_autocast(enabled=args.use_fp8, fp8_recipe=fp8_recipe): + with te.fp8_autocast( + enabled=args.use_fp8, fp8_recipe=fp8_recipe, mesh_resource=te.sharding.MeshResource() + ): cnn = Net(args.use_te) var_collect = cnn.init(init_rngs, jnp.empty(input_shape, dtype=jnp.bfloat16)) tx = optax.sgd(args.lr, args.momentum) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 79186aa47..e3b1ecac9 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -173,7 +173,7 @@ def _test_layernorm_mlp_grad( ) # Single GPU - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): single_jitter = jax.jit( value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), @@ -330,7 +330,7 @@ def _test_layernorm_mlp( with use_jax_gemm(enabled=with_jax_gemm): # Single GPUs - with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe): + with fp8_autocast(enabled=use_fp8, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): ln_mlp_single = LayerNormMLP( layernorm_type=layernorm_type, intermediate_dim=INTERMEDIATE, diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index d59e13053..0d0dba547 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -28,6 +28,7 @@ is_fp8_available, update_collections, ) +from transformer_engine.jax.sharding import MeshResource, global_shard_guard @pytest.fixture(autouse=True, scope="function") @@ -490,19 +491,28 @@ class BaseTester: def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" QuantizeConfig.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_forward(data_shape, dtype) + with global_shard_guard( + MeshResource() + ): # Empty MeshResource is used as we are running on a single device + self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" QuantizeConfig.finalize() # Ensure FP8 disabled. - self.runner(attrs).test_backward(data_shape, dtype) + with global_shard_guard( + MeshResource() + ): # Empty MeshResource is used as we are running on a single device + self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) + with global_shard_guard( + MeshResource() + ): # Empty MeshResource is used as we are running on a single device + self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() @pytest.mark.skipif(not is_fp8_supported, reason=reason) @@ -510,7 +520,10 @@ def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) + with global_shard_guard( + MeshResource() + ): # Empty MeshResource is used as we are running on a single device + self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) QuantizeConfig.finalize() diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 9975f558b..7dec4d757 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -34,6 +34,7 @@ is_fp8_gemm_with_all_layouts_supported, apply_padding_to_scale_inv, ) +from ..sharding import global_mesh_resource from .misc import get_padded_spec @@ -490,7 +491,8 @@ def _parse_operand_output_specs( # Non-contracting dims of RHS always needs to be gathered along the FSDP axis rhs_non_cspecs = tuple( - None if spec is not None and "fsdp" in spec else spec for spec in rhs_non_cspecs + None if spec is not None and spec == global_mesh_resource().fsdp_resource else spec + for spec in rhs_non_cspecs ) # Non-contracting dims of LHS to be gathered along the SP axis. diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index 122265ea2..f8d18983e 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -404,9 +404,6 @@ def fp8_autocast( if fp8_recipe is None: fp8_recipe = recipe.DelayedScaling() - if mesh_resource is None: - mesh_resource = MeshResource() - Config = DelayedScalingQuantizeConfig if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): Config = BlockScalingQuantizeConfig diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 6d4894fd8..480989dcd 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -286,7 +286,7 @@ class MeshResource: cp_resource: str = None -_GLOBAL_MESH_RESOURCE = MeshResource() +_GLOBAL_MESH_RESOURCE = None @contextmanager @@ -314,6 +314,11 @@ def global_mesh_resource() -> MeshResource: Returns: The current MeshResource instance """ + assert _GLOBAL_MESH_RESOURCE is not None, ( + "Global mesh resource is not set. Please set the MeshResource via a global_shard_guard" + " context. If you are not using multiple GPUs, you can use an empty MeshResource by" + " wrapping your program in 'with global_shard_guard(MeshResource()):'" + ) return _GLOBAL_MESH_RESOURCE From 96944a81f68df34dc41cc5869825aa9a70e317fd Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 20 Aug 2025 11:03:10 -0700 Subject: [PATCH 087/153] [PyTorch] Avoid garbage collection when capturing a CUDA Graph (#2092) Avoid garbage collection when capturing a CUDA Graph Signed-off-by: Tim Moon --- transformer_engine/pytorch/graph.py | 29 +++++++++++++++++++++++++---- 1 file changed, 25 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index 866f0b639..eda18a185 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -4,6 +4,8 @@ """Functions for CUDA Graphs support in FP8""" from collections.abc import Iterable +import contextlib +import gc from typing import Any, Callable, Dict, List, Optional, Tuple, TypeVar, Union import torch @@ -58,6 +60,25 @@ def graph_pool_handle(): return _graph_pool_handle() +@contextlib.contextmanager +def _graph_context_wrapper(*args, **kwargs): + """Wrapper around `torch.cuda.graph`. + + This wrapper is a temporary workaround for a PyTorch bug: + automatic garbage collection can destroy a graph while another + graph is being captured, resulting in a CUDA error. See + https://github.com/pytorch/pytorch/pull/161037. + + """ + gc_is_enabled = gc.isenabled() + if gc_is_enabled: + gc.disable() + with torch.cuda.graph(*args, **kwargs): + yield + if gc_is_enabled: + gc.enable() + + def _make_graphed_callables( callables: SingleOrTuple[Callable], sample_args: SingleOrTuple[Tuple[torch.Tensor, ...]], @@ -445,7 +466,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument args = sample_args[per_callable_fwd_idx] kwargs = sample_kwargs[per_callable_fwd_idx] fwd_graph = fwd_graphs[per_callable_fwd_idx] - with torch.cuda.graph(fwd_graph, pool=mempool): + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) flatten_outputs, spec = _tree_flatten(outputs) per_callable_static_outputs[per_callable_fwd_idx] = tuple(flatten_outputs) @@ -483,7 +504,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with torch.cuda.graph(bwd_graph, pool=mempool): + with _graph_context_wrapper(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), @@ -548,7 +569,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument per_callable_output_unflatten_spec = [] graph_id = 0 for func, args, kwargs, fwd_graph in zip(callables, sample_args, sample_kwargs, fwd_graphs): - with torch.cuda.graph(fwd_graph, pool=mempool): + with _graph_context_wrapper(fwd_graph, pool=mempool): outputs = func(*args, **kwargs) graph_callables[graph_id] = func graph_id += 1 @@ -570,7 +591,7 @@ def hook_fn(module, inputs, outputs): # pylint: disable=unused-argument torch.empty_like(o) if o.requires_grad else None for o in static_outputs ) if is_training: - with torch.cuda.graph(bwd_graph, pool=mempool): + with _graph_context_wrapper(bwd_graph, pool=mempool): grad_inputs = torch.autograd.grad( outputs=tuple(o for o in static_outputs if o.requires_grad), inputs=tuple(i for i in static_input_surface if i.requires_grad), From 406e2c9d9ef93a20c5f658f1a05b62dab5fc0fd7 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:24:33 -0700 Subject: [PATCH 088/153] Fix incorrect version checks for atomic GEMM (#2095) * Fix incorrect version checks for atomic GEMM Signed-off-by: Tim Moon * Fix typo Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/gemm/cublaslt_gemm.cu | 28 ++++++++++--------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 1c4af23eb..d65cd7b55 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -517,22 +517,22 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, &epilogue, sizeof(epilogue))); if (counter != nullptr) { -#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) - NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", +#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) + NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); #endif #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) NVTE_ERROR( - "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", CUBLAS_VERSION); #endif #if CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 120205 && CUDA_VERSION < 13000 && \ CUBLAS_VERSION < 130000 NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, - "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA verson is ", + "Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but run-time CUDA version is ", cuda::cudart_version()); NVTE_CHECK(cublas_version() >= 120205 && cublas_version() < 130000, - "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", cublas_version()); if (m_split == 0) m_split = 1; if (n_split == 0) n_split = 1; @@ -658,20 +658,22 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor using namespace transformer_engine; // Check CUDA and cuBLAS versions -#if !(CUDA_VERSION >= 12020 && CUBLAS_VERSION >= 13000) - NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA verson is ", +#if !(CUDA_VERSION >= 12020 && CUDA_VERSION < 13000) + NVTE_ERROR("Atomic GEMM requires CUDA >=12.2.0 and <13.0.0, but compile-time CUDA version is ", CUDA_VERSION); #endif #if !(CUBLAS_VERSION >= 120205 && CUBLAS_VERSION < 130000) - NVTE_ERROR("Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS verson is ", - CUBLAS_VERSION); + NVTE_ERROR( + "Atomic GEMM requires cuBLAS >=12.2.5 and <13.0.0, but compile-time cuBLAS version is ", + CUBLAS_VERSION); #endif - NVTE_CHECK(cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, - "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA verson is ", - cuda::cudart_version()); + NVTE_CHECK( + cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, + "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", + cuda::cudart_version()); NVTE_CHECK( cublas_version() >= 120205 && cublas_version() < 130000, - "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS verson is ", + "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", cublas_version()); const Tensor *inputA = convertNVTETensorCheck(A); From f1b18ed040b9e474ee639dd67c28ef5764211938 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Wed, 20 Aug 2025 15:24:53 -0700 Subject: [PATCH 089/153] Update list of authorized CI users (#2081) Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 6d7410619..85a81a6d4 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -56,6 +56,7 @@ jobs: || github.actor == 'vcherepanov-nv' || github.actor == 'tdophung' || github.actor == 'vthumbe1503' + || github.actor == 'janekb04' ) steps: - name: Check if comment is issued by authorized person From 20be25a3d9606897f7c88d817cd301c29137d9bc Mon Sep 17 00:00:00 2001 From: Md Fahim Faysal Khan Date: Thu, 21 Aug 2025 06:55:55 -0700 Subject: [PATCH 090/153] [ TE-JAX ] Expose cp_strategy argument to DPA api (#2090) * added cp strategy arg to DPA api Signed-off-by: Md Fahim Faysal Khan * converted DPA cp_strategy to string Signed-off-by: Md Fahim Faysal Khan --------- Signed-off-by: Md Fahim Faysal Khan --- transformer_engine/jax/flax/transformer.py | 26 ++++++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index d85593c1e..fb3ac7b9a 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -26,6 +26,7 @@ from ..attention import AttnBiasType, AttnMaskType, QKVLayout, SequenceDescriptor from ..attention import is_fused_attn_kernel_available, make_swa_mask, canonicalize_attn_mask_type from ..attention import fused_attn +from ..attention import CPStrategy from ..softmax import SoftmaxType from ..sharding import num_of_devices from ..sharding import get_sharding_map_logic_axis_to_mesh_axis @@ -274,6 +275,7 @@ class _FusedDotProductAttention(nn.Module): # pylint: disable=too-few-public-me max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT context_checkpoint_name: str = "context" @nn.compact @@ -323,6 +325,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=self.context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, ) elif self.qkv_layout.is_kvpacked(): @@ -350,6 +353,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=self.context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, ) elif self.qkv_layout.is_separate(): @@ -372,6 +376,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=self.context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, ) else: @@ -505,6 +510,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods context_parallel_causal_load_balanced (bool): Indicates the sequences are ordered for causal mask load balancing when running context parallelism. context_parallel_axis (str): The name of the context parallel axis. + context_parallel_strategy (CPStrategy): The strategy of context parallel. 0: DEFAULT, 1: ALL_GATHER, 2: RING. context_checkpoint_name (str): The name of the context checkpoint in the forward pass of fused attention. Optimization parameters @@ -529,6 +535,7 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-methods max_segments_per_seq: Optional[int] = 1 context_parallel_causal_load_balanced: bool = False context_parallel_axis: str = "" + context_parallel_strategy: str = "DEFAULT" context_checkpoint_name: str = "context" @nn.compact @@ -648,6 +655,24 @@ def __call__( scale_factor = self.scale_factor del self.scale_factor + # case-insensitive mapping for context parallel strategy + cp_strategy_map = { + "DEFAULT": CPStrategy.DEFAULT, + "ALL_GATHER": CPStrategy.ALL_GATHER, + "ALLGATHER": CPStrategy.ALL_GATHER, # Alternative spelling + "RING": CPStrategy.RING, + } + + strategy_key = self.context_parallel_strategy.upper() + if strategy_key in cp_strategy_map: + context_parallel_strategy = cp_strategy_map[strategy_key] + else: + valid_strategies = list(cp_strategy_map.keys()) + raise ValueError( + f"Invalid context parallel strategy: {self.context_parallel_strategy}. " + f"Valid options are: {valid_strategies} (case insensitive)" + ) + if not use_fused_attn: # unfused attention only supports splitted query, key, value if qkv_layout.is_qkvpacked(): @@ -696,6 +721,7 @@ def __call__( max_segments_per_seq=self.max_segments_per_seq, context_parallel_causal_load_balanced=self.context_parallel_causal_load_balanced, context_parallel_axis=self.context_parallel_axis, + context_parallel_strategy=context_parallel_strategy, context_checkpoint_name=self.context_checkpoint_name, )( query, From 40dde4dd57c3bbe45a9fa439ecefe66004bd8565 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Fri, 22 Aug 2025 11:40:29 -0400 Subject: [PATCH 091/153] Update NGC version to 25.08 (#2085) update NGC version Signed-off-by: Phuong Nguyen --- README.rst | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/README.rst b/README.rst index cfd5c687e..19ab1a7d9 100644 --- a/README.rst +++ b/README.rst @@ -176,15 +176,15 @@ For example to use the NGC PyTorch container interactively, .. code-block:: bash - docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.04-py3 + docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.08-py3 For example to use the NGC JAX container interactively, .. code-block:: bash - docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.04-py3 + docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.08-py3 -Where 25.04 (corresponding to April 2025 release) is the container version. +Where 25.08 (corresponding to August 2025 release) is the container version. **Benefits of using NGC containers:** From d88137c40618f62996b5538ea4d182481d9b7c4f Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Fri, 22 Aug 2025 17:04:36 -0700 Subject: [PATCH 092/153] [PyTorch] Debug Mcore wgrad fusion with te.ops (#2097) * Return dummy wgrad tensors when requested by Mcore Signed-off-by: Tim Moon * Apply suggestions from code review Co-authored-by: Jan Bielak Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Jan Bielak --- .../pytorch/ops/basic/basic_linear.py | 41 ++++++++++++++----- .../pytorch/ops/fused/backward_linear_add.py | 38 ++++++++++------- .../ops/fused/backward_linear_scale.py | 38 ++++++++++------- .../ops/fused/userbuffers_backward_linear.py | 34 ++++++++++----- transformer_engine/pytorch/ops/linear.py | 3 +- 5 files changed, 104 insertions(+), 50 deletions(-) diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 877596824..833633055 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -12,7 +12,6 @@ import torch -from transformer_engine.pytorch.module.base import get_workspace from ...cpp_extensions import general_gemm from ...distributed import ( CudaRNGStatesTracker, @@ -20,18 +19,24 @@ reduce_scatter_along_first_dim, ) from ...fp8 import FP8GlobalStateManager, Recipe -from ...module.base import _2X_ACC_FPROP, _2X_ACC_DGRAD, _2X_ACC_WGRAD +from ...module.base import ( + _2X_ACC_FPROP, + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, + get_dummy_wgrad, + get_workspace, +) from ...tensor import Quantizer from ...tensor.float8_tensor import Float8Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase -from ..op import BasicOperation, OperationContext -from .._common import maybe_dequantize, is_quantized_tensor from ...utils import ( canonicalize_device, canonicalize_dtype, clear_tensor_data, devices_match, ) +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize, is_quantized_tensor def _wait_async(handle: Optional[Any]) -> None: @@ -73,7 +78,8 @@ class BasicLinear(BasicOperation): weight's `main_grad` attribute instead of relying on PyTorch autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be - meaningful. + meaningful. This is primarily intented to integrate with + Megatron-LM. userbuffers_options, dict, optional Options for overlapping tensor-parallel communication with compute using Userbuffers. This feature is highly @@ -979,20 +985,22 @@ def op_backward( # Saved tensors from forward pass (x_local, w) = ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = self._accumulate_into_main_grad grad_weight = None if ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(self.weight, "__fsdp_param__"): - self.weight.main_grad = self.weight.get_main_grad() - - if not hasattr(self.weight, "main_grad"): + weight_param = self.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = self.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -1019,6 +1027,17 @@ def op_backward( # Clear input tensor if possible clear_tensor_data(x_local) + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: grad_weight = None + weight_param = self.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + return grad_input, [grad_weight] diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_add.py b/transformer_engine/pytorch/ops/fused/backward_linear_add.py index 8af46a27c..845ba262a 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_add.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_add.py @@ -9,13 +9,10 @@ import torch -from transformer_engine.pytorch.ops.basic import BasicLinear, MakeExtraOutput -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) +from ...module.base import get_dummy_wgrad from ...utils import clear_tensor_data +from ..basic import BasicLinear, MakeExtraOutput +from ..op import FusedOperation, FusibleOperation, OperationContext class BackwardLinearAdd(FusedOperation): @@ -53,20 +50,22 @@ def fuser_backward( # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(linear_op.weight, "__fsdp_param__"): - linear_op.weight.main_grad = linear_op.weight.get_main_grad() - - if not hasattr(linear_op.weight, "main_grad"): + weight_param = linear_op.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = linear_op.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -92,12 +91,23 @@ def fuser_backward( grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer, ) - if accumulate_into_main_grad: - grad_weight = None # Clear input tensor if possible clear_tensor_data(x_local) + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weight = None + weight_param = linear_op.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + return grad_input, [(grad_weight,), ()], [(), ()] diff --git a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py index 630a63157..a9595d516 100644 --- a/transformer_engine/pytorch/ops/fused/backward_linear_scale.py +++ b/transformer_engine/pytorch/ops/fused/backward_linear_scale.py @@ -9,13 +9,10 @@ import torch -from ..basic import BasicLinear, ConstantScale -from ..op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) +from ...module.base import get_dummy_wgrad from ...utils import clear_tensor_data +from ..basic import BasicLinear, ConstantScale +from ..op import FusedOperation, FusibleOperation, OperationContext class BackwardLinearScale(FusedOperation): @@ -54,20 +51,22 @@ def fuser_backward( # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(linear_op.weight, "__fsdp_param__"): - linear_op.weight.main_grad = linear_op.weight.get_main_grad() - - if not hasattr(linear_op.weight, "main_grad"): + weight_param = linear_op.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = linear_op.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -92,12 +91,23 @@ def fuser_backward( grad_output_quantizer=linear_op_ctx.grad_output_quantizer, grad_input_quantizer=linear_op_ctx.grad_input_quantizer, ) - if accumulate_into_main_grad: - grad_weight = None # Clear input tensor if possible clear_tensor_data(x_local) + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. + if accumulate_into_main_grad: + grad_weight = None + weight_param = linear_op.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + return grad_input, [(), (grad_weight,)], [(), ()] diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index 54a4d49db..c59532521 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -14,11 +14,12 @@ from ...cpp_extensions import general_gemm from ...distributed import get_distributed_world_size from ...module.base import ( + _2X_ACC_DGRAD, + _2X_ACC_WGRAD, fill_userbuffers_buffer_for_all_gather, + get_dummy_wgrad, get_ub, get_workspace, - _2X_ACC_DGRAD, - _2X_ACC_WGRAD, ) from ...tensor.quantized_tensor import Quantizer from ...tensor.mxfp8_tensor import MXFP8Quantizer @@ -513,20 +514,22 @@ def fuser_backward( # Saved tensors from forward pass (x_local, w) = linear_op_ctx.saved_tensors - # wgrad fusion + # Megatron-LM wgrad fusion + # Note: Get grad tensor from param so we can accumulate + # directly into it. accumulate_into_main_grad = linear_op._accumulate_into_main_grad grad_weight = None if linear_op_ctx.weight_requires_grad and accumulate_into_main_grad: - if hasattr(linear_op.weight, "__fsdp_param__"): - linear_op.weight.main_grad = linear_op.weight.get_main_grad() - - if not hasattr(linear_op.weight, "main_grad"): + weight_param = linear_op.weight + if hasattr(weight_param, "__fsdp_param__"): + weight_param.main_grad = weight_param.get_main_grad() + if not hasattr(weight_param, "main_grad"): raise RuntimeError( "BasicLinear op is configured with " "accumulate_into_main_grad=True, " "but weight parameter does not have main_grad attribute" ) - grad_weight = linear_op.weight.main_grad.detach() + grad_weight = weight_param.main_grad.detach() else: accumulate_into_main_grad = False @@ -558,10 +561,21 @@ def fuser_backward( # Clear input tensor if possible clear_tensor_data(x_local) - # Return gradients - grad_params = [() for _ in range(len(self.basic_ops))] + # Megatron-LM wgrad fusion + # Note: Return dummy tensor for grad weight if needed. if accumulate_into_main_grad: grad_weight = None + weight_param = linear_op.weight + if hasattr(weight_param, "grad_added_to_main_grad"): + weight_param.grad_added_to_main_grad = True + grad_weight = get_dummy_wgrad( + list(weight_param.size()), + weight_param.dtype, + zero=getattr(weight_param, "zero_out_wgrad", False), + ) + + # Return gradients + grad_params = [() for _ in range(len(self.basic_ops))] grad_params[self._op_idxs["linear"]] = (grad_weight,) if bias_op is not None: grad_params[self._op_idxs["bias"]] = (grad_bias,) diff --git a/transformer_engine/pytorch/ops/linear.py b/transformer_engine/pytorch/ops/linear.py index 8686c1853..325126a3d 100644 --- a/transformer_engine/pytorch/ops/linear.py +++ b/transformer_engine/pytorch/ops/linear.py @@ -54,7 +54,8 @@ class Linear(FusedOperation): weight's `main_grad` attribute instead of relying on PyTorch autograd. The weight's `main_grad` must be set externally and there is no guarantee that `grad` will be set or be - meaningful. + meaningful. This is primarily intented to integrate with + Megatron-LM. """ From 78e097f17df26cda6b78152d4d1ad737b3e2e36b Mon Sep 17 00:00:00 2001 From: Ace Eldeib Date: Sun, 24 Aug 2025 23:11:39 -0400 Subject: [PATCH 093/153] [Jax] Fix narrowing conversions (#2094) Signed-off-by: Ace Eldeib Co-authored-by: Xin Yao --- .../jax/csrc/extensions/activation.cpp | 16 ++++++++-------- .../jax/csrc/extensions/normalization.cpp | 4 ++-- 2 files changed, 10 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/activation.cpp b/transformer_engine/jax/csrc/extensions/activation.cpp index cf75c850b..17fa9906b 100644 --- a/transformer_engine/jax/csrc/extensions/activation.cpp +++ b/transformer_engine/jax/csrc/extensions/activation.cpp @@ -37,9 +37,9 @@ Error_Type ActLuFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scal auto is_2x = static_cast(is_2x_int); auto flatten_axis = output_buf->dimensions().size() - 1; // output does not have act axis - auto input_shape = std::vector{m, act_len * n}; - auto output_shape = std::vector{m, n}; - auto output_trans_shape = std::vector{n, m}; + auto input_shape = std::vector{m, static_cast(act_len * n)}; + auto output_shape = std::vector{m, static_cast(n)}; + auto output_trans_shape = std::vector{static_cast(n), m}; auto input_tensor = TensorWrapper(input, input_shape, static_cast(in_dtype)); auto output_tensor = TensorWrapper(get_nvte_scaling_mode(scaling_mode)); output_tensor.set_rowwise_data(output, static_cast(out_dtype), output_shape); @@ -253,11 +253,11 @@ Error_Type DActLuDBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, auto m = product(act_input_dims, 0, act_input_dims.size() - 2); auto n = input_dims.back(); - auto input_shape = std::vector{m, n}; - auto act_input_shape = std::vector{m, n * act_len}; - auto output_shape = std::vector{m, n * act_len}; - auto output_trans_shape = std::vector{n * act_len, m}; - auto dbias_shape = std::vector{n * act_len}; + auto input_shape = std::vector{m, static_cast(n)}; + auto act_input_shape = std::vector{m, static_cast(n * act_len)}; + auto output_shape = std::vector{m, static_cast(n * act_len)}; + auto output_trans_shape = std::vector{static_cast(n * act_len), m}; + auto dbias_shape = std::vector{static_cast(n * act_len)}; std::vector workspace_shape(workspace_dims.begin(), workspace_dims.end()); auto input_tensor = diff --git a/transformer_engine/jax/csrc/extensions/normalization.cpp b/transformer_engine/jax/csrc/extensions/normalization.cpp index b07404eb7..c35bc6668 100644 --- a/transformer_engine/jax/csrc/extensions/normalization.cpp +++ b/transformer_engine/jax/csrc/extensions/normalization.cpp @@ -118,7 +118,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc convert_ffi_datatype_to_te_dtype(scale_inv_buf->element_type()), std::vector{ product(scale_inv_buf->dimensions(), 0, scale_inv_buf->dimensions().size() - 1), - scale_inv_buf->dimensions().back()}); + static_cast(scale_inv_buf->dimensions().back())}); } if (scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING && is_fp8_dtype(out_dtype)) { @@ -135,7 +135,7 @@ Error_Type NormForwardFFI(cudaStream_t stream, Buffer_Type x_buf, Buffer_Type sc convert_ffi_datatype_to_te_dtype(colwise_scale_inv_buf->element_type()), std::vector{product(colwise_scale_inv_buf->dimensions(), 0, colwise_scale_inv_buf->dimensions().size() - 1), - colwise_scale_inv_buf->dimensions().back()}); + static_cast(colwise_scale_inv_buf->dimensions().back())}); } if (_norm_type == NVTE_Norm_Type::LayerNorm) { From 2e23ad7127217728b55f6f11a1ed12638ee20a8a Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 25 Aug 2025 08:59:30 -0400 Subject: [PATCH 094/153] [JAX] Add Shardy warning in GEMM custom call (#2101) * added shardy warning Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/gemm.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 7dec4d757..188b37601 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -8,6 +8,7 @@ from collections.abc import Iterable from typing import Tuple, Sequence, Union from functools import partial, reduce +import warnings import jax import jax.numpy as jnp @@ -658,6 +659,12 @@ def shardy_sharding_rule( prefix = "GemmPrimitive_" + warnings.warn( + "Known issues with TE GemmPrimitives when Shardy propagation is enabled. For now," + " please turn off Shardy by exporting the environment variable" + " 'JAX_USE_SHARDY_PARTITIONER=0' if you experience any problems." + ) + def _generate_operand_rules(name, ndim, cdims): specs = [] ldims = tuple(i for i in range(ndim) if i not in cdims) From 47ab4a743e2d76d7835d0c0e6f9a725e658885b1 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Mon, 25 Aug 2025 06:00:52 -0700 Subject: [PATCH 095/153] [JAX] Add Transformer Layer tests for pre_scale_bias and post_scale_bias (#2104) Add Transformer Layer tests for pre_scale_bias and post_scale_bias Signed-off-by: Kshitij Lakhani --- tests/jax/test_layer.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 0d0dba547..8fe7ebae3 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -263,6 +263,16 @@ def enable_fused_attn(): _KEY_OF_RELATIVE_EMBEDDING: False, _KEY_OF_WINDOW_SIZE: (2, 2), }, + # attrs29 + { + _KEY_OF_RELATIVE_EMBEDDING: True, + _KEY_OF_SELF_ATTN_BIAS_TYPE: "pre_scale_bias", + }, + # attrs30 + { + _KEY_OF_RELATIVE_EMBEDDING: True, + _KEY_OF_SELF_ATTN_BIAS_TYPE: "post_scale_bias", + }, ] ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS] From ccc1abf94bd3b4c9ac6639b0b41b07532f31f3be Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 25 Aug 2025 18:30:08 -0700 Subject: [PATCH 096/153] [Pytorch] Fix `UnboundLocalError` during build (#2116) Fix UnboundLocalError Signed-off-by: Kirthi Shankar Sivamani --- build_tools/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/build_tools/utils.py b/build_tools/utils.py index 0dc5e3689..23fb56598 100644 --- a/build_tools/utils.py +++ b/build_tools/utils.py @@ -13,7 +13,7 @@ import subprocess import sys from pathlib import Path -from importlib.metadata import version +from importlib.metadata import version as get_version from subprocess import CalledProcessError from typing import List, Optional, Tuple, Union @@ -269,7 +269,7 @@ def cuda_version() -> Tuple[int, ...]: return tuple(int(v) for v in version) try: - version_str = version("nvidia-cuda-runtime-cu12") + version_str = get_version("nvidia-cuda-runtime-cu12") version_tuple = tuple(int(part) for part in version_str.split(".") if part.isdigit()) return version_tuple except importlib.metadata.PackageNotFoundError: From 07db17b5a968a1651832ebd856c3091710fec75f Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Tue, 26 Aug 2025 09:31:45 +0800 Subject: [PATCH 097/153] [PyTorch] Expose more activation functions (#2106) expose more activation functions Signed-off-by: Xin Yao --- tests/pytorch/test_fusible_ops.py | 30 +++- tests/pytorch/test_numerics.py | 24 ++- tests/pytorch/test_sanity.py | 13 +- transformer_engine/pytorch/csrc/extensions.h | 37 +++-- .../pytorch/csrc/extensions/activation.cpp | 62 ++++--- .../pytorch/csrc/extensions/pybind.cpp | 35 ++-- .../pytorch/module/layernorm_mlp.py | 44 +++-- .../pytorch/ops/basic/__init__.py | 2 +- .../pytorch/ops/basic/activation.py | 152 ++++++++++++++++-- transformer_engine/pytorch/transformer.py | 3 +- 10 files changed, 314 insertions(+), 88 deletions(-) diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 1b9d9acbf..9325f5d1e 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1532,7 +1532,10 @@ def test_make_extra_output( torch.testing.assert_close(y2_test, y2_ref, rtol=0, atol=0) torch.testing.assert_close(dx_test, x_ref.grad, **tols) - @pytest.mark.parametrize("activation", ("relu", "gelu", "geglu", "reglu", "swiglu")) + @pytest.mark.parametrize( + "activation", + ("gelu", "geglu", "qgelu", "qgeglu", "relu", "reglu", "srelu", "sreglu", "silu", "swiglu"), + ) @pytest.mark.parametrize("out_shape", ((37,), (2, 13), (32, 1, 32))) @pytest.mark.parametrize("dtype", _dtypes) @pytest.mark.parametrize("quantization", _quantization_list) @@ -1551,7 +1554,7 @@ def test_activation( # Tensor dimensions in_shape = list(out_shape) - if activation in ("geglu", "reglu", "swiglu"): + if activation in ("geglu", "qgeglu", "reglu", "sreglu", "swiglu"): in_shape[-1] *= 2 # Skip invalid configurations @@ -1578,14 +1581,26 @@ def test_activation( y_ref: torch.Tensor if activation == "gelu": y_ref = torch.nn.functional.gelu(x_ref, approximate="tanh") - elif activation == "relu": - y_ref = torch.nn.functional.relu(x_ref) elif activation == "geglu": x1, x2 = x_ref.chunk(2, dim=-1) y_ref = torch.nn.functional.gelu(x1, approximate="tanh") * x2 + elif activation == "qgelu": + y_ref = x_ref * torch.sigmoid(1.702 * x_ref) + elif activation == "qgeglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = x1 * torch.sigmoid(1.702 * x1) * x2 + elif activation == "relu": + y_ref = torch.nn.functional.relu(x_ref) elif activation == "reglu": x1, x2 = x_ref.chunk(2, dim=-1) y_ref = torch.nn.functional.relu(x1) * x2 + elif activation == "srelu": + y_ref = torch.nn.functional.relu(x_ref) ** 2 + elif activation == "sreglu": + x1, x2 = x_ref.chunk(2, dim=-1) + y_ref = torch.nn.functional.relu(x1) ** 2 * x2 + elif activation == "silu": + y_ref = torch.nn.functional.silu(x_ref) elif activation == "swiglu": x1, x2 = x_ref.chunk(2, dim=-1) y_ref = torch.nn.functional.silu(x1) * x2 @@ -1597,9 +1612,14 @@ def test_activation( recipe = make_recipe(quantization) make_op = dict( gelu=te_ops.GELU, - relu=te_ops.ReLU, geglu=te_ops.GEGLU, + qgelu=te_ops.QGELU, + qgeglu=te_ops.QGEGLU, + relu=te_ops.ReLU, reglu=te_ops.ReGLU, + srelu=te_ops.SReLU, + sreglu=te_ops.SReGLU, + silu=te_ops.SiLU, swiglu=te_ops.SwiGLU, )[activation] forward = te_ops.Sequential( diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 543f5f08d..b76f3d2b2 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -79,7 +79,18 @@ all_boolean = [True, False] -all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "qgelu", "srelu"] +all_activations = [ + "gelu", + "geglu", + "qgelu", + "qgeglu", + "relu", + "reglu", + "srelu", + "sreglu", + "silu", + "swiglu", +] all_normalizations = ["LayerNorm", "RMSNorm"] @@ -427,13 +438,16 @@ def forward(self, inp: torch.Tensor, m_splits: List[int]) -> torch.Tensor: _supported_act = { - "geglu": nn.GELU(approximate="tanh"), "gelu": nn.GELU(approximate="tanh"), - "reglu": nn.ReLU(), - "relu": nn.ReLU(), - "swiglu": nn.SiLU(), + "geglu": nn.GELU(approximate="tanh"), "qgelu": TorchQuickGELU(), + "qgeglu": TorchQuickGELU(), + "relu": nn.ReLU(), + "reglu": nn.ReLU(), "srelu": TorchSquaredRELU(), + "sreglu": TorchSquaredRELU(), + "silu": nn.SiLU(), + "swiglu": nn.SiLU(), } diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 5f61772d9..5151aa96e 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -104,7 +104,18 @@ def is_fp8_supported(config: ModelConfig): all_boolean = [True, False] batch_sizes_with_zero = [0, 1, 2] -all_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "srelu", "qgelu", "qgeglu"] +all_activations = [ + "gelu", + "geglu", + "qgelu", + "qgeglu", + "relu", + "reglu", + "srelu", + "sreglu", + "silu", + "swiglu", +] all_normalizations = ["LayerNorm", "RMSNorm"] diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index 9df220ba7..d0e92a59b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -154,38 +154,49 @@ at::Tensor swap_first_dims(at::Tensor tensor, std::optional out = st * Activations **************************************************************************************************/ +/* GELU and variants*/ py::object gelu(const at::Tensor &input, py::handle quantizer); -py::object relu(const at::Tensor &input, py::handle quantizer); +py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); py::object geglu(const at::Tensor &input, py::handle quantizer); -py::object qgeglu(const at::Tensor &input, py::handle quantizer); +py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object reglu(const at::Tensor &input, py::handle quantizer); +py::object qgelu(const at::Tensor &input, py::handle quantizer); -py::object swiglu(const at::Tensor &input, py::handle quantizer); +py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object qgelu(const at::Tensor &input, py::handle quantizer); +py::object qgeglu(const at::Tensor &input, py::handle quantizer); -py::object srelu(const at::Tensor &input, py::handle quantizer); +py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object dgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +/* ReLU and variants*/ +py::object relu(const at::Tensor &input, py::handle quantizer); py::object drelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object dgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object dqgeglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object reglu(const at::Tensor &input, py::handle quantizer); py::object dreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); -py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); - -py::object dqgelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object srelu(const at::Tensor &input, py::handle quantizer); py::object dsrelu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); +py::object sreglu(const at::Tensor &input, py::handle quantizer); + +py::object dsreglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +/* Silu and variants*/ +py::object silu(const at::Tensor &input, py::handle quantizer); + +py::object dsilu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + +py::object swiglu(const at::Tensor &input, py::handle quantizer); + +py::object dswiglu(const at::Tensor &grad, const at::Tensor &input, py::handle quantizer); + /*************************************************************************************************** * LayerNorm **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2ef7a869a..7851cc5ff 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -101,6 +101,7 @@ py::object dactivation_helper(const at::Tensor& grad_output, const at::Tensor& i return grad_input_py; } +/* GELU and variants*/ py::object gelu(const at::Tensor& input, py::handle quantizer) { return activation_helper(input, quantizer); } @@ -109,30 +110,39 @@ py::object dgelu(const at::Tensor& grad, const at::Tensor& input, py::handle qua return dactivation_helper(grad, input, quantizer); } -py::object relu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +py::object geglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -py::object geglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object qgelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -py::object qgeglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -py::object dgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object qgeglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } py::object dqgeglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { return dactivation_helper(grad, input, quantizer); } +/* ReLU and variants*/ +py::object relu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); +} + +py::object drelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + py::object reglu(const at::Tensor& input, py::handle quantizer) { return activation_helper(input, quantizer, 2); } @@ -141,28 +151,36 @@ py::object dreglu(const at::Tensor& grad, const at::Tensor& input, py::handle qu return dactivation_helper(grad, input, quantizer); } -py::object swiglu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer, 2); +py::object srelu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -py::object qgelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +py::object sreglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } -py::object dqgelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dsreglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); } -py::object srelu(const at::Tensor& input, py::handle quantizer) { - return activation_helper(input, quantizer); +/* Silu and variants*/ +py::object silu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer); } -py::object dsrelu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { - return dactivation_helper(grad, input, quantizer); +py::object dsilu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} + +py::object swiglu(const at::Tensor& input, py::handle quantizer) { + return activation_helper(input, quantizer, 2); } +py::object dswiglu(const at::Tensor& grad, const at::Tensor& input, py::handle quantizer) { + return dactivation_helper(grad, input, quantizer); +} } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 235516d0c..6442b05da 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -113,38 +113,53 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("comm_overlap") = nullptr, py::arg("comm_type") = std::nullopt, py::arg("extra_output") = std::nullopt, py::arg("bulk_overlap") = false, py::arg("alpha") = 1.0f, py::arg("beta") = std::nullopt); + /* GELU and variants*/ m.def("gelu", transformer_engine::pytorch::gelu, "GeLU activation", py::arg("input"), py::arg("quantizer")); - m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), - py::arg("quantizer")); m.def("geglu", transformer_engine::pytorch::geglu, "GeGLU activation", py::arg("input"), py::arg("quantizer")); + m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), + py::arg("quantizer")); m.def("qgeglu", transformer_engine::pytorch::qgeglu, "QuickGeGLU activation", py::arg("input"), py::arg("quantizer")); + /* ReLU and variants */ + m.def("relu", transformer_engine::pytorch::relu, "ReLU activation", py::arg("input"), + py::arg("quantizer")); m.def("reglu", transformer_engine::pytorch::reglu, "ReGLU activation", py::arg("input"), py::arg("quantizer")); - m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), + m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), py::arg("quantizer")); - m.def("qgelu", transformer_engine::pytorch::qgelu, "QuickGELU activation", py::arg("input"), + m.def("sreglu", transformer_engine::pytorch::sreglu, "Squared ReGLU activation", py::arg("input"), py::arg("quantizer")); - m.def("srelu", transformer_engine::pytorch::srelu, "Squared ReLU activation", py::arg("input"), + /* SwiGLU and variants */ + m.def("silu", transformer_engine::pytorch::silu, "SiLU activation", py::arg("input"), + py::arg("quantizer")); + m.def("swiglu", transformer_engine::pytorch::swiglu, "SwiGLU activation", py::arg("input"), py::arg("quantizer")); + /* Backward of GELU and variants */ m.def("dgelu", transformer_engine::pytorch::dgelu, "Backward of GeLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), - py::arg("fwd_input"), py::arg("quantizer")); m.def("dgeglu", transformer_engine::pytorch::dgeglu, "Backward of GeGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); m.def("dqgeglu", transformer_engine::pytorch::dqgeglu, "Backward of QuickGeGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + /* Backward of ReLU and variants */ + m.def("drelu", transformer_engine::pytorch::drelu, "Backward of ReLU", py::arg("grad"), + py::arg("fwd_input"), py::arg("quantizer")); m.def("dreglu", transformer_engine::pytorch::dreglu, "Backward of ReGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), + m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dqgelu", transformer_engine::pytorch::dqgelu, "Backward of QuickGELU", py::arg("grad"), + m.def("dsreglu", transformer_engine::pytorch::dsreglu, "Backward of Squared ReGLU", + py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + /* Backward of SiLU and variants */ + m.def("dsilu", transformer_engine::pytorch::dsilu, "Backward of SiLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); - m.def("dsrelu", transformer_engine::pytorch::dsrelu, "Backward of Squared ReLU", py::arg("grad"), + m.def("dswiglu", transformer_engine::pytorch::dswiglu, "Backward of SwiGLU", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); + /* DBias + DAct fusions*/ m.def("dbias_dgelu", transformer_engine::pytorch::dbias_dgelu, "DGeLU + DBias + Quantize", py::arg("grad"), py::arg("fwd_input"), py::arg("quantizer")); m.def("dbias_dsilu", transformer_engine::pytorch::dbias_dsilu, "DSiLU + DBias + Quantize", diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index c384dc3a7..2e51ac948 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -87,39 +87,45 @@ def _get_act_func_supported_list(recipe: Optional[Recipe] = None): # bf16 (recipe is None): return { "gelu": (tex.gelu, tex.dgelu, None), - "relu": (tex.relu, tex.drelu, None), "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "relu": (tex.relu, tex.drelu, None), + "reglu": (tex.reglu, tex.dreglu, None), "srelu": (tex.srelu, tex.dsrelu, None), + "sreglu": (tex.sreglu, tex.dsreglu, None), + "silu": (tex.silu, tex.dsilu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), } if recipe.delayed() or recipe.mxfp8(): # Delayed scaling, fusion supported list: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] # MXFP8: [tex.dbias_dgelu, tex.dbias_drelu, tex.dbias_dqgelu, tex.dbias_dsrelu] return { "gelu": (tex.gelu, tex.dgelu, tex.dbias_dgelu), - "relu": (tex.relu, tex.drelu, tex.dbias_drelu), "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), "qgelu": (tex.qgelu, tex.dqgelu, tex.dbias_dqgelu), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "relu": (tex.relu, tex.drelu, tex.dbias_drelu), + "reglu": (tex.reglu, tex.dreglu, None), "srelu": (tex.srelu, tex.dsrelu, tex.dbias_dsrelu), + "sreglu": (tex.sreglu, tex.dsreglu, None), + "silu": (tex.silu, tex.dsilu, tex.dbias_dsilu), + "swiglu": (tex.swiglu, tex.dswiglu, None), } # no activation fusion written yet # Per-tensor current scaling or fp8 blockwise scaling: [] if recipe.float8_current_scaling() or recipe.float8_block_scaling(): return { "gelu": (tex.gelu, tex.dgelu, None), - "relu": (tex.relu, tex.drelu, None), "geglu": (tex.geglu, tex.dgeglu, None), - "reglu": (tex.reglu, tex.dreglu, None), - "swiglu": (tex.swiglu, tex.dswiglu, None), "qgelu": (tex.qgelu, tex.dqgelu, None), "qgeglu": (tex.qgeglu, tex.dqgeglu, None), + "relu": (tex.relu, tex.drelu, None), + "reglu": (tex.reglu, tex.dreglu, None), "srelu": (tex.srelu, tex.dsrelu, None), + "sreglu": (tex.sreglu, tex.dsreglu, None), + "silu": (tex.silu, tex.dsilu, None), + "swiglu": (tex.swiglu, tex.dswiglu, None), } raise NotImplementedError(f"Unhandled recipe type {recipe}") @@ -1375,7 +1381,7 @@ def fc1_wgrad_gemm( class LayerNormMLP(TransformerEngineBaseModule): r""" Applies layer normalization on the input followed by the MLP module, consisting of - 2 successive linear transformations, separated by the GeLU activation. + 2 successive linear transformations, separated by the activation function. Parameters ---------- @@ -1391,7 +1397,8 @@ class LayerNormMLP(TransformerEngineBaseModule): type of normalization applied. activation : str, default = 'gelu' activation function used. - Options: 'gelu', 'geglu', 'relu', 'reglu', 'squared_relu', 'swiglu', 'qgelu', 'srelu'. + Options: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', + 'silu', and 'swiglu'. init_method : Callable, default = `None` used for initializing FC1 weights in the following way: `init_method(weight)`. When set to `None`, defaults to `torch.nn.init.normal_(mean=0.0, std=0.023)`. @@ -1592,7 +1599,7 @@ def __init__( self.layer_norm_bias = None # FC1 init - if self.activation in ["reglu", "geglu", "qgeglu", "swiglu"]: + if self.activation in ["geglu", "qgeglu", "reglu", "sreglu", "swiglu"]: fc1_output_features = 2 * self.size_per_partition else: fc1_output_features = self.size_per_partition @@ -1973,14 +1980,17 @@ def onnx_forward(self, inp: torch.Tensor) -> Union[torch.Tensor, Tuple[torch.Ten activation_map = { "gelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), - "relu": torch.nn.functional.relu, "geglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], - "reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], - "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), "qgeglu": lambda x: torch.nn.functional.gelu(x.chunk(2, -1)[0], approximate="tanh") * x.chunk(2, -1)[1], - "qgelu": lambda x: torch.nn.functional.gelu(x, approximate="tanh"), - "srelu": torch.nn.functional.softplus, + "relu": torch.nn.functional.relu, + "reglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], + "srelu": lambda x: torch.nn.functional.relu(x) ** 2, + "sreglu": lambda x: torch.nn.functional.relu(x.chunk(2, -1)[0]) ** 2 + * x.chunk(2, -1)[1], + "silu": torch.nn.functional.silu, + "swiglu": lambda x: torch.nn.functional.silu(x.chunk(2, -1)[0]) * x.chunk(2, -1)[1], } if self.activation not in activation_map: raise ValueError(f"Unsupported activation in onnx export: {self.activation}") diff --git a/transformer_engine/pytorch/ops/basic/__init__.py b/transformer_engine/pytorch/ops/basic/__init__.py index 843bfc1bd..2c903675f 100644 --- a/transformer_engine/pytorch/ops/basic/__init__.py +++ b/transformer_engine/pytorch/ops/basic/__init__.py @@ -4,7 +4,7 @@ """Single tensor operations supported by the operation fuser.""" -from .activation import GELU, ReLU, GEGLU, ReGLU, SwiGLU +from .activation import GELU, GEGLU, QGELU, QGEGLU, ReLU, ReGLU, SReLU, SReGLU, SiLU, SwiGLU from .add_extra_input import AddExtraInput from .all_gather import AllGather from .all_reduce import AllReduce diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index f1b59170e..5ef421bc1 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -16,6 +16,19 @@ from ..op import BasicOperation, OperationContext from .._common import maybe_dequantize +__all__ = [ + "GELU", + "GEGLU", + "QGELU", + "QGEGLU", + "ReLU", + "ReGLU", + "SReLU", + "SReGLU", + "SiLU", + "SwiGLU", +] + class _ActivationOperation(BasicOperation, metaclass=abc.ABCMeta): r"""Apply activation function @@ -147,37 +160,75 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dgelu(*args, **kwargs) -class ReLU(_ActivationOperation): - r"""Rectified linear unit +class GEGLU(_ActivationOperation): + r"""Gaussian Error Gated Linear Unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: .. math:: - \text{ReLU}(x) = \max(x,0) + \text{GEGLU}(a,b) = \text{GELU}(a) * b + + where + + .. math:: + + \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + See `GLU Variants Improve Transformer`__. """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.relu(*args, **kwargs) + return tex.geglu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.drelu(*args, **kwargs) + return tex.dgeglu(*args, **kwargs) -class GEGLU(_ActivationOperation): - r"""Gaussian error gated linear unit +class QGELU(_ActivationOperation): + r"""Quick Gaussian Error Linear Unit + + Quick GELU from `HuggingFace`__ + and `paper`__. + + .. math:: + + \text{QGELU}(x) \approx x * \sigma(1.702 * x) + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.qgelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dqgelu(*args, **kwargs) + + +class QGEGLU(_ActivationOperation): + r"""Quick Gaussian Error Gated Linear Unit The input tensor is split into chunks :math:`a` and :math:`b` along the last dimension and the following is computed: .. math:: - \text{GEGLU}(a,b) = \text{GELU}(a) * b + \text{QGEGLU}(a,b) = \text{QGELU}(a) * b where .. math:: - \text{GELU}(x) \approx \frac{x}{2} \left( 1 + \tanh\left( 0.797x+0.036 x^3 \right) \right) + \text{QGELU}(x) \approx x * \sigma(1.702 * x) .. warning:: @@ -187,19 +238,33 @@ class GEGLU(_ActivationOperation): the first half of the input tensor, while PyTorch applies it to the second half. - See `GLU Variants Improve Transformer`__. + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.qgeglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dqgeglu(*args, **kwargs) + + +class ReLU(_ActivationOperation): + r"""Rectified Linear Unit + + .. math:: + + \text{ReLU}(x) = \max(x,0) """ def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.geglu(*args, **kwargs) + return tex.relu(*args, **kwargs) def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: - return tex.dgeglu(*args, **kwargs) + return tex.drelu(*args, **kwargs) class ReGLU(_ActivationOperation): - r"""Rectified gated linear unit + r"""Rectified Gated Linear Unit The input tensor is split into chunks :math:`a` and :math:`b` along the last dimension and the following is computed: @@ -227,6 +292,67 @@ def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: return tex.dreglu(*args, **kwargs) +class SReLU(_ActivationOperation): + r"""Squared Rectified Linear Unit + + .. math:: + + \text{SReLU}(x) = \max(x^2,0) + + See `Primer: Searching for Efficient Transformers for Language Modeling`__. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.srelu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dsrelu(*args, **kwargs) + + +class SReGLU(_ActivationOperation): + r"""Squared Rectified Gated Linear Unit + + The input tensor is split into chunks :math:`a` and :math:`b` + along the last dimension and the following is computed: + + .. math:: + + \text{SReGLU}(a,b) = \max(a^2,0) * b + + .. warning:: + + Transformer Engine's gated activations and PyTorch's GLU + activation follow opposite conventions for :math:`a` and + :math:`b`. Transformer Engine applies the gating function to + the first half of the input tensor, while PyTorch applies it to + the second half. + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.sreglu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dsreglu(*args, **kwargs) + + +class SiLU(_ActivationOperation): + r"""Sigmoid Linear Unit + + .. math:: + + \text{SiLU}(x) = x \sigma(x) = \frac{x}{1+\exp(-x)} + + """ + + def _activation_forward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.silu(*args, **kwargs) + + def _activation_backward_impl(self, *args, **kwargs) -> torch.Tensor: + return tex.dsilu(*args, **kwargs) + + class SwiGLU(_ActivationOperation): r"""Swish gated linear unit diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index 1a98f2f52..89e43f845 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -175,7 +175,8 @@ class TransformerLayer(torch.nn.Module): if set to `False`, the transformer layer will not learn any additive biases. activation : str, default = 'gelu' Type of activation used in MLP block. - Options are: 'gelu', 'relu', 'reglu', 'geglu', 'swiglu', 'qgelu' and 'srelu'. + Options are: 'gelu', 'geglu', 'qgelu', 'qgeglu', 'relu', 'reglu', 'srelu', 'sreglu', + 'silu', and 'swiglu'. device : Union[torch.device, str], default = "cuda" The device on which the parameters of the model will be allocated. It is the user's responsibility to ensure all parameters are moved to the GPU before running the From 3d0ea80a77d35eb07647c60aa3a3b54ce67c82c8 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 26 Aug 2025 09:37:52 -0400 Subject: [PATCH 098/153] [JAX] `ScaledTensor1x` to store `amax` (#2117) * added amax as an optional arg Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../jax/cpp_extensions/activation.py | 2 +- .../jax/cpp_extensions/normalization.py | 7 ++- .../jax/cpp_extensions/quantization.py | 2 +- transformer_engine/jax/quantize/quantizer.py | 12 ++--- transformer_engine/jax/quantize/tensor.py | 44 ++++++++++++++++--- 5 files changed, 51 insertions(+), 16 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index b8dcca66c..fe2253598 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1037,7 +1037,7 @@ def act_lu( out = out.reshape(output_shape) if noop_scaled_tensor: return ScaledTensorFactory.create_2x( - out, None, out, None, ScalingMode.NO_SCALING, dq_dtype=out.dtype + out, None, out, None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=out.dtype ) return out diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 3b563efbd..7296afc72 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -1324,7 +1324,12 @@ def normalization_fwd( if quantizer is None and noop_scaled_tensor: return ( ScaledTensorFactory.create_2x( - output, None, output, None, ScalingMode.NO_SCALING, dq_dtype=output.dtype + output, + None, + output, + None, + scaling_mode=ScalingMode.NO_SCALING, + dq_dtype=output.dtype, ), mu, rsigma, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index a7697ce25..0b2755744 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -591,7 +591,7 @@ def _quantize_dbias_impl( None, x, None, - ScalingMode.NO_SCALING, + scaling_mode=ScalingMode.NO_SCALING, dq_dtype=x.dtype, data_layout="NN", flatten_axis=flatten_axis, diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 09856065c..9a65f99bf 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -494,7 +494,7 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> return ScaledTensorFactory.create_1x( x_q, scales_q, - self.scaling_mode, + scaling_mode=self.scaling_mode, is_colwise=is_colwise, dq_dtype=dq_dtype, flatten_axis=flatten_axis, @@ -640,11 +640,11 @@ def _create_grouped_tensor_from_tensor_list( return ScaledTensorFactory.create_1x( grouped_data, grouped_scale_inv, - self.scaling_mode, - tensor_list[0].dq_dtype, - tensor_list[0].is_colwise, - tensor_list[0].data_layout, - tensor_list[0].flatten_axis, + scaling_mode=self.scaling_mode, + dq_dtype=tensor_list[0].dq_dtype, + is_colwise=tensor_list[0].is_colwise, + data_layout=tensor_list[0].data_layout, + flatten_axis=tensor_list[0].flatten_axis, group_sizes=group_sizes, original_shape=original_shape, group_axis=group_axis, diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 97e127269..1459175b7 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -104,6 +104,7 @@ class ScaledTensor1x(ScaledTensor): Attributes: data: The quantized tensor data scale_inv: The inverse scaling factors + amax: The maximum absolute value of the tensor scaling_mode: The scaling mode used for quantization dq_dtype: The data type for dequantized values _dq_func: The dequantization function @@ -114,6 +115,7 @@ class ScaledTensor1x(ScaledTensor): data: jnp.ndarray scale_inv: jnp.ndarray + amax: jnp.ndarray scaling_mode: ScalingMode dq_dtype: jnp.dtype _dq_func: Callable @@ -152,7 +154,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv) + children = (self.data, self.scale_inv, self.amax) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -224,6 +226,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st return ScaledTensor1x( data=data, scale_inv=scale_inv, + amax=self.amax, scaling_mode=self.scaling_mode, dq_dtype=self.dq_dtype, _dq_func=self._dq_func, @@ -255,6 +258,7 @@ def __init__( self, data, scale_inv, + amax, group_sizes, scaling_mode, dq_dtype, @@ -270,7 +274,15 @@ def __init__( self.original_shape = original_shape self.group_axis = group_axis super().__init__( - data, scale_inv, scaling_mode, dq_dtype, _dq_func, is_colwise, data_layout, flatten_axis + data, + scale_inv, + amax, + scaling_mode, + dq_dtype, + _dq_func, + is_colwise, + data_layout, + flatten_axis, ) def __post_init__(self): @@ -308,7 +320,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv, self.group_sizes) + children = (self.data, self.scale_inv, self.amax, self.group_sizes) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -413,7 +425,8 @@ class ScaledTensorFactory: def create_1x( data, scale_inv, - scaling_mode, + amax=None, + scaling_mode=ScalingMode.NO_SCALING, dq_dtype=jnp.bfloat16, is_colwise=False, data_layout="N", @@ -427,18 +440,22 @@ def create_1x( Args: data: The quantized tensor data scale_inv: The inverse scaling factors + amax: The maximum absolute value of the tensor scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) is_colwise: Whether to use column-wise quantization (default: False) data_layout: The data_layout specification (default: "N") flatten_axis: The quantization axis for the tensor - group_sizes: Arra of ints containing the size of each group (default: None) + group_sizes: Array of ints containing the size of each group (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) Returns: A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided """ + if amax is None: + amax = jnp.empty((1,), dtype=jnp.float32) + dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) if group_sizes is not None: @@ -468,6 +485,7 @@ def create_1x( return GroupedScaledTensor1x( data=data, scale_inv=scale_inv, + amax=amax, scaling_mode=scaling_mode, dq_dtype=dq_dtype, _dq_func=dequantizer.grouped_dequantize, @@ -487,6 +505,7 @@ def create_1x( return ScaledTensor1x( data, scale_inv, + amax, scaling_mode, dq_dtype, dequantizer.dequantize, @@ -501,7 +520,8 @@ def create_2x( scale_inv, colwise_data, colwise_scale_inv, - scaling_mode, + amax=None, + scaling_mode=ScalingMode.NO_SCALING, dq_dtype=jnp.bfloat16, data_layout="NN", flatten_axis=-1, @@ -516,6 +536,7 @@ def create_2x( scale_inv: The row-wise inverse scaling factors colwise_data: The column-wise quantized data colwise_scale_inv: The column-wise inverse scaling factors + amax: The maximum absolute value of the tensor scaling_mode: The scaling mode for quantization dq_dtype: The data type for dequantized values (default: bfloat16) data_layout: The data_layout specification (default: "NN") @@ -527,10 +548,14 @@ def create_2x( Returns: A ScaledTensor2x instance """ + if amax is None: + amax = jnp.empty((1,), dtype=jnp.float32) + assert len(data_layout) == 2, f"Expect 2 layouts, got {data_layout}" rowwise_tensor = ScaledTensorFactory.create_1x( data, scale_inv, + amax, scaling_mode, dq_dtype, is_colwise=False, @@ -543,6 +568,7 @@ def create_2x( colwise_tensor = ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, + amax, scaling_mode, dq_dtype, is_colwise=True, @@ -560,7 +586,8 @@ def create( scale_inv: jnp.ndarray, colwise_data: jnp.ndarray, colwise_scale_inv: jnp.ndarray, - scaling_mode: ScalingMode, + amax=None, + scaling_mode: ScalingMode = ScalingMode.NO_SCALING, dq_dtype: jnp.dtype = jnp.bfloat16, data_layout: str = "NN", q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, @@ -594,6 +621,7 @@ def create( scale_inv, colwise_data, colwise_scale_inv, + amax, scaling_mode, dq_dtype, data_layout=data_layout, @@ -608,6 +636,7 @@ def create( return ScaledTensorFactory.create_1x( colwise_data, colwise_scale_inv, + amax, scaling_mode, dq_dtype, is_colwise=is_colwise, @@ -621,6 +650,7 @@ def create( return ScaledTensorFactory.create_1x( data, scale_inv, + amax, scaling_mode, dq_dtype, is_colwise=is_colwise, From d972e76d0688b7f3441df15fc6e36555106c4817 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 26 Aug 2025 08:49:36 -0700 Subject: [PATCH 099/153] Revert "[Common] PDL for Quantization Kernels" (#2114) Revert "[Common] PDL for Quantization Kernels (#2001)" This reverts commit bfab8c679f17bed5b63ae5c904c205f164beaae4. Signed-off-by: Jeremy Berchtold --- .../common/util/cast_kernels.cuh | 54 ++++++++----------- 1 file changed, 21 insertions(+), 33 deletions(-) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index c084c3116..9a02d71f2 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -203,11 +203,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) // Wait for the data to have arrived ptx::mbarrier_wait_parity(&mbar[stage], parity); - // Trigger the next kernel, so its TMA load can be overlapped with the current kernel - if (stage == STAGES - 1) { - cudaTriggerProgrammaticLaunchCompletion(); - } - float thread_amax = 0.0f; if constexpr (COLWISE_SCALING) { const size_t shmem_offset_base_colwise = buff * BUFF_DIM + tid_X_colwise; @@ -1127,13 +1122,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - cudaLaunchConfig_t cfg = {grid, block_size, dshmem_size, stream, NULL, 0}; - // This kernel will only be called on sm100+, so no need to check sm_arch - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; cfg.attrs = attribute; - cfg.numAttrs = 1; - switch (scaling_type) { case ScalingType::ROWWISE: cudaFuncSetAttribute( @@ -1141,13 +1129,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, false, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - cudaLaunchKernelEx( - &cfg, - cast_mxfp8_2D_kernel, - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); break; case ScalingType::COLWISE: cudaFuncSetAttribute( @@ -1155,13 +1143,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - cudaLaunchKernelEx( - &cfg, - cast_mxfp8_2D_kernel, - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); break; case ScalingType::BIDIMENSIONAL: cudaFuncSetAttribute( @@ -1169,13 +1157,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, true, CHUNK_DIM_Y, CHUNK_DIM_X, THREADS_PER_CHUNK>, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); - cudaLaunchKernelEx( - &cfg, - cast_mxfp8_2D_kernel, - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, - workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + cast_mxfp8_2D_kernel + <<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, + workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); break; } From d770886f02b6c43da1afb65d19487a44cbb8cc88 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 26 Aug 2025 12:29:17 -0400 Subject: [PATCH 100/153] [JAX] Add `tpsp_resource` in the `MeshResource` map (#2113) * clean up sharding Signed-off-by: Phuong Nguyen * added tpsp_resource Signed-off-by: Phuong Nguyen * update tests Signed-off-by: Phuong Nguyen * rework test for MeshResource Signed-off-by: Phuong Nguyen * add mesh_resource into fp8_autocast in test_helper.py Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- .../encoder/test_model_parallel_encoder.py | 5 +- examples/jax/encoder/test_multigpu_encoder.py | 2 +- .../encoder/test_multiprocessing_encoder.py | 5 +- tests/jax/distributed_test_base.py | 10 +- tests/jax/test_distributed_fused_attn.py | 4 +- tests/jax/test_distributed_helper.py | 35 ++++++ tests/jax/test_distributed_layernorm_mlp.py | 26 ++--- tests/jax/test_distributed_softmax.py | 4 +- tests/jax/test_fused_attn.py | 6 +- tests/jax/test_helper.py | 47 +++----- tests/jax/test_sharding.py | 38 ------- transformer_engine/jax/__init__.py | 12 -- transformer_engine/jax/cpp_extensions/gemm.py | 15 ++- transformer_engine/jax/sharding.py | 104 +++++------------- 14 files changed, 125 insertions(+), 188 deletions(-) create mode 100644 tests/jax/test_distributed_helper.py delete mode 100644 tests/jax/test_sharding.py diff --git a/examples/jax/encoder/test_model_parallel_encoder.py b/examples/jax/encoder/test_model_parallel_encoder.py index 382133360..41832650f 100644 --- a/examples/jax/encoder/test_model_parallel_encoder.py +++ b/examples/jax/encoder/test_model_parallel_encoder.py @@ -267,7 +267,10 @@ def train_and_evaluate(args): ) as mesh, te.fp8_autocast( enabled=args.use_fp8, fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + mesh_resource=te.MeshResource( + dp_resource=DEVICE_DP_AXIS, + tpsp_resource=DEVICE_TP_AXIS, + ), ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) diff --git a/examples/jax/encoder/test_multigpu_encoder.py b/examples/jax/encoder/test_multigpu_encoder.py index a4bd83d2b..bc6a56752 100644 --- a/examples/jax/encoder/test_multigpu_encoder.py +++ b/examples/jax/encoder/test_multigpu_encoder.py @@ -264,7 +264,7 @@ def train_and_evaluate(args): ) as mesh, te.fp8_autocast( enabled=args.use_fp8, fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, None, None, None), + mesh_resource=te.MeshResource(dp_resource=DEVICE_DP_AXIS), ): rng = jax.random.PRNGKey(args.seed) diff --git a/examples/jax/encoder/test_multiprocessing_encoder.py b/examples/jax/encoder/test_multiprocessing_encoder.py index f112740a3..abf6a407b 100644 --- a/examples/jax/encoder/test_multiprocessing_encoder.py +++ b/examples/jax/encoder/test_multiprocessing_encoder.py @@ -382,7 +382,10 @@ def train_and_evaluate(args): ) as mesh, te.fp8_autocast( enabled=args.use_fp8, fp8_recipe=fp8_recipe, - mesh_resource=te.MeshResource(DEVICE_DP_AXIS, DEVICE_TP_AXIS, None, None), + mesh_resource=te.MeshResource( + dp_resource=DEVICE_DP_AXIS, + tpsp_resource=DEVICE_TP_AXIS, + ), ): rng = jax.random.PRNGKey(args.seed) rng, params_rng = jax.random.split(rng) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index bda42f5f7..7c08539c3 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -22,7 +22,7 @@ def generate_configs(): pytest.param(2, (2,), ("dp",), MeshResource(dp_resource="dp"), id="n2_dp2_tp1") ) configs.append( - pytest.param(2, (2,), ("tp",), MeshResource(tp_resource="tp"), id="n2_dp1_tp2") + pytest.param(2, (2,), ("tpsp",), MeshResource(tpsp_resource="tpsp"), id="n2_dp1_tp2") ) if is_devices_enough(4): @@ -30,8 +30,8 @@ def generate_configs(): pytest.param( 4, (2, 2), - ("dp", "tp"), - MeshResource(dp_resource="dp", tp_resource="tp"), + ("dp", "tpsp"), + MeshResource(dp_resource="dp", tpsp_resource="tpsp"), id=f"n4_dp2_tp2", ) ) @@ -43,8 +43,8 @@ def generate_context_parallel_configs_for_attn(): """Generate CP combinations along with TP+DP for TestDistributedContextParallelSelfAttn only""" configsL1 = [] configsL2 = [] - mr = MeshResource(dp_resource="dp", cp_resource="cp", tp_resource="tp") - axes = ("dp", "cp", "tp") + mr = MeshResource(dp_resource="dp", cp_resource="cp", tpsp_resource="tpsp") + axes = ("dp", "cp", "tpsp") DP_sizes = (1, 2) CP_sizes = (1, 2, 4, 8) TP_sizes = (1, 2) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index ea29736e7..ef8e370b6 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -45,8 +45,8 @@ def generate_collectives_count_ref( _, seqlen, heads, _ = shape is_dp_enabled = mesh_resource.dp_resource is not None tp_size = 1 - if mesh_resource.tp_resource is not None: - idx = mesh_axes.index(mesh_resource.tp_resource) + if mesh_resource.tpsp_resource is not None: + idx = mesh_axes.index(mesh_resource.tpsp_resource) tp_size = mesh_shape[idx] all_reduce_loss_bytes = 4 # 1 * FP32 diff --git a/tests/jax/test_distributed_helper.py b/tests/jax/test_distributed_helper.py new file mode 100644 index 000000000..e74e9aa6f --- /dev/null +++ b/tests/jax/test_distributed_helper.py @@ -0,0 +1,35 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import unittest + +import jax +import numpy as np + +from utils import pytest_parametrize_wrapper, is_devices_enough +from transformer_engine.jax.sharding import MeshResource, global_mesh_resource +from transformer_engine.jax import fp8_autocast + + +def generate_mesh_configs(): + configs = [] + if is_devices_enough(2): + configs.append( + [2, (1, 2), ("dp", "tpsp"), MeshResource(dp_resource="dp", tpsp_resource="tpsp")] + ) + if is_devices_enough(4): + configs.append( + [4, (2, 2), ("fsdp", "tp"), MeshResource(tp_resource="tp", fsdp_resource="fsdp")] + ) + return configs + + +class TestMeshResource(unittest.TestCase): + def test_fp8_autocast_with_mesh_resource(self): + for mesh_config in generate_mesh_configs(): + device_count, mesh_shape, mesh_axes, mesh_resource = mesh_config + devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) + mesh = jax.sharding.Mesh(devices, mesh_axes) + with mesh, fp8_autocast(enabled=False, mesh_resource=mesh_resource): + self.assertEqual(mesh_resource, global_mesh_resource()) diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index e3b1ecac9..90b762c24 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -62,16 +62,16 @@ INTERMEDIATE = 64 -# Only test with FSDP and TP as DP is not used -def generate_fsdp_and_tp_configs(): +# Only test with FSDP and TPSP as DP is not used +def generate_fsdp_and_tpsp_configs(): configs = [] if is_devices_enough(2): configs.append( - [2, (1, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] + [2, (1, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] ) if is_devices_enough(4): configs.append( - [4, (2, 2), ("fsdp", "tp"), MeshResource(fsdp_resource="fsdp", tp_resource="tp")] + [4, (2, 2), ("fsdp", "tpsp"), MeshResource(fsdp_resource="fsdp", tpsp_resource="tpsp")] ) return configs @@ -186,12 +186,12 @@ def _test_layernorm_mlp_grad( with mesh, fp8_autocast( enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource ): - k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tp")) - k2_sharding = NamedSharding(mesh, PartitionSpec("tp", "fsdp")) + k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) + k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) k1_ = jax.device_put(k1, k1_sharding) k2_ = jax.device_put(k2, k2_sharding) if use_bias: - b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tp")) + b1_sharding = NamedSharding(mesh, PartitionSpec(None, "tpsp")) b1_ = jax.device_put(b1, b1_sharding) else: b1_sharding = b1_ = None @@ -247,7 +247,7 @@ def _test_layernorm_mlp_grad( ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @@ -276,7 +276,7 @@ def test_layernorm_mlp_grad( ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @@ -408,7 +408,7 @@ def _test_layernorm_mlp( assert_allclose(mlp_out_sharded, mlp_out_single, dtype=dtype, atol=atol, rtol=rtol) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @@ -429,7 +429,7 @@ def test_layernorm_mlp_layer( ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) @@ -452,7 +452,7 @@ def test_layernorm_mlp_layer_fp8( ) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("silu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) @@ -473,7 +473,7 @@ def test_layernorm_mlp_layer_shardy( ) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tp_configs()) + @pytest_parametrize_wrapper("mesh_config", generate_fsdp_and_tpsp_configs()) @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("use_bias", [True, False]) @pytest_parametrize_wrapper("input_shape", INPUT_SHAPE) diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 8d2ad6fad..d9eaf314a 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -41,11 +41,11 @@ def generate_inputs( if not bad_sharding: x_pspec = PartitionSpec( - mesh_resource.dp_resource, mesh_resource.tp_resource, None, None + mesh_resource.dp_resource, mesh_resource.tpsp_resource, None, None ) else: x_pspec = PartitionSpec( - mesh_resource.dp_resource, None, None, mesh_resource.tp_resource + mesh_resource.dp_resource, None, None, mesh_resource.tpsp_resource ) if broadcast_batch_mask: diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 29a9bc2b9..ec530a395 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -397,7 +397,7 @@ def _setup_inputs(self): self.mesh = Mesh(self.devices, self.mesh_axes) self.dp_size = self.mesh.shape.get(self.mesh_resource.dp_resource, 1) self.cp_size = self.mesh.shape.get(self.mesh_resource.cp_resource, 1) - self.tp_size = self.mesh.shape.get(self.mesh_resource.tp_resource, 1) + self.tp_size = self.mesh.shape.get(self.mesh_resource.tpsp_resource, 1) key = jax.random.PRNGKey(0) q_key, k_key, v_key, bias_key, dropout_key = jax.random.split(key, 5) @@ -630,7 +630,7 @@ def generate_random_segment_ids( self.qkvo_psec = PartitionSpec( self.mesh_resource.dp_resource, self.mesh_resource.cp_resource, - self.mesh_resource.tp_resource, + self.mesh_resource.tpsp_resource, None, ) self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec) @@ -658,7 +658,7 @@ def to_dp_shardings(x): if self.bias_shape == BiasShape._1HSS: self.bias_pspec = PartitionSpec( - None, self.mesh_resource.tp_resource, self.mesh_resource.cp_resource, None + None, self.mesh_resource.tpsp_resource, self.mesh_resource.cp_resource, None ) elif self.bias_shape == BiasShape._B1SS: self.bias_pspec = PartitionSpec( diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index d0a3efd27..9b67de6dd 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -71,20 +71,20 @@ def test_fp8_autocast_delayed_scaling(self): QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() - with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling()): + with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()): self._check_default_state() self._check_default_state() ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=ds): + with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_delay_scaling(get_delayed_scaling(), ds) self._check_default_state() ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) - with fp8_autocast(enabled=True, fp8_recipe=ds): + with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_delay_scaling(get_delayed_scaling(), ds) @@ -95,20 +95,22 @@ def test_fp8_autocast_current_scaling(self): QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() - with fp8_autocast(enabled=False, fp8_recipe=Float8CurrentScaling()): + with fp8_autocast( + enabled=False, fp8_recipe=Float8CurrentScaling(), mesh_resource=MeshResource() + ): self._check_default_state() self._check_default_state() cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) - with fp8_autocast(enabled=True, fp8_recipe=cs): + with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) - with fp8_autocast(enabled=True, fp8_recipe=cs): + with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_current_scaling(cs) @@ -119,46 +121,23 @@ def test_fp8_autocast_mxfp8_block_scaling(self): QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() - with fp8_autocast(enabled=False, fp8_recipe=MXFP8BlockScaling()): + with fp8_autocast( + enabled=False, fp8_recipe=MXFP8BlockScaling(), mesh_resource=MeshResource() + ): self._check_default_state() self._check_default_state() bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) - with fp8_autocast(enabled=True, fp8_recipe=bs): + with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) - with fp8_autocast(enabled=True, fp8_recipe=bs): + with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): self.assertTrue(QuantizeConfig.is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() - - @unittest.skipIf(not is_fp8_supported, reason=reason) - def test_fp8_autocast_with_sharding_resource(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. - self._check_default_state() - - ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) - - mesh_s = ( - (MeshResource(None, None)), - (MeshResource("dp", None)), - (MeshResource(None, "tp")), - (MeshResource("dp", "tp")), - ) - # TODO (Ming Huang): Support multi-GPUs testing. # pylint: disable=fixme - mesh_shape = (1, 1) - devices = np.asarray(jax.devices()[:1]).reshape(*mesh_shape) - with jax.sharding.Mesh(devices, ("dp", "tp")): - for sr in mesh_s: - with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=sr): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) - self._compare_delay_scaling(get_delayed_scaling(), ds) - self.assertEqual(sr, global_mesh_resource()) - - self._check_default_state() diff --git a/tests/jax/test_sharding.py b/tests/jax/test_sharding.py deleted file mode 100644 index 0d50b7345..000000000 --- a/tests/jax/test_sharding.py +++ /dev/null @@ -1,38 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -import pytest - -from transformer_engine.jax.flax import extend_logical_axis_rules -from transformer_engine.jax.sharding import global_shard_guard, MeshResource - -LOGICAL_RULES = [ - [(("a1", None), ("a2", "ma2")), False], - [(("a1", None), ("a2", "ma2"), ("a3", ("ma31", "ma32"))), True], - [(("a1", None), ("a2", "ma2"), ("a3", "ma31"), ("a3", "ma32")), False], - [(("a1", None), ("a2", "ma2"), ("batch", "batch_1200234")), True], - [(("a1", None), ("a2", "ma2"), ("a2", "ma1"), ("batch", "model"), ("batch", "data")), True], -] - -MeshS = [ - MeshResource(), - MeshResource("data", None), - MeshResource(None, "model"), - MeshResource("data", "model"), -] - - -class TestShardingSideAPI: - - @pytest.mark.parametrize("base_rules,need_assert", LOGICAL_RULES) - @pytest.mark.parametrize("sr", MeshS) - def test_extend_logical_axis_rules(self, base_rules, need_assert, sr): - with global_shard_guard(sr): - try: - target_te_rules = extend_logical_axis_rules(tuple()) - extended_rules = extend_logical_axis_rules(base_rules) - assert extended_rules == (*base_rules, *target_te_rules) - assert not need_assert - except AssertionError as ae: - assert need_assert, f"{ae.args}" diff --git a/transformer_engine/jax/__init__.py b/transformer_engine/jax/__init__.py index 55ffad93e..0b5e43402 100644 --- a/transformer_engine/jax/__init__.py +++ b/transformer_engine/jax/__init__.py @@ -38,19 +38,10 @@ from .quantize import NVTE_FP8_COLLECTION_NAME from .sharding import MeshResource -from .sharding import MajorShardingType, ShardingResource, ShardingType from ..common.utils import deprecate_wrapper from ..common.utils import DeprecatedEnum -MajorShardingType = DeprecatedEnum( - MajorShardingType, "MajorShardingType is deprecating in the near feature." -) -ShardingType = DeprecatedEnum(ShardingType, "ShardingType is deprecating in the near feature.") -ShardingResource = deprecate_wrapper( - ShardingResource, - "ShardingResource is renamed to MeshResource, and will be removed in the near feature.", -) __all__ = [ "NVTE_FP8_COLLECTION_NAME", @@ -58,9 +49,6 @@ "update_collections", "get_delayed_scaling", "MeshResource", - "MajorShardingType", - "ShardingResource", - "ShardingType", "flax", "quantize", ] diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 188b37601..95ef42821 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -453,6 +453,19 @@ def _parse_operand_output_specs( ): lhs_specs, _, rhs_specs, *_ = map(get_padded_spec, arg_infos) + gsr = global_mesh_resource() + + # Ensure that tensor sequence parallelism is not used via setting tp_resource + if gsr.tp_resource is not None: + for i in range(len(lhs_specs) - 1): + if lhs_specs[i] == gsr.tp_resource and lhs_specs[i + 1] == gsr.tp_resource: + warnings.warn( + "Tensor sequence parallelism is detected as" + f" tp_resource='{gsr.tp_resource}' appears twice consecutively in" + f" lhs_specs: {lhs_specs}. Please setting MeshResource.tpsp_resource for" + " tensor sequence parallelism to avoid potential issues." + ) + lhs_ndim, rhs_ndim = map(len, (lhs_specs, rhs_specs)) lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs_ndim, rhs_ndim), contracting_dims) lhs_non_cdims, rhs_non_cdims = map( @@ -492,7 +505,7 @@ def _parse_operand_output_specs( # Non-contracting dims of RHS always needs to be gathered along the FSDP axis rhs_non_cspecs = tuple( - None if spec is not None and spec == global_mesh_resource().fsdp_resource else spec + None if spec is not None and spec == gsr.fsdp_resource else spec for spec in rhs_non_cspecs ) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 480989dcd..caa2a4620 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -9,10 +9,8 @@ parallelism (FSDP). It includes functions for sharding constraints, mesh management, and collective operations. """ -import os from contextlib import contextmanager from dataclasses import dataclass -from enum import Enum from typing import Callable, Optional import warnings import jax @@ -43,44 +41,46 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): return mesh.shape[resource], resource -def get_sharding_map_logic_axis_to_mesh_axis(): - """ - Generate a dict to map logical axes to mesh axes. - """ +def _validate_mesh_resource_configuration(): + """Validate that the mesh resource configuration is consistent and conflict-free.""" gsr = global_mesh_resource() - IS_FSDP_OUTER = bool(int(os.environ.get("NVTE_OUTER_BATCH_FSDP_DIM", False))) + is_dp_enabled = gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1 + is_tp_enabled = gsr.tp_resource is not None and get_mesh_axis_size(gsr.tp_resource) > 1 + is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1 + is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1 - batch_resources = ( - [gsr.fsdp_resource, gsr.dp_resource] - if IS_FSDP_OUTER - else [gsr.dp_resource, gsr.fsdp_resource] + assert not (is_dp_enabled and is_fsdp_enabled), ( + "Data parallelism and full-sharded data parallelism cannot be enabled at the same time." + f" Got dp_resource={gsr.dp_resource} and fsdp_resource={gsr.fsdp_resource}" + ) + assert not (is_tp_enabled and is_tpsp_enabled), ( + "Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time." + f" Got tp_resource={gsr.tp_resource} and tpsp_resource={gsr.tpsp_resource}" ) - batch_dim_rule = [] - for resource in batch_resources: - if resource is not None and resource not in batch_dim_rule: - batch_dim_rule.append(resource) - if len(batch_dim_rule) <= 0: - batch_dim_rule = None - elif len(batch_dim_rule) == 1: - batch_dim_rule = batch_dim_rule[0] - else: - batch_dim_rule = tuple(batch_dim_rule) +def get_sharding_map_logic_axis_to_mesh_axis(): + """ + Generate a dict to map logical axes to mesh axes. + """ + gsr = global_mesh_resource() + + is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1 + is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1 te_logical_axis_to_mesh_axis = { - BATCH_AXES: batch_dim_rule, + BATCH_AXES: gsr.fsdp_resource if is_fsdp_enabled else gsr.dp_resource, SEQLEN_AXES: None, - SEQLEN_TP_AXES: gsr.tp_resource, + SEQLEN_TP_AXES: gsr.tpsp_resource, SEQLEN_CP_AXES: gsr.cp_resource, - HEAD_AXES: gsr.tp_resource, + HEAD_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource, HIDDEN_AXES: None, - HIDDEN_TP_AXES: gsr.tp_resource, + HIDDEN_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource, JOINED_AXES: None, W_NO_SHARD_AXES: None, W_FSDP_AXES: gsr.fsdp_resource, - W_TP_AXES: gsr.tp_resource, + W_TP_AXES: gsr.tpsp_resource if is_tpsp_enabled else gsr.tp_resource, W_JOINED_AXES: None, } return te_logical_axis_to_mesh_axis @@ -274,6 +274,7 @@ class MeshResource: Attributes: dp_resource: Axis name for data parallelism (batch sharding), default is None tp_resource: Axis name for tensor parallelism (hidden dimension sharding), default is None + tpsp_resource: Axis name for tensor sequence parallelism (hidden and sequence sharding), default is None fsdp_resource: Axis name for full-sharded data parallelism, default is None pp_resource: Axis name for pipeline parallelism (layer sharding), default is None cp_resource: Axis name for context parallelism (sequence sharding), default is None @@ -281,6 +282,7 @@ class MeshResource: dp_resource: str = None tp_resource: str = None + tpsp_resource: str = None fsdp_resource: str = None pp_resource: str = None cp_resource: str = None @@ -303,6 +305,7 @@ def global_shard_guard(resource: MeshResource): old_resources = _GLOBAL_MESH_RESOURCE try: _GLOBAL_MESH_RESOURCE = resource + _validate_mesh_resource_configuration() yield finally: _GLOBAL_MESH_RESOURCE = old_resources @@ -351,52 +354,3 @@ def all_reduce_max_along_all_axes_except_PP(x: jnp.array, mesh: jax.sharding.Mes if axis != global_mesh_resource().pp_resource: x = lax_paral_op(x, jax.lax.pmax, axis, mesh) return x - - -# Deprecating Items --------------------------------------------------------------- -ShardingResource = MeshResource - -global_shard_resource = global_mesh_resource - - -class MajorShardingType(Enum): - """Enumeration of major sharding types for distributed training. - - This enum defines the basic sharding patterns available for distributed - training. Note that this class is deprecated and will be removed in the future. - - Values: - SINGLE: Single process training - DP: Data parallel training - TP: Standard tensor parallel training - DPTP: Data and standard tensor parallel training - """ - - SINGLE = 0 - DP = 1 - TP = 2 - DPTP = 3 - - -class ShardingType(Enum): - """Enumeration of detailed sharding types for distributed training. - - This enum defines specific sharding patterns for distributed training, - including combinations of data parallelism and different tensor parallelism - strategies. Note that this class is deprecated and will be removed in the future. - - Values: - SINGLE: No sharding - DP: Sharding along data parallelism - TP_COL: Sharding along column-split tensor parallelism - TP_ROW: Sharding along row-split tensor parallelism - DP_TP_COL: Sharding along data and column-split tensor parallelism - DP_TP_ROW: Sharding along data and row-split tensor parallelism - """ - - SINGLE = (MajorShardingType.SINGLE, "single") - DP = (MajorShardingType.DP, "dp") - TP_COL = (MajorShardingType.TP, "tp_col") - TP_ROW = (MajorShardingType.TP, "tp_row") - DP_TP_COL = (MajorShardingType.DPTP, "dp_tp_col") - DP_TP_ROW = (MajorShardingType.DPTP, "dp_tp_row") From 54c0c8579b42123c0674a5b223770f0eba05eba4 Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Tue, 26 Aug 2025 10:52:33 -0700 Subject: [PATCH 101/153] Bump cuDNN FE to 1.14.0 (#2072) * Bump cuDNN FE to 1.14.0 Signed-off-by: Vladimir Cherepanov * Change submodule hash Signed-off-by: Vladimir Cherepanov * Pick up a cuDNN FE fix Signed-off-by: Vladimir Cherepanov * New model configs in tests Signed-off-by: Vladimir Cherepanov * Exclude cuDNN backend for some configs Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov --- 3rdparty/cudnn-frontend | 2 +- tests/pytorch/attention/test_attention.py | 2 ++ transformer_engine/common/fused_attn/fused_attn.cpp | 5 +++-- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/3rdparty/cudnn-frontend b/3rdparty/cudnn-frontend index 9793df569..deda80e53 160000 --- a/3rdparty/cudnn-frontend +++ b/3rdparty/cudnn-frontend @@ -1 +1 @@ -Subproject commit 9793df569ce413f4b1844a9176f7ae24dd981603 +Subproject commit deda80e5372d50e925d7bf4f76c5db779be3fbd5 diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 3088853a2..56bfa1423 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -274,6 +274,8 @@ def test_dpa_checkpoint(dtype, model_configs, model): "mla_3_0": ModelConfig(8, 1, 16, 128, max_seqlen_kv=2048, head_dim_v=64), # inference "mla_3_1": ModelConfig(8, 1, 16, 256, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference + "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference + "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference } diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index bb30261b9..60b10862e 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -252,8 +252,9 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 91100)) && // 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200) && is_training && - sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && + (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200 || + cudnn_runtime_version == 91300) && + is_training && sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || From d3706087318ec95dc961a606e07947f15f0224a5 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Tue, 26 Aug 2025 11:09:12 -0700 Subject: [PATCH 102/153] Revert "[Common] PDL for Blockwise Quantization" (#2115) Revert "[Common] PDL for Blockwise Quantization (#2066)" This reverts commit ebca61532000c72113cdb2987d50b9fba08d0d8c. Signed-off-by: Jeremy Berchtold --- .../quantize_transpose_square_blockwise.cu | 63 ++++++------------- .../quantize_transpose_vector_blockwise.cu | 54 ++++------------ 2 files changed, 33 insertions(+), 84 deletions(-) diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index 3a2247f5c..a603d1f1a 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -14,7 +14,6 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" -#include "common/util/cuda_runtime.h" #include "common/util/ptx.cuh" #include "common/utils.cuh" @@ -168,12 +167,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } } -// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's -// store to global memory. -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - cudaTriggerProgrammaticLaunchCompletion(); -#endif - // Step 3: Store cast output, Step 4: do transpose within thread tile OVecCast tmp_output_c; @@ -397,12 +390,6 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose } } -// Trigger the next kernel here so that it's load from global memory can overlap with this kernel's -// store to global memory. -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - cudaTriggerProgrammaticLaunchCompletion(); -#endif - // Step 3: Store cast output, Step 4: do transpose within thread tile // Edge case: in the non-full tile case, there are three subcases // for full thread tile, it's the same thing here @@ -526,15 +513,6 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor const size_t num_blocks_x = DIVUP(row_length, BLOCK_TILE_DIM); const size_t num_blocks_y = DIVUP(num_rows, BLOCK_TILE_DIM); - dim3 grid(num_blocks_x, num_blocks_y, 1); - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; - cudaLaunchConfig_t cfg = {grid, THREADS_PER_BLOCK, 0, stream, NULL, 0}; - if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= 90) { - cfg.attrs = attribute; - cfg.numAttrs = 1; - } TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -545,6 +523,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor TRANSFORMER_ENGINE_SWITCH_CONDITION( return_transpose, kReturnTranspose, + dim3 grid(num_blocks_x, num_blocks_y, 1); const bool full_tile = row_length % BLOCK_TILE_DIM == 0 && num_rows % BLOCK_TILE_DIM == 0; @@ -554,28 +533,26 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor tensor_map_output_trans = get_tensor_map(output_t, num_rows, row_length); } - cudaLaunchKernelEx(&cfg, - block_scaled_cast_transpose_kernel, - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, - scale_t_stride_y, epsilon, tensor_map_output_trans, pow_2_scale); + block_scaled_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + tensor_map_output_trans, pow_2_scale); } else { - cudaLaunchKernelEx( - &cfg, - block_scaled_cast_transpose_kernel_notaligned, - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - pow_2_scale); + block_scaled_cast_transpose_kernel_notaligned + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, + scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, + pow_2_scale); } // full-tile ) // return_transpose ) // OutputType diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 5bf2f5201..6f5c0f3a6 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -17,7 +17,6 @@ #include "common/common.h" #include "common/recipe/recipe_common.cuh" #include "common/transpose/cast_transpose.h" -#include "common/util/cuda_runtime.h" #include "common/utils.cuh" namespace transformer_engine { @@ -235,14 +234,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo __syncthreads(); -// If not return columnwise, we trigger the next kernel here so that it's load from global memory -// can overlap with this kernel's return rowwise. -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - if (!return_columnwise_gemm_ready && !return_columnwise_compact) { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - // Step 2: Cast and store to output_c if (return_rowwise) { constexpr int r_stride = @@ -334,14 +325,6 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo } } -// If return columnwise, we trigger the next kernel here so that it's load from global memory -// can overlap with this kernel's return columnwise. -#if (defined __CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) - if (return_columnwise_gemm_ready || return_columnwise_compact) { - cudaTriggerProgrammaticLaunchCompletion(); - } -#endif - // Step 3 (return_columnwise_gemm_ready): Transpose, cast and store to output_t if (return_columnwise_gemm_ready) { constexpr int c_stride = @@ -601,10 +584,6 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); - dim3 grid(num_blocks_x, num_blocks_y, 1); - cudaLaunchAttribute attribute[1]; - attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization; - attribute[0].val.programmaticStreamSerializationAllowed = 1; TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -612,38 +591,31 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output.dtype, OutputType, + dim3 grid(num_blocks_x, num_blocks_y, 1); + const bool full_tile = row_length % kTileDim == 0 && num_rows % kTileDim == 0; TRANSFORMER_ENGINE_SWITCH_CONDITION( full_tile, kAligned, size_t smem_bytes = kSMemSize * sizeof(InputType); - - cudaLaunchConfig_t cfg = {grid, kThreadsPerBlock, smem_bytes, stream, NULL, 0}; - if (transformer_engine::cuda::sm_arch(transformer_engine::cuda::current_device()) >= - 90) { - cfg.attrs = attribute; - cfg.numAttrs = 1; - } // shared memory must be requested up if (smem_bytes >= 48 * 1024) { cudaError_t err = cudaFuncSetAttribute( &block_scaled_1d_cast_transpose_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); NVTE_CHECK(err == cudaSuccess, "Failed to set dynamic shared memory size."); - } cudaLaunchKernelEx(&cfg, - block_scaled_1d_cast_transpose_kernel, - reinterpret_cast(input.dptr), - reinterpret_cast(output.dptr), - reinterpret_cast(output_t.dptr), - reinterpret_cast(scale_inv.dptr), - reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, - scale_stride_x, scale_stride_y, scale_t_stride_x, - scale_t_stride_y, epsilon, rowwise_option, columnwise_option, - pow2_scale);) // kAligned - ) // OutputType - ) // InputType + } block_scaled_1d_cast_transpose_kernel + <<>>( + reinterpret_cast(input.dptr), + reinterpret_cast(output.dptr), + reinterpret_cast(output_t.dptr), + reinterpret_cast(scale_inv.dptr), + reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, + scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, + columnwise_option, pow2_scale);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } From 1398fa5f36bfb780acfb7348dcd297a84fd3e705 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Tue, 26 Aug 2025 20:58:34 +0200 Subject: [PATCH 103/153] [PyTorch Debug] Skip log test on device if it does not support fp8. (#2109) fix test on old device Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 0b0adb451..ca8e10ad6 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -119,6 +119,9 @@ def read_log(log_dir: str) -> str: def test_sanity(feature_dirs): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats)) with debug_session(log_all_stats_config, feature_dirs) as log_dir: model = te.Linear(128, 128, params_dtype=torch.bfloat16) @@ -207,6 +210,9 @@ def test_numerics(fp8_recipe, feature_dirs): @pytest.mark.parametrize("layer", ["linear", "transformer"]) def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): + if not fp8_available: + pytest.skip(reason_for_no_fp8) + # If layer does not invoke any feature in current iteration, # then it changed into non-debug mode. # This test checks whether this works correctly - From 8dba2963435f7cbd97b6664c0a0b9424c81cfb87 Mon Sep 17 00:00:00 2001 From: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Date: Tue, 26 Aug 2025 15:03:01 -0700 Subject: [PATCH 104/153] Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak --- qa/L0_cppunittest/test.sh | 2 +- qa/L1_cpp_distributed/test.sh | 15 + setup.py | 13 + tests/cpp/CMakeLists.txt | 2 + tests/cpp/comm_gemm/CMakeLists.txt | 19 + tests/cpp/comm_gemm/test_comm_gemm.cu | 441 +++++++++++++++ transformer_engine/common/CMakeLists.txt | 27 + .../common/comm_gemm/comm_gemm.cpp | 519 ++++++++++++++++++ transformer_engine/common/common.cu | 18 + transformer_engine/common/common.h | 16 +- .../common/gemm/cublaslt_gemm.cu | 18 - .../include/transformer_engine/comm_gemm.h | 156 ++++++ transformer_engine/common/util/logging.h | 17 + 13 files changed, 1242 insertions(+), 21 deletions(-) create mode 100755 qa/L1_cpp_distributed/test.sh create mode 100644 tests/cpp/comm_gemm/CMakeLists.txt create mode 100644 tests/cpp/comm_gemm/test_comm_gemm.cu create mode 100644 transformer_engine/common/comm_gemm/comm_gemm.cpp create mode 100644 transformer_engine/common/include/transformer_engine/comm_gemm.h diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index cd46b0b63..aa56d69ed 100755 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp cmake -GNinja -Bbuild . cmake --build build export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS)) -ctest --test-dir build -j$NUM_PARALLEL_JOBS +ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)' diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh new file mode 100755 index 000000000..f4f914b3e --- /dev/null +++ b/qa/L1_cpp_distributed/test.sh @@ -0,0 +1,15 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +set -e + +# Find TE +: ${TE_PATH:=/opt/transformerengine} +TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') +export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH + +cd $TE_PATH/tests/cpp +cmake -GNinja -S. -Bbuild +cmake --build build +mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm diff --git a/setup.py b/setup.py index 0b1b52327..52adaf923 100644 --- a/setup.py +++ b/setup.py @@ -4,6 +4,7 @@ """Installation script.""" +from importlib import metadata import os import time from pathlib import Path @@ -66,6 +67,18 @@ def setup_common_extension() -> CMakeExtension: if bool(int(os.getenv("NVTE_BUILD_ACTIVATION_WITH_FAST_MATH", "0"))): cmake_flags.append("-DNVTE_BUILD_ACTIVATION_WITH_FAST_MATH=ON") + if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))): + cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON") + cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution( + "nvidia-cublasmp-cu12" + ).locate_file("nvidia/cublasmp/cu12") + cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") + nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( + "nvidia-nvshmem-cu12" + ).locate_file("nvidia/nvshmem") + cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") + print("CMAKE_FLAGS:", cmake_flags[-2:]) + # Add custom CMake arguments from environment variable nvte_cmake_extra_args = os.getenv("NVTE_CMAKE_EXTRA_ARGS") if nvte_cmake_extra_args: diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index eb2825ba4..412c5d34d 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -37,10 +37,12 @@ find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_ message(STATUS "Found transformer_engine library: ${TE_LIB}") include_directories(../../transformer_engine/common/include) include_directories(../../transformer_engine/common) +include_directories(../../transformer_engine) include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +add_subdirectory(comm_gemm) add_subdirectory(operator) add_subdirectory(util) diff --git a/tests/cpp/comm_gemm/CMakeLists.txt b/tests/cpp/comm_gemm/CMakeLists.txt new file mode 100644 index 000000000..55f5207ac --- /dev/null +++ b/tests/cpp/comm_gemm/CMakeLists.txt @@ -0,0 +1,19 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +add_executable(test_comm_gemm + test_comm_gemm.cu + ../test_common.cu) + +find_package(OpenMP REQUIRED) +find_package(MPI REQUIRED) +find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) +target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) +target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) + +include(GoogleTest) +gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/comm_gemm/test_comm_gemm.cu b/tests/cpp/comm_gemm/test_comm_gemm.cu new file mode 100644 index 000000000..b34d4db4b --- /dev/null +++ b/tests/cpp/comm_gemm/test_comm_gemm.cu @@ -0,0 +1,441 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "../test_common.h" +#include "common.h" + +using transformer_engine::DType; +using transformer_engine::TypeInfo; + +#define CHECK_MPI(expr) \ + do { \ + int err = (expr); \ + if (err != MPI_SUCCESS) { \ + char err_str[MPI_MAX_ERROR_STRING + 1]{}; \ + int _len{}; \ + MPI_Error_string(err, err_str, &_len); \ + EXPECT_TRUE(false) << "MPI error: " << err << ": " << err_str; \ + } \ + } while (false) + +#define CHECK_NCCL(expr) \ + do { \ + ncclResult_t err = (expr); \ + if (err != ncclSuccess) { \ + EXPECT_TRUE(false) << "NCCL error: " << err << ": " << ncclGetErrorString(err); \ + } \ + } while (false) + +#define CHECK_CU(expr) \ + do { \ + CUresult err = (expr); \ + if (err != CUDA_SUCCESS) { \ + const char* str{}; \ + CUresult e_str = cuGetErrorString(err, &str); \ + if (e_str != CUDA_SUCCESS) str = "(unknown)"; \ + EXPECT_TRUE(false) << "CU error: " << err << ": " << str; \ + } \ + } while (false) + +int main(int argc, char* argv[]) { + ::testing::InitGoogleTest(&argc, argv); + CHECK_MPI(MPI_Init(&argc, &argv)); + auto ret = RUN_ALL_TESTS(); + CHECK_MPI(MPI_Finalize()); + return ret; +} + +bool IsMulticastSupported(int device_id) { + int supported = 0; + CHECK_CU(cuDeviceGetAttribute(&supported, CU_DEVICE_ATTRIBUTE_MULTICAST_SUPPORTED, device_id)); + return supported; +} + +template +std::vector CopyMatrix(const std::vector& data, size_t mstart, size_t nstart, size_t msize, + size_t nsize, size_t ld) { + std::vector ret(msize * nsize); + size_t dst = 0; + for (size_t j = nstart; j < nstart + nsize; ++j) { + for (size_t i = mstart; i < mstart + msize; ++i) { + ret[dst++] = data[j * ld + i]; + } + } + return ret; +} + +template +test::Tensor Make(size_t m, size_t n, float scale) { + test::Tensor ret("", std::vector{n, m}, TypeInfo::dtype); + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + return ret; +} + +template +test::Tensor MakeFromData(const std::vector& data, size_t mstart, size_t nstart, size_t msize, + size_t nsize, size_t ld, float scale) { + test::Tensor ret("", std::vector{nsize, msize}, TypeInfo::dtype); + ret.set_scale(scale); + ret.set_scale_inv(1.0 / scale); + auto local = CopyMatrix(data, mstart, nstart, msize, nsize, ld); + NVTE_CHECK_CUDA(cudaMemcpy(ret.rowwise_dptr(), local.data(), local.size() * sizeof local[0], + cudaMemcpyDefault)); + return ret; +} + +template +float GetScale(float amax) { + if constexpr (sizeof(T) > 1) return 1.0; + return static_cast(static_cast(std::numeric_limits::max())) / amax; +} + +struct Params { + DType a_type; + DType b_type; + DType d_type; + bool transa; + bool transb; + size_t m; + size_t n; + size_t k; + float tol; +}; + +class CommGemmFixure : public ::testing::TestWithParam { + protected: + CommGemmFixure() { + CHECK_MPI(MPI_Comm_size(MPI_COMM_WORLD, &nranks_)); + CHECK_MPI(MPI_Comm_rank(MPI_COMM_WORLD, &rank_)); + NVTE_CHECK_CUDA(cudaSetDevice(rank_)); + ncclUniqueId id{}; + if (rank_ == 0) CHECK_NCCL(ncclGetUniqueId(&id)); + CHECK_MPI(MPI_Bcast(&id, sizeof(id), MPI_BYTE, 0, MPI_COMM_WORLD)); + CHECK_NCCL(ncclCommInitRank(&comm_, nranks_, id, rank_)); + ctx_ = nvte_comm_gemm_ctx_create(comm_, nranks_, rank_); + } + ~CommGemmFixure() { + nvte_comm_gemm_ctx_destroy(ctx_); + ncclCommDestroy(comm_); + } + + struct PatternDims { + int64_t a_rows_start; + int64_t a_rows_num; + int64_t a_cols_start; + int64_t a_cols_num; + int64_t b_rows_start; + int64_t b_rows_num; + int64_t b_cols_start; + int64_t b_cols_num; + int64_t d_rows_start; + int64_t d_rows_num; + int64_t d_cols_start; + int64_t d_cols_num; + }; + + virtual PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) = 0; + + virtual void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) = 0; + + template + void Run(bool transa, bool transb, size_t m, size_t n, size_t k, float tol) { + cudaStream_t stream{}; + NVTE_CHECK_CUDA(cudaStreamCreate(&stream)); + + constexpr float MAX_IN = 1.0; + std::mt19937 rng(12); + std::uniform_real_distribution dist(0.0, MAX_IN); + + float a_scale = GetScale(MAX_IN); + float b_scale = GetScale(MAX_IN); + float d_scale = GetScale(MAX_IN * MAX_IN * k); + float bias_scale = GetScale(MAX_IN); + + std::vector adata(m * k); + std::generate(adata.begin(), adata.end(), + [&rng, &dist, a_scale] { return static_cast(dist(rng) * a_scale); }); + std::vector bdata(k * n); + std::generate(bdata.begin(), bdata.end(), + [&rng, &dist, b_scale] { return static_cast(dist(rng) * b_scale); }); + std::vector biasdata(m * n); + std::generate(biasdata.begin(), biasdata.end(), [&rng, &dist, bias_scale] { + return static_cast(dist(rng) * bias_scale); + }); + + auto ga = transa ? MakeFromData(adata, 0, 0, k, m, k, a_scale) + : MakeFromData(adata, 0, 0, m, k, m, a_scale); + auto gb = transb ? MakeFromData(bdata, 0, 0, n, k, n, b_scale) + : MakeFromData(bdata, 0, 0, k, n, k, b_scale); + auto gbias = MakeFromData(biasdata, 0, 0, m, n, m, bias_scale); + auto gd = Make(m, n, d_scale); + auto gaux = Make(m, n, d_scale); + + auto dims = DistributeTensors(m, n, k); + auto a = transa ? MakeFromData(adata, dims.a_rows_start, dims.a_cols_start, + dims.a_rows_num, dims.a_cols_num, k, a_scale) + : MakeFromData(adata, dims.a_cols_start, dims.a_rows_start, + dims.a_cols_num, dims.a_rows_num, m, a_scale); + auto b = transb ? MakeFromData(bdata, dims.b_cols_start, dims.b_rows_start, + dims.b_cols_num, dims.b_rows_num, n, b_scale) + : MakeFromData(bdata, dims.b_rows_start, dims.b_cols_start, + dims.b_rows_num, dims.b_cols_num, k, b_scale); + auto bias = MakeFromData(biasdata, dims.d_rows_start, dims.d_cols_start, + dims.d_rows_num, dims.d_cols_num, m, bias_scale); + auto d = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + auto aux = Make(dims.d_rows_num, dims.d_cols_num, d_scale); + + bool grad = false; + bool accumulate = false; + CommGemm(m, n, k, a.data(), b.data(), d.data(), bias.data(), aux.data(), transa, transb, grad, + accumulate, 0 /*comm_sm_count*/, stream); + auto workspace = Make(1, 32 << 20, 1.0); + nvte_cublas_gemm(ga.data(), gb.data(), gd.data(), gbias.data(), gaux.data(), transa, transb, + grad, workspace.data(), accumulate, false /* use_split_accumulator */, + 0 /* math_sm_count */, stream); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); + NVTE_CHECK_CUDA(cudaStreamDestroy(stream)); + std::vector out(dims.d_rows_num * dims.d_cols_num); + NVTE_CHECK_CUDA( + cudaMemcpy(out.data(), d.rowwise_dptr(), out.size() * sizeof out[0], cudaMemcpyDefault)); + std::vector out_golden_global(m * n); + NVTE_CHECK_CUDA(cudaMemcpy(out_golden_global.data(), gd.rowwise_dptr(), + out_golden_global.size() * sizeof out_golden_global[0], + cudaMemcpyDefault)); + + auto out_golden = CopyMatrix(out_golden_global, dims.d_rows_start, dims.d_cols_start, + dims.d_rows_num, dims.d_cols_num, m); + NVTE_CHECK(out.size() == out_golden.size()); + for (size_t i = 0; i < out.size(); ++i) { + EXPECT_NEAR(static_cast(out[i]), static_cast(out_golden[i]), tol * k); + } + } + + NVTECommGemmCtx* ctx_{}; + int nranks_{}; + int rank_{}; + ncclComm_t comm_{}; +}; + +struct AgGemm : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto a_cols_num = nvte_comm_gemm_numroc(ctx_, m); + auto b_cols_num = nvte_comm_gemm_numroc(ctx_, n); + + int64_t a_cols_start{}; + int64_t b_cols_start{}; + MPI_Exscan(&a_cols_num, &a_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + MPI_Exscan(&b_cols_num, &b_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = 0, + .a_rows_num = k, + .a_cols_start = a_cols_start, + .a_cols_num = a_cols_num, + .b_rows_start = 0, + .b_rows_num = k, + .b_cols_start = b_cols_start, + .b_cols_num = b_cols_num, + .d_rows_start = a_cols_start, + .d_rows_num = a_cols_num, + .d_cols_start = 0, + .d_cols_num = n, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_all_gather_gemm(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } +}; + +struct GemmRs : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto rows_num = nvte_comm_gemm_numroc(ctx_, k); + auto d_cols_num = nvte_comm_gemm_numroc(ctx_, n); + + int64_t rows_start{}; + int64_t d_cols_start{}; + MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + MPI_Exscan(&d_cols_num, &d_cols_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = rows_start, + .a_rows_num = rows_num, + .a_cols_start = 0, + .a_cols_num = m, + .b_rows_start = rows_start, + .b_rows_num = rows_num, + .b_cols_start = 0, + .b_cols_num = n, + .d_rows_start = 0, + .d_rows_num = m, + .d_cols_start = d_cols_start, + .d_cols_num = d_cols_num, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_gemm_reduce_scatter(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } +}; + +struct GemmAr : public CommGemmFixure { + PatternDims DistributeTensors(int64_t m, int64_t n, int64_t k) override { + auto rows_num = nvte_comm_gemm_numroc(ctx_, k); + + int64_t rows_start{}; + MPI_Exscan(&rows_num, &rows_start, 1, MPI_INT64_T, MPI_SUM, MPI_COMM_WORLD); + + return PatternDims{ + .a_rows_start = rows_start, + .a_rows_num = rows_num, + .a_cols_start = 0, + .a_cols_num = m, + .b_rows_start = rows_start, + .b_rows_num = rows_num, + .b_cols_start = 0, + .b_cols_num = n, + .d_rows_start = 0, + .d_rows_num = m, + .d_cols_start = 0, + .d_cols_num = n, + }; + } + + void CommGemm(int64_t m, int64_t n, int64_t k, const NVTETensor a, const NVTETensor b, + const NVTETensor d, const NVTETensor bias, const NVTETensor pre_act_out, + bool transa, bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t stream) override { + nvte_gemm_all_reduce(ctx_, m, n, k, a, b, d, bias, pre_act_out, transa, transb, grad, + accumulate, comm_sm_count, stream, kNVTECommGemmAlgoDefault); + } + + void SetUp() override { + if (!IsMulticastSupported(rank_)) + GTEST_SKIP() << "Multicast is not supported on device " << rank_; + } +}; + +TEST_P(AgGemm, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +TEST_P(GemmRs, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +TEST_P(GemmAr, Gemm) { + auto [a_type, b_type, d_type, transa, transb, m, n, k, tol] = GetParam(); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + a_type, AType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + b_type, BType, + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT( + d_type, DType, Run(transa, transb, m, n, k, tol);))); +} + +std::string ParamSuffix(const testing::TestParamInfo& info) { + const auto [a_type, b_type, d_type, transa, transb, m, n, k, _tol] = info.param; + std::ostringstream ss; + ss << static_cast(a_type) << "_" << static_cast(b_type) << "_" + << static_cast(d_type) << "_" << (transa ? "T" : "N") << (transb ? "T" : "N") << "_" << m + << "x" << n << "x" << k; + return ss.str(); +} + +INSTANTIATE_TEST_SUITE_P(AgGemm, AgGemm, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, true, 256, 128, 64, 1e-3}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + true, false, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, false, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, true, 256, 128, 64, 1e-3}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, + DType::kFloat16, true, false, 256, 128, 64, 1e-3}), + &ParamSuffix); + +INSTANTIATE_TEST_SUITE_P(GemmRs, GemmRs, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + false, true, 64, 128, 256, 5e-2}, + Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, + true, false, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, false, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, false, true, 64, 128, 256, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, + DType::kBFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, + DType::kFloat16, true, false, 64, 128, 256, 5e-2}), + &ParamSuffix); + +INSTANTIATE_TEST_SUITE_P( + GemmAr, GemmAr, + testing::Values(Params{DType::kFloat16, DType::kFloat16, DType::kFloat16, true, false, 64, + 64 * 4, 64 * 4, 5e-2}, + Params{DType::kBFloat16, DType::kBFloat16, DType::kBFloat16, true, false, 64, + 64 * 4, 64 * 4, 5e-2}, + Params{DType::kFloat8E5M2, DType::kFloat8E4M3, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E5M2, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}, + Params{DType::kFloat8E4M3, DType::kFloat8E4M3, DType::kFloat16, true, false, + 128, 128 * 4, 128 * 4, 5e-2}), + &ParamSuffix); diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index b51e61929..183a7a72e 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -110,6 +110,12 @@ list(APPEND transformer_engine_SOURCES comm_gemm_overlap/userbuffers/userbuffers-host.cpp comm_gemm_overlap/userbuffers/userbuffers.cu comm_gemm_overlap/comm_gemm_overlap.cpp) + +if (NVTE_WITH_CUBLASMP) +list(APPEND transformer_engine_SOURCES + comm_gemm/comm_gemm.cpp) +endif() + add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") @@ -123,6 +129,8 @@ target_link_libraries(transformer_engine PUBLIC CUDNN::cudnn_all) target_include_directories(transformer_engine PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) +target_include_directories(transformer_engine SYSTEM PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI @@ -141,6 +149,25 @@ if (NVTE_ENABLE_NVSHMEM) target_include_directories(transformer_engine PUBLIC ${NVSHMEMAPI_INCLUDE_DIR}) endif() +option(NVTE_WITH_CUBLASMP "Use cuBLASMp for tensor parallel GEMMs" OFF) +if (NVTE_WITH_CUBLASMP) + target_compile_definitions(transformer_engine PRIVATE NVTE_WITH_CUBLASMP) + target_include_directories(transformer_engine PRIVATE ${CUBLASMP_DIR}/include ${NVSHMEM_DIR}/include) + find_library(CUBLASMP_LIB + NAMES cublasmp libcublasmp + PATHS ${CUBLASMP_DIR} + PATH_SUFFIXES lib + REQUIRED) + find_library(NVSHMEM_HOST_LIB + NAMES nvshmem_host libnvshmem_host.so.3 + PATHS ${NVSHMEM_DIR} + PATH_SUFFIXES lib + REQUIRED) + target_link_libraries(transformer_engine PUBLIC ${CUBLASMP_LIB} ${NVSHMEM_HOST_LIB}) + message(STATUS "Using cuBLASMp at: ${CUBLASMP_DIR}") + message(STATUS "Using nvshmem at: ${NVSHMEM_DIR}") +endif() + # Hack to enable dynamic loading in cuDNN frontend target_compile_definitions(transformer_engine PUBLIC NV_CUDNN_FRONTEND_USE_DYNAMIC_LOADING) diff --git a/transformer_engine/common/comm_gemm/comm_gemm.cpp b/transformer_engine/common/comm_gemm/comm_gemm.cpp new file mode 100644 index 000000000..76f46298d --- /dev/null +++ b/transformer_engine/common/comm_gemm/comm_gemm.cpp @@ -0,0 +1,519 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/comm_gemm.h" + +#include +#include +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "../common.h" +#include "../util/logging.h" + +using namespace transformer_engine; + +namespace { + +// TODO: log warnings on failures of the *Destroy calls below, once TE has such ability. +// For now, just silently ignoring the errors, since the only diag available in TE is throwing +// exceptions, but these calls will typically be made from destructors, so cannot throw. + +template +auto CreateWithCudaCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + NVTE_CHECK_CUDA(create_fn(&raw, std::forward(args)...)); + return std::unique_ptr(raw, destroy_fn); +} + +using CudaStream = + std::unique_ptr, decltype(&cudaStreamDestroy)>; + +CudaStream CudaStreamCreate() { + return CreateWithCudaCheck(cudaStreamCreate, cudaStreamDestroy); +} + +using CudaEvent = std::unique_ptr, decltype(&cudaEventDestroy)>; + +CudaEvent CudaEventCreate(unsigned flags) { + return CreateWithCudaCheck(cudaEventCreateWithFlags, cudaEventDestroy, flags); +} + +template +auto CreateWithCublasMpCheck(CreateFn create_fn, DestroyFn destroy_fn, Args&&... args) { + using Handle = std::remove_pointer_t; + HandlePtr raw{}; + if constexpr (raw_last) { + NVTE_CHECK_CUBLASMP(create_fn(std::forward(args)..., &raw)); + } else { + NVTE_CHECK_CUBLASMP(create_fn(&raw, std::forward(args)...)); + } + return std::unique_ptr(raw, destroy_fn); +} + +using CublasMp = + std::unique_ptr, decltype(&cublasMpDestroy)>; + +CublasMp CublasMpCreate(cudaStream_t stream) { + return CreateWithCublasMpCheck(cublasMpCreate, cublasMpDestroy, stream); +} + +using CublasMpGrid = + std::unique_ptr, decltype(&cublasMpGridDestroy)>; + +CublasMpGrid CublasMpGridCreate(int64_t nprow, int64_t npcol, cublasMpGridLayout_t layout, + ncclComm_t comm) { + return CreateWithCublasMpCheck(cublasMpGridCreate, cublasMpGridDestroy, + nprow, npcol, layout, comm); +} + +using CublasMpMatrixDesc = std::unique_ptr, + decltype(&cublasMpMatrixDescriptorDestroy)>; + +CublasMpMatrixDesc CublasMpMatrixDescCreate(int64_t m, int64_t n, int64_t mb, int64_t nb, + int64_t rsrc, int64_t csrc, int64_t lld, + cudaDataType_t type, cublasMpGrid_t grid) { + return CreateWithCublasMpCheck( + cublasMpMatrixDescriptorCreate, cublasMpMatrixDescriptorDestroy, m, n, mb, nb, rsrc, csrc, + lld, type, grid); +} + +using CublasMpMatmulDesc = std::unique_ptr, + decltype(&cublasMpMatmulDescriptorDestroy)>; + +CublasMpMatmulDesc CublasMpMatmulDescCreate(cublasComputeType_t compute_type) { + return CreateWithCublasMpCheck( + cublasMpMatmulDescriptorCreate, cublasMpMatmulDescriptorDestroy, compute_type); +} + +} // namespace + +struct NVTECommGemmCtx { + int64_t nranks; + int64_t rank; + ncclComm_t comm; + CudaStream stream; + CudaEvent event; + CublasMp cublas_mp; + CublasMpGrid grid_col_major; + CublasMpGrid grid_row_major; + CublasMpMatrixDesc a_desc; + CublasMpMatrixDesc b_desc; + CublasMpMatrixDesc d_desc; + CublasMpMatmulDesc matmul_desc; + void* workspace; + size_t workspace_size; +}; + +namespace { + +int64_t block_size(NVTECommGemmCtx* ctx, int64_t global_size) { + // Use non-cyclic layout to maximize opportunity for comm overlap. + return (global_size + ctx->nranks - 1) / ctx->nranks; +} + +void AgGemmInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a1 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, k, block_size(ctx, m), 0, 0, k, + get_cuda_dtype(a->dtype()), + ctx->grid_row_major.get(), ctx->a_desc.get())); + } else { + NVTE_CHECK(a0 == k, "Unsupported tensor dimension in A: expected ", k, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, block_size(ctx, m), k, 0, 0, + block_size(ctx, m), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } + if (transb) { + NVTE_CHECK(b0 == k, "Unsupported tensor dimensionin B: expected ", k, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(n, k, block_size(ctx, n), k, 0, 0, + block_size(ctx, n), get_cuda_dtype(b->dtype()), + ctx->grid_col_major.get(), ctx->b_desc.get())); + } else { + NVTE_CHECK(b1 == k, "Unsupported tensor dimension in B: expected ", k, ", got ", b1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, k, block_size(ctx, n), 0, 0, k, + get_cuda_dtype(b->dtype()), + ctx->grid_row_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d0 == n, "Unsupported tensor dimension in D: expected ", n, ", got ", d0); + *ldd = block_size(ctx, m); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, block_size(ctx, m), block_size(ctx, n), 0, + 0, *ldd, get_cuda_dtype(d->dtype()), + ctx->grid_col_major.get(), ctx->d_desc.get())); +} + +void GemmRsInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0, + block_size(ctx, k), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } else { + NVTE_CHECK(a1 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, k, m, block_size(ctx, k), 0, 0, m, + get_cuda_dtype(a->dtype()), + ctx->grid_row_major.get(), ctx->a_desc.get())); + } + if (transb) { + NVTE_CHECK(b1 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b1); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( + n, k, block_size(ctx, n), block_size(ctx, k), 0, 0, block_size(ctx, n), + get_cuda_dtype(b->dtype()), ctx->grid_row_major.get(), ctx->b_desc.get())); + } else { + NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit( + k, n, block_size(ctx, k), block_size(ctx, n), 0, 0, block_size(ctx, k), + get_cuda_dtype(b->dtype()), ctx->grid_col_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1); + *ldd = m; + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n, m, block_size(ctx, n), 0, 0, *ldd, + get_cuda_dtype(d->dtype()), + ctx->grid_row_major.get(), ctx->d_desc.get())); +} + +void GemmArInitMatrices(NVTECommGemmCtx* ctx, int64_t* ldd, int64_t m, int64_t n, int64_t k, + const Tensor* a, const Tensor* b, const Tensor* d, bool transa, + bool transb) { + const auto a0 = a->flat_first_dim(); + const auto a1 = a->flat_last_dim(); + const auto b0 = b->flat_first_dim(); + const auto b1 = b->flat_last_dim(); + const auto d0 = d->flat_first_dim(); + const auto d1 = d->flat_last_dim(); + + if (transa) { + NVTE_CHECK(a0 == m, "Unsupported tensor dimension in A: expected ", m, ", got ", a0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, m, block_size(ctx, k), m, 0, 0, + block_size(ctx, k), get_cuda_dtype(a->dtype()), + ctx->grid_col_major.get(), ctx->a_desc.get())); + } else { + NVTE_ERROR("N transpose flag is not supported for input A"); + } + if (transb) { + NVTE_ERROR("T transpose flag is not supported for input B"); + } else { + NVTE_CHECK(b0 == n, "Unsupported tensor dimension in B: expected ", n, ", got ", b0); + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(k, n, block_size(ctx, k), n, 0, 0, + block_size(ctx, k), get_cuda_dtype(b->dtype()), + ctx->grid_col_major.get(), ctx->b_desc.get())); + } + NVTE_CHECK(d1 == m, "Unsupported tensor dimension in D: expected ", m, ", got ", d1); + *ldd = m; + NVTE_CHECK_CUBLASMP(cublasMpMatrixDescriptorInit(m, n * ctx->nranks, m, n, 0, 0, *ldd, + get_cuda_dtype(d->dtype()), + ctx->grid_row_major.get(), ctx->d_desc.get())); + + const cublasMpMatmulEpilogue_t epilogue = CUBLASMP_MATMUL_EPILOGUE_ALLREDUCE; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue)); +} + +using InitMatricesFn = void (*)(NVTECommGemmCtx*, int64_t*, int64_t, int64_t, int64_t, + const Tensor*, const Tensor*, const Tensor*, bool, bool); + +cublasMpMatmulAlgoType_t cublasmp_algo(NVTECommGemmAlgoType algo) { + static const std::unordered_map s_map{ + {kNVTECommGemmAlgoDefault, CUBLASMP_MATMUL_ALGO_TYPE_DEFAULT}, + {kNVTECommGemmAlgoSplitP2P, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_P2P}, + {kNVTECommGemmAlgoSplitMulticast, CUBLASMP_MATMUL_ALGO_TYPE_SPLIT_MULTICAST}, + {kNVTECommGemmAlgoAtomicP2P, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_P2P}, + {kNVTECommGemmAlgoAtomicMulticast, CUBLASMP_MATMUL_ALGO_TYPE_ATOMIC_MULTICAST}, + }; + auto it = s_map.find(algo); + return it != s_map.end() ? it->second : static_cast(algo); +} + +void cublasmp_gemm(InitMatricesFn init_matrices_fn, NVTECommGemmCtx* ctx, NVTECommGemmAlgoType algo, + int64_t m, int64_t n, int64_t k, const Tensor* a, const Tensor* b, + const Tensor* d, const Tensor* bias, const Tensor* pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream) { + for (auto t : {a, b, d}) { + NVTE_CHECK(is_tensor_scaling(t->scaling_mode), + "Unsupported scaling mode: " + std::to_string(t->scaling_mode)); + } + + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorInit(ctx->matmul_desc.get(), CUBLAS_COMPUTE_32F)); + + int64_t ldd{}; + init_matrices_fn(ctx, &ldd, m, n, k, a, b, d, transa, transb); + + const cublasOperation_t trans_a = transa ? CUBLAS_OP_T : CUBLAS_OP_N; + const cublasOperation_t trans_b = transb ? CUBLAS_OP_T : CUBLAS_OP_N; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSA, &trans_a, + sizeof trans_a)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_TRANSB, &trans_b, + sizeof trans_b)); + cublasMpMatmulAlgoType_t algo_attr = cublasmp_algo(algo); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_ALGO_TYPE, &algo_attr, + sizeof algo_attr)); + + const cublasMpMatmulMatrixScale_t scale_mode = CUBLASMP_MATMUL_MATRIX_SCALE_SCALAR_FP32; + if (is_fp8_dtype(a->dtype())) { + NVTE_CHECK(a->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_A_SCALE_POINTER, + &a->scale_inv.dptr, sizeof(void*))); + } + if (is_fp8_dtype(b->dtype())) { + NVTE_CHECK(b->scale_inv.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_B_SCALE_POINTER, + &b->scale_inv.dptr, sizeof(void*))); + } + if (is_fp8_dtype(d->dtype())) { + NVTE_CHECK(d->scale.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_MODE, &scale_mode, + sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_D_SCALE_POINTER, + &d->scale.dptr, sizeof(void*))); + if (d->amax.dptr) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_AMAX_D_POINTER, + &d->amax.dptr, sizeof(void*))); + } + } + + // Might be set to ALLREDUCE before, need to OR with the new flags to set. + cublasMpMatmulEpilogue_t epilogue{}; + size_t size_read{}; + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeGet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue, &size_read)); + NVTE_CHECK(size_read == sizeof epilogue); + // (bias, gelu, grad) -> epilogue + const std::map, cublasMpMatmulEpilogue_t> flags_to_epilogue{ + {{true, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX_BIAS}, + {{true, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU_BGRAD}, + {{true, false, false}, CUBLASMP_MATMUL_EPILOGUE_BIAS}, + {{true, false, true}, CUBLASMP_MATMUL_EPILOGUE_BGRADB}, + {{false, true, false}, CUBLASMP_MATMUL_EPILOGUE_GELU_AUX}, + {{false, true, true}, CUBLASMP_MATMUL_EPILOGUE_DGELU}, + }; + if (auto it = + flags_to_epilogue.find({bias ? bias->data.dptr != nullptr : false, + pre_act_out ? pre_act_out->data.dptr != nullptr : false, grad}); + it != flags_to_epilogue.end()) { + epilogue = static_cast(epilogue | it->second); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE, &epilogue, + sizeof epilogue)); + } + + if (bias && bias->data.dptr) { + cudaDataType_t bias_type = get_cuda_dtype(bias->data.dtype); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_DATA_TYPE, &bias_type, + sizeof bias_type)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_BIAS_POINTER, &bias->data.dptr, + sizeof bias->data.dptr)); + } + + if (pre_act_out && pre_act_out->data.dptr) { + cudaDataType_t aux_type = get_cuda_dtype(pre_act_out->data.dtype); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_DATA_TYPE, + &aux_type, sizeof aux_type)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_POINTER, + &pre_act_out->data.dptr, sizeof pre_act_out->data.dptr)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_LD, &ldd, + sizeof ldd)); + if (is_fp8_dtype(pre_act_out->dtype())) { + NVTE_CHECK(pre_act_out->scale.dptr, "Scaling must be set for FP8 dtype"); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_MODE, + &scale_mode, sizeof scale_mode)); + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_SCALE_POINTER, + &pre_act_out->scale.dptr, sizeof(void*))); + if (pre_act_out->amax.dptr) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_EPILOGUE_AUX_AMAX_POINTER, + &pre_act_out->amax.dptr, sizeof(void*))); + } + } + } + + if (comm_sm_count) { + NVTE_CHECK_CUBLASMP(cublasMpMatmulDescriptorAttributeSet( + ctx->matmul_desc.get(), CUBLASMP_MATMUL_DESCRIPTOR_ATTRIBUTE_COMMUNICATION_SM_COUNT, + &comm_sm_count, sizeof comm_sm_count)); + } + + NVTE_CHECK_CUBLASMP(cublasMpStreamSet(ctx->cublas_mp.get(), main_stream)); + + size_t wrksp_size_device{}; + size_t wrksp_size_host{}; + + float alpha = 1.0; + float beta = accumulate ? 1.0 : 0.0; + std::tuple args{ctx->cublas_mp.get(), + ctx->matmul_desc.get(), + m, + n, + k, + &alpha, + a->data.dptr, + 1, + 1, + ctx->a_desc.get(), + b->data.dptr, + 1, + 1, + ctx->b_desc.get(), + &beta, + accumulate ? d->data.dptr : nullptr, + 1, + 1, + accumulate ? ctx->d_desc.get() : nullptr, + d->data.dptr, + 1, + 1, + ctx->d_desc.get()}; + NVTE_CHECK_CUBLASMP( + std::apply(cublasMpMatmul_bufferSize, + std::tuple_cat(args, std::tuple{&wrksp_size_device, &wrksp_size_host}))); + + std::vector workspace_host(wrksp_size_host); + if (ctx->workspace_size < wrksp_size_device) { + nvshmem_free(ctx->workspace); + ctx->workspace = nvshmem_malloc(wrksp_size_device); + ctx->workspace_size = wrksp_size_device; + } + + NVTE_CHECK_CUBLASMP( + std::apply(cublasMpMatmul, + std::tuple_cat(args, std::tuple{ctx->workspace, ctx->workspace_size, + workspace_host.data(), workspace_host.size()}))); + + NVTE_CHECK_CUDA(cudaEventRecord(ctx->event.get(), main_stream)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(ctx->stream.get(), ctx->event.get(), 0)); +} + +} // namespace + +NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank) { + NVTE_API_CALL(nvte_comm_gemm_ctx_create); + auto stream = CudaStreamCreate(); + auto event = CudaEventCreate(cudaEventDisableTiming); + auto cublas_mp = CublasMpCreate(stream.get()); + + auto col_major = CublasMpGridCreate(nranks, 1, CUBLASMP_GRID_LAYOUT_COL_MAJOR, comm); + auto row_major = CublasMpGridCreate(1, nranks, CUBLASMP_GRID_LAYOUT_ROW_MAJOR, comm); + + // Pre-creating matrix descriptors here, will be initialized with the actual params later. + auto a_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + auto b_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + auto d_desc = CublasMpMatrixDescCreate(1, 1, 1, 1, 0, 0, 1, CUDA_R_16F, row_major.get()); + + auto matmul_desc = CublasMpMatmulDescCreate(CUBLAS_COMPUTE_32F); + + return new NVTECommGemmCtx{ + .nranks = nranks, + .rank = rank, + .comm = comm, + .stream = std::move(stream), + .event = std::move(event), + .cublas_mp = std::move(cublas_mp), + .grid_col_major = std::move(col_major), + .grid_row_major = std::move(row_major), + .a_desc = std::move(a_desc), + .b_desc = std::move(b_desc), + .d_desc = std::move(d_desc), + .matmul_desc = std::move(matmul_desc), + }; +} + +void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx) { + NVTE_API_CALL(nvte_comm_gemm_ctx_destroy); + nvshmemx_sync_all_on_stream(ctx->stream.get()); + delete ctx; +} + +void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_all_gather_gemm); + cublasmp_gemm(AgGemmInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, + const NVTETensor a, const NVTETensor b, const NVTETensor d, + const NVTETensor bias, const NVTETensor pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream, NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_gemm_reduce_scatter); + cublasmp_gemm(GemmRsInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo) { + NVTE_API_CALL(nvte_gemm_all_reduce); + cublasmp_gemm(GemmArInitMatrices, ctx, algo, m, n, k, convertNVTETensorCheck(a), + convertNVTETensorCheck(b), convertNVTETensorCheck(d), convertNVTETensorCheck(bias), + convertNVTETensorCheck(pre_act_out), transa, transb, grad, accumulate, + comm_sm_count, main_stream); +} + +int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size) { + NVTE_API_CALL(nvte_comm_gemm_numroc); + return cublasMpNumroc(global_size, block_size(ctx, global_size), ctx->rank, 0, ctx->nranks); +} diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 4e697979d..a810fb471 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -26,6 +26,24 @@ __global__ void __launch_bounds__(1) } // namespace +cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { + using namespace transformer_engine; + switch (t) { + case DType::kFloat16: + return CUDA_R_16F; + case DType::kFloat32: + return CUDA_R_32F; + case DType::kBFloat16: + return CUDA_R_16BF; + case DType::kFloat8E4M3: + return CUDA_R_8F_E4M3; + case DType::kFloat8E5M2: + return CUDA_R_8F_E5M2; + default: + NVTE_ERROR("Invalid type"); + } +} + void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { NVTE_CHECK(t->scale_inv.dptr != nullptr, "Tensor should have allocated scale_inv."); diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index aa47f2c3d..e2a3c52aa 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -270,6 +270,8 @@ struct QuantizationConfig { }; }; +cudaDataType_t get_cuda_dtype(const transformer_engine::DType t); + template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); @@ -382,9 +384,19 @@ struct BitsNumber { template struct TypeInfo { #if FP4_TYPE_SUPPORTED - using types = std::tuple; + using types = std::tuple= 12080 + , + fp8e8m0 +#endif + >; #else - using types = std::tuple; + using types = std::tuple= 12080 + , + fp8e8m0 +#endif + >; #endif template diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index d65cd7b55..9e6c5417b 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -22,24 +22,6 @@ namespace { -cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { - using namespace transformer_engine; - switch (t) { - case DType::kFloat16: - return CUDA_R_16F; - case DType::kFloat32: - return CUDA_R_32F; - case DType::kBFloat16: - return CUDA_R_16BF; - case DType::kFloat8E4M3: - return CUDA_R_8F_E4M3; - case DType::kFloat8E5M2: - return CUDA_R_8F_E5M2; - default: - NVTE_ERROR("Invalid type"); - } -} - uint32_t _getAlignment(uintptr_t address) { // alignment are in bytes uint32_t alignment = 256; diff --git a/transformer_engine/common/include/transformer_engine/comm_gemm.h b/transformer_engine/common/include/transformer_engine/comm_gemm.h new file mode 100644 index 000000000..14cf56a00 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/comm_gemm.h @@ -0,0 +1,156 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file comm_gemm.h + * \brief Functions for distributed (multi-GPU) matrix multiplication. + * + * This API is a TE-native binding to cuBLASMp library. + * Refer here: https://docs.nvidia.com/cuda/cublasmp/usage/tp.html for specific + * patterns, which allow communication-computation overlap. + * + * All GEMM functions here have the same computation semantic, as expressed + * on global matrices, similar to nvte_cublas_gemm call: + * - `D = AB` if both `bias` and `pre_gelu_out` are empty tensors + * - `D = AB + bias` if `pre_gelu_out` is empty and `bias` is not empty + * - `D = GELU(AB + bias)` if both `bias` and `pre_gelu_out` are not empty tensors + * + * Functions differ in matrix distribution patterns + */ + +#ifndef TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_ +#define TRANSFORMER_ENGINE_COMMON_COMM_GEMM_H_ + +#include +#include + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#else +#include +#endif + +typedef struct NVTECommGemmCtx NVTECommGemmCtx; + +enum NVTECommGemmAlgoType { + kNVTECommGemmAlgoDefault = 0, + kNVTECommGemmAlgoSplitP2P = 1, + kNVTECommGemmAlgoSplitMulticast = 2, + kNVTECommGemmAlgoAtomicP2P = 3, + kNVTECommGemmAlgoAtomicMulticast = 4 +}; + +/*! \brief Create a comm-gemm context. + * + * \param[in] comm NCCL communicator. + * \param[in] nranks Number of ranks. + * \param[in] rank Local rank. + */ +NVTECommGemmCtx* nvte_comm_gemm_ctx_create(ncclComm_t comm, int nranks, int rank); + +/*! \brief Destroy a comm-gemm context. + * + * \param[in] ctx Context to destroy. + */ +void nvte_comm_gemm_ctx_destroy(NVTECommGemmCtx* ctx); + +/*! \brief Perform AllGather communication followed by GEMM + * + * Gathers distributed data from all ranks, then computes matrix multiplication. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_all_gather_gemm(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo); + +/*! \brief Perform GEMM followed by ReduceScatter communication + * + * Computes matrix multiplication, then distributes results across ranks with reduction. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_gemm_reduce_scatter(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, + const NVTETensor a, const NVTETensor b, const NVTETensor d, + const NVTETensor bias, const NVTETensor pre_act_out, bool transa, + bool transb, bool grad, bool accumulate, int comm_sm_count, + cudaStream_t main_stream, NVTECommGemmAlgoType algo); + +/*! \brief Perform GEMM followed by AllReduce communication + * + * Computes matrix multiplication, then reduces results across all ranks. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] m Global m dimension. + * \param[in] n Global n dimension. + * \param[in] k Global k dimension. + * \param[in] a Local part of A matrix. + * \param[in] b Local part of B matrix. + * \param[in,out] d Local part of D matrix. + * \param[in] bias Bias tensor. + * \param[in,out] pre_act_out Local part of output matrix before GELU activation. + * \param[in] transa Whether A matrix is transposed. + * \param[in] transb Whether B matrix is transposed. + * \param[in] grad Whether this operation is part of gradient computation. + * \param[in] accumulate Whether to accumulate the result into the D matrix. + * \param[in] comm_sm_count Number of GPU SMs to use for communication (default=0: use heuristics) + * \param[in] main_stream CUDA stream used for computation. + * \param[in] algo Algorithm to use. + */ +void nvte_gemm_all_reduce(NVTECommGemmCtx* ctx, int64_t m, int64_t n, int64_t k, const NVTETensor a, + const NVTETensor b, const NVTETensor d, const NVTETensor bias, + const NVTETensor pre_act_out, bool transa, bool transb, bool grad, + bool accumulate, int comm_sm_count, cudaStream_t main_stream, + NVTECommGemmAlgoType algo); + +/*! \brief Get local number of rows or columns. + * + * Utility function to get local dimension. + * Block size, nranks and local rank is derived from the context ctx. + * + * \param[in] ctx Comm-GEMM context. + * \param[in] global_size Global dimension. + */ +int64_t nvte_comm_gemm_numroc(NVTECommGemmCtx* ctx, int64_t global_size); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif // TRANSFORMER_ENGINE_COMM_GEMM_H_ diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 173aad52a..941899b28 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -12,8 +12,13 @@ #include #include +#ifdef NVTE_WITH_CUBLASMP +#include +#endif // NVTE_WITH_CUBLASMP + #include #include +#include #include "../util/string.h" @@ -87,4 +92,16 @@ } \ } while (false) +#ifdef NVTE_WITH_CUBLASMP + +#define NVTE_CHECK_CUBLASMP(expr) \ + do { \ + const cublasMpStatus_t status = (expr); \ + if (status != CUBLASMP_STATUS_SUCCESS) { \ + NVTE_ERROR("cuBLASMp Error: ", std::to_string(status)); \ + } \ + } while (false) + +#endif // NVTE_WITH_CUBLASMP + #endif // TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ From 62a57dd45ad8ec02943214059917ff94b644ae35 Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Wed, 27 Aug 2025 10:29:48 -0400 Subject: [PATCH 105/153] FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen --- qa/L1_jax_distributed_unittest/test.sh | 1 + tests/jax/multi_process_launch.sh | 23 +++ ..._multi_process_distributed_grouped_gemm.py | 164 ++++++++++++++++++ .../jax/cpp_extensions/quantization.py | 7 +- .../jax/csrc/extensions/gemm.cpp | 5 +- transformer_engine/jax/dense.py | 131 +++++++++++++- 6 files changed, 320 insertions(+), 11 deletions(-) create mode 100644 tests/jax/multi_process_launch.sh create mode 100644 tests/jax/test_multi_process_distributed_grouped_gemm.py diff --git a/qa/L1_jax_distributed_unittest/test.sh b/qa/L1_jax_distributed_unittest/test.sh index f332e32e8..8ecc5a917 100644 --- a/qa/L1_jax_distributed_unittest/test.sh +++ b/qa/L1_jax_distributed_unittest/test.sh @@ -9,3 +9,4 @@ set -xe mkdir -p "$XML_LOG_DIR" NVTE_JAX_UNITTEST_LEVEL="L1" python3 -m pytest -c $TE_PATH/tests/jax/pytest.ini -v --junitxml=$XML_LOG_DIR/pytest.xml $TE_PATH/tests/jax/test_distributed_* +SCRIPT_NAME=test_multi_process_distributed_grouped_gemm.py bash $TE_PATH/tests/jax/multi_process_launch.sh diff --git a/tests/jax/multi_process_launch.sh b/tests/jax/multi_process_launch.sh new file mode 100644 index 000000000..3e0852f39 --- /dev/null +++ b/tests/jax/multi_process_launch.sh @@ -0,0 +1,23 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +#!/bin/bash + +SCRIPT_NAME="${SCRIPT_NAME:-test.py}" + + +XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true + --xla_gpu_enable_command_buffer=''" + +export XLA_FLAGS="${XLA_BASE_FLAGS}" + +NUM_RUNS=$(nvidia-smi --query-gpu=count --format=csv,noheader) +for ((i=1; i /dev/null 2>&1 & +done + +CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_PROC + +wait diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py new file mode 100644 index 000000000..6fce62d8c --- /dev/null +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -0,0 +1,164 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from functools import partial + +import jax +import jax.numpy as jnp + +from transformer_engine.jax.dense import grouped_dense as te_grouped_dense +from transformer_engine.jax.quantize import ( + QuantizerFactory, + ScalingMode, +) + +from utils import assert_allclose + + +N_GROUP = 8 +MESH_AXIS_NAME = "fsdp" + + +def test_grouped_gemm_fp8_allgather(data_shapes, kernel_fsdp_axis): + assert kernel_fsdp_axis in [1, 2] + x_shape, w_shape = data_shapes + + x_sharding = NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None, None, None)) + w_sharding = ( + NamedSharding(mesh, PartitionSpec(None, None, MESH_AXIS_NAME)) + if kernel_fsdp_axis == 2 + else NamedSharding(mesh, PartitionSpec(None, MESH_AXIS_NAME, None)) + ) + w_no_sharding = NamedSharding(mesh, PartitionSpec(None, None, None)) + + def init_data(): + x_key = jax.random.PRNGKey(0) + w_key = jax.random.PRNGKey(1) + x = jax.random.normal(x_key, shape=(N_GROUP, *x_shape), dtype=jnp.bfloat16) + w = jax.random.normal(w_key, shape=(N_GROUP, *w_shape), dtype=jnp.bfloat16) + w_amax = jnp.max(jnp.abs(w), axis=range(1, w.ndim)) + return x, w, w, w_amax + + def test_func(outter_x, outter_w, outter_w_amax): + in_specs = (x_sharding.spec, w_sharding.spec, None) + out_specs = x_sharding.spec + + @partial( + shard_map.shard_map, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + def sharded_group_gemm(x, w, w_amax): + group_size = x.shape[0] + x_reshaped = x.reshape(-1, x.shape[-1]) + n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) + + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=True, + n_groups=group_size, + ) + + output = te_grouped_dense( + x_reshaped, + w, + n_groups, + kernel_amax=w_amax, + quantizer_set=quantizer_set, + kernel_fsdp_info=(MESH_AXIS_NAME, kernel_fsdp_axis), + ) + output = output.reshape(*x.shape[:-1], -1) + return output + + def run(x, w, w_amax): + output = sharded_group_gemm(x, w, w_amax) + return output + + output, vjp_fn = jax.vjp(run, outter_x, outter_w, outter_w_amax) + dx, dw, _ = vjp_fn(output) + return output, dx, dw + + def ref_func(outter_x, outter_w): + + in_specs = (x_sharding.spec, w_no_sharding.spec) + out_specs = x_sharding.spec + + @partial( + shard_map.shard_map, + mesh=mesh, + in_specs=in_specs, + out_specs=out_specs, + check_rep=False, + ) + def sharded_group_gemm(x, w): + group_size = x.shape[0] + x_reshaped = x.reshape(-1, x.shape[-1]) + n_groups = jnp.full(group_size, x_reshaped.shape[0] // group_size) + + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.CURRENT_TENSOR_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e5m2, + is_2x2x=True, + n_groups=group_size, + ) + output = te_grouped_dense(x_reshaped, w, n_groups, quantizer_set=quantizer_set) + output = output.reshape(*x.shape[:-1], -1) + return output + + def run(x, w): + output = sharded_group_gemm(x, w) + return output + + output, vjp_fn = jax.vjp(run, outter_x, outter_w) + dx, dw = vjp_fn(output) + return output, dx, dw + + init_func = jax.jit(init_data, out_shardings=(x_sharding, w_sharding, w_no_sharding, None)) + x, w, w_global, w_amax = init_func() + + o_sharding = x_sharding + test_func_jitted = jax.jit( + test_func, + in_shardings=(x_sharding, w_sharding, None), + out_shardings=(o_sharding, x_sharding, w_sharding), + ) + ref_func_jitted = jax.jit( + ref_func, + in_shardings=(x_sharding, w_no_sharding), + out_shardings=(o_sharding, x_sharding, w_no_sharding), + ) + + out, dx, dw = test_func_jitted(x, w, w_amax) + ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) + + assert_allclose(out, ref_out, dtype=jnp.float8_e4m3fn) + assert_allclose(dx, ref_dx, dtype=jnp.float8_e5m2) + assert_allclose(dw, ref_dw, dtype=jnp.float8_e5m2) + + +if __name__ == "__main__": + from jax.sharding import NamedSharding, PartitionSpec + from jax.experimental import shard_map + import sys + + coord_addr = sys.argv[1] + proc_id = int(sys.argv[2]) + num_procs = int(sys.argv[3]) + + jax.distributed.initialize( + coordinator_address=coord_addr, num_processes=num_procs, process_id=proc_id + ) + + mesh = jax.make_mesh((num_procs,), (MESH_AXIS_NAME,)) + + with mesh: + data_shapes = [((4, 16, 128, 7168), (7168, 2048))] + for data_shape in data_shapes: + for kernel_fsdp_axis in [1, 2]: + test_grouped_gemm_fp8_allgather(data_shape, kernel_fsdp_axis) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 0b2755744..04795dc3b 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -931,6 +931,7 @@ def grouped_quantize( x: jnp.ndarray, quantizer: GroupedQuantizer, group_sizes: jnp.ndarray = None, + amax: jnp.ndarray = None, flatten_axis: int = -1, ) -> GroupedScaledTensor1x: """Quantize a tensor in grouped manner. @@ -943,6 +944,7 @@ def grouped_quantize( x: Input tensor to quantize quantizer: The quantizer to use for quantization group_sizes: Array of ints containing the size of each group (default: None) + amax: The amax of x; if None, it is auto-generated. (default: None) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) Returns: @@ -985,7 +987,10 @@ def grouped_quantize( scale = scale.at[i].set(quantizer_i.scale[0]) if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: - row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) + if amax is not None: + row_amax = amax + else: + row_amax = jnp.max(jnp.abs(x), axis=range(group_axis + 1, x.ndim)) segment_ids = jnp.repeat( jnp.arange(n_groups), group_sizes, total_repeat_length=x.shape[group_axis] ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 29d0fbfa6..032ac9eb7 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -285,18 +285,17 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t out_dtype_bytes = te_dtype_bytes(out_dtype); if (is_tensor_scaling) { - cudaStream_t stream_0 = nvte_get_compute_stream(0); size_t dpitch = tensor_scaling_sinv_aligment; size_t spitch = lhs_sinv_dtype_bytes; size_t width = lhs_sinv_dtype_bytes; size_t height = lhs_sinv_size; cudaMemcpy2DAsync(lhs_scatter_aligned_ptr, dpitch, lhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream_0); + cudaMemcpyDeviceToDevice, stream); spitch = rhs_sinv_dtype_bytes; width = rhs_sinv_dtype_bytes; height = rhs_sinv_size; cudaMemcpy2DAsync(rhs_scatter_aligned_ptr, dpitch, rhs_sinv_ptr, spitch, width, height, - cudaMemcpyDeviceToDevice, stream_0); + cudaMemcpyDeviceToDevice, stream); lhs_sinv_ptr = lhs_scatter_aligned_ptr; rhs_sinv_ptr = rhs_scatter_aligned_ptr; } diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 4a50fe0e5..65d65e7d4 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -16,13 +16,45 @@ from . import cpp_extensions as tex from .quantize import ( + ScaledTensorFactory, + ScalingMode, + QuantizeLayout, QuantizerSet, noop_quantizer_set, with_sharding_constraint_by_logical_axes, + is_fp8_gemm_with_all_layouts_supported, TensorUsage, ) +def _all_gather_kernel(kernel, mesh_axis, axis_idx): + assert mesh_axis is not None + assert 0 < axis_idx < len(kernel.shape) + + # TODO(Ming Hunag): Add a condition branch for with/without shmap. + kernel_shape = kernel.shape + kernel_whole_shape = (*kernel_shape[:axis_idx], -1, *kernel_shape[axis_idx + 1 :]) + global_kernel = jax.lax.all_gather(kernel, mesh_axis, axis=axis_idx) + global_kernel = global_kernel.reshape(*kernel_whole_shape) + return global_kernel + + +def _psum_scatter_kernel(kernel, scattered_kernel_shape, mesh_axis, axis_idx): + assert mesh_axis is not None + assert 0 < axis_idx < len(scattered_kernel_shape) + + # TODO(Ming Hunag): Add a condition branch for with/without shmap. + kernel = kernel.reshape( + *scattered_kernel_shape[:axis_idx], + -1, + scattered_kernel_shape[axis_idx], + *scattered_kernel_shape[axis_idx + 1 :], + ) + kernel = jax.lax.psum_scatter(kernel, mesh_axis, scatter_dimension=axis_idx) + kernel = kernel.reshape(scattered_kernel_shape) + return kernel + + def dense( x: jnp.ndarray, kernel: jnp.ndarray, @@ -253,10 +285,12 @@ def grouped_dense( group_sizes: jnp.ndarray, contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (1,)), bias: jnp.ndarray = None, + kernel_amax: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, preferred_element_type: jnp.dtype = None, group_offset: jnp.array = None, quantizer_set: QuantizerSet = noop_quantizer_set, + kernel_fsdp_info: Tuple[str, int] = (None, -1), ): """ Perform grouped dense (linear) layer transformation with optional quantization. @@ -268,10 +302,15 @@ def grouped_dense( contracting_dims: Tuple of sequences specifying which dimensions to contract (currently only supports ((1,), (1,))) bias: Bias tensor of shape (G, N) + kernel_amax: The amax values of weight matrix of shape (G,) precision: JAX precision for the GEMM operation preferred_element_type: Preferred data type for the output tensor group_offset: 1D array containing offsets for each group (not yet implemented) quantizer_set: Set of quantizers for FP8 quantization of the input and output + kernel_fsdp_info: A tuple containing FSDP-related information for a weight matrix + represented in the format (str, int). The first element is the + FSDP mesh axis, and the second element is the dimension along + which the weight is sharded. Returns: A jnp.ndarray containing the result of the grouped linear operation @@ -282,25 +321,29 @@ def grouped_dense( group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ) return output -@partial(jax.custom_vjp, nondiff_argnums=(3, 5, 6, 7)) +@partial(jax.custom_vjp, nondiff_argnums=(3, 6, 7, 8, 10)) def _grouped_dense( x, kernel, group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ): output, _ = _grouped_dense_fwd_rule( x, @@ -308,10 +351,12 @@ def _grouped_dense( group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ) return output @@ -322,21 +367,31 @@ def _grouped_dense_fwd_rule( group_sizes, contracting_dims, bias, + kernel_amax, precision, preferred_element_type, group_offset, quantizer_set, + kernel_fsdp_info, ): use_bias = bias is not None is_noop_quantizer_set = quantizer_set == noop_quantizer_set + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info + kernel_fsdp_enabled = kernel_fsdp_mesh_axis is not None + if is_noop_quantizer_set: grouped_gemm_x = x grouped_gemm_kernel = kernel ctx_x = x ctx_kernel = kernel flatten_axis_k = None + + if kernel_fsdp_enabled: + kernel = _all_gather_kernel(kernel, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx) else: + original_quantizer_set_kernel_q_layout = quantizer_set.kernel.q_layout + x_contracting_dims, k_contracting_dims = contracting_dims flatten_axis_x = -len(x_contracting_dims) flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) + 1 # +1 for G axis @@ -352,10 +407,24 @@ def _grouped_dense_fwd_rule( ) casted_x = tex.grouped_quantize( - x, quantizer_set.x, group_sizes, flatten_axis=flatten_axis_x + x, + quantizer_set.x, + group_sizes, + flatten_axis=flatten_axis_x, ) + + ctx_kernel_usage = TensorUsage.RHS_TRANS + if kernel_fsdp_enabled: + assert quantizer_set.kernel.scaling_mode in [ + ScalingMode.CURRENT_TENSOR_SCALING, + ScalingMode.DELAYED_TENSOR_SCALING, + ] + # Perform `cast` only + ctx_kernel_usage = TensorUsage.LHS + quantizer_set.kernel.q_layout = QuantizeLayout.ROWWISE + casted_kernel = tex.grouped_quantize( - kernel, quantizer_set.kernel, flatten_axis=flatten_axis_k + kernel, quantizer_set.kernel, amax=kernel_amax, flatten_axis=flatten_axis_k ) contracting_dims = (x_contracting_dims, k_contracting_dims) @@ -363,9 +432,51 @@ def _grouped_dense_fwd_rule( # rowwise_casted_x.original_shape == (M, K) # colwise_casted_kernel.original_shape == (G, N, K) grouped_gemm_x = casted_x.get_tensor(usage=TensorUsage.LHS) - grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) ctx_x = casted_x.get_tensor(usage=TensorUsage.LHS_TRANS) - ctx_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS_TRANS) + ctx_kernel = casted_kernel.get_tensor(usage=ctx_kernel_usage) + + if kernel_fsdp_enabled: + ctx_kernel_in_original_shape = ctx_kernel.data.reshape(ctx_kernel.original_shape) + global_ctx_kernel_data = _all_gather_kernel( + ctx_kernel_in_original_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) + kernel_shape = global_ctx_kernel_data.shape + + ctx_kernel = ScaledTensorFactory.create_1x( + global_ctx_kernel_data.reshape(-1), + ctx_kernel.scale_inv, + ctx_kernel.scaling_mode, + dq_dtype=ctx_kernel.dq_dtype, + is_colwise=False, + data_layout="N", + flatten_axis=ctx_kernel.flatten_axis, + group_sizes=ctx_kernel.group_sizes, + original_shape=kernel_shape, + group_axis=ctx_kernel.group_axis, + ) + + if is_fp8_gemm_with_all_layouts_supported(): + grouped_gemm_kernel = ctx_kernel + else: + grouped_gemm_kernel_data = global_ctx_kernel_data.transpose(0, 2, 1) + grouped_gemm_kernel = ScaledTensorFactory.create_1x( + grouped_gemm_kernel_data.reshape(-1), + ctx_kernel.scale_inv, + ctx_kernel.scaling_mode, + dq_dtype=ctx_kernel.dq_dtype, + is_colwise=True, + data_layout="T", + flatten_axis=ctx_kernel.flatten_axis, + group_sizes=ctx_kernel.group_sizes, + original_shape=kernel_shape, + group_axis=ctx_kernel.group_axis, + ) + else: + grouped_gemm_kernel = casted_kernel.get_tensor(usage=TensorUsage.RHS) + + # Reset quantizer_set.kernel.q_layout to align the PyTree as the given one. + # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. + quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout output = tex.grouped_gemm( grouped_gemm_x, @@ -393,7 +504,7 @@ def _grouped_dense_fwd_rule( def _grouped_dense_bwd_rule( - contracting_dims, precision, preferred_element_type, group_offset, ctx, grad + contracting_dims, precision, preferred_element_type, group_offset, kernel_fsdp_info, ctx, grad ): fwd_x_contracting_dims, fwd_k_contracting_dims = contracting_dims @@ -474,11 +585,17 @@ def _grouped_dense_bwd_rule( preferred_element_type=preferred_element_type, group_offset=group_offset, ) + kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx = kernel_fsdp_info + if kernel_fsdp_mesh_axis is not None: + wgrad = _psum_scatter_kernel( + wgrad, kernel_shape, kernel_fsdp_mesh_axis, kernel_fsdp_axis_idx + ) group_sizes_grad = None dbias = tex.grouped_dbias(grad, group_sizes) if use_bias else None + dkernel_amax = None - return dgrad, wgrad, group_sizes_grad, dbias, quantizer_set + return dgrad, wgrad, group_sizes_grad, dbias, dkernel_amax, quantizer_set _grouped_dense.defvjp(_grouped_dense_fwd_rule, _grouped_dense_bwd_rule) From 04add79d520e960034063a752f231463ba85f426 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 27 Aug 2025 13:07:45 -0700 Subject: [PATCH 106/153] [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen --- transformer_engine/jax/sharding.py | 30 ++++++++++++++++++++---------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index caa2a4620..578517d62 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -41,22 +41,32 @@ def _get_mesh_info(resource: str, mesh: jax.sharding.Mesh): return mesh.shape[resource], resource -def _validate_mesh_resource_configuration(): +def _validate_mesh_resource_configuration(mesh_resource): """Validate that the mesh resource configuration is consistent and conflict-free.""" - gsr = global_mesh_resource() - - is_dp_enabled = gsr.dp_resource is not None and get_mesh_axis_size(gsr.dp_resource) > 1 - is_tp_enabled = gsr.tp_resource is not None and get_mesh_axis_size(gsr.tp_resource) > 1 - is_tpsp_enabled = gsr.tpsp_resource is not None and get_mesh_axis_size(gsr.tpsp_resource) > 1 - is_fsdp_enabled = gsr.fsdp_resource is not None and get_mesh_axis_size(gsr.fsdp_resource) > 1 + is_dp_enabled = ( + mesh_resource.dp_resource is not None and get_mesh_axis_size(mesh_resource.dp_resource) > 1 + ) + is_tp_enabled = ( + mesh_resource.tp_resource is not None and get_mesh_axis_size(mesh_resource.tp_resource) > 1 + ) + is_tpsp_enabled = ( + mesh_resource.tpsp_resource is not None + and get_mesh_axis_size(mesh_resource.tpsp_resource) > 1 + ) + is_fsdp_enabled = ( + mesh_resource.fsdp_resource is not None + and get_mesh_axis_size(mesh_resource.fsdp_resource) > 1 + ) assert not (is_dp_enabled and is_fsdp_enabled), ( "Data parallelism and full-sharded data parallelism cannot be enabled at the same time." - f" Got dp_resource={gsr.dp_resource} and fsdp_resource={gsr.fsdp_resource}" + f" Got dp_resource={mesh_resource.dp_resource} and" + f" fsdp_resource={mesh_resource.fsdp_resource}" ) assert not (is_tp_enabled and is_tpsp_enabled), ( "Tensor parallelism and tensor sequence parallelism cannot be enabled at the same time." - f" Got tp_resource={gsr.tp_resource} and tpsp_resource={gsr.tpsp_resource}" + f" Got tp_resource={mesh_resource.tp_resource} and" + f" tpsp_resource={mesh_resource.tpsp_resource}" ) @@ -305,7 +315,6 @@ def global_shard_guard(resource: MeshResource): old_resources = _GLOBAL_MESH_RESOURCE try: _GLOBAL_MESH_RESOURCE = resource - _validate_mesh_resource_configuration() yield finally: _GLOBAL_MESH_RESOURCE = old_resources @@ -322,6 +331,7 @@ def global_mesh_resource() -> MeshResource: " context. If you are not using multiple GPUs, you can use an empty MeshResource by" " wrapping your program in 'with global_shard_guard(MeshResource()):'" ) + _validate_mesh_resource_configuration(_GLOBAL_MESH_RESOURCE) return _GLOBAL_MESH_RESOURCE From c95080002d87eb722f245f5e5766b2f7dd20a4c7 Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Wed, 27 Aug 2025 13:22:32 -0700 Subject: [PATCH 107/153] [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> --- tests/jax/test_helper.py | 37 +-- tests/jax/test_layer.py | 45 ++-- transformer_engine/jax/cpp_extensions/base.py | 2 +- transformer_engine/jax/cpp_extensions/gemm.py | 6 +- transformer_engine/jax/flax/module.py | 31 ++- transformer_engine/jax/quantize/helper.py | 250 ++++++++++-------- transformer_engine/jax/quantize/quantizer.py | 79 ++++-- .../jax/quantize/scaling_modes.py | 2 +- 8 files changed, 260 insertions(+), 192 deletions(-) diff --git a/tests/jax/test_helper.py b/tests/jax/test_helper.py index 9b67de6dd..e4511e1fe 100644 --- a/tests/jax/test_helper.py +++ b/tests/jax/test_helper.py @@ -14,10 +14,11 @@ from transformer_engine.common.recipe import Format as FP8Format from transformer_engine.jax import fp8_autocast, get_delayed_scaling from transformer_engine.jax.quantize import ( - QuantizeConfig, + get_quantize_config, is_fp8_available, ScalingMode, update_collections, + TensorSource, ) from transformer_engine.jax.sharding import MeshResource, global_mesh_resource @@ -49,7 +50,7 @@ def test_update_collections(self): class TestFP8Functions(unittest.TestCase): def _check_default_state(self): - self.assertFalse(QuantizeConfig.is_fp8_enabled()) + self.assertFalse(get_quantize_config().is_fp8_enabled()) def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.margin == test.margin) @@ -58,17 +59,23 @@ def _compare_delay_scaling(self, ref, test): self.assertTrue(ref.amax_compute_algo == test.amax_compute_algo) def _compare_current_scaling(self, test): - self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) - self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.CURRENT_TENSOR_SCALING) + self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + for tensor_source in TensorSource: + self.assertEqual( + get_quantize_config().get_scaling_mode(tensor_source), + ScalingMode.CURRENT_TENSOR_SCALING, + ) def _compare_mxfp8_scaling(self, test): - self.assertEqual(QuantizeConfig.MARGIN, test.margin) - self.assertEqual(QuantizeConfig.FP8_FORMAT, test.fp8_format) - self.assertEqual(QuantizeConfig.SCALING_MODE, ScalingMode.MXFP8_1D_SCALING) + self.assertEqual(get_quantize_config().MARGIN, test.margin) + self.assertEqual(get_quantize_config().FP8_FORMAT, test.fp8_format) + for tensor_source in TensorSource: + self.assertEqual( + get_quantize_config().get_scaling_mode(tensor_source), ScalingMode.MXFP8_1D_SCALING + ) @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_delayed_scaling(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() with fp8_autocast(enabled=False, fp8_recipe=DelayedScaling(), mesh_resource=MeshResource()): @@ -78,21 +85,20 @@ def test_fp8_autocast_delayed_scaling(self): ds = DelayedScaling(margin=5.0, fp8_format=FP8Format.E4M3, amax_history_len=1) with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_delay_scaling(get_delayed_scaling(), ds) self._check_default_state() ds = DelayedScaling(margin=3.0, fp8_format=FP8Format.HYBRID, amax_history_len=1) with fp8_autocast(enabled=True, fp8_recipe=ds, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_delay_scaling(get_delayed_scaling(), ds) self._check_default_state() @unittest.skipIf(not is_fp8_supported, reason=reason) def test_fp8_autocast_current_scaling(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() with fp8_autocast( @@ -104,21 +110,20 @@ def test_fp8_autocast_current_scaling(self): cs = Float8CurrentScaling(fp8_format=FP8Format.E4M3) with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() cs = Float8CurrentScaling(fp8_format=FP8Format.HYBRID) with fp8_autocast(enabled=True, fp8_recipe=cs, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_current_scaling(cs) self._check_default_state() @unittest.skipIf(not is_mxfp8_supported, reason=mxfp8_reason) def test_fp8_autocast_mxfp8_block_scaling(self): - QuantizeConfig.finalize() # Ensure the testing not affect by previous tests. self._check_default_state() with fp8_autocast( @@ -130,14 +135,14 @@ def test_fp8_autocast_mxfp8_block_scaling(self): bs = MXFP8BlockScaling(margin=5.0, fp8_format=FP8Format.E4M3) with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() bs = MXFP8BlockScaling(margin=3.0, fp8_format=FP8Format.HYBRID) with fp8_autocast(enabled=True, fp8_recipe=bs, mesh_resource=MeshResource()): - self.assertTrue(QuantizeConfig.is_fp8_enabled()) + self.assertTrue(get_quantize_config().is_fp8_enabled()) self._compare_mxfp8_scaling(bs) self._check_default_state() diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 8fe7ebae3..6f672ade7 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -23,12 +23,14 @@ from transformer_engine.common import recipe from transformer_engine.jax.flax import TransformerLayer, TransformerLayerType from transformer_engine.jax.quantize import ( - QuantizeConfig, + get_quantize_config, ScalingMode, is_fp8_available, update_collections, + TensorSource, + fp8_autocast, ) -from transformer_engine.jax.sharding import MeshResource, global_shard_guard +from transformer_engine.jax.sharding import MeshResource @pytest.fixture(autouse=True, scope="function") @@ -356,7 +358,7 @@ def test_backward( ref_params, test_params = self._sync_params(ref_params, test_params) - if QuantizeConfig.is_fp8_enabled(): + if get_quantize_config().is_fp8_enabled(): for _ in range(4): _, updated_state = jax.value_and_grad(self._loss_fn, argnums=(3,), has_aux=False)( inputs, @@ -365,12 +367,15 @@ def test_backward( test_others, test_layer, ) - if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING: + if ( + get_quantize_config().get_scaling_mode(TensorSource.X) + == ScalingMode.DELAYED_TENSOR_SCALING + ): _, updated_quantize_meta = flax.core.pop( - updated_state[0], QuantizeConfig.COLLECTION_NAME + updated_state[0], get_quantize_config().COLLECTION_NAME ) test_others = update_collections( - {QuantizeConfig.COLLECTION_NAME: updated_quantize_meta}, test_others + {get_quantize_config().COLLECTION_NAME: updated_quantize_meta}, test_others ) del updated_quantize_meta del updated_state @@ -500,41 +505,33 @@ class BaseTester: def test_forward(self, data_shape, dtype, attrs): """Test normal datatype forward""" - QuantizeConfig.finalize() # Ensure FP8 disabled. - with global_shard_guard( - MeshResource() - ): # Empty MeshResource is used as we are running on a single device + # Ensure FP8 disabled. + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=False, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype) def test_backward(self, data_shape, dtype, attrs): """Test normal datatype backward""" - QuantizeConfig.finalize() # Ensure FP8 disabled. - with global_shard_guard( - MeshResource() - ): # Empty MeshResource is used as we are running on a single device + # Ensure FP8 disabled. + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=False, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype) @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_forward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test forward with fp8 enabled""" - QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - with global_shard_guard( - MeshResource() - ): # Empty MeshResource is used as we are running on a single device + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_forward(data_shape, dtype, rtol=1e-4, atol=1e-3) - QuantizeConfig.finalize() @pytest.mark.skipif(not is_fp8_supported, reason=reason) @pytest.mark.parametrize("fp8_recipe", QUANTIZE_RECIPES) def test_backward_with_fp8(self, data_shape, dtype, attrs, fp8_recipe): """Test backward with fp8 enabled""" - QuantizeConfig.initialize(fp8_recipe=fp8_recipe) - with global_shard_guard( - MeshResource() - ): # Empty MeshResource is used as we are running on a single device + # Empty MeshResource is used as we are running on a single device + with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): self.runner(attrs).test_backward(data_shape, dtype, rtol=1e-4, atol=1e-3) - QuantizeConfig.finalize() class TestEncoderLayer(BaseTester): diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index 22842e4f3..a27cec001 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -219,7 +219,7 @@ def manage_primitives(enable_names=None, disable_names=None, disable_all_first=F """ Helper function to manage primitive states by name without modifying environment variables. Allows enabling specific primitives, disabling specific primitives, or disabling all primitives. - This helper is used in the QuantizeConfig.initialize() methods. + This helper is used in the get_quantize_config().initialize() methods. Args: enable_names: List of strings, each representing the name of a primitive class to enable. Defaults to None. diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 95ef42821..be73f708e 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -28,7 +28,7 @@ ScalingMode, Quantizer, GroupedQuantizer, - QuantizeConfig, + get_quantize_config, QuantizerSet, QuantizeLayout, noop_quantizer_set, @@ -754,7 +754,7 @@ def _te_gemm( fuse_bias: bool = False, fuse_gelu: bool = False, grad: bool = False, - use_split_accumulator: bool = QuantizeConfig.FP8_2X_ACC_FPROP, + use_split_accumulator: bool = get_quantize_config().FP8_2X_ACC_FPROP, ) -> Tuple[jax.Array, ...]: # Prepare non-quantized GEMM operands @@ -1107,7 +1107,7 @@ def _jax_gemm_fp8_impl(lhs, rhs): ), f"rhs.scaling_mode={rhs.scaling_mode} != lhs.scaling_mode={lhs.scaling_mode}" precision = ( jax.lax.Precision.HIGHEST - if QuantizeConfig.FP8_2X_ACC_FPROP + if get_quantize_config().FP8_2X_ACC_FPROP else jax.lax.Precision.DEFAULT ) return _jax_gemm_tensor_scaling_fp8(lhs, rhs, dim_nums, precision) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index dc9d0209b..c548c54ef 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -32,7 +32,14 @@ jax_scaled_masked_softmax, jax_scaled_upper_triang_masked_softmax, ) -from ..quantize import QuantizerFactory, QuantizeConfig, QuantizeMeta, QuantizeMetaSet, ScalingMode +from ..quantize import ( + QuantizerFactory, + get_quantize_config, + QuantizeMeta, + QuantizeMetaSet, + ScalingMode, + TensorSource, +) PRNGKey = Any Shape = Tuple[int, ...] @@ -350,7 +357,7 @@ def generate_quantize_meta(quantizer_name: str): collection_name = ( variable_collection if variable_collection is not None - else QuantizeConfig.COLLECTION_NAME + else get_quantize_config().COLLECTION_NAME ) scale = self.variable( collection_name, @@ -363,14 +370,14 @@ def generate_quantize_meta(quantizer_name: str): collection_name, f"{quantizer_name}{postfix}_amax_history", jnp.zeros, - (QuantizeConfig.AMAX_HISTORY_LEN,), + (get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32, ).value return QuantizeMeta(scale=scale, amax_history=amax_history) - if QuantizeConfig.SCALING_MODE == ScalingMode.DELAYED_TENSOR_SCALING or isinstance( - fp8_recipe, recipe.DelayedScaling - ): + if get_quantize_config().get_scaling_mode( + TensorSource.X + ) == ScalingMode.DELAYED_TENSOR_SCALING or isinstance(fp8_recipe, recipe.DelayedScaling): x_meta = generate_quantize_meta("x") kernel_meta = generate_quantize_meta("kernel") grad_meta = generate_quantize_meta("grad") @@ -483,7 +490,7 @@ def __call__(self, inputs: Array) -> Array: self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel = kernel.astype(input_dtype) if self.use_bias: @@ -692,7 +699,7 @@ def __call__(self, inputs: Array) -> Array: quantizer_set = self.generate_quantizer_set() fuse_layernorm = ( - QuantizeConfig.is_fp8_enabled() + get_quantize_config().is_fp8_enabled() and not self.return_layernorm_output and self.enable_layernorm ) @@ -743,7 +750,7 @@ def __call__(self, inputs: Array) -> Array: kernel_shape, self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel = kernel.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) @@ -1005,7 +1012,7 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: # TODO(Phuong): use fuse_layernorm for high-precision # when NoOpQuantizer and Tensor are implemented fuse_layernorm = ( - QuantizeConfig.is_fp8_enabled() + get_quantize_config().is_fp8_enabled() and not self.return_layernorm_output and self.enable_layernorm ) @@ -1088,7 +1095,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel_1 = kernel_1.astype(input_dtype) hidden_size = inputs.shape[-1] @@ -1100,7 +1107,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): kernel_2_shape, self.dtype, ) - if not QuantizeConfig.is_fp8_enabled(): + if not get_quantize_config().is_fp8_enabled(): kernel_2 = kernel_2.astype(input_dtype) contract_ind = tuple(range(0, len(axis))) diff --git a/transformer_engine/jax/quantize/helper.py b/transformer_engine/jax/quantize/helper.py index f8d18983e..3d460e81a 100644 --- a/transformer_engine/jax/quantize/helper.py +++ b/transformer_engine/jax/quantize/helper.py @@ -7,9 +7,11 @@ This module provides configuration and helper functions for managing quantization metadata in JAX, including support for different scaling modes and datatypes. """ +from abc import ABC, abstractmethod from contextlib import contextmanager +from dataclasses import dataclass from enum import Enum -from typing import Optional, Tuple, Dict, Union, Sequence +from typing import Optional, Tuple, Dict, Union, Sequence, Type from functools import reduce import operator @@ -26,7 +28,7 @@ from .device_utils import get_device_compute_capability __all__ = [ - "QuantizeConfig", + "get_quantize_config", "fp8_autocast", "is_fp8_available", "update_collections", @@ -34,12 +36,15 @@ "apply_padding_to_scale_inv", "remove_padding_from_scale_inv", "NVTE_FP8_COLLECTION_NAME", + "TensorSource", ] _is_fp8_available = None _reason_for_no_fp8 = "" Collection = Union[Dict, FrozenDict] +NVTE_FP8_COLLECTION_NAME = "fp8_metas" + def _check_delayed_scaling_fp8_support(gpu_arch) -> Tuple[bool, str]: """Check if delayed scaling FP8 is supported on the given GPU architecture. @@ -154,6 +159,17 @@ def _format2dtypes(format_: recipe.Format): return jnp.bfloat16, jnp.bfloat16 +class TensorSource(Enum): + """Enumeration for where a tensor's data comes from.""" + + # Input data + X = 0 + # Model parameters + KERNEL = 1 + # Gradients in the backward pass + DGRAD = 2 + + class AmaxComputeAlgo(Enum): """Enumeration for AMAX computation algorithms. @@ -166,28 +182,8 @@ class AmaxComputeAlgo(Enum): MOST_RECENT = "most_recent" -def _get_scaling_mode(fp8_recipe: recipe.Recipe) -> ScalingMode: - """Convert recipe.Recipe to ScalingMode. - - Args: - fp8_recipe: The FP8 recipe to convert - - Returns: - The corresponding ScalingMode - - Raises: - ValueError: If the recipe type is not supported - """ - if isinstance(fp8_recipe, recipe.DelayedScaling): - return ScalingMode.DELAYED_TENSOR_SCALING - if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - return ScalingMode.MXFP8_1D_SCALING - if isinstance(fp8_recipe, recipe.Float8CurrentScaling): - return ScalingMode.CURRENT_TENSOR_SCALING - raise ValueError("Invalid fp8_recipe!") - - -class QuantizeConfig: +@dataclass +class BaseQuantizeConfig(ABC): """Configuration class for quantization settings. This class manages global quantization settings including FP8 formats, @@ -204,14 +200,13 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: Whether to use 2x accumulation for data gradients FP8_2X_ACC_WGRAD: Whether to use 2x accumulation for weight gradients INFERENCE_MODE: Whether to enable optimization for inference - SCALING_MODE: Scaling mode AMAX_HISTORY_LEN: Length of AMAX history for delayed scaling AMAX_COMPUTE_ALGO: Algorithm for AMAX computation """ INITIALIZED = False MARGIN: float = 0.0 - COLLECTION_NAME: str = "fp8_metas" + COLLECTION_NAME: str = NVTE_FP8_COLLECTION_NAME FP8_FORMAT: recipe.Format = recipe.Format.HYBRID FWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[0] BWD_DTYPE: DType = _format2dtypes(recipe.Format.HYBRID)[1] @@ -219,61 +214,82 @@ class QuantizeConfig: FP8_2X_ACC_DGRAD: bool = False FP8_2X_ACC_WGRAD: bool = False INFERENCE_MODE: bool = False - SCALING_MODE: ScalingMode = ScalingMode.NO_SCALING # DelayedScaling AMAX_HISTORY_LEN: int = 1024 AMAX_COMPUTE_ALGO: AmaxComputeAlgo = AmaxComputeAlgo.MAX - @staticmethod - def is_fp8_enabled(): + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + """Initialize the quantization configuration. + + Args: + fp8_recipe: The FP8 recipe to use for initialization + """ + self.INITIALIZED = True + self.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 + self.FP8_FORMAT = fp8_recipe.fp8_format + self.FWD_DTYPE, self.BWD_DTYPE = _format2dtypes(self.FP8_FORMAT) + + def is_fp8_enabled(self) -> bool: """Check if FP8 quantization is enabled. Returns: bool: True if quantization is enabled, False otherwise """ - return QuantizeConfig.INITIALIZED + return self.INITIALIZED - @classmethod - def initialize(cls, fp8_recipe: recipe.Recipe) -> None: - """Initialize the quantization configuration. + @abstractmethod + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type. Args: - fp8_recipe: The FP8 recipe to use for initialization + tensor_source: The usage type for which to get the scaling mode. + + Returns: + The scaling mode for the specified usage type. + """ + + def is_supported(self) -> tuple[bool, str]: + """Check if this QuantizeConfig class is supported on the available devices. + + Returns: + bool: True if the class is supported, False otherwise + str: Reason for being unsupported, if applicable. """ - cls.INITIALIZED = True - cls.MARGIN = fp8_recipe.margin if "margin" in dir(fp8_recipe) else 0.0 - cls.FP8_FORMAT = fp8_recipe.fp8_format - cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = _get_scaling_mode(fp8_recipe) - - @classmethod - def finalize(cls) -> None: - """Reset the quantization configuration to default values.""" - cls.INITIALIZED = False - cls.MARGIN = 0.0 - cls.FP8_FORMAT = recipe.Format.HYBRID - cls.FWD_DTYPE, cls.BWD_DTYPE = _format2dtypes(cls.FP8_FORMAT) - cls.SCALING_MODE = ScalingMode.NO_SCALING - cls.FP8_2X_ACC_FPROP = False - cls.FP8_2X_ACC_DGRAD = False - cls.FP8_2X_ACC_WGRAD = False - cls.SCALING_MODE = ScalingMode.NO_SCALING - cls.INFERENCE_MODE = False - # DelayedScaling - cls.AMAX_HISTORY_LEN = 1024 - cls.AMAX_COMPUTE_ALGO = AmaxComputeAlgo.MAX - - -class DelayedScalingQuantizeConfig: + + x_scaling_mode = self.get_scaling_mode(TensorSource.X) + kernel_scaling_mode = self.get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = self.get_scaling_mode(TensorSource.DGRAD) + for scaling_mode in [x_scaling_mode, kernel_scaling_mode, grad_scaling_mode]: + is_supported, reason = is_fp8_available(scaling_mode=scaling_mode) + if not is_supported: + return is_supported, reason + return True, None + + +class NoOpQuantizeConfig(BaseQuantizeConfig): + """Configuration class higher-precision non-quantized operation.""" + + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: + """Initialize no-op configuration.""" + raise NotImplementedError( + "NoOpQuantizeConfig cannot be initialize from a recipe as it represents" + " higher-precision when no quantized recipe is set." + ) + + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.NO_SCALING + + +class DelayedScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for delayed scaling FP8 recipe. This class provides specific initialization and finalization for delayed scaling FP8 quantization mode. """ - @staticmethod - def initialize(fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: """Initialize delayed scaling FP8 configuration. Args: @@ -282,6 +298,8 @@ def initialize(fp8_recipe: recipe.Recipe) -> None: Raises: AssertionError: If recipe parameters are not supported """ + super().initialize_from_recipe(fp8_recipe) + assert fp8_recipe.amax_compute_algo in [ "max", "most_recent", @@ -291,71 +309,88 @@ def initialize(fp8_recipe: recipe.Recipe) -> None: ), "DelayedScaling scaling_factor_compute_algo isn't supported by TE/JAX." assert fp8_recipe.reduce_amax, "DelayedScaling reduce_amax should be enabled for TE/JAX." - cls = QuantizeConfig - cls.initialize(fp8_recipe) - - cls.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len + self.AMAX_HISTORY_LEN = fp8_recipe.amax_history_len string_to_amax_compute_algo = { "max": AmaxComputeAlgo.MAX, "most_recent": AmaxComputeAlgo.MOST_RECENT, } - cls.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo] + self.AMAX_COMPUTE_ALGO = string_to_amax_compute_algo[fp8_recipe.amax_compute_algo] - cls.FP8_2X_ACC_DGRAD = True - cls.FP8_2X_ACC_WGRAD = True + self.FP8_2X_ACC_DGRAD = True + self.FP8_2X_ACC_WGRAD = True - @staticmethod - def finalize() -> None: - """Reset the delayed scaling configuration.""" - QuantizeConfig.finalize() + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.DELAYED_TENSOR_SCALING -class CurrentScalingQuantizeConfig: +class CurrentScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for current scaling FP8 recipe. This class provides specific initialization and finalization for current scaling FP8 quantization mode. """ - @staticmethod - def initialize(fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: """Initialize current scaling FP8 configuration. Args: fp8_recipe: The FP8 recipe to use for initialization """ - cls = QuantizeConfig - cls.initialize(fp8_recipe) - cls.AMAX_HISTORY_LEN = 0 + super().initialize_from_recipe(fp8_recipe) + self.AMAX_HISTORY_LEN = 0 - @staticmethod - def finalize() -> None: - """Reset the current scaling configuration.""" - QuantizeConfig.finalize() + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.CURRENT_TENSOR_SCALING -class BlockScalingQuantizeConfig: +class BlockScalingQuantizeConfig(BaseQuantizeConfig): """Configuration class for block scaling FP8 recipe. This class provides specific initialization and finalization for block scaling FP8 quantization mode. """ - @staticmethod - def initialize(fp8_recipe: recipe.Recipe) -> None: + def initialize_from_recipe(self, fp8_recipe: recipe.Recipe) -> None: """Initialize block scaling FP8 configuration. Args: fp8_recipe: The FP8 recipe to use for initialization """ - cls = QuantizeConfig - cls.initialize(fp8_recipe) - cls.AMAX_HISTORY_LEN = 0 + super().initialize_from_recipe(fp8_recipe) + self.AMAX_HISTORY_LEN = 0 + + def get_scaling_mode(self, tensor_source: TensorSource) -> ScalingMode: + """Gets the scaling mode for a specific tensor's usage type.""" + return ScalingMode.MXFP8_1D_SCALING + + +_QUANTIZE_CONFIG = NoOpQuantizeConfig() + - @staticmethod - def finalize() -> None: - """Reset the block scaling configuration.""" - QuantizeConfig.finalize() +def get_quantize_config(): + """Global instance of BaseQuantizeConfig set by fp8_autocast context.""" + return _QUANTIZE_CONFIG + + +def get_quantize_config_class( + fp8_recipe: recipe.Recipe, +) -> Type[BaseQuantizeConfig]: + """Get the quantization configuration based on the FP8 recipe. + + Args: + fp8_recipe: The FP8 recipe to use for initialization + Returns: + The quantization config class corresponding to the given recipe. + """ + if isinstance(fp8_recipe, recipe.DelayedScaling): + return DelayedScalingQuantizeConfig + if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): + return BlockScalingQuantizeConfig + if isinstance(fp8_recipe, recipe.Float8CurrentScaling): + return CurrentScalingQuantizeConfig + raise ValueError(f"Unsupported recipe type: {type(fp8_recipe)}") @contextmanager @@ -404,22 +439,22 @@ def fp8_autocast( if fp8_recipe is None: fp8_recipe = recipe.DelayedScaling() - Config = DelayedScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.MXFP8BlockScaling): - Config = BlockScalingQuantizeConfig - if isinstance(fp8_recipe, recipe.Float8CurrentScaling): - Config = CurrentScalingQuantizeConfig + global _QUANTIZE_CONFIG + + old_quantize_config = _QUANTIZE_CONFIG + + _QUANTIZE_CONFIG = NoOpQuantizeConfig() try: with global_shard_guard(mesh_resource): if enabled: - fp8_available, reason_for_no_fp8 = is_fp8_available(_get_scaling_mode(fp8_recipe)) - assert fp8_available, reason_for_no_fp8 - - Config.initialize(fp8_recipe) + _QUANTIZE_CONFIG = get_quantize_config_class(fp8_recipe)() + is_supported, reason = _QUANTIZE_CONFIG.is_supported() + assert is_supported, reason + _QUANTIZE_CONFIG.initialize_from_recipe(fp8_recipe) yield finally: - Config.finalize() + _QUANTIZE_CONFIG = old_quantize_config def get_delayed_scaling(): @@ -437,12 +472,12 @@ def get_delayed_scaling(): an instance of DelayedScaling which is set via fp8_autocast. """ amax_compute_algo = ( - "max" if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" + "max" if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX else "most_recent" ) return recipe.DelayedScaling( - margin=int(QuantizeConfig.MARGIN), - fp8_format=QuantizeConfig.FP8_FORMAT, - amax_history_len=QuantizeConfig.AMAX_HISTORY_LEN, + margin=int(get_quantize_config().MARGIN), + fp8_format=get_quantize_config().FP8_FORMAT, + amax_history_len=get_quantize_config().AMAX_HISTORY_LEN, amax_compute_algo=amax_compute_algo, ) @@ -581,6 +616,3 @@ def apply_padding_to_scale_inv( # Pad the scales with the lowest representable value (2^-127) and return pad_width = tuple((0, a - b) for a, b in zip(padded_scale_shape, unpadded_scale_shape)) return jnp.pad(scale_inv, pad_width=pad_width, mode="constant", constant_values=2**-127) - - -NVTE_FP8_COLLECTION_NAME = QuantizeConfig.COLLECTION_NAME diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 9a65f99bf..6cecfa361 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -21,9 +21,10 @@ from .scaling_modes import ScalingMode from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory from .helper import ( - QuantizeConfig, + get_quantize_config, + get_quantize_config_class, AmaxComputeAlgo, - _get_scaling_mode, + TensorSource, ) from .device_utils import is_fp8_gemm_with_all_layouts_supported @@ -56,7 +57,7 @@ def compute_scale_from_amax( fp8_max = jnp.astype(jnp.finfo(q_dtype).max, jnp.float32) if scale is None: scale = jnp.ones((1,)) - sf = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) + sf = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) sf = jnp.where(amax > 0.0, sf, scale) sf = jnp.where(jnp.isfinite(amax), sf, scale) return sf @@ -234,7 +235,7 @@ def _quantize_func( dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) amax = jnp.max(jnp.abs(x)).reshape((1,)) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) - scale = (fp8_max / amax) / (2**QuantizeConfig.MARGIN) + scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) scaled_x = x.astype(compute_dtype) * scale clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) @@ -320,7 +321,7 @@ class DelayedScaleQuantizer(CurrentScaleQuantizer): scale: jnp.ndarray = field(default_factory=lambda: jnp.ones((1,), jnp.float32)) amax_history: jnp.ndarray = field( - default_factory=lambda: jnp.zeros((QuantizeConfig.AMAX_HISTORY_LEN,), jnp.float32) + default_factory=lambda: jnp.zeros((get_quantize_config().AMAX_HISTORY_LEN,), jnp.float32) ) def tree_flatten(self): @@ -397,7 +398,7 @@ def _compute_scale(amax_history, scale, q_dtype): Updated scale value """ # 2. Calculate the current scale - if QuantizeConfig.AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: + if get_quantize_config().AMAX_COMPUTE_ALGO is AmaxComputeAlgo.MAX: amax = jnp.max(amax_history, axis=-1, keepdims=True) else: amax = amax_history[0:1] @@ -827,12 +828,21 @@ def create( @staticmethod def _create_set( - scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs + x_scaling_mode, + kernel_scaling_mode, + grad_scaling_mode, + fwd_dtype, + bwd_dtype, + is_2x2x, + n_groups, + **kwargs, ) -> QuantizerSet: """Create a set of quantizers for forward and backward passes. Args: - scaling_mode: Scaling mode to use + x_scaling_mode: Scaling mode to use for input tensor 'x' + kernel_scaling_mode: Scaling mode to use for kernel tensor + grad_scaling_mode: Scaling mode to use for gradient tensor fwd_dtype: Data type for forward pass bwd_dtype: Data type for backward pass is_2x2x: Whether to use 2x2x quantization @@ -846,9 +856,9 @@ def _create_set( q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE_COLWISE else: q_layout_x = q_layout_kernel = q_layout_dgrad = QuantizeLayout.ROWWISE - if scaling_mode.is_1d_block_scaling(): + if kernel_scaling_mode.is_1d_block_scaling(): q_layout_kernel = QuantizeLayout.COLWISE - if QuantizeConfig.INFERENCE_MODE: + if get_quantize_config().INFERENCE_MODE: q_layout_dgrad = None if "quantize_meta_set" in kwargs: @@ -868,12 +878,12 @@ def _create_set( else: args_x = args_kernel = args_grad = {} - q_x = QuantizerFactory.create(1, scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) + q_x = QuantizerFactory.create(1, x_scaling_mode, fwd_dtype, q_layout_x, n_groups, **args_x) q_kernel = QuantizerFactory.create( - 1, scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel + 1, kernel_scaling_mode, fwd_dtype, q_layout_kernel, n_groups, **args_kernel ) q_dgrad = QuantizerFactory.create( - 1, scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad + 1, grad_scaling_mode, bwd_dtype, q_layout_dgrad, n_groups, **args_grad ) return QuantizerSet(x=q_x, kernel=q_kernel, dgrad=q_dgrad) @@ -892,10 +902,10 @@ def create_set( Args: n_quantizer_sets: Number of quantizer sets to create - scaling_mode: Scaling mode to use, default is QuantizeConfig.SCALING_MODE - fwd_dtype: Data type for forward pass, default is QuantizeConfig.FWD_DTYPE - bwd_dtype: Data type for backward pass, default is QuantizeConfig.BWD_DTYPE - is_2x2x: Whether to use 2x2x quantization, default is QuantizeConfig.IF_QUANTIZE_2X + scaling_mode: Scaling mode to use, default is get_quantize_config().get_scaling_mode + fwd_dtype: Data type for forward pass, default is get_quantize_config().FWD_DTYPE + bwd_dtype: Data type for backward pass, default is get_quantize_config().BWD_DTYPE + is_2x2x: Whether to use 2x2x quantization, default is get_quantize_config().IF_QUANTIZE_2X n_groups: fp8_recipe: Recipe to use for quantization. Scaling mode can be specified directly via the scaling_mode parameter or indirectly via recipe. Recipe is preferred as it will support additional recipes in future where scaling mode differs between x, kernel, and grad in the quantizer set. **kwargs: Additional arguments for quantizer initialization @@ -912,27 +922,44 @@ def create_set( ) if fp8_recipe is not None: - # TODO(jberchtold): once recipe and scaling mode are decoupled update this logic - scaling_mode = _get_scaling_mode(fp8_recipe) + quantize_config = get_quantize_config_class(fp8_recipe)() + x_scaling_mode = quantize_config.get_scaling_mode(TensorSource.X) + kernel_scaling_mode = quantize_config.get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = quantize_config.get_scaling_mode(TensorSource.DGRAD) + elif scaling_mode is not None: + x_scaling_mode = scaling_mode + kernel_scaling_mode = scaling_mode + grad_scaling_mode = scaling_mode else: - scaling_mode = scaling_mode or QuantizeConfig.SCALING_MODE - fwd_dtype = fwd_dtype or QuantizeConfig.FWD_DTYPE - bwd_dtype = bwd_dtype or QuantizeConfig.BWD_DTYPE + x_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.X) + kernel_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.KERNEL) + grad_scaling_mode = get_quantize_config().get_scaling_mode(TensorSource.DGRAD) + + fwd_dtype = fwd_dtype or get_quantize_config().FWD_DTYPE + bwd_dtype = bwd_dtype or get_quantize_config().BWD_DTYPE if is_2x2x is None: - if scaling_mode.is_1d_block_scaling(): + # TODO(Jeremy): check x, kernel, grad separately for 2x + if x_scaling_mode.is_1d_block_scaling(): is_2x2x = True - elif scaling_mode.is_tensor_scaling(): + elif x_scaling_mode.is_tensor_scaling(): is_2x2x = not is_fp8_gemm_with_all_layouts_supported() else: # NO_SCALING ignores is_2x2x for now is_2x2x = False - is_inference_mode = QuantizeConfig.INFERENCE_MODE + is_inference_mode = get_quantize_config().INFERENCE_MODE assert not is_inference_mode, "Inference mode is not supported yet!" q_set = [] for _ in range(n_quantizer_sets): q_set.append( QuantizerFactory._create_set( - scaling_mode, fwd_dtype, bwd_dtype, is_2x2x, n_groups, **kwargs + x_scaling_mode=x_scaling_mode, + kernel_scaling_mode=kernel_scaling_mode, + grad_scaling_mode=grad_scaling_mode, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=is_2x2x, + n_groups=n_groups, + **kwargs, ) ) diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index fc4fd1353..868570f73 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -396,7 +396,7 @@ def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: The quantize layout for the tensor usage """ # If we need to support 1x1x for inference in the future - # if QuantizeConfig.INFERENCE_MODE: + # if get_quantize_config().INFERENCE_MODE: # assert usage not in (TensorUsage.LHS_TRANS, TensorUsage.RHS_TRANS), (f"Invalid usage {usage} as we are in MXFP8_1D_SCALING 1x1x (FWD only) mode so no transposed usage is needed!") # if usage == TensorUsage.LHS: # return QuantizeLayout.ROWWISE From a282136c7aa111fadb9f2c0866b11ad236f44485 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 27 Aug 2025 18:03:58 -0400 Subject: [PATCH 108/153] [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/layernorm_mlp.py | 7 +++++++ transformer_engine/jax/sharding.py | 2 +- 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 8727ea7e3..00e3ddc3e 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -289,6 +289,13 @@ def _layernorm_mlp_fwd_rule( bias_1_new_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1_shape dot_1_output += jnp.reshape(bias_1, bias_1_new_shape) + # This sharding constraint is needed to correct the Shardy sharding propagation + if dot_2_input_axes is not None: + dot_1_output_axes = ( + dot_2_input_axes[:-1] + (None,) + dot_2_input_axes[-1:] + ) # add the act_num axis + dot_1_output = with_sharding_constraint_by_logical_axes(dot_1_output, dot_1_output_axes) + dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) # (batch..., hidden_in) -> (batch..., hidden) diff --git a/transformer_engine/jax/sharding.py b/transformer_engine/jax/sharding.py index 578517d62..339e74e2f 100644 --- a/transformer_engine/jax/sharding.py +++ b/transformer_engine/jax/sharding.py @@ -165,7 +165,7 @@ def with_sharding_constraint_by_logical_axes( flax_rules = flax.linen.get_logical_axis_rules() if len(flax_rules) > 0: return flax.linen.with_logical_constraint( - x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.NO_CONSTRAINT + x, logical_axis_names, fallback=flax.linen.spmd.RulesFallback.AXIS_IS_UNSHARDED ) except ImportError: pass From 1e2c68d6e6e8783b4cdde4c867e88f87f63245d5 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Wed, 27 Aug 2025 19:23:20 -0400 Subject: [PATCH 109/153] [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../jax/cpp_extensions/quantization.py | 57 +++++++++++-------- .../jax/csrc/extensions/quantization.cpp | 11 ++-- 2 files changed, 39 insertions(+), 29 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 04795dc3b..198beb55e 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -57,14 +57,14 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): name = "te_dbias_quantize_ffi" multiple_results = True impl_static_args = ( - 2, 3, 4, 5, 6, 7, 8, - ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer + 9, + ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval inner_primitive = None outer_primitive = None @@ -72,6 +72,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): def abstract( x_aval, scale_aval, + amax_aval, *, out_dtype, scaling_mode, @@ -95,7 +96,7 @@ def abstract( rowwise_out_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - updated_amax_aval = jax.core.ShapedArray(shape=(1,), dtype=jnp.float32) + updated_amax_aval = amax_aval rowwise_scale_inv_shape, colwise_scale_inv_shape = ScalingMode( scaling_mode @@ -168,6 +169,7 @@ def lowering( ctx, x, scale, + amax, *, out_dtype, scaling_mode, @@ -181,13 +183,17 @@ def lowering( te_dbias_quantize_p lowering rules """ del out_dtype, scale_dtype, is_outer - x_aval, scale_aval = ctx.avals_in + x_aval, scale_aval, amax_aval = ctx.avals_in assert x_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] - assert scale_aval.dtype == jnp.float32 - return ffi.ffi_lowering(BaseDBiasQuantizePrimitive.name)( + assert scale_aval.dtype == amax_aval.dtype == jnp.float32 + return ffi.ffi_lowering( + BaseDBiasQuantizePrimitive.name, + operand_output_aliases={2: 4}, # donate amax buffer to updated_amax + )( ctx, x, scale, + amax, scaling_mode=scaling_mode.value, q_layout=q_layout, flatten_axis=flatten_axis, @@ -198,6 +204,7 @@ def lowering( def impl( x, scale, + amax, out_dtype, scaling_mode, q_layout, @@ -222,6 +229,7 @@ def impl( ) = BaseDBiasQuantizePrimitive.inner_primitive.bind( x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -268,15 +276,15 @@ def batcher( del is_outer check_valid_batch_dims(batch_dims) assert BaseDBiasQuantizePrimitive.outer_primitive is not None - x, scale = batched_args - x_bdim, scale_bdim = batch_dims - amax_bdim = scale_bdim + x, scale, amax = batched_args + x_bdim, scale_bdim, amax_bdim = batch_dims out_bdims = x_bdim, x_bdim, scale_bdim, scale_bdim, amax_bdim, x_bdim return ( BaseDBiasQuantizePrimitive.outer_primitive.bind( x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -303,7 +311,7 @@ def infer_sharding_from_operands( del (out_dtype, result_infos, scale_dtype, is_outer) # Unused. x_spec = get_padded_spec(arg_infos[0]) - scale_spec = get_padded_spec(arg_infos[1]) + amax_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), @@ -329,10 +337,8 @@ def infer_sharding_from_operands( desc="BaseDBiasQuantizePrimitive.dbias_sharding", ) - scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -341,14 +347,14 @@ def infer_sharding_from_operands( scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*scale_inv_spec), desc="BaseDBiasQuantizePrimitive.scale_inv" ) - amax_sharding = NamedSharding( - mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax" - ) colwise_scale_inv_sharding = NamedSharding( mesh, PartitionSpec(*colwise_scale_inv_spec), desc="BaseDBiasQuantizePrimitive.colwise_scale_inv", ) + amax_sharding = NamedSharding( + mesh, PartitionSpec(*amax_spec), desc="BaseDBiasQuantizePrimitive.amax" + ) return ( out_sharding, @@ -375,7 +381,7 @@ def partition( del result_infos, is_outer x_spec = get_padded_spec(arg_infos[0]) - scale_spec = get_padded_spec(arg_infos[1]) + amax_spec = get_padded_spec(arg_infos[2]) out_sharding = NamedSharding( mesh, PartitionSpec(*x_spec), @@ -401,10 +407,8 @@ def partition( desc="BaseDBiasQuantizePrimitive.dbias_sharding", ) - scale_inv_spec = amax_spec = colwise_scale_inv_spec = (None,) - if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING.value: - scale_inv_spec = amax_spec = scale_spec - elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: + scale_inv_spec = colwise_scale_inv_spec = (None,) + if scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: scale_inv_spec = x_spec if q_layout in (QuantizeLayout.COLWISE.value, QuantizeLayout.ROWWISE_COLWISE.value): @@ -432,7 +436,7 @@ def partition( dbias_sharding, ) - def sharded_impl(x, scale): + def sharded_impl(x, scale, amax): ( local_x, local_colwise_x, @@ -443,6 +447,7 @@ def sharded_impl(x, scale): ) = BaseDBiasQuantizePrimitive.impl( x, scale, + amax, out_dtype=out_dtype, scaling_mode=scaling_mode, q_layout=q_layout, @@ -510,7 +515,7 @@ def shardy_sharding_rule( amax = (prefix + "amax",) return SdyShardingRule( - (x_axes, ("…1",)), + (x_axes, ("…1",), amax), (out, colwise_out, scale_rules.rowwise_rule, colwise_scale_inv, amax, dbias), ) @@ -638,6 +643,9 @@ def _quantize_dbias_impl( elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale + # Make sure amax is init with zero + amax = jnp.zeros((1,), jnp.float32) + # It is faster to use 1x quantization for tensor scaling is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) force_1x_quantization = ( @@ -659,6 +667,7 @@ def _quantize_dbias_impl( ) = PrimitiveClass.outer_primitive.bind( x, scale, + amax, out_dtype=quantizer.q_dtype, scaling_mode=quantizer.scaling_mode.value, q_layout=q_layout.value, diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index 7bea11f91..d17d83ec1 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -72,9 +72,10 @@ pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_ } Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_Type scale_buf, - Result_Type output_buf, Result_Type output_trans_buf, - Result_Type scale_inv_buf, Result_Type colwise_scale_inv_buf, - Result_Type amax_buf, Result_Type dbias_buf, Result_Type workspace_buf, + Buffer_Type amax_buf, Result_Type output_buf, + Result_Type output_trans_buf, Result_Type scale_inv_buf, + Result_Type colwise_scale_inv_buf, Result_Type updated_amax_buf, + Result_Type dbias_buf, Result_Type workspace_buf, JAXX_Scaling_Mode scaling_mode, int64_t quantize_layout_enum, bool is_dbias, int64_t flatten_axis) { auto in_dtype = convert_ffi_datatype_to_te_dtype(input_buf.element_type()); @@ -119,11 +120,10 @@ Error_Type DBiasQuantizeFFI(cudaStream_t stream, Buffer_Type input_buf, Buffer_T if (is_fp8_dtype(out_dtype)) { if (is_tensor_scaling) { float *scale = reinterpret_cast(scale_buf.untyped_data()); - float *amax = reinterpret_cast(amax_buf->untyped_data()); + float *amax = reinterpret_cast(updated_amax_buf->untyped_data()); NVTE_CHECK(scale != nullptr, "scale must be provided for delayed tensor scaling"); NVTE_CHECK(amax != nullptr, "amax must be provided for delayed tensor scaling"); output_tensor.set_scale(scale, DType::kFloat32, std::vector{1}); - nvte_memset(amax, 0, sizeof(float), stream); output_tensor.set_amax(amax, DType::kFloat32, std::vector{1}); output_tensor.set_rowwise_scale_inv( scale_inv_buf->untyped_data(), @@ -183,6 +183,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(DBiasQuantizeHandler, DBiasQuantizeFFI, .Ctx() // stream .Arg() // input .Arg() // scale + .Arg() // amax .Ret() // output .Ret() // colwise output .Ret() // scale_inv From de81b7dfb6a9266f20f60fb42e634436fd8324f9 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Wed, 27 Aug 2025 16:31:29 -0700 Subject: [PATCH 110/153] Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani --- .../pytorch/attention/dot_product_attention/utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 9d6677b62..1f88800a6 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -434,8 +434,8 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version <= (9, 12, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.12") + if device_compute_capability == (8, 9) and cudnn_version <= (9, 13, 0): + logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.13") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") From c77614197d8ffa8e1fb177cc7fd125c11aa09e6e Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Wed, 27 Aug 2025 22:20:25 -0700 Subject: [PATCH 111/153] Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov --- tests/cpp/CMakeLists.txt | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 412c5d34d..c2c9d0d91 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -43,6 +43,5 @@ include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) -add_subdirectory(comm_gemm) add_subdirectory(operator) add_subdirectory(util) From a5c79876add9543a4db8e0dda22e2aced0d6615e Mon Sep 17 00:00:00 2001 From: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Date: Thu, 28 Aug 2025 10:35:04 -0700 Subject: [PATCH 112/153] [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 17 +++++++++++++---- tests/pytorch/utils.py | 2 +- .../attention/dot_product_attention/utils.py | 12 +++++++++--- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index b76f3d2b2..e72067367 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -122,13 +122,18 @@ def is_fused_attn_available( - config: ModelConfig, dtype: torch.dtype, qkv_layout="bshd_bshd_bshd", is_training=True + config: ModelConfig, + dtype: torch.dtype, + qkv_layout="bshd_bshd_bshd", + is_training=True, + deterministic=False, ): _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, is_training=is_training, + deterministic=deterministic, ) return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends @@ -839,7 +844,7 @@ def _test_e2e_checkpointing(bs, dtype, config, checkpoint=False, steps=10, path= @pytest.mark.parametrize("model", ["126m"]) def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] - if not is_fused_attn_available(config, dtype): + if not is_fused_attn_available(config, dtype, deterministic=True): pytest.skip("No attention backend available.") outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -887,7 +892,9 @@ def _test_e2e_gpt_accuracy(block, bs, dtype, config): @pytest.mark.parametrize("parallel_attention_mlp", all_boolean) def test_gpt_accuracy(dtype, bs, model, parallel_attention_mlp): config = model_configs[model] - if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + if not is_fused_attn_available( + config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True + ): pytest.skip("No attention backend available.") te_gpt = TransformerLayer( @@ -1000,7 +1007,9 @@ def _test_mha_accuracy(block, bs, dtype, config, mask_type, te=True): @pytest.mark.parametrize("mask_type", mask_types) def test_mha_accuracy(dtype, bs, model, mask_type): config = model_configs[model] - if not is_fused_attn_available(config, dtype, qkv_layout="sb3hd", is_training=False): + if not is_fused_attn_available( + config, dtype, qkv_layout="sb3hd", is_training=True, deterministic=True + ): pytest.skip("No attention backend available.") te_mha = MultiheadAttention( diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 524bd3289..38f400f65 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -266,8 +266,8 @@ def test(): ) ( use_flash_attention, - use_fused_attention, flash_attention_backend, + use_fused_attention, fused_attention_backend, use_unfused_attention, available_backends, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 1f88800a6..7097f4ba0 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -822,7 +822,7 @@ def get_attention_backend( # flash-attn >=2.4.1 | yes # FusedAttention | # sub-backend 0 | yes - # sub-backend 1 | workspace optimization path and sm90+: yes; + # sub-backend 1 | workspace optimization path and sm90: yes; # | otherwise: no # sub-backend 2 | no # UnfusedDotProductAttention | yes @@ -838,8 +838,9 @@ def get_attention_backend( use_flash_attention_2 = False if use_fused_attention and deterministic: if fused_attention_backend == FusedAttnBackend["FP8"] and is_training: - logger.debug("Disabling FusedAttention for determinism reasons") + logger.debug("Disabling FusedAttention for determinism reasons with FP8") use_fused_attention = False + fused_attention_backend = None if ( fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"] and is_training @@ -849,8 +850,13 @@ def get_attention_backend( or cudnn_version < (8, 9, 5) ) ): - logger.debug("Disabling FusedAttention for determinism reasons") + logger.debug("Disabling FusedAttention for determinism reasons with post_scale_bias") + use_fused_attention = False + fused_attention_backend = None + if is_training and device_compute_capability >= (10, 0) and cudnn_version <= (9, 14, 0): + logger.debug("Disabling FusedAttention for determinism reasons on Blackwell") use_fused_attention = False + fused_attention_backend = None # use_flash_attention may have been set above use_flash_attention_2 = use_flash_attention and use_flash_attention_2 From 06a38cc067ba5cfebda992ee1e6a721e1608b98c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Thu, 28 Aug 2025 20:18:02 +0200 Subject: [PATCH 113/153] [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- docs/examples/onnx/onnx_export.ipynb | 2 +- tests/pytorch/test_onnx_export.py | 18 +++++-- transformer_engine/pytorch/onnx_extensions.py | 54 ++++++++++++++++--- .../pytorch/tensor/float8_tensor.py | 22 +++++--- 4 files changed, 77 insertions(+), 19 deletions(-) diff --git a/docs/examples/onnx/onnx_export.ipynb b/docs/examples/onnx/onnx_export.ipynb index 91fc38003..26ac71188 100644 --- a/docs/examples/onnx/onnx_export.ipynb +++ b/docs/examples/onnx/onnx_export.ipynb @@ -10,7 +10,7 @@ "\n", "Note:\n", "\n", - "Currently, export to ONNX is supported only for high precision, FP8 delayed scaling and MXFP8.\n", + "Currently, export to ONNX is supported only for high precision, FP8 delayed scaling, FP8 current scaling and MXFP8.\n", "\n", "
\n", "\n", diff --git a/tests/pytorch/test_onnx_export.py b/tests/pytorch/test_onnx_export.py index b353333a5..e5368497d 100644 --- a/tests/pytorch/test_onnx_export.py +++ b/tests/pytorch/test_onnx_export.py @@ -65,6 +65,7 @@ fp8_recipes.append(recipe.MXFP8BlockScaling()) if fp8_available: fp8_recipes.append(recipe.DelayedScaling()) + fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(None) supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu"] @@ -81,11 +82,11 @@ ], outputs=[PyCustomOpDef.dt_uint8], ) -def trt_fp8_quantize(t, scale): +def trt_fp8_quantize(t, scale_inv): """FP8 quantization extension for ONNX Runtime.""" x = torch.from_numpy(t).cuda() q = te.tensor.float8_tensor.Float8Quantizer( - scale=1 / torch.from_numpy(scale).cuda(), + scale=1 / torch.from_numpy(scale_inv).cuda(), amax=torch.zeros([1]).cuda(), fp8_dtype=tex.DType.kFloat8E4M3, ) @@ -101,11 +102,11 @@ def trt_fp8_quantize(t, scale): ], outputs=[PyCustomOpDef.dt_float], ) -def trt_fp8_dequantize(t, scale): +def trt_fp8_dequantize(t, scale_inv): """FP8 dequantization extension for ONNX Runtime.""" x = torch.from_numpy(t).cuda() q = te.tensor.float8_tensor.Float8Quantizer( - scale=1 / torch.from_numpy(scale).cuda(), + scale=1 / torch.from_numpy(scale_inv).cuda(), amax=torch.zeros([1]).cuda(), fp8_dtype=tex.DType.kFloat8E4M3, ) @@ -593,7 +594,9 @@ def _test_export_layernorm_linear( fname, inp, model, - atol=1e-3, + # For current scaling we use Float8Quantizer in tests + amax computed by hand, + # which has slightly different numerics than Float8CurrentScalingQuantizer. + atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2, is_fp8=fp8_recipe is not None, te_outputs=te_outputs, ) @@ -1150,6 +1153,11 @@ def test_trt_integration(fp8_recipe: recipe.Recipe): ffn_hidden_size=128, num_attention_heads=4, ).eval() + + if type(fp8_recipe) == recipe.Float8CurrentScaling: + # TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling + model = te.LayerNormMLP(128, 128) + inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),) with te.fp8_autocast(enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe): diff --git a/transformer_engine/pytorch/onnx_extensions.py b/transformer_engine/pytorch/onnx_extensions.py index 42f5a1d55..38df5fc54 100644 --- a/transformer_engine/pytorch/onnx_extensions.py +++ b/transformer_engine/pytorch/onnx_extensions.py @@ -112,7 +112,9 @@ def onnx_quantize_fp8_symbolic( doc="TRT FP8 Quantize Linear used for inference.", inputs=[ defs.OpSchema.FormalParameter("tensor", "tensor(float)", "Input tensor to quantize"), - defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for quantization"), + defs.OpSchema.FormalParameter( + "scale_inv", "tensor(float)", "Inverse scale factor for quantization" + ), ], outputs=[defs.OpSchema.FormalParameter("output", "tensor(uint8)", "Quantized output tensor")], ) @@ -126,11 +128,10 @@ def onnx_quantize_fp8_symbolic( @torch.library.custom_op("tex::fp8_dequantize", mutates_args=[]) -def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale: float) -> torch.Tensor: +def onnx_dequantize_fp8_op(tensor: torch.Tensor, scale_inv: torch.Tensor) -> torch.Tensor: """Dequantize from Float8Tensor used for inference.""" - scale_tensor = torch.tensor(scale, dtype=torch.float32, device=tensor.device) quantizer = Float8Quantizer( - scale_tensor, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 + 1 / scale_inv, torch.zeros(1).to(tensor.device), tex.DType.kFloat8E4M3 ) quantizer_tensor = quantizer.create_tensor_from_data(tensor, fake_dtype=torch.float32) return quantizer_tensor.dequantize() @@ -143,10 +144,9 @@ def _(tensor: torch.Tensor, _) -> torch.Tensor: def onnx_dequantize_fp8_symbolic( - tensor: onnxscript.onnx_types.TensorType, scale: float + tensor: onnxscript.onnx_types.TensorType, scale_inv: onnxscript.onnx_types.TensorType ) -> onnxscript.onnx_types.TensorType: """Symbolic dequantize from Float8Tensor used for inference.""" - scale_inv = op.Constant(value_float=1 / scale) return TRT_FP8DequantizeLinear(tensor, scale_inv) @@ -157,7 +157,9 @@ def onnx_dequantize_fp8_symbolic( doc="TRT FP8 Dequantize Linear from Float8Tensor used for inference.", inputs=[ defs.OpSchema.FormalParameter("tensor", "tensor(uint8)", "Input tensor to dequantize"), - defs.OpSchema.FormalParameter("scale", "tensor(float)", "Scale factor for dequantization"), + defs.OpSchema.FormalParameter( + "scale_inv", "tensor(float)", "Inverse scale factor for dequantization" + ), ], outputs=[defs.OpSchema.FormalParameter("output", "tensor(float)", "Dequantized output tensor")], ) @@ -166,6 +168,43 @@ def onnx_dequantize_fp8_symbolic( opset=trt_opset, name="TRT_FP8DequantizeLinear", op_schema=schema ) +# ONNX FP8 Current Scaling Quantization + + +@torch.library.custom_op("tex::fp8_cs_quantize", mutates_args=[]) +def onnx_cs_quantize_fp8_op(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """Quantize to FP8 with current scaling; returns (uint8, scale_inv).""" + if tensor.dtype != torch.float32: + tensor = tensor.to(torch.float32) + amax = tensor.abs().max() + eps = torch.tensor(1e-12, dtype=torch.float32, device=tensor.device) + amax = torch.maximum(amax, eps) + fp8_max = torch.tensor(448, dtype=torch.float32, device=tensor.device) + scale = fp8_max / amax + q = torch.ops.tex.fp8_quantize(tensor, scale) + scale_inv = 1 / scale + return q, scale_inv + + +@onnx_cs_quantize_fp8_op.register_fake +def _(tensor: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + return torch.empty(tensor.shape, dtype=torch.uint8, device=tensor.device), torch.ones( + 1, dtype=torch.float32, device=tensor.device + ) + + +def onnx_quantize_fp8_cs_symbolic( + tensor: onnxscript.onnx_types.TensorType, +): + """Symbolic quantize with current scaling; computes scale_inv from tensor.""" + # scale_inv = 1 / max(abs(tensor)) + amax = op.ReduceMax(op.Abs(tensor), keepdims=0) + eps = op.Constant(value_float=1.0e-12) + amax = op.Max(amax, eps) + scale_inv = op.Div(amax, op.Constant(value_float=448.0)) + q = TRT_FP8QuantizeLinear(tensor, scale_inv) + return q, scale_inv + # ONNX MXFP8 Quantization @@ -356,6 +395,7 @@ def onnx_attention_mask_func( torch.ops.tex.gemm_inf.default: onnx_gemm_inf_symbolic, torch.ops.tex.fp8_quantize.default: onnx_quantize_fp8_symbolic, torch.ops.tex.fp8_dequantize.default: onnx_dequantize_fp8_symbolic, + torch.ops.tex.fp8_cs_quantize.default: onnx_quantize_fp8_cs_symbolic, torch.ops.tex.mxfp8_quantize.default: onnx_quantize_mxfp8_symbolic, torch.ops.tex.mxfp8_dequantize.default: onnx_dequantize_mxfp8_symbolic, torch.ops.tex.layernorm.default: onnx_layernorm_symbolic, diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index acc03ba78..1524584aa 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -177,7 +177,7 @@ def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: """Function using primitives with ONNX defined translations.""" - out = torch.ops.tex.fp8_dequantize(tensor._data, self.scale.item()) + out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv) out = out.to(tensor.dtype) return out @@ -350,15 +350,25 @@ def create_tensor_from_data( def onnx_quantize(self, tensor: torch.Tensor) -> QuantizedTensor: """Function using primitives with ONNX defined translations.""" - raise NotImplementedError( - "Float8CurrentScalingQuantizer does not support ONNX quantization yet." + if tensor.dtype != torch.float32: + tensor = tensor.to(torch.float32) + data, scale_inv = torch.ops.tex.fp8_cs_quantize(tensor) + return Float8Tensor( + shape=data.shape, + dtype=torch.float32, + data=data, + fp8_scale_inv=scale_inv, + fp8_dtype=self.dtype, + requires_grad=False, + data_transpose=None, + quantizer=self, ) def onnx_dequantize(self, tensor: QuantizedTensor) -> torch.Tensor: """Function using primitives with ONNX defined translations.""" - raise NotImplementedError( - "Float8CurrentScalingQuantizer does not support ONNX dequantization yet." - ) + out = torch.ops.tex.fp8_dequantize(tensor._data, tensor._scale_inv) + out = out.to(tensor.dtype) + return out def _canonicalized_amax_reduction_group(self) -> dist_group_type: """Get process group for amax reduction""" From c449c6cfbd30bf806a968bafd05aca51717c3533 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu <42691305+zhongbozhu@users.noreply.github.com> Date: Thu, 28 Aug 2025 15:13:16 -0700 Subject: [PATCH 114/153] [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani --- .../pytorch/csrc/extensions/cast.cpp | 14 ++++---------- 1 file changed, 4 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 819d3e518..e9647b44f 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -205,11 +205,8 @@ std::tuple, std::vector> bulk_allocate_fp auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); - // in the case where full buffer is empty because local rank receives no tokens for all the experts - // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob - // but in the case where some experts receive tokens, some not, we want to leverage from_blob - // as much as possible to avoid CPU overhead - if (buffer->data_ptr() == nullptr) { + bool is_empty_shape = product(shape) == 0; + if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } return at::from_blob( @@ -359,11 +356,8 @@ std::tuple, std::vector> bulk_allocate_mx auto make_torch_view = [](std::shared_ptr &buffer, const std::vector &shape, size_t offset, at::ScalarType dtype) -> at::Tensor { std::vector shape_int64(shape.begin(), shape.end()); - // in the case where full buffer is empty because local rank receives no tokens for all the experts - // then the data_ptr is nullptr, we need to return an empty tensor instead of calling from_blob - // but in the case where some experts receive tokens, some not, we want to leverage from_blob - // as much as possible to avoid CPU overhead - if (buffer->data_ptr() == nullptr) { + bool is_empty_shape = product(shape) == 0; + if (buffer->data_ptr() == nullptr || is_empty_shape) { return at::empty(shape_int64, at::device(at::kCUDA).dtype(dtype)); } return at::from_blob( From f98e305321224fe7a9a10807ca974f271de4fd33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?oliver=20k=C3=B6nig?= Date: Fri, 29 Aug 2025 08:10:07 +0200 Subject: [PATCH 115/153] build: pull cached wheels (#2127) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/setup.py | 105 ++++++++++++++++++++++++++-- 1 file changed, 101 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/setup.py b/transformer_engine/pytorch/setup.py index ae1b5780b..46543acf2 100644 --- a/transformer_engine/pytorch/setup.py +++ b/transformer_engine/pytorch/setup.py @@ -10,14 +10,30 @@ import os import shutil from pathlib import Path - +import platform +import urllib import setuptools +from wheel.bdist_wheel import bdist_wheel as _bdist_wheel +from packaging.version import parse try: + import torch from torch.utils.cpp_extension import BuildExtension except ImportError as e: raise RuntimeError("This package needs Torch to build.") from e +FORCE_BUILD = os.getenv("NVTE_PYTORCH_FORCE_BUILD", "FALSE") == "TRUE" +FORCE_CXX11_ABI = os.getenv("NVTE_PYTORCH_FORCE_CXX11_ABI", "FALSE") == "TRUE" +SKIP_CUDA_BUILD = os.getenv("NVTE_PYTORCH_SKIP_CUDA_BUILD", "FALSE") == "TRUE" +PACKAGE_NAME = "transformer_engine_torch" +BASE_WHEEL_URL = ( + "https://github.com/NVIDIA/TransformerEngine/releases/download/{tag_name}/{wheel_name}" +) +# HACK: The compiler flag -D_GLIBCXX_USE_CXX11_ABI is set to be the same as +# torch._C._GLIBCXX_USE_CXX11_ABI +# https://github.com/pytorch/pytorch/blob/8472c24e3b5b60150096486616d98b7bea01500b/torch/utils/cpp_extension.py#L920 +if FORCE_CXX11_ABI: + torch._C._GLIBCXX_USE_CXX11_ABI = True current_file_path = Path(__file__).parent.resolve() build_tools_dir = current_file_path.parent.parent / "build_tools" @@ -31,13 +47,94 @@ from build_tools.build_ext import get_build_ext from build_tools.utils import copy_common_headers from build_tools.te_version import te_version -from build_tools.pytorch import setup_pytorch_extension, install_requirements, test_requirements +from build_tools.pytorch import ( + setup_pytorch_extension, + install_requirements, + test_requirements, +) os.environ["NVTE_PROJECT_BUILDING"] = "1" CMakeBuildExtension = get_build_ext(BuildExtension, True) +def get_platform(): + """ + Returns the platform name as used in wheel filenames. + """ + if sys.platform.startswith("linux"): + return f"linux_{platform.uname().machine}" + if sys.platform == "darwin": + mac_version = ".".join(platform.mac_ver()[0].split(".")[:2]) + return f"macosx_{mac_version}_x86_64" + if sys.platform == "win32": + return "win_amd64" + + raise ValueError(f"Unsupported platform: {sys.platform}") + + +def get_wheel_url(): + """Construct the wheel URL for the current platform.""" + torch_version_raw = parse(torch.__version__) + python_version = f"cp{sys.version_info.major}{sys.version_info.minor}" + platform_name = get_platform() + nvte_version = te_version() + torch_version = f"{torch_version_raw.major}.{torch_version_raw.minor}" + cxx11_abi = str(torch._C._GLIBCXX_USE_CXX11_ABI).upper() + + # Determine the version numbers that will be used to determine the correct wheel + # We're using the CUDA version used to build torch, not the one currently installed + # _, cuda_version_raw = get_cuda_bare_metal_version(CUDA_HOME) + torch_cuda_version = parse(torch.version.cuda) + # For CUDA 11, we only compile for CUDA 11.8, and for CUDA 12 we only compile for CUDA 12.3 + # to save CI time. Minor versions should be compatible. + torch_cuda_version = parse("11.8") if torch_cuda_version.major == 11 else parse("12.3") + # cuda_version = f"{cuda_version_raw.major}{cuda_version_raw.minor}" + cuda_version = f"{torch_cuda_version.major}" + + # Determine wheel URL based on CUDA version, torch version, python version and OS + wheel_filename = f"{PACKAGE_NAME}-{nvte_version}+cu{cuda_version}torch{torch_version}cxx11abi{cxx11_abi}-{python_version}-{python_version}-{platform_name}.whl" + + wheel_url = BASE_WHEEL_URL.format(tag_name=f"v{nvte_version}", wheel_name=wheel_filename) + + return wheel_url, wheel_filename + + +class CachedWheelsCommand(_bdist_wheel): + """ + The CachedWheelsCommand plugs into the default bdist wheel, which is ran by pip when it cannot + find an existing wheel (which is currently the case for all grouped gemm installs). We use + the environment parameters to detect whether there is already a pre-built version of a compatible + wheel available and short-circuits the standard full build pipeline. + """ + + def run(self): + if FORCE_BUILD: + super().run() + + wheel_url, wheel_filename = get_wheel_url() + print("Guessing wheel URL: ", wheel_url) + try: + urllib.request.urlretrieve(wheel_url, wheel_filename) + + # Make the archive + # Lifted from the root wheel processing command + # https://github.com/pypa/wheel/blob/cf71108ff9f6ffc36978069acb28824b44ae028e/src/wheel/bdist_wheel.py#LL381C9-L381C85 + if not os.path.exists(self.dist_dir): + os.makedirs(self.dist_dir) + + impl_tag, abi_tag, plat_tag = self.get_tag() + archive_basename = f"{self.wheel_dist_name}-{impl_tag}-{abi_tag}-{plat_tag}" + + wheel_path = os.path.join(self.dist_dir, archive_basename + ".whl") + print("Raw wheel path", wheel_path) + os.rename(wheel_filename, wheel_path) + except (urllib.error.HTTPError, urllib.error.URLError): + print("Precompiled wheel not found. Building from source...") + # If the wheel could not be downloaded, build from source + super().run() + + if __name__ == "__main__": # Extensions common_headers_dir = "common_headers" @@ -50,11 +147,11 @@ # Configure package setuptools.setup( - name="transformer_engine_torch", + name=PACKAGE_NAME, version=te_version(), description="Transformer acceleration library - Torch Lib", ext_modules=ext_modules, - cmdclass={"build_ext": CMakeBuildExtension}, + cmdclass={"build_ext": CMakeBuildExtension, "bdist_wheel": CachedWheelsCommand}, install_requires=install_requirements(), tests_require=test_requirements(), ) From 715c3bb82c3bcdef14a8c7ceff6c659f0e39aa66 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Fri, 29 Aug 2025 21:03:19 +1200 Subject: [PATCH 116/153] feat: Add support for multiple quantization modes in the UB communicators (#2043) --- docs/api/pytorch.rst | 5 +- .../te_layer_with_overlap.py | 4 +- .../distributed/run_layer_with_overlap.py | 77 +++++++++-- .../test_fusible_ops_with_userbuffers.py | 8 +- .../userbuffers/userbuffers-host.cpp | 2 +- .../userbuffers/userbuffers.h | 2 +- transformer_engine/pytorch/__init__.py | 1 + transformer_engine/pytorch/module/__init__.py | 2 +- transformer_engine/pytorch/module/base.py | 120 +++++++++++++----- .../pytorch/module/layernorm_linear.py | 22 ++-- .../pytorch/module/layernorm_mlp.py | 24 ++-- transformer_engine/pytorch/module/linear.py | 22 ++-- .../ops/fused/userbuffers_backward_linear.py | 10 +- .../ops/fused/userbuffers_forward_linear.py | 2 +- 14 files changed, 216 insertions(+), 85 deletions(-) diff --git a/docs/api/pytorch.rst b/docs/api/pytorch.rst index 3229298f2..04b49fac2 100644 --- a/docs/api/pytorch.rst +++ b/docs/api/pytorch.rst @@ -49,7 +49,7 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.moe_permute -.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs +.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs .. autoapifunction:: transformer_engine.pytorch.moe_unpermute @@ -62,3 +62,6 @@ pyTorch .. autoapifunction:: transformer_engine.pytorch.initialize_ub .. autoapifunction:: transformer_engine.pytorch.destroy_ub + +.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode + :members: FP8, NONE \ No newline at end of file diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index e510df176..eeb79c235 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -263,7 +263,9 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) te.module.base.initialize_ub( [batched_size, hidden_size], tp_size, - use_fp8=opts.fp8, + quantization_modes=[ + UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE + ], dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ) diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 2fc4537f0..1dabf6e45 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -12,6 +12,8 @@ import warnings import pprint import yaml +from contextlib import nullcontext +from functools import partial import torch import torch.distributed as dist @@ -35,9 +37,10 @@ def __init__(self, module, num_layers, *args, **kwargs): self.num_layers = num_layers self.layers = torch.nn.ModuleList([module(*args, **kwargs) for _ in range(num_layers)]) - def forward(self, x): - for layer in self.layers: - x = layer(x) + def forward(self, x, layer_contexts): + for layer, context in zip(self.layers, layer_contexts): + with context(): + x = layer(x) return x @@ -237,12 +240,46 @@ def _parse_args(argv=None, namespace=None): default=False, help="Print out additional debug information.", ) + parser.add_argument( + "--first-last-layers-bf16", + action="store_true", + default=False, + help="Use bf16 for first and last N layers.", + ) + parser.add_argument( + "--num-layers-at-start-in-bf16", + type=int, + default=0, + help="Number of layers at the start to run in bf16.", + ) + parser.add_argument( + "--num-layers-at-end-in-bf16", + type=int, + default=0, + help="Number of layers at the end to run in bf16.", + ) args = parser.parse_args(argv, namespace) if args.use_cuda_graphs and args.layer_type in [te.MultiheadAttention, te.TransformerLayer]: warnings.warn(f"{args.layer_type.__name__} does not support CUDA Graphs!") args.use_cuda_graphs = False + if not args.first_last_layers_bf16 and ( + args.num_layers_at_start_in_bf16 > 0 or args.num_layers_at_end_in_bf16 > 0 + ): + warnings.warn( + "num-layers-at-start-in-bf16 and num-layers-at-end-in-bf16 are only supported when" + " first-last-layers-bf16 is enabled!" + ) + args.num_layers_at_start_in_bf16 = 0 + args.num_layers_at_end_in_bf16 = 0 + + if args.num_layers_at_start_in_bf16 + args.num_layers_at_end_in_bf16 > args.num_layers: + raise ValueError( + "num-layers-at-start-in-bf16 + num-layers-at-end-in-bf16 must be less than or equal to" + " num-layers!" + ) + return args @@ -381,10 +418,17 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): "qkv_dgrad": {"method": "ring_exchange"}, "fc1_dgrad": {"method": "ring_exchange"}, } + + quantization_modes = [ + UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE + ] + if opts.first_last_layers_bf16 and opts.fp8: + quantization_modes.append(UserBufferQuantizationMode.NONE) + te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], opts.tp, - use_fp8=opts.fp8, + quantization_modes=quantization_modes, dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, ub_cfgs=ub_cfgs if opts.ub_cfg is None else opts.ub_cfg, @@ -423,6 +467,16 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): elif opts.quantization == "mxfp8": fp8_recipe = MXFP8BlockScaling() + layer_contexts = [ + ( + partial(te.fp8_autocast, enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world) + if opts.num_layers_at_start_in_bf16 <= i + and i < (opts.num_layers - opts.num_layers_at_end_in_bf16) + else nullcontext + ) + for i in range(opts.num_layers) + ] + # Prepare random input tensors test_x = torch.randn(input_shape, dtype=torch.float32, device="cuda", requires_grad=True) test_x.retain_grad() @@ -435,14 +489,13 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): # Execute fwd/bwd and collect tensors to test def run_fwd_bwd(model, x): with torch.amp.autocast("cuda", dtype=torch.bfloat16): - with te.fp8_autocast(enabled=opts.fp8, fp8_recipe=fp8_recipe, fp8_group=nccl_world): - y = model(x) - if isinstance(y, tuple): - out, *_ = y - else: - out = y - loss = out.sum() - loss.backward() + y = model(x, layer_contexts) + if isinstance(y, tuple): + out, *_ = y + else: + out = y + loss = out.sum() + loss.backward() return out torch_rng_state = torch.get_rng_state() diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 37f0e8669..17d351292 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -506,7 +506,13 @@ def main() -> None: model_config.num_heads * model_config.head_dim, ], torch.distributed.get_world_size(group), - use_fp8=model_config.quantization is not None, + quantization_modes=[ + ( + UserBufferQuantizationMode.FP8 + if model_config.quantization is not None + else UserBufferQuantizationMode.NONE + ) + ], dtype=model_config.dtype, bootstrap_backend=bootstrap_backend, ub_cfgs=userbuffer_configs, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp index 65da58d5f..1ce89c512 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -511,7 +511,7 @@ void destroy_communicator_mpi(communicator *comm) { } int register_user_buffer_collective(void **gpubuff, size_t bytes, communicator *comm, bool alloc) { - if (comm->free_region > NVTE_MAX_REGIONS) return -1; + if (comm->free_region >= NVTE_MAX_REGIONS) return -1; int hndl = comm->free_region; comm->peer_ptr[hndl] = reinterpret_cast(malloc(sizeof(void *) * (comm->nvsize))); size_t aligned_size = bytes; diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 34d6ff72f..8077f90be 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -27,7 +27,7 @@ using ExtAllgatherOp = std::function; using ExtBarrierOp = std::function; -#define NVTE_MAX_REGIONS 16 +#define NVTE_MAX_REGIONS 32 #define NVTE_MAX_SMS 32 #define NVTE_MAX_OPS 32 #define NVTE_MAX_PEERS 8192 diff --git a/transformer_engine/pytorch/__init__.py b/transformer_engine/pytorch/__init__.py index 2e86a77a5..3bdbe4089 100644 --- a/transformer_engine/pytorch/__init__.py +++ b/transformer_engine/pytorch/__init__.py @@ -33,6 +33,7 @@ def torch_version() -> tuple[int, ...]: from transformer_engine.pytorch.module import Fp8Padding, Fp8Unpadding from transformer_engine.pytorch.module import initialize_ub from transformer_engine.pytorch.module import destroy_ub +from transformer_engine.pytorch.module import UserBufferQuantizationMode from transformer_engine.pytorch.attention import DotProductAttention from transformer_engine.pytorch.attention import MultiheadAttention from transformer_engine.pytorch.attention import InferenceParams diff --git a/transformer_engine/pytorch/module/__init__.py b/transformer_engine/pytorch/module/__init__.py index 5074d32aa..ac682190c 100644 --- a/transformer_engine/pytorch/module/__init__.py +++ b/transformer_engine/pytorch/module/__init__.py @@ -11,4 +11,4 @@ from .rmsnorm import RMSNorm from .fp8_padding import Fp8Padding from .fp8_unpadding import Fp8Unpadding -from .base import initialize_ub, destroy_ub +from .base import initialize_ub, destroy_ub, UserBufferQuantizationMode diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 5d04b29f7..3bbfaacdf 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -8,6 +8,7 @@ import os import pickle import warnings +from enum import Enum from abc import ABC, abstractmethod from typing import Any, Dict, Generator, List, Optional, Set, Tuple, Union from contextlib import contextmanager @@ -49,7 +50,7 @@ from ...debug.pytorch.debug_quantization import DebugQuantizer, DebugQuantizedTensor from ...debug.pytorch.utils import next_iter_when_debug_should_be_run, any_feature_enabled -__all__ = ["initialize_ub", "destroy_ub"] +__all__ = ["initialize_ub", "destroy_ub", "UserBufferQuantizationMode"] _2X_ACC_FPROP = False _2X_ACC_DGRAD = True @@ -63,6 +64,15 @@ layers_atomic_ring_exchange = [] +class UserBufferQuantizationMode(Enum): + """ + UserBufferQuantizationMode is an enum that represents the quantization mode of the UserBuffer. + """ + + NONE = "none" + FP8 = "fp8" + + def get_cublas_workspace_size_bytes() -> None: """Return 32 MiB if using hopper, 4 MiB for all other architectures.""" if torch.cuda.get_device_properties(torch.cuda.current_device()).major >= 9: @@ -111,8 +121,9 @@ def initialize_ub( shape: list, tp_size: int, use_fp8: bool = False, + quantization_modes: List[UserBufferQuantizationMode] = None, dtype: torch.dtype = torch.bfloat16, - ub_cfgs: Optional[dict] = None, + ub_cfgs: Optional[Union[dict, List[dict]]] = None, bootstrap_backend: Union[str, torch.distributed.Backend] = None, ) -> None: r""" @@ -128,7 +139,11 @@ def initialize_ub( tp_size : int number of GPUs in the tensor-parallel process group use_fp8 : bool = False - allocate the communication buffer for FP8 GEMM inputs/outputs + allocate the communication buffer for FP8 GEMM inputs/outputs. + DEPRECATED: Please use `quantization_modes` instead. + quantization_modes : List[UserBufferQuantizationMode] = None + if a list of UserBufferQuantizationMode is provided, a UB communicator is created for each quantization setting in the list. + falls back to the legacy `use_fp8` parameter if `None` is provided. dtype : torch.dtype = torch.bfloat16 non-FP8 data type of the communication buffer when `use_fp8 = False` ub_cfgs: dict = None @@ -152,6 +167,7 @@ def initialize_ub( for `te.TransformerLayer` GEMM layers in `["qkv_fprop", "qkv_dgrad", "qkv_wgrad", "proj_fprop", "proj_dgrad", "proj_wgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad", "fc2_fprop", "fc2_wgrad"]`. + a list may be provided to specify different overlap configurations for different the quantization settings in `quantization_modes` bootstrap_backend : str = None `torch.distributed` communication backend for the all-gather, broadcast and barrier collectives during Userbuffers initialization. Not all backends are @@ -168,6 +184,28 @@ def initialize_ub( + "CUDA Multicast. Launch app with UB_SKIPMC=1 to try CUDA IPC instead." ) + if not quantization_modes: + warnings.warn( + "Initializing Userbuffers with use_fp8 is deprecated. Please use quantization_modes" + " instead.", + DeprecationWarning, + ) + quantization_modes = [ + UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE + ] + else: + assert isinstance(quantization_modes, list), "quantization_modes must be a list" + assert all( + isinstance(mode, UserBufferQuantizationMode) for mode in quantization_modes + ), "quantization_modes must be a list of UserBufferQuantizationMode" + + if isinstance(ub_cfgs, dict) or ub_cfgs is None: + ub_cfgs = [ub_cfgs] * len(quantization_modes) + else: + assert len(ub_cfgs) == len( + quantization_modes + ), "Number of ub_cfgs settings must match number of quantization configurations" + global _ub_communicators assert _ub_communicators is None, "UB communicators are already initialized." _ub_communicators = {} @@ -309,6 +347,7 @@ def get_default_config(name): def add_ub( name: str, + quantization_mode: UserBufferQuantizationMode, method: str, is_reduce_scatter: bool, num_sm: int = 16, @@ -327,7 +366,9 @@ def add_ub( warnings.warn( "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." ) - assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." + assert ( + quantization_mode == UserBufferQuantizationMode.FP8 + ), "Atomic GEMM overlap supported only for FP8 GEMM." if method in ("bulk", "external"): warnings.warn( f"At {name}, atoimic GEMM not is supported for a bulk overlap." @@ -367,7 +408,11 @@ def add_ub( f" {external_gemm_to_overlap[name]} is not using `ring_exchange` overlap method" ) - buffer_dtype = torch.uint8 if (use_fp8 and fp8_buf) else dtype + buffer_dtype = ( + torch.uint8 + if (quantization_mode == UserBufferQuantizationMode.FP8 and fp8_buf) + else dtype + ) if method == "ring_exchange": ub_obj = tex.CommOverlapP2P( shape, # Communication buffer shape @@ -401,38 +446,47 @@ def add_ub( comm_priority=comm_priority, rs_overlap_first_gemm=pipeline_rs_overlap_first_gemm, ) - _ub_communicators[name] = ub_obj - - if ub_cfgs is not None: - for name in dgrad_reduce_scatter_overlap: - if name in ub_cfgs and "method" in ub_cfgs[name] and ub_cfgs[name]["method"] != "bulk": - wgrad_name = name.replace("dgrad", "wgrad") - assert wgrad_name not in ub_cfgs - layers_reduce_scatter_overlap.remove(wgrad_name) - layers_all_gather_overlap.remove(name) - layers_reduce_scatter_overlap.append(name) - methods["bulk"].remove(name) - new_method = ub_cfgs[name]["method"] - methods[new_method].append(name) - - for name in ( - methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] - ): - ub_cfg = get_default_config(name) - if ub_cfgs is not None and name in ub_cfgs: - fp8_buf = (name in layers_all_gather_overlap) or ( - ub_cfgs[name].get("fp8_buf", False) and name in methods["pipeline"] - ) - ub_cfg.update(ub_cfgs[name]) - ub_cfg["fp8_buf"] = fp8_buf - add_ub(name, **ub_cfg) + _ub_communicators[(name, quantization_mode)] = ub_obj + + for quantization_mode, user_ub_cfg in zip(quantization_modes, ub_cfgs): + if user_ub_cfg is not None: + for name in dgrad_reduce_scatter_overlap: + if ( + name in user_ub_cfg + and "method" in user_ub_cfg[name] + and user_ub_cfg[name]["method"] != "bulk" + ): + wgrad_name = name.replace("dgrad", "wgrad") + assert wgrad_name not in user_ub_cfg + layers_reduce_scatter_overlap.remove(wgrad_name) + layers_all_gather_overlap.remove(name) + layers_reduce_scatter_overlap.append(name) + methods["bulk"].remove(name) + new_method = user_ub_cfg[name]["method"] + methods[new_method].append(name) + + for name in ( + methods["ring_exchange"] + methods["pipeline"] + methods["bulk"] + methods["external"] + ): + ub_cfg = get_default_config(name) + if user_ub_cfg is not None and name in user_ub_cfg: + fp8_buf = (name in layers_all_gather_overlap) or ( + user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"] + ) + ub_cfg.update(ub_cfgs[name]) + ub_cfg["fp8_buf"] = fp8_buf + add_ub(name, quantization_mode, **ub_cfg) -def get_ub(name: str): +def get_ub(name: str, use_fp8: bool): """Get userbuffer communicator corresponding to give key.""" + # For now use `use_fp8` boolean input as it matches the current design in the modules + # So favour simplicity until the correct design becomes clear. + # This is mainly an internal API so we don't need to worry about future changes + key = (name, UserBufferQuantizationMode.FP8 if use_fp8 else UserBufferQuantizationMode.NONE) assert _ub_communicators is not None, "UB manager is not initialized." - assert name in _ub_communicators, f"UB for {name} is not registered." - return _ub_communicators[name] + assert key in _ub_communicators, f"UB for {name} with use_fp8={use_fp8} is not registered." + return _ub_communicators[key] def destroy_ub(): diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 04e3eba7d..cd02f3113 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -173,10 +173,10 @@ def forward( ub_overlap_ag_fprop and is_grad_enabled and not return_layernorm_output ) if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.RS elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG # Configure quantizer for norm output @@ -575,23 +575,23 @@ def backward( dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -769,7 +769,7 @@ def backward( dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1492,10 +1492,14 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_grad = True with torch.cuda.device( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 2e51ac948..182bf99f8 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -307,7 +307,7 @@ def forward( fc1_input_quantizer.set_usage(rowwise=True, columnwise=False) if ub_overlap_ag: # Copy into Userbuffers buffer - ub_obj_lnout = get_ub("fc1_fprop") + ub_obj_lnout = get_ub("fc1_fprop", fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_lnout, ln_out, @@ -458,7 +458,7 @@ def forward( ub_obj_fc2out = None reduce_scatter_out = None if ub_overlap_rs: - ub_obj_fc2out = get_ub("fc2_fprop") + ub_obj_fc2out = get_ub("fc2_fprop", fp8) dim_size = list(act_out.size()) dim_size[0] //= tp_world_size dim_size[-1] = fc2_weight.size(0) @@ -740,7 +740,7 @@ def backward( # Note: Cast to expected dtype and perform tensor-parallel communication ub_obj_fc2_dgrad = None if ctx.ub_overlap_ag: - ub_obj_fc2_dgrad = get_ub("fc2_dgrad") + ub_obj_fc2_dgrad = get_ub("fc2_dgrad", ctx.fp8) ctx.ub_obj_gradout = ub_obj_fc2_dgrad ( grad_output, @@ -764,7 +764,7 @@ def backward( # wgrad GEMM requires input with column-wise usage quantizer.set_usage(rowwise=False, columnwise=True) if ctx.ub_bulk_dgrad: - ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ln_out_total, _ = fill_userbuffers_buffer_for_all_gather( ub_obj_fc1_dgrad, ln_out, @@ -869,7 +869,7 @@ def backward( ub_obj_fc2_dgrad.get_communication_stream() ) - ub_obj_fc2_wgrad = get_ub("fc2_wgrad") + ub_obj_fc2_wgrad = get_ub("fc2_wgrad", ctx.fp8) ctx.fc2_grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1036,16 +1036,16 @@ def fc2_wgrad_gemm( fc1_dgrad_shape = [reduce(multiply_op, inputmat.shape[:-1]), inputmat.shape[-1]] if ctx.ub_overlap_rs_dgrad: # Overlap DGRAD+RS - ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap ln_out all-gather with DGRAD compute - ub_obj_fc1_dgrad = get_ub("fc1_dgrad") + ub_obj_fc1_dgrad = get_ub("fc1_dgrad", ctx.fp8) ub_type_fc1_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap FC1 DGRAD reduce-scatter with WGRAD compute - ub_obj_fc1_wgrad = get_ub("fc1_wgrad") + ub_obj_fc1_wgrad = get_ub("fc1_wgrad", ctx.fp8) ub_type_fc1_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -1539,7 +1539,11 @@ def __init__( self.gemm_gelu_fusion = ( bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and self.activation == "gelu" - and ((_ub_communicators is None) or (not get_ub("fc1_fprop").is_atomic_gemm())) + and all( + ("fc1_fprop", use_fp8) not in _ub_communicators + or not get_ub("fc1_fprop", use_fp8).is_atomic_gemm() + for use_fp8 in [False, True] + ) ) self.name = name @@ -1757,7 +1761,7 @@ def forward( fp8_output = False if self.ub_overlap_rs: - if get_ub("fc2_fprop").is_fp8_ubuf(): + if get_ub("fc2_fprop", FP8GlobalStateManager.is_fp8_enabled()).is_fp8_ubuf(): fp8_output = True with torch.cuda.device( diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 695cbb4e6..2ce6fb4c1 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -145,10 +145,10 @@ def forward( ub_obj = None ub_type = None if ub_overlap_rs_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.RS elif ub_overlap_ag_fprop: - ub_obj = get_ub(ub_name + "_fprop") + ub_obj = get_ub(ub_name + "_fprop", fp8) ub_type = tex.CommOverlapType.AG # ------------------------------------------------------ @@ -520,23 +520,23 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_shape = [reduce(multiply_op, ctx.inp_shape[:-1]), ctx.inp_shape[-1]] if ctx.ub_overlap_ag: # Overlap grad_output all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG elif ctx.ub_overlap_rs_dgrad: # Overlap dgrad reduce-scatter with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.RS else: if ctx.ub_bulk_dgrad: # Overlap inputmat all-gather with dgrad compute - ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad") + ctx.ub_obj_gradout = get_ub(ctx.ub_name + "_dgrad", ctx.fp8) ub_obj_dgrad = ctx.ub_obj_gradout ub_type_dgrad = tex.CommOverlapType.AG if ctx.ub_bulk_wgrad: # Overlap dgrad reduce-scatter with wgrad compute - ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ub_type_wgrad = tex.CommOverlapType.RS # -------------------------------------------------- @@ -769,7 +769,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], dgrad_send_stream, dgrad_recv_stream = ub_obj_dgrad.get_communication_stream() # This object is separate from the ub_obj_wgrad object which is passed to the GEMM - ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad") + ub_obj_overlap_wgrad = get_ub(ctx.ub_name + "_wgrad", ctx.fp8) ctx.grad_output_quantizer.set_usage(rowwise=False, columnwise=True) @@ -1377,10 +1377,14 @@ def forward( is_first_microbatch = False if self.ub_overlap_rs_fprop: - if get_ub(self.ub_name + "_fprop").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_fprop", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_output = True if self.ub_overlap_rs_dgrad: - if get_ub(self.ub_name + "_dgrad").is_fp8_ubuf(): + if get_ub( + self.ub_name + "_dgrad", FP8GlobalStateManager.is_fp8_enabled() + ).is_fp8_ubuf(): fp8_grad = True with torch.cuda.device( diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py index c59532521..1ecdba625 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_backward_linear.py @@ -241,16 +241,16 @@ def _functional_backward( with_dgrad_all_gather_x = False with_wgrad_reduce_scatter_dx = False if tensor_parallel_mode == "row": - ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") + ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute) ub_type_dgrad = CommOverlapType.AG with_dgrad_all_gather_dy = True elif tensor_parallel_mode == "column": if input_requires_grad and weight_requires_grad: with_bulk_overlap = True - ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") + ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute) ub_type_dgrad = CommOverlapType.AG with_dgrad_all_gather_x = True - ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad") + ub_comm_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute) ub_type_wgrad = CommOverlapType.RS with_wgrad_reduce_scatter_dx = True if ub_comm_wgrad.is_fp8_ubuf(): @@ -258,7 +258,7 @@ def _functional_backward( "Userbuffers reduce-scatter is not supported with FP8 buffers" ) else: - ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad") + ub_comm_dgrad = get_ub(ub_comm_name + "_dgrad", with_quantized_compute) ub_type_dgrad = CommOverlapType.RS with_dgrad_reduce_scatter_dx = True if ub_comm_dgrad.is_fp8_ubuf(): @@ -409,7 +409,7 @@ def _functional_backward( # Get the communication stream from the dgrad GEMM to use for the AG dgrad_send_stream, dgrad_recv_stream = ub_comm_dgrad.get_communication_stream() - ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad") + ub_obj_overlap_wgrad = get_ub(ub_comm_name + "_wgrad", with_quantized_compute) grad_output_quantizer.set_usage(rowwise=False, columnwise=True) diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 61853f9f4..574642794 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -189,7 +189,7 @@ def _functional_forward( output_quantizer = None # Get Userbuffers communicator - ub_comm = get_ub(ub_comm_name + "_fprop") + ub_comm = get_ub(ub_comm_name + "_fprop", with_quantized_compute) with_ub_all_gather = tensor_parallel_mode == "column" with_ub_reduce_scatter = tensor_parallel_mode == "row" ub_type = CommOverlapType.AG if with_ub_all_gather else CommOverlapType.RS From 4285874da3733c731cae6aadedecd1c876735751 Mon Sep 17 00:00:00 2001 From: Xin Yao Date: Fri, 29 Aug 2025 17:43:26 +0800 Subject: [PATCH 117/153] [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../comm_gemm_overlap/comm_gemm_overlap.cpp | 26 +++++---- .../userbuffers/userbuffers.cu | 13 ++++- transformer_engine/common/common.cu | 4 +- .../common/fused_attn/context_parallel.cu | 9 +++ .../common/fused_attn/flash_attn.cu | 2 + .../fused_attn_f16_arbitrary_seqlen.cu | 4 ++ .../common/fused_attn/fused_attn_fp8.cu | 4 ++ .../common/fused_attn/kv_cache.cu | 4 ++ transformer_engine/common/fused_attn/utils.cu | 8 ++- .../common/fused_router/fused_moe_aux_loss.cu | 12 ++-- .../fused_score_for_moe_aux_loss.cu | 2 + .../fused_topk_with_score_function.cu | 2 + .../scaled_aligned_causal_masked_softmax.cu | 2 + .../fused_softmax/scaled_masked_softmax.cu | 3 + .../scaled_upper_triang_masked_softmax.cu | 2 + .../common/multi_tensor/l2norm.cu | 2 + .../common/normalization/common.cpp | 4 +- .../layernorm/ln_bwd_semi_cuda_kernel.cu | 34 ++++++----- .../layernorm/ln_fwd_cuda_kernel.cu | 31 +++++----- .../rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu | 16 ++++-- .../rmsnorm/rmsnorm_fwd_cuda_kernel.cu | 31 +++++----- .../common/nvshmem_api/nvshmem_waitkernel.cu | 15 +++-- .../common/permutation/permutation.cu | 6 ++ .../common/recipe/current_scaling.cu | 2 +- .../common/recipe/fp8_block_scaling.cu | 2 + transformer_engine/common/swizzle/swizzle.cu | 56 +++++++++++-------- .../common/transformer_engine.cpp | 4 +- .../common/transpose/cast_transpose.cu | 1 + .../common/transpose/cast_transpose_fusion.cu | 16 ++++-- .../common/transpose/multi_cast_transpose.cu | 4 ++ .../common/transpose/transpose.cu | 1 + .../common/transpose/transpose_fusion.cu | 13 +++-- .../common/util/cast_gated_kernels.cuh | 24 ++++---- .../common/util/cast_kernels.cuh | 18 ++++-- .../common/util/dequantize_kernels.cuh | 1 + transformer_engine/common/util/padding.cu | 4 ++ .../common/util/vectorized_pointwise.h | 4 ++ 37 files changed, 256 insertions(+), 130 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 9ba6688ce..d90dd3abc 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -101,10 +101,10 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl DType::kInt32); } // CUDA event creation - cudaEventCreateWithFlags(&_start_compute, 0); - cudaEventCreateWithFlags(&_stop_compute, 0); - cudaEventCreateWithFlags(&_start_comm, 0); - cudaEventCreateWithFlags(&_stop_comm, 0); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_compute, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_start_comm, 0)); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_stop_comm, 0)); /* Defining the launcher order between the communication and GEMM kernels @@ -114,11 +114,11 @@ CommOverlapCore::CommOverlapCore(int myrank, int numranks, int mylocal, int numl */ int max_connection = transformer_engine::getenv("CUDA_DEVICE_MAX_CONNECTIONS", 8); int runtime_version = 0; - cudaRuntimeGetVersion(&runtime_version); + NVTE_CHECK_CUDA(cudaRuntimeGetVersion(&runtime_version)); cudaDeviceProp deviceProp; - cudaGetDeviceProperties(&deviceProp, 0); + NVTE_CHECK_CUDA(cudaGetDeviceProperties(&deviceProp, 0)); if (runtime_version >= 12030 && deviceProp.major == 9 && max_connection > 1) { - cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming); + NVTE_CHECK_CUDA(cudaEventCreateWithFlags(&_comm_launch_event, cudaEventDisableTiming)); } else { _comm_launch_event = 0; } @@ -129,9 +129,13 @@ CommOverlapCore::~CommOverlapCore() { cudaEventDestroy(_start_comm); cudaEventDestroy(_stop_compute); cudaEventDestroy(_start_compute); - if (_comm_launch_event) cudaEventDestroy(_comm_launch_event); + if (_comm_launch_event) { + cudaEventDestroy(_comm_launch_event); + } - if (_atomic_gemm) cudaFree(_counter.dptr()); + if (_atomic_gemm) { + cudaFree(_counter.dptr()); + } for (size_t i = 0; i < _stream_compute.size(); i++) { cudaStreamSynchronize(_stream_compute[i]); @@ -698,7 +702,9 @@ CommOverlapP2PBase::~CommOverlapP2PBase() { cudaEventDestroy(_stop_recv); cudaEventDestroy(_stop_send); cudaStreamDestroy(_stream_recv); - for (size_t i = 0; i < _stream_send.size(); i++) cudaStreamDestroy(_stream_send[i]); + for (size_t i = 0; i < _stream_send.size(); i++) { + cudaStreamDestroy(_stream_send[i]); + } } TensorWrapper CommOverlapP2PBase::get_buffer_chunk_by_id(const TensorWrapper &source, diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 893644ce6..17f3cf658 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2319,6 +2319,7 @@ void userbuffers_send(const int srchandler, const size_t srcoffset, const int ds if (comm->push == 0) { kuserbuffers_pullsend<<<1, 1, 0, stream>>>(comm->myrank, peer, &(comm->send_id[peer]), reinterpret_cast(flagptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *srcptr = reinterpret_cast(comm->mem_ptr[srchandler]) + srcoffset; void *dstptr = reinterpret_cast(comm->peer_ptr[dsthandler][peerlocal]) + dstoffset; @@ -2516,8 +2517,11 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds &(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler]), reinterpret_cast(flagptr), reinterpret_cast(srcptr), reinterpret_cast(dstptr), signalonly ? 0 : bytes / 16, comm->ub_timeout); - if (!signalonly) + NVTE_CHECK_CUDA(cudaGetLastError()); + if (!signalonly) { kuserbuffers_inc<<<1, 1, 0, stream>>>(&(comm->recv_id[peer * NVTE_MAX_REGIONS + dsthandler])); + NVTE_CHECK_CUDA(cudaGetLastError()); + } if (comm->use_ce) { NVTE_CHECK_CUDA(cudaMemcpyAsync(dstptr, srcptr, bytes, cudaMemcpyDeviceToDevice, stream)); } @@ -2532,6 +2536,7 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds reinterpret_cast(0 ? // temporary disable GET_RECV_PTR_BY_INDEX(peer, comm, dsthandler, 2) : nullptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -2612,24 +2617,28 @@ void producer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); producer_kernel<<>>(atomic_ptr, chunk_i); + NVTE_CHECK_CUDA(cudaGetLastError()); } void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 block(1); dim3 grid(1); consumer_kernel<<>>(atomic_ptr, chunk_i); + NVTE_CHECK_CUDA(cudaGetLastError()); } void consumer_batch(void *atomic_ptr, int first_chunk_i, int num_chunks, cudaStream_t stream) { dim3 block(1); dim3 grid(1); consumer_batch_kernel<<>>(atomic_ptr, first_chunk_i, num_chunks); + NVTE_CHECK_CUDA(cudaGetLastError()); } void reset_counters(void *atomic_ptr, int num_chunks, bool allgather, cudaStream_t stream) { dim3 block(1); dim3 grid(1); reset_counters_kernel<<>>(atomic_ptr, num_chunks, allgather); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -2683,6 +2692,7 @@ void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_in reduce_fp8_in_bf16_out_cuda <<>>(inputs, output, scale, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); + NVTE_CHECK_CUDA(cudaGetLastError()); } template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(void *inputs, void *output, float *scale, @@ -2738,4 +2748,5 @@ void reduce_bf16(void *inputs, void *output, int num_inputs, int input_size, cud dim3 grid(num_blocks); reduce_bf16_cuda<<>>( inputs, output, num_inputs, input_size, num_aligned_elements_per_input, tot_input_size); + NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index a810fb471..8b7f92aff 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -50,6 +50,7 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { update_tensor_scale_inv_kernel<<<1, 1, 0, stream>>>( reinterpret_cast(t->scale.dptr), reinterpret_cast(t->scale_inv.dptr)); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -91,6 +92,7 @@ __global__ void __launch_bounds__(kThreadsPerBlock) dim3 grid(numBlocks, 1, 1); \ memset_kernel \ <<>>(ptr, value, size_in_bytes); \ + NVTE_CHECK_CUDA(cudaGetLastError()); \ return; \ } @@ -101,7 +103,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream if (size_in_bytes > 4096) { // Use cudaMemsetAsync for larger sizes. - cudaMemsetAsync(ptr, value, size_in_bytes, stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(ptr, value, size_in_bytes, stream)); return; } diff --git a/transformer_engine/common/fused_attn/context_parallel.cu b/transformer_engine/common/fused_attn/context_parallel.cu index 15708d2d5..5921d97d5 100644 --- a/transformer_engine/common/fused_attn/context_parallel.cu +++ b/transformer_engine/common/fused_attn/context_parallel.cu @@ -341,6 +341,7 @@ void thd_read_half_tensor(const Tensor &tensor, const Tensor &cu_seqlens, Tensor thd_read_half_tensor_kernel<<>>( half.data.dptr, tensor.data.dptr, reinterpret_cast(cu_seqlens.data.dptr), batch, hidden_size_in_bytes, half_idx, tensor_shape[seq_dim]); + NVTE_CHECK_CUDA(cudaGetLastError()); } /*************************************************************************************************** @@ -397,11 +398,13 @@ void thd_second_half_lse_correction(Tensor lse, const Tensor &lse_per_step, reinterpret_cast(lse.data.dptr), reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { thd_lse_kernel<<>>( reinterpret_cast(lse.data.dptr), reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -446,11 +449,13 @@ void thd_read_second_half_lse(const Tensor &lse, const Tensor &cu_seqlens, Tenso reinterpret_cast(lse.data.dptr), reinterpret_cast(half_lse.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { thd_lse_kernel<<>>( reinterpret_cast(lse.data.dptr), reinterpret_cast(half_lse.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, lse_seqlen, second_half_lse_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -519,6 +524,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, dim_per_head, lse_seqlen, lse_per_step_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { thd_out_correction_kernel <<>>( @@ -528,6 +534,7 @@ static void thd_out_correction_helper(Tensor out, const Tensor &out_per_step, co reinterpret_cast(lse_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, num_heads, dim_per_head, lse_seqlen, lse_per_step_seqlen); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -602,6 +609,7 @@ static void thd_grad_correction_helper(Tensor grad, const Tensor &grad_per_step, reinterpret_cast(grad.data.dptr), reinterpret_cast(grad_per_step.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, hidden_size, total_tokens); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -667,6 +675,7 @@ void thd_get_partitioned_indices(const Tensor &cu_seqlens, Tensor output, int to thd_partition_indices_kernel<<>>( reinterpret_cast(output.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), batch, total_tokens, world_size, rank); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace context_parallel diff --git a/transformer_engine/common/fused_attn/flash_attn.cu b/transformer_engine/common/fused_attn/flash_attn.cu index 0c261d0fa..59207d59a 100644 --- a/transformer_engine/common/fused_attn/flash_attn.cu +++ b/transformer_engine/common/fused_attn/flash_attn.cu @@ -91,6 +91,7 @@ void prepare_flash_attn_fwd(Tensor qkvi, Tensor qkv, cudaStream_t stream) { prepare_kernel_fwd<<>>( reinterpret_cast(qkvi.data.dptr), reinterpret_cast(qkv.data.dptr), shape[1], shape[2], shape[3], shape[4]);); + NVTE_CHECK_CUDA(cudaGetLastError()); } void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream_t stream) { @@ -129,6 +130,7 @@ void prepare_flash_attn_bwd(Tensor q, Tensor k, Tensor v, Tensor qkv, cudaStream reinterpret_cast(q.data.dptr), reinterpret_cast(k.data.dptr), reinterpret_cast(v.data.dptr), reinterpret_cast(qkv.data.dptr), q_shape[0], q_shape[1], q_shape[2], q_shape[3]);); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace flash_attention diff --git a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu index 0932b2cf8..4e6c3c858 100644 --- a/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu +++ b/transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu @@ -416,6 +416,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -454,6 +455,7 @@ void fused_attn_arbitrary_seqlen_fwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); + NVTE_CHECK_CUDA(cudaGetLastError()); if (is_ragged_q) { variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_o] = devOffsetsO; @@ -883,6 +885,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( actual_b, b, static_cast(devPtrCuSeqlensQ), static_cast(devPtrCuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -916,6 +919,7 @@ void fused_attn_arbitrary_seqlen_bwd_impl( layout_group, actual_b, b, h, hg, d_qk, d_v, static_cast(devPtrSeqOffsetsQ), static_cast(devPtrSeqOffsetsKV), ragged_offset_type, devOffsetsQ, devOffsetsK, devOffsetsV, devOffsetsO, devOffsetsS); + NVTE_CHECK_CUDA(cudaGetLastError()); if (is_ragged_q) { variant_pack[offset_q] = devOffsetsQ; variant_pack[offset_o] = devOffsetsO; diff --git a/transformer_engine/common/fused_attn/fused_attn_fp8.cu b/transformer_engine/common/fused_attn/fused_attn_fp8.cu index 3e38a5066..d7f098376 100644 --- a/transformer_engine/common/fused_attn/fused_attn_fp8.cu +++ b/transformer_engine/common/fused_attn/fused_attn_fp8.cu @@ -1111,6 +1111,7 @@ void fused_attn_fp8_fwd_impl(int64_t b, int64_t h, int64_t s_q, int64_t s_kv, in cu_seqlens_to_offsets<<>>( b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + NVTE_CHECK_CUDA(cudaGetLastError()); void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); @@ -1577,6 +1578,7 @@ void fused_attn_fp8_bwd_impl( cu_seqlens_to_offsets<<>>( b, h, d, reinterpret_cast(devPtrcuSeqlensQ), actual_seqlens_q, qkv_ragged_offset, o_ragged_offset); + NVTE_CHECK_CUDA(cudaGetLastError()); void* devPtrQKVRaggedOffset = reinterpret_cast(qkv_ragged_offset); void* devPtrORaggedOffset = reinterpret_cast(o_ragged_offset); void* devPtrMNKOverride = reinterpret_cast(actual_seqlens_q); @@ -1933,6 +1935,7 @@ void fused_attn_fp8_fwd_impl_v1( b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } @@ -2329,6 +2332,7 @@ void fused_attn_fp8_bwd_impl_v1( b, b, static_cast(devPtrcuSeqlensQ), // TODO(pass max_b) static_cast(devPtrcuSeqlensKV), static_cast(devActualSeqlenQ), static_cast(devActualSeqlenKV)); + NVTE_CHECK_CUDA(cudaGetLastError()); variant_pack[seq_q] = devActualSeqlenQ; variant_pack[seq_kv] = devActualSeqlenKV; } diff --git a/transformer_engine/common/fused_attn/kv_cache.cu b/transformer_engine/common/fused_attn/kv_cache.cu index 9bdc41e9e..67119c323 100644 --- a/transformer_engine/common/fused_attn/kv_cache.cu +++ b/transformer_engine/common/fused_attn/kv_cache.cu @@ -157,6 +157,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso reinterpret_cast(page_table.data.dptr), reinterpret_cast(cu_new_lens.data.dptr), reinterpret_cast(cu_cached_lens.data.dptr), h_kv, d_k, d_v, b, max_seq_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } dim3 grid_size(b, max_ctx_len); copy_to_kv_cache_kernel<<>>( @@ -166,6 +167,7 @@ void copy_to_kv_cache_launcher(Tensor new_k, Tensor new_v, Tensor k_cache, Tenso reinterpret_cast(cu_new_lens.data.dptr), reinterpret_cast(cu_cached_lens.data.dptr), qkv_format, h_kv, d_k, d_v, b, max_ctx_len, max_seq_len, max_pages_per_seq, is_non_paged); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -215,6 +217,7 @@ void convert_thd_to_bshd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se reinterpret_cast(tensor.data.dptr), reinterpret_cast(new_tensor.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), b, max_seq_len, h, d); + NVTE_CHECK_CUDA(cudaGetLastError()); } void convert_thd_to_bshd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int b, @@ -254,6 +257,7 @@ void convert_bshd_to_thd_launcher(Tensor tensor, Tensor new_tensor, Tensor cu_se reinterpret_cast(tensor.data.dptr), reinterpret_cast(new_tensor.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), b, max_seq_len, h, d); + NVTE_CHECK_CUDA(cudaGetLastError()); } void convert_bshd_to_thd(Tensor tensor, Tensor cu_seqlens, Tensor new_tensor, int t, diff --git a/transformer_engine/common/fused_attn/utils.cu b/transformer_engine/common/fused_attn/utils.cu index 768dbd99f..df1eae0dd 100644 --- a/transformer_engine/common/fused_attn/utils.cu +++ b/transformer_engine/common/fused_attn/utils.cu @@ -600,13 +600,14 @@ uint32_t GetRuntimeNumSegments(void *cu_seqlen, void *workspace, size_t len, cud // workspace size requires 4 bytes uint32_t *dout = static_cast(workspace); uint32_t hout{}; - cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(dout, 0, sizeof(uint32_t), stream)); constexpr int threads = 128; const int blocks = (len - 1) / threads + 1; get_runtime_num_segments_kernel<<>>(static_cast(cu_seqlen), len, dout); - cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream); - cudaStreamSynchronize(stream); + NVTE_CHECK_CUDA(cudaGetLastError()); + NVTE_CHECK_CUDA(cudaMemcpyAsync(&hout, dout, sizeof(uint32_t), cudaMemcpyDeviceToHost, stream)); + NVTE_CHECK_CUDA(cudaStreamSynchronize(stream)); return hout; } @@ -633,4 +634,5 @@ void nvte_extract_seed_and_offset(int64_t *rng_state_ptr, int captured, int64_t fused_attn::extract_seed_and_offset<<<1, 1, 0, stream>>>( rng_state_ptr, captured, seed_ptr, seed_val, offset_ptr, offset_val, offset_intragraph); + NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index f64b75d97..a738be873 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -177,9 +177,9 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.stream = stream; // Update the max cluster size based on the device - cudaOccupancyMaxPotentialClusterSize( + NVTE_CHECK_CUDA(cudaOccupancyMaxPotentialClusterSize( &cluster_size, - reinterpret_cast(fused_moe_aux_loss_forward_kernel), &config); + reinterpret_cast(fused_moe_aux_loss_forward_kernel), &config)); cudaLaunchAttribute attribute[1]; attribute[0].id = cudaLaunchAttributeClusterDimension; @@ -189,14 +189,15 @@ void fused_moe_aux_loss_forward_kernel_launcher(const DataType* probs, config.numAttrs = 1; config.attrs = attribute; - cudaLaunchKernelEx(&config, fused_moe_aux_loss_forward_kernel, probs, - tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, - coeff, aux_loss, Const_buf); + NVTE_CHECK_CUDA(cudaLaunchKernelEx( + &config, fused_moe_aux_loss_forward_kernel, probs, tokens_per_expert, + total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf)); } else { size_t smem_size = sizeof(CompType) * num_cols; fused_moe_aux_loss_forward_kernel <<<1, 1024, smem_size, stream>>>(probs, tokens_per_expert, total_num_tokens, num_experts, num_rows, num_cols, topk, coeff, aux_loss, Const_buf); + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -247,6 +248,7 @@ void fused_moe_aux_loss_backward_kernel_launcher(const float* Const_buf, int grid_size = (num_rows + block_size - 1) / block_size; fused_moe_aux_loss_backward_kernel<<>>( Const_buf, tokens_per_expert, num_rows, num_cols, grad_aux_loss, grad_probs); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_moe_aux_loss_backward(const Tensor& Const_buf, const Tensor& tokens_per_expert, diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 47d215057..03d22942b 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -151,6 +151,7 @@ void fused_score_for_moe_aux_loss_forward_kernel_launcher( <<>>( logits, num_tokens, num_experts, topk, score_function, scores, routing_map, intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, int num_experts, @@ -286,6 +287,7 @@ void fused_score_for_moe_aux_loss_backward_kernel_launcher( <<>>( intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, grad_logits); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index a1785c663..03e972332 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -257,6 +257,7 @@ void fused_topk_with_score_function_forward_kernel_launcher( <<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, int num_experts, @@ -447,6 +448,7 @@ void fused_topk_with_score_function_backward_kernel_launcher( <<>>( routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, use_pre_softmax, scaling_factor, score_function, grad_logits); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fused_topk_with_score_function_backward(const Tensor &routing_map, diff --git a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu index 1f54f2e72..bbe722a8f 100644 --- a/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_aligned_causal_masked_softmax.cu @@ -353,6 +353,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_forward( scaled_aligned_causal_masked_softmax_warp_forward <<>>(dst, src, scale, microbatches, query_seq_len, key_seq_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -363,6 +364,7 @@ void call_kernel_scaled_aligned_causal_masked_softmax_backward( scaled_aligned_causal_masked_softmax_warp_backward <<>>(gradInput, grad, output, scale, microbatches, query_seq_len, key_seq_len); + NVTE_CHECK_CUDA(cudaGetLastError()); } template diff --git a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu index 02f315372..79318cd28 100644 --- a/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_masked_softmax.cu @@ -513,6 +513,7 @@ void dispatch_scaled_softmax_forward(output_t *dst, const input_t *src, const in default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -625,6 +626,7 @@ void dispatch_scaled_masked_softmax_forward(output_t *dst, const input_t *src, c default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -736,6 +738,7 @@ void dispatch_scaled_masked_softmax_backward(output_t *grad_input, const input_t default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu index 351f4946c..03cdd6827 100644 --- a/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu +++ b/transformer_engine/common/fused_softmax/scaled_upper_triang_masked_softmax.cu @@ -445,6 +445,7 @@ void dispatch_scaled_upper_triang_masked_softmax_forward(output_t *dst, const in default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -561,6 +562,7 @@ void dispatch_scaled_upper_triang_masked_softmax_backward(output_t *grad_input, default: break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/multi_tensor/l2norm.cu b/transformer_engine/common/multi_tensor/l2norm.cu index ca2fce27a..cc66562af 100644 --- a/transformer_engine/common/multi_tensor/l2norm.cu +++ b/transformer_engine/common/multi_tensor/l2norm.cu @@ -413,6 +413,7 @@ void multi_tensor_l2norm_cuda(int chunk_size, Tensor noop_flag, reinterpret_cast(ret.data.dptr), per_tensor ? reinterpret_cast(ret_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor); + NVTE_CHECK_CUDA(cudaGetLastError()); } void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, @@ -440,6 +441,7 @@ void multi_tensor_unscale_l2norm_cuda(int chunk_size, Tensor noop_flag, reinterpret_cast(ret.data.dptr), per_tensor ? reinterpret_cast(ret_per_tensor.data.dptr) : nullptr, per_tensor, max_chunks_per_tensor); + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace multi_tensor_l2norm diff --git a/transformer_engine/common/normalization/common.cpp b/transformer_engine/common/normalization/common.cpp index c280c1c35..337b16508 100644 --- a/transformer_engine/common/normalization/common.cpp +++ b/transformer_engine/common/normalization/common.cpp @@ -138,8 +138,8 @@ void TeNormalizationPlan::_set_workspace() { if (_launch_params.barrier_bytes > 0) { _launch_params.params.barrier = reinterpret_cast(workspace_dptr + _launch_params.workspace_bytes); - cudaMemsetAsync(_launch_params.params.barrier, 0, _launch_params.barrier_bytes, - _launch_params.stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(_launch_params.params.barrier, 0, + _launch_params.barrier_bytes, _launch_params.stream)); } if constexpr (std::is_same_v) { _launch_params.params.dgamma_part = diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index f63edfb64..1eeb08415 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -14,16 +14,16 @@ using namespace transformer_engine::normalization; template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_bwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_bwd_tuned_kernel; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -49,13 +49,14 @@ void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), Kernel_traits::SMEM_BYTES, - stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES, stream)); } using Kernel_traits_f = @@ -66,13 +67,14 @@ void launch_tuned_(LaunchParams &launch_params, auto kernel_f = &ln_bwd_finalize_tuned_kernel; kernel_f<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_bwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; // Instantiate kernel @@ -87,8 +89,8 @@ void launch_general_(LaunchParams &launch_params, int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { int ctas_per_sm; - cudaOccupancyMaxActiveBlocksPerMultiprocessor(&ctas_per_sm, kernel, - Kernel_traits::THREADS_PER_CTA, 0); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -109,10 +111,11 @@ void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } // Launch finalization kernel @@ -126,6 +129,7 @@ void launch_general_(LaunchParams &launch_params, dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); kernel_final<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ @@ -134,8 +138,8 @@ void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_ln_bwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu index 9336abc26..787c75ef8 100644 --- a/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_fwd_cuda_kernel.cu @@ -13,15 +13,15 @@ using namespace transformer_engine::normalization; template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_fwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_fwd_tuned_kernel; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -45,19 +45,21 @@ void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES_FWD, stream)); } } template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_ln_fwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_fwd_general_kernel; @@ -70,8 +72,8 @@ void launch_general_(LaunchParams &launch_params, int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -91,10 +93,11 @@ void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } } @@ -104,8 +107,8 @@ void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_ln_fwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu index 0a7b38000..9bd56c4ec 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_bwd_semi_cuda_kernel.cu @@ -13,8 +13,8 @@ using namespace transformer_engine::normalization; template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_bwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &rmsnorm_bwd_tuned_kernel; @@ -48,6 +48,7 @@ void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); @@ -65,13 +66,14 @@ void launch_tuned_(LaunchParams &launch_params, auto kernel_f = &rmsnorm_bwd_finalize_tuned_kernel; kernel_f<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_bwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) auto ceil_div = [](int x, int y) -> int { return (x + y - 1) / y; }; // Instantiate kernel @@ -110,6 +112,7 @@ void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, @@ -127,6 +130,7 @@ void launch_general_(LaunchParams &launch_params, dim3 block_final(Kernel_traits::THREADS_PER_WARP * WARPS_N_FINAL, WARPS_M_FINAL); dim3 grid_final(ceil_div(cols, ELTS_N_PER_CTA_FINAL), 1); kernel_final<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } #define REGISTER_NORM_LAUNCHER(NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, \ @@ -135,8 +139,8 @@ void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_rmsnorm_bwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu index 25bed95dc..90b4f1340 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_fwd_cuda_kernel.cu @@ -13,16 +13,16 @@ using namespace transformer_engine::normalization; template -void launch_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_fwd_tuned_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &rmsnorm_fwd_tuned_kernel; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, Kernel_traits::SMEM_BYTES_FWD)); launch_params.params.ctas_per_row = CTAS_PER_ROW; launch_params.params.ctas_per_col = launch_params.multiprocessorCount * ctas_per_sm / launch_params.params.ctas_per_row; @@ -46,19 +46,21 @@ void launch_tuned_(LaunchParams &launch_params, if (ctas_per_row == 1) { kernel<<>>( launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { dim3 grid(ctas_per_row * ctas_per_col); dim3 block(Kernel_traits::THREADS_PER_CTA); void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel((void *)kernel, grid, block, (void **)¶ms_, // NOLINT(*) - Kernel_traits::SMEM_BYTES_FWD, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), + Kernel_traits::SMEM_BYTES_FWD, stream)); } } template -void launch_general_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*) +void launch_rmsnorm_fwd_general_(LaunchParams &launch_params, + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &rmsnorm_fwd_general_kernel; @@ -71,8 +73,8 @@ void launch_general_(LaunchParams &launch_params, int ctas_per_row = launch_params.params.ctas_per_row; if (configure_params) { int ctas_per_sm; - cudaError status_ = cudaOccupancyMaxActiveBlocksPerMultiprocessor( - &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0); + NVTE_CHECK_CUDA(cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &ctas_per_sm, kernel, Kernel_traits::THREADS_PER_CTA, 0)); const int max_ctas = launch_params.multiprocessorCount * ctas_per_sm; ctas_per_row = ceil_div(cols, HIDDEN_SIZE); ctas_per_col = std::min(ceil_div(rows, WARPS_M), max_ctas / ctas_per_row); @@ -92,10 +94,11 @@ void launch_general_(LaunchParams &launch_params, dim3 block(Kernel_traits::THREADS_PER_CTA); if (ctas_per_row == 1) { kernel<<>>(launch_params.params); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { void *params_ = reinterpret_cast(&launch_params.params); - cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, - reinterpret_cast(¶ms_), 0, stream); + NVTE_CHECK_CUDA(cudaLaunchCooperativeKernel(reinterpret_cast(kernel), grid, block, + reinterpret_cast(¶ms_), 0, stream)); } } @@ -105,8 +108,8 @@ void launch_general_(LaunchParams &launch_params, void \ norm_##NORM_TYPE##_##NORM_STAGE##_##LAUNCH_TYPE##_##HIDDEN_SIZE##_##WTYPE##_##ITYPE##_##OTYPE##_##CTYPE( \ LaunchParams &launch_params, const bool configure_params) { \ - launch_##LAUNCH_TYPE##_( \ - launch_params, configure_params); \ + launch_rmsnorm_fwd_##LAUNCH_TYPE##_(launch_params, configure_params); \ } \ REGISTER_NORM_BASE( \ NORM_TYPE, NORM_STAGE, LAUNCH_TYPE, HIDDEN_SIZE, WTYPE, ITYPE, OTYPE, CTYPE, \ diff --git a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu index a18ea6d4a..d5f6aeecc 100644 --- a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu +++ b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu @@ -35,17 +35,20 @@ void nvshmem_wait_on_stream(uint64_t* sig_addr, WaitKind wait_kind, cudaStream_t switch (wait_kind) { case WaitKind::KERNEL_WAIT: wait_until_on_stream_and_reset<<<1, 1, 0, cur_stream>>>(sig_addr, wait_value, signal_reset); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case WaitKind::NVSHMEM_WAIT: nvshmemx_uint64_wait_until_on_stream(sig_addr, NVSHMEM_CMP_EQ, wait_value, cur_stream); - cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, - CU_STREAM_WRITE_VALUE_DEFAULT); + NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, + (cuuint64_t)signal_reset, + CU_STREAM_WRITE_VALUE_DEFAULT)); break; case WaitKind::STREAM_WAIT: - cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)wait_value, - CU_STREAM_WAIT_VALUE_GEQ); - cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, (cuuint64_t)signal_reset, - CU_STREAM_WRITE_VALUE_DEFAULT); + NVTE_CHECK_CUDA_DRIVER(cuStreamWaitValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, + (cuuint64_t)wait_value, CU_STREAM_WAIT_VALUE_GEQ)); + NVTE_CHECK_CUDA_DRIVER(cuStreamWriteValue64((CUstream)cur_stream, (CUdeviceptr)sig_addr, + (cuuint64_t)signal_reset, + CU_STREAM_WRITE_VALUE_DEFAULT)); break; } } diff --git a/transformer_engine/common/permutation/permutation.cu b/transformer_engine/common/permutation/permutation.cu index 5716196fe..d66298b69 100644 --- a/transformer_engine/common/permutation/permutation.cu +++ b/transformer_engine/common/permutation/permutation.cu @@ -243,11 +243,13 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, moe_permute_row_map<<>>(sorted_row_id, row_id_map, num_rows, topK, num_out_tokens); + NVTE_CHECK_CUDA(cudaGetLastError()); blocks = num_rows; threads = std::min(num_cols / kElementsPerAccess, 1024); moe_permute_kernel<<>>( input, nullptr, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { // moe_unpermute_bwd @@ -259,6 +261,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, moe_permute_kernel<<>>( input, input_fwd, output, nullptr, nullptr, row_id_map, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { // moe_unpermute_bwd with probs @@ -282,6 +285,7 @@ void nvte_permute_launcher(const T *input, T *output, const int *sorted_row_id, } else { NVTE_ERROR("topK cannot exceed 128."); } + NVTE_CHECK_CUDA(cudaGetLastError()); } } } @@ -306,11 +310,13 @@ void nvte_unpermute_launcher(const T *input, T *output, int *row_id_map, const f moe_unpermute_kernel<<>>( input, output, row_id_map, nullptr, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { // moe_unpermute_fwd with probs moe_unpermute_kernel<<>>( input, output, row_id_map, prob, num_rows, topK, num_cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index f8642cfb6..e1657b77a 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -60,7 +60,7 @@ __launch_bounds__(amax_kernel_threads) __global__ template void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { // Zero out amax so we can update with atomic max - cudaMemsetAsync(amax, 0, sizeof(float), stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); // Return immediately if tensor is empty if (N == 0) { diff --git a/transformer_engine/common/recipe/fp8_block_scaling.cu b/transformer_engine/common/recipe/fp8_block_scaling.cu index 759197dc8..42a7b8d69 100644 --- a/transformer_engine/common/recipe/fp8_block_scaling.cu +++ b/transformer_engine/common/recipe/fp8_block_scaling.cu @@ -183,6 +183,7 @@ void fp8_block_scaling_compute_partial_amax(const Tensor inp, Tensor amax, size_ reinterpret_cast(amax.data.dptr), amax_stride_h, amax_stride_w, h, w, start_offset, len);) + NVTE_CHECK_CUDA(cudaGetLastError()); } void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor scale, size_t h, @@ -215,6 +216,7 @@ void fp8_block_scaling_partial_cast(const Tensor inp, Tensor out, const Tensor s reinterpret_cast(out.data.dptr), reinterpret_cast(scale.data.dptr), scale_stride_h, scale_stride_w, h, w, start_offset, len);))) + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace fp8_block_scaling_recipe diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index fcb379a82..9ec86a37c 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -387,22 +387,25 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s const int original_K = input->flat_last_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_row_scaling_kernel <<>>( input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 2: - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_row_scaling_kernel <<>>( input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); break; case 1: - cudaFuncSetAttribute(swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_row_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_row_scaling_kernel <<>>( input->scale_inv.dptr, output->scale_inv.dptr, m, k, original_M, original_K); @@ -411,6 +414,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s NVTE_ERROR("Not valid vec_load_size."); break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } if (input->has_columnwise_data()) { int vec_load_size = (num_tiles_m - 1) % 4 + 1; @@ -422,24 +426,27 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s const int original_K = input->flat_first_dim() / MXFP8_BLOCK_SIZE; switch (vec_load_size) { case 4: - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_col_scaling_kernel <<>>(input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k, original_M, original_K); break; case 2: - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_col_scaling_kernel <<>>(input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, k, original_M, original_K); break; case 1: - cudaFuncSetAttribute(swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(swizzle_col_scaling_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); swizzle_col_scaling_kernel <<>>(input->columnwise_scale_inv.dptr, output->columnwise_scale_inv.dptr, m, @@ -449,6 +456,7 @@ void swizzle_scaling_factors(const Tensor* input, Tensor* output, cudaStream_t s NVTE_ERROR("Not valid vec_load_size."); break; } + NVTE_CHECK_CUDA(cudaGetLastError()); } // 2D block scaling @@ -489,23 +497,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, if (is_rowwise) { switch (vec_load_size) { case 4: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_row_scaling_kernel <<>>(kernel_args); break; case 2: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_row_scaling_kernel <<>>(kernel_args); break; case 1: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_row_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_row_scaling_kernel <<>>(kernel_args); break; @@ -516,23 +524,23 @@ void launch_multi_tensor_swizzle_scaling_factors(MultiSwizzleArgs& kernel_args, } else { switch (vec_load_size) { case 4: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_col_scaling_kernel <<>>(kernel_args); break; case 2: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_col_scaling_kernel <<>>(kernel_args); break; case 1: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( multi_tensor_swizzle_col_scaling_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, slm_size)); multi_tensor_swizzle_col_scaling_kernel <<>>(kernel_args); break; diff --git a/transformer_engine/common/transformer_engine.cpp b/transformer_engine/common/transformer_engine.cpp index a33f3d959..55654989a 100644 --- a/transformer_engine/common/transformer_engine.cpp +++ b/transformer_engine/common/transformer_engine.cpp @@ -544,11 +544,11 @@ void nvte_zero_tensor(const NVTETensor tensor, cudaStream_t stream) { // Zero out tensor data if allocated if (t.data.dptr != nullptr) { const size_t size_in_bytes = nvte_tensor_size_bytes(tensor); - cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(t.data.dptr, 0, size_in_bytes, stream)); } // Set amax to 0 if allocated if (t.amax.dptr != nullptr) { - cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream); + NVTE_CHECK_CUDA(cudaMemsetAsync(t.amax.dptr, 0, sizeof(float), stream)); } } diff --git a/transformer_engine/common/transpose/cast_transpose.cu b/transformer_engine/common/transpose/cast_transpose.cu index 723dbb4a9..648070c8d 100644 --- a/transformer_engine/common/transpose/cast_transpose.cu +++ b/transformer_engine/common/transpose/cast_transpose.cu @@ -335,6 +335,7 @@ void cast_transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cu static_cast(output.scale.dptr), static_cast(output.amax.dptr), static_cast(output.scale_inv.dptr), row_length, num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); } } else { NVTE_ERROR("Not implemented scaling mode: ", to_string(output.scaling_mode)); diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index ca48a055a..6329e79ae 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -264,6 +264,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt reinterpret_cast(dbias->data.dptr), reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, reduce_dbias_num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); } template , - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); cast_transpose_fused_kernel_notaligned <<>>( param, row_length, num_rows, num_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } if constexpr (IS_DBIAS) { @@ -1197,10 +1199,10 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu const size_t shmem_size = cast_transpose_num_threads / n_warps_per_tile * (THREADS_PER_WARP + 1) * sizeof(Vec); if (full_tile) { - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( dgated_act_cast_transpose_kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); dgated_act_cast_transpose_kernel @@ -1213,11 +1215,12 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( dgated_act_cast_transpose_kernel_notaligned, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); dgated_act_cast_transpose_kernel_notaligned <<>>( @@ -1229,6 +1232,7 @@ void dgated_act_cast_transpose(const Tensor &input, const Tensor &gated_act_inpu reinterpret_cast(output->amax.dptr), reinterpret_cast(output->scale_inv.dptr), row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); }); // NOLINT(*) ); // NOLINT(*) } diff --git a/transformer_engine/common/transpose/multi_cast_transpose.cu b/transformer_engine/common/transpose/multi_cast_transpose.cu index 2be365465..bf3856568 100644 --- a/transformer_engine/common/transpose/multi_cast_transpose.cu +++ b/transformer_engine/common/transpose/multi_cast_transpose.cu @@ -258,6 +258,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_aligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args_aligned.num_tensors = 0; } if (kernel_args_unaligned.num_tensors == kMaxTensorsPerKernel) { @@ -271,6 +272,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_unaligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args_unaligned.num_tensors = 0; } @@ -311,6 +313,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_aligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } if (kernel_args_unaligned.num_tensors > 0) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( @@ -323,6 +326,7 @@ void multi_cast_transpose(const std::vector input_list, std::vector <<>>(kernel_args_unaligned);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 103f45cf1..9f0acd807 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -279,6 +279,7 @@ void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, cudaStr static_cast(noop.data.dptr), static_cast(output.data.dptr), row_length, num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); }); // NOLINT(*) } diff --git a/transformer_engine/common/transpose/transpose_fusion.cu b/transformer_engine/common/transpose/transpose_fusion.cu index 7a19c1285..3c51ce3da 100644 --- a/transformer_engine/common/transpose/transpose_fusion.cu +++ b/transformer_engine/common/transpose/transpose_fusion.cu @@ -416,6 +416,7 @@ void reduce_dbias(const Tensor &workspace, Tensor *dbias, const size_t row_lengt reinterpret_cast(dbias->data.dptr), reinterpret_cast(workspace.data.dptr), reduce_dbias_row_length, reduce_dbias_num_rows); + NVTE_CHECK_CUDA(cudaGetLastError()); } void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor *dbias, @@ -472,17 +473,21 @@ void fp8_transpose_dbias(const Tensor &input, Tensor *transposed_output, Tensor param.workspace = reinterpret_cast(workspace->data.dptr); if (full_tile) { - cudaFuncSetAttribute(transpose_dbias_kernel, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + NVTE_CHECK_CUDA(cudaFuncSetAttribute(transpose_dbias_kernel, + cudaFuncAttributePreferredSharedMemoryCarveout, + 100)); transpose_dbias_kernel <<>>( param, row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } else { - cudaFuncSetAttribute(transpose_dbias_kernel_notaligned, - cudaFuncAttributePreferredSharedMemoryCarveout, 100); + NVTE_CHECK_CUDA( + cudaFuncSetAttribute(transpose_dbias_kernel_notaligned, + cudaFuncAttributePreferredSharedMemoryCarveout, 100)); transpose_dbias_kernel_notaligned <<>>( param, row_length, num_rows, n_tiles); + NVTE_CHECK_CUDA(cudaGetLastError()); } reduce_dbias(*workspace, dbias, row_length, num_rows, nvec_out, diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 83359eb05..50ff82d85 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -950,16 +950,17 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu const size_t shmem_size = grad_mem + (in_act_mem + in_gate_mem) + (out_act_mem + out_gate_mem) + TMA_SHMEM_ALIGNMENT; - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_fp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); cast_fp8_gated_kernel <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act, tensor_map_output_gate, amax_ptr, scale_inv_ptr, scale_ptr, rows, - cols);); // NOLINT(*) - ); // NOLINT(*) + cols); + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) } template , - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); mxfp8_kernel::cast_mxfp8_gated_kernel @@ -1096,13 +1097,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); mxfp8_kernel::cast_mxfp8_gated_kernel @@ -1112,13 +1114,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); mxfp8_kernel::cast_mxfp8_gated_kernel @@ -1128,6 +1131,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out tensor_map_output_act_colwise, tensor_map_output_gate_colwise, scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; }); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 9a02d71f2..1158132e3 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -894,6 +894,7 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, reduce_dbias_kernel <<>>( reinterpret_cast(dbias->data.dptr), workspace_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -925,6 +926,7 @@ static void cast_fp8_1D(const Tensor &input, Tensor *output, cudaStream_t stream cast_fp8_1D_kernel<<>>( input_ptr, output_ptr, amax_ptr, scale_inv_ptr, scale_ptr, N);); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } template @@ -988,6 +990,7 @@ void cast_fp8_2D(const Tensor &input, const Tensor *act_input, Tensor *output, T <<>>(tensor_map_input, tensor_map_act_input, tensor_map_output, workspace_ptr, amax_ptr, scale_inv_ptr, scale_ptr, rows, cols); + NVTE_CHECK_CUDA(cudaGetLastError()); if constexpr (IS_DBIAS) { reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); @@ -1124,10 +1127,10 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, switch (scaling_type) { case ScalingType::ROWWISE: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); cast_mxfp8_2D_kernel @@ -1136,12 +1139,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); cast_mxfp8_2D_kernel @@ -1150,12 +1154,13 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::BIDIMENSIONAL: - cudaFuncSetAttribute( + NVTE_CHECK_CUDA(cudaFuncSetAttribute( cast_mxfp8_2D_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size); + cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); cast_mxfp8_2D_kernel @@ -1164,6 +1169,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, tensor_map_output_colwise, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); break; } diff --git a/transformer_engine/common/util/dequantize_kernels.cuh b/transformer_engine/common/util/dequantize_kernels.cuh index a82f11307..e2d8d34f3 100644 --- a/transformer_engine/common/util/dequantize_kernels.cuh +++ b/transformer_engine/common/util/dequantize_kernels.cuh @@ -329,6 +329,7 @@ static void mxfp8_dequantize(const Tensor &input, Tensor *output, cudaStream_t s ); // NOLINT(*) ); // NOLINT(*) ); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } // namespace dequantization diff --git a/transformer_engine/common/util/padding.cu b/transformer_engine/common/util/padding.cu index ad6cf2a2e..0d92b243a 100644 --- a/transformer_engine/common/util/padding.cu +++ b/transformer_engine/common/util/padding.cu @@ -248,6 +248,7 @@ void multi_padding(const std::vector input_list, std::vector o const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_padding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args.num_tensors = 0; } @@ -277,6 +278,7 @@ void multi_padding(const std::vector input_list, std::vector o const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_padding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -322,6 +324,7 @@ void multi_unpadding(const std::vector input_list, std::vector const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_unpadding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); kernel_args.num_tensors = 0; } @@ -349,6 +352,7 @@ void multi_unpadding(const std::vector input_list, std::vector const int n_blocks = kernel_args.block_range[kernel_args.num_tensors]; multi_unpadding_kernel <<>>(kernel_args);); // NOLINT(*) + NVTE_CHECK_CUDA(cudaGetLastError()); } } diff --git a/transformer_engine/common/util/vectorized_pointwise.h b/transformer_engine/common/util/vectorized_pointwise.h index 6e4507eef..0d667a0ec 100644 --- a/transformer_engine/common/util/vectorized_pointwise.h +++ b/transformer_engine/common/util/vectorized_pointwise.h @@ -364,6 +364,7 @@ void VectorizedUnaryKernelLauncher(const InputType *input, const fp32 *noop, Out break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -398,6 +399,7 @@ void VectorizedUnaryGradKernelLauncher(const InputTypeGrad *grad, const InputTyp break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -491,6 +493,7 @@ void GatedActivationKernelLauncher(const InputType *input, OutputType *output, c break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } @@ -602,6 +605,7 @@ void DGatedActivationKernelLauncher(const InputType *grad, const InputType *inpu break; } } + NVTE_CHECK_CUDA(cudaGetLastError()); } } From 607fcc432cce05b45f6ecaf3eac2bb0c1691976b Mon Sep 17 00:00:00 2001 From: buptzyb Date: Sat, 30 Aug 2025 01:56:34 +0800 Subject: [PATCH 118/153] [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/graph.py | 31 +++++++++++++++++++---------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/graph.py b/transformer_engine/pytorch/graph.py index eda18a185..f0fe557c0 100644 --- a/transformer_engine/pytorch/graph.py +++ b/transformer_engine/pytorch/graph.py @@ -850,7 +850,7 @@ def make_graphed_callables( num_warmup_iters: int = 3, allow_unused_input: bool = False, sample_kwargs: Optional[SingleOrTuple[Dict[str, Any]]] = None, - fp8_enabled: bool = False, + fp8_enabled: SingleOrTuple[bool] = False, fp8_calibrating: bool = False, fp8_recipe: Optional[Recipe] = None, fp8_group: Optional[dist_group_type] = None, @@ -896,8 +896,9 @@ def make_graphed_callables( FP8-related parameters ---------------------- - fp8_enabled: bool, default = `True` - whether or not to enable fp8 + fp8_enabled: (tuple of) bool, default = `False` + whether or not to enable fp8. + If tuple, the length must match the number of modules. fp8_calibrating: bool, default = `False` calibration mode allows collecting statistics such as amax and scale data of fp8 tensors even when executing without fp8 enabled. This is @@ -919,17 +920,25 @@ def make_graphed_callables( """ set_capture_start() - if fp8_enabled and fp8_recipe is None: - fp8_recipe = get_default_fp8_recipe() - elif not fp8_enabled: - fp8_recipe = None - # Handle single module. just_one_callable = False if not isinstance(modules, tuple): just_one_callable = True modules = (modules,) + if not isinstance(fp8_enabled, tuple): + assert isinstance(fp8_enabled, bool), "fp8_enabled must be a bool or a tuple of bools" + fp8_enabled = (fp8_enabled,) * len(modules) + else: + assert len(fp8_enabled) == len( + modules + ), f"fp8_enabled length ({len(fp8_enabled)}) must match modules length ({len(modules)})" + if any(fp8_enabled) and fp8_recipe is None: + fp8_recipe = get_default_fp8_recipe() + elif not any(fp8_enabled): + fp8_recipe = None + module_uses_fp8 = dict(zip((id(m) for m in modules), fp8_enabled)) + # Store FP8 tensors to reset later. saved_fp8_tensors = save_fp8_tensors(modules, fp8_recipe=fp8_recipe) @@ -944,15 +953,15 @@ def wrap_autocast(block): old_call_funcs[block_cls] = block_cls.__call__ # Wrap the original call function of the module class. - def call_func(*args, **kwargs): + def call_func(self, *args, **kwargs): with fp8_autocast( - enabled=fp8_enabled, + enabled=module_uses_fp8.get(id(self), False), calibrating=fp8_calibrating, fp8_recipe=fp8_recipe, fp8_group=fp8_group, _graph=True, ): - outputs = old_call_funcs[block_cls](*args, **kwargs) + outputs = old_call_funcs[block_cls](self, *args, **kwargs) return outputs block_cls.__call__ = call_func From e0e3d1235d9da00dfca3f1cd3461187950bfd84e Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Sun, 31 Aug 2025 09:15:41 -0700 Subject: [PATCH 119/153] Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_fusible_ops.py | 56 ++- transformer_engine/common/CMakeLists.txt | 1 + transformer_engine/common/__init__.py | 33 ++ transformer_engine/common/dropout/dropout.cu | 355 ++++++++++++++++++ .../include/transformer_engine/dropout.h | 51 +++ transformer_engine/pytorch/csrc/extensions.h | 11 + .../pytorch/csrc/extensions/dropout.cpp | 89 +++++ .../pytorch/csrc/extensions/pybind.cpp | 7 + .../pytorch/ops/basic/dropout.py | 69 +++- 9 files changed, 639 insertions(+), 33 deletions(-) create mode 100644 transformer_engine/common/dropout/dropout.cu create mode 100644 transformer_engine/common/include/transformer_engine/dropout.h create mode 100644 transformer_engine/pytorch/csrc/extensions/dropout.cpp diff --git a/tests/pytorch/test_fusible_ops.py b/tests/pytorch/test_fusible_ops.py index 9325f5d1e..bb07e87d9 100644 --- a/tests/pytorch/test_fusible_ops.py +++ b/tests/pytorch/test_fusible_ops.py @@ -1749,25 +1749,44 @@ def test_constant_scale( torch.testing.assert_close(y_test, y_ref, **tols) torch.testing.assert_close(dx_test, x_ref.grad, **tols) - @pytest.mark.parametrize("prob", (0.1, 0.5, 0.75)) + @pytest.mark.parametrize("prob", (0.0625, 0.5, 0.75)) @pytest.mark.parametrize("is_training", (True, False)) - @pytest.mark.parametrize("shape", ((101,), (2, 4, 16))) + @pytest.mark.parametrize("quantization", (None, "fp8_current_scaling")) + @pytest.mark.parametrize("shape", ((101,), (2, 4, 16), (128, 128))) @pytest.mark.parametrize("dtype", _dtypes) def test_dropout( self, *, prob: float, is_training: bool, + quantization: Optional[str], shape: Iterable[int], dtype: torch.dtype, device: torch.device = "cuda", ): + # Skip invalid configurations + quantized_input = quantization is not None + maybe_skip_quantization(quantization, dims=shape, device=device) + # Random data - x_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5 - x_test = x_ref.clone().requires_grad_() - dy_ref = torch.rand(shape, dtype=dtype, device=device) + 0.5 - dy_test = dy_ref.clone() + # Note: Shift values to make sure inputs are non-zero + x_ref, x_test = make_reference_and_test_tensors( + shape, + quantization=quantization, + test_dtype=dtype, + test_device=device, + test_is_quantized=quantized_input, + ) + with torch.no_grad(): + x_test += 1 + x_ref.copy_(x_test) + dy_ref, dy_test = make_reference_and_test_tensors( + shape, + test_dtype=dtype, + test_device=device, + requires_grad=False, + ) # Apply dropout op = te_ops.Dropout(prob) @@ -1775,17 +1794,20 @@ def test_dropout( op.train() else: op.eval() - y = op(x_test) - y.backward(dy_test) + y_test = op(x_test) + y_test.backward(dy_test) # Check values + y_test = y_test.to(dtype=torch.float64, device="cpu") + dx_test = x_test.grad.to(dtype=torch.float64, device="cpu") if is_training: - mask = ((y != 0) / (1 - prob)).to(dtype=dtype) - torch.testing.assert_close(y, x_ref * mask) - torch.testing.assert_close(x_test.grad, dy_ref * mask) + tols = dtype_tols(dtype) + mask = ((y_test != 0) / (1 - prob)).to(dtype=dtype) + torch.testing.assert_close(y_test, x_ref * mask, **tols) + torch.testing.assert_close(dx_test, dy_ref * mask, **tols) else: - torch.testing.assert_close(y, x_ref, rtol=0, atol=0) - torch.testing.assert_close(x_test.grad, dy_ref, rtol=0, atol=0) + torch.testing.assert_close(y_test, x_ref, rtol=0, atol=0) + torch.testing.assert_close(dx_test, dy_ref, rtol=0, atol=0) # Hypothesis testing for number of zeros # Note: A Bernoulli random variable with probability p has @@ -1797,9 +1819,11 @@ def test_dropout( # p-value is less than 1% and we assume that the dropout # distribution is incorrect. if is_training: - prob_observed = 1 - torch.count_nonzero(y).item() / y.numel() - z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y.numel()) - assert abs(z_score) < 2.5758, "Number of zeros is outside 99% confidence interval" + prob_observed = 1 - torch.count_nonzero(y_test).item() / y_test.numel() + z_score = (prob_observed - prob) / math.sqrt(prob * (1 - prob) / y_test.numel()) + assert ( + abs(z_score) < 2.5758 + ), f"Number of zeros is outside 99% confidence interval ({prob=}, {prob_observed=})" class TestFusedOps: diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 183a7a72e..cb9f13b89 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -69,6 +69,7 @@ list(APPEND transformer_engine_SOURCES transpose/quantize_transpose_vector_blockwise.cu transpose/swap_first_dims.cu activation/gelu.cu + dropout/dropout.cu fused_attn/flash_attn.cu fused_attn/context_parallel.cu fused_attn/kv_cache.cu diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 834c4fe25..7feb5fda5 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -294,6 +294,38 @@ def _load_nvrtc(): return ctypes.CDLL(f"libnvrtc{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) +@functools.lru_cache(maxsize=None) +def _load_curand(): + """Load cuRAND shared library.""" + # Attempt to locate cuRAND in CUDA_HOME, CUDA_PATH or /usr/local/cuda + cuda_home = os.environ.get("CUDA_HOME") or os.environ.get("CUDA_PATH") or "/usr/local/cuda" + libs = glob.glob(f"{cuda_home}/**/libcurand{_get_sys_extension()}*", recursive=True) + libs = list(filter(lambda x: not ("stub" in x), libs)) + libs.sort(reverse=True, key=os.path.basename) + if libs: + return ctypes.CDLL(libs[0], mode=ctypes.RTLD_GLOBAL) + + # Attempt to locate cuRAND in Python dist-packages + found, handle = _load_nvidia_cuda_library("curand") + if found: + return handle + + # Attempt to locate cuRAND via ldconfig + libs = subprocess.check_output( + f"ldconfig -p | grep 'libcurand{_get_sys_extension()}'", shell=True + ) + libs = libs.decode("utf-8").split("\n") + sos = [] + for lib in libs: + if "libcurand" in lib and "=>" in lib: + sos.append(lib.split(">")[1].strip()) + if sos: + return ctypes.CDLL(sos[0], mode=ctypes.RTLD_GLOBAL) + + # If all else fails, assume that it is in LD_LIBRARY_PATH and error out otherwise + return ctypes.CDLL(f"libcurand{_get_sys_extension()}", mode=ctypes.RTLD_GLOBAL) + + @functools.lru_cache(maxsize=None) def _load_core_library(): """Load shared library with Transformer Engine C extensions""" @@ -303,6 +335,7 @@ def _load_core_library(): if "NVTE_PROJECT_BUILDING" not in os.environ or bool(int(os.getenv("NVTE_RELEASE_BUILD", "0"))): _CUDNN_LIB_CTYPES = _load_cudnn() _NVRTC_LIB_CTYPES = _load_nvrtc() + _CURAND_LIB_CTYPES = _load_curand() _CUBLAS_LIB_CTYPES = _load_nvidia_cuda_library("cublas") _CUDART_LIB_CTYPES = _load_nvidia_cuda_library("cuda_runtime") _TE_LIB_CTYPES = _load_core_library() diff --git a/transformer_engine/common/dropout/dropout.cu b/transformer_engine/common/dropout/dropout.cu new file mode 100644 index 000000000..bab349161 --- /dev/null +++ b/transformer_engine/common/dropout/dropout.cu @@ -0,0 +1,355 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include +#include +#include + +#include + +#include "../common.h" +#include "../utils.cuh" +#include "transformer_engine/dropout.h" + +namespace transformer_engine { +namespace { + +// RNG kernels process chunks of 16 entries +constexpr size_t rng_chunk_size = 16; + +// CUDA block size +constexpr size_t block_size = 128; + +// Vector class to help with vectorized memory accesses +template +union Vector { + using StorageType = typename BytesToType::Type; + StorageType storage; + T entries[kSize]; +}; + +/* Byte-wise less-than comparison + * + * Results are stored in each byte's most-significant bit (MSB). All + * other bits are zero. + */ +__device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) { + // Compare low bits by masking MSBs and subtracting. The resulting + // MSBs are 0 if the low bits of a are less than the low bits of b. + uint32_t result = (a | 0x80808080) - (b & 0x7F7F7F7F); + + // Bitwise logical op to get answer in MSBs + // Equivalent logic: result = (a == b) ? !result : b + asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result)); + + // Mask out everything except MSBs and return + result &= 0x80808080; + return result; +} + +/* Generate dropout mask with 16 bits. + * + * 1 corresponds to keep and 0 to drop. + * + * Consumes 4 values from cuRAND Philox generator. + */ +__device__ __forceinline__ uint16_t make_16bit_mask(uint64_t chunk_idx, uint64_t rng_seed, + uint64_t rng_offset, + uint32_t bytewise_drop_prob) { + // Generate random bits + curandStatePhilox4_32_10_t state; + curand_init(rng_seed, chunk_idx, rng_offset, &state); + const uint4 rand_bits = curand4(&state); + + // Compute mask + // Note: bytewise_less_than fills MSBs (bits 7, 15, 23, 31). By + // shifting 2 bits after every call, every other bit will be filled. + uint32_t result = bytewise_less_than(rand_bits.x, bytewise_drop_prob); + result = (result >> 2) | bytewise_less_than(rand_bits.y, bytewise_drop_prob); + result = (result >> 2) | bytewise_less_than(rand_bits.z, bytewise_drop_prob); + result = (result >> 2) | bytewise_less_than(rand_bits.w, bytewise_drop_prob); + + // Consolidate mask in lowest 16 bits + result |= result >> 17; + + // Flip bits so 0 corresponds to drop + result = ~result; + + return result; +} + +// Dropout forward with FP16/BF16 input and output. +template +__global__ void __launch_bounds__(block_size) + dropout_kernel_fwd_f16(const T *__restrict__ input_ptr, T *__restrict__ output_ptr, + uint8_t *__restrict__ mask_ptr, + const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks, + uint32_t bytewise_drop_prob, float scale) { + static_assert(sizeof(T) == 2); + + // Each thread processes a chunk of 16 entries + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) { + // Generate dropout mask + auto local_mask = + make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob); + reinterpret_cast(mask_ptr)[chunk_idx] = local_mask; + + // Read input data + using VectorType = Vector; + VectorType local_data; + local_data = reinterpret_cast(input_ptr)[chunk_idx]; + + // Apply dropout based on mask +#pragma unroll + for (size_t i = 0; i < rng_chunk_size; i++) { + float val = static_cast(local_data.entries[i]); + if ((local_mask & 0x1) == 0) { + val = 0; + } + val *= scale; + local_data.entries[i] = static_cast(val); + local_mask >>= 1; + } + + // Write output data + reinterpret_cast(output_ptr)[chunk_idx] = local_data; + } +} + +// Dropout forward with FP8 input and FP16/BF16 output. +template +__global__ void __launch_bounds__(block_size) + dropout_kernel_fwd_fp8(const InputType *__restrict__ input_ptr, + const float *__restrict__ input_scale_inv_ptr, + OutputType *__restrict__ output_ptr, uint8_t *__restrict__ mask_ptr, + const uint64_t *__restrict__ rng_state_ptr, size_t num_chunks, + uint32_t bytewise_drop_prob, float scale) { + static_assert(sizeof(InputType) == 1); + static_assert(sizeof(OutputType) == 2); + const float input_scale_inv = *input_scale_inv_ptr; + + // Each thread processes a chunk of 16 entries + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) { + // Generate dropout mask + auto local_mask = + make_16bit_mask(chunk_idx, rng_state_ptr[0], rng_state_ptr[1], bytewise_drop_prob); + reinterpret_cast(mask_ptr)[chunk_idx] = local_mask; + + // Read input data + using InputVectorType = Vector; + InputVectorType local_input; + local_input = reinterpret_cast(input_ptr)[chunk_idx]; + + // Apply dropout based on mask + using OutputVectorType = Vector; + OutputVectorType local_output; +#pragma unroll + for (size_t i = 0; i < rng_chunk_size; i++) { + float val = static_cast(local_input.entries[i]); + val *= input_scale_inv; + if ((local_mask & 0x1) == 0) { + val = 0; + } + val *= scale; + local_output.entries[i] = static_cast(val); + local_mask >>= 1; + } + + // Write output data + reinterpret_cast(output_ptr)[chunk_idx] = local_output; + } +} + +// Apply dropout mask and scale. +template +__global__ void __launch_bounds__(block_size) + apply_dropout_mask(const T *__restrict__ input_ptr, const uint8_t *__restrict__ mask_ptr, + T *__restrict__ output_ptr, size_t num_chunks, float scale) { + // Each thread processes a chunk of 8 entries. + const size_t gid = threadIdx.x + blockIdx.x * block_size; + const size_t nthreads = gridDim.x * block_size; + constexpr size_t chunk_size = 8; + for (size_t chunk_idx = gid; chunk_idx < num_chunks; chunk_idx += nthreads) { + // Read dropout mask + uint8_t local_mask = mask_ptr[chunk_idx]; + + // Read input data + using VectorType = Vector; + VectorType local_data; + local_data = reinterpret_cast(input_ptr)[chunk_idx]; + + // Apply dropout based on mask +#pragma unroll + for (size_t i = 0; i < chunk_size; i++) { + float val = static_cast(local_data.entries[i]); + if ((local_mask & 0x1) == 0) { + val = 0; + } + val *= scale; + local_data.entries[i] = static_cast(val); + local_mask >>= 1; + } + + // Write output data + reinterpret_cast(output_ptr)[chunk_idx] = local_data; + } +} + +} // namespace + +void dropout_fwd(const Tensor &input, Tensor &output, Tensor &mask, Tensor &rng_state, + float dropout_probability, cudaStream_t stream) { + // Check tensors + const size_t numel = input.numel(); + NVTE_CHECK(input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ", + "but scaling mode is ", to_string(input.scaling_mode), "."); + NVTE_CHECK(output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Output tensor must be FP16/BF16 tensor, ", "but scaling mode is ", + to_string(output.scaling_mode), "."); + NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, "Mask tensor must be plain tensor, ", + "but scaling mode is ", to_string(mask.scaling_mode), "."); + NVTE_CHECK(rng_state.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "RNG state tensor must be INT64 tensor with two entries, ", "but scaling mode is ", + to_string(rng_state.scaling_mode), "."); + NVTE_CHECK(output.dtype() == DType::kFloat16 || output.dtype() == DType::kBFloat16, + "Output tensor must be FP16/BF16 tensor, but dtype is ", to_string(output.dtype()), + "."); + NVTE_CHECK(rng_state.dtype() == DType::kInt64, + "RNG state tensor must be INT64 tensor with two entries, but dtype is ", + to_string(rng_state.dtype()), "."); + NVTE_CHECK(numel % 16 == 0, + "Input tensor number of elements must be divisible by 16, but shape is ", + input.shape(), "."); + NVTE_CHECK(numel == output.numel(), "Input tensor (shape=", input.shape(), + ") and output tensor (shape=", output.shape(), ") do not match."); + NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel, + " bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), "."); + NVTE_CHECK(rng_state.numel() == 2, "RNG state tensor must be INT64 tensor with two entries, ", + "but shape is ", rng_state.shape(), "."); + NVTE_CHECK(input.data.dptr != nullptr, "Input tensor is missing data."); + NVTE_CHECK(output.data.dptr != nullptr, "Output tensor is missing data."); + NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data."); + NVTE_CHECK(rng_state.data.dptr != nullptr, "RNG state tensor is missing data."); + + // Convert dropout probablity to scale and 8-bit representation + NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (", + dropout_probability, ")."); + const float scale = 1 / (1 - dropout_probability); + uint32_t bytewise_drop_prob = static_cast(std::floor(dropout_probability * 256)); + bytewise_drop_prob |= bytewise_drop_prob << 8; + bytewise_drop_prob |= bytewise_drop_prob << 16; + + // CUDA config + const size_t num_chunks = numel / rng_chunk_size; + const size_t num_blocks = DIVUP(num_chunks, block_size); + + // Launch kernel depending on input dtype + if (input.dtype() == DType::kFloat16 || input.dtype() == DType::kBFloat16) { + NVTE_CHECK(input.dtype() == output.dtype(), "Input tensor (dtype=", to_string(input.dtype()), + ") and output tensor (dtype=", to_string(output.dtype()), ") do not match."); + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + input.dtype(), DType, + dropout_kernel_fwd_f16<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(output.data.dptr), + reinterpret_cast(mask.data.dptr), + reinterpret_cast(rng_state.data.dptr), num_chunks, bytewise_drop_prob, + scale);); + NVTE_CHECK_CUDA(cudaGetLastError()); + } else if (input.dtype() == DType::kFloat8E4M3 || input.dtype() == DType::kFloat8E5M2) { + NVTE_CHECK(input.scale_inv.dptr != nullptr, "Input tensor scale-inverse is not allocated."); + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( + input.dtype(), InputType, + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + output.dtype(), OutputType, + dropout_kernel_fwd_fp8<<>>( + reinterpret_cast(input.data.dptr), + reinterpret_cast(input.scale_inv.dptr), + reinterpret_cast(output.data.dptr), + reinterpret_cast(mask.data.dptr), + reinterpret_cast(rng_state.data.dptr), num_chunks, + bytewise_drop_prob, scale); + + );); + NVTE_CHECK_CUDA(cudaGetLastError()); + } else { + NVTE_ERROR("Input tensor must be FP16/BF16 tensor or tensor-scaled FP8 tensor, ", + "but dtype is ", to_string(input.dtype()), "."); + } +} + +void dropout_bwd(const Tensor &grad_output, const Tensor &mask, Tensor &grad_input, + float dropout_probability, cudaStream_t stream) { + // Check tensors + const size_t numel = grad_output.numel(); + NVTE_CHECK(grad_output.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Grad output tensor must be FP16/BF16 tensor, ", "but scaling mode is ", + to_string(grad_output.scaling_mode), "."); + NVTE_CHECK(grad_input.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Grad input tensor must be FP16/BF16 tensor, ", "but scaling mode is ", + to_string(grad_input.scaling_mode), "."); + NVTE_CHECK(mask.scaling_mode == NVTE_DELAYED_TENSOR_SCALING, + "Mask tensor must be a plain tensor, but scaling mode is ", + to_string(mask.scaling_mode), "."); + NVTE_CHECK(grad_output.dtype() == DType::kFloat16 || grad_output.dtype() == DType::kBFloat16, + "Grad output tensor must be FP16/BF16 tensor, but dtype is ", + to_string(grad_output.dtype()), "."); + NVTE_CHECK(grad_output.dtype() == grad_input.dtype(), + "Grad output tensor (dtype=", to_string(grad_output.dtype()), + ") and grad input tensor (dtype=", to_string(grad_input.dtype()), ") do not match."); + NVTE_CHECK(numel % 16 == 0, + "Grad output tensor number of elements must be divisible by 16, but shape is ", + grad_output.shape(), "."); + NVTE_CHECK(numel == grad_input.numel(), "Grad output tensor (shape=", grad_output.shape(), + ") and grad input tensor (shape=", grad_input.shape(), ") do not match."); + NVTE_CHECK(typeToNumBits(mask.dtype()) * mask.numel() == numel, "Mask tensor must have ", numel, + " bits, but found dtype=", to_string(mask.dtype()), " and shape=", mask.shape(), "."); + NVTE_CHECK(grad_output.data.dptr != nullptr, "Grad output tensor is missing data."); + NVTE_CHECK(grad_input.data.dptr != nullptr, "Grad input tensor is missing data."); + NVTE_CHECK(mask.data.dptr != nullptr, "Mask tensor is missing data."); + + // Convert dropout probablity to scale + NVTE_CHECK(dropout_probability >= 0 && dropout_probability < 1, "Invalid dropout probability (", + dropout_probability, ")."); + const float scale = 1 / (1 - dropout_probability); + + // CUDA config + const size_t num_chunks = numel / 8; + const size_t num_blocks = DIVUP(num_chunks, block_size); + + // Launch kernel + TRANSFORMER_ENGINE_TYPE_SWITCH_16BIT( + grad_output.dtype(), DType, + apply_dropout_mask<<>>( + reinterpret_cast(grad_output.data.dptr), + reinterpret_cast(mask.data.dptr), + reinterpret_cast(grad_input.data.dptr), num_chunks, scale);); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +} // namespace transformer_engine + +void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask, + NVTETensor rng_state, float dropout_probability, cudaStream_t stream) { + NVTE_API_CALL(nvte_dropout_fwd); + using namespace transformer_engine; + dropout_fwd(*convertNVTETensorCheck(input), *convertNVTETensorCheck(output), + *convertNVTETensorCheck(mask), *convertNVTETensorCheck(rng_state), + dropout_probability, stream); +} + +void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input, + float dropout_probability, cudaStream_t stream) { + NVTE_API_CALL(nvte_dropout_bwd); + using namespace transformer_engine; + dropout_bwd(*convertNVTETensorCheck(grad_output), *convertNVTETensorCheck(mask), + *convertNVTETensorCheck(grad_input), dropout_probability, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/dropout.h b/transformer_engine/common/include/transformer_engine/dropout.h new file mode 100644 index 000000000..6ba1ab912 --- /dev/null +++ b/transformer_engine/common/include/transformer_engine/dropout.h @@ -0,0 +1,51 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +/*! \file dropout.h + * \brief Functions for dropout. + */ + +#ifndef TRANSFORMER_ENGINE_DROPOUT_FP8_H_ +#define TRANSFORMER_ENGINE_DROPOUT_FP8_H_ + +#include "transformer_engine.h" + +#ifdef __cplusplus +extern "C" { +#endif + +/*! \brief Dropout forward kernel. + * + * \param[in] input Input tensor. + * \param[out] output Output tensor. + * \param[out] mask Mask tensor. Each bit corresponds to an + * output tensor entry. Ones indicate kept + * entries and zeros indicate dropped entries. + * \param[in] rng_state RNG engine inputs. + * \param[in] dropout_probability Dropout probability. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_dropout_fwd(const NVTETensor input, NVTETensor output, NVTETensor mask, + NVTETensor rng_state, float dropout_probability, cudaStream_t stream); + +/*! \brief Dropout backward kernel. + * + * \param[in] grad_output Gradient of output tensor. + * \param[out] mask Mask tensor. Each bit corresponds to an + * output tensor entry. Ones indicate kept + * entries and zeros indicate dropped entries. + * \param[out] grad_input Gradient of input tensor. + * \param[in] dropout_probability Dropout probability. + * \param[in] stream CUDA stream used for this operation. + */ +void nvte_dropout_bwd(const NVTETensor grad_output, const NVTETensor mask, NVTETensor grad_input, + float dropout_probability, cudaStream_t stream); + +#ifdef __cplusplus +} // extern "C" +#endif + +#endif diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d0e92a59b..a6b65562e 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -265,6 +265,17 @@ std::vector dbias_dqgelu(const at::Tensor &grad_output, const at::Te std::vector dbias_dsrelu(const at::Tensor &grad_output, const at::Tensor &act_input, py::handle quantizer); +/*************************************************************************************************** + * Dropout + **************************************************************************************************/ + +std::vector dropout_fwd(const py::handle &input, const float dropout_probability, + std::optional out = std::nullopt); + +py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask, + const float dropout_probability, + std::optional grad_input = std::nullopt); + /*************************************************************************************************** * Softmax **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/dropout.cpp b/transformer_engine/pytorch/csrc/extensions/dropout.cpp new file mode 100644 index 000000000..e6f29d0da --- /dev/null +++ b/transformer_engine/pytorch/csrc/extensions/dropout.cpp @@ -0,0 +1,89 @@ +/************************************************************************* + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "transformer_engine/dropout.h" + +#include +#include + +#include + +#include "../common.h" +#include "../extensions.h" +#include "../pybind.h" +#include "transformer_engine/transformer_engine.h" + +namespace transformer_engine { +namespace pytorch { + +std::vector dropout_fwd(const py::handle &input, float dropout_probability, + std::optional out) { + using namespace transformer_engine::pytorch::detail; + + // Input tensor + const TensorWrapper input_nvte = makeTransformerEngineTensor(input, py::none()); + + // Allocate output tensor if needed + if (!out) { + at::ScalarType dtype = GetATenDType(input_nvte.dtype()); + if (dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2) { + dtype = input.attr("dtype").cast(); + } + const auto shape_uint64 = convertShape(input_nvte.shape()); + const std::vector shape_int64(shape_uint64.begin(), shape_uint64.end()); + const auto opts = at::TensorOptions().dtype(dtype).device(torch::kCUDA); + out = at::empty(shape_int64, opts); + } + TensorWrapper out_nvte = makeTransformerEngineTensor(*out); + + // Mask tensor + auto mask_pyt = allocateTorchTensor(input_nvte.numel() / 8, DType::kByte); + auto mask_nvte = makeTransformerEngineTensor(mask_pyt); + + // RNG state tensor + auto gen = at::get_generator_or_default( + std::nullopt, at::cuda::detail::getDefaultCUDAGenerator()); + at::PhiloxCudaState philox_args; + { + std::lock_guard lock(gen->mutex_); + constexpr int64_t rng_elts_per_thread = 4; + philox_args = gen->philox_cuda_state(rng_elts_per_thread); + } + auto rng_state_pyt = allocateTorchTensor(2, DType::kInt64); + NVTE_SCOPED_GIL_RELEASE({ + nvte_extract_seed_and_offset( + reinterpret_cast(rng_state_pyt.data_ptr()), philox_args.captured_, + philox_args.seed_.ptr, philox_args.seed_.val, philox_args.offset_.ptr, + philox_args.offset_.val, philox_args.offset_intragraph_, at::cuda::getCurrentCUDAStream()); + }); + auto rng_state_nvte = makeTransformerEngineTensor(rng_state_pyt); + + // Launch kernel + NVTE_SCOPED_GIL_RELEASE({ + nvte_dropout_fwd(input_nvte.data(), out_nvte.data(), mask_nvte.data(), rng_state_nvte.data(), + dropout_probability, at::cuda::getCurrentCUDAStream()); + }); + + return {py::cast(std::move(*out)), py::cast(mask_pyt)}; +} + +py::object dropout_bwd(const at::Tensor &grad_output, const at::Tensor &mask, + const float dropout_probability, std::optional grad_input) { + const auto grad_output_nvte = makeTransformerEngineTensor(grad_output); + const auto mask_nvte = makeTransformerEngineTensor(mask); + if (!grad_input) { + grad_input = at::empty_like(grad_output); + } + auto grad_input_nvte = makeTransformerEngineTensor(*grad_input); + NVTE_SCOPED_GIL_RELEASE({ + nvte_dropout_bwd(grad_output_nvte.data(), mask_nvte.data(), grad_input_nvte.data(), + dropout_probability, at::cuda::getCurrentCUDAStream()); + }); + return py::cast(std::move(*grad_input)); +} + +} // namespace pytorch +} // namespace transformer_engine diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 6442b05da..541b16848 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -305,6 +305,13 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("Const_buf"), py::arg("tokens_per_expert"), py::arg("num_rows"), py::arg("num_cols"), py::arg("grad_aux_loss"), "Fused aux loss bwd"); + // Dropout + m.def("dropout_fwd", transformer_engine::pytorch::dropout_fwd, "Dropout forward with 8-bit RNG", + py::arg("input"), py::arg("dropout_probability"), py::arg("out") = std::nullopt); + m.def("dropout_bwd", transformer_engine::pytorch::dropout_bwd, "Dropout backward with 8-bit RNG", + py::arg("grad_output"), py::arg("mask"), py::arg("dropout_probability"), + py::arg("grad_input") = std::nullopt); + // Misc m.def("get_cublasLt_version", &transformer_engine::pytorch::get_cublasLt_version, "Get cublasLt version", py::call_guard()); diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index 958e9b06c..f0f55322c 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -8,12 +8,11 @@ from typing import Optional import torch - -from transformer_engine.pytorch.ops.op import ( - BasicOperation, - OperationContext, -) +import transformer_engine_torch as tex from ...tensor import Quantizer +from ...tensor._internal.float8_tensor_base import Float8TensorBase +from .._common import maybe_autocast_dtype, maybe_dequantize +from ..op import BasicOperation, OperationContext class Dropout(BasicOperation): @@ -27,7 +26,7 @@ class Dropout(BasicOperation): def __init__(self, p: float) -> None: super().__init__() - self.dropout_probability = p + self.dropout_probability: float = p def op_forward( self, @@ -37,21 +36,44 @@ def op_forward( next_op_input_quantizer: Optional[Quantizer], ) -> torch.Tensor: - # Compute dropout if training - out = input_ - is_training = self.training - mask = None - if is_training: + # Output dtype + dtype = maybe_autocast_dtype(default_dtype=input_.dtype) + + # Choose implementation + impl = None + if not self.training: + impl = "evaluation" + elif input_.numel() % 16 == 0 and dtype in (torch.float16, torch.bfloat16): + impl = "fused" + else: + impl = "unfused" + + # Perform dropout + out: torch.Tensor + mask: Optional[torch.Tensor] = None + if impl == "evaluation": + out = input_ + elif impl == "fused": + x = input_ + if not isinstance(x, Float8TensorBase): + x = maybe_dequantize(x, dtype=dtype) + out, mask = tex.dropout_fwd(x, self.dropout_probability) + elif impl == "unfused": + x = maybe_dequantize(input_, dtype=dtype) keep_prob = 1 - self.dropout_probability - mask = torch.empty_like(input_) + mask = torch.empty_like(x) mask.bernoulli_(keep_prob) mask *= 1 / keep_prob - out = out * mask + out = x * mask + else: + raise ValueError(f"Unsupported forward implementation {impl}") # Save context for backward if ctx.requires_grad: ctx.save_for_backward(mask) - ctx.is_training = is_training + ctx.impl = impl + ctx.dropout_probability = self.dropout_probability + ctx.dtype = dtype return out @@ -60,8 +82,21 @@ def op_backward( ctx: OperationContext, grad_output: torch.Tensor, ) -> tuple[torch.Tensor, tuple[()]]: + + # Saved tensors from forward pass (mask,) = ctx.saved_tensors - grad_input = grad_output - if ctx.is_training: - grad_input = grad_input * mask + + # Perform dropout backward pass + grad_input: torch.Tensor + if ctx.impl == "evaluation": + grad_input = grad_output + elif ctx.impl == "fused": + dy = maybe_dequantize(grad_output, dtype=ctx.dtype) + grad_input = tex.dropout_bwd(dy, mask, ctx.dropout_probability) + elif ctx.impl == "unfused": + dy = maybe_dequantize(grad_output, dtype=ctx.dtype) + grad_input = dy * mask + else: + raise ValueError(f"Unsupported backward implementation {ctx.impl}") + return grad_input, () From 67fcc15255248a26be124de3854a47f84102f285 Mon Sep 17 00:00:00 2001 From: Selvaraj Anandaraj Date: Tue, 2 Sep 2025 02:14:13 -0700 Subject: [PATCH 120/153] Create GPU reload buffers on main stream (#2131) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> --- transformer_engine/pytorch/cpu_offload.py | 30 +++++++++++++++-------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/cpu_offload.py b/transformer_engine/pytorch/cpu_offload.py index 3fdf8b14f..179c80a65 100644 --- a/transformer_engine/pytorch/cpu_offload.py +++ b/transformer_engine/pytorch/cpu_offload.py @@ -551,17 +551,23 @@ def bulk_reload_group(self, group_to_reload): buffer_idx = 0 double_buffer_idx = group_to_reload % 2 + main_stream = torch.cuda.current_stream() + with torch.cuda.stream(self.h2d_stream): # move back tensors for tensor_label, state in self.tensor_tag_to_state.items(): group_id, _ = tensor_label if group_id == group_to_reload: - if self.double_buffering: - reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] - else: - reload_buffer = None if isinstance(state, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][buffer_idx] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state[1], device=torch.cuda.current_device() + ) + recovered_tensor = SynchronizedGroupOffloadHandler.reload( state, True, reload_buffer ) @@ -570,14 +576,18 @@ def bulk_reload_group(self, group_to_reload): elif isinstance(state, list): tensor_list = [] for state_tuple in state: - if self.double_buffering: - reload_buffer = self.reload_double_buffer[double_buffer_idx][ - buffer_idx - ] - else: - reload_buffer = None if isinstance(state_tuple, tuple): + if self.double_buffering: + reload_buffer = self.reload_double_buffer[double_buffer_idx][ + buffer_idx + ] + else: + with torch.cuda.stream(main_stream): + reload_buffer = torch.empty_like( + state_tuple[1], device=torch.cuda.current_device() + ) + tensor_list.append( SynchronizedGroupOffloadHandler.reload( state_tuple, From 3b4366be34ec8e5b96c73ba20ec22f0dfac1b97a Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Wed, 3 Sep 2025 19:27:04 +1200 Subject: [PATCH 121/153] Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../pytorch/comm_gemm_overlap/te_layer_with_overlap.py | 6 +++++- tests/pytorch/distributed/run_layer_with_overlap.py | 8 ++++++-- .../distributed/test_fusible_ops_with_userbuffers.py | 4 ++-- transformer_engine/pytorch/module/base.py | 2 +- 4 files changed, 14 insertions(+), 6 deletions(-) diff --git a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py index eeb79c235..d52e97d65 100644 --- a/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py +++ b/examples/pytorch/comm_gemm_overlap/te_layer_with_overlap.py @@ -264,7 +264,11 @@ def dist_print(msg, end="\n", group=nccl_world, src=0, debug=False, error=False) [batched_size, hidden_size], tp_size, quantization_modes=[ - UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE + ( + te.module.base.UserBufferQuantizationMode.FP8 + if opts.fp8 + else te.module.base.UserBufferQuantizationMode.NONE + ) ], dtype=torch.bfloat16, bootstrap_backend=opts.bootstrap_backend, diff --git a/tests/pytorch/distributed/run_layer_with_overlap.py b/tests/pytorch/distributed/run_layer_with_overlap.py index 1dabf6e45..2a6e55b2c 100644 --- a/tests/pytorch/distributed/run_layer_with_overlap.py +++ b/tests/pytorch/distributed/run_layer_with_overlap.py @@ -420,10 +420,14 @@ def dist_print(msg, src=None, end="\n", debug=False, error=False): } quantization_modes = [ - UserBufferQuantizationMode.FP8 if opts.fp8 else UserBufferQuantizationMode.NONE + ( + te.module.base.UserBufferQuantizationMode.FP8 + if opts.fp8 + else te.module.base.UserBufferQuantizationMode.NONE + ) ] if opts.first_last_layers_bf16 and opts.fp8: - quantization_modes.append(UserBufferQuantizationMode.NONE) + quantization_modes.append(te.module.base.UserBufferQuantizationMode.NONE) te.module.base.initialize_ub( [opts.seq_length * opts.batch_size, opts.num_heads * opts.head_dim], diff --git a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py index 17d351292..d6ddfe27c 100644 --- a/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py +++ b/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py @@ -508,9 +508,9 @@ def main() -> None: torch.distributed.get_world_size(group), quantization_modes=[ ( - UserBufferQuantizationMode.FP8 + te.module.base.UserBufferQuantizationMode.FP8 if model_config.quantization is not None - else UserBufferQuantizationMode.NONE + else te.module.base.UserBufferQuantizationMode.NONE ) ], dtype=model_config.dtype, diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 3bbfaacdf..a6275abd1 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -473,7 +473,7 @@ def add_ub( fp8_buf = (name in layers_all_gather_overlap) or ( user_ub_cfg[name].get("fp8_buf", False) and name in methods["pipeline"] ) - ub_cfg.update(ub_cfgs[name]) + ub_cfg.update(user_ub_cfg[name]) ub_cfg["fp8_buf"] = fp8_buf add_ub(name, quantization_mode, **ub_cfg) From f378eaf2899f1148c68b567f62595e940556da7f Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Wed, 3 Sep 2025 14:15:52 -0700 Subject: [PATCH 122/153] [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/jax/test_fused_attn.py | 9 +++++++++ transformer_engine/jax/cpp_extensions/attention.py | 6 ++++++ transformer_engine/jax/cpp_extensions/misc.py | 10 ++++++++++ 3 files changed, 25 insertions(+) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index ec530a395..87dfc113c 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -41,6 +41,7 @@ from transformer_engine_jax import ( NVTE_Fused_Attn_Backend, get_cudnn_version, + get_device_compute_capability, ) from distributed_test_base import assert_equal_collectives @@ -348,6 +349,14 @@ def _check_configs(self): "seqlen_q > seqlen_kv is not supported with sliding window attention in cuDNN" ) + if ( + get_device_compute_capability(0) == 100 + and self.dropout_prob == 0.1 + and self.attn_bias_type is not AttnBiasType.NO_BIAS + ): + pytest.skip( + "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + ) # Test the MLA case where head dims for qk differ from head dims for v, only if the tensors # are provided in BSHD_BSHD_BSHD or THD_THD_THD formats if self.head_dim_qk != self.head_dim_v and not self.qkv_layout.is_separate(): diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 089ef75f1..df89174b2 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -34,6 +34,7 @@ te_dtype_to_jax_dtype, get_padded_spec, get_cudnn_version, + get_all_device_compute_capability, ) from ..sharding import ( global_mesh_resource, @@ -2745,6 +2746,11 @@ def fused_attn_bwd( assert bias is None bias = jnp.zeros(0, dtype=qkv[0].dtype) + if 100 in get_all_device_compute_capability(): + assert not ( + attn_bias_type != AttnBiasType.NO_BIAS and dropout_probability != 0 + ), "For sm100, bprop kernel support for dropout + determinism (bias) is not supported" + fused_config = _FusedAttnConfig( attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, diff --git a/transformer_engine/jax/cpp_extensions/misc.py b/transformer_engine/jax/cpp_extensions/misc.py index 94dfaa45a..3bda37128 100644 --- a/transformer_engine/jax/cpp_extensions/misc.py +++ b/transformer_engine/jax/cpp_extensions/misc.py @@ -193,6 +193,16 @@ def get_min_device_compute_capability(): ) +def get_all_device_compute_capability(): + """ + Returns a list of compute capability of all local devices. + """ + return tuple( + transformer_engine_jax.get_device_compute_capability(local_gpu_id) + for local_gpu_id in range(len(jax.local_devices())) + ) + + def should_apply_1x_fused_dbias_war_for_arch_l_100(is_dbias: bool = False, quantizer=None): """ Fused dbias is not supported for arch < 100 for 1x quantization, so we need to apply a workaround to From 0f68f7b2f9e6e94d7037513942389432f9e58d68 Mon Sep 17 00:00:00 2001 From: Zhongbo Zhu <42691305+zhongbozhu@users.noreply.github.com> Date: Thu, 4 Sep 2025 10:11:33 -0700 Subject: [PATCH 123/153] [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- .../include/transformer_engine/recipe.h | 15 +++++ .../common/recipe/current_scaling.cu | 66 ++++++++++++++++--- .../common/transpose/cast_transpose.h | 5 +- .../quantize_transpose_square_blockwise.cu | 20 ++++-- .../quantize_transpose_vector_blockwise.cu | 18 +++-- .../common/util/cast_kernels.cuh | 11 ++-- transformer_engine/pytorch/csrc/quantizer.cpp | 3 +- 7 files changed, 110 insertions(+), 28 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index 50fb696ea..2fc8c1095 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -84,6 +84,21 @@ void nvte_delayed_scaling_recipe_amax_and_scale_update_after_reduction( */ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); +/*! \brief Compute an FP8 tensor's amax with quantization config. + * + * The amax (maximum absolute value) of the input tensor is computed + * and written to the amax buffer of the output tensor, using the provided + * quantization configuration. + * One useful config is the noop tensor, which is needed by cuda graph. + * + * \param[in] input Input tensor. Must be unquantized. + * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. + * \param[in] config Quantization configuration. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig config, cudaStream_t stream); + /*! \brief Update an FP8 tensor's scale based on its amax. * * This is only supported for FP8 tensors with per-tensor scaling. diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index e1657b77a..fd907efcb 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -23,7 +23,11 @@ constexpr int amax_kernel_threads = 512; template __launch_bounds__(amax_kernel_threads) __global__ void amax_kernel(const InputType *input, float *amax, const size_t N, - const size_t num_aligned_elements) { + const size_t num_aligned_elements, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + VectorizedLoader loader(input, N); InputType max = 0.f; const int warp_id = threadIdx.x / THREADS_PER_WARP; @@ -58,7 +62,8 @@ __launch_bounds__(amax_kernel_threads) __global__ } template -void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cudaStream_t stream) { +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, const float *noop_ptr, + cudaStream_t stream) { // Zero out amax so we can update with atomic max NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); @@ -81,16 +86,17 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud switch (align) { case Alignment::SAME_ALIGNED: amax_kernel - <<>>(input, amax, N, num_aligned_elements); + <<>>(input, amax, N, num_aligned_elements, noop_ptr); break; case Alignment::SAME_UNALIGNED: amax_kernel - <<>>(input, amax, N, num_aligned_elements); + <<>>(input, amax, N, num_aligned_elements, noop_ptr); break; case Alignment::DIFFERENT: { // This case is a logic error, since there is only one pointer (input) // in the alignment check. Still safe to process without vectorization. - amax_kernel<1, true, InputType><<>>(input, amax, N, N); + amax_kernel<1, true, InputType> + <<>>(input, amax, N, N, noop_ptr); break; } } @@ -102,8 +108,10 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, cud } // namespace } // namespace transformer_engine -void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { - NVTE_API_CALL(nvte_compute_amax); +namespace { + +void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream, + const NVTEQuantizationConfig config_) { using namespace transformer_engine; // Check input tensor @@ -138,12 +146,35 @@ void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaSt to_string(output.amax.dtype), ")"); CheckOutputTensor(output, "output_compute_amax", true); + float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + // Compute amax TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.data.dtype, IType, constexpr int nvec = 32 / sizeof(IType); launch_amax_kernel(reinterpret_cast(input.data.dptr), reinterpret_cast(output.amax.dptr), input.data.numel(), - stream);); // NOLINT(*) + noop_ptr, stream);); // NOLINT(*) +} + +} // anonymous namespace + +void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax); + compute_amax_impl(input_, output_, stream, nullptr); +} + +void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_, + const NVTEQuantizationConfig config_, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax_with_config); + compute_amax_impl(input_, output_, stream, config_); } namespace transformer_engine { @@ -151,7 +182,11 @@ namespace { __global__ void compute_scale_from_amax_kernel(const float *amax_ptr, float *scale_ptr, const float max_fp8, const bool force_pow_2_scales, - const float epsilon) { + const float epsilon, const float *noop_ptr) { + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + *scale_ptr = compute_scale_from_amax(*amax_ptr, max_fp8, force_pow_2_scales, epsilon, std::numeric_limits::max()); } @@ -197,10 +232,21 @@ void nvte_compute_scale_from_amax(NVTETensor output_, const NVTEQuantizationConf TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(output.data.dtype, DType, max_fp8 = Quantized_Limits::max_norm;); + // noop tensor for cuda graph + float *noop_ptr = nullptr; + if (config_ != nullptr) { + const QuantizationConfig *config_cpp = reinterpret_cast(config_); + + // extract noop tensor from quant_config_cpp if it's not null + const NVTETensor noop = config_cpp ? config_cpp->noop_tensor : nullptr; + noop_ptr = reinterpret_cast( + (noop != nullptr ? convertNVTETensorCheck(noop)->data.dptr : nullptr)); + } + // Update scale compute_scale_from_amax_kernel<<<1, 1, 0, stream>>>( reinterpret_cast(output.amax.dptr), reinterpret_cast(output.scale.dptr), max_fp8, config.force_pow_2_scales, - config.amax_epsilon); + config.amax_epsilon, noop_ptr); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/transpose/cast_transpose.h b/transformer_engine/common/transpose/cast_transpose.h index a73723926..abfa226e8 100644 --- a/transformer_engine/common/transpose/cast_transpose.h +++ b/transformer_engine/common/transpose/cast_transpose.h @@ -27,7 +27,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor &input, SimpleTensor SimpleTensor &scale_inv_t, SimpleTensor &output, SimpleTensor &output_t, const float epsilon, const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream); + const SimpleTensor &noop_tensor, cudaStream_t stream); // enum class for rowwise usage enum class FP8BlockwiseRowwiseOption { @@ -59,7 +59,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor &input, SimpleTensor SimpleTensor &output_t, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow_2_scale, cudaStream_t stream); + const bool pow_2_scale, const SimpleTensor &noop_tensor, + cudaStream_t stream); } // namespace transformer_engine::detail diff --git a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu index a603d1f1a..c3f085b87 100644 --- a/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_square_blockwise.cu @@ -70,11 +70,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, const __grid_constant__ CUtensorMap tensor_map_output_t, - bool pow_2_scaling) { + bool pow_2_scaling, const float* noop_ptr) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + // shared mem for amax reduction in entire block, each warp produces one amax, there are // NUM_WARPS_IN_BLOCK amax to reduce __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; @@ -249,11 +253,15 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) block_scaled_cast_transpose CType* const tile_scales_inv_c, CType* const tile_scales_inv_t, const size_t row_length, const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, - bool pow_2_scaling) { + bool pow_2_scaling, const float* noop_ptr) { using IVec = Vec; using OVecCast = Vec; using OVecTrans = Vec; + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + // shared mem for amax reduction in entire block, each warp produces one amax, there are // NUM_WARPS_IN_BLOCK amax to reduce __shared__ CType block_tile_amax_shared[NUM_WARPS_IN_BLOCK]; @@ -473,7 +481,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor SimpleTensor& scale_inv_t, SimpleTensor& output, SimpleTensor& output_t, const float epsilon, const bool return_transpose, const bool pow_2_scale, - cudaStream_t stream) { + const SimpleTensor& noop_tensor, cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_square_blockwise); checkCuDriverContext(stream); @@ -494,6 +502,8 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor size_t scale_t_stride_x = 0; size_t scale_t_stride_y = 0; + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + if (return_transpose) { NVTE_CHECK(output_t.shape.size() == input.shape.size(), "output_t must have same number of dimensions as input."); @@ -541,7 +551,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - tensor_map_output_trans, pow_2_scale); + tensor_map_output_trans, pow_2_scale, noop_ptr); } else { block_scaled_cast_transpose_kernel_notaligned @@ -552,7 +562,7 @@ void quantize_transpose_square_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, - pow_2_scale); + pow_2_scale, noop_ptr); } // full-tile ) // return_transpose ) // OutputType diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 6f5c0f3a6..4c82b8c81 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -172,7 +172,12 @@ __global__ void __launch_bounds__(kThreadsPerBlock) block_scaled_1d_cast_transpo const size_t num_rows, const size_t scale_stride_x, const size_t scale_stride_y, const size_t scale_t_stride_x, const size_t scale_t_stride_y, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow_2_scaling) { + const bool pow_2_scaling, const float* noop_ptr) { + // skip execution if noop + if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { + return; + } + bool return_rowwise = rowwise_option != FP8BlockwiseRowwiseOption::NONE; bool return_columnwise_gemm_ready = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; @@ -520,7 +525,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor SimpleTensor& output_t, const float epsilon, FP8BlockwiseRowwiseOption rowwise_option, FP8BlockwiseColumnwiseOption columnwise_option, - const bool pow2_scale, cudaStream_t stream) { + const bool pow2_scale, const SimpleTensor& noop_tensor, + cudaStream_t stream) { NVTE_API_CALL(quantize_transpose_vector_blockwise); const size_t row_length = input.shape.size() > 0 ? input.shape.at(input.shape.size() - 1) : 1u; @@ -585,6 +591,8 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); + const float* noop_ptr = reinterpret_cast(noop_tensor.dptr); + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( input.dtype, InputType, @@ -613,9 +621,9 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor reinterpret_cast(scale_inv.dptr), reinterpret_cast(scale_inv_t.dptr), row_length, num_rows, scale_stride_x, scale_stride_y, scale_t_stride_x, scale_t_stride_y, epsilon, rowwise_option, - columnwise_option, pow2_scale);) // kAligned - ) // OutputType - ) // InputType + columnwise_option, pow2_scale, noop_ptr);) // kAligned + ) // OutputType + ) // InputType NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 1158132e3..8d8735118 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1427,7 +1427,8 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o quantize_transpose_square_blockwise( input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, output_tensor->data, output_tensor->columnwise_data, epsilon, - /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, stream); + /*return_transpose=*/output_tensor->has_columnwise_data(), force_pow_2_scales, + /*noop_tensor=*/noop_tensor.data, stream); break; } case NVTE_BLOCK_SCALING_1D: { @@ -1455,10 +1456,10 @@ void quantize_helper(const NVTETensor input, const NVTETensor grad, NVTETensor o ? FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT : FP8BlockwiseColumnwiseOption::COLUMNWISE_GEMM_READY; } - quantize_transpose_vector_blockwise(input_tensor->data, output_tensor->scale_inv, - output_tensor->columnwise_scale_inv, output_tensor->data, - output_tensor->columnwise_data, epsilon, rowwise_option, - columnwise_option, force_pow_2_scales, stream); + quantize_transpose_vector_blockwise( + input_tensor->data, output_tensor->scale_inv, output_tensor->columnwise_scale_inv, + output_tensor->data, output_tensor->columnwise_data, epsilon, rowwise_option, + columnwise_option, force_pow_2_scales, noop_tensor.data, stream); break; } default: diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 0c75789ed..c690cd522 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -518,7 +518,8 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te // Compute amax if (compute_amax) { - NVTE_SCOPED_GIL_RELEASE({ nvte_compute_amax(input.data(), out.data(), stream); }); + NVTE_SCOPED_GIL_RELEASE( + { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); } // Perform amax reduction if needed From e9a5fa4e368464f3b310b90ab7f670f35319344b Mon Sep 17 00:00:00 2001 From: Casper Date: Thu, 4 Sep 2025 22:39:53 +0200 Subject: [PATCH 124/153] =?UTF-8?q?[PyTorch]=C2=A0fix=20cross=20entropy=20?= =?UTF-8?q?vanishing=20gradients=20(#2139)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * fix cross entropy Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Casper * fix comments Signed-off-by: Casper * fix: few more style issues Signed-off-by: Casper * fix: remove grad_output_stride (unnecessary) Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix: only backward was broken Signed-off-by: Casper * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Generalize cross entropy backward kernel to handle reduced and unreduced loss Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Casper Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon --- tests/pytorch/test_parallel_cross_entropy.py | 59 ++++++++++++------- .../pytorch/triton/cross_entropy.py | 3 + 2 files changed, 41 insertions(+), 21 deletions(-) diff --git a/tests/pytorch/test_parallel_cross_entropy.py b/tests/pytorch/test_parallel_cross_entropy.py index 77bea2b36..fa56852ff 100644 --- a/tests/pytorch/test_parallel_cross_entropy.py +++ b/tests/pytorch/test_parallel_cross_entropy.py @@ -6,6 +6,8 @@ import torch from transformer_engine.pytorch.cross_entropy import parallel_cross_entropy +from utils import dtype_tols + class TestParallelCrossEntropy: @@ -18,19 +20,25 @@ def generate_infra(self, reduce_loss: bool, label_smoothing: float): label_smoothing=label_smoothing, reduction="mean" if reduce_loss else "none" ) - def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): - + def generate_input( + self, + dtype: torch.dtype, + swap_dim: bool, + ignore_idx: bool, + device: torch.device = "cuda", + ): SQ = random.choice([64, 128]) batch = random.choice([1, 2]) vocab = random.choice([64000, 128000]) ignore = random.sample(range(0, SQ - 1), 5) + # Generate random data if swap_dim: - self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (SQ, batch)).cuda() + self.input_test = torch.rand((SQ, batch, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (SQ, batch), device=device) else: - self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype).cuda() - self.tar_test = torch.randint(0, vocab, (batch, SQ)).cuda() + self.input_test = torch.rand((batch, SQ, vocab), dtype=dtype, device=device) + self.tar_test = torch.randint(0, vocab, (batch, SQ), device=device) if ignore_idx: for i in ignore: @@ -40,9 +48,14 @@ def generate_input(self, dtype: torch.dtype, swap_dim: bool, ignore_idx: bool): else: self.tar_test[0][i] = -100 + # Make copy of data for reference implementation self.input_ref = torch.reshape(self.input_test.clone().detach(), (batch * SQ, vocab)) self.tar_ref = torch.reshape(self.tar_test.clone().detach(), (batch * SQ,)) + # Enable autograd + self.input_test.requires_grad_() + self.input_ref.requires_grad_() + def one_iteration_test( self, dtype: torch.dtype, @@ -52,18 +65,20 @@ def one_iteration_test( ignore_idx: bool = False, ): + # Random data self.generate_input(dtype, swap_dim, ignore_idx) - self.input_test.requires_grad_(True) - self.input_ref.requires_grad_(True) - + # Forward pass test_loss = self.test_loss_func( self.input_test, self.tar_test, label_smoothing, reduce_loss, None ) - ref_loss = self.ref_loss_func(self.input_ref, self.tar_ref) - # Handle backward pass based on the test scenario + # Compute square to avoid trivial backward pass + test_loss = torch.square(test_loss) + ref_loss = torch.square(ref_loss) + + # Backward pass if reduce_loss: test_loss.backward() ref_loss.backward() @@ -71,16 +86,18 @@ def one_iteration_test( test_loss.sum().backward() ref_loss.sum().backward() - test_loss = torch.flatten(test_loss) if not reduce_loss else test_loss - - if ignore_idx: - print(test_loss, ref_loss) - - # Compare gradients when backward pass was called - torch.testing.assert_close( - torch.flatten(self.input_test.grad, start_dim=0, end_dim=1), self.input_ref.grad - ) - + # Check that loss and grad input match + tols = dtype_tols(dtype) + test_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = test_loss.to(dtype=torch.float64, device="cpu") + ref_loss = ref_loss.reshape(test_loss.size()) + test_grad_input = self.input_test.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = self.input_ref.grad.to(dtype=torch.float64, device="cpu") + ref_grad_input = ref_grad_input.reshape(test_grad_input.size()) + torch.testing.assert_close(test_loss, ref_loss, **tols) + torch.testing.assert_close(test_grad_input, ref_grad_input, **tols) + + # Reset data self.input_test = None self.input_ref = None self.tar_test = None diff --git a/transformer_engine/pytorch/triton/cross_entropy.py b/transformer_engine/pytorch/triton/cross_entropy.py index 323a93922..7cfff1da9 100644 --- a/transformer_engine/pytorch/triton/cross_entropy.py +++ b/transformer_engine/pytorch/triton/cross_entropy.py @@ -230,6 +230,7 @@ def element_mul_kernel( X_ptr, X_stride, grad_output_ptr, + grad_output_stride, n_cols, BLOCK_SIZE: tl.constexpr, ): @@ -252,6 +253,7 @@ def element_mul_kernel( X_ptr += program_id * X_stride # Load the gradient output value + grad_output_ptr += program_id * grad_output_stride grad_output = tl.load(grad_output_ptr) # Perform the element-wise multiplication @@ -360,6 +362,7 @@ def cross_entropy_backward( _input, _input.stride(-2), grad_output, + 1 if grad_output.numel() > 1 else 0, V, BLOCK_SIZE=BLOCK_SIZE, num_warps=32, From 11e9d669ae827b13dce309fefa79f8938da34352 Mon Sep 17 00:00:00 2001 From: Hongbin Liu Date: Fri, 5 Sep 2025 11:24:22 +0800 Subject: [PATCH 125/153] Fix bug when enabling --overlap-grad-reduce in mcore (#2142) * fix bugs when enabling --overlap-grad-reduce in mcore Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix CI Signed-off-by: Hongbin Liu * format Signed-off-by: Hongbin Liu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Hongbin Liu Co-authored-by: Hongbin Liu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/pytorch/module/base.py | 3 +-- transformer_engine/pytorch/module/grouped_linear.py | 6 +----- transformer_engine/pytorch/module/layernorm_mlp.py | 7 ++----- 3 files changed, 4 insertions(+), 12 deletions(-) diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index a6275abd1..0f2e3c4de 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1482,8 +1482,7 @@ def backward_dw(self): (wgrad, bgrad), _ = self.wgrad_store.pop() if not self.fuse_wgrad_accumulation: weight_tensor = noop_cat(self._get_weight_tensors()) - if weight_tensor.grad is None: - weight_tensor.grad = wgrad.to(weight_tensor.dtype) + weight_tensor.grad = wgrad.to(weight_tensor.dtype) if self.use_bias: bias_tensor = noop_cat([getattr(self, name) for name in self.bias_names]) if bias_tensor.grad is None: diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index 3d7a5efac..e9189ccc5 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -452,9 +452,6 @@ def handle_custom_ddp_from_mcore(weight, wgrad): else: wgrad_list = [None] * ctx.num_gemms - if ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute(): - wgrad_list = [None] * ctx.num_gemms - if not ctx.use_bias or ( ctx.wgrad_store is not None and ctx.wgrad_store.delay_wgrad_compute() @@ -829,8 +826,7 @@ def backward_dw(self): bias_params = [getattr(self, f"bias{i}") for i in range(self.num_gemms)] if not self.fuse_wgrad_accumulation: for i in range(self.num_gemms): - if weight_params[i].grad is None: - weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) + weight_params[i].grad = wgrad_list[i].to(weight_params[i].dtype) if self.use_bias: for i in range(self.num_gemms): if bias_params[i].grad is None: diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 182bf99f8..a6c55ceb7 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -1197,7 +1197,6 @@ def fc1_wgrad_gemm( "with Userbuffers (tensor-parallel communication overlapping)" ) ctx.wgrad_store.put([ln_out_total, dact], fc1_wgrad_gemm) - fc1_wgrad = None if fuse_gemm_and_bias_fc1_wgrad: fc1_bias_grad = None else: @@ -2168,10 +2167,8 @@ def backward_dw(self): if self.fc1_bias.grad is None: self.fc1_bias.grad = fc1_bias_grad.to(self.fc1_bias.dtype) if not self.fuse_wgrad_accumulation: - if self.fc2_weight.grad is None: - self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) - if self.fc1_weight.grad is None: - self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) + self.fc2_weight.grad = fc2_wgrad.to(self.fc2_weight.dtype) + self.fc1_weight.grad = fc1_wgrad.to(self.fc1_weight.dtype) del fc2_bias_grad_ del fc2_wgrad del fc1_wgrad From b10f436aa28ec8d885eb0d9bf134c320ffb6353d Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Thu, 4 Sep 2025 22:09:55 -0700 Subject: [PATCH 126/153] Fix CUDA version in setup.py (#2132) * Fix CUDA version in setup.py Signed-off-by: Vladimir Cherepanov * Re-enable building comm-gemm tests Signed-off-by: Vladimir Cherepanov * WAR for nvidia-nvshmem package Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- setup.py | 7 ++++--- tests/cpp/CMakeLists.txt | 1 + transformer_engine/common/__init__.py | 5 +++++ 3 files changed, 10 insertions(+), 3 deletions(-) diff --git a/setup.py b/setup.py index 52adaf923..ed1f5b8a9 100644 --- a/setup.py +++ b/setup.py @@ -17,6 +17,7 @@ from build_tools.te_version import te_version from build_tools.utils import ( cuda_archs, + cuda_version, get_frameworks, remove_dups, ) @@ -70,11 +71,11 @@ def setup_common_extension() -> CMakeExtension: if bool(int(os.getenv("NVTE_WITH_CUBLASMP", "0"))): cmake_flags.append("-DNVTE_WITH_CUBLASMP=ON") cublasmp_dir = os.getenv("CUBLASMP_HOME") or metadata.distribution( - "nvidia-cublasmp-cu12" - ).locate_file("nvidia/cublasmp/cu12") + f"nvidia-cublasmp-cu{cuda_version()[0]}" + ).locate_file(f"nvidia/cublasmp/cu{cuda_version()[0]}") cmake_flags.append(f"-DCUBLASMP_DIR={cublasmp_dir}") nvshmem_dir = os.getenv("NVSHMEM_HOME") or metadata.distribution( - "nvidia-nvshmem-cu12" + f"nvidia-nvshmem-cu{cuda_version()[0]}" ).locate_file("nvidia/nvshmem") cmake_flags.append(f"-DNVSHMEM_DIR={nvshmem_dir}") print("CMAKE_FLAGS:", cmake_flags[-2:]) diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index c2c9d0d91..412c5d34d 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -43,5 +43,6 @@ include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +add_subdirectory(comm_gemm) add_subdirectory(operator) add_subdirectory(util) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 7feb5fda5..dd1ec480b 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -218,6 +218,11 @@ def _nvidia_cudart_include_dir() -> str: except ModuleNotFoundError: return "" + # Installing some nvidia-* packages, like nvshmem, create nvidia name, so "import nvidia" + # above doesn't through. However, they don't set "__file__" attribute. + if nvidia.__file__ is None: + return "" + include_dir = Path(nvidia.__file__).parent / "cuda_runtime" return str(include_dir) if include_dir.exists() else "" From c47f329b2084406093124851a3aeecb935183def Mon Sep 17 00:00:00 2001 From: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Date: Fri, 5 Sep 2025 09:56:15 -0700 Subject: [PATCH 127/153] [JAX] NoScaleTensor wrapper for non-quantized data (#2136) * Custom call tests passing Signed-off-by: Jeremy Berchtold * Fix test_layer.py Signed-off-by: Jeremy Berchtold * Lint Signed-off-by: Jeremy Berchtold * Fix comments Signed-off-by: Jeremy Berchtold * Support using amax on HighPrecision tensor if it exists instead of recomputing for current scaling Signed-off-by: Jeremy Berchtold * Fix shardy issue with amax being shape 1,1,1 instead of shape (1,) Signed-off-by: Jeremy Berchtold * Add higher-precision VJP tests to test_distributed_layernorm_mlp Signed-off-by: Jeremy Berchtold * Cast non-quantized kernels to input dtype in VJPs Signed-off-by: Jeremy Berchtold * Rename HighPrecisionTensor to NoScaleTensor Signed-off-by: Jeremy Berchtold * Use NoScaleTensor in pure JAX impls where it was missing Signed-off-by: Jeremy Berchtold * Fix tests Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 16 ++- tests/jax/test_distributed_layernorm_mlp.py | 17 ++- transformer_engine/jax/activation.py | 12 +- .../jax/cpp_extensions/activation.py | 49 +++---- transformer_engine/jax/cpp_extensions/gemm.py | 16 ++- .../jax/cpp_extensions/normalization.py | 34 ++--- .../jax/cpp_extensions/quantization.py | 70 +++++----- transformer_engine/jax/dense.py | 37 +++--- transformer_engine/jax/layernorm.py | 5 +- transformer_engine/jax/layernorm_dense.py | 12 +- transformer_engine/jax/layernorm_mlp.py | 23 +++- transformer_engine/jax/quantize/quantizer.py | 42 ++++-- .../jax/quantize/scaling_modes.py | 86 +++++++++++- transformer_engine/jax/quantize/tensor.py | 124 +++++++++++++----- 14 files changed, 359 insertions(+), 184 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index d5f21651d..11f07d913 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -31,6 +31,7 @@ from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version from transformer_engine.jax import cpp_extensions as tex from transformer_engine.jax.quantize import ( + NoScaleTensor, ScaledTensor, ScaledTensor1x, ScaledTensor2x, @@ -182,7 +183,7 @@ def assert_dequantized_grouped_scaled_tensor( class TestActivation: def ref_act(self, x, activation_type): - return _jax_act_lu(x, activation_type) + return _jax_act_lu(x, activation_type).data def value_n_grad_ref_func(self, x, activation_type): jitted_reference = jit( @@ -337,8 +338,8 @@ def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantize ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) - # if isinstance(ln_out, ScaledTensor): - # ln_out = ln_out.dequantize() + # This is a no-op for non-quantized data + ln_out = ln_out.dequantize() return ln_out key = jax.random.PRNGKey(0) @@ -765,7 +766,9 @@ def _test_quantize_dact_dbias( te_output, jax_output, precise_comparison=precise_comparison ) else: - assert_allclose(te_output, jax_output) + assert isinstance(te_output, NoScaleTensor) + assert isinstance(jax_output, NoScaleTensor) + assert_allclose(te_output.data, jax_output.data) if is_dbias: # TE kernels cast the intermediate results to the input dtype which reduces precision compared to the JAX implementation, for dbias this typically only affects bfloat16. @@ -1020,8 +1023,7 @@ def _ref_jax_norm_impl(x, gamma, beta, norm_type, zero_centered_gamma, eps, quan ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer) else: ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer) - if isinstance(ln_out, ScaledTensor): - ln_out = ln_out.dequantize() + ln_out = ln_out.dequantize() return ln_out @@ -1177,7 +1179,7 @@ def _ref_func_impl(x, gamma, kernel_1, kernel_2, bias_1, bias_2): bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape linear_1_out += jnp.reshape(bias_1, bias_1_shape) - x = _jax_act_lu(linear_1_out, activation_type) + x = _jax_act_lu(linear_1_out, activation_type).data linear_2_out = jax.lax.dot_general(x, kernel_2, (((1,), (0,)), ((), ()))) if use_bias: bias_2_shape = (1,) * (linear_2_out.ndim - bias_2.ndim) + bias_2.shape diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 90b762c24..a44921c64 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -173,7 +173,9 @@ def _test_layernorm_mlp_grad( ) # Single GPU - with fp8_autocast(enabled=True, fp8_recipe=fp8_recipe, mesh_resource=MeshResource()): + with fp8_autocast( + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=MeshResource() + ): single_jitter = jax.jit( value_and_grad_func, static_argnums=range(len(inputs), len(static_inputs) + len(inputs)), @@ -184,7 +186,7 @@ def _test_layernorm_mlp_grad( devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) mesh = Mesh(devices, mesh_axes) with mesh, fp8_autocast( - enabled=True, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource + enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe, mesh_resource=mesh_resource ): k1_sharding = NamedSharding(mesh, PartitionSpec("fsdp", None, "tpsp")) k2_sharding = NamedSharding(mesh, PartitionSpec("tpsp", "fsdp")) @@ -226,7 +228,12 @@ def _test_layernorm_mlp_grad( fwd_test_type = dtype if fp8_recipe is None else jnp.float8_e4m3fn bwd_test_type = dtype if fp8_recipe is None else jnp.float8_e5m2 - assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + + if fwd_test_type == jnp.float16 and use_bias: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type, atol=0.04, rtol=1.5) + else: + assert_allclose(multi_fwd, single_fwd, dtype=fwd_test_type) + for i in range(len(inputs)): if multi_grads[i] is not None: if isinstance(multi_grads[i], list): @@ -252,7 +259,7 @@ def _test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad( self, @@ -281,7 +288,7 @@ def test_layernorm_mlp_grad( @pytest_parametrize_wrapper("activation_type", [("gelu",), ("gelu", "linear")]) @pytest_parametrize_wrapper("dtype", DTYPES) @pytest_parametrize_wrapper("use_bias", [True, False]) - @pytest_parametrize_wrapper("fp8_recipe", SUPPORTED_RECIPES) + @pytest_parametrize_wrapper("fp8_recipe", [None] + SUPPORTED_RECIPES) @pytest_parametrize_wrapper("with_jax_gemm", [False, True]) def test_layernorm_mlp_grad_shardy( self, diff --git a/transformer_engine/jax/activation.py b/transformer_engine/jax/activation.py index ef6def2d0..12b35ec43 100644 --- a/transformer_engine/jax/activation.py +++ b/transformer_engine/jax/activation.py @@ -14,7 +14,7 @@ from . import cpp_extensions as tex -from .quantize.tensor import ScaledTensor +from .quantize.tensor import NoScaleTensor from .quantize.quantizer import Quantizer @@ -22,7 +22,7 @@ def activation( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, -) -> Union[jnp.ndarray, ScaledTensor]: +) -> jnp.ndarray: """Apply activation functions to input tensor with optional quantization. This function applies a sequence of activation functions to the input tensor. @@ -72,8 +72,8 @@ def _activation_fwd_rule(x, activation_type, quantizer): Tuple of (output, context) for backward pass """ fwd_output = tex.act_lu(x, activation_type, quantizer) - if isinstance(fwd_output, ScaledTensor): - fwd_output = fwd_output.dequantize() + # This is a no-op for higher-precision tensors + fwd_output = fwd_output.dequantize() return fwd_output, (x, quantizer) @@ -91,6 +91,10 @@ def _activation_bwd_rule(activation_type, ctx, g): (x, _) = ctx assert x.dtype == g.dtype dx = tex.dact_lu(g, x, activation_type) + # No quantization is used in this VJP backward, so the output should + # always be a NoScaleTensor + assert isinstance(dx, NoScaleTensor) + dx = dx.data return (dx, None) diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index fe2253598..d3c7d2b08 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -29,7 +29,7 @@ ) from .quantization import _jax_dbias, _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -from ..quantize import ScaledTensor, ScaledTensorFactory +from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, QuantizeLayout, @@ -922,7 +922,7 @@ class DActLuQuantizePrimitive(BaseDActLuDBiasQuantizePrimitive): """Subclass of BaseDActLuDBiasQuantizePrimitive for fused activation quantization without dbias. No change in functionality from the base primitive but named differently for use in more granular disabling of primitives via NVTE_JAX_CUSTOM_CALLS.""" -def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, ScaledTensor]: +def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[NoScaleTensor, ScaledTensor]: """ JAX native activation implementation """ @@ -941,11 +941,11 @@ def _jax_act_lu(inputs, activation_type, quantizer=None) -> Union[jnp.ndarray, S x = jnp.squeeze(x, axis=-2) if quantizer: return quantizer.quantize(x, flatten_axis=-1) - return x + return NoScaleTensor(data=x, amax=None) def _jax_quantize_dact_dbias( - dz: jnp.ndarray, + dz: Union[jnp.ndarray, NoScaleTensor], x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], is_dbias: bool = True, @@ -963,7 +963,9 @@ def _jax_quantize_dact_dbias( _, vjp_func = jax.vjp( partial(_jax_act_lu, activation_type=activation_type), x.astype(jnp.float32) ) - (dx,) = vjp_func(dz.astype(jnp.float32)) + # VJP is using non-quantized backward for dact, so the input should always be wrapped in NoScaleTensor regardless of whether the forward pass used quantization or this dact will quantize afterwards. + dz = NoScaleTensor(data=dz.astype(jnp.float32), amax=None) + (dx,) = vjp_func(dz) dbias = None if is_dbias: @@ -973,6 +975,7 @@ def _jax_quantize_dact_dbias( dx = quantizer.quantize(dx, dq_dtype=x.dtype, flatten_axis=-2) else: dx = dx.astype(x.dtype) + dx = NoScaleTensor(data=dx, amax=None) return dx, dbias @@ -981,7 +984,6 @@ def act_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - noop_scaled_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """Activation with optional quantization. @@ -990,7 +992,6 @@ def act_lu( Shape: (..., ACT_DIM, K) where ACT_DIM is 1 for non-gated activations and 2 for gated activations activation_type: Type of activation function to apply. quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: If quantizer is None: @@ -1035,10 +1036,10 @@ def act_lu( is_outer=True, ) out = out.reshape(output_shape) - if noop_scaled_tensor: - return ScaledTensorFactory.create_2x( - out, None, out, None, scaling_mode=ScalingMode.NO_SCALING, dq_dtype=out.dtype - ) + out = NoScaleTensor( + data=out, + amax=None, + ) return out if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1092,7 +1093,6 @@ def quantize_dact_dbias( activation_type: Sequence[Union[str, Callable]] = ("gelu",), is_dbias: bool = True, quantizer: Optional[Quantizer] = None, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor, jnp.ndarray]: """Compute gradients of activation and bias with optional quantization. @@ -1103,7 +1103,6 @@ def quantize_dact_dbias( activation_type: Type of activation function used in the forward pass. Defaults to ("gelu",). is_dbias: If True, compute bias gradient. Defaults to True. quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: Tuple[ScaledTensor, jnp.ndarray]: A tuple containing: @@ -1146,19 +1145,10 @@ def quantize_dact_dbias( if is_dbias: dbias = _jax_dbias(output, dtype=x.dtype, flatten_axis=-2) - if noop_scaled_tensor: - return ( - ScaledTensorFactory.create_2x( - output, - None, - output, - None, - ScalingMode.NO_SCALING, - dq_dtype=output.dtype, - ), - dbias, - ) - + output = NoScaleTensor( + data=output, + amax=None, + ) return output, dbias # TE/common does not support 1x dact_dbias_quantize on arch < 100 yet @@ -1167,7 +1157,7 @@ def quantize_dact_dbias( dz.astype(jnp.float32), x.astype(jnp.float32), activation_type, quantizer=None ) return _quantize_dbias_impl( - out, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 + out.data, quantizer, is_dbias=True, dq_dtype=x.dtype, flatten_axis=-2 ) is_gated = act_len == 2 @@ -1194,7 +1184,7 @@ def quantize_dact_dbias( quantizer=None, ) out, dbias = _quantize_dbias_impl( - out, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 + out.data, is_dbias=is_dbias, quantizer=quantizer, dq_dtype=x.dtype, flatten_axis=-2 ) return out, dbias @@ -1258,7 +1248,6 @@ def dact_lu( x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]], quantizer: Optional[Quantizer] = None, - noop_scale_tensor: bool = False, ) -> Union[jnp.ndarray, ScaledTensor]: """ Backward pass for activation with optional quantization. @@ -1268,7 +1257,6 @@ def dact_lu( x: Input tensor that was used in forward pass. activation_type: Type of activation function that was applied. quantizer: Optional quantizer for FP8 quantization of the output gradient. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: The gradient of the activation with respect to the input. @@ -1279,6 +1267,5 @@ def dact_lu( activation_type=activation_type, is_dbias=False, quantizer=quantizer, - noop_scaled_tensor=noop_scale_tensor, ) return output diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index be73f708e..acc8d6727 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -22,6 +22,8 @@ from .base import BasePrimitive, register_primitive from .quantization import grouped_quantize from ..quantize import ( + AbstractBaseTensor, + NoScaleTensor, ScaledTensor, ScaledTensor2x, GroupedScaledTensor1x, @@ -228,6 +230,11 @@ def _dims_are_consecutive(dims): "require non-transposed LHS and transposed RHS operands " "(`contracting_dims=((-1, ), (-1, ))`)." ) + else: + assert lhs.dtype == rhs.dtype, ( + "For TE cuBLAS GEMM for non-quantized inputs, the operand dtypes must be equal." + f" LHS dtype != RHS dtype, lhs.dtype={lhs.dtype}, rhs.dtype={rhs.dtype}" + ) # Determine output shape and dtype assert ( @@ -1134,8 +1141,8 @@ def _jax_gemm_fp8_impl(lhs, rhs): def gemm( - lhs: Union[jnp.ndarray, ScaledTensor], - rhs: Union[jnp.ndarray, ScaledTensor], + lhs: Union[jnp.ndarray, AbstractBaseTensor], + rhs: Union[jnp.ndarray, AbstractBaseTensor], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((-1,), (0,)), lhs_quantizer: Quantizer = None, rhs_quantizer: Quantizer = None, @@ -1191,6 +1198,11 @@ def gemm( compute the GeLU contribution to the gradient. Only supported with TE's custom call to cuBLAS GEMM. """ + if isinstance(lhs, NoScaleTensor): + lhs = lhs.data + if isinstance(rhs, NoScaleTensor): + rhs = rhs.data + # Try to get LHS and RHS quantizers from a quantizer set for backward compatibility if lhs_quantizer is None or rhs_quantizer is None: quantizer_set = kwargs.get("quantizer_set", None) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index 7296afc72..de1877de5 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -30,7 +30,7 @@ ) from .quantization import _quantize_dbias_impl from ..sharding import all_reduce_max_along_all_axes_except_PP, all_reduce_sum_along_dp_fsdp -from ..quantize import ScaledTensor, ScaledTensorFactory +from ..quantize import ScaledTensor, ScaledTensorFactory, NoScaleTensor from ..quantize import ( Quantizer, QuantizeLayout, @@ -845,6 +845,7 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) + ln_out = NoScaleTensor(data=ln_out, amax=None) return ln_out, jnp.squeeze(mean, axis=-1), jnp.squeeze(rsigma, axis=-1) @@ -869,6 +870,7 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) + ln_out = NoScaleTensor(data=ln_out, amax=None) return ln_out, jnp.squeeze(rsigma, axis=-1) @@ -930,7 +932,7 @@ def layernorm_fwd( scale_dtype=jnp.float32, is_outer=True, ) - return output, mu, rsigma + return NoScaleTensor(data=output, amax=None), mu, rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING @@ -1064,7 +1066,7 @@ def layernorm_bwd( ) mu_empty = jnp.zeros(mu.shape, mu.dtype) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) - return vjp_func((dz, mu_empty, rsigma_empty)) + return vjp_func((NoScaleTensor(data=dz, amax=None), mu_empty, rsigma_empty)) return NormBwdPrimitive.outer_primitive.bind( dz, x, @@ -1133,14 +1135,14 @@ def rmsnorm_fwd( scale_dtype=jnp.float32, is_outer=True, ) - return output, rsigma + return NoScaleTensor(data=output, amax=None), rsigma if ( quantizer.scaling_mode == ScalingMode.MXFP8_1D_SCALING and get_cudnn_version() < FUSED_MXFP8_NORM_CUDNN_MIN_VERSION ): out, rsigma = rmsnorm_fwd(x, gamma, zero_centered_gamma, epsilon, quantizer=None) - out, _ = _quantize_dbias_impl(out, quantizer) + out, _ = _quantize_dbias_impl(out.data, quantizer) return out, rsigma if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: @@ -1152,7 +1154,9 @@ def rmsnorm_fwd( epsilon=epsilon, quantizer=None, ) - out, _ = _quantize_dbias_impl(out, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype) + out, _ = _quantize_dbias_impl( + out.data, is_dbias=False, quantizer=quantizer, dq_dtype=x.dtype + ) return out, rsigma is_2x2x = quantizer.is_2x2x() @@ -1254,7 +1258,7 @@ def rmsnorm_bwd( gamma, ) rsigma_empty = jnp.zeros(rsigma.shape, rsigma.dtype) - return vjp_func((dz, rsigma_empty)) + return vjp_func((NoScaleTensor(data=dz, amax=None), rsigma_empty)) mu = jnp.empty(()) dx, dgamma, _ = NormBwdPrimitive.outer_primitive.bind( dz, @@ -1276,7 +1280,6 @@ def normalization_fwd( epsilon: float, norm_type: str, quantizer: Optional[Quantizer], - noop_scaled_tensor: bool = False, ): """Common wrapper for normalization forward pass. @@ -1293,7 +1296,6 @@ def normalization_fwd( - 'layernorm': Layer normalization - 'rmsnorm': Root mean square normalization quantizer: Optional quantizer for FP8 quantization of the output. - noop_scaled_tensor: Wrap the unquantized output as a ScaledTensor2x when quantizer is None. Returns: A tuple containing: @@ -1321,20 +1323,6 @@ def normalization_fwd( else: raise ValueError(f"{norm_type=} is not supported.") - if quantizer is None and noop_scaled_tensor: - return ( - ScaledTensorFactory.create_2x( - output, - None, - output, - None, - scaling_mode=ScalingMode.NO_SCALING, - dq_dtype=output.dtype, - ), - mu, - rsigma, - ) - return output, mu, rsigma diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 198beb55e..1813734b5 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -4,7 +4,7 @@ """JAX/TE custom ops for quantization""" import operator from functools import reduce -from typing import Tuple, Optional +from typing import Tuple, Optional, Union import math from packaging import version @@ -38,6 +38,7 @@ QuantizeLayout, ScalingMode, compute_scale_from_amax, + NoScaleTensor, ) if version.parse(jax.__version__) >= version.parse("0.5.0"): @@ -64,7 +65,7 @@ class BaseDBiasQuantizePrimitive(BasePrimitive): 7, 8, 9, - ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer, amax_aval + ) # out_dtype, scaling_mode, q_layout, flatten_axis, scale_dtype, is_dbias, is_outer inner_primitive = None outer_primitive = None @@ -535,11 +536,15 @@ def _jax_quantize( x, quantizer: Quantizer = None, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1 ): if quantizer is None: - return x + if isinstance(x, NoScaleTensor): + return x + return NoScaleTensor(data=x, amax=None) return quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis) -def _jax_dbias(dx: jnp.ndarray, dtype=None, flatten_axis: int = -1): +def _jax_dbias(dx: Union[jnp.ndarray, NoScaleTensor], dtype=None, flatten_axis: int = -1): + if isinstance(dx, NoScaleTensor): + dx = dx.data sum_axis = dx.ndim + flatten_axis if flatten_axis < 0 else flatten_axis assert sum_axis < dx.ndim, "Flatten axis out of bounds!" dtype = dtype or dx.dtype @@ -558,7 +563,9 @@ def _jax_quantize_dbias( flatten_axis: int = -1, ): if quantizer is None: - return x, None + if isinstance(x, NoScaleTensor): + return x, None + return NoScaleTensor(data=x, amax=None), None return ( quantizer.quantize(x, dq_dtype=dq_dtype, flatten_axis=flatten_axis), _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis), @@ -566,12 +573,11 @@ def _jax_quantize_dbias( def _quantize_dbias_impl( - x: jnp.ndarray, + x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, is_dbias: bool = False, dq_dtype: Optional[jnp.dtype] = None, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """ Cast wrapper @@ -581,28 +587,15 @@ def _quantize_dbias_impl( quantizer is not None ), "quantizer must be provided if dq_dtype is provided" + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + # Early-exit for non-quantized call - dq_dtype = dq_dtype or x.dtype + dq_dtype = dq_dtype or x.data.dtype if quantizer is None: dbias = None if is_dbias: - dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) - if noop_scaled_tensor: - # Return a dummy ScaledTensor2x to ensure .get_rowwise_tensor() and .get_colwise_tensor() - # always works. - return ( - ScaledTensorFactory.create_2x( - x, - None, - x, - None, - scaling_mode=ScalingMode.NO_SCALING, - dq_dtype=x.dtype, - data_layout="NN", - flatten_axis=flatten_axis, - ), - dbias, - ) + dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return x, dbias # If TE/common custom quantize op is disabled, or if quantizer layout is COLWISE, @@ -630,21 +623,25 @@ def _quantize_dbias_impl( dq_dtype=dq_dtype, flatten_axis=flatten_axis, ) - dbias = _jax_dbias(x, dtype=dq_dtype, flatten_axis=flatten_axis) + dbias = _jax_dbias(x.data, dtype=dq_dtype, flatten_axis=flatten_axis) return out, dbias scale = jnp.empty((), jnp.float32) + amax = None if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Globally reduce amax across all devices for current scaling so we have a single global scale. # This differs from the PyTorch implementation which uses a local amax and scale per-device and persists this # until the tensor is dequantized (e.g. in the GEMM). - amax = jnp.amax(jnp.abs(x), keepdims=True).astype(jnp.float32) + amax = x.amax + if amax is None: + amax = jnp.amax(jnp.abs(x.data), keepdims=True).astype(jnp.float32).reshape((1,)) scale = compute_scale_from_amax(amax, quantizer.q_dtype) elif quantizer.scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING: scale = quantizer.scale # Make sure amax is init with zero - amax = jnp.zeros((1,), jnp.float32) + if amax is None: + amax = jnp.zeros((1,), jnp.float32) # It is faster to use 1x quantization for tensor scaling is_1x_kernel_supported = not (is_dbias and get_min_device_compute_capability() < 100) @@ -665,7 +662,7 @@ def _quantize_dbias_impl( updated_amax, dbias, ) = PrimitiveClass.outer_primitive.bind( - x, + x.data, scale, amax, out_dtype=quantizer.q_dtype, @@ -706,10 +703,9 @@ def _quantize_dbias_impl( def quantize( - x: jnp.ndarray, + x: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor]: """Quantize input tensor according to the quantizer. @@ -719,7 +715,6 @@ def quantize( quantizer: Quantizer for FP8 quantization of the output. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. - noop_scaled_tensor: If True, wraps the output into a dummy ScaledTensor2x when quantizer is None. Returns: @@ -729,17 +724,15 @@ def quantize( x, quantizer=quantizer, flatten_axis=flatten_axis, - noop_scaled_tensor=noop_scaled_tensor, ) return out def quantize_dbias( - dz: jnp.ndarray, + dz: Union[jnp.ndarray, NoScaleTensor], quantizer: Quantizer, is_dbias: bool = True, flatten_axis: int = -1, - noop_scaled_tensor: bool = False, ) -> Tuple[ScaledTensor2x, jnp.ndarray]: """Quantize input tensor and compute bias gradient. @@ -750,8 +743,6 @@ def quantize_dbias( is_dbias: If True, compute bias gradient. Defaults to True. flatten_axis: The quantization axis in which input data can be flattened to 2D for quantization. Defaults to -1. - noop_scaled_tensor: If True, wraps the unquantized output into a dummy ScaledTensor2x when - quantizer is None. Returns: A tuple containing: @@ -765,7 +756,6 @@ def quantize_dbias( quantizer=quantizer, is_dbias=is_dbias, flatten_axis=flatten_axis, - noop_scaled_tensor=noop_scaled_tensor, ) @@ -968,7 +958,9 @@ def grouped_quantize( """ if quantizer is None: - return x + if isinstance(x, NoScaleTensor): + return x + return NoScaleTensor(data=x, amax=None) # TODO(Phuong): add support for flatten_axis = -2 assert flatten_axis in ( diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 65d65e7d4..b0ba734e5 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -24,6 +24,7 @@ with_sharding_constraint_by_logical_axes, is_fp8_gemm_with_all_layouts_supported, TensorUsage, + get_quantize_config, ) @@ -80,23 +81,19 @@ def dense( Returns: Transformed output tensor """ - # Remove when tex.quantize() can handle quantizer=None - if quantizer_set == noop_quantizer_set and tex.gemm_uses_jax_dot(): - x = with_sharding_constraint_by_logical_axes(x, input_axes) - output = tex.gemm(x, kernel, contracting_dims=contracting_dims) - if bias is not None: - bias_new_shape = (1,) * (output.ndim - bias.ndim) + bias.shape - output += jnp.reshape(bias, bias_new_shape) - else: - output = _dense( - x, - kernel, - bias, - contracting_dims, - input_axes, - kernel_axes, - quantizer_set, - ) + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel = kernel.astype(input_dtype) + + output = _dense( + x, + kernel, + bias, + contracting_dims, + input_axes, + kernel_axes, + quantizer_set, + ) return output @@ -175,7 +172,9 @@ def _dense_fwd_rule( flatten_axis_k = len(k_contracting_dims) - len(kernel.shape) casted_x = tex.quantize( - x, flatten_axis=flatten_axis_x, quantizer=quantizer_set.x, noop_scaled_tensor=True + x, + flatten_axis=flatten_axis_x, + quantizer=quantizer_set.x, ) casted_x = with_sharding_constraint_by_logical_axes(casted_x, input_axes) @@ -183,7 +182,6 @@ def _dense_fwd_rule( kernel, flatten_axis=flatten_axis_k, quantizer=quantizer_set.kernel, - noop_scaled_tensor=True, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -240,7 +238,6 @@ def _dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis_k, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, ) # GEMM NT diff --git a/transformer_engine/jax/layernorm.py b/transformer_engine/jax/layernorm.py index 7a3ad597b..0f5c6aeef 100644 --- a/transformer_engine/jax/layernorm.py +++ b/transformer_engine/jax/layernorm.py @@ -17,7 +17,6 @@ from . import cpp_extensions as tex from .quantize import ( - ScaledTensor, Quantizer, ) @@ -112,8 +111,8 @@ def _layernorm_fwd_rule(x, gamma, beta, norm_type: str, zero_centered_gamma, eps output, mu, rsigma = tex.normalization_fwd( x, gamma, beta, zero_centered_gamma, epsilon, norm_type, quantizer ) - if isinstance(output, ScaledTensor): - output = output.dequantize() + # This is a no-op for higher-precision tensors + output = output.dequantize() return output, (x, mu, rsigma, gamma, beta, quantizer) diff --git a/transformer_engine/jax/layernorm_dense.py b/transformer_engine/jax/layernorm_dense.py index b830cdb4f..fb9783075 100644 --- a/transformer_engine/jax/layernorm_dense.py +++ b/transformer_engine/jax/layernorm_dense.py @@ -22,6 +22,7 @@ noop_quantizer_set, with_sharding_constraint_by_logical_axes, TensorUsage, + get_quantize_config, ) @@ -68,6 +69,11 @@ def layernorm_dense( - The function supports automatic differentiation through JAX's custom VJP - Quantization is applied to both the normalized input and kernel """ + + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel = kernel.astype(input_dtype) + output = _layernorm_dense( x, kernel, @@ -188,14 +194,15 @@ def _layernorm_dense_fwd_rule( epsilon, norm_type, quantizer=quantizer_set.x, - noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_input_axes) # Kernel in (hidden_in, hidden_out...) flatten_axis = 1 - len(kernel.shape) casted_kernel = tex.quantize( - kernel, flatten_axis=flatten_axis, quantizer=quantizer_set.kernel, noop_scaled_tensor=True + kernel, + flatten_axis=flatten_axis, + quantizer=quantizer_set.kernel, ) casted_kernel = with_sharding_constraint_by_logical_axes(casted_kernel, kernel_axes) @@ -278,7 +285,6 @@ def _layernorm_dense_bwd_rule( is_dbias=use_bias, flatten_axis=flatten_axis, quantizer=quantizer_set.dgrad, - noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel.ndim diff --git a/transformer_engine/jax/layernorm_mlp.py b/transformer_engine/jax/layernorm_mlp.py index 00e3ddc3e..fc957801a 100644 --- a/transformer_engine/jax/layernorm_mlp.py +++ b/transformer_engine/jax/layernorm_mlp.py @@ -27,6 +27,7 @@ QuantizerSet, noop_quantizer_set, TensorUsage, + get_quantize_config, ) @@ -104,6 +105,11 @@ def layernorm_mlp( not zero_centered_gamma ), "zero_centered_gamma is not supported if norm_type is 'rmsnorm'" + if not get_quantize_config().is_fp8_enabled(): + input_dtype = x.dtype + kernel_1 = kernel_1.astype(input_dtype) + kernel_2 = kernel_2.astype(input_dtype) + output = _layernorm_mlp( x, gamma, @@ -266,12 +272,13 @@ def _layernorm_mlp_fwd_rule( epsilon, norm_type, quantizer=ffn1_quantizer_set.x, - noop_scaled_tensor=True, ) casted_ln_out = with_sharding_constraint_by_logical_axes(casted_ln_out, dot_1_input_axes) casted_kernel_1 = tex.quantize( - kernel_1, flatten_axis=-2, quantizer=ffn1_quantizer_set.kernel, noop_scaled_tensor=True + kernel_1, + flatten_axis=-2, + quantizer=ffn1_quantizer_set.kernel, ) # NN GEMM @@ -300,13 +307,16 @@ def _layernorm_mlp_fwd_rule( # (batch..., hidden_in) -> (batch..., hidden) casted_act_out = tex.act_lu( - dot_1_output, activation_type, quantizer=ffn2_quantizer_set.x, noop_scaled_tensor=True + dot_1_output, + activation_type, + quantizer=ffn2_quantizer_set.x, ) casted_act_out = with_sharding_constraint_by_logical_axes(casted_act_out, dot_2_input_axes) casted_kernel_2 = tex.quantize( - kernel_2, quantizer=ffn2_quantizer_set.kernel, noop_scaled_tensor=True + kernel_2, + quantizer=ffn2_quantizer_set.kernel, ) # NN GEMM @@ -404,7 +414,9 @@ def _layernorm_mlp_bwd_rule( grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) casted_grad, dbias_2 = tex.quantize_dbias( - grad, is_dbias=use_bias_2, quantizer=ffn1_quantizer_set.dgrad, noop_scaled_tensor=True + grad, + is_dbias=use_bias_2, + quantizer=ffn1_quantizer_set.dgrad, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim @@ -445,7 +457,6 @@ def _layernorm_mlp_bwd_rule( activation_type=activation_type, is_dbias=use_bias_1, quantizer=ffn2_quantizer_set.dgrad, - noop_scaled_tensor=True, ) # k_non_contracting_dims calibrated with the shape difference of grad.ndim vs kernel_1.ndim diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index 6cecfa361..306603bbe 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -19,7 +19,13 @@ from transformer_engine.common import recipe from .scaling_modes import ScalingMode -from .tensor import ScaledTensor, ScaledTensor1x, ScaledTensor2x, ScaledTensorFactory +from .tensor import ( + ScaledTensor, + ScaledTensor1x, + ScaledTensor2x, + ScaledTensorFactory, + NoScaleTensor, +) from .helper import ( get_quantize_config, get_quantize_config_class, @@ -217,7 +223,11 @@ class CurrentScaleQuantizer(Quantizer): data_layout: str = "NT" def _quantize_func( - self, x: jnp.ndarray, is_colwise=False, dq_dtype=None, flatten_axis=-1 + self, + x: Union[jnp.ndarray, NoScaleTensor], + is_colwise=False, + dq_dtype=None, + flatten_axis=-1, ) -> ScaledTensor1x: """Quantize function helper for delayed scaling FP8. @@ -229,14 +239,17 @@ def _quantize_func( Returns: A ScaledTensor1x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - amax = jnp.max(jnp.abs(x)).reshape((1,)) + amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) fp8_max = jnp.astype(jnp.finfo(self.q_dtype).max, jnp.float32) scale = (fp8_max / amax) / (2 ** get_quantize_config().MARGIN) - scaled_x = x.astype(compute_dtype) * scale + scaled_x = x.data.astype(compute_dtype) * scale clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / scale @@ -263,7 +276,10 @@ def quantize( Returns: A ScaledTensor1x or ScaledTensor2x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype if flatten_axis < 0: flatten_axis += x.ndim assert 0 < flatten_axis < x.ndim, "flatten_axis is out of bounds!" @@ -347,11 +363,14 @@ def _quantize_func( Returns: A ScaledTensor1x containing the quantized data """ - dq_dtype = dq_dtype if dq_dtype is not None else x.dtype + if isinstance(x, jnp.ndarray): + x = NoScaleTensor(data=x, amax=None) + + dq_dtype = dq_dtype if dq_dtype is not None else x.data.dtype compute_dtype = jnp.float32 dtype_max = (jnp.finfo(self.q_dtype).max).astype(compute_dtype) - scaled_x = x.astype(compute_dtype) * self.scale + scaled_x = x.data.astype(compute_dtype) * self.scale # quantize() in the old dot.py do this way, leave this code block here for future debugging # compute_dtype = x.dtype @@ -360,7 +379,8 @@ def _quantize_func( clipped_scaled_x = jnp.clip(scaled_x, -dtype_max, dtype_max).astype(self.q_dtype) scale_inv = 1.0 / self.scale - self.update(jnp.max(jnp.abs(x)).reshape((1,))) + amax = x.amax or jnp.max(jnp.abs(x.data)).reshape((1,)) + self.update(amax) return ScaledTensorFactory.create_1x( data=clipped_scaled_x, scale_inv=scale_inv, @@ -460,6 +480,10 @@ def _quantize_func(self, x, is_colwise=False, dq_dtype=None, flatten_axis=-1) -> Returns: A ScaledTensor1x containing the quantized data """ + if isinstance(x, NoScaleTensor): + # No need for amax in MXFP8 block scaling, so simply extract the jnp.ndarray data tensor from the NoScaleTensor x. + x = x.data + # TODO(Phuong): use quantize_func from JAX if flatten_axis < 0: flatten_axis = x.ndim + flatten_axis diff --git a/transformer_engine/jax/quantize/scaling_modes.py b/transformer_engine/jax/quantize/scaling_modes.py index 868570f73..e81a614f0 100644 --- a/transformer_engine/jax/quantize/scaling_modes.py +++ b/transformer_engine/jax/quantize/scaling_modes.py @@ -166,6 +166,90 @@ def get_shardy_sharding_rules( """ +class NoScalingModeMetadataImpl(ScalingModeMetadataImpl): + """Implementation for no scaling mode. + + This implementation provides metadata for no scaling mode, for using non-quantized higher-precision datatypes such as bf16. + """ + + def get_scale_dtype(self) -> jnp.dtype: + """Get the data type for scale tensors. This is a placeholder and won't be used for higher-precision values that don't have scaling. + + Returns: + The data type used for scale tensors (float32) + """ + return jnp.float32 + + def get_scale_shape( + self, + data_shape: Tuple[int, ...], + is_colwise: bool = False, + is_padded: bool = True, + flatten_axis: int = -1, + ) -> Tuple[int, ...]: + """Get the shape for scale tensors. This always returns an empty shape because this mode applies no scaling. + + Args: + data_shape: The shape of the tensor being scaled + is_colwise: Whether the scaling is column-wise + is_padded: Whether to return padded shape + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors - (1,) + """ + del data_shape, is_colwise, is_padded, flatten_axis + return (0,) + + @lru_cache(maxsize=4) + def get_quantize_layout(self, usage: TensorUsage) -> QuantizeLayout: + """Get the quantize layout for the tensor usage. + + Args: + usage: The usage of the tensor + + Returns: + The quantize layout for the tensor usage + """ + return QuantizeLayout.ROWWISE + + def get_grouped_scale_shape( + self, data_shape, n_groups, group_axis, is_colwise, is_padded=True, flatten_axis=-1 + ) -> Tuple[int]: + """Get the shape for scale tensors in this mode. + + Args: + data_shape: Original shape of the data tensor + is_colwise: Whether to use column-wise scaling + is_padded: Whether to use padded shapes + flatten_axis: Axis along which data can be flattened to 2D for quantization. Defaults to -1. + + Returns: + The shape for scale tensors + """ + del data_shape, group_axis, is_colwise + assert isinstance(n_groups, int) + return (n_groups,) + + def get_shardy_sharding_rules( + self, input_rank, unique_var, flatten_axis + ) -> QuantizeShardyRules: + """Sharding rules for the input and (row, col)wise scale tensors. + + Args: + input_rank: The rank of the input tensor (for which we produce the scale tensor) + unique_var: An otherwise unused Shardy variable name prefix + flatten_axis: Axis along which data can be flattened to 2D for quantization. + + Returns: + The Shardy rules for the scaling mode + """ + del flatten_axis + input_spec = tuple(f"{unique_var}{i}" for i in range(input_rank)) + scale_var = BATCHING + unique_var + "_scale_inv" + return QuantizeShardyRules(input_spec, (scale_var,), (scale_var,), {}) + + class CurrentScalingModeMetadataImpl(ScalingModeMetadataImpl): """Implementation for current scaling mode. @@ -740,5 +824,5 @@ def tree_unflatten(cls, aux_data, _children): ScalingMode.MXFP8_1D_SCALING: BlockScalingModeMetadataImpl(block_dims=(1, 32)), # WAR ScalingMode.CURRENT_TENSOR_SCALING: CurrentScalingModeMetadataImpl(), - ScalingMode.NO_SCALING: DelayedScalingModeMetadataImpl(), + ScalingMode.NO_SCALING: NoScalingModeMetadataImpl(), } diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 1459175b7..dbbac4abc 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -25,6 +25,8 @@ __all__ = [ "TensorUsage", + "AbstractBaseTensor", + "NoScaleTensor", "ScaledTensor", "ScaledTensor1x", "ScaledTensor2x", @@ -34,14 +36,9 @@ ] -@register_pytree_node_class @dataclass -class ScaledTensor(ABC): - """Abstract base class for scaled tensors. - - This class defines the interface for all scaled tensor implementations, - providing methods for dequantization and accessing row/column-wise components. - """ +class AbstractBaseTensor(ABC): + """Abstract base class for all tensor types.""" @classmethod def tree_unflatten(cls, aux_data, children): @@ -93,9 +90,76 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st """ +@dataclass +class AbstractBaseTensor1x(AbstractBaseTensor): + """Abstract base class for single layout tensors.""" + + data: jnp.ndarray + amax: jnp.ndarray + + @register_pytree_node_class @dataclass -class ScaledTensor1x(ScaledTensor): +class NoScaleTensor(AbstractBaseTensor1x): + """Higher-precision tensor.""" + + def __post_init__(self): + assert isinstance(self.data, jnp.ndarray), "NoScaleTensor's data must be a jnp.ndarray." + + def tree_flatten(self): + """Flattens the tensor for JAX tree operations. + + Returns: + A tuple containing (children, aux_data) for tree operations + """ + children = (self.data, self.amax) + aux_data = () + return (children, aux_data) + + @property + def ndim(self): + """Number of dimensions of the underlying array.""" + return self.data.ndim + + def dequantize(self): + """This is a no-op for a higher-precision tensor so this simply returns the tensor's data.""" + return self.data + + def get_tensor(self, usage: TensorUsage): + """Returns the tensor based on the tensor usage.""" + q_layout = ScalingMode.NO_SCALING.get_quantize_layout(usage) + assert ( + q_layout == QuantizeLayout.ROWWISE + ), "Only ROWWISE layout is supported for NoScaleTensor" + return self + + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): + """Applies sharding constraints to a tensor based on logical axis names. + + Args: + logical_axis_names: Tuple of logical axis names for sharding + + Returns: + The tensor with applied sharding constraints + """ + if not logical_axis_names: + return self + + data = with_sharding_constraint_by_logical_axes(self.data, logical_axis_names) + + return NoScaleTensor( + data=data, + amax=self.amax, + ) + + +class ScaledTensor(ABC): + """Abstract base class for scaled tensors.""" + + +@register_pytree_node_class +@dataclass +class ScaledTensor1x(AbstractBaseTensor1x, ScaledTensor): """Single-scale quantized tensor implementation. This class represents a tensor quantized with a single scaling factor, @@ -113,9 +177,7 @@ class ScaledTensor1x(ScaledTensor): flatten_axis: The quantization axis for the tensor """ - data: jnp.ndarray scale_inv: jnp.ndarray - amax: jnp.ndarray scaling_mode: ScalingMode dq_dtype: jnp.dtype _dq_func: Callable @@ -154,7 +216,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv, self.amax) + children = (self.data, self.amax, self.scale_inv) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -274,15 +336,15 @@ def __init__( self.original_shape = original_shape self.group_axis = group_axis super().__init__( - data, - scale_inv, - amax, - scaling_mode, - dq_dtype, - _dq_func, - is_colwise, - data_layout, - flatten_axis, + data=data, + scale_inv=scale_inv, + amax=amax, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=_dq_func, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, ) def __post_init__(self): @@ -339,7 +401,7 @@ def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[st @register_pytree_node_class @dataclass -class ScaledTensor2x(ScaledTensor): +class ScaledTensor2x(AbstractBaseTensor, ScaledTensor): """Double-scale quantized tensor implementation. This class represents a tensor quantized with both row-wise and column-wise scaling factors. @@ -503,15 +565,15 @@ def create_1x( flatten_axis = data.ndim - flatten_axis return ScaledTensor1x( - data, - scale_inv, - amax, - scaling_mode, - dq_dtype, - dequantizer.dequantize, - is_colwise, - data_layout, - flatten_axis, + data=data, + scale_inv=scale_inv, + amax=amax, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=dequantizer.dequantize, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, ) @staticmethod @@ -675,7 +737,7 @@ def with_sharding_constraint_by_logical_axes(x, logical_axis_names: Tuple[str, . if isinstance(x, GroupedScaledTensor1x): raise NotImplementedError - if isinstance(x, ScaledTensor): + if isinstance(x, AbstractBaseTensor): return x.apply_sharding_constraint_by_logical_axes(logical_axis_names) return original_with_sharding_constraint_by_logical_axes(x, logical_axis_names) From 5b3d65cc1a157ac76d0e4c6342db0c5d80f69984 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Mon, 8 Sep 2025 11:26:52 -0400 Subject: [PATCH 128/153] [JAX] Fix GroupedScaledTensor creation with keyword arg (#2154) Fix GroupedScaledTensor creation Signed-off-by: Phuong Nguyen --- transformer_engine/jax/dense.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index b0ba734e5..8087159a3 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -442,7 +442,7 @@ def _grouped_dense_fwd_rule( ctx_kernel = ScaledTensorFactory.create_1x( global_ctx_kernel_data.reshape(-1), ctx_kernel.scale_inv, - ctx_kernel.scaling_mode, + scaling_mode=ctx_kernel.scaling_mode, dq_dtype=ctx_kernel.dq_dtype, is_colwise=False, data_layout="N", @@ -459,7 +459,7 @@ def _grouped_dense_fwd_rule( grouped_gemm_kernel = ScaledTensorFactory.create_1x( grouped_gemm_kernel_data.reshape(-1), ctx_kernel.scale_inv, - ctx_kernel.scaling_mode, + scaling_mode=ctx_kernel.scaling_mode, dq_dtype=ctx_kernel.dq_dtype, is_colwise=True, data_layout="T", From aa06107cbc1cc7378c665809c1608c53070447ea Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Mon, 8 Sep 2025 11:28:13 -0400 Subject: [PATCH 129/153] Fixing few issues with multi-process launching. (#2155) * Fixing few issues with multi-process launching. Signed-off-by: Ming Huang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Ming Huang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Phuong Nguyen --- tests/jax/multi_process_launch.sh | 6 +++--- ...est_multi_process_distributed_grouped_gemm.py | 16 ++++++++++++---- 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/tests/jax/multi_process_launch.sh b/tests/jax/multi_process_launch.sh index 3e0852f39..fcb066de7 100644 --- a/tests/jax/multi_process_launch.sh +++ b/tests/jax/multi_process_launch.sh @@ -12,12 +12,12 @@ XLA_BASE_FLAGS="--xla_gpu_enable_latency_hiding_scheduler=true export XLA_FLAGS="${XLA_BASE_FLAGS}" -NUM_RUNS=$(nvidia-smi --query-gpu=count --format=csv,noheader) +NUM_RUNS=$(nvidia-smi -L | wc -l) for ((i=1; i /dev/null 2>&1 & + CUDA_VISIBLE_DEVICES=$i python $SCRIPT_NAME 127.0.0.1:12345 $i $NUM_RUNS > /dev/null 2>&1 & done -CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_PROC +CUDA_VISIBLE_DEVICES=0 python $SCRIPT_NAME 127.0.0.1:12345 0 $NUM_RUNS wait diff --git a/tests/jax/test_multi_process_distributed_grouped_gemm.py b/tests/jax/test_multi_process_distributed_grouped_gemm.py index 6fce62d8c..31209d1bc 100644 --- a/tests/jax/test_multi_process_distributed_grouped_gemm.py +++ b/tests/jax/test_multi_process_distributed_grouped_gemm.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp +import jax.experimental.multihost_utils as jem from transformer_engine.jax.dense import grouped_dense as te_grouped_dense from transformer_engine.jax.quantize import ( @@ -13,7 +14,7 @@ ScalingMode, ) -from utils import assert_allclose +from utils import assert_allclose, dtype_tols N_GROUP = 8 @@ -137,9 +138,16 @@ def run(x, w): out, dx, dw = test_func_jitted(x, w, w_amax) ref_out, ref_dx, ref_dw = ref_func_jitted(x, w_global) - assert_allclose(out, ref_out, dtype=jnp.float8_e4m3fn) - assert_allclose(dx, ref_dx, dtype=jnp.float8_e5m2) - assert_allclose(dw, ref_dw, dtype=jnp.float8_e5m2) + e4m3_tols = dtype_tols(jnp.float8_e4m3fn) + e5m2_tols = dtype_tols(jnp.float8_e5m2) + + out, ref_out = jem.process_allgather((out, ref_out)) + dx, ref_dx = jem.process_allgather((dx, ref_dx)) + dw, ref_dw = jem.process_allgather((dw, ref_dw)) + + jnp.allclose(out, ref_out, **e4m3_tols) + jnp.allclose(dx, ref_dx, **e5m2_tols) + jnp.allclose(dw, ref_dw, **e5m2_tols) if __name__ == "__main__": From 603dbf72e4529868bcefd68bd5f901b84093626e Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 8 Sep 2025 10:55:57 -0700 Subject: [PATCH 130/153] Update list of authorized CI users (#2152) Signed-off-by: Tim Moon --- .github/workflows/trigger-ci.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/trigger-ci.yml b/.github/workflows/trigger-ci.yml index 85a81a6d4..f12a95d79 100644 --- a/.github/workflows/trigger-ci.yml +++ b/.github/workflows/trigger-ci.yml @@ -57,6 +57,7 @@ jobs: || github.actor == 'tdophung' || github.actor == 'vthumbe1503' || github.actor == 'janekb04' + || github.actor == 'shengfangd' ) steps: - name: Check if comment is issued by authorized person From 84fa28d2477c2243ab32cc02ba83faebc59a9e6b Mon Sep 17 00:00:00 2001 From: vasunvidia <108759426+vasunvidia@users.noreply.github.com> Date: Mon, 8 Sep 2025 14:54:26 -0700 Subject: [PATCH 131/153] Fused RoPE with combined QKV input. (#2122) * Fused RoPE with combined QKV input. Initial commit for Dropout with 8-bit RNG Fix documentation Initial commit for Fused QKV RoPE WIP Initial tests passing Enable rotary percent and margin Enable CP2, start_positions, interleaved Cleanup test Revert "Fix documentation" This reverts commit 53df10044e7769982bd4af2ae2628e6b7717e715. Revert "Initial commit for Dropout with 8-bit RNG" This reverts commit 301505e24031cbcd679069e1c2cd4d00eedf2dca. Cleanup. Minor cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernels Signed-off-by: Vasudevan Rengasamy * Misc. Cleanup Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Optimize kernel performance Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vasudevan Rengasamy * Move fused_qkv_rope test to test_fused_rope.py Signed-off-by: Vasudevan Rengasamy * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * apply shared memory optimization to separate fused rope kernels Signed-off-by: Xin Yao * fix lint Signed-off-by: Xin Yao --------- Signed-off-by: Vasudevan Rengasamy Signed-off-by: Xin Yao Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- tests/pytorch/test_fused_rope.py | 147 +++++- .../common/fused_rope/fused_rope.cu | 457 ++++++++++++++++-- .../include/transformer_engine/fused_rope.h | 63 +++ transformer_engine/pytorch/attention/rope.py | 163 ++++++- transformer_engine/pytorch/csrc/extensions.h | 12 + .../pytorch/csrc/extensions/apply_rope.cpp | 97 ++++ .../pytorch/csrc/extensions/pybind.cpp | 4 + 7 files changed, 898 insertions(+), 45 deletions(-) diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index ae25af949..62d80b552 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -1,25 +1,32 @@ # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. -from typing import Callable, Tuple, Union +from typing import Callable, Tuple, Union, List import math import torch import pytest from transformer_engine.pytorch.attention.rope import ( RotaryPositionEmbedding, apply_rotary_pos_emb, + apply_fused_qkv_rotary_pos_emb, ) # Gradient is a broadcasted scalar -def _overlapping_grad(output: torch.Tensor) -> torch.Tensor: - return output.sum() * 2 +def _overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + if isinstance(output, List): + return sum(t.sum() * 2 for t in output) + else: + return output.sum() * 2 # Gradient is a full tensor -def _non_overlapping_grad(output: torch.Tensor) -> torch.Tensor: - t = torch.ones_like(output) - return torch.sum(output * t) +def _non_overlapping_grad(output: Union[List[torch.Tensor], torch.Tensor]) -> torch.Tensor: + if isinstance(output, List): + return sum(torch.sum(t * torch.ones_like(t)) for t in output) + else: + t = torch.ones_like(output) + return torch.sum(output * t) @pytest.mark.parametrize("start_positions", [True, False]) @@ -238,3 +245,131 @@ def test_fused_rope_thd( torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() + + +@pytest.mark.parametrize("start_positions", [True, False]) +@pytest.mark.parametrize("dtype", [torch.float32, torch.bfloat16, torch.float16]) +@pytest.mark.parametrize("seq_length", [2, 8, 2048, 4096]) +@pytest.mark.parametrize("hidden_size", [64, 128, 256]) +@pytest.mark.parametrize("rotary_percent", [0.5, 1.0]) +@pytest.mark.parametrize("margin", [0, 10]) +@pytest.mark.parametrize("tensor_format", ["sbhd", "bshd"]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad, _non_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [1, 2]) +@pytest.mark.parametrize("interleaved", [True, False]) +def test_fused_qkv_rope( + dtype: torch.dtype, + seq_length: int, + hidden_size: int, + rotary_percent: float, + margin: int, + tensor_format: str, + loss_func: Callable, + cp_size: int, + interleaved: bool, + start_positions: bool, +) -> None: + if margin == 0 and start_positions == True: + # This makes sure that the `start_positions` offsets being applied + # are with the maximum length of the rope embeddings. + pytest.skip("Skipping test with margin=0 and start_positions=True") + + if start_positions == True and cp_size > 1: + # `start_positions` is only supported for `cp_size=1` and inference. + pytest.skip("Skipping test with cp_size>1 and start_positions=True") + + if seq_length - margin < 0: + pytest.skip("Skipping test with seq_length - margin < 0") + + device = torch.device("cuda:0") + batch_size, head_num = 2, 64 + + t = torch.rand( + (seq_length - margin, batch_size, head_num, hidden_size * 6), + dtype=dtype, + device=device, + ) + + # Get arbitrary offsets to be used with RoPE for all the sequences + start_positions = ( + torch.randint(0, margin, (batch_size,), dtype=torch.int32, device=device) + if start_positions + else None + ) + + if tensor_format == "bshd": + t = t.transpose(0, 1).contiguous() + t.requires_grad = True + + rotary_pos_emb_q = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb_q = rotary_pos_emb_q(seq_length * cp_size) + rotary_pos_emb_k = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb_k = rotary_pos_emb_k(seq_length * cp_size) + + for cp_rank in range(cp_size): + # unfused + # The fused kernel computes in float32 internally, so we force the unfused func to use float32 + # for more accurate comparison + + t_clone = t.clone() + (query, key, value) = torch.split( + t_clone, [hidden_size * 4, hidden_size, hidden_size], dim=3 + ) + query = query.reshape(query.shape[0], query.shape[1], head_num * 4, hidden_size) + + query_unfused = apply_rotary_pos_emb( + query, + emb_q, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + key_unfused = apply_rotary_pos_emb( + key, + emb_k, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + fused=True, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + value_unfused = value + loss_unfused = loss_func([query_unfused, key_unfused, value_unfused]) + + if not isinstance(start_positions, torch.Tensor): + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() + + t.grad = None + + # fused + query_fused, key_fused, value_fused = apply_fused_qkv_rotary_pos_emb( + t, + emb_q, + emb_k, + tensor_format=tensor_format, + start_positions=start_positions, + interleaved=interleaved, + cp_size=cp_size, + cp_rank=cp_rank, + qkv_split_arg_list=[hidden_size * 4, hidden_size, hidden_size], + ) + loss_fused = loss_func([query_fused, key_fused, value_fused]) + + if not isinstance(start_positions, torch.Tensor): + loss_fused.backward() + grad_fused = t.grad.detach().clone() + t.grad = None + + torch.testing.assert_close(query_fused, query_unfused) + torch.testing.assert_close(key_fused, key_unfused) + torch.testing.assert_close(value_fused, value_unfused) + + if not isinstance(start_positions, torch.Tensor): + torch.testing.assert_close(grad_fused, grad_unfused) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index df9ea6ee5..ccd0bc44c 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -21,12 +21,21 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { + extern __shared__ float shared_mem_cos_sin[]; + float *shared_mem_cos = shared_mem_cos_sin; + float *shared_mem_sin = shared_mem_cos_sin + d2; + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + __syncthreads(); + #pragma unroll - for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { - float v_cos, v_sin; - sincosf(freqs[s_id * d2 + d_id], &v_sin, &v_cos); + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + float v_cos = shared_mem_cos[d_id]; + float v_sin = shared_mem_sin[d_id]; int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; @@ -49,12 +58,12 @@ __device__ void fused_rope_block_forward(const scalar_t *src, const float *freqs // copy the rest if (d > d2) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int offset_head = offset_block + h_id * stride_h; - int offset_head_dst = offset_block_dst + h_id * o_stride_h; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { #pragma unroll - for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + dst[offset_dst] = src[offset_src]; } } } @@ -67,47 +76,54 @@ __device__ void fused_rope_block_backward(const scalar_t *src, const float *freq const int h, const int d, const int d2, const int stride_h, const int stride_d, const int o_stride_h, const int o_stride_d) { + extern __shared__ float shared_mem_cos_sin[]; + float *shared_mem_cos = shared_mem_cos_sin; + float *shared_mem_sin = shared_mem_cos_sin + d2; + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + __syncthreads(); + #pragma unroll - for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { - float v_cos = cosf(freqs[s_id * d2 + d_id]); - float v_sin; - if (!interleaved) { - v_sin = (d_id + d2 / 2 < d2) ? sinf(freqs[s_id * d2 + d_id + d2 / 2]) - : -sinf(freqs[s_id * d2 + d_id + d2 / 2 - d2]); - } else { - v_sin = - (d_id % 2 == 0) ? sinf(freqs[s_id * d2 + d_id + 1]) : -sinf(freqs[s_id * d2 + d_id - 1]); - } + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { int offset_src = offset_block + h_id * stride_h + d_id * stride_d; int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; float v_src = src[offset_src]; - float v_src_rotate; + float v_cos = shared_mem_cos[d_id]; + float v_src_rotate, v_sin; if (!interleaved) { - v_src_rotate = (d_id + d2 / 2 < d2) - ? static_cast(src[offset_src + (d2 / 2) * stride_d]) - : static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + if (d_id + d2 / 2 < d2) { + v_src_rotate = static_cast(src[offset_src + (d2 / 2) * stride_d]); + v_sin = shared_mem_sin[d_id + d2 / 2]; + } else { + v_src_rotate = static_cast(src[offset_src + (d2 / 2 - d2) * stride_d]); + v_sin = -shared_mem_sin[d_id + d2 / 2 - d2]; + } } else { - v_src_rotate = (d_id % 2 == 0) - // d_id + 1 - ? static_cast(src[offset_src + stride_d]) - // d_id - 1 - : static_cast(src[offset_src - stride_d]); + if (d_id % 2 == 0) { + v_src_rotate = static_cast(src[offset_src + stride_d]); + v_sin = shared_mem_sin[d_id + 1]; + } else { + v_src_rotate = static_cast(src[offset_src - stride_d]); + v_sin = -shared_mem_sin[d_id - 1]; + } } dst[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; } } - // handle the tail + // copy the rest if (d > d2) { #pragma unroll - for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { - int offset_head = offset_block + h_id * stride_h; - int offset_head_dst = offset_block_dst + h_id * o_stride_h; + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { #pragma unroll - for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { - dst[offset_head_dst + d_id * o_stride_d] = src[offset_head + d_id * stride_d]; + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { + int offset_src = offset_block + h_id * stride_h + d_id * stride_d; + int offset_dst = offset_block_dst + h_id * o_stride_h + d_id * o_stride_d; + dst[offset_dst] = src[offset_src]; } } } @@ -198,6 +214,251 @@ __global__ void fused_rope_backward_kernel( offset_block_dst, h, d, d2, stride_h, stride_d, o_stride_h, o_stride_d); } +template +__device__ void fused_qkv_rope_block_forward(const scalar_t *src, const float *freqs, scalar_t *out, + const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int row_offset, const int in_row_length, + const int out_row_length) { + extern __shared__ float shared_mem_cos_sin_qk[]; + // Split the shared memory into cos and sin parts for q or k + float *shared_mem_cos = nullptr; + float *shared_mem_sin = nullptr; + if (row_offset == 0) { // q + shared_mem_cos = shared_mem_cos_sin_qk; + shared_mem_sin = shared_mem_cos_sin_qk + d2; + } else { // k + shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2; + shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2; + } + if (freqs != nullptr) { + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + } + __syncthreads(); + +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id; + if (freqs != nullptr) { + float v_cos, v_sin; + v_cos = shared_mem_cos[d_id]; + v_sin = shared_mem_sin[d_id]; + float v_src = src[offset_src]; + float v_src_rotate; + if (!interleaved) { + v_src_rotate = (d_id + d2 / 2 < d2) + ? -static_cast(src[offset_src + (d2 / 2)]) + : static_cast(src[offset_src + (d2 / 2 - d2)]); + } else { + v_src_rotate = (d_id % 2 == 0) ? -static_cast(src[offset_src + 1]) + : static_cast(src[offset_src - 1]); + } + out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } else { + out[offset_dst] = src[offset_src]; + } + } + } + } + // copy the rest + if (d > d2) { +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { + int offset_src = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_dst = offset_block_dst + h_id * out_row_length + i + d_id; + out[offset_dst] = src[offset_src]; + } + } + } + } +} + +template +__device__ void fused_qkv_rope_block_backward(const scalar_t *grad_out, const float *freqs, + scalar_t *out, const bool interleaved, const int s_id, + const int offset_block, const int offset_block_dst, + const int h, const int d, const int d2, + const int row_offset, const int in_row_length, + const int out_row_length) { + extern __shared__ float shared_mem_cos_sin_qk[]; + float *shared_mem_cos = nullptr; + float *shared_mem_sin = nullptr; + // Split the shared memory into cos and sin parts for q or k + if (row_offset == 0) { // q + shared_mem_cos = shared_mem_cos_sin_qk; + shared_mem_sin = shared_mem_cos_sin_qk + d2; + } else { // k + shared_mem_cos = shared_mem_cos_sin_qk + 2 * d2; + shared_mem_sin = shared_mem_cos_sin_qk + 3 * d2; + } + if (freqs != nullptr) { + int tid = threadIdx.x * blockDim.y + threadIdx.y; + for (int i = tid; i < d2; i += blockDim.x * blockDim.y) { + sincosf(freqs[s_id * d2 + i], &shared_mem_sin[i], &shared_mem_cos[i]); + } + } + __syncthreads(); +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = threadIdx.x; d_id < d2; d_id += blockDim.x) { + int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_src = offset_block_dst + h_id * out_row_length + i + d_id; + + float v_src = grad_out[offset_src]; + if (freqs != nullptr) { + float v_cos, v_sin; + v_cos = shared_mem_cos[d_id]; + float v_src_rotate; + if (!interleaved) { + if (d_id + d2 / 2 < d2) { + v_src_rotate = static_cast(grad_out[offset_src + (d2 / 2)]); + v_sin = shared_mem_sin[d_id + d2 / 2]; + } else { + v_src_rotate = static_cast(grad_out[offset_src + (d2 / 2 - d2)]); + v_sin = -shared_mem_sin[d_id + d2 / 2 - d2]; + } + } else { + if (d_id % 2 == 0) { + v_src_rotate = static_cast(grad_out[offset_src + 1]); + v_sin = shared_mem_sin[d_id + 1]; + } else { + v_src_rotate = static_cast(grad_out[offset_src - 1]); + v_sin = -shared_mem_sin[d_id - 1]; + } + } + out[offset_dst] = v_src * v_cos + v_src_rotate * v_sin; + } else { + out[offset_dst] = grad_out[offset_src]; + } + } + } + } + // copy the rest + if (d > d2) { +#pragma unroll + for (int h_id = threadIdx.y; h_id < h; h_id += blockDim.y) { +#pragma unroll + for (int i = 0; i < out_row_length; i += d) { +#pragma unroll + for (int d_id = d2 + threadIdx.x; d_id < d; d_id += blockDim.x) { + int offset_dst = offset_block + h_id * in_row_length + (row_offset + i) + d_id; + int offset_src = offset_block_dst + h_id * out_row_length + i + d_id; + out[offset_dst] = grad_out[offset_src]; + } + } + } + } +} + +template +__global__ void fused_qkv_rope_forward_kernel( + const scalar_t *qkv_input, const float *q_freqs, const float *k_freqs, + const int *start_positions, scalar_t *q_out, scalar_t *k_out, scalar_t *v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, const int q_split_arg, + const int k_split_arg, const int v_split_arg) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int cur_seqlens = s; + int total_d = q_split_arg + k_split_arg + v_split_arg; + int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v; + if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + offset_block = s_id * b * h * total_d + b_id * h * total_d; + offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg; + offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg; + offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg; + } else { + offset_block = b_id * s * h * total_d + s_id * h * total_d; + offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg; + offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg; + offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg; + } + + int q_limit = q_split_arg; + int k_limit = q_limit + k_split_arg; + int s_id_for_freqs; + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + s_id_for_freqs = s_id + begin_offset; + } + fused_qkv_rope_block_forward(qkv_input, q_freqs, q_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_q, h, d, d2, 0, total_d, q_split_arg); + fused_qkv_rope_block_forward(qkv_input, k_freqs, k_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_k, h, d, d2, q_limit, total_d, k_split_arg); + fused_qkv_rope_block_forward(qkv_input, nullptr, v_out, interleaved, s_id_for_freqs, offset_block, + offset_block_dst_v, h, d, d2, k_limit, total_d, v_split_arg); +} + +template +__global__ void fused_qkv_rope_backward_kernel( + const scalar_t *grad_out_q, const scalar_t *grad_out_k, const scalar_t *grad_out_v, + const float *q_freqs, const float *k_freqs, scalar_t *qkv_grad, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, const int q_split_arg, + const int k_split_arg, const int v_split_arg) { + int s_id = blockIdx.x, b_id = blockIdx.y; + int cur_seqlens = s; + int offset_block, offset_block_dst_q, offset_block_dst_k, offset_block_dst_v; + int total_d = q_split_arg + k_split_arg + v_split_arg; + if (qkv_format == NVTE_QKV_Format::NVTE_SBHD) { + offset_block = s_id * b * h * total_d + b_id * h * total_d; + offset_block_dst_q = s_id * b * h * q_split_arg + b_id * h * q_split_arg; + offset_block_dst_k = s_id * b * h * k_split_arg + b_id * h * k_split_arg; + offset_block_dst_v = s_id * b * h * v_split_arg + b_id * h * v_split_arg; + } else { + offset_block = b_id * s * h * total_d + s_id * h * total_d; + offset_block_dst_q = b_id * s * h * q_split_arg + s_id * h * q_split_arg; + offset_block_dst_k = b_id * s * h * k_split_arg + s_id * h * k_split_arg; + offset_block_dst_v = b_id * s * h * v_split_arg + s_id * h * v_split_arg; + } + int q_limit = q_split_arg; + int k_limit = q_limit + k_split_arg; + int s_id_for_freqs; + if (cp_size > 1) { + assert(cur_seqlens % 2 == 0); + if (s_id < cur_seqlens / 2) { + s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + } else { + s_id_for_freqs = + cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + } + } else { + s_id_for_freqs = s_id; + } + fused_qkv_rope_block_backward(grad_out_q, q_freqs, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_q, h, d, d2, 0, total_d, + q_split_arg); + fused_qkv_rope_block_backward(grad_out_k, k_freqs, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_k, h, d, d2, q_limit, total_d, + k_split_arg); + fused_qkv_rope_block_backward(grad_out_v, nullptr, qkv_grad, interleaved, s_id_for_freqs, + offset_block, offset_block_dst_v, h, d, d2, k_limit, total_d, + v_split_arg); +} + template void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, const float *freqs, const int *start_positions, scalar_t *output, @@ -209,6 +470,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); @@ -224,7 +486,7 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c const int o_stride_h = d; const int o_stride_d = 1; - fused_rope_forward_kernel<<>>( + fused_rope_forward_kernel<<>>( input, cu_seqlens, freqs, start_positions, output, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); @@ -242,6 +504,7 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 2 * d2 * sizeof(float); // cos, sin int o_stride_s_or_t, o_stride_b; if (qkv_format == NVTE_QKV_Format::NVTE_THD) { NVTE_CHECK(cu_seqlens != nullptr, "cu_seqlens is required for THD format"); @@ -257,13 +520,58 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_h = d; const int o_stride_d = 1; - fused_rope_backward_kernel<<>>( + fused_rope_backward_kernel<<>>( output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void fused_qkv_rope_forward_launcher(const scalar_t *qkv_input, const float *q_freqs, + const float *k_freqs, const int *start_positions, + scalar_t *q_out, scalar_t *k_out, scalar_t *v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + const int THREADS_PER_WARP = 32; + int warps_per_block = (h <= 8) ? h : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k + + fused_qkv_rope_forward_kernel<<>>( + qkv_input, q_freqs, k_freqs, start_positions, q_out, k_out, v_out, qkv_format, interleaved, + cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +template +void fused_qkv_rope_backward_launcher(const scalar_t *q_grad_out, const scalar_t *k_grad_out, + const scalar_t *v_grad_out, const float *q_freqs, + const float *k_freqs, scalar_t *qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, + const int b, const int h, const int d, const int d2, + const int qkv_split_arg_list_0, + const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + const int THREADS_PER_WARP = 32; + const int warps_per_block = (h <= 8) ? h : 8; + dim3 blocks(s, b); + dim3 threads(THREADS_PER_WARP, warps_per_block); + const int shared_mem_size = 4 * d2 * sizeof(float); // cos, sin * q ,k + + fused_qkv_rope_backward_kernel<<>>( + q_grad_out, k_grad_out, v_grad_out, q_freqs, k_freqs, qkv_grad_input, qkv_format, interleaved, + cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Tensor &freqs, const Tensor &start_positions, Tensor *output, const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, @@ -297,6 +605,46 @@ void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, c stride_b, stride_h, stride_d, stream);); } +void fused_qkv_rope_forward(const Tensor &qkv_input, const Tensor &q_freqs, const Tensor &k_freqs, + const Tensor &start_positions, Tensor *q_out, Tensor *k_out, + Tensor *v_out, const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int qkv_split_arg_list_0, + const int qkv_split_arg_list_1, const int qkv_split_arg_list_2, + cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + qkv_input.data.dtype, scalar_t, + fused_qkv_rope_forward_launcher(reinterpret_cast(qkv_input.data.dptr), + reinterpret_cast(q_freqs.data.dptr), + reinterpret_cast(k_freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), + reinterpret_cast(q_out->data.dptr), + reinterpret_cast(k_out->data.dptr), + reinterpret_cast(v_out->data.dptr), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream);); +} + +void fused_qkv_rope_backward(const Tensor &q_grad_out, const Tensor &k_grad_out, + const Tensor &v_grad_out, const Tensor &q_freqs, const Tensor &k_freqs, + Tensor *qkv_grad_input, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( + q_grad_out.data.dtype, scalar_t, + fused_qkv_rope_backward_launcher(reinterpret_cast(q_grad_out.data.dptr), + reinterpret_cast(k_grad_out.data.dptr), + reinterpret_cast(v_grad_out.data.dptr), + reinterpret_cast(q_freqs.data.dptr), + reinterpret_cast(k_freqs.data.dptr), + reinterpret_cast(qkv_grad_input->data.dptr), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream);); +} } // end namespace transformer_engine void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens, @@ -328,3 +676,38 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } + +void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, + const NVTETensor k_freqs, const NVTETensor start_positions, + NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_qkv_rope_forward); + using namespace transformer_engine; + fused_qkv_rope_forward(*convertNVTETensorCheck(qkv_input), *convertNVTETensorCheck(q_freqs), + *convertNVTETensorCheck(k_freqs), *convertNVTETensorCheck(start_positions), + convertNVTETensorCheck(q_out), convertNVTETensorCheck(k_out), + convertNVTETensorCheck(v_out), qkv_format, interleaved, cp_size, cp_rank, + s, b, h, d, d2, qkv_split_arg_list_0, qkv_split_arg_list_1, + qkv_split_arg_list_2, stream); +} + +void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, + const NVTETensor v_grad_out, const NVTETensor q_freqs, + const NVTETensor k_freqs, NVTETensor qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream) { + NVTE_API_CALL(nvte_fused_qkv_rope_backward); + using namespace transformer_engine; + fused_qkv_rope_backward(*convertNVTETensorCheck(q_grad_out), *convertNVTETensorCheck(k_grad_out), + *convertNVTETensorCheck(v_grad_out), *convertNVTETensorCheck(q_freqs), + *convertNVTETensorCheck(k_freqs), convertNVTETensorCheck(qkv_grad_input), + qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, + qkv_split_arg_list_0, qkv_split_arg_list_1, qkv_split_arg_list_2, stream); +} diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index f0817a97f..610868f93 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -75,6 +75,69 @@ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu const int stride_b, const int stride_h, const int stride_d, cudaStream_t stream); +/*! \brief Apply rotary positional embedding to the combined QKV input tensor. + * + * \param[in] qkv_input Combined QKV input tensor for fused rope. + * \param[in] q_freqs The freqs tensor for Q. + * \param[in] k_freqs The freqs tensor for K. + * \param[in] start_positions The beginning offsets for applying RoPE embeddings. + * \param[out] q_out Output tensor for Q. + * \param[out] k_out Output tensor for K. + * \param[out] v_out Output tensor for V. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] s Length of the s dimension of input. + * \param[in] b Length of the b dimension of input. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] qkv_split_arg_list_0 The hidden size for Q. + * \param[in] qkv_split_arg_list_1 The hidden size for K. + * \param[in] qkv_split_arg_list_2 The hidden size for V. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, + const NVTETensor k_freqs, const NVTETensor start_positions, + NVTETensor q_out, NVTETensor k_out, NVTETensor v_out, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream); + +/*! \brief Compute the backward of the fused qkv rope. + * + * \param[in] q_grad_out Incoming gradient tensor for Q. + * \param[in] k_grad_out Incoming gradient tensor for K. + * \param[in] v_grad_out Incoming gradient tensor for V. + * \param[in] q_freqs The freqs tensor for Q. + * \param[in] k_freqs The freqs tensor for K. + * \param[out] qkv_grad_input Input gradient tensor to calculate. + * \param[in] qkv_format QKV format. + * \param[in] interleaved Whether to use interleaved rotary position embedding. + * \param[in] cp_size Context parallel world size. + * \param[in] cp_rank Context parallel rank. + * \param[in] s Length of the s dimension of input. + * \param[in] b Length of the b dimension of input. + * \param[in] h Length of the h dimension of input. + * \param[in] d Length of the d dimension of input. + * \param[in] d2 Length of the d dimension of freqs. + * \param[in] qkv_split_arg_list_0 The hidden size for Q. + * \param[in] qkv_split_arg_list_1 The hidden size for K. + * \param[in] qkv_split_arg_list_2 The hidden size for V. + * \param[in] stream CUDA stream used for the operation. + */ +void nvte_fused_qkv_rope_backward(const NVTETensor q_grad_out, const NVTETensor k_grad_out, + const NVTETensor v_grad_out, const NVTETensor q_freqs, + const NVTETensor k_freqs, NVTETensor qkv_grad_input, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, + const int qkv_split_arg_list_0, const int qkv_split_arg_list_1, + const int qkv_split_arg_list_2, cudaStream_t stream); + #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index 60685a31d..139381f2d 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -5,14 +5,14 @@ """ Rotary Position Embedding implementation of different types along with helper functions """ -from typing import Optional, Tuple, Union +from typing import Optional, Tuple, Union, List import torch import transformer_engine_torch as tex from transformer_engine.pytorch.cpp_extensions.fused_attn import QKVFormat -__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb"] +__all__ = ["RotaryPositionEmbedding", "apply_rotary_pos_emb", "apply_fused_qkv_rotary_pos_emb"] class RotaryPositionEmbedding(torch.nn.Module): @@ -170,6 +170,86 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], return grad_input, None, None, None, None, None, None, None +class FusedQKVRoPEFunc(torch.autograd.Function): + """ + Function for FusedQKVRoPE + + This implementation accepts combined QKV tensor in `bshd` or `sbhd` format. Q and K RoPE tensors are the additional required inputs. + The RoPE tensors should be of shape (s, 1, 1, d). It produces 3 outputs: Q, K after RoPE, V is the same as input. + """ + + @staticmethod + def forward( + ctx, + qkv: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + start_positions: Union[torch.Tensor, None] = None, + tensor_format: str = "sbhd", + interleaved: bool = False, + cp_size: int = 1, + cp_rank: int = 0, + ) -> torch.Tensor: + """Fused RoPE forward.""" + + if q_freqs.dtype != torch.float32: + q_freqs = q_freqs.float() + if k_freqs.dtype != torch.float32: + k_freqs = k_freqs.float() + assert tensor_format in ( + "sbhd", + "bshd", + ), f"Unsupported tensor_format: {tensor_format}." + assert qkv.is_contiguous(), "QKV Tensor should be contiguous." + assert q_freqs.is_contiguous(), "q_freqs Tensor should be contiguous." + assert k_freqs.is_contiguous(), "k_freqs Tensor should be contiguous." + output = tex.fused_qkv_rope_forward( + qkv, + q_freqs, + k_freqs, + start_positions, + qkv_split_arg_list, + QKVFormat[tensor_format], + interleaved, + cp_size, + cp_rank, + ) + ctx.save_for_backward(q_freqs, k_freqs) + ctx.tensor_format = tensor_format + ctx.qkv_split_arg_list = qkv_split_arg_list + ctx.cp_size = cp_size + ctx.cp_rank = cp_rank + ctx.interleaved = interleaved + return output + + @staticmethod + def backward( + ctx, grad_output_q: torch.Tensor, grad_output_k: torch.Tensor, grad_output_v: torch.Tensor + ) -> Tuple[Union[torch.Tensor, None], ...]: + """Fused RoPE backward.""" + q_freqs, k_freqs = ctx.saved_tensors + + grad_output_q = grad_output_q.contiguous() + grad_output_k = grad_output_k.contiguous() + grad_output_v = grad_output_v.contiguous() + + grad_input = tex.fused_qkv_rope_backward( + grad_output_q, + grad_output_k, + grad_output_v, + q_freqs, + k_freqs, + ctx.qkv_split_arg_list, + QKVFormat[ctx.tensor_format], + ctx.interleaved, + ctx.cp_size, + ctx.cp_rank, + ) + + return grad_input, None, None, None, None, None, None, None, None + + def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: """Change sign so the last dimension becomes [-odd, +even] @@ -393,3 +473,82 @@ def apply_rotary_pos_emb( tensor_format, interleaved=interleaved, ) + + +def apply_fused_qkv_rotary_pos_emb( + qkv: torch.Tensor, + q_freqs: torch.Tensor, + k_freqs: torch.Tensor, + qkv_split_arg_list: List[int], + tensor_format: str = "sbhd", + start_positions: Union[torch.Tensor, None] = None, + interleaved: bool = False, + cu_seqlens: Union[torch.Tensor, None] = None, # pylint: disable=unused-argument + cp_size: int = 1, + cp_rank: int = 0, +) -> torch.Tensor: + """ + Apply rotary positional embedding tensor to the input qkv tensor. + + Support matrix: + Fused: + Training: + qkv_formats: "bshd", "sbhd" + context parallel: yes + start_positions: no + interleaving: yes + Inference: + qkv_formats: "bshd", "sbhd" + context parallelism: no + start_positions: yes + interleaving: yes + + Parameters + ---------- + qkv: torch.Tensor + Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which + rotary positional embedding will be applied. This tensor has q, k, v concatenated + along the last dimension. + q_freqs: torch.Tensor + Rotary positional embedding Q tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + k_freqs: torch.Tensor + Rotary positional embedding K tensor of shape `[s2, 1, 1, d2]` and dtype 'float', + with `s2 >= s` and `d2 <= d`. + qkv_split_arg_list: List[int] + List of integers that specify the split of the qkv tensor. The list should have 3 elements, + the first element is the number of elements in the q tensor, the second element is the number + of elements in the k tensor, and the third element is the number of elements in the v tensor. + The sum of the elements in the list should be equal to the last dimension of the qkv tensor. + start_positions: torch.Tensor, default = None. + Tokens in a sequence `i` should be applied with position encoding offset by + `start_positions[i]`. If `start_positions=None`, there's no offset. + tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' + is `bshd` if `qkv` is of shape `[bs, seq, ...]`, or `sbhd` if `qkv` is + of shape `[seq, bs, ...]`. + interleaved: bool, default = False + Whether to use interleaved rotary position embedding. + cp_size: int, default = 1. + Context parallel world size. + cp_rank: int, default = 0. + Context parallel rank. + """ + + # `start_positions` is only supported for `cp_size=1` and inference. + assert not ( + cp_size > 1 and start_positions is not None + ), """start_positions != None with CP SIZE > 1 is not supported!""" + + assert tensor_format != "thd", "'thd' tensor_format not supported currently." + + return FusedQKVRoPEFunc.apply( + qkv, + q_freqs, + k_freqs, + qkv_split_arg_list, + start_positions, + tensor_format, + interleaved, + cp_size, + cp_rank, + ) diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index a6b65562e..4cb05725b 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -338,6 +338,18 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor const std::optional cu_seqlens, const int cp_size, const int cp_rank); +std::tuple fused_qkv_rope_forward( + const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, + const std::optional start_positions, const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, const int cp_rank); + +at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, + const at::Tensor &v_grad_out, const at::Tensor &q_freqs, + const at::Tensor &k_freqs, + const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank); + /*************************************************************************************************** * Miscellaneous **************************************************************************************************/ diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 6f6f82725..d1ba1a351 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -102,6 +102,65 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, return output; } +std::tuple fused_qkv_rope_forward( + const at::Tensor &qkv_input, const at::Tensor &q_freqs, const at::Tensor &k_freqs, + const std::optional start_positions, const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, const int cp_size, + const int cp_rank) { + TORCH_CHECK(q_freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(q_freqs.size(1) == 1 && q_freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(q_freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + TORCH_CHECK(k_freqs.dim() == 4, "expected 4D tensor"); + TORCH_CHECK(k_freqs.size(1) == 1 && k_freqs.size(2) == 1, + "expected the second and third dims of the freqs tensor equal 1"); + TORCH_CHECK(k_freqs.scalar_type() == at::ScalarType::Float, + "Dtype of the freqs tensor must be float"); + // output + auto act_options = at::TensorOptions().dtype(qkv_input.scalar_type()).device(qkv_input.device()); + auto q_out_size = qkv_input.sizes().vec(); + q_out_size[2] = q_out_size[2] * qkv_split_arg_list[0] / qkv_split_arg_list[1]; + q_out_size[3] = qkv_split_arg_list[1]; + auto q_out = at::empty(q_out_size, act_options); + auto k_out_size = qkv_input.sizes().vec(); + k_out_size[3] = qkv_split_arg_list[1]; + auto k_out = at::empty(k_out_size, act_options); + auto v_out_size = qkv_input.sizes().vec(); + v_out_size[3] = qkv_split_arg_list[2]; + auto v_out = at::empty(v_out_size, act_options); + + auto qkv_cu = makeTransformerEngineTensor(qkv_input); + auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); + auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); + auto q_out_cu = makeTransformerEngineTensor(q_out); + auto k_out_cu = makeTransformerEngineTensor(k_out); + auto v_out_cu = makeTransformerEngineTensor(v_out); + + auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor + if (start_positions) { + start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + } + + TORCH_CHECK(qkv_input.dim() == 4, "expected 4D input tensor"); + TORCH_CHECK(qkv_input.is_contiguous(), "input tensor must be contiguous"); + + const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = is_sbhd ? qkv_input.size(0) : qkv_input.size(1); + const int b = is_sbhd ? qkv_input.size(1) : qkv_input.size(0); + const int h = qkv_input.size(2); + const int d = qkv_split_arg_list[2]; + const int d2 = q_freqs.size(3); + + nvte_fused_qkv_rope_forward(qkv_cu.data(), q_freqs_cu.data(), k_freqs_cu.data(), + start_positions_cu.data(), q_out_cu.data(), k_out_cu.data(), + v_out_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, h, + d, d2, qkv_split_arg_list[0], qkv_split_arg_list[1], + qkv_split_arg_list[2], at::cuda::getCurrentCUDAStream()); + + return std::make_tuple(q_out, k_out, v_out); +} + at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, @@ -193,4 +252,42 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor return input_grads; } +at::Tensor fused_qkv_rope_backward(const at::Tensor &q_grad_out, const at::Tensor &k_grad_out, + const at::Tensor &v_grad_out, const at::Tensor &q_freqs, + const at::Tensor &k_freqs, + const std::vector &qkv_split_arg_list, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank) { + auto act_options = + at::TensorOptions().dtype(q_grad_out.scalar_type()).device(q_grad_out.device()); + auto qkv_grad_size = q_grad_out.sizes().vec(); + auto total_hd = + (q_grad_out.size(2) + k_grad_out.size(2) + v_grad_out.size(2)) * q_grad_out.size(3); + auto total_d = qkv_split_arg_list[0] + qkv_split_arg_list[1] + qkv_split_arg_list[2]; + qkv_grad_size[2] = total_hd / total_d; + qkv_grad_size[3] = total_d; + auto qkv_grad_input = at::empty(qkv_grad_size, act_options); + const bool is_sbhd = qkv_format == NVTE_QKV_Format::NVTE_SBHD; + const int s = is_sbhd ? q_grad_out.size(0) : q_grad_out.size(1); + const int b = is_sbhd ? q_grad_out.size(1) : q_grad_out.size(0); + const int h = qkv_grad_input.size(2); + const int d = qkv_split_arg_list[2]; + const int d2 = q_freqs.size(3); + + auto q_grad_out_cu = makeTransformerEngineTensor(q_grad_out); + auto k_grad_out_cu = makeTransformerEngineTensor(k_grad_out); + auto v_grad_out_cu = makeTransformerEngineTensor(v_grad_out); + auto q_freqs_cu = makeTransformerEngineTensor(q_freqs); + auto k_freqs_cu = makeTransformerEngineTensor(k_freqs); + auto qkv_grad_cu = makeTransformerEngineTensor(qkv_grad_input); + + nvte_fused_qkv_rope_backward(q_grad_out_cu.data(), k_grad_out_cu.data(), v_grad_out_cu.data(), + q_freqs_cu.data(), k_freqs_cu.data(), qkv_grad_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, qkv_split_arg_list[0], + qkv_split_arg_list[1], qkv_split_arg_list[2], + at::cuda::getCurrentCUDAStream()); + + return qkv_grad_input; +} + } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 541b16848..7649ccb6d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -278,6 +278,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { "Fused Apply RoPE FWD", py::call_guard()); m.def("fused_rope_backward", &transformer_engine::pytorch::fused_rope_backward, "Fused Apply RoPE BWD", py::call_guard()); + m.def("fused_qkv_rope_forward", &transformer_engine::pytorch::fused_qkv_rope_forward, + "Fused Apply QKV RoPE FWD", py::call_guard()); + m.def("fused_qkv_rope_backward", &transformer_engine::pytorch::fused_qkv_rope_backward, + "Fused Apply QKV RoPE BWD", py::call_guard()); // fused router m.def("fused_topk_with_score_function_fwd", From a26a7f1f660a416ad790123b577ba665191222db Mon Sep 17 00:00:00 2001 From: Autumn1998 <1515848689@qq.com> Date: Tue, 9 Sep 2025 10:22:24 +0800 Subject: [PATCH 132/153] Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162) * add bf16/fp32 token-per-expert on the moe-loss-computation on router fusion Signed-off-by: tongliu * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: tongliu Co-authored-by: tongliu Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/fused_router/fused_moe_aux_loss.cu | 2 +- transformer_engine/common/fused_router/utils.h | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index a738be873..94082594f 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -229,7 +229,7 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, // Loop: for all positions in each row for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { float C_coeff = Const_buf[0]; - IndexType tokens_per_expert_i = tokens_per_expert[i]; + double tokens_per_expert_i = static_cast(tokens_per_expert[i]); double grad_aux_loss_value = static_cast(grad_aux_loss[0]); // Loop: for all rows for (int j = global_warp_id; j < num_rows; j += global_warp_num) { diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 46e0ba632..b6f9d87bd 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -246,6 +246,14 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i using type = int64_t; \ { __VA_ARGS__ } \ } break; \ + case DType::kBFloat16: { \ + using type = bf16; \ + { __VA_ARGS__ } \ + } break; \ + case DType::kFloat32: { \ + using type = float; \ + { __VA_ARGS__ } \ + } break; \ default: \ NVTE_ERROR("Invalid type."); \ } From 5f2b83100c75fb633ef416d3685efb5e23062f5c Mon Sep 17 00:00:00 2001 From: Phuong Nguyen Date: Tue, 9 Sep 2025 07:43:28 -0400 Subject: [PATCH 133/153] [JAX] Scale swizzling via JAX transpose op (#2163) * add swizzle in jax Signed-off-by: Phuong Nguyen * added outer_impl Signed-off-by: Phuong Nguyen * clean up FFI Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen --- transformer_engine/jax/cpp_extensions/base.py | 9 +- transformer_engine/jax/cpp_extensions/gemm.py | 97 +++++++++++++------ .../jax/csrc/extensions/gemm.cpp | 52 ++-------- 3 files changed, 81 insertions(+), 77 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/base.py b/transformer_engine/jax/cpp_extensions/base.py index a27cec001..c05570566 100644 --- a/transformer_engine/jax/cpp_extensions/base.py +++ b/transformer_engine/jax/cpp_extensions/base.py @@ -134,6 +134,13 @@ def impl(): """ return NotImplemented + @classmethod + def outer_impl(cls, *args, **kwargs): + """ + to describe implementation for outer primitive + """ + return cls.impl(*args, **kwargs) + @staticmethod @abstractmethod def batcher(): @@ -196,7 +203,7 @@ def name_of_wrapper_p(): outer_p = core.Primitive(name_of_wrapper_p()) dispatch.prim_requires_devices_during_lowering.add(outer_p) outer_p.multiple_results = cls.multiple_results - outer_p.def_impl(cls.impl) + outer_p.def_impl(cls.outer_impl) outer_p.def_abstract_eval(cls.outer_abstract) batching.primitive_batchers[outer_p] = cls.batcher outer_p_lower = custom_partitioning(cls.impl, static_argnums=cls.impl_static_args) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index acc8d6727..2acc3fb68 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -152,6 +152,21 @@ def _quantize_gemm_operands(lhs, rhs, lhs_quantizer, rhs_quantizer, contracting_ return lhs_q, rhs_q +@partial(jax.jit, static_argnums=(1, 2)) +def swizzled_scale(scale_inv, flatten_axis, is_colwise): + "Swizzle scale_inv via JAX transpose ops" + original_shape = scale_inv.shape + shape_2d = (math.prod(original_shape[:flatten_axis]), math.prod(original_shape[flatten_axis:])) + if is_colwise: + scale_inv = jnp.transpose(scale_inv.reshape(shape_2d)) + cols, rows = shape_2d + else: + rows, cols = shape_2d + reshape = scale_inv.reshape(rows // 128, 4, 32, cols // 4, 4) + swizzled = jnp.transpose(reshape, (0, 3, 2, 1, 4)) + return swizzled.reshape(original_shape) + + class GemmPrimitive(BasePrimitive): """ Primitive for cuBLAS GEMM @@ -286,28 +301,18 @@ def _dims_are_consecutive(dims): ) pre_gelu_out = jax.core.ShapedArray(shape=pre_gelu_shape, dtype=pre_gelu_dtype) - # Need extra workspace for swizzled scale factors - lhs_swizzle_size = 0 - rhs_swizzle_size = 0 - swizzle_dtype = jnp.uint8 - if scaling_mode == ScalingMode.MXFP8_1D_SCALING: - lhs_swizzle_size = lhs_scale_inv.size - rhs_swizzle_size = rhs_scale_inv.size - lhs_swizzle = jax.core.ShapedArray(shape=(lhs_swizzle_size,), dtype=swizzle_dtype) - rhs_swizzle = jax.core.ShapedArray(shape=(rhs_swizzle_size,), dtype=swizzle_dtype) - # Declare cuBLAS workspace # cuBLAS workspace ptr must be 256 bytes aligned but JAX buffers are not # necessarily 256 bytes aligned, we add some padding to ensure alignment. workspace_size = get_cublas_workspace_size_bytes() + 256 workspace = jax.core.ShapedArray(shape=(workspace_size,), dtype=jnp.uint8) - return output, bias_grad, pre_gelu_out, lhs_swizzle, rhs_swizzle, workspace + return output, bias_grad, pre_gelu_out, workspace @staticmethod def outer_abstract(*args, **kwargs): outputs = GemmPrimitive.abstract(*args, **kwargs) - return outputs[:-3] # discard workspace arrays + return outputs[:-1] # discard workspace array @staticmethod def lowering( @@ -374,24 +379,22 @@ def impl( grad, use_split_accumulator, ): - lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) - lhs_transposed, rhs_transposed = _get_gemm_layout( - (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) - ) - lhs_scale_inv = apply_padding_to_scale_inv( - lhs_scale_inv, - scaling_mode, - lhs.shape, - is_colwise=lhs_transposed, - flatten_axis=max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims), - ) - rhs_scale_inv = apply_padding_to_scale_inv( - rhs_scale_inv, - scaling_mode, - rhs.shape, - is_colwise=not rhs_transposed, - flatten_axis=min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1, - ) + if scaling_mode.is_1d_block_scaling(): + lhs_cdims, rhs_cdims = map(sanitize_dims, (lhs.ndim, rhs.ndim), contracting_dims) + lhs_transposed, rhs_transposed = _get_gemm_layout( + (lhs.ndim, rhs.ndim), (lhs_cdims, rhs_cdims) + ) + lhs_flatten_axis = max(lhs_cdims) + 1 if lhs_transposed else min(lhs_cdims) + rhs_flatten_axis = min(rhs_cdims) if rhs_transposed else max(rhs_cdims) + 1 + + lhs_scale_inv = apply_padding_to_scale_inv( + lhs_scale_inv, scaling_mode, lhs.shape, lhs_transposed, lhs_flatten_axis + ) + rhs_scale_inv = apply_padding_to_scale_inv( + rhs_scale_inv, scaling_mode, rhs.shape, not rhs_transposed, rhs_flatten_axis + ) + lhs_scale_inv = swizzled_scale(lhs_scale_inv, lhs_flatten_axis, lhs_transposed) + rhs_scale_inv = swizzled_scale(rhs_scale_inv, rhs_flatten_axis, not rhs_transposed) outputs = GemmPrimitive.inner_primitive.bind( lhs, @@ -408,7 +411,39 @@ def impl( grad=grad, use_split_accumulator=use_split_accumulator, ) - return outputs[:-3] # discard workspace arrays + return outputs[:-1] # discard workspace array + + @staticmethod + def outer_impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ): + return GemmPrimitive.impl( + lhs, + lhs_scale_inv, + rhs, + rhs_scale_inv, + bias, + gelu_input, + out_dtype, + contracting_dims, + scaling_mode, + fuse_bias, + fuse_gelu, + grad, + use_split_accumulator, + ) @staticmethod def batcher( diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 032ac9eb7..113072131 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -28,8 +28,8 @@ static uint8_t *move_ptr_to_next_256B_aligned(uint8_t *ptr) { } std::tuple> xla_buffer_to_nvte_gemm_operand( - cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, Result_Type swizzled_scale_inv, - JAXX_Scaling_Mode scaling_mode, size_t axis_boundary, bool rowwise) { + cudaStream_t stream, Buffer_Type buffer, Buffer_Type scale_inv, JAXX_Scaling_Mode scaling_mode, + size_t axis_boundary, bool rowwise) { // Set tensor data with collapsed 2D shape auto buffer_dims = buffer.dimensions(); std::vector input_shape = {product(buffer_dims, 0, axis_boundary), @@ -61,40 +61,6 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( } else { input.set_columnwise_scale_inv(scale_inv.untyped_data(), scale_dtype, scale_shape); } - - // Swizzle scaling factors for MXFP8 - if (scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING) { - // Get the swizzle buffer - NVTE_CHECK(swizzled_scale_inv->element_count() > 0, - "Missing swizzled inverse scale buffer in the JAX primitive."); - auto scale_inv_dtype = convert_ffi_datatype_to_te_dtype(scale_inv.element_type()); - auto swizzled_scale_inv_dtype = - convert_ffi_datatype_to_te_dtype(swizzled_scale_inv->element_type()); - NVTE_CHECK(typeToSize(scale_inv_dtype) == 1 && typeToSize(swizzled_scale_inv_dtype) == 1, - "Inverse scale factors need to have an 8-bit data type."); - - // Create tensor to hold swizzled scale factor - TensorWrapper output(get_nvte_scaling_mode(scaling_mode)); - if (rowwise) { - output.set_rowwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); - } else { - output.set_columnwise_data(buffer.untyped_data(), input_dtype, input_shape); - output.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); - } - - // Launch swizzle kernel - nvte_swizzle_scaling_factors(input.data(), output.data(), stream); - - // Set swizzled scales into the input tensor - if (rowwise) { - input.set_rowwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, scale_shape); - } else { - input.set_columnwise_scale_inv(swizzled_scale_inv->untyped_data(), scale_dtype, - scale_shape); - } - } } return std::make_tuple(std::move(input), input_shape); @@ -103,21 +69,19 @@ std::tuple> xla_buffer_to_nvte_gemm_operand( Error_Type GemmFFI(cudaStream_t stream, Buffer_Type lhs, Buffer_Type lhs_scale_inv, Buffer_Type rhs, Buffer_Type rhs_scale_inv, Buffer_Type bias, Buffer_Type gelu_input, Result_Type output, Result_Type bias_grad, Result_Type pre_gelu_out, - Result_Type lhs_swizzle, Result_Type rhs_swizzle, Result_Type workspace, - JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, + Result_Type workspace, JAXX_Scaling_Mode scaling_mode, int64_t lhs_axis_boundary, int64_t rhs_axis_boundary, bool lhs_transposed, bool rhs_transposed, bool fuse_bias, bool fuse_gelu, bool grad, bool use_split_accumulator) { - // Operands (this includes swizzling MXFP8 scaling factors) // NOTE: TensorWrapper operands are always rowwise for full-precision GEMM, or FP8 GEMM when // device supports non-TN layouts (compute capability >= 10.0, excluding 12.x) bool always_rowwise = (scaling_mode == JAXX_Scaling_Mode::NO_SCALING || (is_tensor_scaling(scaling_mode) && nvte_is_non_tn_fp8_gemm_supported())); bool make_lhs_rowwise = (always_rowwise) ? true : !lhs_transposed; bool make_rhs_rowwise = (always_rowwise) ? true : rhs_transposed; - auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, lhs, lhs_scale_inv, lhs_swizzle, scaling_mode, lhs_axis_boundary, make_lhs_rowwise); - auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand( - stream, rhs, rhs_scale_inv, rhs_swizzle, scaling_mode, rhs_axis_boundary, make_rhs_rowwise); + auto [lhs_, lhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, lhs, lhs_scale_inv, scaling_mode, + lhs_axis_boundary, make_lhs_rowwise); + auto [rhs_, rhs_shape] = xla_buffer_to_nvte_gemm_operand(stream, rhs, rhs_scale_inv, scaling_mode, + rhs_axis_boundary, make_rhs_rowwise); // Output tensor std::vector out_shape = {(lhs_transposed) ? lhs_shape[1] : lhs_shape[0], @@ -188,8 +152,6 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GemmHandler, GemmFFI, .Ret() // output .Ret() // bias_grad .Ret() // pre_gelu_out - .Ret() // lhs_swizzled - .Ret() // rhs_swizzled .Ret() // workspace .Attr("scaling_mode") .Attr("lhs_axis_boundary") From 4903f947d6de871cd92c3478c8b5b78f835d5b7f Mon Sep 17 00:00:00 2001 From: vcherepanov-nv Date: Tue, 9 Sep 2025 23:52:01 -0700 Subject: [PATCH 134/153] Extract cpp distributed tests into a separate project (#2165) * Extract cpp distributed tests into a separate project Signed-off-by: Vladimir Cherepanov * Remove obsolete exclusion Signed-off-by: Vladimir Cherepanov * Run L1_cpp_distributed tests if at least 4 GPUs Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov --- qa/L0_cppunittest/test.sh | 2 +- qa/L1_cpp_distributed/test.sh | 10 ++-- tests/cpp/CMakeLists.txt | 1 - tests/cpp/comm_gemm/CMakeLists.txt | 19 ------- tests/cpp_distributed/CMakeLists.txt | 57 +++++++++++++++++++ .../test_comm_gemm.cu | 2 +- 6 files changed, 65 insertions(+), 26 deletions(-) delete mode 100644 tests/cpp/comm_gemm/CMakeLists.txt create mode 100644 tests/cpp_distributed/CMakeLists.txt rename tests/{cpp/comm_gemm => cpp_distributed}/test_comm_gemm.cu (99%) diff --git a/qa/L0_cppunittest/test.sh b/qa/L0_cppunittest/test.sh index aa56d69ed..cd46b0b63 100755 --- a/qa/L0_cppunittest/test.sh +++ b/qa/L0_cppunittest/test.sh @@ -17,4 +17,4 @@ cd $TE_PATH/tests/cpp cmake -GNinja -Bbuild . cmake --build build export OMP_NUM_THREADS=$((NUM_PHYSICAL_CORES / NUM_PARALLEL_JOBS)) -ctest --test-dir build -j$NUM_PARALLEL_JOBS -E '(AgGemm|GemmRs|GemmAr)' +ctest --test-dir build -j$NUM_PARALLEL_JOBS diff --git a/qa/L1_cpp_distributed/test.sh b/qa/L1_cpp_distributed/test.sh index f4f914b3e..e074b46ae 100755 --- a/qa/L1_cpp_distributed/test.sh +++ b/qa/L1_cpp_distributed/test.sh @@ -9,7 +9,9 @@ set -e TE_LIB_PATH=$(pip3 show transformer-engine | grep -E "Location:|Editable project location:" | tail -n 1 | awk '{print $NF}') export LD_LIBRARY_PATH=$TE_LIB_PATH:$LD_LIBRARY_PATH -cd $TE_PATH/tests/cpp -cmake -GNinja -S. -Bbuild -cmake --build build -mpirun --allow-run-as-root --np 4 --oversubscribe ./build/comm_gemm/test_comm_gemm +if [[ $(nvidia-smi --list-gpus | wc -l) -ge 4 ]]; then + cd $TE_PATH/tests/cpp_distributed + cmake -GNinja -S. -Bbuild + cmake --build build + mpirun --allow-run-as-root --np 4 --oversubscribe ./build/test_comm_gemm +fi diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index 412c5d34d..c2c9d0d91 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -43,6 +43,5 @@ include_directories(${CMAKE_SOURCE_DIR}) find_package(CUDAToolkit REQUIRED) include(${CMAKE_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) -add_subdirectory(comm_gemm) add_subdirectory(operator) add_subdirectory(util) diff --git a/tests/cpp/comm_gemm/CMakeLists.txt b/tests/cpp/comm_gemm/CMakeLists.txt deleted file mode 100644 index 55f5207ac..000000000 --- a/tests/cpp/comm_gemm/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. -# -# See LICENSE for license information. - -add_executable(test_comm_gemm - test_comm_gemm.cu - ../test_common.cu) - -find_package(OpenMP REQUIRED) -find_package(MPI REQUIRED) -find_library(NCCL_LIB - NAMES nccl libnccl - PATH_SUFFIXES lib - REQUIRED) -target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) -target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc CUDNN::cudnn MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) - -include(GoogleTest) -gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp_distributed/CMakeLists.txt b/tests/cpp_distributed/CMakeLists.txt new file mode 100644 index 000000000..ed3ddeb88 --- /dev/null +++ b/tests/cpp_distributed/CMakeLists.txt @@ -0,0 +1,57 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +cmake_minimum_required(VERSION 3.18) + +if(NOT DEFINED CMAKE_CUDA_ARCHITECTURES) + if (CUDAToolkit_VERSION VERSION_GREATER_EQUAL 12.8) + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90 100 120) + else () + set(CMAKE_CUDA_ARCHITECTURES 75 80 89 90) + endif() +endif() + + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +project(transformer_engine_distributed_tests LANGUAGES CUDA CXX) + +add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) + +include_directories(${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) + +if(NOT DEFINED TE_LIB_PATH) + execute_process(COMMAND bash -c "python3 -c 'import transformer_engine as te; print(te.__file__)'" + OUTPUT_VARIABLE TE_LIB_FILE + OUTPUT_STRIP_TRAILING_WHITESPACE) + get_filename_component(TE_LIB_PATH ${TE_LIB_FILE} DIRECTORY) +endif() + +find_library(TE_LIB NAMES transformer_engine PATHS "${TE_LIB_PATH}/.." ${TE_LIB_PATH} ENV TE_LIB_PATH REQUIRED) + +message(STATUS "Found transformer_engine library: ${TE_LIB}") +include_directories(../../transformer_engine/common/include) +include_directories(../../transformer_engine/common) +include_directories(../../transformer_engine) +include_directories(${CMAKE_SOURCE_DIR}) + +find_package(CUDAToolkit REQUIRED) + +add_executable(test_comm_gemm + test_comm_gemm.cu + ../cpp/test_common.cu) + +find_package(OpenMP REQUIRED) +find_package(MPI REQUIRED) +find_library(NCCL_LIB + NAMES nccl libnccl + PATH_SUFFIXES lib + REQUIRED) +target_include_directories(test_comm_gemm PRIVATE ${MPI_CXX_INCLUDE_PATH} $ENV{CUBLASMP_HOME}/include) +target_link_libraries(test_comm_gemm PUBLIC CUDA::cuda_driver CUDA::cudart GTest::gtest ${TE_LIB} CUDA::nvrtc MPI::MPI_CXX ${NCCL_LIB} OpenMP::OpenMP_CXX) + +include(GoogleTest) +gtest_discover_tests(test_comm_gemm DISCOVERY_TIMEOUT 600) diff --git a/tests/cpp/comm_gemm/test_comm_gemm.cu b/tests/cpp_distributed/test_comm_gemm.cu similarity index 99% rename from tests/cpp/comm_gemm/test_comm_gemm.cu rename to tests/cpp_distributed/test_comm_gemm.cu index b34d4db4b..8355d5f96 100644 --- a/tests/cpp/comm_gemm/test_comm_gemm.cu +++ b/tests/cpp_distributed/test_comm_gemm.cu @@ -19,7 +19,7 @@ #include #include -#include "../test_common.h" +#include "../cpp/test_common.h" #include "common.h" using transformer_engine::DType; From 483d9594fb070f62966f6a12ed6c90942310b48e Mon Sep 17 00:00:00 2001 From: jomitchellnv <148147880+jomitchellnv@users.noreply.github.com> Date: Wed, 10 Sep 2025 10:54:43 -0700 Subject: [PATCH 135/153] Adds context parallelism utilities: moving cp shards to diff ranks and pad sequence to divisibility factory (#2129) * test - adds unit test for cp utilities and the utilites Signed-off-by: Jonathan Mitchell * assert line change Signed-off-by: Jonathan Mitchell * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Jonathan Mitchell Co-authored-by: Jonathan Mitchell Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Sudhakar Singh --- qa/L1_pytorch_distributed_unittest/test.sh | 1 + tests/pytorch/attention/test_cp_utils.py | 715 ++++++++++++++++++ .../dot_product_attention/context_parallel.py | 211 +++++- 3 files changed, 926 insertions(+), 1 deletion(-) create mode 100644 tests/pytorch/attention/test_cp_utils.py diff --git a/qa/L1_pytorch_distributed_unittest/test.sh b/qa/L1_pytorch_distributed_unittest/test.sh index e5b4b5861..7f061d222 100644 --- a/qa/L1_pytorch_distributed_unittest/test.sh +++ b/qa/L1_pytorch_distributed_unittest/test.sh @@ -35,6 +35,7 @@ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_torch_fsdp2.xml $TE_ python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_comm_gemm_overlap.xml $TE_PATH/tests/pytorch/distributed/test_comm_gemm_overlap.py || test_fail "test_comm_gemm_overlap.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_fusible_ops_with_userbuffers.xml $TE_PATH/tests/pytorch/distributed/test_fusible_ops_with_userbuffers.py || test_fail "test_fusible_ops_with_userbuffers.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_attention_with_cp.xml $TE_PATH/tests/pytorch/attention/test_attention_with_cp.py || test_fail "test_attention_with_cp.py" +python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cp_utils.xml $TE_PATH/tests/pytorch/attention/test_cp_utils.py || test_fail "test_cp_utils.py" python3 -m pytest -v -s --junitxml=$XML_LOG_DIR/pytest_test_cast_master_weights_to_fp8.xml $TE_PATH/tests/pytorch/distributed/test_cast_master_weights_to_fp8.py || test_fail "test_cast_master_weights_to_fp8.py" diff --git a/tests/pytorch/attention/test_cp_utils.py b/tests/pytorch/attention/test_cp_utils.py new file mode 100644 index 000000000..00200c62d --- /dev/null +++ b/tests/pytorch/attention/test_cp_utils.py @@ -0,0 +1,715 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +"""Unit tests for context parallel utils.""" +import torch +import unittest +from typing import Tuple +from transformer_engine.pytorch.attention.dot_product_attention.context_parallel import ( + get_batch_on_this_cp_rank, + pad_thd_sequences_for_cp, + generate_positional_ids_for_cp, +) + + +class TestSequencePadding(unittest.TestCase): + def test_padding_with_custom_padding_values_sequences_shorter_than_divisibility_factor(self): + """Test with custom padding values for all tensors.""" + # Setup + + input_ids = torch.tensor([1, 1, 1, 2, 2, 3, 3, 3, 3]) + cu_seqlens = torch.tensor([0, 3, 5, 9]) + labels = torch.tensor([-100, -100, -100, -100, -100, -100, -100, 13, -100]) + positional_ids = torch.tensor([0, 1, 2, 0, 1, 0, 1, 2, 3]) + divisibility_factor = 8 + + pid = 777 + label_pad = -200 + + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Sequence: [ a a a p p p p p b b pppppp ccccpppp] + print("input_ids_padded: ", input_ids_padded) + print("labels_padded: ", labels_padded) + print("positional_ids_padded: ", positional_ids_padded) + print("cu_seqlens_padded: ", cu_seqlens_padded) + + expected_input_ids = torch.tensor( + [ + 1, + 1, + 1, + pid, + pid, + pid, + pid, + pid, + 2, + 2, + pid, + pid, + pid, + pid, + pid, + pid, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + pid, + ] + ) + expected_cu_seqlens_padded = torch.tensor([0, 8, 16, 24]) + expected_labels_padded = torch.tensor( + [ + -100, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + -100, + -100, + 13, + -100, + label_pad, + label_pad, + label_pad, + label_pad, + ] + ) + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + assert torch.equal(input_ids_padded, expected_input_ids) + assert torch.equal(labels_padded, expected_labels_padded) + assert torch.equal(positional_ids_padded, expected_positional_ids) + assert torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded) + + def test_mixed_sequence_lengths_with_divisibility_factor(self): + """Test with sequences both shorter and longer than divisibility factor.""" + # Setup - divisibility factor 6 + # Seq 1: length 2 (shorter than 6, needs 4 padding) + # Seq 2: length 7 (longer than 6, needs 5 padding to reach 12) + # Seq 3: length 4 (shorter than 6, needs 2 padding) + # Seq 4: length 10 (longer than 6, needs 2 padding to reach 12) + + input_ids = torch.tensor( + [1, 1, 2, 2, 2, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4] + ) + labels = torch.tensor( + [ + 10, + 11, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + 30, + 31, + 32, + 33, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + ] + ) + positional_ids = torch.tensor( + [0, 1, 0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9] + ) + cu_seqlens = torch.tensor([0, 2, 9, 13, 23]) + divisibility_factor = 6 + + pid = 999 + label_pad = -300 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: [1,1] + 4 pads = 6 total + # Seq 2: [2,2,2,2,2,2,2] + 5 pads = 12 total + # Seq 3: [3,3,3,3] + 2 pads = 6 total + # Seq 4: [4,4,4,4,4,4,4,4,4,4] + 2 pads = 12 total + + expected_input_ids = torch.tensor( + [ + 1, + 1, + pid, + pid, + pid, + pid, # Seq 1: 2 + 4 padding + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + pid, + pid, + pid, + pid, # Seq 2: 7 + 5 padding + 3, + 3, + 3, + 3, + pid, + pid, # Seq 3: 4 + 2 padding + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + 4, + pid, + pid, # Seq 4: 10 + 2 padding + ] + ) + + expected_labels = torch.tensor( + [ + 10, + 11, + label_pad, + label_pad, + label_pad, + label_pad, + 20, + 21, + 22, + 23, + 24, + 25, + 26, + label_pad, + label_pad, + label_pad, + label_pad, + label_pad, + 30, + 31, + 32, + 33, + label_pad, + label_pad, + 40, + 41, + 42, + 43, + 44, + 45, + 46, + 47, + 48, + 49, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [ + 0, + 1, + 2, + 3, + 4, + 5, # Seq 1 positions continue through padding + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 2 positions continue + 0, + 1, + 2, + 3, + 4, + 5, # Seq 3 positions continue + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, # Seq 4 positions continue + ] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 6, 18, 24, 36]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + + def test_sequences_longer_than_divisibility_factor(self): + """Test with all sequences longer than the divisibility factor.""" + # Setup - divisibility factor 4, all sequences longer than 4 + # Seq 1: length 7 (needs 1 padding to reach 8) + # Seq 2: length 11 (needs 1 padding to reach 12) + # Seq 3: length 5 (needs 3 padding to reach 8) + + input_ids = torch.tensor( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, # 7 tokens + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, # 11 tokens + 3, + 3, + 3, + 3, + 3, # 5 tokens + ] + ) + labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + 300, + 301, + 302, + 303, + 304, + ] + ) + positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 0, 1, 2, 3, 4] + ) + cu_seqlens = torch.tensor([0, 7, 18, 23]) + divisibility_factor = 4 + + pid = 888 + label_pad = -400 + + # Execute + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=pid, + padding_label_id=label_pad, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Assert + # Seq 1: 7 + 1 pad = 8 (divisible by 4) + # Seq 2: 11 + 1 pad = 12 (divisible by 4) + # Seq 3: 5 + 3 pads = 8 (divisible by 4) + + expected_input_ids = torch.tensor( + [ + 1, + 1, + 1, + 1, + 1, + 1, + 1, + pid, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + 2, + pid, + 3, + 3, + 3, + 3, + 3, + pid, + pid, + pid, + ] + ) + + expected_labels = torch.tensor( + [ + 100, + 101, + 102, + 103, + 104, + 105, + 106, + label_pad, + 200, + 201, + 202, + 203, + 204, + 205, + 206, + 207, + 208, + 209, + 210, + label_pad, + 300, + 301, + 302, + 303, + 304, + label_pad, + label_pad, + label_pad, + ] + ) + + expected_positional_ids = torch.tensor( + [0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 0, 1, 2, 3, 4, 5, 6, 7] + ) + + expected_cu_seqlens_padded = torch.tensor([0, 8, 20, 28]) + + self.assertTrue(torch.equal(input_ids_padded, expected_input_ids)) + self.assertTrue(torch.equal(labels_padded, expected_labels)) + self.assertTrue(torch.equal(positional_ids_padded, expected_positional_ids)) + self.assertTrue(torch.equal(cu_seqlens_padded, expected_cu_seqlens_padded)) + + +class TestContextParallelUtils(unittest.TestCase): + """Test utilities for context parallel functionality.""" + + def setUp(self): + """Set up mock distributed environment.""" + # Mock torch.distributed functions + self.original_get_world_size = torch.distributed.get_world_size + self.original_get_rank = torch.distributed.get_rank + + def tearDown(self): + """Restore original torch.distributed functions.""" + torch.distributed.get_world_size = self.original_get_world_size + torch.distributed.get_rank = self.original_get_rank + + def _mock_distributed_env(self, cp_size, cp_rank): + """Mock the distributed environment for testing.""" + + def mock_get_world_size(group=None): + return cp_size + + def mock_get_rank(group=None): + return cp_rank + + torch.distributed.get_world_size = mock_get_world_size + torch.distributed.get_rank = mock_get_rank + + def test_cp_rank_slicing_simple_case(self): + """Test CP rank slicing with a simple 2-rank, single sequence case.""" + # Setup: Single sequence of length 8, CP size = 2 + # Each sequence gets divided into 2*cp_size = 4 slices of size 2 each + # Rank 0 gets slices [0,1] and [6,7] (first and last) + # Rank 1 gets slices [2,3] and [4,5] (second and second-to-last) + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) # Shape: (1, 8) - batch first + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) # Shape: (8,) - 1D as expected + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([[3, 4, 5, 6]]) + expected_labels_r1 = torch.tensor([[30, 40, 50, 60]]) + expected_pos_ids_r1 = torch.tensor([2, 3, 4, 5]) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_cp_rank_slicing_multiple_sequences(self): + """Test CP rank slicing with multiple sequences.""" + # Setup: Two sequences of length 8 each, CP size = 2 + # Total sequence length = 16, cu_seqlens = [0, 8, 16] + + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8, 11, 12, 13, 14, 15, 16, 17, 18]]) + labels = torch.tensor( + [[10, 20, 30, 40, 50, 60, 70, 80, 110, 120, 130, 140, 150, 160, 170, 180]] + ) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7, 0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8, 16]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # For each sequence, rank 0 gets first and last slices + # Seq 1: indices [0,1] and [6,7] -> values [1,2] and [7,8] + # Seq 2: indices [8,9] and [14,15] -> values [11,12] and [17,18] + expected_input_ids_r0 = torch.tensor([[1, 2, 7, 8, 11, 12, 17, 18]]) + expected_labels_r0 = torch.tensor([[10, 20, 70, 80, 110, 120, 170, 180]]) + expected_pos_ids_r0 = torch.tensor([0, 1, 6, 7, 0, 1, 6, 7]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_with_cp_size_1(self): + """Test that CP size = 1 returns original tensors unchanged.""" + input_ids = torch.tensor([[1, 2, 3, 4, 5, 6, 7, 8]]) + labels = torch.tensor([[10, 20, 30, 40, 50, 60, 70, 80]]) + position_ids = torch.tensor([0, 1, 2, 3, 4, 5, 6, 7]) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=1, cp_rank=0) + input_ids_result, labels_result, pos_ids_result = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # With CP size = 1, should return original tensors + self.assertTrue(torch.equal(input_ids_result, input_ids)) + self.assertTrue(torch.equal(labels_result, labels)) + self.assertTrue(torch.equal(pos_ids_result, position_ids)) + + def test_cp_rank_slicing_sequence_dim_detection(self): + """Test that the function correctly detects sequence dimension.""" + # Test with sequence dimension = 0 (sequence_length, batch_size) + input_ids = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) # (8, 2) + labels = torch.tensor( + [[1, 10], [2, 20], [3, 30], [4, 40], [5, 50], [6, 60], [7, 70], [8, 80]] + ) + position_ids = torch.tensor( + [[0, 0], [1, 1], [2, 2], [3, 3], [4, 4], [5, 5], [6, 6], [7, 7]] + ) + cu_seqlens = torch.tensor([0, 8]) + + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Should get indices [0,1] and [6,7] along dimension 0 + expected_input_ids_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_labels_r0 = torch.tensor([[1, 10], [2, 20], [7, 70], [8, 80]]) + expected_pos_ids_r0 = torch.tensor([[0, 0], [1, 1], [6, 6], [7, 7]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + def test_cp_rank_slicing_mixed_dimensions(self): + """Test CP rank slicing where input_ids/labels are 1D but position_ids has batch dimension.""" + # Setup: Single sequence of length 8, CP size = 2 + # This tests the opposite case from the simple test: + # - input_ids and labels: 1D (no batch dimension) + # - position_ids: 2D (has batch dimension) + + input_ids = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8]) # Shape: (8,) - 1D + labels = torch.tensor([10, 20, 30, 40, 50, 60, 70, 80]) # Shape: (8,) - 1D + position_ids = torch.tensor([[0, 1, 2, 3, 4, 5, 6, 7]]) # Shape: (1, 8) - 2D with batch + cu_seqlens = torch.tensor([0, 8]) + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 0 should get indices [0,1] and [6,7] + expected_input_ids_r0 = torch.tensor([1, 2, 7, 8]) # 1D result + expected_labels_r0 = torch.tensor([10, 20, 70, 80]) # 1D result + expected_pos_ids_r0 = torch.tensor([[0, 1, 6, 7]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + self.assertTrue(torch.equal(labels_r0, expected_labels_r0)) + self.assertTrue(torch.equal(pos_ids_r0, expected_pos_ids_r0)) + + # Test rank 1 + self._mock_distributed_env(cp_size=2, cp_rank=1) + input_ids_r1, labels_r1, pos_ids_r1 = get_batch_on_this_cp_rank( + cu_seqlens, input_ids, labels, position_ids + ) + + # Rank 1 should get indices [2,3] and [4,5] + expected_input_ids_r1 = torch.tensor([3, 4, 5, 6]) # 1D result + expected_labels_r1 = torch.tensor([30, 40, 50, 60]) # 1D result + expected_pos_ids_r1 = torch.tensor([[2, 3, 4, 5]]) # 2D result (preserves batch dim) + + self.assertTrue(torch.equal(input_ids_r1, expected_input_ids_r1)) + self.assertTrue(torch.equal(labels_r1, expected_labels_r1)) + self.assertTrue(torch.equal(pos_ids_r1, expected_pos_ids_r1)) + + def test_integration_with_padding_and_cp_slicing(self): + """Integration test: pad sequences then slice for CP ranks.""" + # Start with unpadded sequences + input_ids = torch.tensor([1, 1, 2, 2, 2]) # Two sequences: [1,1] and [2,2,2] + labels = torch.tensor([10, 11, 20, 21, 22]) + positional_ids = torch.tensor([0, 1, 0, 1, 2]) + cu_seqlens = torch.tensor([0, 2, 5]) + divisibility_factor = 4 # Will pad to lengths 4 and 4 + + # First, pad sequences + input_ids_padded, labels_padded, cu_seqlens_padded = pad_thd_sequences_for_cp( + input_ids.unsqueeze(0), + labels.unsqueeze(0), + cu_seqlens, + divisibility_factor, + padding_token_id=0, + padding_label_id=-100, + ) + + positional_ids_padded = generate_positional_ids_for_cp( + cu_seqlens, + divisibility_factor, + ) + + # Expected after padding: [1,1,0,0,2,2,2,0] with cu_seqlens [0,4,8] + expected_padded = torch.tensor([1, 1, 0, 0, 2, 2, 2, 0]) + self.assertTrue(torch.equal(input_ids_padded, expected_padded)) + + # Now test CP slicing with cp_size=2 + + # Test rank 0 + self._mock_distributed_env(cp_size=2, cp_rank=0) + input_ids_r0, labels_r0, pos_ids_r0 = get_batch_on_this_cp_rank( + cu_seqlens_padded, + input_ids_padded.unsqueeze(0), + labels_padded.unsqueeze(0), + positional_ids_padded, + ) + + # Each sequence of length 4 gets divided into 4 slices of size 1 + # Rank 0 gets slices [0] and [3] from each sequence + # Seq 1: indices [0] and [3] -> values [1] and [0] + # Seq 2: indices [4] and [7] -> values [2] and [0] + expected_input_ids_r0 = torch.tensor([[1, 0, 2, 0]]) + + self.assertTrue(torch.equal(input_ids_r0, expected_input_ids_r0)) + + +if __name__ == "__main__": + unittest.main() diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index c6f4647c0..f00bd573f 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -4,7 +4,7 @@ """Context Parallelism.""" import os -from typing import List, Union +from typing import List, Union, Tuple import torch import transformer_engine_torch as tex @@ -3927,3 +3927,212 @@ def attn_forward_func_with_cp( raise ValueError(f"Unsupported communication type: {cp_comm_type}!") return out + + +def pad_thd_sequences_for_cp( + input_ids: torch.Tensor, + labels: torch.Tensor, + cu_seqlens: torch.Tensor, + divisibility_factor: int, + padding_token_id: int = 0, + padding_label_id: int = -100, +) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Pads sequences to be divisible by the divisibility factor. + + Args: + input_ids: Tensor of shape (1, N) or (N,) containing concatenated sequences + labels: Tensor of shape (1, N) or (N,) containing labels for each token + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + padding_token_id: Token ID to use for padding (default: 0) + padding_label_id: Label ID to use for padding (default: -100) + + Returns: + Tuple of: + - input_ids_padded: Padded input_ids tensor + - labels_padded: Padded labels tensor + - cu_seqlens_padded: Cumulative sequence lengths accounting for padding + """ + # Flatten input_ids and labels if needed + if input_ids.dim() == 2: + input_ids = input_ids.squeeze(0) + if labels.dim() == 2: + labels = labels.squeeze(0) + + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence (make length a multiple of divisibility_factor) + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] + + # Extract sequences and labels for each batch item + batch_sequences = [ + input_ids[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] + batch_labels = [ + labels[start.item() : end.item()] for start, end in zip(cu_seqlens[:-1], cu_seqlens[1:]) + ] + + # Pad sequences and labels to required length + input_ids_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_token_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_sequences, padding_amounts) + ] + ) + labels_padded = torch.cat( + [ + ( + torch.cat([seq, torch.full((pad,), padding_label_id, dtype=seq.dtype)]) + if pad > 0 + else seq + ) + for seq, pad in zip(batch_labels, padding_amounts) + ] + ) + + # Compute cumulative padded sequence lengths, starting from 0 + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + cu_seqlens_padded = torch.cumsum( + torch.cat([torch.tensor([0], dtype=cu_seqlens.dtype), padded_lengths]), dim=0 + ) + + return input_ids_padded, labels_padded, cu_seqlens_padded + + +def generate_positional_ids_for_cp( + cu_seqlens: torch.Tensor, + divisibility_factor: int, + dtype: torch.dtype = torch.long, +) -> torch.Tensor: + """Generate positional IDs for sequences padded to be divisible by divisibility_factor. + + Args: + cu_seqlens: Tensor of shape (M,) containing cumulative sequence lengths + divisibility_factor: Each sequence length must be divisible by this factor + dtype: Data type for the generated positional IDs (default: torch.long) + + Returns: + Generated positional_ids tensor where each sequence starts from 0 and continues through padding + """ + # Compute the sequence lengths from cu_seqlens + seqlens = cu_seqlens[1:] - cu_seqlens[:-1] + + # List: amount of padding needed for each sequence + padding_amounts = [ + ((l.item() + divisibility_factor - 1) // divisibility_factor) * divisibility_factor + - l.item() + for l in seqlens + ] + + # Generate positional IDs for each padded sequence (each starts from 0) + padded_lengths = seqlens + torch.tensor(padding_amounts, dtype=seqlens.dtype) + positional_ids = torch.cat( + [torch.arange(0, int(length), dtype=dtype) for length in padded_lengths] + ) + + return positional_ids + + +def get_batch_on_this_cp_rank( + cu_seqlens_padded: torch.Tensor, + input_ids_padded: torch.Tensor, + labels_padded: torch.Tensor, + position_ids_padded: torch.Tensor, + cp_group: torch.distributed.ProcessGroup = None, + qvk_format: str = "thd", +): + """Slice batch input along sequence dimension into multiple chunks for THD format. + + This function is inteded for use in self attention. It will not work for cross attention because + it does not handle the case where the sequence length of the query and key are different. + + Which are parallelized across GPUs in a context parallel group. + This version works with variable-length sequences using cumulative sequence lengths. + """ + if qvk_format not in ["thd", "bshd", "sbhd"]: + raise ValueError(f"Unsupported qvk_format: {qvk_format}!") + if qvk_format == "thd": + # Get context parallel size and rank + cp_size = torch.distributed.get_world_size(group=cp_group) + if cp_size > 1: + cp_rank = torch.distributed.get_rank(group=cp_group) + + # Calculate the chunk sizes for each sequence + total_slices_of_any_sequence = 2 * cp_size + slice_sizes = ( + cu_seqlens_padded[1:] - cu_seqlens_padded[:-1] + ) // total_slices_of_any_sequence + + # Process each tensor directly instead of using keys_to_change loop + def process_tensor(val): + if val is None: + return val + # Determine which dimension is the sequence dimension + # Ensure cu_seqlens_padded[-1] is a Python int, not a 0-dim tensor + if isinstance(cu_seqlens_padded[-1], torch.Tensor): + seq_len_val = cu_seqlens_padded[-1].item() + else: + seq_len_val = cu_seqlens_padded[-1] + + # Handle 1D tensors (like position_ids that don't have batch dimension) + if val.ndim == 1: + if val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "1D tensor shape doesn't match expected sequence length. Make sure the" + " inputs are in THD format and padded correctly." + ) + elif val.ndim >= 2: + if val.shape[1] == seq_len_val: + current_seq_dim = 1 + elif val.shape[0] == seq_len_val: + current_seq_dim = 0 + else: + raise ValueError( + "Make sure the inputs are in THD format and padded correctly." + ) + else: + raise ValueError("Tensor must be at least 1D") + + # On this particular rank, for each sequence, get two slices, one from the beginning + # and one from the end. + cp_rank_slices = [] + for slice_size, seq_start in zip(slice_sizes, cu_seqlens_padded[:-1]): + # 1st segment + cp_rank_slices.append( + torch.arange( + seq_start + (cp_rank * slice_size), + seq_start + ((cp_rank + 1) * slice_size), + device=val.device, + ) + ) + + # 2nd segment + cp_rank_slices.append( + torch.arange( + seq_start + ((total_slices_of_any_sequence - cp_rank - 1) * slice_size), + seq_start + ((total_slices_of_any_sequence - cp_rank) * slice_size), + device=val.device, + ) + ) + + return val.index_select(current_seq_dim, torch.cat(cp_rank_slices)) + + # Process each tensor directly + input_ids_padded = process_tensor(input_ids_padded) + labels_padded = process_tensor(labels_padded) + position_ids_padded = process_tensor(position_ids_padded) + else: + raise ValueError(f"Support not implemented yet for qvk_format: {qvk_format}!") + + return input_ids_padded, labels_padded, position_ids_padded From 405d474b39d0975a5c2a732d68e3ba9cfe28313b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Mon, 15 Sep 2025 09:29:20 +0200 Subject: [PATCH 136/153] [PyTorch Debug] Fix issue with negative underflow% stat. (#2107) * fix underflows log issue Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Pawel Gadzinski Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/debug/test_api_features.py | 12 +++++---- tests/pytorch/debug/test_log.py | 10 +++---- .../debug/features/utils/stats_computation.py | 26 +++++++++++++------ 3 files changed, 29 insertions(+), 19 deletions(-) diff --git a/tests/pytorch/debug/test_api_features.py b/tests/pytorch/debug/test_api_features.py index 974772599..d28db1647 100644 --- a/tests/pytorch/debug/test_api_features.py +++ b/tests/pytorch/debug/test_api_features.py @@ -268,7 +268,7 @@ def assert_empty(): )[0] expected_underflows = ( - ((tensor_fp8._data == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5) + ((tensor_fp8.dequantize() == 0).sum() - (tensor == 0).sum()) * 100 / (100 * 100 * 5) ) assert debug_api.transformer_engine.inspect_tensor_enabled( @@ -302,7 +302,7 @@ def assert_empty(): )[0] # Second config in same yaml - tensor = torch.rand((100, 100, 5)) + tensor = torch.rand((100, 100, 5)).cuda() debug_api.transformer_engine.inspect_tensor( "decoder.6.mlp.fc1", tensor_name="activation", @@ -316,7 +316,9 @@ def assert_empty(): stats = log() stats_names = [x[3] for x in stats.keys()] all(s in stats_names for s in ["cur_amax", "dynamic_range", "mean", "std", "l1_norm"]) - assert stats[("decoder.6.mlp.fc1", "activation", "mean", 200)] == tensor.mean() + torch.testing.assert_close( + stats[("decoder.6.mlp.fc1", "activation", "mean", 200)], tensor.mean() + ) debug_api.transformer_engine.inspect_tensor( "decoder.7.mlp.fc1", @@ -331,7 +333,7 @@ def assert_empty(): stats = log() stats_names = [x[3] for x in stats.keys()] all(s in stats_names for s in ["mean", "std", "l1_norm", "min", "max"]) - assert stats[("decoder.7.mlp.fc1", "weight", "max", 200)] == tensor.max() + torch.testing.assert_close(stats[("decoder.7.mlp.fc1", "weight", "max", 200)], tensor.max()) assert not debug_api.transformer_engine.inspect_tensor_enabled( "decoder.7.mlp.fc1", tensor_name="weight", iteration=201 @@ -377,7 +379,7 @@ def fp8_tensor(t): return quantizer(t.cuda()) shape = [1024, 1024] - tensors = [torch.randn(shape) for _ in range(2)] + tensors = [torch.randn(shape).cuda() for _ in range(2)] tensors_fp8 = [fp8_tensor(tensors[i]) for i in range(2)] feed(tensors[0], tensors_fp8[0], quantizer) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index ca8e10ad6..dcc9861c8 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -167,8 +167,8 @@ def test_numerics(fp8_recipe, feature_dirs): num_quantizers=3, ) - tensor = torch.zeros(1024, 1024).cuda() - tensor[0, :] = 1000 + tensor = torch.randn(1024, 1024).cuda() + tensor[0, 100:200] = -0.0 quantizer = recipe_state.make_quantizers()[0] quantized_tensor = quantizer(tensor) @@ -191,15 +191,13 @@ def test_numerics(fp8_recipe, feature_dirs): if "underflows%" in line: underflows = float(line.split("value=")[1]) expected = ( - ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) - / dequantized_tensor.numel() - * 100 + ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) / tensor.numel() * 100 ) assert underflows == pytest.approx(expected.cpu(), abs=1e-4) if "mse" in line: mse = float(line.split("value=")[1]) expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean") - assert mse == pytest.approx(expected.cpu(), abs=1e-6) + assert mse == pytest.approx(expected.cpu(), abs=1e-4) if "overflows%" in line: overflows = float(line.split("value=")[1]) expected = ( diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 3842ab1c5..2fa6985ac 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -199,6 +199,15 @@ def _get(buffers, stat_name): ), } +FP8_NEGATIVE_ZERO = 128 # represnts -0.0 in fp8 + + +def count_nonzero_fp8(fp8_data: torch.Tensor) -> torch.Tensor: + """Count the number of non-zero elements in the fp8 data.""" + fp8_data = fp8_data.view(dtype=torch.uint8) + zero_vals = torch.tensor([0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8) + return fp8_data.numel() - torch.isin(fp8_data, zero_vals).sum() + def add_underflows_stats(recipe_name: str, columnwise: bool = False): """Register *both* underflow stats (num and %) for the given recipe.""" @@ -212,22 +221,23 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False): stats_to_num[stat_pct] = len(stats_to_num) STATS[stat_num] = ( - lambda x, aux_dict: ( + lambda x, aux_dict: x.count_nonzero() + - count_nonzero_fp8( aux_dict[recipe_name].get_data_tensors( rowwise_data=not columnwise, columnwise_data=columnwise ) - == 0 - ).sum() - - (x == 0).sum(), + ), lambda buffers, _sn=stat_num: sum(_get(buffers, _sn)), ) STATS[stat_pct] = ( lambda x, aux_dict: ( - aux_dict[recipe_name].get_data_tensors( - rowwise_data=not columnwise, columnwise_data=columnwise + x.count_nonzero() + - count_nonzero_fp8( + aux_dict[recipe_name].get_data_tensors( + rowwise_data=not columnwise, columnwise_data=columnwise + ) ) - == 0 - ).sum() + ) / aux_dict[recipe_name].numel() * 100, lambda buffers, _sn_num=stat_num: 100 From cd2034f3f28ec07ef4feb18d469a393a9cd2596f Mon Sep 17 00:00:00 2001 From: Ming-Xu Huang Date: Mon, 15 Sep 2025 17:08:28 -0400 Subject: [PATCH 137/153] Lower precision gated-act to accelerate FP8 current-scaling. (#2153) * Applying the original precision as Norm outputs' and activation compuations. Signed-off-by: Ming Huang * Adding knob to control norm output precision. Signed-off-by: Ming Huang * Removing the knob and applying lower-precision norm with current-scaling only. Signed-off-by: Ming Huang * Fix the error when quantizer==None Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang --- tests/jax/test_custom_call_compute.py | 13 +++++++++++-- transformer_engine/jax/cpp_extensions/activation.py | 6 +++--- .../jax/cpp_extensions/normalization.py | 4 ++++ 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 11f07d913..9e39b84c0 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -465,14 +465,23 @@ def _test_norm_forward( x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_mu, ref_rsigma = _jax_layernorm( - x, gamma, beta, zero_centered_gamma, epsilon, quantizer=ref_quantizer + x, + gamma, + beta, + zero_centered_gamma, + epsilon, + quantizer=ref_quantizer, ) else: output, rsigma = tex.rmsnorm_fwd( x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer ) ref_out, ref_rsigma = _jax_rmsnorm( - x, gamma, zero_centered_gamma, epsilon, quantizer=ref_quantizer + x, + gamma, + zero_centered_gamma, + epsilon, + quantizer=ref_quantizer, ) ref_mu = None diff --git a/transformer_engine/jax/cpp_extensions/activation.py b/transformer_engine/jax/cpp_extensions/activation.py index d3c7d2b08..cdda20166 100644 --- a/transformer_engine/jax/cpp_extensions/activation.py +++ b/transformer_engine/jax/cpp_extensions/activation.py @@ -1045,7 +1045,7 @@ def act_lu( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = act_lu( - x=x.astype(jnp.float32), + x=x, activation_type=activation_type, quantizer=None, ) @@ -1178,8 +1178,8 @@ def quantize_dact_dbias( if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: # Current scaling does not support fused operations. Perform dact in higher precision then quantize after. out = dact_lu( - dz=dz.astype(jnp.float32), - x=x.astype(jnp.float32), + dz=dz, + x=x, activation_type=activation_type, quantizer=None, ) diff --git a/transformer_engine/jax/cpp_extensions/normalization.py b/transformer_engine/jax/cpp_extensions/normalization.py index de1877de5..7a978c1b7 100644 --- a/transformer_engine/jax/cpp_extensions/normalization.py +++ b/transformer_engine/jax/cpp_extensions/normalization.py @@ -842,6 +842,8 @@ def _jax_layernorm(x, gamma, beta, zero_centered_gamma, epsilon, quantizer=None) output = normed_input * gamma + beta if quantizer: + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + output = output.astype(x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) @@ -867,6 +869,8 @@ def _jax_rmsnorm(x, gamma, zero_centered_gamma, epsilon, quantizer=None): output = normed_input * gamma if quantizer: + if quantizer.scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING: + output = output.astype(x.dtype) ln_out = quantizer.quantize(output, dq_dtype=x.dtype) else: ln_out = jnp.asarray(output).astype(x.dtype) From 59130cc9d0bd7cc66457556373d731ca0744cf9b Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Mon, 15 Sep 2025 16:12:12 -0700 Subject: [PATCH 138/153] [PyTorch] Support activation CPU offloading in fusible ops (#2158) * Add CPU offloading logic to ops. Fix test to compute dgrad. Signed-off-by: Tim Moon * Make sure grads are contiguous in op backwards Signed-off-by: Tim Moon * Add op-based MLP to CPU offloading tests Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Handle different weight cache behavior on Hopper/Blackwell Add MXFP8 to CPU offload tests. Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove MXFP8 test Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --------- Signed-off-by: Tim Moon Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_cpu_offloading.py | 196 +++++++++++------- transformer_engine/pytorch/ops/_common.py | 4 +- .../pytorch/ops/basic/activation.py | 3 + .../pytorch/ops/basic/basic_linear.py | 3 + .../pytorch/ops/basic/dropout.py | 3 + .../pytorch/ops/basic/l2normalization.py | 9 +- .../pytorch/ops/basic/layer_norm.py | 7 +- .../pytorch/ops/basic/rmsnorm.py | 7 +- .../fused/forward_linear_bias_activation.py | 13 +- .../ops/fused/forward_linear_bias_add.py | 15 +- .../ops/fused/forward_linear_scale_add.py | 5 +- .../ops/fused/userbuffers_forward_linear.py | 3 + 12 files changed, 174 insertions(+), 94 deletions(-) diff --git a/tests/pytorch/test_cpu_offloading.py b/tests/pytorch/test_cpu_offloading.py index 0b0732dfa..0e01f0b04 100644 --- a/tests/pytorch/test_cpu_offloading.py +++ b/tests/pytorch/test_cpu_offloading.py @@ -2,8 +2,11 @@ # # See LICENSE for license information. +import contextlib +import gc import os -from contextlib import nullcontext +from typing import Iterable, Optional + import pytest import torch @@ -11,15 +14,16 @@ from transformer_engine.common import recipe from transformer_engine.pytorch.fp8 import FP8GlobalStateManager from transformer_engine.pytorch.attention.dot_product_attention import _attention_backends +from transformer_engine.pytorch.utils import is_non_tn_fp8_gemm_supported from utils import ModelConfig, get_available_attention_backends -# Check if FP8 is supported +# Check supported quantization schemes fp8_available, _ = FP8GlobalStateManager.is_fp8_available() +mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() -fp8_recipes = [None] +quantization_recipes: Optional[recipe.Recipe] = [None] if fp8_available: - fp8_recipes.append(recipe.Float8CurrentScaling()) - fp8_recipes.append(recipe.DelayedScaling()) + quantization_recipes.extend((recipe.Float8CurrentScaling(), recipe.DelayedScaling())) model_config = { "small": ModelConfig(8, 512, 8, 64, num_layers=5, eps=0.1), @@ -48,85 +52,139 @@ "transformer_layer": lambda: te.TransformerLayer( SIZE, SIZE, NUM_HEADS, params_dtype=torch.bfloat16, hidden_dropout=0.0 ), + "linear_op": lambda: te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + "layernorm_mlp_ops": lambda: te.ops.Sequential( + te.ops.LayerNorm(SIZE, dtype=torch.bfloat16), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + te.ops.GELU(), + te.ops.Linear(SIZE, SIZE, dtype=torch.bfloat16), + ), } -def _get_input(): - return torch.empty((128, SIZE, SIZE), dtype=torch.bfloat16).cuda() +def _make_input() -> torch.Tensor: + """Generate random input tensor.""" + return torch.randn( + (128, SIZE, SIZE), + dtype=torch.bfloat16, + device="cuda", + requires_grad=True, + ) -def _get_fp8_weight_cache_size(models, fp8_recipe): - """ - Calculate the total FP8 weight cache size (in MB) for a list of models. - """ - if fp8_recipe is None: +def _warmup_model( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> None: + """Perform forward and backward pass""" + tensor = _make_input() + for module in modules: + with te.fp8_autocast( + enabled=quantization_recipe is not None, + fp8_recipe=quantization_recipe, + ): + tensor = module(tensor) + tensor.sum().backward() + + +def _estimate_cached_weight_size( + model_name: str, + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], +) -> float: + """Calculate the memory (in MiB) needed for weight caching.""" + + # The weight params are cached directly for unquantized compute + if quantization_recipe is None: return 0 - params_bytes = 0 - for model in models: - for name, param in model.named_parameters(): - if "weight" in name: - params_bytes += param.numel() + # Count number of weight param elements + param_elements = 0 + for module in modules: + for param in module.parameters(): + if param.dim() == 2: + param_elements += param.numel() + + # FP8 tensor-scaling caches one byte per element + if quantization_recipe.delayed() or quantization_recipe.float8_current_scaling(): + if not is_non_tn_fp8_gemm_supported() and model_name not in ( + "linear_op", + "layernorm_mlp_ops", + ): + # Modules do not deallocate FP8 transpose for weights + return 2 * param_elements / 1024**2 + return param_elements / 1024**2 + + # MXFP8 caches one data byte per element and one scale byte per 32 + # elements + if quantization_recipe.mxfp8(): + if model_name not in ("linear_op", "layernorm_mlp_ops"): + # Modules do not deallocate column-wise MXFP8 data for weights + return 2 * param_elements * (1 + 1 / 32) / 1024**2 + return param_elements * (1 + 1 / 32) / 1024**2 + + raise NotImplementedError(f"Unrecognized recipe ({quantization_recipe})") + + +def _measure_cached_memory( + modules: Iterable[torch.nn.Module], + quantization_recipe: Optional[recipe.Recipe], + cpu_offload: bool, +) -> float: + """Measure the growth in allocated GPU memory in MiB after a model forward pass. + + Memory measurement excludes the input and output tensors. - # One byte for columnwise and one byte for rowwise, - # hence multiply by 2 and convert to MB - # there is 1 byte of scale per 32 elements in mxFP8 - factor_for_scale_inv_tensor = (1 + 1 / 32) if fp8_recipe.mxfp8() else 1 - return (2 * params_bytes * factor_for_scale_inv_tensor) / (1024**2) + """ + # Reset memory + gc.collect() + torch.cuda.empty_cache() -def _measure_memory_between_forward_and_backward(models, fp8_recipe, cpu_offload): - tensor = _get_input() + # Context and sync function for CPU offloading if cpu_offload: offload_context, sync_function = te.get_cpu_offload_context( enabled=True, - num_layers=len(models) - 1, - model_layers=len(models), + num_layers=len(modules), + model_layers=len(modules) + 1, offload_activations=True, offload_weights=False, ) else: - offload_context = nullcontext() + offload_context = contextlib.nullcontext() sync_function = lambda x: x - for model in models: + # Forward pass, with dummy step to trigger offload for last module + inp = _make_input() + tensor = inp + memory_before_forward = torch.cuda.memory_allocated() / (1024**2) + for module in modules: with te.fp8_autocast( - enabled=fp8_recipe is not None, fp8_recipe=fp8_recipe + enabled=quantization_recipe is not None, fp8_recipe=quantization_recipe ), offload_context: - tensor = model(tensor) + tensor = module(tensor) tensor = sync_function(tensor) + with offload_context: + tensor = tensor.clone() + tensor = sync_function(tensor) + memory_after_forward = (torch.cuda.memory_allocated() - tensor.nbytes) / (1024**2) - max_mem_used = torch.cuda.memory_allocated() / (1024**2) - torch.cuda.synchronize() - + # Backward pass tensor.sum().backward() + torch.cuda.synchronize() - return max_mem_used + # Memory usage in MiB + return memory_after_forward - memory_before_forward -@pytest.mark.parametrize("fp8_recipe", fp8_recipes) -@pytest.mark.parametrize("model_key", model_types.keys()) -def test_cpu_offload(fp8_recipe, model_key) -> None: - """ - We run three configurations: - (1) No offloading: All activations remain on the GPU between forward and backward passes. - (2) No offloading (one layer): Only the first layer's activations remain on the GPU between - forward and backward passes. - (3) With offloading (all layers): Only the last layer's activations remain on the GPU - between forward and backward passes, while all other layers are offloaded to the CPU. - - We expect the memory consumption of configurations (2) and (3) to be similar, with - the difference being the size of the FP8 cache that is not offloaded to the CPU. - We also expect this memory consumption to be smaller than in scenario (1). - """ - import gc +@pytest.mark.parametrize("quantization_recipe", quantization_recipes) +@pytest.mark.parametrize("model_name", model_types.keys()) +def test_cpu_offload(quantization_recipe: Optional[recipe.Recipe], model_name: str) -> None: + """Check that CPU offloading runs and has expected memory usage.""" - gc.collect() - - model_cls = model_types[model_key] - models_list = [model_cls() for _ in range(NUM_LAYERS)] - - if model_key in ["multihead_attention", "transformer_layer"]: + # Construct model + modules_list = [model_types[model_name]() for _ in range(NUM_LAYERS)] + if model_name in ["multihead_attention", "transformer_layer"]: available_backends, *_ = get_available_attention_backends( model_config["small"], qkv_dtype=torch.bfloat16, @@ -138,20 +196,18 @@ def test_cpu_offload(fp8_recipe, model_key) -> None: os.environ["NVTE_FLASH_ATTN"] = "0" _attention_backends["backend_selection_requires_update"] = True - without_offloading = _measure_memory_between_forward_and_backward( - models_list, fp8_recipe, False - ) - without_offloading_one_layer = _measure_memory_between_forward_and_backward( - models_list[:1], fp8_recipe, False - ) - with_offloading = _measure_memory_between_forward_and_backward(models_list, fp8_recipe, True) + # Warmup + _warmup_model(modules_list, quantization_recipe) - assert with_offloading < without_offloading + # Measure cached memory after forward pass + memory_without_offload = _measure_cached_memory(modules_list, quantization_recipe, False) + memory_with_offload = _measure_cached_memory(modules_list, quantization_recipe, True) - # The only difference between the memory consumption of with_offloading - # and without_offloading_one_layer should be the size of the FP8 weights cache, - # which is not offloaded to the CPU. - memory_consumption_diff = abs(with_offloading - without_offloading_one_layer) - assert ( - memory_consumption_diff < _get_fp8_weight_cache_size(models_list[1:], fp8_recipe) + EPSILON + # Check for expected memory usage + assert memory_with_offload < memory_without_offload + memory_from_cached_weights = _estimate_cached_weight_size( + model_name, + modules_list, + quantization_recipe, ) + assert abs(memory_with_offload - memory_from_cached_weights) < EPSILON diff --git a/transformer_engine/pytorch/ops/_common.py b/transformer_engine/pytorch/ops/_common.py index 8e997428f..99bbc34c4 100644 --- a/transformer_engine/pytorch/ops/_common.py +++ b/transformer_engine/pytorch/ops/_common.py @@ -29,7 +29,9 @@ def maybe_dequantize( if is_quantized_tensor(tensor): return tensor.dequantize(dtype=dtype) if dtype is not None and tensor.dtype != dtype: - return tensor.to(dtype) + tensor = tensor.to(dtype) + if not tensor.is_contiguous(): + tensor = tensor.contiguous() return tensor diff --git a/transformer_engine/pytorch/ops/basic/activation.py b/transformer_engine/pytorch/ops/basic/activation.py index 5ef421bc1..22779b601 100644 --- a/transformer_engine/pytorch/ops/basic/activation.py +++ b/transformer_engine/pytorch/ops/basic/activation.py @@ -11,6 +11,7 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor.float8_tensor import Float8CurrentScalingQuantizer, Quantizer from ...utils import clear_tensor_data from ..op import BasicOperation, OperationContext @@ -110,6 +111,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x) ctx.save_for_backward(x) ctx.dtype = dtype ctx.prev_op_grad_output_quantizer = prev_op_grad_output_quantizer diff --git a/transformer_engine/pytorch/ops/basic/basic_linear.py b/transformer_engine/pytorch/ops/basic/basic_linear.py index 833633055..70c70c54d 100644 --- a/transformer_engine/pytorch/ops/basic/basic_linear.py +++ b/transformer_engine/pytorch/ops/basic/basic_linear.py @@ -13,6 +13,7 @@ import torch from ...cpp_extensions import general_gemm +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import ( CudaRNGStatesTracker, gather_along_first_dim, @@ -964,6 +965,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) ctx.save_for_backward(x_local, w) ctx.with_quantized_compute = with_quantized_compute ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/basic/dropout.py b/transformer_engine/pytorch/ops/basic/dropout.py index f0f55322c..30ccf5ebc 100644 --- a/transformer_engine/pytorch/ops/basic/dropout.py +++ b/transformer_engine/pytorch/ops/basic/dropout.py @@ -9,6 +9,7 @@ import torch import transformer_engine_torch as tex +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...tensor import Quantizer from ...tensor._internal.float8_tensor_base import Float8TensorBase from .._common import maybe_autocast_dtype, maybe_dequantize @@ -70,6 +71,8 @@ def op_forward( # Save context for backward if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(mask) ctx.save_for_backward(mask) ctx.impl = impl ctx.dropout_probability = self.dropout_probability diff --git a/transformer_engine/pytorch/ops/basic/l2normalization.py b/transformer_engine/pytorch/ops/basic/l2normalization.py index a340e7d42..440fee34d 100644 --- a/transformer_engine/pytorch/ops/basic/l2normalization.py +++ b/transformer_engine/pytorch/ops/basic/l2normalization.py @@ -10,10 +10,8 @@ import torch -from ...utils import clear_tensor_data from ... import torch_version -from .._common import maybe_dequantize -from ..op import BasicOperation, OperationContext +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...jit import ( l2normalization_fused, l2normalization_fwd_fused, @@ -22,6 +20,9 @@ warmup_jit_l2normalization_all_dtypes, ) from ...tensor import Quantizer +from ...utils import clear_tensor_data +from ..op import BasicOperation, OperationContext +from .._common import maybe_dequantize class L2Normalization(BasicOperation): @@ -101,6 +102,8 @@ def op_forward( # Save state for backward pass if requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, rsqrt_norm) ctx.save_for_backward(x, rsqrt_norm) return y diff --git a/transformer_engine/pytorch/ops/basic/layer_norm.py b/transformer_engine/pytorch/ops/basic/layer_norm.py index 3d8862e99..91e6de07d 100644 --- a/transformer_engine/pytorch/ops/basic/layer_norm.py +++ b/transformer_engine/pytorch/ops/basic/layer_norm.py @@ -14,6 +14,9 @@ from transformer_engine_torch import layernorm_bwd, layernorm_fwd from ...constants import TE_DType +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -22,8 +25,6 @@ ) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize -from ...export import is_in_onnx_export_mode -from ...tensor import Quantizer class LayerNorm(BasicOperation): @@ -215,6 +216,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, means, rstdevs) ctx.save_for_backward(x, means, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/basic/rmsnorm.py b/transformer_engine/pytorch/ops/basic/rmsnorm.py index 42d3fc101..8c3f02974 100644 --- a/transformer_engine/pytorch/ops/basic/rmsnorm.py +++ b/transformer_engine/pytorch/ops/basic/rmsnorm.py @@ -14,6 +14,9 @@ from transformer_engine_torch import rmsnorm_bwd, rmsnorm_fwd from ...constants import TE_DType +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...export import is_in_onnx_export_mode +from ...tensor import Quantizer from ...utils import ( canonicalize_device, canonicalize_dtype, @@ -22,8 +25,6 @@ ) from ..op import BasicOperation, OperationContext from .._common import maybe_autocast_dtype, maybe_dequantize -from ...export import is_in_onnx_export_mode -from ...tensor import Quantizer class RMSNorm(BasicOperation): @@ -196,6 +197,8 @@ def op_forward( # Save state for backward pass if ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x, rstdevs) ctx.save_for_backward(x, rstdevs) ctx.dtype = dtype diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py index b87b12f84..02bcfee0a 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_activation.py @@ -10,14 +10,11 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.ops.basic import BasicLinear, Bias -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...fp8 import FP8GlobalStateManager from ...tensor import Quantizer +from ..basic import BasicLinear, Bias +from ..op import FusedOperation, FusibleOperation, OperationContext class ForwardLinearBiasActivation(FusedOperation): @@ -121,6 +118,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py index dd59e602f..15cc081c1 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_bias_add.py @@ -10,14 +10,11 @@ import torch -from transformer_engine.pytorch.fp8 import FP8GlobalStateManager -from transformer_engine.pytorch.ops.basic import AddExtraInput, BasicLinear, Bias -from transformer_engine.pytorch.ops.op import ( - FusedOperation, - FusibleOperation, - OperationContext, -) -from transformer_engine.pytorch.tensor import Quantizer +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload +from ...fp8 import FP8GlobalStateManager +from ...tensor import Quantizer +from ..basic import AddExtraInput, BasicLinear, Bias +from ..op import FusedOperation, FusibleOperation, OperationContext class ForwardLinearBiasAdd(FusedOperation): @@ -118,6 +115,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py index 448f72763..21190d4fc 100644 --- a/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py +++ b/transformer_engine/pytorch/ops/fused/forward_linear_scale_add.py @@ -10,14 +10,15 @@ import torch +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...fp8 import FP8GlobalStateManager +from ...tensor import Quantizer from ..basic import AddExtraInput, BasicLinear, ConstantScale from ..op import ( FusedOperation, FusibleOperation, OperationContext, ) -from ...tensor import Quantizer class ForwardLinearScaleAdd(FusedOperation): @@ -95,6 +96,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer diff --git a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py index 574642794..a604e57dc 100644 --- a/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py +++ b/transformer_engine/pytorch/ops/fused/userbuffers_forward_linear.py @@ -12,6 +12,7 @@ from transformer_engine_torch import CommOverlapType from ...cpp_extensions import general_gemm +from ...cpu_offload import is_cpu_offload_enabled, mark_activation_offload from ...distributed import get_distributed_world_size from ...fp8 import FP8GlobalStateManager from ...module.base import ( @@ -353,6 +354,8 @@ def fuser_forward( # Save state for backward pass if linear_op_ctx.requires_grad: + if is_cpu_offload_enabled(): + mark_activation_offload(x_local) linear_op_ctx.save_for_backward(x_local, w) linear_op_ctx.with_quantized_compute = with_quantized_compute linear_op_ctx.input_quantizer = input_quantizer From 258d084237dccef6d862d20eb2fd63c77315cb36 Mon Sep 17 00:00:00 2001 From: Jan Bielak Date: Tue, 16 Sep 2025 11:29:04 -0700 Subject: [PATCH 139/153] Do not use normalization forward + amax fusion if cuDNN backend is requested (#2174) * Do not use norm fwd + amax fusion if cudnn backend is requested Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Read envirornment vairable directly to avoid include error Signed-off-by: Jan Bielak --------- Signed-off-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- .../common/normalization/layernorm/ln_api.cpp | 3 ++- .../common/normalization/rmsnorm/rmsnorm_api.cpp | 3 ++- .../pytorch/csrc/extensions/normalization.cpp | 12 ++++++++---- 3 files changed, 12 insertions(+), 6 deletions(-) diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index af19300a9..398c0acbd 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -66,7 +66,8 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { - cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output + NVTE_CHECK(!cudnn_backend, + "cuDNN does not currently support amax output for non quantized output"); } bool gamma_in_weight_dtype = false; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 1aae72e15..82e360ed6 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -52,7 +52,8 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { - cudnn_backend = false; // cuDNN does not currently support amax output for non quantized output + NVTE_CHECK(!cudnn_backend, + "cuDNN does not currently support amax output for non quantized output"); } bool training = diff --git a/transformer_engine/pytorch/csrc/extensions/normalization.cpp b/transformer_engine/pytorch/csrc/extensions/normalization.cpp index 59bac8fe5..c63f892ce 100644 --- a/transformer_engine/pytorch/csrc/extensions/normalization.cpp +++ b/transformer_engine/pytorch/csrc/extensions/normalization.cpp @@ -110,7 +110,8 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); std::tie(unquantized_out_cu, unquantized_out) = my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); @@ -145,7 +146,8 @@ std::vector layernorm_fwd(py::handle input, py::handle weight, Maybe // Quantize output if using unfused kernel if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); } else { @@ -290,7 +292,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w TensorWrapper unquantized_out_cu; py::object unquantized_out; if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); std::tie(unquantized_out_cu, unquantized_out) = my_quantizer_cs->create_hp_tensor_with_amax(size, out_dtype); @@ -325,7 +328,8 @@ std::vector rmsnorm_fwd(const py::handle &input, const py::handle &w // Quantize output if using unfused kernel if (force_unfused_kernel) { - if (IsFloat8CurrentScalingQuantizers(quantizer.ptr())) { + if (IsFloat8CurrentScalingQuantizers(quantizer.ptr()) && + !transformer_engine::getenv("NVTE_NORM_FWD_USE_CUDNN")) { auto my_quantizer_cs = dynamic_cast(my_quantizer.get()); my_quantizer_cs->quantize_with_amax(unquantized_out_cu, out_cu); } else { From c221909dba98182dcac7bd438edad30871639b33 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Wed, 17 Sep 2025 06:32:54 +1200 Subject: [PATCH 140/153] Fix unjoined comm stream in UB communicator (#2160) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index d90dd3abc..087493495 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -612,12 +612,16 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, recv_stream); + // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf for (auto stream : {send_stream, recv_stream}) { NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, stream)); - NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); - // We sync with the comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf NVTE_CHECK_CUDA(cudaStreamWaitEvent(_stream_comm, _stop_comm, 0)); } + + // Next we sync with the main stream + // We have to recapture an event off the comm stream to enable cuda graph capture otherwise the comm stream will be never be joined in the graph + NVTE_CHECK_CUDA(cudaEventRecord(_stop_comm, _stream_comm)); + NVTE_CHECK_CUDA(cudaStreamWaitEvent(stream_main, _stop_comm, 0)); } /*************************************************************************************************** From ba37529c273182c2ef192e7198ceac1ecfa78e20 Mon Sep 17 00:00:00 2001 From: vthumbe1503 Date: Tue, 16 Sep 2025 17:10:39 -0700 Subject: [PATCH 141/153] FP8 Output Quantization for GEMM (#2123) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Test working as I think it should work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * revert accidental change Signed-off-by: Varun Thumbe Restrict the number of cases for unfused quantization, some fp8->fp8 cases are handled by cublas Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe fix merge conflict Signed-off-by: Varun Thumbe bug: missed a } in the code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak Signed-off-by: Varun Thumbe FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> Signed-off-by: Varun Thumbe [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Signed-off-by: Varun Thumbe [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Signed-off-by: Varun Thumbe Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov Signed-off-by: Varun Thumbe [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani Signed-off-by: Varun Thumbe feat: Add support for multiple quantization modes in the UB communicators (#2043) Signed-off-by: Varun Thumbe [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Varun Thumbe Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Signed-off-by: Varun Thumbe Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> Signed-off-by: Varun Thumbe mxfp8 unfused quant support, refined unit test, remove unecessary quantization code Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe missed a quant code removal Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe minor bug fix Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Add cuBLASMp-backed GEMM-like API to TE common (#1824) * Pick up cuBLASMp during build Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Change lib order to fix link error Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Context creation, incomplete... Signed-off-by: Vladimir Cherepanov * Test fixure Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * A sanity AgGemm test, failing... Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Fix axes Signed-off-by: Vladimir Cherepanov * Take care of uneven distribution Signed-off-by: Vladimir Cherepanov * Use MPI to get position of local matrices Signed-off-by: Vladimir Cherepanov * Refactor Signed-off-by: Vladimir Cherepanov * Refactor & fixes Signed-off-by: Vladimir Cherepanov * Saving... Signed-off-by: Vladimir Cherepanov * Gemm-RS Signed-off-by: Vladimir Cherepanov * Gemm-AR, not working... Signed-off-by: Vladimir Cherepanov * Fixes Signed-off-by: Vladimir Cherepanov * Setting all-reduce epilogue for gemm-ar Signed-off-by: Vladimir Cherepanov * Use supported shapes for GEMM-AR Signed-off-by: Vladimir Cherepanov * Tweak tolerance Signed-off-by: Vladimir Cherepanov * First shot at fp8 Signed-off-by: Vladimir Cherepanov * Use TensorHolder in tests Signed-off-by: Vladimir Cherepanov * More test configs Signed-off-by: Vladimir Cherepanov * Support comm_sm_count Signed-off-by: Vladimir Cherepanov * Parametrize dtypes for A, B and D separately Signed-off-by: Vladimir Cherepanov * Tweak scaling Signed-off-by: Vladimir Cherepanov * Amax ptr Signed-off-by: Vladimir Cherepanov * Flags parity with cublas_gemm, saving... Signed-off-by: Vladimir Cherepanov * Cleanup Signed-off-by: Vladimir Cherepanov * Bias tests Signed-off-by: Vladimir Cherepanov * Fix bias test Signed-off-by: Vladimir Cherepanov * Aux, saving... Signed-off-by: Vladimir Cherepanov * aux_ld Signed-off-by: Vladimir Cherepanov * A fix Signed-off-by: Vladimir Cherepanov * Use test::Tensor Signed-off-by: Vladimir Cherepanov * Set scale inv Signed-off-by: Vladimir Cherepanov * Remove unsupported test configs Signed-off-by: Vladimir Cherepanov * Tweak tests Signed-off-by: Vladimir Cherepanov * Replace libcal with NCCL Signed-off-by: Vladimir Cherepanov * Add NVTX markers to API functions Signed-off-by: Vladimir Cherepanov * Tweak GemmAr tests Signed-off-by: Vladimir Cherepanov * More test config Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Vladimir Cherepanov * Fix merge fallout Signed-off-by: Vladimir Cherepanov * Remove MPI dependency, comment API, add algo parameter Signed-off-by: Vladimir Cherepanov * Fix nvshmem dependency Signed-off-by: Vladimir Cherepanov * Fix nvshmem build Signed-off-by: Vladimir Cherepanov * Excluse CommGemm tests from L0_cppunittest Signed-off-by: Vladimir Cherepanov * Add cpp_distributed sh file for CI Signed-off-by: Vladimir Cherepanov * Adapt tp TensorAllocator Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Skip GemmAr test on unsupported HW Signed-off-by: Vladimir Cherepanov * Oversibscribe is needed on some clusters Signed-off-by: Vladimir Cherepanov * Fix incomplete libcal removal Signed-off-by: Vladimir Cherepanov * Move CI tests to L1 Signed-off-by: Vladimir Cherepanov * Rename context to include NVTE prefix Signed-off-by: Vladimir Cherepanov * Remove leftover code Signed-off-by: Vladimir Cherepanov * NVTE_WITH_CUBLASMP off by default Signed-off-by: Vladimir Cherepanov * More detailed NVTE_CHECK diag Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Comment API Signed-off-by: Vladimir Cherepanov * Include stdbool header for legacy C compilers Signed-off-by: Vladimir Cherepanov * Remove now unused argument Signed-off-by: Vladimir Cherepanov * Abstract away cuBLASMp algo behind our own enum Signed-off-by: Vladimir Cherepanov * More detailed shape diag messages Signed-off-by: Vladimir Cherepanov * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/common/include/transformer_engine/comm_gemm.h Co-authored-by: Przemyslaw Tredak Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> * Add license Signed-off-by: Vladimir Cherepanov --------- Signed-off-by: Vladimir Cherepanov Signed-off-by: Vladimir Cherepanov <56651474+mk-61@users.noreply.github.com> Co-authored-by: Vladimir Cherepanov Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Przemyslaw Tredak FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086) * FP8 AllGather in FP8 GroupedGEMM 1. Support current scaling FP8 quantation with a given amax. 2. Support FP8 AG in fwd and BF16 RS in bwd. 3. The workflow is AR-max -> FP8 Quant -> FP8 AG -> FP8 GroupedGEMM. Signed-off-by: Ming Huang * Slightly refactor Signed-off-by: Ming Huang * Adding documents of new args. Signed-off-by: Ming Huang * Adding unit-tests. Signed-off-by: Ming Huang * Adding license. Signed-off-by: Ming Huang * Move unit-tests to L1. Signed-off-by: Ming Huang * Move quantizaer store/reset into FP8 only. Signed-off-by: Ming Huang * Adding all layout support for Blackwell+ Signed-off-by: Ming Huang * Adopt the feedback from code-review. Signed-off-by: Ming Huang * Fixed the wrong stream used by d2d in groupedGEMM FFI. Signed-off-by: Ming Huang --------- Signed-off-by: Ming Huang Co-authored-by: Phuong Nguyen [JAX] Delay MeshResource validation until first usage (#2124) Delay MeshResource validation until first usage Signed-off-by: Jeremy Berchtold Co-authored-by: Phuong Nguyen [JAX] Decouple Recipe and ScalingMode (#1728) * Decouple recipe and scaling mode Signed-off-by: Jeremy Berchtold * Expose global QuantizeConfig instance as a getter Signed-off-by: Jeremy Berchtold * Format and lint Signed-off-by: Jeremy Berchtold * Merge branch 'main' into dev/jberchtold/jax-scaling-mode-and-recipe-decoupling Signed-off-by: Jeremy Berchtold * Rename UsageType to TensorSource Signed-off-by: Jeremy Berchtold * Update test_layer.py Signed-off-by: Jeremy Berchtold --------- Signed-off-by: Jeremy Berchtold Signed-off-by: jberchtold-nvidia <158520091+jberchtold-nvidia@users.noreply.github.com> [JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128) * add dot_1_output sharding constraint + use AXIS_IS_UNSHARDED Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen [JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118) * add amax input to DBiasQuantizePrimitive and FFI Signed-off-by: Phuong Nguyen * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make sure amax is init with zero Signed-off-by: Phuong Nguyen * fix sharding rule Signed-off-by: Phuong Nguyen --------- Signed-off-by: Phuong Nguyen Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Further relax constraints to cuDNN 9.13 for disabling fused attn for kv caching (#2121) Signed-off-by: Kshitij Lakhani Temporarily remove comm_gemm tests (#2133) Signed-off-by: Vladimir Cherepanov [PyTorch] Disable determinism for sm100 (#2130) * disable determinism for sm100+ and cudnn<9.14 Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * fix remaining CI failures Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * revert some changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * revert more changes Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * remove sm100 from determinism table Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --------- Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch] ONNX export of FP8 Current Scaling (#2068) * Compute amax in normalization forward in current scaling in untuned kernels Signed-off-by: Jan Bielak * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * code drop Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix Signed-off-by: Pawel Gadzinski * apply tims suggestions Signed-off-by: Pawel Gadzinski --------- Signed-off-by: Jan Bielak Signed-off-by: Pawel Gadzinski Co-authored-by: Jan Bielak Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> [PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for experts receiving zero tokens (#2134) use torch empty for empty shape instead of from_blob Signed-off-by: zhongboz Co-authored-by: Kirthi Shankar Sivamani build: pull cached wheels (#2127) * build: pull cached wheels Signed-off-by: oliver könig * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update setup.py Signed-off-by: oliver könig --------- Signed-off-by: oliver könig Co-authored-by: Kirthi Shankar Sivamani feat: Add support for multiple quantization modes in the UB communicators (#2043) [Common] Add checks to CUDA kernel launch and CUDA API calls (#2074) * add checks to cuda kernel launch and cuda API calls Signed-off-by: Xin Yao * Remove exceptions from destructors Signed-off-by: Tim Moon * fix weired dispatch in ln/rmsnorm Signed-off-by: Xin Yao --------- Signed-off-by: Xin Yao Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> [PyTorch] Support bf16+fp8 cudagraph (#2098) * support bf16+fp8 model Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update Signed-off-by: Robin Zhang --------- Signed-off-by: Robin Zhang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Dropout with 8-bit RNG (#2014) * Add dropout kernel with 8-bit RNG Co-authored-by: Vasudevan Rengasamy Co-authored-by: Tim Moon Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix license Signed-off-by: Tim Moon * Avoid ambiguous types Signed-off-by: Tim Moon * Do not enforce dropout prob is representable in 8 bits Signed-off-by: Tim Moon * Expand error message Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix small statistical bug from using less-equal instead of less-than Refactor kernel implementations and add comments. Interpret masks as bytes rather than 16-bit uints. Signed-off-by: Tim Moon * Fix linter warning Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Remove unnecessary helper function in PyTorch extensions Signed-off-by: Tim Moon * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Tim Moon Co-authored-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Create GPU reload buffers on main stream (#2131) * Create GPU relaod buffers on main stream Signed-off-by: Selvaraj Anandaraj * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fixed typo Signed-off-by: Selvaraj Anandaraj * Fixed typo Signed-off-by: Selvaraj Anandaraj --------- Signed-off-by: Selvaraj Anandaraj Signed-off-by: Selvaraj Anandaraj Co-authored-by: Selvaraj Anandaraj Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Selvaraj Anandaraj Co-authored-by: PaweÅ‚ GadziÅ„ski <62263673+pggPL@users.noreply.github.com> minor code cleanup Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor cosmetics Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Address review comment Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci minor comment update Signed-off-by: Varun Thumbe Fix CI failures for UB overlap changes (#2149) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> minor bug: quantizer should not be none for unfused quantization Signed-off-by: Varun Thumbe [JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100 (#2135) * Fix failing tests for dropout=0.1 and bias for fused attn for blackwell Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix the skip message Signed-off-by: Kshitij Lakhani * Assert in fused attn bwd pass for sm100 Signed-off-by: Kshitij Lakhani Add check for sm100 Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Add support to get all devs in the process for jax Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Code clean up Signed-off-by: Kshitij Lakhani * Make get_all_device_compute_capability more pythonic, thereby avoiding unnecessary type conversion Signed-off-by: Kshitij Lakhani * Represent attn bias using enum instead of string Signed-off-by: Kshitij Lakhani --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> fix linting error Signed-off-by: Varun Thumbe [PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Graph (#2119) * add noop to comp amax Signed-off-by: zhongboz * fix for fp8 blockwise recipe Signed-off-by: zhongboz * resolve comments Signed-off-by: zhongboz * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: zhongboz Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> address review comments Signed-off-by: Varun Thumbe * Update test_multi_process_distributed_grouped_gemm.py change accidentally added while merging Signed-off-by: vthumbe1503 * Update dense.py change accidentally added while merging Signed-off-by: vthumbe1503 * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address review comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * address revie comments Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Bug solved: delayed scaling quantization with mxfp8 inputs didnt work Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the unit test error Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * just to trigger ci Signed-off-by: Varun Thumbe * address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe * fix merge conflict Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe address review comments: quantization inside gemm and outside both should exactly match for fp32 accumulation [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci Signed-off-by: Varun Thumbe [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Varun Thumbe Signed-off-by: vthumbe1503 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- tests/pytorch/test_numerics.py | 76 ++++++++++++++++++- .../quantize_transpose_vector_blockwise.cu | 11 ++- .../pytorch/csrc/extensions/gemm.cpp | 53 ++++++++++--- transformer_engine/pytorch/csrc/quantizer.cpp | 49 +----------- 4 files changed, 125 insertions(+), 64 deletions(-) diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index e72067367..a50b3fbca 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -39,16 +39,21 @@ from transformer_engine.pytorch.distributed import checkpoint as te_checkpoint from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm from transformer_engine.pytorch.cpp_extensions.fused_attn import FusedAttnBackend -from transformer_engine.pytorch.tensor.float8_tensor import Float8Quantizer +from transformer_engine.pytorch.tensor.float8_tensor import ( + Float8Quantizer, + Float8CurrentScalingQuantizer, +) +from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer from transformer_engine.pytorch.module.base import get_multi_stream_cublas_workspace, get_workspace from transformer_engine.pytorch.utils import get_device_compute_capability from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states, get_available_attention_backends + # Only run FP8 tests on supported devices. fp8_available, reason_for_no_fp8 = FP8GlobalStateManager.is_fp8_available() -mxfp8_available, _ = FP8GlobalStateManager.is_mxfp8_available() +mxfp8_available, reason_for_no_mxfp8 = FP8GlobalStateManager.is_mxfp8_available() fp8_block_scaling_available, _ = FP8GlobalStateManager.is_fp8_block_scaling_available() sm_80plus = get_device_compute_capability() >= (8, 0) @@ -2607,6 +2612,73 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): torch.testing.assert_close(o, o_ref, rtol=0, atol=0) +@pytest.mark.parametrize("N", [32]) +@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16]) +@pytest.mark.parametrize( + "input_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + ], +) +@pytest.mark.parametrize( + "out_quantizer", + [ + Float8CurrentScalingQuantizer(fp8_dtype=tex.DType.kFloat8E4M3, device="cuda"), + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3), + Float8Quantizer( + torch.ones(1).cuda().squeeze(), torch.ones(1).cuda().squeeze(), tex.DType.kFloat8E4M3 + ), + ], +) +def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_quantizer): + # For MXFP8 and CurrentScaling, below unfused quantization should happen + # FP8 input --> cublas GEMM --> BF16 output --> Quantize to FP8 --> fp8 Output + # Skip invalid configurations + is_mxfp8_needed = isinstance(input_quantizer, MXFP8Quantizer) or isinstance( + out_quantizer, MXFP8Quantizer + ) + if not fp8_available: + pytest.skip(reason_for_no_fp8) + if is_mxfp8_needed and not mxfp8_available: + pytest.skip(reason_for_no_mxfp8) + inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) + weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) + outp_type = torch.float32 + quantized_out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=out_quantizer, + bias=None, + use_split_accumulator=False, + ) + + out, *_ = general_gemm( + weight_fp8, + inp_fp8, + get_workspace(), + outp_type, + quantization_params=None, + bias=None, + use_split_accumulator=False, + ) + expected_quantized_out = out_quantizer(out) + + # Match results again Pytorch GEMM and allow for quantization tolerance + pytorch_out = torch.matmul( + inp_fp8.dequantize().to(torch.float64), + torch.transpose(weight_fp8.dequantize().to(torch.float64), 0, 1), + ) + fp8_tols = dict(rtol=0.125, atol=0.0675) + torch.testing.assert_close( + pytorch_out.to(outp_type), expected_quantized_out.dequantize(), **fp8_tols + ) + # Match results between quantization happening inside vs outside general_gemm + torch.testing.assert_close(expected_quantized_out.dequantize(), quantized_out.dequantize()) + + @pytest.mark.parametrize( "shape", [ diff --git a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu index 4c82b8c81..d38bf7996 100644 --- a/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu +++ b/transformer_engine/common/transpose/quantize_transpose_vector_blockwise.cu @@ -579,14 +579,19 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor "Input and output_t must have the same shape for columnwise non-transpose case."); } } - - NVTE_CHECK(output.dtype == output_t.dtype, "output and output_t need to have the same dtype."); + if (rowwise_option != FP8BlockwiseRowwiseOption::NONE) { + // output may not be defined if rowwise quantization is not needed. + NVTE_CHECK(output.dtype == output_t.dtype, + "output and output_t need to have the same dtype."); + } NVTE_CHECK(scale_inv_t.shape.size() == 2, "Scale_t dimension must be 2."); bool columnwise_compact = columnwise_option == FP8BlockwiseColumnwiseOption::COLUMNWISE_COMPACT; size_t scale_t_k = scale_inv_t.shape[1]; scale_t_stride_x = columnwise_compact ? 1 : scale_t_k; scale_t_stride_y = columnwise_compact ? scale_t_k : 1; } + auto output_dtype = + rowwise_option != FP8BlockwiseRowwiseOption::NONE ? output.dtype : output_t.dtype; const size_t num_blocks_x = DIVUP(row_length, (size_t)kTileDim); const size_t num_blocks_y = DIVUP(num_rows, (size_t)kTileDim); @@ -597,7 +602,7 @@ void quantize_transpose_vector_blockwise(const SimpleTensor& input, SimpleTensor input.dtype, InputType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( - output.dtype, OutputType, + output_dtype, OutputType, dim3 grid(num_blocks_x, num_blocks_y, 1); diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index f4768bb9b..485d67055 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -93,6 +93,8 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans bool use_split_accumulator, CommOverlapCore* comm_overlap, std::optional comm_type, MaybeTensor extra_output, bool bulk_overlap, float alpha, std::optional beta) { + using namespace transformer_engine::pytorch::detail; + // Input tensors NVTE_CHECK(!A.is_none(), "Tensor A has not been provided"); NVTE_CHECK(!B.is_none(), "Tensor B has not been provided"); @@ -123,10 +125,10 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans "into D tensor. Beta has nothing to be applied to."); } + DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); // Output tensor TensorWrapper D_tensor; if (D.is_none()) { - DType output_dtype = out_dtype ? *out_dtype : A_tensor.dtype(); std::tie(D_tensor, D) = createOutputTensor(D_shape, output_dtype, quantizer); } else { D_tensor = makeTransformerEngineTensor(D, quantizer); @@ -139,12 +141,35 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } + // maintain unquantized tensor in case we need unfused quantization support. + TensorWrapper unquantized_D_tensor; + py::object unquantized_out; + // Unfused quantization is needed in the following cases + // 1. Inputs: BF16, Output: FP8 (GEMM output has to be BF16, so FP8 quantization needed after that) + // 2. Inputs: FP8, Output: FP8 (For any quantization apart from delayed scaling, + // GEMM Output needs to be in BF16, to allow for unfused quantization) + bool unfused_quantization_needed = !quantizer.is_none(); + if (low_precision) { + // At the moment, only use-case for fused GEMM: + // Delayed scaling quantizer with per-tensor scaling inputs + bool is_per_tensor_scaling_input = IsFloat8Tensor(A.ptr()) || IsFloat8Tensor(B.ptr()); + if (IsFloat8Quantizers(quantizer.ptr()) && is_per_tensor_scaling_input) + unfused_quantization_needed = false; + } + + if (unfused_quantization_needed) { + NoneQuantizer q{none}; + std::tie(unquantized_D_tensor, unquantized_out) = q.create_tensor(D_shape, output_dtype); + } + TensorWrapper& out_tensor = unfused_quantization_needed ? unquantized_D_tensor : D_tensor; + // Bias tensor TensorWrapper bias_tensor; MaybeTensor bias_grad = std::nullopt; if (bias.has_value()) { if (grad) { - auto opts = torch::TensorOptions().dtype(GetATenDType(D_tensor.dtype())).device(torch::kCUDA); + auto opts = + torch::TensorOptions().dtype(GetATenDType(out_tensor.dtype())).device(torch::kCUDA); bias_grad = at::empty({static_cast(B_shape.data[B_shape.ndim - 1])}, opts); bias_tensor = makeTransformerEngineTensor(*bias_grad); } else { @@ -157,7 +182,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Activation input tensor MaybeTensor pre_gelu_out = std::nullopt; - DType gelu_type = low_precision ? bias_type : D_tensor.dtype(); + DType gelu_type = low_precision ? bias_type : out_tensor.dtype(); if (gelu) { if (!grad) { auto dtype = GetATenDType(gelu_type); @@ -210,7 +235,7 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans // Direct GEMM call to the correct overlap if (bulk_overlap) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, D_tensor, bias_tensor, + comm_overlap->bulk_overlap(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, comm_type.value(), extra_output_tensor, main_stream); @@ -218,14 +243,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else if (comm_type.value() == CommOverlapType::AG) { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_ag(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -234,14 +259,14 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { if (comm_overlap->is_atomic_gemm()) { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->atomic_gemm_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); }); } else { NVTE_SCOPED_GIL_RELEASE({ - comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, D_tensor, + comm_overlap->split_overlap_rs(A_tensor, transa, B_tensor, transb, out_tensor, bias_tensor, te_pre_gelu_out, te_workspace, grad, accumulate, use_split_accumulator, extra_output_tensor, main_stream); @@ -251,15 +276,15 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } else { // Launch GEMM NVTE_SCOPED_GIL_RELEASE({ - nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), D_tensor.data(), + nvte_cublas_gemm_scaled(A_tensor.data(), B_tensor.data(), out_tensor.data(), bias_tensor.data(), te_pre_gelu_out.data(), transa, transb, grad, te_workspace.data(), alpha, *beta, use_split_accumulator, num_math_sms, main_stream); }); } } else { - if (D_tensor.numel() != 0 && !accumulate) { - D_tensor.zero_(main_stream); + if (out_tensor.numel() != 0 && !accumulate) { + out_tensor.zero_(main_stream); } if (bias.has_value()) { if (bias->numel() != 0 && grad) { @@ -267,7 +292,11 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans } } } - + if (unfused_quantization_needed) { + // Quantize the output + std::unique_ptr my_quantizer = convert_quantizer(quantizer); + my_quantizer->quantize(unquantized_D_tensor, D_tensor); + } // Pack outputs std::vector out; out.emplace_back(std::move(D)); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index c690cd522..cd7e70fec 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -96,16 +96,6 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); } std::pair Float8Quantizer::create_tensor( @@ -318,17 +308,6 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso at::TensorOptions opts = opts.dtype(torch::kFloat32).device(torch::kCUDA); tensor->set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), getTensorShape(amax)); - // quantize output and its transpose - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); } std::pair Float8CurrentScalingQuantizer::create_tensor( @@ -562,20 +541,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti this->all_gather_usage = quantizer.attr("all_gather_usage").cast(); } -void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const { - // Change the rowwise and columnwise_data to the configured dtype. - // May be a switch between E5M2 and E4M3. - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair Float8BlockQuantizer::create_tensor( const std::vector& shape, DType dtype) const { @@ -917,18 +883,7 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize this->dtype = quantizer.attr("dtype").cast(); } -void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const { - auto rowwise_data = tensor->get_rowwise_data(); - rowwise_data.dtype = static_cast(dtype); - - auto columnwise_data = tensor->get_columnwise_data(); - columnwise_data.dtype = static_cast(dtype); - - tensor->set_rowwise_data(rowwise_data.data_ptr, static_cast(rowwise_data.dtype), - rowwise_data.shape); - tensor->set_columnwise_data(columnwise_data.data_ptr, static_cast(columnwise_data.dtype), - columnwise_data.shape); -} +void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, DType dtype) const { From 7042d7ae6daab0624e3bf7412e276d61be8283f6 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Tue, 16 Sep 2025 22:30:24 -0700 Subject: [PATCH 142/153] TE Gemma tutorial attempt#2 (#1839) * add tutorial files and other local changes Signed-off-by: Sudhakar Singh * remove extraneous code for easy debu Signed-off-by: Sudhakar Singh * make cuda graphs work with non-paged and paged attention Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * perf imp for kv cache ops Signed-off-by: Sudhakar Singh * add code for calibration Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * optimize kv_cache reindex and copy kernels Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * changes to make quantizers work with fp8_calibration Signed-off-by: Sudhakar Singh * avoid reindexing from python side Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * rename variable from previous commit Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * minor fix Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> * use quantizer only if needed Signed-off-by: Sudhakar Singh * functionality of the tutorial tested and perf checked Signed-off-by: Sudhakar Singh * remove files and update headers/licenses Signed-off-by: Sudhakar Singh * update header/license Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * update tutorial for review Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make weights downloadable on the fly; remove extra print statements Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix lint and update comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add comma back, typo Signed-off-by: Sudhakar Singh * sequence_start_positions should be None for training Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add paged attention numberes and update requirements.txt file Signed-off-by: Sudhakar Singh * more fixes Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * make tutorial work on blackwell Signed-off-by: Sudhakar Singh * remove gemma FT tutorial for now Signed-off-by: Sudhakar Singh * fixing the headings placement and rewording attention -> kv caching Signed-off-by: Sudhakar Singh * fixes from comments Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * fix the images Signed-off-by: Sudhakar Singh * misc fixes Signed-off-by: Sudhakar Singh * add more comments to te_gemma.py and cleanup utils.py Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add more information about the hierarchy of the classes used in the tutorial Signed-off-by: Sudhakar Singh * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * add better cuda graphs picture Signed-off-by: Sudhakar Singh * addd updated cuda graphs pictures Signed-off-by: Sudhakar Singh * add illustrated cuda graphs Signed-off-by: Sudhakar Singh * fix Signed-off-by: Sudhakar Singh * small fixes in documentation Signed-off-by: Sudhakar Singh * add torch.no_grad() to force reduced memory usage Signed-off-by: Sudhakar Singh * some fixes from recent comments Signed-off-by: Sudhakar Singh * more fixes from remaining comments Signed-off-by: Sudhakar Singh * add te_rope_emb to class desc Signed-off-by: Sudhakar Singh * fix tutorial wording; add calibration fix to grouped_linear.py Signed-off-by: Sudhakar Singh --------- Signed-off-by: Sudhakar Singh Signed-off-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com> --- docs/examples/te_gemma/media/calibration.svg | 620 ++++++++++++ .../te_gemma/media/calibration_1_half.svg | 415 ++++++++ .../te_gemma/media/calibration_2_half.svg | 401 ++++++++ .../te_gemma/media/fp8_model_init.svg | 500 ++++++++++ .../te_gemma/media/fp8_model_init_1_half.svg | 358 +++++++ .../te_gemma/media/fp8_model_init_2_half.svg | 371 +++++++ .../te_gemma/media/generation_animation.gif | Bin 0 -> 135280 bytes docs/examples/te_gemma/media/graphs.svg | 232 +++++ .../media/transformer_cuda_graphed.png | Bin 0 -> 369694 bytes docs/examples/te_gemma/requirements.txt | 4 + docs/examples/te_gemma/te_gemma.py | 703 +++++++++++++ .../te_gemma/te_gemma_loading_weights.py | 189 ++++ .../tutorial_generation_gemma_with_te.ipynb | 941 ++++++++++++++++++ docs/examples/te_gemma/utils.py | 370 +++++++ ...tutorial_accelerate_hf_llama_with_te.ipynb | 2 +- docs/index.rst | 1 + .../pytorch/attention/inference.py | 28 +- .../pytorch/attention/multi_head_attention.py | 24 +- .../pytorch/csrc/extensions/apply_rope.cpp | 3 +- .../pytorch/module/grouped_linear.py | 2 +- .../pytorch/module/layernorm_linear.py | 2 +- .../pytorch/module/layernorm_mlp.py | 17 +- transformer_engine/pytorch/module/linear.py | 2 +- 23 files changed, 5152 insertions(+), 33 deletions(-) create mode 100644 docs/examples/te_gemma/media/calibration.svg create mode 100755 docs/examples/te_gemma/media/calibration_1_half.svg create mode 100644 docs/examples/te_gemma/media/calibration_2_half.svg create mode 100644 docs/examples/te_gemma/media/fp8_model_init.svg create mode 100644 docs/examples/te_gemma/media/fp8_model_init_1_half.svg create mode 100644 docs/examples/te_gemma/media/fp8_model_init_2_half.svg create mode 100644 docs/examples/te_gemma/media/generation_animation.gif create mode 100644 docs/examples/te_gemma/media/graphs.svg create mode 100644 docs/examples/te_gemma/media/transformer_cuda_graphed.png create mode 100755 docs/examples/te_gemma/requirements.txt create mode 100755 docs/examples/te_gemma/te_gemma.py create mode 100755 docs/examples/te_gemma/te_gemma_loading_weights.py create mode 100755 docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb create mode 100755 docs/examples/te_gemma/utils.py diff --git a/docs/examples/te_gemma/media/calibration.svg b/docs/examples/te_gemma/media/calibration.svg new file mode 100644 index 000000000..16e1a4314 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration.svg @@ -0,0 +1,620 @@ + + + + + + + + + + + FP8 with initial scaling factors + + + High + precision + weight + + Initial + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + + + Calibration + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + FP8 with calibrated scaling factors + + + High + precision + weight + + Calibrated + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + diff --git a/docs/examples/te_gemma/media/calibration_1_half.svg b/docs/examples/te_gemma/media/calibration_1_half.svg new file mode 100755 index 000000000..478604d41 --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_1_half.svg @@ -0,0 +1,415 @@ + + + + + + + + + + + + + High + precision + weight + + Initial + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + + + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + + FP8 with initial scaling factors + Calibration + + diff --git a/docs/examples/te_gemma/media/calibration_2_half.svg b/docs/examples/te_gemma/media/calibration_2_half.svg new file mode 100644 index 000000000..439f4c16f --- /dev/null +++ b/docs/examples/te_gemma/media/calibration_2_half.svg @@ -0,0 +1,401 @@ + + + + + + + + + + + + Calibration + + + High + precision + weight + + FP8 scaling + factors + + High + precision + input + + High + precision + GEMM + + + + FP8 with calibrated scaling factors + + + High + precision + weight + + Calibrated + FP8 scaling + factors + + FP8 + Weight + + FP8 + Input + + High + precision + input + + FP8 + GEMM + + + + + + + + + diff --git a/docs/examples/te_gemma/media/fp8_model_init.svg b/docs/examples/te_gemma/media/fp8_model_init.svg new file mode 100644 index 000000000..57af23dc3 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init.svg @@ -0,0 +1,500 @@ + + + + + + + + + + FP32/BF16 + + FP8 + FP8 with fp8_model_init() + + + FP8 + weight + + FP8 + GEMM + + + + + High + precision + weight + + High + precision + input + + High + precision + GEMM + + + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + + FP8 + Input + + + + + High + precision + input + + diff --git a/docs/examples/te_gemma/media/fp8_model_init_1_half.svg b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg new file mode 100644 index 000000000..d86751e07 --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_1_half.svg @@ -0,0 +1,358 @@ + + + + + + + + + + + FP32/BF16 + + + + High + precision + weight + + High + precision + input + + High + precision + GEMM + + + FP8 + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + diff --git a/docs/examples/te_gemma/media/fp8_model_init_2_half.svg b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg new file mode 100644 index 000000000..c3e4146ba --- /dev/null +++ b/docs/examples/te_gemma/media/fp8_model_init_2_half.svg @@ -0,0 +1,371 @@ + + + + + + + + + + + FP8 + FP8 with fp8_model_init() + + + FP8 + weight + + FP8 + GEMM + + + + High + precision + weight + + FP8 + Weight + + + FP8 + Input + + + FP8 + GEMM + + + + + + High + precision + input + + + FP8 + Input + + + + + High + precision + input + diff --git a/docs/examples/te_gemma/media/generation_animation.gif b/docs/examples/te_gemma/media/generation_animation.gif new file mode 100644 index 0000000000000000000000000000000000000000..25150cb9b64162084b017442a3905c57127c6713 GIT binary patch literal 135280 zcmdSfRZtvG_%7(d!T^K=5F}EkGa;EQ1g3?#|%u?(VLGySwl2e`@#C z*{i+YzUum_tE;*%zUq4Wk&%<;<2TOuz=!An08qZ8DNAXnONgmTak8-^zyba%Jt88F z0Ym{T|IGpavn=ra{r&CjEjT#%zd}z>Phnx9m6a7YH#Zg*mW+(d+S*!OT^%(wwY$6f z$;pY9me%CtWKvQR5-m3lrzj!~*ZJiYCMITNWF!Rz#rgSpMMZ^!gM+H7>cqsv#>R%I zsAy(p=D&acjEs!f+1WcfIsotph*a#boxSbdy{oIMUlj;{ngD2Q?7qIfO-)VH)6-E= zQ4$go&d$z#eSI7p955J6U0ppRBg53xl#-INw6ydS$5-mFN)`nlu)nZwZ14WhLH_?d zkki2cUqAnVz#vF)NN8AiL}XNSOl(|yLSj;KN@`kqMrKxaPHtX)L17WJxTLhKyrQzI zx~8_SzM-+HxuvzOy`!_MyQjCWe_(KEcw}^Jd}4BHdS-TReqnKGd1ZBNeFL_+wY{^u zw|{VWbbNApc7Abrb$xStcmMGC^!)Pr_6|TmC6TSn>hc4AV$vV1%O zCE${^ULCB;9f+im2qBTH&KrtlP%qXWs?HxtWdAvwCRbB1mdf}0XmzNja3WLG50zBD zwrDC>HkRqfa4mGEP&HRRUB0e(u0*?r*LrQZu4JLYs3(L}p}usf#&V|k$4GtIN`u4p zaJoW6`C5zH_0if$L&Zi17y*q;v9WTq2l9!T&0w^#YI`7xS|LNRsd{%LiOXhvw5eu) zB1IKZ2-^wJBI!u&+jCc*EbAA%=>K5WTMgyQ-RQiUH1 zi(*HUP#sYEi)&h$eHA9kF%dqX%@uV=K5)o|$H+NQONb%h|Cpjxe@G{!N|du1VvQ

70h+CDgkjRjPS0?*>3nJ| zqhb<-RNW8+_+e7_pOl0^RZFLj8b^a+!w`9?)Ut!49~8 zp%T5n0$DiOY%!{LTLlA?j;xMVAA{8w$ zKnyg!uyzN6iwxaC;4s+_6O@zJ6zGx^bre{g51rLJC>J$O%10Mj;r=_Vzn$VE&`T1e z3xl|+C%hhwkU73uo>CA(o?s!`Um=~5zNWBxr{B3OwA4YL7Y-><{gZ?*1AUFk*K=^cVsuWke^HgK z#Pw#YJ%=nRQI9q=<8$mEW&b$;-7Vw0^ZQ-7n5A-o{p#)cw4ksd=ZZbWYVZ`qJs}iZ zk3D^RidD|aOL;>k-U@41AEWU@_8#XRe9Fao=gcP)i>eUcgIo8jZD4Qb`OJU1lH~NF zRae$@=1TyQ8bWmB6dNr!-^!tYCmf_%M_2N}bzeUvDgsWE<~V+*mV6+aZ3jWvg|Q_4 zArx?J$h2qf%(WBuzS`N}yBvI-ski)ET_tvJvYHFyPrc?E}i=zjX?gWc`Q-yb4(B58VB7{gG3an8&CG> z3Bzp|UKH6hfp4adLcd;OoidN1HyMO^9{oXIw{b8CzyYVI0g9RLS1s?eMZ9?&J$_TqW%)=x826dl3{7E*(Fwf-fh9B&ZO&=7kI9?2&BooGD@$2jL zdK#+OH}R+m&zlj8v!k2b*gIv_LVrv4+wE^xwKJ)%LXzR~PmQEUPJq*y$hqkxm5L{v z(lea6gw}?sQKZ?j*F$S|Iff||wquJDp(2C(Crv80+7L&cSqhRnX+YSOKry#}56f^j z<(EHD-|suW@92Ot0ddknvKol`NJ2V~W4WP48zrSl@^T?qfXbdDkxNKRJzZg{g$u)# zQ%KH}RYJ{S_Q|6tpY-lK3z~eUQc5)6FsE3bx;hly= z0Bx_`&qk`?5*-OQLY7T*sbE0uY_qw+1sKeS=v z;CNAbimwac*Ujvn|NPZz?6syW#6Mr`qX*eSq9^n&U1N=*O#4avk57z0V9LjZWN|>8 zaRdi1(z8%uXeWj5{g1WMWoUM&34?@7KqxjRX&=3@_txT&OnGJUvw$&$pnxCo&9%p2 zK5Uy%A?_x-AxbCGC~sO+H*zKWXq4oh*zw{#me((8hWnKvNXD_$seOar<}k;0o#J0_)6CKp^W`KtX$FyX znRqW&yXScVGfNLYpWiR~4njMU(H$8hmp2!%D*goNF-xTJj;_Ghd;dTMdsd`uh$P$u zDo5;}*Lwm!1$X>xWEBC1Ux&APNxL<CXF6tj#a*jHo z(mIhQ1`l-|JSU;LbrytKCrccx?44-s6Z78%-K~Xm{2~R;oXs3J$j9HWhZ_!8NfE(Q zd;R?D;hb^AD@#03UUG3d(jCybv)ShFbndTfml-h_fjKB=tK;5D`aSC|I2R%@dW0eZ z5N%|-8@`#LAiv(w$;AOr8~V@9TcMe4ev9pP4HAdo_f_k&4XErC@^ownXckUX`ZA_Q zWNRD8y&QEPUSref`-Ik`%T0~9V;y|AfW==R1BHkd z{pGB%=WmG`It2`fyF#H=HJwsb!!PARb3-KREHjBu*WO_OJA4-(o1H&%0}hP^|LR=L zkOEEfVrkJE1po*>hgecZ^M5wPOe~P7x9R`R%=_S8`TYvgXOn#g@=Dr-^Z2<-s0-6c zq7Veo`u=rs66NXE^ZOv{NAa~K*I8%JSg}!*kk9P0K=ote*Dy^=g8?S}2MPck(0I`e zr^qiw)7-0o?hAvHv#fw)3@^8@Gy4|%We<%K!k-Y#AOcI#zit3t6R=tVAaoJ(L=XH6 z^pg|!@jj6L97Ki85Rkz~f|m6CzL0FVN3;B%l|DXrvBxyX-;)0iAMyoP;;s8%DsgWo z8Pp4@^qtR`Kt1u;pe%apqblB!LRT}HU!sT z-32`(s!-{^cJn1dbeqI>Zua3V#ZI*No9OoV5eDbIKkEp?VSV|@{Wk!`T>alB zeNf93PoR}NBsuX<#ADJGdkDO9_$T!rsadu~8pKH`$!SZ;Cl{g>-}SkD0e)5}nuA8I zjwx6N$)!svI}}EF`3RUr%Gm(L-EHD}e)00A)Yc_*W5$HwEG2^>87Uo90nvDPznCt^ zv`MF!_J!0*!t`nW^jV$s`M>Fl$?3~Y>8ngj>FbZ_Fv5&2{)`=+jJ>}Z2gw;nO&KRk z8E20f7lfHt{Fyg8nRkCPALx@apPDjXmNMTSGXX?d@B&$gx>+CGvw$gCsLffRW^_Z)_l944CP9G2xAwx=9+ zqFhdaTyEXmFYdW~DY*j8xkAgiB2T%ZM0w)W0(p|UdD8BAvMG7;&3TH;dCE_DszmwU z1oAa>^EKV`wNvunb({0`m-B;-;Ejk1Oaux{bqmbh3oKI#teXpLmkaEl3LJ6MJdI_&Bdk5#pSwCkLJQeqVycyl6v=&#*~ug z<`U;(5FKVoCsAp)KxwaTX}^2vU`pw5bLr@E>G)IWBvIM4K-sKr*@!^NBw+z=W!dU- z+4@r%jHrA|pnONSe9yi7Af^1Mx%_0g{Oqaxf~evOwHyafkS7mE5CaEC!S-BS@$yvh zMpXGKQ27m@PzkSB3G}E$wW$Q9R-(65Vh~qh3RZp8tHQRa!ttmgO06OZts-lwB44Q@ zW(Im#!!ItU8p*@Kf2*KN1$7ix!(CO=kEC)EQ_y-;z+aWH>I>GCC08@sRIo9ZaOhPC z9jEZ1QScJi3YL%x=+z28mwauh?T6J!FxM$RC&`qM%7ucI^y;*ZlQlwXMMvsXkLwIa zl66u^^*lg7&?xLY>LUp2bp`8f_3G^vYTq3~>%>EA-A3xXMr!}1Hn=i2e7uCG1p~JG zv408H{XuKgWNy4ks_&#=vLOJ?87FUn!}>a_+ld-CRT&@@Oz;;X@o0d0zb5>dw2uz# zg{q`*RgF=^P2p%Dq4?wiWEH55}W zF9(4X1`jXFR;Cd6l-0JF3if|S!?^;S`!pXBYo9%%?2cfdkbqv$K<{W6NMYU|YnbY{ zL#G2YgH_4%ctR#UlAlsPqOW#v$K!W581V`i36|h=EzfkNS9Tu3x;q@&y`LL5FuVTr zK$ej@vQtU?4p}FLjJ0@B=bk^(yt8!g-gUNOvgqY~1xvu=`qCHn@uGLq)2gPGej(NQ z3m|e4-rDTdXH5I>n+K5-p=g(vL$sfez%ho(cZ&+(=?h16=!gtsCAa{5mrW zL&y$46`O!F+xo9GF!V&CbWixNAK)_p#5r1@f!2k7Gukgnv1BG_OV2cU48NMMc6^Z7 zlRNs7&)uf~-y@`F9Xwp{UTP4ShKA@wdhH2hvgML@>e;(QNF1>~=is{~mis2ridan* z?CD8L0Vw(g`E^T{t5rWYb*q-4C1LlTS-onyBCDB={H z<}1xC`uanL@|?Z5<{GyiJXnfCA5~L)zmwvfp#A#++J6NFRJVq9(vMs*)Q=d?CJG^p zuj9*5mV% z**0u1ODu1E*@VAi1`teY3-v|G@nE4t!(e#N*ur<-D^Oj)`_Uo$cLwfFpFw<^&j844U^17#p%R~bQ?M_0yWRe( zA#NaO-pgeFo3a#dxPKDX_q;RG#_ljSt&Oo3XW<`0+7S@JNfo=OdkkFa-sPthZ~A^1fLh-b{&3v@>|WnnhLcjXFcHv>HwcMDPR~LxzI!nY zHg-9p`uIM+X|xWU!-`{D*f~+X*z%?wd=n@%80W#=|DX>f!dxQ@YWehQ|8zaZJJ6gi zb6eQp_-W79l%VUh;3_wP!TB+fGID?Giot7bI^qnC^cOnTi^z{Q#~CFdU(is9!#r)m zE|wxCwi3mP9i8TXSCmHIXtTADbxCYUYqWuqQ!W1oLGNIE@abpr`DN6}xaL`e@l))N zA$;jPT8B5gb8MCK<=fp4HKTKlo*vvrtH1)@*04if=eF3VxhM@|;VSL7+7RnIZm4@;SV_-TK$1Nd(1)hrc}Yr47o+uN+@og^rytVdqY zV0H?CsJ0~oJNu-uEQGI9mLS&hA&MaiE-TBAz4KA~;Gq&2g%wTk)6IG@nh%VELd-0v zpr>GL@~xz0WW~nP>6lqaQ6I$oIOGCt9{K!*;87Ngv< za+-x{w3?jy)f{rjMPiRpuA`3Vl|96%$AW{M>bmP{ee#=>y~bE5T43UEL3 z-%(G<^|*;p>VD*$QjHIDkQBZT!5kK!IWoItV5ZRiNitwA`Ju#*p0u(w4d5A$>+|%$6SEw4xR$mygpf`xWdS@t(S+6}zC(&f!Gj@C)5~s}v zk_`JVPD4dF)$WN=t}2}vv^&1g3$y6Zxr(+lL#FOcGC_3m2sPOX?ho3pl!7? z+p%r!v~nn8?d;Q_jrqL6$aejz_g^&oMvG!5`}lSHS=%PVj1v2nv-TG1mWzZWyTJBw zwEWhGvyopNSTV;HkxwFs4xMZ3HV)mWpIbLOL4w-nW*A>bVZAtpVUGPk+)|DZJmeJn z0d8*=ry;sb5XumP3f^TTup(_|gndHs_gIUA?e74lUnEZBjKe2^zbD0WTdya{>PBf@ zrCqgerj@_iF^{WKFFK2TQwgV?)c!T*()iu=6mvSx3;m0PX=3=F<%zq}yFy)sKQ1fL z>8!46e{AIL*Ipq%c2?XA5AQd8KfYP4`gagMq}YYf@DBJ}%&=MVqhAoaAh#O{2x?q@{t3xiznVg^B&FULZLK z0yx!SkI8~{|C`#@PF&l6IcL2e-N=ndPt)bcdZ18p3wjjFS~~*Q*nnSPCe;4;6lpsVCRdn14=xRMaA> zNvJW+FQ|1RWU&rZNzTjSt>_lfNjgv6v6bTM(`k1~xK0b*DOiz)4u@uOX58Ea@VL0B z(D6(fT+n=oW`im_Rdbmgf{J*N2JYirCSDKj2izT`NMq z*ofshvfa4s9$QFIN(xLD#l?j%BJE3(U#kr`JBM=ko)vq9$bGz z3w)Il3$h*$R6Z(_B^6cseDf{*M7_{7nCbg_g)Szd1AGZ@&C#Lv5n6k8EtFZRP&_pS zQk{ZZW_5dHVzQi+BCT4kUv^xK*$Eb=izyVSVe0e`nJ;t?g398a7)=mq=LrjaXqJtn~zFH3j`XD;`~;UZyKk?B8Y8LvLA&BfZ69ZOm5q};GsgBocRLWzMTKM%2S2Bs=8ZLKj?olR ziNKM&-jT)aWT*BpNu98$`oRP=N2h|?z3@@nViXO*YrCj^1nJdp~6n?tL}2fk^ecIOP%wG|rjFZ0j>NEVlgWz%l+ zZ0k5ynnmqiMK$k?*&Cf_=-8yp#8aDv0cM=v=?46d`yF>nWFs)<<=-P(Il-FTKqTKE zWBc};Efa$US*0k>Hghg3V=Ec;6W(5oZXZ!*KmNG4Zc;oxPDss@{&1g38ba%y*t~XD zpWMpqa)FDLmUZ`~H}bX3%1c6M@OAtiS4X&B_ovGyQl?Hf_3h;+F=7XFa?bYTGFv5e z{iU0=1_hmn4RFuu8p-*?uy{K|T4K?j3mZRWyWM47L~6doKhXHF5sWMN^*Ys6jozUDXfd+PYlPlsw#56k^jF)l`gvYkA>=3$)$1q9gvqL2;-Vh@^+A`O zd&$(>q!sULk3QJbH|y^&xcdA++=%en8DEq@Eh!KE4-y43b=y@;#<@q=$&m75OrY zxjF93V0~=l<^=k@!T%fEQ{f|^p%|p}+9$vT_={@GNig_JJ;=&v%ZUwouv&^R4*I{| z`BM=vd^`9k;-=57?YHwn9ps`3Qw8(C7Kp`d@Xl5B`VHtQ~473~!w37|AZVa?% z2(kbM$(wt4?|W+}1w1(U1B(vuJRE{^1(%rGx&c&K=Xt~v=MfKO3{$LY+#*9K$#iDo6jL-F<2GhZw$VPcb#{P zY6uSp;6Cb;#3oj_mShftP7;A%<)CXnL1ORj8xxfR(8K`pE#k`3{#lk3 zG?)|y{lqfb7+MPsMVi75`2w1krA$lW>TSeYWe_`B41b5gyL{6MSSwCm=rw?xQLZ{Fd9k_P-jlNP>cJjDv{F zB(5-RieJ6VWMiIT>i1({#jNS-phOmJ8I0`wg%!3ePv6x^A32Sb^92`M;FUDmh=o@|9yRi_B z5Ah1-$3i;XOI4I_7P#J1!jx6=gp4#Lmd0u*--zYhRT$%)sAImn#fO+mE`fF;d8QyI}-tk(3v=22Hh zNdPAtVMkIZKYB$|RCF?tyK!?iY%J@~{YZlG6B~aQJC~yIot@#L{NWBMVG!Np`4guYb zF&u_aTo%zls&-w2lxh#2=J}avI02bAnbNL7oM!1!x#=rbECnb@MaikuoAD1E;js@1 zQ`D)sl^J1CDImsA<&FsmqH*w@7zdP0Z7AwbLjhkca)uLfSQ>HwS}u!no=72C??f}{ zFD~)#ts%uVI}}ACj-poUgZ>&$9ei4kz3=OP(XFKd-^1AcQ_bPX~j{S}xUBuI zg>AYuT5O1JqplY65wzo*S@81DyN7h)zW@#Zc7Xsk2L&MB0dT1UwBC%}iHVB6>?{L; zV{R)!RztBQ!=I{t1NdNJZAK+g_0Z$*ku3M&AsQ7ox&?N)6^PcQ6g3s{3l&4CJUm1{ z>cF9FmFNdX_!|xYzscN^3v>W9$}jq3FhvziC(qxyNZ&Drfw9aKnwrYU-c6+0i;Ah~ zY9mtjl`fT&A+@3%w-P(L08yga-LDkhh8guJU2Ca4a_ZB~ptOvY3RS8>1d-PGlb)={ z5BU*~+PjLdsptaN8otz@k}E&!2eXn131ugbYOuBJjP39(99BN}~ z!yBj9RN&Nrv_eyXRKINA2#46+$&g`(AWw`9-bhm8yFA-oXhr9+>axuj9)-Ks(da}O4ds#kKXJLrTTz0@7!1mU6v+-_cXIo$%I zZC9!9VhwFacf{8pMDs)ts#uZjMqXjz3@!FNQ0FhQ5MezL`5fkZgWpW;(QVAtW^r7? z+|sj!*NKVHa%A4ka3bnO5WpKo=?sNWvO+`8G)c;EqTaxFpE?vz#R#ou-|Bgi)z z&i^Ib^Uy?@&)oJtVrM4a8u45&7UnFG)(6e)wZtAk#qKD->RO)by8kl(AnEQ;rBr0; z`0&y{%Q;~BZBT=HP)M=UdZj0#A}Qi&bpcm+PL?VMz6G^M}DdFN0(xqa9HL zOq#9vFF6UVBYv@JVb;u8fH4zW$~oeZ9=$O)n>GzlZ~bb(s-k*h>*%l5k%5(=Z;BNe zp4EG!19>kl>z*zho~1C7aY>=^Am0A+RgnkJ{_588c)HGb&*5t!J~yf4hHn$t2ovAF zkM_0>WQUD4%ji<+w_l`A64Ot>r`O=O8WXJ<hJ`w*1+9a{NF}Df!tGV) zJRz^`kkjyb)|KSy(r)as&r3xWV<9+7tFF+MHsOidv5K>x74@Xm55p_N9(IflM0(L) z)UjjV)mmm|l-6RpS7ou+Rgc%CuUAox7jR8nybm#P$X;*QPi%485H8V+M6p zQ3C-oiC`+10lVZHxnS`OgJ z|BZjobll&7^tYz>?r`J0ajVuN;hy=ruT#0NiTB=MDU`N-PlPL!nPiBKXUlwZ_Xz)> z7iNK#ICcUWviGnkLmOvNo@LA6VsGF7LAHrmW<+T~>E|tvk`esx&FOXY2jl+1vBxOC z-LSR6p;%2oR{P%o47@ahuu3*ikg}s)Nk7K^L8I~r2>wu$`tXay;d_ZUjTfKWr0XGP z&RW&+QL4q!fUY_kMfCN(n-$TiJa6i5#{me!|6=jad5;;XvT5#`ssSa z9%vA@L>99`cAAuebrpftRkm&1c=p6}lpJyR_I8e5aOiJ;`ZwtG!X9?Fe(-$O+>#a9 zJ%)xF$uy&N2KWBUf9ycI?woz)Y#IB6QREC5>43dqeW(oDf+oYNPl{_U_dj^ITUq&ScxwjBWqr+giWudUkjYG zd?qhdA;#s6hWp2i%Qs0}_yv0oF6T!F`6YjJ2JH5esI>523qD8*s_9R zOf0YHCnpOL5KzgkkVmCx6ys^sh%6#A`ZV^~LV6NEFSeVJ-m1X{f`1O`lS?EECEMmn z+f&Q63&lJBFL_%~0{!3e7RNx!hFJ&o_usOfjv!QgI-wi6AS&6QG%`$D4t1>%mC)f4 zI=tM^`OQgqf}!)8!!foiZ1GKujLQvXo71%NU;YQ)c3p!0Z+IIj%rwTP&^kq+NmV3= zjZusM&?W`rs%cB0>YKXpv0_n3I$}9Xs(RX2ZW%kR74ISn-mR+E*-6aD3ErJ`a=bR1 zE44mMZuZAA8;-xX^4s)Jkz{)QVJ-Dl$z`TwF=gih1GPS)Z+SC(u$VSb5ky{YXuLUW z)##5`XghnUzdhxeA$@1R$~rBhhk=h?WMDo(x+0kGB)bCu;eNGn_E(i>jwmf3=t7hh zCI(reB;<@*Gsh>|nguL*ndw@A+D(m&&jTbkDLMXdQv-PRep5$aEA#0YckSMqgi7Mc zZAVJeL3d&l*^;R_1q-e^0ifPVFeM>HZzlpDg}#%^Ul#(*Yhgn^j?VAn?O3J|# z-JGYN)tHv-)AmDd*;*^tL#Oke3i3Q;w;a7$|+S^ihkR2~x*4Jy=Rt`l~UtJ8Nl8;`Eu)h!3PW2Tj>MOL$O_%O;OKuhy z4vC! zSKH>YZ^s`QZCL)e$MRcvZ1dTCys3$a7`Z;nM)=LKukdGTV?#VKTXSQ6r2TVr%|`k2 zS%kaD@ZrDfH}y?XB=jbR{tA)WBPs-R_k&<1d;62_5k=42MT@T#)FnT-MwnnhpWhy~ zkX1~nJNijaM0Og^-&pKQUf*7JL(7KlZ}v04WIjqB|9XFFN)3Lhy{|pSx&Q-(a9#qt zIzI?p_{{1>o$H!*ii=(N5|9KA;U#rpSWh~Vfi?#v54r$vlJ1`Wp_EUWUAWCV7Xc{4 zA5gd~#gUYv#1Ifc$QHmZ^s7q0(q~HaM9N+=d4&Mf5>in9d^fGYW$+iZURaiz50ma? zi1z3fE|)4Cn>mGG+w*8n3gi=O%4L`?A~pTwpfq=}+9$iPpLiQU8UE#}&?#nunJA!? zNpw{t7(^pzKgci6#1Q1mvOvDzGbl-96CDwjM{`Wsk3hs3=$W>Yt?HJhXk9HA^^(UR zLl~l9Fdm0&O~W{FFjU3h7{7x`%VNPOudz3lfDl8=n%X0mec+Th@<_`bv?Qayj;>=* zw9OpcRIJ%A5kJZj!=5WQI)+G8lagLa`Q_uR!jfH$+RRGZ#Y<%DhVm6Q0(f%r{VxDEasCJ<;D;XNQG?7$K?xjJ4UAF1m-1sG~`5b~{GEsc=% z)ElfP3Q+p2RPLTc+c^LHt8!;SAZ1<>V!9(GptTj`ELj@I23Bz*u%^SFcnmZs0EJ4d z)>E`79-2c1$fmp0DzsVjoxf`y-VfVwi&*>?!=vSW4{txv?LlA5+ARZ zciW)0LRIcz%$D>#sgz6R{^t=gt)_Rx3b(1g)l22v#;$4VFCl0mc~SLk;l`g`^>Suo zQIOmAcIln2M>ck0S#XJev(Y|))?Sx)+=-!|5viuGjtgM=P|6s+UBvON!&| zvN+uMW5R|!;WPSmh^FD@dP6Tho|d!v^{C0UIbZ9OK06Bu%eK}DSIc0DBQ2ReH^KzM z0MnggC{oQ9CGjIYGNen>;9;x6?&0%AIEIHySoHn3CUYvyw5Uv-op7PY!JmhmNdrQ$ zbb_O!0L2Jjuf}Pw@+Q5PB~Br*rx8aiq4l5knv889^T1o3Aj5(S_hVv{d_1CYZ}H2U z9Yu`Jv%e&1UbT5smlpotpISKVuj7B#S*fo!&djcHM@Hire^SwQMmx!&PujrBjUJGo7G@H^0WRXorN2+0|j4~R}O`GcYs8cC_ULUje4*omBVJX0xZWvz9FE<3Cl=nZA*Nit1pJt1?_C7`(R_JFCM@S z)DPQ}lvmkk!B;vnu+G`1mQUZ8gaxeiptP zQ|;BKt}I87-%4lWh+kSAmEGKk+pn4k1dic4AN3Rwuj-B5*)lTFX4KcPn>yXubpM4; zX^{hrJNg!XXFup=wOux7w=Cp-{;NModU2QYSB~A_-ej0rS7?cI8eA z!u}tu1R_CzKnN{Lkuf>HT=u-$y{7s2S-D8xW%_y;g%BN!#ghl;?L zjKGoe_YbYr*x#e2c<@6`fLk#gZl=`=$%Rgj#&x^Z%fCoQbUf0t*S1tm>z(9M* z9n-~tnvmN^aiG7=l)oxp057uNPNhc>*n_RmL%_|Wjou-KDqss7a1HY-G!N8B3bc6e zZ=Vft68|Ud=J?muuciLqS-rnCgx(}6$kQ!Q6KKT=aph)!=p9HHLHIm?7^Vc~!34IJ zi$PJArZ_U1zl-p1ErK13Twi>G@Ma-Z?7 zZHOEnb3a3PL4R1sbPCCG}S-2M{OS6elqhCu1d*ofW6VA1^}~r%H&b?=0gK7jfAT z@pmys$eb`?4>5riFFGkcCPq5Jq$y!;Cl*kYQ0x$67@T0MlX!@jh;N=SPM6?jmFV3h zZ@=`<%L?eD7-=(C-b`|J3#|;tWwI8h09P3vW8^y{#sN$rZhqYi>#cA593R9Qkw|V6r59s zcv6wz(|$^)bvLEC)uhrcqfjfc|L&2%z8 zOJt-3r2jKZj~UF^d(8N|okm5UArzc(@i+6xB-1ZHgQF;Y=rH4fFiR0JBO91?FOj8Q zlnLLQg+!O|vSj%o#RPRY3#VAIoiGFIDG*;b>rE${teNvuisnaD2&ua%FPeL~@0J{DnrtIe(anLQqLsDBuJgq6~4N z20TzzDX6XsR3{#46aY;vhQ{bZvxcFWOvMRD(EMd+kw9?)QE`cUacN3%VR3P#F7%cU z4%?{2^9g;n5<%qR``6q5FL*0Z)`nUXno=g~E+%jr^MAtIug)gKa&UA1C%hG?m|ZT4 z%`ba6;y$!4e_5`0d#V5sS8g5^K2cVF@Tdf)R-!&vAf{GyVLao1y+EL6uOg(V!u6>7 z3jSa4wxx<}rHbOYii)_JMzESrubRQ5nkluKrKOr}rJDV@nv=MOTd?K}S`EEw6+dy+ zR~WuduWmDLZ4qrPu4t{aUadrEtyD>^3>rZ0xK;tJPU*Q;gShUyV4Y@aopwu|;!54O z<2o(0dR^vv{pUJ!;(80gdMmwp%awZT<9gfYdbO5XO`Y109u01(4el)so+}OB&kbPW zMqj~3f4#;)k48vpV@OM5*h*srMG=F*ks^5^DC;+AT`mRh}*dXJXI)RyL!me!S)_UD#P;?{0D!PZ{A z)_#xH!PM5_me$dg*74`oN#eF?!M0hwwt0{L3~k$TOB>bdO56H#8;rPpOR#-MuYJ#> z{UEjdsHOd6rTy%={ermTN<^^ZMz7<}qvIj9>^0(>UFPp$6D>etFJ&K?xqszrqS;XI_hFrRehfH zWm)ZJd+BB;>ERUW;nwf@;@QKO)+5l`BedEh^3o$p(#!c0#i-xQ#)BrF)+^uItGL>$ z{L-sR(kDlvFRkCl$piY9)~DOrr@z`~@Y2Vd1}|ybXBwtjg9(67>SyNju^#QWS?zat z>32HmcP1Hd5gPcDHsIDe;JP~Celp#n1 z5pAPb0S(i!&akgd%X-bq+s!Jb&nl6QBRYlD-&rIv>%th^{ngcr)SqV=0#l%^Q6fp>ZK?ZRu~@lD_0pg79(z*X(Eg z`N*}!sJ7+ewPon(B8%;!<;z@A-j5XzNjNspaz@)qYul8q=VCVLbiMHExbd>K-6G_5 zxwLGxylr)SZB-S0;bOJ2_s1HiH9V{BMDyBzVbR(IXt|wsbzXQqH-5^Sb?M-Bb+m2$ zXl;FRY@I@BZAW;+WlO;Q_idXmYnxI)m*vxAmfzO@a39ZzSticH5%m5mMoE>;^lm_B)@wcVML}@MJs65geLh+uuZX zRYAi_Z#&xMIfQEq_dkXu%6H_*_GCr&STk0@>ARL^GHTv1p!b}{+nxhij~V)KZTj9H zHV#_!;Rml}Bjr7l^?mRjOs>=33gErs4Q*aq`C z+fT_jOe0%bTJv>bJIwLUx;^Q>D_si9I8rZP;dTHI-IEl^yG zySr1|o#5{765QS0-QC??Ki>EI&zd=ygE`50)=G9(o_+tWy@-6_$=ma7{PZYeKbp2@ zi}(CBVqq#`ILqVwf%lT4?!2x1Y>4jSqW$tpYyVbyxI_999pyX7Y8bYn|1IN+yG(tl zeGQH8ir7o9*Xr1A{1Rc~lCokwt8J~q;)29_r^o7QGGhIk_6_0K^@=70{l_(V`&g#u z6}`ImiHz^4@LFo8ngQ!8HY1kx2`DnZq}39^dtWC556Df zav4h&GB;M4k5jD|RRoWIACJ36)0P=$Pw5ZtUQc?FPg7<6rL8wne5+6GBO!dxvte-N z)^`yfA5SHb2m56ged`aE6;IU{V~PyV80W{}i?e9=k-UwUq|xgi9}AcnFTGx2g_$dL znNPz+FC!fzSssh^p0^p9W9QF&DY9>5fp2MP7p-2W?H_L%?u)TRD|zmBeHREw3T5+N zZarD=WC`zguHAPnV*`?`tdy%yrZIZ1#}^8KuJZ7gNuicaJW<)VydPt z!O6u4x|WZh5%5OL;Hb>qR>WUGBY0pi0zU0$eQV(!mR6c( zF^y0Bn1;qfr()?5o?g~Aws!UgB&6q+*0S=9$S9AHvP}O=qRG@Escr2QnOV^>v2kI1 zAqMIW!K8-!Mhm*k67Vl0g~3l^-0lvh{wClm%9Pt1Qu^-l3dfu`82uxhGx$xMMzbRN!M#d7(~5Eh{cq<>fhI>go5LRi2b z^JN+>T4pSarHfSt1LNSn&%;M`G~?&uBdk^`4O%D5ZJ}%x^8(}v;~KE+m787uhE#^g?l@~PRVt?#&fRpqH%c+_Z7jhm_h52P+*qiMbLnKP#q9(69Yb*?cqElC zlDEyZm$?kaMuVmG_38FxrJ+}i>u3NQ{%B*I$oTz@0?X-v;IqChu2S+bU$4c+vqpbs z)wRLUGSLmhv;vV>SDDsKD?m{Hx4w18R)oe{3=<3}WmgWxZiCZTXZFhZPk!rbKW(w( z-l!ENfGlPZeORjubm#hSe*5M7ve&>ErEIvo-52aIW!@pfL@i&t7{3i65HR_hBWOR_ z%&yTW#iBjg2qu&#={JnJXRvXGD{HY{xOq>dy{ca$oUYq%{;)=ATLH* z+*HrtT%ph>)N1&{v~Uo&c{@GKTRf{MV7h1zluXe?AXV`t97~>UN0}GQuUZ&0scnnc z8o*U?ToXw4levsI;Ys>RHzZ|HP@|vmE%vMZj3{4Sc!7X+_+18z@&MDgujAOiaXk4n zhb1FU%OE#{qHIFA;iL_X!9LAS` zL#vcPhuyz2gbsT#umOy=&)o+1^Bx~deBIn9JklhV)BKsKuJ#I>`SIL+}?l*hX?+fg;>e5ep^U9GtcgvS6)?8^7 z=kAZ&*7>a(dEfuG8x;k;-97Jg$eflp1z1w`w(SA-nD4l@qP#a#T%qzIY z@^J!G8;`G9Q0g!|b@AWb7G|_y7({=wzL}VP{tOdiJa(dnRN@?kM!ZZe)P$yAqLRIi zB8b2VlDE+dq3>#UG!}N?A4jRP?&^GbZixBlvjGVy{6UT8FbQ7hS-M1m~1<0nF!Zp$pve;k>GCNWDxd6d^@Yu?r;PR za0?FyT0+UixWf+ce|-6?H7Q04g*YGtO)nP}Rw~oZ4F{&vvyF)1j{A^ell>r@;8b3mIV`HTWWsYqY!oR}HDmR*7eXJJ#esP)fLWmQ8lBDD( zXqI&ozf~E;SYk~YYBLe|mpgH4Ql`;Vok}T>>YY|3HZdO>w@->rtym19o-%EhVPEhR zwWH@b*W+|-&syjge1OvTdAxhtzUXHUg`_ppWG-oZws<^@T;_AMi10+AFoRaBurT`Bg93Y((=g1?7lss87?E+UPMh-xL_Dl3^AHFBuY2pD6=#4FbjcC5{# zD>F@&gL4_DO9-^2lX)>wj8YP$zCiOgkqKQ{J$7tJ@h0THJ)W8Cj;TpGE42+B8C^tR zb-h5PvCc$XI~2HgsGU=`jZr6GFmq}-_FuEe{w;G9;?%mkXAbR>sJUlX*?3A;?#2}& zh~I3--eG6uauK@m%KXo)lc2=`Y_ANAU>pR%L1kJ^;kw&RhmtVgJ|23Bh4T|q#c|Qc60f-1DQR%MTC>b zZmf1<hsKr9jAs2 zUaF7^^gn+_;YLK2I1TU^)uhtF6=w>O4SqrfyeT1wTqn5Vu_iP=crW*mz4P=f~R1(aJ`NZtEk@%b`bY z#vJzrB}?aQ!z!uD);n&S@q>}^gHkp;d};(|N)0_r;Up*9&=uFxkDX0R32y3X!=~R6 zgx`SI4Zm?JAU7%+hgQhWWl3Il<;$zs(Gjnb*fjR*eK=Zo)NB*H@OGx>o1zAGuDSx< z57()y^tRRPGObrPcM^)H4PU9VAf5LWyqxE7p_Ws@EC(_VEE6p2cc?^542%@bOH?o% zloc#Tx&v%8is!fWUhIpjlvi^XR(GuoPTfCMX%_aREp3GHPQ8s8y;AM&doz*L(E`)e zmb4zz3tJUvBHjH8MNKBA-EW)^T15#i{*-0jT_k+)-7f5CJ-^;lo>VHeo=^qk<>q2u z7e74j<|{lcgWGZGU6{9dP+rF?TBIE^8w=noUW7Awp9BfHOcGz84|6HjmPKjKlC94L zw+ZhS6Rn_42c5P~G+&uSsw5*O-p*+>E;Ejrp4<4g_l!KXfjFH2jZPr;Vh6-R<>!L? z;k9`Q35g2zfepnz2{E0B^UMZ#RYpcS^xn0RpYsRS9C%(7+B7uKHwF1j?%AQ|Vaz_z z+&OtA-Jv5<0)9bw)0lXXiu%&B17NcNU2L9`Re<|MUs9Z}Z0rCI6Fk>GE;SKX0(Sio z5r0BQgYPU}JC{h1I$n>kq^OHtOoxz)CS(E1hd`uFFBD>cA`S$*H{jkFLbk};k0b!4 z(JOiv+YH-q(8>RUuF&VA+s8UdU5>)b^qS^5lZ?nD^ldK|_E#Te>0Pn1in*~78 z@j}W1aEkh+E(DPuVn-i>{!(BFiUtuI`6m}?gW1?5RFI`rd{=A&=lDUNQW+XQ=x35< z$fBPiB@P6@XOfMtrYKM>0l3f`+;s04m+0AoqlyUY`CG&h_p=Sa0eTGbdEtQ2Dd2S3 z0f8#a1f5|_zrEkFp{Q`a`r$x9WC4JTASiDDBr1>r5O~i9iKasivlzaC6S1Qbalj5Z zD*A;m2?+y=V84Xq>hg9wgk%DQqv(JjP$7smz1TZLjIJSpSpoMCfMd|t&$<^l1c1x} zM4KpnY%E4WT}DOuM$d|b{JobuCO{1YPzna20m2s;L19Y(Ex2$z;s`XI+~JP9I~18`ypL>dfFIf5+5fl_+}aLNXhgP;t$y^&BMKEe+n z*(*bcs#ETRpp;a+CJ+5^`D0s>nFjS zMaYQQkavo#kUb|FR?M&eF!@Fp%I|FOrx7bSDb)IV3gjBwsmjbW0{ z{=`WTUpN5lj|(t}!#557t8s9W5)0(^BZ)2uU@PoPARY^|DTu=ePz#9`GqYd8RdW=}`#>Vk zt0c~d9mVI;Qq#PF7zZffe&7`NcE%W%q%8q_CQx&yl5}An zA@~eK$@fL9#RMAyivK84(4b*L*pwfoKpr;#hfi>rCf*rVM-d(W&Mx-c<9$ zucWM`xtiY%(94)(hvlT+8j;gw$W?nbjC%`%IF;k|S#UmCi#M+N`!~~woI#AzSaca3S3Ng*wrPM$LA78>Z5U()& z*eU@|*yg0U;)QqbsmO~JIRjkbk5>}rR^g)8>ggU9$6;NZ> zsvQcB+^zvkMm@sjPT_sEbgAN^p&2cy)pg7ME6Qbv*SY~{|0+Z{bcw291=n&i(q68g zw^3rWgp2Xi%;Sm(ct8FF_!7>q%B(vGnA0hM-+?@bk0Zsaj8eyMHt$3Y?JR(6>PcmE zm#FuK|M3P#MLd#UwDkKHFAKf5!aao%Eh_fQa&-e*OzX?<#N~FH(eC{i|F^o_+QX*x z)Rw+yMNJHVF;xP^eF_;8U~4g%?~iZ!QXAWkIQf@4iIG-Tq_!c6Mmj3EU%_C9I0%NO zI#~S{KRo&^(62BKsT>@ip`}{3U&*O0URmr7VhQcktKA=#r~RdRL5GH__5RKSgA7D^ z9q4JA*MIT@l7K%b6RN6mh!Hw<+50oq^B1!kpL1Fwz=J^A0&hxS98TZlW9zJV7lFDj zzd*+XWyE4-4&>$MH>#C~wt0FpgwqDNMqNUC45AoqVSeeaksQ&4w8s-xG|*nSD(m3pUA0IGLPkz{$nW3vy~YRoneQ>hIG8^OxMZ(TTS(3R_x>W5pza zsUeTJfgv}`4x}>KZ%~gK1*8bxRn46UG(*uRBOfF!l`-%`Wg1fen*_jOt=eLEUeb_} zMx#?nTLpYLaT$O|C;_a}9Y5;J2mixJqRlNCIrT^J`{)c$*>`XQ#ufAkE@Ka1ek8GV zU`4(Fx^Wt|GL4Ov2JI3sIS7Dr9vDy`!(OX^dlOVy`rEr1B#P{5-S-7{}i>Kx3YNy4?!%^3X(@^W6w#h|wzx6d_Ie;dfY zsB>#5Ue|GnOG(+5%&xnFlG|Fe_q3ONRgN$lS&u{#+FIb#If96TV{i)d&a~q_mLpO+ z_9_{VLl^5TS2od@4_kCk0P2EcOW1!})cg$dQ`;Ci=y(#>8f3#)t)3P+UhJY0R@dSg zo3cYcSCR!yQhVdlirSWCs#R`{jrusg1YcFnbkH$#I}0>UzLPyy7O&=Kr9CO9SXvopH@4Jgxf&<3yM;_$T~2{yPe^iu;ZCs1e5JN0~urfsXx8i+IIgT357Ro3Mg|$rpO->k8rGu z#UminobToP5a`nHhw$$kj_-k|jPIURH_SFm`>b<3mPnqKY}dsd2={>W9SZB3N*Txx zYpXctia_3$CptXhjQwT%{kD57W8{5<2&{;V!+G@0Q=Ub}Qe8^!jYQ;w4!4b}IaUPE zEu`_)@v_Yih11^+61z^cdzV}mcG4{|9vQ%QpgQl7&8;bF_Q{v76PT%!f7$#^?Hz!3e>*`nxb zH_z?4hP=3*Siklg=%*x7yb9`H?bI6wuDAN>t+$^+zP0>rKSs{j+KmWFDLXhGKRY7a z5Q!9nP3+#M@7ZAexDdN<=YHj-{u#-w=0(jY!>z@~uCs9mZG9JW<&#+6|0OW&cg6i5 z>)jgUI}i0OvEiGl3^6Dv)qQw#7^{Os-mBiR<~}-P3a@*n-hbN$;@;JdcLkh5M7Rv2 z2ZNeAP7IIsVduWq2V3uFQWsBIgtw`ls}<+>p68q2wKuZ(?qcS}d6_PvP(<2P_>q4n z2ITQqXYz0Qy>vmpIO_&&Pv>frVVXK^qHNq8%B~AuT&~hRywx1J@$L?6oRy8|UWVVE zm%sAH-QX+Tw4waLd64L_e>qLOYnFMp(|ug`LP`nghS0>1M^TfEcpIHqawc3~E#FFy zeDkut6$8TIe&7Op0|Fre{epx2!+gU-qa$Ntb3FNfxVKf)a96gk(^qAvTkxu;4c-DR_Tww zdcv{4w5vgJxAS?q@y=@rKRBk(HN~j)#dqE8(4>+Bmpn&9r($$`sN{-|_v`+bdi#XU zRia2W^Ni%0s%SkU9wZEuuB$6b|5R_KTJEc?F;R;ZL6^4MNMBHPEHPXvH$Zl8F&M$G zld6684=I{rt+b+t9w*XH>n{d^Tf@rF((GF37AuD{VC|2I_P4p>@`)DT3N=rf=lruB zSj&yJj}O2X6%;lD2N5xcFOUs*j+=i5h;_H#<`QkU<=__d0{$ovN_GH+MeWdz;}_MJ z_Y`&XVd1gl#f#I(H$xfwDR;t-qOmeVbMyC?&DlxwcOous7bhaQ?nQQ@Mahqdtn}s3 zqio?k8#KtGi+|+%p|i3NN#J9OAp9BojiY571xGHfNwyvjpFEu(yy%a`jG;#&Ol^9Q z?z~@INdF=?r=2PD)`So>cgRMW?QmZF8Q-o$C-}%*?~%FK3Or3E6Y8d{*}$eOlBAKD zl^i+#OIR)lu&T$T4vNCs|B*)*GJjZFJ8wqg$~yl+u^Ux}*#uS)TBoF`@c+VbS~f@y zK9!yM1uvUVr3OnKKTl$!oRn$iQtTJ4RI-;cp_z!CVkZ4Vw~;czZMD!Q#iKDN7vQOrYxd`sKtu2okrV4H`R)f=yq_ET$We_e)0oS+k3-7`odsN$}$D#MQvn z>XWUegnkz?oh`?qmEfvxrPsr>4X@(mPj6h>mCb~R78Tvq`xRsK52rU?lH$ErB5P(3 z(;nqTSz&8mHx5S^_}~Yp4u)1xutJunR>#VB`VV_C~c-w^V;^u zk^e?B+k2|!rq$$BqTG_-_wf1L_ zDAm)=UkRI%JpFdnXVT2+%XhpyXx#UHAku@DYK<%2?#DOP6*_EF+n?9+-2^>PRtqaU zx#74rxHkL9+PQlF22q(gqaezZH{W8bM89u%tYevn=Ywf0@@sm8Ii1xCn?<^D6~ZtoNP?a za9iKd?7q!7R(=`}i?A6keGM#;MkrwfQ(v9WX-#xsQt{FayBBxe9*dV(pen-KLT6Ig zxj!ka^!rs0mhQcB>jByQ~V zX4#Ae1?cctBR!%$`KBhOlKv`{%HzHt-)eaZg+$OhOxYCYEq|3^9Q>v*tGr*bi}C)c z_{&9Jl9(FZ_LNi~3p=)MfFsC4ZS^ld820aw?z9~xq@k1250Q}N+X6Ts1z}&Il)|+_ z;mT|l;YiR-4k;oV;LpLE??ZdwWu^b|YtGd=#)JqxQZ9`{9>UqsxUkl>o!AN~+0hT7 z+?2agH8~gldstkfH)6 zlWU)i>8~vA(F7GuMSi#DOAX~_i?sfm&+FM)ynd+5@u~%nU%JA``R;e!ZMq*D*&YGt;tiSN%|5?#AFgCZU}rrk@g?Z{U7kC8oaV9nac>LQ+>B^vB;6oGstDyUI6URYfm#k@Te{w*Q)y~u zqCGM|x&$>us(FdD;pzEf@|+2oNggy(<;}J6oG1Yl}O@>p(XF^JvCh!0BX^%|A&zVC|PQ5i}hkLM5d*(W4h?3MG5RRy#E%2ggaR%%Jf39ZvB3<&O_j+{>? zRWve9P2Jr?zO2C?OJay`5+fJ1250BbRsc2UmAPJ)CL`+6QVrKc@IcAGXhU&BdH9Ry zXL3i|e$G}tH~hW{l-e)9(}l?MvY@FTXDLp*!NUrcRbVuQ7>T zQdQ?elyIb{%J|^nrnj@Kf%0)`qmaG|QFm98WEQQP3O;AiazR&MWjA6Fwy+Ut2oQG6 zuydfxe)U9>CnI1p;BV8|5PjWdpJ*C)I?uT0A>a6-wJX~dYHY@e@H_+TU%iNMI@Q+a zIE-CmNeo@F7d778=bmVuZy0wdLtU7O*Q+t0GFAh>FGSNpTo(QINV3qjUnwLfY_ePn z9$l`X9k_2Gsa*fKYkzsW@JU73%@z%vr<=Dt81 zN{onmTdE@M&KilIL7qTkj!p%S2w0`)I6KEf12std)p-M5WtS%?U)x2H)>dFWi*y8g z@M#S~++T^IM5YM^M;#m!RZ=eRiZ!#V}Y(M+I1dR7zArI9&_b2El1VlB(Lx| z4^X4Mq>`ZJE_ax71hflEtte%ON)&<%-Yhv>9~^#*lhy1Yy-j$n_TFE~92E;skR45< zPBj~7deQ4E2eSv|*JxvNv&JauA(YtP5o7f+j8vEjhf&((bcgo=_J?$YN9HfEN-o%L zyeg(=ur{v+3L%4HJkFL*MR9#eF-%W!HU-hbMEHI4T5)22&`=YiIl zS@68PwH$nmSN_WSoytywGIC6X$5bvt@q;6AC`$H#IMKc-=BNnkeZW{SnOM+I2ND)k zX>1QJa(jr4ba1qZZQ#>Nij%7QRT~KBoSc}zfq_;hpV*WPo{2(0_hM_b8xwR~5jOKp zc8u-Qeo^`NRW))NEk&z{9VSrb0!J!-D7g_An2f8?l$`jWBq+xmXO;NZP08zO!BrK` z=Ii3$+(BDwv5+%3ug}$8pDI9$6{}$>b$JP;u~XW*~FpfeXGIZl!V z*#nF6ww)1GwDFM6sz}m3WoF)8_(oHv`*iwYRQl@@=mE>D)+ml>$PGvSeLXh?Y^*l!=270m%E~w{AI_h-vH&ZX; zhv=)XNyBI5TL=g5=cq|iV*M1Cl$$0U<{75C^^snF+e1jV-!2CWI3e} zM>Z@mnj9}d!WpLpALG;LSo17RO`Q(Jsc>d9Q%SKmX^uu2AsKZnhcl1w z`_zZ=s`uD(ngla)I1F&iGDn+|yjf))f=OO_tk@8=s8`%cR0A27gIyo3AkC{k_Ejna z&}#+oTtS>}-N|6l5U}1;b&l{q+0gEEukyu1!Q-p77)6LZ>7?K=_QtYN>{>iaSwv|R)!3HM{UMEh&CAIV9^5n;U zldfgv3*fB1+JPd)A<=;Ey5Uc?xs3*4pM5SVXSJYObsD1UzodHV)T$*GM=maw71%14 z#Tu3t(_!ZZ4yx4Ac$f4f(fO*TZ1cPft^(bb7R|!Sz>#VwtQ6hZ0tkvEZg9h3ef>vB zLVal63$UA)~p4H)q0`(T$@iA1Q z?4{V{`Ue8ty(CwpDNrSDT$T9NnS?_;CUGF_2sgyn#KO0z-m<06_e;e&4IElp*B)$e z9Zd7DR)>I<4t)?hK(g=)Yt0XfteD~E=@uJuq!59jgj^le#pDhSN$*3$o>t7hZMJ6p z#*%^)xTEQSkDD|E<8H8eo9>ay8+8ktOCe{mp@f<OeO| zD|c%yc%-}598CkOKWK$TV!nQ*$oN*QSXH96kfp~Je<;R6VIx&~q`8kazFkelA=woq zg55oRSDKYjVOG+4)7b54itx-aRJuB>-`_zHR)F`tf|^n2gr4}Rc} zZ(GQ7;g8%L7cV2_!ZcXxhdDNv(oOpf#s-Zs{A#EuG2Y}Je{P+a_wPfb9adv+3i#)P zGEka<)XzIuWSBP2p_<^vsYB+x;`%x^`)MPL4O^ckQ@b8JJo{2z>@RsYYT4j<1%ZbC*mWq!{IbiWsMK@jL0iJ!L4PC zuNpkz(tW23A90#XIECeyIB_FXlMy1>*h`TdV>V<|*@!Ep0#Z$(?!j#AO9^Ys<(3ps z-`WpW@^6y)mIwQPXfoZ4)>$RzOT2W;zNEzkNMm*`S9q)rzL9PXl}36veN}5;g`41C zjh$~X55#UM3@;m#=%jOP=ucnWBv?OY%KFPQ8V8vw*#Lexw9c>`X$o2%51a2f%Y(BZ z(WP5|Oy8(kq=nC|xEO0}B}gBBNwS1r`}rx%j74n5tg&g%9>QL61tP>$Nc=kQt z4@P9&Q}VX-4kn4^qXj(UF4~+i@Rlx42Pwwg#1ryp+TnDN50gIcwN88+5Z(ybW#7v;5)ZrN6M9&et$E6X z2#!=FCVvok?1Km0GiK}GPbUcltOHRgi?B)4tHa91pB*qYe0pDnFaDD6vP?JdPwx`lAdo41{R4eKkP zOX_KJ(#dAR@J00U%KW_~9pN>5$NBef$K3SxPmAOF4~pjm`9&u=dbKA6AG4IyYw<1Y zJG?vjj3*p?cO){zKPr5`tV`UKPktLScW#@c4*?gX-Mzin$EO50Z_)mhxwELa8SD09 zl3Am|Y@kQHXU# zXg$8iXx+!0kB*l*!hq0;QMmL|pR#YqVeiZK zCQ10%tqwCdlx5mN8QtHu2vb8e&L#&9Sq!1ySe*Z9j3o*bDnZH!V4k~2E?jnJ8WWdTsWy?%;dUDzYck1~402IDpF_kU)XcIzW`8VSmi+y++BDb7uiI*c_wOCZ z{n2>2)fi<>qmqHb(3BM3#%AJ)Yr!%!vvKtrnxAzN=u}%@>)arQOsYO3{$nnoedN8 z>rD%iWphk-qin6Hj8dK7sdm%N_Dw0W1J9|9lH}jnjT8U0H|6FC_8%1)M%rtcnHE?H z2j~@Mcp{h;51WC@+Y!u5tN+X1u7{YHH=hriSG2z)fc5kL)7|zH zhFVk)`&J`Wj`Jc~ezLa`mUZ*OG^f?eRwHzD^hk{SP-_tqR!x{ZcM3nE-GgL)9;xM? zxA^v|LjVl>@Pb>O&tKx&4)?(fosb`?R$WhoVf0-vDt+h4Zvn0 z=7;7b3b&sYXBe}e`Hx$ik)Tes85$+_St3M5BOb|kaYJ}p>@+SWamDTxa8oGHSr`+;*Bc|db19p88 z>F>$;YHhqXVTE)C$vL5s>7jra3b`mubU2}JnrVTBR%f>8`GNrn@0T4dr(=P!Lytolyj(L zXQHQFYp71%=WT>Y&m%q)U%@vhob+JGyPtqsR2@>QlN)t%{?AEOv)huOa&-?sF&J40 z#iNT$2lDEwkni2p+7JIU3zqM&t4-6_BR$jyk#o}*)*BHvS#$${lLZAs*$LBkLfA;X z*!>W14N!vW+eQU+AyrlcB{F)?3Wx~WL<7F^q2eQRdUwB4G5mw>g??eSZSTsax2=Ez zOe(<&QdUPod^(bsoSnSTQ78b_A8BNbMS=h!VvnK|{wo`}Zti2-NcU6ynMUJ3UB55U z`$o0^$f%qW2K+9;2HH1Ya6|wKXfB)I^2ioNcpFS*qru-vUz-PoM}JSpe*d+JJeLTM ziA)@Ag$F=<`cf{kZw*L*_FloD9O^O^h!0jE!DXrv)qJ*1>ZL7U`iOj@LeoOmNYyZU zaRhk_Y3W7O((VB+8+%6gV&YN(YIJM$LK3?di7L?8Pgo-3{Qb5u;YcP9K+#mJ@Ui>Q2g+F7YFAL2sRd3;giEe*s{p@Qsl{8xGt{wpQ8 zC0T$}@$WyyR1o~Ka^9GZoN=WTb`l|5OVT!%&}Fufp{u+5S5-xI=0Z*?vbmhGWvzy0 z_I%j@2ukayN|*q0v=-?ywOB9%MxTBU=b9|^FU*OSbh{12jmKdps~q6b`RT}t@A`B$ z_SlvKwzI%Nb4=9UT_Irp+k}n7&#KfY+T9RDJ_sbc28p&$UU{ks3N)7nOsFK4X;cLi zp?BZVRRzm@gl*d~Jv1dXo*=`dXi2zZ=XqdWkQ*g00JXn@1vT-%d|U0iJHM`Pkpe?b zsIN(E)?~e;QBdlZ==A!WK&IeWhE#;@(VF8b+|iYgIl~S?G%U7>XBc38Pu7){{hBIF ze}l+2B0<0wn_`jjkt=m0gscbTBQhfQuD02$eXHJXa4W09?e1t z)F7taX#6iBM=)p?fkt&+D=k0JMM^>&MJ^qsNast@leVaY+wyEVt8%lXCZqA!n&dk0 zHGPn1vTXme-Ll11Lm!(wtk!vMV$$z7HJ;<=qt1qG_l92(1Y+>*FWJUeM5I5vENvK; zm!TONyHMKE#Q0|W#H6z~Q|MZ^Tseo(6!s*1aLxL+yn6{cyP!$hgGYp|w?(n|uW%}D zvjbP_d{f!ON9ZiVf+qH8yy(WhA6~7j;cszYG);IB#!xr~bn`PbPen%D%D|`7bV)bQ zWFm{+6;T);sgpF%Rz%#DniJ3}EQrrEAm3MpdYs}Pxhj=P-`9dY&MYPZma4Ar>PI}z z?JHW=E+QUUPdqN%FIqM}kRLk{Jum&~J{feB$j4q9&#N%&)*T|0r$GtN>$r;6J%-4q zQFG6m&+|+Nd??S8p`N#S476=WGLg^!3%%7wd08CsysxcjJF|{_Sv~Q5XuW8=@Irar zMD%*>WiV{Nii~{SrSWQu+}jK@>vza1{9|F~I{5op>7kepC9BlKA++q6vch#gfoO*q!-{{AJ(* z`aAu7*#hKD0+bsO72%wvK>?bF0X4Gz3WHSI;-4vL5rW#I)2nWv`CwqKccAA(pe1FH z9ZryYVcWpZpc~CW75_%*)%Tl|Axn#vJWiwuKCi}R4Mq<8ASTQ( z>KIz)77PFL;E*}4o7oZO{2+YZ47d7Ij1YVxpptXIF-osE?g#v@GFBqFp#;ArH>)K~ zt>pMva!nY81cRk`A=NnP{=_`!#4UYJdPbXW-6a2^s$q0XiN!APxLHv5&We3k z$@NERX$xtpTdb~2@KC|ZVWugOrfKC(iKVJ(*|L$)y@vR;A5Up}O|DbPnCp-5ee!Co z1$G*$N>NShu3~9%M~MWo8vaK3ho%{CxGt|Qn5w;iWHA9jczsgxm@W=p$Uq4lm$dY_ zbXOyNtMB}nME~3-5A>evrAcx(LXP)9(?$r1{ zF^_nFlx#LvXI3*z+)Fi-N+tTs5d1&-zoGJgjG3eyQHvO|BR8fsSUjiX3ZF?aKlLBf zV`&8dQh#LM=Iv7Zau0Kb%isa#!!pmg z!UH4)I>UwT4TVZqpbbYr^$|R|GAEUYbz(lZP@MYR_r{_}O1OLH&ktW=h}s{7PCLb& z2+dE0`!d+XvMsZ6HMa6EiRA~w^V8EZRWv)(@h=^5)ca>2Im1yQws;pJpoPpRQRr3i|5lU4AG}YI0 zm7vC|@1fOU_f_Fo)qGSZ)me4b)Na)b5;f#lHR|)#EZ`a|#p+4U8ZN}zL5CW(gqq*x zwFoS=I4?CKBek~iwJJ)rQZ#kFvo%f+wF&Z}byEN8sIcl(PwI$j>nO46wI%BJnd`~t z>kPs5^YV51#1vry-?>N|Y$Y1(Airl`)_dqRxPcoyxEe6L8}uL>++P|z)f=?_)tiGG zhtSh!d>f{(8>hBeOj{b`x!A%Hn=&>VQ&O8WVjB|>n{$1eFtQ-?&6^8Dn~Ou6*$12Q zz|9qiE!3UO={D3=<}G0L3N^u|`RtJv!)*gx0KJ(Vk=a!|RwhoN6jb9otx(HFY29W?(l@89mJ=<02*BsNJN*ZUEh$p@R0zd?j1i~yFWu*=FwkY z!@4=R>8T|lK1l`d6L?12j@wXpfv|R=*I(p?r9Bd*$@F0e$giEBvn&E_->=dx*4AI_ zqrK|xwAi!=!rbl5Nd4Hry%*Ecz8ryY*MwR7!uTjrP0Y zQ(K+(h|&)5ybipp_x^kB@=xnE8660fEDbrW59IEq=k5!_9{?~8MQaSPjkbr44()CC zW$O&=st-ir_vF$JvYZZLr46T#4(D$V6ln})bGI3~4=cA0^X3jDmz9n@UJchH^`;Y# z9EA@0y!LyS)-|+_47rm+BaK9>kH$z23x$86>N@=N{kSA zOuR}C#nMcXoK4^OjK8^0Q4Q5xq_Gto3uZ!6VX*0i|Y;&PMUZroW%f zNLa=~dLWSTbbMPIr3#x*h z3|tDOTY6xdiR4*I1TOV+PrAJ=r1PNXzR(I0EJt%KYkr^42w#RBTOtXg`>DCyxw@LZx&Fm% zhhRJNJEHIH;oTVl6)?z({}y;TZWp6E7# zH`Lm~d!G;ow#H6E(V%XBr{5-%-lnkH-oV>p<=VnO-^MxJ$|as%#ormq-(q;*KK7XT zL)XDMu6h@~voE<|udy@eu=Bfp%gAz8{mZVX)tET%EbjV(wDexZ=8pN;&T8v6+5E1Y zV)=-y^v++cT`lST2kw27vpKc-J;U)K9l~*AWHpiT{WFiffV91>*1em${eR;FRv8mE zFQ|dBZTtXhBNwq-yP@deYoL<4>xeJo@c)6w-kU!z9XG+UbD9&Gq?|* z*F26ATvyr%k1X>JpBzmbczBXp?W0B>=TRShZ(ksy!A0_}9C)$`? z7C3*UTKV>+l2pC*d1>mIM&P+(WW_ip>_Hj!(fZ;;Xb`>5)y9=8{VMxJ`NvN3}7H z;&;$8n+?sC`{5C((i2wTHJjvhIPPj(t!PJIwNBhL`Cnld_^q2 zuIql)wX`RTI~9C%Gf2%ImT_aQbks9n7WuiU(Q~(%=B9GFNtt`%vG!;OmcIGOdAAj@ zS#CW)sDG6C^|oiU3+wT=iz28g@?Nd9-M#R>GxVmE;NiDw8M_72iq-s9#+`8K!$|AY zLB?2h`ctm7WbWChqd_rPUGAJ^b zIVB_W-}i5ez{Ms1i@=ReD@hCNAMgoDkLw#8_ZyBMo0#|=J2gEx+qGQujClmMvH9Q5 z-=1HxX@p@RWO<%<%>n7P*wQ;3&R<&ADf)M z0+te{C@zMe*@w3uOFu6^Tceq;L#bc5oUt?`UrQ&L ze-bqFbkqO&A?=ELXTI*|-qLVm?U7@?uwFp@%!b;KW~G>lRTJ27Gw=OFx=M4y#okQi z9v(yUnSOCXioeDJ(CzM}hcl*O0`8r2Kf=2*n{rQynQ8U|!qEWT0-k_nv zT|V5nt?ZR~QtcIU$#DlxN|L$G7BTjIBjHcQl9vf@!^+$Jcjk|KX#+WJdnt{x`}&`i zWkwE!l%<{y1eba7rSrEg@C^xXaq=gHKL4Wig~Niar=H!S!&(5( zyxZl$4hrPCa+}r)CX)D*N+G1aq!=}qa*UrXAtC0N;66?DYI_!u$u6FLD zyR)&*S+{eKZUuvT!~xu+o%1>F-*%UX1GthQ3z7JMlg$FeVY3sq(H7Q?R_@2U^T_9| zSKHIj$A>#?2hXd=y$k@loRk|%nd>!2CK5@XRAk)|5m68f#fkP6mduzBQAs8mH*5Pm zRhhR5Elw83j}@UN_X7t^$}B96K53h{F+aV^MRXmxPVzY(+U5DKewVgR>YL+0sk@c_ zcP7mghWa^o&KZ$Tkr6Jew@0PeZiTqT>(D?hc>}2@Us34`>WjID2 zXN6^W2;u;4GS}ShS<<`W)>}zZz&F%I70q03T+Za=P|ueA&kY z5;k?7_u%lj@}WFnEQw6e_(X7xQVNHo{LoiotT1Sp6a-T$5Zmw_QoHz(rga=bxKJR44yOH45tQnomj8Q9ek?%Dj^|)uN}NVvJcPI^136~%#dEXq2$ypy(&lG0 zlbT(Tp!3{j-oj`C5IE{ICWqB6yh!O6;d}G;bJHhjMJkS9(!`jZM1#~#diQ}}+2n#B z-!wClTJKh+QePCkLk(8Q=s=?RcmY-)W>>ETf1>d+DO3svYxasolZ335q)8QP{~DNr z!iS3t!;5uycoj*Flmp(R6@R|G!Hq-|YgjN!4A2#4I*5tOu##ByB}oP~U6ta!a4bs< zSm7~<$SaIH#?rDu^ z^>C@mk9qOgU3hKQr&543_QI-mTwVQH>5s%g&2_7|`Y}DoPkHjGZSBkIX5AG#=%Ajp z6-g%3fhfdD_;ztN-mz(Wa>H4tVBx%BI=ACs*=1Q#UuDG6{-6}KyBjr zA#oW%L7INDvFs&+KLq0fULm|~nP2I(L7B=`2TSR(PG{%}-kESC`|q+qSsdx7FJn3x zRd{{G8zqT~uez~}%5Uh$aMs9?`Py%PmWMkwwZwiu?v?7vj1nd_q{mP8l>^$wM3$^G z^I!MNM4ZOt?HjVcb{^1#*d{&KQ7^oz%$>Vs2?|gs7J&ttO4;^X{+sbgD?0-wDpR=++7l6o zd&k1zJMuc9Ul~(&`LAA>y+^!`jR#&7y>2%9RNyiRWfm;{W@)D3s6ADaAXuuEX8v)x zVXCaf-oVpl`zyumOrJ_k1r1_6!RQ}|TUDTZV5W0BH#1UGQ#*WM<$600!=tIK>s3Cq zO}SmzR;g{69zJwvyj?v0D{+rp=5Ln}OtrSmV69 z4xkXfTOo+7>-uYP?^a1i>w1JoP9j|I)@bo5>-%I?PGeH;)>&2S2Q)@b6B_R}cp~eE zj8)D+LvNS^DkWmK1)&SfI*zHpTt69Jy-4MB>PKzWDvlyU2#$1u7R{o%>zQ zKN^>81ZM- zeu(dglLpWDq)-_7k4{dlp~w#P0Bj(Ssi~=~tZZ;_FaraFhK5FGXXpO@enCM&Qc_ZN zb+wC&%hRV%1q1{Z7Z=CI#{T@a|K$rc4#pH=Xqs*UhFI*i*l>>8SuQ(45HF!0+cItB zILRaXa&7x9#4^ST?G#G2p);I(nFlVdurenlV9MxrmCLYI(N&Yj=+W(? zm-pYi7x={?DD=POJ(rN=)U@>f$a`rih5se*xq@rz8ycIMTUy)NJG@ouI(vHi`UeJw zhDRXXqZ7Y=|E*ljy&RukT86ExuB~sxEpG1Y?(H8O9vw?;pZp1YSJyXOmXBZa{FUhKY+-JD)I&sdA<$AwcGw3 zJlIXu{P|DbW19Y1IG!UH%RVgr$*$;E!Mhv<``>op$v>qlt`GLb(@;ddw|3p3-?xJtYBh8J{2TPryH2U}w*JsGo2Z#tu0*F z@Uy+;msl_Vw5|{}Wltnx6PULLCpDu^Z^{KknPn_j2eu!Fo9i`Z{Z@Fg-hXL@NEW1i3is8s3BDO^=P>VBf}Ua zp=*x07!?_|? zV=_iGubh9Sxxjaf72F=!wlkfdB4VE<>Yv#I!d#P_D47uU(VQ4{GdA(K&)}W>y&s6& z$MYR)w=kEP!K|n-e>VwSQZqMJ1nFTfW@9ZIG21Im{T;FgX~LV?gSG|am?L%%TtmzI zm?<(4xet4&Y($2`l7wwc#r&Z9{m0OQng!}COGNOS@2p)-Fh9guPcJ@fV9OmmY&>ef zK5E)z7*ITFUYrX(YMI&{JZc@j!9H&5!BRSIZzT^q?xjs>&A)^7Mvtfg|u(J`3-J!Ell^dM% zFKsLH6PEUqCV3i{DHN$`_N=+Tj;dBxZ>4hv$L|pFw8`QHIX~K;iJY z>!M-!{Bb?*<-$(C^5x>{Z20BU{Lb*@@&xTQ?iFkhQ{`&Km>%LZ5GtHruo58)q*#y9 zXd&N7Fjiq8-vl`>k!__1dXQ~rC*qOqi+YPYj73YrVI^WTG??k`NgL7T(p9aWdlmjWXRuT~Rj z2(LE_M(Lj3?AGc(y*=zbBDg!9&LFslZ*dbm++5-(Jpw+e0nn4^C>{_C2uaeJNIKuz zP{rP9;)l`3Gs?A(emeF^GNnZm?qkHkwLSLjJ*LHwz2YPU9ow@g(doY_>>%#(@dt*{ zaTu_6k|!|v1PBe$vDy@NQXBaMO5&`F`pLa129yWM>2^Jh?Rz85SswiEj-it%NA9ii z$CD7TWP0M-D>N?KlTdu59MVs#-Mk>)r^ZSQa`Xd*Jwn>v;Z|YaYS!gsgWJj@oED_1 zFN31QPfsEj0J*gHj4Qpb$S2VJC35M%MfJTEi$M){%Vkt0=$F@4Ly2!%XXG#Hm!-0; zhz-4CWW{G2c#l^Rmyd$jlI!pNRKXVsar;h1^^HwYbfO}`s)-3WSvaVdUY^)O$-v`_ z+4F@kHk7hQf|)qBXo#J!GWoq5vp@>#uqksnXo{0wsFIDU%hL8&;L=?_UMHIhMNDNX zp9G7TUg3y6xGe2Zf{IRJO}@i6J|neqINMT&`a%v+MGEE7$SR+=cJ z&7UvM3(wsIz`*|g@qJbHO(Uz!0ad|x_y9DA+MY)4HMk`@=C>R3i=rzvu-Z|1Ro*<6 zE1Oa|%dcd0XueQ5waU-ArnFN5XUTih6Cd{9Fsi53h0WsZYW4-c^X4E$D)Cewb9Nhw z)IQm(H-mSmD!`w{e^eKz-?D28vP_k0LrRR|DYbv?)>T*7|Dd#T=U823*XS6phIln{ ze7??~ZeB$stUL+>WNfP1cESY9fES!|gvFW_yfx(puAGMT%(MMpODhucsK|}@v8#t| z6A&N9a85{bXlF;)R7r_*nK|drO|DXvR73K}EWTk^&Q2tJuW9BQur1b!IjO0oz2&kI zrD9%K`C3xflSgVBi(S6Km-uDG9oUi0q383ywgEuV2z0p4TRd1TZd}SEaq7f|p2Q^T zZN7juOcYZrxl~UV?Ympot#T~8P1Y8kH*?pV7B72TBo>h0-CI{7al*WwJLRKD+Efsh zzCzUq;L`^1z|>{kdjh7Wnvcx7`P?0=dM``rTfs3<+CvDXPBUwD9DcHMW| z#NNKxlyr2RC_QYbz75&;5OyYA*V#A%^zy z>L(%OxI<7P2*>m42R-nl>ytzv#j{os8$QU7Wt*n8hHH_&9}K6_zw+0aGjs=Z$b(Mf z`NuZcEJTOiGn^%-yF3hUD=PBmzsDNJT z{l*l#z>uaQLhAd|(?SG8_q|U?2E}Mn{|<(;8pjtkjcI!po{bBw5TE*as>3gdy)%=D zTnN=e^}v@{)}X!W$e=n7P<30n7y(y^(*hzRCDvKEuLnLz0=~1jt>Qwjhq0+o((&(B z9c?ZY4JFYFRA0?vx7|!NK3tS}y6?U(yP27PfTJcgW12Aj2Chc5Eke{?1XfQ?u{vj_~73SPsE$n#Dm_+n{>vBLKB%r(}!8a1IX$Pyz*g9 z@&Vg=agO;g75dV$`tsKK2rgkfB@059kn?`K?JXsOL{;Y{Lh8rx&I|F7pNZ}(j%aRO zdA%_4Cn@s-u==ZD_{YNhs|x&eg8XIV0(AR4B{lt7w*5{n@t+U`+KL3)X$Cs521eNh zy3_?a$$7b51v+TreeLr$F$s{H3GkEi(FzLcH4aje3$W5enn4MQ6A5(I432XOPPz*8 z!3YN0_(k>wg%<`HvxaEMg#?lYbMpkJU<83?LR_v8R~&=Zd4d`>f(l7PKGp?oU;5@3 zhA_8kl9xi}o{&A{Gs0 z!VD#q_rt=BY3_^pW)f+(9o&`_ao`kFY7(206eA3X;4O;gF^yQSi+({C!!8xO8x)%q z6c=9@E7TS%wG+Fx9e$nUfkEarAX5~6gb}Hb99+#BD_R$Ghfu-ELGR_`4(Ou+_3@CT zc){8Db4^c!>o6|Sgxagvo9*~$r}%}w1lj&X&H6AWQU9N8iC^WD46mb%$ihsD;yvmU z-^#~WiAGY1#`%(kb89687bP)>Mu%%fr#VMQCr8I#$K^ODscDL3PT~1lbU2&w3NzD_B!dL;bC{{k&Ka2Gcqn6;csH2@ z7Sxs!#sXo0b5PmRU%1%5G>gdy&V8kj@m7oi4J%i4AosOGy0X!3u53t*d`O;RNY1Al zMh%x(Ev$SUtUSG4a#e8d`+>X+Gn2*KvUX#OwZVlhU;qfrE&@ox<&w4zPBB}!!C0DeUY!fuF-w`-g>25nZSR)sf}#5^Dd1~J1OfDZz~{CBGP)Z* z4uYcHmR{V%oVkP?ke3BE#Hqb1o4zSq0hg~~l_eJyeZ_iCid{~r2)#3d+R2xDDikR4 zR#3QBFk)BM?UV=mVph1YSM=Oe@P?M5Dk9&SBwaOB_{>#^>Li|OSBz^^F(dw~Zp)t! z#=l>wkO{3SAT7rtGP-cYB0y|yk?qv_pkUG-pn(O%_~ zkjBKW;zrlL#;M{a|HjIa!77MrP40YAK6d>O;(>gxY1gqnY_Pe@scCntnBM z2k8{HAh%bscCjIKHJW$HDs|Zxb}8P~n(cPgpmiIvcYkne@iOZsf92D)SlnIN({;Gr zt-;w-zSe2j#8PV8W+_}`hU1Vh*W(n3WV2w;6vj?OnfClHhPtHJnX*dQsTTrjC+6(% zRqC)NmvHOab;V5TYp(4Jb8{uU>WkWMb2jb?g|*CW#bu|I#mzRj@AVe>^iLM_YbmDO zaFll{#wCPdWG{IBxbDB6uPSsKm=`aeb8Rf*Y^o|5tk50wOsyNe9T?l~@1z{uNNGe& zCzS3Z5Ng>v_F%=Wd2g1^K%V??YAUE}e<(8*zq~l%Rcihx&fxn!tid6#{DnV53md7P zag+|a60NYIf^_n37QT_vvXRm0p|sTDTZ^Hwup!N%idDCc3#G9n!_gA2Q98}x1lN&X z#Iyos%n@aKD0xLHs2erB1@px)?(kUq{@5eVIDYAXfkkkzS<^{VQG3u>apu@U!RXA; zDDS~&_g%xn(1dExFLcW>RPe8S*hIhV0G0c23g@r4e2#qZGY-^>!z;<&RfhJWc(VcT0y z&MEy$4xg+anzmTPkH4S(vN#!|hm{yUlT$MDdU&SzU`F_0Qj%)Un``vaVkS6z_P%7c zet5RAykzpD(!AoqZ(FW~SiQMio4GfOGfvI3E-xlp(tewY&nR#$j#JHcQBAhF&GawM z4K~jYTTYGM&rEa8SC&q+E6)^oEqqU#ULIa*+FwY&pWk#}`t@RATY0HEZLv0d4iYxc zk6Y`y80o&)Wi{Mxqqp4czL;-0Uv#jHfJ=|V^G_sWGY*F3ai$`4mb}uIwMv&tUQF%{ zFSKwjZKomLyeYSo-48Dh?ZXCeVaH%D*z?b@WjENQrgK=WW#5FE|zfqv&yVwGmMC`k5;#D zU$n#6qwlnNMGdkF4~Ma~V2Lz!@@j8!C}f}UZ1E^;QOWEOCS4Y)g{ zx0-|ZJa1_#%?g(Bc_+6NR(kjhiZ`c!xc$XrZ=`caMP;wrW4ZZoPd<8&c&WSGW397g ztq1a~&uWc?dOzJ`x9xskwR~X4ityW9%K_u~fj>|056<3Ih`ohGueI)Bhzh%-Qo9yp z<9m9LM$2&C#G#=|pHX_c3FOG^VX*b)r~%{!czGPpa@?kH97WO}L+NUEFoF}>3=?k= z<2d2NTROQPytr*HE^eSZT-0Fz4)-;lSPY&7gtk6ET3W$9C8IfY*=uCxUg1|+scbkE zSZ?BcSlyFkhc}*)FP}Wc8_QZaV^Td~z(1MCUBwnZn;blS{j$1!Z{^$rKhoow#OV3z z-Z}U361}O)ggNzwxA_H8!=@V)<&!eaEi#`m5@x#_}}{WKsU*4OrzOU;q58YjfozwAuX{s&{2hb8Em^ zQJ*>H8@MS4`Ss?7)Az)eZ(s7TQWA#tFz^-W@}-Qh2(YKA6|1Ymz>Buuy}3m=sbkw0 z`Nc{d^3XbkaDj1y+fuj1?sI2vuLKP;VYFYeR_MNFuhM_Z$x)liS=ZLe`#xis&pfHG zYCO25vOQojQutY)J!HVBc#p%PWS`Tr^x(jN@z82C&zj7d$C6Ej?IfJne9SZiT-jb; zb}nFFeIe*jBb0RrKj#+7zjQK~_XfB$+=w;0oVGbP-rc$0-8DDzg^9X7dcSP*M0z5_ zV0=v5;fF=S$6+Ctm(r=Z!Yor#my6pIMlNg>)_KAp7e@U$W?!>_Q6ZXGqrzh6lxZj_ zHlKFDLV#CyI8{*X)Q&&LcpzHp^#i4X8p}kkLL#5#ZcP4%SHkUY-;%dI@h;J>wLRDs z>%%+I^KI;CYz-9|G5G#m?aw2RAQzjxZ%wh>)(;Xl z=fZ8sB-}DD?ygR^eoYUzNIu-%T%H~2U$l5W!4jeMMkRN`_CjZ`qxHoVyUHYbgDEXS z=l@*WX~moHYaLx6mCM!2PjVj-`d}uI)9QEn>^k~Tu7<0xitHUC4B>)vPHUWet91;K zFAz7Lo{J%gGDg28cb3d0kz=oCj8*1i%Z+`nAj%YvC9TDjpk-G7-C4u-nknf^NN`Hx zH;^+k$W&52A7t5Z&73M$PR5dEKj*w9B=U2$o+ZQmram~`6HAOW%irN?VOB7E18dGJ z+Bw$TXa(|}oLFrawtOdb5L-c-3;AY#Mu-@Dk%B;gHaoZotgu#8+;GEQLRCY~QCdBR zI8a!-+rR-m*_h)fYsaEkgm#j}euH$cE-(2^FP zHM6Z5w*Uk!-3M-D!}bEwAzq#-`SlDHQ{#k!l9p2vaeUC!V2D0d=b z8gf>neCo8Zp8{}twJmTy^8ilTSt~iVi#?`ROdCfYAl2jWs{m06x?lj5_bi#UlbRo> zuTeETso-03^AU5y?Y0LUL(eIJP6r$kv?2fu)?uHcE|xM!>^UT;&=D_`;7L6R5w{KY zG4s1uYny6qcc{OSVUFh*q1;y+z(OSnKu9f*#0g>UonULfK3?S~FJE?AZgPP7AZin2? z&!{u;k`4%;_uRE@08bXVCVJ!tunK^kBkPUnHim@fMFCSE6DBp`c1h;f+!jbG$MjNl z!+EERIwT~?z})r~3zpSW!SbTT63>}M_Z3#!;r7rw1u)J2k+4vzrC6#Kx>5skJ@q4l8&Wr2@xk*K7lm?P#R2XnS89>br@HP zN4MID<)mVgUgr>#khMrWaJi4UzUmReluM|l{r2K#J4P}|9}@MUM`G&d3+=H!SsArB z64vQUZCXh&F>WVI9T=+a-$Q1}vT_f>M@w#v0WJ~Ffy(2E$OWp@D%Tat#?o0B1MOdlX%_AeW{(dNf zf`SnzB5_=;qAhbZ^$WB7jzHR1AevSV06QR9-oxPon*O&moT}%_o-s!SyqPAd`cbVY zJ{4|DRokfQa5eQQDMVMS;EEW~7!~uJm%&{@{zE1)uY_G=Rkc6KIxeWGHAi@=aD@y5PGEu2A+j@=NjV8dU7EFW7twd4=L4 zVx>^1wWh{z9ydyOwTsE$g@OjsOZ{d(0`9M>%OiW@Dx3s0q|3=Chc`cFaKcOtB3VCz zh^NX@?J33goOMD@&nw*g+du9#EO1E+)+WbW;6xNl6x5IF#xrm^U4PY12zp(r?}V#I zq5(mb1Nc2lF@MQzFnvXMX+6VrXz1RE1nY4$@?9)=`ZG#;1non8!#xL8TCwgv@Vr5! z5aLZ`;Y+{S_5|NXubz^830J3EXd|!7S?f$T!wjDIO6llCOl;+0!m+)>;?U#C;MbLo z8QUjI-u@H?3Qkzz!^5!gds z8^ms;*pGAx;Uo!XNW8|^V&fe9+Z zH4@LgP7b%%P7*up=W`-UZb8lst=Q>KV~A2}yM%F}?zJ#4?QqJ5MUrTzRCR+9Kjdfg zrogTCtX)leP@eXxtyM{CHZOPMWH2!Yx zdacX@RE3iTS@m*km!g-J->TkIMdPSRu|j?XEL)0=h3So{#Cltch9glW*|MM7Ep_ub75E%=&vbJfzG1N>qIs34NoK$32gKw-#b$8nG>L!I`f>E zWDdrwUbSXnQV!LdKPM<_^Kb95B56BnY@(uhmL}WVDGAg6vH+i+c*`3$sk;;y>xIqw z{P=^UDxY0;;PS+U5}bj>+HJSO>r5?8|0lV=|5>irLsz1EePt`%BLOrBW=i$N$?9N zjy`!1*(P?3i*kHp3I#sOpcPYJqBr|P=dsTVQw@fl`m~jlinLoDO??_80ejs<{qkgR z?iI07-cvM%VX1anlvC>4bAwx^6}vIL6|!_#@WK;v+BcbGOB*!M3-ZSoboYWQ(6?Dc z%(S-e=t#}_)=i%wQ%1~4kuRsxjk@Uvqe|@ph}4>C+a**c$EJ>&lRT11_kRpd{*DO{ z1q5six8~9c(Mg^4w;`8htz{0c+C@yI;#oxF?MZ2(yuqE*yrWI®{r_M6pA__12k zr;RGL%BmgC2EqHyM#o{+d0XAnYe`2Uy|RXy`S>bp@qH?EHiLMM{u71)U@n7mcdmlT zpKhsxr6!y2YcGdb*PYCZg$f59N@bZ9`%KCx^CV%X+9u> zn3ybogFdHp4ynt^bCne&hpYkboQf5CLHf^v1FJ$ZInUuTyaR7Ur#eLQb6yelit%;v zFXc$0Je^jhZ(Vw6XNS0uyTjTPfdkM@F zQuwAq3{%=P6gbQsBui0)=e8xUsbl47;C;8`To*v2R>-E571~gYVUkc^(FEs7K$t!< zf7gO%yrkdISo$uPoGTra_dcC+r8HSqZ2nVCs+Mg|3T_IJ@bi1IB`4iA*BPo_RAz)( zr{4K4eaxiykT;M;X<(Ia;Ib}awrO;eXEMUH?oDYdX4L(j*?c$pJ3We#WL9ck{-g7z zG0;E?M5ZpespP|~xMY+xn{QIX`~yC2&i7Wmr9+08<)_dV5#^Q|V*v+e!B6!9dzsZx zM>%7TcU$%j1&)URJ9d8uqyZ<70_Sv=#P2Ms?8eU71+Ff}PRwszN4nfv3fzT^-IfYG z4hy{2Sv((Dye0xY(OA8S`Arf@S$&vUqi_m+gbIC`dwpdJ{jxE9RSNw-Z<~H(4X|Pj z{2CD8z#0@_^3|;{D4{TTmkTvGy)Xn)m|9jC+QRCX%NjPq8vfZRe2F!pMmFM*HF6{% z49*&bR`dpgEt;6eG=!8bhIuE5oh?>q=kr}*tV~gy`ja@7qWI4{gdf=wtk@F86cRny zl3aWf0@#uhI876h*+BWGTwz6^8n%RrqLd!fxQ?RKDbtvVqO{GTsMVr$_)f%WQ3e`& zI6D$JlbAh}5S+zq7EBM$7Ge+N1Lw%>`bmRx)y;g>zk>E` zga+&LIU30a8)`V3;4A};Jsiz*15HyLEiD7hn;fm#11)flHjjZ;G|u*~18smpJst8rRwcdp^1U7*eA?rn+<qQy0~fi16H~OFO9h2f=Od7qXAcLOjf9^vT%T_Kr2&L;F=`k;SVXu z`c&|m3S=YX_l6O$&kVQ~A-~}O+?F)mj<8B_1MZ|VK-{t+yB+(xHNco8$X?IkLO<{T zj=(W6=9e zXyPvSqqNE6gS9+T86dVFK*D252ZKIg;UQ+_K{hr)epzNr%Y%YlhN61>s9uJ8#*Jod z^FU+8gFX#K_q4eSD8p#v#!M``N$0`JfMP*;u4~G$eQCLIdTlO8cyNuOxXWc1n`L;4 z-1wKrXAe9C+)x61-cuUl@~1@Hge*UffxJYw5TchShce~Q)_~77cn?1F68F+Vh^>F@ zIh2zW14#or>@aFM-rsC##3NG&B&J zN8S~*3R+}ZARUSAGBqFlRw+HV?V?Zx!yp%Qui&2N zxa!q64rvyFYM{Vt+2YqP?J8xe-mq}I)vzn4`79uVUMyp6hxmiaZn4V++LcBK$oGQf zGwh1Mh7b_G*a%Es&}TXuPe0;;lL)knJ=8no(nEbQ+|?6ar^ z)zHCe-1Zql)gQLlKFZpssR(}Rr7ilTZ=YgRtzN*U;aHvQA*dN!q!}rglu)hZ$fjLz zkw6I%)X^`}X}yT+sn#7e)f4B_TNaGntp2=gs!u}2r~haljaFlDWcuZd_X~|fBy-Kz zN7HYkyx(5J!zpBH4Drpr7xRAChld)~7}1y+2l5(w!h-{9Oend{OuzA(W(Wo4*OZuc%s+6LVPlwljBWO$H!df^JwPlc`R^)oN+Y*cDl|N3EkQV4TOgF z)(H%YgllYuQ%>7Nm_x<`5%YA{lU_(Ccdmug~%7qhoPmSSadi`J4&FMbChuanqr3 zs`a%W>*K#+Cm7e)+fvmNy4D|Ci6)7MCPj)?#?>b?V1o+k8!N7#H98+sLPS%MLsLgZ zp}*_X_OQ~o>YGnRNp9=+4n;FNDMB*w8;YMbWEEp&voy5wxsZvA?EuAc6NhpGaUi)G zE(N+``SVH`fPVp3k^cm){s@=)Ux6zC2^j%a5w!Im(TV`B2*rxPt_aThk8VYHR|INB z09S;1MMzf^dY+Y)mE7E10t#AOTwMHTBu^Q+(5TsP2%n=gCprG)Nt@Zc!udJ*@2wN>JEp#dt96Wqf z23|}Os{h5$0?=ZlQvW9?TjDVIKS5a)==}c;WrKo4{yUUS2BrL0C|mgdDMod5|A&(u z9UK3*lYKph*n0jql>L93Q3xn|dw2gYC|j1-_^|mFf}e1WX7iyYwI3sJ~LS@I&z*DZ4ja_IIA3=3u^# z?qB7oR`bWzC zX-Cl@sO#(!&$I5OZtgPu~vPpgs#U{w&pNbS4!pctlvr@2??gro4${_PX zW8cp7`==uXMqpW?-E9PxW#7pqi&Hn-$xAXqVA;PVsmuU&1eQ%O+btp~%m?oxwCwI~ zaaj-h-;&f`Y27B61fgZ&y9h0d#<7pkvgZ2;E&Hb>MJ%+pUpXejaZvSJ9ie4sjEWIj z)?tt2;Gdq00s1<1*1~vL|o@%Z%8c9$K_Z`&nJ|G4-joBB3Z7B-ybzBFD5mO zOE0E$9S<(1Um65*!DkGyiP{j81&F$oMS&iC{wD_Q<$~R`<>jJXXX)jV+a+JyvKKn= z3g&ZZ=|u5_G3shHM0n(CErP{{=xRL<(U;mtGKO4lraB&8Z)FA&0&li+6RmD`3jY?S zY7cMro|g3@HVY7qsoR6PEyyjRG4=fL_NWb=`|h~wPi3l)gPY{;R1GouuPQru_i1bYahngbhNy{c03oJ_jA}^V3H<$ zfTUwTR;4Vg_i`P?4IIb*JYiWlC$=?Y9ypdlsahy_HeQs7+LQ}c8-}0PYx)XHOe_BE zxgeWYw2P~Qe$qT&@P9;!pmG92nBM6UA`y6_0otCtacg^ym>MK}LkoEOP6|l@@B{#( z4OQuyenL!(LJg9_z^tQ3ahmBtdjr6f7N8?75_!T0Y6}4ze|yS2CPOU*Ku4GQMtS?( zix6gj6hq{LM93!3%-k00+vSb4B!_|yppAJ008suQcq07pBjzrK62K>lA{G!6+ntt>?@s;|^lt`lDMe=II?QyuI-i$;r8JAT96o$+V@%zbs z^;N0NCw1Ap{>K2Kd?NIkeM0~qz{?5mqduf$h2ja1ca>ljNRIUhfH?E5Tl$HGr-hdk zHwY?ycmk^TYeP}s>Cp(V%hi5nip-n|2-%IvD}T1b78nJ16Z0efwHN(UEO|9Lk+uTL z@Ga(UQNJRj3$W%eSe>DEU=reTQEY%)tl6vJZw0DOF}$B%HYuB`!SgD$w&&2D1yN2l z%sqcaKePus)P^2=A3;~vivQN8$l+x{1j;2}Nd{+n#Ng#o;fUIl;?vmyZFofzqBkWx zI6LA3ugoOCEj5ue_zDTaSpKNb$LD^OhA)xK<)8ODb=)DA8lx?DEGaG_4UicgnL zL6=S2%21c-!6kUZW%Ds!IMi)RartKMvgK+R>cM(BxQuY%i29Ta5;Xx#ZZQ)Lw@u2M z!X8F0Qimq2S(Z(J3g2Fe(xOs%oUl1Be3N_uKoS*NK?>3}eDQ*|PfT({K{BPO zKtWoxPcDxyL&5Dk5m`Yvau-w*FJgIAf0g`;&)Mc#6$&+ z5?cnf!K3m$mF1DmCTvqYF0)_B?P9**9tf~Cj8;Kvs_j8$Wai?^%`JRot)+Qi^c!Z@ zfw6_<;fHpSMOo0#g4K<^>t=$+&P)1)m4VH77H;2m=H5?Lx7Hi(d~|eLKu(OQ&ssS2 z`x!hcZ6(!7g?KyjO=pSs*>X%3w);_-zWnl&quSoATx)4p*L7_y$Bd^DwoJJ5Fe2;g zxCpnM=nCCUS)a@9xBL-5nNZtM(2NB}T_bX1Vx70&2otvXavtXj7r}nC)c^9Wvcd0l z+f~eV?m`^JV$_ltAmu(1ixr|aGz(3Wr!^Lv7(HajaZNr@JrbzN!@HzxrrfuK-iCLH zSVHz*$mwfb%mE)Ga$IEV*ZlBf<}1^#=(%2WhYJ(__D{X+>oi&6hBt25zd3VPY&uP` zep=Bk$$K@{X%dbmOt-3>kyBMrcV1#WzM4XPOHm}zmkGTOK7{Np?6tgc z<=yA$s*PSL$GKfkF@X;|N(NJV+V0GooYvJog!3$~teES+O;fGDY<(Gp3wpGAOV$i~ z>|Z8mZlJMI5NW;Js`0xXZGC{hY~_9T@?f%3pZ4g|LxI5SuoeZSPy@}$i{1|h((Cl7 zLEyEfK}ah4blXHE(+gX~i?r|sW}WwZEKcR67l8=iqv0#&Iz?g(5n!1Q=Zp_R$O1LJ z*{+0nG<}8YeE5P;*qwYOW_+F(;?$xm&|`Qpnjp?+y%D<-<}UM7n(2=&;THkC z)AU!b^HWRm6S48vCH0r9!)bDKYZ3O_JQij$@f-dYK>i)Y>zBXP6_S<{ise;+E@_~J zoDTxa>i6NaYdEaQ1{`Jv5M!YFB>8wJr1u()15;3e=eidZ!sQYT%}}gcfuy z6|9~a{FW!srY_iXO(m@_C_@u9QY0ig2sL#ZhZ%165Ey(=7V?2MR2mknVH5JkAT(11 zkew9jq#2U0>E~OBGlm{|CKdYLE3APutf?*_$0?{yBs6O~tR*R=mNjfr%Du`sY}+f$ zFDN_|BYb)$>_ea5+Dzzcg8)sZa3&|*85X|<(oh&i$N^&Jpe}SCwy5uvTqKS>?uxPfH_fOKn<$m2D7^l#YbS54ooJ$> zh;i2Ni#pt0N4N2g!055eXyvx(OV*f;x(L1^e*v;EmVVp?bq8o)%nU3>2@osH7QSv0 z#4C?lw;lO3IF=|F7uN)dV+ZlhBMvhuwq+(-PBfffE2h6NUZ_7_&oq9Kz(=hg_hXUw z*V!mp`8b;TFfFaP(7HIS(gzE>-}`#QW( z9uFvF`3^I6kT$hYE;XV)I!Y_0-!!#uHZ>A)3Cqj;iEP@3;N*hfv|Q6T*mcZ&aLB<< z`l=~jGw}}+)AVA4w8_GB{QmUWo%F$?xD&Ft?S8yXX2aHijCQTek^T%pvgAkk%qQB( z$lCFZMR=fV#Km*$o1!?p-9#F8kUnPASv}tFmc=8*|6%W~zS>;)w%b5yaY%xM06|KN zTLmpz1qu`j#kHk4MQXSecXxMpcX#*T?(WVGZP!|D*YkYed$Pyg2XBrt@&_at_ng=K zt#V)yM^T!GGo?jLK@pN{+&WN)r!|a@9s{Nj%CZ{Cdm4g=;a3WesYkemivEZ` zfj=!$a5d;QaWH{f5JWfOk+Aj{E#gNUa62d=M$O5D>c+_nj{7|!Yv{3B7_L%nAq*n8 z>|b=YO;MBskN`fSD%=>FBR*t8!FE(P9#Y(RznW~z4X`IpvC&QWCK5_^A;=G#j zZ8XI*EkzS8)x|tDC@l4*8_tc;nCAjW@4h6PirAZ%*;$4qyRX{W2&bi%*_Ftr5edcb z`@~{H09-+dak}U}PXHB&OW3FBHK*w@E8b+s2`LI0Tk@VcX|{O^xGhCtPt}l7umzIc zA_CH5Y@J4Q?Y!&Lk!mt0MF0~RSyN@4Rj0mV^jXWqS*zSxZcvtN`vL)H_-m(C&h1nia9?nh691tA^5v+Zqj&@FQC$Jmi9 za$h{nAueY}7mN~hzCn73I4hF7AsH=^^bxy8lv#X7q6 zS_07w%#Hwh{W@{Gy26v%PcLilS=J=F)=G!f-><-bm{EUQzm^XXWDu?6A14qLZFr*J z@NC?Nr6TxMM8osk2J!I*iSau5ih7~XqKyjbjfx(P?z)KhRRe&hQH2Lzji(7wz59<;+d5J?&$~&KeB^OpkKjv+CdE?q`eS1hEW^~*je-c(G7;MhQG5P##VnJ+jF|u%l zX>vSsa(-iSqjORob80VgybU%<=A_0}Q_>ZA(>uJlP`LD~Ug-^$VqsR@qRKJ!s%f;Y zX&HtQOq>}o^9;854DRGKf!7S=Lm{+lhQw>^aATVL3*gJ>l=Ak}xm6%CSt&zSF^V;c z$DZtc94=2ht_M*7cJVn^|2dwnxqDS}{F8HnWb=F~a{?dcAA8L|<(n5$nU86=K|2MY zsLiV8d`pwE_?B_zA=~vWxqLy+x=m(tfhA-COAyHlU*o-Wo^)3E@Js&@_+sPB3XNNGi?-XVtfqQJ6p%IL z@~vX1Bhw4a=7`^l%vj6!S}n0&t!Hc=Oh@ta&;65 zwqBa>on+{xI^Hm#sqZ`ZO zTlDH%E3R8Rlba2fn@e6Bmsy)FUaiR4t=pU1s4Sc4-C#g=IQFZyB9(1RrEO7+RlF-8 z!K(yZ@11(<_07p0a@DQ1$t{%ZjS2YcgGex{~k4H_ao2l zqV*mydzE)tsT?<<67pPDbzj4wXIdtNlAHOPf_Yd&~z?)!T9F`{T^}>g0P! z-3NDL_pw+G&$13#4fo#|9#YyJs&*efdUYhmzwyEW_lS)0sMaItBx~>Mt3$%P+tjQYdHKuI$B- z>!Z7{Cg$yx98x}is@+G*>o`z-4#_-+diIUJvTb9&xP_E02+Lko$bKM}^@R6wpEzs$ zG|Maga;3VT3tMQ5{A#}X(hKimnjEMaIVb=G;{sd@9yaRg;%w{Hiw5 z)l*mGY65mQ&&aAaK>!)r;{yGD&0}+4CnfKcy1;(E+UD6y39S@r2RZm%=uX+G5T6R zr^-N)70K5RCc$C)WdOMUZFc$SJo|-*)S*mQdm6GX~jj__NG%gcZ z$IY#Z!M>Ye=b?MnI@JPHlNOdIdn?0Rw`{O3@2s_Mf4wZ`x_2&hJ8N|E8>7G(zLHN- z?JZcvvuf7cq!Nz-?Ihhf=Zq0uya&LF?ONb7Qr`wJt}D<3ode78S#{sLoo9zw`#I=9 zu?kQ#>F06D8eH5AP$`y$2dExJN`EK!(E6cYGj*gTxb3Fo`QvA|*$Vsw?quVe@V97c zfLdr9N&P$STZQ+z)0gsq+nL8SK6WBMLh%=Vuu*~6&BOLMACqluU8G zD_#QSfZpd97$z;{=-wRz9qPaZYN|@s&&(6W|=rc=5Viat^upM-ZthGBL}`n{7-G zz4uYJhHBD2yYzmivU1lA1_j=KUb$MoPJI>*J@;#aI6*4ZT@U5zg zCZSEosX^|AvfQTfwQ2&@wul{PWRB74%xrd&9T}2p(!CBOb4!!5?jptA1hXL$gHADi z*k{98KSAmMjT93AlANJ(!tla^@Rw1&;)hvC8Zm-S!#22+U#A^8s!V5_#UoD^BVSN6 zb9Anvn+PUXPn@o;1zw(3MqiPgZ+p8&t*@nCv7fDAice*1v~#E$?JPt`Un;NP23(yV zOhsQ^EZz45xJz~bj*622I1;wzbTPhwZ`A-!pgoFd3<_09NZ%mtH|W4LP0W+9r=M8a z$x$dlwD%r^KL8yc(@Cq-ikdzK6|sVuB7(43N*12WAONN1m}*gFMvn=ZIS6o%3_e0^LT6}n6y#b34Q{~0eKD~n0I=#zmZz7zN> zn@hYHB!XoDbW#P0Z)Etf=x>%ELq{c?`S9aFfy+k?zfkh&=J=>)s83(Q@ttiNrLnaBb%?8(0PDW*`F9~w&2-elm73J#5hzN6C->kodj>$Fi z5$5>@u4~XK%D1Z!<`aevu4{6Q$#+{57Q_;4XbUSU^hXgE5~jY{(3Kxk7_K5L$_?Dm z*HKg)nWbxBxKOR#C=s;D$a1}VjFeY0s2I;OP52Px|f-29ZPsJx~^ z04awLZhk2pQ{J+MRLl@;nYJpb?BPd2Dp%fYnZw6aj;bJ4JAqr4D~kB4XOocXlff

%j%!2I z0HP$^aS>K}ix*AQNGG%7CO?k<7E(>rbT4SfLr3Wy$rMpD*U*j^81Xci{EDbWkZ{+> zRY`-29NLN}EVJtuI<7&(4{ehS+6_omdd~n>g|^EN?FN;Ozh|+5cBm8Xg#cTXG#^Am zJ9T9C!r&+0b6p320o4{|ee;N4Ot@uB`opoVXX{DzhI4 zozQ;4Pu%Alw4Xqy4AOb2O57hhw4cN^q4Uaycpw%?c#tBjtSc2wJeVqTkS0F?(v__y z9?A_m$k0*NQ=B3mF2x->$TFMIQ@tV{sUCX=E}iU2oBX6na5d@Tz2K7$Y#o)4;ZvkbT*D__W|P2= z(O0C)f{@c5R~3XS0$afrmObqYo&1!{4_lQC#yK5GRr#E*3R{yOJ{>B>nf#n>16x;z zoQ<@qe96a&hHdD`o{hpMzZ6%)HnEL^&&F3&zLrnHw#l=JNPp4|uu1dCt6*_#r$Tek#xZ$uLivzh>7FIRyk49AM zWG_~+sSA(z{(s+M!)*yJazvX{G| zQ3fX#9KJh?W;Zd-#vG1Hy3lRe1QrcDO;# zgmm+kj;`)YQPJw^Y5@TO3JQwo=;*DjEg%qRV`HPHrZzP-B_Sb^ot@3f%4%d}`CrN5d;Rr)C5Qj-BnPhVlEc%RYS$sy-z0}$ z9S35%*nb>{pNa$YPsibJ#ewK|$KfBvf$&$y;UC2T@|)xEkK#c1i{tQ%;z0Pf!Gmd|$HQp?(D%HiuK7g7jG?a5NY4Ki*}e6Tpealp+iTg}j8 zK7*uIPnxYi9{6xJU9whXFjIF)}`eYTGrqgrwLqQvmg`EIfy-`TmW`qkR8|J2?r zjlk)kp~dtAl`BG10pvIqTil0gpi9(8=;A=TaVd%$93hXeWN_{0Nz^U{LIf~PwjDA* zs-cN-qF@AxJKP+fnwOS)42|1%V%k^R7YY$1#CYZ0zBhxbGxCJw>+H4@vr;6`|Co#D zd5;U^&Mc~xf7Y?CQ6JvZ`yOJ?Ze{+!68;3%ct|`1oC`B?Dx?F zB2+0Kq`u=976*H5y}LsRlYjcstIwZgUW2gG<%PuLs{rq(cW=BwZ-4!^FA$KCNcf)X zrIL)1f0bn{t`52=yHlP|#Q^C&F=#uKY&WRM^6q{A*5dbmCc!{1!^+Kqh^wb`zG~Ax?zX$UC`aY6&Q#kK1dPb)m_8^A_v_cOci6R zkxE4TuTfx2GsR5nvrQX(@P*Bb6=Fpr*38|XSGS+R#%l|1ag!-Z8q9oYl&;iyIuzn` zkY(D5$icP^SDhYY+hS*^J-qJ0UW*Oiy#_hAB%Xy=?iYRM!`o z^x!~w&6!|$_bs+`)pyti9i&=ZjkVc@-Xs+I$`(yXxMfJ6DZDO7NsC*ck zmJn69cL8U(azykgOIw{@%g#`7#XN7erSS8q?L*m0&4@*7jziUHjJ@LiPL_|TJbB18 zBejR`S(_<5R58|M>&WxB?fuKuajL)CMU~Onu5rFP(><<#C&OdF_9CMhEF02jo1sl| zuS1oHrLPgfZs=Sd{%*drF!Pfm1HXZiFqFK%r4RpuvO1ruvo4`*u#Wv3$_eru7~m#4~L2tX+I#SRBBK(Y0=vg>IQp6qWDu?=o8#*U6;r4=P+I_OJTjP>Ij71nNS7|D*of7D>{|IEQ& z*I>1dI=A>Wz z9+lju(;6$TESmq|&(t;}GZy#)CnF5eI?=+h^(B#ulGe<3ep0#j<|V0W5@A8w=gE(8 z4p-wUg3TvAR2$urEm%#GUVrh0z2*iKqKDgLx3$`xexg}X`9il)TDby>S|Jyq|@8pERWH)qb-QEZFZg* z*{nb;pT4(w7i_Zz`Pva|xqW24`BnGf5E5$);DQDP8*l>{V{^`dG7dq)-nOGiv06pF z0TDvUkhMpb`;t^_yD|Lfv$N)qqx~}qyCxbG=NEQQZLKMrQSOJ>t1hZyjMxH0K6pcr zYF?n=OlzFgS#vbmi!D1mOL`ABMS+CaJvD*=W;tvfAQsPakQo5BCv^^z%l6@#U-<@J zC=t-21fl}ew7;?1+9GixMWR|B9XUOV0RUJX)rlN%X+K$we6SISJfMRkC z(9yOTI069V9A1o=>X$fJ3ORiZ)(&=Y)e=&2J~q#IVUwkW1T->z9)jYC2EfTNO(sGj z>;ae_JGslTyyblVc-h&m#QrSN%wEgJk;W_l;*L~ekxFC*wsj`eLSg`Tn8~_+D%Qzw z(K<_%J`9$ACSfZ}Yt`#4JsjdRviwat*}na&`wXpDr}I~VW)D4xO=pO!ruL^rB9`vu zr(bS;)7dIv-VxH9mh+l#u|1%5Uv7D)bmSe@@>hoUCXVt`8YDsf{ z*=G^iS3=0W5X}WG7zw$?mkjPpapHR$-H$z3uk872>J zC?jF0uYHL(Q%UGxUuevULzJs~99I}>s*g>x&pM}Xsw-oHj%&D!59VWNc-Q^#HM#JN zRId_rrm|2k9HDUZlkhfhL{@8fj!s0qeZ-B_i1gNo=GllMcv!1^M4t}h06e_DG<+m9 zvLP>Wj1N7k@pe>tX~aC7VKLPETrTWgNK}Mb6e?HLHk5JgV(8@@92eP9$kd6KEX0#oH2?p^4>A2^q7&TyP_fRDbz2V;SA}hpizWsgmA5N&2*x zSoZpZF2;wsLxT^xAM|zI=Gl`Cx!)LIB!3o39>RA@vK&peK25eIPO;}saZpHcGEV`! zgr&G4&d`mfc%7yos$YKGsdxbjsX^wcAz`UuWvLOPsZpn?F<|1fIPSCrg|sB|w3M*4 zw6e4es?oHp)3hAo^t?3uA3URZLATXLIu zMtfRv$7x0vab_zjnwjQmA>|h5+K>j@!JO@AF!LzGbBm14B9mokNRTnd zzR^D-yv;yze(FkBp34?a$gxH|&S3@!$zv68yBy6|-O0}oL?Qq1W@#sC-xQkvq&rPJIj3rLOy z(ZEp_tT>zW$hFiRQvQ+%)h8deln3g`saGuVD<>ksG$|OXu#~K@qckJu@E90_46Q-P z){Lnhl~h`&mwEJ-<1&b`&-w>+3u!)q4vYOq9^(aviI!L?u>CVZvZUXxm)b4Fr{irS9Z8nSUF z3evi3VbtnRny z`ysxkEt}))6NNHb9O_zL>bGblt|Cm`tI>VW?)+!Pg{agTS7-G z(|W5Ka$C}PTh=;q3RZpE_}z^2wj##1T%P*;jJt)F?TF&nbM}gM*}nFw^|l()jylo1 z4N4tJmK|IX9k25`IxO45pp8A}clvlbOO!gLEjva1I!CeE$1NKtBkoLRbat&bEa=}^ z;^|xw)bBK2r{3u3*uv`Gsi59d>NpthJc^(`S#LiR?YiJ;eS+7ON!p`O(Z#ddUF^}L zoYBKo*Hg{YWrx+ve!Dk&y+=q>zxRQCZ6=k@`l{qd^H{P%@E{8nP$~2KqI$k$1 zK4dkFOffF2JV84q);@v=kHIDvc_tq!PcB;B8K|5z?3>)2px%Z}9eYgP%ADM1y0ga= zIi;O9buQL3u`yAPJ$Z#aa0u%6B;IEoU|( zW{Ab_;9kz$>73bLpSewThthhsUu^bNf0lleitcju4S4R({SVW~IYXRLbDCl^Y?rr- zCgyrQh6O&{wa!d_Vm&VuHU9!IS|GAX?LINDIol**P0hbS>HTQIq-J2N){4MrK@JD? zO%!F+*n%n_szMdnIwn8^Cp~*&@x$e!5!sRn-_j?(QJ-34Pz(|%XDRXs0NMsv<1E{d zEjwH;;cNqNjFw%!mfa?ovzGuVkQE=k6+h-LgfW2NsFjeam9P&;amOn_z-k=XDiDA~ zpthP~y_y!a3hV(8^jKtXuIBKq-AcQO^0nylAb zqSo80)}KF{HWmEXL$=Y!w=tlyF=V|l61CBAiIR~S-*LGyO}07Pwa(l{2)Eu`j@n$U z+FZvmLFm_N-;(C|w)R!F4#^hHdN)t2w$3NF;8_6dm}*G}|HtMCG`o!!z3$uVD@uS& zzpzayfP@d+!DHFMx7h)2ITBazkY?||ws$DVcW+OXFu-{ERp~ z3zMVb-DG6x0I=G=guonu*A{V~sY3Qm6}G8qwQNet?M-QkLyW9GX{MRU&i#*_^MW`> zaOH#dQT z$f=%PCXHPoypQOM*q4%jM|H)#}Ujsmo0pvWd$}GLwtAU2n>L)hj~#eUa#W z6Ynbl03O0}m;;4MGlz;T;SmRlr@mSHZaFB{4ij8k4!g2FZ~nXG0K6w@3<7qpgYIK_ z`sVfR{>ySWXuIe4S^xf@mcv)$!omZB-z^6PVN=e$5C!93ECG^1ec3<$ZaLh1+tK2!Ari?Tqm- zL!9o4d}MZruk$)qEL!%7Dw{mcx=KOeV-csr2K@P$5G*6XNL0$>G5n#7K?bOum-DlF0aNqx%K8sVG?7SmRoA zNLnTGQ*ZySIh;^xe)bUzX15~#tot!e&rjxp_=`R~-!OM_DeTLKiE^_eW{oqf4^#E7 z5Dsh7DbcXneu}S`q~`lGH4t{SU~8DgmyN(w_udMarRmOG@t0xkNK5mBH8{z$J!I=^ z&!PKjlFY{X;tUqIb46YQxH0ctgN)4sq(CB2+U7@rSpw_O%J0?K0&!w`9%Hd&7}aBO z4M;pe|E@WRG4g;K@!u%^2lt#@bx zmHb-2iw?Aj5q@p-nHAu6#=>#GcIL`+a0lym(V>$Euk+#EK_&lA_6bY8F3yDr|1R$D zro+Lwe>dO7IbM&zO;|wBV{Be3{NAVERR@?AexK0Y$bdc(rb_&NQML($>cD$}KOpf8 zMj1HpT8x)qQ2LE>;Gm51chw;>a7e+Rl3-Zr^F-jV3c1wl zv?3hUPK>;E9blD&WBP>?L1P~(F9^qt8ezfXCf{X;LFM3yFQgMzkV(^p$lyuyjY`PW zwe2uveQ^Pqw!KLfLNaZS%||rjK%f$W&>gIaW?k+^h0MA!RT0g35VK8&%z5!%63zQO zBMY7P6XSy}1iVoRT?kT!T0<8@-baNlh8a{rmm)q-hAu@}UHcA>L}X#hah`m{D+z%r zVJk_I*4M&ARM=`7B(sWmEu(NUY%S~Bc*toa3t!LcT@;-#{o2B%37*QQ@0q z8&xD*6$g_9*UrNw$#%_6@`&v^Y<|+61_IRx$W9Z?hIF^(ZZzU>8Pm1%z%~`J*TqY4 zMY`Yfj68C`4fd;wHGpgkyH8ih_y5K{$s z?LOSX=Nb{jwFWv+;?JN7%RR=60y@%3&-6bRs-4g}~Q3F4;%f%}I90sPB?_)CHS{%Jwre%B!Wu^@ndY7n^i@xN&h zIKO)kxPN#MxW9-H*B%7!A0hU=_2&9Xky$I}!0KN$Fi(tM8;ETY%2-=I_ya*zU;KB&_ zi$K5#Sd0uLM8X6jgP};M7)a>Y$T)b&_%LKbiW^X>8zj^>V04HDJ>dV%*Db#qpZ{gY z|I3d5cOd`&|A9RC$3Q;V;PB@_zIX$CIA+w$^m8F^x^%se2VGwj{XUU*fP6TfK@eob zi6;+^&Eq*AM2hq0L>~0ZM*hb{9{lS@{>MZf_{&EA`$Qi2&qhA`e7A#R3xv2U`hw+R ze?Zpo`mAX7#UXI~>(<5fUD3;*BYE6EM)JVnKSuJP7{qDO-$wF4#7h2pBu^t$i=nV> zPt*a5L7bLCtmL=9UEdacAyfwjhW(hy%eEk9^4BYQP%M_#^>xv0C(e#ooUbkQWVqMo zMX}<5&O!~hug{C(ipGJ1TN9@?Vu#JIzV3!`N_o<<(=q)S50+_2kBN7E z5I06|LvUMLo^$W|`XG*s(g?STDB$h-g>)oPb0Dsan!lDtI@X9r7N8fA+|&V}0&VH` z9@@$jnFM|{Szu5Xrjal63o^+|WHd&nQJ(Y*HrQQYzGqiTt$yhjqJ4jn)mNTc(P=uU zU1KpM0=-jP1W&a7JdrIMepid{W6%@|4M*jt4ih)n$B2H!eW~tG9iOSCA~)|yJ(`1y zn)W>q-K$6r*`-N+r9YY@==@}-H?*!Wn|P}yMOTZ2yv_a_!L ziwlPbH^06qpD6isREq4lY5IypWjpAIX^D-=*kDcR5NZkfnAkctCe z4b`K6xvvqZPe_etR{Na!zN5G3`y4rVgK&~zh^ms}9JBjE``+C>kH+zL#M$r;MZx?4 z63m5L-J!Wgd=I>=Fc;bLL)$-)eT>Szy~Oy=tX<)GUNoxMGA~1Mx9#P}I4d5V=URz9 zE_?-XRGFHBiedC2lb;grX6lFpzd7xfMK4U{*VmTO8XZje@F^_O=)GJvcUxxD=ZyE6 zjtXlE1Ig5dnP0J2@87T(sk=PL+z?%tCb59Ssy=EYiop?wM-dO&3d}!t>gz#B#>T7* z^KMiUXd6!59d9Q4npbY{4_;;$`+jcpm>IjQwR43MAyiU!9Z4QmZfbuVo1WDyu5o<4 zEgt-64A=0u&g!E@S!MQnpSw7Fp zg*fCSKEYT*I>y@U+7Yh5=kus+5f#_8%O-Cxa*q^~<7mG}@AZDP#nFn;5m#5RxPMZ9 zMFS&;d80H!hfJP>tjd;fb*6gCXYDntzPmE-&c8Ite=xDOd$XLNq3V#EYJ7u3&|+di z#W8`Obc?0eVgPP^K>AX7#n#Yux}N!n+z;W^-4tlh)x2r$I0CRmjy|$ZO}SmjSQz7C%m&R1a4Pm5@93tJ)k!5%9lk znKIeJB|pC-dnGjD{W&K3YF8@gN@Us4wC)k#1V*w=?^laM8jC^5*PtWQ7#dsZTAM6c z+bL(O6Q*X{+Tkw@$p(a`cHJ-RNFZ8hga`4^A!nX6NgzJgYd- zw+ku2EsAfI#sEZc3OxqE-a_oC1pojY&}$HTP-30~Q7$3}HaUkBU#S-WhsQ#;qAt(} z2XE1V4uBj7sS;qqLnW1CM>SJbiIi`ztx6ovdK}e5oODZ^Tt7RC(6ZCJ*f9{w4eIK2x2yhmd*CNjP&>XcrAw&Dju(bPVA*@3bLAlKroG3T<=?7DNrG8EtpAi(JX=Q`BS5K~TQaRP6DKQ{ts{WdEnC~81p1M<*+c8uyX=W{!%OZMp#8)PNzMop%?Lisx<=%H z_uhy77|o2T*^@}i2L%tQg9r)v#4FO&!xxS93Jrh+_nFpo+X?YPiuaLg@gbzs&eQf$ zk*OwNa%%MZ!l0l$^~KHM81EftBwl>GPmB+ClHkSXE114C{mXootNJgAL4BbzB`AN`e&R{gs3Rk4yj^EiUgq zLqlGc`nl`)pY;U`9tYc+`AX3Rw#tP#Lpi?48<>^`o0|or+z16wg!-fVebWh5mS<^q zadbiF2$%m7i5|ww<$r}7iY5@2=<1hT8rZMRk`8s~S`Kpy^Ehj5+*n9bRiG{?82+R&~tTet+8KS_8V7LSS+|3HN01~9>T zE};T7@u)c#>K2h1;%X>?o>cB^E2(R%NT0|rlITX8Am$eTAUPfslDP2IG}Vkb zrhv*~_wG9tLpi@;q&gBOM~=qszD_l=e~=bFn#OCMlE%rh08NW?OQlXqwP;IKI*p^F zcl%bKR=S$@Vl_Sr$})%sC{Rd$%Wai-nr6tGzpH8{fawzrz~VCE(_%} zeTF`HW68#QG+C#OnNXK=!!7%yExUm}d_67-i6Q4`G-V4T7e&#g*W6*x%@IvE^Tt`m zX<6wUU>o!T2Z-7c-yw%of(JF=+2k z*`CFxn3Z&zPgt zA^#0YysTcK{93jmNs9Z8BA3+y_3-?6ip*?bMd_B-Nt9s|=p|9=AsNCYiIgQS zkBFT}9D>S9qGL;ZFo_3FOZ?lRMO-EH!4}W~;)w8aXzNL7Bq(x*t}IeA3Oxo9ScR^( zMWtIn7v!_E6``vuWw-NRLNDpciy5Fi<>i6KF;#j{OoiN9OlXZ_h5x;ZmPgQLn(_*t z*orPp=tGi9PvfH3FDt!blS~wMt#XS^S=S@7 zEvwqHAliIX{S8#Drc>=8S$&F0WJ5hxZI6_ehFN2}Q;l*C$>**?Rf2q6DaEXSG>2v2 zRzM;j)mqPFN)nDkGG9J=s`qf;jV-F|CfQoX;Q<=}v+n*QQwT{0BM&OflC@uvjV`@T zb-PYatsb2#nl?kc@6?*XBh$gQ{tkMc!Z%5(j%@z*EPILulGOUEM-iIXpNp4643dK0 zGnXgdB}jh&DSF-^`Rb)^D55dUqdIn{aZp-Om9Y_(i}s#A8|<#+*O##=Z6OUb%_O6Z zIE|Ti&Gi7N4G`K?jdFdd?V;@VQfI+8;HbtAgXFt=tV(GSEent>VK)}oC# zUt0~%v&|RTEZcx2O^sNjvTHjuFPz(D5`Q%g#EF zPFd2`5%_JCyBZ#_pXTRb*)ky|Udv_Au z9`g3i-Why1ks%5z2rr`X8)KtV>P}|t>u~GU#OT*G7?{iG;^giAGBF&=*p-kruop4x zGBN0ZLGAf6SXFA+V`I2#e8_NO=tgC~DR!?vHsWqn_E81g6*d}FITUsQkC}kmSv45k zAN~YWtk56L<7uzG+jXqe>;h{G_ZW72)9HcTWiQsAr#x1G-65jiIZZki-Z`qbI}%hr z$VoTUQ4kr+MUd{5k!b0e<|CRMao$>K)l7l?I^S|+mUm2^wB0(gM{#!4G8;A#RW|Y2 zqL3wWY&)`#hpl<^V)9hEKl!3)U_9|74^1GZvrcErmGWe|lQvXrXdm`k&MJ=iZ2FFJ zV-4ZVATK;Ua12P+)SN^U>dq@#fl67$K!7#t2}7gGVu;n7U7bLs^I}LS&w}sGeifK2 zvM<9}=V4bteZk|8&|)>5+kc@?BIeBNBWMN|dSc{cwUsq!_>!c!`x z1*y#jIcA7t)q;`Wf(j1AfOJtGREQMmBX-$qlK`D`BO?Vao(fNwX&Rkp<0lCtul5t&16~VizRq2^|@+=jCOtLjeq_qNOf_&>WarL#b zOZ>#~wQRnC8a{$`3Z8YT+@2;f0_w+(9c?yO-7SJAhw53DT9=+Wj4-2^Wa_>&> z^5bZ*pz1#5y%kC!8}N(jeKOw_aYJw<*#Wh2#<}>x9l=#Oa&U6h0cGqexe6{=nf#DU za$Ca&L~L`2uYRce3Zyf22&z3aA_rAd@*P2aJG4F=fh5<>y+JLLM})cSwuYcCzGM8} zeWz^TG)2@gX77dvImrF$7;Wa*Zwu(FdU8|nB%~ThUUhLo(i}>c+(=D$|-LGHCZBDF!kt?12O4YUdVvZ7afbT9){`}#;p>+co4 z6>t`4&->4cKJRQ>{ht+m>^;;i@IQQ6A6%=5;!g4xU(O~}2xUdA=zs9#KJ=z1#OjC@ z{dc}xABbY)r==$m`<*ZE`-(DL`LKWf!IzyWI8g$0Wiqe%a+>!yCoEA%+M5E0dOue5 zdhXu>QJ&z)74}UVs8m{9^W|m|shiH=mM*|P8!!kdBYR2qPK+R_hUVHf)&(G z@!_L(Z@h4DQ3>45p+yE9`sF)cUV2FSTB3$Ty9I>P@qI==ldG$E&j%BvF%AGPlOg!> zrVhSd*jK+NcO$k5l)eHyOu z@-uzp?OVu83g7Ya`ogot_44m{ITsHX&;ocy!StOkUrh`AkMJ_>{|GN5=<=Qa8(zMq z%T)hIy!?YMlmAz|{DUr&{};Uc7hQ(^C%pVOT_*iEUjBzJll+O7f2GU6X{5xHS{Ee4?r^^IC@$&C zy!pnad$i(w)EM2cAdAjNM8-1h?}BEaI`n^43em5N`iQlo^7 z>BZFC!mmgW#lfdlIK7f{t+V_}jn+#;CrUdlOMk7(h_U;V!IWv%mgUw}P0at>kWU~j zJtH$KJ0~|Uzo4+FxTLfUah|rSx~8_S{y%51|Eyp?TwPn=*xcIQ+1=YeI6V3vWw3Rb z$6Ipao4qOihZ$@-k%E-(6Vd-`2K%md|1TZv|CYg~{8tD2&kQ!@O8xIhW*D5 z_Foxn(*N4QMr5#I|ECW2GU@-RgG~~o{a-uSB;P~W|G9$={Sm_c4;^giUm@&&b+93( zi>W_D*uQnKA-{&Of7yr&BG!5<;)A0tuxKSS8pBT)i|&5}Pk zUlRk~|1^Y+_uESJ$4C^%0l_(;F4Zbd6{~p2yeJ^4E8o~zuEMfn` z`9R-G*u@+Fa6a(S#1GK-59b5^1bu&VKH%S=?7w1%0@GfIi%g-$5S;LHd3NeYn>>>|a11 z?ltNA1@z%wlfGX-AMQ2j`vvsj{vdsSgFc)er0*}#hx?QC{RR4Pev-bQpbzIS()R=O z;r>nfetw(a*(}^;f|!mj>#GBXTD!iK?L}_uGi!0{;dPsj_DAj|bn+g_cUyKjMFSd> z`QPF7*xyo%L1Cj60AY|gG6%$F8PoDpQxZ$MPWs1X<}N?GS4QkP>K_j~(B=_LBlh(Q zNJzOy$0Hf`IuPSdKw_ft%1d=_;?Oq#q=a0$N5)}Nkz@hMaXTxo>@bL9!E6to!+2M< z{L%)xSPN6b&FG~cPz|Q4B&Eq@V#(xlhi1Y|*zT)btiG$GA8PU`%y8#oP{4~HDkMwF z{J?~zgf;hr^?gTugg>{W&jyrkybP`d%P>3fGAe;#>2k%aMY;NBj2eTt;jM9rc|mNL znh#zEb(+X9M!_x^yx1`s$`NjNmoqX zE-Ds1VKNH5JwEgQu=iGBQSW`c_RtJSO1D9$Ae|O1V$qFsgTM^k9YYSo(A^+8G|~+U z0z)fG3kXWb{=3$CR`0!^eO>##-jjU>CmjE|fA{zP+{{Py3KKXpCa5HS5!B=u_2yI0 z;nvSoZ)lQg@7uw0;7M|ElBS>;7d>={?X7dNK7o1S zYxIyhpO8|wlzEas33S9lz9xBzhg_rqo%fnJJbjnvy;R_^)x3%|ZB%q(`H7WJ>3N@#!DgOT4j=A<6gO zD;iMJpiQrd&c&n4`t~IZN4+rt&*d!+)Z1^-dqKPO{h#e2;uLL;{wfTN6o#+6`I%y( zY|cM0P>1$?8Qz0Mo`1ZtF3>RW_>ObDuGg0(%f_JzA{T{RsW0@EC&I5}T->)>=tZgh z2Yna88AtQE5Bu)cISJhPo@ddyt}tCRQA75f?5_7_lAp&UZ-lGOi-k?`-6p}>HrMb7 zZ#qLf$^J2!!&*Y=ZC+dckh!I{Llk}BHuG5B!mf;a)0X#ccZcU93Gs2q(ZX&Je~SRN z>~TD|3{4cHuax9X{*%}0ZTmE`QmboDCnHbv4l+U&)*kro=QAcOPG!k@OUhSzWC2KF z(@1)|4%qm7zVN-CU-2uW%dYGE!DUb7)rD99|9Q;y_TxK8()%u6=U=$>XBbAOy$^DI zJ=dQz9}l+6v=mleSZ}wT^&iU|$KGJP{6;Txwi>H&GC=~}sac}k@@?n;#D8Tl-hSDH zVwSyh$-cUn%elJLBm~j;RSWoG(+~oU>|oJ3NDN$+*}|$Q42R25^e$uPo=?4n+Z4A3un%n z!O0Pe&{2;lNWm#eiO{HuDAU5JIF5Km8d*()Q_G?ispc4on!yf5MY?50wiRKw(nNWS zMRn8Q^+KW?KvB$tTG#_KQMMgX;~dyyV$tzh(NhrInUv@hRP;P5e39iiI)f(W6HUZt ztr#GQy$*@kM8%ZM#O%yO?9s$liN!L17mGZE#MY<8o~A^e&&0MI$AUPcFvVd?EYY|d znD{IR7;hEqS}`W^35-)6MlOy`rX9y77e@oerf!V07BysIq79;>jo%Lry{&}L2#r5C zkLRMr=S0VA<0i!I#WM(^A;RD=9%zDxT!IADgaw@t-jN_T0J@J(c*JCtcuPEy0FkJS zhN+?xNly~h#p9ldC(&vrX=r0A&_fe+2a>c9NnxTdjKKJY;>j%INy5OfX=A)NePi;w z8Q{`1%$gP+A`kcWOBNJ|yICc8HYPYKqT#Rn;JlnEfyMAHHc%KMC89AU3Y`*x_VI1> z#S%dL3~fUZDcT5_mG|8gM8*jsS;;@d$~I96Q~*saLZp@;QWH5-<1{!C3EH4)ab&GF zvR)fm4)w1DBe5KjZ70ZT1oBWcHIEkA0ZpsrOly`)QI$vZpQPbrr7@zAE#UNi&U6;V zw1L&MDR9~hlyDfG){9PWJ4weZN|&D#&-iSGT(t_A-0+=m%-9x3425TGpppCHnON$X zyWq@CMCMzi%yWy(BXAaqGYdU{P@>7irAx<`2)Lm2yJnsB^(25ymMi-ToJHB>hdpR} zy(F7tGecH8m^n3*evsg!wmB~QwDD;0cS}Lt#!_g2tR&;T+zWC@lyg} zS?k76yPNr95lg{3*`-T)F)F%6S;pc~nig1eW>Z z;rZEF`MQJouTH_{kGS%1#|!LF3moYR9a9Sg*YoCZQz4rO!;&oHltLf6R5|OyIX6%c z9Vo~H6rxiUnpzarR1~9AS(mK}l-fKho8~M(?JZTY=qv9FFaMlXK1Wclxmlj0 zP`=Dou~=NODu*z!s#vowry!{O!cuu%uJXrV{v$4u{nO{?5gI$E$j>E27%GXQbhK?8pk!)2bEHDsss;hXgrgo5VES>dUjG^pUxYbHvQe&uJxJ5_ux=kVNcr zPl6k&Pr+3QC+dFeHF!GJ+k4gA`4BN=8WlIORCATom_)TqNzIl(?XArk3f;1M^u)@| zP+3VNZzPcvy~p#p{JWAM&eAuRI(5@~Z%W{GS9B!cIgeL_^;21ud{yxJCxms43J7P3 zhSzfi){@2dK&3tr4f+T@6*Y&7IYYtd8{>gsL?Tiz z-=i^Cr}4>h1Ei!;t)nq(qw%qNDN3iQ5ZaKlg=rsI%UmppT{_pKyw_9}iCrz(EU(^N zuY!$2Hs7ymZXv{OV>xSINylvab8U*Q?d z5cVE&_a3P9?%DKyjqKel?WJ5B>ixXcyF}PG%iTAv(l=?-_b#$eZnU&-XsEA$tFMQ< zfglTnE6^XT+8>hEAJ)kI# z#e=VO!`0{~sZoN@ zdZVjeqw8s-n}jW++w-G4=c9WJV+8wBW8d}04!y>X3Dd?-TgJ}k$1cyu2tbVEn9}3e z`s27@?{WO}ae~(IYj_Le#QI@m|FPJa|3FSb08aX~fc9tJ1eoi9!v3qa4(RNF*A9s6 zfZz@&?10w}`0Iew4w&wM%MKXrfT|Ao>wvZn7T#EC**=PHOOR@ep`)nyU{YrSS@wb7tf8S?g`hB47U-#J<{upTcCzqoC^FZ6*xD@?g2HO70rD*>$ z(Dr98MfyW0C;4Ng?H4RX`sbClpZ+=NUsu{lqsaawp#O@cNd6?C z|HM)xz)sum0y==D$o?vz16YdW&jLDtrAYoPpaWQnQuF7+QQMf@w5`VW>O{*_Dp2TKwE!li!4Qba$w)bCh| z=oc>aA1p=ulS`=SR04{ZqaXd^^D04J0=5u`7n<9BpiKkJY*mAJ&ixm}=-^a(_&0%VN{1Dvn zKDp_8zJNEU@^c+zQt$TyG1d-MD+zk!>i0t41#!+NrYq00E?*Qq@P7O>I^uaATb;0q znYC*A>EwrpFG~c2EngbDyez|gS*q*$ShKa{c~wY~u;~x3mv_djQ)^9J%WYYsb@K_P zQ4!`9s^z8n_X0I@H(e_;Hn`QV2((7=4~xus%k*w2Pa(;LG~OV@ZPl0w(L;*$#o^^V zhCsJB<(USs&{}3HL^wMd=2nu&`_)`+3q4sPQl0j-OhxwI*=%2xNGWEKoslWGZt{DL zhN#A`w)%wgN&S!qD<3F2i?$_#6{@83|{cW^u zf%E~e*%s8Rd1q$C!G1TA=E4Ntei2-xgr~X;li{T9culsWc@$~X$uZN=Z6{#Jq^E(a zA|^l~vlqM+{}JDThM%#1FZ7I@} z$KY~bvy*4F0@gUJnc-|&dOK>h_+f3MV~mh9dVaO&_^8QMZ797EX;B+?1PJH$Ki%{D z*XU2j!_;bXs7Kc2C1p~5$TfhkOU2=Wuz^S%1o+k!DzuV)^VL#Jb9aEo@!HvPue%j6I5(}G)^*Yje zn7F`4&UHV(ig!c~@E|DW4As46ZAO2euFu1_iu@|N?Dmpbys%5D+`1U~b{pR7;t-4r zQq-1HvU;_)9l#dvY=yPeIbgWwqe>gjbFjGZ?6OOMXiD8+)7 zd@@kiRup_R$6f2(BLE^r9gJhS9x({=sOPZM>Iki+A$R9QMY7!SD@%51n;{ES^y&&D zd#32vw@#*F88u`{#&8@(Z|on(5lw$5l6@+LtPCGi4v(gO7rAIjCUNZfQH)GZD~9rE z(7F~GOw4Obj0|2Bb#s@4YJ8%IXM)))kdw_K1HUxIf zDTV++x;X;_zX^)Rj|0U;mCMKBzm28QCe=s9VLN$a-QI}9_!fJUmedg+Nj^>DnPdEC#l#DJ^ zCTWr*f*33g-|m39A&5y3@Qv}9!+!X3M}ol(aLQ-(*r0G?Gib_+UvdN|u?;$9nFR(@ zBE}U*d>r>wxMhWy_k*XRi9FGW1(vw14I-*=uiQqWC}`>oETsfWGz~}GC|4pbM)b=n7(RMu{BB1UN1IBBwJQQ9ahpaA$kt}42UGrgbG zUBxPW$Ot*Rajl6BojyG7sR7RD1BDYRW(=^TtQTJkY0T&sNZf&5n@dUfh9-OkPiOu? zdo3Q3*{zN^FV3Vp%IxYuV9pX|YG<`B!|_iEH*Yj%wT)*G4-%HqW;ei6DXa-O6~m|_ z2wfYqQOlW3I)o`Kv^glbtlM;ipP)JQMrmAJg!$;4T2P_@9pRpMZk1g6O%Xc68bog8 zI8v&F!1*M%NG@rGGp`VoCNGgU2G1)cP4>Mpke4NptU?UqDAnp~2+WkF^~j`b{ zFUax0&%Uiw62;;b8&Tp{Ra91jZ>*GANr10lky>MokMt-F@+)pS#q+&2TN+Z8+o6mv zBT?qdQq)H?LD^t6q3~ird&!X;)Rqe948^PFLliUV*KP z+v8DX=~qP1j9W2VWxZTP+Kd~2L*k94dNGv^ZXx`Q6)2t|5_h*Sn^_XKP@>wjs)DT< zhxG=s+9<1pha30%X|*9~sUSVBicXDTRrE-BjSj3_N)iX_)>(}XOWA!S4p_JLxlxtk z7WPn6t=4i`>1M4mOSvL@ou*MavxY8?46;tiu~M%Td+D@Jg{8ve44Xx#Ue&L{vKgC$ zrMdp$a)o^)_BXl)1=5<=k=O>F4RTqP9xB)uvkmu2>u>m-VFgQ|9%a>rAh8sXD5;LB zC_-$KGgMrILp(jUlvOE5(h$~kAqIjFTfrS8{S1qJpfP(6Yq!ude+a8pp{ckO%VMrc zST0GAf}mMw*@bR7tX2|>3f9~**W7m2Owd8!(ka;j5UE~5&z62<3xK4C=UNCx&sqQ` z^#T1{#&uPdY^a?796$r@#!IH>Emn>vggbdUbN7b#k?I0u?j@zVl82hAu&=uDf8p zE)lOTv9vA;!j`Vz2o*y&!F{RjhkD)eUfqh|v~D1TraIrPMtI(>&d~Ers^^7X4;VI1gjz;0fB zfNu`e&Vfo`Ul4e{*~hZ_7I>B0qh2lh5!%)AQ%AC0PO*Q z+TfB>U|?Ybv3Z~_@GGeS@&f>=afd*8dHIr(;lG9o0MG+q9RTzIlm~!2ztSH7{QyV^ zNEG~x+yktKk@L?M(*Kp8ff_;l|JS0@-+9pg4hKE>yQuVc9`s+}Aklx}LH`X7;`@UK z{Wmy>^MB?k{Xg1wjA>5!FucsMra=dr_&{ zF{Sql33C5>Fr%thUi(g)ok)20MQQ8Z>)NGh4<~IUi8_Aslhz&|E-7O3f-rF?1Rx(n z>D5r5!l-XR=Ij^#)FjE)LE;qlx?}y8s}wq#;Gwd|s1y5Y`Pc9Yz&sGD^W7DhjuLoFNDMH)dGA+wcD&Lk$!a=06|ZL_rd zO04xZom14b%P0DG$9II;{G)vsR+zrgC=4&e!RUa|raFk6&?YRFh43>Oi1QJcI6nL( z_owySs7H^i?ZXth7HMxZp!<|F_MACCGZTJHQGB8nlMqaCiw$kruMl5t8h;m)wcK*x zrC@bZo!)1D{q;f3_iw^uygu=B|HFUVGP(d(ukL=JaX&`@02`WvdV75R(sf?wUwn zY)}T4jOSzSr3sj=D+)EJ`$dp~p0rxtyMm0%nS3kI)&yxeXIOr(St|kPaBmqL(xga=J{L-lIRy-8DEw)VQtWeU1vbOP=fICt8k)E4i=>vj}L#;Z3b;Qxk%`L;TpDmRQhSWejSvR z~vFC9YK1~{kfYgLYH>P;V6M65R*h|VvR22E^ zJLFup#(gFzQ}+6DK596M>EO68*)$F29fe{B#)Iy$?Zcv5KHr=`^+K9?-Xv5=ym-@%bKDLshS=J8Kh~0ncmjGW#~5Rl^hG z6gWUGk93eG zyn~Ex*K15`p0aSig;!=Tnr`~zGL>8RwD7leHocE{4&Mgu+1bb1P=7w4I_aNy>Fl{g zvcw=N-nX%^?>${V8E1AhoYD2tWps0{aM?h-hpf#dpTD(-Xl$j)N%FfBW%}}C&8AV| zwoe&CdPNt94pyc16{WGrEq&Rh!I~Jix}*4|x%AWa0b;%y3x*B)vBnX@Fqalj#!aKx z=Jp;%_nh72HG4zI(nZE`?ooR3yXdBv;JKbo-=Q5-aQe0bj@{!0YW|m%zDzyYp=axjDv&uB!XIt!t(y&_ z$0JwLhwI6w%iIoLb~g;a4LY>#2vvH&kZrHYfy;d5)VvsCHrVoXm-zHNhwA zx+T)kPYG{f^y-a@s_}MlG5Upfw;(_{J+B8$}}Gl^#`no>M0ckT}OFs z^ttJ1`OD7uW`zba()cMtg5OsMc})jB77uT!q3E^&WPF#mrcACmAMNKAH}(3*rcf?X~h| zHy)J_<`l6;uK8V)k0_ECp25{XLPFBR%$&p`iWD6RAQ6PC;T3oh`7@F|QxWe4#qUqs3zMfLUznu}AnC5H0M5ixsZ6 z)iH`2Nfn)m61i6F1sF!q!pMo<$dY7O(|+jTZC5H0+s&ysI!$@OqgT}7aj2ZQCR0PI zg7~RJgZD|&h?z(>C^rXK>+nN>XF^z&nSQlv0*R%m@lnE+a6;9y1XDacMPVNaT3vkHz91>X+SfFwl}B^EsLniNecf0p$5NDD|b$Zf=1TiIHk zB;DUgo}_d(Q%tTfN!H&;a6;I*G$v@H6KosdWyi^PG~sRoToCclpn=54#c;jilt?8V z4KzGYG%OUs6^=-N;U~wPz#pUGXjbhmcK054L`git2J8|gj)*}&@&JeS9jEYFrWOXI z=6p%=K_@$-Q?>F_Udv4};St!V$eEPEo<>qs8eeCrBtK#JMQ#BWDIbG|fIKtynwbJ10D=I0L4H zFcC-Y(sF$mfVU`R=x(G=qti*)ldhDwFlN(EHW1O!^rHbd$spn_I+GNYmHR$x1DsvP znT@-N?8#4h8vcwfL^*v$DW^X?%P~9yFl2*^v)Bo^7*f-(b7is2BENuh-)&?HHE{}i zWQcBNVR`7^w9Yy+W0fh;l&Y_ty$a1%-po~v;Jg=seAbkqpzJxfp&=EK z&61R#UY-9!C*R5=*O*Sv)FYqPy5KQgCN_cUgW2q1bb%7I&=_23##NY;8c^`)G!LgF z_s#9hH)46u35xtT^F$E2u1z^19*@K5kdpmH(aMEf5}8j26EaUylQx0zYtxie&H#d9 zB%PjZM4>fZVnkDBeo6kZaOvsG(zP$~aLdx#&G@F-(i7I4+SJl-JsAOwMQ?OC2{(+e z!SW47WkmIN&n+W+v8=8*$|p+h-UgfXNfd?`MA^Go*K$M@h(vx&ExImW!PF>NiYk{& zeP+|JTA?&$vpZ|^ZPR97$L5Db<)L*YTE}#IlS@0=iZN2daNE9!<))dy77#Bn9@ zSrsY$8$c(b(0xNC=J|#O`G&6f4a3|Urn5IJx=e>OLO;Z+#csv2c~&dIs*h-@-Nj)E ztM>(L426bDg>1y}5H%^Q)fEBpWzCiqYZ6A}C*|v@6^%4;vMRQz{WbpNcUk34ZsFFk ziPSwKtdoyaeuAufT3Yvl-b7OpTTi#nxU|ktw=4o($udyu8*j^lg~7*Ju|Fkf&z<9l ztg_pR=LNfLmDJlHgYi+p9AQ2N{XuppKc>6E;U2|aNOsF2p8(H*2ge?dXo^3Tpdvll zqlbchB~eydC^Z^ihtu?4k4D${Mt>E5*c^~e@e|j|DSFftkiXS%2KU7iYsSS7F7-r} z+e9Q2HoKTK#mxm5mwH|HHHO&~hio)EW45p!H`Ycrd*XX1B2fdSEsmWn)+x=m6k!|u zuF{^d6&iDzoy*x&MbK$`vy-l@)fT@MlP7pdrIl>B)8bR7d)V6s`i=x-3F&;BnABUh zb#J!$#!Hn>!Sk-5&8GXJ-DWvm^4GcyQQZ!xw%5=ub$=AoJ?z&qDrTE? zRoV2Ti5m5=*rH_hAZg?2ew&(ZHFt=y)4C1OlUcueEEj$C|USu z399xZ&rsBAd4`+F1JAMh$T42r3MZSmS&1=Y!6hVVE2wLlI z^^TKE*HJ`|hg=t>X?;g*XJTp7i#fuES87Hs{@%{=J$|G+-cS{fUGxp|i}wPI6N1tc zcl9SkyeGucCnU&QC!`i8WG*J;7$@&bPd?P2lqdI|R7{^#Zk<$Jm{hx%RA>C~j9mJ| z3;hoo-XFBmKj^f6&|COmz;^M$h;hn9dg_%lD_9-FjFFEjH+C5JofaJUSkuyH)gU^t{kO>%52Q{OHBJ*TnpU9WNzl0eWj; zI(lI?T9j&hVR|29$y@X~%i`4i#h&QJ4>^mTE{j6j(!4v23*UH`ydTW|xZq*#5XF|8 zIr++S)+)+qguW_Y^8C7l)5e3*hQ`iV9(q3$5WRed*oTMI0DZlU2f_G}x`Kz+{v&-1 zPkAev{PJTb`6p->`nC@bTZ||j48&>A!|fwlD|GRxi+lwnuyU7)MLB%g9+ka^d_%dAnMOO{9_4P}0 z=a=uit4)GSb9Sp128$rjlC=*vp~31M2YYU(7|{#qHOJAV?#ruyjY8w?MD+w_hr8yfRqZ`F0E4f6QWj8lp zs%_>~uN1~?=HzS&l`(C``K**)Zf1U(lULg+=)|xcT|lXEJ#Sm6S=>VIZ!I%!cV=*Q z+b;tBdied-j@E77VFOOf7*TtsubF~hSM9!L2Yj`w_?p%E^~26rB<9Yi9nL^v&@F?V zWWjapy$#+?AK>@(-3rccJ39wzoZ%TeN5q`bKA#SOf<4m~l6WuPWaCn8FUe#NBDBYj zf5nlpxR;Q#;Ys{0`sOBs+_$(-n>4$7Z$?+>D>>+{zD4b?3=@6jyLrXVkcH9UvmX(^ zZD+sFGse#6uyDthgKun0*nwlzU|T$vgP>|lVClfU^Xtdkyx&9a@5mc+C}w^Stp2Xb z%)ty=-M#q!*mvm#KgV{(u4XLzkB4nP{K$88c7M3WuV3*Uy1w5fWInv_hhgn-XdG}D zZ*a(}juF3eXqU4WpRt2t+;NL#BQeBw+WF)?#+JIXgvs&)BnOfR+-1vAn+3x}?b?rx zZ*B%$&E4Za(a+h$2Q5(hoj6plILUm=$Yc|#5UmzDwKv&{&&4PrVQ=CSr&e4#wSI3# z0|POTZn?VvuN7>R>YsCKFq#e7FgsQgT&CZWu;Ke1Yk8b&H@#l>-Bf38GhF54`|ah{@3+-pe|tWkG#P3yDwfR{v$6ibx0+u=??!(UO*d)Xhi=x zFYu>&{rowDw7z$`+#mV+OD0)E-^}(u^7Rl>Utk0icoQI`X-Nk9 z^}H{zjl0woJ68bxKRa6Gd+^?vpRzNTaa>VxR|PlbVg{U?7{oaUG3G|sOEw%G95 zVD8T5Bgl{Bd*`&Tf5{8{lCS?zs@)(@)KOyuq9$;uhj`PwKSu-o`d|A3Kl}A zLkQ7b=r4VNVoi3L4yfDLIbEHfeSzb>lt#FyZ`&W2`J0_KyPK_Iy#BvrO&Gx>t=3tU=rW^ znXhNQ?#GM$@@Kw2_JjmUg5ROi%iucylp4e9^Ur)e&tM;O=+At;1y?mtnfD@Bj!Ur> z9OM!Q^7T2Ki0)tV^$J5kEh&+~@DcsY3nb-CJO|?S zFBPBLQSwY_346Z;>h&*QPwKiZd>GS%UQlX@coQYhm^3p?&Adu_89!x~?(L@VDm$Gg zdG5+plAj7(EI*@buELXa*PF%-ayNr-XUy_XDPz!!0=u+&H zjFSks1CvKcXl#ZDG_&#+)0t1#GprCpgETNI z)d*L-sDRLoG^q8|2ww?ONR~_*Y&UB}z^2__cDIf+Bt^&gniWm?^(oTOC#lB7(6AiF zL((wDCS%g@qFiPwcCv6Voe4R@vYb_fEP~3zgrc}V>y`#tqZt8e_}$z zHV~F_Cz&kTTilcm%>n1ECyPnYHf7jYPZFIXiw!}TGM(TjNglG3!M>lGvVg6UrK!l{ zu;^diq~%O{AVMB5q5bMMCn8ZngFN95;uRa4I5<(ojXaS9{fa{gn(!o)sRFf)y}0Q&AS%20nj=n2pR7p)4^K0~SBHCW|#FOC2EA8jU)^*Vz zsO{@vbnr=PHoUw)-G0Su>5aoM0-~nL!j=~hB57In;$ob6+hgk(v;j6Iz;fxNmUjpd zZZRNs7i(k2brR!clrqtYnCYJ9af;9lH=)vW>EV~pk2JeBLNhIvc$cx)NzkFiRPAu4 zFA3Wj#wq^l=7VQ_3Oo66NooDJ8AWN~sv&|mGNry~DWeCy%bb(7!_D|s-VGVXn5HoH zTvHH)BcgTXg*e7qRs~DYBM4iUw72*c64dBX26uVqj33*tenyg9aPLxHNv+Ef%#J_i zd7X0-#%L*j%jsQgMQ$!DW4EH_>@d9db*gadhN%1O#H&2U*F`q?*3aCXCJRHdi(f|f zyl9XJE(>v`FqF0-A7q@JQY~|>kVM%Sn2t}kR~S_VGv3pqrVs0u7r}~UY~5zynnUB* zy470Y+gW2d&Q9BB*VU!V+Pfotmbh`<&kKUmo#2T|t1|A26XiRu&C?24jz`Ub^bXJo z`nfr85#Vzo72E4(vV?=cKWzi ze38s4)|cV)4FIz%88|0oGJN4A)jS#0UUY`H1KY*$lVK|ZG8g0$!tzNvBVz8W&wFeHpinsn2Ex z3`pH-G8w;GX`aoZ7u_1#8FxsW;Ait22JX#EjJwe3vqjRAMfdhA#yx-P^JTChq>Gt^ z>07ww`6t>XNS`dzejJ?S{4=K^bjXnD05N^OD!xPt9rI=So=ttRu4L#jk;(Lfs95u2 zQ+vr{x}E9p4V>g++se>$Zi(rrar$Bhx3+48l?e;lV(JB;z0wQjF{B7}PKf`f&_q5>m`t>X_^> zOzv?^eo`zU5_K$57?$KXmMkgu19fZ#Vi>l{IQA1#oagE|nlKy^-Eka4Qe0DYTniYk z%{VTJ11X-1I-WZW&vP8lmlU5UP#r%Mh95bOA4>{O0IP%HFtF)uAZh#SM8xfXNEv{U zG3!1Cs2MZ%!@F}SSL>PQ%Q+NxBxc&G0P_N{ zHUNDCusHzv0WbBVf^in;Ks}SKPe;d z|4l^vzln&yr(^yPqzucyN*QaE9epgnlw(5YzYSo4e@Yn*+xtU2+)ww1??$#f91$Dy zav7CEmUaK;WKi?o9aMadaqwQXLA3*LG8Vp1s_@_Y@qt%2`p1+C#qw|0IYw9Pb|Mtq z^jW`L=eW&1c*rY?p=vv_KM%}53NJp}dMv!8p1^y&Y@BHaS>R&M`sFqUvf_j#(><4{ zS>g6Mp3wMY)sYb6cp|}tPh!=}H@ah^sIcO6vzU4B$H%H$Z}wLJALDp1bIks1Cn_Q4 zh_nUI*KgMai~oG5d+6fFw`O1^>fpipsMqeh(zND5dDV8YooT`>>96xxd>2Pip@vP< zpCxyszK_LLN|UXQwK^XTYw~xjVGkQVyf|E`Z2#65nE6AdLFPtYrQfz>+F1d6$5}Xu z=h+G~_!GMkP|SIHYeuhp~PJTUm#sDIGPuq=6LM^mvbGC|gMEVLqKUSyw!k zaun8kUN}ZK4ism!I`h)U1Me%|n7qX6puznX0=Z>E6b*TZ*R~fR^7do)9m~h`#JLN4eWLyv*3CCX7H9qX^UW)W0w(^xXAnW8wDPF@Pn4dyT}KSmu1SwI_8QkarElMT_31s{m`6YY;zM4(U?EzC&&LJP_jzqz zOT1MT_qEB+YrT61#bfTucQXiC*J%4JM}yQRGXsZLAMqwX?f)p`yN`WKTHIsYJC!1j z@x!_bZ}>CSlGphZHk?0XlqX)_e35n|pY55;=Cd5!io$Yi4t0LE$sRk3*gIx<51a;H zl!5UX0e*E$;+4XFO5@QAnZS6$=7W>v*DUX$VBJNFl=exkx<_+7lt2r z**v|m`LxjFC4#m5w!w!+&62|(p1va`3~S1o>2%5E`ad?!sJW#gMcj<77#CGAwlrh{ z$g@R*xwW4cYZWEGtjzS*vb;5?QJyMNlhewhm1M1%5%VMS4bi6QlBmu=NPM*uudOv6 za%Rz41d>6t^VVEQW@X>L`iNGIUPwtqD~_<0C#}-sbtP9QX3J zNBk=}+y(*VJQ&9#l|Dy9OBaiGnx}<#q1ZzIpwmU!HFtGr7fY>_&yP1eHLd)9@z_74cq;7(@~ldW8@Ewa>hZ zc^T&QiVcYvke4}#|5mt)lh-(6&m@?J_F5I~LXTQd31QL5`AHu^$NY$Ev+zbNWCF#9 zP1Ih0OUzoAe-QmH#{N6`P2L+D4|Gon5;gAT2}QRnxin3{iR+3K9&bKFo=z62)s<+D zWNvCSpMI#4xj$yk2tvF|wUBHU9(69aN#!-}?!a{``xgAEfHQkWb%L(?OHYUA^zI9H?f%a`K-*~PuWU1W>#7wfUm5`)0$ zy3O5%4HxKCab9ZyLUd{E938Ox4-guv9QR8Lc-S-gimMqE9ROHG&*b?45njG znE2Wy50+k0xAcalx^(T%H&%>ruP_rIcF8VopYcd8zC`}WcUW+ZT=E*1^gU@GlHPmG z>}Ag#JM&KV zj651@hd%UQ@>n~Qd7HxT^P~I;-!}J)vuwE3XE3ujdMsmY%vWy?HbBCAAp80vXsvC( zX4mJKZPD#O!FSZe*M!>lq|j$c0jg!;>Hf$&J;?*pk{7#32TTKfW8tT00X<{aD5Uly zL-AY})1oZWKJC-l6!lsU@{L~cXHMa{v2O5QGhlqu2lv?L~e6 zHE)i7|D8TZwZgy;oq>K0-uV4~(qcjEtUe@=phQJ6UgcvCRfzvF+`o?6t2;ChXu-U5 z4wg&t(Gd#}DGF9kaW~=!nZF;he&qRhJ@{NR@N&&>JSot=pU06V&}GKEb1mp<%>_8M z8GIjlp$Rd=_p#;(gY>%vTZZ*J8!{&wKhq3P#y3r|G^VK!zYxZ}L1mlM9}s`c-NR;D z(P!IBW#*M9`ylbp7|hJ8P@`scbMppp2Ls1XO3aNYVM(7>0^%JD-YUC ziP@To37)ceT@$srZ-o(xeW>NA0ER1S#=Qg}+3YA1;_ z3NRr1Vvhwl21(l5NlzgTMresVRAR1)vyOPIWtch!+}R>L*>)od3U@YDO3pTMe%9yA zz2?MG2>-|qk5hyf2ZXs5g*j;%5h{j0)C!cxPieZDBBPk%iiZdk_kNZVYV6aWQuZmu z>Lf6@*pGcZED0PA4@Z>bAXaY&Y91lNXoKw(Q|nDq{l9oy_6NVWFbY(ZblFHv#YE=+ zfU2duW($MF7C~~0kp-Q|`Y*^*Z9inOTNl{R*DB2w;q|sStrhLrUhMzD%HvVL(L2uc z%z!iuHY6Gy>~|czFyPt)0cYgg%y^-YF{I>?g-+c*;o1>*--Bkj9;YNKW&S{O9ihFZ z8vTzNLY}%H_m9%A8j+LYfzPH9R4JL*gK4-Po|lbT`RJ@-L1>0!DT2AqVtF?S0u!u|V@hF4C3(r6dC9@4V(uwogQ@%_SsGlNTC*>7=z{!@^S=Ga ze@R!6AXq>qk;T)LIfBl#9xUi9&eo-Kk4(vDrYl^7hrKRIQM4}1|5U)VnWwLu&Bs-= z*`!pIXHs-uCzqZpDAFS@db8+jc%eC0p^Hwzf@>~|L{5}WPD*OAjD(|5Xz*QvLXp`p z%ZRK^tb+U!jznPa#5EiUzNDs<(rdW3%$717#dqN5)^Kq&DMsF>jq24h?beDK6NzeB zFB|iSQh_HbYnNp>myORpY85k^oGl|;EnlK5|HxIL_Pl&bM{H)6bL=<@Vs4HuGT+s) za4s-^Enj(mHR@n9+Ur*34-bnguIS7a3v3ljTuIA)epE$GC+ul>?6t_`v*Rkl0T_WN z4+-HLvY{#*1nm0I8%n^*ps&6uS$$i#8W1ubaUiR?nyY!{s`<{U%c|p9O>G2igzoCr zu;j)IpV@St#tR@h?jaI{pb7Kxwdg6~4UckJ8;*Olk&iZN6~J|)A&E-drI9LUY^%LV zHZIB8xpiTW$si(yZo1tNJCF+Bj&*)OS5vV}S;*Af5<;umC68Wfs zhU7@Llp#c%QeeDIo{MBrHX%nYH-bj7>1jk$hizSx@B5~L=HM#NCQ+@z+Aa3_t>UQu z#&D0U3}jZ=Tyst{>TP8Dn}L!9kAg1ll4RwQ{vr1E(#Em5oOG4OacIj2UAE~nl#fnU zQfd?Yv=vI;7N1&JLDx#u-})k=%_FC6bf~S9u+d7Q#LkMlUHg9f@K)Ost+tw>cDGOM zY07zi5{=Z#9nK~lzFeNQh>mJPgKK&ngElE1ba@EwHY8!|F<~h;eiLOGE3ubv4Wikl zq@zW$lgR6B1{&o@&_Xlc#;n)!X0w&2g_U(UjeWa`Q|j$wCwLdHR|y{cf3WvfQEe~k z-fjYI2@u?ZYjG%2+-WIR3KS`>Z7FWWEx3DecXxMpcehg9T?=12JJ(!u?zP7_=X_^h z?#tZdA{iqYnT}scjeCr{;U>4kBu3C}^R$?^ML>qTueTXGlE0 ztK$ypQcmv*4{Q+BNIPn1)k|!A{Z`{mVH@pYM@?Y&4PVn;L-$+JZUNdJ4Ydw!lScdf z8fLHy+HUP$L!OmX4_RYIx_csZ)olS;c1~ELb`p1A6u2*KwnEIhCpd^|p0{{0zi=}^ zaVb9_m8gG%w`fbsS!tjz$%HBe=6X@;&mJuN5#L`4)=w|&uVz#(a_%Z=q>_B|xkF>{ z<0Vgb;$WxqU~l7K|NWrz^x$aWptX}}b8Vkp+}-@Gn=YF#j4UYEbLE-K8- zT$eFO(ilP07!l^u7|G8uGU9OxfpIF$acY-wTFRtx`lfNlrEz9rLaV#+v-Su!P1E@| z#3*tTA6P$)kplo>6H2_AIKzy@=l}o~`lPaD=m0qoFgi&Fos=vJ?PCSXpaAdyQyXsL4-jl0Bn}fav3;covAMZ3rcHq zxoJR>uo=9qxHV4Y02Kiz^*}@<(?l3HcLA*3xG@mEGH0`nkxFMKtT3R%XDm+U!Mbv;xw6DWaGkVr%e0&_w({_EWh)84 z+huV_29AJb)d*uH7-Kdkcoo3%gwz7)IkftPb;;jy0dPE7EU<=THUlVK)q*6ip@u-* zVV@xLf-B0*^W{|P>@uE+*)&$`52bfMSeNmLmZwPgmp7~|HyqODlw2l1YT(ts0vDzo?9j5JX-VYex9}b}% zB{}e6JD7Y&JMDVl5pgh&O1nt5eBgSzyZW$yDR5Yzb+~!ByPbSk*?hR~x_9_+SWj|9 zdb+%Ku65MvdUP$gf7^U?!g%;VatI(jc0N6Lf=2U<+U?jZ;uvX!8oY9Bbb5>)N{uOW z^7;Jcc)6s4z(E2qqlr_8dET_LCPc$`b-oeNExoojqQ7eKt2VZQiT zz{Z7Ub`)}q-KCt~WwqbsPBBb1aQS`nav$|disEYg&6T9) z$ho%1)mZkG{>YUA+U2Vkm*_%EteU4@Dd#>yC&|RsD&McPo?Sl(&R9iWz?EL`#hr+y zPCZRNjU&BG7`aI*2~E+y417EfmXN)*+lvYyz4aElEi`|UYI|#(a;wyGTOs=-{PLo7 z)j7lQI^L&B}4ApckT z;S%qe<7Bq52%`S)h>U*%8R``iXaA!_#@~PpoBt#-(lP;m5*Zy{SaiyN5E=4t+_vY} z7#}4<34Re7zkm$A*VSKdEI<7QGV)3?KuLTsAY-m#UFLW1;iYO1+#f*3q@i4Y0mVSD2npaPy&zLWn$myr|LMNm_L9F-ZU|{=MqH(wkC5GhC;eb z^OmMdFd&24`OeC0XgXzMIEuI^R zi{nLW?~035(;q~JZq=1hva9RQyBn_Z7uQV>e*_=yu0t8_!7w7@8Sx}sTQ{6DjL2YW zl=BCryhraq6}Ug|^v%MP>HEk-47(L_^f1E!v;MfQ*4CL5tHKRPNSvi5aWGh(WfYM$i+a z+9#==J31g2ve=+-Q;vEdgn3NME{smND<5*=63E&rx0#@eOWr);wnHKX-A z_wcLsV1sk&Px(zuZ|-w=E{>Y0c~_OVe$nsRgYqnA#!uWumh3ELakC*9j`4NE<>jL7 z(vsfdE4xbQf75!H)kcAAjLRc0YsTs_e3z-z)f3*ESKw;ZrX} z`e>8YAJL)o_?xkxlH)exNsGrfROo8~YiQWR_}d94f0Q1I@wZY{4afVC!;}bi<3f_- zVK7ECmys&S!ph@A#RG3vK^IAaNqAfIh9@$w&eeW(mdorx?M5;oM(|EEjKw%BRX5k8 zxFv)=l)NQLj+E{7nV)oMqY{ttu)0U*cdCQd~I})2CSmZ(aUi12Uuv^&mdX#o^%a1rW zlG}^K(B#|QiUh{x+x;pb*h10t*j#i!Sy*^A0M0txYH9rm#D_n9)~6;~1gF zmFv={v*Is}imlxj{lU~fRU~NAN5lNV6cjjfy+mx|;o*wpRL@oW(BF?o#7&Y>bIkRD ztZ2p~(_WL&imCPkBF4Yvb&}F6&GlcEjYpNeCT0AjI&d^L9$i;S!c1*5H?Vaw9@F-c zgvDERa2aDFwy%44e36ma0s@Jk3P>92F5~)tpSDTpT; zMB)yZfOl$cII?Ua>Et7XpJ_vN#Nq(a=jI%X{9tax^3d}A12*k8fMC>}q|)+@A}vOg zx16sab~;AX_G>xYT9IQjBXyNY>g zN(EXt1*%iK%6N72N+qs1g}OL!E2bCI<&1@&^&U*Yh1%QvhP#Z zI>SoFYuZsda-3Vxx@yL2K8ZSV-oaGYzFWp?xtTh0y;;zHe86i3V&NV`5jAw+Ng-{B z%Eun~un|}Y6Vi?mee6Z2p^K~y>A;^p_Tj-?)J1cHbduwq_(^K$VWmL2=rEN}0#q0E z@LC|lcc zh5_FS{)of$S@O-IfzTuVs5=_&c`Bmj7g#iU%wPFD9Uo)qi?|TMczE=ACY`3Cv^K#6 zdffDRHqVlwtQ)~(8tz4|B$}p?LJGlDp7KS$>XOli7BqtCvgnIKLrr7#6@r<%X^e|v zhb3d}M}pZl+{;pbO%r`IQo^}D<;(K8B@+W!>vs$-`m!=l)6`g-aA9uxvbqj!$<*A9 zaB&Uysc7d1Utk96-^q-xIjH|6 zGr+2UBQsF{Kr{X#Gf@6SGk%j9D1V_DzsL;KzoQu>C*Z%M8Q_?|p&6*Z*o=Rp8Q|Y+ z#&0wO6I{!5M$D87P0k8Nb;K@W0@UUu*{W z-*Cpw3IYty(5~|M7a&3YPe1}jN5Ciu7##uAkiw7#7}x+~9$@eT410j_4>0lphCTdN zQ^05k7#9JfBVY&w41|DD6L9FbFo^{mCO*t#0rOJ81QjsvD9k|$laazK7cfo5ZwKjr zp&I@NU&5>vVgGBb_+LQce*uaAAAkhTKLH85{gYpS1lE7Fby#|^|7%<4S6T<_pKYB# z(mGiG#bNqOT!Hxyhv^@X1m-^+rhh0*G5+o_{X=1j`45Nb9|}{9KifM0PU~Ra{w5{< zmexU^F}pjc{aacG{l)!p=ik#h7{A&&|48eg|LHLOOJR!k+hO{b!W4bx4~OZ$6{Z-e zfL{*NzbQ=7|4UkD5aXX|9avch{hw(amLQCOq;-CkbI%q?! zkaS3;z_&j4)_$D*?O9_^>?uv2tjT;u%F;cJa82ikyH#H5v!Zc;-{`P6wPM;TOjAKh z@Z8+UI^j9?r}vx&{$oMR{_!Q}+2U!H(z+pvQvI*DB_qF$N4=AbV;0MkB-E!@2p|8J zS>L!L&6lqboL`y4h(jwE^%|N+fO?TDA(Q_;rhhV<_aaXYH8@`}M`|ix?IK@gC|?<+ zrhlr~@uENnr9h=e3U#_X{-W@UOo7@=|8#Xd%0-byaDm3w@2(EYA7Pz8yE-VSg*ph* zv+V?zrG7GndU$jUz2A7JV}oo)^l6MIdb54YW0P_}y-1krGqV*)y4W}1A)Jp6!4=4| zq1iW-Pf&|3p5`yIIWX?SQtus(uCBV+Fttor=s%q1YZ}@y@4(Vnp+L3c>)P0{EQwr9 zVUFP)!r8UX!O|Rh$Hh0*#B5hypm|_HT+=Xn!)#wfxqKqZRloU(*>OTW>9ib=_t2%p z`a3a$_U33!bAR0n*YJ#$p9DEA56@Ad@QjH-EC}~ARKZdYyusCHeAjL0uuup#&D%93 z_1Z6p&Bw-Ug=^CmwH*aF#4LUal;1H@U=mYi6qv*`9;br)C9~^^fga=f@F;O@cS0R& z2z$mlv37QkU@!wr7v4LR>edxXf|1EJv?ETa147h+8nQNhp|ODh6=sK- z-q0^Xf%p`8AytLa#sI)9FRSd>tudc~nv&`)Rsm~@aQL8y^KV;cP`T~k*q0cd zT@wKImC2eMF3i+lLEZ1Br{E8nfY6v*4bCs#hoFFKpB+nhp#fqnp;5ugcU54nehi%Yn$nO(le$)-Q0% zfMI|(4m819$D^tu9LZBnpew!gPQEaw95x0KQrTuK<$N^&yFAm?L6RO<2~g^p)B0c7kk$$%_@C9 zEkBn~nQqmuJWuiHZt!fVu%Th!-JXAhhNRn2XF`L^?cHafS9pwIPFnZ-9>xqjcidRJ zAvl|xYImRUSJ5fe^+Tb)33(aCytAFmzv8 zzj*EKEb0mvbQ&x zDS@-AwzYMl^(+4F3x)WZ)iQX206zA9!MO|p60e+9Xy_pRTOSCZHGT{U+xT?X?kOq+ z=Ue_6wfJB%hnMyNS*(u1Jpq}lo&^YjX{?Ss5`oF#jun)Cfyv~~HNAley56m`_yb%) zDJRZ?UxMP~gZg{%duM|Z+k$9DWAR~0owziw=`Z+3AfLG}_z0~oOZNDyF)nKm{Erad z4G8|ZfiKlg?MF4gkb`Xq2of5`3O$wx4P*_y-u{?xAA08pxxEezNb{wF%Y&eQ33z&g z_arXNH{2c67v{X?jvfc0RPn%)grwPr`>uJA5*R|tXT!bSJ;@{?0^AXv?!nYMcNn+@ zZ#HFN@9^=WAW@Qrwu+K?3Jy^gCw}p{Q6}M0R9bVmagb<>G5?P{xcCmyrt$$seYkns z#tpO425SLmbTJwy;iR^4xYe;ShG~J${j?G2}x#SmqA) zV|~~}s!7uBF|%>l6m?0G@)67a*m3yD!XmNcKW?zr9g+oflXv^DcIJ|K+mestuy}b= zMAssd3{!Z>zmeY}V!xS75ls6An8#Y^4S{#W>UKzd35-TE!qS*aM;!Gv$5K;Ednsid2`UiQ*DtjI^uG0 z-E*Ah(fcTGa;RCBQ`)6@NO(O5lR_UB@xC57#xFY)A|0+UIj=h36$^N`0f z5(&`X-{t|<^3r$F($w;Q3g_q6qg{~IuA^YN$Uc z2}-|PpinxMPqIKY^U7ySbEudbQ0fRO=EK|{#Ve+c3%K{dI1Uwj3@A&t6|<|kqJb#v zwZ-IOged3nmE$a)NCTCFrFnA5C?~vC9Xh4Td|;)7s#a*#OATaj-c_ZxI=G0py0aBX z5C)vXtNyGGCbBSo#Zq9(2fljtwVJo4U9N_lrNSN=j1RSHRG6=6Kdy<)sqj!oCRjB_ zTJ_zv17HA>d>4@X)m`&x%@B@l*6`};S?Z{2Z86ZT!gA{poXo`x0QtT4dJOf&$JLnw z_L&I)rh@vi*7_nRb96u*Z+Re6rFsJ*Aft{CxxvXCaTHK{*HEwCfF4%geTUSB+|bp~ zn2QdmHclNiMjjuqFH>)-Tm=HcfFq~{P5G_B)qe?0!7zjAUjkF4WsFopR`k^dzyh51Xte+o?h@RuU}5}5w>mm>cmF#YW>MfxQ${p~MBh6zmn zN z;ZavtH$FZd78dsE)vK>xzjk$Xv9hwRuC6L5C?qE*fBg87oSeL{un>z?$h-s|cFSX1 z2S8?KfnD^#s?P@5&tTzaOirf%Y(M`Gu2}!GZ-tNq1$cORdHeYK`3D3B1&4%&g-1kw zgBfOH5KC7V z&Q<)s>PBNRFl!CgmaX@MQA#Jjtt;Oc_#glbZM5!JzQ7HH&rc*QFuTuAzh%zcp)QR&+|NIy z)e*gM=9mu-9az<#_cJz5|F(-y{6&`N@n9)$Ld$Y+OKXUxiXd(}wvh+td!Y zZ7Rhs`!sB?ez!)};U4!T7fny+r@h5X^M4e6=8qQQ=L}iqN2`z3XO^tr21p9LE7#|L zjbykHb^lrwUQIV-? zps0gGqnhOWp!h{F`p091e3A|A;H4+AJ3L0rX1D@aPI;Op& zFC-rekviY~(4d?K5;_B3Fo>H)_bVl`4WSMuL7R z_n?Z+sorXKGHq8vGFOrGzZ3eU;IX{g|P*@WpVVUuV{G;&P?O$Kzvb1e4yBSUAcE z;5aW@JK%{7GMsxL@9`eP$jDSbycY-N`-;(irGAbg$4$PM)zf)R`vac{BY+({RdgVD zoN{i7X9hwDBIG9|<28*I_dmiV={=&prb~P2yawwiFXw=s$?#>;55Ot-YBm=LS zf|QL+GIB(~umNBuengS+_;TxEcgqVTv>pMSK)qahR&ByR@1;@X;x;#OeMRPebeKNR zWZa_VF-5BFc6enFpG>x_@c!=eOM(7_1LwXxrHxL6GP zdlg#ob!GM93q~8F-?WMWM$kPyiOIK>$)A2EbK90*9ZB}277GirBwi$QoW>F*rlwJ) zdrd&;tU|sIB=d76d=2o8c?D`>a7A7uJG82S(i+M70Jh143qff~hDQMS!Z09`1c-`$ z%cWzU4EH&#x?@D)P`5Xp#0COf+#m`?FOc6SzoP%V78KMUf$GD1%=S{y>JQYJs^f3B!wCl8Jgd1NV^S*Uk2!k+>|znc z1O&W=9Wg*to*ajG8ajsnp>>ihAtrG!>FPi`BB?a*sa&ITNbBoMIw_Jf{!NWf*}ZhH zxgKGqoo}_eMAEixJGvz?UacEEV_beRuA>HMuQIyptHp(I^!$^@lZP>F$4GY$%Tj~UrknLS;`NeF>+Z>J}=uS*Rrhy32mM_|yX35unP>xk|}EDO%l68KnV zB2DE0sF0v1it<@c#X4~tVg}Puk)n})i~z5@IAREJvQ(!sifjp5;nRlm=U?yfBHJgS znN|gzKFQVCst0ioON11#Za-J5qu*RXZ?uWSF}3~2J#pWKlLvfwMgn_WcpDL-rHI-h$l%xCU-bA2 zTV^hG(h&5Zs0fJPNY^uXmY2lI;mPydi=sq9gOoTZ32xhXK!ri|{B2+|dxa_aXGf6; zx~qb(C?)n!9cPBC^K1Zqcy7nm?VI^kl|B|nj!6Tmk@a7KAQRu8E;-FzJpE9_Ydt19 z)1$Q_m%yO5Cw0MtN30qGJfpvZPx3L(>w8xp={dwx#!GZ?3M)zdxjp?a zLi`%OPOm3Vwi6%X+2eo(Pv&ETE32ZOQWnlY2;T}E1-~E`0WAR8=zB0fo45r2(56JQ z^dzCCclgxU=7$fd@OCkazPMPgx`os)dLBC<$LK?H>}+;Xicnt+HsT{V4AedHorY>$ zO<@V)_2o@y#gyKCr?|6kc6fFgjBp?;Oh{3u9l{iYC9}@eF*C*zeUu2y^TYCa5zPw% zfMX}f*|pV10_3O^A`2C%>+SmtLUyBNV3IGlTM3DQ^K6A7{X9MVFHlZ)voi1!oOJk@ zdi;69LE|>phNGt+4yPv?nejXQ>Mrt2_M#5YEi|m#d2U;Fyp#Oh0NoebjGqdoDrH!t zO4NEsHb!#s1|-@^zQ4LZJuFTxG?I8ST7J%IWCpZN^hel<=0N`4s3x6xw zGJNr%n6N?rh{hj%S+vTGc;FJ*u}sZASMvEm8Q|hPs9Y^HB>B)+{ud8qg z6gHvBg7&$;sr_J#8{=O5S<4XbvpDmd0GL; ze9JODAod`EUR{8f_cI+z2!fVGET22b*QeD70uoSuqoU1CsnTgB#_jd(3(+-#W@@C+ za)m;?xWMZRzwI@^z1YWw9Cz_2{b@1PmGjP)-; zwmy8HY#fyQc1~#2wtO}}n>uxfy2A@FyyEYNd_{S05kUPee^$UMmR;kC*J|h6nMlTK znv|iZr|NSpSV%wTa*7QW|GOpFEi{qC4N*Yn`S*s)hzHEO~ zZ)fRAm(HCuOUj`+&3)njoxwLc>=YRS-YUS&-BZ0!qq84iQ(y3+9N~|7%)BrH9ZgOJ(bbT)j+ji&#NW(1X+2Y z(+eZg?)ZV8*S!EsGGWqvp?MvNrOFYJYkFWG#TgA$Qm}+%&`7=%v9e=d%O)g7k7n?W z`VtrY*h+l(Lk5M$32qf|mBC5ikEo61N~xgGWXWOg&*GRDZ41SZC299Tu?iu6<>q1f zo5WH4q;a04N!6rjhosrKr1`p}#kr*Ao1|6zX(YF4WCZCHyy;YG>C}$twDIZm_34cB>CCt3ECd;BycsVk z)iO98Gq~e3cfAKB1@-j>urr_L z+k#et!uH*Q+Wx%uI=GJb!v6Zg!TG}B+rqMZa4%BPq*~FmW6^AUQHD`r`%WH1D;!oU z+$wMJnp*J>$KuWS;_dq4-TC7E+u}oll4IVIQ?(MsbH|d)_>$}TlH2)```eNSf>J<3 z@p62@3PCYELn%^1DY&5&b)gjft`w8744bbEm#_HgKv5}A5qv~hZ)+Lc_c8*)_+GU_ zAO(;?r@Tk5oETQop}dPHBP3%?0J9blv)`4c(v{D-m(v@Ac>{_011qTZ^5Cs17^ExV z^eROeVx^EPnWZae7s>@1DxsVp8q!1}$W;o)M7j%A?s-++Csl3G z>aMiv$It5iW`snaoGSIziLCFcwE`8CH!mk-qA((gRPfXZ za3_o?BoJ(;4)%3IA+-pDDPq`C1qn+*NMS@oqaY+QI7$F1kuZQM3~34jJyN}PpSHH4 zRTB;%U8; zZ1zyt7lE)mFrmiF^FUBKCVT)Ch!YE9ek<^p&~#ne@kmc(|I``T9{vscQ(m0!wQz4h zOr^lAJNh+3!3hj_iquXRlg=g#nnfm`R5oWKgYRbmkdnd6E0Ap6>yYf$3W+|0>LJ921L?ji8k)ej((}YxP)$I9#lsDF6)bdqR~c*txlFJ zaZN@LD|+Y72{2Ec?L{m6;5|`Nf%1rZLxI;&ls><(3cNrr@E{tf`<4tYk`J~HW|vMofy5rMMZGp4mh;Q(ekMF-ptz?Uq(%u*rK9EWpu z?mZs^)<}$_yeB&fm%LRchK!mWrNJjWgI`I97a0WHB2T{10QC<5Bv{yV`TGFSe$-J( z-ThvJ&##B|ntBTQy5+j&j+^5GrvS&}zED7)e9Yk`Xr&!6vIvyX1QCuhvxETC$TCOb}<6q!)3G{($c&gMbsZl1u}($NPQc03+*0~2z6B2#f}>iNn6 zi(}6VG(a*$b@0`18Y|+~S(#A%;2zV!00}_-Zp|BZoO%nouT?kdAy)YcCk9P9_Ac1fk<((&Y#a<#1V?;o`bzW)L{K{B{poegU89HI z>yfU7_OOML)~Q(?+Q5mg@|ADcEGDB*=_A9&L_fbCL#7N|@10r;?_=#7H122BKz-u6 zl#TU0Gga{H8xdyGr;s9bwDXm3&LBi0!ET%t#m`NPX+tlYh2*nWX!ymJGFs%>I~uX3 zN}_0{MB>ixD{&S(mUL*-(^tJcR~he5nJOFPNTNYIaPT5p+AiOhsGm|kg(Lgn0^%qQ z{;sn`@$SPjW@MM~NFGSc_Vif23i%7cHMV9;{^@S_CUM3t*mo7d{7f&2t3w?Q{%(Qs z;(0?4oM4;i;Qqjm*3!-C(nZ52&@Lrod&jav_D-1GBoMkS5Ytf|N%=J0mm7jg^_aX# z8o8S`+9_Dn!cl6NKK|0nOo_ZbHZ8fIS!;{bjta%pcs6qVdBr;NXC+^`JqBxN<|0?V zB3rRuKX2iWuM2(iKcH1i{Kdh}DSgC9`=CB8C3>3;SnpzC_TvxeWGcb<|3d zjSfPWo!M67uuF4tliKDMm*l?3Z~#01(&U86o1uU_-xvL5Bj5NLsx=MA z&ijvkb5E6sgHX4WAGVwzQcY~z7aR;3xMx`Uw0uqC4yQ~7GsxS?$SB%i$YUe7WF#n1 zfv#%_1Z)BYW(>V1-I6j7yd5G2eBVAeK4(~YgEoRhx_I7bJ+HDeJkzv^IK8DCgGa{au2&__thQ9*JUsVKh zr`(#e*_pr$O|}MEOVJs%5QvMIKc#tV1^vSC1y%6~VjwBMjplWR>rLK038lr=T;vsd z@$%=A(LM5g$`{unM@JV*=77*ef@bBB4QRVD@)tMcvKL@4L5AoSQK7FVA*+lsEmHnP z_rnzZJn4e8D*FXb$PP_eT3j~AA8tLXFY`;61BA)_hxqJFHp|59h#!eAp#U;&Zb}w* z+Am}}P+AZr89FoFog2O zT9TDQ{fy|fcxKx#J9RX2_|RoV@YfG1{3l+gILo${jnL;8vR%^x`T=M90%}Ir%#6wi8Yi< zlZEOHt|)xfkNb0#u>Z|qzM67_a=s3n?hznGk1;S)(KpxA+{PF2it!z&{{3beH~CUk z;tg2$UVx#2qP7=zYufmHC z2=~5z4)a-D$(9F77x;!D+n&9aTO!dZq8Fcfb$v-(L3(}^ArgB2)Mndi;4lzZkom*M z)K@a5QB{i!=$(44ACp`0mVyACt5ICyCpHge%jN%pBchC{p_+p7sciD>F9 zbwv_Rc#pbkzW``iQ8ak~HGM@2&8ByDrsPCxs;53Ock8PNQ6-1x^yyzqp9}DE`#w)- zl*rw%@-|4KLy7Za?*EbzB8X&|8S#!D!}Xh10G-(9&CiC}$*%cpBB>!#MtRxE0Y>?8 zX8A@1S@cGv1;xX=bcL160mjAkhyBJSNta0bapm)u#--f>-1}wy%mXGB!-55QxufsY z6N(b0q)n?!RSQgOmR*s}YIATI%<6ETOB2>@787FB?>8g+m^YpdJB2r#FE@Z|YUTpX zTiUnx%v<5m-m)~@;}RXWmE;Aov?FnWEjuw~`FA_8TaYbGakU%Gx{2JtR=pI~y~n-B zW{Cv-)W!v;{md<3>p{dSM)nGl318D8;u5)-VZlf6*U^qq#+N0p-4w0HMoEgkPP`K$ zela0e<5*$xUN*^SO2yoM$Yy4V`k-ozM|;6~wqMKCcHXEYm~P%gA)mvRrGhVEJs);bxj&!QPANT~UGQzxm@3D2y}18kEASWwW1j9qh9UTz^5km9YmYxzMMbOP@N>) zi6T%uUNfe*(bC4i)608HzgH&31FuTS9PhGl-A;diZz8o`neY%0Pl2K@xYluUO7NPsyj;kvoE( zVnTt!j8F6t&p>am_A$Uup2%Z~B8f5u-b+(F%L_?p6k}mTiqLt0my(!-mreqASv~PY zwg5e23sv-`V+-pS1AXT{1588m$i;cOSz`@|s96WT4C;TjfP|6x0$`ZoXQ{KeiJ;HX z?|nMwKOJ)5@{1r!t;$%qvM2r^%Hiok?0HE)5!9!#^>krp*z}ox=y@Rxi7%wr`OU6K z5cX?ZbVPs!OU^g6`VCKILvkApOn`^(J7gLlh~5-kefP29us$8W-b1RyQ{|G#sb;ozGljo z4aK8LOzCr~(5?``T;YX$d)e$3Q7{iux zYhs*XwcC&I{8XgUS&oWhLxry_pIuQ+~_+%VRGr}JJ|x%anAP2T!A7ZjFX`0!}F@Tup$%v zvjLAEwqHa9^3C0h;pX91@mrag-!mAaBSP%z+Qf=!{00`kBjEvra`nFk7EBKpO%^RS zp;%b}m+{s^haRWptMtd3CJ!BR=XK+i+x`ztr(!*>>ut*y z(=*LyO(oDXKCR?Kqna@ZBT9Y-(C8aTdqEV5vFU&;LJM{W&r~Td&pjC$JbY&^F_!vS zhmMUFUM+sbYJeMqo_y_fmw2$Hp_BCbI=v0L?O>Vw-mLiIeObfn;XscMwjbL~KLiaY zb$`{D{(i>KV<8`H6{#z|LH|mxUTv7pEXE`Tf&LI8I95s9<|k$a$8s#9uk|KB(dDfP9Q&Vf6w4&7zjNXp9^kaKVURQ6o54sae>|)2 zt(tz=&#-5V!_#w2ky^5G67w3No^5TCoAv17 zux#zC9RN*j*In~zOO3{_?SJ~{;G)y*<36drbAfGhPzAi^csB`e_}JsJtv=z;ZCgp+ zH)JDwZD$R;PBHa~GTNrRdYVHt!nI*_#n-b6ohNu1F8RXRLFMr92Y7Z|o9vIqtH}Ci zGeN0KRdu!%k%INkct1;$+%`AZxi)S1MW$Lc)qwt zQ%H_W;3~Tg77}R{cR4?;sb9?rxpgTZS(qsiZm)YpgJJ<)G1#G5=}zkg#*GH9WZjlR zG0rvBP{^ev*#|ozF{{xUr|}w%4PSSB6}Kz{(*iq-Szm)AJSdF{_~|SQy$Y*67e>Sf zhbBBbR){<1wfPR0-hqK3wTdyn9}j7-^I*1ncbEIMom-B930sVQ(FdT2En%{+2Mw%7 zTcKrZ#ZW5_Mo=+osc{|q;90HY%I&9stzt5A1_gn z_mHRcPONcJ=~Wl>GvlK5X}q+43GoXSc0M$4x9jnFUF$;}s}TZmw4?<~#p?Uc@Rs`e z#z9b^w!{n_@F^0i2D!Y5Aiw5KZ&W{5BcC4EuOECqS^2WFdyP-|ML@ugcwXYMjvQ1T z>>qqNa{OCzITfzyoFOi*_QqZ+ettR*V3mO8nt+cg=Kg-pVSb+iX#;C>-Lxcv{e=Br zbbGz)HHrlVt#7|<-=^zS3ADm<7}^fXi48LGwT#d4^!glhgX^|%ZClo>XSQu;;pb&F z>$dnN9vg5n`{sxnxeyXutP)z25z4|D_?R19x$S;U z3tGC?i}v$-;vdnL8V0P0h)MS*0(xTXn4y8}EX5<)dV|pH{fhkbzy9@P2Y3{gt(nfA^Icw+1=;v|o#+p+Y z1{(E7|GW+tM+EC-eS0nu6XhRsWfd2m9`j5$NLL~p9TZDA?X6`PuEP`KXXh@a8cc!c zZc|tnYZ~iiuIk@ZYvZFDha3}k_8q|6N})#?K+zN9F8P(~#=Cq9nk^AOYHwG4$aP;L&Dp*gvF{jbx2}5PrM&)yx~n8pF_AnUt(!)G!b1= zeqH=2k7=P~#LPwT<4n>c-A$4iPIB;ZgvxXb`cBA3oW_o1Y{qQz#$^H_FlpvG_6vT* zd0glb{ZG3BxijIs>_vW*Eu;dv9=<037LJbB|%Tm8dN3>3_+ zz;G`!MOHIhZgYfhbGZ90o#}IZ`ZL_)bDgDft!i?(9dpC#b0c`ZY|;o|%A z5_a>Fkn)r3^WI5?KFafku|V??!*g&eQv*-4T-Ou9W%;U-c>MYWH30>6js=zZ1x>pJ zUoOpC=nET=GWzKF~z0?t)K+9-A-iDm(+9h=c8kV0U1lnm|NjUkjqIK%E_h66U9ne z-jp^*mjmkaPazqFtQlkD`88uD7zri!0Ts7u6$E@>4yO{{gbL1v3T1APAaW&FLkZtN zMLSt}>PF=|S>&qZ4LHDR*==5F{A`h*Di}L~9ygFycNc>Z07&4d{)nt=U|g+N06xvH zG~KJFVE~)lRa>i9bdc424XoMcH9720F2$;HBdoG>tQuD#EZ&=qn|GwwVqk;}a42op zh9cL66V}np)wrrxLkVk%Fe^L>B~+11x5nWqw0l?0S8fT|+*yNHSkCK^n+w9Mr}@Q0By{a~GIB5SL%j*oe$Wsg^#u8dgUL z_9koelLNJ+VlS{_l1ZsfCfGMDgk?Hef8T4=@@>@r*3_ZXgd7e4S~Lx&*4D!}Z%a2j zT#D3h!D5z>EN*l2O80;F8aS5#Cbt7WgCX8(yf_ zU){C-h^$J=6YvHAQMup?$l83?0B)RUdf#!+ECAxpfM>#u4S=rqZ@XmKkl*SdBeTG} zvSPWMKy+sCpYo${@>B3K3L(FZ<|jhA$9ke-l%>JjV-uA1?j@dOW0q}WkJWw;8ZFTA zZIA0>j|;fhW3b0lqnG5hxQHT9LB1L;J7Gi!h3SOZtM$IF1U6`+(eq#!rrx5snQM!ODA!i%o%pz9j)C3;|%YoCXfS zBF2`9*_ImFheW*xj3Y;kBhbWbl^fw#f*0D zzXLbDn+zD73g{ac9NbCtzFF|?Bx`*mg4WhbvwRE^q-Z5?!eKuddjEc4Onx{|1LaE) zBP8kd$nI$W%ZVPviMCGH>CB0#&WTB?dI8$t3|in?r;^xDnN=!{r5{SmUzLUk@x4^6%8O|!Kh#SRu|9(s$u$2pAIj_B74k9sVdCX7T;G?WFL zVX+*4#{;T{V$Z#Wk+*zN89+$Sax5j4taCylkSurz#IogQT!ZaJQQe$Zc%^nciT>!-E?e&0HH zaA_${p)D>2DmbN-;#O#JE2X$Yakm7L;10pvT>}IU?(Xj1;_~vGGw(g;&fNcC|Frka zXRq&C?f-IsLgL-%D@@0kh!=qkO@~^_URuiYS{Yc{NjYULgoLE~T5X5ryK7iC6nGf0arrVuq4AfWp|P)xrG{ z23<6KqIldo*t{4giCyP>0>><`CWVslPh*b;iy|mg{o`}ypIF|%P^^(7i>1o!^`SEk zYp#ntALA+LW2nStEkDL9Q0rw|Bqeh#^`D1aH5c>1!!Fc;b=LTQm>2=1c+*g#Y{X!# z$ecZ_y_|5~s(E@6VCGkA$R76TV0>9w*t#_sXk@uEfV3Y{NU@gA3nBD*rY~%roBfY! zmf9u`M&){I>be3iextwkFxyDIh!1SXZxzR-M1M`93ba}3^C#IZCK&o-4X!L7#akU? zjG8wkYWzL83ff=s77QkBNv+6SZh9PDwRMi0O^6#e&(I^y@xkTTxPqZyXXDZD$`8*V_K5kW*S+mO+hvRfv#bW1qLvqF(L3VeeAMv3E1 zkD}?1w(W{jZ>}Hz%LL59e$~<&+N0yobt9fn+HfCb}KZl-ZjVvLsV=pFhsF>7QZ=MZ3$J zdlu~&rv?ou9ZFN5^DW0kN)Me(@+TV7ovzg$P8u$8TGMc4oB{wOX`aXQ%l_PN33;h@ zHt((Ol~U2twlUHqu?!L>U?(LFo3JHaeKDU{B~ZK=4W#+}v1NT0eU*FLn#L&QwP6nc z&ifMG?|7nWXN==)Z8xYr?2_qi-dmx~QR_TW!I6LH=O5Ct2agW$)Ba6tE^;uhX&78o zPMsq$=Pn}A&`xCZR7o6{8C!{VIF1JHN-u}P(!YNqGbV<}eL4(QJ`XcLhYW9y&?n#J znJ}GP@mT(gTyB!_Y4jZCFf{$!h<-N8ch!*^mt_D$5r?Y|zsU{%R2Y8o=kaX#Vdfd$ zt*$K`JPiBu9bNDk(%5PmeQZ1y{vnn4U;5*5X4=EEhDW&8^LnmJsCj;nCVI?BMwC#%X8C zzZ>RD$R~30i%ASkzaV~vIU;slRCz`Rh~%|giVW{}3Hf_=_2+itzTM#fZtv%^Z0ka} zdpwDfk`Ai<5hcaFfx(rd1Ifc;nF_dPZj=3A@Y^_skvD)|*;%uR1+keE4i5*wQ(}Z%pzb~k!k>EqJ zFs;Pw#*D1>(0^z-5OjOXQM5(GEIwsgH`c z9|Go6Gk2E;oA_-vO%SrZ8 z=hCk+-zfM|h20-pe|ved-k;WLF2KKhb&j8V+$;8|H2+~Eo_7RhGQ;#1i;~izS94n7 z{YOF8aTzS<4=a)@2|da|i(h}Fy1k6jU<5pVxAS+~_sUuE;jG2$oR&@sb(eC@nZg81X+IDRg~Y1H{s7{?Z=;)BXW4xWDFE+C7(q(As_ z(rLZChx_*6mz%!nMpttxcs~9_sbTA6YVmsPIb+yK&WJmTb5fFpIJeq|qaFP-RuN9s zfF70d?Ztx2GA=ug>+;Kt)zGlJ5x)W*+@2kl+J+OjsXOP6zWxg@qk+7%PY;(YnIBAa zzdWzK@03QOj}19vf5$}M4iwg^CbXdnFl6>+Dqv?KGXdHJ-8d6q9a*`##-?W4MWMQ9qtVl=o-8Ajf5S8--437SgZ7%+@Q{ z_*En>2ExoH7;na~%_$ftuPw#ary54aMf)s}U8A>8z4FZeyf0mo_L@ShJJO&Qu}SZ^ zo2-1s_X)|i#sHF4C^obWot4Pui@*LLr9GPXmG9-acSxF^sw4jtf;n&}Us01x-uv~q z9v2_a8^&ETbO7a;;JCRGW2~bL)RjeORd(2IE*gFnWBU59*0ASNRoaa)rr>ge(x0iS zbO|FC5vs^M+D())2`pcdMM&PB?>zG<3ike4J_h!Vxb=&FQv5#Aiq>M5*=wKK?hcsy zB}`&o%_Qe5nD*s$FrDWJ4DJ?vEF7N?a-u20xF-z@-c!k>{vgdSXMs&kI`LBQgN(}W z&#DxSU_ZH<3{AR1^_P8_CpAMvE? zWu1U5&6U?RpB(PG*6}Z2Blug&PmgjmAqJGYb*+G48!uD-m8%7Qr1PKG{)0`6Si%Bj zSYdjB%Z@8JO}n||n`J?p=gUuq{mN%Hx;Rn`v&7LwZCeudtlAarjnxTVC3Y2vYphD9MY8C=m3Q)SWLnxi249w!$W>hFhO{?O8Ni z1o^c2OsCVSBkaN|8}b50Mt&Y1_mHL7s<5TRGJ+;#2f(oeZ;{6} zlIPjHD4MpPUp8Xa7n^}E?jJN&)G3l*tStnz*R{hMh7_C@Jp;pcjK<1t&w+7es% zx(xrT@%S+3Ld3zY#cP47Rcc4O0TJBa{2o)wrGM%ZMPaKoF>~Z0Qn*Sg{(tI{7mvdD z5!?4i*6b3Cf8#EjE=!L8%-=4R?6|bx3>`na^Ke0-vTL*!>56vC%b_`%`=KyXenv z?FOT7hw)9DCyH1vW}dY9>+bt5&os6LHoLfLE`en@9?p3l#IDXJmmbg+?{-r!6%5}Z z2A>OQpPM?*`vAv()BJ8dK2v>p!qwX|kjIHS4)B>jUbi4e83Y zV#NbWYkZs019KqGmd@?~DB0mXl?iZ4FSM2bWM7#f_c@^O=M2;T8v5Qqb*3<2on0G6l4 zHCIah8AcxxBWMM{)eAsFMxm2L14sZ*yzI7X6rFuQE=fcgZv?O*0GvGRjz0*@)e3trt5H>THgrAs}tH zP{Y~4gzG>iO|(fz{|D0`Q>`ST=r~vdE>uQ@DxJVi7RZSku7e5SV0&V2f!kL}0Av&| zZcL^L@Err7QVTD57W(e-UC88$$LkQgP`6-9R|hpJ&)q3cTHMHFbB4kls#&&3#nZ&C zZ|HNHvAKe&^m!23v6L9sz{r)b&Z>AX^QUKp09KN=U6cQ~0nj4SZrs8?ti)ft0{eetdOHjDnlQaPC`;r zdD2O5JQxW+E+F8%P7N$inoxn*v!%1hWp$dT7H7tU?fT*iL7@#PuVga8Jc(Uj;PvF8 zHGiRSD%kU0H$5#Fm_Z?5BaTr!X9<)Q;*vhm7RR!e`8G6Dr#c=-JmJVJCKM3!aWBnW z2oIAzqckf=$Ll6bX;)05FVSv~;Lz0OYl7hvZjN&!TwW+LX*UmK=Z=k`J@+S8y)oa| zFG)NUmsuOg9tu2Viw2a(i)sN&uEDQs!_3@)X%=`SNDPt_3mg9wdkg3vJZG2Cyw0oq zgPD-8h4}`OMWFm_^1kcOm4Z`6gX7zvV8 zOe+mUlEOz_czTt#1kUydC>$Q~T1KS{3qd~%-NSs->3U>{t7i4Qn(1*8_1%r@p? zXJ;ul6vBEcvVpNQmZii2rBxOwU7#Z0q#}Pg7=k)3Z!NT=u#As`ceOA$%{Q7mF{-*Z z3Ljv>DVXb|RgP{5*f%T3SplGtq|~}qq+3)3h2;4ACuFUcl(OfI%w~O@PhTq(f79>m zDMw0^M5TjUcFtZl7ZR$vmys=4W!2Af{eh-DB@PW6KexIey00AYPoNUyTSZ({EzOZ- zF;`7VgJ-R6ReMwR{g$B&uP6vmZdX+Jil&HAs}c*Z_OLJGc&^f?Fe9F$w)mH*9y8cl zxbe)MMyQ_P50|PEwbxidlMQY9TDe!VXOVbDUAr@q{ViW;ydG$dU-eS1 zv3sBXwRV$6V`J4=V|5Y2_O@QTKWjX>X4AQeY}F;owSIP_p*guhX|5;<-`_;1>;{lZ z6QZ!-fjAHjP`L$;_IoXaB97%*RFmI^$UjrxPhYoaE-y+aW{1Z$wd{Xr1!Yto;UWI? zRIYOP7oyOEK^x+S@l)|V`SP&ef(7lfczj!>pMKq(CJf(xgz70^^n0ur7TH)EIP?K{~$o*|wA0@~msmV*I<+pt8 z(=5&{LNp=K;R_@I-*^ily0jw^k@y@qBcz|8te>rzvd}9lY1$tx+jSw$q<$T0@<^@F z$`lAP0l%Y-#!H>^xu3v0_JJ;6L*Ij%iza9t zHN>Rf`20Gi_j7qKTXk>3ovUkLZyKjW%0X|ERbPf~NUm;Qc1mB_LT{nMuVUK%&rbc~ zlKr-rZJ4W)lPv^akq43=Q}X?E2N>Eqte`CNMxuLrOq@l}*$Y5Du@Py~!D1rHm z$e5NTvu4j#j9`QR{TKyhZ;FVE~SXI1te zsyhe6tlhV(eWRUL|yO$E*q0-sL*jp}THqmhoQT313K%<=TgFqGF z4<%7OAmq=4RBYFp>}ZPuq4$f!4gs?f!qM->KPG=+sjI^$je;h%7bm~x0KXqj{(yHb zAv&pQ?S1?woPLifE>6`UNh$SkjTR@rHme2Dvh)|LS=CH?!qt~Ovtx_3%}b`_4tx9# zr;c>%&D^G8*`v?k(}MvsHSDu7gb~ntg?Ku}u)`rcXe+3s1FkoldFap)JIm{U%DN{^ zwk9pI9*BU?brsH#y`Co{oPYMWN84(i(r>%+DAk8E(!WZC;37Tv!CpO?iR(VBGV~Vx9MkYjnc*h2;cOi_dizWxz{^MN8}V zOB#WTD+J4hgG(!1lWU@je+eOG&xT1%=6Euff9x%>nJgQc8e2>KdTa6vPsvPteMPWy zMT~h>A(sQFv?>|ADiZ8KxU|Z&vC25HLUW|0Q~O&h_&2NfZ;cnfGc0Fr@0ZS+d3Yb* zT{V;Bf0Xl#Ter$x7bxWrepnYE+VHPk7o*>h65EjWUiVku@K-ncda?0Ai57im(PG+u z0N3UfbQ7>`6&PTS`O`cRd&{8IyykulCAv-lALh5()Zn(!I@;19+74l`a$=;naz3(j zAzH~5UpcT?f%R_<5bXF3ZQ6#Z5)+P&b`P06wk((tP3k>gYJR;_T)1g%v$H;>XspN7 zK-l?%8x1qh4)4-+^kg@?#cq**k3eJZ#(58y!nTyiwrt3jX~c`k$iCKlpVXxc1F^Hb zxU;ynkGbelB0lphNU4|`t&{s(cbG#*>H$g30mmiEaj9muf4=R3#2yqnYg0Q<)U|lj zSybP1!Csu%zYi_l#>6~`quH!q!yD39OO50X+C0F1b$rbIuAjWMo;7izrJ0$=$C??Y7S6pN{`uP$@966?^4A>LNXhv8 zJsJ}WUFe8V7$Z9V2UhM;O*wZ?`*(ghqNcqS@c4K={^N?%XdzTH_|;4PJf92J;nv8v zNH*H55S-gTt&A%=|JFGEZQys|era)o%Ov(`ewE?Z z)n(Uou9wNmk8-H%{INCP<@^5-w=$Y$uHc0HdeqF$6_AQeP*%&_ErgA=JaKmS8kZWc zfsMh!eXijZqhDxYUs6G1XAhkTRV4L&U=0xy}>6Y9BO0ck>T?t5NaNa$I6 z!MvL1JO$wQgf^9n9Hcg-gULF*c#%$;fNV)J}t;5U5kHkFbPslAk^&6SJ z^!A`QGCt4zFq%%AV{RB^X4mMISHk;O?y;O3Z^J>7NzXspk8p*s@#g-OBBt$?;Q3{e zEb9m#v=zXeTre-$X7CWNfR|76jrL_LBnx@reJGqLz@k%arXFNh_oGeQYoCZU;x#Al z;(P7Fm#v~`KF>@j{EwqE?dxL2--0Teni{~)i_Lus*5#L{f3?4`lOIiWL1~1&2?QP9 zltgg|C#@O=|Mm=Bjw`{udY3ce!g9(q^@3`7!aSOLaqil11#W3NPoOu%^I@UU1;Q*@ z?SgozW?(B8CerYDwmHE2&4x(yrnEw7Q#+Ik<<4~Ar2+rJ;`fi|IJ7hQuP!yai+}B} zsUgb*7u)N`>$&7S8vBWVW4E`D#D>0)A}oe}XlFen{?LfHQQH94dZ5h~TGd=&Z+)$V zyS|{AipO6>yp9*GGlh}oYxkKjub^FZe{U5kQ$r}BdiwJZK9du|TF19}=13McUaEz@m&jbJb6M{m;)8&G~)C+C`J%-hjk?p9q=?f8sWuiQ2T_1QzSMVWvU#P}iC9_CNL zt>-Q;wk`*>7Ql+(OVA#OkF0MS(GK8ALL;%@$w8+XK=Vg1KnZ%VUxU%xuYkzXwyQaR z9()20Xa$Y5tS`?HI>4G|tXsUQk%Z$t(<@i)@5br)Q~lfE>yL0U==ZOK%57V~JEnLL zSrEaA0j%_DM^x2K(SE^UcyNoCC67R%;0wk&AD|ffoE<<3{8CiXwo?3#=fmCJyOA~& zfp2juCkK>h81Gp}b-OMBptjA7Q(<`-8{-TZh+%k?rldpo4-tcsr)xIrc@UaXylD8- zcv~q(ilEEoq68U%9PN;h)Mjv2@a+1DxAHil*WbsnxAC8`xs8$j__O@Pz86bXD$NWt zK4td|8>Y%dSM!!>4`thRtq&Kw;V)aSkMSM6?hZA|TCYzH%398^k3CjzCm-vjKV2Lu z`<&nY_&3MTR!%7CAC!>9#0jJp<=u7G?x*v6fnEnX9eUXj9dhC;oxDysV8x+;i68Tn zWs|fB952csv;}w2zCgJ&JHneta^vr##+gI% zl}beQf|TA$hdz6WnN3s9F=?WBi_lGmY3|<6IqE&;M>;fNc?(Ap`w2Z_^J`%zv0n)< z*|+$@uX|LwpON}u;+=mtAN1UFOuJJUr+Pe92F#4TuOCojato1EsLhV}P5qotEuT@u zD9bf#<8}M<({_{Ya`aBWT%lXs$FYb|r)=FjvrjF!Zx~dgXcq^1xdnmJ21I??kJQ`8 z>OV1`_Vc`_{and~IxzJi=1F)jp%m^qkcs{yUx;Z>B|v*3%jS$+?vI69#?AL!xyjr& zh3x8OeG@rSKY%hhd7|G{(5eeRlvwE%>C6kEyhmfQol1)IjvI#>_|CF0VkHe;+b0oH zpGMK&Gzv_ zY*~eoc-Rt^)uJ z`eKZF6SYuzYE6kP%yI)IG~LMjPtk2LR}aHurjz?)__380YC?OWw|2BLnefm|-UBY1 z`mxAsP}W%Imc4zcwyu@%&?eya>*MCf+MdA#(=3n0gWB5qc`su{b(*JKDSz`Z%$86YjUsb;?Gl_ngy#dWp6B6+@rK#Y(GCOeV1nWb^*WjB>APcl8wQ?<#$iiux(3k;>&+4IG8so@h#_NFFiI{ zQg0%VjVBG^_a_vzchYrY;x%tw&mlhdUoO3mdXM!l-r`=LJ(fLotbTI(x1RW5KHYl% z_~^Mq;d<4h3bT)$3$uUcq z&=^_$-vD)(1^kQb{Mi%ynNa?$z<{@`0o*bH9GU@KW&wiW0Ab1i{+R%ws{j#!0C9mp ziJgG2l!3C~KzY|dMVUY)vp{9mATCzj##mjF1o?(o|Mp2Wu^=^rouDUk=PzKeF>CN| zaIk59un8*Id?wg(C)geca=Z%u!wPbi0lDOZ+)yCb9gxQr$SWcEM}835j2{SS`<=-c zf&#;Kz~R7M_+6*=ehRx^07E!R}9oQ-`Y@PL2uMA{c)3GN%Y#;Sx?hHEyhM%&AzYV7K zneyZd4!gc`{EIS;nF@d02}i??z+j8Ol8wMIkH8Iyz%Pg(?2RDajUdI1e8v_@E*nW< z9!V7vd7mFHixxqL8^y>L^;$NH$vlcBB#NydilaA*Yd7jGZZt1jH1|8%=nv-6{2|eT z1<}I2(W1N2Pa)LLY%yPCW4@ZlNQcD87R1Q+#whN_DC5R{<7A6fla2jh9;@+GMJ#IBJl-WF-mM_sqc`4b zH{J&~!H+E=KsF)BJOLDv5W-xL0O?JD?IwidCPuO)M$0C~nkPP`P!kIhlY0~4yNPMI zNf~TOf3jqga?F$RLXrv!l8SnhN_LaVaFZ+8lB;BsYZ4PXM}poYBpc=@7mXxy)hGYF zO77@Q?qW;nrb_AEO-41O^z5ee;ie4Y!aJ{1#?0ZPA@B)7_&_3jp#VO84PVrPuQkBe zsZu-5lXrKYPV7?;cTKs9*I!Obtr$0MML%T`C z6iUakfZ!#i6S}7p+@upVrlYfG0I4&ug)%6TGA6S!UNmNq7pBwWWs(VHQfp_voXcRu z%b?lIc$Jj-<|dPwI*YY0lU*+3-CpMVB(%}4G|{B+G7{6z?Ac%BvcFnnONVAZrBLPj zvK9BTmGN@Eg=QaZXZ>LJnU641EzHsB%hB7*F~H0H$(}1|0amxjO^IhDGlZ;k>0%k$XF^TNybVbAxI%MY-~4+_l(73POf=i$WVn&K6J zf!XO&1+f+d@u3BYg$1t3eFgBnf;7Cs4EDk-xx$<|bE=PM{p1Bja(?1_=zzcCD^@QKGnDv)g~0x?&UVl6(QJ*kfBAbjYVB@#XZ!;-ATnLq2m6%;z7KUmYd=+>XLT3 zl4*;Q*}~$<#*&%2lDWR(CF;@@p^^>ll1e#AE=A(8_NKeOZ z9u;`~73ee-t%d*bD3CXmGc;Aany+BgDd$M8yx9{!-9z8&^F6ev%E`;UxiMeCDB~6{ zJJl8~>oYGatPcYQYcU*^OvM*-i zE_$+OE()m+)~OLmt_K%YL+9&5J&GddtLnpb%3?UG0wE2F_zf|Z4axlt@%s%Bj>gE_ zCu3)0LULo4a6`U)V|r6#?rmc_$CJ^ah~}!UfyTZ}yxwrWK1seQm&F2sY#D2>qr1iR}^|rtDe!q3I-+WHK6+^+3?4%V3io`8O;twFNIm%95 z<&mVGZWr@Na;r89Xd4xxwIZ~M?4WJ=qLiAm{k1|nlU2J!G8)}o8waQLbIP~Y#GY`H-kcO$8`{2t;=Z<9@d&G@`PM2G(O0F=AC}yEhTq>%te>yj z->lH@3F&X-91yC-Yq1)rxf7Qc8R#dFM|chl>-N^}^bIHs<}>$?KnMLr`o?t!7dg#_ zQU+IPd!C3ejU}t0*MA1riib=ViWWSFjy->^AcoEdy8i7Aoj`}1bO-(p4BLA4UlWW> zDfHeejQsW-MuUww#|~i)j)>h2;}VXlG7l3fjtY5>kitgqsz=BNN8jF!e5WED+y5|1 zr#SY;bCeM_Hi#Hz8XV)f8)YLLclj{Jr8rLLImQbc2Um}N7#t7t92X><`1yWZRB__| zcKkDJLa=)L>)-^D=Y%ZbZgut~?Ni64WLJGWz6gj10BlX{9%>mHLo zVN*(9x_=K&9o$Zu6HXhxpR!h*PVktrfBFDcPW~C3ZuOXUBb>?oIO3%^({(%T2b;mJ zoDNc)v6Y9FV0mBQsxuR>ow0+m&^;lpQ&q}t(0z_Z&Q>?Je-e)&$q)CMo#Ct1}Bj4 zg<-hjK=W)5;S+I7)z7skdblt(xY#1SFgrLsvADPnGxj4`GHzbnC|N4}vlPU&B(qPr zOp}>=s5rArw|q&5SF5=E{(kA|aM_V>`BrcG;(i%R%mZD2`rdj4Utettw(?bRg@Ah% zn!ZfDG>Rv-N_8aJY`v;WxB8WuXbtCTk#1>%l75XP)e}u=jecpCwRDxdWsSgkjaz9w zG;fV}Y2u~+IswtTz|!!$mi5mMz?sAKyW(|;hhE3&zDSL$vv$;!q@ppUW-m}GG`w)@6+SHAt z)crl(eZ<2|{loqo?Lm9kOr`z-#Pi@2ifE>@<-irOKSZ?LM}O$Qa4q@%ka9Se zIyJR)=z4d!LO-=2c66Y7w3Rx!mU?81I69!8+<7?C6FK~JX0v#te{4c~d=)nFx8+#P z^Z4Om;fDT1Z{Zj-Z5-fpq9SsF_jUn$`9v+{goJpU;M1wp-N}oxd9t)q72VTx+LvRL zkEf!D(^t!L^af{Av}dg0qiBk_C;iXIXWmYoUtJeS`5`Z%?T2du0EAp=s-KBCXuEWcJ zZI8tlvexG(a-5pPf0L2_M4NDvM$3qvSG6X;#~wjDyWx;Qq9vCDWDKuj3@$`HE^t#W zV8p{##~0CW0incY$%cS9xdsNqTU@*PxoAE#yPK-;tM=BbQiGN7$HK_R3s}=lnfrCp za#`{`TbeTBr0>1#ay?qad#AS;GAfND4!pP|=&u|B$j3t76#%>vzz|Wt@`-ERhkNT7 zeW&cF|(7z4<|%>!gbM+eeTadVSkgV2R|p)oPMSoje{B(ZQ@#zb^{ zG+Z7K8YC7sAt?o(i=Gev4}M!iM#cd^G`F-O+uA!iySjT&z5j*Z4v&nEJ&sRIPEF6u z{=dYxGaLOWhd3=P5?zL&SHOWx0iBPLOJ;;^3xy>@ zbovXk{EK`zHdu8|x~%XA1(n00L(l)=Z~rgtt?sI<%%2)=F6=9twy3YY2G{@q+M}ZD zB?)Gm#yevnwMj!^bZ_~3X!;C^N{mHUH= z3>FxKPg2D({rH^85F1FwLwiMY`t6J9YNS4$1T_W#noJ!zqB!zR+LSUwdGtg6Hw*wp z+TD*i`RP(OIqlU#gW;Cu|Hf~XTU+jLucrG%hhOeD%P?ZHuK12{kZ8(_Pp|hf-`ex3mw2f>wv~Ytvay^k_eH>42dN%q*UWBFOS!c zcz!~sPWf449GsYL7b)}NdmITS$Pk6{VFYZwiIV&AiWXan2=ljCfaL75o7QgBa=c+M z+g5^cocUIwS#H5r(je-`Qk?o~(0Ga`+WT!u&7m_ez;~7@ib%8VjX!|$dnHQ{grJ)C zC)FESJ79VYEp>ioCeNOkg95poOxpWY|8;5bPx5vskF}i)dI&RxkkMvsfTk$yf}$o=&8vRT>fe_yJJ;_=9bVvqbumpVB6dn9{iBL9|&&gsPgwM<;-vOtE`d#~!D*?fTfeo## z@Bx7@W>m^4s%PkMsIS>!2ls)w4ompJNEN)9vDP#n`li_ln;%CQNI(*J;yAAU7+U0j zPMOk{|1+?WwV)?7eHfu1;Pz^g$psNBT4%2RMRv{iubE66rN6BLbRQNhMjHUI5EqqefCuc zY>bA{5JPmP$%~B%`1V!J;2W7hP2o=_FHfWAuaA3b40?Sd!UER-J@39!Z~Z1keV*h* zVPgtlc4Z0`8X03aYyz*+^<6V#r@)eq!wyY%+!w|vpA&W5id^Cj9CNbTP@ z1{DSyLGa%cNC0M@6+#os)n=6UHVZx{G@Hhdg+RX2f zSkycUbFF3KmgkPA%xnqag?lwt%Ik$xZdSm!*&e9)&iun>0eIuLAk?9lX+t>82 z0*7#7i-ombDJQ8~tO;fYm9e%}G6ph?YB_K8l^@yc*J4=Ad{LX`FG=XMs?-IUqVYE% z&*!On*!tuoDTKeK@ujn&t&qG@L63`$=m_XI#NfqX%Eu>|8~Ok9Ol|Ka{GBV@r_>nV z-=lYHV=|+8kbQkJOrT>|V`6jNDB1L+1eDNe`}t4CBv}zt!Td~3`Hyz+WZx~j^j+jHx!1k}4@}QRe9;2T%n>C| z_-#3?#L5&B2EtYgV;SI@XWxhyG#yqB&Ylp+W@Ht4_Mu?1Pkkkf?_Vo{Q`Q@9r%a*HH zYmI>GA3}JN#pgG=Hm3D0M1lKk*yT#1Pp^5LYU8KV!M;v`=`Wa8R>jSu35N31)1Hix zE_8M$pY3Ec|J;Q6Mr{_bQ$|cFeg}eaM?SNfh#MVVJ!4oU!((&0if?jufR+rFMSv7G zUeQp&YJVSsEjE{NwWTsaO`(QiQ+cTw*1fE>d7pmry;^y-6-cSq8RmOMQQdpc1ppdy z73+p9hiaJuDhS_VXGMHh_}e{oi~qsp8I~ftj$Mm1QT2a%OcD^Q*1O+)l%&$A9xp2+ z@-fEGwbCY{NO93mi+GGty8-Jg z#bz_ceM!ZRnfw=dSvTt0RiC^@a==#H`s^ZQu0 zoaYMP@2Y5M^|lh8&}`cebG~hy+Qt4W;+V3c5Z*Y`qjXW3RrK zaf*UDNW55FH>n?PTDly%yi!=)(5tKKYvsCZ?+`iJV0NFm9lGikemLC~zb!0jvGJ6>b!3F{)iubX`|sW zCa-sYrlEWDOyqL<7Up@}^mu<_P6=c z?706*($4rtYk0n>cdF6wf3F$vaYpLg)S)ZSpL-_2ciOCJ%Yo4~KvFYMYKEUh)Aw(; zGv`i#5|9OcX=i`wE3Fx%?i!@3;kOO)>;LGXd==C@_2;vu{TFbMiEFS~0<)}uz&BKo zHBPXdCPn9^^Eidg_k`el*94F|>W4PU8*@2U}DGTgkXk zQ3eED1@80&sDZ)Igpj2Akb{VTzuSMd)Irgdpye}ACZ$aRDx|PJw78xy>{_q`>8pnhT!m8Xc=*ovc4`i83ULQembZ7P*rU=Dq{fjrXT!2|ICxIBU2*sYowe!$W3Z zV`@Gv6yf8*(A6EQZE!fIRs<6&tee>zupQnB_P*Tq#W0T`6O07phlc~h&QTFHnxSkU z=sTH62CAqJ;D|$%^Ygq&!fRNtM)<4QDDGXa7Y*TX3NUL!WD!M#DtYAFhG>yq*Efj~ z%-6vm1VhkVqI7zqL_-o|H0$}&Wm2} zBcHstxDqwJiN0pw?pc2qEvLTOlqc=&;Ce#sZ0xvTm{V}FX94_4dyDv+qI+pKb`8hX zgloINx1OxGi9zjdF&jOp(NvyBf@y!Aw6`)dAeWxkks(j(E!W9z+LQIRyB+>VFdgeA zz1cY(@5y>A$Gl^n+yTs>?t4k5{U7P=zqWLp0zaCk0e3uu;Xl$_dWxioRP#(8ZT{wl z%=g+^rl2HAenZ~ItREF=AD^7Jd-Pm;nHIfSQrdR!sIyI;q_@Y6nXmU!r7beW`m(=k z=NR#3gEX?gxo5LltH|QznkD5P`sQ@krv7fs&@0S!5X$Rh$mPDt z)e*{ZYRvP*BecSEL}yQR-HWxK%L`7*b?!?H=*xPN->Qb)hi2zd7uC<@6=@gM z%;h)S6p0HJA>@k6UDaq(VJ&s_0K_ToP6!mhpI3HIb(>JoV0e)8T zg7LnRwZ??$n-WN2$$C;LioIZKu2h4xbdS2sSgZ6%J9kHsz3j@OHKQdRI=9=p0n=JfnpkuE`e2U;YvFp= zy*gRYe0>f}eJCU6n_|Ntu$wXR0 zZjlEC$Y&{S)oiUV?$Sy3irW~hjBys)l)-I`p6zuT$TthA%+Pl36e$NrG${tm&m$Oo zPt$G0|5TR#C$jXv8%uYcQdS*O3}~c09SsanRm7n^i$HPmpg{)EpUO&M1E{E{7@cl!jAw6LN^b(9H))|a z<*qlCwl7_zFY^h#_3X<{>B~p-6)yA@-}PlG^oFFMf_3{(uzMp?K%9(*{fJ__`h|WZ z?LfQ8K&S3Nx90#VWuWg#?YJ=TgyD|R(+-Y_3{L0{PI(T_qzuj>1{W3vm+l5vXouEB zhBoMQhqgS2c2b7+5JLwGLq~T*C$z(7BEx@mhc7&bX|Ga-ZxF+G3&Rh0!%y5Ty66a| z-UznW2oOGkN7Fn)usA|=KSDw`N+vq`TyOM+*C-`?l)8D8mU?lN{(h8!ZtRul*c-hu zX0I_;_!xWh80X>`nEQT=hi;rtbo{;E_(!jC0r>bK1KQ^Ze6@jbF*v@cSBJRb#FzVV zX|9RSdJ{6u6S9L7a@G?{qLZq6lZw`p-@GQZ;ghK0I zbjs3t>NkAK>VC=^jxPio_4k~1g-^RTPkSy-+vrYH98UX-&IIbs1bfZ=93q%`;nf8| zq9P7v!U<+0q5V;dsBp#E1h3g7y4l3R*<{h#H0#+^(YbWRxlFyee6P7ez1iZ$*@DHn z48r+buKBXVxgzWNYOnbU*nIuqOsL|h^Es-SVF4k&V7`UwfGu>DEc7(@_a80{!WTxY z7bXT5CPf#g;EOZOi(`w6!~b6ZUl^e0?cLG+-QpeI`=gr>f?cVFn z-R~{m@$KC${M*L70`TCAzK!1z4G)B@iTCZ_!z>s4ZK(uq&;Wkm0iNImZQ#8);6M)E z%L^Xi2tMHg4dJ|4;TXQq9KMSZ-r&&Wa&^;&n{ot(f90KF~0(ED0b3 z$Rz+HAOa*XXe<97<1g&u%VH{b@&U~u4FvE4FHjCgp5%Em<39e+PF^e-KpY1V03^@> zJm6N4KmgJJN#)?>dDG)he!@VGERYlxz)(#BU;#z`3}8MEZ+AHn6$V(5f!&#Z1Hrz8r-peiE3=;5^In*sneFzLOJKrz5N zhcE&y;0wDh08IW1eZ>pKJ`$9kT^>>E%`)h%F2QA9EL(E{#y~Y|ZtF?|4F7@bxeo3I zfB_b8NEsVI0?5(rq8(=Qrc+ z;m!;2-s^|J=HM;>i_Qz_)$UF6?gh^aP~sR8KP${4?fPE8{5~c;(dNX!6AjNo>h25Q zF7bQ<3_LOKz7Xq!#0z@?@g?6b1YcUeaPrI2@f+XHIv*wnKMW4i3j=@i6aVux9}Jl= z0=@9?Dj)F(pYpx{3kCoKC8O~>-@nz)CLezcD=+j;Z}LSy@dhvUV-NK)-}EzY^vK0O zO!O>0fA#77_Foe80DlMq|Mk6q^JNe5P`~g?|LNeK_GE7l;}r=kzbsY{_wXzC`SBn3 zP7K&$QBx4i71I;7C_=-=?q<K`|Pay_OVQzPV^c9?h6nM$Wi=H5B(0W6t@2IzOY@)4-A>$0KmW<$DS;SZ~ftm z{q~_zhu;fw;$FT09e~gQFbL4XA%I9Y$bgukI5>l7;7FLUfG9|KX+b&IF@iYr$SBY` z=+OAcg9!8J`U)E>J4;(@dke54v4Y#{`wJW_JWO0{e2ko|yv*F}{0to}JxyJ0eT|*1 zz0KY2{S6+Tt;;KJUTJ_R;1B=d*s5e>OX5 zHm%yVY}>kh3pcLZxpeE=eTnuiumlihjO52rR)BsBbQopoDua%ZUb_NFus9UR#5ykV z3*gA{BgKprC1CJ$24=m!BATfdGSuj%Z`m~%h&{JZ$^M)_KbR9=Z?mRfF^)|6b5=H-`Sj!9;jX8IE5nQ*14W}9xl z31^Z>#>pC+bXMMpXP$auXy=|bS?On>f(}Zkf_)aM-F%2H%4nmGVt44HC`Br1rIucb z=}&WJirkW%ehO-+qT)p9sMr8%YO1QP%IY%e2&9i*2?RTB~h<&2|fJxZ<7|?6}NuYi_#g zt~(vN;iAfJyzZ%oKv|3+XC>x3NOs?Zr?T>3%n3d zOmW3Gy(&Qv67;Zf#~gp`amXN#JWIwRpFA?k9+#|f$}X$yvMeUQOf$?I(=xNoHOCC* z!xsMxbkHXGJd9~ZAB}X25(kef#b|YcMfAeSf?2ugp(>?e@EZfBvxVpCQwr4}AdimIMZvKHojB zZtrWL0|!<;|NW1F+LK`K6!nSx{!W63ORsGu__88d0 z4T>&?7TjDB?0C;EIa0A}Dz%%1+7hl)UVvq{1i+P38}cmn7yB zIax?UelnSH%w@sq=*wtI(@&BtCLpc3OJULvnT9;%F2}~qT7J_l&YY$>&xuNC_ClNA zGo~st*-A35@s>DTW;f-i&L8Fzo&4Mfp|yyif&Nzq}>GohhGBu8<%P_vK{q9iRTf~t8@c*0Vi7+q;aWeG=dR*Rk^Eay#Z zYSNtQbVMk9r!3gHJ$D8)7BN-mH;dZQhO%X(J8h~{(F0IQ3e=|&%jh=iMoXjOjHVyN zX-$~=)Ub*bQVIl*RoiJ%s=@-O*W>0vJ$g^CE-0j9t*c!L)YG97^{hcnt6uNNRRr4e ztx8>sTr@UhnG;799e z%hT3Wv(GGSXI~53>_`+CY%OhMLF-tTY80lFB_Lc0OE=fXcDTfC4r6h(kbsUY0uDl1P&%e~vOx4Z70Y<(*$UYYhxy8LahgWrjN%ikIK?bpaRE~&pBBq_#WY6oiwDt@64!XgGq$mR z2`K>~2qDNr7BZ2KY~&;td6`HqvXh$(=tWC<(uQvIqh%3kN^2U?mbS_xCIISCi+a?gF14s>Vd_+?deo@-aG+le z>ts~9ShJ3`tpi=_U*-DNyvFiuDYjbcUJDz3Z>P3< z&hP%R+uXgPd!Off;3mcPI{Egvgx?sPpli71?ri6UZ%E(-ued|kTT!1Rj{?Etcm+HT z@{WuA;~l?o=N#_vblXSe7H>ImBdk%4JNKPV$lGeCIgFXU$1OahD4{ni!9{ z)SfF#`GS<$2MYSopKeNn|3i7^0`_Mk<+0Lc03k+xhh4vL8HqemWO@q zPlfqayIyCr&m8Rm*L*VD?=DuiqqU`SH*nv{Wp}&_ey38u1=YcKXK->ow|@tE;0F(R zlIp#$ddgo7{x6j))^yB3%<~h81{=^>jwC82v z5e8ty!yU&P2H~BNx^#QbzV^g_r0N6AdIkTk_cj%ij8&LcrGJW^Yf3Vxv;_%?#z2a%kc+vxY`{=L#^zF`m+o6&CtbD!i=}#cg zXMY12J>RE);U|Ar$A2JVegx=%>|%iWgMgy9+4ETdV$b<^RffJ@iAjo_q^m|j-Lra(*DAgy;BAqym<2Y@sh%MDWNRH_UC+KJ??5K|KC@qOs zG42SD^_Wh(I4|O8kNcP@^caig$d3W}O8`kN{wRYy^pr4JjxF*(mzg=#Udh zO$-Sw5J{04IV0XSZykAVv5+zmAORvtk|Sx7C5e(I`H`?dZ{Q}9E9nX?S#B`d4KfLC zG}#RpsgXM=eC}sN5$AqD8I(ZzZ~PW-9Z7Hf#$`i;k38v=gCdD%rj&9sm2Xp&a8{LU zla*;>4ifp4UpXd{lyC}{cs((ek7I}kcb15AV`jOQY?*jMsZuD1auVk%773Pli6vwA zfuZ$BLrH{?wR=bOmr_HQDfdc?=P*sFmyPKr!B}AfC53@$Ta#%`iAb4McXaRFDncp>>(J7h0 zS(pSCg~SPm708*$xt-nVo<-7}OvGH&d3DlRXZG2D)(M+#)Sk@JpYI8v%;|;nX+ZeN zd%o#{;t5}Qn0=kupzIl-5gI4)iAf5IngvRr_nDzcsE0h*p@Np2wGt|#9`c>_Ii40; zp5_^xDC$ba8HVrYnM^sNGfJ8Rx}X?(qtltA2Wp^VIGa@1hN$(QxNw^^YNQoHp+Y5D zk-4Mb)T7slU_n}fGJ2#{x|cS3p*otS8rr1!=bGqweQy}1;GmOMYNo_dqBy#uEGkcy zd7@0ZreX?%atczPX$@t1rg{1sN&1|-MV`n7o*TMlKYE2y8eT&hA0euzi&`IiI-gDI zq=M?DIXHn2I(i?v4S2cAsGABKj{2tq2B%AEqL2!uzGbM`Nu`{ssyw2QXwsUL>Za$( zrWrS@M(3dE$)L4lr`4dTtLm!@whWWXhMx+ZDk`dKdaP1JsAO2FxY`Xx`m53U9iCdN zWLK?wcde#Lr$P#*b4sR)Dy`w#E?AnZT3W7Lx}}jys{83p-Rg!ADz5RWgu6O*uR1}o z8mFP!qS-2#?#ivVs!;Sg4ZSL_1uIhI3aEj)t_ZrC{JMIZiKz#=iPr$F&jxF;Xo|3) zs<7#*uIVbQb_lTEs-6)$4&ZvRCtD&JTdex3v7_p-_6oAwI;A1|uHTBND2uZ{(y9l? zuo8%`Dax|6YOMQ8v;Qiyh?uH5tF-z-twYqK(JiPwu`w8Ml*d8TV~sGGXb_>7a7wKp5Ov8xxal8m!Ri<(x9vrxOYyJ@*A3%kp^l;*pz0KCB~ zX~Sz^s_P1?#=NPPw6Y7mN&>qD(!ABnYH1d|**mTRyD*Haz2A!r-m51S3%=z$SJQir zT5G=RyAtA?C+<|szVW+$kjpUeE5G+US?GH$48~zB#$jx7rZL86jK*F(#1M>3DJ;$F zOwHZgQP*rJ_?*uZ%+0_|&m5YzJ}AxujmnYSfln&R8)eUdywD82$?;r$mAuSKXS65s z%oZ)h`&`Xg$Iq>d!XS;(vgWncI?|cf(y&+2WEInFDecC+%%g$|qX~`B%S@sIP1E>` z(Fr=wBCD(uO(7rs(*(WJeOc1h+|(qU$w@8M+G);mJgiUc(*9htjXc#|9l2Ev(aXAs zu4mK}lFeSds6pMZL%r2SJ)NBEpg4) z*=y4sd#=J9v89^T?OfQ84P%~c)iLY07Rb@;Od?1P+3YJxL0|-)4cebA+M!L_qiqDe zDP%_=1gg#2tL@sY4co5WmZp)~uWj40joXTQ(wkk_szVOD0NlSV+`&!Uoy)q(d)$uA z+kAc4J)yh?^xSn8-AXvq%PqgJf&@s=0^7~q-R<4q{oUD(1hXLC+1&t0aNg;S-s`R2 z?Y-XOEeqw%-t+C=_5I%2t$EcoZQYzr*uc`?`^~m{eI^19-~oEr@j~DQ{+7DNBpAPDwF6x~w>50zhLFn8@qcIvIl&$XSuMX?6F6*=I>NE_%H#lzYgrAi|d4(-QYa|+noU|knGE@?9I;X&+hEe4(-z}?bS~0*KY0Ej_up7 z?cL7p-|p?<4({VF?&VJI=Wg!lPVLPO-YqcR!@kQb0Ppio?-=!N@Ar=H`L6H#&hP#1 z@Ba?)0Wa_aPw*H8Z}10?@CmQ*3(xQk@9+-~@ewca6HoCN6>squkMS9=@f*+a9q;iU z5Aq=|@*_|37A0@;Cy(+euktI;@-6T3FAwuEFY_}OPxCcz^EZ$4Ij{3O&+|R+^FI&t zK`-6>kssaWD6CPxp0i_jix?d9U|-&-WF5@ArQX_<=9@ agHQN{Z}^9g_=&Ih2#e47jsNz60029}9H + + + + + + + + + + Without CUDA Graphs + With CUDA Graphs + + Launch 1 + + Kernel 1 + + Launch 2 + + Kernel 2 + + Launch 3 + + Kernel 3 + + + Launch Graph 1 + + Kernel 1 + + Kernel 2 + + Kernel 3 + + + diff --git a/docs/examples/te_gemma/media/transformer_cuda_graphed.png b/docs/examples/te_gemma/media/transformer_cuda_graphed.png new file mode 100644 index 0000000000000000000000000000000000000000..cf22822baff5b8c19a377c5d4e0d23d3225b0c8b GIT binary patch literal 369694 zcmeFZXH-*L*9NL58bO1iD7}daBE1OGQL!M^3W9{*#DMe?0TGlADhSdM6jVy2M!IyQ zL+GI>NDV~@36hX-SMZ$Aa?bmnZ`?8NxPQJq1|vH=JA2JF*E65lRtVPBzQC}XefPF) z+Zfa@s$JQ(jUKXX+YVJaD7d0K)sF)&+nujmII}IcopTnv*-4s5d69cIOm~9! z^mi}bbl$d2@&oN}`*8~c?6z&2uhi8}U-N*^Q{0~)x}H|aQ004XvtnsU_Q;!$a@I%R zq~3q-%r5odV*BYcm=Tw#(+3|Xy41u;dA7&hwQAsGjDGpYgB>5`E}Y&utvFV6d+kx% z(DjP5{fncYD^a-0y3s>7aty8)x#d{vQB*ef0U3ThUNW3f14paw_MfhI??tNa*bfeg zt2=(Wd{JS3`+nNpb(;Hsy3D)qZkHM(?RKeiKV80GSj)|Ldh6i_e}A~z)_t>&My8L! zejZ^qM2XidGGXh}mx^vFvEqzBJ@KUw@qUn!{U5YTk5>k~;WetX(?5^LxdRYB;(L19 zTh(UsHBWlTPs6H~n(Yf3nBd>KSupbd2L3-){~x%%o0Iwnfs}hrBg&=Z#j*wUZQTbj z#BeFA*1Ns;ppS2N=5D{VtE@wVmF`b7qrG$=LaNqr%c`H4&xaD6KlSWD%ufoW*$?+k z99qfC6uLe1^aJdMz*SFPZ#wAF;fTmIX-vZUMjiY@8n&cD)La;A(ZLaEMSkGQ)>6v! z(5N!?qi(VO#SAxwW03xjU+P8$9&XwApCi8H+LH6>wf79UY!nz|utoGE;XT8_A)TGO ziwf^z_t{(S3S!2zn=@SUu#K5$%)2cn`+i={6xG72!C05vFQy^%_)1TS*b{QK&eJ@j zvmZaWO3LM}slWJXM%!!yx0vVXA-I37&)#d;7VSE(VJLk?%u-QXLpM9Cn~;isBBPkk z!((?svHS~@>v^qd@A@rW6a+j>I}+&31$7)96*>aq=a1TzJ$hQ5L&;yd#Dh#ZNb?nc z9gp3%7)rg{G`7FPsf=NCkmT3zRsJqhDrPD*&2U|#K=E<>!9y;1jK^7LqElG#ooQle zSYCj`yeE^Ty3&gVI0_-`diCj1BpAXEXW?YF6m{k&j@w4#ZQ8oX4?_pSd|+Yc@ke*5 z+%?p>>>ohPvKnXGf$p~sL$?nTPhXx4>^{b}rdQ}kf7MM2(b>7)W_=z#@lhtZ*j}vm zd6QQ`zh&5@K6;u##94X>e@-%4KAagXW0%n6yS16m8`#0HRb%#lTbx$Z6j_h3m;*gcV#n|8ymor zE{}H8rF$&YTHJ1wfm;%@BKSmjo<>^OqvFCUbavDH_Fu=A2Sg4Z6~mlunG74JiL zY6-hW(}lc^wYbx-D|xiVlO?zDKiFHe2SX6EweA4?#@)_#glb0zftA87?p6!8cI?C5 z5cqT8dC46j?aS8=+E+i(@3ZcRodM7}q;ZphSRQ;F)n6jpm+$qtc%u_>F{YR^#3(s+ zIobHg$TNO-+cS|m`_H(7j# z%vz)(8^a7=7>X=C4!{@;&-5M*+lh8|$WQbgX`a%H8y<3c>HXyfw|SbQp@ZanikiuP zSQLX{=+L-G6*h|bOrA=CN+dIypvKNA{k~1FXWI)Ugd+2GQFN($g7T>kHad28QzEyK z?@;gYpc`e^2gQ7@jU})z8)F zx|79XkGreDZ7vvzCD!@hv=$~eqHyMtCJu)G$a*61!@(ZV?0L%ph?_mc)hD}=we3?f zHr41deh|};ZnW#g``$pM3d6V~wqi%%!_6$A%kx-XVtz=OBFb&?oFiJv#^a$)5V(T1^0y9Q@HThYlwa*QW-g#);@MF ziKD#vQ0O_4s^i0b2y|##af=3$t?TKXN%{jF@@M$C!pe(bZ*XeGWPtk>?mrl_E6mlG30}>3OdS=ludh|`3=aZ+t!Gmt zywwpKCTb}%XSM><4?#zm*WU?6NDoY&WJY^4Kdi{r$2BFRrJL=-Sden(QQnq1dWnPv z4O=vB+$2`6kjj^Iwt}esa^0-Uj_cg#VwFqy(&#&rf{XHJ!>u=u?jaWcqXW1yAv3^? zlX=fzGUUOKR{y~8W;$1=vyk~^NIS7`S^wH#d2xMKvRTn|-b|4IIN^G68@Dd-y=$=x%hXZ}rFqxJ2TBfUEjU+LeSxt#V%i=T z@hW$N_d2$yTGaVn?#NtkAA-!+X00qsu(Z=*u~hG!SUQI= zR-aF?s2+{@M+bAe9NKHf_A-B(@D_rM*YBT@VPsh_foJ36od$>S9R83fd*+8&-WP^! z{*x|-$?FOIig|lW*+UF&yexIXCU>p?8LPSwa0JywDcXYm5S=e}M zLf8QI^6|gu9lWXznQr9?>f!yo7QzrWVmTag%N}A@e=|}7*gVAiCb~L#Y)0lAD}%kO zDaS(oNO2o$3(@mgqn&Z-<(^c#Laf=l>2Oyc4wMY{9(w|l#>hby z%%NSVmDnCqsK3#aS}YbGn0!_Q~8 zINI%d(cr~RG=LXcuEeX1t_zQ5SH<-;Ghy42`(+4~1*aZ8ZuIHz`aBO74@PC*7|75af1~GbD3=x|b^tyZ zoKXp$n$)oem0$P3Q0b7hZ~X(kb#p}3D?b3|{sbDwe@vF`nmqYFDIts&J%tfG@}?s# zDKu4^iy)7~UD>j_qN|eyhSu-2Di$x(H2(B8J`w8^Y*peDn|X?0-oYdtgV!iE>G$he z0UiO?-_QE;eIK@(FXZ#(~Xkz>qnlI9e-;17?T4Jxf-E6 zKgpi{#u>?NUQW0ZgtkAkjNFnn{Y^`EgQ#;Ft_`l2O|YPG2m3r>XE+p9ykE$E+E0H@ zguA>Mo@hRTP#BkpmmI*WDH!uDzs1<7_jrnqPBti{u++cdL3#Se=f@DKVuJGs*`XwJ zT50Cbqwbn4qDB>pQ~=`H4`R+MT%4Bqpx5)EKq*J!#xOKE)F6`s5JuWu=JDY<-XW56 ztzD|*0eixgM5h2Vv3^#qBrKkFBFFk{xhqr17+Hw9yqQPKC`QIDj>*V@SZbCdutnk~ zD6tsV(xMjqV%Z?9Q<})x@}p@+Wl0@CXneSmP9c8X@1~CKRpZkY#VQ5|diM^;aA=hb zPYG#zmJi)3^yoz)?)pYwLms!CUpV=<3FOqM+Cz-skOF&u7dU39PLg-JIN1X*_MnUy z4H|oi>qSPV&E8;iH(s-d4aab`$qfB2wDA?3*j}2@_DG`!{h@0Lu(ZAk)GTc>;xB<6}fT9;p zgrC+%dYNYv#vd9cT_^|#7V3c;WONt_7+<4?972v9tDt(&D8mU0;{ERA1sc(zirKW{ zDyA5q7tDI$+G*$!6BOz_4)(}8t=FzBy*jy3KM{ZJaa!{s)E12O48sG)CWFzOGVqq* z2AhZli^kE6U)=r`W(B%vH&b5Jc2V?pZQ$^C%?v}=LAozmdhLX>^m~MF zwkvB$e`RQZQiVjH*K$oxd-~w{U-2AbXw*>30BmoRW_uu9Ss0`lAK`XFSM*?x7JrSJ z8LH)Z6RCLlZLQ6_X$j^v#YxvhmVO(@tbHoPKs|-X@4~`8M8w!bF*h!8APs-b(QeN3 z>JB~hgAt2a7yN9|+vSn*7vD3am59_PH?Z!r_s_P+3GGJun3qGm>|OQQ{S%3+4f?qH zJ!ttzD~<)vu;spj!=t72^-GjPLiVmVhrNP#7I0&2-nAq&dEtjT@(Qg1m}B}s_2q_j z^`Mk=(!yR*&FZs^Vh_RzwD7&bhM2$sCosN^Yl!ZjDdI+%^BN4#)kYF|97;Asb;>7$ z0&eCC+DVhUPzXV-`GsE{p=F-Gsx2GTiQ`_uF8>i|`FJ(_Q3WTUN(GOp8&=3NKGP0| zRrcucfE27%ko#jy@+hlag~#U`(_MGc0Z{{ft}~Q<@vWlDtR6u*Uq!rr*&|0G-rKAr z>bvwkPm~<^XrMkMeE?2Ax}V>PTv2=^#`^~6a39YCaR|?rF2{doxS7!)Mt`mb#~K7^ zTFsE>-2T6We`a#32_BiSTo~8gF#724IEgV$wSl(Gcg%OB*NX@rkQbjKdJ03*Uwq3G zM*5VM=Y5jVr77K~CzNGb6r0@0>Y&6io|B-vF=`BQwyKxSpNbcN3WE(f#QFa~Ii1H$ zJs7~VGhJ_lFG$9(4B?LsZ6V~fJUXx6E-Zrz(jWj@dJZI8{mmG?gb-%5c~hkK=i*_T z=1Cc2$L6cUq~JiC2d9BQq5hM!y#5tR4V`}KMI6nJ)#ob;<@Q zRY@doH}uCAHyTDh^4KLu5xsDn=oH@7Q}RRGNAr`@>T?)4(%)dlX&gs#&XP5EBxbU} z2}~`eAessu>R68H&r-S&FJ=Hpo@uaQJSTAc%CJ}Abg1PhuMZCo9D7j2esVdTQYIoB z?^M#Jm$-5e?GvJ5TjW}m*JAv4`j7$iv5kcjCOTK3A1p0yy;*U}|GvHd-Mkp!zFC)e zlMR)at;u}7uUPCSgRClT-o@I6{pkrqJ#hR7(OxIosUf@2l1EE9;jSzTc@x*CY^Ec3 zqRof5;tjPPkawQH3P;Hk_Mp8q)wp&n4<)PRVm=j(+T=}3HFzbjl;u_oT`zp!tM1s$ zR0q$HhWn zyX@sp#on(Znr2TLmtzmrA;@Whk=4nW?l#VPTa_gscx_< zORGnWv7jl)XM*tyt7VFC%isVdZ)v>8=VGm{o^z6>u~oW1-YX)X~Eh=HFShd z!R#UCRuuaef%>})&N>>xvmtkh8}))wST~kT(a<65CxEyb3}j{O30w=rLHr@KqP?X= z!vo+MLGrv`dbD`vAwX);B8W>b!$HVxSC%&|ODahT1wOlP;$&5dtL?j3hp-}1^A|@! z-j-ESmoJJO)t{CHL3BZwZuy*qLHU%9T3p)A3~YH9rvTi8OXbPC%t$M4v5nD;?hm5A zzoY|#pQwQK!jt=>aTQMB)g??5cS|Wi>*1AkE z3Cj>IFYa|HAJXSrHn1*qWfFnof(f_mJ9u>D7S<(KLul?|l=7umE3dc$L^}TV{qo zgd^%%b(98L#2d=mPPY=z-45;Q^HWNZNL~mF(5v7#Gddk|+ubKyq1z|u5O5cTWQia{ z&+)@ppaPkOiha3FFPTwPz)E?Y;A8yg1uq7YhTNEMzd;%%l@HJU7Uy!F32N3DgPTVsCGO6e3x_*<_-(lJE0L zV`1{gPeqB>YX8GhRF8^c$6?J6$7McQ8{Qg*@|$yABeFN`5pB6=k9&xfVaZgkXlCAH z?|M*J?=semX18}@^=^j2U5&1>4Tl{qnYv!7kIQ+FJ-Fai-h7+N zJPacFE@J5I=s+ruZYmFr-ZB)x!#gUmFa< zZj8qlYE|L>dZ$C#8N+s>wwI>*z+y@mzSLi(UUWTxC>e>l+OzuP>)Tp=$1H)euwtS# zadr9A4V7h9&euho_F=n_Z0Fe8wmCX356M&!>(+0cElQb?3p^Lv zFY&y|6}mhGuL`&185n2@arcB1ID0^Kr(fcrV*ik@?DWUo2B)F-z1J0) zt&y;LkqdV>NjGKf$!#qj54JB~87^I+)eRX?l*_j$$FnM0)mS?`8?#1^m4*-Eq)3-L zx}G+<84Z^f0O#QzS3lC;>22m~1u{1`=3B)nmSHhz%}o2?)5$i=@hb2MgxIU_3a)=t zEfk4;z-+DB&VW9xqJqO9c%_n0d-}6|a5hW^b*(SO73Z{g=uxMeJsgFM3rP~4+(+ZR z+%WJeN{IwKEr+8F7R*27#uo0gcviMGKRu`MVtjrBQW!hXLd1;sWL9LS1~+6WJv-`C zYzO@AS>T{U%jP|byB1CrWlEe6#Y(Vw899*^*%u_i8`9Iy^0PN`^@ZDIjzWmUUn_L^ z1E*8N_CnTAe&epQd)$sXErs^$8@E3z-ubC`{-UKjho<`}KN789 zvki_zXG$^0eyUtMVOS3NgUb{J6vb|HfMllsR?@;a-dT zCKl{kw%@^U+0Hn?3{|T??5c*$6?(gX#jvN>BYew@V$%fEyRyQ%4j??!>T&uZf%>S( z>U64LDL%SOcBGk6n8Rb@hQmjF_j5%BPwnjXpzcKTp-X{FwtK$zPiI@likZM6mY__5nn|uMKy3H^!i?D=&BEJ@Vs~ z%EQu`^8Vx%FS^e2GHzYDxQ7mMTKc~|Ai$58eO~6wFO9@2Tncn*TUWaOs8MBPaN{`q z%YJ}25nMa?wyxN2H{}h2mW4Aye`*chzt*_-kNvbO_Y3cfGphNIoCkXNA)o47^Id*h zS5ysDS?SWM?T-Zb{rGqrZ!ZY8x32Jt@WOU(Le%yrTh@2)gZ=Z_T>5Pb z!zEXZy$R3o0lZ}A-)^t-D3oGjWTP2z)2{hParE(woo}vY*8DFfW|69dB--HmaunI< ze$$sXrrko-V!m~sDKF1E4y#hT8dc=lVlQz4OTWcPKZ@;qX%0hiu%NJ{^~10)5VzXa z9iCdOKYdI)CFLZ2yx%ks$5W%KY$Dt0(!ay6a(#*eG5r3Cz`s0B)4k2EW-qtmO!LV8-W;$P-HTC@)j`V(c-fXU?)&N*P!uoUy{$B@ArQflC$JQ>@pQKbN`UPo= z+$uh|cJ1ZtOjbfB`eIXlZo$x|-;nXL^M;U3}r(O7ipn)kOHU;$Lby)dbMTtN2#>thX4y zO-<$f?_c?Ox9pa4tZ(>X9p4tqmZH_EUu*Ztd#%2vO|T_aEbzb66n&$`Q~qr()D6!h zW$&SNG)3+=tyq@cO4RxLTrA58@b&@%X`2zTCkhB z`z*4m!OXVEPd`!t>C?)&P;4Sikwlh0?H1L@kKVIi;t&kJDfRgtTX~?afwzcCc+uFS zAbo@L^3y{YGJ^lsSlUT_0pi$xisY5ex@2z1(8R4EeW#(ubcybg)t+~Y2{%N`*oq0U|6mS3{|gbeUy##;74qm_ujvP)LtdAxwK*|)=8Z1IFPAIw=dvFK zwz_|rk>FDCQOtgEBqpXYk_S^++laEuw2voJsKh9F?~T`Z(kY}I(7%J&T;~&Wl1HO@ zCVam1l6#wsZ^?9kHn+-LxprqRkEVSa)0lSRsTUKC5rX8}uN01O|FM$PHnR3iU0!;r z%j}f_`Fm^Rr6hcN+{6U}vG3|-9(OBpt6TiSq`G_Di@}bJ6xY=OiE!#C5>b4<0v}&S z%5G0S^%j$<6MHcsXv%|wwCE{f=-VLJA17sf>E*4Upf_Kt#C$VdcZt63SLwj*AB(9!!jv#+Yh>|xtbpt_FFs&~MG9P|k8ejaiQ7hAR@yliFl!Q&u zTS)%weF|^IGfgt&95H!kvRJzQuDF6?j1Ols<5E!sSX*H=UGmCoO=Lz36V}OpqBr# zR&~sGGpRkXD%CqMau$`Kb8bEedJGornQ@!dW}PHedUxe@dr~-6;{XWYFS62te=|Pt zIp+arjoKyFRNo+^KL=clabcZl)ty6JtOShR`K=hg$P^^eU0jj zh-qqT?N+Gdr@05hro|h?t4q>gJuX8J*$=(sx;1<08dZ+H#@!!SBNb{?-9804Cmn$We6)BA#`PbSLsi`YZ-ZXnadZT(Fg0o9@ z>|QO#ulrVwYMwfHpzd1^jZ@Vc&d!TGhT5}H4%#ytXs~LYKz~8ceQ1(_*NahyVXM;s zjI%xay^+*JCWCh!S#i9kyqYgn?rkRRO)5=&HkR0#kdop!6z46>-j?U{waGC_=o06)nNj~dl!ZEj4ie`9Zoxb7vhtuR+LwvE{w*{a=tjes|tDAd#p5-X}V3awtpfMh-yF!+A>lCw$e%d?(V&iI7=~o z+3jqsS;D-AT~RaBd}_0oU}5Z!U%0vG-rTtDSK%yAr#(J+3!^pao{#hjsAhA zI7m-99}VB_S0Bv(32mBOuS)X*nD}h%yX8!IV+r2`O?ks^&~n^a0n@EqH@L8{R`opI zs6OD}=$QBAO{CmLc|852M}8SOidjqAJYDNj0#ZbE^%+!#N(u2p{p&@cxN4mVuuvd{ z=6|BiIzwJ^86S3UU#ma7Ll%p_ENZT**gdfD-HkC_0b{`RG1u5K>s-;;EfpjFxC<7ug_Y-;&|1~T4)5vpx3#IXHP)|HePLc>#SV-* z@Ou;;#|)KRiGSBhva0=%- z)suk4J%AgxtQ}dsP@}pGgq5J7^&E{FenPvg>V-(|-dv-yMBT>}xs1sFG|kok@yDJ! zo6*x9bAgyZLK#D5FWET2k8!`oC1;gZ*nMFmcZ0?bBPZz4gK4Q6)nnE+tW7SCr*Mn0Zu;%Y>P3fQNKnN0vx%Hds62s%!Q<7l^WGq z2gh1fe1BrsKI@2?p5SzO4d;Ke!Y< zbPdt=#~7oh{w{Hbt&k_ruFATB1wWc&xgEdXG__@HHYuqe`dZT$xv&@|@7ukKrr{+f z0Dwc!S%vcTGB9ygdPqGms=6-{0F-PkGM`h_Qxbcpyc?5D=uA^ri%*%RK50MIET?tT z^9FF4(Ul6`*ESDXvo_t^v4*RHZZ=$)6bC@ zmCa$r!^f%OkK-DJTW&dQzB@)D5C2HeRe?>1&ARv>8UYKI@Pl-Kc|-v?0ImVRorwgY z(Qxn)`IISB-gIX?F$dr7zK`~4aICoQtZj{|)!pXu6H@v-0e+o2V4j(?4k>a7(%6M0 zvLFC~2F;w$TXR~AnWk!rvqB|xSRw>q7Mv@Ko|N%D_z=|iL%%;J|nVJBlL`#FdI#O{4fh>^qK!r zopj@vgRCzuEHq9&4@bXWL%T+mRo%JDq7KQjx>)=ntCAVGcDwJE86*Q+1wp5L0@ReZ zmGghrGM8R}Q`zlIU<4;-fbP#_1+>U#Mq*=!N=~=9a0(xz871ie^n?DOnoDF zXpP-&Quyo|5Ohok(0aqCvl@U@RlWy2YD`ec@C}Cqo!L)MCgkr)6KYkh9;-y1<%#1n zO&y>S?D5`{va0+(PKUFPumXw_AN<+x>L2)dFq#np!eb2ww7TniMI?BMCfu{S1U^WXxdA5e3)ANv zJqo+#MGN0RG(-#E_-O*s!neJD&V~bE(|-Yz^o9$Zs4|{TT?MH5lf%LHTtGks^hJv* zz^nR2*4L|GHo-AJF{CuJENrz!tbJz8I&w}d!0##d_lZb~fCWXhCn;oJ&D4I@WWpAc zyv5hLsUpw!l+CKaOKGYCPn*5-#5C2t!;TtXQ0hVg5h!g@o}U15YlV1cy+c*+!>9{M z0{r{J?n1ziO>h3=j6#{9P9(TphYguJL8_m&<|_ug{wnl`chcD@@AM^zR>Vo(`xiU; zZEM=RBlk+jJ-HKv1;-TW9#Ai_(qRVEKwu>=YuGB+r2{ep+Rl*%(m7 zA*iJ6Yn|k>nR|NQY8x=^s67jI+J0Y-+w z7sjeUpc(%?+Bb?4*S$?!!<#>cV#f~u!)OQZ)%trbT)n(@Ny1xrb}Nic5`YXMo{H;^ zf3IbJ`m&tP1{6&TY&sLO+tP>mlD^SHb{>Vf0QP-j3xe8Kxy6k)6Ql@n-Sf}yn5HgO zl1FWQGq^y&sTQU~bJ!r{I0JwP+?QryfZsTnve^o&4lA9M4%3F@#Cv&PHI2R`dOIN> z5WVI5pg$M-5b9^&l)SEi@`}@$)5PiL%G9|~ zDX$|TL@n0iyr!vkz+$wd1Yy!7e>`b@+IqE7%R0$V5Lxovhu@I3sjy7tjy$gLEQWg_ zKLF@xf0xmJlJFucU3%A-{Z`&yH#XNN*|hW>Meo3xI>aA4el6>-;+~%yqlX;h{BGut z>h3y;oPfEI>MkcKxC&>#5J;FRVzzEibbO#uZg8yP1Br8!^{L;P$u%_jj>ZBrqThq*V=ur2JPFlZU-fv7 z*0_{co&&`EL!&f#T?df-FP%qSHyUq^wI?gi?Rc<32k`^}$EgKPmR9%Oh0NmfWe zn~y%SLgNokE}#EJe+n>{xLMVC-=LIqBIolzpF?72fzDt9R z+{{&@+7B!;{_oUm6X>VO_+faLR8?+qrdGhLhj%E*N1odN3*8Ey07H;o&0_gWeSxX` zsuK%nrjv>`2t$^!fNe{hr#-av6hUR7t^FC0WMNupjp3)+Poq)9dVLJchX?}jF{QbP= zl+T`%J}&TdmG{)8Dc&XxKM?h64S*D=&=YU9wX}et8VwNAPx*J?Pnz;F_$ZXyOSZjO zO6&)5rJ%JqOnRl_t2cYr)YL6?V1Vn5qGyk`xoIDUMN7mNOPZcbt+GTvOr07gBpiT2 z&LPX=#MFIhuw=CohaH^X^?Wk}Yx0I3Qc;FC8*qm?BQExeKdul73i_DKHeXz;N|X+k zrs04c26@=m_!YRg{ieLD>SZcyAG-qlIQ>E0fTuq9h??-two}RUkfc*R8EOF zY07QD0QKJ)AT+k}s&IUJS{D_mTlPXAD^}EKH<=?N)(s1OyL)T|uAO1quby$ce^16E zqK|L{F`wy`+|mFp;vCZRdF2p!v`ypK0ROxzi=}!DC`>*pByXSn=3O`1kV;>`8tdfy;-QqS4zGawUwKA$ppXB%>KBaZ9 zszX`C<^XIM{=I?my)W`w#ISf~WwOQe-^%ZrojV@T{qb8dJuT&@DBi7B)4ghwicC`R zg=gxdxKROv1_TIF=E#k!KKLcFFhpw|8d8^HBt6Y^!}*6Z8e+V3~qBh zf?V=0O`exPX8PZX<20eQSZBU1PYj_QctG~>>+yS2g%?GC_SZYqs(z|ngZ|lKJa7vN zZklE8OKvo`@ZW0Zv~J!>6upA;R2UPT{+ zmb+hRnEL*%)v#Z?PjR2>Ls_;;%Kz3w%P6bv4*;S0!~THZ!TNinQHysA`TyGaT3j3R zMd7DAaI3SJ_Ns-BAFBL*HqrGP{*rgQeEVvO7^_Um1X1cQ$KQH`T6CZS>-)cGIJRX~ zgR1zXB3F~*v8I1(WM-7{!(2ALIf9lC_?H>KNC4WYzb)(IzpHBOH@j}b+^4nELBIE_ z#lP6dR-`2W4({YR|JOC( z_xN11Gf4^A=V|$z!uH$zf9a0VI?h^jO^Sz`sQ=#crJMQS2N}w@49&e&a)9_R#L~J$ z%&7E>&E~%}eXSnw)442<*T4I(%`;ltJN9quWXAVf^WFz^GXHv$2ks2h4+a z`0slHG%I~e`M}Q_aAhuH|`hx%7d;Rh;s##hmMLDOyx&zdge_N@acT)IylZk>+ z!&Sajw61nIhwL0s?|)|u&?i+n9wm)}Dmab@I5WooOW$~_=ZK++@3XPo_}^9hm9jQA z>$~ZFlQh=yH1|EZD%N#!Wvz5|#GQnW>!X%;`fidh>7Uh+(K@g+kY_$rU>!-*ou=3k z^wJ0GSTSE783%skdf*bob?T;rt<7g5^(zE%xHtUXFw!Pt(|xvRyK_?b+ty9w;Gl9r z5>Zl|;))pCv^&%O3X|z5qGMB%8?9Z!xc}-oywXvR(T%Z92E~%TWd7a?r^&!~b^(KZ zOI2b?%CS81?+Vou>Gh7{8 z@<}%$>&OXUeRBqK&Q)PwTKaCx$;WOk3wf^$>Z$vl?TB6iGzzl+iw2`M^>*=0gh4zD z0Tt-$dl;E}~eIxKV!X$2;_U3-d1 zdp)rbu7c=G86A2dR;7GpeDaK%<{{KA>S?T7Qs;xQJDpF*d|#<-OlT~qn6xL%;+1*> z&h$GdXYZ^U^dOA-4fa)4kzZ-2$R(_?n4%<{VspQy+^3&M- z`pjqDn4{=(mV?Wrq#y5Uv<@rk~`?pRDv>FL=8+!sjV7Mm#?Z|6Z5R3l)1d%l51ZglN6VR(>DpmXt=wml=<~3 zg#l`F)CHuHhH=DRkj=j1- zByZg6XYxLGT4G90em++EBLQDj)nXC)mW!tv&;XXX+;e7sK39$7s z5%Po;Lqh2pZ~O?#4la){8E=fwyH~N7M?ud__F}82^+d5byNW}Rv$ABrQQ4gr@eAzo zt`9MT6)uAJKpp`9MtmxDU*4s8H$m{um{+c91~cmWCt}eo5ktpNY1OV*_g_+ z)`?`L7t;k%nPN6XvC5uL>O5KKMx}!Bfr76-PVI>ug(v+ahJi@6poWVqXvYJoB9;~?Dy?d zCY3(a<~P%Bko~Xk+n*mC~U~Ys!WOm%NvZIpp+IL+5hT z^vSh;j>OZ$X*dyzGJ{{`#z6}z}=7{$HoOFWA%^3KbJV<+chrI?f2v(psgXz3$)w zc7yyXao{} zqN-Nl^=vlg-B$_4BP5A%IG0(CDu4En61f(j%xe|sjrR~!!j8LL;Md6$FCY){-Cy{fTzW3A1`v10kdhftN}Jn@t91RaC-+GF!_Y{Za2 zkYV^iwMk#*PQG)f-p!CedNE2hLG<`v3_539nXn{x*8mff>mDJCo{;lU7H=%~Jiw|V zyl1txjmhtP+pS69-m+-!ts4K7t%|&$Qmo*85Wgykhq606j8^zMA%q=113@~L?xmYv zRdI~juPl-gT21}3YMFu)TUxlbnAmGNy1g%7OFMOB>{MERtSZ7xVv ztTeu9VkpugGSl0E%Hu&ZyMl~={&C=nGE0SHNyi5@T<&@`8 zoxR!pq8gXu7o@Rv#aOs#M#@yf`syp35*CeLdvJ(^$0{WH=^=fp!igSx?p4-!T3V~q z_DDU01^DgFBEOI9$uxPkVqKLbGRq1jc|-zQ{We$OKe{t(b?&MA6;as1dpWqyH+gGV zP)^lwu4I?+eubYvf0FAj_d4a!J{LrbLvTE&93LuE#!b%qxSxCxdOAwar~(>~oU6_B z20d?Y%m(XBsK^Cro2Q;;4qc6fa#L*>|JJKM3zHV{8IiZ)66TBPt3q@rA<&I#^Az+Q zWuNBW_;{5i5xMfD-Lg9Gwv;jNT@Gc2j99Cb$i%!$JMn!sGd<7Ft(xkcGJWB5c<^E6 zPN8QCPDVU%Ip*uH$bt~@_I%zdU4)8nh*N)TIna8zO@;B>rq)octito~f z#~lEjJNoUf&dB9?>`k;8vhz?N<|B#Zp~3~52IUgoyUyg7_~VLYpmvokw}kteR{%P} zt(%nP{Cw-2E5;z)u&ioG^PKeeymBM0R&~m1hSU~Fs2LcDR47OqEKqiZOPGQBDlkh) z)E~~3pGpgNnRg6Fu1=-MYJ@C;3Vqn52To$~EakH?_WRzemrT&1yZTB_$~VU7Q%?A{ zBf*c&Nz;u;f?0dW7SCsZ+WUr$1+IOfeeFEiATa28YrNj^Q>fj09{E*Pyrk~rd88G4 zye%tll;VPl_8T+OVzv1it%HI?$U3)P?Z>F}VV9VoXM9~qQO2BM)XMIF$VAoDibPG! zaEb?B$5?Lda?KpwevKZ#DCKVVhYmS2&&a0-rtOIdCdu<=B*`aB8fIPt4i7w?I(BJa zY^2VY&hAz&7cJ{_7(X%DcJoDlOl9Wi1d4J|ki9nMr0FxTky8b!mw3f*c(hZN9$teR=X@Z zh-J{!khIgI^#{0-O61hNY+35J>%jr`e>5hn^G$*Oc*YM~CHvBvp@kPW-+ewEpTP0$#r5${}4 zxpM{Jp_Zob0s4ogU?&cMy=h7y8*A30)Mxz`ZV-x~l91G^-uzn9;|@8++cIS_8wgZ|(W zal^JVy4Q4XGU+Bzc^If?c7tLI*x5e&iXQS1Z7$Y-sw$NiW>KTMq(ImYzYsSbugZG- z%l@I~?9f_zia>mOT!pcD0oczrHD#0CXtWueBL7(c?CLTFUFOPn1Bk;m9m9v(%ZfDk zT$m`!#plIuKQOB8{e%hforLrrL*Ms?F5Pe&}g%8jS`j@_j|KS?V2 z0AEIoog!$l4ueS8ZuJZ?kc|t7 zwL8GCW+*Qf(ja@aWFIwWuwFgn&**g8+ z=U9*X{RuLUA^CHYtAsh(>Z7kzFrREw?7kKf`sFN!N~7r@%b<_fRSuEM=0eJOSKwa- zgDoo%?@v7!^tmiUF#TDL@ArW=e`@vJnh%0@HEVX7xq2Om`!0#YT$0tNJ7oO=uTe@L7w@Id+8>HcPTXy##&3p_=Qrn5KL{ z%^M>Db5X6>-AdrQyY%5Z&L~?ro8rDe+?8vZ`p%Z=Gmu<)F&+G}lp(=!L79~L21i(F z)gdqg)BR||`WhMiLy^%Rmf7HXLHd$#7HIQ2?}~(N!`nTzs{DofiEL4n!nExg=&NE{WEmO)ch;z8D9_6` zvs&MK!2&wHFM$$)^fV>rN)zpGUjG2)&3yzC}V6O!M_yxfO>2sb3bn;-pU zs8W@8eZE=*`jIAmS(OogNKdzLPbjy~hgFIr>u$^1Q^^o!^r3=uBI2QX&}IzW(@DE|1u`NN;Q1YieA z=+l@=-Rg}FY3F5yxvTF^ytRAP72w-?Qp&~1>ZTlU@Dbm)8)SUGF!4^W^4cc7khD!R zJ3~_5>9f}NJdMZ(tKhVghzcj|ba| z!Oj3J)FY;hoPAi7#Cmh=F?OiIbbz07$ohLJEae zN{`nwW`LY?&XJD)s$+7qpX3O1SfDpoPy`ei{9Tc?@WE6YPG#qGlFqj}PwBpQdNPlE zsm-x898S_UA<&9=EwNHd8aDi3pOc#cp}Wul|H^g(G(MFYyk~j>jV{Ec{XcBIcT|&U z*FF5s;NU1A<1h#Yfthgx1rZbi0tQeN3mp=A5s?;}ln@9_q}hNO5s=1NV_UHPU%Gs z={$*xQkC@Y(yov~N0)|Q6+asd4l~YR|5CZap^-AJA97Q|#LqS@yAGj9%b<4u_Is`P zk9?Ar<4-_7Sxd;cyg2zrkguVE_2;(H_KuCT>$YB}O=^Ktp6~Y{&;G4%C{}k&=Uq7% zkPcm+9dZ7=^+#c|&S}SnpiXfO=falyYM`sxE;+(S*xq@#dWh)0EepgRKxP?!)S~Y1 zsPq7dIlKLlGhIgf-uDv!oTQI~aKGk8gO1+&s)VpPyP0eIs`p#;4DDXG9O2_;FL``e z7wwCQ{JyidAWGb255}&1$npkJeRfhNOb;Se7+2c78Bjm(`o-mY{O<#375+9|N z`0r>0p^!1r+vvD^CtRld+?LXxmLW%>3$>PeZU-U=dzIaXTo@jn90J)X&QI{sDS zl-JBs3z?*c9(J{LQwVjunUsACySEM|s7L%-2Yhv}pI7MIYEu@Ry3{bPv0>$J$;w95 zDFdO`iYJ3pC>-ii)zV=Dloa8sNb$37H~x9ANpVp!<$j#f0WQ#P&2XLaJ~^9rS)l+d z7_*tY^f!1H&e}=G%rvs=AziEWyKQD94Cb@#DIXYB^@Vu_x`v_`RQikHUeOo4lptSJapmBB$;sQ zir@^TKSvWAR6rB_=-y;MSSxP{pE@-e1j>yoSW7D|0a^0TlO@sOjlPBSH&F*#mrT{^ zEBDii&vpg@-6@z5=jD@v-z+k!%ep_1&q(n<4XjU}#9How>G0#n7c`vmNJeG*_X7&& zJPkAUZ)4)3v|t5$=BoA$=)2WW6j1FRr7C{x|NAfvBS+)~?;B_Sm>qf^6x6P^4k5mT z3>0{~m(p!jg4RK9=}`lc)LqtJtGE9-P4K4#mZl#eK>H(6_WF7C1tpWB3hX~Q5^9{_ zm(EYqr^)aeSxW+K%Hh;t@?YtW>WrG>BnNh=P-e}I)}TplDJ{*EqgnHOO=AR8`Wv$4|?+s%nqE`taD+4x@yOg-eDnPS&?h zoZznn`irg&cDGboeIT78hnn zJC}?48+a=^_9toYiLav#6c0|g6r8X(|0GZ%z>~6)vbSHe&zgsYtgAM>S_Q!L7Zq_M zan*o;frx|~H2ympf@}A>WG{j}A5EZ(Ni-4%ISjIg{HjkPcu;m_2#U)fS|&SR;x}SI zN(U(bq*#7~;Q}X;4`T`(bHnPKshb38-N|8;IM{A}S@5Omcv`i`hYlguAVmkg=kZ!HMdvQdX{mT15z0C| zKeyXoPIVhAJwWV|U-b0qs|^PO8i59_F+@S|%_ZF45vRO|KwZBfLHGSK(-kgImt)p4 zIPSg{CSo>PY>KUyOlQqYLZH5sN-+-YZ7feQ5P!$X(acSlNCal@th;L)uOX%3h{E2pR;Xkm#q#j3}Cv?M|^1KZ|zhIKa67esO|9A2r~^Gu~?o znF+rg<5K_|GpOkhc?Li(N&oK+ts+prhQsbeyL!dz*zWu*{_eLy zyO)MD+72S#;=_oq{vF5h{J%)$bUq5acfTed*Vfl)FL^j3KWoe5wfEyA57tItkkoxQ z6w=gTe{sLlti~0#_VtwlByFnI;A=!0c)TOZ0#quS;<~GtuzwyDROcB7{7jnLqDj{X z08a;R*RTH+({KT;-v!~~{5I>n1eP{afKfq&_Z$iC*u15W&5@=UIU7FsV+!((b1We}68zY( z;<4rk;?111J=wxZh?$#}^sR~B+L`oNRw}P1PT{eub@3#qK1shzB}Y}T5Ugxo&T31> zi`nj9Yy-TyYzR3GCG03IU);>|g|-~_CCOUk`qo!0yH9>Yvitd?-a}}Ehq7-BU;Who z@{A+M@_WhsdMJY6`SwHZB;O=g>#bGBTopO}4TbrJzk7uE7isa!%Nw~RZsP7-KcRFF{`21-3< z|1#nts_*e3kipbi!NI(&A3#XF`jYV-VdTusmvvLG{ZA~=U#U;fUdB03FVG(!i9|3! zfzyrEz>UL4`6H3J&x@F_vKpGD4~}`G^M!J;f)gFa8`>(qbobU&`s_x z#Is9neJ$cLn&f{%`a-?lBEyOoFg{{EAa)rtD{y|q{8D&S4ew>ufS2tfiZ5|6$B?Cb z|EUHl;*MeFc*#>yHoMdX$C7oy&&e_&@_1G!jV z91U4qMUG5aAn;bk%r*qwxDC!QZtFL(pa7aA#Go+K;* zDBXX`cZq#ws}G)1_7JHy!3|>}WT$Xx$B@B=h&Fq_L5YR3->O1O5h-xr_nf{;#xV&#<(_kIY9zPwuoLQl~I&Q4fXbI z4sUdIS0jjthC}#kDH(Ff2y<-(#&uErc&Q$t4(G$g$k{fQfd5xmNgu{`bvm!jufd)W z+Ao78IKZEGS#gVuvn}<~ODQiwd|G7yQq+-^&hzQt??mSLTPMKroey9eLx06cl4{KN z%0s0k%T=>O%R0T2>E7FxM$&(H4gJzP9{}jq)x@~&k*Sa|)!(E<_#q-MoW!B%w2}(~ zV2)$pAYBn|Z`*$ibn6-M;O{9%PlOr$Ct8(SRCdmu(v0-|Rd8*{vg`DGOUHP4_VmMB zu}XiQ8=q*7Cx`Qb>>IN}9Iq+g{%<1$NTTeXXEpL!f8!yQ;W6b6XP_RP1s&z7Rea}C zxCEg@=q1=sZJF#ksGcV$K@YPP`%*ttTr)C>-(b{g#VPN{)=Uq^u1w_B1K{~lM|}2q zTz$27(%#f`Zp#7fuI8|M?O&u_}AWr0orHki7!SQ+HxqSipL;Dtd#b-QaYz^L2*yt`GnMu zf`Pn+N6l>#YKABllrAKS#2Q!_!=UXkyY?jti6GWT=DX>u>A7?JF9{{NsUC$TXhr(F z>l;6uIAzxL>U#urO<2cjTz%R}u^A6T2-Mme)ABhb;}A!aETpvRzV0eFmt-y_C2NPcy5nJKe99Okx78 zACwOr;p-f9_3c)93n-TTX=r(%uH!`;J7&>uwHci@8xTttLnlx zbpWB@#c&m;U+c0l9z&FqV51`N8Z1ED?{EBfF?1pVL@g53_b+3hz5@vF&?f&`spl&9 z)V{}mfZwlzZ>|+)EAC7zykG0&HyVe7lzgi-Q$1(faO)px4nSwNGDM?Y@EC{Dry|I= zO9ChmhIOhU*;Dm8$8!OploKKkD5W#S0vjG3^M;smeeg*7Oj01)Wy3ea2v=5Z4u#r$ z_}MSF5*%gTn4-Ed9%q7WfS@tAL96xKM>+e*PxVc*jDY0iS)Y$io1^#cycegzSm+Z?AZ7=P%;3&@(qyBdn^#X zynlu7t~>yOFCR=uC1Ap%w`ZsXhaYm0o8|V9v@&n&!0RC)%JkQ!5{R-dx2zF7RyqS6 zXKZOF#$mi=sFhh#X5c0FU@6_KveXS~{`)5!fme)k37O>@T-$SaqzD3c@m6>DxQLd$ zpIxv{p`$Inm4hty2N~UtYlPhU>d(!s>Yj6+IZ$@}jxT;k9iu|8omMtpXw+ zsf#Gvt&dPW^QLjbW8@KmrTvGakQ~Eos||&~ZmrgaZqBqh_}^03KZYDBh=x!$U$z0*j)G+YpTxlutwzb-vOKZ(B}$d`cG&HQ<;dvoGd zNGA04S4nQ${&ii4(e$^IsAW^8u}z68!Pm5=jf>D?hrl#Y$$B^I_5L7%)-!!M7Ic;O ze;J{C-1+)3ChA%@d)lG>4KovP&AB)I2ql9(cZVjq?HR9U^}|6P%>UQI1po)9 zD4U=wy|j4u6lGQiMC?W}h*BbStP-VC#Gf46Xf?^>?}1cqd6rLn=nV4Kpp4l zu;Nx9VXzKG-z+{ro&{1pZ?`m6-l9o(r#!vvX^(fAttld^-k&vnmQ5S3FdZ53Q7wx` z+s-l%lA2+H_gN!3W4@hRl^ZcAOemB1T!@c(zR^P3RVjFit?O`TUS@_RPl9TvEJF+qAC2T?s17qvZ?UQ^dxDo=cNOwNrX z!{w}p0Y+{Hk~Q1-?ptA7dphK2en|O@;jO`(3qFQDZb@;*A-dl3jDyDrZ{R+|lz%Ow zC5gKkQ)gedC_0ehHPEIh$b=LGIo(I1LVjvhw=J7jjM3V!Z?h;XT@m%Mcy>6|^K0c~ z8+rJBL35k!TDsI*a$!~ZQ+h@jQ{Pi$*x_E^!8X=2m|@wpo!WyHBv{$X88iEN_hh8X zLf)Z7#+g=x(pCbyS;t(})8K>W^%c<7^qh3WS#}LWCjO40L(K-_XV>brT;E?cjC6kH zV7;j*l(l{17e-r?OXoN{CTiZr(C!V`zkc*DJ1LbZ_YP}+tLQAH9t-AqRatyfphv#? zv-MICM=uxuqm6yS0k#W$;J*jW7ki9Pe?FBi#lPJW6SbmsZvIaz?wf(4wxqL+i<>pK zUrC#BUNpx>QrANgZ&wbtE6qH`(@q6#jjV6aWygdKsO%xTd=;65p8P^LGvtHsNwFdZ z-;a!uD#YFS0Jn5$=C``@H|DNaA}+|MPTlkK0CS*YIi{6PdA-V~rF~26rnxTE zzt=V0#@lf1H^@fUYT9N$%4-Tk=d zbANA~%`=@1%O5YE?n|?2%n-mhEiV^PE*cBTF;`d zYB%!9+_R22S)SX&73)u;Iq;1~lfqA64Q}+IAA}HV7b2;wlES1j@s0CCWpRTb2EkFPKo2P3oBD3aAC~MzPIhygtqA?6NG*0+Upt518fWJ z^f)h^AyBTuD9YSBibD({uoU{m7vUtD9Ks$lNDvWiTkbj6S5o z_jXdu8w?~G0Y@*-SVTLt&Wvgw!2qv-d;o;Mp0s^huNCV_C4`|j8;4?{=C|oT=Cjd0S(7;8T~wsnIJ;_5Xm}7PeiuEqFKfp@&1Z^$jga7TkfxvtrpS)v&Pvm2hk}WW}F|Od$~}tqtKMXQD$E+9rh9u2nRm*4S^3 z9?3^SK_7<^fwC)DLo?5WJxAuD7*^B;BUlPPB-dwdPcb26^_A!D)OpQ(sU$Q4sJmVw z8e(zUD~>rRGsv-5!QtX(mG(|)*_+VfI-}(c#N9jU?PD1OOLM14n*EO=lbFFm_HpWa zS4wJd8_TSD##guZd#Jx!g6bU>sEXQ!w`mdi)!v$AcA^2N8uD7*yclm?T8}ckNwQne z-Oqp%#isd>1xZa5C1sM1$=p|7w|${#U;TZhfH27@C?uS-RalK|L#Q4^#7f|sn^BJlmN5VV|s*; z>rx6;b<6QVj+-Ko)h5NmGsv&_&=ku3=(xnsv;>M)A&cu{#>Ffks@{7eLuXWAX;kY; zLs|K^&x=c6X787P3(4?u3@_}RPRD*_5z1M&IkE@QL*qgy)Sq>UO!6y~WB3$rOUBQp zd$^@gKjic~28c<6no$7yD;csk=*m^R^DIH9Xb~@o{!3-?Fz6W6!Pf=_*5k>sck7^P z6@a-ss1!`lQi>c5y~QHWg97==*~g~XS9~@aRSI3<@>!CETy4!^f>Zsz?e<~t({7|b znxddZos1;B$|rP&XUfGGV_jOkU|O!(yoQ&>vQVEdi>RA3T{w6{O{LFghCp(){5UuC z%TpCU`slG)SMNyEd$WkOKAql?~Ue)Kl;?`#U*0?b4 zIh;P^4JR)*AnB93hE`IIi+t^xxaBHKtnEfVlJRDe%iEvW!uSnoFf*bKdp3>~%c}US zJ_bYF`6NzG9hFuJ5@U%TjJ4d?%+i$u{ZXgS)QdQ!A;jB*$x)L!BupJFb%ci5*6W1MfZ zkczWvNSCwlCEkJB9!@pI$0Em532ri+0BzNdUpv<#XDe>jVO_$jvNGWfK8NXEBj%JV z`}*8mZpuoBV#UH1(YEC`X4gnjT15X%g63E`tIOZVA|Xz(Y?nrbQ>{y(K{oE&9y4&Dp>H z%C)CveY}~nI^>2~c^?YLy2>Mvn{Rv;dHF33eg-uc*c}&1B%2k*(cVo9w;HGfnhu*x z)CezHe%5F~jdDhiFaijiU#?HYb`oTHRisV8RSXPmLRV=8 z!Moy}>kAg$Guv3dXbAzzV$0G3e%J__F@_(G^YUAa{4o3WhdL;xD<~nO%D!-h+v&FS-PoM)wFm^R*DS>*h$g8^ z?fIa#ft+86%1YSU!Z-N5}kVQA2M@*<%mGF^GYlN-F^gdoyw5RfZn z>)m_y&f^bAGg&uos_Z2sp(5)-5qKlbQU&zM8d;@&GN{Z+*#&w~TdY5Uhz#kzB($}4 zdl8`bkq@ocHs`AKl#^7ec5F4pm&#kA^Y@QLu9Rw$SgIq zdvq%xK5j|u*k5Xcwzmm$%yIxLC8xA z9F3mK!$C7o1j%(oW?r5PSzSCz_O)zp-7o5W-_`rgS?<+-lvM4^aRkxIl3A(2s=g|} zwo*`dAv;O8&Gxl@G^U}fQk0DkE6(g&>pBCmgCG7Z3udQ8;A8QRQ3Nowx|0m1#wq9b z^4hi)Z`Q-K%9ZQxdy2u@BhBWGHn$Q4o_bGi#>Ll4ugTZb*M-rZSPM{?hh);?q> zj5O=C8_9?6n2ml}*Gs{h=)3sj;I!1W8GLAT5dY6jR6bR63qm4t=N*Z=vcf+6CTlUd zW|huqv4~nsmUP8Ma{Ncptpwcka`+{DkSsKG6AfX8RKX$)cc161-yEgS9YYZA@!!aZ zV;FZ@Aot$+_J!(1B9x4=*znAZAfAp$sx&zjZ55)Ab~o7%gL;k_S-VEYFGQM;F9c3b zM$%10E+0|QcM_jpKjzKHqjp|KNvwLpmX{iE*7(;jYpz{|g(iKW*5Gt%O?cX|RF69T z^-cs$&l}bbuUR)l?tU0o={w|cZCcGK@3`{}yW|8EQF2XzULTqx6^hg8gitJgSj9a> z0#H*N1%dJQK(EbdfNBm6tef;XZ6Is&3q)D>>w z(N}qzh;ro*j-Tm*gX@(MgUDx5-tIU#A!(EN_?fvzTlR7Or&kY54U{T?|4{d4lTItt zr$rOC&pGQ$SaviGKlUvsBl*q5?^iYTDx6{w> zVz#I{Y2TvRela_9fp<+&YfKkhrDZRlsm>$}ZC*qz>|^yb8Ewsw{b{L&J^~fehbukW zRUPY?o4v&zDGJ(EeS0Xd06nskQcEG9i*MNJY1ENKuvJV)OC3Px;hYj3I)isnDi5^J z?OI#@J`&%*b`o|uIC~{6ee#3XxDPiNSI6ovUu+1UK6jO?F_(JP(Rw%Oo;$JO>Owcs zC;&d0MlG$g4Kfp5{i6gT%+9|pv-WdVOZ~Pdep~1iHp3c)Abv?e`Y)d1LKn`?$M@ZQ zN%2aSTKsYw0UqH;hyvYXs`1ae@VU&lN7~PW=>$gYYemlu?@a8S1|n;z zi3h29vlJZeOx+gluv{HN!N>40F=WmqX1}zNE0{{3_($zKAT-h*2#h-%MXG@@fP+Zt zfjd>_C#O7X)tvB_V33-3TP! zCpCv%V5!3bW)S~zCCNK5^N`?ru1!UE<5IVd>yll&NP2jBs9mQIgwQ37w#3I#%51VF zvfsy)($ZDbLO@PJj0>WH{yeS058%je4c|Y9- zHD8h;r0GD&)p?75(#DvZ`k7y&V7a}~v|Eh7%!yDd=SY6A&j_*WkL2|Yg`IQ-C`Iu?OqLEc;HjW%=)~Ok2-hNCHQ5a!c>XQ|P zt<4TJrz%3V%Cir(t$kY!r{}reN@vBTo_oD-m@d~>dz+iP!$ytYAjU=^v34n6-|606 zDtj7&d9n2|VpJ7523xwHH*zCX<=$E_oiNgK!~#*?dj|m-C(5AL7?Q?!c?r{$e-6FF4vH-w`oB*6T0A!>C89g+-*vkOZ-gV&U9(EH_@i`xe#m3djXk;9u51OcO}-`L|!L7O60W73;ET@v(011Fmc={39ES)MK~BC0quM;8eg zv5U|0T;!l1Q-P-!T`9&IlXb9|x!H;ZR<$@&UQ5&052umqJ>sa4`_FwMz3OvLQQmXs z{_NH8d5=Ve`aRdsenA+PK84NnskdYrzA8Y5(%tF$YU<%p$7}0mB9T4#P$IFt-oncr z7#&%chWxr?S4F_7&E`q?vWZ>72^O@VBs?-(0x>uj&#+9F>7FGzc?wW`T!&?6ev-AwxEo+10=Y;!JCrAu+h|Efs+q;gb;6nN z$K^EFonD%bM#j5f&{s#ubL3JYtH?1q~8lin8Kz$aQ;uL~KWkdSm}RAmd!4E+Xk zUeMlRE^`1#e^P8tyU?jOpqV4Yp%Q8D?L9TEq87ACixl)o#fU!D@|u#vFULf%iSn=jH6JXN5Q> z#^E$laB{|21+Xc+EIK=IUOoSbawKV3Bha+wDGE`&_bW_!%QK-$gErkz?Ch|M*MJhi zK}+CV*<#$OHvvZ?Z@dC*;{Q#)aA7`4i@lX>47Sw-T4sHd*(nl?)XLMu?3#lKR_|u< z&FsZnA^J%kpJvb0Ryf%r%7?Ct;wa|3nrX<*FCtxvmv(_l%zST!i(&Qp;H=;8Jgs{o z*N-tlukWR`WPE*5V>n7}`+F*X(l7kDdpYx8teAK*oEBO#yLFqFB}>(4 zxABKhmh9Zl_VJPYF%9m6j|0GPe5Gme;rqp8K#kE|M$4u*I!!3&4ikA=>Oo%?13h^b zpkQ8S%9G6}jRI7D-3+fBkV5ROB(T=)9O)Mv)l9UZ97G*wpD-bicy=GTo7MAr{xg#H zwn0vGFXPG;!N6OSX_Abuzx{yhEh}k5sPgtBPzzh=gzEMEo{K-=tSPaTIH<+PS_(Ft zsz^GvX)nfqsa*+WPOgEVI(unbnoyr*zXUXE>#e`zRy^k|x8xmZ1V2(Xq|l0IRQ6nr z(MjhQYD28Xk&m|DADYs(8=Bf|imGbCS&^RkaHtbV&7;D6>5fGdvq?>XwT>7%L?a+N zu3;rM?gUXSw4d5n$lY+YB#q>K@IQEmBGPRWk%IGf39T7z5u?1NA2W8uWUU;_94s4ah(Y8h%sRum{LXY zw-z}jV@(81qnj@R*MCb?j$vBdAI`nibwTu%?4fkoyBTw-*}X+as_3t8+s@=fT9CiP zMC?Vqb2Hj3ddo;|YJB&_Syw;R3^p#$SyC zN=!Pb7bp!f7_qQ|b>n@*-BREFB}+Ar4>2j2iQXKXDL!2e?m=ur`Swc|uFH&?JN+!h z@1Mw!F>)kZ)a79APT3^b1ml*1Vc17Cp4{1JVH}1~5~XBCcEHGN?$WPxdhVUzvIotX zl`O1Xf+n4lPId2)h+6b#^n2uVRo&38V>=*7Qa}2Qs@|&$7)=5TiL7C&y~gx^Nky4&!y(}PI*vG zhN~sv#HZ?s4m~0?$~m`mb=994IU>L^w4>=C0=6zHG~MVw%L0cU{cY}iGF<%A6e4d9Lps&0S&=V+-C5f{plg8ArFx{;g%RqZHl-_(vmrO=+5uL?`N5R(9(iSS&z+(Z$?>t?7p~;HLL9 zJ}NZ_zPF}5muN(pi@O74T5XcG;Y+d*=WAO1&Xb*;?epjK`|3-MTvjom%Lva!Ez5%R z3Hx;5{7TTu?@$u`?CbJ^CY_YdJYxYwo!gVR z4t4Q$m5h;(Jg7!wHN?`h>?neT$KXr3x87qxQUvnYuv{-JYlDiN2oOH8R||W4Uioa5 zu_9;(n{nXt-+|HHY4%CR_0{Fy+#{Z72Q~TjZIz_Tnb$liQ=PhZjw#M?>iwLN9za9upv#pFIcJJ-FZ0Ch-4NY8qjpJ;MEXO>y9O!i>9Y%MW z_T2*~{>Nx+#Tladpd14ABCa4L zFja5*00zBU&m{SA_jB$t;3*o3jMZdxJXJ3c%#*3>(nxb^ipaGxyA7F9E!Kkx`K)8| ziwUC;rLdPb-y6P)`NJ(!ronH1+av81 zBc>+OwDp7}{By_(!WkuAU-eocy;Ny|wshnfz2uHU3Sr=yP_04#vw(8T%5+tgRqa6Y zXU9O5jFRgQw;rA0m>NxZ`ke;WqvJu}ekOa5j?EoT!>>Q&bPtyM=@S^=XZ|_#ND?y# zMu3Qls5bg+fratT)xvAdxAl#Y-(ITGOA?OoK~1)vj;?LKQE02?eLnZc|IJijd(ga5 zH~yH%aJ6@BnEGmCj4LqryQj=jGGSumhX1gIxp4$}D?gRBM?;{yz+>>XeG1;rF~t_e z*d44Rj=@1s?Ql9S5m6CIcgpKaQTnL_>_Y!^oCeG`fwf>;=;T{|;asAGL_fUdDH$=*j|EZ~DP~=S@5OPy zAnY(7A=@#;0=fCnN8+uiNQWXZT!2r~bWcKJZAUpAUYB+9Rv8H++<=X$O0tT|ye1~y z|29N(Z7N$R+w!W|Nq!vM(gkviwwg>i$JnA{8_J_0%53k^!TP8+gh$yO!`z_SZ{dgiy}QI71At~|?Pe)L>_ zzU|SC{ohIrkDb)ACq#Q!Iv@A`^3sDdPRAM*>ZIGO_a>80q1yuOfw(BBVEI>=dPLRA zNZXqSRQ~l$V&aC7W$WRVS>kt*H22V+ChKuOL#w4~&RiH(F7xW2PqX`6K4228@9!#| z#VPgtmJqr67NgBD#mT|6L{a^sj6;B4>;l-vt|*0Mz#*0yga zIQkp>;7jb&&6YAFKNW)K`#VDy)^l3UD>=!As+V}@^t_r|WRKYbVcCiQ+uVZsU{}Sy zQWzQ+Q8J>oAiI5trQ!BD4KwAs`b&!p&Wa~PUf_MCd-`*}m_sz@x@#v>^#J)_Ek_Qu zHWr<)bg06>`9iOdkYvNJ)hR=Q13&r0p8vvsiP7eanZwE=TCZFs-3dp8FBwBL_aAy% zucj{2|Evh;YqOKD%D^v&lknYaSz>JR3$S*=deGN%9A~e+6!++`N5MrLTfB-Ne6d{m+Wfk#P z4KHpn#!XgHO{nK2pj`yT#q-TxGg3kl``Wea#8L!3=80V9uc3hP8MZ-m&~gl57OVn^ z-2IeogQ0JSa+-&A<}#LmCUHWKz@txm+<3_o;62gOqa}@%lyctMBRb7facbr?jy%c8 z?fZa1WlPq?1w)kX$6DM#K|Dvg3h_fbDLWt_Ew}IWo-TrVRT=2yZgo1UY&;+(phbbB zXDj;%cM*C$f!%Mqb%%^t?s9=>gi`f?>WE2RXEU+;`I)%Pcb3X@_NziM*X;Ur3!wv& z*6d#nvh(~Lms+p~Uh>?McpaU04|Hp1Rh^IWL8Xd)Xw~Ohi}@yiK5o}K8*;c=2e$U^ z?I^=%?1Fm!?W(=9#05?KH6U8GrR3(M*V<1K9o`9_<&mkg{c+QJD+xGSqfWz#oMSwc zO98M=t^MD&sb^EbepCuIEZQNMn2SsU>OfI8mT699e65)f`s28sxjB!=<7N4Geu%7>V32yu=^CTFBIYv>Ngir zm3Q+9&ZQrcZK+>3MtPt35xo5CF&qdt;ih<5PMa#EWJV@(VJqB!O+FHl29!o_sKtTo zj|cscW49y0Eho%Kw8~&Nu%~l$zhDAjwe>f#iMFq8jmH;&Vv1C}Y1geitPgEtAO_1_ zk+cOxNaNHCnR-97t# zZa+L5?}bTn%tcq(TXz-$d9w>`)Un#Q=VodRv)s934lzFRJptg?ITJ&cN5srlHG zs(o8~jgkDPDH3IN{o%E0Zpkef?@v<}?ELKtOB;*pd~m?~nE!tv5@~*1=_OQ7Uv{)s z?_0IR8oMgtb8gEsnAsBWL9;Xah~sgd-;I8}33UPk$ZfaovTIv5a^|~ED9NR$mh%Gq z&pO!pjM*sL4t4>McLbqfP-T^@pTSLWcJkw2vKJQ3L}xOF9MG#L;f!YCyPT=)ZsCh(oM4$?_rC1-87M8;XgJg~tQ< zCZ6vwWx+uH_{@O#e#=S2%`J7qU|PLY+t5rgAVtgTw``&JI~LcjZ!ox@^aW7e%I5UF zu07;nBOYg|9G;wk8>>!}QWyvuCI6vBozq?DdH7pMZG~yTZqc=H1eS5b(D!6n;zs9} zXjlo0!`sH~DkNcp>{qu(Givj%YQt1H0*en!&sqQMl;=?v7TsVYRnWo?aZaY z3ad4CgWbaazB%CX{~_(uEu(>7jUHGAoP;`gTmVzX^A#YKRjxc+VC`9<>BQb!%*r9w zbd;xx>%>G;7cC+|T-^*R2Q`ZwlBDl^rl@d)HYD<&fKlfEWr2P<1~_xj5Jr79Ockl0 zD^G0LFFvlH7%vRBA1vjuI*wdg6i z1bfc1y&Mgg#h+8E((0QO)+&?e@D!2LE+0NzPrB&)R_Mn=Yl3{~mPhKh6TyhnVYlyO zS*@)NUBH;IUr73E$E;+}ToRik-K}>{nCYxX^}JEG#YYre97A@&Fd~f8Wsa z5C2gNdhVR-m6p+Q-6r@Up+McO66E%*1Y~J@`5i4dHOpkINFCXqmmp{1U52DkvB#g? z149r`W=t${j3ra*Ymwh3!W!C7yJ084U*^W^fVXB29VVu9hO+Z7W)fosUR(q>NC^4P zBXWKC`tow-qWdie%WdR_fqlkO-}m4WJqjk`LnM|l)rkq&J8l>}8M9J-9*m!U+{-1K}GHPW$R(u8wb8KSTapn@kj zmEh|X>ZV$PL!eo5gwr+0QuxxUKt_Eq-XR#9%VKzyJVmguUd@dd3YUxv^=Ph_Lwzr4F*!~ zf4L1D%;CY#eo2_jj|X??{d>Cg0(hw(`+SmVm{HoVFbrzs5H6^-8ye^{fp0qDLtoth ztd5)^ytwyttt^^)v|fXCee zIGXBHHgHSSQ^lgD9I|~o9k+ekS`j}NcYak4w=HK|S7ozyO9wdCe-2$VPrq~VCAQ!Q z?bW4GV(M~S)Cp<$*f)EK#=08j?ae5k#<|;ceE)UP2piM=aiNs8$d1TJv#a}y#&`|V z_av*r+si1?PWT*Ld zsK~+ro)s1z|L_Dv7nAX&@x=Y7cErW^Lu`bqj|e?2D9oUL_J2rJg8&3|7Q6Ll20y-WQ|AFlR5k#+`r`rq=TBX=$y*zg~WWq%gp zQ_#RYI)h)+DXl}0k+2l}sJ^lBaIxQX+*-LddwM$okw27xqx%$WUsz$y%j$vFPyT4CAkx&S96SC)Vrxk=lHcVCjl~GTZ)(ESK_Ko32_vB#Irg$X)F82@O%)x zarxxv7f-UxdBrfptVO51%WZwz>ZtjumUHG_Ao&l3SVhMtSg>UjxFX6uR{14u(0YWS zFS}&jJ`}qE<3qLU7SgvG80;w4NU*MAbyWmv)QLks1Kq;iSiYM<+M+kNAS4o8;UvR=6 z#0&+77@!~&c(Ffpjpk1Wzfg#Z_0ihpzd~4AsCbd;TA!0=fGf*F=W|lVUe)d(il(4R zx(drROBgatK@id_H57W=-*=~;>T5R#m|MtG!x@_56Lwpe$)R5-X?^~doT22B2^iR#{*__)a)YXoZTq-dD*O-~keAbZ)?=S$K|BW)`EazJ1uknycjfY~dKPmNi zF0v+g9GcKH)Q_BFk~R~G6>Htqh3V}pJc>kYT=c~+z8vks%Y}g!;9-V&N$%XbScIr< z`dE!k$v~6N;g_%pMs0=~TW{`J@0Jkhz{*3#3y9t>@#O@J^^Q_t zd2KSm<|9fyQ)+!BoeEvB_!Ta1uJ^Y{v!siW=t0$C>TGO9uKTBs@HX2Bn~*JmzFJhB zPMw#kkz(VDkkhZZ{~a3}B!V>|rh+!{FXw4o_^MuQe%NVf#Eux`@?uttaaQpO!yygh z7;I3hBJ^IAq`{NlfhR$Q4Y)udk6&*5Zl09%`iwX+MIO3F-)$3ckbR*d_{)#|z_m;W z0RI^`#Ou4HxUNkd#pl81+MyA(oLXuoOU)}ED7Uc&A1#nOUw`LD?mljbcUaU1SFgj`xJ+I~q{`PK>L)yFOus|-H-S3XFN_4@YecC}0O z53a697du_ii){25UHMCy+x>)lxBK03-637Bql63=3wHri`;^QuV=w+n-uvzg#65IE zZAWiapyFx%>`}6(2_q%u9 z|5yvw%;KE$#NN-|=ftL&Z(vAa+K(!5uoqSg^v!lq{FDOai~i7H*zyHCssb&cJ=%dF znKteJ7olVJl%^>8*xc3#FWWRJfG=uI|10+abAA#k-7z zbJlGQ2W8RV)AhFH#otbl@bBX9{KK=Txsb8H#n0R#eH9OSc^Md8wMztw)LC6MxLW3b|Rz=e|jPY+of21+>%4d;ZJq@<}+vp zQO>NX}pO zOVWhh0{*GvRp1#yY`6}^3+r8DSrb!-ur%&JQt<3riM5i)b?%LD9z!8CC*&pvHpVV z&NCu&c8}hc89g>$W^u)Mc}({)g&fPq>+DOrtouDy6ZJH-%5 z59VO%b?YC5Ie31`T;`6SRGL~O5toHfwOcf9Z_$G7oguCVbPcA_kD|(7bRD|)12f_d{Yy7Hlf?72rlm+!RPIiA z{ic*?DKD;cvzJ4O__s^bt~HMdeTKx>bU(YvIh9+aNo)vr42xgBz%vn>m^iU-1xqIs zl;bGDeS?*_kextyUXU**PDNpcpai-^#BKrGyI9uZWI9dULXwmIe7(fa;KCo4xdpIg zWrK_Gs(|qN;yOCW?sKR|xPRFlDwJKc_m7@SUWq+ODavqf!dJDBJP<4j_QXH^NufIJ)?Qz z@tvVDwi6&czW7~uPK@X8;X~h5UB)2{N&UiEoD4pmp~zo(#2b2047&-$&AWE7gE1O@ zd2~r$AMXr(5e(x*2!3d!MmC~X+T>{TIkla{)2&<>(&@+lcEXs}AEOS5T=?cH67b?T#YMq(jb^h3(Bh;wv8$D{x zxW@BBqSi*tdQk<>k|r`)*BAy(zHRjg(ghK0=MYyVN}RNQCf4uMf> zbEP6ck9fjjh9L5f^bKpGnh3VH-redF7FnqvHr?e3q>n;F;wY<3N_;dBSZ1tG(76Ei z{<~JNYoaC1FfTvNPkKN1c4_r@tZxjB=EbWUrJnN>yV7a7X2qr4;G3MHDS1>|B6rAE z8dURKF&!LUwqMgi(<{xc0_Vv@7C^1r2GvXlvt*7(D36oJJU4fY2_9l^(F3aUZWj56 z@!xpgH|Q&?d}y`GEnE=B9B&X)d*Mk6#p?*`bMhjTZ)j=(7fK9S=5DsOVR}gnHR0mD zeOV8UW7~=z(b&C?yaf_AB~DlY(y+LVaHp)&mLJ2cn5={g`+~4sF8Jp?d5xAPH@XFy zYn4{pFFlaZ*aCd7?FH>iC@Gh=^~9t=*jg09Z5=msPqg1{9l3C3S;7oq!7kZC_kDfYcN8B^tJt3v+B2REV_KZX>)(9dJ)h(S9>Yv#B&bXW(B2~+kwE6dTtj0@M3 zx;d)fci2OdXFRDf`g6=f3tqW`;ifO$MYb7Dbf$6+x$7~fU!jB+c6TQ+kP4@bvIII# zQt>R^dvA!e1o*@xEv1~HZnp&D-Q~7aw4mQx$ti><+iHhqlyGesU z;)@d`-l2r1Ct8u!BIEHZ^zCv(b51zyJKRB^F_ejfdVc*yPner4To)PXQCXEBpjJj2 zK2GitjMc)@=~?O{Z@rKCa3&nw=S_Y{`d6gr7as#L^|wFr+Y?iDfoIJ;q5aDB58k!e4B*mVqAMm$jPP#|!F3ij=F*uu!(91f5G(%pC1{daGBWQVf zX6rsSW|O|rc_NuKa@(f2e3{dIu7}0D5WKI++>DwfJwmbiFxBe3qft2AfJLS}9`{)f zB#fr2y>pNeZ0h#|&S=)o$e{ewQ1vQqQZuN=(T5NaBAc?0p<3$Qk4ba%5|#&1#1 zXBi5%?Mi1)U#QacoQS^f{P;-HAa%84w$M0^*Y#j`<3c}7-+`3E|9(QiDL7A$f&W+J z)b9$`|xO{i7!;*iHFE8@*ft5jR{6uwDiG z*2uIaku=ly�sn(_(N3K4mfwsSlZ`ljqRYRX>Z-KIY-cOGX4%cgpI@$gC}Ri0i7T zLSShD+U+W!j)R^HR0vDycJb7X$~pX<4Eqe%c_4p-$d4Vf^^`7py{8tBGrT?F#p-w4 z;It_BkBCyj%%&F_&dbQK^b`6wjVP5f1D>m*S7(kZSH{J9#uwYV6!WJAU4n`iUF$oJW^JsPbX%fRC#FWMa=YL3jo|{R{C93o9_G*AG!T}`_Kq^OtJfw0 zI-Hg?pH?>pW1SxF-I!1##$9A_7?2|_>>SmewxMcGVClU{dcSlN8r~wJV<==u^aVwe z2aU{sGB{q-<4xLoZf4}7CDSabvj)~5-0=jdNLUo4yO6f-5kC?orA*V;1IUE*mGlK~ zhuezJco~FEXe>%h$8O&IMM|I=kE^*5r1)~^6!U6-_awW9o1$;cJ!JLtBhX5JQc{7H zw^W=OJ|J%)@=_3uR@s^?9)1!x!&WpuB$$xned))rBs)7|PxIO5UiCrMO#bC{^i;%a zN*p;>pzCd2$4S_Kk(;}B8J>i=Zt9d&+FCy5S6Y_RJQ6LRf#TUWy%Lf0esbio7(t1e zYO*ft9gTrYoX6N4M)tjMqr-kZ;|DzBz-+tsS!#6wjc@&grr=H#I8T!5r*fm|yi^6Y zALc%gwA16z=6wcz-RjL7jw%UP;SOe0#*70?Le*RibWGUY1~vMMG%ec8=nc$`9Y>u< zRbPi~QmGQJOG|1vzr%1SKgv@z$4^8Ewf5$7#zWJ2-ZCN2hg!JJFc!XKwX5S_l9Jq8 z<=Hn)mBoZdfB_w_bMtiI^lbcu2^kAq&qW$nWG%lH%gt>vP6Y*C?uMi!si(_gxOM%7Q@$2ve&VJjO)P`MP%&JS zC`ZDXvw!e&QGDkwa1>VO$0eg!Z(}*a6rXH5-dAWFh69r0;o3r{J)_OjurC2D9yvGh z`aZdJ3|I_r2*oK1!{4gYl34(<0K0 ze+d+k{Sl`FRWx=k0{H`en5jZx}!g!Ok)St%J#wKtjwz;(oF0%R4}t zmjJc)4ZKbI;l@heaE6h9Pk^Z=YQ6^k>J1a3+*pn-8wnjj9#nELa4 zPtb?QYW=nCo^Y@Ep;DTjr}*jwAIZEa&pi)Scn zi0w32CxrbI>SFIALzB|>X^(iaVX?*p5pDbjakh9+~b)N2y4tH*}@jRbU>G)y+Q19}sfGVVNN zZNwKEm2`4T?-LrO>L_r`X8KBbpl%-#HLnc+jC=H_7M2#S2k9Pe#k>vWpWw@iEj198 zo8mTRX}r@FK77&0X@F_LC%2m+)H}gQP0R2(ypm#<|KrziE1Gx0;Z^ZrKeTpoETL-V zSJ?g6;K@@w)~yO%dTwO+6Ii}&GuwyM6*skJ%D@;T^mS8ce^x}wd(DK$aoe)JY6sm# zjyr8JDy*{;qdesi2cMWhFXC^^)V)wX0kqYM=0qs?ilu5Ck1FdKvkw}7iM46Nv+uQ) zHOZR9>k`g-(^7p_PcIZ3U(ytw*;@|n%of`WjVwRIjg@yu#Wi}Y(A9HA9rl~S7em`G z7+EM;Z%*jDK!LOw(^!O=FQ!QQq?EX`_=H^tpG#|=c}alGV*<~8@?ve}K?wm>)4Gal zmF`I|{bQ<+cC@~*7n%K9F@wgvF3%i@1iJK1slz`XAht*`mil_)s6&ATqz;B&6odil z`^T3w4abf8L*n(yb4{!Mk+W4Mb$>~6cA_X@AcCEZmE1o^%z|T9xozCPnFU9t3gqrz zipof;V=NgO<|!-o3HO;vJrOfZt7op)mv=?a{fXuIv66r-PSNTx5Lj4Nj!fIh-*PQF zb?_8iij!z+>VRuRzpwPVca9G4;43vv!(f9)+>|F)4Ww&e3h-%l)uT|jjM287bzc54 zDm-}|TEwow<0b(vmhp|a^xj*YS(P>4^Dj3c{)7<;t!6kzO0fN=Nh>O840K4|bcuun zU*EI5nSB!X`ex$niEc$>W(9R>IlywyF$7JKp+I)Y2+ADUHc@F;SIicjup9nEYxq(n z1Z4kD&yT`{+z=45du{;x+7T95tU<2<5_%-yC2ndRz#*_dDQy66G7B=no0g@#H>@1= zsnM}CTx={&Ak^P38p4muMg_+9tBwxs7LZnP7(?Ys3?J;)h6JRG>zD(anKk-QOqEIrkiQ+i_e7g}nDvuM=AzVD_&gUX}6pYAzw4s}^F_vhpwJ}+a z*_wvuT1XnC3ikHF=}BN?w4WQdZ2beH0CGv+3n!zA@!Rq;m2AOLp3Oc@Ihe?W;H=?h z^Y{Ws{aqy;R0zKS`*7~|TiV&$oZ2;&v%7jGHC@G%%4|pR3}rXn$!8vDLr8W!;izFR z5BJp@24UN>V7Wu%JqucSDds;?1iUlu&!a|-c-`eueIs*feUvMG!s}bROpABGWaJYw zmMT0h7~T6xE^)B$SXfniB!Tn^Yvv*(3$ig2e6@k`HY_P6CSlTgt(O_XFuT@>-9FH$ z3FRjY_OELuOH4Ylwp!tjy45T=8YIl!%-879xsw$MgCYqLr z#$=M7lxQcGM8w?8bnO94=25@^8>WrZ3+P z3inAI^hhJ*u{3xD4*3{HqMIg8rmtmuVpnGjAz$)*{$w(hLidH6!F>>(R?FPcM7AjQ z` zFX0+w?$ag6s>_W!p}0VoKY&x|(jnKv0at`jcz2hxCvX%Crg?98LRWE3_dGD)wAkhu zrwF$v$6sPPJ5;IpSDHH7)B}Nb(`L2`Zkc?e%XHN=k91I?{Sp3K)Pe4K>o-vY%{^9L zRE^2#N6%@?!{Uez1KT!GcBel@Fm3SiU1is2-FuoY*hRI$Dq5m)f1{IW=#$NLd&`1+ z*}J6p?Xf0iL?p=*@Z!mmQC!${cFBdsXGhR91pE-X?v9cfVaJ>uYfOKhJO8#7Bx#i; z@f`Q(+wRYl^-gq9kyr>FbNr@jpnBbYDW}?f7&SO#d2?fHdR%qPrOu{ewo!X`>*y!E zn=WKA8cbHMaw5Huy)%@{PYzFPykJcDbZxORN=ybhShgd62@F)|mp4X_;jahSco7ebjooq*wI-a_TAWqKohIKWj+JRwk|K@c~82 zP4j!Id|2?+9d<*xdNyv_8%mS5r!NtfjAmeaaBFHkYS#CQz3s0sthD3P&}$417A_Fs z+Ha#sH3E#bb;YAw4PD*<3?=1K>YeaO75_BifKL@j4D3~VKRe&IM*n8PSfVPfc5mNU z;&9*qE}>^FzYiN?c_f==@|0Sgz!3sj@G{;LJfI^+9CQM#kzij1u~)CAcBs!uTXs}k z+~B9NRbjK`rb=`5Sl#_pFq^H^?Yjm1!yKHdkjb>xJ)LmvDek1OY%k48L6pLPh`Cz! zj?SjUckeNvahbrSx11qumAQpySix|yMVqI6oCBssXC+E)m0na}AaG9W?%xc4iG+4k z`UNv-au1cl!2L7!7)xeVM=1#uPr_-l?0K{=3IxW>%?@yAsfbN0ny*S*Qzd5@1PPzk zw%4D-r4(%i`^=@DtQhJ>43bi=n(#9I3MeP`l;?FSdWJ=3Ra*yG~ox({WY}P99^#{$C)0O&Z=c|W*})s^fdcrvmqi;{Id>OAlV6v;f;-~{4Sx+ddn{%GS|7AxiyPH=foN!= z9{7RL-#3qCwzD0?OETt={DR6tzhVFf)ab%1l3*Ek>~4uE&j3*-JK=wELNO<Wm@ z$TP-_LTHWFI|di!I9i(DXIhZ0ABp1xACh(phxXCYmM0Ai!j#GeQB(;$IdZKiLbX2c z#Bm&#)VMZ>+^&oDE)Ihh>J=h4oJT*|Ct@975U8`Q?QOG#%ecquX1Kid3aC9MS;u@r z=ayq%+aTC!NXrm6k6&*PLrW1A?E1EZk(gWM*b7>St|;z#cSz~!?FjtWdDSi^=-dO{~!;+*zFFk|RUratGA92Oi=LO}yH7sWg3AT**jAyFlOR<9Q zd-$C|{B=C9*3HN3&72mvH|41ognICSPuI&F_)++J%WNzY?``IiS<5CVEN&n|TQppu zyzjQg+Toz}Muf0ByRlN^`q18ymZ=GBaaECLzmN1dslwBE5qcGzOeXAzwN&yWq+wbq z&ET~}s&+{JFBGMtyaF-gl^V|Qs*c4CQH}lgbnhP=dFKb73)NSA`piE1M!VNag(h>G z*i`RaZ!dU>lR;ofO1*#Z)R_|357|5Aa*@yMpb?31XWQ&J-NDnxG3aXUGns`5iTAzr zEq7!~+U|`vG;R4 zBHF8It{DTv+$?Xc+J@aqP-4~jxwW(jkv2v}OK92`&$RT(&Cm2ohO-v_&ddx`XC(K2 z%BcwTgSMPk@Eguj`(DxqRToG&AC2pGan64g7l@;zY-rs2Z8g;PDz5M-p!fMn1er|P zonCq0Xf-k{|Cc{eWy$a(2}im(`;XV(Oh-{XLD||uB+7;SVv-+B!Ib5L=x|&M_9*h4 z*xw0b#_Jwvb`7L>ctjX{9lR4$891?9UDBIlVX)+`nEjFSRr9Z$6)&aF%lV>?YYSmDl!|pu|5`L7u>X16SOA5XVoTZF;@`MqAs^EICrZx{;*1; zQZBSQ4adF-HSMK9W&^VIkJFF`UFtMFFY(nma(yAf-|dAE+WINAYIF9oB4j_cE6CcKxq35m)RA(}qo*@UKZe+U3|E$PEPGNAHlM&J=n>mok>&AFEkj z!xoXeNj82HyGht0|1_54kY94>rGZdQ%H6wx8Jz1)z(x%_uRu-K7DaklWb$F(&$F0AQCH=x z-<@j~!8+&$?~81-zfDdiR`gh!3PK2_V7A0ztJoqJ?;)_30&yNh*g~`O4r>+3VLJyi zGOLg5Cx%~3rqH-Qj*K%)67G+JnVZmw+!y9MVc*4(KQL&meq>{qU6S-P#xhsr_8?~S zO0a+bB?K&e=5u>WyRf*}g7>nA_Wzw>IYV|C|4hYJU+%_>Ud7_&=%P2(aP|DI@Z;d$ zG?pgl!0NNp^Fiiq>FU=M#=8DhvNorzD<1vPyWbwta1&3N5k5&`k!Pdyokz$*U7A_l zzHfFpsBpJR_q7fG%5)u~GY_)j5jiL{%vkvAZP+1xBUcGGCp_&`N-2UM8?MkVh~i+# z{wFJ*QFfkV@G_2Y$hNrvkGj=AW0NNp@e}WJCZ36t)>_MT*k*c$6gRSb40}>rXYRu$ zG;W(F5kbIK#kz^A*NMI8Z#HC=+c`qf&;I&`hItLT-+6)<~X%TOy%$dpmfdVMS~CvIm*3qO*%u|t*$+e#z$X%Q~JmxIA7KbDl0 z;ibSCW~{Sx5^#RHHzVf$tXi}pV<#E(PS{~@518zAh|N0|mT_3Ww3um@pO~#nDUQuL z-Rk7<13Q@U6_BqOd3R03}704e*J>Ec4Nr^$GFH2&`AIjQkE@&d{xbK_Cg78~&! zZ)6Nc6Z7uCWTpJPhpD1kGli{j;!viZi=Q(dA0v$LNJ4YOpFbMKo`o8~DqCXa(p;PM zgz9eVe0cKDC0~mVetX_(Jxo1(C4>pk;70u;=Upqs2l zv$ucrn2_JjIkdAIwScAVhWxuidjtdBzNfl}ZR|#dIN8vcQJaWJd>e49uCFH2HsK)) za+27IQ**-_)}ji>Tnc>)a-S%XSHyB`hR?6QT9NFcj&oCRvC0kj!=1Qf>1Jwv#aTFO z+)ax154@6DLMIxdNRkJVo52h~OxaR~3~=`C6@vBg8Cd^T>GD|6$?oY%0>YwE@O+rG zeyQ@BH`~Qv&4+dh1Np&kzR%wvtI$QSpkdCgh^sG471|tqyLBSWvw|(!#e`1ZchR^0 zY5X#l;DlBBYv(sL9tPhsP*t;VVy(KwQQ#8npX8j6-1=_K35|?XQ!XD% zEyzgZ8fo4Ntm|MUl&ORPcwhBnkj^VfTVkp+w+4YYU3W6YRJgTyI*$FIN2*R9 zm1;s76}W=m%>B1npW=g>h<{=ulOk(qP!L5ub}6UwmexXJ)`MdIK>&`{EFK}{(28F9 z4>4=8V4H^NmPLQ0y?y&~?ii7BA0Xh5ImSu1a>tC@O(4e~3;02Zxd;A~yZPgZCE! zdN}ZZ>0ul;%CF+wbl%+qNye7QA`!o4ed09zlz~~`dKk6;x0}^~s?(Fqc+_F>b>Ah5 zq^`%Q7BE91fByLdhHbMw*_>N0aG;w9Q=u=IKKYwn?Qxb2hI;r-(lZTx#pPDdhc zCS<<>D;k`zUi@Cm+(w;nmX=S=Uo(ar9Uyto2^Z-U!y|IxTpukD$QPS7Wg%VNNkrVC zYPDs*_3K!t&}SH^4ANxdVatba%2ggICyTNilEs|-hI1tjzfz#9-Vs}mFCk4&Ryy$H z-1B=|*z^uCW3eZW(Y6MsR5w`#J>JjM@pArYPv|7+-(w?iRyDo&gx4J768a22-nY(& z)+j2do9z6oMrNnl>dpbEA>H5HpIJ*A?_#o-UO!vcc*oPdtQDtO9#jK;1QXczcQL-1 zoeTWi4Mt;c1a~1G)o=SIYJi412Sd|fg*F+~-0v$Rpe_4r&pS;S}lM5}=ox?{ni zTTi@w%~xn=hqb|LOxBe9JcDnHrezJvUpWVm_v~gNWqVra?8ihR6l)M>qMW|gE3*&v z4d&|^m7VNh5x$x?4cK-f>$DxS=ls$oj|3On9HABL%DxzP%ck{wluYox(TW_Jmu-n~ zC-$d~W1n4d!+~ezSzYAUfVNUj_Ojt-m+^zKW<<1}4t~xJYRXu8@J;;>)7=-rMXXek zzDb)+Pu~WF_qi=D?vZpq1M&dv$hAaYq?(s{w9zKM-n55GlmOElSQGgjUtigR_NCx* zMBE%fa}004PeDmbKpPHwGJ?f${*igsSjN^@aEXO-+U~>RIjyAC+Rr5! zg>JO8T`0u+Ub{Lo1!;BNeeGng!x(1ppeONt4x7Zqc}>TqZSWsoYx1rZthp%Suj0ki zaSa7JLuA<|;9-IyC&u)-&3|KmqSi-A?yEpL`BxtCX*Z;!H{`K5Zk#+LzNqa%UDA?| zRC{guFTBNG^!7MamUEUk^5eXC;)E4N_=nhI=p(k!Q|>v!N@E}^O&A0FqZcr>DQ$x0 zh@`J|bKBI2Np;AO-+UN)$x#4NCGmdxZ7qAa!X1)}HFt1ydF9L6ptaqa4{7a?x<5}* z6Paa>U)ZbYTlU$W)3ft>WnyJ~D3E?sMi(|Aq>(m6Q#w$ zq6+}D2Uv$*;T!YxV}}94`0;l(9L7XQJG>jFbi{RT(GI3-)g*i*v*wB`dVQLh@S|)q zbO7h|DKs0XQJ7*H3LNcCf?%9~7@68s5ZT#`9K{0`_HPYzga(NHFu9rS^RRh@DLS3u zNLul!h+TTPlRqAPMwvBKt)*UsuCwxqMGNJjKK0j6B8Z2bqWlPhgQnS zyX&>GYVGwF8|!p^qJUx_DIpI?mgy({!lB99vv+ zWJ|S4m+BCj!g=Y}A^)JpPuXh@$Z!jtsI%cd?fHwU&Y+*Q+%>MEY8y^!6*oT;~Rz zf;?zMxIo>AHP}rHFH+jf>Q~kpbb^NZe*vCas5a`zA2TmzJR@HO#0hhUUHF$Qi0e!HY?MNjW@liN?+o(2LB&WSp`>ijo5FWP&+%V*}Jx!%= zvf=O^hjC{s*E7jj(t#Db(EZbxhieF-4i{!gz6SYU_D?&Txi=I56e9$nn88fte^e^Z z0)}bd1l%3rjff8>rM3x(|5vJVb^Vm@tGMWNbl*EMBJP(g@JUBah5diw*xx&WVR3GJ z++MCAdHrJ<`u1IZKq!BEo?*Gzv!mA&rdw$ zDfNxB-?wMD2SV!P`ZZ2|WaZd)ZRmTcHEe6oku^N#lvKKl3DRj2pFqRnhhDBUmF(Ew zqE*44zUvAmYOmFVAbx?BFa?8tk~5BHJ4Itih#YGuP#zFYKVF~zVc$Gys;b}fbrHJx z4PR&QH&&>`e|cD7eMEKw}R&q|+l_EesTyl}13bYDnXDmy&! zZqd10wcJvq1Z_pA*nN&y3`hwT0EWhtWhCzdL*kLa#OJFS9}tKdW=b23e%Vz(Ziw_VL5Qami4*=1ya4c_y@m*PFSKmw8y!x}0`mGvkovye&6el$DYx)o%O?`iOtD|l-xpA_HZ{mi!mY~pl z-Im2<>d-W`btO-OUaVCM!&9Wc;{BVsj^Sly`?fDXJ>qZ`VO7m9%rM}hDXI^cik;ZG zaXPHrH(l7dY$8^Zupzpa~F>X<&RH|$h;vwJS#t2a^N~WsG$Wr z7Vk$wjyvtD#!XVXxLJ=>U5lDtcCin1QFw31qgFWBH?L5ll32?ym+fp$hcbNoQ+&Y! zUy6)t*G}G7pe@Id;7EAj#AL5)O#Ut&ZkF-yUBOB%K`%`cmIBa&4ABxylt zxDzuop#hJI4Lu`oRsSneP!4J}1bJhux=Q2|l*AW{M-z2jN_mc2(KJXgG51xEHLurc z5sx7MJR<{)3t5c;`z9>~|K&x3#}du`bA2fiJly|Fi}38c6umEZfvxwieR#c(0`x*U z#&qS?RJ>NRaV7!2mLuYvs`DZv130_JzYIV^Qzz~eJMfm!&3g^qE`IbBZ|#XhPNymY zU@q&7_!TVez*39>CGUDC<$!^$F#Yk>mg3n7epWc=0;=jYY6WN1(YEP(sNWu2XH3Do zp_Fs)+w!!o7u)yz6FNfyjy=!3L6y}vtlwtTM>5u{(LiQVQb4-ubmz&vVXV;2ABN&7 zKaK2vYd$Z7RkVJyq8a;0%Z*$*C5ka7)MrIg3a&rcP;{USQPyRbFI)Se#gLp6c+#kr zHZ>yd&47F)Yun4&s9&^x6@(GNl!~9C9aB@^Un1!@D-gQDXs=juI5rvuP zTj}bN+QRWH#z2aF4a1EP0a|ILjz=FLva?E9am#BN#80cq8D?0&=*8U+FqF2{Bp8ST zTZ%)~1Ov(6cLkj*Jm$DdlO6^$SLid1RryISB8~SSZn`MdYM}eX0crm2hxB3bw=z3E zKj<0C;=f6q!1NZ`8~?m(0w?+C7qzORmnpN6T!Go^LIINpfLr?hzqkb=^6~~H|I?Q^ ztdOsRcjx>Kafoohs~AuFtfk!z*A6wgajNk`JUGz6)u zS~^R*Vtju3G_I}1f(vc{{=_`syOHa#6&JD{o)U;nbYPRh$=Gxv*ZBA|Z7SlP%~3P^ zd7UcY@9ACpkVG>i&W)%4Ap}mHtP|iMJtFR7kX`L__^vm}J>Pui*;O)PD(13ndBDp+ zC{E1}zw_;n(`@Q!yGTMFBdc{@xEa(Me*5YVmR)3aS-7fW;Eij*=-jzn@HA<8e(bXC z$2Y~dMEo2%XhwuSz(Dz1pCF1o=6lStUo3sG1uYXOYhDiEuRT$yypQ7JcFqG6oN1_a+rD z(?fR8aQ0MY;eBa)Oy9bcZj&$X$}AwyUU}Kfxg0dnngLV0&FE(C%*VzD-oYLF^E4Ck z+*q-!nzu|!cix(ez{CsH>~lchp?{MUsx?Ph|2E}28@A9Nb|IE$<{i zpIAEzKltC$9M~AjNX|oU9rYvMf@UW8$#8|MFi~c@GR_sGNl*IQL!89u#Re@F@(Ie7 zyMlEY*KkOZnCTk71KPKaop>mL=a=Y3KhffHWUpwOw+r>8-7O`WU)lf7V3j}#RrD<7 zQRy1MG5N3h31?*Dt2hZ|hiL)kxEsLr+RAgg#UncZKKb}w{17liCtv&55exzyfk7N1 z+Zney%#p{kNRU)K$$9Ur0H4 zQ3N_00M?lOvi7%~I733f3>3hWZ=W|Hr-73KAd`$#LuC}*&N`)ImPzt9^=&c)&~Wd!7fhGMO$!2uN3Neod1)p&Nlcn2*?vK$muo(MBFvD`A z>|ARx6>IX&f3z5u(73M~3NPjFMN$!k&QOz=A_^LLe)2G2XNY{%3L2TQ^5A4LqKYkjb;E{||s(EfoDQ$`sRAhg}g z@3~v+g=iZG)#1|fl;ZO-1Iy1p!Jidc!|orp>!D{{YLp+Rs%5o%!6$NYdQ^hPG9Nb7 zQ`gbeq{xj0RnzlWWG}9#^1=pE_Ks%GouVulOR@ELu%RyG7qF{`{|$a9vleaRAym>O zmbUgZ$`)5~Pc)<`^`|5#|5n8N5voU%$3gHrf`1BTCqTixlJ1`NU0ey2|19r2 zrd=WHM28CmCK5l45idhZ^p7=oQ_cDKJ8neWU>XS9rAa?$2+RfW#J%k&B|d*%a0m@_ zRsl?LM0$U=&t>m~Xv#wVPd{el9~0M(70%Wf*xp+3Pd|1>MteuvHh`1=Y4N8n$b0NQ z`Vrpb6pEFKuj7XCEFP6iY=+o7RGX|wnEb@Pt-H@0k~vrvQnUfK-P4y;T*VXuX-_m| z;*|7L)6}*a$$L_corbOG0Pd`Hdo*PJ0Xo4@ZdKb#aFW`JNDQDoM;#`%`OnS*YPV;J z2u-^jz|SYnggp5gGrA;v+D=#99UuSh3b+jo%y$sM^Q+-wV;*PyRfQ>(i}L$47hR2{ z@^MISd_f7=wR+qGl2JA$IAEFkW&HwkCk)LMm)bP5*EMZXCG>NByZYar2|K_ZlDV_h zKs~{SJ1!O>KI|_|Qh)ul!qqagFX#ol?9wm}w@n@qXMPwpfx=$W1MZ=e!mq$`{a=~K z*KmIJ70TXavzQbrAdq=-yg-}vmuK}et7YVwCCU8qkrtt^Zu|% zzo7|mcxnDP<}MW@1@;Ys9|0ur7IJ6iz>v_9{N|0@ok8R_jQ9-&w1#b))Sv zM5U#*gefl6!I8a?r-_FVfumgD91UL~V^dhji+$o&U7FN|uymIDP>Q5$-*P%`LGf(9 z?Ir+_j*obO*$_kk{9LQ5@V&%bA6r;qoF3W)dU0RusX9wpaLJhme@!AgX28R@>(Zy) z!1{z8RsdoFWW?>?EnR*A^-8GhgTJTvO*LAgtbut_efvAe>)vnv#7yve!u&EXuOr;c z(cTLKAmr`(0ISznowv%0QLfgJ%YkY^{hlJqPW|QQE=y_B?>+nP2jA)@d?oz<=h4qO z#?|%QCayJVc;1=Zq9uS{1&)3FOUGpG?qbI`qZ%LQA4|9Zu!MGMH&a?%$Z%zmIjwm+ zI@go`t+&!^dLcUt+zMRa3b5-Ieq%A_dA_9K3ex}4bOdRyxzP>9P2Imwdef+^@>Ssf#n$x$ITvz& zvimIL(n(Hzp4HbW7idp2pCK%BQ(SQ87efUXQ1D3XkShwHI^DshRT)08X$83A}+>6y3CT8|4f%8L7BgXMG zf3JlsXgl;!jwnJLbJk@AMdYHoorLM5jDy%!=*_jD_W1B$ilfRn4UbR5;u7=B|BBrg zIz8RxSq6YJ&p7Os|*BfG~M zY6T#srLlynNq0j++G)NwJ=0^ZKETxpl?EQ@Z70s@ZH%%25Y&g)l#D=EF~x^s*+kd= z`;GHEoai@@-M=Uh7M`$uEVILp#|?%RRy5cr+|ev!v|l%&0A(Ae=GBn#38ydBNY=IH z)*D9-n^knsGS^4Jk|x9+?U}d9)F;Mt4&?q6*VuyTL8^!%=9YhoLO>R!^?wng0{c(I zsB+O3OLqR43T;;R9t^MA`0x`A*aA4|*{k!RfBLliEQh|GFkDifl-Ut?2|VrQm}ycV zCeBrOAS}KQ!vgJ1<22z7PHoE};XM~amBOsah?p`>9fJ4}8{I6{6Dc*_J>tO+=U~7; zbMe1c$RKFF0#TL~_gwX##Bi2EZc-lty}7Udgj@s@U4I`PLjrc9IA9?4p}~%~t|SKP zxIj8!Lw@VcQW#&X=9+??*gd^bz1^>I+YQn#N7M0AUxMAikj??)7U8vaFg$3^K9KWyj|5n>~ zP_!W#@nv;x)tv$`akul?M$v1%1?3N818Wo!eHZuNLwn##e5sS=7d$MEkLnk#=q!sOV}qaW zV;F}xH&}_2(+7M@CtShAnKIgNtdaCiY10`e0Q1QrILeRRDMOO=orpMr%<}WUwLHD= zZL~Iu4Yk=d6`YufsrGl$KWX}_xkys(4<{=fX4Js*g!t($FnQ7AXl)1iH7MS&xPdxO z^`)8;y>cFLryW`|xG)sx)F#>!6@W@^8)9;gpiM%&sJl@O_)fKK<2;W6boA(e5s2e& zlB98*1pDc~KF-6d;M^63+RxqXznCu(!6n)ajQWDN8|d2ZX%?zt#38%8ez!J_uF-(e zJw>3b23&ad$Ou?PjqAPR$Z*e+`1-RQIzJ0YuY`S6^4_^g?y}4Tc=x;NbNzUIe-uue1wcm_{@>sS2(LWy61+^e;CYX= z%RplMnhqJbWv)(@@!Dj8T(4+lW>b(QO+PThDXA8`L;BbH;T46MliOjI{Lb)3ip>EI z;xcd$X$3ifMf{r^GpIHqe)+WL2!LiKE?~k^42(V{U3|t@7pG!PZ}7g0e@zltt?#`B zx~hNP*^f$~^F&HH*fFc64wd-Buf>1N)x)IET*{A`Bu1lS zMEuh%4yYupFn1WLrrBLuDHJ!w(7DUr? zKP-;9L!8^P--|x-=5)YGeGIlvBv~6$u_oyf!!EdrJuH5H{M>ISUC&P`4grk;ZUn`w zX9-RA!7Q}`8j1fmS(>#l=aNefFh9g#BGOL(5Wb^D#`n$O`9Z_*kedLk66aiIJ#}g` zIQYl#gRVDAar}1mD=RS*2n?H6khb*lwPbB9`;~{OK>5}TjgeTb%DDsdok_$);@%$OdLuXLGhx8Bgfh4Mzy9> z-a3*qZ~Z@XePvXYUAwi4($XNCE&=K81_?=NMH&PIq`RcMQ$V^Dq@+{2B&54jq+`>3 zYvbedzUPed{c_0IDBSm2SIle9d4*BEo-}xRKjvBvSy^B~UJt&fY;IRf`aKF;VMJDi z8R?bGJ+;hYoD9nz3H`NHi5Jg6;Jf@A_;eD!h{Ov1^zXuq_DcPVMn%oAmjo^@PBV>S z+G1xop&T+6Oj3gNLXmbArYRwgb9lJJowlNzOb2&;BsJ@@pS)B3VvAgXQl8xJV~5ni z>N)q>M&vJc-WE*iQCq^-m@yu5pukHEy2^Gwc(-p%h9$t{DD+%ueaSPCzix20Q{8c4 z*_xLUy_f!DW&T2v;sf@`MQ$%*f~`E^weaPWts6oMF->qb+m=k(LTQbJhoHL>eLVLZEE< zjC!2XY)`?Wa<&-)!Y{umNo40@dh`jD%xQL{ScrbmD|3p1%E?8MJmBZ zd!KtlgcVog`bJ}ElI}7nK3f;5X-Zwn07#pr{+(jgq_5@w(<>CF?d+%@C|`S^*1ynW z%vj^HhfBDM8<>kUlriK>qf1^{?w>2Qu*HSN+siljkuyJ+BIxba+38mgt%r=_%xDPbub;7hY-wz?<`OHN3l@UDK0pAj&V9B@9y>t-2J^@;>q z8Jg35Wn%~&Ks(;c~!KnaZ+o;ErK_fxz4dgDN)TYZJk>raetGcg9CaHNpjLSv_QlTr= zUhZC!?)Nki?#tu$e&W$FpiR4p@%=5dCwLViYfkWBEnW+%LtgiRuTk)b+rek6QM+nG zK&SEMy)?DD+A}L0#n9kbP7>Q4NWrN#VC9(x~QxZC|r->d{0SMt$BK*8xu`+YLC`Y)R{gH2)Hbf6GC^t zdWhM%G?`tNm8>z52;3mY{EGeUY7BeL<)^D9x6{EigH57i%*>@XHCOD3YxzKBS7Unu zdBBF|RM}|i-P2tA$OA5UG5FkruAu0u>jeB>SKF zt&!+&QKi49#H{PuZ^-^3kYDo6tag?(J_p_8Ox)#ceNKaRMm*sR$=_s-Fb=^KGW@%4 zMD8I4s{+xf9(j=_6i!Ikme5R%=2hcm9>BuXnl^~EdgMkHUM0gfCKfMv2FJ{qOle>M z^EO#Dvd=}2NUuIWR!sNe&x!rg7|;DH4#9{>!qx>v%klEOFcUKlc0A_Jn!RrQRyIiV zT5~SBy6s}t7Z#s_TE*4(%OMkzJ}Pew(BDL_@F0c+7iV0BPNVa zo(QsUr#KPmZUx6}+{xD2ND&R4>Jqp^NmJh&`wG&r;v9lJTCRn4=f7{4G)+ij24&di zHANSp9n157DE-ZR0KLw?Yb^@gJ+4onXWK1W(q9FF7a5k?h%yCs#qC>IEv()dV=XGV zpXD(r@gHw}vnsocw6u8R->K2Q#|{eQ{h`a?^=;(jL0ZHR4r;H9WkaK}Pj3xR%1@E5 ztxR|Ija1WHTU4N$>A&;6Zn}lyzMhXguNs@=dfMFaQWs}BR@ZCL(ZS_(SnKx6+;2k{ULNt_*3_2K58VhuBCw-* z8Mm&J)_7GFh+g*en#G~Ov) z`7+1zQ#h;C?u9#jLMTWB@xowsHS;p-a%GR9|9NWjR&(?DZlF^FP5VJ~QRgzB8SXiF1YF4R<07~ViOC6U>zh0hrhl3NMOhJW} zXTArp+ucY>xPX@^<1@`P8h|+{f5MdrJ*oB`Nmy!di>suGZ>(P1kS*{iWJuj>x=izk zlXrOP|E!q)k}!d1;WS{fXHjb5Z;~zAznbJO52L?vVqFd!SZXkDcTt5bH!pz-F^^0E zA;Oiyv-nz+Xk3XRqAvyGmQl(MW5eI+6Zq{hD4Wtiy5bT1>F1PJ^Bx}wO4mdRem6)& z(1Nx>`vf}t^1>)b`&P?{KYLQp5oFNLgYr+?71Wnaw^l9;;ttvu0XwyU(A%%vsz|c9 z?F#V17k?uNhRKg;G)wH$#TK}V%E~3(TV+cKGpzUlUHSj@M`{zs|2Ps=eZvw)TK+cw z_owgv3Dk!MG`^%Be<$s;H8ya`onwU>KY#uN>^v*F)Rq6w{{{Q*bKam&sGi3j|BU($ zRw9MPV$Zhn)A!E~45=Ufuc zS_0%6z%3vcOcyq}WfU5OyaM^v->&?sAM6m9TRqa(|&6T@T{*`cy=r5gJrd&`nFBYi)O4KIvq43R zXbe6n{l;g_-GR?()ArE**XcRy629@SG9UcudX5?Ksj)<)&p(i7;o#w*u4kW4l}52( zTeLbKRK(kydk_8cz93!BUrgzj2qls0%iH}nVa+tbuT);7@#}ji#EJ*YC)IBdW5mM0 zf@-O2cx|;Kkl|Ci8TW98n}EXsIVj{ixEDNcJ6qTW_9e#`w`&eyRhqiG-9Yj|xZ@jGJWRFnYMqx|KI;*Oj@cYN*YH&(!ZTT=aGT$m_&i6wx(71HqQwk0 ztf_E#8^EU6@VVh-!$75AOQV!-8QQc`XDGM_qmKyM-H(sWiTFH7r3MAQf0T7PSYF#E z@lE-KHYTqsMJyzL;>+?Lu%FAWnVPOo`74V^H_sc~pYAYg4~5-efv$~tvr%G&m!tFG zguAzkXY=~<8hZPEt#Hu#B4#;hYQ}c?3U0jFSAp@2d1*#PhJC%qEhnJ{ucyL+HfWXM z5t2*lZLcAwp095jB`g!Hny&9B@}-xTxxUg3t+dRsbK!7$?!k6Z?{#xVSv=?ZnT4qb z0tuRlEr)W2vkujK8|thZ?ChDb8hSu>8@lIph7ggK*}O;M#pAFbK z7>P3fQbeB>QQks&coNfGol9zT?$))(ys{KqqUPxxsakL#LEMFYlTXb0sH;@oR*8I0 zlEOanJ|qxB#Tm8vMi#+s;;~a%^a#IhT8-~DGzT4q|0c{ zGf1)htWn{5P>GtxhxGlPn-opV>*sfR3eS^;{o^f^%5ATgUcnEGwY@CFk53JYgwON- zAaf|2jjPPOK}kN@XKVR8&Id1S!hD5%&PNg_qBU(Fh<*$KFcaK~in|Ir8O1H_7cZCl z+`CPqlkwIC=u&3ng>c+dJ5qX`I#-r^G9W14D%Yfeh|7|>)WQkd#~v>uyw(lF7Ca_3`usz7HX?qWcdN^!DSezxdo1Ovd%Z#hgZ&yyiLauS=)gK;4W2s`A!t3{WetCoX z`1bme=FO(FPZqHWW|~u6xUkARVvkT*+#&Ya)uv^G`!(@+mlHqgHa3(};L?sr*aZ_x z>g7AFt1Uq_7XP_-L(lYwVCC||CIQbtkA_C|fxwJX3Ok{$1C>Oz_g%sNol7vIY%shb zftuM-I^TN{fAIVu5N|%0rK#dhe=gFg51_>@QP446U7!aBG@?p;T6n^&vv_lTc?XIa z>h@_DBs8O!T{Uhf?i6JyRo~z8_94)s61loy{#1muLq8`;U|(`SMtom>DpGi-u?4;} z;062=D(Yyku*(u=t1bGs0Wq&5D-9CqxkyFm0#VF;EZ2phUblWy-RK1V4(H%Z=L=ah z*88iY9bN8f?2o+Y zqsWXoP0;PLx(73sx2g3-n<8J`h3s-&BnW^3*vDzp$>pS^UY?&kjF|~n$VAsd4V}Q? z^Mlxe0a@nFy=%`gW~RlTBUT*cdoRq06Wl5a(NKH&d^UFw%Ak6My|L0n11V8o5mj$z zWQ#noa*W={aLunV^L_x6E#q@#qKG_Etd`0d=IEDpAs%tuntP%k;}K1{J!4rRJ5A3a zHt~zRlzE*~c0NB%Q(oJ!aI-_Fkm?mwv%DhrLEVwj#OV8UY|rIOXo@V4#}?}$ffnxk zO!*Z~vIiot6?jheDYmK+pgd%?$M!_qUJj~>kx}>6wwWx^=VifX@N;O36-;vRre_0vrt<&Tl9EeVe z26JVNXb526RJ@5`6rlbjG4{h<^i$VmWRuu~OFFqqd&1zMMw-Ox#i>3k+@{nOR4Z6m z%lh@?i5e5Mk$lvN(W2a>)Su;ba!f*L7!%L@9%z(ao%EqtP`o5!GFb&x#x=IF+W<7V z`Lj4yvfab<-_0%5IvrlZLrU{%xeY(hOJ^x3-LBB~(_13{90JwASj}3_9SooHNXkN6 z(rMsKcsv!&zc0@}%jt-bFSQClO|6(J>8<9xOLVbmNxuV<_|zpFkLCIQqX!JX(>Ez0;*|u z$GNRyH69MgyX+YI@5BD#`Luc9XM(tcWN3>J-ZBz+VTMZsOaOkKSMDqK_9a_L=(3@WtFBw%ihwE!qByi@qc>5pKC<=Jb|_- z1K3hAEVB(Q{*~+btt$sgaGFki-EUYy?BmZAalWyEK_5->NpD4fnb6UGN`^7K89nwd ziQu;3ho;t-&_m^>z z^h282qlPX6svHkx_q;$_$hpj2W9?Y7l_ZJNP&Rcdnx*gW=SaqI? zE}_~ortNHHep?(1R$srQNkX+hV09mfB;(0{^h)hgDhKr#O6jze=vi^swHBSr0?=wR zJG!u1r_?jMrWjPosY|F|S6yxoH0IcbptmAM^O*kI@BsW3=dnH{nITcRdJm(j;?y?g zg8SDm=v|WqP6O~lOZE1ij?HHh@Y79xw-eiaMJb1f6PRr_bk-htJ2KGuvi%R*yXlkN z8o|We{M>q--2bQ-z~w>Y7Jb(9g5@S8o$W*&M{JXvwgBnfMKV$)7F`>-mP5tTy`{!! z?dA238}>B?#@+?a7s-TR-UEo6zbpTdyxmse>5E%h6pHOrFkSudg<8x(rphFU`ARuJt=*51*oSx{6zNd>Ay9V-82Qm3n49zs50~$rut<+LcSNkMB%VbIQ_ua=%%ZbYfr^vV zhI)+T37;RovS}@+yQ&zTfXDoR9UTzbpj66zXf_bjsJ|Ue>ye@;r^2^Z$wy=cm(Lzi zpeTF+s?>lj?gjU@D&!Tdd*K#hcAL)`0b9w20N`eehhW0~l1~XXoXE3I=HvtOLeHqw zs5g@1T#qY))4>qO!zAzpJfGVdYgLXCWhG!pa8v^x2EoHI`GbzpsVzY-t5_mpWr%d; zp(BCkmK3=9pj&)_VU$5LnP3BmNBPm(9`|-MtipL9n@8Eu#q=rIEayuV+fe^TLo@jB zXxPYPvg2Zv|>}GPs?K zYh$hm3I$34nN;|#0a6+V^#ZhBE~yZ$m&_)J>hXcgK)3j0F+Gd#Lzd~!((X-!F*<8= z{-swV-Cr`l_Cm?YVuh4{(R=k|Fm`~^#3#vj=lcH*p-Qxv6-emQ&sat36k|U6iY$Sb zrj7y$%HTOv^9Fb{rS`phx5qLPqN+kvaL>SWg4V+L6EzcDqmi#f90_0tu~xKlhxdv6 z3X)Sm>=XXH3gorw7P6$SfLyQ^u^Zbbk}RGG3JLV5=iY#xz_rSzw6OLS*Aok1$~;67 zY#<}+2YQOeYz)V2^{ zS{^b`u7I8BEe+4tJuFTh&f%Z+4Wy)QV=_(&Y8pQ);jb_@!7B?{VIDuWAM)%>A*vajw_LKF0!LD!0Pcxtp4^}wmwp5n<&N^_9*$tW4*7wDNtHm@~(~& ziMee^gSjplS;`+Wfg0cv$nPE#DL zn+!8pOe}`NGsT5af#KeVBrZO&FUUFy*$5qgUFa(S+MVsz@8R5hVlRhzB`d5AYoDPdBee7HpsBSOZokfgd2UwHVsyM6JzyD8l|{zF6>P zK4j2PtJCJsF~lk$w!Zkw#kfL@_r3{cju4j6Qz(2xxGvD^^jkv%$tKiu@#N^z6`VSF-Xm(Q0v4og6--#;@>Z_9<-5(IWwWMc4-wCrHxPf1ygj61llG^64!|ya5BpzK{ADn0TU)#!C#vSq0lO_ z#@>1qz)_4=arz<)ubf*S^KVZwNINSVA{lf#;=*;aqKAU13NB3dKrsFX{#J8_hY&p(_9p2)pxg*t{6C-*UiJ=ZX!AZ#Z^$^_K^S^U#Tmsx6 z#`>_KHy0Gws0OknWL3wu=*usdKl`Int#px&nG8iA9f80a(@=!kyKpBH>1Cn^bRHNr z(G#}%LApBCTJ&FPBJq~fKe~gVr*OEss@*7I<$Z{~yqw2u7+=wzMT8Q?mB5w1{wMcl zIErkqZ6Nm4e<)a{?=Ix$6En`}LhQ-EFQEUM>`TpKTz>S$t#41>jm_`IcuOzun*3Yi z|5Afjw!Ic<>(r#8f}0x$G{ejWKr_r+MUJ5#fCWE2aQ3#O{@mUSS@>Swj!u>KEf_lh z4Y1}S*?@(YG5C=PvR4TJc|Og>`KISRq@Y@WB3@^WIAAk|zT49Dx!H&g;4A!Q_v#6H ztN#;WR$=2lHJMfIPceKWcHV9#bt1H8i1IeX6Eh{3>xbvtX7u7fq3{&;UJ`ns;4hO#r!$}WHsmwI13%6qp^;2$y+PU=kW-=`Ll1|osbf^m86 zJMbzdF#y{^VDn?3m9W1^j}WLhcCH(>T|bW5*%;4RmLv^3u5zXVR&4$SV8t3}YK}tp zAnKAQ&;F}yLFntI!_fU1JT!}KF;!U;g4S`gW_ed_lL)y74Gbt;9mlRah;btQy}T%F zLx9MiOZ@f^i=rvZ7o0xemP{tDb#-T!Y#EaYbPHO$S+Y!k?F?J-!J6B+x)*}>a3c+3 z3n?f$lw5npa*6Y=01X=HgQ{x%2uPW`3rSnpltf`R%U|Z%hN3~p3t&Hp2WZsuH{t?7 zSW+9%w$!~TBBt@N_6we|rzBl9=mX9sy%tYfRX~q*au#5DvfBBeYHn_^Q3rAYwFJ zy34&nI3xfLy}Qp~>BEnV8V-0D`NzKmQ-Ob{sTpF~$Vl-TmVVMu9|R8?|p*!%tBl&H)b0oAr>1YTQqXBbUoQt8c<9H?Hjz? z3n?MgK{O-^1V;$Q0(aE~yA|Tv+gUTw!swerM!%rkH7=l?w28l?Me4B)btO*nAUD!< zY~^sN>O42-e@vyOYjR|VyusU+S}0vm7m<^k0^B>Eu2ILlGa4>CyWsvA?Dx(_vpP!E z4lY=&_3G>L0xG&|w@8e*M}o{F#-Ei;FbY&~QVv`B4f7BY>7 zgvv&Pz7mS&ErAz=Pzow&#S4dA47}jy<+a1RrEKP}|C_Cf6hD!H9IVET0g_#UwrozK z+E^ut(pnB0+fXEEtzhF9W}ODTyYUO}Uy()%^-3-3F&5}mG+l#6q*XH977{%1BLM-> zv3vF5#jpbG+(ECuXhk@xY0~7q>Gw-H-7*p)#nJe3trozT`u%^Ppc^%k`f{M4YGh# zY4~;mx0uv)6MD_A2#SWdHzR@(3Wa64W3Y8QfD)eItxzy3lo>1!ggR_P1Ch++9Jd05 z)foA)uSE`+gae006&F)%@L#S?L`&=u95|%vcR!{Yn-{Y^F_u^(`aWJGY~XRcSLz)^ zvE|GxR;4Ry+iyu#7$uB7_l7PL=^l3<8L&ViOiVtes#21)McdVy121-s_^rMcQXUg>>C|I2TDOkuEEN8f7 z#<)RJR&jk0bPV1-_?a+)+xdc(RQa8J2n|aP|KC zAq^iY>M*UXh$^W+pOOp{kWhSGy05%#y8PlG*-2EnpSaz~3pw3p9SE|KC?oY4X2nt10X5(x`p zwk0PgkEt{kMBH2Si_j+{dSdZfT8H54UUxbx_C4G)gS2)3Al$XHV`2VPg2>by8mFJM z&;e5eDsyQN12NCp;VHSfpKcJ0KQLwLdP zTw;lU(+wKUquU$P?fJbgA zA1#!o(U<8tX?w&qAMYsvHV3s+Q{mpeCCgku@8tZ_a;% zgfz%Jd%E{QpJyOqa-$|MLm)U5j*Zf_Nz6I)ncdn;ugpioE84cFhfi{glFmYXa1=eNMPC&VWHRKkf?eMX_ zSY_&UJ&lV3(!Hv|t1;pPvZ3Bf45YE@Xl*Z068I19i=)e&aKm5}$^pZg~EiN^J=<0+0lIthuo!>nK zENg??ttUZD*}?XAl8bjxeNgmLbzQ*H*66&X^7$hMXq7h*f!(>7Ua(O;rd*uodDm;1 zqgqM+!wK%D`K9Dp36xl$w+x4zpOjo}X*!c8cUNfH?l$BV_7Bz<3s{nA=TxugL9Xv-Hh zX8{ffTW_Od9SoZ$jIrdO0>5LeGMjBz{LSqp58s}oic`w=Q9+|4d*q2`Mu*6#=;TON zys9%ROYDaHZ@A{YCC_xARipEwuK~KGRsG<$8H&1pJ8$RT_M6%d-BBxJmz+ul!UC}j zhuX`rz7W2*ol83Zn1~|~Ses2Xw}jbXA1t`y4RIa#)k-1el`wbs?Ca`TcE2xrt?T#k z$ep#*g5fy4C^MHGZiOmGA`RE z>hGD_fADJQ2~9&TB)Nvh?E<$unK5*kBWCKK-+59Q+KmBEs;J+Fh1ICf-PU(%!dHKB zJ@kWbne4{c+Xf(4ODQVtJ2^|4$DR&Z3VTeya%DZ=F(u-zdVMp}EH8F@80k{C8jY*6bb(v-ceFRX`5dk^f{uKmtoA61_78_ZQy#}nrwF8w z|Ma$cY)W0g^s*-0ukLbko&$8YpnQeXtcBvqc-`S#l}kzO;YD9|#1>K$q=;Jm(qxFFtiaz#7lNb9L`gf4)>I>ao}DIVQquUGl15 z?=|M4-rOnyfZh_)byw(f3{SKtmKo%c!9n;KT7J|h813TUhjzYH*`@rtOF#h!2Lvi^ zB?NvuE&=xb=7DJyc{I|KH&7C`p^zV1ZB0Kw*CqwJwz&1>Z3F^WUHJ9Xmid)x4t4k! zR=*S?;{qXIJ;+3G&DKH1b9> zjJ@owXKd_!dK)lFPHEG6!ZJ8_D{wgzZ^Xa`E?c%C5(zXA0k@FYhO66#JGJTHo*Srm zINchPNrUZJXFidv9j&Q7{G9{aWi01-Y)}Eo--EmHMne z)HoZzO6~TE?zZ(WpiNiq4+r%iDq?t~-(lS)FIB%)Q&)A?Ii&o^spVTB>Yw!|ZBJa+ ze|7+8iswXjIB^RoGGbJ4J zsG6#hH=eUnRHaM0KSvl&A)!mNYJka8JNC9l@@rgc%jNmy*4y8@Dh4{bKp&NXh6sTb~iee-H%GZkRs#r<*iUBDbjQ z1}HjVs`BF#pF2A`Bbyl@IhV=xNCZwmnk`SSKz#ut34%?88+=WY7GoYc90HR1U=tyq zWCp7JI_L72`k4c;htMsepbXVMYr;{>b%j#evFZ128;cIPITXc+_PESQNf^9)r@o8V zy%)-d0+R6cr}Rd9F8)7=t?Ew6hnst0!e@TZa6O_+Bao5kdbAE5w> zujIJG+F)afPFl0OGxDrSOz^(k=X8d8XrQ!}Ys3%jWqVH*9p%o!P!&G6Q(dUL4vu1N zg>o9nv0OU29dtXbS;oa-16vxW_{lPD{cY-8Bqk}z14~N&O#%hKl^?`G=bKWiy}gp( z&aBq-4&0vBO(b^Dp5ttV)2-*{1^%YLV6gc(b&eB5tQf7i|7cyV;zL3cXINLDq479z zEgQQ)=E0|VBeFW(gS#A{wEGGShgOd)->rxnap?e0dXw*A+uOB*qHi^7N}{Ju z6fzxAD^&`bJo|r5*<+dG-ss#^UutN)aBRXTzFxb{(sN_{P243q*na;aFM6`oQ;Q&< zyGlw;bQi?tV+wZe}qH61$z?Y%=$x_o|BJ*QuKc3RM8rm*ys&5_mx zEgZdmM!NY7p=u+prFR_eA}@l4Bhq()WQv_2=rY@JPF~`vjNXlIz*))5brwR>NJh1V zUcte%3E9&(ZPQcV^xmxl2X-{;BjFK`rY+WC`a!+x;VFD$Q#j(0^Q4L`E9r}U-eKHU z*AK_qie(N{a!E=2VqO&}4C*(PUgw@sNPRoc;C^am0dtQ{Z7CEDP%48}rnJ5FuSce7ttpIzaQHIC_JXt?qC@Lokhb8!Fw_GkvY< ze|VA9p#nTJca)zFGJAB{8NIx`P87~6J?6$2_N229`%H9hAm1TqbhAvjq!Z{eqXrpG z3u6Lfo8v}hDbaJmD11pRx1J5F*K^-Y6$61Ic81Ze3h(ENX%DWot~C8ME^V(yU{oc*7DNFzYIChtZM^qfW+D-6K) zaxmhFbpQF|b*6dviD%o3nG>9Gqy1L(#KgTgDsjQZWp*W`Cr>|jAzO!llTOH zHj*kPuM5cLRe_7Y58`K&M{?z^?kIg{_*{C;Qywby6j z`#+77_{|9g@~xszqla1Sp2N8xy(deg_D_~@Fg~)&!M!H#Q-9QIeE2#{d;G)q)|Bqw zx#@AQ-3SVwSOv^f$nU2TMjCH%M^%d3gayE-d?O^KWXijbqHgJ%&)pFy-SWfIOGLAO zqT8Z|js}j=U-61_^*4@Z;>kkQEjk|4Q6PPX`PzHGctMc?1Zb3UDT1Ab$@jLt#JTK^ zFw3OIx7-gxc8@pV9Oo$g^{v8qU9C1*sEZ)=N_~gr<^eUUI;S@ z;zBT0;qS}Cz;%9%Tr0nW82qejc4~$tLePr>m;L%Ss5E)Id;U3_++efuVDlMDQIz0q zuERcaL+C6zSoyO$p;;ponp+Z;Sk-D`3?AMyWeLedyESF zIuzu0E|mctX3H14xgPb~7bCrfb)*6whoxk{P}(`?pL5B_(B;Is zU+I$7?=n${RT+*WlUu1Y9jAWuuU$%~%-ixJe|D{`03Y(bmjPMBqsu=q2%47jg~9-T2_1N4i}WBDFrbC)VEj_!8?Hs07mt;kPSJSvaNG*fyKmJ2 zzrGZou-u9a+C~$ldWG3nwfcR8U!>qR-zoq){K8^jS0cjw)7=*i8Buz_I0Qmxl0(Wq4NPU+kIuc*>EwEqCOZm0lRx>KH`V0 zV{UM>jaQ4jHe6x-U9G@k6jjo#<>fTfK`L(W^BxUc&u_Q0Qnphv7P413TG+QCABxne zxag|d9-4OK7YYI&R(Ywq!PrqX4>T=x<~CY5dQ&ZiPSx9M>6z>X589~%&Z~)w@2UWJ zV!V>xBHF8m86>Mw`@=k$UVLt3tf1+bc!@ujSk}*s#T0AC^uovZ^G^5mK7{R^P}}qe zfiNgi0KfUIRqL_8zh=54(ntREo|6d@^2SrT2iKHKYz~xf8Vy;vk7N#7TQ5rZ3O9hDGx9-J!h|}^yn-M!=>C-CK z{`JYfR^RKn{k7$IM519v(NUT*J$+9~bd#KIOGj@uEoq`{*Asu2Z4T$h& zv=9HS$MF`6LBM@!#|vZY*qG8^oE}l_2J$U`(66`(GCz^x|EerMYBZUC^ zY(i?)3^7%R9^Z~)gtd=8hz~cCY0_JX%5YT%A0!laNxhbE+g9Y`v0-~ClB%O*$ za47XQEL@q)ER307tsWCoKEY3x|KPvP=3HbUX|zLc)H@N+uS3oVo!C!a^nq5&?=QSZ znoP38eZ?P#ZVlFtwf`;y23w_mX(tSFp}W7aoa-hJ2)5$Ay2_txEt#^DbgnxNaaP5j zPNnLb3KCbXaD8{{Gtm^D(5XeOIO0@hPiFZ!_bO<@kHYW$Ej^dT>;n^MFbkXA))WFq zIwA`6aZ;U2Nvp&|6#7PjwBpfpjutOPCMlyO-O~4Sl03C9HG3~&Q&{H`sX1L!MqN=H z8^l~sM2mEG(P;udvy@r;r7?IMf6~+Ad^}DPx>V^vnw3Qfb{iuxe(SNGtRdyC_ha6u zcM1C;7#S5+j;%A2*fKnh-AP?FG*<1$$rWP^$TywQIEb*xCKG z1lF4w!N*On(%Y=GFy08#v`nTBe>}v_4gbo){I&Ir(@XS4NjXl(7Wp#x@fztBxQ4bu zV(C%8HtO-e$94C4jk6qAc0}L)RKu_#7Be(z0h8)y__wu?fZLJUQ7dZrSC(Ys^t ze4V>2RVO|x6)X(07UF!C>{(GdVjpFmU2v|;=|1K+cf$XOwAFv;mx&T)@%Fp^I~+WV z^QU-<)c{2l298JIi1~tZ=1C$f1yAA4-RR7W=Nj4O_L}<4YD1|=ZYnftU(-1(9Tgiz zV{@2}5uQuzlc1{r!yCLdhn_C!I2kS~n^EoWX;0Y{D?Dk$>;jo{n$nvf|73y`F(sS^|LHPpq2UG4Z;4=qZZkxw)ZPh-jfTkU>C2f zx3o@IN@;&KU%rq1g+5=(@VRYZMR87Zt*_;Fb<1!!e)ciL6kggz9vEqyifKs=KA5Iv zvPP-Pc%z^ZT~TM4v~BV1DxyhkunYLn>%6P$aPNW5 z1AJ1)M^lXz{g;ovHq;!|evg=!AX|P`?)5X|`AVa{#jkrF*j8sb+S@{kgUy-5EiFqd z*m4&AVb!oITi@c$Pf9wEE%o$%d9y0mE%8{P8`K#=rJ^|ANNJiPhE;Te z#Z0)4?0p!+XYq%;U)DFepPBl_ekHU=Ey^ftM(&d}osBcf&4V+k(up%Ef5_@|62UPgbu*c zQt50;uOA7_&T4jv584rT)}DX=@k8BIg6I7Ev zBNOF=m{>$8QiLd*n_6veJEVe&U2`U-b9)=AuxG{@J|tu^+83?2j6gBUT?uOlMZlo9 z3C3lK0WL5=8WAP{Pyj4ISM9zjh4a*kSu|CM#jP4*t?)@`vDf6W@a_t8DP_$QieWsK z`<-ddg{#1xjB&nHSnRD!?YnvdKOfsAss&bFb>`}+>{fDjiv+Sx^aHmaQ(;ne1G7hj zwdN_hWF=R2Khm3t0;o^L7Jf|PO~b2~+8+m09$DZ?C10^t9Q*j=B0cE^Rx8r|v)(-` z0+SLUX_s`)&Of~UYl%bitFpO^Qcc>XdnBh3s4_dz0s!hlc8yl91>FqF>f z=FpUzdw@Q1UP^8gusbzL5>SP*L}LA*R$KP~{05P%cE6lW2`9w{ji0>{e9}7IO2X%A zm|-kToG@Ugw_|~#y|Qc1RjL*CN@l_>=bL+KYpdPj(q``=QeW)FNUlvX0%7vKxQ?Fp ztO5_TT-%=+ZkV^{j_POi;cw#evwbuR>zBZfNfM7um4hz_MrWo1EWFi7B#yCqrteU+ z;ml1x^0io8ANl)B{cXN&l!uOS7vgyWCCkIXztW3OuM zH;dTE{q`9lY)AcjPoO&I^amMw`bAyqP|N&bj}oUvu;lQRi;2OVSiD#Kn?AcCy#P^z zzNj)uy#@k)9fH2jGVT^thI^WGa6_xEVe9HgmYgfSU)_iI=IG;}&~d-MFw8JqMKmll zW2v_HNyAL^Rt4N?7|Vj`4g7`6)f))y*E3m$G|oq_RZaN&14qt`ORMeuBAR$BEb{e! zOuCdpw(kjIzq|YyLUcOxZNGIiiM~6o$n1J@NNyQS;W{56642m(`7rxY2j^ zBeyL39a&Y{f@0RD=oM&I=bP|1Dj|l%?ipQFNj5AFZyuS9<7q}dGf%-#w$0zJrUKjC zlcFA(Y{(R$|`I{ zwO@lcJ+J7Jw4BA#rYC2hhW;NsVzG`S2+60kw8c+8l&lIp|g|Dp=n+%+gJ|~ zqIH;X4}k%s2MDMpRBd#!kjVL#e<{!trX=;?<|1}E&t}GVr#ovz2@8#UC+nC z@jF4xwTKHDM3eaZEG?pgKM0Epg`St&_XT+aejVAiuj82|YG7qFqrE(FMz$~{R((%W zv&AKIP<=iEFnjkOqkiBF=05g!(q)T^EgCH>(iV{r6P*3@+wYP+ z4N)rnQ&7ywedB8x;43zfUXeuLYm5Xq!PGm73F%VKZ#ZX(EOW1WdaTnf?`N&Gy7F(* z?-wGXO8W@*kQ@_UU@w+dklddcq!djSD$7kQW8@&JlFPI8x!*(n!YCt5TbEL(5-oNv zm!5e`zP%l{A92nT0%&egn)BXgF|zw8X!=W=DLmi5+aS3wo2}7I`DLV{QY8+&UAZbF z`$bD3YrN%-^z(B)>wG+Q0WAZ%recxU8G>N-6Fr+3pIarkLt=^TkE!*CzkCu)GqQml1LlR*b31?*+=GuMud5;P_=S-J^l&=VrZ4>b_-)}7v%OgMx*_+onKUsY`0vMfUa zRu%EEqigp6@YlUjqf&LGC_ALXqAwwTT@FWronNL_aQt{jkEQVcvGvwrQO4ccs0s`< zbPnA}4-6?C(xD(FIW$rtA&vCVB}kWqba%HPjdTe}cQ>5Jy|4E>d!O_DLkw`iJUsJT zYu)P(>{z1Dv7LP+N3?mtgTGuD5{#sNWVX`;K~%mxoGT&&GybOyJ|O-~>Vx-HGlS`j z00Ip=xRpae{c;SJ5~=$ zqRe9doif|o*dwGvm;tI`<2G2hJ7l_1$zsD|L7vu>6p5qoMk(S?5C!JSH2eco14ON{L3 zy)KkQaIf9iR7K$UkL0j%TGsUG{I^-jXzOU-s>U7=-nC1HUwlNzkw;pGo#CHrN*2hM zcTf^la4VdrdZUxDE5GnUQy@lvI?RroU+NdYdI4#y^7x`k>FWJqwZY8gZuEcLMCZ9m z>sWJKhURzzv;WdQo9CB3xZGLexo%zQ2y^kg%rkC31D$u zF@T9_C)d$8Mf#fUOlPP+x73KWBye!>@1fpojTy_(kOEpw=k*wyyiHa^Arq4M{w?wN z+kuacS#=zucRR=Pr+7JG^G$ceuL~;d&zOLt73MxL+fA*+@HwXS3=fF)rY+oM1+G1< zXf-+&m4*)W>ltJEM2k#GuM4cif_R%Fmi!qg5YYF(yUaKsh3G$&>pwhip0>Y%%}lTk zcFfVZ8b`EFvIn*Q04D2YwCt?Ui9xPIef3F%p;)~uG~rl2BCjtqv`W(IK^8Uo1=%=W zYa=uYHR5HtC!A}On8r$xMBh&;QnmalM3;<4v6%pz)d;w0Ap<~j2?%9nselIn#5d88 z>q&TXGy}(u>^CFc&m%5G8Bv{mmLq|#I7{mzu}M4%Edu1KVlZv;df8@Lj!Yi%!{%3| z!g{iDcCk2me4Tf2QkkGjyrC~8x58WnlR_aE1abSNPirhCEu_O=^)2-=HxHQu?JI61 zE!p#_TFN2OI%Ahk?Fq3ge8^8Awo(HVPu*EX!S@tIx2_JimVZ#ti-1rgYJ-eE2VL;X zyS7NY?_^F|=S%jMEd4J2e4YQ13Xe}MA$OF6A(x~pDse`e38`|codk1>BanaLgDD-8 z?SstCS$T7moec?Q@ohu~S~E(!bX*IVi3kp{cKnPZ#$a z^N{b3?&QOw0%(Iw_n#|`>Lc@t%gSmm1}T(9P;$k@yaNl62yBxrNGyDWY5g^lE!yv< zv0mgMv!`Pc#GU`@8_CoKm!;HG2-nis<|dxI#}~S3UOnVSZ?X+4m`h4I5|6wF>jh^lJfLCcVNPI%Z8zV(4-_j)%i6m^)*9U0%n1*gW?A+8X73R zMGYA;VB}~1KToLzKuqOn@zn+`kvUUNw))42y+`2J7FHnU$-KCdR9B^6WB_JXb=7HV zcGwjB2MvAy90>b`QXLligd&2qMYTsAJi>;mjJhW6({O`dGt}|E@suEwHR5@5hWYm| zUw5y|3ulEy?&JU=mT$4@ToF^wGYO+H?^`mLNQ&7ww2Wns;n`(hr`vz`4GoulNnaO~ zWaYu?&7b#N+DgdPZqNRhl1WoshSjMjpk$t9{Za3p*2K*2`OO|Mo}{js9kHOz95$jA zHL=aqeXW%|q~ob1%b2dQOCqW<@UP^pr4*`h(VKZ_m2s5?7?IN}iSY<;dYhrx&AeqWn6#5o8eAQ!i$D9ao`!)u)qdEhsWS;P)AseZrZR-`W+!;wSJbT$ zPs~6y?pW;#3ae=eCynp7BH0iy|6_Dyn^5D zOb_lirI*KOSD>!kZLJ9mrz|(`8Vm# zPhzid#8ylf*k!T9`I|BP7;(NjER3C#7EWnbe9EWPE8U0$;vv#zgH2tg%{RO|@+75@1nNpJaANFzT1W_)y2n*u%rpt6x05r+r2iI~l=-Qc ztt4}l-fX`q8Zm3qpgV3B!&f7+R>@gRpljRA+PgQo>tJ%AZ^?9mFwdgu0*m^N0!K&o zRz7=4hfAn=HR-W)i9$Y{TMAyn(zPf>mYHCu=n?op1j_wXV;M~k~nMPTo; zAiaKRn@tnJ7Knn}G2Itz$5J3%jpsOZl@OsVCfhfnhmhwCtZrkh0gL=;$-<_+;4Qo+ z)TOR5+1S^FeosI|*crF3sUeDEsWZ2Onf&50#QOI`iN_`Y9&!;-Hh`T+Comw2N7pC% zCGl2ch{5Gznv`0k5HZs`DwdUQ{>%j2V0LB`ub+8TY)DLI55$I4rvCuXNo#!Dq4pIQ z!-h>16X^iBg@RsSWYF2P>8I5zfbDDm*{w&P*6IoqzpZwq&8}-!;iT5;8%aeA{}n_Z zA1|_6<&4^;W5CiH+>=uRNK*Rd1)@spw3(a z8c&QIZP%{#U2%pBxu-qjU4`2v@AepLF#jI6z^>rL-%PzTOfv-!zzV-fi-wFLC4;Vc zm3O*)?gfE2`F3>v6sa{I4IIS>D;X<|-UA+nJx?+TOn;-Wk%<SVt*#W%z$HgHp@lvhb6&bP6Btf?XEfjyvvlf;D>VY{}@0gY;B6X0A|5cR2 z?0aQ4uuHqPdeJ33F`Uvl!oYYHFPdZOQ+l-I{|(>YWk(YI7AEkvV!ArGu9H--^q#M) z*zoc`rx`Pj!k$))hj@fwro;XHON$%o*Oz4lNX$s%Ge0hvD?DyMq>~*L&YD@XeOrU; zzt7RErw(*BzNdh5%}GLQh3w?D>6!s;E45`ZxMp{lC&nAhzxnxmK2w?X2G){p#xn*Q zQfs`q1QN&h_EgC4jdRp?U{ve;Uce%8IRk?AY^Cp29v~(%JFpMQbl@$yC;@ox=mkXW zQf_>9LbUCy06~vu$3>vhzKp;*@&IlnKyAoPxx}M`1Nlj_^h`+ZrWW_3zzj)eqlTF{tP?iA|4LXHEEN4W*b;+H1P>x?MA!JoX($_= z!kdj$4bfFpBK&jt0$+&Fx4Xpp%vfG3)(^*mFAzRwqNVSjQjPZ2+P~ycHtkY1vMQ4;)R;w8KOYh z-^L?dZf;yky*;{;|A5BApwx=qv`|r7Z=OF!?t0c#^ehZwP690=1r;y&ZH{OnAi@_m6ANMB(O0V%Q-n=#i zh~76YD5lw$(VUPzFza=8AK|X;8PP~92#cBaJLM*(_pI&w-YmGOFV*?VG-CuvS%6@) z`~L}WFbqhD-yb8j-5m=5mo@so-5cQcoFh4iGYTJG)PxGu1x00(v=ID7T|A`c7A4U# z+f7vq?v4{bH?hUR{vGoNPC^lO5U$_aBm`vkV zjPMq&bF?c1B?Ibfif|VcrAV~_HZX$8dQhypOUNnfHQUWLS9e;{cX{>O+9uGT1n-WP zJ1N^Y_P&=)N*p~AS*`N>WIfus+7N_Ye=Uv@3A-75XM4Ec8JK(DE`>Hs{6c8{!7vFQ z1yVb2ltB}(xeJLL83NSb!7Qqne~8X4W!|;Ys=h9$bi5H^n%sUr z8hlqr&#o%24Z>%7Nkx;l3Bn&-@OZS>WsVEm7TfCh`SbDsSgil~`7QMibP>Ur%zETFTs)>y!*%IXykbO&@`J)D-0irAYtAi-&tQhwl+ryXE z9ib|=r!UsP8pqfQ#h_jvS!*%T#_|GqrtTcKq>}|BG4>bD_-FVybVWG^*(DeUyRJQ( z)E!|ru)xxXeQmg`8JkUi^$I{1`~Tk30A@$n{@=z3Na7EDqwZy(!0T1kif25HTP3hgNIR`1w|@1J24v??9vD3JVTtR9_9#VSwC?f>o8HoCy& zs&OIdWG0RP4l{Qh4$Q;BK?*C;Ri;S6s_x2S)sGriNdXv7knoA`x)Np zNR4_?q!DWwyAS#FB)ZZi*WqbsjkYU6r47UM2^K0YEm6J~G)2;11#o&I8LB3H;^V@7*7$ zHnz3JqHGqlW%8Ez??F`oeubFI?0xW>3*@BBpakX!3gXVmBf(Fuq7Vka zQ9x!vCPG_OHA)G({Dd5?(QclVFDL(VDL{9fdEHwSK-b?(e~u6yO1J2;*tgYb9tkMz zr<-ja_G`>T1xUl7H=s{n{0hmHOZ(WCzwxjeWFJw5^@eK4+^q_Tnq`%73?-=GSG5cP z2qMxqzkgGcK0pOOBgN*yRUn&F*#hz^pAB6MAGh7SABwh-+#UIcp5gs!wd{7Tp>!Ue zHh%pa9p`lQN&{r#k{9Rk7V=KpLl*9?C|niN4DR`d;(a!fwz6f+SJ6RdzCx+^lz~B@ zWJyXIM`{=Pw&@aDxdP6NX{X8A6yxb}c>tmi-d*v0O^cVV+9MCNH^hKp%56_#NIHqh zgzf8}5ik(>_~{;Qj%7QOHA9>TgkzeEzkQaN#QbYU(%m%}8D#!9BP7NH{~e$k!2YLD z8`Rf8eIbnxFLa1{`<=w3>4`*s`2l0DS-8Sy=Y0gk4u#`;kWP#bI*aIGwm46Ws7-(B z`=a?emk_^S7S}TArU!19p1{e@6O-~E#|#9_%^`vivfWSZ+XGM-BfoJORsA=I@xLnU zJ5Wv@(Ur)Tr}>MDii&{4r_9!;sHpGq3s0*nnMa>VKE4_092}6WOLH0#ved2OVB1P# z)DzlD3qZ2xxS#qqbDE%-RgD(Bjj8LZO*ze!( zK5M%yX}Y-74tTL+v~#}N{Lvtz%>9{Y#JfO6jkeD4fnob&rv=SY(6GdiPVleS;7~T_*TyeuV2lK<-$6gN$BWa3myJK~{X=jQ!Rro0qO%%Qf2g1sf zU5|ScXKh+wF)w(Kh*}~J$OXcdKAe|)Hq{)nK~Vll!LkDFx~n9Ko^#B=4}+C$FIc)a zzcZ4(>|;9e*bl!WRp7%3UIxp6_}=DUB-NkPB$~5!<(p#7Qh(O_YF%NG)v%~tvG@xV zuL87JLgZ0SI3~+}w!jaH-fI~34wQVxRX>HSIr&59 zYf+k_NDZV%nf>GWyshR;>LD5t&eavM12aFPu&)%1>$b?s2A;ev$RwDAAL)TtHV7i$ zvboEK9AwcwIe;Ax&NrD~wnt8Qg0Rfh9y@0unZi7ihPb5PBs+7;W5Qx4Y`Va9IL5<% zQnr^`lLCr(-WwW%j<+xdmfN1UA3!;7@B6gnb7!vk9}!=>a2!Pnk;pGYOvmp&K0zb_t~zZgYoD4Doeih@JutRwd*0IGRaB-j4VvvH zggkQ>L+8t+WoX2uy1fJlhW-yZq!fcsY;g&d0DgllC$E&hG<@M8BaDvuGON&+quH+X3cwk|Io`K z3$e&5a)?pCCv&x5!uF>3s9oS;KY!%t+WXDLiT98SuyYm1U)wG$UPNQj#m&&rm;!$*Di! ze27w_T>0sW{NoYZx{)$Tz_OB^`YwJp{seFBR#)E`Hl_w#lmGWN4FSOG8gV-AJB1Sc z7PMQC7Bu-=@b^X|y6Zsd{r#QV6vHISkB~}lLp)h4D3Ydc^lw;Xl8|YeTSB3q1{9lQ z5%`_dQ|IBkF;@wq8x)ZiSs{qRYjyU=eoe1s8{0w;L{ExCto{?dI(ge`uS?i$%>k*c zEii4GV1Bn|uHZYQXJBy^rIYOdlJ6XG@OIQp*T+q4#pV$Yct>=>;D-+r7)9-`rJQ+e z53B{K`7E-&nke(t1Z0GwQhdwrY`#`2Do5K}rN!Sf>QL;TgJ!qbel})H_4_*32WDty z3ccYn-MiHso2(B8PzTq1?hTEI#u4b~6Lj=vz z{xV0|JyH`LGkg3!rLHi_;jT`bAE^=|-_Lk)UjeZ~Bh3=~gQuKny z`W$e}H$AYJc_Dqkkt+iSg>l{~Z!}?pfUqZP*qFH^bzGu3O^%YdDQrwo3ivY1lYSoO z%UwI;%V9hLBz&9fVPZu4^Z5?Gt`=deV$w0KNpt(BMN0~PkZ z@N>OMmNA%fK>wjVa^gh%XwglNrl^GNoZY^DAr#53tLiw~qtEW7c7>j)kj&{W$!}Rs zKXR88TDGesQ<&r@Cab7s=quwaWlaqmYm30SrBstGO7c5A|9df+ZR12B7ai0!@|*a! z@9?9Jk>5HBPE77)O02?>)4l;43<-h(8w-~&{Jtl(bqm6##ONTrX>?r}$9PgtL1+&b zpn9-2AL=o@G{9+ktUULqM;R>t>`Cru;-|%ho9cM4@$pbt7DzO6{cEP~Bm3LxDG#0{tlm6ylpUjS3s!ZS>#_kWIiD4qoH zIEej!Q+%g$;9&|)uljpMA3M4?TVn!ax61CZ0$JPv!D~UcUTZdOuis6F>%c~3?Lnvg zmw!j!ZM~9(@U~B8LSCv4-ygn_ZWmu3mOxj$3Ys0flQ#j5&cK(ce!Qb>OO~)0iXaGr zsz|+lIuDx%soz381dGjzMAmZ~9X3{MqU|~5vo5w)w44b1OJ1u-{;ijJdU%d~D2^_y zq@-Lk8>up`2Jaq>tl+KhHtQ#)V7qxY{BZjmpUq*lAr4}-8KoNd@>Y3woz9viq z{UNgSdU6tpv&v#XZhxM>MM9gp|^phdF6*C z*0NFfb_mB>M~v)&muV}Tjl+bqCJvsq8&R{spzFfhBtH|f-=zYg5>iVXp^$Bt+8z92 zhwd4!=c-!7NmgxmY^p9M3xEq9UwxO1p6|Z!gI~%Ce)^so?qWaqo!_&K%l8jVbq9j?Ib@l zDwK0xSk(;$r|mIbl|vQUYLWG0LQS)5&p7MV@yj8J z3D*MjSE~F-k<^41G(0Qvd;TTMK7Ls$HaH`m z$M0X6I_)q`sc2?4aXFdG z@`|1*s#_*(=*4U}-}t#A$5c%~(1;o^cx^TaOO5tm6WDz=vAfnols;(P+h?9@8qv31 z(S$o*EJ(c-WUYH6{+eWsB3UaGzJ{-&zmE&q( z6Jw)hN=t0u0JcWVkE%1`Or-=$Tcd;)6$6*cvxMG6+%*WA;qPlsGLs{3V8Fey2`1>* z)FE`P9=*Vc`Wa9V-lBZST*0xLIKRi`qxW@3g#ueFS|X1pR+5FG3AtLdq?c18UgUd) z_y!W(Ic^hA!xgW%pt&0{xL!elDSbp&t2(UHvMStTu4S@XG6&7q_JnZsO;iZr=}M~Y zu}=8-HB?38ZLedA63CIJ#6W($cvM-TX&XKA@w;w96?E@Gbv>VV)bLCiRqY%1n&Z@* z;OPrP3A{ra^?Dv+YAB&|aV7sF*EC^k{QwnJ*j>ypR4|TNiEEfXdlwT~j;O2KJ?uxn zMYSqi55F}|@}lLQw+;7#AvldL3)Kw4!gy$#t zyl=tikIpF($7VvCF>vxHL$+VNs9Mo>YS)ZLq*rD1-flQ?C%ZywmO_c1m(v7Cbg}NV zOW1Dc%?kDt^lfqiI?700=4DBf+FggoVytUARmodP_d(&;I_6XsdnvB zG6P#5PM`46iH^Fj%+N}(?6NHzNMZH$^hpuerz_Fr&C%LIOx)LFYy9 zUA!3&T22_dCNle1XT#BscGxyDwS6Q122u#V@(d&G`6=@3PU)y<^UNvxvF@+vHPtM$ zq2~u<8DS(odU;0ThUlK)UpnGWoDt{4+zP&A`IPMp#5bC!CASzXc@8E|Bk4&pTZ3kl z#8jT3)eFT+*Qpu#h~Y1#-Oi-R&0RgLgdty}pDd*sXwC$XZ-}T)g_E=3ZDTXCWFeRc#Q|Iq(zk=m09Rete1}n!!SF zIiA}#+ED2d-0TpWU<>&==54K7QzjYG80l+oa27{_`dy0OyBV)ry6GYqCnoI8y8+56 zx8N)TeT?O4B47fmm|L+?IY8Eo>wcJYnmSpLyR=yqfD-jaN`}R z*ofVXb2@niN9V9Dqw}ELydk~w9@ksGoI^t1?Q+f9OOkKQ@blm3Mr?Js=tA!AYGy)? z#Uam0Jmo4MEnWps^*t$*X6BmX?i5qgu-*|-L|s!#6+OPRE?=QxEG5p!aOw?s0LjW# zi0*m!{Lr7tmT-J%*K~Slm!N54Ln>8ESPwv$g$o1~00 zcD`CF6O#(Fa0uLJQ$=P{CU~&%9-IwX(%BY4@nE`yCIeL{M_g7a;lup)>7D$LLvX>Q zEw}p*hriFHoje>4ux*EZ5VM~YmJ_aU9XD|VopVUZ)m+zQMmT!jv7WYuG^AX8Jbouy zM^)TY7g!(jCi^;l6j#%ksQ*E!a-bf9s!>Llf2%ZgI_B`zcf^8*@--argMF0w*TNSe z_2U~+Ny5gE+?WmHRKP_nYQBhdGuT*zB+$sSHx{dtm|a{_sirsjHR199>*arEs9LoA zvY-iulk%a##!9L2(tJtY=4GxsPbE}ZgFMbW0|^%ol7$iM5Bdee5^@BF-+JvfjOZ%Q zF2v$kt{^+@*U+$J?#()U3qO4?GRzlmfwN1(3y z0k0*GuA*uOchF}oD_n0rkdn`ax9o5Bo2cD4SMiK9h&||CanSJySNfc&&+VlTr5W$H z;3-KvT>d4}xZtB(XXEzV1jE@=_P3)c6^`hdEbsGVZY6zCAGmzE${)fw*}-yueLGZ+Oi3Qn(pCM5fTZAm&^o}5hZ)MB-5 zSBq!s9NC6c?6T#yJnR|xd)oq%$Ogm9o}?Lmj(&ENW#%st3uL);;o&Z8(^kkW-4b`|ZEL7`5;mEDT$Xf4^!4~5EHroV27 zTg_8QB2^q`aEsOR9dI+2qb`cEe-73Z6UtSYnTM6$f+XMe9-meagYrT%VrNsmH1^Rr zIn>?_@Y$xUKlAoTAUim2eu<4*EhekqeI=i>QgZkF#GHE7>>Z;Pr``)T8&rUMERi`R zk^K$N@G@w^aameY^35c^3A09(mz zldA>n?Y`I_xF-!Y%P(|)8O0Y%T>IgW7MhQ*v}qq$x~Ir4;)>0DwsD3#fUN2D@LpyQUXzz;Z-ass5fWpy))gJ0x-5ENY(Rb!zx+YZ7!cboKSLF{EKWRgi4L-%jpV1>+_b zH_qbO(ggW2CTqXNipPFrd!SX`4e_ma=GXe;#l6fZBiNhQNB2VMzj{t_fNl^B=>ngR zkgU=lE$x-8_s>UQUvyg#0;fnp=uVtndwI?lC zq@8bwY?DN#lawOtHjOt=iV@n2%?T{+(YX^_mG&FNq~zm=?olSF^4gu%*Z>C(FGJ4XkYN`NjcNl0?#0j(~uCG@4E zt(kJH_#1sE>gRYHM<;j3jk5OHZ}Ai>k#hDwb`MRf28h zA)D8O;&Shn*=Fl(shCKGq((nrr)h6=3*1rRlWmi|NxS63-}8~M9p~F$*uwGLvxf3~ z4$K(FwpG&J{rsz4LV3qiaQ^_SdZ2C$mI1OUm#|f8;v6!sN;-OuZ__6C=_ceD=T&5K znMN6h=2qoST?%e19*pjnEuS0;mEnnh+jy6 z@ZbG5Ra+$YHhE^ZV+vAq(>1qK3^OTB8Ir&w{Jc*}o> zeg2=bqTTNrKh{@E#?oiY-a z`=-x5(_?o!BppK&=(2<{@7+3_4N;_WcYcYt}aajPLzgJt7B^bvm z5Uv=q5d&Y?K?lZ&Tr=*Vz7LQR(Ye6%=Rva{Uy_Au9fm~miL$SvNv9&|kX&J#c=QPa zfBmh-H7Y~*(g;;NAjq(6L>|F>QSs3TdzzB9<2ImDr|O0sC!COh0*n)U`G?}U4HUirLi>sG-X!HX{DP6p~a{X*c6FPna9O1Hkm#m zUBKA-XJkf;2`e=nPZ>qEWt?9>aXh6OTy8?xq*7vjMef_uYqz@nDMOEZY^dW>7!h0` z;q(>*k2@$y$U9|7hrv+jG6Sf_a|VIwV$r>2u2msQi=RzhdP;rz;saj0mPvmPtL!jm z1xEnS;_r)v9Z#r(H}dr1XkV~2icS?Hu@f;zK&aoEzHoPkE3N*Ocz92dr|3=ZJ0Cv; zY)OaP@_EF03HRI>3Nh>wet7$_WXUKzz2JRC_~`;6dihPbv58d>IBGjmBw-yS`=cDQ=U1M#utta%_<>nevC}X zg56{ORLJkc^0L3-8Y+UamRn2)4`~?k$Y3RQy*60N`y?IS2CuXzN4ihDIRs<83`p5t z+SyKRp(*?f*&-B}i_(zQfWw{yi5v{GDL9*Qt!B?q?dDaDHTX-_RU%th&ogg9k~)Z)Yfe zmh7o=pO%4d%lyq(Z@)nQtxETGY@PRw9Ulj!hQw}dkV(c(QupCs2nwR$VF%7hvqN7l zUij{hX3dZ+O6#929A9Gh5YFFp&SToAhW&}36kf?rT7P2$!oM=teRS-d`8dQ(>PnZX z`*A5L+Vs?G`y#57dnbv13X>_33Yt+znQF2^g2~C9SC2>R4YH3qH2G$^aPxdN`4)XG z!6(zuh&N+cqH<0DCxZ0>jBZ|kypQSdZ@R_z7?$`}lcat7by(sP@bOHqSJ-Z&I%EF5 zl8pHy_Fm#hC&T>ksH<>dx(D7{O|zkxeCG8&M%yG6;o*jkSqpxP+qTea443u`mbDM! zL06LEg1QSG9cd;wp32UVz2p0O2rZZ|rAqlWt2c>QKN8&BDqtWu{7+*oieF7Yw)(4G zwt)3d(4{_<+Zy`Nao$-@!BECj(~8xKlJdnmE^^G78~84-8tHcO+P8cdx1Bbde>jmh zqa%)^ITw6t9|f8>6QwKIWiKg6Og+>;AJD5qblx=Ye%X1(V)9E|Rzmr?2zyDk!zMuK zCrYY2;hm-!?46)bKOSjdxht(21rqFDX>ncp`y~fJcT*Y<`xl4DXFoQ`Q!GSz&*`*s zKBJoa2C$Qe<1a*vo7pktK6&1;K(27LZ$%PEkNwemVZ^ExMub6mYuAbuOn=qY?oTFf z2rKi>i+s>OKj#qhCR5=s$^9~pbZDx33 zuZqq{$tw32AVeM*rbkcx{n|K~l07^6oTxWTM|+CsYiIFhqlTAvbm~q>MXa>n0@8c( zxsFe7d7>;>dXwouSxtl!1V|0GKZ8Ekse^G-Lm`uYxjTNB&JsFFQz0!%>L~LsQi{re zjbWoW!s_A>XM;D~NY)MC2BwR+28Jc5FZy7EF_JH@@LkLF*cZH~JoqfAud(QCwriUw zLjvuI6{D@x*HkQcVDrI35b!pN>7@= z!6PNc(PIriv29@I@S{G!$otA!w~HB=r)65U=S+&AY+r^VrO3uJyqBZjJVDs8!yhxFmw4d8Gu-;kI!u0RMgXEo%^ zVk97Wm^Z1(uB+(h=R2Vl9$@-5Z>1MNTH3y9qtrYNg(9I#M~hsBvEP6TY=AXTJ>PWs za?2qaET)Q@uIew9`e6AgPpV(Arurxj@J`~lAW+tuphz3b+t%;A#kd;36Ijr&uKc4V zG=DzA(1q$_Y z82>w)Y-D;+mFK3s@;o@$A_>Yz;fcR|gszNZmx8K+iSaQFZfqMq_gG~^u5pYVbubrZ zG;tMC$~|BYBfc?^G+`7S`0*a#AsSl1Vz6_d%z-36C7%pXk*55B9zL#4SY^zy3SxDi z1x1Rzk$wMf+6L%~|FtwLSIhnHof)MWsL5jw6?t=6*6pO=XGx0)XjtN@i@MA!Bx`Z` zF+8fa;xI=w)z4#{Kc%KM^3T^Z9X?-l<%l8fSTH}x9_pBcV^aWu)>aXNGD2@Jq0X2c zbrsrpZ&k`Y)pub^;&RHX!n5^z$R+Ii+;eVWNhO)KtJnm^JJ8yiPBU*O<+1wUiwQ7$ z#s(NkgD0S*E9cAAZB(t|5pCP`E7i1!&Uc<9+mQKCN=qhSCD8MUOu2NCTVWKerd;QL zQ1Y?leG~32PzfF9zvOk!c%J4k=hB<}g7!P5>|bbN`^?5(^XCjs^_1jg(y)gc9&FGU z9+lvG=hZP_)g^5auxUxa@+pdGm&h-Ww6fduOso+H=7z_!F)E-Obt zC;0;6O^;@VLeW#jb=^H(*aR1)P=ytYK4L(m-hewjJ=y!D#$XyXK&Pyy9Ti|*k88x4 zs{MWG8C|?yg;^Oh6ZyQ!#MkawTu1b`L|5;A8J?okVLX0GnAA6$S@A(3zHVm|5;%R8 z?1f#dZ66To?vSkfLs=CQZkP9tX|gw|t?fm~g!ncKL1ZZsJ;v&WJ&LHbp7MD_$8yh} z$sz|wv|#W={YT?Fu_V>HE;Cb9(y^S}P$V$m=wbFD6g^(~SZO`^Scf`dE~6N(s`iq+ z(d0062KT*}+63_-9qSX;zCk^5^xi_~`fmXjTu@FZXGkUMc94&8sxy zm9JUOcYf&ub24Ri18zrrj(PVE7aug@;nc{1V1DR@NGOLRb!e~5ju3>mMq1}R!RAOt zI5U_MZen@DNXJbqH=&{bMg9KjKhDk&3`F%`eJk4i*&EM9KD+0;LyePh<{c`PkQAy z%I$4_(%}O(U6}Z5=aUe$Q!zw;;o|vw_u|({KHD1^|21U)Wm}&CpGQAQD(fEAMzObr zUr$_}?1%qjD1d?>KZSSVw9XW5L2=no0JaWms^xTvafGq1*Xz`qVKcoIhgVbnx;K6=L}3sA9-qCoNclb_nD4slPUWpUr5ag zj)Vj?>Zuu7^O}=y9*Ku&G}dblk{O(^_7GwZ1YRYsDq^JnbseD>id zq_%i^uJ>)U*OZPJaL7rf2Xmwy2i~#RhI$lS)ei!6np#ltCP!Ry5!TLZ30$@hx&Nl3 z>SI7G&<~R7tf3QV&kjX)w*`_Tiqs>kbAE|S{u%FHOVAHnpR7Lp5_=#rYk!x(0J;zT z>J$894ac!doZw7kv+HA3iEX?x1^k0J2|uSHRG>Z}qm7a2ZL(dW$#T>g8Xe4^w>(Y_ zPvexE4xhZzbps>LMcFYwR^Q950sd3%z-((*@|XJycRpn+YoGIYFE#D;EUgH9n<4u7 zCVaOu{KJesNa_P0vtMWgRE%9##e_x>nz= z;&5?M)j*;K++^c<0Flc3l`By0wQ1lr zMJItQVdlK_=qGFu>=uS%0U{hc+?0@HRat<(^P+VCRAFSqOoc#X&r}8VcMZ)`-sZpi zYzuF?-2Y;pV?*Vk?DT(^7?Z9s1L2!u+^o=r9j!SJy8w97FE>5Yhuz`5Iw%s$ z6LD`F%LXt}S|$C5fuie7K98$;AKK6v{MYcv`rV)m(aJPB47ZhOU(NTH0LZ!4J6V|+ zO|C@x$3tlVcqr77zLCQc0CE*3NbNnMHs>7}l-Id5jvcKA?K7?OnN9zM73<6ulkp2b z1?{U92fU;+$I->-l*K@GUNcJQu`|oKCWP(n???kn3ut*eBXleOU>gk4He-cLkgB7>6#pV4m{Y!RT^b%v#5v2U9!$K=W2@_o7}1#TSksn z;|*B;I;BJ5+^S{>$J0&PG>PtKCSBzwr-rldI*35xmP6e&B2KPlG`8m)da7c`5#O14 zf3->IUuYxBJp`?K_VU8S5+77g3J@F~IDY7qvo#NdX`+i+%;Mq%OrBjP!8nOjtbMW9 z^%ZsrY+&W|Y=`P+oSr_+O!Ie%Hv~OjeZ$7~1wQ_6Csq1Z5<;2ofHW9RJk1W@fKpSj zX;cnx>81@f1vFGGi{vP_|1*9dc6+s2$sBR4bY`GKeoIDYO{@IL9&y~VVwxPD_tUd# zLw`u@rf&Cw#JnkdtYpUV{2g8*r*1gQhz3;QeZHpkhhKySm(QNl&LtgYMfYa{S7z}= zCEKCF32$?aNZ3mFpf47xss;S9>jGMt;TKxdz&A$$>YH5U5eM(7iG+UX5~%d-#@5+R zHSUs;-sUR&;(IBV^LW1LB1iGrYF9(2_5-Rd$1v3IG{rbT!9IJI)WzUa%emP-Gn%qD z6>3zT>0w$>#ixoyBPC2_o(jLn1+t;}HwaO`a$1I`vfFh4-)ZnyIu_Lp*E~z??i-i2 zSevZwV(3Px&?Aj<@frPNw(n43e&c+<_6@Y=!a)9AJZ5V6luL6rt*$3wMf%-|K@YJ} zUzCclalXTse&sFG%@a&C_FN&XqUuZ<=u6URljk$om*L+(0xU-v5?}wMHF7~$I?$!; z?R8$~JAf!2_<8vt;|6=qYv$L7VF~9pq8Vt@HOT9J?T@kYja!G0hFH|*;Atppwns{| zHMV^c?>N|E65H5v@D71Hvu7H}G-cdW3&|L(57}9y*%y3xZldJj3yjp~_ip@136~?8 za)jRg*AwLb%5nr1?9Cea`(d{6E+(MS9hUG(J_>Wk+;LZzdk2syb47zR&z6prcB@rp z;FU*C2pZEcMUM|>JajLs>VQ_Ch8x#d3X1g2rh4I8{!iGPJ&>!AumjR znf9P#6YIt7KLkC5tyo8<&npNmvR)@u=Al2H7f~kE{MKHUG#;#;JxUJ~s^01bXze{o z1Wc)z0mY?vuiWg|`FAGW>%C<<aPdX^a*p#OmT zB7wQ-C#op4I??DXZ2Mhp)~MqPOsen3HQ(qO%R{;(+cZuK+vt}}$G0ejT4yHT*}B}H zA+}|enb&#kNma%Q+bauUzTR-Kwh)~#p4bwpWHpV(25Z%GjGr$z@ptp5~c;6XMwLzG21-b??w}X z&Z${kbcMoKRyw>$EA}Rbm)b03j$j%y(eb5-Jny|M^46c~FXkIDyP`LiajEF5e~vx8 zzujP`-0Xx2|J-6uwg!drh*TM?%E+1Ye!M^6`XD&U!sPR%897r%F>|=>*B|!VH4yd* zbED*W3E)|&v*VEMkk{|o{ybEGLR>W0M-k{b>cXA(h4=SqsX%>Y=Z`B?xk8UG z5%#FM)#N*}LTd_$`iGxJH5cRe_(fmMZ8b!JQecKYVdJ;7(Jk%af+FVq&OhS7YOhia zuL6duXw$A-LAP0nSKq$BKqN4i5fB0&ca)dGOcA4D6fT|%{oyYqk7+QqPjeJfeDBjy z>>_g1q2&xd$M70|4x9nXBJ^h$sncOqahq3DkTPCJS{7`$eY5cDBd2c4a)D0jQ_@#{vLey+y;nFm!c(-+xaujkk# z621IFX3BVCkNAk9+-1GJz0r1m^xkJWy{{}mYP2Rh9oUT_sfuHgG)M@t#X1IlmHN+s z-;=X7zvg+3+`W|LKfgg3!jMm0Ot}_4AiwjS2mhMz?1S5FxHk8Nz#zuG$wgebGwVip z?z5~gwtE(O0nxIlnaPPI(q5_UHB8H{efP9ejSWSfqh2WpnURNWSbfvI|Ew(}uX^xH z5AGROloI}EWoW3QdxqM}Y0?E^*b`JK{y>C%doqr!orfcrcHzrtd{=ayB_v9z6yE+6 z;2+46Blb%-pqR4~yAM@|s!!*Z`H5gQr%zg+?1(5!n7)hiZrEiV@_S}E{cvw1zWZ=b z-UG>~He7k}Vo&j$5j9XKl){1eG1(fIG2L8eYfJ3j(ac@$#B*-+?+cqB9v;3P`@|jn zI`I_maVo9@!{!}S$?My<&&DL(2i;8Jx;#`uRP(iE#@^zdk`Gy*+Aq0s^v5e@xS}mm zw!qZ5<_qQ^PX_9iPHD4d8eZm7XSpOEe26S=cX1%!u_B6cRrx9WNF(2K+ipRPLBDK% zL2EPpKK*uvea<}Tu6ed4D>vm9LdQdp0pIty(OZx}{B_HALhEHr+S+$t&B-X)mAdMD zFB$}AB*o`G@+6+>YG30~Zp~k8zkfL2T>@Ii@L;;9=5uPJ!LU_gnJdeKLmmR}+ejt` z5-k1KN9l0X+=o?5#SKCXDCMC1lbwsZJu3p0&dTg8ytO7~sc-H4C{aq>&<*4EH&XIa zEmcD|%}z$|2eJ; z2952czuyzzb;oIEeM(SNdbNxP=lX^HH*%lIj$!5{@l{Uogc%IJmSUp#o_&3&4Ahxo z{o>msMvPKqpnsH@*)x6gP=Qoa;dMs#ZM{d@rz8pm%^?)SZCjxkfG`+WLlF+`LWyP0 ztJuAwg__pI?iRE_fcRJn`67qI?B3lPqPj7OuCQum)mJ1;fbg~uPuQU{pY;==m&u!4 zLLTgsgpk|Ks+C8fw`(PUS2spwW;1U8iF@&0z{>0<;o#9rLX5Lr3YIgMgX41#B1NBu z>th|Khqpu9IxxVlB{N0fmt*^6C)drd+w43 zFAekpFcY&*Mh(nNdQ)=K@5gl&sMQDHzTAXd%@rBDkZibkWLKNPKkg~!E~|`%t}Vp7 zA>@3U)e^MnB{R`9CY&+P=$z<0!C*2zK3S|fK9A42qPnAM_-S$PAs@C5;_|HlW8!uAh;H5KGsc`~ihE;x&6t1&8>dIe`X&m5)Bx zf`XTOVqAA49sVko{RTM_i0V|ekg?T!mHFj}dS;;}<5O{eS^mZ%h0(}{j4sWUUXN5r zel59x{7YFOP0_nzHMORqKQFZH1#{z*ys4Nmy$a*1-iCnAAPzV=Ne#aAno=#a(CQR(* zhRHh#TrGwqgo7r5WHR<4-CZAMssqsjK}pH8Et65O3Ek#K(j)L{%yJ>LjM&deUXjSZ zJdUuvdt~}po=moj=?eYv{JnQ47d5Pj3}|iyY&BmFNm%AZQx^*@h$f5p!f(o#VysH$k*;+sMQf3h-TDa< zKljKXztZ$$x_B*9rFhAEPel2w^n21gWp2RG-)F0@oIyU6S@kbZuch}#_=#JvgpNN><6~VF_DZg*;y^E?=`1R9 z7uF*5b(Lo<)t9OCDI9(nJ5X|z(V$$K95=Y`EyI@-VZn}${}9{jxfWas_+81MVJ>Fa z2GPA`{!}~_^DFK*$P;d`^UVesIctn33ywklaISOM z1pbbkJ!bJv-;^F&8!}sx+$bIEU9WL)ZV;5q=kjQsHH>wR zLlK}%o6G3hLV}-8I7sG+#S;y>L-(+F(0Ebnu&eYqYi9ybN=PYGbP$_JHhbueA@VSZ zFweCeLAS4!ws(o!WZfTlO{lAx+>+2Ic?PN!z50c5IZ60tw=PG39kK-Wdbi{pxG|3z zcy#m9byLdltsj?iM2)yI)3~B}Kdw#@H78pJ9Zheem$9?p5yz$a8HJ9NwpFoEQnFJg z%EHkdgOT%6jBhp%lBHVwSpvDU=eL+!ifFSa3U;mZY!Zm2K5cLVI}PcsnLe=70#@c9 zA(Avn(HS$lpJsPM5{?UhnP};+OqA!F<77eof*?p+DE}*aVK_~Azb99c$p$1a2Ei0k zbFw9Sp9@0x0j|!E@r+O@^Su~+X&uE2|IL%wM$IXRxOq0uCWDoqD2zW6sW=ZbMA(6( z!QHjys^3Y2ST2NEe1~`2Ogw9dz;wcBha_C3vD_4rzoCMSR7N%8ORL_Lk*B``y^)|T z(mE#9w-uNGCm3%p9*^`I62_ZnuQ7R1MpaV>eZ4}zSI4|riBGeWDAH})H@G4uK0=?T z_~ZF;vp`DgDZ063zsNJg^+>mLdWRpiKd}Xd@hz&Y9QnD<+d|%rd|!gy@$ymRI`h8# zLXwj5wBys*z14w>YRU|%p9NN_`zYi6Wm8@(_y&h3s7-bXx(oh+A9WVuoslp))EcmU z)k%%icKG5MMmw^jGs5WxkhaSSR9!*H=og9RPKMb;DS}!ju@v6{KsQHa){{+=0;Y&P z-Y6F?+g(pUu$)Wl%|`m>%bH6}ap}>hTZ^@h2CtBC1_)mm(sVdlJ$g;$l}$xOh@zzD zia@=bUj2N?p*SM7ck~@!VBg)5X8hb}d>R9Djif?l$$AY&c7(tm2AXG{NBg(QU(J8s z&zl=!FPseM_Zw~JKTVqVda|APjOnv>Th~Y+lDDXrV8}*zZZ+qsKtSmaK2N2L`(M;R zQ_4nL)Q?!-VB?v&L9qitVfU5qUn<;DsYfa;$2A}2xb1?r%U7g8dMD|DqJ0**%`@ZcgSGDQrT@h9I-$;G=Yv1n_9iu0)htY*+d=Z~qQ4 z-ye&b=Ax;`nHZvmWVygPVxIlEp6{OqgY1R3-kUCPc_Kv*o8NMF+$QTs?{AwL0eTzI zjswEyr=(;pFCgi@Z1zWe#6#7a_4t;a(a9vl+~{$RD83eRa4XZ6;FxAKjMNLH7xPVM ztCFhE7izH`kLyA;J-^^s*!56FJv8_{*KuKFw2dsrPQ9+{ZroG&djF`x`UBo7qThQT z_gP3!VH0h?<9bZLO9?lE%qS|7aTMhuY~L0!_aeL*4btBdn9}a@tI`d42K9%swbJL@ zTnHeC%Kegjn~^n(TtG1oH%$GDr{lxVw2@j*yHVpKezJTs#x~#O3toteMlU?bz#u;n zxLdNZR=Z5#=Q#f3Y1n=Iy z+mc)DO?w~1?JM_Nr5^g~lz!n3#y@SPuN15iQ0%To48#^v%kCEhWqRbH8GV?%Y1do1 zIOt_|I9!##yqKTn)R5%<$5q?ya9{>)dnvw_qU%9DPYIZZs=;5ml~XY9yq{D5qfsyA zD6A2AiB5X5zZ3uCImKu(D5;W)9@s6oZYF|uyL;wnGjqiT7f%d9V-<)?5(MC{cQx>& z!kPSP=B3=fl>uFP)Ft+|&-JE(gyy0U74(i`FL!Lihj+IqJ>Mrl{&|EBVTPnX##Rr8xqqxLYvptmRDLPbA~vh}w|>E`0neVJ^^4UxZC zO5SuQJ&;HSb>aWloz#_s!m%O3tu;LK$~3ZcaP8@@ax;7(kjqYMDo4G^Mys$PeB~S1 zzfXKI+z9?}zW-60Z)J(zC3z;mg>P)XKkVA!sK`1Hr^1m$y%w z06yA@_6P9+}c_k+;&zP;yf8_9J4{(Y=~;6KYzN?He7jZwPH`+u<7Jb106&@!YGDT6gfO7hqLijIz!txE^dxU@M5MbLb8&gb^$r+-pccx z5wCZz^=|djzBhHX_5D80K1HlZ^`dnm!cNTjRC-1Z#8q!pS3;*V#n&T4pN=Z^cD|wIT&E4BQdl;r&Kso+M~^|o9YaSrdWKqdOgGE+l54+(6XRy~=Hi$yH%s-HX=p1~wYL&91MDYo*KX2--8o+mHY1m5 zs!rSg)zXJtfueS{fhf1T@~|h4Ge9uQqNUaLFjt-}DAV(b&H3Dp{}0m3NNmo(%B((s&nPsZ&byr}AN-H#>EPj{8o2lbC zL^>@4ZC2JQweme6MdTm|Jjk#oGh`nM2?vOlqF=3r~v_x1_nCqq94f_9Z3z$jx! z7cn!X7@7+16Q%=+xth7IoGBoYfZ;!Uru0;5Ik3C|Dvl4#3ADhcpM?(P$G3QC=13ga zhjOOE{q0=-&u@bl&Hv7V#UJr*4zvi77ZPrN|Nb5HJAymh*qYfq3q7zZ2Zl5Gw>_Q~ z5ox|~+%}(Wxzl_lpU3HZ*(Uu}XZ{0ezLqmQQ;m^Ec&o7@jzra^=43iEN-}&zrY(%m zke9zfutYz)eYp^JZ=@hUdGts8!Y=<46my>JCMfEg{P7@HKsweZcQ1LER^OeYqyIzb15z?{m3K3iXi_ zvgktfF2NliTZT7=iLt&wd}N-Tf(t!U5&+OF<9T#K4xou1_;gY29Gi2}nMIG-ja{iv z?iVnJ-RwbH^(`p)$>?!a1NNvG2{blT8R#zQdpOQ1LdfnxZXZ{yJ=+nlKEzFZ@0 zSc?O}%h%*<^6xUid45cQIwbRI3$51<>~t@ge&+d7lkAh}@@T2=OzBFIQGNf3A(wiAV>H}=uN;1CD%*9er# zc4l>R#N5e16g8z_)+Qb&ajp5hsf87zF0;^g*JaN&-R!P zKNMaQg8_kDxggB|K!*Vck$F2VZHt!&ShcUffCr(H+2LP>QbV9jez(orgSwdwPniBOIUi|ye+iNnZ)t!l%njOQaBBfa;uk`TdJX#(eJ+AK>E2ZO;P|G)uiHD|R6HHxB|IU)sreKQj27tFvo zlrFR}?m)8^}Z~CVbFrK+d1cUmscZ$6Jhr3m>Oye_&hqQmj~ir%;io&%9rnO|>95l@44sUV>CJw?JCGmOYrG_t#>{@4-SkNJA@Dp<)csZT$vDzgZI|I!5pDry%gMl6*_*KREgiOuL%u7)Etms}O zZ7>}qg;rOI+;W(WVGH}Wm96J(xaCG^eF4guQhn31nkZM5^VdRv^rZaJS7tn$MABC< z+nb+N=08!2AjOhoL?RFdJRiY5z1m`2*}1@YD|@o)b!@DOuaV-`t)rxLZUwErDKmc? z!@RMV$d@Vx#3^GSNwR_6iESXMH9^RedXT1dsl2K-UB+vIc)s&-cMpnd7;#?1CKP>y z?2cA*4CCY2a-MGk8S-V9%N3sY#Lu(`L3VjpG;JvruyC)G2Goiuu>2>I+Ec3 zrQvdXyH=$=|Vqx7>^=b51EvbTBx5OCX!M^OuZK!KP zi7rosF0mc%)%t(CWvfE+y%SeoL~=#fcwpbOCe5jDN`ypFa@9c8d|W-?t(*Rpmu{;; zoC;in6s;rl$7GDs9$Trdp`>_J(0M@uZmh}#ch&OOs>i%*5_;zBL+@n&geB_x4rH1` zmSvci#cSL>F%5Baja&q-SO>|GGBWfx&3ubrq`YZv0D(U=zbn01b4RNeS>SB{4HNJh&|zeko5BR`|L| zhly5AfmiVQNONzXzslJfBl2ux$hq}k4&k(*wS1&%*oQRK3xn3oerc518xKRZjB7h` z^KdZ**P#$pLva;Is96aD##+YL!D>0$LD)Y}Fp;SuD_L;!WC2ySZ60T;L}6~#7CUe1 zSE<;CxOstEM7rmFIyyfDPe*g_ma{hI1e7y;XSQhF+j;Sf?=*ab2nk=!pxG0xd&y7@ zUl&5jL6YrYXZ79L^J+T_DrPz2v85F}z$qYHwvJ%MO3!mhdBdrQP?aNL6Xz2YH_zN@ z9PC5=xd+TkONq>TXq;ye+?I1c;s@KLUTt6#g?czl?f3(?P|CqD7xq!J%A3|Wp$)!G zdb%x_6`~mk!_A_ZI2eG#xhm&eX-Hn2D7ebl>X^UC4xy&!vtO7+Yn3z~*EJ@Y(PZOR zfNhxe^tx^w;O^xN{-{p2AE9smba+HMC%OGWx2E;goAcZ+TWZt2uKOcvUYF>CrZt1! zzW?#RveyhQAl3J3zYq3Kk)9zDg=_}GuPrab?+acL5HQVPN%}T-*-T*{^P`Hl4Fv-g zk)5-QmIn$sS1!)CsvJ&s^o@|?yaytP*J*<-49~Dw&+He(g<$FRN+?ec3DPo~D5^I6~N48v)9X@(De znw2A{F3T60oFpGy&m5m%4R5s@V88+p#ctzQkr%;R4QD^adqtaVsEAJ&e8XOk34@MOlw15H5hhOm@&W>U!<3xR=FxgsimSvp@DQ>$&dU zdx0B|5Fgss0SeiR9IGLTC+WLzzX?MxvaVQ;Od=XP6TBcdtS_rP2BKQWTkGDdCTEom zUb_vo3&%&$k0CuC#NF}gYkl6R3Al_Scg!A6zlcI{1tQ-$n=|M3H1hGB-se zS9&rn%?T{FkOZ~FGJg_UIdR=SaQ1}t8!C@GBZw3kVoNGH?=$y1P-7W((c63Z+P7ZI zkqX^i!A?2%X!bq}ns9hK7XNwVpFluq1gXz?#+1TA1tFT6>)NUq(B_d-xsjQA`S*yY zr(xjN>Wv<|pF;KH$CCCu+Gudy?Wj`QtS_E#AX!?? zu;1?{z2;)dQc^*Ze;L2O3$)DM03!?#-m7i@s=d$D^GnP=S4eOe^Y~%=I^x*Q#+>)s zoJ%7-_T0Nd2+R2mm4s9i>d;m2x%Q@g$$VmxJU(wF-$^TO_tWVMJ#$OtvEA=kUco0n zU>BPgXfKzS?f0N!-v;<7w5buT?Q3n&Q{zcIuZl33Z>)q$VFzx37~v{pkHd2@z3V-} z*6ZbH6YIBhI7KAzkltP8YqE6iY?Bt}3%Sdcs*Y>fiN!aSgxly(2$@lc8d&=V7hheu zHD6v|`79Z3SU~bZZ5mqJkz87#S=R?HbzO2z*YJ$hz8&vy7a_d!4_OWNf}Ix)Z5W6> z`>kR$k&y{&gYzHGZW#)mN^UzFrDP%3z2AkO;<8qWZ!3NfXuVj}G?ouu{p!qmOu)3R z7xQ>MJXmO6a)CLh%*JF&l}iL+5QZXv`2MXlMda6`Jfqo+|YKiFQ#O$JJ%#steGh%fX~qX&p?MRF-4pB z@VGv~&FsO>vUvEu(CJCrb;-epJNj;$>93K_Pul7nw^^I|vo{;kcB3_mDK4Ue4!OoI z_m)WSjazk;Ffa=JtiXf~(4t^i4_WUlm?^s|9H|ui&Hn z77E4JnyvOxP=^PwdzlFhXM7*|gIB6!ku- zNHHO1yCdMbcZ<~f^f8!@X(Wje7OJqM;K|?Ww(~TR1E|J7KkP|4)$a~YQyS|nV`X`P zuLXVznQJRQv0|~$#Y7P&EIQXqh5yZz-)tniaDA>X=J%7^E+z2%^RrkX><6B4WR6T_ z{(|v^7B@%y`bnwoY8Vu%Vx!U!4%Z$z`u&X&5xx_6G+ki7TS~QGZ7V&6d~XDUiXx)i zKi-qwd~4)>R|cO8Kny$N(6W)=qh2;9cOm1GaN?gcZcbuP%C4?*;%5l#8W`CH@l^1j z2$wl~TlLWNd?kX76FkYuVYQQjjnr6)&9Bd`{$$kAx8v9|vu!^JUSB%SG&+d=ggm)= z3(RIIf^3>Y*!#iK{J*c?zTL=^SpAN7PPT-BFB4n!<)d!6;F&Y94DtBgJXjK^ zgxbGA`%i{&Os7B@-IM{H|NPS!1^5r zOFW~Ey6zy1`Z%f#2t&x9oXQyL~izv}vC#?|NI`@9-_0QjpEC`vq9f8Y$D-GzW4 zKX*8l^83K&wV?$(QN~?tBj=Tn^aT=uzu}vY?N}~Wg?LjyU^m+oczrd0t$ua`1AjjZ zt2gbRk^e&*6T{~FVGH*dej`-<&#Am4yj2k#jO+)}f3aXY&)CgLi*46eRY!viI5q(= z?N>FjZ*n2wP~&PX2Ap3XCh}=r?Z`C5e|}e1<=yP6OlcQ)Fc(KTQdNyOgl0JVIeDAi z*T%OtjP5Fh6S<&-_#H{he}iJk@cLZS^RS*{eDefz+?zuG-rCcUP)Vh@KBh*aC8{&nqT z83g$$`i_52WPpuOW#-l){d1(PZ#QP1*nh#8e;q`tx#4GnKC;(+#?_Ka2>0lI>X`o@ zC>SQR9z{A9m>4?mbqf53N0K{AZQ-Ztu|Z4_EKx*=WNsvm_}`m?!H|Qlr+YcSpnR7S zMS(}j=D2ws-z>n%L10NNa=VHek^w*6Gx41JiKr(4V*bQ2wTyc_CQh>d4U1u$_SVY-x!)&C z6QsP{8~6`1G3ev|-*tl&MowjrwZw zX|-1QVjLB>(BI72L?w@TfNVJ&7AgD(G6*)=nq;&KDPf!joI|{X=|3cs|A!N=F9)v| zK&wdnFMOqND#f&ayz?K_$ziv8?9JA|LNu z0Up#ilWwTW_5otl0aMJ(7~iIh8vX2*S6Q z{vOCS7L~t|Di!<}9J7(oNS^cW?=6bWd~ihKwx7egX1q}*i;I$4DYDOV-~3w{4u26| z)-{`Z%cuE|Dv6*Q;gKaC*=6Z-6`y8mjEpe&7jzrS3DC~J&%=Z4!Me|L3&4;Oafp;X zdlP|w5$?!6tHX0-lC_%O>}6HuYv7T;&>fCjU}#d^m4QeQzE>71mG!TK>8SuZQ0fqo zmG1vVr`3`lIU;Sb{{ve-D@;R&7f|D`U|ZSZYqs+zFw=J&e`*_~by$(gZo=lLfchhL z+pZ0$Z=`1kbK9j#Gu_W%Qg5XjosxxA$J#^n*=aI9F zPFwzvhvxP9WSi`J*B`m(ird2OkfVm{v#je_DpKxqz_0#omn+OO*KEJE?>Y=l7 znp0xYSxeL|_wBT?jh1w(I;AyX{+HrL7Qj=M@Lo-|7#be-JvexkkzG^s=;_m^(sFWQ z>gtI-Ymkmh{^CgI*RPkBpy9?WiU1u%+3(IrH4VK2=Py^eEp2R6OLSFwdV0XL^MJ#( z!H529H(F8Vy~mSWcSTs%v~O)z4*cMoMHY5^`LevJpZiut!_fG6U|Jf1!&Uxd!f%Z} zj6!2j)OfX3)s__GWulVTf8S|$PJsXHN4%w--D+bR{-a_NbxqCi{pmz{O|ZiRN?_-8 zGV8+IQ2&XgZ)h0nQSKu-nWrD^GI~ye?rfeA5E4S052&*|*{oL+wHf5&>M_ht)!oII zo4AJr{C2QwaX6={lmDV@GArnWM(vcgre<(V4Eo)B_ZFwBO?W%@m%A6&)`|{Xf{eTc zTU%T2Lm=X=t~>)7e8Jxx9Oy(s*%*>UO<5S3m?U4md{-@`7F6-QnIVeW~np&T{zYZ+Dize*N0+Bv$!T>Z;U_d)^b%$jB%c zuF4malo}O&F3*-ff-15s!I7;XEsY+I9d?+zasKD|d6;3#G?fL786@?aJ3C21J|>a~ zwjI|rUG+!P`tB^zT|zQ4GRFL!4xH>*G@BTala@1*+czF@l7fL4wh^!H+@ z$!v`ZE4epiW%0Wm#0n)o`+gLUZ<%e9Jy?95@9osQzLIPh)Yj%RxGy*O;J5vm=Rp&3 z78@IDaxjj3gX>&kZrVx(dHEGU{y4=37^e5#;b&_F>M>9PwqdV`XBM8!CMo+;y)|ZZ z`@h-!Wp;{YA(9#?teGvmKn7u3i#N_BZlB!!Jpz962FbqaGWNf?zo(+YvWA+KRhhw@ z4l%6BujVir$a3F$#;Qy{O-dT&AgBtCZ8}Yb)O1`%Mx=np@d5}n+)y(eBO^ASvGP@N zlJaes$e53c5U+vw#|I;@f~_Y4q`$p}h#gVPw}OI#CIb)-f=J;f?zW@ws|zF|dzQYz zz%JXPyGK>tFnTL&)9RpS94@)8WD!Ow`~LNkWE&?zn1Fd*=Dny})T_1fKm#*@_Kk@& z1%@RI>Wu!|@0*#X(jP%|zgmtW<|FxlUtC-)-Mp2#*;X5YTs)r-k3S#Y&z-;8ap8IOBcxGbE$!CbTiJEPsaI`hQoH?2ME3fQ*#Ewfn zOV91Ct?;5ucUhf#+U1W`pFhtAM)Le1p82J)B&pRiN{RD~*ypu<^zXIOp4c+?ybCfg zMVxN&IQjpeBn*dDRYgUdB0F0@Ihzf$mW=)iU^genWSlPXO@KeaKavhPhDH3SMo?zCWm8WH}dW;$P-AXbY5eZ-uUm@{M9>`4)T>%Vdea(EP*%uI z3!lJ#-B}5!0_2Ae)q9HHoXfkP+wz0cIgJ3#FY(F@BK=|Nk4l%7Tjpf)m|H#kvJ~ht zMu-HK&7WC>a0wcSg}jLr1<&$-(+(TR=}|d5*!mi!e9qa(G4dOcZsS&oNOCj7u^}5y zORJrwS%B3q#)3X%#|e)AzCSVw3iggy+8Y3J+~JiP_CCwP)iK({H{z*M0=_*2835qW zQ9S>^XFRoL%XuWV%FppIFdoAbZy+vUT1X6*-niE!L^S?I^;m6Npcp_riJb-y%8G7BgtGd|XoYtM@ z*|Et>oJO_BPg?9P{}o)ZTA9!=MI=8FFhp70H}-k-x!5l)lU2!Sqju+!nHGcd zpX(RY);f))_>xmlNRU!6=H1!!FTN&w1RrFDxIW`X2Ah@-4Xf8+kKv*c~1ony{nWXszt-Mwgc}WZWovNxm=Z6pNMH$n@Wao76r}INIj0yLmyQgn9Pc zw?wU3tbKSAqXf4_tgtpnS57rptNP5zCM**RiRHCBC<1}vtRs>CJNOc(VP`90PWEb7 zRZ0d$1PGbYPa5aHtTc)&u6QGDQXGQ#5zETUMa>F!pfM13-eTcEOQ{w$HR(~7=(orP zHXKu1s>*?JeU@)|y~R7Fq@sv$u6y7YSdSip^$kWO7=ZNR0M7t@2gd)e)pn8fTm=y& za!}j_l$XH^1&?|c)+g-d=EnQs#&Z1xsTItA1IC(_McLC!owfq~4Z84cYOS+b`z(nY zBJTZC69^SBOVR(<6FQxJ0|QBlf||v`;r!#AoW&(2BF@g(rF80}{B*^Ok3fu91H2RJ zmY_`O_t+}UJ7-IY@!HNAR@{Udog}NKK|4NNDw3&bMn|u#uFtsnl<@^9nHz{fr0ChB zZ`VQy#c9M+G=bn)cjrMNScHW1h_NGVblKkV=Y#tocStR_0fd}X8x0iWN@Pv@XydAr zJP&h|T`gKdpr*B;R5m%P%7>0V0p+0?jcUC4vpl^bDeV{KTOlb5GiY6sx*Mc!eKOr8 zY@f3&O@KcbeRnhY#|f{4SALdFc1`_liN}PC)_KtnFsY3W@?}zYYs+G>%iWLTsLIk|iu?j80<3thMt3pTd=ln{dgd16JwdIX z#d@hI9Nd)z@Dm28NybS(RSC@UuT39~gAsfq*Bw2?@(7l)Teeah} zsE~5MPKx@kXu~hPG&3_Z7sx|yTE#1h@YUh zEiNk~1e;+_PhDAmsoz;@AW?rSgUa|L&m$=_^}saEJn5VV2eN9DV_}X=fFuqG^wn}j zS9s2gn#)C>4UJ8V=W0*`d|O*@DJS3rEhD)+?~dB$!AtXNui3{Yr88z@vm44!K5|^8 z7T(TF!x8UQR;J|Ekr|D^upDWqA;Y-(E zWrf)5T5i;zv$U1*ih7x{X<*l=jkwGYRM!tW9ep!*l9HJA4$}NTIIbDbUH0GaQ6Y8= zX?>?4xIbobDLsTh1z2OWvMN7V#@fcrEi}-opqY-oOqt?b&vwVC=H8>&=?*ChQhN@mpZ9wQFj`E-wgNwzGn`JUQj zc4;9|!96Q!2ukPT)kU@AB-7974+uz#0SnI@@JnkO?a??n>>l2~|NKVA0}0vQ-hO)m zCt{!$lE06lwTy57tF*8r@ycn19-f8zsFnZv`-1(l zY1j8I{gDxQ#sZ?0kz5cvJK5LU*p8}Kwc1Zdreu5AQBlb=UU;Sl&(z}GCFY~*lK<$k z-M-U;gpQsb#KvfxapNWZ62#a7Ex7W>@;Ze=AhrJDm!z1ze37qfm|f+Et)cH&|gW9pZd*=(vv%vTrdc{nhXL}@f$ zzZSlFffq)+$^c^L_Q@z<&e3gj-^S&# zQ@ha)t3~wMd19uBxkZBcFj+znFQ|6eos$v-mZ*}b=$*~2tsARu2jjObDu2Yu!BPAc z&3SA>Bd__%1M2!0P>tb}k{a8K*FV2EpixI5Jr)!ol`y3%u@@G*qF7(HE-|#ig1xIQ z!GIY*R?%-FP*eXwe);^mdp!Ay=iTmI(PcIzHI=B+=_V;hPcd;ZN5Y!RNu#C}=Dl9S z%R5PBmoEz&5n*{x1b>q1JhzY0H_Xj3xPlSK=%}rAFQAXl+J3KIy9ITOw8VSXtfV=w z8`On32CDfazw_40?>VkW?Be!8|*ud!Mkkd1yApG zlf8e3l~|XcaPCF66jmTy1Sm}ORoMjPSh5qfpf#9D8;$~%Z%#tGR z910|dL5inVYFrZ~{dtwG7$1~_t1Y^5Ud%SB+td%c6~cYKqy+! zTXFI5zWg#6R_|M)O2E8T!kf+UuA5c80P^{Eg{7EfrJlUds45S74#-&q9n z4@`L@Z>ZNeBv_xfj!kgqve6TCqiv&C_au2e=FwMq%pO!GBjjBy=4J9!q5JN_r&UH{ zZ|jSR{;B2Ukjn$>F%&NEGFwo;!(@CF{k$9Jj9>`-8*bd=t_FcEh_M1Kwzn>A=3Lk; zQVT}%{2u1Sj5XMd8~jBKi&}0iXRI3S;y6S2j@T3%3Eama4p-(xiZtR9&bxZP|80wd(F=^OBTvMcVV0KV;s!mp`%ZnL7puET z{b7CT_j_!m&Pi-JvNmdA?gP6wA^+ zv7fQ9m({GnfsyoFuH8A=H4ji=)+5mf`Mx1+O)S+t?zOn!>Zs$`4NwJc8LS?9}0HP3-#G>rwSs6ATsSc|9|6 z6 z!A=>L+M`cVQ?eYKB{JVhITLB(lXDU>R};J|BkG-B)rzhLmA+Djk39OOfXKPIDI-Rq zgS_<40?II@mDVv}E7*)_`ASD#hq$OU2NW3gE?S%2@#kQIw+`6BcCYcK@n|jSu~|N=>U-$VZMW&#F2vQ2knl&KCD9aiZLFH`oFZ?A}CK2LQ#!~+@A_~kRsi7()r|>=M2__DnE-)zXs7m1xaNy3Ic-h z=R)a!kdhnA+D`6i=)*EV) zYsyv9UP*0d|7rK{%wv-{MGNaeZYjoAh=dl*L=kl2W|4*S791icyQO}2iFCGaVP44$ zakKfuusoh+2@)Hy7j<~SWPV_22*hcncG-%Eh};Bg$}FSaEP0JzXbTXe)5T)9A&qu6 zZ!@Chkw)sbH+de*9`zW$J2)^;`3JXLq)c8RZ=|Sta(V!_l#brgE|u zd$sMI)RrigY*YgKa#Q6GCA{&W>jzG-x?%=cH<-7Y5}JX8T2PmBMqU0vT{GYK@4DRmk?yd>$gS)$17~Gu@+!-vmL+}70xwH3I_g0ApRv<@|DmBzm{O!FeiAuUHp2Y5c!g`9*4amtzH>ATl_}LBdLl51AWiQQBJT zpSBpXXe)i>S5m1UL{?6O%*uk*YQt~r%xZ4qg68DJv!S{bxllmN%cn3OVrGy1jeZk> z#ZCn)eTv>SA1l)+`fVa?P~h1?+RKP34d^z4YNeHvM{M!@Oj*G0hL_%0cm7`5O5|eW zAULoKgG0X*`3|eRbrzF|j6AYhL2+o0{2QXOKyJA1%1R$4lDzzUx$%;uxABS!+W#3V zp7JwXA@o=%#LW@!a$;m?s67`WOib*gp?j8WQ+2}G`7OLXiN;v$#Om*#?gKm~uhZGl zKj+I_tK)FbzH6MTkduhoxXJjz;4Ffpreki!9e(#pL`(5ct6da6U)@SwzwYMiy# zPS+Jfvu|J-imQBXY3Yd0tb}82yy9|Fp+q(^Y8MFwlQFO$=JPj<=J7eTSd@RBTiB~O zmo{7#CIXukl4dTv*2yXq2Nu^p`2ZkNSLWnutZFc z;WX5}UT~r%*)u0iR@LG|qA0`L8(~hG#lnYb7-<>ksJMjC>^y%gO3Naj`{l|3tek?( zcotH5MsOKb)+x|tuWRCWqBdRmIRykFHbFb`O!fF5^2Oe0X!cD^lJR@G%3iz?e{ljK z)v#e7p!1woXQC_)gYa=Jv7feXa8@uOXDYm5mwmH{Q;c$r6~)f>&eB>TLUWO~liwWp zIk|EmNxDiKB0-iXGjKVdA?;8-LPHag_gvA}j{$5`FXVU0b;l zw@4v``CB8QZ>auF`arw{e2df=5#uX1^M+=O4~42<`h75^c6I76IB#ym1nX4@79ZGt z`sXSW(ji_O!aYMSe}h>|-b=CxAg$g*;1SsLCsd2`F2o?1WsS9^MBB{QMA>en8f-c` zgIJX3;kniN)~RKpwFw-q%^`I{;cSXnIUl1o@Vldx7t1}2pXG8Mp1#6Mt4^Lln=cxE zz}(k}D7cx#BrYE(%DD`_5(9*Z8&LVEJs(UGAB6ayk%&ifIB!2w!Z8%_^6)%%7R=T7 zXAzDU8c;)vQ>+MvvUilRGo{Q509$@55~o+oiG!}aMv>XSB6*Dbu7YB7!9(2ni6lp&vy z!@;onp2!6G*GI)o@{i%^{sIHU*#2fw-<%06^YG&x20J;i^GPH`Xy z&81%Y;%@dAh%0?IY_Tsa7pclM8K6No{wE&9Ta6pPil5ce0{S`%>i#9R^EVv=0PDG1 zztA6B!Lc}r2zD=B#;(q)s7!(!*xfx#4$~n&=W?md^fOID=dHmZD$3Q{NttOR? z^z)2{rYF$b^vEK=NJkRx(bH8^^Yo&s(CE+G4|wwaM`^g=UC4>vo8an$aF~qzS-S$( zyrqpi7=N8khc{)CXr3AXOOH4;lDdUAXf(| zC#TCALQv{=1|cB>RO(6_ZHW*1r{cp|ROzjm&i3|N$Oy{Iw_aYx`RRx9jJd4^d`vA$ zivTUbW;IVrYHC>GjI#9-tYk1WHMD)IEB^L%qIrz~k?ovU%lKzBog_AU-Sn4~7j` zbb1MFto=cR`qQ4#xe?M6S@;(B{N(bu)w<&aC_`|MI_3amzxt?y4)+x=$=(%$74u5S z1gD|1;>c?0RXM{B6Qu-~WkgKts&v#)E+7GlJ%w})jjZMl7I0o(N*lc#VP+O0!XTl) zC-@**=cw*DD0;QEh}LuXv3B12HLg@&Odn6}rO1;v#iKX&F)KCYp-6_A{niBnA;dFrpPe>Xwq3nLK6AR#cq@oB;f@2s%3dtfG;2efVB##@wU)wu90eJ)GeJLe6BCoO ziwkLyTx{ItXe)A|Sr;;S5kZ_ePo4$R+Xl&1CE`h)@#_gvziuM?H`5B8mH5z#`CHzV zGRD83@I?$Pzm*cP7WTCSPclW)RZ;UXAP5;4Oerc#djBMEpyEl$gi}VH`6=x7In9YS zfN#B&_knh;3p^yhM`()=%o*JO1Z~BWl9d-Xv;18ii1hU(Z__>B5}k)8L^*c=>93Ga z+TScxex0eL?Wk^F{2XU%c$Vw{{Wg(L``S(0q(?TrS3Hro@o%ElKc-q@bzW&$fzATDP5PhI2%HE-{|eF*<&93S=@v-wD$`b1}s2uc|l$ahIjhIdy8W$7WgB%&xfY#&FTr%Qf^@aBcV z=kcx|w=zTIF`e!@FAq(_h22bq!A!58){f{rJU9)anaOoMp$9Rc4zfm8RDhSJF;+hY z!CAxP6ds!zrb zl0!S;(4{;hlP3pKN*>Hv6(QB}x>1EGa&uQ^ZgY;cxreX)!LHNWHtgsVI`MtN<>mjG zok6Y(5ASS`KS%ME+d8qV=<11#Fw>1<_$%*RVG8sm10 zi&X_T)lw0#OiR2W8mcTG%4mkCTLO)1uPm8zJ9~i-B_s-HWgX1dsMZchN+e1-o2ug{ zKSg=lQ{Rp91{1TM*;lG5Zg+^LLRc)3>le0}`)w~$$n|jYw3i+uzz@|ozpd^6_P))p zo;Io&7%eX9#AT%CbB)S(@GKM*Qccy~NE_AjHo}7q7xcBF|0L4V+8iY`-WzFO1pGTD zHtv3x%!}cF@+C}xrzXv4foI`{iDTGaUExpxb#^8w+{uakhs<+ytmEvL|M4G%_Tuwo zK2dB(?(Nx`x)8;My@QLijt6Ge@<>z^k?~2Rtn|3sT|J7O$8$r0*Qxu>A2{!JzS25c zynO_Gi|BH?85J~(ePNQAGL}Wo*r1+A@w(Z5?FuIcIX?sH=jI}G9K#mAqL|UR0Lw7- z?c&VA<)y^83sllyps(FgaL!^U2NI4bp$|vhyN4HHz*`0H_Vp544(K4e?%}M&PgN7{ z#^i}*^a^H89oy5B<%WL;YRCsqZXaWKDGYMK!t-T}mJIkkxkD>sus24d>iVm!_?D}o zdUR$?5eh7F$Ak!7bMf;eml{MJt@FE1bBZKo&dq&YtvesJKJ0`CJ z+PKEIN%2*+TEWejCOaQ{9D%^rN<0^)p#3XOxV1I@7C$i1I{uW_CZ&>0y6_4k^Pz%1 z3HieW!@(4$fale%omab@(;QEEmd)JaCb7HzBXw_w>ho#HgR85XjEe?8_p)Y6o(@pf zWvhRhOYw);rw_6O-p}W!5wclwN=69W=*lrWxPgw#@Xz^&i@U4zuVUwZ@6>GT=F?Tr z0vQjQ3e{9@F3V3ZTaFrH;PE3mWQtaFVf%^^V zlhUb2X+4>W4bd+nKD-WoT~3~kuUm8a+6j7&-+%0aBOwjg@kP9*&Z<+VVsO-1#v9U+~ zE74oS*VWXh_jW9-nZ}Ipy0Ml77;($6?nQN_xq91^TslbvQY>um!f)=j&*|-iD5wxK zIX>uR%b|cp>M>hA=nHT~U+P*adA9)%LB|nM?Ts5mvC4-NYdju;t@?fZ`cFZ13r=5I(t%b_ zs3>=JtRy5(`dj*=)7!9Ge&Z#eXvl?sZj`64herhw_jytdA72GiccN4?+Yx7ZSBi#9 z4Is<&ja#f&s;(o=2s!x|uFWzAlDN6Qa!#ZE%o3w_-AVfy7dFi9@n9-zutdi5==UbZpR~9Gp;M`PENgz#F+ASZAzzAi3yVLjEB(9f|zEPgl zGzblm&0f~7`faaVrgI_@NB`J{!SB{h%2@-=5P{h&;mS`$xYt8But)^&m^hZ-+nZ@l z=F8;zVmbz!YwTdcBiX%^yiX0zj@Y`StezeM7;OFeJ+vN_VQa|=9*h^g4R-ZL_3mE4LJ^x>P%@I!bK4Z&3L03i z@=V-0yKpAh(=&7B1GSg8)%U)T#y@OC&H0YsCZZ1R_UO=ILOgOyvXBOX2mqgtG7cz( z+i*t0!NI}uUi-gBkF3nhB$5AmnG({L9!h~7&$cTW(dJPjsJwP?;V5&Z^(JdySfoNc z*1KK3c{8Y<7f8D{4HFn@NNumA9NQW=QcjwcQF2+q;Kj zk4Oz^-6Xxp0DD2uHx#tHAg&!(g?4SmXwKhwJkuyO@;WwMcLeO|I>Wtq{hPH7jkbj8KL>laSZdM z;cXF(u?W{>eus}y4}u&(Nqo0<$GY@EO)Z91$x{WDV5N0+x<>4fW|Mq-Bkw=BFweDm z+NSO6Sr@Tafkn5DGCCW?Ik90No|8s2nep4Rf?^ySjV2zp#%sZs6}Mko)x{Ibg8|4p z9|b?iYmD!I*WPkzlF}#e=?u^?*#)f|6wpYO3BL4Cf|nMoljQ*uCN$O5=h(0+L z2=U?pJ9xv~U}AR8yCkj+LqEn}bzcWiJ(A1ctv|mAzLA~#y=6!y_MiQwPT1dK3^59F z>*E;Q)piTCdTaW7+t23x_iS6MNc-je^XJDckO*|HV9$Byct955SWkdFCFxz-x%oNh z+im4N-i!o{N>*}IoY#L$M|aVQ=-$N~fFZ|jqNy3!`iOXq z^#&o^G=R(C`**X$PPNYw)n8sR!0(p%)V3OrN+AcIDhVaWpz8K0QvROo!kH^YFE7oB{(zyA8t0!OpQEk(Q(DKL zbJRCF)9&YR@c43>G|sn^4qn(aiI|#&yo1qu5}Y*mx_*M!6y1 z@b4tqzFuUYC>Xx2Lt=2xZui5kfvun(kTx%?EWjU~ zd+a@PcmFj#0gW|Lg)>Brx{kPZFEvz65uqnv2`j{2s9-ipTAatWYiM&58NZ4?r$Yq) z*hldd)Lv3bkNuy@G@^(7Y^6gyhk-nYy7>tU4#?lvhu(K$4OLJ;MntSQq>dy>0fB3I z5q(Wpx1}XFW=ZE`kV#^AvrF)GLrHw#W7RG7ln24}8NcNlC~~%AOx#NzCdn-(VUfFe z#;39gIDHNebIfGa=ipK&rQ-pG9N4isC!?XkfJd2Yoh|q1Dg=-g7Uy z>%b87BqTRIxe8Rxt2lR+3Grb^SI~~BHhvnVfuq?3V(WzGVeX5x$hgqMqoRJn> zd_3-CPD1z}YHeMCvIuumYyCFn-l8H zTXeC(%?=gFTYI6yGV1Yp*s>P#h)uDvHJX7XBF6v6(;FDrLny8I-+_^P87RktsWK)% zA0rGFskucTo+gK!ok(1hp-rfRSYVr4S-lWkP}No(oQx6*f^$OlQ^2+35{kX&S+Icc z7O}dHY3v{9%Jl$Tg7UHf1492^x)7wDLnC4*b=eO?Ni0RJh!&r?nAia2r2ZBfMsmA& zZP zQ=F?8I}+}B*OSN&t_-Jt$9?VWEUA*lH##~!rX2Nyn}MXiKj#SBaabxPXi(1i5jhf{ z6TUIo*&z_8`wopy;Op|Q<@AW)xXflEnp(g~;h<~;uuklQgbOZp|8(4Qb(9U1=|AV? zh68Y;UBk)?XHczATcD_-zWEClAAEnq2k|=?Fw|#URB+!b+}*4NYG*4RZZWTPg1#SY zg&s^z)Rh%E=^3xm@@#o`mb5k*k9oxVSGTmlu0FAy(bHsknqJS=uvg`9Qosq>rFdI& z&5X@S0CU=pII*XXPWvjLeW+Dkv&StPpBOoqB}~~@qX4fJ^>5tBT8TLuF0w86eV>HIW9ecIbZ&gjvmxkI5@FMLt|qgxXQOc zWUDf8fKS51v#k&9lmMT`xnAVTC*2xSp&puC?xn3^N|`Ko zoRLvdS496;=hg#43f4Bu)&?hPs_RCa0Nc)}N2@eL$5$czevL#Zuwj+}SEr;@XGf@l z%LL`(hmv|uenNq9kJ(~H%i~KeT1Q!BE3u4sk=MP3v2H{a`l>HLwzBC4op&v(ykXv- zlA^M5)!z9M`!Ij^U3u@YTyKeK!W{0H&%X4~$qAs6`{(!l1H?&$mYX{ zZc62w#>M9`)W-+KSi0l8Ldvq}aU*))cRI4E`W%Q@ucFZ(;KN@|eF=-7x8O8TZN8Rg zCAi}Wb-J?iz9r(`RF7S)+dVQjfxMQrrca^6ATmitcIMawSc{|M<#~A?cYHi$l5gk>3?%WHbnKiwywTn3=+eilhgSwFhr^375KVp@m3QsGKH~r2 z5LE6A0vh%2s37_p+$XlBpjvoov#qbwOVkEt`C74x@gi@aHeD=)1<{lSGsmn+s9F1z zF({d+${GjzOrmdmAfHVD`)vPe?OcDRV_uNfbPRsVe4&()l2~2lL+T_Q-xHgqcq-}v z^J&$*s2=-0a)LTt@u{PRz&xM2fwbx40?fFtj360BB%3KcCl60V+@$6elA?@)M(LF( zfnMBtBzn#sC*bh%*DrU7u{Bx;GPyT3>1xe0O*jD*s??H+?7{J3@(o>-Svd~}zC7@IE=jduM?Dr09O!Jk^KpC8I6LBnNTVf*5-n`z`v9KfKghI!7*4CjO zyL1r<>R=%k)%+JS9->JqXTeoRED|YM*WE!XbAQOfST9Q>W-TGY#jjghBhlKS9z`FI zG4B&A1S~6Rk1~D*10rTBz;-UBFl;`VX|2*oiY%9l;9D?&Va~KhHl8ClEYaMV8{<9M zb+RmgyoH-5>h-Zvo>dFDvno~9w059>Wp?$+%F0TAej~atpTO!_8zl#YM_rVeB;!OG za^q?fMEUTlX61*e9v2c|P+&BfEU6 zJ%~L)58|pKI<9;{BjP<_mYIr%>$fe@imItZ$T)WR-_Qv{L`-4H6?$FFkC&b2b8s>lG<>(R9+p8fL+WIX|auI^ron{B)eFIH5I0J<8t>gsts z8)^hU8c(TptGqUnZR2XozZVhmO(F~^X%&O3i&0*Mo4-uODVBG(_irxZq=<04o~lY4 z2hRI=I_KZqs!?AS{jb71&_^6Toz>huUgt3~p3z6z@1^u*jf>Lr) z(g&-Ip8VAGHaXK3A`gTfeC5pz!)->|0{%`DZsnv?QG$Z3j|zYt<5LSs!=*&BMxB5p z+h?S`J;SpUWt#ZEx@F-|UAG@$|L0 zK43-CpTx^_wh{C!E@~%#<3U0c?L(@!RUPG>9vy`tCr@axUZ#Sjr!4~wXwA0YjBP%& z%fH+@!B01TW|-lDJolgNCyP5gX^XeRbAhVvo&|A@)tVZMzW65hm#UH>I*kGr}GU+Zlf(0BYvpX9{^RT(#G?ces|l$oa2-NoqS%T{r{@wPIGD_>BW}5%Q){ zQi(#a-`CiOf_qe5SLkicC$`ZxZVr<(wyAkjH5&GUqtP8O%OJt+`e4Xv#3LvU6zVW) z51Cqn+L(ZouOgUG3OtrH;K!EgYjwDGR?cU?EZ%{fnpS%(ezJSHU;tJfv9QNqRu$dv16!Lr z!07^a=X?hYIRb&MUg6>0Ie6VQ(2yfakR1V5d6%$=<$mk3y}fI8W+@~+UAV>>UeR$V zfIGYfvQABQCaRZ|a?GAX?W@`{exOBnP|FWqtgdCI2)!iTKRZI5*9pQb+KHUzxfARN zVLAzfoQ6#-ub0UZP*Z>Pmo&0TA@4PvlAqqFz+0o_*N0jTkL;)SqN1Em2NEDc6=i3O z2ERugd#L)tW)g71^GHT!$Zj-v#^@*28^6H%c@CGKnD1{{%JMF_@zmZ{XNbm@xIZ9< zP3LnhRz+XN=*}m47Q|y~_5HiE#hcD6Ti$}`UN}#6ht-|89hL~2U;!=a)KXWyl=8QZ zQtr<%SQfh*-InMlkB_+^_dfWo=Xa{V-c%QcA`6?6I8C@ebg7&_l>Hf(gcn(y8^&3B zoT~+g#&^(~RDn4?!3ZOsMmDzOKF3$Aaz8!&T5Y;AGf|!kTnS`%4}Lwh@CbzECXr;_ zs7Q|r1zKtcCaTB_bz!&d6#fLPk8P6kuknBL+}Ie-y~s)6+3v?bR5uhOzNg8&sPOvd zvm{rm%3N68bmNP)M-N9a0kmo3$`4IHFw*1jXWNa22`d%U(vRwxQ}vd=1Q@a0tFjT? z71W=KHlQ(&X_;g)3X>u3g*b6~Rb`9rdW;rqDzzS}Q>L{|kF#(GHX9H@>jg|%sp)kI zs-}u{JGvw>`2yX2{9@|UVXB(NW4owdg>KI#x}$kX^~Eu?#$WJ9IkHj3>ei*keiatl z*Ro2>UNK#5_961vbw7gwI=j7E96JR@%01N$SHrv%T?S%+?4b2bYTJLtg2F<>9)D#? z)f&gT*DQ}{1FT+7@N)QF;J3G}pT6vew{LvmCbmO%S`?r1Z@9clXGGnJg)8#@*x7e> zgCTi+ZE5vND`p@NR4Gna4GZX!h^9F{q_fg9;&D`(qh$Yuxz?N;RlB@x!`0b~AezNG zyhalyugol?Vx?CcWb8)7Uc$VCVT3O{@_iA?>t(#Gqypcd6>6g`kiJT-vTF+lLU1e? zV@LC8IVc~62`Twf*4~73dOv<|i%3k8LEIcb<}oSQzv@CCmf;&2+I#7Q9E-KswBI2) z*UE?SLi)V|_?ZM5TM7hoK0YO-J+SBJ8ozcr@mZ?XqIzpOc|&vJkAHNIim(zn9DDz* zbW!#V&Q#dd|2&VR3OPSt1c~S>SEmy1N9Qt&a%3k3e2CX#vYvLU&&a7RG_bhU$R>CB z+Bls${)C<$L~Z`)IFxtCzGnufUSGrM1%fjn8LkiSiA;Qwhh5P}A2|X#7VyGfd3Wr%~9U=}8`{y|g`8h<1-p0B~#dB2Cl<>%ij46-xI2vZ+CW)^034UGg@3)IZ2 z^vY}Ty|AT~bbSjuR2(ZD!Divaqi?<%*@5>?Zn7aJFnP7W!t85V`kR0x9rkZ5`Jk7vEou~9>N&GBneng z%}O|B%u3kU+h?+|*H_)f09^~*#9{W5lsI?B;n?R;+*$jvRi;B+oN&IWvEzi){g{#2 zAT*3*5MfczIZw6<(j8s)(p4kn^n2#M62(qDDnJIs{#pG~wX|(yVu61EgJF4zy}6x# z_HWzP>UoJwkQA4`IOy_X`C$o@+?#MKJJ2Ev)2*nw&iIlWuzTm zX*B-md1<@&lAc)hC09=>gPR~V%(7`d@u&L;T($?SNJShnqQ!MELITzU`=l#7>&kJs6 z_;!Lqz29*37rUf<>V%RD1+H8e?7Tr@#K~%9c*yCe=X3OA)t^ORq>XD+kQqTd;yXxPQB9p_eyMgy`@o5aYZMr00H#M@IH_&ObU9I5~=aaQ1k!=Y2eYHH89TQe# zg22hHb{O!%S0PhFiKS%(psLjRrAam7PlZUJP@-N{()ETU&Q9DSrz-DOS(r$RPgb;t zC}poNNXzBG*!vDW8>Ntf_{wmdj=M7{IIJ2o`MYl;VntubUz8j+4$c4Rnf$&iGbI)l zFL&>Igl%&N`Vj@%3hXTht z*%)U_rc_sCxumJZ{rl<>Q3xcMwIs0gB0GlR^qV5s6@a!UkmZG{}CjsB42PVo0){SWpo=%89&(*g*_Yim3tu^G`4qYk+ zdSuVqPkp~IW5=wAHJq`9O=hBz0dM_Cc=ue0k?w#5_Xp+3Rc^1eGydh0Qz(Uu~j zb;oyY-Ij6{_Y-(2QslI9|G|DDN>#5t!o2rK#K@FTtfq_-{k4%-`<0oG!jD`9&?^<9 zNZAM%>!me z>0|FRy}@UVcw&SVEsjOQH1teOWUMEgWj7nw4_(AKK-kL2mOpo+g=!~HNdMtb^_=%F z|KP|U<*A{a6m1-iDcO|K@~p7Q1*c{1kx^!F$qeDfO_9u%$Q?Ho0q{B}F zafx6S5eJS-uU1U1aTWi;3&F53;4DpzTJrt{X;CCFK-lHtLHvxUNzC3ZX3<3Bt92$c zwT0mCK6iLWdAi`p=$LShEXm4g34_B=^y{;s=z?#%zDZ7mH3<_sOCOf^?)9A*x8%4q zQK(&(SidWh<3}*DiJ`;Pr+>C?^z{-h3;j7MOR)ZT#-dxCNktQ#_=r5EUF63RV@!Aa zCeI#0%M2%zoOSnl=DkDpv3%TOb$pGpJ*^^@2!Mv-z=eh?{!kq|5hR!9#Ym!$K0M=qk=NFtNWBEbO!`Wd0Q7>zJf*-|~ z;k9?$qe){Kd3naN>E6X1{Rj{t-I%DO5hDOFnzGt^Lw*Z8+hxA4d}v`srHPBo?>ho$ zrY?BXqLmdigDsVg1}&*RYSU5iYfcQ5Q((ZP3PR?n`f=!labyvhlwUqI{@;gB#gunu zrCFgwHSwBWT)O_9jO;eSkx0$TV=nA6l8>W7I)No?1^?WRUUYte9o=$-uU-68Q5S*{ z)9ZS5^sU-dcT3>1{ma#uz7;lsjXZhMv_PKMTHZsXW`DQV3t1ug^soMC86B||%L@ZM zO}OC5xI^_tKFRec+f=#>XEmD3|u~B8UHSkL7V%6mUB!i-YU> zd&~TV4^`TI{tGf9q=eIgvUp133q1UrbH1#R2KJZ*@5^MFjrM4nWx@8O#rRbOE*#*i1xmpM!Jh$VoGHuva?cp`n+%#`KA*FOiy`Cr4 zILtrcsxvO|7A5AdFYJ4fVYXjjQKi*OP*e|zcm4JNuJPEOMAsSttla3ca0l-9Y>1m- z2~!Q4kA0?5$FRE9ul5K%$ES=~QIv^mmt&URsyWc)$ z2d*026{a129{Xa6Eb-F2yY6ek&f9&#uWrNqUik3D z7_cylyFqRg(c=oDjZBh`u$#}L!-r8GL2mbKkuhHx(j;n?$J@bngr?k{{AjnQtpT1W zYKj0%z;(h8!10cq{j^%7w0uxNQ^=CjPR^n`Lel4XKWZw)i8rGhZr2xv zeQ2y_q(KTo$57!$R(pxkBS`}^jmuJU^?(;u=(c|@YMNwaLXmMjxjADg=LRG8ZD`&bPsw3n`fIZ|D`QY8 zk)ucE|8@q3G?ks zyVYWOGXNHgX}uxj$UK#jS?IDNL;!4NC8y$EeBZn$ zab755@6eYtaY;AzLpfZL`Kvr*Bx5`3YP*038x?A>4Qkp0vc{K7{5&JW1!Gg-w-MZ$ zF-!R|YI&N&MR|+Mc|!G2R#e^<10bV#)Nrz_%sgU8Fg1*WnI3gC(e$x;g&nmPRW!9Q zQSP`+4J6$_>4A-R1&@Y-4;{JotB@nLr_r6qYPgsga_B9l?vn`UNLqp&Pm)rdlCK)} zc-5urr3GQ<>8XZo8eq%|)Zk0T;OJdkeuq`N-IG)>gdWr6qAlaa<=>TC zKb$-Z*@zI(y1aSRA@#niHZSRiAr+ZaFUCZJQ^LWIsPhx%AhgHY84r%7^<>OYWVC_p zox;j4iU@uI8$H{T5~C2Mo*BGva5el~oHedF(u|D)FdlsK=K=i}4PIHx8h zT3MB<{vmXzPf2M1}!6YEEKiOBYuIa+@`aVn1ItxDaDf|&V`yCzk z$59F*9HtU=YN}p*z5vBfNjnB%BS&?ElnBAfL@0S`4gS2{;P5b_L^OTEz;LCq*6%v@fp3J#fq}oAf5Pa3I=%ugA?5;S@!l}~cwef@6;~tWY+9J=A64W2 zLHtU$;jHO}JQ!?A$5b2gDcTVgR>WDbdii zm|Qd(=V&#%32ToFx(V8#FYNZG#U;WD(}~5F4z$Z|^~4k<<@GpZd@E|T>RbqUhzHwS z{zThLlWrpR<ey=d4;-n6cnSrFyPi3I z4waf3tEwPbXBuRMBfB8zxM8=sm_74j*73cJqKV)S6B2?JBbX=H6hN}FDu zl6EU!l-M8uuV!F53^<4Gt*uwMu^S1^MGMD9XmHoj-+g+>l}2;`4GSI{Za5u9wqzTw%hMt-%zF_f|J~-vsq$ zH+$ZLF}t9PZA0I=?0os;S>Il4pLbvr?m;AP#%^~^buvZii-kbX+i9!I_FeMSM*=M( z4BU@q#z9#4vi@w7h}I^}v91u5x()3koz;@cHWalY(4nabp6e04By(j1D)toK`eJ{> z=koQdyj6c6EhO>bqf1U{KAD+(svlE zUsWiNPCA40_NM;M&AJ^2d{6mt_9=;QX{49XsM5Qk=-HifVMyhBu8_in^KCAJ*V5B_ z^DhNCFU?PcHGjqG(_MT%;URg((rU))+%aAdht2xv{_AB$mHt&3;6us=r!dn=iTxe3 zFfBR$g4x^GyR5t%X;nvNN)2ccirK1@T?Xn+`yVY4zamKY>KvJ0I zW4YJ0_=?Wry>b2%ak4ltB@1)>Q}A3TCX8IDoV3P@eE8dgLJR~Lsg>OMzBu8l|BTf4HBUq2@eshM-~I| zA1a;3{?9v`So!YgpxYA*Z+C{LvrX@wy(*&7!h7D6 z1wTCm5lE4KJbA=N|kxWKaFr(!?JV5mY?J zlhrB9HWrkSH7Y;8?^nF_^7L}freyF_T~btwJ0E?1<}slsXv8SkB`J%t%CV8tcKZB+ zs%k2)LlP_EamI0vOCp3h`eP>w@6IiYU7Ih_VN>Sp(cG1YFVO@gB3$(89$={Jt4|jw zWmJ%$L;475+DLACA{&!tPCL->xljKOPT+GHE<0bec34iCoIcIzu|oboVM2LX)$Xxi zZ&YQ=*)w==#>Bei=%iND3_F=`DyfvBSZEw}<{J;H4nIc*43c)^qt1P3 z{5&(0hC{9l1@3x!p*@(nBB9B($1yku90oTee{sY%V;`_d5;E!!R6|hKm%aN-bIK)2 z@%>Ou+K8}>)J1#Z%{$M!?iRl;goh;~UK^&|^qi8VvOwfZx2WZ|A|5XD+Ges>Bx~z* ziL%&<_BjWAV^dD9Whvk<9f7dtc-~*%1@ViL=f0zDpkw6R{Y`8;4zhrp{3_CCRp&u_ z31?Ut(KJbNt*{s2c=cf-dNIxvb!xkk8fl`WKY6}=`DWu;rjw2Rz9I=%qif$`SgMR3 zN-BH|>BFdxi38!+&39i6gI@g1O^YS`Z1|~c2z6HHad$BdjRmQ zkrGk*xU0{Wq064os9I!x$}GP(OyXu3JwJVi%oLz2 zes@Ql6rOLlW1U(|5h+d`QE~?Dk&RH_%;P8eIh6?KKx3Fbb(A(~)MMGC{zqUW%)I_x zJ%ydO`2WM!TZTpTNB^S0z|dVo3=fJn#CNSA`r4T5y{&>fOWhXM+c(%q;u(jC&> zedqT-=RD87_c<@;&BWfb*IJ)iUstZ8TBh9Hl8HfWy!R60-iEVYZ>wL9M{{9A(4@)E z?DH%LH~A02UC@rrehLu6@-;$tY);Ph&byNa*Oz1knmGsc9gtq z?Molt_pvdh=QmWQPx_Yhb~&WM?Oi6gouIPYL14QOo?I1!y!FTauLXC3kgtOERTY!c-~Y-q>|2iO5C0}i zW+swlN|g3IQ!u`@(kH;v7#RYE`HOfN~^dA@RLFmo#;rxHi|#kQvj+!I@t{@|XiV6-3> z{Q5;i3I1(V3g_&#^8UC+D1wLCPgvA_b+X8{SMe7EIcM^_YYOV^-2%$@3)7aYHO%N5 zL&@)aA*5d|@}jIt)O4B16)L;iZx`9!z2! z_E=5sKZpG|Sxyp%+ot(CVi}vgBY%(Ww?zlm_Bk&O9z$jIXIPzi^9F)hjo&lVsqN_( zchuV7n=t-Y(HD#k7uHRyyhkVt0iZ$O1?6;^!*ML4|ynq1Z#XKW$93?=qb0-Pyv;Zh3pzkXG93!P>8jn8qfZF50`s-s_5@{cAR98hJ^{hrPyAvo%CYU2h(-jjk z#pBO_nd)?ZMHpb3FE>T_!6NT+w=4;8aMKymCp4^-l7y;C& zi+)rtnpWt$1911n6$@^Gh6O8hpTL?cQuu@_)Gsox+9_MX$%8|MreBAw(9Z#U74U4U z|NQNLgt5Rt$^THm7ZoDl-8FXy3DEWM!-kIwnV7`as3--Wj(%42enADl-^7$W!AkUs zsy;$r>%lh%#*rceC_+NV5L(kyD`Ju_+Jj8ma=qjmbItoxtBU7GPsC4SYPlhlt^V%B z-qP3S9&vZw@S&zgp{CFve-ZP54Nk=tp4kfn{ihQ$p`yaZQfv}FZO8cT9nbV!1frK; zshFequXt}GJTG-2CyyBWyo3fR;8O{KE$L zh(cb5hfH=*C;#Ge9KmdTaJjdMpT6%Bja z_GlQwaKxxIkDNcUO`k4$lV(ZZb4gd{AzH6eIMGRW(z7$Fg4kvFf^E~j zdk>sMN!^qljdbxEIa5&5E%W)(yyC-A`a&X?)*<^TL*^!M(Cn*+DnqKJB zmc+9_eXh*`WYzg@85@s`YAJ3P;7cg!aWlfVTj=hSc+No2S|Cnhx7DV<^0+snQM__g zW>UbzO2Ch%TKN6MvF8tps3$KkD9SV~Do)v5eedM*W=wIV;K{hB_Oqefq}?k9LZnV) z>gpZLqJF$-nvVkk?bMBG+;1vH7JK|+>JCI2@*PFUby1= zXrV`pU8jfDzx{2OcPj_{_zOy8;0uAwINZqWM*F+BSlp9+<~P@u<5lL@4Cm@PEkewFCsQ^xHiE-y8zmDoE$c%Jjd>MP5A2PuZV$5L18To3)bhv z|ER~mrd$0LEfq%GW&<7Nb09GX1N%A)OHRr8$N5U(w+qFgr5R*diT(&b~CAYlolsoh=P| zJsc#3C%2LXSYQORnI@Hj*2Add_I9%_+pg}dY%YoUccz3$MRU~EdHCeCG1kw%oxWEO z4dork2%090L>SJlc)?$slg*?W&YvE9j?8TMSkZGm-_h|vj+_BFkTbwBK43h%1a9g1 zl5S+@lsTWSDn>EwQxp`lQYrXjiG7v#1BcAQm zRGmF_#mVc0qt~gV{TaCooD3gD>PH#GLOb>gnY{tLpq?KME8$epR zoFl!VhSlum!~;2Uu}01W`~JAU0LqC6>;94>hm+X{iqPykO3mW{kA7^y!~az8ul&vO$pjMiTsnW8PKgw{zwL-q&^_+fvAG^Wy6Gu*`$}d5`tMO*BwJ9)-u_1e*#c1eyqrzyA8FNO9eGe@8iWb#;|1dTJZ!Jj4gQVu)6f84sm82$B#VuyvNi>lGxnbYJm~ z{cMTxJwuZ`vhW^aIxk_Z-u$~0J_06HpeCg+4Aki@Jeej^4Wi8- zQc+}3A$&p*i-a>s?X~e>E+g(P_e#{GHN zyK|(EEi6gSSaxQNC3{=+d?)a`cPh9i&bnq$I91ZvYCK&12Lo76B^jaK6+1HIq5sO{ zv@H*b%Yccq$SJI&|776W8>x1@^?9a)+}(}4m7wh^Q}vP+zJn_0zM(R%ILIkyL}JHe zP)H`4K*!sGH4*#c<#tkBaFeDM8L$a0TAsIng$>&I)1mF52l1uddh;7wHv!?d{kF`> zMjZjI=UJ)z80}nbrms7?i#atLvqF7e3h`(7=IUcWG--%bE(oUzB`;?-m_ z>9$U9e9tJDHn>WB3bwl7`B9<>FVmziTLZgWqg6I_v}6;2fTP`TP7{LB@ig^)p~cJu z&=-)l3pxj8LzN#1UG5}D&cS?J8)>1%C@!WXadHbq!Va|S+E^_Orrw>knH$=zPEE& z>7+OHj2Hj4(PUN5_x+=_lkap+PSyz z(q-LV#*lxE;Uw}JLhXb=S6pfP7dEE(CkC{kiRx2qkrbhNxc(luDd^WeDhtm+^Vt)XpW!50cy8?$TJZ)F4> zgT^*u?~VNwBVlou`2w%wG_`nqBbzlp=ixOaCC<)!jpfR@t!kIXEk9U7^+VMm&@OMI?pA59FV4Z*WoTq z^8*cB_1`N|+(wz^cbBZ70<_w$Oy3>Ls-dAOlYQaPX0O`xi3$&rFwfK&Gp+z(L$t?5DtMeV909?j~SR#!`qc2dVMd}d5t*E~#eKmS_sj9BT zTj7p?c34s&>c&)L`%HD)uw=bWZV^W^fhbm&N_qcpszZ@>jE`^Y`n4?SmHj@&9Nliz zwuh5~qWsa6Q?C<( zm`qG!B@G_UpfA|u>35G@9MYR%kQB3V?qC9Qh>Db|N^cjVUXgA9+4~8T*_kk2C4`l~ z!!vm5$~Fr0D9}Nk*NuJHLT%ZrW2?nro~2;QvXvJ;j07sFi9(_n%o-V|my4RKNV(qd zxiQF-nosRthXF(_<^H#oXgji1o&BrzkLblH z6%Q=cVa<*~Zx!$`J%fEPqQIXgw5eh>R!!xD@r=R&Ye`mH;Y zn9+~8!EU#gHksGpAO}3UN7VfXw<${%`qlG%wty&~_>J|4%ODy5G$JSDW$rp_V@7&U z?$bLk@ix1=3IH^ugf40$q3}UP|1E&3ErGLt*P@0NQ5`@?cu<-mJar;t%v)@~9ngyW z2~1obab(aykW5cE5rtAWyBY@-ZeNldGeb)u9bw5In?hhV0x`(;$shkZLSVbRZ3>fS zXE+w|3p`Y^*M<>EDPnRhwc#8<)DUA%PjAyygZ8NM?jWf)Tm70_LasRmpj&8e&u4xe zW4q0`>8;n{B(1eR8j0|XtJLIR=*htwkG}P8y4`y4JOZ!kOGOd_62~VeLr1nJVR^;g zkEBQvEbWSTs-#brfkO84Vz5g^O03)_1-~mDZIh|KPw1W!KZ{WquFXHkrXzt zrcz3wJ_@QlM4a&^I`M1JjiTYjgwrL&ks(rhI~qNz{Ss9HO}I4A-!#A60Afmjl|R-} zYgYIus=Wx)oNqhL0|J_`fa3%Ok{{!S&?k)sCcix5dKk?-@p<+;F-m^kx0n9~;QGjr z{d5EXU^BWp@87Z>WKT}F0VIIGG|cipZf6DIcsA1F$+2=B>ic>@F!&`^$g^~>$<9-$ zZx^qc_1#NRF12P~g0?25||P@Ydy$&OBwVSfasib`KD#i_o8^L4qsZ=u2I(~B155i$*X-<`QR zDQb|`LGtm&Q&SX97FL4CWiOQ5E@a2q;^%6U_b>xNTO z^*-ywZ3Ov&ls*MN{4>)ZeM>`_PyzbDsZjVa6v0oY+I0|tU~3NZDWjpX^h+5`PSr^q zk9BYwc<#M2C7;+3UNA5`2bI(LQDwy@svkt>C9+L?dX~_ygWi)JQG9JQ%NKC}0|NBG zEw^Ka<*c4DJb$Pt(}Q-j{u@`HtIj6P=s%LKbW{4@p3Nao9}wU#y*50VyC8l&RGgP| zlp9-Ozf5d|RN`x;C6?n2*$-{LC{^rq7Ow?Q>evxBzKCwQW!L<8RHiIH8h2~S<8ml! zYPciQ3`ZUaoKXB=0BUg#yvM9x(WPfr=-2k{Ylv%kf#KZ+|8td_aF#(i{0dp(bdMzt zL_X+Kmr56g05qW_Dq|Aw!_K1$nS*G6uKFe8v?_jl*wBwP<9sb%)#`317(YJU4=Tl( zYPmC?a%Z~S1CWk%D0Y~ONNGTuOW-~7mG&U{BV2J_Y^fw8;2pu zhd}|A%q~X`t6J^Zx2jOmi4$kU6*#v8*(5d%uw)(gG0Q(CIy>KB%qMyb#=WU>@0xfR z7|Sj@mo)5u+1ivAtcf6+i1M$y;O%+3B~r`-XC6KEB_kH=+4-a%3>S2Vjyvvm-?htN zMg-)!1TLZzj+Hwkt47IJ~gp%)_of`+by^mBq3fR=J z3p7#w{-o5e71nQbzcSb+P<3*Y=?qT8>MhYNvUf%a;5AvdqTdY)FBzUmBC0U$$Jerw zQWbRNzz9awa^RXbq!P{Vc@xJIVs5MOQbDk_&%c1UB~j@A<;D5p0ZDRR!YAI#5w% zGgavHCMY2#$&xhj{5k%Ad|)3*VG4aBUv-&nX|P~dt^KR92|%*qQhJgz;&d-Uo&9YQ zr>AUxPI(X&YbcVJ@Ynd?=CyA4J?G9y+~!1>yh0H&&(736-_?$d8|sxrt?t%=?h`*+ zWqKPIiu*6nRe6@hfSCJBls5;Hb-nrg76!9FruP~qU?@$qY(1F>lIRVIj*XRl{~qV; zyt3%)o#2tUf4trAnQ_&(fO_a7yXVu>E7^=L#-BYm@Eqy8_byG9?k4G|_}&gVIB>i`CL>)mj>lmNyzH z@P)Q*!}n{}o$B5E8hGD=exrd zN}y|J)fvc#EXD*XuN?@M+-pycusKFiwAV-`SJ{o+(q!j68Cl1cefEk#TseE^_iWaR zGKJ$Z5GiRc3P!yQG2yZ01#`jhR@Subq8XEor@}Sv4NZ`COpL?KWpR@{_d^l0=>WML zZIpI;l20>!RdBneD?*q2Jh+V&P8b7pbJ;IZZ9vM5=9}XCY+7?OviHkQ)fXt}&TpqQ zv<4}Ws-mp=qoRI0aL0t=HZq&pbLa1VCJkLbCUr@jl>0HUiFi=V^0aoZ7ETqCIOto_ z7sUj+&JFADahtMt_kQ#y6Ksuq!o^g(y#frG-~z_qd1Ga~jo0&XC1NCugp(m|THoL~ zi8Qjo13MVz1`1N#T*-?E=w-Mfk{}U8F^ehLA*J5zH~%G9^q^&pIE0RA^dGg=Co)@K z^iNRSt(-fwvB~%RaKRtl*CCc3`z%Z{BoEFUb33Vcc6usqBmluY57H}5o;znYj2zUi z3MGZ8a8Mx(zbvODvLgC+b8rr6Z3Z@n|cnyf@qfQTIvFdPT~?zMD|c)a>=0wdw- zz(VdrN>tI2r|%O?XD!KeGV4@v@_$-qzVmZp5Q`8C8#2wmtliDIy1F);K&U3P}uwc4U6Ob8g#+3_nfv-WV~<_7lw*#vl44;BL0 zh~0>-%_o0sTu%21Qp{Hdz9GyPW{yD6^AEkeuosb1N50$NyW2xtbb!yTT5EW)`$mb7 z?VvldxKp40lL6V{6S_&jH1Cj~Mt61`(Am6&b^k-cf7z(}FyS>XhCnpj1<`(2oUOQvGTP~$p&?Tlq5;dFp}naMYC ztXs1Zv)xX1Lsjb!By+{u1Ft1aT8%(#@)LIhkiD8j!Kd6RDH;Q7v?Xyj{yax0NP$P0 zSRr5Yf5 z4VlspczFYUE`2nDa89EHW^K=o~*i^A(uQt`ltS;gE{pJTvcdonOt3B7=c>a~Q$7^W-MGZ$z_7C@=~ zIB{2jV*Ygnhk|e%rFF|VyU>H`+8`+;zVC?+duZ$7yn$2U;jr7#BqAR z`(_?-*w`xCy1dx4J&HLNFCpVogQC|rXZu`*w%=v*U`D0j@s4ZL9M#-0vAezJu%&)^ zn*JSAga8YM(2Bzb;fk&ykMi5CU^^>nr~78-mT$CFoTr z&p3EUx7-W!R&FYp80rF5Ypq0D)DzH0FVsC#2lJl%nSA>}5 z2EN$c&@yUl>Gxc3#8u7yphlD&)1ky{J1@@o(2VBmN1|gvT1_clIft;gOQ&C%@gf7bcd<7db<*v;zdLx9* zyg>6M{7ci>KN|f5n?tLXH~5kPJF<%R9}Gta<+$VlB`dxalniG6SW8!O$j959Bn%O< zrI%^~0BBv9PH4^WU0zWgtyCy#;cOcadHdr!wk4l)6M+~Au>ehmn6nMvPOFPh_z@`n zuW-K6C3IH>#clA4fm-xJEWe;Y5jb0qoK>P&GOo|lc9BPF_`T-W4=hg!w`v?A#l9yc zNVGBvKANw7IUirHpXkIu$qo2j1$hLB)e<6+r>#cQSiC!x4ezT#hu+f56Po)90PR~v z&7hBK3;>B(K`+R&SSOsPkswFk&mAz1bRI+}d)t$}p|}kqh9Cw7$v*{mH`{oFI|YB{q-~Wr-Ttzg4E6cXV5(AO_~Rzf zep#k0C3NfmZc6N3rTR`_k@zdAfM~selVO!JHR|_oS@{p#jxPoB+ncDkKU|M5>&BO7 zDp}VduT{0Knb>Tu4}yko4GWhP+VmjVUpH`bT*NH?ZEy>3`y0)Ufk5Y)RgxuaY=Ig` zGqkBJr%+N5Twg5)wp#hQ)z%Y5&bu-EJaQ6_Mkeza0yz(P+n_CeM z<8>8^mEr!lV+h$O@Iw34d`Ibu9VF7|rT;}xm>^9fq;3VcWy#8e?1hT%e8Bo^ksj3({ z#D8;GsdcUWdn3$9RoOA1s| zFB2(D1yYu)t>6Q8F|`8G09zZk*7+}o?u9UgJ-Jr3Ja|;TaBQ;X?>_D!J#?36ckFn& z^Mp%Xw*+5JZ~*v1{EVEK01wkHv7ZYfGAmydM;`94JK>mlYwvtwChBaJn6<5i$-<1g zjys+M6Ab0qGzKroqcvk1_pkX{n?e~)Zq``4;Uqu+I}ue{LU;G?{f4>W9$2;e`ye-v zZ8Ny6j)9UBskFwtTG)2lX}%YpYLni@#?y_W+PZ?i|S?gy5)+sWlt?vH%VM5Is!((9#>F&uYONww&H*|5};&)tx ze+ZDErluYnn~tcjf3sM|Ax#_q8-fB*Th<5-o+u;6o03HJ>%2btv0;J^K#rZxJ7Z?+ zE-%+#`-h#&NBwC{u&K2^fEYs22(aV76!8jAIhhBpV&}SF=9iSvEE$cSJ}gnIPLJLU zK>R@8G~i4kX`48@9YQ83RO1cK!Hf$O5uz%@+?2trWDG1NXaE%gT3UaA=3P2W*kOfx zyHc8o3wrWTSoFsbG7n8mZGvDWH8_y#IR0sqcUojhPwAOs&DB<>VAzNSwd0C$B$V;H z5jpXHl9CG)#3MsTjHCwzfuh`bJ{#ypC5+tgEEqU0pgXX_l5krHURWEyVdT}0TL~mY zHDOUyRToqq%GEa+J%}E@J}PQ!qG-ZSrB*@~t;PRSxKIJahBX!p;9KCtCfb81GZ|Vq z$?5HFwiDuCNIJzogT@uz@VHW?!7cPOBs48Iz9%SUrQ7i3g81j6mkG7NzDV^;5mJr? zNZ=Bqp`s^SdAGgqPS6uj00P>6L?|8Pu9&x51pfXlZM=YfAca(33rar{a4U$NWxmiE zBxPi|76b|zI{W*EO~joeN2m=*9TL)?xwRg0bfLFES`&Szeex2}(Yk)sr zN8if^u*s*{V67LoUk$7F|BdU=%V%n0S^e_3P|b94mh;H>Fo+kr?CYQaLKk$+Ddw#W zZ#tgt%RlTr*VWzZ8nmCa(G&fj$?WuTxs_QzP!0Jg|Nk6K|MBYwB8RioM{*H+UST+C zO3ED(Ep4Nxj-C|n+vu!Qu={hs!i0&Zed34+(n4_}yk7SDmIy0V_kRo$65x_%o7 zGD)_&6CA0HFtCYYFjk`(lf1U^{|)38a})Xz;9nAi2EbbyV9E$vgxb;4z%HLDYnfC-98& zil2Y6EGa1vl9w>5vE37f_tTF|bQx0BB#tL%f6GGw(StI7c4rbHV`o=V8#VQ{F&%;{ zDDabzq$a6^W2I~tUJjFZCJ>BXZD@ob93cCj2-s2-p-okQl-TnpUpacSy9(3E(&PVAmSaVn@dxCoe=+&0nO@oH9%DG#!i$*k5ZHP7s8_A-$ zfa%uGrJkaLp00yxJsvuq@AGc9XBZD4bn9c1(JgPQHHV6mIL^jP2J%FrZGd#5Yc8eM*2v%a*G~HlKMwSCePKW1{Mp7O&z=#6OTtg#gT1Ta)*jhoa+h6G>~$Y89e;JTFTV*^V{s(wu7 z%#wN$B%J;gjI0k4V6$t-?}n@Xk3l8o3|VL>uW|&-gBlYBwvMq#Ig&s604frVS9=y| z>&4qKN=RFf8!{YdWk0ymlHqY+qXUw;@3HP98rCqCaHGmCcJ z#S{CWn`FCg&xc1ACWL8z*ME$Dm6t=esSmBLrC;(ukv}}X2*6suO$Rj@iPS8XQGd;w zZWy$h^cWF}FrK0G>4Gy#kpuUTNi6J7jV~-k-Pm>w>VrN?0nn0>93fi#kO_MbWs7SS zdTfd2CS=8Fgz=d@lKzPR83!Kwj;s{2vHx%6Dg{P53A^mHEvnvxbBgn&x_q+rt&VCd zk~$ehLN9OU&YUO44X&-vKNIqg3}8eb&^$e6>5gUOeK}-qT}lH5!lLym*rG@ovMCWU zlx{F#|7^|?u>RL01_T~W4OVfQMEBkJGv)Jx>mR}X?V%Z|r_LGPywgwPON-~vcz74c zsg20UZ^u>U#Vt{IG&BlB97!XNX_|x<{$W$ZU904k0HDRM}r`n+}*IrUGW*dt|tYWb(Em<*Ko8C=>Pq}{QncKkTF>s zB!5n}AST?~?n1%c#F3t1#PJ&1;6~tC+2nL=i@CSU%J%wTMQr<#1i(g`(`@6v0ZT7PzWIEO8L zL7m!EkfkSq6KTZ@af%+mK%L=_KL-Zm@?WYWN-q{b4Ya0tms!tkR|)Mh=hbN<%cG%) zH*{M@Q2sF<0?A|39}PI4kn2mtz7jT;(f1V6j=%Z>nw1otq2FCx>dEbk`ne_lXLQue z(-^;HW9nod6F@GBg9kVSuMY}G*V7DXD!-!qCD+qw;7<$g8yp-I;g6|~CiI^3Af-x{ z!Rp>*vkeG+9XV?E3&}oq@csKSu+?8oT0iH>l>EYk0#R(gq%ZMsjC;AT@#3yx26R6M z)9lL&LtJ^9jVG>o#$m}+VW(az83W~ zy|loL*+jz%UsEdG>8gu>D?R(=MoorR~5ZW0-;cm>Z(o${lJHnd?c^~QvDq;(7T6iKm}(~teY{tjtzmS9I~g-I5E zj2XN2Bp-QX1fm6uo{P8~ledOLM z!s8;&!F0wc@FJNM#P$~#_&JPwwIQgTD8Cfk7QX{y=sm--{43sdcQk{GXgf}aKW?5k z1V6fsAGb3g=Q*4ZI3gBF#AJ?K8)c_;CL{hxp#b00%yjyq;-z}oV5JwqR}_&6zW^MT z2|(t^Pr#%l&HAT9Zj#+JjVhOij?+YyeHP;Yb0PC`c4`KH_EfE*QXk>ioKW7I<4O~a z!)U*3iPKmM94}m*f@0mILILqizRo~mybIUMa=HsvCKz6^gx>YmW{A^GZ^Xv)amjhp zaYr7~?)E=XgFw*EIr4UI8o+)Ob-Ww8r66&2X$Zo&9|&5_Gm*T%rAQctXr-xYbt`#| z_NDxUUzJJF9`5oX-`d4LWFK|gUEMVw8~a%a;prlsPg@&vXQxf|HL;tjmib|- z0{5Np_Pg}Thn4E;!0Xm>*#G+APQG1c%H3~L@PWsc*|bn*eetcrwy~i3r|n8EI?#^4 ziN@+6ydT*ic>h1^A_mJ8OO9>%mI@{XSl4@h{En#zuSEtdUQ(&FRkrA7x2R_}*f?jEanl9wuH=ipyg~==8P(h^rR@C^bvCpA3Q4P8!=fBa z(((oi;npUz;bTfQ)ADZrtkOY;9IuFYnf{G0;FHEH(2t3_5f{%C*a(h`7y%29i1#4C ztQpFk)`DICiW8T2+w({XzGFLmdo{jpn!i4v@&|jzMDOUXO_o#=ImuFF>D{;W?urK6 z+t3u{Fe@cRF_8TtCI^QP?qCgmfnt%+@YR`hj{KMrtEJ(a|AY zJ@(r=3UV*)yZ;L9<+Mt|30Jaxz{>gC^>cJLY^nH?|D$F8df(pdA!8;oK3X&TP`NN? zmh2K2dpz0+=e2RiJ2`DjPw@~x3qgqf8x+~SrO9#s8lNSEcR{V9!2jlUc^l28*?D!g zb6})K2UO|@dXFi@vqN2C5BZFUK|aKVJM4n5urJ#v+x7vN=_||UAqYAP^^H!tHeTxO z5vz+$#ajF@LyGusGw+43HS#>_9o7Faxz2KNl?cQtd;Ak+-PU)f=^UAkOtqQ!KHdoa z@>@)dB0)jnkqk+D&{CVvj9UjfI`aHarU17KE@-eSvT7h>aeG{4#|X0Z_5F;yM4O+> zROdKxcZ(cRF0Zd^uQwMUh~6gqNa+9y;GT(?VxqI_`QINFG#fd6$OINWX>37U8GI%E z@egfNO9STq9o7-$_fFC%`gk_@!YcnkYRqq&ALoBkTBfP$Y{nNk!?~cSTfEcy^FLC2 z+`J$I3l@!DXvSoP_OAHp0V_>^D>Ojc&MU~pD12+qqt|m8l%FfPKWiAn#$MJ1Xy(;k zX=iy>-Bb}JyT$;E;4*8nT%toLdw;*C_#JPY42H~=;19+n$;g$^{)U(ZF7#*PIJ8v_ z7fGZPPHTyNL0|N(43E`rzxh#8XH5q-3#h`rJ(Y<y0prl%C1mf8%%lcAle* zpW3cLveVPu`QLG_FF2sythEKL3_TSF+5yz7E?-ZqS|x`_8`hNSr%EpLGKhS6m{w)I zN+2-H3JCxDi)3@#8rrZ6c4xUCYTk99HTBlPAIHYK0DU*PT01*j=waMh5n>cG{#7Ut z!{1V6zA%s>)n1kAgdCNEo=ATT6W9IwPQMx{o9=~qoX&-(DG9dWIZ&%-xI+iN?-GzU zwAov2x;nF%u+D_LeI$co@2b4DrFr4P?uWM5K z3mn~WH^A+0HOsH(+u;@v$~CUF7c$$h61ic8X+i|CyU*M5+z2DT5(^-UlF>QKtXHE` zOt!J-B#ZIi{+cKrZ7cp?$BEs}WcoAAk=tueqQ>bZj32*zN@3V|5-88DuG;TT2*j^+ zVs19ecMuTI3w_M4Z@#MAxsE{OR@}#L|Hh(fDmiA&LmGEVcY~Jb{-}gD=6pv%?llnD zqi$9dTBX0~KWrD9+*dnj8orTh{7P66*Xd;8veE5lZaw z1G?!7rzd|FRX1!XV*y@CkiO91QCaJIVRtd~m>h7eYR1wEBCrr+XcOHArit4>969TD zje#m;01<|d0|49(Zy(%EjsgilGb@EG3RnoGkkUGB@mtG*wYBzVG2|k-E=iV!7ewDb z!j)l~#~(+9JgSAS3C!r)B3qdV@!Sb(ft-`4k3XzVlE`UvG&Ixl;ofY7A)C;;*39sp zb{gmMppG%`C1wJc?<)@bQO*pE4-7Bw$uj zR;%1^p*;HM!gjo#Q7|2`&NiLuq?V}Rugb{^^nR@FrRlnJ8LXZytT=nji;0K$?1CL_ zV(H1wx-=^K&vy?~_>S)vNecpn?D<@U@MVx~Wu6b!zSBvnO4x3=u6$koE$xjtw&nO+ zTmSEU<7*;rw|PW@6zX?F4MDr$W}Rc@*ARSwk>$gahESNF zX@2Y#O?IR6i4T|;>4J>O$bwa;#dA#lh)@!72X>o}ex%w0Y9R(>d`tvD@ep*jOE%U;|IrfuVH>qW4JF=~Jn9 zwQItU03QZ}WSVvq*e+odW{+pRH_G=S(cck4vbqFF zN%c$zSvAR)c$1vEe%kc?`@@chjCzC+dr}-%`ZwVGERm*B*0y&8G-8*fenN zUOQ=?J~R-XBFKkE(|ZsAeOupOx3D%GFNAvTiF{i}M<)FTx!8M#TUC2ie2DrNd%rV- z5Aaw{zu^u$eKr&5on*jZho;1gy|l0rEz0BB73vs&nzvW4*%>!^byqIQYwd;@QZN}J%sS|UJc)h? zvZ8!|R9T-bdx=zcsMu?fE%YeLSr)GSRhg0xNARXI7UUB{@FRsqcU9gs{PljloyJmb zg|%Q^*;X@&`u&vKQ5#(Xc^J{!F3F`w@LiDXLAt!VWs_O{4c_9dt$4K>@y^C~x7(P3 zr)L|>>($ITLgZJ%?PtQ?Z7lIaU2*9<)OmUh&Y<7#2>qWpqz=zK;Unu2;_2Hu0%4nQ zk|QMSvUYOLa`zn<$?i|sd_Nc~QY1XF;t$zc=*zf8fYXWZfdBsP?(d}l!1B^B$#BHCl}^iQjK=hq<)#kY`+l|~ za%TRhm^i-TI#<{jx4-^mSK!wGf*salMk9s@O{Q>;dK7R9^(@#@X>nsYGM0>o(yi0h z#@CsforJ#t1^O1}oq|BsoY$)9<7V*R{I1xu$wBzmXvQb#u8E%5L0$Fj)IJ7MuW1`v z56W2@iQBu=aQ~Kx$1HO)VSnQPK8t;=ea5vWCemjUmcj?x>59W`5zFm3MDDZc+e`r# zpN0`(btn6I5!t7Y9u(Ied*d;wH)r%jODs??OR>qP!m${naJGyf)n;y8uPh?^vnDY- zOm2~q$_;3r0CUE~rWbWPmps;KUQ-E{gtX89GX7Kg;h;{UCBq(5C9#t64Nrs|i@d`K zxYgukm-6G_GP&Yse%Z%dQi3Jwh}?_Sq3_rFWADsAY|2j1G}2ATgeIx_O#_Pr)|2GF26wm7E0itWKO zSnvZnKuuk}=JeP=F>$Q<)&6!66i@EF#3=}!WFUdOIbJf=*ED~TzjgUyoH0`pUUFDo z1rmZS8YE`3g`kEdeKXfus>Pek9i^6bF17;5=FAW?kOV_mW&S;N}!mrI(NWx6R=_7z4jZ$p+Z0`QvZC;GM%VFKmUa}i7^x{Fresu?v z^=D&}@o~44=<}y4Xo%rITV94nUN%K(MllfMBzf<@oObs-@V+d$GQy7)kWS87=XY8w znm+gQJ$3{qslWb_xTylghP8oTyF)3~QlPjMcZXoXT?!O;cXxLv?oNsZmqJMK$@_gXbIn{c zzaS^iNzPe&@3rscaQQXj>7y?=%iF6AmAg>m`F`_T@izHT>x26=eZ4lclnh_QpDA9- z)>QCUJWCG+N}!wXKK4D1J4yG~SnrB=X1%EE;)`=xlxsKsD7`WwX+ETNmqaoQxYP!} zg2TbdbHWB?7c@+>7UX8MLL|(zm2l==2iGovIBInbGouH`8QeZ^TiiCo-ujFOv^4b5 zs1_*@lP~7x72Mqr{RIxCzLpBf-EU+5-6R&*J_48Is$)pSU{j)es5Ek z)8KfIos#wf5+D~)yrN~3LF-5l8lZ5W z&QRZ>@ptmPKz}?FPXab-Lfw&nxovbdb%mvsz_v+ml9ze+Um9h6uMc{@k}&_h=kk2M zH!%o2&(OI(b%aZif%{kRrFdC#FlEE{C;C-!4142kGY6jSUb_lo&3#{Y)H7Gp0z(M>IFD;H_)`SH}x<@|DpYb)A7 z5R?9s-B9z))CH_OrW5yHjh-F>m$Su&(aXL2u7_red{h7ue^v-Z`v_nAwsyEN+3P=> zyV-(KH4+^+0zIzz93jIBl%iL3U#GLzeTJIf)tq1I)r)M#WUXuBMxfKovJcg?P1b9? zZ^0;3d038#^I-ezOX==Q$1dPH<_ZLhcH zx-W-2L8toX2e;>RZ{j-X!*euCSQ+9%F3V9tfm!@{AP&Wx3+4G9A!f)i%=6zzd<#S- z=A;x*xXbNF17Po=B%OVNO;{Kt8BO~tpzx;mENI<2$fzu2#!PY>G@oAhCUBa+Om5Ps zQQ2};MR0N@iI3bleoe_2oCh;JUkD1(pj?yXromXM1@BVpT20)3%1(hscR3I2-RfFO zmACqi!hwj&?V>UoNvmGA+-8qJShs}2HA|u&_xdT{XDf2ID-1S0h`a+3aHXG+y9C;4 zQ-)_QQWK(4Ys=A;opM_?H#1%!)%UEMuBR&+KZ)PPO`9*8gv7bOi=ZUhr?@OZr=eft z7r^3wshSEHbKfEW?0q~J;LzUT&u3Pixbu$^BD_e})9rRgt1>0$*YiZL zhS<;Y17BZszqvt<>S;jCzJ4|d4MQ@7P)sF;;r=8rqm2Wry z!3)IYJ_(upKylel0v0Z6Ji4C*3rD%riX)%Fj0`>(qNPm#j3?1DHm@v_k3C*vyGHPRBvt6nvz&ivpcAKjTpMz)20Nz~vXS#<@NkQ_p83cyV@gT2`j zz;?#HLtxJh?U)k+s&+aAmc51fJLy^buYr7g|4s~gWmf;)uR_P;&v4xg*DskXK_tn| zuj?FTA|GE~DJXS5NR73Sek|Kx$OYb2^=;Hrun+99^n!BQG|(GN9{9sds58m`S5+2t z04V*xAw@Dsh9@Tc?e$3a!dAOC=>4ov@7MK@M-n#hMR^Yz@L!k78jOp>@gjGYl2t z{t(?vdl!FZ((3$T!F@UxxTt!i$XHPBMfD`-f#S`$>@ij{wi(Wj9Hhixrc`M(VA6=# zzCwrHp3OMb*E>`)gPUB)84^oDjRx`VYs9MB2IrjU+?;A-r{L2d;-l@kRoz@K^J(+j zTDP?otlBn|d>=Z`8Y`bv5B>aGF8`FURleF~8bZy>{c}@Ssq%Wohwqq}DwG;d zvhK8ir0C_5_v8z!ZxOU>)P(G}nwU#Ak0yEK6vx*~<-T2`+6j=8t6QC^#0z3b^f<~( zBjzo$l)Hrb|9e1ae{|9da46u)H0;mvc8b~>$U%_nl)tY_WIr~ASe zoaR8N#Dt*l&XWmydvMQJs9uV%PgRjzIu6$8w>fuJ;Y=zx#%@+?333oI=IyC{!?M*A zy|%Rs6VK}Ph%nqQ^YflueSAz7n77QOE&d-0tUXMFTA=@@K_eSqyZWUF2ng(-pH}AS z?l|WEx8DbWaF>PsCqlEfV=o}Y?KHU3^mw!)J%~L7M!um62-8Q50)oHl`5^F=SS3n2 zsQutD+2UAvX}1sB(_L#fU4kdf0%3-;s#Sz~bP`OI z-KgfWEt~(Pz!Wp0L2ps;P()QrX>54>M)SrklObiu~_L?BZ1$wa&)`=u6uJB>Y zB%fQ}S*s)^A9BM?w4AlJVx7NuT3@q1Jb28u#WpPynw_9gK!Ucq!QE&Uv}_7JMG&1j z4bU24?I-w&ARl(#2>8dCwWgGlyvnE6TjERom;Edevoo zZZ#_U%wWnAb5JSwe#bwd;s?J#8~>d3(%M3`+3UUvA4n-e>myJ@rVQ;t$=)x7V=*cB{S7JSbd+hl@_Mt_A6VgTvX`CYr48*bZQ0_uXV;f3sC+1f*L)a@YA zYeI5QPt|{Sm4rU5aEQAUxMCHpW@9du2tq5lG@Q{f+C`B{@|)v~smEXZD=h1|ElO4Hj*n%|oN8*IanH$p$fP;!`R%1^LVRwISTIuCjS$@Y-;Kh9~ zg>Htp%+ja3+MP1kHWhuZnL%q?7y z?8vZM`Ob8WjmM@U!0kF&QHbA3LU~B(X_0>Z&Eht5HXVEMi*Wm3=cKQZ(r-2Y+dGUh zeUqzyw@b%Wa#JmF!&~&P1#Re!d?IVIMgW+F5xiQkzIW863N@ z5rRh42gv1Yr6j(S^-4-E4@%D;UFu5T5NLF|KM;ZDn%V7Q=|vaO5QR#UB*p%071W*6 zHxW1~SUz2gdU6}F4j!@kywLW8J0ss5kf&}h7NGuwmPJb5)p=}sp0w;SKM8_@mP9`$3Y>U7^fvj-GHURg;cbm zP+SGe{mZKF$CO=0x3j`BvkO%>yJ|d$i14gGUG`s5W8wN<`eigf<)+F7gN;Cr5{6;s zZ{01Aq9Y1woODZWA;&$cFH55o5y-6(`@OtbHe~8o|3&CxL)1|UR+3eLJE=$tkoS-J zcAAUMD3mWlbBVl+e<%oal>*;1|MMz8dcx%3y|{(ny+5&4Sbx?JF}7(1z7V?y=9An3 z))OnZa|$bYtm#&41%euXwf5M3Yc}Pu;Du1MpL0CTFb2C_yZ|Yek7US6o0&X?D>(=T zwRoZy?+;3tMpDvfr$cOwmj<^g&hg*%ZDPEXPix4$9}UqZiwF)k>V`D*R0wY0I((zC z0;q5f6rr_U#MysU;D4;7CCTUb$%31zA0IKasJrup7f`6z13X-Jx$|1i=Zd1(ru97j z#QtgI1r6&nm0|FDN9#wY|B=T^ZG6F99Txle!D(N@FOyWyg%P@SGOtwGNKQMa?S)jg zUvvw+_H&yY1UMw(lCjFjmTM@z#k)>hcX2IB2JAry&3sFRtu4e~Z5P~L&i2`wor+}h z|8LcNS0>eqG!_q6m25bU0~5zLVOS*TywZWBAo!ZM1O56<)2&QJC_iZ^BH@IVQnyh8 znBZ1-T?E3xu7o~~oqrW%}-%aW(6lTApYd|M9dbpZYCCE;_VZPt2Eh=z(Z0KxORyFK-<{GnpKxLMzBQ3V$67LDy6y z>d9#8LvuWn;BkI6N$?uGR*!*5pMSH$v^y=s$6!%w_PPNSHDT+V_o-LwHe0$Jx#OCwEj&&5m^ieWmXj;PvQE<`z-xPkZiv3t(Vw2$tij zQ2&a=%SeffpH~SW)^85)bf^_z`-~i|rwZWb4gr#ZY;Hhpy;$%*o3C%xYdKHpCsQ)n zpZsppEdX$IOT13ki6`yQImw212$#UH^g-TlBXtsDp!~<@gytfm;haDa)yQj}S47(qUN|JSif|IHeAJS7d~e0=pg5 z+YH|v^c*4qPV7BLX+*k}5;LmM6IE}h@+56xHdHigDD-gGb1O0~Ty9E{Sv`L3&fcKU zlb|i@a{u*0A;q{Wd}@g~a%hk`MUJ|%9BnAZBJt0qsetu7810ZWBXlsr?eoc<|F1Gm z$}T-jSQmXE3t^6Y8m<-+nUQ1C=rH@o#z8=f7LWdma9KUm(*Aa1xWn`8^XO5KjRXG0 zVMpYzsD2!*2zC)l`PJI}`-{zXkUNTCT^e;9h5r-@LDU+#1BsCU7ZUV9!UA>4vGew^ zGmc(elG{8yP@JQF&8uE?GcI;msZ>{`M37slOx(Y&sJBhHAOjGhPesC9gRaiBwC1t) zTzeRL><=rff}WB^&n+aOg5>92`Xm6Y2(8?AW)&orYkmf7HS!LsLEI-C-@d_8ng=E| zIYc{-xa$5mJi$gHdx{0kx8mB<%_sZ0^4J}5#6KdneN6I=suyT?b&S7+PNxSR0)hW9 zkuwbh0gwjL6-v?O7*376U!kZSs5ClA8jMy2<<}x)j@v^$N229yY_gsTv|5GE*WO>4 zdDw0&yh#;ey6VQQ1-F5B)}&wmF(jdtpX&ecC{>LP@F&ap!Gr28^DCFG>XoLnjXC=s zJ3deFiRHu2-ZmqTnwqLA(*M=JwZ{BUJEQhKh}v(UB9ndiV6_iwi>E4p#I4M}@sS)K zPNsc!xxYw|GHc1m{|<;hgP&pM)M22m%E$Jg5-%)8=3lKVqUR+4UZI&7ZZtrci>T=m z#cX>rc21O|#*1PwJFOjkXEwSpW>r4*DKz#&v~w?o3(GolRk@Uk3dZH~azQ&axw0$F z>=d0Jt0wK}Cb@!Tkdnj;`AVANLRh8F%P}TaTL)&&C-U?y7)T;`U zue25i0vqR_rDF^u>OOk-6fAyyY#|k+CW-t4tG^3=I78sKgC9*$Jh$<41-7KjLMaC4 zs&HAqChBX&hKD0>+1W_@uVYz!Igl_C@`h(9{CjroXBLt(*qBu7;ws5&kB#u_!zvAn z9vqEuln)PmDdC@u7QLmFmHg`J_unhHH6|<~zbMdtVu%>VLg27tukTp>T0kDz)%Jf*j5;6IUD~R7tSY!Dbs+S6c+dIpUu;qJ{LLn(;s%s z;%yhiB`=CDWXujzY98lOJ{;2+(x9GNx_5RL2#!oqxQw_mrLj_90$4uTv1%^Og}t)y zdzSY!pRDvs0)h0fX@Z?;WA;nw_(1*32|{KJZ$AF*00w~!VP2G7NlCjK82b_yDNtSf zt|G}SH2g2kw_)0fN)%q(n6y_VvxpydcrP%95li~G5M_RiKm^$=%+4}tHFnGpMiFja zV+31gox6pJVi~us`03Dx0jd;BdZ|$i>2#^E&}IrPE$zL_?Jn!88hKL*Adi><8i6!v za(p2mL?erN>g|L=QR?^O7)pE6K{)WQ9(QZ?eqjQl_kEP^MSVH(b+-VrADsN#d#QBY z_}pDpq3qsH(UXzMV~*(X>;>ZAsGcz+HKx3t2>z1VB3LhIFVpU3+NjdH%!rlN#YYFU zb(QGf_`}9$XXS0PqlGh48(bnSl7I(FGj_v?Ge{9nlyk~^%LA6!$oTk}cki0ce3986 z#M8t#(X&hmKj=rpc%ndJBJ2iw{ECcKO>OO%&Q7wFl=pOW%4ik3VWG@~DeT)615&qe ztHrm&r)y!C=98o()<30sc|lQx3e)JhtHaVW;HF5zN*Km}+ff!A0TEpo~aLC z@ujakz16Pet+;!$aj6k>j+-lt!x!fObE)Rf#fyxi<9k#*V_%#PI?QIo(A@ExKD%Yp z>j3Tr^VGypLnPy0L9x$%@0V8!Yc4+OWwL7K<^Y9K-_z)z;i|}g-=ro|b>GXo+eSzOsYYSa5NywxYW}z2#3J>+P%kJt z&F*c6#DDO~p-D!sjblUh`&1hKPg5(nrWWj%SSO|e1j7Ymk5wXaO5^LS!8|V_ony^Y z>xa;6xBr|0a0oANpFTa`rvVQnNf8V8qA(AfVr~gtwLb1De-|*5JFrit7U{rD6|{T4 zppf|T?HllMrHsUwYf##ge27F(Wg)-#K`AlxsDS?%^%q2JvV82I2U86PMz>g@c8BZ3 zSwOh+4lH|zc8LnK@fxeI5d;{j2*}g@3g2)M7eL3y!zu7Uvsr0Q-Jou^V<2T}@U$H^~glGd3F6E;gi)ajO+IRQ*KS8NPzB z**;z&40%_fNky(kYm2-kK#_&;*n%d;oXSg%tVDrR?{oGN3E_E42F4drz0aFfN?0-X=|*^fB`W}FS+RT=((X8xRtf;l zAt9L4B7To+!SF~AM1v&d_&J$TTaB(MtK0o8c0S{u<3kK!nF%CvzuzbFk}!w-psY`S z>BE_kI19DemCB0EgIEaP z`ofd??fUH0H^=-iGL`Rkrb45!M-rWYjo7dp##XGRgx=3?Nv1z2pek;-EDYiqC_bM> zqU(&uE=n1KWsl^w{zmF81$q0aXV9V!&4;M`r z{ld$@B?6Liva8tApci#OWle>|(vNni0IA1T?*f&{4g=+4b7Z`|aIqsvx%qFO#sq56 zV7y9PK`kp0Y2<(uM~LoPDXP~FlzYg2(X^g5lk&9S1-pb(jnywmT@+BCH;-hqZ8R{M z8Jw)eJJAL*H&oHcd?KqN^_^zZntKgj|816(>(y~pSXqhhFe^D=asaEgwka-uWW$n4 zeZ2bB!G8(ifqqXXv#1xmWg?BZ#Jl0Z8S-RA!xN&L(w3k_7bKbCRLTBh*0TK#iY{Gm zAA|`!_crS5bVNuVEw zFd8gwQVEaBmGI5u5WP0sSlU_%H58xF%#C&K_FkCmXCXB<(>s|+yICDmv9EPsxw9Oc z)3vW3clcpk+&5<$o~yh4#nzm|JxI^%omaEz>fWOB2Qs$5NhC#~p-@p78JYf~MFlac zL77@w1A19^dn!i(h^f8um|9BKw!it3&GorR3qM7{0*Qs!XI`fKt0pmj8#mkQN>-## z(j?xyrl_3$a5Xz5fAsP|nvpXV#N{(uKy`n7qF^?>Z{9um;S+neGgFqxh&E}KbF1zI z4>v9vpN8=8|B8BxNyzz0*qNux#jmv+Xl?n%?}j;)Ze{j7P1=V`Hvtt#Xi(9xhQ9g0 zM}L?ZXmL|}QYy%@Rr~_$wrY)Cf9YO0eW~M}k6NZIFGJBtel~7*Jw$(HE5J-p5d71u z50lLna5xL0e>782<>~cCDeCpb0-O3Tw}ze78TJCZ2IJ-vokQnIAe0FbFqlBFvk!aa z(-Ql6*--oVg0E6wUocyrissTivWkLe^<@A}#MEiBS!>0?%>%BLr)kt^tFb>zWe{Ey zXI3K11Eo&wGxkFZc5Dqv#vOpdT-fBZG%jZ+t`~Rx9M%YvY9dP+39;iP)|KzpSJl`oKATT=G z=R_$kbH->0g}F_8yl;Y8W_n-nl%|Lfme>+Bof%T5=ewP$Wa@lR@OF1myIhe9`i$Y! z*=5~jcz(uIUSC6mh*(CFrga2wI{KDJ?_!sC;kLRoBb%)WL!a5~|BUX(d7$hIGy2gvdnK>lP{HqdR+2%> z5nIF}@xA~NOR+S6w+4B}?GivrLl@!s1Ha~b-|p{a+-Itc5PgJ6;$xyq(7LC|mci%~ z3XB5rNhU=mXf6?iOEnXsOD&~qjarL792q$wr7X{6q;AjEzjvN_fajW6XZqJgJeq(2 zj}A=7e&F@|+NBRdBbK55O50LQs=jHPl*c@n%GO6yc*;Pu#Hc%SY~=2HMnJVpL`S2C z#3D)2_wp(lfP*0`#QB~Eo+`8v)nes7zv+hT%Qskn@1tm;t+A>6z8fd6{LP}dPxA*= z`H9GcgdqCtT5y9BvtI!0u47Qg#7fIJa`=L#%(6%eUpB1;Y;R6%_|1y6647N`4o|QC zi65}ON>%;<`cvpNYtV*$Ra$B+`MCXF0~cpF1J?r)4AaaD3m0 zs9Cg;(B6Jm1WbZy#OM)sWmpQ$+%(t!!PCV!zN{=txWczGxPeNgfW^N9tUc|45Xv&_ zt1MSMI!tNnal&DN-l5W8r@Px9EOjxn(oZ=Q-9zE=C04;YwMHXv(JUh&IA?1NlymkW zUwI=dQ&26pa4_P=w$ZRK8I29oc)~!Cq~KOk^L3@2M;$aM)efn%|p0H zfUfBVTOFGsX7=EPM_n;ImyEO#8osI zWGrF5d8=8s9a-3?9S`$sg(k`onVW38s2OLhD_h(?qKZ_-B;pgGW#|S~^({iB+|HaZ zkAKQDQa2TnavFIRB*K>e8bkfuRrS@5T;b_|p^zm}LPTZij0T*(JEZH)^N`HFb!;`E zwKkQP-8?#yx|$Uh{0!e^jSFLg{qgtdLYac)2t&3ijk+BtWD#?Gv2henUozaFrNGdm z;|u8YxY{2RYExH?(Pfd=_-q4nZrP^4a4ku1PR~T6D0yS%^E;?S7HS};Xw<*mRP0H~ zgJaZg%KY;wHW<-8XBDfX4}C(g?QLkF4i9X-9524 zie8#YsiPDarUeE6IK(Jmh#nP8`Qxki;#}xx-+jt$&xW2Ef$1ynK#uPUt{ecAPUL3J z{Pgd2!nm2OAKn@EJArhjxK38CyT2d>kkm+`TWnE=On6RM87qv>ml^g~9Z`fD_55CO zmY(6joN@(tR{0a(_8>(jyPh>F(aR5sa)8~rPwN+SCi*+C>$$GY+~$PZJqIK1Di!gkZ`_kWO=SMtwLH4qkFEX?o!PmR98}>8ewm@9H)krnkS?{UnUp*SYXzLN)fKw((yU$^N)~Sz0Bz8C{r;BS|==-B0*$tL94L)RYU~;hrc6UsK_ey z&-F4~b4N#(QXu$s5V7JV&8=6P6V08dABGs5YbQ~L+Sn1nqKE?33Y}=~Ek~#!18DG@ z0alq0!+RMUL^C(XLg+T{(x{@89EzE@+1Bx<(3?Jt4(}js?0tT`z!u_o<{14F*G0F{h|D-{OOkiv=o#VL#tjs8;awQ#^{_}g9cKmNK zP`39pD_(u_uhb~b7|GuAy1?#sU1oYy*p~nKt{+vc5teki7I3XyOlv0~YF?h>ycOo{ z-N-s!o;?28{(EPdP@d;g-c%k)A_hT}z#;Nm6rS2uoN>Hhcz*za+L(WMR#e|E>xBcp z)84}{W5_4JItFesunTAE?2%yjsoc^QU5CKB=#ynYQJH47~3=?|p2xPN!VA5W0i zU!j|yVK0j;+Dwj}%SXi^?pA;C?=-mCk#pC^hbTqICayd9xKmJf`@YdFJ#Mb@$(e(l zVi#!tgIMW!8HBeggcz5*2{>yjoBK&J^UA2f05e|zhuacft92=0klPX|2FAcuB=ev7 zRTn;Y{6F$e+dJaCdFwpR+SPuwN{klmNB^+?Ck0?!aA|9JE`$%R#|;e&3nosBDD^r6 zu6|FBT4#I1_B=E}BVr|%SHZfB2ynRRmntTInjz1g+vg)J z{Qrz&8;|Pf?EH-^mV9xSrGvO!RasU+o2W6}f|mCd|7U94#_NytV%x!PS@+mX_3S10cw zjbROA5T(>`8o9*yqgNYi-AK8&vsqN$^Q?mJFQE-U$L1x%f_i3zlarIr(b4Fdt1EGW zR}!YG4wko()#O0u8OacY4-}EIu`^vT(zTE!zd$uI z&q<5mM$O($j4LT>Ym#n6@Yl~H&(sGOD7Gb~{fu5p@#7OG9jzRIZ_*)OMaj}gOpBYU zWc<#EhLPtMeZ65sRrvPtk@Is2G#Y$ArVaoun>#!$uN%-}Pz(+d1nK$e#f#`p{ zA|k(IZ+kuNP;2a(174X1q~4wa-<&Lui~ifAf*UCwTiLMVD=o>ZQnawJENpKl!E&ha zzBfg8o6KUD6so}h*Ss;Aan(O$P%CdQz>hxqx^xww3U@S1SW%19A$%_>VUOE8UL$LD z*FJI=d+TM?omQxx@spdqAZMPMmeaOG=2?h#sR7`V)p_WbWsD3hxZlT~|NP`0x8S=6 z=)1yACf!&T7OPN|*3%>VqH;R_lB9+X+p6*%HVMiYq;slj;1LV8qJ&(w&?oCdZ(Ut# z8jxm@85`QCN>W9der4k!h>Blu*_%BK0mOPduozZPa(+FmwlQ#FIrU$zDm)A&~m`j8{1h_mq9$Dr+@^ydW(7 zDyHS0)vyXQ9bqz{s%qK3;n)_m+*~Jf%18gJt zj9HoofV@i;E=DP)E!ssZ)oCg7eP?LCHn>(rwe`_0koY8HoYj8#1&18=em4Y7#w)9} z2P{-`nC(Giw;;<-7#NyfZ3J~YGv*G-iwLvNMs-8yi=PAncp5P@1FjHSnc3T%L^B+# zM*a;U8_X)p`s$H6x+h52K>FhEuzP&boP4ufnj#;?2r_Ke$14Ity>9i|u&sfryutUs z@hzI1+A!5E+w!4+S|E_b9e$y?HF#anzdf)F^8v}I)6oUb_7b|cJtl8&|GnL5+>0G| ztv-p;5n1%$nE~q`TAw@=t|AEDTuf{zyS_D`ncoHpRln0 z??)DQl{;xr3c zz~86=cdIlaku-E%Q`N{<1XdcIkghY^d&j9_Ocz)fd^I<1r}Xe)QJVQUaM7Ef`vX8w z2VTfj&@8&!T+<~RtID-ZNlR-2F9_x)<%VPji)>x`0Fyd|%RQQFVfT|kiX8s;J+eZl zBQ-wG+ut)e>kBZ2 zo9|1EWT9(h&Bf$Cxh0#Ucl)lFwR8W^~msk4uF z0Y!1F)pUUq1w-Q@dhe~u=y>U{glj$&%;yFaP0cK)%&d2yPtM=j(p=y6Doznu{~DeD zniIU#%tsGu4*{F^K4zyKOYCZXqssbvTd99@KO8;;7g03pNW&>iD9mtp`~qvOw_%Oh z8jWo}32LUNhTr|v6MJYz?x_nPkg^6T*yR47T~`74#sb95*(o9p?b5u0qhUwDgD^X< zC<+9DEn5xN&9qF9%StD|7z;Vld;gxurBfOC!i{x)085F9dryzchkc#(el zujH9Wwp8reTALXNNBy{bmR))vlq|=KZR-N&lhA7Q+~Gb50@-bmJDMs;&E>1X@yr5{s zF@BD|!3(>GDN8NSE{sKeN6J-;Aj=bkOQJGbezp%kpn0Z`@RZR14csY!VYOZ*U~Q^d zX-tq$2tj6-&UI-?=JYo3|JSjrHg$OTg9H=V6Mssld4r?6uG-|V>j&YP{su2}FSZuh znm?B~lN(U#Xs_1St_=S1 zii0+}uWcBBH(EM_mtl`m54c(k`r9jK#-`$P=hFmI!ud zY)lt=-}Q&W6(*_aY*&3e7-gwfB`U8(1X@2j=9%ipweX)}bOs=h^~_vl63pAycO3OX zH;`LfJ-Ch$Aly+zryjL=Iy$TM9jg8(5U&vIP5mt0Tk=3+j5yKL7*#Cstjm(2){U0{ zZS97F((Wazv%>9-KLLt;#yU?gFHD9P7zwa{!`Dj%@z%DLhD%5-^WON; zFJxZ?(%x~?y*GVv2Py%b4zRuBz%l=Qg<PkhxrcwqGVs#H;Ol=5*O)6_YwluH zmjA3nAN5g@H+FXJ$yL(uGL!$4V>}jEI5@oKzy7z!Sd;f=!!_4rHM<7*dy&vz7;;JO zi`aHlD_b)!B@(+_zu}uFy zMwWG})6m%%58>Ee_c3m}X!%tGs9|~wJ{Gg1s)3woG(ougJ%FLv^n6Cj#wK=Xt0>JY zH65$?JcY;xDt=)rlssyc9Uf6<2n)$eb8}HnO-RVwL;@??4c(2wEm&CYB97ZPJ;bC` zs*eu#wRLX(@>gM}%^f2p?S&0;o#^_!X$`F%SQkEzNdx`zupUDGy6%|BBG#72=34NI z$sf!h@NqMfOnR5YKkRos?Q11VFSHu_E;1`Ge3di|lyD zu1^{V>qO0QJJRqBfJJ~{hakVal`m1&Ffp2pylHfmDfKkX?<=sKK>qd${VSH0Pi$T-w=kETrWKj9aSnk#EA1=xT(ku8 zOO>q;i##2*-`gld?*YH+5<`H6I`e=6r#XPyeEVWu`j%S?pyx{M!TCF z`IF6Pd-#H^GhT<-SDOuaq(~*$JcKwi5{#ruxl=Y(m*0@|wr z8B{WMy};f9$RZ@xJ(zcP{-mMjWM$p@W48uQzLU+%Fi}%)$Ml_a-`QFV)n-=w$*G0t zU%y4FP4bgi;+jpwL_K!IgXG$?AL_+d7MtRAWLQ(RLr3SSgqG@ewaI-scA`Na?_LIv zJ|!4+w8HUk5@#c7jWa)J)S(o)m{D=>>||z`}%)+ zTS_1teIhxL^XQYmaW~k0!lss%1jts)#50Ok2NqW&>$HAgL4gn1O`P*N|A@9kCIs!T z==F^xl2hHF#&DrFUr#Y41$O5SaVYS5%Jn~-?Rc7PRj!nXC;D9`&yA(e%+UqLCt!RAmOXNrNSUa z>CTmK05%h-U)X%7Nx}8X_F>r4wbe*}&ADzl!2Auq;LK(k@-Nlhm*WgW$EPE2=aB7y z;JWIpkJ>)(#K@~6kdq=A8Npp-Ix1eJl}Lno_n~XPJZ!@OXDz=c7b*dKypG!>49yxkh84-d07zAuOf+!#__Wq^ zXLSY!#I3gIm7aIhF_g?268G6<9H(=yef8p@{%Y*|I`atm4Dw6R2xU^2oGS&6GC&29 z7c?Y&m$NXzOnv+dApf&dLnhat&N_bQRo~U8bI2XV%=}PJra0lz+}#rQ+|>!-=2Rs| zNAk;dg;RT>l!;*EISu}xF37iS%8^#xE}|3-Q7e1SqQ{I1SB<9=?~ zCFj&VJ-<>If2gUW6ANE?A?%K~?PC=E2zsPC1hZ5--IA4=`EPdYOfcmD{9oMcHVy5{ zaub;PT4$$11a8kRkY;_J&T^i}yk4H3z$=}GZTz5_Rhid}gu0rJ3uQaX)%}K9@zUq) z%g(<0v$Cf~w}b{a@r@_D8fKi|*o~5NuSPA{KRx|wF6!DK-o9^sU~mC{{c49LyKfEQ zQs?Jh$2hP}ZM=Ul6r7xS12Vr&5WP|fy^wp2TQ__{GtJs;@VEE(&#}9@UtRSuZit7* z|LZaF6tu2?x+ypw$l|jjnERbNOvDM4BSKc;f{fvkSq5zQBOZGuv)r-e;+UnH{ zHo{@Z+7FZt%InJ-aZD|5c1P(4FikDOW4XEhhVP5}mKs%j40QEbgt~fn@ZJ;L(6#`* zco>rm6m!}2G<96-c?no}NRQsJ+6M%#`>(J%+EvB7EK7tpIa7B2v73bAIkhQ1bo>rH zTIa$oN{ImeRc`$ExcRU#ooBgm;|5Z%*}o8Vs_V(oZRfBvtUcS38N0X>Yjf5<`>H5F zc?UkhlKdNex1iH2(6g(4uTHLZ0^J&XCbJ5<(3jjS?x?d1@{hXr?SNnqUOZ)yuSBJ7 zegkxEOC&U1+f{9p+3P%=*m#JSm6H?bi+zZZ4mAKjLTWqq7PB3lPtE-y1WnieA=m2-Tl6mb0SiJB% zKD9TRQjbbM3!lf1v^Z?GuC&tRoyBg({oSqsJ#-P}1huT2Z2H@-hTiHNCkUWS}r zEmse3vvz-{vcVXo`c%lj134vnf1RcfTkP{-y04i=v|!JSjjes`bDBBj(s-2^0n7fR z=U@+n--tqa33y!p0C4b~_VU+xyu5myxCZCB1<0X!rhJKuOQrbK62zz`f$4it)|%3B<`? zU*tS}Q7@3F;;i$h-ph;Av%>0NQIq$BK|s-D;KvCrT#rN3{FOX^R&VidsW{;YlU#EJ z5mXx@>4YE5yPkV9K$2DtB-RPy_tyS-d$%;oKZ!gyGOpCeuHGaU)L$a5K5+()R4r5) zLn`ln>-5t0>=D)})%ily%i3y$IdZg}#cf!ApMg4uq`slFiSnXqRf}V$K->1AGPaf7 zpkHrOO{UZeEI0^<=HUb`j=Kj0Ub*^0)_pw-n}J}}ld;|_!jr=~Ux&vEke`*F&y>x{ z)Z*ClS)xevS#9B-M`_1%FLW&Mc`O{L>*_`6@4h-~*7ZD5$zCglnbL;!1MFKF%}c5k z!z&pMB*Qh{@;FRRCl|1{?n%FO3XXFq%gt1iaFiAK6Zov_dvO-P)pP!XQ|0XjQm6iQ z5!aW}3uX%7?>VcFJ6YmYS1HRja5hw~`3-Yu_Ki$}|IYZmx+#YP%h~ee#sj;=|NEKY zgoDG@g)k9nbQ=mGj|mA4h5?|dt2`7+ca84HrQfhbgsWbF-7ZBQj`GjTh@IV{f|I zjGeH|zim7(X$Jzpiqn~$6{OJ0=|3mep_^iu3?U2jg;%sEKnoC!gVg?>9uFaB+lmdDc zVEHF5OwO+Cn7D7ho-^UWmU!?K_B(Tj5nM2qSp)uTZnDx3zRb(fV^-WWem9bt1+j5Ni*5Z%0zS+VHt zT(7NILe>%qIYQmbXB|B4m`sj-K#i_!d?7$=d?8)0E%-obrg9F^*;nFjqcpnu(!pt< zusN;XFT`f0#6HA*CE9*{cH++3*vv#|s?jA`g7;VXG#Mr41Te?rsiI+JsZ#wWW9ZCDfrTp zLvFWQ>8v8hePEvn+##!ZTNH1*{|2UyOyUWQIg4>D*|mVb8qmA=zu~ zr4~D>rtg_uj#_Z{;xHHX)vno0l;ONszI7(e8K$pnIyakR9%?;T5+Qsf+$F%UrHaUym=q+;2ql5%M zoStL5;MQL?0@nW?({cZh-jgF`sQMs1ipkW(ef1t4Yt*R7<%KY5HJUfX*oN%c zLsHg{ZuNF{vFTll!)By+N%6_6Q%uc&Mp|QKJk=&@D0t5ANsCtK@_p5{J9^vMRC`W$ zJ}^ec9DamQ4EVpiejf?&zVh-krW$r%QqlE>aMUEcdTzhYI8oX<`*NA%^9;0JqS#Xl z`FE;wnApf4qKMw7>LFZq5((LQGF&HcMAi>hgkt!$6wI6FZg4#1v}n{$VfY#trS zCJwVJlJ_tF5)Su;+b)k9EsQYpm33t=1nN za%+{|Td}^zE<>~aAw+pSxGi$^gZ8X{v%2+A=6~#n9OJbEP}1P19pn$tUJv-Dry-j8 zL2kBe#D4U<>pSsu>|$>wGwN!o>!)#f_0pmIN%3_>mE+CvFFck%htN7Qv6<=J|6ft8 z?|phOR-@>Yaj|24c!c{^>9RKou<}TAM(k%QgI$-SkB7|ZNs+^la|M@BFLv4kGuU1f zY#f|@r_WGrGqS>zP_SL(yi3+Ro8dz8uM-C%b{nNg)wtM`chjXNts1+BR&CNzZEu1MX6S0p}ybU z@h1GQIeIOH@_7ANZOAAIbL~~g+Ok`RCQA{EhKaHo!Pucck2?B?X-5rLj3w!Vu*q&8 z)5000Il8a^(LNY7i$|?5cQikjQ8^NF{q~x!9Q9=OvLTjIRoUZEXjD4`e)#icW77m6 zP+VX*eFak&T79-FPTyEA}c9NrqeJbqED z4|)1z#F%KBH)*OISLKhl%EQ{K5tMBwfpg0*vd?`Vf_j+FSomyIhI)L;>mc6%>%S=< z2wCUk36lcG5YkW!;Wk2U1piPksO{5rF)>Lz_OEEKsmkE0MKI88v<1mo`26|VF&o+_ z?=J)1&60VPc91Gs2kpZvJ?HnDqfZla@-Gjq^pw#)9w&ka7)nus3R+34CEp15m;86Y z@m=3mn#cb!4NZ&YT#N)ki^Dpz_ehUbC79yuuBXwBt8p>0e;#gSZ z#0G9g(!#|rUXZsp)*3mqwpjMrqIRXbXZe*5tD0uO*G+4Sbxh`K5WA!0S<6}U?e}gj z8B~PfjLv%nQK_Fm)TH)xEk)HXMdgP)Us&{|g{;QfMa8t!M%06v)d#x*ecE_;?a=tH z-G56vJJ>;e{YcYYeyOj0=w^zA>+;b2#ph#hpOAt_6!IE-&%8f}n#Ilr*=X@EkG-Bh zF~(|LeM!DE7~o3o`w~Qpkp0w*S}g1-uvqi|NSq}B6KEu_AvzXp=Zq4~&)bYx!Q(ND3px6u5vwQxrcAkfI&+Qjdw$oBl z2ev2$b`kvpAP91~*o#W#T!ca!<0-b@lA^`O!x$nV*Xlc_Uk;BrCUSEj=-^JV{eBVv z`n|v=9_Zooav_EBKl~*fG(=X(Gm2=@3_CcYZtg4~4*mMzIX?tBoL)Gw^_!|C4!)kPM9>ZU%(K5ALqQ;$oHy8=N`$=4s%oxqCw zIBV78{9fp8K!=U)`0jTMi4>I@2#0 zuOPAS?gk!#dpB=_#I-OzH~6+WtTqyC*cNS-a4EVqS;xE3af}DXwOj3j%Fq@eGhRqB z=?|u*1p`^P6rojy3x9#xK;RrvXaQNi9fA=1h>JVaD`n7PP-vGNyZG$%{?_{pi^2|P zt*93wuutVtX&`Lng`nh{L!CS~q0n>_SNy2(n&Zi*91XUb%e5Z}Mnnm2-vL3Py;uFT zk-zQS3OVfRw#rnC*%6mI<=WSAvRJ--Rh)0r4ej2PgJMbmvd%ZVPigAC%ucB8SK^Y- z**t+Ab-zVu)&D)#OIUF_@?+VoGvjn^B>*jw^Hp|A&m_dLHaq9Mq59Qi3K>G9U%<;L6}%XC1|mm+alqT4hlEM zKxfcWVBHVbN~FJf@9cx=-+gqS$PF|?PTj>Tk1vii8COdiaeVdz@IjP-x(GnxGx!A~sLo$`>*miS+E_nf-6>;(jl4w5 z-@nEHO*7aQrVT;Ed6WLtB@zbriwGF4;UDVbpe?!|UV{{?ldRfI-vkw$q4nj;5X3i` zCf{0bt59>`BWbb_tuz*=|5U?i+mPc{=UzwGdw1)|@|~ zevP0b1BD*G>i4R~se063(;y;}5D@aqqZeXa%@jy>cyhq<~0V-D>5kVM3iV_-SA}h8LH^>Ml)?U0dDGs2|o#0wY9ap0SGeSgLN`aHorSElERZ=+Ka+-&U`)0br)I2%zaze&V$YC{N>eeOhG5% zKiDFVVwf+7=MV>J9km`jBHZH=NZxG9e@Xl*;N%Ie|Cuw}A6teC+bdGWgKunor}zV^ z;m<$uLI8--(=797=+`ZfxG>+{&@3A1&eEUQ@ts`7dYaX@yy0#q%=RBeKhlu_@FqTz zV+~qu2}Se1;k=HoH<)w2_t*Lu5FhH6vH~_`V>`dX7AT%Q{~-qW+yw=7BUkaQq8qk< z8SJbJm|eGY0@w?pqQ!9WPVV3PWA6hDw=7O4mu;+gpD6i=8^}cMrNgwv&RD`SB*xT_ z32b^h6-fU1ER$- zaV}LwSk%}=S|AK*+`9=lZ6)xr9vJ;w?Cn##3#YwFBN%Jjs>Zp@KN{$IiTIX9wg9gl z^Q6Eby#w0?U|@-kmPHCuc`FyV+p0rNNGbc)jn0B(GbMn%NMzJ6>IVSWlQ#?9YTm!q zUI8gpa87Yj@w2$6QPgli>ES)BUm?GWzjT$X_lWHQVuXEnpgH^I^xOyy!vk3cVtRaf zZfpC!jY2PsmSe{ifpbm1+A9#LJ));PBxU-=!iAYq2lc8(w(K7pmF=R zDc4t4 zLE<5p9^W1x?aU&cv()WSvX`^^VNo%l%2An9f03}KfGr#QwG9LcjUV=uMMgc}0%eJa ziTGUt5Taa<_Tl;Zl!a|=WA0dJ)=laz5Dtxrld2-L5#WDrB*EG?v6JGZ+enAZUXiMR zc|pl$y*#e`FTm#)NVjU+P!dqxx2!WShR~m@)zCYVZ?PTicS7Sfw#JsWU>C(C6yM!f zDj-pCa+aM_@9~YP;Qg~@QP4vj3av4le#MP}oewqog;{R{xgMEpF6mCN|E%_6$AU=N z-o+P4KK=a+M~ zQg@F{L7{;}KQk`>#0|RFEG}~WnRd4l)tbt~k^dS#1(G@_IkWuHc$=33SW2ZHHq-?< zwhw9=U>%tK)$0vuEA+*bX|tW{vlmFT(3$qtvwSe4i<^pTcvnnH?Fuz{pqmOdW(f$` zYoZ%*&v~l=Hd*Zw9VS$X(pH zajmDF8*ZKzY5o1V5IpuSsXa7TdeGAdd3#U@%F1*&&{a`P3snI=2AX^&Q#1s24ww`O zwi)ILlczR)Ed{7OJ$Cdqmq0U+CO`;0F(w;)TrW>;|5^h+wls#B(+V@km$>vy|4NYEAp~-5)?VlC`)RXww_99D_+3p3pwnQmR&Z3erw5YCL}&a1sjrlgpJdBwbT> z?Cz0xj^0^zJp8$%j5ZaoM6btQ$+4 zQAZ%-eA;Yu-(A3(qsHD2Jv!DJZh%<#mcBo|6*{^{Y=`J>faO4o(7V1^K&X^CnEpf&-t#=P z9qbXg6*3jH2^t~#-JGj!rUSUP#3J4Tp(daS&ynu6fUzbIs7VFjeF!PYICoOjW*By9 zE?4kzI`jP31vddnmlK^vd$4uavseCHgIx)p_U^5aARy{Em#F(XG7hiIZpE!;10Q19 zeZDYO4CGlZRzTqv@j!sKji8yQvO{KP!Ht$u{nCq&wvltx18#m-rstx?k*VOCJ4Vg#s6Nm$ouSazfkWH9wL+VzWKM3uHfomUIPhKCJ#pwJiB;Usp^Y`?4*i34;mD*rnBn+8m= z!`}{LFn{&10-MzS`UOH9p3LQuRN#Mhxfk)sLl`p7y2Y~*OcJh{voCJw(yv?S)%m~) ziRusXM^n@M!tX+%XZP?|N&W?vayi>r^O$o8=##nMlWM${;C99I@+(EyO^|$}?m0{C zgd+ebgBvq?&0HnN&!`oAq4{MO`&&VwEs}VyM!A`m*aLCHVdbi;gxv}rVP7Y)ML&;0A=odyBZ%iRrf@J>QzGe1p<+pZ~Io@r>{MgG%oQ1 zHZnj4|%3?_t{M+O3QIK1>bRprma4lduTF3s0%%=l zz)HG1OdEoV^V$OuXnu!mH3Eq0hmVUZ!oI`Je5@fBdjKZd@8p2>PH|UfcJZ=s2Z$x1 zd#{xY3&XW4BD+QC74iX)Cx;kwzk{>|XtI->Cy9%*i*4`Q-_ZrW0*K}s19JRPr ztq3r3gQ8_lU;gy9Dj)}krpCrUB&7it6U`(Mf^2PA-!sf?*?W9-$U=4HS>iX|a5h1H z*_LM#eg#53lBgpCK}ATyZa2-X4~vU@0&x}6z!}}*fBz1!DB(AtV|ITv>;Pd*Uj6Q1 zJL7kUezi<~=!>Y>cE`W}aySF)j@?C{BK|9(ue2;>m{?UemnWT*G<4!_1v&#r6<5wBil*nk}GlOi9aPn@dD z3QGf`NN?}!tlXJ!$eA~$)D;oPxCf`9DhfPxPqKZAnt85*a|B~oj~jMd4HC;Gw^LOJ zxA_i=8a)7~k9x(sA;k!!Y|yLkJWzt*@;!fB{-)gRclAt5JV2I~DE1#HBrJ8`NO~Qd zmo(Wm*352U%xTZlw(SDBsoS+pp}OYY2wzvO+c}44zSc_THApF%eu+g}?q}S)vGkpcovR3nJ75PnYYd%370ID66z)w^0? zM+J2kLIP~A;RQmUuu^`2ATLjWI2_-0HkLS;FpyAO8trq3fcm>%BV^Ch0ziV5;P3s> zU3_b(3EgSu`(|K+Z4V-azxCX&GdoJ{S#}n;)c9#Vg#3KS;#+w8?N9OhXDy|XdO`-| zBJ6dl+W&_AWj;*Xj$b{B8?n>@BE^S1ho-jIe|qrK_8D+oOuqCHLxU5=IHJS8GT?+9 ze1Bq4vfe49?{5s}0C%8-W6mMRGwp!#rxCwuonG1+bu?*b#27d>bD%YzB00)HiC4Xo zKkFDlp?gO=!gz`^ZdJJk@kc5A3}x*1NSW)>o4!~nAWsMxye#q{ct|Pc`=33OqzW8P zgcw(l;@5h4+PT%-UQ^e+@w79DXjggv%9x;6YnuM)ge)>Y=%@3~@ktxQXmJc73B zAN*;h$fuHI0sfOi`$^luQR@SGJ#5<MZ^x{Si@<#8AIB<|xoHr`-Nb5dt+tWyoVc;(6Xg}+D8*Dfj^H)=zq`e>v zC}-Yc->BNEFZ&asE8?3g>)ut&}T)$ z18_0oqT4U!H~QTDe}h)FplUQ=MGv*6UnB(cv=NE+r63fF8gUh}uI_l{c*+)1URez@ zkMXPGlhx|?9)skHf>kes8#(_jt^qcx&|b5#^X=TaE~^9w&hFZZBB#@^?v`h{B0kGI zr#G>rp{luGnBNQ8${Ko#W|{ha(nM!WkSgqThwi5bYovaWdi37jel@YR<+;}!^GPFH zoH;Kwt3c_Xh>+ATEz{k#o^$$HSq;q;&;3$1T1^v}kFC>rV_WJWYh;#Bn#z@O; zP~9d>Pb}YEkn)!1y@|e;zF{XShfp0?2%U;Z$I1kj@cp}if`TK@E_XS3%~kB>Uk;eP zW1M2K>_=l^b~5~^Wt#Ggn%f)!a0GF>`>J5tQ;uTH9J@2kJJ_!HjEiH{M;f@|UWO8O zIrGa(W^IOKh%3da^Lu~gcR>$|ir$i*cB&WO5V6E3`guRoNW+8}8jd!-Sjj2kxtEUI zWxufh%4fY1G^x;W-=bP%8KuGFSp+Md?z&YLeMf{j?;mcyP&s?6KblTYY&2O|VOBdC zGfGsXA!S(n5Jz6UVgMxZ4x)`7PG384K4gsof36MR3m3t6ww*c!e#TGrbPIa|LuK5q zNY{w4SY5x4Q=xxrCeZCJA1`R!*b_9`0wj|1p$NmE*C6)wZ;aSz54xi$LCeY_c zOX!(Ct$(oQ*(k3=3)czTV|pI-#6mu_#X{HxdvACa!N?Absq88GD-|!=y$}-E@TBIN z8$n0v9GO(a&N&$^8y|!=&E05^7z_RHRepm8%-K{w7;MCSFE z6Gn3Hc5m3I#20$nLP4bAv!=u|0ey7b1-U^B&rV^MZftGDm~*z7pusm%-q4u5eGP^F z#;&$(FKGyOD|Pp*g6b5aQzrJ#{%{i-XW`~6gBHN_6CPBLXV-#{edMouN<1SoaULA+ z*-hpi$dl&rlcRd}q5ts=5@x^8ko?3?o6JHVoTEIzB}Y)qAh0jWZKo2F4IQ2@K_SlEo1t z?%pCTwOH{TXGKTq4uQZ<3oMi&+~(3v2TL2x4PTQZNrDTF6xG5 zaLs+qLMz`CmG!ce^+ABag=2&Eo}I5lQ%H^pW4~oT6;u{j>0h|&XOC&5k6c(LH?v!! zTM@?B7GIqX&|;MOOj9m)bzafx=sNmTQ)zgq(#6}L)$b((^a23e)}4G+_d-j?4fay_ zmbJF%Q-`#@ylr|xX7eL>-3#A?zb8E!poBE(LwQTB2YqH0i_88tjP=7|Io3qdYd41b z4OOCJ44Twi$r2-}PQu!!6?-a*BH1{6Y~O{6vBJfxzHGdG{h2kzhUU+nodlMAeNDdA z{^2ES=i7~xPr*NKIt1Ry?5uO()(*~~cBuOi6(z3uURPEf-!|^AjvuN|i5s{w*pXb; z9qUi1(7DeWnZzc>wJg`nF|X^V-*}%oaI}$4obNAbq%(a6bQ;(#%y~_KO84-wK+V$_ zd)+e05Kmg|*t5$d3P!j?_l18*D!HTl*n|ZcSG77c%y^X`{G;<}|JT+Tz?QDY-o9NbBm@0XF4 zk1|ZYGicbQ_a$RS)&lcK@sMMhf3J?bx}NPqu!vm^_N*%U^fQvpBJ<(dV1a33jaa?> z3VL+%0mV)04s1S@4wfAH&ePa3%6wnP@Uh)qN@9BkSOqNgtfe!3(ew1m8~b(?j*&NxGKz=cFQ_Dp&mG1Q#GJ{q{T5}=`L-yt)HU7Y89V5ztHy!}pN z{f2^$SmNVUCHa4aOMP%4p)9(G#9Oz&UK_Et;q!JKMFv><-Y&k}7180Ja$+sWlogkA zbpi7Xe=qvEWLgo4G){l;Tq&^4ypbv%k3MHrVWJQ?Gdbu*ZqFhV``cF(QS3m7X=$d? z10+Y+B0Y=8=}MLT2ta`z)=tc>Y(C&PyONkRmO`B=TS6OKbYYvDs}iUmyn{46!unmu zusc>bj%OG_C!`_a7T7`8XU$JM>3IxQH+p1D%RVG<4b^dTq@kzKTz@lNY~aEwE2Y&e zg3_@^Q0%;}tbS|0n3DP>XIcNKB-Y|ct6@JSdkXjR>cWzbbB9~poR3J(LDZm$mOZ0p z1sg-dJf|^6`Sy6pQs_Z@lvLX|LN~MRDPi(~{nO31C7Nwcu*7Y?CjD{zK5G5jON zN?cdONJ?uwz_wd0qIVYC+$wan_rm}hc$TDA9eH+NqA@@eQzND?hZ?l)Ww}`bG&zF} z^q3WfDCivLTdr9QJjS}}JH*kjxRUUNg6SG=UkB(V7o$m;DYF;JXfOgzwsqy8Vs*)gBWs#cn!i4qRS*SEf0c22x@95oI4#6ot38F< z3Q)nU1fcnM&5{PnqhjxT=`aV8v6p^4k|N0j25Gm*PM}75=1lt0MlLRwh(@r6Jp1=( zQk<#bi>K2mr9{f&aSNtqTcngx6;5Qp5NzP??sY#X;`wZDJLPi^ozbw~Q!|j#{N;;Z z)p+5T>S)J&eIMGq7EqV8Ne8Ek;1XdIr2UbZR&q%D&!V@NdE*hNAF4fpxn7i?A{?bF zt5fV>;W(Q)6VabSDxASkX-uE$ns#|zd;Q!Q4>7Dq1+KOR*TBB1Izt=&7(KVypG?Bl zDlsZkapdZ7ZN7>Q|9uriiZ5^&gfVv@Au;wEzqxDBC-*bym%3O7ELba6B+6B-GLtbn zF({~7$|yhoh_n%DtbB%^H@<=)y<`TWCQ-O>F#;P1VW%^{1Rfl_B**+w;skmDAYTRQ zMOA^#!yC!9_+zx}8G0l|p=t&pix_6ERLme)R8bOP-$H{MX)L&t8cpe7vYFl8^h0wp zya%I>p+-%#X8i3rIC~oABFDejKLl77iE=4Eg%E|C4Kl8l^3fV1rC(hjV@l1aasK-_ zgNC2&ElAWHc0!}Wruf#T^V|d%ah6d*3c6Y_zB#a$8EKf&qtJepD z3(RhyBwh5NuG#qt5t@_K$o$hr_9QESt0o_59KqWJUD9l&{T(ve^Mt`1y>finlQO%Q zMwv%EK*#3fqbbtg=r6-ph5t6tI=Bv05#;~u=6;|%E zn4YJGXx=YkWv@ZKx9LYQApTm%=&~q&e;!g!LP=|!^KDXoWo1S4gEWbdvtk`I!DOmn zgHnxsH>puJux#%R;0VeOTGV!!l4r7{y{;`3KcQbFd*;u$SBhfw^BM< zYduV3F@-9zaEd}#%gY%r{1zUdxls8L+0l)l9Q#CHBnq&}rKd3J0!L7@ASyRazIx+L z%Sdnr0 zAL8o5l78F86<#g({a-@6CQpp@eAIgAmii5`!wjhe|EdK;P&H^&*g02E+`vTqOeigO1h z9xWSw@ihg@WCtJrb1C6Ib}*;m--gBX0ft4#)eM%VP_xm#-W^#^#I%)@vl8NfniWnzenD>idwb=K-Z#c+h?4V7uVa~ zyQyU#BB?tR+n?3I-n+<9dbX5skyAd59a#w&XoMs5VxVMqKwEV_Rcil4nZJ7KS|9Me zAX^CX&+x3SnUM!LBjYvmbv7`j+}yzU;3P&J(_$~Ie-NhxP>$1)0q38m%mJcveah&I zPs~U~clD_jZc~#hJHQE4RPuP5SJ#L3l(-a*P*VfQfQ(3SuQ>86l`ERtBuE50%iU8U za4oWC4g5lP60m2~#ek425wUl^j#XaJ!YPW(+ESwpaTcccs{I>)9;KR89j+uZbL$L_GIp{ zVay>tCg|`l8af_rXSaSJN!4+e2!2X&1>F#$RA~^OBK*;$YdBju0aR8Nwcsfq5WF_u z?4q19l#S03(InJsNOM|P!f<8=yuiuo8rcmsSoSYYp*ZKa&+j8C6ITM(R)FWKwhYu< zH;>F$ZKH#5u+vLx$Sa>hPo;Ed+g%B<^9V!bZ|K40nt@eIda(4%73*>P;I+eg`u6y^ zQ$j~UV`#RavuswL!r|-}EsMh*PU(d?8!GT*3Hx9uDn&RU>{YT7vQWZHb}cBM5;1mC zS5|izYaT(7UMT!c*Pe}#y&hE~TVWy>xLScDeAWR649KAfOn)=DBMFFFJRX>Z2WUfqjp!PBsU>{DON7y3sggOwuD z{HcLSftu|&^c_ZJg^6Z!HToPz3Qp(urGs?dM03o9JYs~s;nJ?<1!_Bp@3!{Q75 zl=+6srRL3-!=f18ySTjfJpQeW=*8ecO%|l@MBO6hD~XsguTcbdsulF{)T*Axg~5BF zaIrEHo(F_rb^-|?PA|L^X(-~C^j&d5Zg}seoAANAeHh1n22?1bKGOI=w)57-_NgkY4T^#awmi1eHgL1*W=^Ha&V_ zSz_NpfOrq0!>NlN_=uthvIDGuZXoswW{+l)PECH+xF203KZPy`PKmI%@Dei;X0%}p zd!m*FwZTn5+Mj92lYRp)!Exp$&?^_NclsuPJJ>iuT26(L$t0gUo$T z7GYCr#~bL$vU066nfd-4M=7H&%4s+F%plycHU{WHd*x6oI<{hl$u?O?Y-JuqtExth zfl|+U40sMhP@-tW(9~|wnaqcCsA4qB=3*CyKY)Bii&BlYz?@=YbQ!08Fr_>o-*F@m zjx@0I*|rGs5?N`HCFRh=Jy~*F{^nz=-;{fdC(-1zz&GKq4*y2COYiUCKMKabjUZ%& zS`G7?)|2*`kw}@+jFk#2Fr*QgM6s>y@UYYEu*sicy{626S@7wm<5oCQ&G1o+d7x$w zP=r%xS;{w7Erk{Ql!j%_2MeTt;$vpsa$&#V95SM;tHbH&Cpw*&lTFGoB{iFk(QmCV z76KK`X>jWDJcX3q4%edds5aKF(6uj;o9ABMgF~L*)k=KaXXvBczPdNnf2Qy0z_wAC z(ln;!dEIkT%bOj$BI+HD9yL>2S$4ipd=ujGt80>Q4fHWSPG3|~vtvz7%9w)cO?t5q*B@ZZ z2a+*0AtvN#oYc9;Db4XYITYqaM#5yj3oL^qY9MFTEaExI2air%OUzP2&DIA?^JJ0Q z^%7r!+c{{GSlr}$T_V}aTb=sC$T95S_*Wb5Y+CPWU}cZ{Ma_gH#I0A)q)K1uzhKiv zQEl7yRmVtP_4Rmwo5k4XAj#^Q8l9Qc4^;W93)LPJl0OGV+|aq>TzT0LdM1J%r^k)67fr6b z0u8i#t-d77su!mS&u~zSDa~r*5q}-=QrNeUcH5N3@#~qL-_Km>yB&sttru9~wFRt) z>U-dt4+6{We^y(QLb%^PiQ)&zoDL|ns*@-t{T4R8^!x~R{t9X#b$og6H_TOV@Uo1| zNvf)htjYJgg2h%;l2{-CeLoB=Lc`o`&RaxNHY~Vf!u|&VT*hBGkS`mxDcQ=Iy|hSo zocNDzT6O8YvjlVvPm7*)RS%YXAcJ+3#Bu~50*4j*J@-?aRd zGV4)lt~t2MnHVe0k)kILQT^4!2Cw0Qbx9Q|lemFuPZF@)J+rc^?|r&ri4?t>bzIp& zE7DlLU2TVB*R745lI`q3(N~45X)0Byv)DJ@Au$jLz?Qgh;Z z5St!Bi4a_|9G-sTKICQ22sCHTpBctbqc4+4u3sj}`92slrMf1(ePvnd&7SIK^H^ri zlzU#FtuDAk7@7ukQDvk&_1ReS0Nwi=u_oV8H`evhSVbOX-tHoaB)85&+M9M!PEJwI zh0uKLSwTCrLt2D(A}zN}hJZT7^}{ z_+qP}4C@np^Gy?HS@S5*%&9r^aJB6IG<0n0L^xIxpmjXaGoe|+%l#=pJ_t+TUw%Q% z8FzL0!s_jz(>*G2HT-x$xK5d@Zdte08mm2!BTzD{wZ^iw(~Uhm1PY_v`B^-Z-8aGX zhYL0HizK?AXsw0w5D7Q)5`@{o`5l1&VPQfdn;=9Ml5XH`5ZV<%?QakMbssn&o^h(l zU*Ab;-hEWv`|E7P7myTKy$n7kf|U$It@*20b5Hr0&+KY0lZ5N_(?P+S;+T|*4%EC; z-_zg}U8WYvx2$$FkTc3B>5{E5XrCW3R*`Q~lYc*8CbCAlUPGxKhsU!d*N<6P^isON zIpyX|F@SR|&4i2HH;(AA1trwF8kA@{wmCWu^pGtm5PpksZ%NS7+K#aHuA=TOSfEwJ zE?2IdBTMNgkAHKm$}rmBMDLWGl z7;~&q>&@F+vyKfm%ZxO4<8EEicwUWmzY;R78;SE2rPo9WtXR8E26bIqAPW?0y5*=Z znJxCkKL6Nl&Y~0N#eJ%!oa`{tMt0S>XMA0QevMRl1+B#+_V*WLaaw$wLEZxR@)^32 z-fsPwkjW#SSL9KM1(ybl{^nqi)uMMK?4A~;6s`_6Pki%bGa8zEm;GPoQQV#TL1{9{ z^g&FL4*@hU!~GVh%`cRcv{y13@e1esjG@ zVF$)jk64fU;3WBK9#$So&UHVaI5bq{=iL?i#q26KSY0(@BvOU9Cog0h}ZecQIAul z>X#mjp+@PTG6@8oKL~*al_-)29UZzt#0eKD!|_ERW$K=fx@;1Us6}sS@9Diw7ypSc zypYR1N0|((U=JJCrY{?GDyaANx9hor*5Ji$Z%bRj;oUzv4IvJKIqZ)#cWR|rGPn0e z^qj3t)7*y}kf?pX)E&NuvH#d8Kh(ChUfZB}sjqZq%TbFvJKsm_g`jYg#cbhefh1nV z8shM#y8i?5!O7JZ;oR5~ZsIyH(a{MhMi+U!8RkKt|Np31_P#~&uYtzk%XmztjtrP^ zD8-jz!x_8(hmRZW5*Q`zvLaqAp36aT^N_$JW9OCYKFPC{V@kW=#dp*LZnZruA1#^L zVhOh>K~abEPeg43Iox!Shsr~eKltH$@L%PF`2j#N{Ib=-#M3(H_Po|Bd*w|lVgKg3 zPuVM*RUODcmBR0W&Xc{M^8|>wnY$L~>G@nK>@X4ULBd`~fP?h^TCHWrjM@QSNIPya zS2ItYr|DepX=e?Q-qv>b|Q7+3)bno3?!d-9A(6f zep*k4a8FMC!4Ld)QxbJ~5J|MfzCvE{kpG4RC6SktdICQv7h2ZH2I>Sw9z+f8bd|D$ z2FOSLB?tqqJt%Go(}v>mZmG|7sJ~nK3BhXs!8sib00F6_Pi?-f4*nX>yJh3KPiR=- zzqum8gqyygE7`?DC45`NLqPAs$)EOnS;XsS_zpBRT@HxeOTnLjn(X`i?u>r^(=qrb zNMU$K#dxH<3E>D2VlQnGVE;2r6jIjyHZ0d0*7%Yvpy?lsnX=M%86Z~&lVgyB3us@! zC?NzVUj=(_0!Nu9S;e9yLR%8{J5w=KA91h$=2~3-lnkEN4@}=!+umskK6$wAH#zn8 zzrJpISoU9k&terCJc3;7mblY2p~o45gNBmu3`OwH zcGy1j^CrLj-mdg$lR-_ff&h0AtU$Z>B^F%Ei{~L_ZTq<9_LR0#d=xrj-@B|M zEds?I;LiJq9t|p1T?TWBsqj=nsJDqhD>?&j6)%bXREa*qPzVb!SF?7Q8^wyhIJK_#< z>9ub0TOfA^4-Abfia}efO6tw&f%XrgVA@3AHmYg0*BWkaef19f-15fg z)zg|T^=&+0p5my@9h)#-s6b>cVz68)ng>zc5@UV_#S^zNeN{QwU%i2xP`*_N3^$o; zl4ryNh8>NwSWto@mo07OL`f9D?2Ng)Cx0(cn_HNkte#BhvyvxUo$B{?3xifC&Q%wO zrWZ?F;5kzpFPk8EIgq?{ik7Y8e{A^2Ua_IC)1a^OM_Ly{3siYVDTEAai~csA-GT_< zGB?4|*=tS>Y>0r3QT!sNCwom%pzhXDb)#|rEe1(kez{oof?xau?eyXXy~rwMSt~eN zILl(GYR3f->Mes}Wa+wZ5%5f=!-Wyy9gOH|Wm(XjnXDBc$n6;OHuHN|h8M1diVxT9 zIOE9^cMO)-AA8M9<^hMlSe&{w^n&*$kXHuM-5e-Xv22Z)FW*0#n+2Pe4FkBaY@gfd zz;1vfY;5C1x76LrMn2>D2gXeVxWCw+TUb3~P;H5}z ztD64n_Q38!htS^mSc?YMpMIzlLjM3~OqM+d&5^e;hPmJ1Ji0)H7{g;6(YQJ= zpqAE~JyrcqN!86ZB?Q(kV_w?!;#_!ee&Ubq$1laAR&G#g^-}YvK5A|Zq`Bo2OMF2T z0}EOcCv&axEDq~A&x#GosvPgCg;fDO(eC7MKpH%-Zf4ty?Bh}9 zqL1mx?a6tPW{Z*)A38|3xY7Twy@~wN!3FeMy(p95bF9(h06yMgEG?4x2CVU{Xw|ir zwxTlnT<5UfG%e^nxSRNNRd*d$Q+QV2#cItU;Te$_Ck`Ib0R-n-l!SY_?u0{yzE!WR z#+xY0<3k@^%z4uo;BlaTea{^JUP_VhUPCktyj{UY%SK-x@Z6SI>{;Rxf2c$uXz07# zZHdfrdsyz^>6dr8lVF^Ge8uZiR+QJcA0K>*OPgGEDZy|ttt=9*?8$wSmeMk)PIkTV z>R%Kph8!`_H|gCMbet=e#BT1fyW4Iea@75!yq-L9s{B4nFxqrM30B0L+QE+3oOfP| zvg?g`t?CCWLVgkN^`1_Er$i=A;?$^7^`$3~k+Wj>0cN%lco)N(_-`cF*t4?$$)izT zqjZ5m@3#c9B!{^#1wyLrmfnwt;=W>DUT)z!*Fyf@lS*d;b%bH9%nwrJL*-YO5{; zg3#mR4nSIUTh$^574+nDraBA)cXXBiONLf3x+4nq4A*yWj>cbE8l{cEGO@~Pk9cZ} zy5z!DgJEOa9MB?jQ17ZM1DmSctfk`lFRhK&63%FBxd*H3n*to)T^a^nAw6sAjF-Bs z_hCKrMizMA!8tuW=0CPpG)LH0#TpE%TB}ltr0olE4d-I&ZjQa}vSB6?&nI&oCF${p z%*_{Zv@;KyPW+47mXIS;@3t!~xV}@FYZ?6Fv!S>u79}V@2$s?CHTVM3Xc+95Mm!q0J(SkzuysU(UlU>A%1_cTD;-_-6u?{M-T|tlday~rq&3^== zsIdN`06dsdMX%MrA~y8CIvdt>(H!-yzeOtodcEcIEutcUrDR($G}DSTLM4=xj$!_z zjn6Ra99CwOQq5Pww$GPxkbm#uPO4sUGN;r0Hsoa3ZPkqQR?wTwTta=aBIl|I{ljHy zsRz9ZJpJJZ@Xn^a^miQsFQL)7*7_!%Qtf@r_9?RvP|}o*q+7aM>Rr_uV0EnbGtr-B za=iS{QAX0eAN*kc8St2mJF~X~cN)7kPcV9Jj|-r$$F3hg)ujeMSM>PKvvo~y+Xm9a zjCM}37Tua(EM7^BqyS9dm^r2gJg4y_4W|$k!l+5{j{dQx|L2#3W;v=gjOd9xFRZ(D z$xZVW(mJ!&k5l#K{KeYE5<9PNpyL46$j@fof#r5Jw&Z+RUCgREOwz1%aT98lCU(cp z=t&x}B_{q1>VK9-IK$M!PtcLUpol>=%e{*N_qNM*zRa zone;v-lp8~WhOf_>GBTE1>?ZQURUs{jzsDG+V6knt1{Z*p)lBCeeOI3KZ&wZmPL&@ zlL8*{1MMfZ3wum0q zyfD1~$3nkT?Cx~W#Hi0)varWuD%mqtFv%J2tf5-^|D)^8z7v-LY#W8v zcX=O7a(1aQrheVs4S4B&H&FA~Ba4)SDfe(7n@q188f|!<1`KY9K2n)KbNyo9?%$8_ z%NxmryD)<`%hVahqykb+`(9(hnoMs|zOki0rA$d9=&9Xwlb}2o&QMNE?K2KWl>#E^ z*Vyuhb)f5YUjDo0994cgpep`sv-T@{*|JVygEE^Sl0ZCr^Dv5&Y zvrUhx#5MDoVcMW+Mn8{_|5em>QhVKK6m3*@fD7$2V;m2T^ZyQ26mct_$B}$67)a6u z?!11%HEU7y%MZpnip)->CrEM5hfBSA4aLH_j-bp@lib50Qz;{hW;@@c`*F|i@Sh*^&JS$Z|pfXjf>-p>? zxsfwZuV+8b4!RPQYe0*KP(fwQ9mczg=fzS)f3YJBvI!pu+6|HT*ki-Rnc_gGth2{_ zh?Xr&^kjB zYPw{1@euCM5dP-v0;Zb{q~qo8A!hVs+FtquR-GIGSKY0TL3QU;m7zouo1LmS5r2>w zn*R&QcQu)SNLxQ6p{+9N$)L_J}KafNeWj?n39Wd>)~5Jt!Ioo?{6pJ?+* zW^GW&jrg|mxI1u|PwB=E^E)spslIZCz;SeJD(+OX zvJ^u|AK5mG9nTweANBfYIde%URrv^R@?DBmunq4{yULbILc3}-!687_*ay$c<2R2&<=|8)M`_2iw`XFwYJQ|nW=;yMtH1JIvwEFay41+1R=c zol|=ZgtLkViU;KM5SddeS3nn~Lv5I#y@)=wq(v8}8;gjIVTRFVx;X4SvV=nDIy?UB zvT#)2cN;O>YeSuSBp1BMTe#mJL(6`c6w6Lu7syqs2TWYe`RH=}dHK#A@y(M4_yoZT zFNmKLDMz)=Qyr7za!Wot0iOF=;K~a#1V1r@-`{;IXZDr_dnG&Yuy&2r?1n_|54BEQMag9pK^5Sgd+eLLM36J1Vk?VC1DspAXZh z)I`u4i-v@z<30;izcv%!ta1vK1$5BT-Vs(G?f z=>TNnQ~1#Nd?AE?dmF^ox#I*<(lMyQ6C0UR?X@ctkB2vwd#Gw-=gbW}e_v=Szmzbk zIp!yCNE82ffgsW4&5AyU^eH{Ro7&npOz~IzBo*IROOx``W%)0!;Q9FTOaUP>Uwtw2 zfm=y!q3lIv{Mqz$yDx~xRyBbo90Tw35LV_<%2b*PYFsLUam&3ctpiXPzq!)>?8W+7 ziheA6bSkFjBzy8YoTCj=SUH6?Hc1+6tTraK_W-HFxA zxv;xNr=IcPvGJ!I0bt}BxAF8k2maZI<1lX7iTimInkTL_%4oEZ!sY}!hr{W@`?t8^ zYK>^J`^{N3hm_GTkn4TW-EsC6LsOXLrE6z%!s68vJZ&?N!z&g_qYaNQ>n`pQVQ20u zZZ*nx8<(Xgrrh%v>)h5vA#b%Z?X4)((jXil3jS7Cffxj8zKh9twVs5XJ}ZcL=jcgo zkK+@n_l@!^IXI0b^Q9VUZ+>`lIeyc}D&I4`d&YcrSY^HZT?rP{@GFd+LCxs{66SlvSlPC2{~y3%M?~6RCiobZs51w6Bm4GA;qq4;Np4Bt6}0!TLB7OY~~)1^tQ?z41`h)Pa?Zh4_!D{ zx0XG6{-e^eWfb)J^h#(9!Zd;q<=Y0{m0v2HY0K3^(NXE4qP)aV4KJ{y&V^=YerX-x ziA%nmnPG?}R7`4*9xyJaYJQ_YcN)CXO-WOq=rKPUzHBRtb=>Y`^RL)%09k8B8^RyO zun#;2y=N|v7hYU~H&qg>%$}`fLQ?i<-;V|ztwGR$CaNn<-`|GWO#C-~AmC1^Emx+| zm5IH4(?gi)URM8f$g;aL84rI!Xzpfuf>_VqIK*i$qfV9A`Sk9|*dTSaHwV3XH*)WR zUMu@S;~Jks<4YU*kJh>L?#X=X@jlJp%lX*cUIbZB6+SCq%lS?Ro!ddL-1VNHoen}5J*{fQ%?sFp)8%5s&4P1;7t1Ge zb0y~yq=f2k)~~g?tk(b3%C6(s>2voX);}r4W-*Aw(UI(=07slZnQL6_QybYxVhVtE(mFXCdmm-pN*4 zqz5nR7p;#Jbo61RMkJ>pVXiCg&JI&&muAFr-?R{-T5~N^85Ya3pjQ^4FU0x_C8Dx9 zqJeiUlY85rWyy3fXDK!cvi-b%@e35nvAC77j#q<@B{sfRS0TH$2u6d+j`?=4e-;tK zM3XJRt2_0mVx5`WwLgwPPeL^bAx%!-)Rj=3ObdD*^PV(ow3uMg>>Xb&{VtqWT$4~e z6~^*>3Q{}IbMy38W4|CWGlt-pUVRQD z=|!1WS>!YN59Gvcg*&ViCBCU)%yg`0pc<45Lt=*Jo6<YH>_E^|&VLO5P!|7Z*Q%~CsH|`}^)-I=fB%)kvX%i6#v&HuTA5 z?;B`qGnnsroKq7F@a5TF_QV~N=AsMqYDaOIFVpe8oq(l<@VI_cx&*Q0WmSV9F9iIW zxSc(o4_qzhj)XS})eRYK;Dh_#ro@z%UlCWQ?k6oWXBGB#1HF41jp`G7e=H8>*Q-Js-=#A^`+REtk^18sc8Utv9 zA?_eBie0j2A8iti6M`}dO;U4dD7>2rilqM}yS!W8^}?C$Y@qmC{t#-zt_7WH9CEss z!$F412j1Zzhi8od7QeX&2S7BK>( zwP)37_WNN)vfY2SzfI(c3coVmxM{{@_WD0|>9E<2LdK|WkhRA@Sr<1}oYbw`a`aCm ziVnJd)A+3j4)3C@Z2H)IUS{$wXO0BDJpSSSN}ZM2NJJa?`((S<4yvHiEUm6=^SH`q zjbmd=yqkMJL3DowIOajO0Szb`#RHxnM!(BjZucMT(RHUG#*$HC7I<2(=VkMTh1tQH z#iJbT?dXO})^GZV(nDNLMgsR^b5n0I*6g(c%(R9Q1?i!CDa5yY(Au{6XhKm z@$7F?iy%ThVK-nh7iDqyvMpuwy7CiW0_F>`kEOt?^UFC7#P&pM3w$Eg7RaL5B#X_d zb|()*j_w)z*!;7%35XBj*OOd2~&`~Ye7&73t_zSu-damML9qkYpmynTD z4_i8``7Mf8*X4|9Ks$MkdfS=DbFx7>Kp;CJP#LDWmM;N-P!y?a(1#yp<}#uCH3 zu=&~nCmOJ%bU&B%|@pr=9K|keqewTCG-n=tg!nXb)wA6cV1PKOui-PX& z0Uusf9nmS*dXGoxH8QoByaIjE5`=MS+B3%>>M}3U@jpd#9LzV96XtexT`eQn49)@>)>l2vu>i z1RR#+-=(*7?#Iw!l8^<=YxD&P$^-DRZZ#8)|`wooy)^BdZ4MkQ73OU&Pk5MIvpqpTK?_RD%q6CF5S*$yN{ecn><}COhh7BrnXFRn zWsmd_2M8MNx2khtC)lP!WsH*S3cOQe(#wnc$at({Kp+}`>o9G!t#0vRQwaucUyE`E&qhLTsQ~h ztkipSZ1?W3+>lYANRk471~f)kI-1kT>0ljC0){ju4>!L(K9XJj=bh%4`UB53Im;l^ zncY(#?%gvF8yh_Vlr0A^bNlalXC`JWmx5ovrS!yZ;0}9tJVvLb+6Q&ODk)ZG2r#Tg zw_afCm&DCX`a_xLCCrsgNEsWPp=V6Ks`yA$?RkdcxCG~5emEk@C8{nd*o)CDmv~Cg zdSVwk=Y1=q-_h)T3nj(Gu#sGo(fDe(A3np|Imc|?UseG$XJS(S$xRhE3|3}W{wU}P zK{y@WHNBKMBrM6+lU9Mjk0fNkCF_4$Htw{lX_$6Uded`g7lPdLKf~t;scQ!>H}N{h z%FN^;BPO=_4|o?LBSY-Bn)a777A=dB94G^(5`0buMkLfR_7(tx_nUd|o|(1EtEucX zXT3VDbaH72%5;ppP~>mx0T~n$f7WGqJ*j%={kzc59K{PRXWa`a%>pba-P+hN)57Uw z9XB~P-;dw7C;s58U*~6amHC*&czrRmubj_nuy=lO%?wmwJv@t*iXIO&%R#~6!iRW&st21?hP1C1bPCMZgmB+O>Hobo)P+B z!R`I$&ly_ck(G8hrd`-H;D0K<_lXVP?$Ozo!6@DBPH0=5~_tF6bkEYmy?U8BMkU2 zcU`h;TxbE;o9XiVZBf0wdve-e57W{^#1}2t!Oc6Yyxk+Dd40p75k-8@@3bvRh3}d2 z0B*o+Z$T`ZNTvr(_axA!{ib6d7Q2LPG#MIDg&TOQ&1cPA(R|rWFfEqB@Rnj1C;oP$ zXLOvviF*#FNEttw)V(<(ekyl{7Qo;9OJcEoYR*&0_g92MYkX{j2-ZQei<}Fe#(!HV z_!#%h^gwKP_t7!`(J6v^S9ZE+?g=^QMwT{Dl~j3>nE|FZhU&;vkQ4URgmQ`C z2CdJ=Phvj?x~nqMITri%7L_S0C}~;JfKWgTBLYLIF6uD*jYwCAXb+Vhz1`mDwkm);qv&msC`7cBt4_~ojvTi8-jXR9P?(=&9=-3DETVYQ6eM)k3QEIrwen<5ef z`IAv^Q2H#RPPsH#FEamWnh{!BGuI{b5@nM17IsdrL(O+K1Gjis8&;A=Yi1723An?D`h zHJg#Qq|yKD_ww{C^tU^FIWwq=zVgr#W(+j^_{jMhpbm;gUeMR!3o9z-Ir)jj#vU<` z-@GcYAc^^9TS9soV+yHVb760n6_8?lcw+`26XNj{6#ZZ~p+52k9@l$h69r=?gVAlv>EXDhfzpHPzY-?O?cMA(vKRS@tnx4)7 zC^!>l!d#^M+Rgbz!u)kP*C##-{JXMdz6tACSiU9f>rvN~`@>$AunepB1EWq&=Kl#H z&{R*%_X{+v&iETuV3TF(6Ux+F={LsGcnspS06WO2k#&34LfT*P#qjvEosykCu*beB zC?5ELJsKT;0f0-^ie#8oOZe*;QC}10^N_N`c{Y;m3=yc=^(s7=5WHalDL6pxX~)#v znjCfy7yC0xSz#rRs`xvrmLW;+0!-j5M2PDDN9g>cgc?*8pP7EZLvwhqs9wRYfDqGj zUS7$R_wrd8%pg7*Q%PtDfB0yvINBRYF;7eOgn<>YzZ-Y?VhwqL*wko$s;Hvkf?Lm< z%-rEILAZw>`B?L(+3^Q(xPEHp2o~h9K(wtS{?ql+?=?GsW{IZ&U!k@lzBA*KlOoTj zwnuWGz5hP;%Ua)fb)1F+9B`^zYz*xqR22Kjc=4Sw4-?rzD>@UkYv-GAB&=!kdtF}3Yen3n_!TmR42qk}w!v*LKoab~wkt{s zS)Vodd4*yn`5}n?zFi9AIZ(l8zdeLs|0$>1wAZ@n7~eVQU0{RnujX^cd*Q2JgKBw# zmzDR_AAW6wf7ZeawGp@|?u<|SeG#Y^el53 zZ*d&0k>ej6mH={lAI<9S$K3pXKe6dG@SqzU*=yC}yXi?57yP__1Yn{X#J1+_L_^*s zK;=SNTCx1(j<5bqATsblHf#`7`+Ic&oIuUqAq&t@!*rH5Hnv(*&97|<%U$Xr0Sj~Y zu9aEQS_kpBcqcX#(jzzeT6y!;M@@)RAJ`c2mh?%;hv4V1BN6?v8To%ZQtB-(choM4 zv+a|>s5THf6_MT{rpAE>0a}3LkDxe!mh-fh^Lux^C2QSKF@0NnYXiDDAMO+m| z{~Sz94Pgdoj%j{&C7?#W0)CG3A2JpNLY5zv*72$+fWqb^P{DWd^99d?&4bp!FP0)h zSZ%;B2#p{_40u`QWW8d5$Nm3!oCxGAJqh)^<1Ou?(Bb*Bd72vVeBJ%_MU!y?d~&1r zqvvdY^B(&Dp#p40ZXOFjvs;S5bMukZ=C;18Ad}D*l^i(1JRNHa99Dz402%W9_@6ze z1%SZWsi)vJ`^$ce#9G1gG8UWWO=yP0&7&7fKw1M(vj~GOzS}q7RG{q%1Mlf*fEE;2 zR8)fO%-KH@0TF0x#cV>ggcej~Aj%D&>l;0cg$dOkkYBHBKj}hEaq*clp*HGh?U7Bf zW0frcGHBw{*V2`=$(7Uil~a>^!1KGq{U~1l=}V(-R9B2qf_D6WF+C#h6-5912tvEr zrnYc#lBM|@a927utTTJ;`zmWm(0vHNL`LGJ=mZ9iTAdn*Nc$4-AfyTV5AP_Sr}5}b z)X}v4rQYG?Jv6lS-^(4>pV_qkHqaXX-A2zgjaNd4?)*LBs{+>hDm(Ma3)2oLM$YYP zKpOCV(6i4S|Gv%tbs}u-#cs{X!<@FIGp5vfsrTYzRZD&&CN;#~`_DdoWldcX7kL_q<0+0oBXc@d6nJKl;()5)|6WLj);SRAGK+7`qn2j!+jk;~t6Jf45? zm(yZWC*QvTK4@4}Fa1zI9`JQbSaMJJ4+_8U*^=cQk`U7`Vj_@}N_*P;8BQNdU9R`}obE zi`P8ipKr+(;@*t1_;RWaq_4i8Ox$~ZcJ;~y@Wh|km;xfuP>-JH#1E>bHzwMBF;x$Q zZ);(_SL65{q)|nC> zVELh8j-hC)h*EDipOvZWa>Jdif3-Ru-ESWLWCT+}HOpJwzy^M;TqPy{uHB|{v#)Oo z1rE_Z^$TcTiNh9IVG6Lf|BXVgnd~r!6l-g>0uBJZa@Sy)-1HW2n9mOH^db|k5OMcz z8GeVE0*IFfodzaV1lxb}XL(qNFuZy}$6`8l%foupCNAQ?N3qtVVm4z^tP{H0#K9(9G|>Ov1!M<`!*gJpmY) z(t8uB_#nqV;?Lh-L)`T{VVd3mh@yDWP5ce~ro(%@arKw_hZ{fwos=+DEv*JG3IoF+ zxc}l5o7OWst)UJsyKMX3BNUj_y-c568T$*#CUZ^4FcGNy?bp**e>ec&Vbym05A#mu z0W{%|D|rE^4Zah`85uY&%(``emB70DZxpL|V#tqwa7kZWIW*YU-T zHmRYp&Y0((Pyg{B6SfP^Ka@z3>L7{d+4lUAnv|5o0cb~qI= zRc)dk?GPB2-1NEWj3d@*E?V?_^}k@(RE?wgC2qyHhma>;jVH%?S-~>ac?96&Po8W( z_g|N=W%++6ObFuvn;-kS@~5it-B-Y~i_;p^KckPB7*d*&a%}3u%UcjJlH^ky%&odV z5T-WD=VH^uIH0=L0A>>AzaQhcnGxA`ua)JlUCk!5FaDCOnzp|;xZLTb8UQ}eB*=r7 zdN+>+=JQ;U$3VXGp915hGtGzCcJJUYXbFcWWZY8CmO-=h*I8k?RnSCOd%m0_exiEHJd~ z=g&jDqQ=#Ic}-u$$5&o^hquy+81q++-rpOJR)e>?16B#`7UGuljveeAQt+WV<`uKI7GS82=Ign zB8F)B!|${pi^2@*k}O9AswU)?>5X>@7=xN4m>gDJF`}GeE>%1`_SjcAkto6o&>hk4 zjHy2TN|Y!Wb9J3|lQ1n>1x^7ok$dEpZ33uX^UmSjnBS`lwNbi(N8!rfs5RuQu3pjJrNZ$+9!7#+W7KGiNzM9Cm5{RD^~L9SJHJF3Te2_UA2N_lV}qYojNd; z?;16YUkPGX`ImvJgG~BLp5a06Rxt9;m+eN&!-~>;p1Jwr+`!1a95(UiMQD{S ztRB_G4f<*J3alc+_}F(22~p6Sho7AB73e665!6%}A#Z+=Q$11a$Y*0!Q{@1^>d@)l zb#zQW?i3+d?OP#)*m(aDC40M!Evb{*&d&N8tTR7bmD{A|rS+ zM54|aUf{+EyFIp@ zv4JS|bgpn5x}45~Qq}kI+Q%!$4przTm{fN#4IXj0Q(PxuO=eFLoN#%ulJ{Wotq{da zy_gewBAGNLqfdwV##bzM$~kRm{A-*C&p3vW-Vv!_LRaKq!h_=pd1v}Z^({J_+-R7? zX44ENbnms6Zu4Tv^eE*t==?Lj&6x(~jr+6L^PN#7U-^F9e>u_GJvk&&%Cokkm>{!4MB=qfulI7gY{~5CAi-Iq zue|9uT?bFfXZMgYD3-K>oNqs6ZUdt%*`ymkxh&`Ob#*bl=CBjs^2R+1vr2`P2c4R- z9o0TsqBp)fC6P|lSicmZllak0$w8u8A*cn|)>fa_awu_Ck~y)Xpnh}7i`(cMaR#Gr z+OE$Wd(JR3=LeCj4pH8=$$X<;ER*5r_m~sLgPp>2!K^SiW2=57I2`K0b?%l^D^Kcq zX6yn;=7CN_rz`8c4b}^m_so`!VNP7-Z0bRAlg4mF%@lVbFEPl##5GC;$vrDboBhtZ z-!6(P8$H5o0yA-)3oVu`tpJl6)0)C)9(IrT2Er%m>OqC>P&u(sMMNL$)UiA$gxdgB z_u9;b`W98kT!rISuVzkCqT0+aEf=#BO9D&j%R**-Yq;P#_%;-@5Ojms0C*k0yKAwN zvSLoOwFi*-WX%(!fuH&Cv}-IgQM;U@&1QWNT+7U#n-LQ*vI7WDvVz#BnT!h14D_(>x{$-+yF_JTgT-Q+<0tyn z{$<0+r$q~bCRh(#)2MQ+aAzzGi@~5xs#3pvk0Y$W?Z&>u_SaGxRkF_FEpADi{?ef# z(VgWbi6v0Pl3DpWF3tx#+go3^PGUXZnF^s4SKwaH0t!_W%m_tKoISGB56(pAm!vY1 z=Vro|1U_99YDIGT0bAJFbV)_Ta+aHoISji}EADj;=%7p%a>d(?+Jc}{hdY&(!>oNK z-sTRf#n@y=HoAEy^NT#i=K6XERJTc$V8uUq;PIoo(y%d;IfSJ*ld;-%XJePDo}b=z z-{*Md=x$=F)Kc^pu$F|oz)q1iU|G+C7x9L9eSe*b(6QT*a)!ylA&uJRVd(a;YDbUQzfTG{Hu;x@5!5I!8TG{8Amhi>db zp{_4Wh5b3p7lSDmzS@}R{H5ZS|@tl1$6Ge^3 z^{4uIll;yw`{1J9NL4cCpF+`R5SRY+_-DDr;kyVh)a_$djNvFA3>UjWNFd?PvapiR)55yk7fegw zVI~xJf4cj9z50@z2 z5H@T6wkqAxTldl;e1AQmC2LFn_P(Bc)duVT{A2MoU;d_GU}tqR4{7CUe~LTzEI zKXd4UZ*1{=ji6EUqTp)H<1fe;{kemjfzr+)WYulu0SDjhNJm@9ZfA$|V?E^(Ws%nEq&x%5ooyf?ICTCmmz!PtO>)5BjD6qP~2w2d@Vv=w}FIx-XNW8 z&Y%QSskAEJ=KYPu6OEEwsZ?~gzA4;!wQ)U5gSM!p1TS9M>D&|LS9iN8ST@Gv2FevH zB>BH{?tCG$?AMakzjHxW=Vw^28cV{0Nzr%?X?L~H-)NfG*x15?+nlBwMpZ{-^X{E# z6Wz^d^=La}{z=XZ zufN&WiGGtsZ91G!KiUf(AfhgBXa)!7dVgiZwALrt)RuKCk#M757ZM|s1%RYXN;ggU`t9mq;vBC&Z&)) zK8!xMm_f*VeXc)kKQ6z|g6&LHbH&h6okY=}tVidn2J34hB-dz$ zuR`c?lDLsjI_bu92 zy?c;3GLg{+1^;DN zKWJHAiW^>;EsLx@5*{;F`SoZ?j$q6^F};}Jnlh?ZViyqHLSM&`NEBD z+WA5tBXpL%Ij=?dD)UL-z8Dw#NK*6-Wjq$N5dp3=7~%(ddx!X)ToGg~)Gu)DR|Uuw zHLLjgq)D z!l7L$BRr0B7ksh(b{5pAo~fv{gpBie^K#9dR_ z5x-~9enLk8H!p2)2E^a8%YFLw5-fPUJl4WcESOOLd<6` z1=ExBZ?%3YrD}$NilAucUD*l9@F}$!Jz*qX{M(o_6q6a4epkILK|QYYVqZ|dkImP+ z#p<8k&(lX3Q%mY}aoc0{z9i9!#oauA$IiK5I2pDcu}A3>uaMyjyIuxF(R<7yjYoKi zD@ic!$2BI5cda~#C%vzkKBd+%JW}zV14kZVnmV>iJLaa}m)@V7ZWCxM+4ZCJErFWX z=fwlMkM5DyTLsroM#9I?M!?$AjpV>cc%>4iAIEFI%t<5SLhp$aH_pUQam7}XIjv)6&G-qPhD{w zdIuZV9R;`uoj={JTVLepQBm`F8;Z7UarqJUfzso`(KK*t!?|-#@jd|Odl(;q6D{zu zIYKQsR}4v^DLR=GHw-0(S=gr`1k`}pUacdoBhl`pjN|cO_AjmW)ayO0~mD; zNFC2^BDP+`xrginCc}F@~cZR0eN9+jH6K!i1#=NAcoVf&}rxp012lX#EAQmXfmPib_ zrv_22`+%S?Pdl)UrvX_tu9jMc#{qGH|AlkKajm@Q$Ovi(%IT>a_mL3chuQ1!%OiEcCYEyL%Z*1ds z4zvE|;PEBEhDEI1Ujdcsv3~cHt5zBxJ*Rj)4spznOE>Q^7+lo>e!)FU#V*gjz|@ zwX&L}?Zn`KlTt{g)N&O2)l)<{e?%uMAni^R(NpO7B&4w!=pTCACi5ZZlv2^hej%4d?uwWcDEr`oB)X1C}&+Y zNPo7};y|f0*stImCp{d~%-A|Hbqdt`^Q3nz!LRg*;2lw2?yf$te7$3*vd41j*fp}g z3g8mR@>hir{_y|h1F)zeI*zVH*;<)7Qk6|S{!e)wn>orC9nq@0YvK$fz!jl(2>EC|8?_OoZn7xNY`2voRC6)Fj6}9QY0E?^WvNgzE9wMpY~MdWjic z;A8dtc?OlMX~bBYk-pd7;=H}+S%l09LqF|=x^FU%!Aw6z)^}w5(3QV8|JBkm^jec} z3nE}x^pU(n-|lf81Wa&Dj4;c~s7*jJtwbphj*sg-kp{`Wf-H&l1fcNevv6>l8S0^C z0}foS(2^hw3UrgKYuoR48;=o-qNu#=86m894+hnp3(`4C{@EPA=~a)~9^X74DZ>zo zn9+GYHDVD5i~>Mgu5G;{F~RO6xrl4B;KQ`+N1Z?S?iuFbK9lw=RW=VF2&;hO3P`cL z0FJ(s5eD8$m9&J$t}vB%h9HH*T>a!j0P;y5o7&ix-v>CNJZOPh6%=6|)mga!T+zC8 zsd-q|wZsA|a%c|GrnaZ-#h}`^Iku~p!4630g~+I}cM77%H{Zvx-C${AY%}ozC3hk! zF0!Qr9l^ut-#7iz39n3ly?a>u%#{N&%Q31~J0_T*urd#eKjPT^>IiA(MTTSq%NnA% zSocib^mhT*Gp?3Y;2Ms*xrEF_2bElBO8c=7fZ%tsU1?Prq7;)dDnXC@-vVVx1K8=T z7xCFf%;^Kh#fLCGyrM4iwbagz=?>)i&OE1Ip71I0Okt1%qte<>c?lftF$%-fjv(ard}Zd# z3mzB1>x-OQ?Tpj(>6VB3s*dJxi`3NvfU@Pvw3DV=qwR*xFpD4M#-L1naNj8im{NB) zJ12~i*K3U5W!1Jgdy5OG_$W*NQyy?|h3i1D11uhq%>-h{AI{1V$p3>*ZlhwB{4B}1 zeUk2(hZz){zz5tbW#mNO%CLLgZ&WK4D3ADEf;Z&<0|qx0twM-Yf(-t_cPd$%rbz*MLXi0w=!p*Lh>s&2J zU?9Au3^6bBX`M&iiX&JNdGaaG_@4^=qKVX{*rvQqcaGWJxuvh_mYsdADOuDJ+J=q( zHfBC(2;fiQp`a5t3o^hRsC$x;yztZ{B!G!-1(k%9#j_n-+#*R*VIPcGLq)<*rQ4kp z3`tJ}Me%oAwI%xp& zutC9`1Is|-zZe5c3mY!(@x!(NGX75TqcGyNG(X%caHfk$`jW-r@Ql;&)6s)L9JKgl zm@RspnGXe5rgF5IeB%prTtV9EejYBvOSKu?@wO9Y93A$>)}c3H(!*N8wBX6XePrL6 z-SBOSZQgcKFt9_58RMymMwFxNu|RqgGXHzlj>~vH_A=<(@8G->+?chn3m&^15WMs}6JQ<#jSzr$vja z6UiZs;f6C3hI6<2SCbA|T)rBgyYw(V4`}SAR+Bb@C><=*Q7@E`MF)gXRfJ8Z5`j(Ud8&z!L%5MCw&Z}zg^5Jb#0?28FN%9ntobjdQGP%4YGau&aAi%M ztQ!FO@FBh`4^xop*L9z~SrE=ls7`E_0;2clp@G7hP|j)I2h3bSjGz7n_+=?LsmY7} zc5EGDyv~xP83T)2HUz7Z>O1n48%!cYKl=diQ8#S8PAXWyErf~ht;2Q6F)c`ehm9Gw zZAAtEnIHui6LcbPpbodV%bgA=uKyBHTr_$S4Ln0f595Lzq3UwY^0M;G<36ZtB7SkA znKbTCdeXaR1Ajzc($#^UIt{*3hg;w|0eSBLr_siFKyInt7q(nb)jj(wOwa!{p?aj^ z3DGE_8(ZGh0L+Se66lgnoBwgdB=$3y6xj~nTC);u3c^ItL7=b2;9D%xZK504m}XoZmq(cSO3;*HUK;p+;vRiwssvW zLbjIvGLgF-S)Yt{GS@V|4QmV^%L1(w9cxZUAaMJ8=#pg%C(H==j#)|1myVeW+>Mzp z@54h2HsR91S?CO~I=Xg^q}02CPPmpf3hL}-Uebi>6D85V;K2R-~;;0DTMF>_#k}}OsugtpTE9gS0>wKE~t7-+~S>;eh+Yg&pBloUs0!MDa>9m z!kjLgiA+V^v zwMqRj4u}s{SBu1tqiPgGp&OMg!-qBERz7fQYHo9#iPizHsBc~B;)zp{cIFxpaKBA? zCR3ym+p)CI5=0|>-Nnq8AJ8&kSqPJ-*9AO1a81d2QU?WpjZD`alcqhGV6KWCEpGZ} zr5EpZaXg#n^?Oc9gl+SGy-G&I3@>Nod*xUWtxUaRq8+e-vBUAq^a?nvja<@QJ;js^ zu#fPq753*R0eNd!sXB3rG}S{syN8K=?7)7|51U#V49w`=6Bd|kK{$o<%VYOjTE8c| zVJ%Q(|CjhXW;XQa(P!FMr^I*G)vxQxs`QViOKfNTxd+vTO24&OABpKux4yroTYTx{ zGhMs0n|Ionp$U!p<9=^ee+EILus9y~5Eyb}rcx@dt*BL83e8$^jG7(42qdV7)iUEw zgc!M4;~!u^3*squ9^4d+z5aKspFv;Ow(xIqZA5+iO+eu};DHHpzBxedJf7ge*j`Kg zpm&c2{;c*Vz614qc=u+ebe)k1keX;FGyVuq0QIgID)@>Wv)1l~GOPsn^BR@Ll-(mNzymCkehUS+Y;5RVb6 zy57-tU&q{eDMEcwZtQuv+*^@-f_}w1Xo-_Ka5sBV7oDyTAm+x#aL_p+Jt_K|bHrLD zQ-;X?WmObqPt`e{V(T!!!7QiDto@~nnXeJqbsOp@2-k?2p)pxM^b5PvlS?~yyFnIz zEk0kAV|P;iym<`6?IGKhm`Udy5IMlfs_Y}Q=5@dR&&Xq7-cvzi=Y9ZJJ*?ljFw@5Q zPj=nD@FBBD%G0^!C9&jPb2pn4suQ^0ApPfnS!af4zgRjG64PYuXW(SMm|L^tT|0fP zj#8MreO>RpA-8X$s*jkiZ7Wr;*MSq8L%5Ye|CAZ<2*A1f;p4o&_(%(nP*1WGfwH(r z%hF-;uLy2h?=n!^79LfxKkbKI+1bqN@v%~{T#w|A)LLSMLg4N9ejt8XwrENOWW5Q= zuDcBkmGp`=NZoP*8b0L0v~;h(znx(ywgNldB!_}$(`D~? zpEdP6?{(LqYY{iw;;m)LfXpOsT$q?8u6Wd(ho;UY2cje(A`q^3T3Q~4uJVAuo(~^z z8AQc@Hx=oWyUYbvWIiAC_6?!cPOAL;^vg2G%Bf{YWQ`!;Whp1**XH4nhOy-`*?z6F z+D42sh#!c<80b@)a^uPs#VExWki4@)(h`>W<^x*ohI+x(Yb1z^e%v1#!q8!YYiPQ~ zgbWB%W~gcto%YRw1+V`eiO&i3n|<&=itzwZ4uRiq`m`yFaNLaDP=_Un6tHb8@ zU#Y0m@0=>P4otGAI@4-4^n)}>m=rixliEvayKhv&Ga%zw4s?@A*$^{>2R)zm#~Pi; zUO|&zWygsWU9tS11`XriO%YL!js7d6obON~rTU)NBEgwOW~{r3F%IlAWY_Mqorupj zw)>NI2{|3Al1du4wVNQzgY4ETU!gKC^JFj=0VgJ7jtm}^Cl0|SZu@#@j(HDt%ovIC zOxKjhOj0l}gjTaJ{Dw0jjj=Pl3PZOr^3v0#W{Vo$Vn_kujO07oIF!W2O3;uk{}leJbClCzdg2l%B#E+Uyw;|>+Qmw!_jrapy~DpF(OHyz zt=?BZcvoCv!)REcbJ5Y1q5f!V|2a!joQJIJVIOjUI%qaU2e?pQo>lJw&5G{i z5oGrgA+!uk2eOMA1^n$4V#Z>e)B@?uh&T<^oe4wsYPk4Se5k9y4wc_59^n;*~@(nnf95P8y zfyNZv%!O*?g4IFArKo(kX0#$dTXRlejPrZFx+n@i+!J{-3&%qWBE>6P|jF=@5wJEA(x>LFR^L*B;`WVd4$xrL2s+6$&7|mm6 z9701AEHx;Wxkj?WVVF*^hLF@mxZBt2WKj)h1CtzZ{TfR51M=814X1je2sO?;Rm@Gu zHkE?*3j*R)zG6lE8(Vh6;r}kto{go6pL)@6SO?6xifxS3FBDab)*lwB%Wp zl~(duWX$P@Q9vRqd5l4j1SwG$fDA!-zfNzrm@J#We1r_eEab##NHpSg!&xbpbxYkI z{y^wGHIjyu?Jpo)p%qrtPuhSwQ5CM;Pfz3Q%>5=;E^7*BXe{dMA$~8rGk?37UAH;B z%3t{+!(lSlIKbCh04oxAbBeCLApc0+0F+hqR_SMMFpW*@!}x7G5rR#B@~QJbp0MQiW2)E>1* zjM`fs)U3TXtzDbgBenN#ZDNx~M2!RydGma~zxOzP@B5d7;E>yWU+1{a&vmJPZ7*;O z=?RGWbK*Vd00{EnL z?nR;>Q @TULbxZo27W7;syCG3~QdX7}|nG(Fnq>~XtIy>1uFZCC+v8;9^SEVm3f z1XK_lK%%kY=&#=;%0rTF{BID@$q1V0QVHy4`9&8|u-(>SwVRisgkQzLbf*|o({=Yn z>xPale3}Fkc1jbIl`hu<{EA8I}nO zSXL-<@1hHNtrR$Ou(og}f^@}NyUrB#QHY`QHwy6tCJcO*&MFX%j^Ut60UV+E$?^+*? z|4TI}o~2t?n|9(UqZBML5d^9A`Q$YPGn^Li*Tox|uV?w{v|vH6%$I?$bYXxQa)MeQ zo=9b97I{MdE13Hg>v@Tr;LmhNF*blA{8#a4JyMVVmr&Jdb|HE#ZyH%{V=w-&_$%JT zLS=b+4%>t8Nm;ahJB#|Wcblh;u3(2^K1rYTpeqJUAwKwh9ew-J0?-p4bZLfzT(Utg zx(|$hu}tw07qn!@)qo!|<9V)%r%ko8D81TxDU1JAAPcNqhtPwBEMV{Lnh^f+^d8Zf zv4B-hvAF#or*3#IFV4@63~<-Icz@f!AWRyTScpfF{B>-g`JA+)yab82+TGNVo}20H zncmshZ@hr@X@Q{B`EWcXZCKPy1@dN-MYG<2ioAm5de`fK%#qs6z~ttPY_;bcsofF$ z)T7&?{`%*RJdukU?wSiC*t54^uL;F|jV=cmT_U+)glydQRw8P;a7{!98DvtYG{AdY_{!=TXH zCvYuju}tD_o&`+E3Bgd;YYGOepi7bdW9V{P$Sr2R+7dH-Q3AOtyf_D;?k-R^$l9x% z@wNK;vNqgxAM6kWo0&&LP!;%Zyu+i*HOk^<^`fBuw(nvCa8(WC+VoTiTnM@oLinCV zz)o+gD6ZRVyWD{7XR{;U95&NY@~7<==YX>XSeYOu>|);V?$dv`=N`E4hx5NFWVyM% zki+x2i?gh5PW(36UO>-Y1|Z$yrpd4jgt;MDU}5z2ImZ{lRVG-&uAnKt{}ebN^gq7q z&5GgVNLU@-+XvfYbB9sxCL?5t*CY5_Vz;em<75$x3{x#Fx@p9vTby?mX#*~x1{IKX z(f*=@Aq%3GL^xg}0lZGyr&Ewvd+Fo!HB~=QqaJdt#U223o03n(OaXWFPSp8yZk+?s zy$lP2*guw~xGOz(_4@*TS2dx0n%%iuL5p7N|N*0DmQ6s-b(d3x1G4MEH9L_1`bZZNmhZVu)eUpc4el z0k@TY!396JX)hT;WQh44Betj~7RY*&j-v-92FSGdJ8xJ4I{#y-3ps6_DYtVt(4GC$n=o&} zwakon%0ccoLq#_$7xP)&;vt8cmT7kgxUVMe{FHCuOxYh`;_*igc4`Ajz1x}*$8*Xv zH&{jJjrKWmyIvdm_l!jc|IYOmdL8z3pB(qPIn^GXoPc+(h!=|PwvpwA_;I{$wU6(x zz|GXV1(r}YChtYZBq|oqRjkvHP?0w1!U2ygz5C<~7PSzY~^4 z7MF7FYJl5wq#EQ()u9ya@}GDJaATPD2~y0kM7Ont5~2V*lNXE7V85FV%d4r25a{{M z#gAJ_6F@|Oz?l~;2shwfx&+z(d7T!`pu~tvz~w~VV2`Q?86wFPxV{eq zAsVH6F5B>%?`|sYyl$phR^_lsih{ZmrbWHa>t&c~C9#`eXk14~epdGx(Sl~^@ytRV zy3_VVu5#va`hKqL4!4K zbRm8QtJuLgH-htyuie8rGKQ!r@8zoNw2_+-mOA6!jbQ$>8dT%xwU=E<-R0$ap#i^< z%%59_D9Y zw)wnt5_inm!iX%;3otQ3kU5;zr1zX_sY32a zoy-r7B?ahj3j(akO=KJPd=^-Y>-YOg_G00iQ34WCE)$B7x*9yEy;RfIrNP|f#+j9TS1PgX;}FF42KKvL`y1U~@$ zQw;`Mo9}m6La{Feoj6{V6fT3`ES}sd!Pj5AV2HmwO**MXYp;BgZo?Tre$CK|4Pf;O z`Y~H?S6_N1<|2W47mko&7Jo{xQF2jWWok1!?Xf$oJu>Y;-O{q{=SITi@Qh##*6oBq zhf1>&ral`N5OyG)_1H>tvM#<$V5qy^0694_$tr$}31jST0DFwsMCL0A&HMR%CTe@s z`@IYp-$9=+XeTPjT0;mF;4HV#JKnMvxbO)qiRt!wC;kTyEW+13Qb zKj+BkwORxD1uIV!CRIAEs(PC+c@VFuS^T|6ovJ<9P}@&X#0~(;<5)y!`SnrDSAtwU zaABs1sHL8Wyf^S!Dd$y{SmLG;rTuvA(LiAnz&9_))5K_}wm{lgyjPwuBIYZ@pKliC~TWIJ1*s;#SSxAJ{P5Fnd*>*>f7R2bA6XJ3!7 zF$OI?YR3@Su-_M?U@-2o+V1?fcv5H)DJ2u4Ged|f`sTA%q^_&2YnZslY*^lRmuJ1u1w`89jo(=vL}v`hu0ODzTaI}> zdS5StozvG!Qn@8%0oO>Go$io}$`^Ta(sEv3`X$oF>M%hgjC|2bhGz{&wJdp7gV9~+ zM4K;ubbY~V%e*MgoNc7Rvg(!9Awc6yTO*f6Wq9@vDN#MBChlD|@Knse_Tc-R|0gSW z!uaqbswC%3pC0w|+RY@g{PJG(Q({WPt;)Y6N*v<3r}yc}$(B2ZcT&5Pcx@POdEY&V zA(Aiu9?B5r#+<$C|EaDuW(q&FA`XdOkr@`(6SW!jAPUl|y+GKkG{tY3+04ro$ z;jt^Nn^n#6GbU4Jhm~<=>v;}?)<|lm`uWd%?NQ;oU=$;Vl?RmveZQ*;_qBTP;fT}5 zH7DG;l`6wP)5V0=3CL1@Lx|c1zG-GG2AoRs8UEfn+>a<*Ro~Y4n!IiFjrqCxk*rvf z8`x3RA<-H(p#}^zqzE$m8(=#l^|`QI8)7vV{;^$>$i|r?C^Djn1?uXSRN%#pnq#0A zx&R}gPqaQ4qd6s$13Ag`gGCsRd?E>?bhZ^f0!WE&Vqchb+Z@^&4~_ZM<`>pR*LS)( z+qB$#G+SKg*iCUVx!<9KJdlgCf3q4J`o^2&yTj~~M9e|oJCc{J4Li|do;bek>BAd9 zRBn-sb+PeP47c)E!Sb(RWciC2j7OR1PIv(axFJ96b8ou?@KBR6TlE-lESj)DpFG-U9CYiar3d%T9!HfFXIge!Kc=u ze(oYq2@qGrQ$*7`K+O5h1kC!y%EOA5ybv|Z(Q+AO5-`F7F~-i0ApJJ^SAboFPGUye zw8V{$NN0whkdi3cDrm^5j*5V^uzdetGaFuMLV(CC!p>noT-lra2a#Xd8$3>j=Z6WB zs8c(})Sw%EOzSj~Z!|yB+k1K%?8oQ$Ji2aM+59E$9sA=(;0q?!3jW2_TV`d$(oq*l z=L9NwT4Ar}MZ{e9W_(@D&v?4~2W=IDwC>2Ml)%e{_!Q)wi{$6rL|5Ik9!GiHgwoUo z=Qsa~_;=#IcZ;VvYtOkHH?XhlstrzA%JJ33-v+)hu{!6aaY6S>Rk z5S)%(<5}?v6JO~nZ-GQGD8u`Rf3>o(y)w$%AL1;Hk6)ZQdZyZug7W*E_cr{4qNU#2 zuP;T-5}D4rlL8T+)!l5ibAycEUvgKQGQWU$e|hv(GkubX*OYux0Q=YIR@FJ)b2ROj zxbN0CkN$N7$uvHTudPj`d-u&J|nqrapx{-SxM(W*>5qTgQ)EjuihEK86* zB;9@JN~HV2xGU!!i-cC~hLA^&?yp?vf~5C*DX~?Lcc`#)Z3Bzh zQx>l8YLMyx0iF5Q!xr^z^`e-tR{Va`9aQwUT5jv&92Axcj0o}Tq;t$XW@u#UYQi;cx7r?B|5&9XbN4b-OF$Q%W~TkCK5$%7X;$#{%k;Jm*37V)z6Is z@!LOjS_dt?S`jfxBx}>+s8AJbr`WnHYd^7selFpYX!pn&Eqp{+EntfCf!u8lyLh$4*Nv%%4hRll_xZULIDBmmYDs;m+9&Ns{{TGnQ zD>2EBog36oG4W5Q&PZzbvCsNsoII_u=DDHv3*Hun0~%(@GRAW6=yZ}yIGub2MNGrr zt%{l{i}2b6gNLP&f{RvN;2oX;`VJi!q68~|) zwjMAT*9myvz3rNtPViBsPK;`gNmW7B7Ze`D$xby?U%b{gl~$)|AQj9Xx$Hnn+U?0czFs*k6VJ|%3qxuF|*5x&8G$GGf2D2$rvrh?-! zTRvm=Yuc3b44F)gcZ<|iXpKzoUiI@?t)L@87VF(#-~=A=#HtLwa{L02%HRyQu~qw0 zEYndiZC7Cd{+U<*4-aNu5UySxA74*^KI9ks%$U{)@C_Nbj%^t9JuGGVW{;~iA@E_W z0Hw#wz7Dl3KQM%7*tuP&{o0_cW8K|;=IYa8^urwd_S1y5XL`Ji%IY*H9%w^Gh@bJB zkK%+Ix7`+AtlOZQPG$wR5>uU=j@dQhOin${2Kra2sYKjmQ`d|eP)mFI@(LCR*38-y z)>XhgpsBJo`2~d%)S2vJ;f+fVS&x=; z8a!8r#5==B&t(kHcngwrnpyXy@r68Qq*wof;u3aFmyn6O-5w_%j_=RhA`MvLgm5w0PHkUaFPu)o#+ zkp{KV00vy{^?t}eZdnu78eek zw7xF{Wd7!>H6BUt1Zqz9w5{Ec>>%Wcs)_t4w6X^HrYuZB-->}|J0qZ`1eXdaOnckg~<%G4!`i~_JfmH7jMky{s$5f z4))IoT9_r}dwOEHFktKA#c$2v@oaI0ky%LpDrNMnKvY^L$GI6xh-paDIyfnR{TtO#v<+BE%qD#KRi>#~9yRMk6Pql2;bUAfl< zr9jta@n@K)J|%badB0inzhE5?Bb^@miLG&ED+K z%w`@ZVXgp1rIE?+b}ojXJ^q9VQq+Q3r(_Jy6Kg5qH};EpFqX?v7EZD@NfYHQ%MENy zB%0!QfgP>geM2NTyU>;O^QRjB7e14{jz7Mm-twM|r~7sPm%WONUgm%s>2 z@=myqV?Mz%tKKSY##k};cY+m@^;T2v^qmSerY!Pxg_Xza8TZmNrib50Bg$`SFi;ah z-8uOSo{!pVk5Ny4D(!#CaebIb(NmfwH*u3-Ima3z8TfuhVoh0z6Ds0aOYd?HoQ&wy#*y`&^uY{JX#-r* zgv{(a@c4U^pP$<4RUp*$h~G76v$Uc3UQi4BfwSLatz09eN286jRY20LsIMDpq`Th` zQr6qZ1QHx`+l{IkVEg_A_w%N};fStRts*{v=Z&my&Y~Q$(Raw%g$et9rBisMEPFqM zWr5$fB?V{+C^eg^s4-6YGCt3gVg5q=ev55ukK%6eCL?1;ue^51%zK)_A(4@B0!=q%caRA zsuyB*A_2TRD|)VIIoE#3@r0Mq*yN5O?glW&Tq)!tVfr+&d7E44y_Zn4Z{K9Fkow1& zcOblA_;f0YNSIo79<;wuu`_*B_)iS!+LSoq%xohZ`Ru^fhV8pSwPP?P5l32!J`zFYG7a@FvU3qtnU8?enHvT`LP z-%KG3Y8961R5ReuTj3i` zesOf%XMY77K}?L61z)D}X4@weNWjPtSFOHLZ2cFRA~=rDIrJnZ%zk@8+*>1h7hiT= ziDqC}?6X?M3MXb+-L+ZzRWt;^O4t-cm2X+_WD5SYRuF#%Dr+$!tNefU!-lW_pNCMa zXf4q`eMK7t`#HDun^zn`v8K2p>r6zHt3Aa4Eo|lNX_Y!pwnl*`1wm~{`sx7f(Q6kO z2D)AaKhvZ_y2>xd1w}kjbKq8p0+_j5Z3bnjOHF_XRANnyas^rBXLAw?=dZOrnWr`H zy8Gd9kv`Xqy#rbhr%UzzzEY2#mJ^)Z3}Td(D4|n(u(PmA9xBF4mruAp5s)Bk;0be>ioed}m zJ~mYDYIdqR_}y*x2}FIyk_MgrePY<25UI@TmA6@Zy|n^!9XeS#SG9f+-hfYlYaqg* zxf0Yw+qPY4U|;m^%Qy0m=U3}rgiO$PMv;yTIjxcxey{R-3UY85T5;IzO)hOs8y|x^ zp?j_J%iyP7&+2Lhr_30?Cbn0&+5K)sKHJU>T}y04R#rM@KO>0rOrU1ew;!i;%buajwQEcsy!44E}^?V8zQr3u_jDLEq!eYc6)OtxIp@?Lf`O;o=s5 zK}i?lVd1_g@^Ge};%meyu>$S58Lqncy; z>K$wN7pMqYngfq-`Jrf0rs&@9HNGNmd{KD9YC{W7jBlm#W)kuge9TA&(Y81-$w%xb z7x2k{9pT=@Ca|p$eqoz{@`o5l{=+dn+YUodCXo&F z&p-0&Qd_5NZT8&7|0Ryt`~`>~B3Us(j@>=2OO%Bxfx&As^&ee}u&?e9d4aTeMca** zGjTtY7qPxlG_WfxX+^zz(3V*UH~w`?NE(rK)Cubd&79sAtaRH{FRo_b%>33K=s=a( zeUF&Wuogdp@_jCckqn|s;%l=vGd6~rSlFW}jk1_baz1AvIyW?2_OD8XH|&rfy}mhdfYIH>FnLt58nd8HwO+8}C+h;o;>~xdDlHbFdx*RbEY|&~MAT z+xXr;Yb+I9x73t;dP+=f^OmfP4!c6u-$#ZNa$bG}Xku%;NSVkA%Sx)0bxHKq3;#io z;Hzz^t{-Ib2lWJ=h^i~5&q(_{)A5PARQ~AS6~G;>T@&%gXz^0Wm0b^`KiI&eD!5=| z@%3x9ZMVnJuK*6?0_41`ZF?J@)F1O{oEa83QtKQg=6JE5|FA6e#XS;0_wqXAy2g7I z16~89&5{UQZOUTC95M)~wYC-Ns2D4w_Ir?<(qtEUiorcyTob%HbwD3)c5s+T=SU-u zAeuS)?3qhf8FCfA37we@IDKfUYWbeiek6<>Z_ALmhr_G&wlAaypT~r)MnMWad-5X9 z-2a0!IrE^xZy{GZ;8w)=1FV>wDySH_oOHpoNclvFZSNvvX4@T>2H^qW!Zz8Vvtog6 zH$G&A~Dh!jn4Euc*du{Xb;~ds?Pq=3s+FM~*SC-5?fYg1fridBk z=u8T(#}=M-Ll^w^7&=usGP+zs0}(@GQh_;zC4hN9ByQ*+B&qR*Pp=LMIJpB}RI=ZO zoBG+v0+S1l)hW|0hoVis(WOU2N(>9(n>c$tBxn9c`1_Suy2||SZ^Ret=OwL49PBA6 zK*w(!p*^bJNXa=UbY3p75M?!h0?4gK2`)76O6UzL;3qwNjJF+V1qkYBQxVWlBTOH` z@Ke{4@Sm8&O^OS4ZQ>e5BOi{MK`}eWJtL=)^g{~5hZQZf6&*Sxwzu@t^J2Fd?}cW6 zL$Ul!DhK|^Id_=6_xm|_5sRn0{#b$*p<=Jqgy()niTh;G4DhTMF1LsQ3Z%nYZ~ zrs6!uzVvbQ_QGNAj`Q$H_9c=R0Vz@SkROHa?X5~&F#k5!UCb?kE_fXsEVTh6=M77; zj-+m}gN4GIhUT7>3XDxnD3;$?gXn}~<0w2HE`4)v30i*G`@3f;WAVfF#xMpI6j&Dg z(;tu!Dzbv?grK=G|Hd;6!09564bSSbDw3AiE%|~{!6GOIUc#NqOE#r;fL%kfzGm5yd_0#^I7ZHv7ZM8=c_@-LSJD|1;n)1V*g2g);zsD z2k;aB|yJEP$|zz78tTnG`H z4};Q2-)2(AFK)n({7f3Z2-kTprGv77>OY;X*f!4AHqI?wWUpu<`7(u7ABS^1Z9lsI z2ZV-D%a@qr?45e~fN;BwX05}d4VCNNYuHl9gO^mJ7uQ$N+QXIAQA8Gb2;87&5nDPadiHKu10D_+ltvj55{@zKas zo4aQ;%E((L{9RU$wn%C^NkA`Mb^3>3eY}znlUsD%5lM7#QW~ZTlfsa`L zHwp9ZXLUa2syPhb`}FtYT#Y!sN+d+WEQ@s&%lGBXUGl@=} zP2SPm@+D;FAxuuz&3)>^qO;LU4aJM~;KA%Dckbuz+GwmGAtk25$6DJWoj~rOL@S7e z-EBgsqL(>=8*`YEhA~N&%9B=*3j7q$6sQZmqQb$iT^L{f1CQL;0RVMDiEZ@Kn&#@v zG?b++0SBj*b3&;qtL%v})e6V%thaU$&ubIvv*2yvD1F!r36$wGmL>}>H5}GF&LfR8m+e7ZD&oZ>b|r?uUrtU2 zSJ^W&cq8R>?l>eN?EY#lL=9z!g{ri<(QWGDZ`4o0!OU~3q6@1qK#{G<;@Y_#sNkyW z;W!s6qaK$3S7LF%=TvcYe>DQUt* z$opw7O82M}X?ej1|8n)M+)^>DY4gZG-2!^alw($35puZIp*|hD#~}Tuy2biMm%{OS z5D)eeG*^F-k=P6!4084;JlU4u<~RF#12R6VmU^Clh-wT(CdJ+J3dG{r6Zk7-0e#<9I zkN-N@Zx*o%q^}xEO3i{WQfWF&7SK>;;^+hdLl3JMKZMdE%u>Z%Z``?YoVTIt2P|z} z`24im%xsbRKZ$+8iMhD;iGNsdviNd)KPW-%KBW!L@cc^NXR`BA~E&IBTt?EEPz%sv2P;~mT zV#73)eEWPnw60CnjM#SlE9mEGZX5gHA`}z8GqQN5O+k=g7U6z1K6^7k-iUAQlwM6( zX7?Go-isuDT=6>*>j6y?N7?clh~EY2Kr)5pM?fR&hUcu>Ohx@A8Gm^mB=g+GA@Orp z<~n2&Js4JFSV9OPrRx=nf3UopP*Ni-|lbX)2;oa@o5I)kKBDg%VPC0uq3wAu;0g{I0d+^kU_;bYO$>_ z;*djd`Bbzo45c^UU^x+1;~jp~i*~0cx)~D+PN|cd(W%R#YW*bDvzBDxv!TDJ93u1% z8QAFukQXrELy!z92ql54j_5BgNxuiYent>0z@I;Qx@kC2Fmnka5d-)QDm0%`a7EDU zJj{JYfS+7xFdwdShppKuWv07+2RD;H-Tp$Sc23jYBThZ-u5e6+k4L0hnBf8?zT!jg zM;pGj;aA*L!86U?B*Z-8j)1_N%oD{@djqPR4{sg&kX4%@nVU8au;st^a;n=XdWxee zuc?)$`M5YR_xOSUi;qk$>qLSJJ8wvp1S<7cRcf!OE-C+!=!QO=o)y31H|DWZnyh|d zZzZ=9l68~{Bx9;*#4%H*Em7o{sJ=zHSIK*)Ogk;jz6l@Mxwm=ykJPp6%hF6LDM!}bVH@nnKwtJufe2HiHj7tp<`WJr0 zYVfeS{YDx=uwb?q*}h0%_&Sm}ZV^n?RvfO@VQKbUeD~pq_0_y9iMm|VdSbM;^0yx-5;n!jTORta%cVrS`S(cr`)!tpD77x41< zh7U22RndYC@L2uXCllfZzaHvJ8yoNpacZukUjPeM$mv+?sQjP#)>7adW zkDo$5bCuhkbG;8L3&eTc*e}bhk_W$0w&U97*;gu;NHfAp%SCgHp346&h6AN~%z8RQ&Ld^h_hip>SEB)?U^fn{sMa$PrN{O2x&k0Q_sO+oF#C-hI_slo}6d^i*S2KqKdzT1v%0wzjWE#a?kk zEN7-6DUkAL+x-H*m|R0e?_oD#hC7vkQ>T6O>)=Cm$03%D{#bEY&Fx(Ko#s*~IE1## z+lZ%UKJ)@N`>Rurf*^<2y7-mBMM3G(QGTzy8qRQuG`YPYzVcDaL!xM#l9a6$iA)R4 z$C^A#sP}&Y`xNKhJRtBsc{9U45#`im8Me^w505ZSk5*W3fu8f}uaFB(d3Z%RY%}m2z^?umaanrqO*(0hKf6>>k zpK30;=?%wc0y0y&@x9p%mU*wErc5}!g)R%;Sd%(quN8(&h>L)_ao+H~+wPK@4exw| z$INgqZ7HTUHYR$R;Y#M9ct=VRXyjW z`_l~5^v@i4eL-@F9+{C|n@@WpyrwF~>Bg!qKRzqbBh6pZ++}`uWtsOmdSIS<16mES zCpYw~`I_7k8qED$3w&#n{s<%7I zji~v9O`qNd@rH=hLhWkTG19;#G)1odTmRL2Ii&8hK2QlTGNAC*rz@wc@?(Ht4(RXK z8)Fj;o3dM863te;xEvjqul5WiCJ5LJLbnxOj zvFI}qO15MH`UStikl3_%?U^p}=a}hE>7A5l7ZH>8Oq>mVc2#&!^(ClJ(|hK#zzR+V z1=1M$v!87gRNO>7z!OYM{e(!=B}h(EcWK!{UoL%e-m04cW>%?2X)_F`q-SF0`Oi@_WbZ0VPSlr>Pi2 z@H|6Eash*T)~>L0*H#f5_N?yJ{q&9c#?H!|*T%av?zSHA=(y36!z}=wBu=akA*9JZ zs+g(JYVsC+yZbp*#rWw0xGH{|GMVq(|3TyN#IoiaB$4;gVmRz+;}}s5&oO>5BK`JQ zb+#q;$tC6$-62Qw1DnGgjH$#`9(QC26m$uv7^hJ?;tAkr&ZP_u8KE|=| z^d7taLdv*)nmBS94bpwTb`6|A2gWeB+3dBA(SLF}2RLq4P#40XkdPy$vn*o629*#FZCozw{8@lP& z$}jtswVKgogZ|L5%jcR??u2N8rpWzh>0xm)h5U-3IQt zBq?5r8d(|&(=*vx8>klA&`e^F2~-S(6TBQ5vZ-DCoaG=-FTgS<@mfVKCT0QO!=(;k znU!}1=mU;CH%SXr=PPwBf6K}DRs3n1nTqY1 z^LJvFQkUqf5W!iot5kUjm|!ui^wdEVQBQ)BAUcHEczX9aYl!;cQjlI9z5 z7hsG9V0upK>JcgYNrUf`kxen?wyh0)CZUtvcdvs}TA4sxL7~~6FZ%oC3W{3<=LO+o zqa3*sfPtptxBVX`V%6N(LpByfne(QHS7i!ZCnZdo@bvM9k+&q?6Mh_MZSXU6KoVL}x^2bHeJX1)`H1qqR98Y7sl4 z*|8L&o62)1{P&XInbTqG`SC)vg3@Vc7Ii5AbV=xc0K>J>-*sQX}+hw-@Q@tlv%-d`h-On{Cg5!gmV!Hy@M{M zjvMb=*%@8deD5;wYH#d%weioE^Z;V z1=JxJ+2yO!W71Vp&v9zCj3%Pp!;(spqQ04#VSwa}f%0#c! z@nP13vJ6UAMlK-+!=$JD95O}PFHhcAb6MMYeP#aUjE3dq$`q6oHFOq-(QO9`gKQ40 z5L*>%)oG_oE9&6^q)1&%uV;cHbk^%iZ~S>w`IRsmtv*$kr^K~6azV&;Yl2QndttG{ zsMDz3hI)=i`5~b)UnqNq>r=8#P0kR6Pd6`Y6Kd_K9e_R0FyFg+RFWDlb(82Hw~lu#vC&TGkdOUNyEBW518^iD~f z_GlR+%;yDvOj4B8+$_PXNy_;kL*@WACrS8<(`An4Ng(9A-s8QvM&Hy+xap~Z>F2Bt zu+Q+@aAmy9Ddo?TS5p-kCFM2we#oWx(?C_;^3{GQOPZX-gN*S)DVpb6HS1t9_eadx zrqb4IXp!b29#zu;1Ke|KEis9tl-$lI(4R)X66>;m3y2%l;rH!pWRT(A`|i*>(xkYCXKzRKHc zeV%|g@3p7}@u{}6y~ltim}G*dYRrME1b@W##oV{vbvx@R(ACGL_REk0$=OkCVzByl zs!P5L(`%uYVc$(|cx$Yl8t9Jr`Zgu80QAz!O6rY#%Cby{)vH9Z465?XUZbvx9VR^` zPZ-}bf!Lpm9`-`-hQ=eaO}nh?Pzo)lRyXLoSdNuP52J>PThU7vdi!x5Ei$dI9{Ufc zr!*WGuV!3oA+I?Un&BTk7uc^es^q-#Km22atw|O3d=p4xaYs?<(k;}I4)6Hnv>OpI zA{$KEHRc?w*CziQH8%BDC-YX}@ML6k`UmGrNopcP!J#)!()B%-(vdtOZYfc2=qq2?j*GG*K~HSv5f@oDV*A&>>v7G`%hQw2Go9(Ip7Dn z5VhmM`}|m|?Juyw%>vb*so0t|j%n0id1|gr27VIH-z`g@tP%1&giU@UT4-eXBC~&O z05I^=8;X0u3S)nK-+rgv{&NuBGX5SN)NXB06&7Jh*ms_1ymuPS4pLA{rJB)~3=VLn?Tc`flPU=bg;j}9$;DM1}%YI=|@^T{3| z*I~HUZuMiWT6^->t@BAGn}LXEViy@s0SKo!m@!ST{KP1(hXNW-7QFpTHqI*QA;TxF zBlOxe*!j>CKD}hjXftJgpkvg@N!55~u3;#)P4${qCMZYdIsb-;n&VCSLE-ziOq16c zIq_m#lP=v2H1alr`;S{}tP-o95aXc^y_7eVf%sr6;vIm7F8K#L z6fIE=-8T9xSBWRF$ii=SOS>;S)gkJZlGeO-(dp24I$0e~Bh5Ji_BxN1^g~r3sb#PS z8;Z#HA50Q!W0^P)EF_hE8=ax#7o+Oaav=5~4{3WZmzcSyyz8+%fRSK*(vBFhndMOtr7vl@XTB}Aa|4Sm+Hx{PXr{N@|AA;ZaWCZvR+B`xgwkz>YFy-o;-80H;}K)-|-+ee&aot_Ui;9DPjnz^6|_Rs_F4{ z{rzfZ*X#6xbY$FOQEEfl3K3Uk3O?iCCMFzp#d!xpG*?mPy3O9(vUdYYQ|73g!P;xN zGFMDYoZ83`JBemXbd$>Xt=W8YKvR+48~?M@={HB4Id)m-l-kFeg`>x@_d zzphuT&R8#R7dkedJ~r9lfkyD(`Q6-5$L@tw?M*cuOptftR|b{dZ$TJ5yB)T&x1KQn z@z6Ssxg^xK^Ye$bJK&8UHtJ~0aZ-S>9aE=fU2Rdiy)~3~yHWqpruDlJ=FRIgy3y@p1WQhyQXsd?h=0 zQ(!qKByX?U`>oksWb;*Mwy|?|$+W!(l1pUT-Hx1TIkCDRdWT3@M3>9vyxr)A1VOO& zx!LpmF`Z_Yi9y11st_}-A{Y9@G(4Y|)7Hjp{SHWuee(5RN@2G+qF8Q(rh)w{Wt*}C zH?IWGIJ5D0PFS@Ld64GsR}{!3B$wx(nGOcAo1=xXfvs;bThJV2?Ku}u{)SCz^BPRDt3$55HSA$}=<5)sPS$^j zD7JScG!~`Drfo0KdY@A^gCL(=t;7n&Q$5y*>Xq5kk7@dH>Jo&rSYEP9KjH%#YPkxl zRfXh~Q`q!fd3i&khF@By2&F}2_@OHeRmP|MiC8w@ol-^WfWTH%AfC)oM zwBtMNHm9278lLZR93!rID~3wFi@t%p|A;*10aV{R-Cqn&Fo+P#+zs%ijEW__^V{k^F`iivSP@G1z+LM1?TMO|B7576 z7^P`V%?+Yr#w8q?MP2+Nl7*j}7TNnul+-u{vFRb+J@=)ABRh8ixTBBvQ07dFFXf-qk$o@T>*C#L6A{WA*EwpD8F^dw(T zQ+DI}h&pnss|j@+tH%YD_$sN)HX7qxR?r}&1VgyBzuCYtK9eh)^Y;ieIe%UZ{VY9< zN-#T>xc<=lDE6Pq$^`xnJ&N>}ed|K{>f8M+V>Nrhxj`Es+5(||6|*#xlcNYN^GVyy zA-p|aMQ*Gnkh1z0bP48&B2Cs>a?p5I)^-sP=`OtqkOY=5PVtQjvQV{y$O(bJhub^M zWuIL8ksrLGMxs;=yS$|T)F(y+5Q2|rzQC>=4-XHO)1;-XF^>5c{j%9 zRA5US6Hx3}3ZW_)52NeorK zZ?;Rv#;N=foK|13t*BhXaM(Y>0Il=Z2u;wUIW2NKG6%ZBbHa)bShF5U{e~UMIu_Op z*jlxPeIo;d%@j5tD&>e~Gbr?DHHDs*`AL$WM4@))F~dL=O7 z7^&*Fs8ZwQh9p*wJZ(6S!tEXJ`oSkX>GyD|XZ#g|^gCZxh!{SGcMa&V&Gvs~#6OQI z=dGU#9kNeq$9xmZd!#9dhlnN%Iv4D-{oyw-;RySo;lAq~6bO$z|2O?4hdxpk?%NWi zSYvc-BF^&`jsrJezj;DZ%pJ+tBO9tXG%!1lrov8i>Vn=s!H5^xTKiple8(w2#=W&! zOs%tIZ%CjhJyR3sD7>Cq{^UDS*|WiJ|8o~|sRncJ*Z@%iHT=`Rr=My=!3wBzGm+s5 zY1^C5=qqEq6p-wR>Z`ir_~Y*=b`IAAeLy^N_cYlOY=<`9HQG&PX&+RaUR{UM!|DyU z?4@75bxzeX(2-l?j(=f);1iH=_vlKlcfP{Wduw zbx0|BC_i}Uu%>4u8*N1e^5QoBMs4$au41d5{XEyOzT^fq5$WGBZ{AxCz4?v_=NyfN zPgokFDCc)LWbw2a2FX@ierV!pp%03pKB;12PhNi}drkBu@ge`p3CG8$$Y9OS+UQ^! z6)14ag!ywVkF~yPk4eP?x>kFBgEmzW&K}}Rr@=T)$oEr)jK5}h!Dea=g~e)@NIvWuEOuu4NP!UtAHg7ibb?vaM) zd3>^v^pI1v8A<)>hd6kz_pu@Kn_c|)dx;Li-qS&?m_L#J?KFus3R*n$Q5HITEW`p3 z7BwD+q!N-u;5ml%D@1KJV*h-IQr+!D_ZV&59Cmp1{EI_H&gPH z2x$q6kWM~wKddR?898Bj%1_)68M-qHU%mQLeL3x{=z$he>Hf5C@( z6#|8Bcej>D+V}YO>dDEgkkizGwa4h6tCvK(RE1l(4LF^f{Y}jhMj+TY6FgBg#~OLk zHg-!)*9WhEH(ojM>a3o3M}g#Eo8y5E#q_pQ%>Z1?q!|CkvS4!bmQ_}G9g}wZ?lJ9k z(cMmXFGwh@PfDT)g)gymVA~rj9%V(>uymuk<`Vm$fcD~g|ZpN3#RHXO?e~ZdUN9=}NtjYmcwwTfBLbFY8 zV@_Y!$07UtH)sOz_u_ep*wBxiPO7_Uikf&XOk3K@={?K>)hKjdn_s;U(24kiuQxW$ zR2T0yPjN#=M$jCITm#ESc#ro9!?j_Py432)Qq5dhADUW6^-SeZuLYdtS3}NS`{Ygn zSJ=^_e?Kg(@?OwmbJ8LDaAexc3E;?8JFqyx5QOw{C3S z_kzKXvUDRv=a;CxGJOt-pT#+Lj^jdH_aRoq{?2WB3&=to%Hs^D_wl~HzpenQwSesH zw)n_9I9Yj|itg}>OSFEgC|uWNx17t6-syGx1ck8@g_VH>>Swndx5;lDi%5fqUF_cuPFwPWOYz^ju0~? zE&Xyb@8iVjYiu?^AW=UnlpsXWRi85Ui71eu>5;7z+V_2V9KOt=qQ`x`?e zaKHRoiTutdCPC%P(r%mI+bCwwuhG?H#)dva_K2g+HrYHW5;sXWvh0TEnk4-K7ReD> z#NN(F8T8ddi_*9qGgw$_VN6Az)a#NOn}aSWu7w)AAR;%MdAA%C1pFc1KQE9{B7;ujO76$ND^SmpHHGKHi9I&S;fR zMZfcvjY6|)6o(lsyD2w_`Hufa>Yh1{VQDoFmqW^K82Q|1A?w}B5V$${rBJh<8&xRG z;ta6pm6Fi;fj(h}zh4CeOOD@qXt3?jC;L{V@ADHk=snrm65Y#tXqcyY9AY1Rp9;^v zEFzQQghjRjp;HD*mU8`3e;c>8A>IoRRe?NEhkA6LrMyw9W~EHOTu}6He`+GIy`Tup z$JV>2IlLP|wmIQXWH2w+1?FM7ZuOHKSSb2Zs|0O&|1^dA=v|K3NlpZ^Bi{ABT9~a%a6auT%;2?UCL9KM{qy*)51k1-A)mUR97Im3p3CV&&SBMfs zzf4T)fCTRSGtZyjlROFK)EP6PH`JIuL5{S9wFi@&qJ=uobGI;P8Ew0k*9aZUq&&eA zC4#pCCoi35T841;gj4}E30Rq$vQ7(MM)@|1oUgMM4JOlbuBLpZXF5{O{}WBREu&AEhPC=XPlvD~ zT)2z^$G0?wh|fa*VM&XP)3BcpIXV&F`&}laB$u`;nUmY4-|4JW#G4;{W)(-xza5yN zB3ERMJ=E!&WPckPFfhwr)crcVYWTW4t)K5-2P4xz3_(f~# zT&k2A`kATG6qAJKYVf(;3?Q*0+LIo2A>)Sscx>s{P9fPNAAhX92X3{ zDS75uJh@n)daiFX_Vv?=j0+wnoR^%qGkesCZ@&~j@p(xKDb^cMsDzEBF zp97N~WxP2$jKtx0?@;RxY_a$&JonMdi|C;q4ekgHHgy@(FI!yF+=r1j_m@dU-ah*b z_0Adh4-$rqe%u!(n!H>4oEk@?cIdO=3GQ2;+M?zLFqZ7@`7Kl*KPQboGjToPTqysrc~FyT(#=uxm{oFUmJ8n|$wELE*uaGC30 zHG#M`-kjoe12eT34j63n_9E64XlT6`7+Ct-q?LKvbyI#J(Hf#R@vDdGsH2AFE>8Oq zpYHC3+X&$$>aPzCz2N`i2Zc^UF>qco=F-zR=yhELh^137;b97GbdK}j5@M;EQ##`P z_9K`ui4GFRb#G4(X<4=e@VTHcVZgw~z6G@J8%IXP56u(^sezOKvfEk+q;#m|5S|0| z8wjKh&07Pq)p-nHNwPn>+yB)I4zJu6(VpJ3myI@JA6Md5X)d=wb~w9MUTCN@;y((%hKQB#`e3S1{U!XSkb-zAa4T1Za>f@*q^oqf$yhR!_eTkH=r1kgHw} ztb6fXg?W#r-tN(!c=*;V3~K3i#Dy$r?rOFe*x(1Oww*xN$<04aPPOdmmVRrKTJEx( z<3MBoMoq|${JwY>Gu1JIGflrB_e`AbTc_gaJ0lan3fygOJ_T5VM)AH+=%ueliQZzs zfcOzdjcO%z`HR-V-CM?GM^f?n#lsvt+E7>qhLXpOhDzQ{ zh}vx4KPX-IUp)m!_n++5L&%>0@5D-0!o)l<+7Bn&oM}RdIf?)(_VFXv(qGNWFFyUs zMitmp4^7D|_EfLR-l+xAECEp!uJy2L^Anuk+9#p1oN6&5EAHGE8U;nR z9=+dxr${geO7WsyPcee7dfW660y1RHC&aQ0n%zVfO%D^ygx0vR8N)>AUm&WrSOPF@ zPV>w_I<2IxJdBpGuuvm9)@I2eOE140ubK^hdFi2H2ry=3R6H9%) z_i)z?bl?+O(?FxC;uVtr09v!&3JT0%zEwtJIO8j)$&ZN(rYSiiCSRGy$EM=m_8(}B zsf*pHYPi8PNr#$y#kSp~g?vQopIQ5WGCY@^smmGz&a_-TQoZdI>o5GKm6_n@Z%ho} zs`wkSI7ueWpNh}KN4UHMIcxjRO;UBZwrI#XrJ1c@exhPzS~s0VKRr8bd_w$IUMK@4 zXwju)bud)D>7~!OrL(-x9f8H1G%rz3f=w!Yl3z(=B0c>@mX6F|d_qO$zZw@GQ0 zceST8O@1HD!UY>m?ZNGgy6^-YeT{p${=&w({Syw%qL+p zij4z$zF>xytchzHd|E({by@z9Ao;D##fOJ+=&$>}>;ltx-^2p8ET5~~-x?{iB&XUjsmWJ--H3l>3<2TQIoFQcFkkV~yyk)s)EVGS@G6mzF*u-Cp*ongmn3 z69uLXe8AdnyH`wYy~W1I3rL7Wp*DTJ?`y1uUK)n|(#wY)Y00F$$*dnb5dVOyk3t=U zClAXiuey%~BglA!UyjCl%>mg*RO{%_ow$eZ|RJkPoLk@?I zordA~IUd_p0}C!?RmxzQM1@`PxVKE2-s5*}0}#?JxA96}+cteK?n(L1nrJw-IsF4* zDC39vtbqMOYpWF|5Ga2-r52L-e3|~MqICr2`07S|0udk8r#<;4v3M3;@D5y~JM#H0?$zUUxN8$+LDzu!nK+00&kmtiB7YeT&Qz|~zDRDu z?*?^LoDHXok6aN86&3_b;hiC=)^Jnz330WScnZ!KS%vP49;(-`{-uo3G3{2IHnHaF zClQubM5C!2`4*d0BaeF7kf*_slA}{m4Q#k!0)9zx^GI-w#QQo&4rwtaGPQ-_n|5Zf z;5ljm05i$8W3-SQ{!VGcLa_K8Y@WzQ@$ErxG1?MXQM^#(u90q{dQk#Fl!h3!r{pM0 zK?!9TCLY6(1Z)!uFU3Hd_($oHpAoj^e2v)hBoP|z$Y zV+{d@2J=|?vEBHum@wVPWn`NV5zAo%$B5n5)CgF7T|5N!_1G5CPCd1fL*~Ey>e(jw z$PBH=-{n$3l=$3oYcRL;1$c-A^Q~2rqz#S{#kseRF^l8dl*La?UTlzUh;HwAcn1ZG z4D^UCw1iH6T+shgM4tNj=8Xo)F1sAYclLCM$an9d(XIDnYf&DgnVe~O6SW9HL_@;( z^;;a9pBDY>-(MiaV*J7*eFap@_v=6D-7t*tP%Rlx^jPJF)Il2R@$ zD-)Wjr^3I#%q-BwJhiMv=CFMexTwfX9uYhxmQ{M4t&568nRV`N_NuJr?fef#){J3| z9Z9b`Yt6a-{sZ;a{&wJqL`@NoQQ>9J%-7F@zvaH;pXAwX((_^{-mNMKqpuu$MFpCN zS5?p2HR>Wbj%F4p*W^cfiM}aPA{<(sbSdy~Z%6ame-n=s2#{bNn5X!!-k-~(%ovYF zBiP_5Bgp4FBTCtZvfoP3V(RaBWYXvK(o%I|aXAZoU&|TN1(GJYXk(I=QUWQ zrT==9ko@4S*$_Pg&5-_rklAS_d-M_z|6qSURSYx)Ro}krFLO|mEhGkF&2f^&JRI(t z3M)<|&jGQ{Xl^Uv`y1GnXPo&99r8uS&;&tVm)Nio!Dr4>Ac~!PsXPAs*qGoDuW?L4 z;%xy;lokTog0tR@Wu?BN0-@3rOajGwl4@J0N}zSs8y*vVtpT z9{BI_EN#6gw+3Cm@=HHcLlqJ=%hT79I~n(pG@l}!D1G4Hb4SS1S|kWGh?Gfd-Vj}` z_AOpe_LKNP=3B<$-NGxCo?S_o?oSLdyg^s~d&SHn13xgWK5TNIxpG9xPp(mVcADIU zGlX5Lpo2>UutI=Zm$gU>C~W_t_B2!Rif&{}J?*Y_#kGe;p(j&!vNq(j3^IF3sNI&+ zpQu);)Lpbe!tvu~Qg6EKL7lmV zM?!T*zcx+i!Yr#Wg&G{}I^xA4GzZV|oP3>gPp5<{L75QWXl7F`a`@|RP8<5=?a}?Y zN4BuOq_1ZOKZ8@sqqp;FqoIX%Pti^oSveg}^5uI+)p0-@sR|33o_!mrgBu7Y)D9n^o}2`j$(4{EW0ugTG#sQ&IPw#*`KO!0Zf zuIw_k>H1YSTy^)*#RnuZ|4b_u_lP+%xwUza#l zU@1KBW9y^61B&|ldAx&OK|0&cU%NlTHhd!fa&3a2cwbFvL z0EURbBQCCdE0QaIeY0^3(xAuq(e(>z+W(bpMu0vI=6_aOQTLIk9vja>h~t&YxUPfG zIV##KFNF5 zDa-~$lgC#*@3^nSd)q0kC60ZH?yI?M4j1viiI;-s7yA)chvHQX)DX7#oOM~}&8CCCV?mdPa!u z0tD4h7Erj>DJexI_PRb{iu@_29Q{p))CpBqw|)iNQ$}{CU90W+-W5YndImXh=rIdF z1MolnY}a1p_nuzLq4qyXWR0g`s%L@)@o`6qH)^iqTs-qBpNlw~PQ0eaJVJJn{Z|1` zV=j}2&oTH`rTIifJDBE4R|eks)cemGF3*ybP)WP}$1-x;z34ca#(&MXo#Ux$CQHk* zGHOEUy@#4J55BRP&JK#PuR_x9YRvDvy~B*g$1F(>{<=3h8d{#Ah`Sr5^JGX zc+k}x>UtWxN6;!Px&cQi{=q2MWpD-*+SO)uFog`_4Q8IL-y{2H%G@IZFS^jdjgS<# z{3>W2Yw~^hj0ptG_xJ}N*&ZyN$wx+?uy-8p3ooPpBX@7#r$$hSMAD2^-GrYs*E=gc zh@(TpmBc$x4%v;~Xo93R#5?Rsh;ylO=PfHYXU`BzlrwpB-DWSGx-Bc$OJ5Pe6ZqL3 zR@Qj1I(IOXn_sZCDVw-xQ#dE4Dn_fa`Dd40uTP#3KLE*82tcu$|M4DwUG~IF=EG~C{+hehzj4t~`wg;XIvc%h;?Gs+lo^XF zgFGGRP40kf=)+S#$Z`}oHe^GWTsOT*{YX4c8R*2|n<5z*|FKo-?3@TXq!*dTy)l-H zH0l#VhlxGB0hoU}pUh`=3vwFbj)i>h8P+BbBlUj{rn4*O{@i@ zLZ`?t3};UPTp*R)Mp9tgj%aDn6@wozWZdSZceUC@1GmUCr|k{xXlc`!!PVG!OITZH zU}aZ3A@%aOu#1B)l^=EQ7zBcuW;w6lS;}f(I64gU44@?}Vr^RprjDf=&?3~>_pg`{ z+80bJ2{}cb_SQt3Tj#7-%hp-9W#u=6G-a3h`IyJlT(~8^{aJqGie!6rexZAra^0uA zQPXvE4~Ezr`Ihiz{Ci9?kKOn+@)Yt3#tvAQkTGt1S0>O@g5mke)v;rh#C}3BbE+Ti ze+5F!=?ffl&-H4uZup8;)lNh)F!R#?`=43s_=Aola>7F?pbcwQKkfDlzc2T?M|-*R zl$%b~USlk{_M`ru`L~{2Q@9cFb6~Z8bla_=TEwX)Y48M)Y^8itS7NxXrCWVi3lz&e(Ql+=?7WtoxyMpKw^DLuizBr;Yv5RSB#@%a5Pyp<+M1>jq^r;?^S+_a*y7EmW-1lE5pCpo!7`L}(YdM8& z<5x+5@6a8`fq}YZOH&>6-|^ZAEy1%3+fl!~0Wi0)k?=vtni4?aBy}G`U|6JR0jpa!0o<0^G^4wFYIm z<`_MF@s>gU{R4e`L@V9`>%9zDr`B$ijA>uBtJa^C z;M#Ud^7AgQ;Y!ekSNBhKRl;w&lCQR(N5bu0lid9HOa)J%GuUv>QQq>7&Q$13EB2+A zTEC-xLYS~?1RQ_lB{1xKzPn&+aj+XjBdTQZlhJ1-43|T z|B2E6f1)AgqX923N&3`s3LUNG5WU{fx{+PmpTN-Vc&&yfA#Rj6H|5mrv9n<=n7wFv z?N|dFkPQsPvVf{CO@T@maSa$nH04?JWHa-f$^Whc?|DsiowdgwDv;pVAMGNA;rpHE zz*`ClwyltT$i2LJIYZCz8lj3as)*KLA%Y#{hdsaw$?OhB#nHJYxRGR1lx^gu2VzOK z9+d!LEy2Vonj{n&%($(j5&-CL609J5J8K{4Q{}pD>HI zP7XYvE*!wdtDr>yv-wziL_v+o0{K(b`~Xnf9o6g zfhPtBI)k3Avt|tYhN&l)1gfxwcv#7oRiSKJ+~-u*vmq@D%0Onq!XV+)JLqKP&a3-y z?2BBvmy&tU^q7#$XmVTa7H2=}10gO@5K>wpHbg->Yj`r~GIS0lX5Q(C3+(}XzYHv3 zLG_E^?NPlP(f4^lFEl-6z4*n@L#ini9SL(_G~x}R&Ea6S2BXHaNdnsM&{Ra=3oG5dNv_onmXHg3W zKRUyIj?J}){Z8xiowSSm-x~=gu>Oz0xOlo)2cIngC{ZG-?Z5TjJ#%x*K`l0xx!CF(53$FEL+N%r7D-q{ARK($?w_X`PhgU)*nW zvUx-UT{tGR+^PFIGnl`HC}b#=I3DYBa5`b`?-j(qCBsw(W}@zJG4r-OWI%s1yAh?4 z;x7YFlb4PEEL>#Hbj}XWp&9IRsFfwrrR?_kni6>=OYopsEh$KUctLLly?YJ@tnxY! zMp>Sn4zW;`M>o}xT4Jh!Wo&jL4CopdS^tXPntg1;=3p?Dd%3JDxEDSGv?sdMbwbZBiMxaAu`w93y zHlDtUFJH0F7kIlWdW5J`5I*oq1|kX!_xx0RT_oo01P^Z4%2|=<>iDx6;;Ad&7erTi z-|OfWs=XZP0y`}LFV*r5Zs-0kz@nn?Z+kD*I zF?+g^pX@aWkEk%jeI^%#Rr!C&Wi&G(Dz zgkp+b4%1qz7_TN<5}qKa^PqfyGt4(1CxM#uED#XMTqlqK=Ib3*!IAdG^v}USFDn|J z7c*zeJ)V~6AGy#vbRpdmmZUm=98f{e_2RikwIX9m^BnBG{(nq%S+j2ovHB?kr3;SOqPw&Vz zJTb9&_EF?UD*jI}+nl_fq*X&3KmMoz8EFWvAWZGjP{h@6Eme0l}b#w;&39NrlZxv%+OkDg4j@* z@MR>7Da0nI)%I9Mhg10vu~_I=V$dULI&pzq33CI`;5(CKsdGu+p6^CRSACJVPmA!f4pPK+ z#$j>NOl_j~CDC;*zz%Qr_y0G}_sj$95 zpVmyJIKzYpu4JjOA?Z7F4oW1!_(MrZ38h+<;~>@Y$<{^JBQ)O-_!TuUA=-!V#5y>F zn*f7sFT7? zL?$#w=~!~|5!|1sZUsG9%<{_#zg~Z1EyLT%!QAkVFS@EUb>Bl=sWw=8b?IjoN2PqP z7Y)X1=#+m7sn%Tn(1Ou1)V67mK{k%^R&PrfFXo*El>%O?K?3q4!Hkjdf2mJGMlY#W z4CVEN8t>KlC7;Be_YMIs_#fM!f3H~%9089D`(j`~rT0veKG<#N99W>S(f)z=R50Yp zHvvwHUtDd+ptMseE>qwQkaifq;#jAv_-a42t1iH*g!_Tr;F%kivPEn1m0m)}e9*X6&;w+xr5d z9J3J;oxO7JF==?F057^6;Kz8pd;C90Vs;){po|>yy9K>&Df&E^^jQ61O^A0)>(AzB z&>R`Qiyl%4=~jQg%5~aQ?v}a>_Y2|mTs(E;qcTwPJgdSxgMaydX!AN}OdNox1?|mE zglkR}oe^}T5=bf3XXG~6Mj(0ca2^Mh`=DQ~GR@d>{4Ph^oZ~v98IiTfZ8>_+ZRSix zWaGoCB-^iDknXqr#g(4QXG4gPo19P?HT^5VM_zwMElVRajEw)V?-i|{x7b$yKSK$^ z{~AXTRHvwl>bN;4G)ykRLI00ZN6F?RtQFVcqBJNc~2kw`M~i)}8SL3=?3~|fNbl&$3V&rGiAL0EytjBW)_wJCmt%MYZfNn*GcgnnQID`F zV-#g6kc)fQx2cD9QUTvOA5|?j?_KKH2Vk3s$?X)Hqj&QLo}|G!3wMA!x(IvM0%Oj5 zRNCtP$D*P(tE_mhx}OyZs|p`*#Zq_dx-5H0DP@yGN>~_^-&%j@evsn`y1Mmr;>R@L z39=zdi0)ola_m(&QqpB4w#w^Q#08Men@$CoZ`{CeW4h^@eZZ$XB`DQUxn9>Zlaa_N zpzNY$6`D-;QgXM_75LBDOp7gMrjSQW1%nk#XL~6TmyAqL0(FJoq3|ls@=+JDY!n%L zWsom%YCOTPR>iv&kbf$6MhQO_cevv#`>t&EP6F~d>m_nCXqxwH-8ZHKpL~1m){>0_ z0?%mvVH*lx2YDnWbBt*=QmJ`-xRUm2CxtG{{@H7ilnGej<{lOtNMRkp;Yp~lcS%!i zVL@?7TDnh<0yc@8Wa3j`zdWK@G*|_kY7#h1>U0t0@(__+h&!y3(J## z_4YmW5G=4euvnucX-goPI0y#)^|lYPvZ_)4VK6>}nx0|YdC|xB#(H?GfY!Ip5~MKO z)S~Tel(wW9FoQ}C{iJ^*&97_p z%>`K}Jp4LV`H06j*e_Dk%>60mj(gstXtMe=uc)MKS>E^xF4|A#26o#Mh zJC3Q?(hxgLW#7RM?=5P8wKs0%ZT)|(*{Nq*nBHH4V~@z7ml_SCb(O&km%p-2uJ8-w z2sH{{3r>u*X==`2Rh!gVV~BR6%ma1>B+`^C+%r#mi$Il!6hq=^h3i}`3CHW$aUQoz zcTRVapRfsY;y<$Al`9YIHidw-bfVMM`gTYwNRzeUG23>LEhypsLSNR|y8-jG)dZ3a zJ8KMgx2~+#+9e)|1YCFxvrqrXmSYy4kyKly6n0SHVuiYy-L`Z`_h9N0?Y~8k_rAP-BsFBhrY-S+FOC zihY90m4m?17vlH$u*S_#31F^if(^37OM^kAa}*`Er<9g*vxfm@)W_{#eZr8ttek4s zWF<4S-3ITaCT7ZCiWuz=)YW9A?zJnq!zOl{VCCjzEwAgKx##{xyltZq#GXI4qPv&zlMKPYb6htgB760 zV}Sc(bS5ggUdsWs(-x`uZKAX0rwFSr<}$okdDk+H^zvfBq&9c|HdcIraUWPFXM(%I z=0gB}MxON<3WGvOKX4gPvloGa{VHh1(X@rYi5Sdeb&oA0Q@_(ASMF)AP;7MZ7}GR>40g~;%_mg%!gBEu;=GpS|&0t4`n5JobWvu(+f;Hg* z5tBT2;a^&S_5)pGxzZ7J_}hv~%(aU@KXH($ixEe?j4du_hAG)s+NmVK$d~)I?#N8b zoKcYJRddu3DGn#sQZL*6eod{zBW(+x0>NQ$tuiZCV9;!60rZ;fIu!O9AMRnZ9*l*s zUr)#@uWKKr-j)lbf^XinkL|2nEVt~5vypqoL-kxT-5FM#Qr1H^u7D}&f+{n+)Sc!l zm|Xe9Pj`tI&sBez**&!M@SUGPACv1K$ef-2TUOc+b{jL+sVU?`*^hOE7bh>(HzkJi z-EXG7b($&ne-&n$4nXRhLOq9xG_(hAzZ*a)L_WN%G2}hG_uMFuJtKOpZfJW|Xz|z9 zCzKY9xA52AjTzVbEYvku-=LruG08r_|2$(_gvA=;dKAxQSaXWQm=@lici-<>U1M*m zpo#I}#CDpyxDj7A=p<({VJvQXop(Av17iZ^RT#!GQNAOL=0$-`ELCoo{dB{vg!ArU za%{|(i+`-mPIp*r_~*a$fXSH+rS@I|#9t|G>9Nx2|8&K)Usfe}rgE?QdC}v$*PInC zNB!lu{(swk5gt3mQ&=$J+&SYqFJRJUS2#$`_rFap;viq{*8X>+UjBF{1w5$xgFZBU zF}7CpO&d&Z6)#+(XiFAieNH_RbQ9VKlyeB}sZWw~f4`=dOcyn3bd#>?+MUmm+p*+h zq`_v#TX3c|MqK76s6hQc{l2!Ky91&8ClXsr)+sGs7a=)lBX?_{YS8v)WnsI8FM&Xe zSNu?o>_s{4R2gY9#HzT(DM4s9$o2p0zD?Car8E}~7RrUTe85LZfq6HHaj8@t4bl@F z?(GvlM+7l4E3do|7qnL@w_Y`IISyEb`APuxAJ~_7uYa`tr^wSH*6)5cfB&o}z1nO# z>>T1teyc2y*Sm|#*qZsJ-xfFv)Vr9*?Gr}6XZRIkny0{KznN87>W*OuSlY5-wG;?g zDW1b?2#EBQ5_2N+L(5AdQ27a zrPxqbjtMJ(*dfB=U*eg#&BRdSJtGkEy0J0{iqG^Q0@MJv9#9FNq@@X8**?%;#&i7J zRab^LCeFKkr9YOwd-FLUj3xw$Fus=OH;qg69@`@x!u00I9g9C3({swh9E8cv@nMF4rTK<4c`kXI!Q^Awy~M!`+lvD7 zaq*tu^%KiS7eOiTk1z-A#W4tJXnp0B%Nf9+ue+OtY`(scS>JHF{dxpNVTSh^mbzpn zdw;$NvJ1KeoRce}BQ;VIlPo_dG%#zd0_KULXMarsUHl_ld+@);WMBP}$J#f61^n{< zAG`Iu#COHk(74&)+0&xH1!ZqJ`*{600@|vue#Wy|Gp9*BDVOY|A-Rzr5rr-7X3XR_ zTD8ZLMO)L1W7YRm;JS#@m!77WA#g*8*z)hI&->RXoUV6KnO4=VS?=>Vb;mnbSzTv=S z33HN2{Y9(hNOL^9WOz>verT(FPnY4-lIxP}Z`ZrPT6 z+vIs$fPC?NC78?EJ8zU2*z~D&KYvH=Ne;vqRKcsZuB=ndwONED0?uf9|KXh`k0 zSId`GXL>0eb1e#z4Pi^wE7H{nb2;QR$@CHF-Y&p~d?^y|%%nqzy}8*W3%QAyOZ-@? zGmyJJNzJY%?Xu9`Nim~xAB~Tv#|T}ffXKc){g6iGAo67MXW5E_buX*}MdD~=(=k6N z(16k8o3+0k`tX45mW#TRmlW&|8BWKntL6ye=hIe+fxnYf2w?9H>9&*x9xYH=5U ze9?XCj@Hr#=f|Ve&fFXB#r>|$Yv50&{)JM27?C8O`Uy^aNZOT|fq0jWaBksD$kup6 zK(feKXzg4FybWRf1Jzq_L>_)%RBIWl9^qTR`cP1I1ROX~1)al>wW2x;Pr{C4;gfuYEgZuWDX{=gb$||&he`(B9l&ogJ zN2E9LTgl3OD%Kwq{^8k;Y+`O?_7)5dlh+eu^Fwi??u8rw!=yP5m@uQjvQe3#N>kQGOqN+!slHs5*`h7X;slnvs5-frlj{Gp)ef*b*oRm|RhRnhS zRaAfSvTPk#WKNrY?j{Gv+9F$;(QnW-jcOT-%ZK}Qz<(Gh{(NCQUqjxdDymBg;5^tr zHDSEub@MO2Ml*QG;}pLvc97k&wFjE$-4w5+zPB%fOl8JVay+8k`7> ziHmnr5UME3-wUMLUtO>uq^#`M0z})mXwz&e2>R?JH074-lvHi`-M{JuFB?F4?kdng z%@Pn>TCYatxS~-0M5f0VwtJWwu=Qc6Eq{n@>ByWOIBQlR)s7TGDJyj%G-}7y7sa?x9$-x< z2A4HJ4!PUqj>M=yAR1`w2wuCD#OMY=yhqQsa2~3|Lx+Dxw%KKy{|G4HrZ8rpeF z!2i`dBK{5kP~Dw68dc&2Ht1`3AIyZ)a^y)yHSVMNf3|eQP;G*bUCq>t^wVFl z$*#kGvob{4fVk2OISD z878XG6i=GapZkO~(Q*E&q6e;P6=d*SCjGeknUSWBF`w|N9$!K={bH&)(mRFpSLBM~Hjs|XdSs8S}9e)T8QRCmK(qnM(>NVG0fY+jV z#w%Jg_p~f);xlCdaTJ3pi6YVROOquKQvXO-dNLcZFhvC`_ByKBQAP`lGif6fTGDxM z`JOqB;rUGM4l+z+frb;VkY)N*IplqD!Lbwd@T@KbBqZd&xRR^P4Od52?BuoL9Ym!A z-;)!cI42m`)I)&!b?ZpMobRQ#I1n`)8kSvND5$y;yoa9{-;u5OCvtaq=o`&PA`m?0JPXz?0sm2t~aWcVw3VS zQB?eOPP+51;ci_Y1_h5!_+N17Qug=$ZdYaAJxqD0j$OC1&e(Wm@rua}>p7nqfA(v; zSBmI3Io+p^@JnQS_yvibacx((@X!t!q;t57RfmvsjPH*&F2&%P?!2Z@zmt7g%tQZm zi2>@RR}f*|4HG<_g24!A&YC~5yleT0xS0p`S<}<^!!o>m!{kwz=A+e?9gxLNLUZGT zR#0rWC+dyI(mNvz;UuNoDi&d?-$D%YYl=pFkpXCKAfkYhvQr;_lkr!na5{48F*}6G z;a;#HD*s0e3J{ss5%UXo=x3)PVMTO|C(tyTj8q9h621x&kBq_rePFefg6#H?e1_qY zl!6WlW>Y%cjBozF5`adO&!l2peWHOc}Grltfo?S<@E@*wy8>y@n^$ zbOfu+@^HVZ4246StI;4x?9MxQRESYGehJ4{m^eS+R%D~WL>0(UaO#n+^Kfs#p}*Rb z<&6h_(dDu09cZ+W-7g2Y`-@d>RK`+mUL0>@5*pt@XT2w)ps^ubOjdrN{I!kQUF49X z8x|ioes}psMm)7+#us99LT9U?c$;q{Z z%5Vn|AsXz4z|z^26CUmjqVhp84~Z>efo71zfn69~wuPs9`zf@CxKk0h?Ew{UH%7W> zhjGdfi>*^`DYf_V-4Q@9-;|i#LMQ2#pQK|w^|8C+zxGbz@k|BLpCz~^IParbdjhf- z_T;LxG`66S;dp^r`(r$aN%wa{;^#l@!;?x{-9ae60lJ+?Dh(X;<&ms1Eld z5$`tJn-qvklVBXfNc0!VofK;4s)P_4Ud3)w7}D-!JaqpsrN#Ijf5|NLDtn>;oWvID+kufV z_9~aDtO~zpctHr8?)JMEPF1fArk?opdLwxXzx2!@x)t-`cxub#e`T=CY^s+@LvKkY zw}c3b7Temw8d{)YgS6H@8G^)Roo$$?Ir4qKQadI?BKCnJy*-C^qt~o;5d$0>3W1I^ z<5t9Pb|amw-=Qa8mjIuYKPA@o3Lmxw6Yvp?*lZxs~nMYC0Bl zYfnb#RDx^JTbcBzhYhJ;-${{DJz^{YP-)+g)R4GHJCgH*g0H;(rdymO+MzL|$)$Kn zI{sej`J53-8Q(-tal@C$0^V^(-JH2L@q4rU>SNvYzz`p5N)H9(kaJzb!4J(C9MF^i z$$y-b&%o_3P)W|c1M=Sv;82^tOYCLY#fIaN@76Ow{rnAxveF6CLETsw9ULl{sWUl0 zeh@K0;m(8unj+wP4CrZ3HK$KHKM>|u{W<~Ez3XOReQZS2Hq`!eQIw+Jalk=&dE^$^Q zcyHzFc+Pi(!j1O!$4oHIkk1{EG8{&tMY|>Hm@|86NVeipyp5YGvfN?4f8Qz5Zh$kM zqk%Bb^}zH~+R@(SUcAV~NM4eOVl9SFh{b3r;ilTzaw$($6oOEFI8x~w>uAtwh>yfw zi^>f^AcXEN>@tENRJ~9Kq&-+1Ts9^F?ni)GY4w_jY@fsWd%V39+|a7DEkgP@sNulo z_8R;o00<1oy>GU^Don_v#v~O%eaq2IZs?XMNkrpw2%b$EbZcv&qv#V>6%ha9M;8^d z68PVM2f}{i3HLM3`Ut*|^Spk9Q4uZseZ`k}&XpHR-iBl$6S!h-+!;xgcmIuS{Rq&M z;uxelLT`CUA%jvrtIk88xBdr_HXd$|`OkRgS6NXa64qJ|f#p&w=QC3uQ+{@2`jI*7 z&z!e)!(1HV&sF`{G)%UGdS{eLWLT&K`*&X8+9?^ulDw+4PkIUyl{#sAYW+7CF~!i% z7@2DF(;RAzJAY9t{=|dolH#vbu1RG#oIt)G#DLR%g<756t^G^%yWoeC$Xmj{#Y2FH zvCoY7_*VM5cGsu_kII)-Xs?g_kJc5Cimn9MR-+Q70xY+w>6B16+T!ybod~LPxqn8Q zZFKN_kN{FY*AB{(+(i$L2eS8g&xP5jq+VoWC3-PdnJ5OI{9dhZ?G9mpYsaN;z4MxZ zu!v!eV#ND~Y;W@3-NvH=1$^Ke)ZFXS5HsLea6F}USged>ca$ULwDiI}e-S3C2IND* zCBLcCvY$hfo3M@<%-Ve%wGYz*E5=VY@txP(85Lt1!C}Upu9~9zzmJZe&wanyok3F+ zs4mr@e#~E)cp!JmU(5<;?sGIa56`#$a&Kgsa~WCga)yXK%Qd1J)heZ1$afL%ydp_| z*w{T2tRn*YOm}&_M$s(u zaX0ifCXfRJ5#0Vp zr#}C($s~bLXm_blcFf<`;Vx6l;mUkv;%2m$cW(kY{0R8s$gM^Wj{SGJrvF3=5Q&wS zt~t4If}SkanK|aDAGY!51S=rYrSxt&liItB^Qo~pIBMauYNwtU*xIQclL7Pfc~A3I z+ti{t6ei}aDK#xsS#-(iwnx+DerJ4hVBx6DCAypL^z;{BQV(IjHF4y=S~m-#^0G&- z;6U)9Hml+o8|e-aqCLkz$>We5TJu&7H%8YFB507>+vM+xXtrIS=P=i_?D`wFKaXK9 z4^ij;F_KIOxcO^HEVL%{+{|OrN6GvIc$aYl&K9^1E+hG#!A%TgkmenUXa?*g88j(~ z)7ykgj8L<`vk0`Y(a_Wc^=1Dppk_rWj6o%{`f{Iq>OKd!^jwV&>4#asp7kDq;>Tm0 z>hv$)@A$az5{ED=wOD)XFwk4FU+Rg9??)I&&DDB=CH?;h#sP_DiDqovbTSZRe%yo# z!KS(-2h7o0mw=UUA1hdqVZ3_xQ=1Do_RfUG&%uAwE-DAMwj74eO#>Uh>xUt+tC-#P ze(e+YL!rr2;YCsIt0RS-o3|mcby{*;I=Ao_f6Tq9iCMpPoxbaIwLyRXb9(L8gu5}% z(K7d6Z}~o6nE_@cvdf9vQ3CAg&yaezJV)CpiVD)to_7S#!4H)R@oBhaR8SswDO6Ca z4uT2@iKBg%SP<4T=n$B7cfKv4QpYoB+VJGP6OzpDft$F8ZUd;XGb4R{I`s7GqkVt% z613heRWTa2UlNjG=Q@X%Q!2`jhn9LjqnDe1Nx?5L1Qr949pW~VxU8yeXa+?Ja}`9! zzykiJXO zl=ukW>%;v-sc*T{z@y7+o(2|`y(QaUaUXClFkdb`!?NW(uWSKLQ`dwWY# zVPZN7M(a=hVzW6x=a??1<^P^;RKT%s>_*k~)Z9=!p2!o{DRrfn-v%ES>W@B?@&zc3 ztKVC(DU-_D@fqgg?Ia*-@^k6J9W`M?z>=d6YiPQ5%a}|#R+PKP4wkRMMk@?S`MZ!J z#Oo(=fAXxx;GOf83#je6QUGvCKAA7cH{HR$8R3&60@Jg#F~>EZCXD5rm`FQ!*tiqz z-Tbl(E)&|Rv9W0gEDnBY9llp=x!WJsarSRbckfXf0P31+9)>pcpuQ&8jbt}-F?YG`?ov_`F}5+NS? zN!%F5o8PfJhX$z74RZzJ&}^y^1Kmx~e>~tqe1qyUW_nK6thIe}dXSTa=JGB%7I29k z3rG6G%bV0wY-ssnwvFEC_3|o4KP*5F>sj*YyUS)NE4byW*)CpOl!6^@SW-3MqsBt5 z(EV|rPb_zOx)AR7FMqs?r8CQa=eFV8J)^=6d%}7kBu-LvP~wh&TjF+6LIdS6pllVI zuKkmhPW*NBDEP)^l5ir?YD+kuWic3bAJ8+JAHQIzF!Hf z3$EM8XHh|`5J^Umesnp-&)aSmNtOf4j}12ezMX`TsB1)69|QzNG4bQPQg6I6QHU%x zTxx$6C;?i_EfsP=DcZ78D@>WT;&huvlO1W-?9#8u8|Gn$rwpHMNXrJFgA$5c>SPty zE9*a_#=y#ob-P(Wt`VQ_Q9IYl`fE0$W?aZQ!+LMrM za~wj3hXzwseDh-5iHKLoaZA_B3OHLv7DMk0`3>&B{tpJemI>ZIu07^{iTjG1tF>=H-DC+(RY%BI_l2h*i5-3L#rTz=2~}6*f#|+~RJ#GzUXBO? zx#Ho4`>Jb~i1Vdm#i>W~Y?T@HxmVmhowhnsn z&SvLzoP6NvPF6iR1Xd^i4);D_F@+K6J99b&NM%9CKaH$QA(2z)2dE}=-qQl5N>36o zB~q@>FB#RH4T+T!-yA#cpC%Id!|u(dJ<_~IAcP6t4lFm)KJTcXb^tcrTdx+U;vn7p z6dLr4h8?iNJ1Tet&%Xf>+@C%(2Sr^Sgn2%%cpW3AbsF^Im>rcw#+>RO>8l}V%GaND zL`upS|MY8qe@Xwu>j(rPYH{Zfogh79tk!7?gF?< z1fwUr|6=1~QmVg&NPqMZ8sB8Hz=@Cladfij5vcz-y@C!wV?w~I#1)ty#wR#f^%&KW zV?S@@odd-+L)zHx+9hbC3gzfTv(RX5oUC_Ei%CQGfRL}Ioh@@n8hG~ z%Dk#_a7OSI|g!;QW?PYgT#{Rj7RpqI$UcSesGcc*l+#`q`;DsSX z%qAgT4XyHI%r{jh&YI8KW_fPqQ`?r2ZZ=Mn?`GUfXjU`(n&}}p#cdkd)_e6jWLohi z_T(?sZK>4Ca1a((CZH}g%Wr-GPtW@-C8riMN8TLDm7 zt^svb|B21HyQnoefpX9>DWAlLxTYPN%dS&HI9ctrPI^dHqP^har~t|lej$UUeumJ? zuP)FrooeBTW~sAuaK&p+i;A)zd%+KD_q2FvJ$pDPE5i*ind56PIO4x4YmHDD|Adw+ z&k|K!a)?pYxl{5UsT0ge{ci7h7_gVWnx4w76vm-${bDX~Xd)YoyPbN_`wL~l-8 z%Wr6uD5(-YI)dnb3Ll^dlChgWWd6J#sCV>$1tIr%s-5q7OwaJjpXahx{{5k2XX#9RH`PI!QZy@ zug=xSLwRQ_)pb_7?DC6MsT#rP8h4P50l`Il_x5qwA`0J_s6T71IDQW!XCKXW49o7H ziv3Fv-ndF2LxW6evQ)AE{TU(lMsUxNcvQIP)zwjp=#U+ECoX`>!se_`XzzC-L0v;~ z#X3te*|k(#%Bp8K+v=9|NjU@H7)7HFJC@GFL|(BiwHj5G@I`}eN?B<`p{CuKHMX%X z>!8zY*>39g4op5w1Iv+m9xX3 zrr(?j_$!3HCe|CPF5~Qjr`bb1{lunMa&KVjWs$3UPgF-`Lr2H%bg}!lKJ8}Q-}3gG zjRy5Of2kT}1W5ckX10u4O9!YBQTq zFDVacLk0pr^V4onf`hGeS7h|-KVd^+23$|Y<-jjD$kU$hNJZzi237;=%kh5< z*&$TmriitUBzyeNw@B^Jg_ql7k^kTwq9TT_mEL>03cX5A5EgIXP(yyUjHb5o9ffKu zsP*H;V2`~_rHXwMb!DISOyy#%=!b&SdvJMVpPrmT&H3)f%rpvBN01)yJou=1y{|!^ z(cCu09?O0qSm&4Q9g0D=BG?T>54D?%zMoEn1%X4dAGarLr=@@(ny~K9y0dDZl}YHT z)N!9~M|#x!--4C^niuJG&jSN=-P8r5c3tD}IJr+*uiVNlyin?d(_Oy#eADDp6q?b! z$75V*+AAS}&`@|o>d@?66)F=f+dEkg2Uax24<<8yYu^|{q3P#GA_)IJv8iVo4l}JeP~!%NY+H z6h%Q_n#rSLqXA`F{+@xns+yWx^hwsG(J3;Oiui;Qci+4lHF{tsW9#%=uk2PCcsC)* z(<=#3tF+l-qGk#=Hq709DG}Qo-TN(MEgpsfF-Ql+5riiI-t@G3SKg1v&dVdQjgaoN z?Oq^(ieJT}ookET69&4F=L_<#mdK|HmZ#L&zhI&kaY}uq_s881^+CQtN1tGSIgU70 zcgWKS{+<_q2+6WeQp8ACuZYZiYY-Q?EE+^mv4K3H+%NAu*^ul@u;zOS9n@;r@e2?2 zviQ*4619aFen3EQe>zlP(;S%Vk+rV$T&<8IuKV5czcb8pu?N%ljtD(2;O!<#J}OJm z*`Decvvo3{Lu6sog%qCZ5v$>^u>2N5*yq{%YRoS#tmy8sU(*sbP+>{{RzxU1+7KK? zYJAudZ1RqpJv;(?NMcl2@f!RDV;RK@048F}&TD!Ea4!D{Pq7r{(IJXL3t*@(yG50L zzkjb)j_&=4DCno%U|K4S(>D^2Of8%o|%~FDe_+yW}m`2iNe&e|*X;r>Hoh z611_PQb3={)_DT+S<$fXDfM_UhRNBWW3t*CowQ`6?}YaQf(1b!>K%J- zLrDrBfsqm;^lxyeC@AYBP@0MkSgpUeb{sMIJ#;CdJkU=n_t$3%o8R$7nF#4_Vh03m zzUc}6760kUwkywI*yD!HRpfm27X^Z$HL4+kUf-Xud4MyoJJFLDKn$Zytuk zQ>YZ`E5DVI%06t~LAOT}A&oc$m`wcrWwKGdgAg~r}Q--s>M}P0W9~Asn zcLEY?+@G1#Hxtpy3n6jX@RZ!7zh@hBad1RF-$#0PsZ+1aI->)E>fN!B zJqP z5I}C9nflh1i0mt<{1dj%X+Y5NcB(*~O^-rL5)EQzpd9^6N>|iC)GGG_MYxN1sFW~* z==0yG-rC`~&BaCra;tWUts%}z(1;Ihaq@L*kIXQc`OmN8hor4jwTk6{qkD{qFEYcj4GQO+=@fXcqoY7$+u|UxEf|Y~!U-TDw2b zDZC98V#fLA)cB}bU((Gy1Z1vxh`Rs#1)?-tPbt)IH{;@Xc^tRNCrX%`c7hCFwTZAM z(p*UUm5V-&8Lnx*M2lhQ?&m6lq!02gqXY1jlZ0U?bvV-b$~w2$>jEAO{`;WjtLTy) z$S$$^z>+5^%&!=WP$GGf=VxBnRzvbQZucY?hSkifVB7nJ8uhvZ1kODbvty<}#b}k& zRnVVM0?Fj9_IE56;D0X|&T1TRy^b|8DXn{b8c!bVTd~Doe@iylpIwThrF8nfYxx%O zOu++oSVpP+78|xEEi29S3ttJzBI(s&B8!t~l2yp~z;BzPn#5<1g9T4Mg-PjdpDI=i&J5btqSMWjTs52iUc+ykQrpv&j)DkZV|Q9i6j8yti$%svg7BKQ~c zMDNOqOE8AZEcz z94Xpcr3al$K8eu_kBnf#;Id#B%taG7%2o0)XJ+;(x2#F+iQzp8SEyK50Likt4u zfqu%Mhx6|)C);y5c`aENE_4Jc1z_L3*^s4>IgNaA^NdEZh}F3~GX0urF1v&!wN8h6d(zA2tJPmMP|2szgsHC;;$k65q0o`wkzgbub}X+MP#Vv zr^g<9yL+j@_1U@K_icLw&AtH>`=Tu{JDiU8p!5(zNnHz zIv|K_-<1HZC}k{S>y{a{*BKPrx4)==kIB4t*h%2Q#^U}Uvkdi@o}YG5rr3I7EAmMW zn~1FP6Jzv1H!(@<+HP=i7u7y8#uNyD?jXu;GCwD|;Q;xK>>ZEhExQ+;-3rWZ!%jdz!0bhZDyTS9EF9K`2Pkt$F_Zhsf5 zS@FAz#>+q_Y4ih4QHus0B5B^{^btRem8zH#D*yFerw7^?2FwVG@4HRD ze&0zJ6TNT<_Qd)0pq_z)zOpChNW3%pGYa@4)T^v#+&ZE;cWBxW07QR!HQ{+cTixKK zaAD}+dp9sS4gPGvwyfwYw{=l8r%Ic#1DxHz#U3J!syBrNmX6&u?AvSv%uM2Hjc|eb zl}8?yVhS4k`hpNyEkfaCt=|j|z!0d{8*75)s+A`-eqTy>!}e{gb7j=)*%A41@kX4r zDUz4Wt#T0;Lh%iMA@1iJK?I#H3zmgMnL@miTD`*sF%Z!X2(muvDR-9-Ka0*VE(!pT zx0B(8w9~D@AXpZG_3ixBjtj1Z=wjfyr3(}~TU1-a`dHH8>uSf$JHjSF55*!pV%=T7 zlPrvo*nR+?bvnRj+&s6%FS1!I5034(?>VpGpOhe>jjIAtn~g8k-Mv5;l~6&^MhO|* zLXT|R6p`9&KG?=($HNkFKMV~a7W1d6E87dfsO;OH@ZGv*_O8$L>r0r4t-AF%&b;pO zN}H^=k65<5DWK8!Aw=Ffe?z`Z$Jz_e=>`$e|H?TJtlO|s3ztrN3lu1sHs5= zkg>yDObAf&^j?!A$Fva&EJvTH=+wtOS*H)&Pr_JDeX_uAseHu%Yp)4$Z=hvCpdHw*s4s*$3Ffc=O9R%C6~jCqbyi<670($)Tnu_gL_+t>|K zK2SoD6^m_TJR3?`*QP6apiySvC$@~eHpn{7H)k^s&!OoJY4CmuD2;F#n}U;=dTb4Vg_>=KqGIv_C;yfvbjK+@xR9;G|s=*Ih@SYh5>QVR6hUuhXY7_ zf2Rc71+@XCHs>y$h3ly})qgBQkZZm3s}NZHj?An(2xBh48*y9*A%$)32`4hSA&@Ra z^aVI^ROS>U6SUX5gVoNR%YXqjGUl`vDRQ-j?w9Um43ZJ*(cVz|4#eL3HwRG-AOU zC^je2H!y~dC10?^v*&Ki`;s))uAJcp51S?zYFBc8yrGYW6=l)7q|~5O`Obl{Ds@4R zHJJ*n2JNb@?%DTK8;^Z{0|i3MzR+8J_`n0%~pL@Uxlm-1UZm z*+=cY#P$n1?IlG3<+JN7vHauZV{HBP7uBg4e;k7TT~1JyQhV%gP0CWybbJ_%^-cV(d#{CP8>&R0THQYZ#3)C_r~Qx>BsB_Iae=j z+ArSP1~v!bLOVg~%A|6jabLZwV`usdE)m(?X=CfF-`r~rGnJrHW+<}JV)f@{2HlaSYhLW4~%o^W($mhbbBRj)l%19`i89OgG)B1_ru2C z*(~PB<#yyX-E-gZ4z>^n%ofQ+_j9Pu`uCXYpXwQ3@7}U}g*9E!gQ0kO=ljI03c7Rm zP0je2(OA=Jqg{SG%p?xati{ediKq~oTsLos-R)Wc#6Ko0-iTV4a+W{x7sO_BHj&N= zfq<(cDvY|%r{~qTwHvNF+-UtnYFbDB8ao1roFbxdb~&KOZpaWpm-GMAN9=&65iPu( zs1sVU7qoBxDz*7=85^hj>(d7@_%mZ>2ZdQ|!3%+;tPj#{!B9_c+*l6Y`cbD}LzFA4 zNkpKx8Q9DiWF`9k&^Nu0r|kY67`0Sp5NvD9;{~jAu#H?QJI=O%c^9w)?f}uT&;F!(?!1H@RuwB0E@nTmC2T+c9CTJ(I|_ki zWJq|fU+_=foyRBUZnuP)eI1n&%ccfCW;pT?l*c zlS)Ab7(QY~=zb1uPs6A$yY(1Pe~*-VhqTDIK9}KnZt>yKC|s0(_*!>jOWOHhf$)fn z$nimQMU)ovqjIZ^AD+tuE*;r$Wxb3*UTUkqpJo4MXjbUVH|b~^tL2bvp@ zjcb+coUVd{8sT}dI|JsU?xJrRSf?TF>@l^q{YCLCEb0m2HEm7@S_39p?$f2t1(#nE zjrt3Idb4!C>KL{+|O5A63Q-^ z-YQ1nL|Xw?yxZ`F7ppRBon}g`F2ZP#G1uq4McaW?``c~vgHVYtjx{a&yf;6_&)$Xk zRVOg4^H~UmyeFTLJ|aL{-&I5HYgwvxrk86=uYddiD=O`c@AB_2jgDE08>@$+WX~Wu zseD5s3!&7*v`}8@h%EYc)UyzVCCJ%t@-i2S-fliNzQ%`P)gCpd64L9p3fbIn}l|l@+cc!;oGyL( zoCP$j(w%or0AxdHg%i4eKu9R8t;(%fsMaIa8`(O$b_}d_(Wlj2u=WpQ{>;jdG z-F|@;l!4*bQ5^XF=*;D3*)LDAqMsQqjTY}#n5mT7S_R%pYp>tq~HWa7{p!c52>;>ylX^=Vba;?~=|3^`1v^YGIJ~gj zHg|+?!GY#)=3Q4h|7aXAt_xeZ{!lSnx8O~je=#^}gt~QXBK{Q(;}{+|*N_rdeMN@S zP(JNGKxx{y(uaAt7TFILyQ;V!cNF8 z9qK9f$J9W3RzmVyQ)OPz8&cXzjU!y_P=pJt`fRbO-X<`Zsxc37NI$W3^vBmUgke#P z>`DFGev>;w?erp)HD3a)!L*3YlWDq6$R5-@qo^`!{t}vp)C}GGAU>3d$Bz+Y(HNZUs?SpL;d)uUBpByTw}is3RB$#@0VK z1>*s&uKcSR_mVAH_d7SPt4|<3u2)0$2@1pt0kk#%hr3~be$zqywh?l9RRi#FZ0YJ2 zmlEBtOy)kwE9W_MyUIt|$?-EZP&bo3yEl#_#Xpw98!m}$Uuh-?%IKgxkZ~%Tzwd5h z1uJ+=(AH+*SE@SP;6UN6!p+dY#D#IhDNK?)+p=#7F%q1{YUGkZ{T|?UfU&HagWIS` ze`L0_VusS3{l)TyUnwed^M=YEtA8no0-mknBl*3B<7dx zL85D8X$lE0P^3ZSybMNyh;A;QGl5netiJ5@btU|*B_Q*mGxfN7InJW~n>!GPMe>ni>`0jK8Lsn{6LY__v9rKTE)kKV} zG+(H67l-`#?IHiA1VHnm+Jp4v3Dt6VMkwx^TA*=KZr+Z)T)69`II(%E;%ci}o3ao* zpKTXf%1|IiPuUODvM>1VeH8qW(hKMIQ*^TxeP)ciAJTnvRdizof2$y1ga9DC#|D-n zuO>K;5^4zq^Conj>lEB>K z-{chi>#;>#)b?GsD0(iu*fLa#+P^!BPIkh2hnW@qcFde@kYGR7qixqBl7LHEwlqy6 zN^n11Ew&*D!E?juB2;RuBlw8f_(@%T#iMst8V#bdnnD6;BW1OxI=01!VV(DFP<6k- z`=+5UVq{bLK;&a(07<@x;p$-q0WJu@dyQ>or`pwZz|-4U19Y288BEQwGgp-PZ_a~) zW)u?XEkNubLURuaMgx<^J6GZ5n{d&e7P7#kYn&(+Fv@vq?&3c@xPyfyNqNbToS5$A zpR4nxX!*0z!~rf*@e4|tGumzL$l@W0A&7FI=^4X&^uB zbj5vx8V>T3hPmS-@ktsp3glESoMV*)-)iIk#=Q|xBXeZ;)}4reLjfZ*WPcfTEY+rW zRI#7eK`2*9%8o?Cfm|l)8xMTp13Z>D)w@4r32`EY`lNUv{pOlTTuzAwZSQ9D9A@jV z9nJorm(MX&Gyq&JTK9Q@k-TOb00S_GN3&o{)%*~K$c~qG3->B$BMhEynWN4frFYl< zTivuN(ic2fQ(Fv+@4L)7)i>|3dM%#d^2%qtNYn?YLV~BwpXlB7U??#j7d@ky6sMoS zSbwrcW2#mjJr|52>u|E6mL;wiYO3O`&gjKq0m^ijwu<`h>q*L~0FnB0DW`MJzpx zDe340ekEg5B9Y^sVzZ_JIpj4yMoHLTOAG)$w!ARQYEOeyW@bsyY{(v726-Tiz$0^F%s|yAHbS4(i zA^3tI(Tqpt(4Joy&=Mk04v+FtA&Ly-FK{qyRO*i`(9T(s|L9JWK*6cMLG5NV9w^iW zp)FSVN3;vjA0HzQi_+V)MB_K!jgUMo|DY!jfMbtV{;vrKjaCIY}r*(CwSmxZOTHrg@={0H23eb}v}8TcDU+)oMwq ze}lHMX~-yyA%_~{3YlPkZjVb)H(M`o^Eol6vUoEbSr(^yj~WIbMGu4WIXC9^LxPFf zP`?QaVV7RZ5yeKfvh!}gW6m^C>2cfc``5KDA%kmEQ6$1I=yU>{Zb*KeGBALT;x_n} z#>zZ*V`tx8Xm7Z>@DURo=jQbB?nW34xc*2y+a4GV%EtMIA1(4qY!=tEfE8Oo*mQfi z(I6i7j{MzE(PB#dGXFD&Oj!QKfM~58oWBzpQH^|t{e6^D<1}Q|!u@klP=Tarr&yX$ z`dvv^6E}0+U*7_c`i;BEiS>6B{FkNd?=%IPHYDv8LJ|m~Jo0}Z0HF=4apyBBMAp2y zCseFI>3OvQL!m}}XqN#5l(M*n>U_g9THRycXv|frMY!<3mP96{pQZHDIc?kBxSupsRXqA_^&kOnV+&>ZQ zf~g&L;!YJ?v=}@soHcxD_$Sld&V{nJ1sbZg$WV&yJhd1nB(8EPi|t=1f|aoOX~>_}NqNa>f~<8p4fcus!T6#dSdq5qX?? zjJeBixV)_26*>mHN#S29jta$6cL|bADyJIshF9J9^ck(}1(sh2^rdc_>s$+1?6J8r z9*l;zKc&X~Il=gRWHCh|@gz^3rRyVda}ui4ZE6r3bZCB2WE-w?!j1LQY5kVYKJo8a znxkl19{+&5F%FiR7fm#&lst`CO}Bt-nL?9ZIL!%koRoi5^ii?T9xIk?HY_$YuB}JU zPwjSdqJm*;dA6xEgO>ZbBahpR9L*7Q`n%5eX1Ip zJ%%8iH8Nwz2K=F$B07DOJt#*yD$d|xebt#1YL7d|!b7W}nwn4c4Yfh%^pU!+{If)- z-aOKK1Cf!2kU$7#hox${p{;B@#YIW4F%(wmQ4Yv15;&sr*SNqA*I;l5Z-zPxiCi>o zkx=+ji3f4^&9AD#XnRv~=v&YE=Br=mp<5rI(x-RBO*UoM)Y$ia+Txs;;AX^vmb$*k z)nDQht&R?&4oD&t-lfdFx83|W6~5=DU@TTa*jEcpx=;CWXbtmoINDIWunO@1 zHm2K^vM>9P*KCF?X78$@6v~-WTvpWEW5*zym0T`yK}Ivk7LyeYjs}&&&Jf%y@2S!G z;v|SmI-2aKAY?WX)N1|9#ZN)k>Jq$8mQnAlf=gHu2nlr2geh{&vdX;C0kJz3iUTqU zdkOr)ZNInXVWvS8d=O0phPmYQYZW}YXXNeoI+3>yenC;#^!TJZtxh>@OaU&-d0az? zmimS#pW#FDu0W5}@r_15?YE|4x12uFVNIY*RHSJkhJ0|Y;*>sGXV;>9d6W3~b(T&A zws*rgE9H%YVQzu3q-ByaMC{y}<~(*Bs7giDqS7>%l=#Cim`N*TS52gLGs*zM%;SvBhh zuHRe#ucgJNlS)4uzzAqRubRBi7+G-=sYw4*OyIztjb@k|IrMXx5F!)#WX}!~bfXHe zjSJgwG_3WGi2d6O|KKw$ITa0SBN|w-O_z6&&NRurKbb z&9UwGkcW}@v=}7~UzUt|c_eLH^U%~1eD(@p8uQNIk7dmqt{{7cz_{=nnNR(eGsOzlF^!bm8}|@{k=uvTTjX7kjmZ}bdq4&yXxQM`ta*=LIQQ;g{l@# zj}vJ$uE=eXUpRVaOb#a3`t8iizy|%-j>S^oSVEI&;b&8`w^K=-ZaCAX`LfsES!>o{ zdLL|>{JC7ma)Fd4sEr4l-8l}j0WuKQuGO<5r>w7-Y_R#<7)gnE|EoH$C8qFjiT+ve zQmhaJqJ0)rN=mF^8$6jrh3Il5sqFaw;^{1-+G?Y1-Qr%{gS(aB?h?E>#UW7K-HUs$ z;_hz69fC`7D^Ms_in~iszH{$5e>0LFN%p(k9DvG@vFay$*bp13x2b3=I@5ow9eqW2p0}45Dbl`nN2-tqh!u~U1~W< zT&iP_^!n7`CjMnwMSd@13B~ADWM026Ml#R47goCUh#*QAFVp&O>I^w`*13KVTp=(o zIN|d(`ElEtJ!boeFnLfTsiy@eR~JZe2zZpOuBxLd{qs4Fq62MPvk6FjUEr**xZ+6F zbEQ3%-wQ{L#|31*1Xy^yq+R{x#^*WtFJ7`F%6z5zn4>5Is6Uj#SiEU0ia+9Z@an>c zw{+5|ysy<&S1-eXcj#s8O(4(G(!zs3?0a_DbL^!FV;RU2ay;aO3s%2f{p5u=+`OBb zf&E50Y*0tb*MU0_{J!D4=b%I+DvHSqmxlGeWA+UeqQO0QbE?#tr}$J$)NGqF0ad}h zJG~h$WV)H%_k>SFgN#$aJWfD39kSlk;2N?FxTneTfo4aA3SIPp4SXX465Z$gUqZ`MHNsP$BJWZeZlsX+L-yIM9h&rz7 zvFQrF=@cbCP_w#y>Ys4HrzuiH2kUnRB7gRHZggG)0UsUC1kl75@?^7Ae;)4L&*ahd0BN`)|5^ z@?VB^;sCgaep!0hKKodBxF)klWKR&9p7k-uK;+%f3X$30obRxD?SNtjVXHcQyJJKy z#y1C`2@c8j8I)B4q<&^ow4(3gDHB6w24l8Dp{Z`8gNq|%gpkl}t~+(Q1-h~qfJ{Iy zq@-w+b87Ol@^6FZAA_skVx_dI0_rw?5H*VNQ3gWpI2$j+Q+J6URMz8`;=R;qCaTNS z)*b<_*;|7fpOv=H@Y!-+-n+GeJ=jt49b=a2U4Y{#7oG%og7AaajTIL)*mPt0u?&67 zY#CTqM2D*aZjql5*nWm%W?OrR?z21OtL*$NpMLFy*7QitRW(zg+fd6E`1yP& zZN#oP<@Y%$$CKIpc4LN-Vu(oP#-En_{v%8UXR#| zgFS~AV)Q>85!s>j&SDD?e9yrv*gD)SMT*-GuJ3A%1W;8uHh1(`Q|3-V6@5jAUVE2L zFn+ykO%}I;km7hA02MdlgzBfii&t%c-j=8Jo zlhNiy|A%M9=v@$!53KvRjhY*1kiI^)?n!z*5+B@C>~5H{zVQoW(>bSpV9jMZJ|v+x z$6@Qv&@t&OW4{_&a%y6`wp1^Zo%HRQsR;Io6!f3pv{YZsb~}lis?KbV+6cch&d^c| z!b1otcDhz!7eQQ=c(Hq=g;TV{9WzJj!1AtMm)LuV`$oP8=SC!fna_gm*n*O#^t#cW zlW0h(Qt`Gq!;(1q8eqVl4oWBzDzN6;9lN#PLOB|8=(fG{n+3w$?3o!IwNy{raf+@! z)4X4jm?NRx#TodZB_M>iY|z)J*Rq9f z@5%O@ghDKpbO#0ph`xNY;-iy_AHK<%4vai9GPmWWfhz$EusOh+o{At25{bv}!LIJ+ z(AL3!aEQWv(yFh<%GvGs(^|$S-JI65-f>ZWMB}84X=bN;lMyW`0WgS0b^`cU(NFIo zDcqjbr|E(Azk_2%yFJbFat@56y971Hx)`vnp@-gx_sl59mxb8pQO(vxRcC;yOfe&6 z?`(aD5AEeSKfrc{vYYEieIBQSG0QRcR4)O z;A%ChV(q{!bUncU4CTYWG81Iy_(MlJ`u! zb%}3CzTM&J^@>oOW#tBM6jwFW=v%#1NZzbB&O-u=wy-9Lp>Pg;DwY>sdTs9SH-1?o zHhi2q4c%@tC0lOYoMR5%qAot+oyq``QRG?dej+IW4;Nvl2FS5>A9gdwPpX1{B_qhl zIYL+;Mu3bCF?lMk&U~kKW7*? zw=8%rm7Y?dwFZWfsJoff+qAKtikeBuhYj~7Sm5rA<-4t04mmLOS6o(O3pF7glYNXH zCpx`-$$jPwkVB&xt&@CM(@k3P?Qvsk+s9aDSLw%|7)2CLWO{C$Xf!WlB7;fKlJu@i znEP{az0pVwHvejPv&;pL!4y^a?Z=qo#gA4N{N~OX&16SZ_q-X&p6@=Ja~bC`EM24P zaRu-hf^Mw#;SMU*sLIZST$fJBkM8;UqRc2l&DftrjzR-W9Xp7*t3T3KPkLSPlqmzlyXLo;yRI=s#qeUK#4&NV@&T00TbdnyyBEQ%z&6v- z6%e~Wsq{;P?FsqE3`hWZ=9M2TW^+R>gf?ox*4yL~3oS_&FOYHEy&=lVsrJQO7)=s3w876`f^|>aR%NhNF3R7Zytp>C9V78yIDDW%`T>7b1($<2QtXj1rgcsTvjav z?K$}nY04=F=hC10A0Qt$+?wB0;WiU^-h^)zwBHiz{4hZKhzrvAXXXfhj_5>N82u8= zGP8y^J`QLt_K?9d*PeXBv}lF7$m-tDXN0}jrsPW?EdjS@+9OIr*rN&nIO7pl8PON3 zgG|2|CY^pfI7^!g!sU_=mDHEMf)HjlAgg!cMS5|ly;5^v6PRI!KpGw{oOp9eNyR#R0!$GaH)GBix+f2zdip zNR zg`WRJmg^=X>T8LnD#(35N@~3@&O^u@ns-kX*z(#i&MaaVcII(X^^5}*%1w&iw%ztD zb)hRCG0K)VIo(hqwqR9`l05U-?CQAInO9)!)YS^7kSoo}>8qIyO!7Hi>KVG=eBSFy z?2|Vse6!gt&j8XGuF*rO$sx4**uU<=SGCoA^Uey4QXIV=0gUd;ef zqu+C6zWIkC&R_ct zEQ*j9PI}Qc#-Hp^^a@75i+g~4&c*^Sz1*TC#0)oBs?nwnsbtv+X4lj999qVaw2*G6 zXaE^JC}Y|b`8cDixQ{mniZWSogv1Ec=jxAFK7VJ=67uS#`seBLBQ0CaFC2e_s4CJo)LgIH( zI?mP;^4xSR;8p9oIAIp+iT?J&6(jk{w(C>e^zy14l&JFhm-VLel6idM#s{4|JWl^% zT>Wj3GQS zs@Sr!zUAIH{AX%&6XQ=(cYD&u`DxR)2POgQ7l*V~JRrw;RiKYuS8z>Q`mH_H^LPTx znYWUp)j$yXD;Fk<-NQ!RE#Zx%++ATu765wz1z~12;_r`6zS(rp^m1EpHv^BkwzP>u zfr=2{Cz?M5n~Y@E*Qhl{#GwP~b5VxP=;NJy{QKKwzliz20B9H(%w^`~vpS^W<9|Nj zq)qF`H(V+Z_uN?$2+IMo7G|9M)r>H-o~;v3%QC8K0}O-eS}!rfb-vCsS31Ikj%=?7 z;7{A* z&(%Pc+Vd5u@l=GQ!ZR2hJ2epJ@zkOn7V{b=%Mcp*@JGsf8+iqXSx=7|K$f(1K}|TF z{c#AxU9N4R!F`IkWL$UOi_WT~lQ>F<*vHtqK-*-09q4_*M4edleMj{E^U!jN^i|m^ z1+VW9^B;8a;%wxhK%PHQ%}c<_?Ca{^`v8V%krsC}$&0E4_p1ih3dc#<9zG|w)WN+6 z1odTRBZGWq;#7atiTYOIv5H@jpYg0ssrJ(selKNcDn5E3`=}r{xE^xZc#*7$0nJ6r z_KoKRc9|2-3Jc8Ilre6?G{(2XYx(pI|Yus7U6h6LM^>x%7Jrvt#s^%dh4DO;ZnJhOzXxVc{_U+y)Dfe?Q zC?6ZQ&3>c+xUqq9EBt+0-rx8%_L{wqxS7OHE$7ssNUl^T69aQ%Bn7~^g_u%V%Ov|faWnq?br4l>tc1Lfmj zKz+Z-VQ&7#)SH5kAWz-FZk#!*=}%5y|kpEC)}Rk7GvVYrjS&G>v&oT-0o zHNOSjI!=*`6ZY|W;jK^W3x`9;tq5O`ODxXh}8Ki`#5A5 zS6iH~DG}wZf-)9VFs= z1{|~zklNyCWOXd*-`GM#Fw|wgmpL;-895>A;I}?r44BPa(CVg?(NuCIF;W6_fYvV4RpmA_v31(S;4(xoSM*A|80z< ze{$E#%^F`8$TCFi;p~4HhT#pbIS}Ur$kTdzqm4}04wZ~6T%hM7>7Jvk3c(cLJRz*X z54}om${J|l`Xx5!jlUo#Fc1bwFa5P_DBsWSOhdysWD0%3m5DbF|MaiX88@IfJp(p& zQXbriZxcFI3h_j%`G*}_rZQ6hqvXkMpBPR)-{t}{0q2(A`vJE-N%U+7otyP&(3*T3 zZM}h5dFJrMt74PBOoy7LUjY}^%8mg<|78-j%I`qQNQ?w3EdX8j?7jUyMAm}8HZ)C5 z)o$~-kS#BMGN`$>zpsYnrX2W_j34=mp!S@fn$RKf>mcti<>3Z}XxJr)4VGtqILk*LUT zTFtdI1)+U8$XSePfHh=Llltd>UT+`C#L|)f?z86KA@w-Uut2wDu%w<1M^NGc5@5)m z1D&r8!e=|eqxAKH$51tYZ;MNvpvdh4aNz{if{>zBY<(18pIM3bPisi7A5hdw<*9Dp z(??}>kroDmw@U!$de4K)v;?vGmiu!|8**=VyD*iEb@r!@v&A*UPm4ZHJ1G3+$CHa} zpNjC)qxI8i134Sp>|xvC-^^BBpqM?reXWQLNwYfugE1d=wp-B5cFUo{79xfL5d+5G zBz3p_OJBH}NE0pv%Xy*9KSakzelQ1T`(f7r;7#|*A+sLHiAz--o4%Nnf3QO1rQg+% z!<+8%65_ei1shI}D1j6}5xJiByP`O>$)I++8F5m6inQq9qam1FPKC*3PPZ3q8(o(2 zXqa3^o7Z;uVKBDtOF(9hj=iS{bt78E93R>UnFc0_@jTul^3_0s>my_T9emc8c6P?>_lgHphTO~HCyXGA1 ze}n!AWsu-`Y}a9=M7(OGf$dv6iu?1EKTiV({?^fc1&_y*ttO}v9fsZ85xSiu`H4wP zPRQojO}y7jsyTm08_R$Qvxm zV`))i804p}e9ThxhgIL&8puxSWv^zZ%A-beVaV5(wKNR?Rc^1*S?>eASsu(YeL zhp(Tx64ToCdeDU&R}ovla+af5e#)_@68lNh$yAEt<89pT6v4aX(9} z`o({(KF38+n;dK(ULxl6;ph&P>mQG1&{43qMDKE`)idzWzA2qJI6)j&$hwTDtf3{v zsRl(Nq`3PV6SFB;$nLzqw_vmH2=@o{>8mbK*JOzU67)cKz2FCysp=O|G>cbwR6Mur zd-J*O@CGiIWKU)nv`EbdwgmQl>XdGy_k zx>h&N0JZx`B05rZaVDsuetDBLxV0LjG82q`szEIi66}1D{Vw5>jMF<12P&TsTD=8Y zUAR2rq42_x?AKUR@B45rR}}p8p&`*zyu~@V@)|$vBZFCR;8sf4zFo=w!lkK2n{WT3 zJwnKVSg>AWb;J#(!rv5=c9PIR+F>{reRBNf2gD)11=yDMT7xmWzVLnYFVOL5dkr-` zA%%dp<1tHO-Zm6GB1Ti|2YhdDC1JQEZ0p2vd*vzNF!#sn``GL;d!fi-Zlu8x$OJh% z14`(>8!2TyEf#<1FRnh6UBVTx+=y67>%ugi302PyfikHA*f#4oCEU3SS!3Eu)tX3A zF8W0N@Pj%n%pC*&g>UmBvk}#ftXF+i()Lr(w>%vt&&1S`=*tXEXJu~Y>bM~Bu{Cef zCV9`bTS6@AhU}Y~29tf4`1hyZlXJQNw9*P;)bVCx96ZHuQ+_;#kyDZX!iZ($rX%w| zwxkAU8Bv{@5u~v|&PxFkGQOwzg%wvMapXWb6m=Btd|8xSow!SMzlT59=$v1ckU=)x zb>}93S!;EZl9Ex88Ex1cc7bTpJ8Fs{Nv#HB_9DTLntvDJgY<-+2C1z%zTnTP#va6!Hng?ZJO|JUsU_~(RmJ61E1o28EGw(|uEVqX*iG4)IG=cuTh z4@CK9&dy0M$$fVj35_m(!Y|1hFx{LpknA%V?mcffC65_=-B5}G=t6VUgES%ltLGW9(ucc7+8+I8c zS7CFAtV^+ituHm zKKx%8T%3x4DL!Qu2ajrWbmd^ir5!d7FZ`yNy7Mg=qF=h{Had4(^QcpfE?EChIjH;q zw$Kc>xl8ejB)z{ikt0xEGNhokba6$o`_G z9sZP1VBwp^snM6$%BI286)N3}St`W8`;537ORGInP94zAQP~ee~ zyh%wGg)pa2#w6LKH>bC+85t)c&;bHF*sV}UtvbtS^4VwxU+am`!YzC9ykUIJ-NIKM zFAvw^Ary~&$k1Dz#ZtuzY`D&jZmc};+n+2G3|RYHYzO>w|DT%HCV)S*Lu3NLe6be> z|L2P>CH~ty2LIUO7xI4sT55U3j?~)Kz=SigC5aStdTTf4&=VKoz%!r(jmN84hFR7i za9GH9mcKu!XNT}G#`m-G{2hs#_(w4g%-=NUU>WV^RF3?B)tp}z>S-5<-4|z9 zs+MystF_f5BH+SximB*2=A5eU33Dos3$+h;|F!h57Ll3Oi6qd6HmGfCxr?c0Ni>p+ z5b?I2rs);uZc8XSNicd`x-O0)5t!>kwA(CB9#nVO%BSNaph z>r53vMkcV9P5YSez8g!@QF|xMX?+QA(aO@NwAKEzl$gn``1__+n<0!;VCm4wuCp~U zzV5C-jl@!0^U24n5@O{f63z=OMR9K!dDjR{iCdvs6+j*wiC!aE3xPqt>_@1r?@~b4 zy)2lW?h8d%eTSsWwagI-*&mSfJCxR=(Spt2)tuTKeBjT>(1U)lU}sP@LU2;cdgm6h z8^*4u>m&PZaN4Sinn6nguL*lk8QxpXhGr__h!fNBMaEBPY>EORP1b0Crrc;*9v;uw zP1qZY;VYMW5c^q3x;cU8E5oQo&SsDM)-KDut?{6qIPQrGPUhBA@-06>JBwS2*<+=(=0@BN%ezpX5@%FZ&`g zAM1buFi{A+c#|f2noyuB3!xgU5d@BO*0h_|Z4{&8vW(LNnw>g1c zVZUa{r&!xJWGQX|u!(Y}IF9Nx&*2GDN>VnqiJy8oIrF$<>G%m)4`{`^QM1AO%A2a9 z62v~iAc#~+p$z?v*hL?lB%(5R9uz||$O~IUq<*3Q(CVBseqQ-C%tpHlXt$&HvbLyQ z6Nb2m^ zBY+fh&wzYIweZP1-ez!OXCW(4JXvC-~se9>=T!;G$ET)QzQM5++=$r*3 z>A4REmgeBg{$O30Jdm0nJ53&~V(v8?O6WVnN60@6LbIMkvN((0q>d75W#4?Gj1-lp zB50i-xudF!-}D02Fc<%Q6OT+9l!MBkT#J}-Oni!JD6aY2r7moGikN*t$2L_=(0O%b zjo}zmjGoROAKpb0U7Pb;tPW0zwNGM`;kEr+TjLzF(5ojO;a=|IVM+;ChIS`GXaFsp zFsL3Q`66zDid{c5UIEpC_()KOoXFI1KX`041!OmO>g5!(D^g0(d4xKl5cdf+pAD=!nx+cEO>e=6eAL%Qjbk8# z*Y>Vt1gpv)z|9z^-WZx@u9f3*LKS^1m6MtMnc>C8d@j0y`|p?@KL)XWB&PO(aDLLr z-*_kR+Ym|7Syp2EuQh3*E3MV&_YDkl2!K(uLz+A$b2Hn79K&-muwstO$xD*!Gd}6DdfwkR-W6igBtYd>MB4>x%!PqsqpjqY7n~s5@MIx*>n~Y3XndWPw zyOdv&MFUVm1mq7%_8{CNB6rJdGWe@sd6FWhREH{EeRB}ByoLhT6J=qGPx&fEUszZd zINdJ|Kn`cc>`EzZE$3zfQzXRJXKx&$wg_j`l(qOgXSA@?bvlF} zKJRVT#pERi1KCWvl~)~_B?c$Fs7L7YWb*2q+O*f`IfUx4C5>SQE{10xB)K{tvP!!C zWbtAQpsI4=ZVB(MOoI2+xNwZ}_eiasEh7CoDxXu>V3gV~is3$z&RBO1S1!c!RC!|a zKeCY&n*$kce1$A8F~N4iRz`AvCW6BIrR)e)>qMClnUds(u-;KZ`X{yoRkS9mqr5Z8vFIB40m1#p;Wd#dsO{t67&$xa}4R z*7609a2hmTBj-|>o$}#=x#`V}R0qv!j6~YNSkOgFjF&Y3f?cA%@a~amuu~XuMQ6{J zn?-V;VE|)5r-{?cesipm>kSe$Q+fDBS+9?tdH_~tm5-vsYoIb2+YiaP`>W4a)LaCY z?$*6uhnAE$qyC87kLTkAn~l?v>L}$GxU8VsN2!FTTWnM8=ne65d7K^_rI_cq(&0!3?R{@bj_R^-1a_We+3 z9=sgw(wp}&S>_clPveHthG53@Xm*>Jj?l?F%B6QXUD6W_Q;*24q1+7uCO^_8?|Z^? zvhEv9kG!|}kD39ZW=Fh4Nym2zfwYh3PwyWCi*;*56Br+ZtqbuVQGc1)9>xX3OaB%u zCSZ+CKo4@62|x?zLI+K629HraweC5UfahM!WM6^NuDpRD&B;cT+}GsQaFdVX26a@l z^wkSUt{A6Kq)ra|g2qBm^BM;G0`+z5wqjPYzj$n)$g7 z;D!uV%-o0@#8;z-)k}$^Mxfu2ul(*2kOxBzuLsSkXc1fy6byBug1KNhYQ59jYMAdT zhR;~y%OtM_Kj0i2K9dsh1nUSKKpZt26MqU-ZkpEc~Tk~5oW4dCW7Bz>I;vH zZ^L&{xJ<_-`{1RkhQZS^>|}`7P!QGX*jE?Dy6JRnHbpMQ>TMkY#hPyNLZ1ugs>8Zn z$1y*nA&>YIEUULtk{@*CH~W=UELBg z!NX&Oikk?7%MDj9btQ3$xNXPMo=GU-=8MWMgIeUK!!E(lfmkRicyw@c`aS%QT%dI# zB^&KA%sxqT)wY3mpnlM@0Q?Xk(MQWza~4of-34=DjB1%XC$>3LWb4hkC45~t;SVDQ z-9L^!ChNE%xG=-i;uL}NB?;?JcE-ZFHPnwrWRw7``o}J+UX0z-uE@1AUGA8^G1nw4 zg&+`oh#WcRqqn9E$*>zMAxMsYZuB45(0d&r{Q{B3ZhQdJ&FcK!+t_edB@f=UvxArz zgb(^nGt`tC`Dk4wPnek1o(XXd&}=U{X>Q@0acH z2kgqhVx3d4dE<<2?Jl01_72qns^dOn*Nju4Uhf21E5+vB zFhh5LYbh%{pj`6m;`EufW>^qg{6SgiaH2NS>Emb1p=0P#w$`9;M{2zC^VqmX)2KEX zVNx@|C_t5YMhUACjSkt5T}OgBiZ{Jyo_^^@gShlx~blbT!d-c75BG9 z5kXjQhdQX9}!V`~W{S!+a)6LzuF1BB?x%Xki^# ztU~K|ts7-bgz2P9hCyr*Z-Q6?mON7cI?wi{Pm3CG(3JiGu<2!CY$Ym~7kqrebCVbF zof^KNS2GVd1X?r$hTLr#EBk(w$yf3?`YzC>Z*pw4$69%M)YGDUE38@~E z-GeYy91wHcxI|1Sh=7&n9@P24f&J`#!mQ6uMt)0&RpPSDYpn=orFgkJ7t06p(-_4# zln-I?-1~qf<{-V!k8(p$BRy?Tg38a?oFY>n`rhIZ*`N!Y!skWJG%-3)_vIV&?^Lfv zerEUePY4#9S{;W;T2Nv-$Qfd$%BNbbNMgJ@R->W5wVcXn!L9SPTp%cW!^0M{V$~_5 zJjWg!TPP4p(dqf1Nn3jnRmqLCh8eNm%|BcMp{awL*s2C`{0t>o_%rZhPPLgRRJq&i zPkwOpp_O`rtdN&M<_FQj+AW@rR}u|8ShQn1`aJ{RL(4VN)Uh>gaqL>meP17_?(-<&I4kLs_XCf=-rqh@YsspK29IRBY&1&>pRBsO9 zmAikF=Y^Nv=F`g0y;&EgULiMn57-L4jga?!&x{(|ucryu5oX<)m=m}oo2IG!jX0Il zi{C4`hks_zO1=Tz7B3`r1B=V~@6)Sn&N?Z;Ry7dDoG*}#sm}SN#iu9fcn|}}XTWPW zmd;nC>@q?cJ|2;kKPTHFLBF}Wli;LYx&i|yof>D=^puKVUmh@)p;~^A7j!Gwk-n= z>e`)m%cJ5Emv}_|(*YR|XGJDFjPIY=5R~hOce=e2wfrbZuKXDfeEuU|;fI@GHV@Fk z({77fGMd{=x(z>T`rXm*m^WenHVcxlBtNw+ zdTuGRl2jl;sH=>SCvSTStbx?ZS^ULsU$CG`r6hgM6duSV#8$Q&ncoQo*|%>KutFrERV@hO7B0H77+z6izMvGG%(RKQg<=zBY*7@I* ze4nF=hVS{@!a~=TF^nFFm(yS_<3>*eHul3Tc4avyd2_&iwyM6$xw${p>9_3=$xS6- z&sE%XPaPFKEi%UqCY$JQU@L-M3-}!Nzw2MJMRcs2eGO-nX(<~=JS9J1XbMOjlJqZBK^?nfK zk3HHxI6=qLL>9nO9r`}uR#aPHHEA12v};4I>JWUdp}a6DUliF`M6rh~R0spd)iNEs z`h%HxAZMx=GArZtDQW8AoIy_NGeM^f2zV)XG}7b&>}oLONy3)HLxQ!8Ma(&d>$&S? z8gOgr(Iw-I*O8q!_z9*#G6hY)%5$aZBz;h(P)lbWW2y;GhXH3RGfI+2X? zy>2$lU}K0m*HjpUuQ?m_Un+8n!K*VY%nt@9qeRM%}L*ageVxJEXCvt^3phi6PWx|c+<88sAZhew3Yr!4P?AT~ zjSgk3;Ar~2KU6V#Wc3>;^J8MO{@Q1Pf|Q|$zU~jnb5P}@1I%o!aJk42b1<0v@S)T0 z{#wfFhFxMOub=$Si;ITnWatQX%d|=K7mhhN!(8lJ=??W-VA$48(A&ky2BPHsOIF>I z7H%RRts+3Z;PmK!$L;qb2hL_(0P@A$SVw+u0|wmSuVOWiewFkjCYBp@%U?e3vw;(9 zx)jl<+8Qp}w}H7!*a-R5v*}Ww3!PEaphNBDjfF2L8rgd@buE<7$p=G##%E9cQ#;~a zgLTP=X_P?&>Ss1dvO3zm1-o1-ZC`=&f8n$tv%f317xznY2A zzp#hwzB7P5ffIt?H)f={)Sy2%;vFSeZ|5YGr}L<~>qmYTj7*e_4Z^>C*5oJ^(MVI; zL7T5`qly~AOrvPPU>=Q#w|%2{ntpAdmNA2vu;6) z=%Z2YFurP!5;LGI8&7h&A9Cv%m5~ody@?1?zHJ&4n^xV9R`?t6D%V_A575Q+miUzS zcBhAFG1O5N{~?IGavhRD=hs6RHRWJ$##+IsJcH4JJGZiUaykDmi+boKQVe~?-<&gX!1rT_&HlldcITJ+5M)U zid$G@CnMHl*@FXWH51BJkb~7s{{|_6`pj#=8hnk>M>GcT62U z{_0;f+9>v4(iT`kacs@#DQj3+3){SWZhm;>9tJCz^&gh$OyEyV_e4cV@pR86l`nlP zZa8%bQyz&&;>q6|2!H1t8$tcLH0D{u))V)_p6fBd$@3axUl-@;V$xWLvi+VeZthz^ zY@@TmMKD(oPcN(xv;tZ_;R^U9;O?qbip@IiX#N&$B)@LDr)vCDoUzn4R zI6P4%zRgr4^$gO}>GnmW9>-!+pD3xANbb*jVMtjyK(D^jCR3_@fiGPW3oi6LBPEL> z+|)7)(B+#Q;f_+BiM)UJJvb)}6TkT=^vt2)1&|>I25jV7b0f!N$_S%l0?RNLqu%{f zB?E1V|IF`|`vT9~9qkh6Sk${9HT;v8NT|=@8j-G~WMF(z9c844QO8WWY4LtmfU(1y zh{mXPxbd=eu&E)#Q|WwIShFit_n#+B7UYfOnAf$+F(W10Ay&MW#4EuV8rx1rME{qo zr6NUN;7{M`ViX{z`2V1L3m%hYiInvxb1K|#?#7IDR6UPX5B$}a71;2l{5pd-Or~z1 zC?^LcR;B+LuGBjD@uQupk)vh}F4q_pk1n+D(m!ISb;V|>fI#^{pnW{j1gJ9t{kT(9EMLyezaMC$(Ag9 zl_C%qwuU}og?H|)9Yr#Kq)JtI6*cV?RSc>S3~l~mW3Kz?hpBRqXn4tNK8Z!s`-f*< zp`cKu?-6(YNR6u3U3_Qa18pUQfs$t{CWjWYJW1~j_Qns?jSz;g#dGQXgfR9e^8($( zO;>3g-49%1Oy`y4X_T^x3&YdTu?NDDB{9+(9mUc8Us%f5wv45!#^hNBN0V683I%L)V_zZb9nuP+l=pv>Q~`*8 zZI6+TsxR^kIr>oBEfj$;hjb$E&nuX%JwIO}e=FE?M~ur_*>hMks%6|fU2xM`oMrcn zQ?JptVCQkY{g1>VujC2 z`vf>y!0cMR#!yK<2qPk;Nc1>*+u1oc49Ct`-jG&B%=+9*HY!s_Sjt12zrIS9y5*AQ z`?@tF;oAE+9Z+Y#&68W!y{69{gb()cu!T?sUzwDvl<;XZo9&QQ^!%eQSpq%E>@)R7 z_9|4(WcgPW4B)9ud*PT6H*0J%Xag3HQz|y-b@F9cQh8g%6uzbJg0Q!|#m!ku#X7FZ zswI7@AJeqgzLzmO5zs|+Rsg?$JiOg%)V9)^pZCF1-*UgcWx_q;VA-|J4?dF&(7!eN zEbjjfReRoTwoWg`E3I7mPAhbg-?NN%$s-8~ljEM)N|J2Ov1L)C%^w6jU(e>uZ?)5uKIeRwE$IhoZt@QED z^j4zkO0q(F9I5!&6M>~ zfmsX>%pYXd_1)@XTOkuoowX?vTUjT9b3`T*uJi$9Kf2}=MVDR*`;2g2J&IaS@%NLT zupC4gIAUI6SW|ZrbCar`IG)bG>G+YOmRgVjWHm$u-MXIRG3FvcTMB(fB3~b3q(L>J zp*+;>D`h^Ig4tUm-^LD|@5$?~0(=L#rkIw4KGX(eD(=fVn!rwp`RUKRQzaFBtvYVf znP{h4lc0oX*L;o~$pe{J2>r9djqtbD8(4JajGCj8d^->_KTY_44wh9z=y1>tkI&L7 zwBx3f>Sn^reQmtTbl+^6Ii-8OBZ#{sgC~*=yo0M(zd(TucPqNWjbFdQY1k&v_U<%H z`>B5xh!_&8gt%+hmA;)fX#JbOLPfSpeuv6z?SEe_R#TyZPoubfo3{<7`Z z$Y+YRIzSwEvAJ3htI0sK@;=Tul{XGD*%EgJBXud_^<(RsH|tki72J7*mY-Ow{@Sy4 zo%p#=GDjosJ+8TT(r#J%XVk5n8_&i$ROAYW^W z|8e5lxsOPE_afXC#vSuq&04fd!GwK-Cl!WAvZ@3cYA-A3AEZS37c^Q!ui70ml?3!;OWE2yLkggn!Y<97&tTV z)NQBJ+lj>_buBP+ZKdu?RcA)_E$7sCYT+{b8boNBe)atAFWI+C_l~fe5L@r);(sru zr{CgA=_)jbj=LipPZZ3M)hTPuk!$sGe~JrR|Dt$xfe?c=vOQ2-ppGTafMk1sY@UXg zqEDiZ7cQ!{Yu_{uomRe_ZHmRYd#-NoVKh)aQboqc14H4~DoLvqQ~Ix-)Ue;p5Q;!Y zf~2VyugT)j#pk=XtTL6t%Qi>KdK)sk(%xFH z!&`~h%jp|d%hUIB?ntN>+18tkc?^MFo%JJIDHVsboo!2{2j!?m%`g@}Q#r8WhP-iq z&#EXeen_?Far2i8VrG1Kk_4D8+(8xk;sx}0m=aq{qz!BL^4*rT?2W$fDj^Y`m~dNC zpX-||;=JuXWVg5n_3D48l-<&-UcUzgAV;UjZc%TVxcv`pZvho$*S!y`2q;p5bc2L| zAl)s3bR#XHG$@@jB1lQMGy>8fEjdbeH^M004MR-)Z}cgj_xHT(dDr^BwbTXU%suDq zeVuE^IcIOj@e75;aIN9VAp!UuuBBVLmErl+g(uL5d7Lum9rZ9;sV|$sSmb!RWw%*9 zX9%;u5r4#WW_7snCB9~Nq8DnMbuZ(68Bm_$#!1t&YD7kk!>UQgSS)- zgtCa)D0b)-Pg%r?p2fl67VwrU+sleo4`aH*lSW7%ZDeb2G_4JGJLJJ*!Gr0vv-_0X zv-52^d-q-(ZWuJzpR#xM5!kki+bIF}*W%>Tpi;iZaszsI zLth%=A;Q10RwPH5`MPkYTEQL`-SEIJXs^Ft+mOyixQJC^q5&=~mV@N5QyuV(0==2T z)yaeKO?^O5VuHMyf?R@6~!6;O$Up~J*1bl{2Qle zUC=2mSmr+JhPUKq;##Sw%bPI=%itRM1QLA-3thxZ347XP8XaZbsV2aU^zngBV<>D* zU?ijC4Vtn-^O^5bSm}R9@5XpvURdL!lde`15}wyKrPZVy*UwWJ(H=jP3UWXg6q;qW zY*2|s2<6tzbRMnP9BvqE0>vFK43vPT(~q$qp1=86TY5{(zqHwp7tflTOHF6-Wjv_a z@=sSVN5E+8rz<#9oqKpLTOO<2I^q6sHGcX9={O^oG`5A)^^fA2k?nM}TDsU;)4^3M z6hF&6s&(#VVWk_uFh-KQ%L~IYN(BnQk3Ee7LAl#Pp!)PLW12o2RVwrxdn!*y(no0;?~L-UoZ7aLzq_PvuW;P^b9{jdDV>=_hig($AsHwj*6AdrwE5GwO- zG;tsFDwu_7btI0q)~?mCySSGSvVzfpi5PDynrWwqmR_5KMVoAXN6r~($^j{PBG-mI z?-M=<3+I82$z|yeEs2TEnpX=%+SNta4rbAss8YdtRb)!D3191M$(S|i8w`yf_O5@Q!wF%7L@v~06)r2)>()YcYmC+d9d`1< zN5`Ss>dLB&>C`rm^<{*MdK1v~I&pPa+oE>2{cFLkZx$Ovc}^l-cdyB0LuI71@3uv! zXGac}c|JOZt@-Nvo`q9~Uc^l8(k5_BUzrr2+{F*z@RllsE?GTTI>Yv)8E?$4&(^^> zkD*3@-Cugimq|c|Tcezu(5F*OTA%FRcw#W(20Z{WkxZua%1cL^LSFClT#gQvYQ=K0 z9n1)OkJ=^gVMdQC+~OP?(LRfq%SsMJim;+o7G7IE7@IesM9=CQrT%yjP~9utm;Z_G zIK*r4;o#^YPXPnU2yg${v+7xpe47BRy&v4uWoNty+PlsqR{iOI-t80v90*`K{rh~r zbJbt#-P{dMH?&b9Agy*uJ_yld?xzaJ z?3e`V`j^^=k}9yXf~9LCSViCG=r9yo2xibVZmUK#hLN48lHY+adhgRFup2`;MOf_4 z6oajE%9T|IzqVGo%~Buj%{q)`b!S$}vsg<%&Db)VDd@s1m2oQ1QOQ%@k_15pR(JN( zwAWe1?2Wq?VbRpHHfPTH?dv8!%e66}N$0e+`$kQRM1XR?c%ZUc0(rmYJ#L7O>un=+4NNIBFTp^1z4}^x6k-quOxjp0-G6#KBN`>>2jrAShR`LN+~5?E6)= zPN5_nrZe?87~h?K9@Y^J?~~H`*{F({6zK5ObAk8pz@u`&0*GBD#FDS#P&GKiE}S$N zpv}sc+8P>}tQoO?BAzf7&6550;@*K_EOX&t&dFAw6OmquNjvGF+7rEF`#viN9`smw zi(C)lYRjD+ugS?S%T!Id&Uq`u!ahr(&1UWPUHEe0TTAnHgx5Cb?uNIUA5vgV2;}aW zy~NE;5&6_ei6FIIz3}#4aPgO!lN3ZMn`TH^s6P2*RMRq*5*atOTa|y2g3t!_O!f0S z9nYx=--oB~GDl|9KpFsU*u{_)9#f(E(lh?}1~On#_aZ z4Ma#*7c@)2lMbcq{_=v@RcTQ!_XJipENafx@nGJi53dzrWjllnw2Ij40kXBZBn8H%39(0PhB%( zrR3s(jk=EP5qF(m>R8F!HfDZL1==PR-4-cRwBS#ZCBHXEf$JLVhe3Un9dFIt?Z$h8 z^-w>Jlv1{lO}sc`(|Dl6T%KE`;D91E?8_qk$Yvzn^M2N(k(vsAx))e1Ej=}xU21KJ zx=#Up%5o3D)^zlNoU(iQ+Bdj@mv4}#N30jf1ss^XLqbTjkXVJTwdvWoD~#pII@GX4 zS%4?CH%}0+4(Mx7o&nVrYwi8CIt7Yw9V#~Gx%7?l;F&(dO+g7&9fzT(IoOk-eq#sR z)%({fk_YpH{d&9*JAH^1dT##oww_ZaY_j@Kj@xq|uPaUMahPr1J@#}^zx61bt-Ym@ zLEcf@7L$WM`>j^17LGFsGtUKeOC;PX4Z>4B}| zyh%#rRE2%MEI%r2Xh!Jx!$^x(l$G^RXCNq+mQr39!|ztlUs<}FFpH?)^59iM~; z?~gw*REV(joCzzJcNBj6X~FV=*;6G zp@2Ig8cyDv1W|!m&0YlC?^tOhJN2}Z*uwIQth_c zFF-JPb3X7c` z^l>&BNf+O{K6q>U7zb`2S){$Q>xwX&=O|aPa&S+_av)E}^HS zLg@rU-GL?nG;`zFAf4&Z`6cn=@lGOWE^O#F^#LyprB%wpS^z=7x$x(<7BV!%q`AgH zUf26<8m$GCk6Lr@;RwD;XtrkDI(Nt2-1G?gZQEHl>=3@ZX!yldOZ++Og5VxflTC+O7ks?WT9HS_+jBH3E!a&xS%s>QDH-#wRkCdcMh+qa*s} zUaO%gP+XCx!E`VABh`LNDH4FOWwhZatgfQ1kWjufHt+0dj0Q6mQGrL0@z8e$@VA+e zpb*T=vUN-i?>f0>?4*&e3+;g#hAvH+HbL`-_A33o^43EcMY$fb@!?)(nVABmtxL2| zn<24sK>Z=feKU3?bF*K8tAHTG$-ARM>$Ku#Ea86?NSF^j_#o)b)#LyH(yn}3ivmGc zD-q||!Z*}I?sBQohq6*ceSM+F*@To{Ct4#FZ~0c>shpa~9(f@-^nzfYXcnJL924Yk z**eL1W8%d4Kyh+yk)Q26T^m1ED7(Ih)js?%Jnuh}b!i2NlV3!4s+LI5h4^*p z8H2AD99te|3dc=*XaH9f3RpIe8+as0SnZSOf_VoRd>60wy%R)_s2O-r4b){b$G1z% zYhmNP+R9`}Yq8`c=!xrB5?hwU@}SnOLlP#~*DdjCDJXx=BJ7QV9`AKvsHORh4RkCc z<{+IVXsg=Nx|rk;fBW#hY?>J@#5F1I6W+W2j&j@cH=;B+vgy>$;HcFGx~@?#E1{sU znMk1Ag-!MaJGBos8C$xn*WA@K?X?)vt|yBW3e2_VdPJBnD4mKj_-B02@j?nSI`*<| zeWL+3UrMt+zNDqY-^N}-8$@@!Y1w$-E+N2r48^~*F7~-AV-COV)QC}e1@eFuIt{$W zB}P9eBNLqma}4S6rg%-?Bu=Y3icdXOqk&8fq*op68lu$JQC?Lckp`lr8Aor;z7TD< za*vvoa7=dxrN~Nod(nub7M+Iy{edQ@N~eODEukwt#dA;Bxvu(8Qh7Y+=)#jmAoSEc z-DK2qI^pnq9}?gP*rC+6j$8j`Px(`_5Wkv}DTk3`A3BT?A=&&D#&*N|?XhqDuFZdn z@WZ^SIkWVzwG0of0Em+G=lsNvgDA1_oci(J`i0?ZY^c}@7Pwp+kEYl1Hli;<-qqL` zplv!C#pN0*!$o7dI3|06|DFKwOISCe4Z)A2jlr2mSge3#HzuPD-HmU<{4$Z(7ZmyF zyIQDL+{8at?Z19g!rW(v_WeQQ?ah<{25lt8ir*Vy0stMOL}+4>OZB%xKS9`=fIJ6U z&n*IZS6nZ7>?)a!0PTg|Ea7>eX{7wF)Vj+(j6#vz;0%5DMj7m!6YOr+yy)Q<()Tb} zIWG~0YR3=!Cig@fjFU>`+m=wXGBMV1@ug>;Cc&*4!2RpGYD}g3so(Ia-=3Xi(1wIb zd+KAiKInt+xe$A~q#9yZk-V#7kz^EaeE{>X?8-mOcQy^u!M3o-X!no&Vb8ZG$x;}z zsv5_v7AeR@R&>O&D-WZ9BCs3x?oN`eiX(!EgSJ|pdP_WUJ9TY*>ack85U49&7~g40 z8fZ$oVj8fLUtt53~K zo*q~;Lxk@Fef8QIN={x`H7vn)ZsG=F)~y7!OFp{cG8WMG@(p2djoW4;N3@nVR(Ew? zlul)t4E7NU_o~16_-Zs7e!RRccH9$pDL}9Cg-w~3r&`fpG!d*ZHiN7S;`n1!GSb!z zk30lJL(DW4Rg z0gZ28i@~+u_A@=n^)!Y7hnu@wf`t_gNoJlDInVxEF zc?aK<2)%o@OBQIF6!dzmqjw5FCFm)qE%t-zR>wwNGikV7mlbb>QR@RGN5Sx;WIUG+ zrx@UjCQZ*L)Ie8fDA`Rir8JhB?vmyIe zsUK82`QI}Y!`(3#Fqd8Cia!M~h21I|7BwhX=~cUDuE?)^4U3KcWcE_yac?ad!Dm3D z^E%8BP5A)ULOQuY_{F4hYMi0!W*Qw8T$ci+B;DGRP=zY zHrP!SNr8Z7$bpvm5#ch0#%g`zF^{AO&>qv%jBuWA*F``nuV8>?x>0#d7=>qAJg8rL z`^gsyWqnVQeh~Gf|8-@6- z6>EFKmy#KPDnr$l^2I`aN9Wld+Oj}BI(JKCk#xuE&AZw(Qa=(w@oKf^#)sj}(}T zHqhHiV`fxGSl2@Rd|-y1Wq50w6!M3(Bxt!hVEVH|tsN-UpWI8(&VkVxAfwMgDxWVt z{j!O{xL^IQ4vn0`cqh=6aYij1<&5vKMCglzfiqBkumVCDHnE(CG$GSJ|4D z)GhzKgJ*zRHQIbX3YPlA$q!A?S;2KgX$FJZ@$Ga)2YPe~v447{7TQRirSO{3zM?6n z1EBi|jt$QNQ1iCQp$eAiQcr-!sH2&ZMit_Exo+Bmvw9=+tsO4M<`Ue;W2=Z9Yk zCi&EK$#M0sFE9hoa)E*SjTORNI_`zAu$FPx4=#L>0~&3uJr9!wWi|nwVqQmg2=Kb~ zmHJdHrnp2P!G_m7L`U7o!*crt$v!-oCql?*Z9Zti!143`54?zMYB zGm!8%7b7Y!|7de5pD)ZRUo2dVkKLx5w)UxXkn654Xp9cyJOk6u1pIywHHYW1+swfl ztAl+<*Yy!&Ni(6YtG@ZEOQ;_=tBJJF_+weaKrVhQjJkB1Ce^K^Yw*8IV)TGC6%fj0VfsUFpJZ-vEvNqdvz~S7-Wze+C z>0e6V0<4#Vrz&h_jGmyD|qgA^$9(swb-UM35ts$?5$mz*Ywlz6ph_wzHC zQfsPk&btMeRLAw74OK6Lw8`rGoe!`NpGrHQlNkm39kBAl1Xto#Y--o~$Diwn{tz1* z37Y3_9mf%N{W->nXji*rw>HgNu5$Ri*q1MRw3G0Uq+K`YfUPjvtc6VJ2l%=#ix-8X z4PCFe+LtlZz)53R#U;ZZVIM*9uAw5`&_=$4cQG4}+9a!kO9TrNjlo$*2R$O%Lt0-( zdrcl1w!i#MbWS}$zTvIR^Rv<5BEhWSi(dM&S+dZ2Ur1(sk{&V)*njr1Et|a-n#X|= zDx-f^^oMacz8k4Jo2Zk(2%LHCF_+I5Etd0HHqDu!1X`-A!xt)I03&h9!IytvL7zgi z!%A4ov*>XRk zCM`#Zi;oNctk-0{Bf*?B(DOl-=fmXL#7J}?Ez(%h=@whMn$X)OxeU>^LT912H+;{A zaq!ib3U#6R4Yj?w1fjX2K@Na|m-xfqrCJKlq;2+rNMqkI0Oc%tS&E*+@MC(QGRoG) z_-Ke)w9lwo)eO|}y@!mz&ENQO`Va6=^i4>~_2qb1a`70_HP5yL8tt}cq~xsK1!6azyY{65f_=X>Qi9_a%e(SWuC7?ByR z`7+sk;a*m-;rA;aP=l;sY_8?bJCL7X_EZCKSpsLg6t^biGIcONzM!{w)(er3gCq?2};!(8Ems(z@3vRIVbZ8XLN zFl|TMFLx;dK2@N-KCMu_G%2lAlgC2AIl#U~+=$wJ)(+Hk-XQAq4&bha`q|)Xx+wvR zS!R8mTpaMcPejdnBUl!=!^*P`@&NgYoqdj%cs4h%_Q8RIq~~oIcO7Yd(~np^MH^F~ z;29^ywirtqN4w;+3!?qSUU8tm)$qgDcS$$y*1)gC6p3q-pdBB~0aN?2kn1Gc5~!q| zrc)HShGZ;@ir6XdiR2JO3I?}Tz~X=#Aj^7o)9M!S+Zuwe032;`wZ&0s!+(BuMG7dv zWBm;9ctA5MmzRU~+2;nIAe{hm$heg(r^M!M%tf5z#!+jh8M)15Gx9Ag#Bzc!6h~4nmvUnli*H3XVP0d;z1gpqTP(}*d?CWd+R=d0W1&TR-XX0JzV`)>pb{VUpT}?8!9o-PQEjj6R5voNK>vEas z;RE*iBYbW}8X{;whflpZq0wVlDDZHQ2-YS5a*mI15FmHe=PjC>h=Ed4~gQrNkEv7{)`W3Jo6ss1$aAohzDNQ zlm_q2yNgjcL5Mc|BzIR4yovYfND!DFyA5pSWFJ8__LtE*BS*W-OG`vSS z14<~zd*QptL0(v?u9wV``tWgX8C|_a27ObrIA;AX>36Y<2|On?sUN%g2!!_Z zs2LMP)+iHsnCY^SZ#b!|lSihG@Mn*^_Tcp-9*kIs)>qz)+_oclgem5%zdlM&BeDn% zI2vINIBY0pCmhbj#Me9zp&ri7pe;A!1db0$AoqeChbHtR4q_h0B`gCM9hRy8V<L zx-Ys|=L&gLvWy(CqO{I8WGsq`!ih%fDcL#_e1iK%(g3BMr|~v>7C=28K_WRlZ=F3m z&SwdcQ`Eej?$M|XS2UM67ctu1_5EoyLDVNVA+9RGo{)ECezDeR!-dVL%Xx42q!RML z=^ex;b{M`4XjuYJ!~31V4C_pJ$GHy|@z*N#WRe_Nk(-Q0wJ9@oGc6UY(}wmu3ySmW zR(SWQt?yWU5wGsNQ&uWS97_WYzq*SeLv1e7prg!pRbO~-3s=~Ns%x{hXS-Pc$mAgY zQUK)j(&@Sja_IUeHv+@WZ_%B@NCEADu?}*;$)NE*rsIn0^@e6SbP{VNSf$Ax9G$=A zS3cvRxDdzAeq`zC{jt|%mB-Dfc9#Uv&v4J-r77jf7V`c;nwTHe?1s9&-4?&B!rL~< zAiVANee;WBH%yJUYzv8&W-{F%Q~9+DKaA%@G(bfvDDphxbkARm<^IH6y*3E^`+@<9 zA9dc7Knbo(eTfqN1odRf)z-KZX_+w{dbq0ct*}J1cc9nZ^qn)Dar4T<&}T4EfZ&$shh91Jvga3M2#=3M_R_J7rf?9 z*Y^@@{lk@zs@)p}qSDEXFKnC=1!kE!nT7OJ_9+|crRuGGlgm#xDI7z2j*z>?zQ{r2 z&%V1RLr5e@6t=_^7cURt;*Z-QY~Ww$uQmR>qP%RzZ9*Mnx&v%q&6Bwun)TGm`gBY`+#XujC76-8)? z^n{)F{ca$Sf>t`NHJ)`GdYrtoXJPX^wL5MOBeGsTqwr2|JsjNTJxF3lUuJwf)mT74 z!kJ274pAgg&2B0g`NnnkhpW=(e`Q)BMPG4$B8Br?cW1bO+@tlasrgs4wPiO`)=$e- z+ap3dYWe}VgE;aLGgF4I2|MwE-L_Zyi>At(nkToJ+GmfgDwBe~ju?+W^OoOF_pCQM zy~@u|vKbQKk_CnA!Bw`V5uzK`Fur)fCci&R)hKHdZ`eBTwJoN+*jS>2@3lhrCm~28 zIo#dW>`4CXQYM48=&n12->Q@+9f504w^mgbJHbda2_(X?5FUHlDY16}1;33H|H{Dn z)KVW|>1-f0@4Kq)gE&q+T{n^VKG9#~aU7pFG-6mB-20&8LrFEZW78QyGyH%DzG;GL znQIW(Kert;Ggl0u-{E|tKgYxsUf}?xVc}#^-I=eWJ6J2?lxUrGY+{e-kLlDlcQ!zu z_lEKL9QNy;eh;GAtA!wNQ1DfWmy74m8sQ^^W>E~N@GLcPHIQ10+?foF7KwyGzAuIN zd-&)Ub;;TLIDsV29#W`e2-ByB98|Mi{WNnGX+ez1#r zUICp+upWnra1ad`ej5y5+%1PK)_!zV#IY%g-$zjul=bhJ6;Hhv*0C!dBz#{0i0FHU z1dgiD9(`}c zoBzBEMul4?mYFo%<-d`6qE8VCzdBnx{(V$+z7+;VJr>2YrjsWe9*fHRX%i=!Z9o0 z$7;2^A`;&mytnbpazT!Fc{f~~Rw46ysSz=6EpZ(B$lt_P%RDS1?-vXB~>|YhwtAwY5enA=)7>INCjwekHxlI2y>OX$v{!jdX;DVsj@0jhb|DdiM*l^-zv2Lk;0gj*F&)#Dmj8PA zBB%c>N!MSH;0|8}Bbgnb&HbGLTz{bD*B`q6Y#seO5yH<7{|g#^B|@sZH_id)B=>(> zS}Xl2|N0AvX3SeSjOyF|?F{C={PR2Wem?omfiH7@a^I!?X9fIf&;J)y5nS^{Az*^88+ZQw)2Rqg&Q>%2 z{}RnY+s%J1noBcVRMx!kZ_Vru?d|`~j$cB#Gm@rL0=bX(?@b{<3X>+_f6bZy1LIu6 zcsN_`-@_OQkn6(#;)z^}NCI`H`oDRbKBo`<6@FUqB;Z4~fPDres=oivzW85Un8{M3 zO(NYgVz5ADoymXl^r)Zz#pDf^ZDI-2 zhanN88OX7H=s&Sh*q>#80y047iQh6%@Y@pby~v1b3Mj;4q=cK{?h-TnbW}C@RDuPo zhWsw|IWXf8AlQ5iTtJsuSFhK9IRZr?(U51Vi_jp`F|`B;L?UXr4CXp}s%ET~zP1ug zXYVH(DKdL!aRNx%&54fm7lC^YdWKF^6Txu*gL%QhzTfCk2^!ULuby8lT$5VdU@ofs z{fuB&k|z_5h|O(td4>qY&zXa8h3zT2-MGg%=ZPc|A*sL~>#Fp;@kThh_pn3HYvJTjrd!qld2L6Ca6<`N%*NMd|ws(83_cU<{}g#BitAw zwUdjBk@IbUo*3qr=jRMO4#QW=pE<|VbF`AC0C};!FBc8_fae-=0Ba|K?VT8Kn%jDZ z|4b=}yK3kG*))UMqWzN;&hKvrcSA+bS5nIMzF&C%*r}6p_1W!oT!2APHQ9%Zt4C?G z0v|15j==MC7X$*O&m_;J?5UR{Yy9&r(nP!|A->O3E>j6MsC3jeZmI-Wf%FQH8p6{C z!S`u^kf+MOiDoYYb!JH^TtyPOYa^wo8NDF%JE+b)fi%#yvZjN(4u+60FyhuJ@`uXtrjI74$Q_Jj|2^ZHkqlh-)29NI?4%aHPoA z#i_k$eYVu?Nqjelver(TX20ELgg(>+6a-%$-2Vy$rl@87?d9~<7q9{4;dJZO)n>$r zt+>!M{ z_>E+*o^A6O0Ej3vpPiVwaH0uRmth4Lcx-g@t0UcZtbAyDGa4Nx-PCJ$F@al9DkuJU z#+|XMcKj@nRS@vOzK+kywEgh&G(pcB_9;C$3VF;d_iYZ{AV>rFaT%D1KD`#|^>`_z zrhhcIkRAuewa#O)UpGqM`O3QM#j$+RY-DPO>V8K<&^{1ZEZ~7paR((}EQPEi;F`uC z$AN&*9Xp|cxjzeg>kc~%;E{tlKho=wm{ypzw z7gjCgUGobmcJSpHuv^j0%LU1z3b6-DjK;cZ@R;}48xa$4Kl!Y)Q_oHZjF@>bw*ptg zb0Sobr{!pZzURR&0^UBp7Jo6+^ya(TyYS=o3^SR1W3dU6U0LZfISqA3A4OuUn@|i@ymP)e*2O&d;!%IqmWA1=3Er?KeG;b!nfR&e)B8 z0P}+ijG+5zO{lL2Z0GF1jcu|n7UC{}q^fk&g3n6jQg;1+@}&Pv*~r2}(Q;dcQbTcn zffG??AS+&z(v(&-K6)oE_kH2r+trNwX?HraKRm4G4EN0SxM%cXlJHNLbwOcX&am6N zi)D^6RZ=5MMaol8ctkN{#4u5w)tkh%WRRiel*%IBLDtopQiHs3pXLGuK0$H|m$Dg< z1VUiVuqNPdstP-6e4&Dx+4HJrTvOcscHIfjhbT30>@(^|M57ms>I@{qIW z+*NCgKfe!Bk<4$Q1IexQ6n9o~v>(J3e{L>iE^Xc#uFk}2vpGGb%=hRkIT%N~dOW{G zrBf12sHg`W)(JgNcImS5a;cfH7PT_$de`J=f~P5Xhxs0s%}M1;N=MXHEZdU@%%1l4 zbTPKLKy55VqPvqmA${kA`!gb#dz)o=TvL<5^&3iKV+pjiEA96((lGI;MLqVueenl= zs0=#VfT+8CTRl-Nn)t4S1uCmAO|zVCkUmo*2@Q8K6&3_kL%E&*QQ?ksK0QAb!8W1nD+du zi7^8#QJ@gGzUaEbxe9!HBr}?dal(WPm(awt@qnFF55YsH(ux|l5s&xuItF?MWyD4H za?g)v%@3kxZp&pA)@XoOm@w~aTR*8;GPQEd<~JmVCa0Vkm}5qlf+L4jj9 zM6~pA&4cN-jmZow`0B94Mk-A@c8{oOOV@+cCOIemRp&Dr|it6;|x8yGbY;F5lE5k5;u0DY(T8mFfxN%)3Kgo zJp8G>A8Isri_&n$*W+`Hn7)@YJDT+dEM|nLl>118 z2h>MFa9t$cveMqL>jTh{_3^oYH?xaItewPpA`+7SJkGZaN%1=vBI~dY3`VntUi$K& zl>Voa_NeSCpzI;rqUvjje zw^laf-K%p3sZ42%LIfLdaowtF`HReC=9#XcNWKC_p|9IOzeInvQKqZ?Achq?mQ98x z0~(Ifr7|=aTQaWoeJb#kwH%Z}MD=Ml2VN_vHnFZ!Ekc-uVaMsaU$=I})Q-}3$Da;2 z;o=73p?MiuvS|8ci%e^Ak-sb-(dU@=s)M4mFCv)8)=X4%Q3`KMJmeKcTBIz-K$(w$ zk?AYW%3&UBK51~@N@JM?$yR&2xKaHfa^|ZC0;?0v#g8yHQp=mW&w4v~a0||RO(xz! zQZE-%`W<#|FB4(B^B% zPLf610U0;%x2$COCo<0oi%h&(_cG*A`OY8@?kHSwYu6yskxg8yC_jzOX6lRdsPR3b z>XQBuK1^5h*~@9kmzaNLyW?FdntvvT2@5=oiP*ZZ4M%CHPGB! z`4W-B-xt8+?Wi6M%_p$0sVU%O+XhK<`$1wr-W%z<++M-hyn^<6pG)>{hNaZ&x}H#u z4L-D>jrf*5Bi^)A#Kox+MdUTlKKGtvDsE34Zvk zjSWniFFK-Oef5z^%AM7DBR1N?6qRH3j%r%9CaqWdx+Gn_qcvdR9ge?qoahRxd|pqH zx9l6>Jk+d^`CQ0hVkAwJ?!>zTm$l!R2tAZxWxOfJ^p?{*zw-}qDMwAkZcqDkazR)d z#zU8qcIC!rI$)*jZO2>3zXSR{@kA2<{g`o_AN*2L z9}<=_`HO(Fk>5zkSem|5Snx*KQObF>-d^ITY1!PR9->84hE03^W&tiagiC~08muA~ z->AWehS*%+h%{r1p$YC`EQ;}S;d^c&M=~Csn50vK2;DsCml8?&-GG>zMXAK=!jb{7 zmD+aH6Os^8NYIvTaDwd#j3O@Iw8F#^Az$w6Y7k{NI8xtyZ<)!Lvxq|L?Ko1_pG|U` zbwgjfOEPY33G9Oo$fRQ*S{3EjMa-1)1TmxOhY~%POFL|h^>f+lwt^?qRqlVp0Oqql zTfR#qZ;7Iob!%m|0>}V)LH7*`(P~EtUlGiY_8TErwK@_YR0BUdF0{ye9hewB@1c}^ zACI~oqhCU=2^ecW5$()^K|%S*&{i_CL`^U%P`4?7HKA&E3cb6KEq z06vx4{Lb&7zLacJ#ZXC!f-D%mT&DWIcOT&V&ac}z_$RS(Z&f03%)c_3ZUx-CwN zf{KdX{GSrT`MA~ixLquuQQJ?x?iB~3B_Sa>-`@#YS^i-$ zm9R$&u*jimjA-NTgHNDKW@d|+OWKHKN2R}rn5CvHDiJGcTTe&g!84`sX3i2C4FI*; z0BR3v-9I}4Ps-xr6Ok~f-g?6`Oco|*gSv~;L6mGBRdb7n+gA0-T= z`=7rH#*Q0@mpr=IS}E)7lAPHAo=AEy8?kWg&f6_H>g)Eq8M-uReZ(rXW$O&QEYU!l zc5>I9%A^Z50Jv2%0+ax9-(Zl*>z?3DM=;_HhBpzSub~vl)6@Lj_Fqslee{BDic&SU0OmiyLuvn)qRFO55++^e< z-D41eu9Or8J@WM^0N`Cr?K7M<&n|5Sv}JzZ?*H%>m>ZTqQ9^C-F)H*`SbfXF2)k^N z*ObR80Vj?I%R9yr>-T7Ss-Si}P^qS)I@Klrv8?In z?C-dB>*WeR5xOr2h&Q!SDC>1`0wIgl1#y`*-G5N5!71)pKMlaB8!1VMtm%_TY*&Nw zrvwyl;gTjcQGPDZE-^cD+Ni@FRlKj4W)$>0%Pu~C<?CL%^1*M#k4F_1EHZiH_#!IGH64p^ zsyf)JNCn1B)DIhe7qZIc&G3`D5@egO%K^|1Z?AyI0NyED`vQq*KwS^?oi-A*M;2QT6!;yt!XDn;BpM0Ib>N zDE0v0+|k>g_g=9c;a=SVoP$6&i&_()*1)3wIh&&v0~#O1Z>y&Acd=UgPxI_ZKBwqY zOy$v1yny0HR@AuOBEL1`(ixUcPLH+JKiB)Ka-E`DMn%Wro3yCrL}zEkD%;ZABaV1u^+Lh!g_5Dk~a99GK){KB%~ zcfK-+u4H8K`V}T_wb4W21|Z`75s~&d87s+Q}uaU%LWMCu5P@?b`lt`>ofHG>QIl^}1zbn$C9aPwQf;eX=-NLNx@ zM#M>&!}^D7KosuJdaLo%^8KZ%j7h6r|De+7(U-TMDSHF%NX1Gb@Qc)??OHF{p1u-} zZeI_Wt`A~>Yw(pk>|E0yq!4l8v_Lv2i0R?-2(5{}GTemn{}k=8fM# zlkM96lzD{EyW^Jr@M6fD{jjhe~q8**Fk8cB0HTTg9ZInP(Q z&s4lJSVSn~sm6z;OhDG{N=7}D9%A%t3T6+BL;O=$oVcSB%KNqj-@Y%6)&Dqs`}1809BA-_&x7-2=y15Y3PA8&MlV^8F`n;RDW2slbqyU5&NQjr|U>-)zfu zbV;6@pfO=;(+OgG6hvP-e>Us8`%(yP-hvoEi;`X1vdx$uz5q@#c)9-qhzUj#bfYF+ z^?r*91i#Eoh*_~K7#NQN&b=-W;N2Dxrd$BOEuJTxLMo0Jc7J4Yvfdx zQHsmAiEQ6VPU<*z0&nU5GC@#cqA!z|-gmyde?6=4!Wx9iA94jp-n93=GG3Rb@hOff zHhQ5NzT+Npi^T?|&^;}0k(nS-Ra#1wTw{g|sG&deYt{i)Igp_MW}R6q3zW?+KI7c- zK89WjfR2<>ez=qAgLnju#>?=HWrx`6O+K3wSUuY@U{Dw zaf-S6%y~Uv%iep^P>3AY%zQ3+Y2SfUa*P)Z%>|TP;Xu!gU!s0|@}|#v05Fx^-SE;| zEbBtq@z}g`mZ+wz7n81bh6^_d%lWXwb3caoMy7rGsmArJd>H=-kL~wGmKIT8u7@D> zCq`bxj13&tRpU4;d$5}^6$a%)?m_`(-lXrQJ|5eVY5CdPTUNmXJ9i0lywjx}X330U zsf>4a>izm!u3oQC{$`D)c|)v*;lnpxi>AtIkf$h_;lry!kuxgvi8r_v0>}muytSI- z>Lj!ea7t9^i>#2(YhK)bV+1v`G$y$w^O`n6!IY)SFS?uv+aB?BhX8Z3%ZGjb1e<@E zTI$gv;L^e8U)*e`!UQx1)-v)n6Adkq2W2_}zt)pi!gp+0Y(3TAq6x-F2dn`=>Z4+T z3p3Nm5bq_WL}4~*FNq;w_m-+-O|+U27QmTwR-(*uFuv)j=pmgIa!36ouBSPIF#51K z;@nrsO7yV%-NlXi!Af6eA8ggi%&4K(pRK3E|*|EGk z?NdDiSCT9txKj5_-dCv#G*3&=cRAj-y?Aw2E|kUuKnB4R5&aT*u#a|7h&;dp;Eb}D zbPApjqG7vel`O@}-@0!PEwAY38P)tG_UZ(SDCN9%0(=9D%4#dKpGFolSximIWa7~? zA|X#9j>!?VVP4_f{_b`gSRs>xm9Gqgj^aUWVY%WP?1!Qli>nLk1o6HuS#eJ_o{YVVDQE;h+GGw}E*Rn7rn^3g&H-_B=}fgN@w z6lm;tw)KrW)JxC`lxJ)Bh|WmvYWOY~0wnf;t83(BYp!>Wz77GwyR@%8iA^!!M3%Hy zObCp1NnZT$(=lSmF8v&dM^Z4paHuGCsfs)X>sE9AaC5LwzYs9?LG3#yLtfN!QiF_L zZfWhdyFR&qUCK{6sYK1Ijd>H^?1j?;1?nX&FdJOK=6Rvxv-{t;+#l4!@)nv!zW>u9YRGQ^6S#-l_4K1sU_*ywySbGG1vak$Gjlby{OU3M< z?b)0oP8|miCKMXm$=hi*slB9)!eL|3ZJBWwF&l|^{|VN;b{(a_veERWZ^4xP=z)Lw zXDs?M3-)`$y_Cxv_MUKYwXuG5r&#BD{Iy!j6Ep7{!4Is>(9p6Ni_`e&d@d-aMx2tL z6=Kd!zxI zcc!t=Q!x(v#}=z4yciO5X1Co3h*+$IpQz|D#Aeh@(IBxfQD1Cw9q)k5hf=yoCi)Y; zQW;qLg^p>9i#>J6pX2lwLgnJHvl)Ss-k!dv8{p33OG>nt%7VMdNC!Z5r&QF>6?Qid2W>u$b`SdnhvNLN8NQ zsC7uSpA-E=$J%4UF5=~wRuiwby&)H z^!NKnKr0jdg<#S-15CqQnwTm5PoX*Z&vpJdzTv#2&C@4gho~@`qOuA>coI-cV$^OC zp)&>|my+b)RmEI)6#@BV6x{MP@EA8eQLy5~%W>5RS@p$C53$zF%mPJ{;)G6dr>PKr z{I*XDG2=D&SEsMiOs9pVH!jR~7xRtZd@)8G4aI-h5#Re;wbjX1_#`KWfAO|P_3p_; zl#+<)_tgb72Mb}}8HLxpNdj<4d#?XKoLmIJ-EUOh#y~|3G?MvM=K-!;kGK6B*?_Pd zRyjbW3O2p2$7J$P?cB3&^}+EZ6EkolqmzfNRNgyawCPP>yh|ZAul5T5kIq z#4+<`DrWLRusAF+kiB2+*TbnsG0@A0a_=Az!Qq;kNblMo>BIs~i0aTte% z47J(RK(TxwLumxjR6yyMncb(&hkhcUZc|vRtSJU)NR#hlIuBj^0?(yp+(LO_%{dE7 z-G|X>drGJUS1a;ewm;oIk=F$2|4|2Y0ldue%I637KX_}&@^PFWE5NAZ{OTA2>zA49 z3M-Eg6*`j=L7@rBkmb=@6dJ&DadN2PA^L1A{zrMO{t75A*5{YabM3#Kj)-IK<{nTm zGueiOECunu@D0wp2`=&kqph&AoO^{a(f{u4SE>2f2%o}Z!9 z_F}vDUU|!&WoDwx8*g&715m$!k6_orb|E4v1IzfkCyCOa*`EGfxB`I5qUZn4^7czy z9tKe8*-~(SuGH^yKt1F<7AW!jx_g$pR-o{FNl8`VmViwC#Wd{u0iN2b4*(^U)MtZ{ zVwoh(?aikl`z1KhVweNBcgc|2LiG69{9rrIOWADSAKC_`u{~$1v@as)^BNe|xecEu zj7Dg%y=NJioSEI%rbh3*zI$U|xx5e~IOuFMND|>Au$ceuKwx6LENM~gQ6=RRcl+yP zm64DZI%LA2XAppD{nlyPCzSueufDMTP&;DkF;r~=gBHfZ20y2BU}yqH0s&-0Y$hen zc%%%^CKMGw6Wv71O3`9kW|htg4`fjgv{`9%7ug zvAEkQgySl8`1g6F=V2api)ke|6i0Kj4#H5wEy*9-i6ftfnI=AsPSHoCl`tVoC$Zz| zOHQ+5IgIy-@6=q%>Z&b!5w1u>UAd!%Ls(CjJpFiZg~QMB9i;q>t`D_DwztAs(*7ub z2>9Mg4QO`>cHL)YW{vnFZuL~b$A^~p{KBrcA|>sly-E4#_-n4WTDud6J+W0y`Ie&` zKbS`a61p!|zUl(65`>L_a7w}5bht=9s?&+)EjqPMNG&XHhQ^tNUE|cAluE}jU0;~} z8n|&pATHE%Odu5r{bPOtF={w-rdfI9_iPFMS&l8*+gtbw-whPno>gHueZ_Wb=7W%2 zTSwAtQZj%6ITv{OgP#EOuO0j(D{z-3Dt^`(HE=kM-UTL8C0hHTb_`i%!?U}3F=1bm zaKaa-r+>?^uSu(v=q&*Kb8MJnUuaZsg$A4PW#+hlt7cRsHET(I{YofP3P#dGq0~P0 z$VxJ8L-85=CmL{tkKsnX`-9fq&A;-nedkXOyRkvv#(|;&k@%QJo};W6$zT=el-j>K z^Ay!UW&Du&jnbn(8=4!8MRGhEvah2rn4vwd0{eCBu=JA25Rhem8$qNNEF^9)T9I!( z*0GC(bM%Q90?4fOYHw)v>HqpnWMfi;I~*Q|34)<xge}ep7U+t@vfjocB>rJTJbDj%@0WXk5^^u1#iqa52 z*4TU}KlTr>J^XxO)&%ePg%k%RSVm6`{oRWGho;{Syj%d(iewuFpwl7UZOHhVGfb3v z77s!CMddV>-!pr}*tWcp6jwg_YqbhhH2pxoj#TKXY%wyEc1;#(602OfhU~8 z3ckVI!Dz_gg3%4#5Uk(+flm0IHuXh5Xf`-e)(78Pq$ zx7hkft00I<`O#U}W0Nq}(QGFxVc|O(zPxm^3u$tEYG@Z2lt>AlL-rfiLHDmP(;tO7 zq}O#Bw@eCORQ`Jk?{O23fP^SGI5dldd8*g`2an43=cmsC%pqSg=#J*omj~7GMeF64aLftGKxtFw zLE^EH|4V|MjzR^VhVrAuK5hM)K%)`bO|-`y^KbDtA2t175W2dm$G=7o zeVH(f-Ql0!n*Ssq-Wf@mjCmguq@6WJ$68Zh!nsUNdA&Fnwp@EL$Hxt6LJ}nfnC-yC zDq=g9aDEI{r#Ok2!|?*FsD_cSD zGED{Mg_gh)gx4egMd%Y>`(|#=OLl?O zHej>iY=yEOaA$Mc1EK#Y?x?T~84KcmbmSr{{dg4{j0H2;M%FWJ7Zw&@_EXu_{UHkC zg#Rq>No=>7P-Pqu1G@=j^Fh49%nx-^ljxYJZe;fB9{v2&dS_JX)Slc5c5+3cQ->S- zW5^II{A^0?Rr$-cAru`GSxgx#duGb!F0#iG*AUNUSdHq^W66hm#c1QI`4I}o@tWx5 zG_;8Hj|m(MjEBX>0yMb$;NekL0&||E-WkGY6$E%?(S$K^lBZ99W5~Cp zN9d}p3J-q}oaO*J+G|L6b>A!2#?yEKwVz92mIY3z`tI+k!f_eD{w~4{5bW;cK z_%G?f&y1AS&l3H_dx8zE_r!6}bq=ln>iH!qqak3np1S}K{r`E7b+)OY_@UB0c^BfA z>YMMaj0BbU`u}C@0!a@Qfpo8+a*7gCqQEvUo0X5-!uOOvNe>YFziPf6Xm!Xv zIIWEs7)nEpq<~@QOJEJf%e6WMO#}6q^4)`|l&21T#tvu6z|pA>Jj%gxegysOTvT(8 zRb%=c^EQFurz`q8?xpzZ@y5z8n?tU;lF}FI*;c10(sA}(lqy-M)-+97!?nedblLD(0S^Pm&&=2+@KkW-Gy~E}E!v9gW7{v8K=bNeINnz66=Errh zlPj7tCDzbyT4M_7C;p$z>~d|iIN$~L>z$*9Isx}mtTh&&FCYNnlWB25LxjM3Dn3_{iSw}S zo>W#5yquh;q2vlvD8^uJK=_g>AQ48Nc~$UtoNY1fkp|E?PF8!QaH*6>jP* zEqfvO+A$dQ8}FjV)*|Obe-d{&=}MffV?)OyLIV|XQ;i^giSByu!ysAlG+7IY<_Aq(ufAN3iyHlcDdxF;ZXnXVI0x<;w zw^+bmw|-P-OMf1ViP6H&Mv9vMhWN2+M6LT2Vl3*SDKP1QFL6sXvNwmGZYgDLA?odwze|)*De!6!S?1X^&7k<^_zVMqx8;s-c=n`YgVm}m< zIcn%TLw4W0=!x1VH}zM*$-rb-jN6Sca*I9ZMBiIM(54>-ja0w-HVXPJ(2S)};4p|#x)_u*q zkMK_Ty!;!@g4pjyaXE}a8xoR1f<&FLMVGWzPtxdy?M=dl-rh|mY{6fYq2IY(z z33+r%yZ8qBXUJ}2(R&;y-%>5Abk(!J{FIZDiowYCm9QW|d&H@y=(+aqsQYm~vigz@ z!%lK+T(dY6%7~u!%o68+{Kcyfsy1e!7z%^JFVm}#4nFQA%PS#En>fqConr>uJ?*Fc z!l)R@rL8MWCE0({U?1NI8~B7!K|Da_L7ug|@ZkTA*Dw^#2_frpgT-i|wwMB^i`xI< z`na1d$pei3%rU58aUvP=c8iP-YJtnEldt@QK+C^{;gF51P=IzpXFcV=0%b27tMnayYWotxEHM3i+?6B!B@| zpY#7$C13pPH4c2c7t`OzDn|zT9|2K9;kn+S^|EUK5*BMzUSg%+y4I@jMC9cXW&dow zdzK8-!=hqs(0s6ho?6;i5lF5C`L9dKtMDjtvQo=-__Fd$Huh9`?#o&owDoT-K*B;Q zX2WKh8{ccEpX%uro}THEp;VPYB}ntI!;?tpj97yYmp@mUZKI9}ib$^i&PcE|w<3mq z%odj0dKo-B=JX)srC0-2Kdtdfn%cryHu5qH6E;{*_aKZ&&@Krx$y`kKUOcfO0;P99 z^IW{4g8vR?G38IGx$lX!7V_=hTNmU>qm%#ksSp`0b`0+7pvZgR{(13)UA2?7l}*Od z!0cEa6?ClQc8<)D!htuy0DH(pXgxJ%2xt%AuT%uX5fi;Wr;!YdPRd0$#w>nknjq%2 zqprEQ_Xk#IHLO8MWb@genIB>k{R>aWt zH8}F0P5TLuj{wdKM`{i=&7A*;62Dt5cBor-mD3;cX8`~{)@$Bs`Qq@;zN~Z%mbO{? zLW15H{wlD4RdC?-QO<9sBF6B$L>(N=VBqNH%3OF&4x8QlU zL#D2|1jnlzj31vz&dw8QxSIt&>8Nd7vfZV%ded#>3}T?7hNmHAZ$9K`Pf4V@IhR;< zJiU`jU-d#qAmS0Xu>A??H|%}7opnXv{v0gzG^11>j8!)Hn3Bd49^cG_th`PR4F3kN z<4823+?{)%8{U(|**I+&ON-m#I4YfWxP8Bhwq6!Fv!%nFskQ4a8FftqCB+kN4yB^I zOmKRKf!DDx$5828ql-Kd0LqXkfw*Y@>RsL*JG+6_z)3;sy~u+eQ)7Ni3By${1QGBm4?V=X3>_-qh*YlHQh+}SU)5gSbk%lj zly)2Pe+O0;IMUw^JbPjT7CgwTYy#PmGd3)x3@3R6pilVjU(oqP@nLBf+1ZUkuO?Je zc#%3`p+JO4apaAz88#WpBe0aKRc&+Jwn~C%lx~5{y1feRAEtmhoCmX^&U5p@*H`<> zwai^dYk&z^+^&GiXFI=IsN#F$B_U6}5sV_zjjVxrzA2VBm+AaPGI4ie>Q`?*!}fs) z>rJ=ll~)sL9?T9~0f9N@kQ1@!lm-fQdAPX9tzp&t1XQc6#N}$N=1M4G^OX8Fhwhr) zb|Xbe?bA&lGny7(J6-7t!zi$|u#MR8+0r8!2#I_~96Cf8o^WVSz%hwTR}Jp)LBO@^rbk^{rL?LiT&!@FStxY4E^h`CQhsOSwni z&BLH*ap%3Zl&qA5sqIftXlq~edJme~yWCO%WSl=?P)PNxJq?E%=AxI0I6_%&!PLFh z=28jZag|tIiBzwe06|2r-tg2PWdA}D?L~l_(h8`=T(~VKNkqa`mG=HYKq{j48uy%z zWUQ-iX2Sz{GcN^M3b9hg-~H?Os2J9U_az65ZXD%z7gXg-UgrqX zp9&^Vu%-6DnpQ?|lRw5;R{`DzfDsl=t4b3v0

|M*`rAtr5>|G^a~xfaGXx?m;_g6>k2O+ET7yyWdjU3&J^% zmPdQWymrt@Obg8v30E>c6_0x%@QwMpDgxXLLYZpvkhrBeXF)i?ge?!o+R#uU`qJ56 zAScY>xDogl!}H+;IK6!eh9A-)mroSU$y>s_)DU=DUh_boxz~b|Z-jY4=~^FjnE=~w z!mOy=J7iSdt!pf%dkl@7r=aHLL~_9K2GES;sDo zmCr`*^AoH~NqaQ}YWj)RE_|ahOe^+2@NQo_L3K{K6;GttOKO2GV_RD?PgmNdc&&qT zv&d1DaPT#{mhY|8YMp-}#^zuAh>_ny;vV`axyX@ehc)oO(wa|BA+|rmNs_l%SQUCL zRH$*B1aU!~J;P(vX{CnCk>@wS`@{Nn1!q5RD8X3;Kal>+oLn=rCyxZ#L3a=SZVZ@% z#fPuq z*N^m+Qo>B3Y)6@Sr{|Ur@!6B>B78}Gh65fAB1>Ar5ij!RLLd7;i;SUPd66ObTQCBJ z3NAbL${?5rmejCMj2oi@+o?sQjYnzmftB%rcW)Rb|f~MTtX{uDe41s`6I#SH5TNunl^FPyxXjvFwdcNwh zm<04tL=3m%%fb%t-o(j!<0)+wJuNQ|j0r ztOU{;Yff7114#xiJ8+dw_#s^{0l13n95=FRQ|kMV-|DX~)kQwzE57#d85*8JtBaA% zxwTyCrem zTV9beb^_u1#BpvQUYYZ+8%Nt|x*8O^`L7$N>+7L0qg5fyrKx!UtaT8S8|nP0SXrP= z`}fNKj~b{uCX|KmLj&CwzhJrFg%Qpjj#u9p4705Rffj>!7l<{g>IMNsW+3qi60cd` zQW8fs_?K!!Wb73EM3n1iNZa5K%$q-lZZb@w=-R`~%CI@+QFvja?7QW1H&7+7xaVs! zOr~mkV@9m1V=AlDJe6Jau&;4K&C6eE&lISm?2A0*tvDCNz6+Wz%cwvpy@YhnNv7&L zfOcPFE0)^xt6uu_IOjtELqZ2)y5#IS5Cn3P(hv)irOIj@Xab;yQ^=cD;FZsr5t$w- zxu}U@DT(k2WvR$_U`jpn#GS`K?t7Y8&?L9}`zs-DNT?slI}@)>6_QtEXCQF$c!hv$ zk#y!sWpMrBMnG#G#T~fbY5T zglZ-E$o=n-tP+RSFkNegA&eiZdj$L*uK(&OfFcxXj$~{od?dzjU3>dJO#`I zptIJBTxGM8X~!qKdtL_i>+4Dfw@I;6Vq2a#5W-5DtoOQLvKmCN;w6*xe({wV{^BaE z`E}u94qW!0lT1+&|GWqc(l)&l@tdOTlVp8a4$PBqUZTNH(-3*4YOO&H846=c-f6h@ z!2O_kA-hB;{0Ehv(SMddlOXPayfzR;-PnFL?claPt&S;ps!kJJ z*4IWXSvTBW>VPSKKpST#(wb+sRCaCl5*KipY1ehQt&H9D#3XA^sB{cn`-Mq~nWG{9 zIW<^TpNuQGzWx1?{Rjhk@JiG=wCv4^#VsoV*7m;~#)q-)v zPs^WtOupz%6Rg!wQ>pqot^@@V!nByq`O_FHzO^ADuZyW4HC}&L;4VIS_cj&_>Iy(< zE0BRGgZ3(v2xzp>#r2c0W>S({z(C6 z!-uB?51Zu`ri_c1+~v3Kfea7!vm^B-o_99;x)+iSZ9q=RiTrx9|K`S_k}#uBkuq&Y zR61%5Ty*;fnT`i9hYl!Y9z~0@Fp6S^MWw zd+%?{z5Sj!I#1oy#bf>2*xq^O?%_mqPXRXWdUkz4Lz8>lu(s2?mywkE7K(HIY&(za z7&CY*thPr`SLRn4rlA3w3ZJ+2ps2+-qi75ZJ^U&3k77}e6&q|P*g7f(ktk~z%&-3a zj*~C$vx%Qf9XBoo;?LdU+rK-JemHFq25(3v*|D6H!FWozE&1;A1V_JKOrzsvArn~| z;>XpPk?2I;XW%V&tD7txbOlR)l7U#ytI_XQoNgn&UN};#@>2aeMO?}Zg(dWYHE5jv zND45f1)lHYUS3}TQo~MuvJzQ`py>&r(ZX8`NKdr9>0PgjV+!zE8^Xz zMLHbI09HD#>hSRQz6LBuG)6JpZ)(U=Vo@Czi zC*sVLy~AE)5ED8=gDY_)M`~K@kh+f_@IiI9-MH#cWRxXu)b826DjueWWLh$yF-Ie< z1*{WaK^3=mrgC6$;U$3E-on1W!JY{EfaTLkQIWW&LjoOdZ#1T5aEvfcFJBIv&SbyE($nRhVbd*rwFfUX$mUGft6k z=U@j$atz<1AeQ@pPEEj1_w+6E64)5VTN{$WEjzV@bRqd^M4QF3#E;B6u)rLY#iTS* zRhlX~VEw3;nKf>Wz~K=Iu7@Ka3{06$n(JA~U+4Kkrm5Jc788WcAWCwz6S|AC69wL{O; zgOH>4<(Ja6a2$tse+*dXQ(elk%gD2efKt`wg_|p!_25#9{wx$3lgDv9i8+B*jPm!2 z2L4Q{X95G($jjkJwQgOzOWj3^6%p}HjEvj)2UPyQa^_2a@;MG4%aC=uHf21A{Pfn~ z2G&2}P`aa028YN6{eE+OG;BOVUW|dW_5&Yx z@3f{w1fp_LZ^{^~j)k=@_~Y(l&-GePy!P)lCGR(1=qrS6?rgAebOtvj(=ccl=!p!y zFq=Sto%bd5%twVa^_(M7Z>S$2n!hE-c zg!tHVnH~Xxn5Xf+?D9S%Wju>s`Kjp<3u}@+5u<03sxqtzuY`>PSKt0S};U~=~&f@;;XNloqm5~Z(-Jy7pI z1oKUN|JO_V>7g?7R2BAr^w~$}t`LI}xKw}+9wvzu^gMXA05+chEQ}apn?M`Dr5pLM zRCPZVb=wMgNOC zla!u`lyd$I$-vL=dTR_#g(ZlOZz!>~45<z!oyZ=EQgJY#?Q-ehk{&)wi!tPOJP1nK1 z^}J17T~ogpz6{bp2u5R*OPE;c`HhrI(GVN$VkFwSKql@+=9E4E&dT_>E~cP>kV`ZC z*?>H8_7EB(?~Ef;Le=y`S|o~GW#1@dBH>a2?y;W@1`!7tEiLX0XQsr~ z5b5v)hIpa!MqFZJK8T@aSL6WFo^je2udFnMyDxIVfp>6qMdMwC8DOMHs%%C|MrPmjIU>G1gKY^~zMJ|0Ky(n~ty`$SH>-bR8U^yBa2kMTAk7YQ;Vi6T#u%NZT@JWb!kx{)=40Fv^A`pgp1Gjo=EG?)Zu~MeyOnBc%Jw)N{ z5nE0u+Q}j8swI^FPPvCX@$CaQRaVNcKF!w#wZ;#PWQD_vP2D5cYq zZ+m8B=GXb4ct4*sm#5 zI<3dF!MiZtdT>V|iG}ZfHq|=t|!;KAZbmRshw9_-^?ANL01ghU}iQ!fvFkLEUIgZ$%9rXtF3FRmJ)y{J< ztug=QO*fFBdgY(<87FL?o@7^ljC@=*e5$@Q z-xCsAk%--t_5?a1PqL0u*6s*(>J&oCO0b_D&*U-B$PxZ~%nt90V0W4D;u) z0;VC(QPN@#DxIe%xjk|8s=+HN#VRaiKMZfkSss4v33YarkFMasXJBJ%Z(QCu{Isj( zh?>K}L`-QyA$#SQf7|}Bq9YpCUljPY8$=EJ7;H~PkM-!wlFnPI8hy;j-~zfK&D5Vj zZjkIJl?^C&DKwIT$=qUR^jvuO5d90>puapGCR36JO+(Zr`>eJ4o1h+E8;SE`5<4l8 zE}5LrEIR0wcy^7=mTYM!y0Z2pBrT0h7S^44p!`trCM4?lHfVZG#KokJ9Ex9}`-~f$ z?o|{TuD5#vUXj7gzgqXzxcxHX3U?5Q2TDtYmTan^7HRNwk-Cv;-E9#Uubo?N<;Gf6 zMY`sKA&PaRMkChb(PhMwxjnEGrqYvrG&5XdtK59?u>uc#fF%6U0@5{g0?77d53`h$nxVx9*nII}cr< zGM6yXFT<1^w@hGf%@EDEAoxf1<|ECanG;n>PN7=5D8qqh@TM0TPC^V9(aa|%`(`qD zE;eq!6j80yak+*{j*dV#S6S92G@rPK{&XqanUjozDg#?2&;Hg#I zI~-v-PieBry6^-~Rg*#Aj}7je5t}*9fZ;a1AZS(n1EtPhJF+VL&lnKp=`<6Vu#ke# zhi-Ho`rD;-IIY9!+EI>Wj*g@qygvkKQ$(InG(1*YPSeI0?nebH@jD{VW+7;V!7^Q> zR%Vcv(|LL#RKhXMIJ~t6NL6wPz7xq-BNzlYK_oZPji&^Io+qX$D_)A)OH^I9fXLDk z9MPV5;(E0`oOs?fU!V!!y>w@(_^S-|KwJ&aa+vgT`fNzYmk0 zWnghXb`!G%l;p?M?iT{!N2+x5ZW~Z%dGsK}XNXAvCOq6>k#p?kzEC9cjp;0Z(?dsx zq{0Gv9H`AR&hV{DR*$5A=9{Jf_X%uM6uAX&lYsV!G|7P^Zwhf=?(Mm}5xM>w>6l&6 zH&7^t_H%pq>}GsSa%&{%+{hVu)0i|R$?j~*M<4agV8l^*`IKY%Lb`^Kesk^tBvubU znzhFZs=nuX_t#XpUa=#>atPDobC~pM}q5e@1fKTNP^3XV@suv0RXQoZ91qHe_2{9iTMm z8uP}Y>q;6L@#>yiyvxuHqAb)GrYkj?=1AcVN<=C8EXX9Vt;m;Z;im=S#*(HR?G#bF ziH*%A{0Zm23!Ck>!-(gYin}BE$jSM>%Bd_aq79CW!vSYm;>r)K&;aa-0{WKzZr^wK zDbb{22pS^aI*VgCEdYnx8${Ua?Pas#uFD}Cpa-g**w5Cp=H!_ zbjJZ)e;USF%O8qO%M;&FdWhFrpo^Ri~pSFM=yK-YGg1;iPl=`ZLd^ zuiUOf^j6sHvSLCa%+GB9$c(a$p6>`^rkS(x5X5I7=+yvi7^yB&7NC)_p`F0x6?a8e z&hP7k>4iOKf;D$Oks&AI*i4#_(zAcugGgQ?1`Q7HNYZD;BchfXEDm7hNQMImo@sg5 zABht!cHlh_X>V2^mI5Q;tDf>7Iod@J#Kc)lduF zG3=_5s4ZBEOrAAR2U`7;k)GsC(6-ez+A62hC zsM>in={c==0Jm2kFh4`TTRrh#ljJTZ{Vfr55-`l;SY+&FD)Fh46y|>newqGLnu|(s z%!B0mS7cO7w1mD5Uqyu!J*B~DIf(@393Jg1qA`{mS^o>VPqrT4kL=mIFR9^@aTVw8 zpc!|0hDsm2jPvQik;OcxTnTxu;_5mFBNybirIdmu99WBL>k|adCEC)4ITUveH8+}? zFS9#POw+XIs3x5cGc;B>ml7#mborQigE0$g4-w#5B#YG{{oj5__-o9P+YiYg#YcH@ zU{@q4Qy+@WUp8CMPm=OJ44i98hBUcu20UCIBeLLx)01hn^w^si6?2QnYo0Zmub28= zRl5CIOzKf=DXnwRckamE3PQtK#bSJwrGcZ>yvhW|f-YOHy?HA;ThnYKpFqE~oA0Mk zbO5fIRv$3wFCUo;pMv<)i@Qlm1@@Q1>_S4{)Ib2R`$XNX$XNwS@Ev!%)f9;JE&I|gg z3#seC5ks8qkvGdT=f~CP8YyNq4mmYjXi^jVVE!loOmVWjweN5{Ky{noOh7)7mqF(5 zn_v)(CsLw6;OdGJMwXsRnYa`a$jH$oV55WzOhg%%%pP#EYQw*{o~B^27Cb90&!#KI zx^qTXU)7clwqTk_gB}b#j|Y?a7Bbk_Ag6Y;rKlUhxybUuFAJ7Bcv#Xjkw00(K~W(n zK^df^7=2KLbE#w2T&~pNKp$=0O?pHp>VS`;^|aW6kiq25PH`-2OkI{FZ?fEH=mG<< ztyU$4nq$V-(G(9x5=50y3u)eb=xb$dGcdY}>?h&RbK&V?xyRZXn!k+K8vj7cKdD!HWA1M-Js^kb| zHIibc^wu7dR>p-I$f$T~rKf^Ax^* z@DQfo!)XKaiy?LrKN!4o#iGN49Kw)DMkI;0@FxH6qqzoir7$Z2OV}Cpx_Z>WM;-#B zy7I!E@i*x=_Z%U2qeZFd^fGVCrmDMGplDf*xsTfXoq$F0AUQ`%L{Wh9ufQ1Qb^Q-t z4r6EBI3;jx$otd61UqS=-r~}Syc(K)Z1(6SwD)Nio7(z zS_&RIC-dDFMDs{IHLWhP&?!`Og&vBcj!!C`_;_dJ5+CMlm~+iPio$(kteYf<=RuXB zWvh(H`9^2s$Ywi~+07~iHz(9$QK5U%ARQ*@kQuTDrP3Z`uozC3&geINwkeUiAFhtsfHP_ije!4b$n)*?w!WJ$DNjs};$i!5S|Ex-+`EC~dh_m`LZ&NWy1Xw#7_wTGfD%l378vfk*~cbA4hOAbe7k|Nn8zec18N7LJo7!^-Y z4SM##d>67qwTj09=6rn_6d1@f&5b2CV57Um&9T|_R2^qckh+Gi-Slxl=6u@DlSB1U zTYvbdQ>O=gz5HDEcrj6H%1KB20^RzJUDO~7U6cqHe~}vMEo@za5em$IFoBy%Wuc8D zbe5N`gp*=Cun&=k>uL2;T?9f$I>qJQF#XSk_{m~wW_+iRO`y=>mc<9O zT9}1d)0D2_Y^@mBQMPhqxw}eR??x!^GC8N~*LmVt9cxN`^8fET^GkhtXrU zN(f4j`!8dd5F|yh8z<^e*qMu(R2ldDq3qKP8vW%9a{OC21&2?R*jeaIvEhU~|Dy9Y z)6i!=k4K7o3EN*#i9y2@Z8(b6W`mB>j1YR!dHyZSammn4Xww$58VI?gQ%yG!%LCf^ z49U&WIQ}8ARl>PhXy&pb)4=`vi?};7^_dr)p!)-@R1B@~q|!$?yA7tpLhM^dT}!uN zbL{s|5rqt3fOcEn7mLXpb9x%ZhYZ;?TAsq{<>ZJ@OBFe|55bqW#V}vDP!qn#x|FCD zNdjBy_+pkbh|E}vHN;J>^#0nE-m6?)$OcI$|4Ro#nbFg(biyzJx#rIr7Nbfjfj-CnksPG_Kj~)1hvYiL5faR_#YI@n+)~)gX19|VkVX9O)EjX))+6f0 z@ZNJ>CxJ*mHjyZbnTLmwFmio}rQfiiaqSY14#cQMM_C$}hZn`iV#I+chFu^XF@gRt z3gx(j)*5>WrC(X$bdW%TNwl3N;3qut zo=4{X6Ux3$;=tBw?up=Vruvo9UlFFjAK$y;<9^-BP1c5#K%Crcy23jQst90g$r4D= zuPyL`?Wt!NUO6Yq$)7lOJgS8DB(=PHBVu4M*eOWFh!a!pra`+qpE;}wdqt_l zD7VO1cr27(0=1CUXeTdI)?}vdS(?7WSEMXSuV|gDlo#p#!oLfQCRu9}ceF(b9l}f6 zIhP}RrL*tq%S}g?YO3_RGIU(chE)8-q$jbT7Man6R7tG9*?=bUV8V|( zBX?yHb6l`XzsyfKD@WnFyezWVril=kv=IOmo!B#f>jYar&f5OXE}Yo*CoQ(;HNWv^ zCVPj7RKBEI9m%zZ#q1iUsNU~fKcv{2AfT?m9h6Zz&~Y-`?Zg0e*r@SW3&}jH@w)^2 z+79}E8I7TSh9~y6kl*B`FF6x6sQc!QrkV{*5^6%rUqlRaLcm>G|GG^YYve`1FNj$r z8{wkqqW8I#%4uuj&8*srToAw0S$z99V+2h5o~n%1plj>LIi@dl=N}|P;Ux-)YZ!$q z><7_-Syg0^P*i%uN>ix$9~Se`=SWAbh(<@@Bk&5gN5wdSb1C9A!BcN-w$~9Ud?u(A zhug9!87m=Xow*GKyOCqIO78GjaI7IiVLbWXi^y8JYsXw(3hU-$otP5K(VSKvZB9YY zQA(p1)#0yDI89N#Cod{X(MeUH>`VITV3d|p#X_2qEr)YCS&@Hg%aiqUMtU_CqP$`_@0H^ z^JyKAq46g4rIm`%<7>btgmGT6>V>Qwfs1kW+pv+bEInNIn;u|*)y%N-v_kqNNJt8w{1KKurWr?#nd$Df$zixZ z-}}0L|K0b${&Ag8y zhL;3BpGCX}CgQ;553c?nu#R7nt^7GipWwCP?8$_dn9}#}t1lqP&?8sk`V^7gfhFo4 z0^)jt<)^1e&%cV?5uN_5YNbnM&YUTIR{Nh&#RABfE(JOmoA*<>og5O!uX!rv4lDHV zRc2=s(q4=!u0YGLfAFSxpGv&uBf^lbV<;|wnCakL$dwrR_xHJDh`u$kuujO?|L7o&~qJ2dEw(X{AAymOeljGor~ z6g8g)&!%USe!0Kqe(LRuK`$Zpj(kMcQTcrd)24Id{9#Y=p7_BJn$W&?eZw}>jz%(5ukBEx z@jLNgGq+3%gF{ruh;&$OOHA8x0SL^^9wF;EPegjW`0kSus(o*Udhh7)w|g2bLdO!B zJqtewpIgE9hc$K7-5Rs)p}MVaqjTB9GzZq8T)<0 z>v8j9*&9(nMEYFL<4@!q8@i6wr3qH=loNSQF&`088zc{}-%=S9lqzGu>%JUb(ebWg zp~l6Oj()80xl?Ft>N7^cNuy}zFw1QkoYOTIY@|w5*fs8)$PJvSA>t{&pF_2Cb1EV- zD5QEfP|nva9A+I&Jx0@Kd6WsFb{>W(l*3)^Q;l0`f1#IcVvVZ6sBJ0i*Jd;!U#Gp1v0 z63T!_blM%?w9gj%wQfRvh)Am#j}5{O5lyBk`W(u4PBgOr6Elfu25#s;g{K;f|HBuk#gH-jy6Zx}dV!z{Kb!5>3tjhMcM zG=A@92yzlEnh|DrlMa5!)8-{x>^FKzuXF0jsRj)!!&LP?QuH{t)!VbG1}=G1iv#`- zfBy5(>_Cy~urIwIYl^qseq$rDnRn&*gc?F_Hl!H8?v|A`NCjXqr+d|Bitt9Ujr&)= z6L4nK+1jja?~2$8;8h)LfBj`kY=aBsM+kMd;{yIWFDVdXO>Z$+&O8 ziz?FaELw24wIM4GEWk9I?(`5-@r0Tpg!hk{tlAG|b2`6v#6LFgWJzk(e-^veq*>{GNAG0XyM%%a)ZesNE%>fA6(mD|z*;>ex6dZRzE z%$Qeg^e|@cponnt*gq_cuJ<2|{`T~-k$ms^WE`9)6J7645ZRWFhfemd3gBbtArWs$V~n#Mr0)`uoF4W?Uh29KF}IE{>iXkN5wB3e+OgQ+T_xK zEHx|{%^H2m`0IqF;*tmU_-o2xC1=Bkn(D2|_rawg9xKGl&zPdwaNOWY4eVs|k4j-} zjj!f#kD*De{e6PuhvF>IJ$34|`wc2ScC&9&&CNzlm6S4g#TINH;mHGZI+^d$x!!); z_=_nwWMJj+porQlRYD69ss}d?PPK~M^EJ_cZh?3hARv=ONgH9)LK|@+)8BHMX4SS4pPctYoBtqi z_CLT>v)-3k`wq)z$b{6w=ZVhosMmLUbGyi~W-Q>x`hXDOK2ND-$M#wOxB2+wt2kL7 zhPLemvU#dExGpV4WlOuf0Mph6O9TV%xc;c{l9vp@M!v#espHKu_K0eY zQy+ukP)qA9zHdp7S$>ooocUE>1Sf>48gcqVM3?0pe{={heElh0^JYsl@MiF-?bEr({ zhpAkGC2`eJ_+&`l2BlE{vXI~h^@j!J=)Id zm6?ZfUP!Ph$ApTU^;euOYUTX8p^ff3m-rx9@j$6a3 zmZb`>lz|DhiDfyZ>K!2E_dD7~=0Z>J*Lz4N44-F&e#a%8z-yGaRq1~I(bZ4knN{sa3d~n!&-L{Gp_^bG>=euCIt?hb1mr%Xgjbuz3p#*rP zza`!549-R8w)8&4)d zeb7q9ZOM#nY=YozB+L!l3Q|^mc}$d8e^E!7TyJ%O+gO$N?Lb`+8S{29r8L{k)949f z*`{2dWFw^|^~bDyTlpP-)Ul4#`x{nhwjYW3=T{9`-$tk!L_CXg0^~N^2+XX{yNM5N zhV+70gXv5%uw*pNeBsWtXeNo2W9DhY0T9W;MayFnfs@ht-x1qYnc8CiLTJg@5B^e8 ze@Q)=Lw<7_sqgV`F!i|zYxBoN8-CkVOYvgS+gdSf>Nda>2u=_;qG{LDzVD=n>odw8+{HTA1*`3z^mqP}MhoZGi& zV=cFXmh}YOLsfmnsQm&G_k+{;M&)I>+k|)`{$QGY=>{X68Ztslu}Ow9(0($3%>l+@ z4~}QSF)l4DnX9TdW9Kfk#hD{Ak%w81ch?g-^5pI3r$K)vES!!#y$a})nj|0z-X65} zH$lv_NT9$*&jVi!-Tp6oV>8tQsJ0OrV!WXvr`ce+F-lY0hPHtC2$_bk%45ak&@oLJI+LYqOd5LikBhRS_n|8CZPKwOQ#+v-&T0>c*36q8*#j zG=^j^bgQcU3&B8P2|glGy|Ygbp-@87y=bL%bL=ljf)dXqndOIk?dCg0z0*5FJ$yRF zqsZ|iJ!&>l*=t%OPQo=Jwh9km<#&8fDB#%@{Z0GFWO>u~P-bk6TWtA>ds@M-!8X+6 zWB8(0bW%y?r>7Nl)e)sN&z^w=lm{bUf6;D;((wIPu|K_Z2@W8$LHh*#`}Q9gmFT4T#%U8 z8#?ss3fB_g<=LcN!**e|uNPL?2PgWLRlSQ<{KCHjw| zZyAl@IVe8#yoI`fbC*a}{Wbi|V?wiH`B9T>_7aBTKqt|?QI267mOj&!6} z`H1Cs?QySv`hY_oU&rYDlsF$Xc^%Sb{ZxpM0}?y}EQH$XN!c}$TE8){AM7>Fer0RUutEVEVP;&%JJ{t!)F{VVZp+;W#W57>fp=F#Kp;_iG9ld++V#bC=X za{J3NeR6LsY48{^wIhK|hz!r(n2t8!o21pdFjsIc!NW7fW82%$iJNQWuPEDHxM!@I z`}-KT)F@^&-|B9RIhF%h3!vW-mq{2{a#S&C{vDy$-FBnjYH(52 zH6G{dh75=!8tFUi`PJ+?%42@gFT$SrW%InY2~59ptfWPM>wgc5 zL0bzy9r}v`e4nF_u8iiv=${#+w!U-w zwPLO<=z+iZvyq{YQ`O1@`^`*hivN!wS_rcjr0tKONDuoqAtE zO5h3;^(G$e>D?hqzlwRJyz$?EAiulhsi2dHaA^>?a2+NNP#;@bqkLqu8Oau&9+n=1 zY~b==i-fauhsDuV-31ZcT<~E0G8%XO@Yr2cd(BHSQxi#-Yeb6KZ#{pG@?WRVk%T8i zD#su0!y~{07m8Tk8zanR5b8tM&!^chm+lCsuj%_2$T=(+N}E+7c{3Ki{k|E;HXRdR zEA}2~wCOx{6DtWzh`r|KzYpuX293=6IM7z3UaeGWxLhT>HlpJW! zcwy}O)p9#vF4%l=E1{VDdt!$O^7`4&;l#^z^JuM+Ul`hQubf;TW|K4|-O!?=#VvW5 z%h^jC@?yYpTg0$Co=C%g=E*;XN&bsPLk3^zlD7^ z%x?7{)PJ?THST~K4q_Iq#3Q;nV?T2gUvoy~vZjmF${HQGr%7HD>V5vXLa6=ukp=Z? zRUwT~^b#p&iLZP8p0g$ZNrj;-n+qOjeB0J5rWV<(Zm89}vI#F^`ITq}*YP0Ep;cyC zQyJEp!LfcvFH*ZFXEp1hck-BSQg`=28qa8Mxn0!EYb_Q|d%@{rp?~9!ng)pg1;OGq z0qhw!qHP-T+)DMO%SXhw{d~rn6DJ#ip!l&AA7CjMK1#yu9yP~Jsm3{ zU3tEk>+ojGNF6VxrXuVKf})yZGkW^swqiL3g0hh{?V%$-%(w42#(N)S9e|B}2L)=k zfIRQ&RF;B+w8S_6bF4!OLJes@Y&*$&y%Lj^14oq~;j7L5Bljv6PR`Q@`PjkYM0+NA zu_g5yZNHfM&^{1*(bh4t2fX;%c3DDAlzUIEB9OCjt<}8Cxsh3eNEfvwX$YH`(M;mk zreMgG`B+mi9t0hG0Z663aKahWvA7fz=nsoQKWTR+Dz)L}gR( zB2Wi5%W`jVSf`%f5p!I*a5?-!0=e`0l!0l0{i1;AJ$JCokYCbE5W(xWt5TD5BG{UB zH@sHHu0j^R;|^)ZciA;ADWsd-cTe^E>ljf80%@%)p|`Dde)z`yQZMczegf9w_Bi|( zs?c=$KGGPMwOnuG6|CgsEwQ^jiL@P?dgo+WUBavdeG8W3wga(*?=l1V)T{?0h_Uv5 zRK+7zwt7CS*#`TC_KSL6me)PBR)^z8X_w5n&2NC(ZlQOg!Es-~hAoufzrfX~i?%y`sLUGYslHK@_v7NGh6~iR9Pob$#=N)bA-u4tVnpEgMf02^k zaQK8|rn=`P0v}yNb*;UTh(X5eeM}=uqw@~k8Q3xHCS9c?Th{(7Rqt><@Hf)omb2w+ zZS5a;=5ybulLdFb)TC!y( zpP(Crm-d-dRPI;SPhD?R@WL05-sHC|-_1NkFKBzIP(+xI9e?XgBpuvmZD}SHbu+IkoZQS+iEyH7Y zJ#~CCEqJVdz4^hGENLRzcI%dkdiHH(DM0lYAV`W7w$%MeJFUiV@<7T0N|rQug`+E` zr!cSjk~&Exdv1EMA!X|=Y#?u@j$iHAs=oT_RP4`3rxXMYRDclSWFt)Sk-N~e_T%i) z9TvMc?a)HA!&tdYdsktyc#h#ITqV)NxX-_W(ii##(Y(%LaLdKPOe^H=s%G~rt7b3- zx3Pg`f=L*m@x~i<#6YH zFzAbxfetRtjT-CE3LiWjwogU3f>Ex8_1om=^Ck^}ebesD2GblnyPhcNt;~+7oD>=K z){*;^oxzPd~KC_C9)OWyzwi{)(s zc*`_qeB@Rny@NBOw{I})rU`>M-KhBuF}j{;IIVQykv{rDBCSYD^(2WwVc6W*NPOk$ z<~8fFc<3+03b3dkee={siJoc}(P*RZmT@R;HTzyJb7i^~R0E)cAZjI`Y#IAbDjD(e zV$fypgh?TY-17A=$p5z=(7IW3(Zhdz4+6{@$wyCrmQFDE%+iiG^=*+I1d`+B+(Qmc zp)KKtQx3IVtWNI$&00QmnPylD?q5=~Xm1E&MXq@Hy*KcP!@Y%A+k_yfI8LL!zU^r!U3O zjUHj6lTD~_PoPED27a?%ohBx*R-gi^d*I9gq(0&5P86S0&(WgkX<+S1eAR7(#!22= z;1-wB5N6x&L?8P3D!|lKZSAY=%%J4WEi43jT5L2D z<9*gBbYD%ZijhozLMa-SJKR<;P`H-HT!QCOUwdMw7hjP$x{1(m-?gLYogYrmtr8~U zQLyQKkrr2B=U06Fm;V531ztgz&Q$Oul7eq?5|p;X#E->e7zusa;BwuZ=f~V==jsR> z@CY;6fMbsRl`NLUgq0`vd%R}}j>WumFA{wb05QJNuP=2cS?n8o2g)$Kk8}Jcf-fbl zzldc-N}5oP?#)Yv)#GwQXA}3J+uxukdCS56!CY1kB}<7@Q)ovVTi-OHr3MKCWKTJoqhb)zNVk zvE#m2xD4t4^4mFUU{Jgffw0nlWMTgb7`%;5X4f1TAklQZ*Q`Amh+y_)^sNH-RWA~g z%?#O|ks|bdwjJZG;p_FYWO^;rYbpvXKllB1=JWNI0O2VGfuY3QIEX~VK)g5u^UxV* z5`qk>BRdtD<<6(I$LnFFWw2-#NND@@sy{I&(KBUn8Y|-ueX*n7C?L#9jy$zGT)-eA z{V|(1qE_8Eg<0yibYKJ3eWsLE7>ISbby4O6tE;J@(n=K{nXRQ_1hw(ZCvskrT8o9S zLy#AWK1vkFG77UU>NhG+S099lw;Doz(#uo)p1b10^-rJ5{F#^h5v!q;g`=6ZX>4eQ z_&EFtpzTT}n^2)5X4A6Dj+MiYq8k5kj#_B*0!%iwnhCG{aru%uIF(f12rj9W@bY>` z!EHENydT5ax9=E@C+~$h6OEseTB}kjSrO;mLU75M`xJr{^wsGD0{WcYS-K)Rw7C-e7K9ZDEHTm-Qzpe*`hK1;&l$b*{Q=e> z^UFZP^eI}b^xU_O@D#i0S(bU-Ypv{n!zU4A49AI8WuU5^pY@l{PI15m zJu4R~s-_0ZjuBVLO8@*Dj`P6$fG4oq*?DcTksY_!_s*VDm5^s>5;RRfuV^JpQj*p- zGJrO-?xkdZt1OuS6sC`yB(taCxk&7^JU8=4!;)j2pNOa+Ij@&i!q;#&O==5|J zts-06Ly84wd6z4YWokb6D$O00KB74%(aZN%BI!LSavK*zP3aV?#6Li=EH?T<;Rh+3gh!wJc&SukFSv{X~AmbtW zVpvz@)+Q-Na(F>&vf>agffvdG^5qPCV{ie*#C+vL(osy)V zus=s1mae3>9=v}>W?FJDO@+;Iz3qnlKJc#sn4%8$%!pFH*|`9ka;>SCJEADeUpOz( z8=bhJf?p`F{3v)~fhhqUVUNAEJ=sDKq>~AM0QjTObGV8mwl%*IhM_h5K5+Wz=Bctc z-@1)m@5L+Y*?<9q9ECO^`cEqPgjzrE(a>@_Eeqa`jB6v9FWB7Fwt2^2u62}9XZJ5L zFvPD);Uu}tU30a#t`I~M-;i`ze6A-}j(MTya0jZ2nt+LxQVNw_JaAf`q{irzOBMK? z5~5JGQJKYTc5u;?o{f&T+LG~S7P$W?fM}pXLLy7XS>|!;`ih;FvY_Q)a^Hw=Nhv4J zJNO*=ua#D%WN`i}hOUe{&+O(DUrCqP1uWGL64Ltv%#ofwXqY&%i1`DLQCkb901atu}lVDUS)puT2^NgK5R#Pw6{(%j$LvPqq11B;B= z0Ul{f4X*9B7pq`n@OS1?T}oYGD!|WU|DHNIugp{#GNg~KBo%!tv2c=&cT%qNL{Uxy zS6D=@#Zpz+_Gj{!Y~Z6@UDI&($1@_@kW>e&ZS9Mr^cjK5T`hx)9f`9vHr&qQ+36-l#gc+bMMj6 zc)^FKbeIA5q}TVClxwnMWzusT7$6{^mXr?S!PxnAmm)Q~S2ssY{UjR0Ha|NCyoLGW zAH&w^FCp&>N!Q}mL%7^m?cbv~&sz$AtBXr$gle&oDj){kTC4p)r4t^;DSCIsV{gAC z>o4I2i>p#*%}HzJX6_1o>9D3e`gpG>}n^nCxcFfGTY(*gt4Y$7xc4xOb4 zj6BoZ8&Mfdst3BtKdPc6cTW{yz=|J{o0Q0XXLjkCMPk1~rYDieYmP@D0thFW znh<%01O~(3M}8JN?Y{+{t4u>Q&1blC^Zh**5ZHN_>{HuIJ=8Fom#8v<0N0`J$DHcv z=PdN30vQtJ4PGuxjM~^^APZ8-m?a6R&w<*u|Jf_6&#_51gnJiUUK|CqSg-GN7rswi zx~Y=IY(x~gqX*1(gfH1O*(K`Kz9Z zFKH6&O1%O%NW7&g_PUd>S$>w3^An5EUn(P74qE2?q>H!t%3_|URTmk%_IXpVMA&@N z@6|)1+J7aPQ3ri&#=CcuK3*&>&3W)jq1VVzoYxcm6Ko->m#!oTrFqdbuzkY(k<|wq zD1+#)%pRXjEY_b@@Z_WJ_@=Q}?-z`ErR_2P{3EHZX_suV1B}yHxVVRFhZ;{!*yIpi zXF8UfoBSHwK4FL+OrcMYF&7%Rt>3R&?LU(@kHTChvf;qX5xQ7rk$yCQ{&eYU#gY$x z^rxA=g-}LIrA&#axbG&`<@p>6kDzQ5L%5?llNpB*XCOpWqN=oSMDYaXvZ;b(Jl8wejBFfgm9;+s={jjrP!{QAmyEW5J01Fd`#Gmu?Nd&7 z>*-MO`uSium%& z%<6HfhY}FoPDEU+NPVMgG1&cBE0`fIIiz;Kk(OaGvA|6(FT(qDRa@8(@5YzQXm5QO zV=R0piJKE{dQVuP^MmlTG0=8*hD>Dcm9*v_c#+&cc}ix!XZtu03SnhX$ZN~##Tjf# zb~8I8(eY2?r+$p9%J|a2&V){4aA2xYJ<=JWl5V#-5=K7L5lW=V{V`(Top^>pfRrnt zB!KSh3^O-stCwGA;3O^W>XrS=)Df4{zHGGD9g<2|rj4!fK^V(!LTpw*~2NtkSNEG|veiB_k?q zfL|l$9ZRI=hA(J6`|$%9m=l1`*z@IoXPJAD_MJUHQ(ck&^`JV}z(t<`J5z1%x4q#_ z0R&U#hwrk%)4{Y3IHPUgUpr}?Wq1%q(<`XD`JnO9QX_fG$&32)gs8Vihpc*Y)79gt zFp&FUS5rGZUH}M6q&=fY*jo31tzNzM6@KuFzR-3khOv~gjQ-F+8 zh2XxHW*Oe%#^pHV4mM(2c241__e$u-Z0xtsv}bU_U-?)|83<%}2f)8+j>xAu(o8eZ z?qN=OJBNnvOP5&hKnw3gbOt@46T@BwfD59u5aa?!UG+PPO}dnUFns;u_oXTDR2sSr zz7nK0ap@~#Eri)@R|ma>D4h$d{6BQS#5Q3UUsiqJK>1Q=V(S|$(Xg0DW8A5n-~i&p z`wu-}+dV02ETW)@>)1dmbiz*=MH&GN&$S?0LEx1Vz`zUx6L^vf{Sr8)cb~X#X3RB_UYvZ{%&rel=b! z$`RCGsP?v(1sA-6E9)=q2I_eQCDbh1{w|gto$?-d7Cclc@{1XQ%#l8x=v!q>$`uhH zlVhyf5+~CxBaH77(IT=xJWzy!Zcb%03_z)Q=Rz^!tBNIuA`leM{SlK*ftCnQFpm7I zT$?Se?ufA>*S`xMHqZu5_KbA1pZE~_J;iCy&hVrVY;3m6CMdXY9xF*o&RD&vDH99f zGsvUEK3bxN*&Y@1&<*X0Y2^H2bPFrGAQ;Ii6nGkWVl}GbLvbXO?APaTPJF{g=f#&K zRSmmhH2*Zl)c-(n=9u?Q)*l06PYal#M~9+SOmq7ZwoK?sxI1`Mq9cNN8p+>jWrLYKLVLu(z`HHItjB^8 ztjb3OS^Px`S@h9v+vS{*4+k>ooUK*?99e^;Vuc(gphv$^%Vs!n%%5QN(S9{caQA|5 zQlZ*apM`xd13ZxX_bxOYm%k{rM~|Iz9S9PzWrCt@F*~n*04cun&_eHjyUBr~@HhTA;IlnCd_S@H~1+ywI8r>NAi?Ppq_C#ihRm3j3_NX=nsz^wjH z-M|bYdX{IGb#;Pf8#tJJZLH+~uw_8hJ9!AMwZ$(OtYkA}`lej&dV=qwQ2T~sSS}tO zp}wW+8U_xjsfl=Ee|(8Wvwex8T>vFz@7Rl{sy3_8&)eV_lo?gq#4_IKgI8iamhg$x zrvolvo5%aU?v}ri4KHlaOSDf#!;CB7HJ7DQc4P-0u}P3O$?H9Kp}Od(__UoWj3wc& zI%P}%zdV&kNc6>%T<(ah2ycv*aeXZv^RdEUfQih)vWH5h{@sgj((3(CXS`KS`}5(4k4obekx_z* zm?@k|3@xu5Cd$U#bAk1zGjW`fX67KH*j)iIq-Q4{OdjPLuPJ1SQ!u)c*&@_wW4S!s zr}}T2?HxM0N*o#tG@Z(%v=GX`cZk= zLtEy2jCn7thPg#N*gF>NvCW3|{jyik0AO+{6xis_9(Wd#BH5PmlZAY%P*e>QmTe4R z*zH(g-?iG-py*{CQ6;yVsMEG+od|VCLYp5!#-C45X|C*z&ATCxJ#O~f@6h5kcQpT?RXGH*>=(=Emo|ckLP1&-vvf^cyG0M!w z=`TAS@H1b-TMmf~F#X_AkMQVtVVcZs*kmtNq~7x-=Dc*QpWg2gjG#^gYcMw)`+U>4 z)Jpr0ufT~R%atxgEf$LFATmD{1g_-FpI6~rdJrayPewbbpQ7B}?)-gz7M0gSA$tw0 z{*;!9rF(G+Oo={bpj7=ytNYekZ#bs4G4U>aB+fu1x7~vyPlSbB2Y0&av5yqWUUG0d z!Mk=(KuSj6ZG+0sEF?cl%FG8bKdv*TAhH4${GS>7a{c!5w94jD18uI;cc3L?G}Vds zd(UI1ESx6T-Q|KyQoU#gGIBeO+8fAKg_{}Q8xx!Uo9$h3kz|rv5*7-p)a09)k#vwH z8jd?eu{%V4DP;p!fD5-_?^y!k&g_Dwr-Eq6{ZF9XRy?6gIY1tA#GF5*118+p|KT5% zMm^X`TK4DiAHL?F18?D&?eS+*Dp|y$Ay?EBekEF0X5X&$VAY%W=;B8GsA<|a+{Q|K zg^l(H_b=&yMFRIn;=I7PZ*0?xd`T`M7Z?USuL=3BsVOP_`#2t$ORhC!G>VUS;EMI^ zrenz(n3YWa7InsZWg{IRq_@u(;t2nh|lT?mbN{>!NT^Zpo)gFHCo3Q zqQ+*;&Ots?{fgark0R!kd9SPq2c^nOi}k$3ftCDp=Z}>rX}Ra3@g+@YJDE&UcpFlXQUzQb50QB-6^)#~d{;{}!FVh_SF<^LvjPU`G zz@-)bTHWM)eM{ZPP{&Kb8{qC9O^$!IYw$Yuo=d73Qt?o<>K&A_HW}BpX~~m*^V70K zxGT1G`_NM57xtKhh3qd`H5bb@vv`-9Ou#<4Xbs=8fi75sESgUwPUe4YV@vGjog)eH z){yWwu0-v#Fkk7|1Zgp?Q_E_zJ^oq{*SPi+VuR?l_d6@viU%&aJJoG5B4!~a8G(Ii zswiq@tXtx?Z_C(jWt`u=fpuF#gvsj9BG6@jemd6Yh{+nFL-y_MUb4VHHqh}^Rs;*^ zzCrRJhdvKn680ZL-~(9Bwy!FxA_0g}-3YbHw`%yD_)@}3pol$PHDXy#Tba!eHr7Wv zg_0_E`_kAa21squgqByWh9bS*zXB35+@Cn#A0|x(5LKqFqn4raZeXs1=D5CVM_7 zFwGI@NkssePRT_p!6d`V4oUR|)r3bBLKcf5jSM9i)hnl(74vuej#pauRvMA}fNijp zZiIS$Z=>JQgF43BLcKp*d9>xtjWT9QGEz4C6-xYWBy9{ z-uW&im!{_EFWR3^8J2VyE(3+Xjc#9hTsbi~qh4J9XF`4P7Z)Y5ai1F%Ba$KRoK@`{EN8PZEC5n&*IRV>aJ0$;Sl4V9!so@t-vro`5p3<=*$0- zc8D(6)EocsaV}kFLKZEvO zrmQMpgp~h#mDx~v`>w=qvf^g&meYBK;*@rvJOFqb+q0R10&yc(Z<{5j>LL-0bfo$}r&F2iD zWtm@Vf&)Bwky(w8(XWjl_4*jA1s)Ud>7zj8e?oP?M=7n@J3)imR|I&lJ0<4=MHW1h z%8skQg7k@CfirOi5F27#M(jrQ-EUSt!JRcj6ss%ZYv*Y<0R*Y^>u$nQHGus{{n%(w z?pmZ)kJ3NCKj$?>jkynn>78l`Up7v2pf=Bi141R2aiWJ18)$~a9;?AbVt6Hl_?E)L za>a+$faanIYVaD5hD%~@-zpU&40X_5>76@_l1}`j`9QA9NHKCb%>0`vFlD4vD<9~^ z2Kdm;zGeJ!uW#`P?@9Y7_9n~j3l*Nw%do#On#p=MY*QxRq20R_>U!2YY=0i>Ye7#e{r)fYUYiO= zCOVbRS&%NXjAGl?s7U%E?^VsEhlbO9d&-yb_|(GvgHsr?$;^2fuKfoisjS*`nLuXC zj9|8U1R|h&&Mz znPGZHbBW^dx3ZWXVXXE%F=?u`6(JnsdRdu-q*FucuPKTB)IV*^pA$w{0XMjR==r))Ql|61{tbEtIS6qF?dc4eK>PS!_sZTIupLH5IfDk zkz#dONvr@eAo+bF8kAzn1FDjiJY4#!joL8`#l*UAnZp~aRcvCpL#*YVc@sFkah`l) z`;A5{Hg{OQp8vC*wcI@Wo_XO<0?zOE6y&cykE$oEMuO%1w-u4zztmCuzeVFT$wkP;n3ZMI-Ft0wnIP$`iP#->< z+$e!#&{8{ACe&yY#=LE(Riw>#=Mxgl%tu5UNB^E)`|5D z{%*A{OJoj9S6?i=s3yciWH~!7JVwkAwVp`-O)yOE34oY;bj5Yy1VF5315<3EQzk*O zfkP%pvkpg2JlOoApTP?6M|vzLLSf1eSsM~7`1OdZl_K@y? zmog^I*ElI1&tK6e12gyF^10CKENQcmy2m;-%qsfGy0=iUdUe5RQka*nNJtqIC_Sw= zE*W88ePmg{eb^kL(%>f8C1Kfqa6Hcd@(M)!{YBG#fZQ5gu#;xoTrm?YDrCfe(+X-$!SV}sfc0jFgQ3?{Bb zHB@q4u=u~hBCGjmc15`@H@d)%4w%;9UY(P9t_l zvbC{IStB;x=v7@1pBWx2Qui0wpLkqbIBS8~O8w(Y-HmQuNJYS3p%{|{XxjmgQA-ha zUV=;#)d%-?+2eLN!sYSR*^Kc!`>70|ADt3%zN24_8d%9N6?6>PogMyQ7_i5Mf+89v z(HluDXX@kZ2g{xV4}1jwF@)0;I32}^_psrGo+`$LcFtH&lk?;ZCh-B+&s1MG0mI%{!bl%4_f?XI8e~Ev|ch$Vs7CjLo&qM5A9^! zZEn*k1k2dkvUErU8hf*Vl)*5MZiPP;8DQG9+=H8+$V4O1JAQ6>64R}yjrVez9cJ;F z#j|>$5j+^))NG(-=Q=3Ex5i_ui;AwG5gH7n;;cp;{lKzbbu#_<`jLC&whroIhc@tB zKv8+(I<)jHM$4yE{j>sEOuoVtw}!0zTfqaeb3G8S@jHFl@^&<=IO&Q8&qz@K)^NkH zsA|7D-^N9oAhZK0e-KO^AQKbULf6y83ay|)iC{(eCT|BD(I86tW=<=Zj{K;AOrrQ? zYljko0%Wb*Z#7##+%-p=v%99w*eN0+`m#FVij>Gc z;@11XEqo`An%ljCMjreydg!#Bq+E#5Sd`Uv|=Zwh{W*C{AESV-(L)@jOn`&B`))B4UKQ2_|)Jg zfY~?lm2b|UD{=tVoM!UWA4_v5lTYA0VyLW}A$ay$?DeW9_)*6O`XT>Ia(V~xb1N2X zu8y^)byeR=UNKfT5*$?uH|9R)V&b%5{Yp7XOJLAbTUYq%9+5848_YY#9VAr)>z&JAth9RM^|QF+a0W z_z1E($9P_;RIa)&^SLhSV7w=st3?i{NOhXan7y+XnKPCIeL>srAxhv4MUQ2xK7PqwwK# zegrbm5`1O{<@{tIf^bZja=W_zXW?zdIokTJO~@VR`jZX6hn0@@@*~0ukfO44_hHFE zt}rzBfthUpOcw$xPZ)R=VnfB**<%i#cruXOu-khT`DFtL2}ppecMNRZ!eDCS&I}%M z9JKU^@U#{_@#9GAysCWu?5_xc4730UWT1J-Kqrx%vV|Gqx`sh;S~-EDMH=EuO$HK| zRnC%uoFqsU$Uy3$LI$D$(AftDfF6nrWa0%g8-G^Lwawi3A=V^! z==(_#U&|fW^3Y_U9e;yt=tCq%#leW2eO+^CnmED8Ap|yV8(>TDj)9|hESx;n!o-Yo zte2V$WMJyO@MpI1fw5}>{Bn-5-M_N#kFjZepggG<0iUFIAGpi&i6jGck1LRYVz6<+ zWS~w3GEh3=`~Ki$ps8OV>3Ib)VXL7Sz@>}@4vGRby z>YXU*yS=o540@*Bmyui2iy;4aID1FS?MpXiUuuzk$%2I8@Y}ir8C};{0(knjLC>v^*PoY4v&irl(8jsX5BX1umgGBcW220kD@1+oD9UFj#Yd(jg0Nh@b`^_ zvsVq8)#YtuVKAhdxzZa*`rzcs95zo7j&PmI`E`YW{pygxpuGu7{OAk>xyK z?pA@6(I=P;M3pOsKST1StuQgR!EzmQXc|*OlpBL~GzPt*xiRQT)nMKKP-GxR7@Inb zxX3LqGqAvNvm|UNKZ2TFRC(b%5~H`k%G3%P2J5k^^e{wIzs3<^1a-cmx;aDmeUgF5 zpKBi7k5Kmj7@E*gVN}_Rl3_ogRL#fnacK*@>@4NFQid%}w^^mVy2eQa`uRh{fY#)21a5l1#L{aYzNSA0mNsZtVi5wYd|ATE)s){)9 z1tPqsV;HO8S4i=gpHVw_4C!$j@RY6t^o;#r<6V!GuEVO&R`&(+tDnWj_2FDqFpITt zkDf$!^B3&d^4m^9G)C1F`PZhSIS*rNY$7a~&0q&jt9Yy_{t7kI6aEK*Z=DAnp198&`T`nI2VLa)wnVjXgs-bV^l1Z! zo-<5neH*&>D-on01O8}apvilXcJ73SxdZe}Jz(VBi2|B`Q}1p94fu?!p66)?4)j%8Rt zQ9 zKcQvce<3F^8dg@6JG6k^miJIG%vBqvei#`jeuV-V=vj#8Ap_M;oWZ7G7w)Jqb5FQt z9iL+jm)k%M9WyqNp1B{aSB-LUE|vE3dzv8w$qCU8n-P$&V5_U(2Q=*d5yfR~@V4h{ zAfrGD5*o<{`Wv%>f_(j4FR=~u zn?N7~2_JPG5({`a`phS)+U>X2h9}!}=k0Cl^cz&e!W8?PMk=1+)4bukp zlTHR=4k&?kp8p-q&-^Evhpr;8@fAd-iVzaG6)xTzU}5799XjTL&hzwY!uAnvr;SQ~ zj~N0PXaNw&Ko2GZjhw>vv~BF%@`tJnM5F}@46^@ne_LuYkfc%qd!{etm4UhhGLTv@ zGSHNq3}g*W>jG?UI1T9o0zmVUfofzhkslMLXPJqJv_VxJKNm8`u&MllJW=^L$Uw1N z(!q2EoP(tgu_pCl&|&`4pQY!q}RwZSgN!8K{3Zyey~5K`iafgXrF18N+k;y+#z{-6S+#;(HUGW+Unh$Xpx(4ykRa8lb;rOc_6DGgM-9MnLdKdvt z*3i&#h5fodWDfloT1P%a%*Gh#>zF~yaXTU^zg0eZfCbZ@k&}UJp=}z4ps04lW{;@O z5p^J4k$)H>s>rZZww>I*teZY&bKifWY49>~8(u=>wn~HqY=MjSdRSPyL5G4%hAuGn zZ%6Lvk2B-6cJMg1ZCV43X)@4nQ5k68RiwtHz{ZT$Z3|dMK7;(;@43MEqw2}WKn9v# zhw3QsxceVy+4D~{^q)nRY%jvMlpv5uS)Qw5X64F~%bJGXuw6fi;)$PU$K!8G2KpI| z`#wdu$10c@(NR{82rN3yCE-;1l7q(Bz&`j)lO~NK5YeJQ9GLhHv4J$7IWka41rZX| z^us{D;x$xMNnkGb6)iowkpF<=B_AG3NCu)YICTr#Hb=qCoOHS?%pEttacwcy7d(TU zhA&Y_V{qUSM8mfsYkLvC4uOB5W5NhOpnmW7D9EXVi!D`#aff^2t0)@#FSPW(gcSjF zM3_0W++;{@zN8w&e{4|aE2^6_gx@I{h#hTp3u#dsU}jE7g#{v{;tV8{e??tO7aXlT zp>JXjgTN`2s!iA|fDE)0_T)@l=Se36$>Tb#!o%Ot^vr*vb>udRTK6M1yB1*~DRB4Q z2x~h}=rVE(S6y{m&n?xyel0Q(vlV`XY~mhdRMBmMAHP!uB~a^{_h3a(Ahc*hh;=$P zlpmFkLn1;Q#pZSKuwY3YJ9ut8fZ_)w<)?l@)7YC>;T{SjV+ZKyt-wn0X;e@Ag2p{3 z5#{R(BNKTgSUPHOy4ry{=-J7AZWr#@O8MW(2MMm?+mVDHAnE)7tGtb&W#R%oV@FmE z&&W0mD~hh8c8_{}K#8#GwgcD@WW#JwZM!%GzdO|$FScbG? zTl;;~%@y~cxC$dH)3~QL4|qQI&i{^@@^V-*;($H0U7CWV`v&t%!7nA1?;x-J_AI&V zNhbrT#4`Q^O@kMa+weRhQ!3a-dKa&CGqDUULsysu_8@nZSAq67Ap^-l_fMN2_WcK1 zp7~Ek{%LvUpJ?9mSJdzL8I5~RA#T+Mn3_32N7ofW#U~*epC!JM-pNf|2U82u-S*I4 zxf7N1l7Vg^c|BF*t%-TS-nZ%wY@no1QMN=j&~E~P3?zIMK77uPKn7ZZ z&+MR_pA1w#bqD>s4nZc6ffVo)DvPsVM2m}={HU&S`?gO0Ha3@z3mSae0r4$cLT zfkwVVZh9KDb?JQEFt|j&4DoaSH=6fwXE#3(tPSmaP`Kt3O$Pb_Qs#gXrpqC74s|`9 z)4t~)Xx{e^H1>avEh*)&w&4OH`W8NL-gabe=k_r|AOkG`0vYH5WFSg1UPM7@4Xkx2 zapwz%&>rOO{D+zVo#M6cLppK`;$cb@-+)AJ0IQOLDi=)#B3;qaJ_a|}0FE4BTZFBx zm)Ook)xJ3Vamhe`Lrdo#1o(w=!Exci>DB_;#O=CbQ|{KO@O_xvYX_Wl*M9s3a(lMZtWTjrnXIm9BQ=qxt|So`0a z4D>l7yjH`6f{ywjSXp%%;s+*(>A0u&v6?EhvfyqMB9+KM-yRNUn*gp{3x~4D)YC~Gr_YjrK;OguGO&upVCcKQo zLGmT|*Lhq}=PRn4GlbtO8Hn`FzoUkZ3bT*RiqI(;^j?$+0u{m!S;`81_R{bSN`SVv16sPEX8LPuqK~vunnl|a!P;nH} z88Xl@#IN51O9MI{%@_VzAE0cUUoBh_Kcn&4uMz1zz2>aK>gqFcGSH`3=dlhJHUTiR z3!ER8KC@&X6YrV7n>((=YMGkjDaTGBYf~ha(eX=WZZL5PhpnwWbc~3&R*tPhJX=qW zx*b2Gw*M<^jokokJsW77u7lU+w;+DzZ*#mKZr#nxlfuQ!npI^q3vEDB>rZHX?krNa z7Qw;H3nu*MQ3G=;KN!j@9O{{P!c;CB6AL%^<{U!B=+Y#@Dhje-OlvmHYmXijJ!ni! z-9vT%eyp_RVi<;Ys}NLn1I>GGqoQ>RKBk@vdN!&GY+~*SJyU0{8n1;P%<0)IJ>ame z87V#gp+p9%+x-({4LcF+8^DNQ2DWtU3XKChXlMq*Gx-%1Pg0dL{{8-l`jKnbVBx$17FI;e4S-X`1ac>Se(2-0eoHbC#Rba?(qU*s zM}@h=KJiP)CNCo`Dg@d_&M>qJ$BG&{B23M19{~?f2BJj&acqrEgQYH21+{_8=I4-4 z$KER58 raT}=#n_yu{BoG_uSZ+at>>O%#EV0i2weg1|195A2Y2G&I8&clJ8%{}I zqjusvwnYR&%g`A{w$WH6yQyYvSpXSm_z1QoCd%V#>*1ZaA2o}NOENssWS}{G>X}%^ zf3Pa#4ZR;Cero|Nttgh^0DbFFcx9g8Vi_~n{SC=L55mrWqONTa-Y)(yFtUM$X(^Jr zE1-^aex^w1I9SFJA^XbIt}b4XHz7%TH*1_AFA8is4}9 zy&yY5#RjtGY#>p~XzJOcG{NhKK80-R zXO*_KyZ#-bifTA2kb#ohxy6OwmJFG?-RD%uK;kjl0ncxHAGLk2!{31``L6HUfbD}H zV5OrS^i17hhh&P91pM|FkXo=C z@eD7ad`#VYxzFA$oyDCQ>kA#)@j@lG{fAJ@91k_^A0wnhbm5VK8V25lbm|wCw!bZi z=r5L{u1}aiUb^R)`iaR!$^3XeiGPR0(uFAk(*vo<#zScThS)J2&`kIn?d`59B1TS0V$I)49&18;OuI z^a=tUykJO%yPyzjp1)UWC{yn54myoq(tr8iilI>LB_d!nBT~stb%Oe9_ zLlct|5fjcKJGTpG(2!k;*!60Y^|Lg47sjMI^2bf}t5D7o1?d z@eNe++wd`V8cFLzp-Y4ob8lFOzs9O1vhSvTL{0w*#43@2K1UM;wfX%hf&v*zWS}rc z1}bKbx0?2k07`I+MfGrX^_U?8ZD~3OsoM8Il?5*F$UyI*YJ`G2^8R+-LvpwmjLc}) zt^?fj&Y@PVElRitNozkGX2?Lvi0dIcNU;qCFfO5lM+Ta;Clx?JqlvRhWFWhcCZvxL z36+0sN<^PSW@b5Tj7_0o8ikO|mynVtfs?5TmKkkESnj8&-bYTg;_rS&zVq)Dr5qsd!{!STtTs@1rBnX+`_8}2|aVoUv5pTYIqi5K4vpypy;|=GwTk`yD~*G zkY_2PB}Cwz>;C@EY|fQ|t6*SAiHns8E&3c46y%*!rzbz0M^ElmAOnRXyiuK=TrBRQ zdiWHQw@)EHZx0g7_oHf3oiD`AFAAc?t#;v#U^7_-*Ql3JIX$jur{KV+N@O6fO)sK! zpZd98xBH(ct!RR`0vTv??FBTdl`B~v z5b<%_U~OWK<;Gcvsr`0|-}8J>_7&S&#caJ$ZAYLrN0kgjez9uXhX?Z=A3z48>yUn} zsU1RqkJk(tC{lbJQkL-G_et~ostOsXXn;$oD7XC)4Wl1orMEwfXqT;RInw$*#(Ezg z7@9gj$GR5T<9BE0vl8w>$|D1%Gtz+k`=zocr$NU>P{mFKGLU=ZvnY8Cb`ZnPzoD5& z2HIHvEvgltk%5vUSCLuzJQDJEVRPevshHG`jFOH2fwT4sQ!mwUoFd?yO`{`;Krtljlbh-E_v^l^n2=|uZ%L{uGz zjLluDgmDU6Hg1I_^AAnou=!==_w#KyMf{AWsRLN)u>wX!ve4Urb+R)MlMlA@64LT^ zA|`Vb(OL6h4AJXSX2?JWHruf=dyIV-lRb%T{j^Cz@w&HBB9*{i%Nn{St}t>*#5(au zNZ;fST|GPKSop#!VjskllzidGMW{meMdTGrVQWD4j2|3U_9A=ouV_&FoIjy)=LsYy zZil%pt?AaV4%DZ6!b6(<~jE`l)iIYCG0;@)krb{czH=V->HA?Kfg$^FPtB z^De{#hmk4*S?~q?WZI)YlCYqT|Kc}%Lq?BwGL}* zzJX|mI$bJ%#$EqFX-+aMtf_i}J+#Q@7`zGT95Puven9QSDP*Q*z(~^)IwsCAbk9I^ z*%`F#{fUt>Rev)F7$M>WQj$|)p=k+iV>ei@ZN`?WDTIf*LtEDp%Wb0&lJf~9RGCI) zoX-uCN42<#M+Q2}h=lyMOOJbycE1E4TYHv!F!X4`_WpwiwX=aPC5Jr+QM_;IW78CO ztgPRM6*FX@OspwCgO=TMJoApfqha73gs+W&fj$u}Lg1giA0?BDu=+KA@-8Hu&%wu*j$QJEiBBuG_w7fhtrc`kd|>W7g3@RCaRdTw zoWF=d9vP@$+P-0NwR6ZRsD_O^u4WTbhm^j*&FPmv5+px^>2Y$-8XMOk?FqoyqdgnmNqmJqg6!N6n{g145y%7}K= zxWgvp5E@uAg8#dFkd0qKS#>?E^=zPPk}J!+7M;;tN1I@P$LZfHe>WFytB`^Cb9-KS(b!Te z6YjgYx6l1`|Z_v1Hb$Ut-5 zL*3*JwPYZrXP^8DHC>+~X0;PE^z5K*xfbq;`yt)U8Jfy{-bd~3pHNjhiC|YdJhfaG z8qQhRAil&B4bmNdg|u-OLIZuFsc#J}vpB3Q_za?5zpxzyv;ENc`vrA7P9iQc4i=_% z(ABemdH62m_7EY3n?JRCZX+W;4we=U&^54xNoX&!d+xC0l*+iH_dIe6QEPl*L^{*R z9uCQ0pn6PMc7p`!oVgBeud{*z?}R3 zdr10^A~7nI=|)}iFgQiO3eldQc+bi`Z|%<8C>9UE*Tw=``qnV=&O@a11{!wGJ--(Y zlGQ!rq-Md+${xCgR#@g-imi{@2t6@fGXp^(KatF!=Z*l=0n&UgK zAZ2|p44Eu#VYz-Uisi{i)$5baqUr$`QBeOW%El?EGW(rm_yRH$)MD~N#YN}>I zCGb`Ik?pvR@}_4HY;OT=13MVHBw%&bRaSPKyRZ8wuib%QXKQHaSwr730jmnXfpiyL zQ{B6yN;`K@(QyccUFT3eIs1Mna8^`U0XIi0Ry{<+hbjzQM7@gM{~mRFE+Q!+3YMmH zX1oO)6F)@RFv){2TO!!~07-o_T&+D|U`$Tp25hLmiW>fVpnsQ_$;Tr(ha#k@8;>lh zeHhUJj@;1>7QS#vH~0+A?N5nCPFH&5o`;MozVZ8FcFIkZ zSC%oGk}8Q9StVgpH?6gbZSO%gdI81R>Fj)QJ!VgKA$Q-jJ*nJh-S}zb?-YvJTrU}+1rw#!Nz1ToY(F`5fLuvdrB5l)u!kuHrKUvSHdgy1(ZyG&jXPJ zO3EAI?Q8*UJu7H{*5-E5)-{7$@_Q&Anx12n@VJTexM-ML zI5Pjybj2hJhVC-mrmjEPuJ2Kqoe2YTS58NoRUvikwnCdbkjUy`M(YWc%{NU$Y~OuG zlvD1D;KL(u}ZYrBxvO-I=% zw`J>gQ5D>0*6ms(HC~XfNi+s8Atg2eR&jN2~ zPupSxytcjrv07V`bpI{n7d61c%94@tO#*8W-*{8)`;zVYPgGSk!pqGb8ah7k+>iqo zk2NsVHG$sBL1cE^WQi~3-;o}g^#N7%1e3-5{v#&~{%`rdqL{rW{(aB%HF|y~{=)p5 z_BQ4htG0c3knh;g|F&8(P(%Nl>^c%q*Rts~>{Egi_d=dqUsP8YY|GMfE7Z|7qM| zcaf8k4+k4NXzN?T$gc(4+ODIPt=s(msQq72T_T2!9c^B; zgO*_q;yW*(dg?x_wp+LB7t~V)$$9ZN)Qc-;d>ljnF_cn;K=vJN0-U}ky^njS>3k0> zJm~l;N9ddT!D;g+knX#O%7zJeSvx@6z!rwqVF)R>2tFeMZ+@)bG708L;EO% zuGnLPXVEM)TgYNKVLtrHMVXbPQn;)Q*(4JM4YU-t&)W82=dIq3dB{XbT-Z zFSw_^i_%dw_9K(U4`|qN9b3XfVPfV8Jp)S^MZSSbI{K5}_K`sHI*qJyG3+$V7!gR% zH4VXKr_s9iCrF5tsq&k-KqRd@%qDlYH)5M&_3RR{s`v~VpSgmptSZ=)%HsZA`09`_GW+R7(2tzbPa;Hl)%!! zfmLF$Px}UqOF0+hknQ{{D(d!Ng_jxgnY3-A5SaZTBvVA5`fvAO)WSsJ{P#0*e(T|Ofg>m<(#o<%*HeAyA)M&@e~9AC(e&eXFEfK%r2Io4P5$4DdVJwc#ADi;{!`LO1=H)U{B8oydI9 zhtY`-0>y_0KUL!k7NW6d-@JO0zHd|{1F>A~`zpCwYR4x75s!iR42t{);ym!lK(et5 z@@G`N{(~TV9{$7U{0L;ACHP3jZp(2)g$zW75%bEsj7 z0)F3jAlvaBHis&Zfwq1G(eN!bZRP#fHN%1!r|euuV&YSBaSg(oeui{hwLM)!I(iwU zqB;dKP&>9$U}L%s$$#Jk(o&LO%oxwkFmp~rc=btCjNV5bIf9IB$|+yj#0^w5y^4sn zNpK3tKxEA^ln>pXz5c|HsPCp@!q&5UBu2j9wl`6_^WV9nfoQ=Pzk?c;dtD?bs8ZJk zTN^L7`&~yr4pHrNOc*T$oP#;jcKi&I*N1RUHti4%XuP6X1~FQ>iW!9J?n6k4 z+YB8;2k4r4!))zy5baPHtK+v(Gjb7S4euhe_G9FBoMnbA_x&A6CT>DB_$_iI?;^AQ z3l#QWVxYoq z$vL!a9AJ~CpX}Zye?fJ9JKP*Rp=aa(Z5>Z| zZhaoPT~`?)Xy)AF9X~@dd;)pfQdsbtNRj$P21@!2;t^%p6{ISVf&7M0G<_{T#&1E? z`!3eFD3E~)&Z2<#nyEP4AIeGLTsqoT6Sv6%h#deX%)p6GftaWipWD8cTGkTzfH? z{3~jjN8sn-2VFy|q~rnDm`P-|U*>kT$#oYwJCc!8C@iamleIVWjhvuwBqv|c?wgr; z{sSt-b#SnDVgw3ZJ%9LT9Yk*b4SB+gt3Jj0B%RpzIWjWq;A9_y;M_Nn+k2PIYkKyY z!E?xnS0Dqe+KX~>9{JBeYtrORlr`={u$=-KsPZxzrsXv810=m)AYoH13=OQHrRNMw zpJGJIzC+bG?F63Q4MuAcId)f2R`(oM2gks5RS7oLoy%4zZ$wa=1AtBo!OZ zvjDa7eNocv2DU{O1Qzfn0 zb(kgEChwrC{UG8u<-^e}5^KsoMe(3~9VUmi{|h9C4EohAc?!7=$0 zmONFxwgSe#NA18@NJ>ax<)L--eBr!)1bI9E$({X9?__$NdPGnaf~radGLT&v5*of! zUMFX87qvUjB4M>68R#G?25-sd9L=-y*c?JtLr5l$(DtiCX8UWn^<*bAMQczl`dga-V|ijlmO0O^JiP zjtz7Tonh*jf(X$Gh)3n?wS4~8PTfLP?G#oA1!I{W%_|}U6(U+XyYA4uE3NE;pDR@c zbbyXd2wcN^k=k*djc=}E1!vz&$8MlhG6-L{Am|&|L0j7ij$2+r!N6T+kIm}Ijf>Nh zhI8r3MPmFm#C@M2CddPt1`g0Q_ksD^Q55yuf^??(6A{Td{aMxgHlnC%DUp3F*TH@3 zyO0oJiTa{3*!vMSE0KYoM=@1i;9r-?^=~LCYk;>iktm4_REaq08Ad?lw`KEn*JW(m zv}J}2D-F$!)S>i_D)Co2lcO$pwj(iMJfAaCz^*NFvBd}cG0lFq0 zF!OCh_6V)5{O5QS+(X^2?~s;O3|l&~mXU!vk=pfF)bZvvTek}R zZy{}SJghBfe%NBUOD2gqtZ0WsfhoVD%>SU^ zW=-dZh+W5#fpp9_z-!BXh}7Der2B6muaHLus=q#KyDFe&;zu<0zm8R*>!C}DGy_|9 zyq~s?GrZFdplJ9fUi^Uk$6F9}e2$#j_mJ6e6qTblmHbEP)MXSle9ZjEtoC!t`cqMk zO={$*;wy@keMKeQfo$X~j||khVBc;&FyFE7Ej0n)0yOl!!LCDb9kFy^mJDQ_jg3`D zQA4o|l|E^nUqVXsCXT2^WT43BP|$l5GBtJ%t$PjaBXG3wW<)hEu0zBiU#zUY2JxJF z@)l~E`xMAPRY;@c#LV~8_GF&2Jz264dooJdo}~UL_&)F@lH%j#c90*O*N;9HJBWcu zTCEfC_6%b4TT5p>R*Q~9Oh>m&$90%IDY1AJnQ<{NVHLKi@|_6DvQwy*ueCEi%#NQ> zGk6-?lhe5h11ttKj#B=$$8MmitP}?1vl_d>*d`B2qkosjuei-{kN%uAhwh@Hb|*sJ z9HF6S3r(AFcx`oj!gW<#oHUB5(~1^oZo|3itTukv~YuA9A}2c?rZip+JYecnhWNfa zsGsDGD~g5mA3|b6Jd6yjp{?f(bN5_C*PcN&Z9tN*YgGMFGIRu6v6{tZHXM<6L_FQE?ZgWtFfor{g+Zpl`SW!FivcY>bGP z@*^TxCABNaujqo8wH2#iu5Ym#9(gCxH1&|rMmAvI=SYl;m6L%s!Z-OK#42mg9kzM2 zx^@Rby#26T&lXxHp72O`4uxc&(?+q*SFk253AUv&W3vCm3IEMFO zd;dRJ9822sA!0U$!_deYT3UW^j2J`q$UW38P=Au2D5`3N+cX)-t^}Lu2t~zM_hZFP z)JPj(YCvJjS8CP6*O?779r;}3r4 z({Zn*K@DvTX?ma8mD^iRp;{>}PPXM)9+B~<+fZDx z?`v$0kCVqGU0~r+fJn(HRE?>LOHN!tdEE71I6Y4r+0@Eok)R6tk~)~vF;9#Pv=wnZ_tnHAX zHVElPZu=ETzy}U1 zy0Cp}+IA#@_ri4FLxU0RW@8f4iR&n7J%XgnVmMm4Lf?Q$fPS!vdl_Z4F;ShoCw@eg zq!G^cly`7~j;=3!()OchSng}e*JIjHS0bB4fP)>gg=mwPecIQ|ewpcu*+3_ep=bkn zWGdJ|myxqQ9cI%ukYgz}_uZbu20DVoEebY}TP~txXL&ZzkB|(L4V0l^1Bnq=zgQb+ zY}N*{D@8)XW3quJuA{PUG@=x?_n9X`#-K%)Lm z7@h>ozfp$CaUkIRSu#*l=OC)bepKo2Q9%81hR^H-ukHVcCDr(C$-t~A6~V#L9Xf{A z9CNbeEUJn8qS7}RxTg>o8px_#P{6}>%ZI2Mx~ZnE9OIv2tvwkyRN{}|3Io?pmH-u`x6l`Z`p0+yVM_@d(P?hjh_? zWY>O)y!ua&U3~ys^STfm;14}b6Lt==&Dt(x_WTSP1*fRLv3sZ;JB)<5Z7?Ucvw=A* z0?H6x{5~?J2aqB?gyP<-oKjS|-$j6oxPjLY?CcE#V<+euS;8@9FSb@6M5g2*lIy-j z^~khL??Bc#0VjJq#>mw(vWB^TBVx;6MW*CknF_|#dv4$?~W;PFp zci~~qS(9H^GI|+hvJnKjdceqb1svA5AhYidq$G#2yQmrZ2HUo8hn1NJj9emEP^NsC z3?w3Goknh9Ej%3EU|_cn-dlDcZ}2YnTi)|kkNtp#9))u#$>8pL2i0TWtF)b!$vxB# ze~$R5c-Y!{vSjQQn$?m*gq01=_9 zU}5VE6R%8cYCNML!?{32KV>oylf%sS*HG286Coa6Ft!bW%}Oa!DPhB&`3`D^ze8qG z5v*10@h_dkMx?4k=^tOa_ZkfM&&MS2v2~isR=Z7onYWvjE&+GkdWcX`nnC- zu8YXZD2Jno4J&V|VHN?`@L|MP9zZ&g%Or0it@;_nL?$w$Rzu&OG39mjEnv2O2exWHb`w=S2eI0P90Et^Td#m`-kI5XK6V$?J%_P1GL%)q z&@o#9$H0`Z=kTM4?Z>)tZI^>S1vZzTxGjMRIa@!+<|P*S)^~yfGw3Q zHMWC>mK$8w)FZy&1!UA5;ABN9Jt!lhg~~H<$)eJRy$G_mV5ON2oa3>o{1WHzOy3t(`e}X%tNneUr)vdG zT}K$Yq+wAr44GVxcGR8+ym)|pjD(A77EVNem)TTXk&eON+8N8K{I$Ll%$<_4A^Rnyi{EGCB)$3|5_4PN@9hI!Z7Y^Ars`bA zK{6yXTvsEbM*}6*vbS!4xupX;)=D#M;sLh6O_qqOs%n9=qZhP|i0EqpGv9i|u(3(& zHI2;!NGch?N-tk5r#Wh54_%jJgwwoZvXEo!4n#di5gkfL_&GwyDhv)uAG5^0>a~@B zC!@2yb1S*>(6nRUeG9^J-$#b*AX25DA-D4q3!1C^_8~#`V`W7xT%6sx-Hdjnh_63| zYPIiy$mC}bzj`IxF-t*V>&*vHK5%np&QXv*FE1k-E`1rvH3yMd z^D>glUq;?nVt|A0g_-UHrXV5N>e{bre zZPA;o>||J21T2mD>tGMAgin+zCwQvSt{W4b`uhTU|N4TSlkrL~R}Fp+lMyImS7aCH zAGvl6t;0S90BUv95|E}2X`wwadUvX`O7&$vqmAbJHH+k#!Sd^O-svSP+yx8aDEE}U z*U1YBSrtn6MwU#rA|)j0b@=Gii7FwU3wb8NtrI&KR*-LZlTMW_CNnkqUPq++^3M_5 zXQWlSmz2yuD11>L4u2bP@kHe-A)B3ag=L0vp01M@=Nr({(rBv~Eq_#SbMrhc$&Ue? zYTZ#w+|t*5zu?ZxQh&oxL-jKGbL>rIqc0h?Aa$Tvi5!_ORb(i}^rSFXP@?E+E7Fl0lWZEXmyf?hD=>+~{Mxkk*cJXoNnd9Wi`9_3g_ay_ z+PzJ><4{;~=UPJ-!{Ou=zA?5fd-UF!fn}5Fxz=M)6~$ zPG650p*`fAl-eYWg{)D3Q~6@8z;}dvwtLd2xC99qo0IB6y;0q0X}oV2K-x}clK>N> z35Pe+K1p*=Nq zc~XHq8~f-|fJWn8|9AlWL}h>psb0kGg2HvJ7vk)CbB!ni+`eL>K(upLPdi0dHG`Eo z=&lPH4^a>rFt+`|7YF~8c-QI7L<#~bAe3JO=9tKphy^I~UQ@)ow7VLGywzUh4YN6SK`OQP$h~vguuMXrnK( zj)Fo_lh4VNklh*Nk{|m$BbgO-Mb51RMzuRjdN)8WJ27dax#5wNuCU!)qet(`@O-vk zyviKc876LjZDqfh1?^Jg%j_@s^JntAfKlsJyoWJN6|N*o7NPisAM7b~t4R%gqLsHL zb^#XpIueUkA`2S%5`Q@8l;_t@{27?jUe?$+F3H401fSynimQ6yJ2dVJjzDPqV?P(U z#;)KjUjhd@zfm&vqz|GsOa_>e)xz!-i-F|f;3B$4#3~(p@a+Rcba=t6hBBG0)}|n0 zkpr2L${-5j)B|k%w4ua&x^J4R?&u%Y9*Vz=d_CS~#r&w*U`uK`ypYntb2^T-aru|D z>8TjR<%+=wSnTWR!4L%R{wz0LuBMTa&v_9ANlz-svOq`a@0>jGq{Pe9Ts36N;8Xu4 zYe0L_%KCJThdN8fiJru`8I6A`hImQCeX`g+Yzc^h3dlu5ucyTVV~aN%Kajy}L#b~X z?r`PD!1ks0&pI?!FvEMr2^IK--0BK~+|GV{;Ky@MDz0;gLhyB_uyIv@(7w+0ED&)O zAPO>?dl@0BG9T6tG^n3ZKIXgY%~WA zz2P#f^8?2r_Ej*FwiPzg_!H7zLq?ka@Ogpf?nj)R9TO*07)R>$oJtv!A^Q;r!EoVs zAuZnxDAi+6c%UP5XWJEa`~E=ZwQm53lp(Pvd)y&<;v8=PcCv(bVC^v7#zJ_>l)%Fr z83h)9T^dUGK!{1S64~nt0&x9s!@AY4+d&Xjhc?EZN}`a&__(hX)g_-7^<_H@xUO)= zu5V0jcMuF$26hah8qO6;wVIuh9Yq7tRNwPSNBahl6B&&~XMi}EU2&SgT?*GT52X&1 zW0CM8fgPlV-$!WIvE?Sp;j*wnOYj(!Uk_etJ$~NtR-)cUe|#;Si*1n#N<^0YbE%;ZGmi5bh^37J06tCbM04(G9G@Jb@H* z$>$v*m5-AgyiHMTDQb-Yn`zIaX@=4}2P0oPLU*q@A6Iq7WuE3eSQp?sx@SCo_)PoMktlnRE*vXIX26pCjcQ{3f* z&_FgZEpA_U{5xDEsHy@@v^>?~qx^NyL}F}!Z&5sdA5@u-Hs5J&)Hw=i{5#EXVE3}2 zxrL%4rj!SQ8X5jQICF@GQFn*%V`{z$bIn}p{=j@%%*X6I*b@6ccID5LBj^GQt&=pc z*w`lfevtF(V%OXeMCjc9>=cP+;ex~QB27Fr+)JN=Lu_;F_g`}FH)_;u8rXo5F}|Bs zoUn$X5*&;~r^=45A0RrMoX^GC@RpAO8{zKPwA+%fCKhAxwvNpmN#(sKb}Py|z@|^( z;!(sKaMn7Xf?@>qXf;+Kiv>IS@=TUBMYB`VMl~+ZQowXaO*6vD5=n6*NlMF3enTQ_IFj;uC_Y$Z)N|6Nc(6%~ug+34HlR zwW#S~i&m>iK}oNtc>F7^pz$6AJzuMt8K7xgVwy7#>8}8jA}I4#aDRrYCZcK2&bumr zfUdne%AyUEs~ka@`jY@g&_Yv4RS3R2i*^sq4PgE=MDZ{-a}yJd(WXu4cp6Hs&>%DlHSwb4KYsJ(bH`R8Ya5$=n@^q&IOF9J4|rP2KOU zXb}tz0}jK{%>7>QtyEMX3^Y#giW9$df>jjM5vR(cRDrU{1Ae8yD#N?7P@Cfnrr{#D zs*A5fc2dRNa%2#Sq~Zd+a8nDR!@umnHxGwpokh;OQ1daWr^HR6Qip$VC!LKoxe0`* zL@Ulcf{ddlY!)KlXc&Q{K}9#3caw#Vgb4+q>enCOMi0r4pD~~+!8F%$+4@hWCDM!R z?77ym{i9MT^V_L4ivn!kYqBmb5W;q}=Cw)eN4w#wQmD^gSZMOfp20Z)hFL_be`zFl z^2$=LX$Gi|o)e6;)Gj2kF^@IgFi~K-iDEu%0LBZ5gi?Jc5(D}kF-wn@k`hjeO~oZh zxVUVmvRZq3fJpnE4Kxc_pY2H`hMj0m4MR=rB+N%itxkT{bncK?FO9wI?EYiYs!j|Z z0YQL2e~r!UEs(qxYIPY%X^g>piGE|(CN0QE z1H5=z_kF`dx$r({osgQ^eXe8r`J;b>AvB2O;*tYaI>hS9x$c?p8TMH{BeS^so*^^y zy!~?|(ld?gNPth6L@GXx)IX~-#pN#WWu@_7Iuce4Woms41y4^!bXxV4$Wl9z6&;JW zb^)+5P~6nT-U_(nGB`p;@Z)u2QJ|S#6`!hz3&~N=qWJg>d7D*NjX@)-(-T+nf)bZdn$!EH>u__ru zwRAj3k1)&>nM=Bo>cdFm4e_-TON93@4KJ(dnufnQ(T0wX6|mB9CojBCg-R+>&E8bg z0gfW-HcOBF^K3{fC&pIofj?Jgt9FmK9dB(V^=rb`Jq6@6#yO8nu=Mf>`gfPq1nWvNw^0_)+W2j&BCt)>DM7~q?M zk9tjA9PZJ<9({hfEYUpjGF*=Ap~R?Ft^M?+7-er_l0z46`%g#SV|hq7Uf~GXqR?mm zL_OB#NGjiHa|h%2J|F4~q!Bj!sGF>yK1@2!gKlcmgmZ7?>f)*gLnfkduIoW&>d=nO zroS%f^N*K}p=kDCTSdeZQy@!!(`09ita$WPK|OO7v3K+nNg|Xeyg&`ecK=4^&brty zv`zPgY2_c)%!Uiy4*?>f4#ga4{HXq8|EA%vs(!CEFyxYGcW1R=i$f*jQuCgdd? z?Il`78!LGE>B|?dp2hSkIhb}69U+ANff3RmT`Nq(c9$ji%8-NzR`&+mgI}x_D@BeH z_?x1jWvn35@^|5i(|G zIET1cor(OUdVhQ2PMopU8VsY6G~{SZwVSP@h63952 zA&QMGCQOjQr6=^gXv;X+BzqT@e*|u%ZUXJRa>2(O#JS=HPZ}P6~d0FD`)x-fK9< zmDdxCd+lNDdf~c$-6jyDVx2{l{sAuX3eL!2TdPwpOa*dbBY#O;1bT*I&>`MVe%Tx8 zdB7iZaJan4>Q|DDQid{6G+uKCL;Q!I*ZH7)n9iZDpB}mzbBJQjn9@r#HKSngAQb_e z_KTBet_(zjg=JDe#IE*=c-KqJBrofcSn7zAqt?ajo+`yt$?O5dQ~Ki!t*W@K77=4; ztu`NlAYJx5S$&5m_r1kT*r^<*r5~4UxJ#D^V{_P;y_b2!8pdp1Q<%dD#u(V;ok}Y~ z`Ob(iEGPNAC8T3sQAd9kTMJKyvmc`HIa8;7hyA^C8{fR``GlXNV@ZbmU42j-;hzlAi=$-WboA)C_Q|0Z}_9toZ5Jy<#FEPba ztv-(9U*F@5e=phn`{0@aFtt2_b{cil@mCFnWbuZJy>uob4i2OR6}D*TdhLfuGEx3u zO|pngoScsL=#GV%7d9NepS`^qOc`T~m`Z#^ZQ!0>G4hN=a3YVXue8Dsz(!^72 zg}i8#BamHQ!P+bpq3ex7h4`eI(rGmb3)7rU*J>N1qJWG+)gXwmm)9SAknXO=WWWeO zq!UHO@E|mQx;vV4T3id23r@qD-6+xG89D6}_acEa9)_rHNhXH?fSG`!QDJP%k6xDq zXuQXPA?Y~2VM z(causY~{z5PWpquKs%me>5I+b3ml-I#?(a7l?s%J&_rFfM^d3B8lW^fzoANh+!wp4 z$zh5Ad!n$=HvA{+n&}BXy!Ag79imK-dX6kaYli18G(j7Yi*cbCnBOIvEH>u_m(qnW zQ@}(tl*qy16=f((of}}5ne-&?Ez>JXi7Sx>y_WT$Xe?e-QCUFd69U6?hWLLI(# zNueUVo`iR>QMGh|fOy5y@)Bf4VFhgF_P&t-riT&O3Jq1&X{HfKpc(eI-RB}bs>I6{ zRS|zO-JG{pB}}jx?@h4b@z!8MH+*n;-4{7e>iM+P6^rhu-YIUJg$$t157X$Mqfql{ zsKBqN6|0$&SPE$E;=~NLe{~x<>q=NT^2>R|Zfq2}yuCF217=Jt-N8cl`!?zF!|S== z;hYR(_42FAtlM3PYaG1!BNMpw$ibW!8_DLT)}!wyDYDlo5ZpP8y;t2oIEY?Z=!_j) zV<)KmOi=PRoSa7#CuXnOh|g@J|0`Vko-+cvkTZ`lLGH&5Lhv0RB_^!Si$l9T3fV`edKGDj1=1(FC_WkV0=IVP%MsaW%+n}% zbUnmm4zpf*IP=Uh+>S8x{JKC!rgRg_oEw2PsHdQ+*b5oguN4esKWR=DdPWSW*laf$ zlN73=%4&BU3}*e|PZuz$8=r74o-KCPBu$O*a~;nLMlbF?;E5aA)t1;B)7@HGHiF%#y^3#^_uKCJzDcWdeEsWM?o}0FrclVS!FoP%YX8@uW zw|TsauCA|dbCgB~9&X9u#`&7!24~Rsi5NS&qVD_3Fq$-!6>&+~r|lf^9e!}Ovarl? zVvjh$+{8}7RZcHybl)qgRUc(REAQZ`v03HsJ%*&PB0pec2^~tuVbH7%VXIM+)*`Tg zUGD2)LJJZ2x!)9G16!w%oQhu3?@)ksm~-h^;A-J19E%PPKePpuK3*-0N37Q7mRkW- zJN}SZ<%n)F&DI+93pk)G956`$R_Z3Pdeb#bpXh1BHVP?u$%BeIDhqktC*Ghk4%pEt z9dqnRy)a2&RQ%>pDWhU9E#w)a+y%arLX>#q;tjwdh>;n{jT+!#QBXLy=CAfYWNNDZ zlzvV8KJmL5)`*L0=yoW)VVZHB9cF5;`r((fkXz*s^fKwD-e{M8pF_~*IgzeL0{%1h zZm0OkM`Ch&KZ^!??O5@0_H8!(Ej|3#oLP`vP5pZnIbfqCz>o~Ai1bT>k zGrP(7sI%H4Ih{-Lv|{EJ0?XOL&-g7qc^DJ&`~K&#GJ zA?kb;+xNr>yw{^i*D6atMZ;GC(TlpmV{Ex#jEbwwisGq8 zkodAmwPNH;O9Toi-ZGY&!5SXG!4XnOoGmG7B&|tkuKGpm?I9~Ep@=T`ds?EW3h`&O zR$egEI?9EFBw{O+9=)X3sHL4YcX-L}RCnoLk&bM__)3=ANVnu}Gz{vl-{w!$D#+6+8Pdz%yx#HDI zm8gyP6LIuDdUwF&%pNfpRtimp?F6&9Fodk`w*nK?4tx*fot~O>Wm4Du^TbxP@To6l zzoiw{ow2q=%59BcMH;5YB`$otTyH>&EXw^LU>=PzF~B#`zBKC*Io;#Jr2iRjY3WH~ zI@m3ir*U%m=}u6Ks>U-(M&+JRR!-R*ro@uunL|$`@)QUuX$|gk4>DVEoV`LZTiENF z>?IuHNxEZQ>X@JL4DGrjS$g?N!}8|##cH5GnCb6fX0PBoY&60b%dcNsLd?fbmL-Sl z0r0VlY6By@8tz>Y9Da}$mIB9P#Okw8)(2GZ975gVeF(kOb`3%H#G28qq4Ob8>qVGbv*;gk76$3n>bl3klUm)qqes6eQ?b$$jWK@ncb|f0&Nh(ZnB1%1VF)(@q@bfNA%@sur*0g)q#$W^^>+Mq$3t8hGwR@(x z5t*AhB8u494!HxIQ)9lRbc;|ZYehmOyXfwgm>Q{jG7CF**7p5Nk@+P5$fXvu7N79m zfz-sv0-P>_nVH(~t{93+5u8n8-r%#&Zzg$Ri0*ebZ@|#AjA0~67xVYQs_J_%Run^-9AHo{Y!VtdmP!y=lR z%BGH&K>k*Q(cvHPd_YDl({p6hW^eC^CaSJyE@y`^P6)1Mo-xhr#)t6+)&kRQ%$&yu zoo%DNY~dNs@VXH5ku@V&eLrU+5?%4M1!-1pmgec6@>JZ?AX>l3eLk7GN(Suy-&6VU zU-RoGK)I)I19rnqbe7$NKd^#`h_GB&k@h*Aiy9uX`*Mq~6k&Sbbh_Edrktoac)SkW zX%fW9?1{1w`xmSMuT6?n)h{0;6{Bh8+naMYGw$+I^9IFxAA3<5st?196*v?XWeK!q zs5_n8$g6CPYr=WzX(AJ8xVZ_X#}jG{#>q3fcx8{?VAKM?aYMQ$bFY+Aatb=vkjC8) zqib9pe*Sf*DL>g^k=0(4Je$&IC|wRqKbhS;B>jMkit-`FaCj1cf|4!h6rA>1V8K}| z{@9ewXY5y;dsEY~pDxik)AAhe&OvgT5$Z(`Eq@!Fmg9hLQ&0y~kAbdLZJfNIr9 z{A%WNGSl+c>n8;Ob>wt$M+<%kc}h)f(HO?u%_50jq$r15Ob8)aTt+*hDs$)M!B6yp zgN%ge<2Xf3HDhRBqiV2#VNB?Df6m)^T^qQ{Ut|U`y&ynR2Qk}V!( z6qG@`IFp|tG&)ZBV=Bc(S2vhgNw)ky- z6Wdo3;smeFB5|pO6xPEueis^L@|y3c@T=*=3m?iK z-AEZ@R!v(ckj8oLOo5>X3bwCDcXy!K%zJ>hU<&-FB#YRk zN}2CsK}k?g)CrARh5`S&9Z)Bs*Ks03XjZB&l!sD!E^fG*4z&802R=Y=Iwd#S(Q4}y z>;hyzKE>UMR8zB<&^45mS62c2)i)%5e;CB{ZcBe1UI@6nxeZBVb3}6qlR7oRo8`nP z*%&&f5o3nW_yKevgL>|Uj$`LassE)Q!F#~E(B^#EwITWN;&^enV9wjMT?3`-xAcDO z7_%rK;C-Lm7KfCb_!zPhmsTBU-i!Pi*DO-@ZWrJRJyN$~T!JEjghciKwx=lP4K7@| zWRhf?wf=xnT)2>SI?!xo9dQ2KggnR|xoL7fkYHG>Qi*G5VnGQ#?^_UDL42eP#P+8j zVQ)|zXvD%coQCp+TV4F>Boagf91-s*KvtfOBti_h6<{+$O?+L9U#pU2Ar6m&vYxZ`B;)o)3|Bdjfcm0fRv|svyLUpACe-|8*2NZcINvU* zeIW7cE47L30hB?8XfqQq*H-C>NH#~Q(>mPa=*VkUfl+K1ZsskV6Y`hu7sW=g(=G=W z#2y<92i+cEOoR<%=leF(6r3M}SxL&*M_2k05Un$iS~K??h|Mte^<@n>b9E@$HLKH& z;JQI5s|Sb_v&V5=An?6M)?sfA*NZh*Y&JjPAEf+D&FBe_#$Z%xi3!_I3T`$ESFHU^ zkk$gfJmT?X?B3+hip`QxM9<6&O!f*bx0?;etT>WGKpl459q(@AhS{jVdy5g)(0OrJ z6-5Nsm1i@g)`=z2Vfwo_`~*-_bb z+MGiko}CgVsUaj5w=*~igqCv143s}N;1_sPx>eEe;RN1^uZ>1+#T z!G2GW*sPFK-9ynMhW*gL*s&EDxYl+K$dh`|eX#i&`L3GgcGi=8;9q0axl15W>YyLG zquv~NqU*Mkzy2+e`}UVs?|Rbd2)RXrvCQOU!ZPP`^gRt3GDs7FseMB_lZ-CL_7%rw z4(%u40RG9=H?Stj8#`SH>MIJQ(DL)!{yZ)xjlqGLVrqY)G_M9vBt@L?7GA!8P(wch zE13Pm`&I4>FPv zngkJgX9IBkagDd5#U0kC$ExcT5wsEgq=Wa$?UFx#-+#NVHNXtIK9LAKu#`PQqUL>Q z=NLIyn8DYG(KC?>8Qjok3-KY*y+P?wCYGQHxY z%}ye6#&QhQe1l=uX#Oo4GYV_#Y~&W!v7_eibIkF|&*(yhQam}+iD)cnM#x5JO9lm1 z7I&g}$-DXpYiV#Rh9=Ejn4vg3M8jXq#~_vI%qu9kwB=)qZ`J;Gsx(R~7i_FTWZ#0D zkQQx6Vjo*#aA7ucd))K+12%pRJhMIFA+ZjmmKfcQ<-t+LZU=t8&K&7!$GT8uD*d+s z)ZG@|90WHpx%Zqn7`kSH_e4J(*b$lin|nn0Yp)O`HpUd*xJzD)Wimy%rFd<;OH1X5 zw9SgbSXVKQlZ46gx91-!UL1!vVR&AHN zt$&S|;JkvQrfmD;@euw3Q!DzH4o|xGO1VE|js@0yU(wxBi@MTe*YSY%u**&|8siZH zjsUtDzWddb?pl+fN(-n-5iNM%+M^wVk>UTEiOd zpOa9+ew_}?K|H3_(2tgX!fJd`5=p>MC(AwatT3GD2BhDuE?Ht@2Uv!RB+|D?&9{ub zZIrupuvy-h9t!`1MOZxd2~d9SD)jpXB`P*O1fysN-&(5=W352Von03jCjJOgfnn#! znp{{8cxb2wMUS5VonPIgGw@gbVJcIGrSrYkT|Q!o>3?f@*cV8u$1dI2}Tsc zg{4&aBM~M#@`CB^p%e>!Ljv@Zi&%EyMRZlF-&rB8tH=ti0!YwfZAs7pOr)wmV*yV* zZ9N2_|0&*};w>YEv4?E6a{)}BjJakb?DQjEh@X&ZE1;hVQo!hMNIorjM9(+mv?4GL zLy%eIIJvk~xVdb#)8tnxurMLWNL;y63C@&G`uLK?kj-``%&rOuRznxu$=uImtsrvV z{s!hF5^&h@GnMXIM9Tv|*Nt9{83^dFwb$ciDG-{vA*sA*izl}(Zg$Q_yL;hdGjPxq zeA;wVTJ)0FwSEZQ%j$@G958#V9*x*IHh8Qm?=LSZE>%lm6oA_(t~wxvjDM=}b)&Jr zXcoi_thm`~{(*nu_E2oHNzMH9xmO|GnzSo2w(2t$MHWUi?{}o}KzKT6LdfK<1UPYy z0WXP$pNH%nU>S`YfPqaM7IwQ{UD$9=2X>N;7+$dp{~8nwX-g#0oqeph&=QWS`_L%j z4B}Bl{`}`2%{?K8d1TbFH1kL|>p# zSAOC9e~0GcLN~QvZZu~va79$JGUeTVI>~DX)UYv8RizEp;y!K%BD^$TMABj+I*nJ9SwIMZ#&h54=_D=WZvif`Q zSpIa4fM8tv1G;8(so-S1{Q1iX)*UHG+6uoz%c|en5ns!?Gb)zTV<%+j2p@0UNPHmD z-8;gvSFDOK=1xQKA9XBXLkcp*Hc;`~`i4sMK-YpNSA4JlI!J5vyOZ{Me3|bK9r*%* zt4tdY{0_&gmM=?Zth^gj{dqW zia+uENAS`(o6Ur^PeFu@=Gv;z{q9=EJ*0L1c6xZ3FI2+z6m{8bAjufOdNx?7Jx;1$ zr=1g3)HLfdel=ubzn_1Rew@5@(7VTi!)zQ%|YOC?ttLnn%{Q?(>%q{lk!3 zU5!>^h+K?yhjYl1)I-%wX%Ch>(w7}GJZ%Xe*JZWVu@r5ikmuURm&K|}9I_?oNHqe< zOEN0i^>L z;%dq-?Fo=n9LHhnt(`!^E*L(RVLE^mcxT|EJXVblg__<}LB~_4OwLBaCshFmrsgb} z4nVVKs9bx1V^&WYRN`5wCVqxCs|H5WOY2{W+7FDq77Y|q)}B9a34R?T)5)0`$XQI2 zbrhan76_B6;Y}QpX7gXfqoTY4-Q9rm{h7;_#-rM6rJ$kJ052NLu>+8X^Y-KyrwyfN86XY+i^w*s2fH9T4Q6f?~)19j(-u6`sqP+ z)s0~FNO$`&zci2s|>3wce2(LTODf2G2R%A3;c~V=U9{zyB3ACUP zAANLew5K2aSz<*91Y7pxEMzRHIux5z{Fx*M8naUh{e&;Huiqlej%K$hKDTI%6I*09oP{BIQA)d2%>!2vC?f7IC^2uVK$`J%=> z{JUt}nc*70dWUhUW34yd2{84%z{3xP*lxiQh~Q-al6zB#R)qY(&hBE%jG%$4=tiNWBxBrB zi4nLP6#t(cXlzfmV(|{#lLa?(c9U^_wVY7|HNJb<=5_JO-DcU(&Qh{AB*HY($t$;H z1M!E$&I<24mKN53s4SG|JE+@JR{ZjRfo?l?<78jfK?6fLkH02{#u^AS*@~CY({`2` zNZmHueb5B_1*Q&|nL@V3|BF4fs<(gZOir*Xvv}75gxeYAN+;p%I9i7V;WO)##e32exZ9QJamRc zB@bw`A3%5tMG5Z7Qc32?x7!LY?DcXNgJ7Y0)>CvWXURD`?5(2v`sAnZYT^EVVGzm1-@0<=f9N;f6s zP<0G(a5#6mtMpZ^Lg}G5EF)rDieb<632xtlJT#VJanCM1{ew*EjBcR@&ok@_CMTR1VKpA7KnKy#iXSGeb4P zqGx4W1I1jB&tLm!tp=J!mDgricz%68ExX&44!YSd)Q#I<& z5^Y6Yp0Z~s^TP(Dt!}X!ck8ZMf5$Urgc?QO@GKf?ZNRRIs?@h|b+P%nB@+*33R{4Y zi}uL!b#^QYrH3-o8yzg0++XD13kmSbB<_VcNH&zjY z;s)vtuoNr-AHQ-13J{NtPpR-5}3pB*JkWB~jj^u8)t^&NALx5_srpbldIaqK!op z4Az1!$WQ6TIuue}IC~;-*jVj=m;`H;-_KO3ZeDE4pLd1hpvlh+6uE!$@LpKU4||Uj z3_6^o+m@<`juj|r9jLj4j5&X5jQu6Fb;Z@uXYD!#aExh#RS`<9_ zD`!-kZKs8-t8SqrznApU#ad8)`|s^yf2bp~r!0!R{{Lg?@_6fF{MY;#J;V9(W*Lph zUTAs?SI)4!7L=P3QvNJd=<$)G#7-P#MhL;X+Y2@Tdc-hFn-6~|k91;&&XG`)Z$#vj zE&&g>x9xibGG|X`rYsVuYp+Q?CYI~O9VW-#k^-K0hXn1DWU+j zfge*GPM7>gtqdpFrv9JFc9n)LZ^uMWi8X=)Cczu+K?U-?!F zB7XXOo{h&46s(xl>jF?)F{J*79@Sp?Dx|*ly1^hqQJWOBnIE@;nS9!EVf)Hu^*)bD zYM=G7?Al9HDnvB(g7)p6d=*)8L+PIuHqU$XQjsEpFRmttPh*AYEpdo|PUVQC0W}1Q z+^L*lQ%0_>S>~q>+l}pWuAWL%kk4lD=i_^H)34-e)~$1rHnJdM8aqL$=;`0(l>>8I z`lx>J)mvbBDty$Z=DX{FZg`)W*Qs}Y)^u_6$KX=A7(P7|eRn2+O&vC{p!`iqd^&bE z2zvlM_|`;b{4(K3;}c>0PyY5puYj8t;Dv=o^+kvA4?Zym1mz_ldhDnobGyA__*v&B z7I1=Y45~gM9Li#?NXfnZ>+!7NHh04492|ntq#^#6m*kK*Q}t=~`+VX@Sv#2ZnNl83g`X>3y!dILi8S%Zq+)c(ss) zEl*qC1O-ZRXsS`;tlbjj1ya;M)Kb>`&5Nze$v>G#q`Piutk1Y*Ps1kbyrkP+816Zo z){7%)4th`HeuaoiUVTVW+`BOQqHn(}5-+A>#FVL9ojy#zjOWmq>E@uk!?cP&A|hrD zm|*oI*AZ%aZ(vIsPaYbcSEY4SiaIKuxG2$1-uSEuZ%asbr*v}v@l=KcaB377teC)- z3!fq`=ZVsml0YNnr>sk;7UVV_l|wOCXY%&jTP<^VEq8Ch9q5&r=Xw-!g2H#<1R4tj zUEHnYPKyy{UT}<&EedikEWNq54>PFnufCJe;yTkztPtys-~Q?N7GVOvuu!%>JmH3;3Sc+Ei7^dn&%Cp5=mB5>0z>DFID1%`Fg3qQSc?rXD? zebOfMddAWG!i`Z!H~G3r%N*{Yiwu9VwCP2}N@$?T)5EDq%Dbtrs?zGw%hv6^p)CH; zbi6;~=E%QE;mr}02c4j?lNf(=%{DJz8EX|v+mN?zTNe4|w=`R-3Zdr5P(EoVu!9LVBHt`@@dZIdkfH@`F?)C|`!Y4W^V@cHw=HM7W<)OHSluZgw(D*!P+!{&l7A{(KfafYpphQ|yXS zK}EE&W0|rxsKSr*xSOlxms(8$k}*P)7Sr4B1NOToMs%w-thcwB1foxp@2AN9Kv!px zQ&%!mDXv%70#INR1uKh@@`I@4N$zXjr1cF)4&9L$M+P$3~ zg5{@x-0QKtbJyn!n^aDsCh`{%-Ix%43iwC)?Uk=4WQqTo*MZU(SQjE((5dK*ZJu5q zqNPwe0(ucn$F{+JZ7+C-@!P^3WgFn=v(Hmz|+oAY6l$ z8#X3HS))5H_?H zRiU|Lf}8pWxeFNu;G!J48?@Fgg zSS|D{_xY(lPYrO9N*p=Y8#=JGtiCbMGOvS77vZ0(+y_q#!J1<5Lgz(oZg`@eqBDFA zQB74M{%%PHB@YkP-0&>$HhH3<;ge4;HP?poExe8K(mGoQ4CgdaDI5PT;CZ@}%=M*t za|uMcX|TT3Qosk`fp;F*H?Kl*nkh_vT}bOJBb57yp-M1OpFF;a3KaZZgr{SLwh`CG z$sE*0b}z<=cqQJbh03NiDk%paO$br!AE*Z!RSF_DnuEWG+FY`r*vV{JZq+mVwX*9Z zIo|@wUS#0H{4v-`3WHWbp%O`i=bhza%pGQ<$d-#16k?k)3oQX}b0;p9#qm?WcjtVf z`T2gcY%~k4P+IEN>(y7gmXB9Ut_})>{kAD>p{JDPMH=FPQ@EAP;IIRgwGLcStvH`4 z%FN};(D7qshZ7w4?F(42=Q|SMW3g!`1Aq%^G-6z05SlgMI){Q z12;b7*uY-Ml^pQ)SgQJlnk4`>zg;t3`$l;8IKz@uP?&>UPN^43rVwHu(H$vg^=I=R z<{v|)N-x#~s@Hd0KKgxuG|rwkYeE|56H&f0q=pD@wDMkTL8^fC5{f8Aid2;j(o3XE4N8-aNCyc$w9tD<@@`Zf zJ?DGQ|1TFm$jqL}p4Dcp*=w)crhXkR%lV?G5w@Sm=?5x1ozjjYchqt;d+;F8yueS$E$aX{!d8oytVJU`r@9`4pf{gT+*0dF|1FisS9Pba(2k$(6 z#`sBR1-^~EcwRp}bTia2Y|f*iS7g(_%9XT44W4m(P6EMjCEaTMny3&|$!JjWafRmj z-85mpw*{z3Jv@6m=1Sfoevg<1=A0CgiL&%_C!#0P*B!-GO!n9ow~zZ6aXq+MgM@Gs zU4}3MuCP&|lg{_8si-dJDAG2S4VLdYEAsHZ`aFCmh&iZ+lx3|nX94jw{1k+bsQl9Y z2(gS^WW97~?I<>=gH1Q1XTI_DW6?eCs(AsSQ1Ftp|LcCX;Str(5|7S}v1q1;#L~yc zhnu0Sp-K;cLuyK@S7}u|&q0@qJ4NY5`ep=xwI)3j$MI}D2XI-tq zJ*Es6@}g5jY|TV25?(=W*e<8|XfGDoi}Jen?CkUe{0Wv$m1OCVX!Ka;?y)T}0p z%|!E=K9M|#_W;r&m-%*O@e&C3Ph<{8HmFNv6)MVOUr$xHboZ9h^TdYJX&-dKiMuJ= zW9!1YLI>#IVInudk4Tf$9-tUQ2ZM`8Y9h5ZPf1@f>R+z7>8SK%t z8H??GR4Ky7RoCQ9^_FMCLbW{?ijF7&ituW#Q}2>IUWNPMSUK7 zuWr5uHH!jho2-Y8;%2_=4w$hpX%>s@$cMW!;^yEhB#a4ILQ$5U^x6pIUi9g{Ig2+? z{1tJF%i=F9NZSTpX^&rx=j5K;n1&)r&e!oR`sjy&zrOLC|DIq}K_WujyNUtDmKlXZrx4)4XFRsCNBRrR#q?vaTxwHXs$*ox+D2wD>O;E z(+e=BymM=+UZ1Zt@#!~=jj)Z9xb%wbytM7N7GOQJU8}3zlM#Gya7T*VMTDz$vi3Ft zp%$|TRkY1^T)h5tXjx#brl4=PviZmfn=mtDX|L^Umsf*? zmQ-nMz6}N4*os5em#j(037eRu*53i{9sVMuek+GVvM>%*)@#AIzW;ffT~|r#n<+Ul z`u&CU74slZWYcXV*Mb%E*&FMVuI8$h(jz=(F`f-Hw?pVAEuIAc%QIw!-2H@To2GI<^md69l!!Fsls^;91-V{rai z0!D?wgXt(^)Hg+!e#VT4m_2 zY}jXb%~p1mespLFADf9t37m%V=(mi7L%t=$@Z`;im*{=ZmU|0LY7eYA=o9N@Qd<}= zDH*qhyYF1^h7J|qxu`F+=%K@>e+SJniirh4wAUtlTeKB0>z9(HTKI|RMWAjX*V+Dp zQ_8O0D`3mk^;a;jW2CE>4{e`V>yrQ`c5s2-sI6VT^onHBIVuc z&O@F8%d+szgl6AaMMv%0mTT>vJ-`F%1=;2|eggL=JWtWD*q5e?Pvdk1%~G=};)vyI zd|{U%lu3#%isb&@{%P43-Q8o}9#uTsclExEn#8oG>xT>2k;(^S8 z8xwScEzcudU_}Y_Rblow?1cc@>y|pfC+fa#Myx_K_RH?!=vC!aEkEsnf(kuN63INX>6MUnkT}>z9jKJhAqLOX2RRi$@F%r zIwgOk5ZX67jU3cDZqi^QXb8wnODRdyzfHyOq-Xee5d4GV+df_+C8qM6(y;W zERFYky?s=nc^b%v($?Xu-PzeFi#r>)d>^RZ&w62XtoS^}IEE@o&AswH@~qo=?LntE zJVfPy|MF1fK@agUa_fA`PjCDD%`Pd~8ziI6X07)5l$o;e910{Ju5HNyhwq|6{aZ=5 zF(Jv(KjB&vwUr+9Fb8=IMEV&RG90ZC;>xuKy%}^HQCg5H@@4VZhVL-h&Vuo@H0$;t z#SXEs;_+dc>Wcg+ro>mVxtK_#Ckg!x!gJg8$BZ~iQ?N)w&EJ`_%MZ(rTb}B1$-Qs5 zeZ7h3j_J@DYAzFQNpkHWi&9qUSzv3Bhm7HpH-ziYH~Udu+<%g3X=nRT6rWXqA^gA& z7KCuz<05jFi6~jktm@7PxOr4VZDaRx?oyE{2?e_ZlhlRJ-~;Cvkd}%i&i*to_pv0v zu4#4NX9o&+qSdePi!(AWOhzZL>U?2_Y_qG?f1qEj_HeV3($88WzY13>Xv3eOS1P*e z>njhsX|HVu1#&>WLl$Lm_AxNn*>ZO2m`{X|556h+eBnyUhCr6l*SSdw8Sm4*lGj4z zy&D0t)1{SQKFZ*;J^r&+08FqTW}sRXger?s>6vdDCw^*b>xNs8t}@c_~u z3`26?=O(iz`>BK*A!e!5tD(oc+wJNrg3NO4UBuL`^>s{n@#cz6#m0Y4>xq$IbAo%B9iLg~{GZKU|o*WC;8w5+<^M8)Gw zq46nzyG9$FOedaA&27C}Ye_Sq0=gt7`Me!*t`s{(#VqM<~&eg8Yn$6FtOT3j;leq=^^ z&;>UD=NMq-^lqhApR;Nvn$}O_aE%!PGV-%rFJWWU8u5JZt^dUyJQ=L&Fm(OvsHBp5q={;kIMv}n->DOF*&4gi?R6{VD1Mg=V9w(xcbj~^eR;A#Tsep zIm>Ib#P#ht@zDUUwPiIS?649GpHHC zs_Z;OvhV9|T`#T{2fWMnFz4DW0>jldLKHq(w>?&u;-P}dn(HsEW)f&st-?O_zP?P( zkXB(!sIs-62so^9pr$!1sO|=LaE5{rF@U=J{WI(S7R}_Ma_K(Q>Z<}hJG5;JmNx{W z{ww(7BvUQ*eU5UV|3dJ_Ckw$(2fTQ$D7{y?84d4FWRL@1=QIEiqOn?D#BLKBt7d+klUC@^J%zZ05m*l`Q^*10$(9&+a*0?*s;}eBd z8rLvOUF7nlF|(C^yRvUPke0J`$qL0wR99-&qloQPArV}ZN$aaG+e%;-)zU_}3Ow=d zOwZZk74!(7LhnJ^6v(|@j`zf0CXw*83e0lWk}{ zWSC=z0&ZU1?I*L+gJtp(TN+QinUk7l54G((5A58=Jt1 zNXPMp-w=HkcxBNJrXkuZBga2CZe|-@YlIdOJdE77_g)5!5<-qKob*?+VX3U;zmHs z3otff+)=qGZs<`1u7N`QT~?byAll*>sFJQCMUR_Ex4h0TeOKofg?smP0K|^^p!D4} z$S5vOz@Xu0d>mv5y?37IRxqgF#~`y%>f`kZse-|(T&pp}332V@v-;{kU;$9foEbwp zS2lem?N$fKb530#jn81cl>CIO$0ELwnuFyP9F+$`ZtRAQGf=WotEHX^s z{`q%r*6x4LO9nx{A9m8BEdCiRbH+?Zx1z4_E{PSp$1lXx<7dqhm6!RC?Wg7ifo!mq zr}AYW(NtU(OTLe-}S`v@8KqDO7`;AtEm1 ze^SpMQ^u)ijBgdNlnw72)a>QsC&KRWG^EtuYr)Gj@e?`oAa-JYdZwC#q84z2c&JzbkZ1*zR6Ugq@Cn0_hX+qtIzclobG?Vx#?9@m&>mKry66-i5G5PrTgm;|F8 zQ{nZxXr3@;&v_cbehLSN$85mZ2or7qTYa3}Y^7H|VAlN)qsu$b5lf2p<_lO}PwUh3 z_01#+160jf;IeQ}Zf63LCC|xlJO2ldv04)k(Jg}u3m@|ntR(0JkhM~ z0p1L@8f7eJd(Tazecq1;)RmsPBhW5IIfLE{&x!J5*kAcaYk_~MR8qoA>C>% zE~0Xf3~+4M-0`ZvrGuH-QsW=WYCvU~bC>{?>dKH<^nBX!=^XRT{=@X!95JH<<-oz0 z4dr+PUh?!29R5Tx@s+3?bIzR3|GQCsnQwGtyT?s5C&8(+j}s0tq)IDVQ=XOJRJm13 z$h;uDE{}k;;9mhC@YKu#uha{w#iF5`nH~8fr-h8M&=y2 z1A~DB>oI@%N)^N?8~1~t_W;ZUFM!kmoQ-MK095RR=m@ z`spB=86Ur8&k}zZu&5ykd%355nVECk!JEmF@`fjJYhtHWTb2Ic9f;SO+0AS$7$|th zPxNgkPVnKCWBo@k9DAfGYvhh*so#!6RPa+~{q_MpF+@&Oex{3;XykyO;jc(URO?~{ z=tjNpLDP|S59wb4DnERY^<~t&D-fbKcX4)kX$n}%g}hJJd^7f(ZmtUsfQ8_Cz_A-h z1YrTyxDxD6Zmp=65ZQZ#S!^I3#^_Q`ivK)Zd#o=*fke>%lRMSsdrBVr;36R4lIq>> zus>w!i&K0q@EWl+&;He3s@81A>*B}ADy*A=sMPEC^Xz1y!l>=YmHB8IiB`npP$xqt zkwXtNCJvh$`mfHQU%TzB{5qBlTy?Na^?PLSkgscPV^n8)l8fA@vMlJ$kQd!S*Q=Pw zA+Ixhw$Q0USE?B#9?R&`W7}2cR6!ZomC}RxPqrjV`iIQN#|B-dxU4Qsrkwz@$qIY2NEf=C7qv;ocsmb3()m?OC0#Qs7`3|a z-?l}e0Zvr%fKs{tt}azJTHUTurn!dE`g>@D2Z912dMuO)Cx+_oHE=`Bbl)NY?? zVw*o<6EEupo+#bWm^M2fW(TL*Ocno@YA>0+6(%{{GkK?H?^#2Kuk>fWU*@HXx}&2^ zLs`2eIw8F;XpnGcL`ic~$4s`%61_~KWgcFcZ2`QebvTIz^qw7}s+c;X@p31l*AwKK?@mppBE{ry+H0V! zGOZlTLk^!MI?x5>E7V^qU$zuf{bmAw?1~5v^F%&5gT#p6yRe&9T*q=bths>wpyAl* zDF9+c<-c$RxL>lN%dBh#9|y1J22Go#jBKw$2&L~cLBMmD{S8`RBLdnsu?)oCv1IbEgOUD;oYuuSk;EFTVu-@}LI>vvz39njMMIP0NxV*ief92IHF;>9e*S=jkR z5S2J~3gG4reYW{vAEezY^V`%sUo7($dz-#ZVG}B)E)<5_LT=2V*3vQxqcE6r#%d>1 zO%4%^XwV^6E#Z$HHa_F3`P$y7w|3M>HA}irQ-{WL%B*ijwft7o(+3_O5H};fsd{e` zYqH_rFJ5ow@6MLUb}Gl#iDud7R%=DG6cKz+aXw&Tu`zM*-5n$=8$e}bb92^_Em(cE zK&z^~7ZDq2_A+S+)^b80{EaWz-?n+5ARAH107q0XO4u2Fi4{U^^o2_Ie!r~fZ7XN3 zoQk)%rYR=))@#B;S-*7-he!Wdv}bc=ZrFe^QP)mV1PNx+q>4l8&Muw>8z|DK5oe|? zy%Dd6H>pSWYnUe7s9NH!nBmzSb?>z=u#L8;b0N7<$Ox5+G&vR6xK;N1nyO)GPffMmey4xc-p%f{Bdf3_dY+0xOP95R(gDRA2jlWXf!QvIl;#JDNcCF*eO%SCQU1(^fskOCGk#QYcD-v_3 zc72?70AlB(Zfq_cgVx@4NA;+;{HSp?H8thQ%I;|ygSm-?%j`s3Nmrym<)%Wiyt?{awkAdM+Cy0V;qYJzdL2bTKv2(aPB33( zLD^gb!<%+?-)+s+Y20YB?$KkkrfgqZN@@+Um@brvaF*38=q8fI-ml%NlrViVAG&Yj zDC2`yyufJHsW!PtCup-jZ>;UbY z%2VXFhb1swjy8tj4jsj-@2;9UY)nI(K}Ju`w~IL7Mm)8#o1GDlQb}0YBxxFzQ(E3# zopHb(4p^3p=Z6mI!;Qk*+BTS2jy`yFPfi9ivB=QyG^RAIueV9M+KJAi5Fwrk2{f$M z#O zLrvTdG|F*taSStGwyV6$aV}o=Q1vymT8>RJ z9LynLK2Ofu86g-F3cG?lS!oQDR&8x>ZUv|V*8}(j6HS6o^x(M#G+Dtw@l% z{(y%?P(X&*U}9Z{O*+%k28c+wis%jL<)dOX^@K2sn7xF#`}^Cu zQ%m>vRu6v+kB^5O!0jNMyuwnOi$$>Tq#Pgj`Bq^qqF7?6jh&sIebH{`EW|^_Gj`2U@zWO!&ZE>*o>0Sf~0CRj7UbPD!B@Ok*O8 zN@z|06oKzrO91K`N7waG>=n+T73YGtJw#&eUOFVNSLI!7tT`>>;iw@3f)$H7*KEhG zJgl@jEyCVK7K+;`FYX>9SaIE31Fy+5c)i6wDTjJZN|}vtx;)J9!t^-W)O`233I(t1 z*1k$RoO{RFs6g=?&uR@e&$F|$E1F@X0Ib0A$2UM@r{k*pOFtF?aOaO0?o`oR@tmZ; z`UoT!APAW)Q>Glh%`YoiAi^J2PW&1va@gKVQb2Bgek29{Tr%F(-hSJt`gxn}m*2ev z=LWz%M@%GEskOqcGvbdAx2g!#Hm@egX8m*~B6bntUJv>Q``o_yn?*G+)Pw+@|5O)2`3m^Vk{2WjOIW$y6Sa|Xe zCxe&bEqp*kVpipjW??uENbL!7bKs1PVn9a5naJSypYu~o+x+)~lFQKV$z F|9=lMs;U40 literal 0 HcmV?d00001 diff --git a/docs/examples/te_gemma/requirements.txt b/docs/examples/te_gemma/requirements.txt new file mode 100755 index 000000000..a4eaeea43 --- /dev/null +++ b/docs/examples/te_gemma/requirements.txt @@ -0,0 +1,4 @@ +transformers==4.55.0 +accelerate==1.10.0 +datasets==4.0.0 +sentencepiece==0.2.1 diff --git a/docs/examples/te_gemma/te_gemma.py b/docs/examples/te_gemma/te_gemma.py new file mode 100755 index 000000000..6285fea1a --- /dev/null +++ b/docs/examples/te_gemma/te_gemma.py @@ -0,0 +1,703 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +from contextlib import contextmanager + +from typing import Optional +from functools import partial +from collections import OrderedDict + +import torch +from torch.amp import autocast + +import transformer_engine as te +from transformer_engine.pytorch.attention import InferenceParams, RotaryPositionEmbedding +from transformer_engine.common.recipe import Format, DelayedScaling +from transformer_engine.pytorch.fp8 import get_default_fp8_recipe +import transformers +from transformers.models.gemma.modeling_gemma import GemmaForCausalLM, GemmaConfig, GemmaModel + +import torch.nn.functional as F + +""" +Top level description of the classes used in the tutorial from this file. +---------------------------------------------------------------------- + +HuggingFace Gemma Model implementation hierarchy: +---------------------------------- +GemmaDecoderLayer: +├── self_attn: +│ ├── norm: (nn.LayerNorm) +│ ├── qkv_proj: (nn.Linear) +│ ├── attention: (SDPA, FlashAttention, etc.) +│ └── o_proj: (nn.Linear) +├── ffn: +│ ├── norm: (nn.LayerNorm) +│ ├── gate_proj: (nn.Linear) +│ ├── up_proj: (nn.Linear) +│ └── down_proj: (nn.Linear) + +GemmaModel: +├── embed_tokens : Token embedding layer +├── layers : GemmaDecoderLayer × N +├── norm : GemmaRMSNorm +└── rotary_emb : GemmaRotaryEmbedding + +GemmaForCausalLM: +├── model : instance of GemmaModel +├── lm_head : (nn.Linear) hidden states to vocabulary logits for generation +└── generate : generate method (input prompt -> GemmaForCausalLM -> next tokens) + +How `generate()` works in HF's GemmaForCausalLM: + 1. prefill (input prompt -> model -> lm_head -> logits -> next token) + 2. loop until max_new_tokens: + - next token -> model -> lm_head -> logits -> next token + 3. return all tokens + +NOTE: Notice how "prefill" and "loop until next tokens" are just part of the `generate()` method. + This is a common pattern in HF models. + + +TransformerEngine's Gemma Model Hierarchy: +---------------------------------------- +HF's `GemmaDecoderLayer` is monkey-patched with `TEGemmaDecoderLayer` before `GemmaForCausalLM` is initialized. This way, +while the model is downloaded from HuggingFace and most of the code runs from HF's `GemmaForCausalLM`, the underlying +blocks of "transformer layer" are actually from TransformerEngine. + +TEGemmaDecoderLayer (inherits from te.TransformerLayer): +├── te.MultiHeadAttention: +│ ├── linear_qkv: (te.LayerNormLinear) +│ ├── attention: (te.DotProductAttention) +│ └── out_proj: (te.LayerNormLinear) +├── te.LayerNormMLP: +│ ├── fc1: (te.LayerNormLinear) +│ ├── fc2: (te.Linear) +│ └── activation: (te.GeGLU) + +To be able to use `model.generate()`, an entry point is needed. `TEGemmaForCausalLM` is the entry point which +subclasses HF's `GemmaForCausalLM` and adds a few attributes and methods. + +TEGemmaForCausalLM (inherits from HF's GemmaForCausalLM) +├─ model : inherited from HF's GemmaForCausalLM but with monkey-patched TEGemmaDecoderLayer × N +├─ lm_head : directly inherited from HF's GemmaForCausalLM +├─ te_rope_emb : RotaryPositionEmbedding (reusing the same for all layers for CUDA graphs compatibility) +├─ hidden_states_buffer : shape [b, max_ctx, h] (static) +├─ generation_buffer : shape [b, 1, h] (view of `hidden_states_buffer`) (static) +├─ inference_params : TransformerEngine KV cache +├─ model_context_phase : GemmaModelWrapper → uses (model, lm_head, inference_params) for full-sequence prefill +├─ model_generation_phase : GemmaGenerationWrapper → uses (model, lm_head, inference_params) for single-token decode +└─ generate : generate method (input prompt -> TEGemmaForCausalLM -> next tokens) + +Notice how "prefill" and "loop until next tokens" are specialized to wrapper subroutines - "model_context_phase" and +"model_generation_phase" respectively which makes it easier to use CUDA Graphs. Just one more abstraction is needed: + +TEGemmaForCausalLMCudaGraphs (inherits from TEGemmaForCausalLM) +├─ model : unchanged (HF's GemmaModel with monkey-patched TEGemmaDecoderLayer × N) +├─ lm_head : unchanged +├─ hidden_states_buffer : unchanged +├─ generation_buffer : unchanged +├─ inference_params : unchanged +├─ record : utility function to record the graphed callable +├─ model_context_phase : GraphedCallable(for Context/prefill) replaced by `record` +├─ model_generation_phase : GraphedCallable(for Generation) replaced by `record` +└─ generate : unchanged + +How `generate()` works in TEGemmaForCausalLM/TEGemmaForCausalLMCudaGraphs: + 1. model_context_phase (input prompt -> model -> lm_head -> logits -> next token) + 2. model_generation_phase: + - loop until max_new_tokens: + - next token -> model -> lm_head -> logits -> next token + 3. return all tokens + +NOTE: In the tutorial, `record` is called when initializing the model. + +Additional notes and clarifications +----------------------------------- +- Wrappers, not submodules: + `model_context_phase` and `model_generation_phase` are convenience wrappers over the same + `model` (GemmaModel) and `lm_head`. They own no parameters; they standardize buffer usage, + masks (context uses "padding_causal", generation uses "padding"), rotary embeddings, and + KV-cache (`InferenceParams`) flow for TE-optimized inference. + +- Buffer relationship: + `hidden_states_buffer` has shape [b, max_ctx, h]. `generation_buffer` is a contiguous view + of size [b, 1, h] carved from its start to avoid non-contiguous indexing. Generation updates + `generation_buffer` in-place with next-token embeddings. + +- Padding policy: + Inputs may arrive left-padded (HF-style). Before TE execution, padding is shifted to the end + to match TE attention mask expectations and to keep shapes contiguous for capture/replay. + +- CUDA Graphs specifics: + `record()` captures two separate callables (context/prefill and generation) with fixed shapes and + stable pointers, then replaces the wrappers with these GraphedCallables. Under graphs, the + functional behavior is identical; only allocation/pointer churn and CPU overhead are removed. +""" + + +class TEGemmaDecoderLayer(te.pytorch.TransformerLayer): + """ + Wrapper class over TE's `TransformerLayer`. This makes the wrapper very + similar to HF's `GemmaDecoderLayer` and easier to replace it in the code. + + Args: + config: GemmaConfig + args: positional args (for compatibility with `GemmaDecoderLayer`) + kwargs: keyword args (for compatibility with `GemmaDecoderLayer`) + """ + + def __init__(self, config: GemmaConfig, layer_idx: int, *args, **kwargs): + + self.gemma_config = config + + super().__init__( + hidden_size=config.hidden_size, + ffn_hidden_size=config.intermediate_size, + num_attention_heads=config.num_attention_heads, + bias=False, + layernorm_epsilon=config.rms_norm_eps, + hidden_dropout=0, + attention_dropout=0, + fuse_qkv_params=config.fuse_qkv_params, + normalization="RMSNorm", + activation="geglu", + attn_input_format="bshd", + num_gqa_groups=config.num_key_value_heads, + kv_channels=self.gemma_config.head_dim, + layer_number=( + layer_idx + 1 + ), # Layer numbers in TE starts from 1, not 0 like in the HF. + zero_centered_gamma=True, + ) + + def forward(self, *args, **kwargs): # We need to additionally pass positional encoding. + + # filter out HF specific args + keys_to_remove = [ + "position_ids", + "past_key_value", + "output_attentions", + "use_cache", + "cache_position", + ] + for key in keys_to_remove: + kwargs.pop(key, None) + + rope_emb = kwargs.pop("rope_emb", None) + + # Return tuple to be compatible with HF. + return (super().forward(*args, rotary_pos_emb=rope_emb, **kwargs),) + + +class GemmaModelWrapper(torch.nn.Module): + """ + Encapsulates the HuggingFace GemmaModel class as a wrapper whose + forward pass is compatible with CUDA Graphs. + """ + + def __init__( + self, + model: GemmaModel, + dtype: torch.dtype, + lm_head: torch.nn.Module, + ): + super().__init__() + self.model = model + self.normalizer = torch.tensor(self.model.config.hidden_size**0.5, dtype=dtype) + self.lm_head = lm_head + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + + def forward( + self, + hidden_states: torch.Tensor, + attention_mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + with torch.no_grad(): + # static operation - for CUDA graphs + hidden_states.data[:] = hidden_states.data[:] * self.normalizer + + for i, decoder_layer in enumerate(self.model.layers): + hidden_states.data[:] = decoder_layer( + hidden_states, + attention_mask=attention_mask, + self_attn_mask_type=self.mask if attn_mask_type is None else attn_mask_type, + inference_params=self.inference_params, + rope_emb=rope_emb, + )[ + 0 + ] # static copy - for CUDA graphs + + hidden_states.copy_(self.model.norm(hidden_states)) # static copy - for CUDA graphs + logits = self.lm_head(hidden_states) + + # This is not needed for generation but is needed for training + # or finetuning. + if self.training: + logits = logits.float() + + return logits + + +class GemmaGenerationWrapper(torch.nn.Module): + """ + Gets token embeddings for a batch of single tokens, runs forward pass, and + returns the batch ofnext tokens. Also compatible with CUDA graphs. Not a + subclass of `GemmaModel` since the model layers are simply reused here. + """ + + def __init__( + self, + model: GemmaModel, + lm_head: torch.nn.Module, + dtype: torch.dtype, + ): + super().__init__() + self.model = model + self.gemma_layers = GemmaModelWrapper(model, dtype, lm_head) + + def set_inference_params(self, inference_params): + self.inference_params = inference_params + self.gemma_layers.set_inference_params(inference_params) + + def forward( + self, + hidden_states: torch.Tensor, + mask: torch.Tensor = None, + attn_mask_type: str = "arbitrary", + rope_emb: torch.Tensor = None, + ): + logits = self.gemma_layers( + hidden_states, attention_mask=mask, attn_mask_type=attn_mask_type, rope_emb=rope_emb + ) + + assert logits.shape[0] == hidden_states.shape[0] # b + assert logits.shape[1] == hidden_states.shape[1] # seq_len + + # Fetch the logits for the last token + logits = logits[:, -1, :] + next_tokens = torch.argmax(logits, dim=1) + + # static copy for CUDA graphs + hidden_states.copy_(self.model.embed_tokens(next_tokens).unsqueeze(1)) + + return next_tokens + + +@contextmanager +def replace_decoder(te_decoder_cls): + """ + Monkey-patches `GemmaDecoderLayer` with the custom `TEGemmaDecoderLayer` + class. + """ + original_gemma_decoder_cls = transformers.models.gemma.modeling_gemma.GemmaDecoderLayer + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = te_decoder_cls + try: + yield + finally: + transformers.models.gemma.modeling_gemma.GemmaDecoderLayer = original_gemma_decoder_cls + + +class TEGemmaForCausalLM(GemmaForCausalLM): + """ + Causal LM created with `GemmaModel`. The underlying `GemmaDecoderLayer` + class is monkey-patched with `TEGemmaDecoderLayer` class before + initializing the causal LM with `GemmaForCausalLM`. + + Args: + config: Gemma model config that HF uses to initialize the model. + """ + + def __init__(self, config: GemmaConfig): + + dtype = torch.bfloat16 + with replace_decoder(te_decoder_cls=TEGemmaDecoderLayer): + super().__init__(config) + + self.config = config + self.to(dtype).cuda() + self.hidden_size = config.hidden_size + + self._model_context_phase = GemmaModelWrapper(self.model, dtype, self.lm_head) + + self._model_generation_phase = GemmaGenerationWrapper( + lm_head=self.lm_head, + model=self.model, + dtype=dtype, + ) + + if self.config.fp8: + self.fp8_recipe = get_default_fp8_recipe() + + # Rotary position embedding remains the same for all the layers and so + # created here. This makes it compatible with CUDA Graphs too. + self.te_rope_emb = RotaryPositionEmbedding(self.config.head_dim)( + max_seq_len=self.config.max_position_embeddings + ).cuda() + + @staticmethod + def _padding_to_end(inputs, lengths, max_seq_len=None): + """ + Gets the tensor with sequence padded from the beginning and + updates it inplace to be padded from its end. + + Parameters + ---------- + inputs : Tensor, tensor with shape [b, s] containing token numbers. + It's padded from the beggining. + lengths: Tensor, tensor with shape [s] with lengths of the sequences. + + """ + max_seq_len = torch.max(lengths) if max_seq_len is None else max_seq_len + batch_size, max_seq_len = inputs.shape + new_input_ids = inputs.clone() + for i in range(batch_size): + new_input_ids[i, : lengths[i]] = inputs[i, (max_seq_len - lengths[i]) : max_seq_len] + new_input_ids[i, lengths[i] :] = inputs[i, 0 : (max_seq_len - lengths[i])] + + # Trim the inputs to no extra padding i.e. fix the max seq len to + # the longest sequence in the batch + actual_max_seq_len = max_seq_len + inputs.data = new_input_ids[:, :actual_max_seq_len] + + def _create_or_fetch_hidden_states_buffer(self, input_ids: torch.Tensor): + """ + Returns a tensor of shape [b, s, hd] where `b` is the batch size, + `s` is the sequence length, and `hd` is the hidden size. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + tensor = torch.empty( + (input_ids.shape[0], input_ids.shape[1], self.hidden_size), + device="cuda", + dtype=torch.float32, + ) + return tensor + + def _create_or_fetch_inference_params(self, *args, **kwargs): + """ + Creates an InferenceParams object. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + infer_params = InferenceParams(*args, **kwargs) + return infer_params + + def _get_generation_buffer(self, hidden_states_buffer, data_to_copy=None): + """ + Returns a tensor of shape [b, 1, hd] where `b` is the batch size, + `hd` is the hidden size. + + The buffer for generation is some part (beginning) of hidden states buffer. + This function returns pointer to it and also copies there data if provided. + """ + # hidden_states_buffer has shape [b, s, hd] + # generation_buffer will have shape [b, 1, hd] + # Notice that `hidden_states_buffer[:, 0, :].unsqueeze(1)` will return + # uncontiguous buffer, which we want to avoid. + output = hidden_states_buffer.view(-1)[ + : hidden_states_buffer.shape[0] * hidden_states_buffer.shape[2] + ] + if data_to_copy is not None: + output.copy_(data_to_copy.reshape(-1)) + generation_buffer = output.view( + (hidden_states_buffer.shape[0], 1, hidden_states_buffer.shape[2]) + ) + return generation_buffer + + def setup_and_run_context_phase( + self, input_ids: torch.Tensor, inference_params: InferenceParams + ): + """ + Runs the context or prefill phase of the model. + + This function is overriden in TEGemmaForCausalLMCudaGraphs. + """ + + hidden_states = self._create_or_fetch_hidden_states_buffer(input_ids) + hidden_states.copy_(self.model.embed_tokens(input_ids)) + + # Update offsets before every forward pass (including context/prefill + # phase) to make cache work properly. + lengths = input_ids.ne(0).sum(dim=1) + inference_params.pre_step(OrderedDict(zip(list(range(len(lengths))), lengths.tolist()))) + + logits = self._model_context_phase( + hidden_states, + attention_mask=None, + attn_mask_type="padding_causal", + rope_emb=self.te_rope_emb, + ) + + logits = logits[torch.arange(logits.size(0)), lengths - 1, :] + next_tokens = torch.argmax(logits, dim=1) + + # `self.hidden_states` has shape [b, s, hd]. + # Return hidden state for the last token - output has shape [b, 1, hd]. + hidden_states = self._get_generation_buffer( + hidden_states, self.model.embed_tokens(next_tokens) + ) + return hidden_states, next_tokens + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.Tensor] = None, + pad_token_id: int = 0, + max_new_tokens: int = 0, + *args, + **kwargs, + ): + """ + Generates next tokens auto-regressively for a batch of input tokens. + """ + self.eval() + + # Both autocasts are needed: FP8 for operations that can run in lower + # precision and BF16 for those that cannot. + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False), te.pytorch.fp8_autocast( + enabled=self.config.fp8, fp8_recipe=self.fp8_recipe if self.config.fp8 else None + ): + lengths = torch.sum(input_ids.ne(pad_token_id), dim=-1).squeeze() + # If padding is at the beginning, then shift it to the end + TEGemmaForCausalLM._padding_to_end( + input_ids, + lengths, + max_seq_len=( + self.config.cuda_graphs_static_max_context_len + if self.config.generation_cuda_graphs + else None + ), + ) + + batch_size = input_ids.shape[0] + # For benchmark generation run, this is being set explicitly. + max_input_sequence_len = self.config.max_seq_length + + # InferenceParams is a cache, where keys and values of previous + # tokens are stored. Moreover it stores the current running lengths + # of the sequences in the current batch. + # A helper function is used to create the inference params object + # because this `generate` method is common for TEGemmaForCausalLM + # and TEGemmaForCausalLMCudaGraphs. In case of CudaGraphs, this + # function is overriden to simply return the inference params object + # that is already created in TEGemmaForCausalLMCudaGraphs' + # constructor. + inference_params = self._create_or_fetch_inference_params( + max_batch_size=batch_size, + max_sequence_length=max_input_sequence_len, + num_heads_kv=self.config.num_key_value_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=16, + total_num_pages=batch_size * max_input_sequence_len // 16, + ) + + # Set the inference params for both the context/prefill phase and + # generation phase objects. + self._model_context_phase.set_inference_params(inference_params) + self._model_generation_phase.set_inference_params(inference_params) + + # Context/prefill phase. + hidden_states, next_tokens = self.setup_and_run_context_phase( + input_ids, inference_params + ) + + # Generation phase. + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + output_tokens = [next_tokens] + + for _ in range(max_new_tokens): + next_tokens = self._model_generation_phase( + hidden_states, + mask=None, + attn_mask_type="padding", + rope_emb=self.te_rope_emb, + ) + + # Increase sequence offsets by one because we generated one token + # for every sequence. + lengths_tensor = torch.ones((next_tokens.shape[0],), dtype=int) + inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths_tensor))), lengths_tensor.tolist())) + ) + + # `next_tokens` is a static output tensor, so we need to clone + # it because it gets changed every iteration. + output_tokens.append(next_tokens.clone()) + + result = torch.cat((input_ids, torch.stack(output_tokens).permute([1, 0])), dim=1) + return result + + def forward(self, *args, **kwargs): + """ + Forward pass for the model. This is used in calibration step when + forward pass is needed to generate FP8 calibration data. + """ + + self._model_context_phase.set_inference_params(None) + hidden_states = self.model.embed_tokens(kwargs["input_ids"]) + logits = self._model_context_phase( + hidden_states, + attention_mask=( + kwargs["input_ids"] == 0 + ), # Hardcoded, this only applies to bshd/sbhd layouts. + attn_mask_type="padding_causal", + ) + return logits + + +class TEGemmaForCausalLMCudaGraphs(TEGemmaForCausalLM): + """ + TEGemmaForCausalLMCudaGraphs is a wrapper over the class TEGemmaForCausalLM + and uses CUDA Graphs to speed up the generation process. We need to make one + trade-off - batch_size, max_seq_len and max_context_seq_len need to + be static. It is necessary to run generation without changing the pointer + to the variables that are recorded in the graph. + """ + + def __init__(self, config: GemmaConfig): + super().__init__(config) + + self.config = config + + # Preparation of the static buffer to hold the hidden states that are + # passed from one layer to the next. + self.hidden_states_buffer = torch.empty( + ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + self.config.hidden_size, + ) + ).cuda() + + # This is in fact part of the buffer for hidden_states. Refer to the + # `_get_generation_buffer` function for more details. + self.generation_buffer = self._get_generation_buffer( + self.hidden_states_buffer, + ) + + # InferenceParams contains the keys and values cache. Refer to the + # original call in TEGemmaForCausalLM's `generate` method for more + # details. + self.inference_params = InferenceParams( + max_batch_size=self.config.cuda_graphs_static_batch_size, + max_sequence_length=self.config.cuda_graphs_static_max_context_len, + num_heads_kv=self.config.num_key_value_heads, + head_dim_v=self.config.head_dim, + head_dim_k=self.config.head_dim, + dtype=torch.bfloat16, + is_paged=self.config.is_paged, + page_size=16, + total_num_pages=self.config.cuda_graphs_static_batch_size + * self.config.cuda_graphs_static_max_context_len + // 16, + ) + + self._model_generation_phase.set_inference_params(self.inference_params) + self._model_context_phase.set_inference_params(self.inference_params) + + def record(self): + """ + Here "the trick" happens. `_model_context_phase` and + `_model_generation_phase` from TEGemmaForCausalLM are replaced with + their recorded version. Once the graphs are recorded, they can be + replayed with minimal usage of CPU and that leads to speedup. + """ + # Record the model with training=False, because it will be used in + # generation. + self.eval() + + # Setup the recording for context/prefill phase. + input_shape = ( + self.config.cuda_graphs_static_batch_size, + self.config.cuda_graphs_static_max_context_len, + ) + + # Hardcoded value for the context length. + lengths = torch.tensor([9] * self.config.cuda_graphs_static_batch_size).to( + device="cuda", dtype=torch.int32 + ) + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + # Record the graph for context/prefill phase. + self._model_context_phase = self.record_graph( + self._model_context_phase, + self.hidden_states_buffer, + attn_mask_type="padding_causal", + rope_emb=self.te_rope_emb, + ) + + # Setup the recording for generation phase. + input_shape = (self.config.cuda_graphs_static_batch_size, 1) + lengths = torch.tensor(input_shape[0] * [1], device="cuda", dtype=torch.int32) + self.inference_params.pre_step( + OrderedDict(zip(list(range(len(lengths))), lengths.tolist())) + ) + + # Record the graph for generation phase. + self._model_generation_phase = self.record_graph( + self._model_generation_phase, + self.generation_buffer, + attn_mask_type="padding", + rope_emb=self.te_rope_emb, + ) + + def _create_or_fetch_hidden_states_buffer(self, *args, **kwargs): + """ + Overriden to make `hidden_states` static i.e. not change its pointer + in memory between every invocation. + + Returns the static buffer for `hidden states` which is already created + in the constructor. This is the same buffer as used in the + context/prefill phase. + """ + return self.hidden_states_buffer + + def _create_or_fetch_inference_params(self, *args, **kwargs): + """ + Overriden to make `inference_params` static i.e. not change its pointer + in memory between every invocation. + + Returns the static buffer for `inference_params` which is already created + in the constructor. + """ + self.inference_params.reset() + return self.inference_params + + @torch.no_grad() + def record_graph(self, function, input_tensor, **sample_kwargs): + """ + Records the graph for the given function. The function is invoked on + argument (self.hidden_states,) and all kernels are recorded. + It then returns the captured callable, which can be run later while + minimizing CPU usage. + """ + fp8_recipe = get_default_fp8_recipe() + + # We need both autocasts: FP8 for operations that can run in lower + # precision and BF16 for those that cannot. + with autocast("cuda", dtype=torch.bfloat16, cache_enabled=False): + graphed_function = te.pytorch.make_graphed_callables( + function, + (input_tensor,), + fp8_enabled=self.config.fp8, + fp8_recipe=fp8_recipe, + allow_unused_input=True, + num_warmup_iters=5, + sample_kwargs=sample_kwargs, + ) + return graphed_function diff --git a/docs/examples/te_gemma/te_gemma_loading_weights.py b/docs/examples/te_gemma/te_gemma_loading_weights.py new file mode 100755 index 000000000..d0df9edc5 --- /dev/null +++ b/docs/examples/te_gemma/te_gemma_loading_weights.py @@ -0,0 +1,189 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import os +import re +import gc +import torch + +from typing import List + +from transformer_engine.pytorch.fp8 import fp8_model_init + +from transformers.modeling_utils import load_state_dict +from transformers.utils.hub import get_checkpoint_shard_files + +""" + This file contains logic of mapping the HuggingFace GemmaModel parameters + with TransformerEngine TransformerLayer. When we have initialized Transformer models + both with HF and with TE, we can copy parameters from the first to the second. +""" + + +def _load_weights_for_fp8_model(vanilla_model, hyperparams): + """ + Loads weights and FP8 metadata from a calibrated weights file. + + The weights are in BF16 precision, but the state dict also contains + fp8 metadata computed by the calibration procedure. + """ + + fp8_metadata_sd = torch.load(hyperparams.fp8_model_weights_filename) + + # A hack to remove the extra state from the fp8_metadata_sd + # that contains the extra state from the core_attention module. + fp8_metadata_sd = { + k: v for k, v in fp8_metadata_sd.items() if "core_attention._extra_state" not in k + } + vanilla_model.load_state_dict( + fp8_metadata_sd, + strict=False, + # Because some parameters have multiple pointers to the same weight + # vanilla_model._model_context_phase.model and + # vanilla_model._model_generation_phase.model we need to load the + # weights in a non-strict manner. + ) + + +def _load_weights_for_standard_model(vanilla_model, config): + """ + Loads weights from the HuggingFace checkpoint. + """ + + archive_file = os.path.join(config.weights_cache_dir, "model.safetensors.index.json") + resolved_archive_file, _ = get_checkpoint_shard_files(config.weights_cache_dir, archive_file) + total_dict = {} + for shard_file in resolved_archive_file: + state_dict = load_state_dict(shard_file) + total_dict.update(state_dict) + + replace_params( + total_dict, + vanilla_model.state_dict(), + config, + qkv_fused_and_interleaved=config.fuse_qkv_params, + ) + # Copy remaining parameters like embedding. + vanilla_model.load_state_dict(total_dict, strict=False) + + # Force mem release. Taken from huggingface code. + del total_dict + gc.collect() + + +def load_te_model(cls, config): + """ + Loads the TE model with proper weights. + """ + + # Force the dtype to bfloat16 while loading the model. + old_dtype = torch.get_default_dtype() + torch.set_default_dtype(torch.bfloat16) + """ + Custom method adapted from `from_pretrained` method in HuggingFace + Transformers repo: + https://github.com/huggingface/transformers/blob/f497f564bb76697edab09184a252fc1b1a326d1e/src/transformers/modeling_utils.py#L2579 + """ + config.use_cache = False # To make TransformerLayer compatible with GemmaModel + + # Loading model with FP8 only weights needs both the following context managers. + # 1. fp8_model_init(config.fp8_model_init) to tell TE to use FP8 only weights. + # 2. torch.no_grad() during TE modules' initilization so that they respect + # the `fp8_model_init` context manager. + with torch.no_grad(), fp8_model_init(config.fp8_model_init): + # Just create a model with random weights. + vanilla_model = cls(config).cuda() + + # Copy proper weights into the model. If loading weights with FP8 metadata, + # then the source weights are basically the same as the weights in the model. + # If not, then we need to load the weights from the HuggingFace checkpoint + # and do mapping of the weight names from HF to the TE model. + if config.fp8_model_weights_filename is not None: + _load_weights_for_fp8_model(vanilla_model, config) + else: + _load_weights_for_standard_model(vanilla_model, config) + + # Restore the original dtype. + torch.set_default_dtype(old_dtype) + return vanilla_model + + +def _get_all_layer_prefixes_to_update(hf_state_dict): + """ + There are many parameters in hf_state_dict, whose name start with "model.layers.[number]." + This function extracts all strings like "model.layers.[number]." + that are starting strings of keys in hf_state_dict. + """ + all_layer_prefixes = set() + for param_key in hf_state_dict.keys(): + layer_prefix_pat = "model.layers.\d+." + m = re.match(layer_prefix_pat, param_key) + if m is not None: + all_layer_prefixes.add(m.group()) + return all_layer_prefixes + + +def replace_params(hf_state_dict, te_state_dict, config, qkv_fused_and_interleaved=False): + """ + Replaces params from TE TransformerLayer state_dict with corresponding parameters + from HuggingFace GemmaModel state_dict. + """ + all_layer_prefixes: List[str] = _get_all_layer_prefixes_to_update(hf_state_dict) + + for layer_prefix in all_layer_prefixes: + + def copy_from_ht_to_te(te_name, hf_name, start=None, end=None): + te_state_dict[layer_prefix + te_name].data[start:end].copy_( + hf_state_dict[layer_prefix + hf_name] + ) + + copy_from_ht_to_te( + "self_attention.layernorm_qkv.layer_norm_weight", "input_layernorm.weight" + ) + copy_from_ht_to_te("self_attention.proj.weight", "self_attn.o_proj.weight") + copy_from_ht_to_te("layernorm_mlp.layer_norm_weight", "post_attention_layernorm.weight") + copy_from_ht_to_te("layernorm_mlp.fc2_weight", "mlp.down_proj.weight") + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.gate_proj.weight", end=config.intermediate_size + ) + copy_from_ht_to_te( + "layernorm_mlp.fc1_weight", "mlp.up_proj.weight", start=config.intermediate_size + ) + + if qkv_fused_and_interleaved: + """ + When qkv_fused_and_interleaved=True, key, query and value layers are on one tensor + in TE TransformerLayer. Moreover they are interleaved within each head. + Let q_i, k_i and v_i be query, key and value layers for i-th head respectively. + Then TE stores weight tensor in the form: + [q1 k1 v1 q2 k2 v2 ...] + This is done to maximally optimize performance time. + """ + te_qkv_layer = te_state_dict[layer_prefix + "self_attention.layernorm_qkv.weight"] + + def copy_interleave(hf_name, idx): + src = hf_state_dict[layer_prefix + hf_name] + for head_nr in range(config.num_attention_heads): + dst_offset = head_nr * config.head_dim * 3 + dst_slice = slice( + dst_offset + idx * config.head_dim, dst_offset + (idx + 1) * config.head_dim + ) + src_slice = slice( + head_nr * config.head_dim, head_nr * config.head_dim + config.head_dim + ) + te_qkv_layer[dst_slice, :] = src[src_slice, :] + + copy_interleave("self_attn.q_proj.weight", 0) + copy_interleave("self_attn.k_proj.weight", 1) + copy_interleave("self_attn.v_proj.weight", 2) + else: + copy_from_ht_to_te( + "self_attention.layernorm_qkv.query_weight", "self_attn.q_proj.weight" + ) + copy_from_ht_to_te("self_attention.layernorm_qkv.key_weight", "self_attn.k_proj.weight") + copy_from_ht_to_te( + "self_attention.layernorm_qkv.value_weight", "self_attn.v_proj.weight" + ) + + return all_layer_prefixes diff --git a/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb new file mode 100755 index 000000000..cc8675cfd --- /dev/null +++ b/docs/examples/te_gemma/tutorial_generation_gemma_with_te.ipynb @@ -0,0 +1,941 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "87e8360b-8d08-44bc-9333-79ba949afe8c", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "# Accelerating Hugging Face Gemma Inference with Transformer Engine" + ] + }, + { + "cell_type": "markdown", + "id": "2da33092-eef5-46a4-b222-0188cc6e5079", + "metadata": { + "editable": true, + "slideshow": { + "slide_type": "" + }, + "tags": [] + }, + "source": [ + "## Introduction\n", + "\n", + "Generative AI has made remarkable strides in recent years, with Large Language Models (LLMs) like ChatGPT at the forefront. These models have revolutionized how we interact with machine-generated content, providing capabilities that range from writing assistance to complex decision support. The core functionality of these models is the generation process, which involves predicting the next token in a sequence based on the preceding text. This task is critical for applications such as automated content creation, translation, and more, emphasizing the importance of efficient implementation.\n", + "\n", + "

\n", + "\"\"\n", + "
\n", + "Animation 1: Hugging Face Gemma model token generation.\n", + "
\n", + "
\n", + "\n", + "For those seeking a deeper understanding of text generation mechanisms in Transformers, it is recommended to check out the [HuggingFace generation tutorial](https://huggingface.co/docs/transformers/llm_tutorial).\n", + "\n", + "In a previous tutorial on [Llama](../te_llama/tutorial_accelerate_hf_llama_finetuning_with_te.ipynb), it was demonstrated how finetuning of an open-source Llama model can be accelerated using Transformer Engine's `TransformerLayer`. Building on that foundation, this tutorial showcases how to accelerate the token generation from the open-source Hugging Face Gemma 7B model.\n", + "\n", + "This tutorial introduces several features of the Transformer Engine library that contribute towards this goal. A brief explanation is as follows:\n", + "\n", + "### 1. From vanilla KV-caching to Paged Attention for inference in Transformer Engine\n", + "\n", + "The original [Attention mechanism](https://arxiv.org/pdf/1706.03762) ushered in an era of Large Language Models, but the same attention mechanism, if used for deployment in inference scenarios, can be computationally wasteful. It is primarily due to a lot of redundant computation that happens in attention when the Transformer models are used autoregressively to compute the next token. Several tutorials on the internet explain in detail how KV Caching helps to reduce that redundant computation, e.g., [tutorial 1](https://magazine.sebastianraschka.com/p/coding-the-kv-cache-in-llms), [tutorial 2](https://medium.com/@joaolages/kv-caching-explained-276520203249), etc.\n", + "\n", + "\n", + "Further, even though the performance benefit of KV Cache is immense, it comes at the cost of increased memory usage, which becomes a problem especially for longer context lengths. The major problems are: \n", + "\n", + "1. Internal fragmentation\n", + "2. External Fragmentation\n", + "\n", + "More information can be found in the [Paged Attention](https://arxiv.org/pdf/2309.06180) paper. The authors solve the above problems by treating the KV cache as a virtual memory with the actual physical blocks being much smaller than the overall cache size. This makes it easier to swap them in and out of GPU HBM as needed - very similar to how Operating Systems implement virtual memory to swap the individual pages in and out of the CPU RAM.\n", + "\n", + "\n", + "Transformer Engine allows users to use both \"Non-paged\" and \"Paged\" forms of KV Caching, and the results in this tutorial are posted for both use cases.\n", + "\n", + "\n", + "### 2. CUDA Graphs API\n", + "\n", + "The speed of GPUs is increasing at a rapid pace. It turns out that sometimes the runtime of kernels is shorter than the time it takes for the CPU to finish processing and then launch the kernels, which can lead to significant overhead. CUDA Graphs can address this issue. When such blocks of computation are executed repeatedly, CUDA Graphs allow us to record and replay them with less CPU involvement. This becomes particularly useful in applications like token generation, where multiple \"Transformer/Decoder Layers\" are run for every token that needs to be generated.\n", + "\n", + "One can read more about CUDA Graphs [here](https://developer.nvidia.com/blog/cuda-graphs/).\n", + "\n", + "PyTorch exposes graphs via a raw `torch.cuda.CUDAGraph` class and two convenience wrappers: `torch.cuda.graph` and `torch.cuda.make_graphed_callables`. More information about the CUDA graphs in Pytorch can be found [here](https://pytorch.org/blog/accelerating-pytorch-with-cuda-graphs/).\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 1: CUDA Graphs reduce the overhead generated by the long time it takes to launch a single kernel. It enables the recording and replaying of subsequent launches, thus reducing the total time used by the CPU.\n", + "
\n", + "
\n", + "\n", + "### 3. FP8 Scaling Factors Calibration\n", + "\n", + "This tutorial uses the `DelayedScaling` recipe for FP8 precision, which relies on the correct calculation of \"scaling factors\".\n", + "\n", + "If a model is trained in BF16/FP32, obtaining correct FP8 scaling factors becomes important when it is then run under `fp8_autocast()` context manager. The value of these scaling factors defaults to their initial values, which do not capture the distribution of higher precision weights and input tensors and can cause numerical errors upon usage. Calibration involves capturing an appropriate distribution of higher precision weights and input tensor values and, in turn, calculating appropriate FP8 scaling factors from those. Once these factors are computed, the model becomes numerically stable.\n", + "\n", + "It is highly recommended to familiarize oneself with the [tutorial](../../examples/fp8_primer.ipynb) on FP8 precision to understand the importance of proper scaling factors.\n", + "\n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 2:\n", + "Assuming that the model is trained in FP32/BF16 precision and the goal is to execute it in FP8 precision, the process isn't straightforward due to the absence of appropriate FP8 scaling factors. In this scenario, FP8 calibration becomes essential. By conducting several forward passes on sample data, the FP8 scaling parameters can be computed. This calibration allows the model to operate correctly in FP8 precision.\n", + "
\n", + "
\n", + "\n", + "### 4. FP8 Model Weights\n", + "\n", + "The typical approach is to store weights in higher precision and then cast them to FP8 before operations. This may prevent accuracy drops in training. However, for inference, this level of precision is not necessary.\n", + "\n", + "The Transformer Engine includes a wrapper `fp8_model_init`, which allows for the creation of models that store only the FP8 copy of the weights. This eliminates the need to cast model weights from higher precision to FP8 every time, thus saving time in the forward pass during token generation. \n", + "\n", + "
\n", + "\"\"\n", + "
\n", + "Figure 3: Model under fp8_autocast() stores weights in high precision by default, and casts them if needed. If used without consideration, it could potentially not provide the expected speedup and also end up unnecessarily increasing overall GPU memory usage. Using fp8_model_init() results in storing model weights in FP8 by default, which can help with these potential issues.\n", + "
\n", + "
\n", + "\n", + "### Benchmarking\n", + "\n", + "We'll evaluate the generation time across one benchmark: token generation with context/prefill phase max sequence length = 20, batch size = 64, and number of generated tokens = 492 on random texts with random lengths. This is a purely synthetic benchmark.\n", + "\n", + "
\n", + "Note\n", + " \n", + "This tutorial focuses on showcasing the mentioned features of the Transformer Engine in the context of token generation. It's important to note, however, that NVIDIA provides [TensorRT-LLM](https://docs.nvidia.com/tensorrt-llm/index.html), which is optimized for inference tasks and should be considered for such use cases.\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "b18f91a9", + "metadata": {}, + "source": [ + "## Dependencies for this tutorial" + ] + }, + { + "cell_type": "markdown", + "id": "e5201d77", + "metadata": {}, + "source": [ + "The following files and media are necessary to effectively run this tutorial:\n", + "\n", + "1. `te_gemma.py`\n", + " - This file contains the code to load a Hugging Face Gemma checkpoint weights in Transformer Engine's `TransformerLayer` instead of Hugging Face's `GemmaDecoderLayer`. Further, it contains necessary abstractions like a subclass of `GemmaForCausalLM` - `TEGemmaForCausalLM` that is used for generation with Transformer Engine's `TransformerLayer`, CUDA Graphs, and FP8 calibration for generation in FP8 precision.\n", + "2. `te_gemma_loading_weights.py`\n", + " - This file contains the logic of mapping the parameters from `GemmaDecoderLayer` into the `TransformerLayer`.\n", + "3. `utils.py`\n", + " - This file contains the code related to dataloading, hyperparameters, setting up model/optimizers/accelerator, model training, and other miscellaneous tasks like restarting the Jupyter notebook from within the cell. \n", + "4. `requirements.txt`\n", + " - This file contains the necessary Python packages for this tutorial.\n", + "5. `media/`\n", + " - This directory contains the images and other artefacts used in this tutorial." + ] + }, + { + "cell_type": "markdown", + "id": "36767694-a1c5-4a00-a075-7addc55d8307", + "metadata": {}, + "source": [ + "### Setup and checks" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "1de3351b-fa21-4b95-bb9e-d01ac8bb7edf", + "metadata": {}, + "outputs": [], + "source": [ + "# Uncomment and run this cell when running the tutorial for the first time\n", + "# %pip install -r requirements.txt" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "c756ebbd-24c9-4a54-a381-e7c02c555206", + "metadata": {}, + "outputs": [], + "source": [ + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "\n", + "import torch\n", + "cudnn_version = torch.backends.cudnn.version()\n", + "assert cudnn_version >= 90100, \"cuDNN version >= 9.1.0 is needed to run this tutorial.\"" + ] + }, + { + "cell_type": "markdown", + "id": "e8dfabbf", + "metadata": {}, + "source": [ + "## [Baseline] Running Hugging Face generation with Gemma model" + ] + }, + { + "cell_type": "markdown", + "id": "59560bff", + "metadata": {}, + "source": [ + "HuggingFace Transformers library offers generation API. \n", + "HuggingFace generation for the Gemma model will be used as a baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "2803e0ec", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why GPUs are so good at graphics. The\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds advanced computer graphics and video processing chips for the PC and video game console markets.\n", + "* The company is a leading provider of graphics processing units (GPUs) for the PC and video game\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 46.60 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.batch_size = 64\n", + "run_config.max_seq_length = 512\n", + "\n", + "model = init_baseline_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "b3698dc6", + "metadata": {}, + "source": [ + "Let's put this time into the table for later comparison.\n", + "\n", + "| Models | Time | Speedup | \n", + "|-------------------------------------------------------------|---------------------------------------|--------------------------------------|\n", + "| HF (baseline) | 46.6 s | - |" + ] + }, + { + "cell_type": "markdown", + "id": "8bb40f45", + "metadata": {}, + "source": [ + "## [Optimization 1] Accelerating generation with Transformer Engine " + ] + }, + { + "cell_type": "markdown", + "id": "263b40f2", + "metadata": {}, + "source": [ + "Similar to the [Llama](../te_llama/tutorial_accelerate_hf_llama_with_te.ipynb) finetuning tutorial, a `GemmaDecoderLayer` is substituted by a tuned `TransformerLayer` from the Transformer Engine library. Let's run it and compare the time with the baseline." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "9dceef93", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why they are so good at graphics. The second\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n", + "* NVIDIA is the world leader in AI computing.\n", + "* NVIDIA is the world leader in graphics processing units (GP\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 12.25 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.batch_size = 64\n", + "run_config.max_seq_length = 512\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "b5d40836", + "metadata": {}, + "source": [ + "With just using Transformer Engine with default (non-paged) KV cache, a speedup of **3.8x** was obtained. Neat!" + ] + }, + { + "cell_type": "markdown", + "id": "006d18e8", + "metadata": {}, + "source": [ + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of `GemmaDecoderLayer` with `te.TransformerLayer`) | 12.25 s | 3.8x | 12.24 s | 3.8x |" + ] + }, + { + "cell_type": "markdown", + "id": "21a89d9c", + "metadata": {}, + "source": [ + "## [Optimization 2] More acceleration with CUDA Graphs" + ] + }, + { + "cell_type": "markdown", + "id": "e2d53e7b", + "metadata": {}, + "source": [ + "Transformer Engine includes a function `transformer_engine.pytorch.make_graphed_callables`, which behaves similarly to the corresponding feature in PyTorch. It is capable of recording any modules from the Transformer Engine. Below is a code excerpt from [te_gemma.py](./te_gemma.py) from class `TEGemmaForCausalLMCudaGraphs`:\n", + "```python\n", + " def __init__(self, config : GemmaConfig):\n", + " \"\"\"\n", + " Here \"the trick\" happens. `_model_context_phase` and\n", + " `_model_generation_phase` from TEGemmaForCausalLM are replaced with\n", + " their recorded version. Once the graphs are recorded, they can be\n", + " replayed with minimal usage of CPU and that leads to speedup.\n", + " \"\"\"\n", + " (...)\n", + " # Record the graph for context/prefill phase.\n", + " self._model_context_phase = \n", + " self.record_graph(self._model_context_phase, self.hidden_states_buffer)\n", + "\n", + " (...) \n", + " # Record the graph for generation phase.\n", + " self._model_generation_phase = \n", + " self.record_graph(self._model_generation_phase, self.generation_buffer)\n", + "\n", + " @torch.no_grad()\n", + " def record_graph(self, function, input_tensor):\n", + " \"\"\"\n", + " Records the graph for the given function. The function is invoked on\n", + " argument (self.hidden_states,) and all kernels are recorded.\n", + " It then returns the captured callable, which can be run later while\n", + " minimizing CPU usage.\n", + " \"\"\"\n", + " fp8_recipe = get_default_fp8_recipe()\n", + "\n", + " # We need both autocasts: FP8 for operations that can run in lower\n", + " # precision and BF16 for those that cannot.\n", + " with autocast(\"cuda\", dtype=torch.bfloat16, cache_enabled=False):\n", + " graphed_function = te.pytorch.make_graphed_callables(\n", + " function,\n", + " (input_tensor,),\n", + " fp8_enabled=self.config.fp8,\n", + " fp8_recipe=fp8_recipe,\n", + " allow_unused_input=True,\n", + " num_warmup_iters=5,\n", + " sample_kwargs=sample_kwargs,\n", + " )\n", + " return graphed_function\n", + "```\n", + "\n", + "It is strongly recommended to review the entire code of the class `TEGemmaForCausalLMCudaGraphs`. Let's now proceed to evaluate the performance improvement offered by CUDA Graphs.\n", + "\n", + "*Note the usage of static buffers and corresponding configuration in the following cell, which is necessary for CUDA Graphs to function.*" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "31a3a8a3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing a lot of the same thing at the same time.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "The first fact is why they are so good at graphics. The second\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and builds the world’s most advanced computer chips and systems for the AI era.\n", + "* NVIDIA is the world leader in AI computing.\n", + "* NVIDIA is the world leader in graphics processing units (GP\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 6.39 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.max_seq_length = 512\n", + "run_config.batch_size = 64\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# It is necessary to preallocate a static buffer.\n", + "# CUDA graphs require static input tensors for every kernel.\n", + "# This approach may result in a slight increase in memory consumption;\n", + "# however, the substantial speedup achieved makes it worthwhile.\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "53bb430f", + "metadata": {}, + "source": [ + "A speed up of **7.2x** was obtained by using CUDA Graphs with TE's `TransformerLayer`.\n", + "\n", + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |" + ] + }, + { + "cell_type": "markdown", + "id": "0a11b75c", + "metadata": {}, + "source": [ + "Let's profile the code from one of the cells above, which runs generation with the Gemma model, and examine the resulting traces in [NVIDIA Nsight Systems](https://developer.nvidia.com/nsight-systems) to understand the performance characteristics and sources of speedup. A few things to recap:\n", + "\n", + "1. For the TE Gemma model implementation, `model.generate()` internally calls `model_context_phase` and `model_generation_phase`.\n", + "2. They are just wrappers around the Gemma model's layers, and they are graphed separately when CUDA graphs are enabled.\n", + "3. So, for each token generated (after the first token), a single invocation of `model_generation_phase` happens as a complete CUDA graph. \n", + "4. The following illustration zooms in on a single `TransformerLayer` layer forward pass (within the larger `model_generation_phase` graphed callable) for clarity.\n", + "\n", + "(For details, refer to the implementation in [te_gemma.py](./te_gemma.py))\n", + "\n", + "
\n", + "\n", + "
\n", + " \n", + "Figure 4: (Without CUDA graphs) Blue blobs in the top figure are GPU kernels, and whitespace b/w those indicates that GPUs are idle waiting for the CPU to finish processing and then launch kernels. (With CUDA graphs) The whitespace gets virtually eliminated because all the GPU kernels are bundled into a single highly optimized unit of work with no CPU time in between. (Note that for reference, the kernels are mapped across both cases, and the sizes of those kernels only seem different because of the presence of large voids in the former case, but the sizes are actually the same.)\n", + "
\n", + "
\n" + ] + }, + { + "cell_type": "markdown", + "id": "e6b171a0", + "metadata": {}, + "source": [ + "## [Optimization 3] Even more acceleration with FP8 precision " + ] + }, + { + "cell_type": "markdown", + "id": "1a80288b", + "metadata": {}, + "source": [ + "### Calibrating FP8 scaling factors for correctness\n", + "\n", + "Implementing token generation in FP8 precision with the Gemma model is not straightforward because this model was initially trained using BF16 precision, and the necessary FP8 scaling factors are missing when used with `fp8_autocast` context manager. As Figure 5 shows, scaling factors are needed for two types of tensors for this tutorial:\n", + "\n", + "1. Model weight tensors\n", + "2. Input tensors\n", + "\n", + "If the model is run in FP8 precision with incorrect scaling factors, the resulting FP8-cast model weights and FP8-cast inputs (both converted from BF16 precision) will be significantly misaligned, potentially leading to large errors and inaccurate results.\n", + "\n", + "To address this issue, \"calibration\" is used. This involves running several forward iterations in BF16 precision within the context `te.fp8_autocast(enabled=False, calibration=True)`. This setup allows the forward pass to operate at higher precision, while simultaneously collecting `amax_history` and other parameters related to the FP8 precision, which are essential for calculating the \"scaling factors\" that are then used to cast higher precision tensors to FP8 precision more accurately. Calibration in the forward passes calculates the scaling factors for weight and input tensors.\n", + "\n", + "*Note that other tensors might need calibration in specific use-cases, but for the generation process in this tutorial, calibrating only the input and weight tensors is needed, and so only the forward pass is considered.*\n", + " \n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 5: The default FP8 scaling factors are incorrect, and so the BF16 to FP8 conversion, as is, can lead to numerical errors. Calibration allows for collecting statistics/metadata about the input and weight tensors in higher precision during the forward pass.\n", + "
\n", + "
\n", + "\n", + "\n", + "The code below outlines the steps to initialize the BF16 model and conduct several forward iterations within the specified context. After these iterations, the model is saved, and these weights will be utilized in subsequent steps." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "aecee0e1", + "metadata": {}, + "outputs": [], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "import transformer_engine.pytorch as te\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "run_config.fuse_qkv_params = True\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "# Calibration\n", + "with te.fp8_autocast(enabled=False, calibrating=True), torch.autocast(\n", + " device_type=\"cuda\", dtype=torch.bfloat16\n", + "):\n", + " model.train()\n", + " run_forward_pass(model, run_config, num_iters=64)\n", + "\n", + "# Compute scale_fwd with enabled fp8 autocast\n", + "with te.fp8_autocast(enabled=True), torch.autocast(\n", + " device_type=\"cuda\", dtype=torch.bfloat16\n", + "):\n", + " run_forward_pass(model, run_config, 1)\n", + "\n", + "# Some parameters are in pointing to the same tensors, double save is avoided here.\n", + "dict_to_save = {\n", + " k: v\n", + " for k, v in model.state_dict().items()\n", + " if (\"_context_phase\" not in k and \"_generation_phase\" not in k)\n", + "}\n", + "torch.save(\n", + " dict_to_save, \"calibrated_weights.pth\"\n", + ") # <-- Add path to save calibrated weights." + ] + }, + { + "cell_type": "markdown", + "id": "b6dcd135", + "metadata": {}, + "source": [ + "### Generation with better FP8 scaling factors\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 6: After the calibration process, FP8 scaling factors are correct and prevent numerical errors.\n", + "
\n", + "
\n", + "\n", + "Now that the calibration has produced correct scaling factors, FP8 inference is ready to be run." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "a913f54d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n", + "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA is a key player\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 8.73 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.fuse_qkv_params = True # This is needed by the last improvement.\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# CUDA Graphs related config\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "# Enable FP8\n", + "run_config.fp8 = True\n", + "# Calibrated fp8 weights are loaded directly from the file.\n", + "run_config.fp8_model_weights_filename = (\n", + " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", + ")\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "8cdbb56c", + "metadata": {}, + "source": [ + "One can observe that the outputs are coherent; however, the generation time has increased. Why is this the case?\n", + "\n", + "### Use of FP8-only model weights\n", + "\n", + "Running the model in FP8 precision does not imply that the weights are stored in FP8. By default, they are stored in higher precision and are cast to FP8, using saved scaling factors before GEMM operations (matrix multiplications).\n", + "\n", + "This approach is appropriate during training since gradients during the backward pass are produced in higher precision, and therefore, having higher precision copies of model weights helps, as they have enough dynamic range to encompass incoming information from the gradients. During the forward pass, the higher precision model weights and the batch inputs are cast to FP8, and the GEMMs occur in FP8 precision, which helps save training time overall if the time saved from running GEMM in FP8 precision (than in higher precision) is more than the extra time spent during the cast operation.\n", + "\n", + "
\n", + "\n", + "
\n", + " Figure 7: Running the model at higher precision involves only one operation - GEMM. However, when the model operates in FP8, it requires casting inputs to the GEMM - namely, model weights and batch inputs from higher precision to FP8, which involves extra kernels in addition to the low-precision GEMM kernel.\n", + "
\n", + "
" + ] + }, + { + "cell_type": "markdown", + "id": "626aefa1-d5c4-4d8f-88d9-7d7943afde0d", + "metadata": {}, + "source": [ + "However, things change during inference. Since the weights need no update and remain frozen, higher precision copies of weights could be avoided completely. It is possible to cast the higher precision weights only once to FP8 precision while initializing the model with appropriate scaling factors and then use those FP8-only copies of weights during the entirety of token generation. This provides two-fold benefits:\n", + "\n", + "1. Lower memory usage - since the model weights are stored in FP8 precision only (compared to training, where both BF16 and FP8 copies end up being present in the memory during peak usage).\n", + "2. Faster forward pass - since there is no cast kernel to cast higher precision weights to FP8 every time before a GEMM operation. (Unless the inputs are in FP8 precision already, there's still one cast kernel to cast inputs to FP8 precision.) \n", + "\n", + "\n", + "Transformer Engine supports maintaining FP8-only weights with the `fp8_model_init` context manager. Let's see a small example:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "4562ee82-8c95-4736-8815-cd386078a485", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Memory required for 16384x16384 linear layer: \n", + "FP32 - 1024.0 MB, \n", + "BF16 - 512.0 MB, \n", + "FP8 - 256.0 MB, \n", + "\n", + "Actual GPU memory usage with a TE FP32 linear layer: 1024.06 MB\n", + "Actual GPU memory usage with a TE BF16 linear layer: 512.03 MB\n", + "Actual GPU memory usage with a TE FP8 linear layer: 256.08 MB\n" + ] + } + ], + "source": [ + "import torch\n", + "import transformer_engine.pytorch as te\n", + "\n", + "H = 2**14\n", + "D = 2**14\n", + "print(f\"Memory required for {H}x{D} linear layer: \\n\"\n", + " f\"FP32 - {H*D*4/1024**2} MB, \\n\"\n", + " f\"BF16 - {H*D*2/1024**2} MB, \\n\"\n", + " f\"FP8 - {H*D*1/1024**2} MB, \\n\")\n", + "\n", + "linear_fp32 = te.Linear(H, D, params_dtype=torch.float32) \n", + "print(f\"Actual GPU memory usage with a TE FP32 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_fp32\n", + "\n", + "linear_bf16 = te.Linear(H, D, params_dtype=torch.bfloat16)\n", + "print(f\"Actual GPU memory usage with a TE BF16 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_bf16\n", + "\n", + "# Initialize model weights in FP8 precision\n", + "with torch.no_grad(), te.fp8_model_init(enabled=True):\n", + " linear_fp8 = te.Linear(H, D)\n", + "print(f\"Actual GPU memory usage with a TE FP8 linear layer: {torch.cuda.memory_allocated()/1024**2:.2f} MB\")\n", + "del linear_fp8" + ] + }, + { + "cell_type": "markdown", + "id": "2a26aba9-f3ba-42c4-b4c3-9e845502ae1b", + "metadata": {}, + "source": [ + "\n", + "
\n", + "\n", + "
\n", + " Figure 8: Using fp8_model_init stores the weights directly in FP8 format, which reduces both time and memory usage. Note that the inputs still need a cast kernel.\n", + "
\n", + "
\n", + "\n", + "Let's run the code with `fp8_model_init`:" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "96264b9c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "============================== Generation example 1 ==============================\n", + "Prompt: \"Here are the two facts about GPUs:\"\n", + "Generated text: \"\n", + "\n", + "1. They are very good at doing the same thing over and over again.\n", + "2. They are very bad at doing different things at the same time.\n", + "\n", + "This is why GPUs are so good at rendering graphics. The GPU is very good at\"\n", + "============================== Generation example 2 ==============================\n", + "Prompt: \"Some facts about NVIDIA:\"\n", + "Generated text: \"\n", + "\n", + "* NVIDIA is a global technology company that designs and develops high-performance computer graphics and video processing chips.\n", + "* NVIDIA is a leading provider of graphics processing units (GPUs) for the gaming and professional markets.\n", + "* NVIDIA is a key player\"\n", + "\n", + "================================================================================\n", + "Benchmarking for batch_size = 64, prefill tokens = 20 and max new tokens = 492\n", + "Time: 4.99 s.\n" + ] + } + ], + "source": [ + "# Restart the notebook (to flush the GPU memory)\n", + "from utils import restart_jupyter_notebook\n", + "restart_jupyter_notebook()\n", + "\n", + "# Import necessary packages and methods\n", + "from utils import *\n", + "\n", + "# Provide Huggingface Access Token\n", + "run_config.hf_access_token = \"\"\n", + "assert run_config.hf_access_token, \"Provide a HF API Access Token!\"\n", + "run_config.model_name = \"google/gemma-7b\"\n", + "\n", + "# Provide a directory to cache weights in to avoid downloading them every time.\n", + "# (By default, weights are cached in `~/.cache/huggingface/hub/models`)\n", + "run_config.weights_cache_dir = \"\"\n", + "\n", + "# Set specific hyperparameters\n", + "# (Default run_config are defined in `utils.py` in class `Hyperparameters`)\n", + "run_config.fuse_qkv_params = True # This is needed by the last improvement.\n", + "run_config.is_paged = False # <-- Toggle this to `True` to run generation with `Paged Attention`\n", + "\n", + "# CUDA Graphs related config\n", + "run_config.generation_cuda_graphs = True\n", + "run_config.cuda_graphs_static_batch_size = 64\n", + "run_config.cuda_graphs_static_max_seq_len = 512\n", + "run_config.cuda_graphs_static_max_context_len = 512\n", + "\n", + "# Enable FP8 math and FP8 model weights\n", + "run_config.fp8 = True\n", + "run_config.fp8_model_init = True # This will result in storing only fp8 weights.\n", + "run_config.fp8_model_weights_filename = (\n", + " \"calibrated_weights.pth\" # <-- Add calibrated weights location here.\n", + ")\n", + "\n", + "model = init_te_gemma_model(run_config)\n", + "\n", + "print_sample_of_generated_texts(model, run_config)\n", + "benchmark_generation(model, run_config)" + ] + }, + { + "cell_type": "markdown", + "id": "3e30ca5a", + "metadata": {}, + "source": [ + "The final speedup is **9.3x**. \n", + "\n", + "| Models | Time (non-paged kv cache) | Speedup (non-paged kv cache) | Time (paged kv cache) | Speedup (paged kv cache) |\n", + "|---|---|---|---|---|\n", + "| HF (baseline) | 46.6 s | - | - | - |\n", + "| TE (subsitution of GemmaDecoderLayer with te.TransformerLayer) | 12.25 s | 3.8x | 12.24 s | 3.8x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs | 6.39 s | 7.2x | 6.47 s | 7.2x |\n", + "| TE (te.TransformerLayer) + CUDA Graphs + FP8 (with `fp8_model_init`) | 4.99 s | 9.3x | 5.05 s | 9.2x |" + ] + }, + { + "cell_type": "markdown", + "id": "c6e87275", + "metadata": {}, + "source": [ + "## Conclusions" + ] + }, + { + "cell_type": "markdown", + "id": "7bb2452d", + "metadata": {}, + "source": [ + "This tutorial focuses primarily on making the token generation faster with an off-the-shelf model downloaded from Hugging Face using the following features of the Transformer Engine:\n", + "\n", + "1. Support for KV Caching (both non-paged and paged),\n", + "2. Integration with CUDA Graphs,\n", + "3. FP8 scaling factors calibration,\n", + "4. Keeping model parameters in FP8 precision.\n", + "\n", + "It's worth noting that these features in TE are also readily applicable to other use-cases which haven't been extensively talked about in the tutorial: \n", + "\n", + "1. Longer context lengths (with paged KV cache) \n", + "2. Using less memory during generation (by storing weights in FP8 precision using `fp8_model_init`)\n", + "\n", + "Readers are encouraged to explore these use cases by playing around with this tutorial, especially with larger models." + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.12.3" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/examples/te_gemma/utils.py b/docs/examples/te_gemma/utils.py new file mode 100755 index 000000000..cc31afc65 --- /dev/null +++ b/docs/examples/te_gemma/utils.py @@ -0,0 +1,370 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +import sys +import IPython +import random +import string + +from te_gemma_loading_weights import load_te_model +import torch +from torch.utils.data import DataLoader + +from transformers import ( + AutoModelForCausalLM, + AutoTokenizer, + AutoConfig, +) +from transformers import DataCollatorForLanguageModeling +from datasets import load_dataset + + +from te_gemma import TEGemmaForCausalLM, TEGemmaForCausalLMCudaGraphs + +random.seed(42) +torch.manual_seed(42) + + +class RunConfiguration: + def __init__(self): + self.mixed_precision = "bf16" + self.model_name = None + + # FP8 precision settings + self.fp8 = False + self.fp8_model_weights_filename = None + self.fp8_model_init = False + + # Cuda graphs + self.generation_cuda_graphs = False + self.cuda_graphs_static_batch_size = 64 + self.cuda_graphs_static_max_seq_len = 512 + self.cuda_graphs_static_max_context_len = 512 + + # Finetuning/calibration/generation settings + self.dataset_name = "timdettmers/openassistant-guanaco" + self.dataset_text_field = "text" + self.learning_rate = 1.41e-5 + self.batch_size = 64 + self.max_seq_length = 512 + self.gradient_accumulation_steps = 1 + self.num_warmup_steps = 5 + self.num_training_steps = 10 + + # Coalesced QKV params or not + self.fuse_qkv_params = False + + # Attention + self.is_paged = False + + # This is either provided by the user or it will be set when the + # model weights are downloaded. + self.weights_cache_dir = "" + + +# Global variable for the run configuration so that it can be easily accessed +# throughout the jupyter notebook with an `import * from utils` statement +run_config = RunConfiguration() + + +def get_dataloaders(run_config): + """ + Returns a basic dataloader for the dataset which contains tokenized batches + of text. + """ + dataset = load_dataset(run_config.dataset_name, split="train") + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + + def tokenize(element): + outputs = tokenizer( + element["text"], + truncation=True, + padding=False, + max_length=run_config.max_seq_length, + return_overflowing_tokens=False, + return_length=False, + ) + return {"input_ids": outputs["input_ids"], "attention_mask": outputs["attention_mask"]} + + # Tokenize the dataset + dataset = dataset.map(tokenize, batched=True, remove_columns=dataset.column_names) + + # Simply pad to the multiple of 16 for both FP8 and BF16 precision + pad_to_multiple_of = 16 + data_collator = DataCollatorForLanguageModeling( + tokenizer=tokenizer, + mlm=False, + pad_to_multiple_of=pad_to_multiple_of, + ) + + dataloader_params = { + "batch_size": run_config.batch_size, + "collate_fn": data_collator, + "drop_last": True, + } + train_dataloader = DataLoader(dataset, **dataloader_params) + return train_dataloader + + +def ensure_model_is_downloaded(run_config): + """ + Downloads and caches the model weights if not already downloaded. A valid + Huggingface Access Token is required to download the model weights. + """ + assert run_config.model_name in [ + "google/gemma-7b", + ], "Only Gemma 7B model is supported!" + + # Login using Huggingface Hub API + from huggingface_hub import login + + try: + login(run_config.hf_access_token) + except Exception as e: + if "Invalid token passed!" in str(e): + print( + "Please pass a valid HF Access Token! More info at" + " https://huggingface.co/docs/hub/en/security-tokens." + ) + else: + print(f"Exception is {e}") + + # Download the model if it doesn't exist + from huggingface_hub import snapshot_download + + supplied_cache_dir = ( + run_config.weights_cache_dir if run_config.weights_cache_dir != "" else None + ) + run_config.weights_cache_dir = snapshot_download( + repo_id=run_config.model_name, cache_dir=supplied_cache_dir + ) + + +def init_baseline_model(run_config): + """ + Initializes a baseline HF Gemma model with the model name provided in + the run_config. + """ + + # Download and cache the weights if not already downloaded + ensure_model_is_downloaded(run_config) + + # Init the model + config = AutoConfig.from_pretrained(run_config.model_name) + + # Make sure to use flash_attention to do iso comparison with TEGemmaModel + config._attn_implementation = "flash_attention_2" + model = AutoModelForCausalLM.from_pretrained( + run_config.model_name, + config=config, + torch_dtype=torch.bfloat16, + ).cuda() + + return model + + +def init_te_gemma_model(run_config): + """ + Initializes a Gemma model with `GemmaDecoderLayer`s swapped with + `TransformerLayer`s from TransformerEngine. In case CUDA Graphs are enabled, + the model is initialized from `TEGemmaForCausalLMCudaGraphs` class. + """ + + # Download and cache the weights if not already downloaded + ensure_model_is_downloaded(run_config) + + cls = TEGemmaForCausalLMCudaGraphs if run_config.generation_cuda_graphs else TEGemmaForCausalLM + config = AutoConfig.from_pretrained(run_config.model_name) + + # Inject all fields from the `run_config` to the model `config` to make the + # code simpler. + for key, value in run_config.__dict__.items(): + setattr(config, key, value) + + # Initialize the model and move it to the GPU. + model = load_te_model(cls, config).cuda() + + # Record the model if CUDA Graphs are enabled. + if run_config.generation_cuda_graphs: + model.record() + + return model + + +def restart_jupyter_notebook(): + # Try restarting the Jupyter kernel + IPython.Application.instance().kernel.do_shutdown(True) + + # Check whether the device memory has been flushed + if torch.cuda.memory_allocated() != 0: + import warnings + + warnings.warn("The device memory hasn't been flushed, trying with a second method!") + + # Try restarting the Jupyter kernel another way + # Restart the kernel + from IPython.core.display import HTML + + HTML("") + + if torch.cuda.memory_allocated() != 0: + print( + "The device memory hasn't been flushed, try manually restarting the Jupyter kernel!" + ) + + # Suppress the warnings + if not sys.warnoptions: + import warnings + + warnings.simplefilter("ignore") + torch.set_warn_always(False) + + +@torch.no_grad() +def run_forward_pass(model, run_config, num_iters): + """ + Runs the forward pass of the model with sample data. Intended to use for + warmup and/or calibration. + """ + train_dataloader = get_dataloaders(run_config) + + model.train() + train_dataloader = enumerate(train_dataloader) + + for _ in range(num_iters): + _, batch = next(train_dataloader) + batch["input_ids"] = batch["input_ids"].cuda() + batch["attention_mask"] = batch["attention_mask"].cuda() + model(input_ids=batch["input_ids"], attention_mask=batch["attention_mask"]) + + +############################################################################### +# Benchmarking and example generation functions. +############################################################################### + + +def print_sample_of_generated_texts(model, run_config): + """ + Prints a sample of generated texts from the input model. + """ + + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + if getattr(tokenizer, "pad_token", None) is None: + tokenizer.pad_token = tokenizer.eos_token + prompts = [ + "Here are the two facts about GPUs:", + "Some facts about NVIDIA:", + "The fundamental theorem of calculus for the layman:", + "A fact about AI:", + ] + + # Repeat prompts to match batch size + prompts *= run_config.batch_size // len(prompts) + inputs = tokenizer(prompts, return_tensors="pt", padding=True) + + max_total_tokens = ( + run_config.max_seq_length + if not run_config.generation_cuda_graphs + else run_config.cuda_graphs_static_max_seq_len + ) + + max_length = inputs["input_ids"].size(1) + new_length = ((max_length + 63) // 64) * max_total_tokens + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], (new_length - max_length, 0), value=tokenizer.pad_token_id + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (new_length - max_length, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + outputs = model.generate(**inputs, max_new_tokens=50) + generated_texts = tokenizer.batch_decode(outputs, skip_special_tokens=True) + + def print_output(prompts, generated_texts, idx): + print("=" * 30 + f" Generation example {idx+1} " + "=" * 30) + print(f'Prompt: "{generated_texts[idx][: len(prompts[idx])]}"') + print(f'Generated text: "{generated_texts[idx][len(prompts[idx]) :]}"') + + # Print the output from first two prompts + for i in range(2): + print_output(prompts, generated_texts, i) + + +def _generate_random_words(num_words, max_word_length): + """ + Generates random words for the benchmark. + """ + + words = [] + for _ in range(num_words): + word_length = random.randint(1, max_word_length) + word = "".join(random.choices(string.ascii_lowercase, k=word_length)) + words.append(word) + return words + + +def benchmark_generation(model, run_config, context_length=20): + """ + Benchmarks the generation time for a random input to the model. + """ + + batch_size = run_config.batch_size + + max_total_tokens = ( + run_config.max_seq_length + if not run_config.generation_cuda_graphs + else run_config.cuda_graphs_static_max_seq_len + ) + max_new_tokens = max_total_tokens - context_length + + print("\n" + "=" * 80) + print( + f"Benchmarking for batch_size = {batch_size}, prefill tokens =" + f" {context_length} and max new tokens = {max_new_tokens}" + ) + + input_str = _generate_random_words(batch_size, context_length) + + tokenizer = AutoTokenizer.from_pretrained(run_config.model_name) + inputs = tokenizer(input_str, return_tensors="pt", padding=True) + + max_context_tokens = inputs["input_ids"].size(1) + + # Add padding to the left + inputs["input_ids"] = torch.nn.functional.pad( + inputs["input_ids"], + (max_total_tokens - max_context_tokens, 0), + value=tokenizer.pad_token_id, + ) + + # Add padding to the left (only intended for baseline generation with HF + # which expects padding to the left) + inputs["attention_mask"] = torch.nn.functional.pad( + inputs["attention_mask"], (max_total_tokens - max_context_tokens, 0), value=0 + ) + + inputs["input_ids"] = inputs["input_ids"].cuda() + inputs["attention_mask"] = inputs["attention_mask"].cuda() + + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + + model.generate(inputs["input_ids"].cuda(), max_new_tokens=max_new_tokens) + torch.cuda.synchronize() + end.record() + + print(f"Time: {start.elapsed_time(end)/1000:.2f} s.") diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 7013e85ec..00499cff5 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -5,7 +5,7 @@ "id": "6a5b2993", "metadata": {}, "source": [ - "# Accelerating a Hugging Face Llama 2 and Llama 3 models with Transformer Engine\n", + "# Accelerating Hugging Face Llama 2 and 3 Fine-Tuning with Transformer Engine\n", "\n", "
\n", "\n", diff --git a/docs/index.rst b/docs/index.rst index e678b1d46..2c04810f4 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -46,6 +46,7 @@ Transformer Engine documentation examples/fp8_primer.ipynb examples/advanced_optimizations.ipynb examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb + examples/te_gemma/tutorial_generation_gemma_with_te.ipynb examples/onnx/onnx_export.ipynb .. toctree:: diff --git a/transformer_engine/pytorch/attention/inference.py b/transformer_engine/pytorch/attention/inference.py index 8d5417a45..f0ef8d0bd 100644 --- a/transformer_engine/pytorch/attention/inference.py +++ b/transformer_engine/pytorch/attention/inference.py @@ -215,6 +215,17 @@ def __init__( device=torch.cuda.current_device(), ) + # This internal buffer holds the running length of each + # unfinished sequence in the batch and is updated in `pre_step()` + # method. One use of this buffer is applying RoPE to q and k tensors + # during inference by slicing ROPE Embeddings according to the + # current sequence length window. + self.pre_step_seqlens = torch.zeros( + self.max_batch_size, + dtype=torch.int32, + device=torch.cuda.current_device(), + ) + def reset(self): """Reset InferenceParams state""" self.sequences = OrderedDict() @@ -266,6 +277,15 @@ def pre_step( for k, v in self.sequences.items(): self.sequences_pre_step[k] = v - step_dict[k] + pre_step_seqlens_temp = torch.Tensor(list(self.sequences_pre_step.values())).to( + dtype=torch.int32, device="cpu" + ) + + # Copy the pre-step seqlens to the device in CUDA Graphs safe manner. + self.pre_step_seqlens[: len(pre_step_seqlens_temp)].copy_( + pre_step_seqlens_temp, non_blocking=False + ) + seqlens_q = list(step_dict.values()) cu_seqlens_q = [0] + [sum(seqlens_q[:i]) for i in range(1, self.batch_size + 1)] cu_seqlens_q = cu_seqlens_q + [cu_seqlens_q[-1]] * (self.max_batch_size - self.batch_size) @@ -280,9 +300,7 @@ def pre_step( def get_seqlens_pre_step(self): """Get cached sequence lengths before the stepping""" - return torch.Tensor(list(self.sequences_pre_step.values())).to( - dtype=torch.int32, device="cpu" - ) + return self.pre_step_seqlens def convert_paged_to_nonpaged(self, layer_number: int): """ @@ -458,14 +476,14 @@ def pre_step( finished_seqs = self.sequences.keys() - unfinished_seqs unfinished_indices = [i for i, j in enumerate(self.sequences) if j in unfinished_seqs] finished_indices = [i for i, j in enumerate(self.sequences) if j in finished_seqs] - self.batch_indices.copy_( + self.batch_indices.data[:].copy_( torch.Tensor( ( unfinished_indices + finished_indices + list(range(prev_batch_size, self.max_batch_size)) ) - ).to(dtype=torch.int32, device="cpu") + ) ) # Advance unfinished sequences diff --git a/transformer_engine/pytorch/attention/multi_head_attention.py b/transformer_engine/pytorch/attention/multi_head_attention.py index 9c82442af..5fd16bf1a 100644 --- a/transformer_engine/pytorch/attention/multi_head_attention.py +++ b/transformer_engine/pytorch/attention/multi_head_attention.py @@ -889,23 +889,11 @@ def forward( q_pos_emb, k_pos_emb = rotary_pos_emb - # adjust key and value for inference - if inference_params is not None: - if self.qkv_format == "sbhd": - sequence_length = key_layer.size(0) - elif self.qkv_format == "bshd": - sequence_length = key_layer.size(1) - else: - raise ValueError( - f"qkv_format={self.qkv_format} not supported for KV caching and RoPE." - ) - - sequence_start = inference_params.get_seqlens_pre_step() - # sequence_start = inference_params.seqlens[0] - sequence_end = sequence_start + sequence_length - - q_pos_emb = q_pos_emb[sequence_start:sequence_end, ...] - k_pos_emb = k_pos_emb[sequence_start:sequence_end, ...] + # Applyig RoPE for inference needs start positions of sequences + # for each iteration. + sequence_start_positions = ( + inference_params.get_seqlens_pre_step() if inference_params is not None else None + ) if pad_between_seqs: rotary_pos_cu_seq_lens_q = cu_seqlens_q_padded @@ -922,6 +910,7 @@ def forward( cu_seqlens=rotary_pos_cu_seq_lens_q, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start_positions, interleaved=self.rotary_pos_interleaved, ) key_layer = apply_rotary_pos_emb( @@ -932,6 +921,7 @@ def forward( cu_seqlens=rotary_pos_cu_seq_lens_kv, cp_size=self.cp_size, cp_rank=self.cp_rank, + start_positions=sequence_start_positions, interleaved=self.rotary_pos_interleaved, ) diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index d1ba1a351..064da8a67 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -28,9 +28,10 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, auto freqs_cu = makeTransformerEngineTensor(freqs); auto output_cu = makeTransformerEngineTensor(output); - auto start_positions_cu = TensorWrapper(); // empty cu_seqlens tensor + auto start_positions_cu = TensorWrapper(); // empty start_positions tensor if (start_positions) { start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); } if (qkv_format == NVTE_QKV_Format::NVTE_THD) { diff --git a/transformer_engine/pytorch/module/grouped_linear.py b/transformer_engine/pytorch/module/grouped_linear.py index e9189ccc5..5749d96c9 100644 --- a/transformer_engine/pytorch/module/grouped_linear.py +++ b/transformer_engine/pytorch/module/grouped_linear.py @@ -883,7 +883,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] * self.num_gemms weight_quantizers = [ self.quantizers["scaling_fwd"][ diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index cd02f3113..ee24dc33f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1767,7 +1767,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index a6c55ceb7..9f799c553 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -445,14 +445,19 @@ def forward( act_out = activation_func(fc1_out, None) act_out = tex.quantize(act_out, fc2_input_quantizer) else: - act_out = activation_func(fc1_out, fc2_input_quantizer) + if fp8_calibration: + act_out = activation_func(fc1_out, None) + else: + act_out = activation_func(fc1_out, fc2_input_quantizer) if not is_grad_enabled: clear_tensor_data(fc1_out) - if fp8_calibration: - fc2_input_quantizer.calibrate(act_out) - fc2_weight_quantizer.calibrate(fc2_weight) + if not fp8 and fp8_calibration: + if fc2_input_quantizer is not None: + fc2_input_quantizer.calibrate(act_out) + if fc2_weight_quantizer is not None: + fc2_weight_quantizer.calibrate(fc2_weight) # Configure Userbuffers reduce-scatter if needed ub_obj_fc2out = None @@ -1897,7 +1902,7 @@ def _get_quantizers(self, fp8_output): fc2_grad_output_quantizer, ) = [None] * 10 fc1_weight_quantizer, fc2_weight_quantizer = self._get_weight_quantizers() - if self.fp8: + if self.fp8 or self.fp8_calibration: fc1_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_INPUT] fc1_input_quantizer.internal = True fc2_input_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM2_INPUT] @@ -2114,7 +2119,7 @@ def _get_weight_tensors(self) -> List[Union[torch.Tensor, QuantizedTensorBase]]: def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None, None] fc1_weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] fc1_weight_quantizer.internal = True diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 2ce6fb4c1..3bc807413 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -1643,7 +1643,7 @@ def _customize_quantizers_float8_current_scaling(self, fwd: bool, recipe: Recipe def _get_weight_quantizers(self) -> List[Quantizer]: """Get the weight quantizers of the module.""" - if not self.fp8: + if not self.fp8 and not self.fp8_calibration: return [None] weight_quantizer = self.quantizers["scaling_fwd"][tex.FP8FwdTensors.GEMM1_WEIGHT] weight_quantizer.internal = True From 93a67af81a98f6542ecb2e414360bd0a74ca4367 Mon Sep 17 00:00:00 2001 From: yuzhongw-nvidia Date: Wed, 17 Sep 2025 13:59:52 +0800 Subject: [PATCH 143/153] Fix memory overhead of linear layer when all gather from sequence parallel (#2125) * fix memory overhead of all gather from sequence parallel Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * quick fix the errors when for UB buffers Signed-off-by: Yuzhong Wang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Update transformer_engine/pytorch/module/linear.py Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> * Avoid deallocating FP8 scale-invs since they are reused Signed-off-by: Tim Moon --------- Signed-off-by: Yuzhong Wang Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: Tim Moon Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Co-authored-by: Tim Moon --- .../pytorch/module/layernorm_linear.py | 23 +++++++++++++++---- transformer_engine/pytorch/module/linear.py | 16 ++++++++++++- .../_internal/float8_blockwise_tensor_base.py | 5 ++++ .../tensor/_internal/float8_tensor_base.py | 9 ++++++-- 4 files changed, 46 insertions(+), 7 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ee24dc33f..4d30be414 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -353,8 +353,11 @@ def forward( # Deallocate GEMM input tensor if no longer needed if not weight.requires_grad and not return_layernorm_output: - ln_out = ln_out_total = None clear_tensor_data(ln_out, ln_out_total) + ln_out = ln_out_total = None + elif with_input_all_gather and not return_layernorm_output_gathered: + clear_tensor_data(ln_out_total) + ln_out_total = None # ------------------------------------------------------ # Prepare output tensor @@ -891,9 +894,19 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted - if not ctx.return_layernorm_output: + # Deallocate input tensors if permitted + if not ctx.return_layernorm_output and not ctx.return_layernorm_output_gathered: + # Input tensors have not been exposed externally + clear_tensor_data(ln_out) + elif ctx.ln_out_needs_gather and ctx.return_layernorm_output_gathered: + # Non-gathered input has not been exposed externally + clear_tensor_data(ln_out) + if ctx.ln_out_needs_gather: + # Gathered input is internal clear_tensor_data(ln_out_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: @@ -1169,7 +1182,9 @@ def __init__( self.return_bias = return_bias self.apply_bias = self.use_bias and not return_bias self.return_layernorm_output = return_layernorm_output - self.return_layernorm_output_gathered = return_layernorm_output_gathered + self.return_layernorm_output_gathered = ( + return_layernorm_output_gathered if return_layernorm_output else False + ) self.zero_centered_gamma = zero_centered_gamma self.symmetric_ar_type = symmetric_ar_type diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 3bc807413..7e526245c 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -317,6 +317,13 @@ def forward( # Finished forward GEMM... # ------------------------------------------------------ + # Deallocate GEMM input tensor if no longer needed + # TODO(yuzhongw, tmoon): Figure out why inputmat_total is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + if with_input_all_gather_nccl: + clear_tensor_data(inputmat_total) + inputmat_total = None + # ------------------------------------------------------ # Prepare output tensor # Note: Perform tensor-parallel communication @@ -878,9 +885,16 @@ def wgrad_gemm( grad_bias = grad_bias_ del grad_bias_ - # Deallocate input tensor if permitted + # Deallocate tensors if permitted if ctx.owns_input: + # Input tensor is internal + clear_tensor_data(inputmat_total) + elif ctx.backward_input_needs_gather: + # Gathered input tensor is internal clear_tensor_data(inputmat_total) + if ctx.parallel_mode == "row" and ctx.sequence_parallel: + # Gathered grad output tensor is internal + clear_tensor_data(grad_output) # Update grad input if overlapping reduce-scatter with wgrad GEMM if ctx.ub_bulk_wgrad: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py index adffe7c58..da0220eb7 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_blockwise_tensor_base.py @@ -349,9 +349,14 @@ def _create_columnwise(self): def _transpose_columnwise_data(self): """Plainly transpose the columnwise data and scale inv.""" if self._columnwise_data is not None: + # TODO(yuzhongw, tmoon): Figure out why _old_data is not automatically + # deallocated by GC. Manually deallocating is a temporary hack. + _old_data = self._columnwise_data self._columnwise_data = tex.fp8_transpose( self._columnwise_data, self._fp8_dtype, out=None ) + _old_data.data = _empty_tensor() + del _old_data def __repr__(self): if self._rowwise_data is not None: diff --git a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py index 61edc999a..6d4822344 100644 --- a/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py +++ b/transformer_engine/pytorch/tensor/_internal/float8_tensor_base.py @@ -95,8 +95,13 @@ def __new__( return instance def clear(self): - """Deallocate this tensor's memory. Typically not needed and must be used carefully.""" - for t in (self._data, self._transpose, self._scale_inv): + """Deallocate this tensor's memory. Typically not needed and must be used carefully. + + Scale-inv tensor is not deallocated because it's often shared + between multiple FP8 tensors. + + """ + for t in (self._data, self._transpose): if t is not None: t.data = _empty_tensor() self._transpose_invalid = True From eb69fad7b865f034015ad149b898a031fbd810d3 Mon Sep 17 00:00:00 2001 From: Daniel Stokes <40156487+djns99@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:15:20 +1200 Subject: [PATCH 144/153] Fix incorrect TP rank calculation when using data parallel (#2179) Signed-off-by: djns99 <40156487+djns99@users.noreply.github.com> --- .../common/comm_gemm_overlap/comm_gemm_overlap.cpp | 8 ++++---- .../comm_gemm_overlap/userbuffers/userbuffers.cu | 14 ++++++++------ .../comm_gemm_overlap/userbuffers/userbuffers.h | 4 ++-- 3 files changed, 14 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp index 087493495..ec29e6e12 100644 --- a/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp +++ b/transformer_engine/common/comm_gemm_overlap/comm_gemm_overlap.cpp @@ -607,10 +607,10 @@ void CommOverlapBase::bulk_overlap_external_ag(cudaStream_t send_stream, cudaStr int comm_bytes_per_rank = comm_bytes / _tp_size; // We use the reference to the overlap_gemm to get the stream to send an receive on to ensure the kernels don't finish until the previous gemm is flush - userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, - send_stream); - userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _ub_comm, - recv_stream); + userbuffers_send_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, send_stream); + userbuffers_recv_all(_ub_reg, 0, _ub_reg, 0, comm_bytes_per_rank, _tp_id, _tp_size, _rank, + _ub_comm, recv_stream); // We sync with the internal comm stream so the destructor can wait for the comm stream to finish before freeing the ubuf for (auto stream : {send_stream, recv_stream}) { diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu index 17f3cf658..1dcd54d0d 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.cu @@ -2542,25 +2542,27 @@ void userbuffers_recv(const int srchandler, const size_t srcoffset, const int ds void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream) { + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; for (int j = 1; j < tp_size; j++) { int i = (tp_rank + j) % tp_size; int send_offset = srcoffset + bytes_per_slice * tp_rank; int recv_offset = dstoffset + bytes_per_slice * tp_rank; - userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, - stream); + userbuffers_send(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); } } void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream) { + int tp_size, int world_rank, communicator *comm, cudaStream_t stream) { + int rank_round_tp = (world_rank / tp_size) * tp_size; for (int j = tp_size - 1; j > 0; j--) { int i = (tp_rank + j) % tp_size; int send_offset = srcoffset + bytes_per_slice * i; int recv_offset = dstoffset + bytes_per_slice * i; - userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, i, - stream); + userbuffers_recv(srchandler, send_offset, dsthandler, recv_offset, bytes_per_slice, comm, + rank_round_tp + i, stream); } } diff --git a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h index 8077f90be..4d52fbb64 100644 --- a/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h +++ b/transformer_engine/common/comm_gemm_overlap/userbuffers/userbuffers.h @@ -306,10 +306,10 @@ void reduce_bf16(void *input, void *output, int num_inputs, int input_size, cuda void userbuffers_send_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream); + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); void userbuffers_recv_all(const int srchandler, const size_t srcoffset, const int dsthandler, const size_t dstoffset, const size_t bytes_per_slice, int tp_rank, - int tp_size, communicator *comm, cudaStream_t stream); + int tp_size, int world_rank, communicator *comm, cudaStream_t stream); #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ From 8aee1bb774998556e8fcc1234e7bb137bd5d0c43 Mon Sep 17 00:00:00 2001 From: alan yang <89962857+cassiewilliam@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:23:15 +0800 Subject: [PATCH 145/153] [Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model (#2045) * feat: add cutlass group gemm support Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor multi tensor gemm interface Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * refactor: refactor nvte_multi_stream_cublas_gemm func and add license info Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add unit test for cutlass group gemm Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: add cutlass support type protect Signed-off-by: Min Yang * add tests and fix lint Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: fix unit tests error Signed-off-by: Min Yang * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * feat: refactor host workspace malloc Signed-off-by: Min Yang * update cutlass Signed-off-by: Xin Yao * update cutlass Signed-off-by: Xin Yao * further relex threshold and add a env var to warn fall back Signed-off-by: Xin Yao * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Min Yang Signed-off-by: Xin Yao Signed-off-by: alan yang <89962857+cassiewilliam@users.noreply.github.com> Co-authored-by: Min Yang Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Xin Yao Co-authored-by: Phuong Nguyen --- .gitmodules | 3 + 3rdparty/cutlass | 1 + tests/pytorch/test_numerics.py | 68 +++- transformer_engine/common/CMakeLists.txt | 22 +- .../common/gemm/cublaslt_gemm.cu | 119 +++++- .../common/gemm/cutlass_grouped_gemm.cu | 77 ++++ .../common/gemm/cutlass_grouped_gemm.cuh | 348 ++++++++++++++++++ .../common/include/transformer_engine/gemm.h | 11 +- .../jax/csrc/extensions/gemm.cpp | 8 +- .../pytorch/csrc/extensions/gemm.cpp | 9 +- 10 files changed, 633 insertions(+), 33 deletions(-) create mode 160000 3rdparty/cutlass create mode 100644 transformer_engine/common/gemm/cutlass_grouped_gemm.cu create mode 100644 transformer_engine/common/gemm/cutlass_grouped_gemm.cuh diff --git a/.gitmodules b/.gitmodules index 21492db5e..4b188d6bb 100644 --- a/.gitmodules +++ b/.gitmodules @@ -4,3 +4,6 @@ [submodule "3rdparty/cudnn-frontend"] path = 3rdparty/cudnn-frontend url = https://github.com/NVIDIA/cudnn-frontend.git +[submodule "3rdparty/cutlass"] + path = 3rdparty/cutlass + url = https://github.com/NVIDIA/cutlass.git diff --git a/3rdparty/cutlass b/3rdparty/cutlass new file mode 160000 index 000000000..57e3cfb47 --- /dev/null +++ b/3rdparty/cutlass @@ -0,0 +1 @@ +Subproject commit 57e3cfb47a2d9e0d46eb6335c3dc411498efa198 diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index a50b3fbca..a0e285b91 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -125,6 +125,11 @@ fp8_recipes.append(recipe.Float8CurrentScaling()) fp8_recipes.append(recipe.DelayedScaling()) +use_cutlass_grouped_gemm = [False] +# Only enable cutlass grouped gemm on Hopper +if torch.cuda.get_device_capability() == (9, 0): + use_cutlass_grouped_gemm.append(True) + def is_fused_attn_available( config: ModelConfig, @@ -1805,6 +1810,7 @@ def test_grouped_linear_accuracy( bias, delay_wgrad_compute, parallel_mode=None, + use_cutlass=False, ): fp8 = recipe is not None if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: @@ -1876,9 +1882,47 @@ def test_grouped_linear_accuracy( delay_wgrad_compute, ) - # Shoule be bit-wise match - for i, (o, o_ref) in enumerate(zip(outputs, outputs_ref)): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + for o, o_ref in zip(outputs, outputs_ref): + if use_cutlass: + torch.testing.assert_close(o, o_ref, rtol=1e-3, atol=1e-3) + else: + # cuBLAS implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + + +@pytest.mark.skipif( + torch.cuda.get_device_capability() != (9, 0), + reason="Only enable CUTLASS grouped gemm on Hopper", +) +@pytest.mark.parametrize("dtype", param_types, ids=str) +@pytest.mark.parametrize("num_gemms", [3, 6]) +@pytest.mark.parametrize("bs", batch_sizes) +@pytest.mark.parametrize("model", ["126m"]) +@pytest.mark.parametrize("fuse_wgrad_accumulation", all_boolean) +@pytest.mark.parametrize("delay_wgrad_compute", all_boolean) +def test_grouped_linear_accuracy_cutlass( + dtype, + num_gemms, + bs, + model, + fuse_wgrad_accumulation, + delay_wgrad_compute, +): + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + test_grouped_linear_accuracy( + dtype, + num_gemms, + bs, + model, + None, + False, + fuse_wgrad_accumulation, + False, + delay_wgrad_compute, + None, + use_cutlass=True, + ) + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) @pytest.mark.parametrize("dtype", param_types, ids=str) @@ -2542,10 +2586,11 @@ def test_transformer_layer_hidden_states_format(dtype, bs, model): (16, 10027, 128, 512), ], ) -@pytest.mark.parametrize("dtype", param_types) +@pytest.mark.parametrize("dtype", param_types, ids=str) @pytest.mark.parametrize("layout", ["TN", "NN", "NT"]) @pytest.mark.parametrize("accumulate", [False, True]) -def test_grouped_gemm(shape, dtype, layout, accumulate): +@pytest.mark.parametrize("use_cutlass", use_cutlass_grouped_gemm) +def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass): torch.manual_seed(0) z, m, k, n = shape @@ -2580,6 +2625,9 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): grad = True single_output = False + if use_cutlass: + os.environ["NVTE_USE_CUTLASS_GROUPED_GEMM"] = "1" + for i in range(z): general_gemm( A[i], @@ -2607,9 +2655,15 @@ def test_grouped_gemm(shape, dtype, layout, accumulate): single_output=single_output, ) - # should be bit-wise match for o, o_ref in zip(out, out_ref): - torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + if not use_cutlass: + # cublas implementation should be bit-wise match + torch.testing.assert_close(o, o_ref, rtol=0, atol=0) + else: + torch.testing.assert_close(o, o_ref, rtol=1.5e-2, atol=1.5e-2) + + if use_cutlass: + os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None) @pytest.mark.parametrize("N", [32]) diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index cb9f13b89..08e876404 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -45,6 +45,11 @@ if(NOT EXISTS "${CUDNN_FRONTEND_INCLUDE_DIR}") endif() include(${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cudnn-frontend/cmake/cuDNN.cmake) +set(CUTLASS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/include") +set(CUTLASS_TOOLS_INCLUDE_DIR + "${CMAKE_CURRENT_SOURCE_DIR}/../../3rdparty/cutlass/tools/util/include") + # Python find_package(Python COMPONENTS Interpreter Development.Module REQUIRED) @@ -81,6 +86,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn.cpp fused_attn/utils.cu gemm/cublaslt_gemm.cu + gemm/cutlass_grouped_gemm.cu normalization/common.cpp normalization/layernorm/ln_api.cpp normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -121,18 +127,30 @@ add_library(transformer_engine SHARED ${transformer_engine_SOURCES}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") - +if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) + set_source_files_properties( + "gemm/cutlass_grouped_gemm.cu" + PROPERTIES + COMPILE_FLAGS + "-gencode arch=compute_90a,code=sm_90a") +else() + message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") +endif() # Configure dependencies target_link_libraries(transformer_engine PUBLIC CUDA::cublas CUDA::cudart CUDNN::cudnn_all) + target_include_directories(transformer_engine PRIVATE - ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}) target_include_directories(transformer_engine SYSTEM PRIVATE ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES}/cccl) target_include_directories(transformer_engine PRIVATE "${CUDNN_FRONTEND_INCLUDE_DIR}") +target_include_directories(transformer_engine PRIVATE + ${CUTLASS_INCLUDE_DIR} + ${CUTLASS_TOOLS_INCLUDE_DIR}) # Compiling Userbuffers with native MPI bootstrapping requires linking against MPI option(NVTE_UB_WITH_MPI "Bootstrap Userbuffers with MPI" OFF) diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 9e6c5417b..f287072bc 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -19,6 +19,7 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#include "cutlass_grouped_gemm.cuh" namespace { @@ -650,9 +651,10 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor CUBLAS_VERSION); #endif NVTE_CHECK( - cuda::cudart_version() >= 12020 && cuda::cudart_version() < 13000, + transformer_engine::cuda::cudart_version() >= 12020 && + transformer_engine::cuda::cudart_version() < 13000, "Atomic GEMM requires CUDA version >=12.2.0 and <13.0.0, but run-time CUDA version is ", - cuda::cudart_version()); + transformer_engine::cuda::cudart_version()); NVTE_CHECK( cublas_version() >= 120205 && cublas_version() < 130000, "Atomic GEMM requires cuBLAS version >=12.2.5 and <13.0.0, but run-time cuBLAS version is ", @@ -675,13 +677,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor n_split, gemm_producer, inputCounter, stream); } -void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, - const NVTETensor *bias, NVTETensor *pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor *workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream) { - NVTE_API_CALL(nvte_multi_stream_cublas_gemm); +void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor *workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { using namespace transformer_engine; int num_streams = nvte_get_num_compute_streams(); @@ -711,6 +711,25 @@ void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVT } } +void nvte_multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor *D, + const NVTETensor *bias, NVTETensor *pre_gelu_out, + const int num_gemms, bool transa, bool transb, bool grad, + NVTETensor *workspace, bool accumulate, + bool use_split_accumulator, int math_sm_count, + cudaStream_t stream) { + NVTE_API_CALL(nvte_multi_stream_cublas_gemm); + using namespace transformer_engine; + + // Deprecation warning + NVTE_WARN( + "nvte_multi_stream_cublas_gemm is deprecated and will be removed in a future release. " + "Please migrate to nvte_multi_tensor_gemm (with CUTLASS Grouped GEMM support when " + "applicable)."); + + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, workspace, + accumulate, use_split_accumulator, math_sm_count, stream); +} + namespace transformer_engine { using cublasHandleManager = detail::HandleManager; @@ -718,3 +737,85 @@ using cublasHandleManager = detail::HandleManager("NVTE_USE_CUTLASS_GROUPED_GEMM", false); + const bool warn_fallback = + transformer_engine::getenv("NVTE_CUTLASS_GROUPED_GEMM_WARN_FALLBACK", false); + + auto cublas_path = [&]() { + multi_stream_cublas_gemm(A, B, D, bias, pre_gelu_out, num_gemms, transa, transb, grad, + workspace, accumulate, use_split_accumulator, math_sm_count, stream); + }; + + // Currently only support cutlass group gemm on Hopper Arch + if (!(is_hopper && use_cutlass)) { + cublas_path(); + return; + } + + auto is_empty_arr = [&](const NVTETensor *p) -> bool { + if (p == nullptr) return true; + for (int i = 0; i < num_gemms; ++i) { + if (transformer_engine::convertNVTETensor(p[i])->has_data()) return false; + } + return true; + }; + + auto all_groups_uniform_k128 = [&](const NVTETensor *p, bool trans) -> bool { + int64_t ref_k = -1; + for (size_t i = 0; i < num_gemms; i++) { + const auto tensor = transformer_engine::convertNVTETensorCheck(p[i]); + const int k = trans ? tensor->data.shape[0] : tensor->data.shape[1]; + + if ((k & 127) != 0) return false; + + if (ref_k < 0) + ref_k = k; + else if (k != ref_k) + return false; + } + + return true; + }; + + auto is_supported_dtype = [&]() -> bool { + auto *inputA = transformer_engine::convertNVTETensorCheck(A[0]); + auto *inputB = transformer_engine::convertNVTETensorCheck(B[0]); + auto *OutputD = transformer_engine::convertNVTETensorCheck(D[0]); + auto A_type = get_cuda_dtype(inputA->data.dtype); + auto B_type = get_cuda_dtype(inputB->data.dtype); + auto D_type = get_cuda_dtype(OutputD->data.dtype); + + return (A_type == B_type) && (A_type == D_type) && + ((A_type == CUDA_R_16BF) || (A_type == CUDA_R_16F)); + }; + + // CUTLASS Grouped GEMM fast path (SM90/TMA) + // Conditions: + // - No fused epilogue: both bias and pre_gelu_out are empty. + // - Supported dtypes only: FP16/BF16 (FP32 accumulate). + // - Uniform K across groups and K % 128 == 0. + // - use_split_accumulator is ignored for FP16/BF16. + // - grad is irrelevant when bias/pre_gelu_out are empty. + // + // Otherwise, fall back to cuBLAS. + if (is_empty_arr(bias) && is_empty_arr(pre_gelu_out) && is_supported_dtype() && + all_groups_uniform_k128(B, transb)) { + cutlass_grouped_gemm(A, B, D, num_gemms, transa, transb, grad, workspace, accumulate, + current_device, math_sm_count, stream); + } else { + if (warn_fallback) { + NVTE_WARN("Fallback to cuBLAS grouped GEMM."); + } + cublas_path(); + } +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cu b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu new file mode 100644 index 000000000..18736c4f5 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cu @@ -0,0 +1,77 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +#include "cutlass/bfloat16.h" +#include "cutlass/cutlass.h" +#include "cutlass_grouped_gemm.cuh" + +namespace transformer_engine { +namespace grouped_gemm { + +// Explicit template instantiation to match the template declarations in the .cuh +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, const NVTETensor*, + NVTETensor*, NVTETensor*, float, + float, int, cudaStream_t, int, int); + +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); +template void CutlassGroupedGemm(const NVTETensor*, + const NVTETensor*, NVTETensor*, + NVTETensor*, float, float, int, + cudaStream_t, int, int); + +} // namespace grouped_gemm +} // namespace transformer_engine + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream) { + using namespace transformer_engine; + auto* inputA = convertNVTETensorCheck(A[0]); + auto* inputB = convertNVTETensorCheck(B[0]); + + float one = 1.0; + float zero = 0.0; + float alpha = one; + float beta = (accumulate) ? one : zero; + + auto dispatch = [&](auto tag) { + using T = decltype(tag); + if (!transa && !transb) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (!transb && transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else if (transb && !transa) { + grouped_gemm::CutlassGroupedGemm(B, A, D, workspace, alpha, beta, num_gemms, + stream, device, math_sm_count); + } else { + NVTE_ERROR("Layout 'TT' is not supported by cutlass_grouped_gemm."); + } + }; + + if (inputA->data.dtype == DType::kBFloat16) { + dispatch(cutlass::bfloat16_t{}); + } else if (inputA->data.dtype == DType::kFloat16) { + dispatch(cutlass::half_t{}); + } else { + NVTE_ERROR("Unsupported dtype: only BF16(FP16) are supported."); + } +} diff --git a/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh new file mode 100644 index 000000000..1add57132 --- /dev/null +++ b/transformer_engine/common/gemm/cutlass_grouped_gemm.cuh @@ -0,0 +1,348 @@ +/*************************************************************************************************** + * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + **************************************************************************************************/ + +// +// Copyright (c) 2025 Shopee Inc. All Rights Reserved. +// + +/** + * @file: cutlass_grouped_gemm.cuh + * @author: min.yang@shopee.com, yangfan.bai@shopee.com, finch.li@shopee.com + * @date: 2025-08-08 16:20:00 + * @brief: cutlass group gemm kernel. + **/ + +#pragma once + +#include + +#include +#include + +#include "../common.h" +#include "../util/logging.h" +#include "common/util/system.h" +#include "cute/tensor.hpp" +#include "cutlass/bfloat16.h" +#include "cutlass/complex.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" + +namespace transformer_engine { +namespace grouped_gemm { + +template +using GroupedGemmInputALayout = + std::conditional_t; + +template +using GroupedGemmInputBLayout = + std::conditional_t; + +using ProblemShapeType = cute::Shape; +using ProblemShape = cutlass::gemm::GroupProblemShape; // per group +template +struct GemmGivenSchedule { + using ElementA = typename ScheduleConfig::DataType; // Element type for A matrix operand + using ElementB = typename ScheduleConfig::DataType; // Element type for B matrix operand + using ElementC = typename ScheduleConfig::DataType; // Element type for C and D matrix operands + + // A matrix configuration + using LayoutA = typename ScheduleConfig::LayoutA; // Layout type for A matrix operand + static constexpr int AlignmentA = + 128 / cutlass::sizeof_bits< + ElementA>::value; // Alignment of A matrix in units of elements (up to 16 bytes) + + // B matrix configuration + using LayoutB = typename ScheduleConfig::LayoutB; // Layout type for B matrix operand + static constexpr int AlignmentB = + 128 / cutlass::sizeof_bits< + ElementB>::value; // Alignment of B matrix in units of elements (up to 16 bytes) + + // C/D matrix configuration + using LayoutC = typename ScheduleConfig::LayoutC; // Layout type for C and D matrix operands + static constexpr int AlignmentC = + 128 / cutlass::sizeof_bits< + ElementC>::value; // Alignment of C matrix in units of elements (up to 16 bytes) + + // Core kernel configurations + using ElementAccumulator = float; // Element type for internal accumulation + using ArchTag = + cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature + using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag + using StageCountType = + cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = + typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, ElementAccumulator, ElementAccumulator, + ElementC, LayoutC*, AlignmentC, ElementC, LayoutC*, AlignmentC, EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination>::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, ElementA, LayoutA*, AlignmentA, ElementB, LayoutB*, AlignmentB, + ElementAccumulator, TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout( + sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule>::CollectiveOp; + + using GemmKernel = + cutlass::gemm::kernel::GemmUniversal; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +template +struct ScheduleConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = cute::Shape; + using ClusterShape = cute::Shape; + + // TODO(Alan): Add tuning for different scenarios to select the optimal configuration, + // as the current configuration may not be the best. + + // using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + // using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + // using TileShape = Shape; + // using ClusterShape = Shape; + + using LayoutA = GroupedGemmInputALayout; + using LayoutB = GroupedGemmInputBLayout; + using LayoutC = cutlass::layout::RowMajor; + using DataType = DataType_; +}; + +template +using GemmGrouped = typename GemmGivenSchedule>::Gemm; + +template +typename GemmT::Arguments MakeArguments(int num_experts, void* problem_sizes_host, + void* problem_sizes, const ElementA** ptr_A, + StrideA* stride_A, const ElementB** ptr_B, + StrideB* stride_B, ElementC** ptr_C, StrideC* stride_C, + float alpha, float beta, int device, int math_sm_count) { + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + + cutlass::KernelHardwareInfo kernel_hw_info = + cutlass::KernelHardwareInfo::make_kernel_hardware_info( + device, math_sm_count); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + fusion_args.alpha = alpha; + fusion_args.beta = beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + + arguments = + typename GemmT::Arguments{cutlass::gemm::GemmUniversalMode::kGrouped, + {num_experts, reinterpret_cast(problem_sizes), + reinterpret_cast(problem_sizes_host)}, + {ptr_A, stride_A, ptr_B, stride_B}, + { + fusion_args, + (beta > 0.0) ? (const ElementC**)ptr_C : nullptr, // NOLINT(*) + stride_C, + ptr_C, + stride_C, + }, + kernel_hw_info}; + + return arguments; +} + +template +inline __device__ __host__ T ROUND_UP(T m, T n) { + return (m + n - 1) / n * n; +} + +template +void debug_type() { + std::cout << typeid(T).name() << std::endl; +} + +int64_t inline getGemmCoordSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(ProblemShapeType), 128UL)); +} + +int64_t inline getPtrSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(half*), 128UL)); +} + +int64_t inline getLddSize(int64_t num_gemms) { + return (int64_t)(ROUND_UP(num_gemms * sizeof(int64_t), 128UL)); +} + +// cpu workspace size is 4MB +static constexpr size_t kCPUWorkSpaceSize = 4 * 1024 * 1024; + +static char* getHostWorkspace() { + static std::once_flag flag; + static std::shared_ptr workspace; + + std::call_once(flag, [&]() { + workspace = + std::shared_ptr(reinterpret_cast(std::malloc(kCPUWorkSpaceSize)), [](char* p) { + if (p) std::free(p); + }); + + if (!workspace) { + throw std::bad_alloc(); + } + }); + + return workspace.get(); +} + +template +void CutlassGroupedGemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + NVTETensor* workspace, float alpha, float beta, int num_gemms, + cudaStream_t stream, int device, int math_sm_count) { + using Gemm = GemmGrouped; + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + + using StrideA = typename Gemm::GemmKernel::InternalStrideA; + using StrideB = typename Gemm::GemmKernel::InternalStrideB; + using StrideC = typename Gemm::GemmKernel::InternalStrideC; + + typename Gemm::Arguments arguments; + size_t kernel_workspace_size = Gemm::get_workspace_size(arguments); + auto gemm_coord_size = getGemmCoordSize(num_gemms); + auto ptr_size = getPtrSize(num_gemms); + auto ldd_size = getLddSize(num_gemms); + auto param_workspace_size = 3 * ptr_size + 3 * ldd_size + gemm_coord_size; + + NVTE_CHECK( + param_workspace_size < kCPUWorkSpaceSize, + "Insufficient kCPUWorkSpaceSize size: required=", static_cast(param_workspace_size), + ", available=", static_cast(kCPUWorkSpaceSize), " for CUTLASS grouped GEMM."); + + auto total_workspace_size = param_workspace_size + kernel_workspace_size; + transformer_engine::Tensor* wspace = transformer_engine::convertNVTETensor(workspace[0]); + + NVTE_CHECK(total_workspace_size < wspace->numel(), "Insufficient workspace[0] size: required=", + static_cast(total_workspace_size), + ", available=", static_cast(wspace->numel()), " for CUTLASS grouped GEMM."); + + char* workspace_ptr = reinterpret_cast(wspace->data.dptr); + + char* kernel_workspace_ptr = nullptr; + + char* host_workspace = getHostWorkspace(); + + ProblemShapeType* problem_sizes_host = reinterpret_cast(host_workspace); + + ElementA** ptr_A_host = reinterpret_cast(host_workspace + gemm_coord_size); + ElementB** ptr_B_host = reinterpret_cast(host_workspace + gemm_coord_size + ptr_size); + ElementC** ptr_C_host = + reinterpret_cast(host_workspace + gemm_coord_size + 2 * ptr_size); + int64_t* lda_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + int64_t* ldb_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + int64_t* ldc_host = + reinterpret_cast(host_workspace + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + for (size_t i = 0; i < num_gemms; i++) { + const transformer_engine::Tensor* inputA = transformer_engine::convertNVTETensorCheck(A[i]); + const transformer_engine::Tensor* inputB = transformer_engine::convertNVTETensorCheck(B[i]); + transformer_engine::Tensor* outputD = transformer_engine::convertNVTETensor(D[i]); + + const int m = trans_a ? inputA->data.shape[1] : inputA->data.shape[0]; + const int k = trans_a ? inputA->data.shape[0] : inputA->data.shape[1]; + const int n = trans_b ? inputB->data.shape[0] : inputB->data.shape[1]; + + auto problem = ProblemShapeType(m, n, k); + problem_sizes_host[i] = problem; + + ptr_A_host[i] = reinterpret_cast(inputA->data.dptr); + ptr_B_host[i] = reinterpret_cast(inputB->data.dptr); + ptr_C_host[i] = reinterpret_cast(outputD->data.dptr); + + lda_host[i] = LayoutA::packed({m, k}).stride(0); + ldb_host[i] = LayoutB::packed({k, n}).stride(0); + ldc_host[i] = LayoutC::packed({m, n}).stride(0); + } + + cudaMemcpyAsync(workspace_ptr, host_workspace, param_workspace_size, cudaMemcpyHostToDevice, + stream); + + char* param_workspace_ptr = workspace_ptr; + ProblemShapeType* problem_sizes_device = reinterpret_cast(param_workspace_ptr); + const ElementA** ptr_A = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size); + const ElementB** ptr_B = reinterpret_cast( + reinterpret_cast(param_workspace_ptr) + gemm_coord_size + 1 * ptr_size); + ElementC** ptr_C = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 2 * ptr_size); + + StrideA* lda = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 0 * ldd_size); + StrideB* ldb = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 1 * ldd_size); + StrideC* ldc = reinterpret_cast(reinterpret_cast(param_workspace_ptr) + + gemm_coord_size + 3 * ptr_size + 2 * ldd_size); + + kernel_workspace_ptr = workspace_ptr + param_workspace_size; + + arguments = MakeArguments( + num_gemms, problem_sizes_host, problem_sizes_device, ptr_A, lda, ptr_B, ldb, ptr_C, ldc, + alpha, beta, device, math_sm_count); + + Gemm gemm; + + // Check can implement the kernel. + if (gemm.can_implement(arguments) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to implement CUTLASS Grouped GEMM"); + } + + // Initialize the kernel. + if (gemm.initialize(arguments, kernel_workspace_ptr) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to initialize CUTLASS Grouped GEMM"); + } + + // Execute the kernel in the current stream. + if (gemm.run(stream) != cutlass::Status::kSuccess) { + NVTE_CHECK(false, "Failed to run CUTLASS Grouped GEMM"); + } +} + +} // namespace grouped_gemm +} // namespace transformer_engine + +void cutlass_grouped_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, int device, int math_sm_count, cudaStream_t stream); diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 50b33909f..0c358328b 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -133,12 +133,11 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor * \param[in] math_sm_count Number of GPU SMs to use (default=0: use cuBLAS heuristics) * \param[in] stream CUDA stream to wait on. */ -void nvte_multi_stream_cublas_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, - const NVTETensor* bias, NVTETensor* pre_gelu_out, - const int num_gemms, bool transa, bool transb, bool grad, - NVTETensor* workspace, bool accumulate, - bool use_split_accumulator, int math_sm_count, - cudaStream_t stream); +void nvte_multi_tensor_gemm(const NVTETensor* A, const NVTETensor* B, NVTETensor* D, + const NVTETensor* bias, NVTETensor* pre_gelu_out, const int num_gemms, + bool transa, bool transb, bool grad, NVTETensor* workspace, + bool accumulate, bool use_split_accumulator, int math_sm_count, + cudaStream_t stream); #ifdef __cplusplus } // extern "C" #endif diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 113072131..06dded1d8 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -526,10 +526,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK_CUDA(cudaMemsetAsync(dptr, 0, count, stream_i)); } - nvte_multi_stream_cublas_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), - pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, - lhs_is_trans, grad, workspace_list.data(), accumulate, - use_split_accumulator, num_math_sm, stream); + nvte_multi_tensor_gemm(rhs_list.data(), lhs_list.data(), out_list.data(), bias_list.data(), + pre_gelu_list.data(), num_non_empty_gemms, rhs_is_trans, lhs_is_trans, + grad, workspace_list.data(), accumulate, use_split_accumulator, + num_math_sm, stream); return ffi_with_cuda_error_check(); } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index 485d67055..0d18a5ec5 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -477,11 +477,10 @@ std::optional> te_general_grouped_gemm( // For now, we only have multi-stream cublas backend. NVTE_SCOPED_GIL_RELEASE({ - nvte_multi_stream_cublas_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), - te_bias_vector.data(), te_pre_gelu_out_vector.data(), - te_A_vector.size(), transa, transb, grad, - te_workspace_vector.data(), accumulate, use_split_accumulator, - math_sm_count, at::cuda::getCurrentCUDAStream()); + nvte_multi_tensor_gemm(te_A_vector.data(), te_B_vector.data(), te_D_vector.data(), + te_bias_vector.data(), te_pre_gelu_out_vector.data(), te_A_vector.size(), + transa, transb, grad, te_workspace_vector.data(), accumulate, + use_split_accumulator, math_sm_count, at::cuda::getCurrentCUDAStream()); }); return bias; } From c334fc46bb166187b5b1da90b18ee16ddc4b2462 Mon Sep 17 00:00:00 2001 From: zhujian Date: Thu, 18 Sep 2025 11:14:08 +0800 Subject: [PATCH 146/153] [PyTorch] Support FA3 for MLA and with CP (#1907) feature(FA3,MLA,CP): 1. Update FA3 to commit-id 3ba6f82 (tag 2.8.0.post2 with compile error fixed), PR-1604 support hdimQK != hdimV backward 2. Update get_attention_backend method because FA3 support MLA now 3. Add CP MLA support for FA3 4. Add unit tests for FA3 MLA CP 5. Update attention doc Signed-off-by: zhujian --- docs/examples/attention/attention.ipynb | 2 +- .../attention/test_attention_with_cp.py | 8 + .../dot_product_attention/context_parallel.py | 249 +++++++++++------- .../attention/dot_product_attention/utils.py | 52 +++- 4 files changed, 205 insertions(+), 106 deletions(-) diff --git a/docs/examples/attention/attention.ipynb b/docs/examples/attention/attention.ipynb index 6cd56d23d..61a6ad949 100644 --- a/docs/examples/attention/attention.ipynb +++ b/docs/examples/attention/attention.ipynb @@ -390,7 +390,7 @@ "| Attention Backend | Precision | Architecture | Sliding Window Attention | MQA/GQA | Multi-Latent Attention | Context Parallelism | Determinism Possible |\n", "| :---------------- | :-------- | :----------- | :----------------------- | :------ | :--------------------- | :------------------ | :------------ |\n", "| cuDNN attention (all frameworks) | BF16, FP16, FP8 (PyTorch only) | sm80+ | No | Yes | Yes | Yes (`bshd`,`sbhd`, `thd`) | Yes |\n", - "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | No | Yes (`bshd`,`thd`) | Yes |\n", + "| flash-attention (PyTorch) | BF16, FP16 | sm80+ | Yes | Yes | Yes | Yes (`bshd`,`thd`) | Yes |\n", "| Framework-native attention | BF16, FP16, FP32 | Any | No, unless used as a mask | Yes | Yes (PyTorch only) | No | Yes |\n", "\n", "Some unit tests are provided to serve as a starting point for integrating such features into users' models. For example,\n", diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 0e8501abf..7078cb69d 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -36,6 +36,12 @@ 2, 4096, 12, 128, num_gqa_groups=2, attn_mask_type="causal", window_size=(512, 0) ), # GQA "cp_2_3": ModelConfig(2, 4096, 12, 128, num_gqa_groups=2, window_size=(512, 512)), # GQA + "cp_3_0": ModelConfig(2, 4096, 12, 192, attn_mask_type="causal", head_dim_v=128), # MLA + "cp_3_1": ModelConfig(2, 4096, 12, 192, head_dim_v=128), # MLA + "cp_3_2": ModelConfig( + 2, 4096, 12, 192, attn_mask_type="causal", window_size=(512, 0), head_dim_v=128 + ), # MLA + "cp_3_3": ModelConfig(2, 4096, 12, 192, window_size=(512, 512), head_dim_v=128), # MLA } @@ -81,6 +87,8 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): f"CP implementation with QKVO A2A requires num_heads ({config.num_heads}) and" f" num_gqa_groups ({config.num_gqa_groups}) to be divisible by cp_size (2)!" ) + if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: + pytest.skip("MLA CP currently only support KV P2P!") subprocess.run( get_bash_arguments( diff --git a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py index f00bd573f..09384217c 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py @@ -358,7 +358,7 @@ def get_fa_args( max_seqlen_q, max_seqlen_kv, *[None] - * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale ] return [ *[None] @@ -366,7 +366,7 @@ def get_fa_args( max_seqlen_q, max_seqlen_kv, *[None] - * 8, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, q_descale, k_descale, v_descale + * 9, # page_table, kv_batch_idx, leftpad_k, rotary_cos, rotary_sin, seqlens_rotary, q_descale, k_descale, v_descale ] if qkv_format == "thd": return [ @@ -829,6 +829,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -838,19 +851,10 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=True, **fa_forward_kwargs, @@ -985,6 +989,22 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + else: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1001,19 +1021,10 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1144,6 +1155,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1160,19 +1184,10 @@ def forward( elif fa_utils.v2_7_0_plus: fa_forward_kwargs["window_size_left"] = -1 fa_forward_kwargs["window_size_right"] = -1 - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q_inputs[i % 2], - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1269,6 +1284,19 @@ def forward( softmax_lse_per_step[i], rng_states[i], *rest = aux_ctx_tensors attn_biases[i] = rest[0] if len(rest) > 0 else None else: + if not enable_mla: + # If MHA, then split the KV into k_part and v_part. + # Otherwise (MHA), k_part and v_part have already been split. + k_part = ( + kv_inputs[i % 2][..., 0, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][0] + ) + v_part = ( + kv_inputs[i % 2][..., 1, :, :] + if qkv_format in ["bshd", "sbhd"] + else kv_inputs[i % 2][1] + ) fa_forward_args_thd = get_fa_args( True, use_flash_attn_3, @@ -1278,19 +1306,10 @@ def forward( max_seqlen_q=max_seqlen_q, max_seqlen_kv=max_seqlen_kv, ) - # Need to add MLA support once Flash Attention supports MLA fa_outputs = flash_attn_fwd( q, - ( - kv_inputs[i % 2][..., 0, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][0] - ), - ( - kv_inputs[i % 2][..., 1, :, :] - if qkv_format in ["bshd", "sbhd"] - else kv_inputs[i % 2][1] - ), + k_part, + v_part, *fa_forward_args_thd, causal=False, **fa_forward_kwargs, @@ -1865,7 +1884,27 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -1875,16 +1914,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -1895,12 +1926,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = 0 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse, *fa_backward_args_thd, @@ -2016,7 +2046,29 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + k_part = k_part.contiguous() + v_part = v_part.contiguous() + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2026,16 +2078,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv // 2, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -2046,12 +2090,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse, *fa_backward_args_thd, @@ -2160,7 +2203,27 @@ def backward(ctx, dout): dv_ = dv_._data else: dq_ = torch.empty_like(q_) - dkv_ = torch.empty_like(kv_) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = ( + kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0] + ) + v_part = ( + kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1] + ) + dkv_ = torch.empty_like(kv_) + dk_ = ( + dkv_[..., 0, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[0] + ) + dv_ = ( + dkv_[..., 1, :, :] + if ctx.qkv_format in ["bshd", "sbhd"] + else dkv_[1] + ) fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2170,16 +2233,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q // 2, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=( - dkv_[..., 0, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[0] - ), - dv=( - dkv_[..., 1, :, :] - if ctx.qkv_format in ["bshd", "sbhd"] - else dkv_[1] - ), + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or ( fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus @@ -2190,12 +2245,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout_, q_, - kv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[0], - kv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv_[1], + k_part, + v_part, out_, softmax_lse_, *fa_backward_args_thd, @@ -2267,7 +2321,15 @@ def backward(ctx, dout): else: dq_ = torch.empty_like(q) - dkv_ = torch.empty_like(kv) + if ctx.enable_mla: + dk_ = torch.empty_like(k_part) + dv_ = torch.empty_like(v_part) + else: + k_part = kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0] + v_part = kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1] + dkv_ = torch.empty_like(kv) + dk_ = dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0] + dv_ = dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1] fa_backward_args_thd = get_fa_args( False, ctx.use_flash_attn_3, @@ -2277,8 +2339,8 @@ def backward(ctx, dout): max_seqlen_q=ctx.max_seqlen_q, max_seqlen_kv=ctx.max_seqlen_kv, dq=dq_, - dk=dkv_[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[0], - dv=dkv_[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else dkv_[1], + dk=dk_, + dv=dv_, ) if ctx.use_flash_attn_3 or (fa_utils.v2_3_plus and not fa_utils.v2_7_0_plus): fa_backward_kwargs["window_size"] = (-1, -1) @@ -2287,12 +2349,11 @@ def backward(ctx, dout): fa_backward_kwargs["window_size_right"] = -1 if not ctx.use_flash_attn_3: fa_backward_kwargs["rng_state"] = rng_states[cp_size - i - 1] - # Need to add MLA support once Flash Attention supports MLA flash_attn_bwd( dout, q, - kv[..., 0, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[0], - kv[..., 1, :, :] if ctx.qkv_format in ["bshd", "sbhd"] else kv[1], + k_part, + v_part, out, softmax_lse, *fa_backward_args_thd, diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index 7097f4ba0..fffda8136 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -126,10 +126,10 @@ class FlashAttentionUtils: # Please follow these instructions to install FA3 v3_installation_steps = """\ (1) git clone https://github.com/Dao-AILab/flash-attention.git -(2) cd flash-attention/ && git checkout 27f501d && cd hopper/ && python setup.py install +(2) cd flash-attention/ && git checkout 3ba6f82 && git submodule update --init && cd hopper/ && python setup.py install (3) python_path=`python -c "import site; print(site.getsitepackages()[0])"` (4) mkdir -p $python_path/flash_attn_3 -(5) wget -P $python_path/flash_attn_3 https://raw.githubusercontent.com/Dao-AILab/flash-attention/27f501dbe011f4371bff938fe7e09311ab3002fa/hopper/flash_attn_interface.py""" +(5) cp flash_attn_interface.py $python_path/flash_attn_3/flash_attn_interface.py""" v3_warning_printed = False @staticmethod @@ -477,11 +477,10 @@ def get_attention_backend( # Filter: Head dimension if head_dim_qk != head_dim_v: - if (use_flash_attention_2 and FlashAttentionUtils.is_installed) or ( - use_flash_attention_3 and FlashAttentionUtils.v3_is_installed - ): - logger.debug("Disabling FlashAttention as it does not support MLA.") - use_flash_attention = False + if use_flash_attention_2 and FlashAttentionUtils.is_installed: + logger.debug("Disabling FlashAttention 2 as it does not support MLA.") + use_flash_attention_2 = False + qkv_layout_group = qkv_layout.replace("b", "").replace("s", "").replace("t", "") if use_fused_attention and qkv_layout_group != "hd_hd_hd": logger.debug( @@ -508,10 +507,41 @@ def get_attention_backend( ".".join([str(i) for i in device_compute_capability]), ) use_flash_attention_2 = False - if use_flash_attention_3 and (head_dim_qk > 128 or head_dim_v > 128): - if FlashAttentionUtils.v3_is_installed: - logger.debug("Disabling FlashAttention 3 for head_dim > 128") - use_flash_attention_3 = False + if use_flash_attention_3: + + def _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype): + if head_dim_qk > 256 or num_heads % num_gqa_groups != 0: + return False + if head_dim_qk != head_dim_v: + cond1 = 128 < head_dim_qk <= 192 + cond2 = 96 < head_dim_v <= 128 + cond3 = head_dim_qk <= 64 and head_dim_v <= 512 + if not ((cond1 and cond2) or cond3): + return False + if head_dim_v > 256 and qkv_dtype not in (torch.bfloat16, torch.float16): + return False + return True + + if not _is_fa3_supported(num_heads, num_gqa_groups, head_dim_qk, head_dim_v, qkv_dtype): + if FlashAttentionUtils.v3_is_installed: + logger.debug( + "Disabling FlashAttention 3 due to unsupported num_heads, num_gqa_groups, " + "head_dim_qk, head_dim_v or qkv_dtype. " + "Supported: head_dim_qk <= 256, and num_heads %% num_gqa_groups = 0, and " + "if head_dim_qk is different from head_dim_v, then " + "(head_dim_qk must in (128, 192] and head_dim_v in (96, 128]) or " + "(head_dim_qk <= 64 and head_dim_v <= 512), and " + "if head_dim_qk is different from head_dim_v and head_dim_v > 256, then " + "qkv_dtype requires fp16 and bf16 data type. " + "Found: num_heads = %s, num_gqa_groups = %s, " + "head_dim_qk = %s, head_dim_v = %s and qkv_dtype = %s.", + num_heads, + num_gqa_groups, + head_dim_qk, + head_dim_v, + qkv_dtype, + ) + use_flash_attention_3 = False # Filter: QKV layout if qkv_format == "thd": From 7f77127cbe5dfc37d5ce02c2e7ba388cfa2a83d4 Mon Sep 17 00:00:00 2001 From: Kshitij Lakhani <33047503+KshitijLakhani@users.noreply.github.com> Date: Thu, 18 Sep 2025 10:22:11 -0700 Subject: [PATCH 147/153] Fix cuDNN version checks when getting backend and for sm89 kv cache (#2185) * Fix cudnn version checks for kv cache for sm89. Add cudnn version check in preparation for 9.14 when getting backend Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Minor fix for cuDNN version condition check Signed-off-by: Kshitij Lakhani * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Signed-off-by: Kshitij Lakhani Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- transformer_engine/common/fused_attn/fused_attn.cpp | 10 +++++----- .../pytorch/attention/dot_product_attention/utils.py | 6 ++++-- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/transformer_engine/common/fused_attn/fused_attn.cpp b/transformer_engine/common/fused_attn/fused_attn.cpp index 60b10862e..795697635 100644 --- a/transformer_engine/common/fused_attn/fused_attn.cpp +++ b/transformer_engine/common/fused_attn/fused_attn.cpp @@ -251,11 +251,11 @@ NVTE_Fused_Attn_Backend nvte_get_fused_attn_backend( // 9.11: d_qk = 192, d_v = 128 + Blackwell + bprop + non-paged (head_dim_qk == 192 && head_dim_v == 128 && is_training && sm_arch_ >= 100 && cudnn_runtime_version >= 91100)) && - // 9.11/9.12 bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA - (!((cudnn_runtime_version == 91100 || cudnn_runtime_version == 91200 || - cudnn_runtime_version == 91300) && - is_training && sm_arch_ == 90 && head_dim_qk >= 128 && head_dim_v >= 128 && - !(head_dim_qk == 192 && head_dim_v == 128) && head_dim_qk != head_dim_v))) && + // 9.11+ bug: 128 < d_qk <= 256, 128 < d_v <= 256 + Hopper + bprop + MLA + // Conditional to temporarily use blanket cudnn_runtime_version >= 9.11 until fixed + (!((cudnn_runtime_version >= 91100) && is_training && sm_arch_ == 90 && + head_dim_qk >= 128 && head_dim_v >= 128 && !(head_dim_qk == 192 && head_dim_v == 128) && + head_dim_qk != head_dim_v))) && // bias type ((cudnn_runtime_version < 8906 && bias_type == NVTE_Bias_Type::NVTE_NO_BIAS) || (cudnn_runtime_version >= 8906 && diff --git a/transformer_engine/pytorch/attention/dot_product_attention/utils.py b/transformer_engine/pytorch/attention/dot_product_attention/utils.py index fffda8136..9b2b9a1ac 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/utils.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/utils.py @@ -434,8 +434,10 @@ def get_attention_backend( # | FP8 | non-paged/paged | sm90 | thd | >= 1 # Unfused | FP32/FP16/BF16 | non-paged/paged | all | bshd,sbhd,thd | >= 1 if inference_params is not None: - if device_compute_capability == (8, 9) and cudnn_version <= (9, 13, 0): - logger.debug("Disabling FusedAttention for KV caching for sm89 and cuDNN <= 9.13") + # Temporarily disabling fused attention for kv caching for sm89 irrespective of cuDNN version + # until the cuDNN bug is resolved + if device_compute_capability == (8, 9): + logger.debug("Disabling FusedAttention for KV caching for sm89") use_fused_attention = False if context_parallel: logger.debug("Disabling all backends for KV caching with context parallelism") From b95717e67202194c6abf64f8a671edcab466a0ff Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Mon, 12 Jan 2026 23:47:51 -0500 Subject: [PATCH 148/153] Merged conflits resolution and restore ROCm functionality Fix build ant UT issues Upcoming ROCm and JAX 0.8 support - cherry-pick: 8e25035 03525d3 (#403) --- ci/jax.sh | 18 ++- ci/pytorch.sh | 10 +- hipify_custom_map.json | 3 +- tests/cpp/CMakeLists.txt | 2 +- .../cpp/operator/test_cast_current_scaling.cu | 2 +- tests/cpp/operator/test_cast_mxfp8.cu | 45 ++++++-- .../operator/test_cast_mxfp8_gated_swiglu.cu | 37 +++++- tests/cpp/test_common.cu | 108 +++++++----------- tests/cpp/test_common.h | 16 +-- tests/jax/distributed_test_base.py | 5 +- tests/jax/test_distributed_layernorm_mlp.py | 2 +- tests/jax/test_distributed_softmax.py | 2 +- tests/jax/test_fused_attn.py | 10 +- tests/pytorch/attention/test_attention.py | 58 +++++----- tests/pytorch/attention/test_kv_cache.py | 8 ++ tests/pytorch/test_numerics.py | 12 +- tests/pytorch/test_sanity.py | 5 +- tests/pytorch/utils.py | 21 +++- transformer_engine/common/CMakeLists.txt | 7 +- transformer_engine/common/common.cu | 4 +- transformer_engine/common/common.h | 25 +++- transformer_engine/common/dropout/dropout.cu | 8 ++ .../common/gemm/cublaslt_gemm.cu | 89 ++++++--------- transformer_engine/common/gemm/rocm_gemm.cu | 60 ++++------ .../include/transformer_engine/recipe.h | 33 +++--- .../common/normalization/layernorm/ln_api.cpp | 2 +- .../layernorm/ln_bwd_semi_cuda_kernel.cu | 4 +- .../normalization/rmsnorm/rmsnorm_api.cpp | 4 +- .../common/recipe/current_scaling.cu | 56 +++++---- transformer_engine/common/swizzle/swizzle.cu | 13 +++ .../common/util/cast_gated_kernels.cuh | 89 ++++++++++----- .../common/util/cast_kernels.cuh | 26 ++++- .../common/util/cuda_runtime.cpp | 27 +++++ transformer_engine/common/util/cuda_runtime.h | 6 +- transformer_engine/common/util/logging.h | 3 +- transformer_engine/common/util/ptx.cuh | 5 + .../common/util/rocm_cast_gated_kernels.cuh | 61 ++++++---- .../common/util/rocm_cast_kernels.cuh | 43 ++++--- .../common/util/rocm_dequantize_kernels.cuh | 20 ++-- transformer_engine/common/util/rtc.cpp | 2 +- transformer_engine/common/utils.cuh | 1 - transformer_engine/jax/cpp_extensions/gemm.py | 2 + .../pytorch/csrc/extensions/activation.cpp | 2 +- .../pytorch/csrc/extensions/bias.cpp | 2 +- .../pytorch/csrc/extensions/dropout.cpp | 5 +- .../pytorch/csrc/extensions/pybind.cpp | 6 +- .../pytorch/csrc/extensions/recipe.cpp | 2 +- transformer_engine/pytorch/csrc/quantizer.cpp | 10 ++ .../pytorch/module/layernorm_linear.py | 3 +- .../pytorch/triton_kernels/cast.py | 3 +- 50 files changed, 606 insertions(+), 381 deletions(-) diff --git a/ci/jax.sh b/ci/jax.sh index 82dd97928..0f2aef2e8 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -21,6 +21,12 @@ install_prerequisites() { script_error "Failed to install Flax and dependencies" return $rc fi + pip install pytest-timeout + rc=$? + if [ $rc -ne 0 ]; then + script_error "Failed to install test prerequisites" + exit $rc + fi } TEST_DIR=${TE_PATH}tests/jax @@ -55,22 +61,26 @@ run_test_config() { run_default_fa 1 test_helper.py run_default_fa 1 test_layer.py #it effectevly always uses unfused attention run_default_fa 1 test_sanity_import.py - run_default_fa 1 test_sharding.py run_default_fa 1 test_softmax.py } run_test_config_mgpu() { echo ==== Run mGPU with Fused attention backend: $_fus_attn ==== configure_omp_threads 8 + + # Mitigate distributed tests hang by adding 5min timeout + _timeout_args="--timeout 300 --timeout-method thread" + # Workaround for some distributed tests hang/abotrion + export XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" + if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then _dfa_level=2 else _dfa_level=3 fi - # Workaround for distributed tests hang with xla_flag - XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run $_dfa_level test_distributed_fused_attn.py + run $_dfa_level test_distributed_fused_attn.py $_timeout_args run_default_fa 3 test_distributed_layernorm.py - XLA_FLAGS="--xla_gpu_enable_nccl_comm_splitting=false" run_default_fa 2 test_distributed_layernorm_mlp.py + run_default_fa 2 test_distributed_layernorm_mlp.py $_timeout_args run_default_fa 3 test_distributed_softmax.py run_default_fa 3 test_sanity_import.py diff --git a/ci/pytorch.sh b/ci/pytorch.sh index 1b3aefd36..32791d100 100755 --- a/ci/pytorch.sh +++ b/ci/pytorch.sh @@ -1,5 +1,5 @@ #!/bin/sh -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. @@ -65,7 +65,9 @@ run_test_config(){ run_default_fa 1 test_recipe.py run 1 test_sanity.py run_default_fa 1 test_sanity_import.py - run_default_fa 1 fused_attn/test_fused_attn.py # Backend selection is controlled by the test + run_default_fa 1 attention/test_attention.py # Backend selection is controlled by the test + run_default_fa 1 attention/test_cp_utils.py + run_default_fa 1 attention/test_kv_cache.py run_default_fa 1 triton_kernels/test_cast.py run_default_fa 1 triton_kernels/test_cast_mxfp8.py run_default_fa 1 triton_kernels/test_norm_common.py @@ -88,8 +90,8 @@ run_test_config_mgpu(){ run_default_fa 2 distributed/test_numerics.py run_default_fa 1 distributed/test_torch_fsdp2.py run_default_fa 2 distributed/test_torch_fsdp2_fp8.py - run_default_fa_lbl "flash" 3 fused_attn/test_fused_attn_with_cp.py -k "with_flash" - run_default_fa_lbl "fused" 2 fused_attn/test_fused_attn_with_cp.py -k "with_fused" + run_default_fa_lbl "flash" 3 attention/test_attention_with_cp.py -k "with_flash" + run_default_fa_lbl "fused" 2 attention/test_attention_with_cp.py -k "with_fused" } run_benchmark() { diff --git a/hipify_custom_map.json b/hipify_custom_map.json index 8773c233e..97824bbdb 100644 --- a/hipify_custom_map.json +++ b/hipify_custom_map.json @@ -5,7 +5,8 @@ "util/cuda_runtime.h" : "util/hip_runtime.h", "ATen/cudnn/Handle.h" : "ATen/miopen/Handle.h", "CUfunc_cache" : "hipFuncCache_t", - "" : "" + "" : "", + "cudaFuncSetAttribute(" : "hipFuncSetAttribute((const void*)" } } diff --git a/tests/cpp/CMakeLists.txt b/tests/cpp/CMakeLists.txt index a9776d3a9..b71addebf 100644 --- a/tests/cpp/CMakeLists.txt +++ b/tests/cpp/CMakeLists.txt @@ -64,7 +64,7 @@ else() project(transformer_engine_tests LANGUAGES HIP CXX) # Ask hcc to generate device code during compilation so we can use # host linker to link. - set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted -Wno-unused-result -ftemplate-backtrace-limit=0") + set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -fno-gpu-rdc -Wno-defaulted-function-deleted -Wno-unused-result -Wno-unused-value -ftemplate-backtrace-limit=0") endif() add_subdirectory(../../3rdparty/googletest ${PROJECT_BINARY_DIR}/googletest) diff --git a/tests/cpp/operator/test_cast_current_scaling.cu b/tests/cpp/operator/test_cast_current_scaling.cu index 856c24cfc..89ae79f9a 100644 --- a/tests/cpp/operator/test_cast_current_scaling.cu +++ b/tests/cpp/operator/test_cast_current_scaling.cu @@ -219,7 +219,7 @@ TEST(AmaxConsistencyTest, AtomicVsWorkspace) { // Path 2: two-stage amax using workspace std::vector ws_shape{N}; Tensor workspace("workspace", ws_shape, DType::kFloat32); - nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), 0); + nvte_compute_amax_with_workspace(input.data(), out_ws.data(), workspace.data(), nullptr, 0); cudaDeviceSynchronize(); auto err = cudaGetLastError(); diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 1b4983ca5..9e4e12bf8 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -55,7 +55,7 @@ void compute_ref(const ProcessingMethod processing_method, const size_t scales_stride_rowwise, const size_t scales_stride_colwise) { -#ifdef __HIP_PLATFORM_AMD__//PIV TODO: Check isnan isnanf isinf isinff availability +#ifdef __HIP_PLATFORM_AMD__ using std::isnan, std::isinf; #endif const size_t tile_size_Y = 32; @@ -311,17 +311,26 @@ void performTest_x1(const ProcessingMethod processing_method, : output_c.columnwise_cpu_scale_inv_ptr(); const size_t scale_diff_abs_tolerance = 0; - const double abs_tolerable_mismatches_limit = 0.0; - const double rel_tolerable_mismatches_limit = 0.0; + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; + std::vector mismatches_scales_indices; size_t mismatches_scales = 0; compare_e8m0_scaling_factors("scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, - mismatches_scales, + mismatches_scales_indices, mismatches_scales, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); +#ifdef __HIP_PLATFORM_AMD__ + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr, + ref_output_scales.get(), scales_stride, rows, cols, rowwise, + ref_output_c.get(), otype); + mismatches_scales = 0; +#endif + const size_t mismatches_elts = 32 * mismatches_scales; auto [atol, rtol] = getTolerances(otype); compareResults("output_c", output_c, ref_output_c.get(), rowwise, atol, rtol, true, mismatches_elts); @@ -465,7 +474,7 @@ void performTest_x2(const ProcessingMethod processing_method, auto err = cudaGetLastError(); ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); - compute_ref(processing_method,//PIV TODO: AMD path + compute_ref(processing_method, OP, true, true, @@ -482,27 +491,43 @@ void performTest_x2(const ProcessingMethod processing_method, scales_stride_colwise); const size_t scale_diff_abs_tolerance = 0; - const double abs_tolerable_mismatches_limit = 0.0; - const double rel_tolerable_mismatches_limit = 0.0; + const double abs_tolerable_mismatches_limit = 1.0; + const double rel_tolerable_mismatches_limit = 1.0e-4; + std::vector mismatches_scales_indices_rowwise; size_t mismatches_scales_rowwise = 0; compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_rowwise, + mismatches_scales_indices_rowwise, mismatches_scales_rowwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + std::vector mismatches_scales_indices_colwise; size_t mismatches_scales_colwise = 0; compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_colwise, + mismatches_scales_indices_colwise, mismatches_scales_colwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); +#ifdef __HIP_PLATFORM_AMD__ + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("scales_rowwise", mismatches_scales_indices_rowwise, + output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), scales_stride_rowwise, rows, cols, + true, ref_output_c_rowwise.get(), otype); + mismatches_scales_rowwise = 0; + adjust_ref_for_e8m0_scale_error("scales_colwise", mismatches_scales_indices_colwise, + output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), scales_stride_colwise, rows, cols, + false, ref_output_c_colwise.get(), otype); + mismatches_scales_colwise = 0; +#endif + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; diff --git a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu index a32e2ace8..52180786d 100644 --- a/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu +++ b/tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -249,7 +249,7 @@ void performTest_x1(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref(grad.rowwise_cpu_dptr(),//PIV TODO: AMD path + compute_ref(grad.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), ref_output.get(), ref_output.get(), @@ -264,6 +264,7 @@ void performTest_x1(const size_t rows, rowwise, colwise); + std::vector mismatches_scales_indices; size_t mismatches_scales = 0; const size_t scale_diff_abs_tolerance = 0; const double abs_tolerable_mismatches_limit = 1.0; @@ -275,6 +276,7 @@ void performTest_x1(const size_t rows, if (rowwise) { compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales_indices, mismatches_scales, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, @@ -282,12 +284,21 @@ void performTest_x1(const size_t rows, } else { compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(), unpadded_blocks_Y, unpadded_blocks_X, scales_stride, + mismatches_scales_indices, mismatches_scales, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); } +#ifdef __HIP_PLATFORM_AMD__ + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("scales", mismatches_scales_indices, gpu_scales_ptr, + ref_output_scales.get(), scales_stride, rows, output_cols, + rowwise, ref_output.get(), otype); + mismatches_scales = 0; +#endif + const size_t mismatches_elts = 32 * mismatches_scales; auto [atol, rtol] = getTolerances(otype); compareResults("output", output, ref_output.get(), rowwise, atol, rtol, true, mismatches_elts); @@ -364,7 +375,7 @@ void performTest_x2(const size_t rows, ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err); float ref_amax = 0; - compute_ref(grad.rowwise_cpu_dptr(),//PIV TODO: AMD path + compute_ref(grad.rowwise_cpu_dptr(), input.rowwise_cpu_dptr(), ref_output_rowwise.get(), ref_output_colwise.get(), @@ -383,23 +394,39 @@ void performTest_x2(const size_t rows, const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; + std::vector mismatches_scales_indices_rowwise; size_t mismatches_scales_rowwise = 0; compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr(), ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise, unpadded_blocks_X_rowwise, scales_stride_rowwise, - mismatches_scales_rowwise, + mismatches_scales_indices_rowwise, mismatches_scales_rowwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); + std::vector mismatches_scales_indices_colwise; size_t mismatches_scales_colwise = 0; compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr(), ref_scales_colwise.get(), unpadded_blocks_Y_colwise, unpadded_blocks_X_colwise, scales_stride_colwise, - mismatches_scales_colwise, + mismatches_scales_indices_colwise, mismatches_scales_colwise, scale_diff_abs_tolerance, abs_tolerable_mismatches_limit, rel_tolerable_mismatches_limit); +#ifdef __HIP_PLATFORM_AMD__ + if (::testing::Test::HasFatalFailure()) return; + adjust_ref_for_e8m0_scale_error("scales_rowwise", mismatches_scales_indices_rowwise, + output.rowwise_cpu_scale_inv_ptr(), + ref_scales_rowwise.get(), scales_stride_rowwise, rows, + output_cols, true, ref_output_rowwise.get(), otype); + mismatches_scales_rowwise = 0; + adjust_ref_for_e8m0_scale_error("scales_colwise", mismatches_scales_indices_colwise, + output.columnwise_cpu_scale_inv_ptr(), + ref_scales_colwise.get(), scales_stride_colwise, rows, + output_cols, false, ref_output_colwise.get(), otype); + mismatches_scales_colwise = 0; +#endif + const size_t mismatches_elts_rowwise = 32 * mismatches_scales_rowwise; const size_t mismatches_elts_colwise = 32 * mismatches_scales_colwise; diff --git a/tests/cpp/test_common.cu b/tests/cpp/test_common.cu index 5f5a48f50..9f926d07b 100644 --- a/tests/cpp/test_common.cu +++ b/tests/cpp/test_common.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -539,13 +539,8 @@ void compareResults_sequential(const std::string &name, const Tensor &test, const T *test_data = rowwise ? test.rowwise_cpu_dptr() : test.columnwise_cpu_dptr(); const T *ref_data = reinterpret_cast(ref); for (size_t i = 0; i < N; ++i) { -#ifndef __HIP_PLATFORM_AMD__ double t = static_cast(test_data[i]); double r = static_cast(ref_data[i]); -#else - double t = static_cast(static_cast(test_data[i])); - double r = static_cast(static_cast(ref_data[i])); -#endif bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); /* For Float32 the floating point comparison is enough to error out */ bool assertion = mismatch && test.dtype() == DType::kFloat32; @@ -593,7 +588,7 @@ static size_t getFirstMismatchIdx(const DType data_type, const T* test_data, con size_t thread_mismatches = 0; #pragma omp for schedule(static) for (size_t i = 0; i < N; ++i) { - double t = static_cast(test_data[i]);//PIV TODO: static_cast(static_cast(test_data[i]) + double t = static_cast(test_data[i]); double r = static_cast(ref_data[i]); bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol); @@ -693,6 +688,7 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, + std::vector &mismatch_indices, size_t& mismatches_num, const size_t atol, const double abs_tolerable_mismatches_limit, const double rel_tolerable_mismatches_limit) @@ -701,7 +697,6 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const size_t tolerable_mismatches_limit = std::min(abs_tolerable_mismatches_limit, std::floor(N * rel_tolerable_mismatches_limit)); mismatches_num = 0; - std::vector mismatch_indices; for (int i = 0; i < row_blocks; ++i) { for (int j = 0; j < col_blocks; ++j) { @@ -728,71 +723,48 @@ void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, } } -#ifdef __HIP_PLATFORM_AMD__ //PIV TODO: merge with upstream -void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - double tol, bool rowwise, std::vector> &mismatch_idx) { - const uint8_t *const test = rowwise ? output.rowwise_cpu_scale_inv_ptr() - : output.columnwise_cpu_scale_inv_ptr(); - - const double scale_tol = std::max(1., row_blocks * col_blocks * tol); - - for (int i = 0; i < row_blocks; i++) { - for (int j = 0; j < col_blocks; j++) { - const int idx = i * stride + j; - if (test[idx] != ref[idx]) { - int t_scale = static_cast(test[idx]); - int r_scale = static_cast(ref[idx]); - if (std::abs(t_scale - r_scale) == 1) { - mismatch_idx.emplace_back(i, j, r_scale-t_scale); - } else { - GTEST_FAIL() << "Error in " << name << std::endl - << "Mismatch: " << t_scale << " vs " - << r_scale << " at index " << idx; - } - } - } - } - const size_t scale_mismatches = mismatch_idx.size(); - ASSERT_FALSE(scale_mismatches > scale_tol) - << "Error in " << name << std::endl << std::setprecision(4) - << "Total scale mismatches: " << scale_mismatches << " (" << 100.*(double)scale_mismatches/(double)(row_blocks*col_blocks) - << "%) Exceeds tolerance of " << scale_tol << " (" << 100.*tol << "%) mismatches"; - - if (scale_mismatches) { - std::cout << "\x1b[33mWARNING:\x1b[0m " << scale_mismatches - << " scale mismatches were found. This does not imply an accuracy issue." << std::endl; - } -} - -void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, - const size_t col_blocks, const size_t rows, const size_t cols, DType otype) { - TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY( otype, T, - T *ref_data = reinterpret_cast(ref); +#ifdef __HIP_PLATFORM_AMD__ +void adjust_ref_for_e8m0_scale_error(const std::string &name, + const std::vector &mismatch_idx, + const uint8_t *test_scale, const uint8_t *ref_scale, + const size_t scale_stride, const size_t rows, + const size_t cols, bool rowwise, void *ref_ptr, DType otype) { + if (mismatch_idx.size() == 0) { + return; + } + const size_t col_blocks_size = rowwise ? 32 : 1; + const size_t row_blocks_size = rowwise ? 1 : 32; + GTEST_LOG_(INFO) << "Adjusting reference data for " << mismatch_idx.size() + << " scale mismatches in tensor " << name << " " + << (rowwise ? "rowwise" : "colwise") << " direction." << std::endl; + for (const auto scale_idx : mismatch_idx) { + const int scale_diff = ref_scale[scale_idx] - test_scale[scale_idx]; double scale_val; - const size_t col_blocks_size = cols / col_blocks; - const size_t row_blocks_size = rows / row_blocks; - for (const auto &[i, j, scale_diff] : mismatch_idx) { - if (scale_diff == 1) { - scale_val = 2.; - } else if (scale_diff == -1) { - scale_val = .5; - } else { // Shouldn't ever reach this - GTEST_FAIL() << "Error in adjust_ref, |scale_diff| > 1"; - } - size_t ii_min = i * row_blocks_size; - const size_t ii_max = std::min(ii_min + row_blocks_size, rows); - for (; ii_min < ii_max; ii_min++) { - size_t jj_min = j * col_blocks_size; - const size_t jj_max = std::min(jj_min + col_blocks_size, cols); - for (; jj_min < jj_max; jj_min++) { - const size_t data_idx = ii_min * cols + jj_min; + if (scale_diff == 1) { + scale_val = 2.; + } else if (scale_diff == -1) { + scale_val = .5; + } else { + GTEST_FAIL() << "Error in " << name << ": mismatch " << test_scale[scale_idx] << " vs " + << ref_scale[scale_idx] << " at index " << scale_idx; + } + const int i = scale_idx / scale_stride; + const int j = scale_idx % scale_stride; + size_t ii_min = i * row_blocks_size; + const size_t ii_max = std::min(ii_min + row_blocks_size, rows); + for (; ii_min < ii_max; ii_min++) { + size_t jj_min = j * col_blocks_size; + const size_t jj_max = std::min(jj_min + col_blocks_size, cols); + for (; jj_min < jj_max; jj_min++) { + const size_t data_idx = ii_min * cols + jj_min; + TRANSFORMER_ENGINE_TYPE_SWITCH_FP8_ONLY(otype, T, { + T *ref_data = reinterpret_cast(ref_ptr); ref_data[data_idx] = static_cast(static_cast(ref_data[data_idx]) * scale_val); - } + }); // NOLINT(*) } } - ); // NOLINT(*) + } } #endif // #ifdef __HIP_PLATFORM_AMD__ diff --git a/tests/cpp/test_common.h b/tests/cpp/test_common.h index 716cbc65a..9a84995cc 100644 --- a/tests/cpp/test_common.h +++ b/tests/cpp/test_common.h @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -484,17 +484,17 @@ void compareResults(const std::string &name, const uint8_t *test, const uint8_t size_t N, float mismatch_rate_tol = 0.); void compare_e8m0_scaling_factors(const std::string &name, const uint8_t *test, const uint8_t *ref, const size_t row_blocks, const size_t col_blocks, const size_t stride, - size_t& mismatches_num, + std::vector &mismatch_indices, size_t& mismatches_num, const size_t scale_diff_abs_tolerance = 0, const double abs_tolerable_mismatches_limit = 0, const double rel_tolerable_mismatches_limit = 0); -#ifdef USE_ROCM -void compare_e8m0_scaling_factors(const std::string &name, Tensor &output, const uint8_t *ref, - const size_t row_blocks, const size_t col_blocks, const size_t stride, - double tol, bool rowwise, std::vector> &mismatch_idx); -void adjust_ref(std::vector> mismatch_idx, void *ref, const size_t row_blocks, - const size_t col_blocks, const size_t rows, const size_t cols, DType otype); +#ifdef USE_ROCM +void adjust_ref_for_e8m0_scale_error(const std::string &name, + const std::vector &mismatch_idx, + const uint8_t *test_scale, const uint8_t *ref_scale, + const size_t scale_stride, const size_t rows, + const size_t cols, bool rowwise, void *ref_ptr, DType otype); #endif std::array get_scale_tensor_dims(const size_t rows, const size_t cols, diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index 7c08539c3..3f3b5db84 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -8,7 +10,8 @@ import pytest import jax -from jax.experimental.pjit import pjit, _UNSPECIFIED +from jax._src.pjit import pjit +from jax._src.sharding_impls import UNSPECIFIED as _UNSPECIFIED from transformer_engine.jax.sharding import MeshResource diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index f697c51ff..8d11cbaed 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -231,7 +231,7 @@ def _test_layernorm_mlp_grad( multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) - # TODO: skip cases with single fwd as nan/inf//PIV TODO: is it AMD path? + # TODO: skip cases with single fwd as nan/inf if jnp.any(jnp.isnan(single_fwd)) or jnp.any(jnp.isinf(single_fwd)): pytest.skip("skip tests with nan/inf single fwd.") diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index 720d10d50..d3a82b2d3 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -135,6 +135,7 @@ def impl_test_softmax( f"{str(w)}" ) + @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]]) @pytest.mark.parametrize( @@ -176,7 +177,6 @@ def test_softmax( @pytest.mark.parametrize("softmax_type", [SoftmaxType.SCALED, SoftmaxType.SCALED_MASKED]) @pytest.mark.parametrize("bad_sharding", [False, True]) @pytest.mark.parametrize("broadcast_batch_mask", [False, True]) - @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0")//PIV TODO: move to shary test def test_softmax_gspmd( self, device_count, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 959ef1a00..4d7718cd0 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -397,8 +397,12 @@ def _check_configs(self): self.head_dim_v, (-1, -1) if self.window_size is None else self.window_size, ).get_fused_attn_backend() - if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - pytest.skip("Unsupported inputs combination or device compute capability.") + if is_hip_extension(): + if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: + pytest.skip("Unsupported inputs combination or device compute capability.") + else: + if self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: + pytest.skip("Unsupported inputs combination or device compute capability.") if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index dbd7c5936..d2fe65604 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -95,8 +95,6 @@ def reset_attn_backend(): "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3"]) yield -#PIV TODO: _get_attention_backends is moved to attention_utils.py - model_configs_base = { # test: b, h, hg, d, sq, skv, p, mask, bias @@ -130,10 +128,11 @@ def test_dot_product_mem_calc(): if not is_bf16_compatible(): pytest.skip("This test requires bf16 support.") dtype = torch.bfloat16 - config = ModelConfig(16, 128, 8, 128, 8192, 8192, 0.0, "causal", "no_bias") + # b, sq, q, dqk + config = ModelConfig(16, 8192, 128, 128, num_gqa_groups=8, attn_mask_type="causal") is_training = config.head_dim_qk <= 128 and config.head_dim_v <= 128 qkv_layout = "sbhd_sbhd_sbhd" - _, _, fused_attn_backends = _get_attention_backends( + _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -376,12 +375,10 @@ def test_dpa_checkpoint(dtype, model_configs, model): "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference - "mla_4_0": ModelConfig( #PIV TODO - 10, 16, 16, 192, 4096, 4096, 0.0, "causal", "no_bias", head_dim_v=128 - ), - "mla_4_1": ModelConfig( - 10, 16, 16, 192, 4096, 4096, 0.0, "no_mask", "no_bias", head_dim_v=128 - ), + #"mla_4_0": ModelConfig(#PIV TODO: do cross 0 and cross 1 cover it + # 10, 4096, 16, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 + #), + #"mla_4_1": ModelConfig(10, 4096, 16, 192, max_seqlen_kv=4096, head_dim_v=128), } @@ -815,15 +812,17 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between if (pad_between_seqs==False and get_cudnn_version() < (9, 3, 0)): pytest.skip("cuDNN 9.3.0+ is required to run pad_between_seqs = False"); - _, _, fused_attn_backends = _get_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - window_size=config.window_size, - pad_between_seqs=pad_between_seqs, - ) - if share_cu_seqlens_ref and FusedAttnBackend["CK"] not in fused_attn_backends: - pytest.skip("This test is only required for the CK fused attention backend.") + if share_cu_seqlens_ref: #ROCm specific config + _, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, + ) + if FusedAttnBackend["CK"] not in fused_attn_backends: + pytest.skip("This test is only required for the CK fused attention backend.") + test_dot_product_attention( dtype, model_configs, model, False, True, qkv_layout, False, pad_between_seqs, share_cu_seqlens_ref ) @@ -837,15 +836,16 @@ def test_dpa_qkv_layout_thd(dtype, model_configs, model, qkv_layout, pad_between def test_dpa_qkv_layout_thd_mqa_gqa(dtype, model_configs, model, qkv_layout, pad_between_seqs, share_cu_seqlens_ref): config = model_configs[model] - _, _, fused_attn_backends = _get_attention_backends( - config, - qkv_dtype=dtype, - qkv_layout=qkv_layout, - window_size=config.window_size, - pad_between_seqs=pad_between_seqs, - ) - if share_cu_seqlens_ref and FusedAttnBackend["CK"] not in fused_attn_backends: - pytest.skip("This test is only required for the CK fused attention backend.") + if share_cu_seqlens_ref: + _, _, fused_attn_backends = get_available_attention_backends( + config, + qkv_dtype=dtype, + qkv_layout=qkv_layout, + window_size=config.window_size, + pad_between_seqs=pad_between_seqs, + ) + if FusedAttnBackend["CK"] not in fused_attn_backends: + pytest.skip("This test is only required for the CK fused attention backend.") def find_factors(x): f = [] diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index 288c5382e..bded2a82a 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -14,6 +16,7 @@ import torch from torch.distributions import Exponential +from torch.utils.cpp_extension import IS_HIP_EXTENSION from transformer_engine.pytorch import make_graphed_callables from transformer_engine.common import recipe from transformer_engine.pytorch import fp8_autocast, fp8_model_init @@ -401,6 +404,11 @@ def get_tols(config, module, backend, dtype): @pytest.mark.parametrize("is_cuda_graph", [False, True]) @pytest.mark.parametrize("is_fp8", [False, True]) def test_kv_cache(dtype, model, qkv_format, is_paged, backend, module, is_cuda_graph, is_fp8): + if IS_HIP_EXTENSION: + if is_paged and backend == "FusedAttention": + pytest.skip("Paged KV cache is not supported for FusedAttention on ROCm") + if qkv_format == "thd" and backend == "FusedAttention": + pytest.skip("THD KV cache is not supported for FusedAttention on ROCm") reset_rng_states() logger = logging.getLogger("test_kv_cache") fp8_recipe = recipe.DelayedScaling( diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 70fcfe272..7340bbecf 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -163,6 +163,8 @@ def is_fused_attn_available( is_training=is_training, deterministic=deterministic, ) + if IS_HIP_EXTENSION: + return fused_attn_backends != [] return FusedAttnBackend["F16_arbitrary_seqlen"] in fused_attn_backends @@ -1180,7 +1182,7 @@ def _test_granular_accuracy_with_fp8(block, bs, dtype, config): reset_rng_states() inp_hidden_states = torch.randn( - (config.seq_len, bs, config.hidden_size), + (config.max_seqlen_q, bs, config.hidden_size), dtype=dtype, device="cuda", requires_grad=True, @@ -2046,7 +2048,7 @@ def test_grouped_linear_accuracy( if IS_HIP_EXTENSION: if dtype not in (torch.float32,) and fuse_wgrad_accumulation and not fp8: pytest.skip(f"Rocm does not support fused wgrad accumulation for {dtype}.") - if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED://PIV TODO FP8 support check + if fp8 and fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") config = model_configs[model] @@ -2323,8 +2325,8 @@ def _generate_random_numbers(n, total_sum): breaks = sorted(random.sample(range(1, total_sum), n - 1)) random_numbers = ( [breaks[0]] - [breaks[i] - breaks[i - 1] for i in range(1, n - 1)]#PIV TODO: fix changes - [total_sum - breaks[-1]] + + [breaks[i] - breaks[i - 1] for i in range(1, n - 1)] + + [total_sum - breaks[-1]] ) return random_numbers diff --git a/tests/pytorch/test_sanity.py b/tests/pytorch/test_sanity.py index 453618e19..a7d762c3d 100644 --- a/tests/pytorch/test_sanity.py +++ b/tests/pytorch/test_sanity.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -566,9 +566,6 @@ def test_sanity_gpt( normalization, parallel_attention_mlp, ): - if IS_HIP_EXTENSION and cpu_offload: - pytest.skip("cpu_offloading not supported in rocm TE") - config = model_configs[model] if fp8_recipe is not None: diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 6ee10b5a8..f1f443bdc 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -12,6 +12,7 @@ import pytest import torch +from torch.utils.cpp_extension import IS_HIP_EXTENSION import transformer_engine import transformer_engine.common.recipe @@ -287,6 +288,24 @@ def test(): _attention_backends["backend_selection_requires_update"] = False return available_backends, flash_attention_backend, fused_attention_backend + if IS_HIP_EXTENSION: + backends = {"AOTriton": "AOTRITON", "CK": "CK"} + if AttentionLogging._is_logging_setup is False: + AttentionLogging.setup_logging() + with logging_context(highest_level=AttentionLogging._log_level): + for i in backends.keys(): + for k in backends.keys(): + os.environ["NVTE_FUSED_ATTN_"+backends[k]] = "0" + os.environ["NVTE_FUSED_ATTN_"+backends[i]] = "1" + _attention_backends["backend_selection_requires_update"] = True + available_backends, flash_attention_backend, fused_attention_backend = test() + if fused_attention_backend == FusedAttnBackend[i]: + fused_attn_backends.append(fused_attention_backend) + for i in backends.keys(): + del os.environ["NVTE_FUSED_ATTN_"+backends[i]] + available_backends[1] = len(fused_attn_backends) > 0 + return available_backends, flash_attention_backend, fused_attn_backends + backends = {0: "F16_max512_seqlen", 1: "F16_arbitrary_seqlen", 2: "FP8"} if AttentionLogging._is_logging_setup is False: AttentionLogging.setup_logging() diff --git a/transformer_engine/common/CMakeLists.txt b/transformer_engine/common/CMakeLists.txt index 5d20f3b0a..24bae4c1d 100644 --- a/transformer_engine/common/CMakeLists.txt +++ b/transformer_engine/common/CMakeLists.txt @@ -146,7 +146,6 @@ list(APPEND transformer_engine_SOURCES activation/relu.cu activation/swiglu.cu gemm/cublaslt_gemm.cu - gemm/cutlass_grouped_gemm.cu normalization/common.cpp normalization/layernorm/ln_api.cpp normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -184,6 +183,7 @@ list(APPEND transformer_engine_SOURCES fused_attn/fused_attn_fp8.cu fused_attn/fused_attn.cpp fused_attn/utils.cu + gemm/cutlass_grouped_gemm.cu util/cuda_nvml.cpp comm_gemm_overlap/userbuffers/ipcsocket.cc comm_gemm_overlap/userbuffers/userbuffers-host.cpp @@ -237,9 +237,11 @@ else() add_library(transformer_engine SHARED ${te_hip_sources}) target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}") endif() + target_include_directories(transformer_engine PUBLIC "${CMAKE_CURRENT_SOURCE_DIR}/include") +if (USE_CUDA) if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) set_source_files_properties( "gemm/cutlass_grouped_gemm.cu" @@ -249,6 +251,7 @@ if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0) else() message(FATAL_ERROR "cutlass gemm/cutlass_grouped_gemm.cu kernel required sm 90a") endif() +endif() #USE_CUDA # Configure dependencies if (USE_CUDA) @@ -393,6 +396,8 @@ else() string_code_transpose_rtc_cast_transpose_cu) make_string_header_from_file(transpose/rtc/transpose.hip string_code_transpose_rtc_transpose_cu) + make_string_header_from_file(transpose/rtc/swap_first_dims.hip + string_code_transpose_rtc_swap_first_dims_cu) make_string_header_from_file(amd_detail/hip_float8.h string_code_amd_detail_hip_float8_h) make_string_header_from_file(amd_detail/hip_f8_impl.h diff --git a/transformer_engine/common/common.cu b/transformer_engine/common/common.cu index 1d40d026f..e67694c38 100644 --- a/transformer_engine/common/common.cu +++ b/transformer_engine/common/common.cu @@ -28,6 +28,7 @@ __global__ void __launch_bounds__(1) } // namespace +#ifndef __HIP_PLATFORM_AMD__ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { using namespace transformer_engine; switch (t) { @@ -45,6 +46,7 @@ cudaDataType_t get_cuda_dtype(const transformer_engine::DType t) { NVTE_ERROR("Invalid type"); } } +#endif void update_tensor_scale_inv(Tensor *t, cudaStream_t stream) { if (is_fp8_dtype(t->data.dtype) && is_tensor_scaling(t->scaling_mode)) { @@ -116,6 +118,7 @@ void nvte_memset(void *ptr, int value, size_t size_in_bytes, cudaStream_t stream } } // extern "C" +#ifndef __HIP_PLATFORM_AMD__ void checkCuDriverContext(CUstream stream) { // Ensure the thread's "current" CUDA context is set. cuda_driver::ensure_context_exists(); @@ -140,7 +143,6 @@ void checkCuDriverContext(CUstream stream) { } } -#ifndef __HIP_PLATFORM_AMD__ CUtensorMapDataType get_CUtensorMapDataType(DType dtype) { static const std::unordered_map dtypeMapping = []() { std::unordered_map typeMapping = { diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index a42f691f3..ce510334b 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -278,6 +278,7 @@ struct QuantizationConfig { cudaDataType_t get_cuda_dtype(const transformer_engine::DType t); +#ifndef __HIP_PLATFORM_AMD__ template constexpr T DIVUP(const T &x, const T &y) { return (((x) + ((y)-1)) / (y)); @@ -289,6 +290,22 @@ constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(const T "Integral type required."); return DIVUP(static_cast(N), static_cast(M)) * M; } +#else +// DIVUP is called with integral types only for which passing by value is preferred. +// It also allows using of constexpr arguments w/o needing to create storage for references. +template +constexpr T DIVUP(T x, T y) { + static_assert(std::is_integral::value, "Integral type required."); + return (((x) + ((y)-1)) / (y)); +} + +template +constexpr __device__ __host__ __forceinline__ uint64_t DIVUP_TO_MULTIPLE(T1 N, T2 M) { + static_assert(std::is_integral::value && std::is_integral::value, + "Integral type required."); + return DIVUP(static_cast(N), static_cast(M)) * M; +} +#endif //__HIP_PLATFORM_AMD__ using byte = uint8_t; using int16 = int16_t; @@ -704,9 +721,11 @@ constexpr size_t scale_tensor_alignment_Y_rowwise = 128; constexpr size_t scale_tensor_alignment_X_colwise = 128; constexpr size_t scale_tensor_alignment_Y_colwise = 4; +#ifndef __HIP_PLATFORM_AMD__ // Alignment requirements for the Tensor Memory Accelerator (TMA) constexpr size_t TMA_GMEM_ALIGNMENT = 16; // global memory address alignment constexpr size_t TMA_SHMEM_ALIGNMENT = 128; // shared memory address alignment +#endif inline bool is_aligned_ptr(const void *ptr, size_t alignment) { return reinterpret_cast(ptr) % alignment == 0; @@ -737,9 +756,11 @@ void update_tensor_scale_inv(Tensor *t, cudaStream_t stream); #define NVTE_API_CALL(api_name) \ transformer_engine::nvtx::NVTXWrapper _##api_name##_nvtx_wrapper(#api_name); +#ifdef __HIP_PLATFORM_AMD__ +#define checkCuDriverContext(stream) {} +#else void checkCuDriverContext(CUstream stream); -#ifndef __HIP_PLATFORM_AMD__ CUtensorMapDataType get_CUtensorMapDataType(DType dtype); // Set up parameters to create TMA descriptor. @@ -747,7 +768,7 @@ void create_2D_tensor_map(CUtensorMap &tensorMap, const SimpleTensor &tensor, const uint64_t globalY, const uint64_t globalX, const uint32_t shmemY, const uint32_t shmemX, const uint32_t stride_elems, const uint32_t offset_elems, const size_t type_num_bits); -#endif //#ifndef __HIP_PLATFORM_AMD__ +#endif //#ifdef __HIP_PLATFORM_AMD__ bool is_supported_by_CC_100(); diff --git a/transformer_engine/common/dropout/dropout.cu b/transformer_engine/common/dropout/dropout.cu index bab349161..c7a5555e2 100644 --- a/transformer_engine/common/dropout/dropout.cu +++ b/transformer_engine/common/dropout/dropout.cu @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -41,9 +43,15 @@ __device__ __forceinline__ uint32_t bytewise_less_than(uint32_t a, uint32_t b) { // MSBs are 0 if the low bits of a are less than the low bits of b. uint32_t result = (a | 0x80808080) - (b & 0x7F7F7F7F); +#ifndef __HIP_PLATFORM_AMD__ // Bitwise logical op to get answer in MSBs // Equivalent logic: result = (a == b) ? !result : b asm("lop3.b32 %0, %1, %2, %3, 0x4D;\n\t" : "=r"(result) : "r"(a), "r"(b), "r"(result)); +#else + // AMD GPU: Use bitwise ops to get answer in MSBs + uint32_t mask = (a ^ b); + result = (mask & b) | ~(mask | result); +#endif // Mask out everything except MSBs and return result &= 0x80808080; diff --git a/transformer_engine/common/gemm/cublaslt_gemm.cu b/transformer_engine/common/gemm/cublaslt_gemm.cu index 777af085c..9c2ca9b4c 100644 --- a/transformer_engine/common/gemm/cublaslt_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_gemm.cu @@ -11,7 +11,7 @@ #include #include #endif // #ifndef __HIP_PLATFORM_AMD__ -#include + #include #include #include @@ -24,7 +24,9 @@ #include "../util/logging.h" #include "../util/multi_stream.h" #include "common/util/cuda_runtime.h" +#ifndef __HIP_PLATFORM_AMD__ #include "cutlass_grouped_gemm.cuh" +#endif #ifndef __HIP_PLATFORM_AMD__ namespace { @@ -227,10 +229,11 @@ namespace transformer_engine { #ifdef __HIP_PLATFORM_AMD__ //Forward declaration. The implementation is in rocm_gemm.cu void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, bool transa, bool transb, bool grad, - void* workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count, int m_split, int n_split, bool gemm_producer, - const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset = -1); + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, + cublasOperation_t transb, bool grad, void* workspace, size_t workspaceSize, + float alpha, float beta, bool use_split_accumulator, int math_sm_count, + int m_split, int n_split, bool gemm_producer, const Tensor *inputCounter, + hipStream_t stream, int compute_stream_offset = -1); #else // Use cublasLt using cublasHandleManager = detail::HandleManager; @@ -612,33 +615,6 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, } // namespace transformer_engine -// compute_stream_offset = -1 means the stream from outer rather than compute_streams -static void cublas_gemm_ex(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, - NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, - NVTETensor workspace, bool accumulate, bool use_split_accumulator, - int math_sm_count, cudaStream_t stream, int compute_stream_offset = -1) { - using namespace transformer_engine; - const Tensor *inputA = convertNVTETensorCheck(A); - const Tensor *inputB = convertNVTETensorCheck(B); - Tensor *outputD = convertNVTETensorCheck(D); - const Tensor *biasTensor = convertNVTETensorCheck(bias); - Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out); - Tensor *wspace = convertNVTETensorCheck(workspace); - - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, -#ifdef __HIP_PLATFORM_AMD__ - transa, transb, -#else - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, -#endif //__HIP_PLATFORM_AMD__ - grad, wspace->data.dptr, wspace->data.shape[0], - accumulate, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream -#ifdef __HIP_PLATFORM_AMD__ - , compute_stream_offset -#endif //__HIP_PLATFORM_AMD__ - ); -} - void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, const NVTETensor bias, NVTETensor pre_gelu_out, bool transa, bool transb, bool grad, NVTETensor workspace, bool accumulate, bool use_split_accumulator, @@ -652,14 +628,8 @@ void nvte_cublas_gemm(const NVTETensor A, const NVTETensor B, NVTETensor D, cons Tensor *outputGelu = convertNVTETensor(pre_gelu_out); Tensor *wspace = convertNVTETensor(workspace); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, -#ifdef __HIP_PLATFORM_AMD__ - transa, transb, -#else - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, - (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, -#endif - grad, wspace->data.dptr, wspace->data.shape[0], + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, 0, 0, false, nullptr, stream); } @@ -724,13 +694,8 @@ void nvte_cublas_atomic_gemm(const NVTETensor A, const NVTETensor B, NVTETensor NVTE_CHECK(is_delayed_tensor_scaling(inputA->scaling_mode) && is_delayed_tensor_scaling(inputB->scaling_mode), "Atomic GEMM only supports delayed scaling."); - cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, -#ifdef __HIP_PLATFORM_AMD__ - transa, transb, -#else - (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, -#endif //__HIP_PLATFORM_AMD__ - grad, wspace->data.dptr, wspace->data.shape[0], + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, + (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, wspace->data.dptr, wspace->data.shape[0], 1.0f, (accumulate) ? 1.0f : 0.0f, use_split_accumulator, math_sm_count, m_split, n_split, gemm_producer, inputCounter, stream); } @@ -753,9 +718,26 @@ void multi_stream_cublas_gemm(const NVTETensor *A, const NVTETensor *B, NVTETens } for (int i = 0; i < num_gemms; i++) { - cublas_gemm_ex(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, - workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, - detail::get_compute_stream(i % num_streams), i % num_streams); +#ifdef __HIP_PLATFORM_AMD__ + { + const Tensor *inputA = convertNVTETensorCheck(A[i]); + const Tensor *inputB = convertNVTETensorCheck(B[i]); + Tensor *outputD = convertNVTETensorCheck(D[i]); + const Tensor *biasTensor = convertNVTETensorCheck(bias[i]); + Tensor *outputGelu = convertNVTETensorCheck(pre_gelu_out[i]); + Tensor *wspace = convertNVTETensorCheck(workspace[i % num_streams]); + + cublas_gemm(inputA, inputB, outputD, biasTensor, outputGelu, + (transa) ? CUBLAS_OP_T : CUBLAS_OP_N, (transb) ? CUBLAS_OP_T : CUBLAS_OP_N, grad, + wspace->data.dptr, wspace->data.shape[0], 1.0f, (accumulate) ? 1.0f : 0.0f, + use_split_accumulator, math_sm_count, 0, 0, false, nullptr, + detail::get_compute_stream(i % num_streams), i % num_streams); + } +#else + nvte_cublas_gemm(A[i], B[i], D[i], bias[i], pre_gelu_out[i], transa, transb, grad, + workspace[i % num_streams], accumulate, use_split_accumulator, math_sm_count, + detail::get_compute_stream(i % num_streams)); +#endif } // record events on compute streams @@ -796,6 +778,7 @@ using cublasHandleManager = detail::HandleManager("NVTE_USE_CUTLASS_GROUPED_GEMM", false); @@ -877,5 +864,5 @@ void nvte_multi_tensor_gemm(const NVTETensor *A, const NVTETensor *B, NVTETensor } cublas_path(); } +#endif // __HIP_PLATFORM_AMD__ } -#endif diff --git a/transformer_engine/common/gemm/rocm_gemm.cu b/transformer_engine/common/gemm/rocm_gemm.cu index 50710ee1b..bbc53db35 100644 --- a/transformer_engine/common/gemm/rocm_gemm.cu +++ b/transformer_engine/common/gemm/rocm_gemm.cu @@ -952,13 +952,9 @@ void hipblaslt_gemm(const Tensor *inputA, bool grad, void* workspace, size_t workspaceSize, - bool accumulate, + float alpha, float beta, bool use_split_accumulator, int math_sm_count, - int m_split, - int n_split, - bool gemm_producer, - const Tensor *inputCounter, hipStream_t stream, hipblasLtHandle_t handle ) { @@ -992,7 +988,7 @@ void hipblaslt_gemm(const Tensor *inputA, << " gelu=" << (outputPreGelu->data.dptr != nullptr) << " use_fp8=" << use_fp8 << " scale_mode=" << (a_tensor ? "tensor" : a_block ? "mxfp8" : "unsupported") - << " accumulate=" << accumulate + << " alpha=" << alpha << " beta=" << beta << std::endl; } @@ -1039,13 +1035,6 @@ void hipblaslt_gemm(const Tensor *inputA, NVTE_CHECK(!gelu, "fp8 gemm + gelu fusion is unavailable right now!"); } #endif - if (is_fp8_dtype(outputD->data.dtype)) { - NVTE_CHECK(!accumulate, "Accumulation mode not supported with FP8 GEMM output!"); - } - - float one = 1.0; - float zero = 0.0; - float beta = (accumulate) ? one : zero; int device_id; NVTE_CHECK_CUDA(hipGetDevice(&device_id)); @@ -1217,7 +1206,7 @@ void hipblaslt_gemm(const Tensor *inputA, if (HIPBLAS_STATUS_SUCCESS == hipblaslt_ext::matmulIsAlgoSupported( handle, operationDesc, - static_cast(&one), + static_cast(&alpha), Adesc, Bdesc, static_cast(&beta), @@ -1297,7 +1286,7 @@ void hipblaslt_gemm(const Tensor *inputA, // Warm-up call NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&one), /* alpha */ + static_cast(&alpha), /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ @@ -1319,7 +1308,7 @@ void hipblaslt_gemm(const Tensor *inputA, { NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&one), /* alpha */ + static_cast(&alpha), /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ @@ -1380,7 +1369,7 @@ void hipblaslt_gemm(const Tensor *inputA, // D = alpha * (A * B) + beta * C NVTE_CHECK_HIPBLASLT(hipblasLtMatmul(handle, operationDesc, - static_cast(&one), /* alpha */ + static_cast(&alpha), /* alpha */ param.A, /* A */ Adesc, param.B, /* B */ @@ -1523,10 +1512,12 @@ void release_service_stream(hipStream_t stream, struct ServiceStreamCtl &ctl) void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, - const Tensor *inputBias, Tensor *outputPreGelu, bool transa, bool transb, bool grad, - void *workspace, size_t workspaceSize, bool accumulate, bool use_split_accumulator, - int math_sm_count, int m_split, int n_split, bool gemm_producer, - const Tensor *inputCounter, hipStream_t stream, int compute_stream_offset) + const Tensor *inputBias, Tensor *outputPreGelu, cublasOperation_t transa, + cublasOperation_t transb, bool grad, void *workspace, size_t workspaceSize, + float alpha, float beta, bool use_split_accumulator, int math_sm_count, + [[maybe_unused]] int m_split, [[maybe_unused]] int n_split, + [[maybe_unused]] bool gemm_producer, [[maybe_unused]] const Tensor *inputCounter, + hipStream_t stream, int compute_stream_offset) { // Tensor dims in row-major order const int A0 = inputA->flat_first_dim(); @@ -1534,19 +1525,21 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, const int B0 = inputB->flat_first_dim(); const int B1 = inputB->flat_last_dim(); + const bool is_transa = transa == CUBLAS_OP_T; + const bool is_transb = transb == CUBLAS_OP_T; + // GEMM dims in column-major order - const int m = transa ? A0 : A1; - const int n = transb ? B1 : B0; - const int k = transa ? A1 : A0; - NVTE_CHECK((transb ? B0 : B1) == k, + const int m = is_transa ? A0 : A1; + const int n = is_transb ? B1 : B0; + const int k = is_transa ? A1 : A0; + NVTE_CHECK((is_transb ? B0 : B1) == k, "GEMM inputs have incompatible dimensions (A is ", A0, "x", A1, ", B is ", B0, "x", B1, ")"); - const int lda = transa ? k : m; - const int ldb = transb ? n : k; + const int lda = is_transa ? k : m; + const int ldb = is_transb ? n : k; const int ldd = m; - ServiceStreamCtl ss_ctl; bool use_service_stream = (math_sm_count != 0) ? get_service_stream(math_sm_count, stream, ss_ctl) : false; @@ -1564,14 +1557,9 @@ void cublas_gemm(const Tensor *inputA, const Tensor *inputB, Tensor *outputD, handle = hipblaslt_handles[compute_stream_offset]; } - hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, - m, n, k, lda, ldb, ldd, - (transa) ? HIPBLAS_OP_T : HIPBLAS_OP_N, - (transb) ? HIPBLAS_OP_T : HIPBLAS_OP_N, - grad, - workspace, workspaceSize, accumulate, use_split_accumulator, - math_sm_count, m_split, n_split, gemm_producer, - inputCounter, use_service_stream ? ss_ctl.stream : stream, handle); + hipblaslt_gemm(inputA, inputB, outputD, inputBias, outputPreGelu, m, n, k, lda, ldb, ldd, transa, + transb, grad, workspace, workspaceSize, alpha, beta, use_split_accumulator, + math_sm_count, use_service_stream ? ss_ctl.stream : stream, handle); if (use_service_stream) { diff --git a/transformer_engine/common/include/transformer_engine/recipe.h b/transformer_engine/common/include/transformer_engine/recipe.h index d39d8f98b..89515108a 100644 --- a/transformer_engine/common/include/transformer_engine/recipe.h +++ b/transformer_engine/common/include/transformer_engine/recipe.h @@ -92,38 +92,41 @@ constexpr int amax_kernel_threads = 512; */ void nvte_compute_amax(const NVTETensor input, NVTETensor output, cudaStream_t stream); -#ifdef __HIP_PLATFORM_AMD__ - -size_t nvte_amax_workspace_num_blocks(size_t N); - -/*! \brief Compute an FP8 tensor's amax. +/*! \brief Compute an FP8 tensor's amax with quantization config. * * The amax (maximum absolute value) of the input tensor is computed - * and written to the amax buffer of the output tensor. + * and written to the amax buffer of the output tensor, using the provided + * quantization configuration. + * One useful config is the noop tensor, which is needed by cuda graph. * * \param[in] input Input tensor. Must be unquantized. * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. - * \param[out] workspace Output tensor. Must be FP32. + * \param[in] config Quantization configuration. * \param[in] stream CUDA stream used for the operation. */ -void nvte_compute_amax_with_workspace(const NVTETensor input, NVTETensor output, NVTETensor workspace, cudaStream_t stream); +void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, + const NVTEQuantizationConfig config, cudaStream_t stream); -#endif +#ifdef __HIP_PLATFORM_AMD__ -/*! \brief Compute an FP8 tensor's amax with quantization config. +size_t nvte_amax_workspace_num_blocks(size_t N); + +/*! \brief Compute an FP8 tensor's amax. * * The amax (maximum absolute value) of the input tensor is computed - * and written to the amax buffer of the output tensor, using the provided - * quantization configuration. - * One useful config is the noop tensor, which is needed by cuda graph. + * and written to the amax buffer of the output tensor. * * \param[in] input Input tensor. Must be unquantized. * \param[in,out] output Output tensor. Must be an FP8 tensor with per-tensor scaling. + * \param[out] workspace Output tensor. Must be FP32. * \param[in] config Quantization configuration. * \param[in] stream CUDA stream used for the operation. */ -void nvte_compute_amax_with_config(const NVTETensor input, NVTETensor output, - const NVTEQuantizationConfig config, cudaStream_t stream); +void nvte_compute_amax_with_workspace(const NVTETensor input, NVTETensor output, + NVTETensor workspace, const NVTEQuantizationConfig config, + cudaStream_t stream); + +#endif /*! \brief Update an FP8 tensor's scale based on its amax. * diff --git a/transformer_engine/common/normalization/layernorm/ln_api.cpp b/transformer_engine/common/normalization/layernorm/ln_api.cpp index 55fcb99e0..8c6fccfb5 100644 --- a/transformer_engine/common/normalization/layernorm/ln_api.cpp +++ b/transformer_engine/common/normalization/layernorm/ln_api.cpp @@ -67,12 +67,12 @@ void layernorm_fwd(const Tensor& x, // BxSxhidden_size bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); -#endif //__HIP_PLATFORM_AMD__ if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, "cuDNN does not currently support amax output for non quantized output"); } +#endif //__HIP_PLATFORM_AMD__ bool gamma_in_weight_dtype = false; #ifndef __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu index 45c82dbb4..757e8f900 100644 --- a/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu +++ b/transformer_engine/common/normalization/layernorm/ln_bwd_semi_cuda_kernel.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2022-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -17,7 +17,7 @@ template void launch_ln_bwd_tuned_(LaunchParams &launch_params, - const bool configure_params) { // NOLINT(*)//PIV TODO: static + const bool configure_params) { // NOLINT(*) using Kernel_traits = Kernel_traits; auto kernel = &ln_bwd_tuned_kernel; diff --git a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp index 5ee5a4908..6c85cc432 100644 --- a/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp +++ b/transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp @@ -53,12 +53,12 @@ void rmsnorm_fwd(const Tensor &x, const Tensor &gamma, const float epsilon, Tens bool is_aligned = true; #ifndef __HIP_PLATFORM_AMD__ bool cudnn_backend = use_cudnn_norm_fwd() || is_mxfp_scaling(z->scaling_mode); -#endif if (!is_fp8_dtype(z->data.dtype) && z->amax.dptr != nullptr) { NVTE_CHECK(!cudnn_backend, "cuDNN does not currently support amax output for non quantized output"); } +#endif bool training = is_delayed_tensor_scaling(z->scaling_mode) || (z->columnwise_data).dptr != nullptr; @@ -216,9 +216,11 @@ void rmsnorm_bwd_add(const Tensor &dz, const Tensor &x, const Tensor &add, const // cuDNN does not currently support fused backward+add NVTE_Norm_Backend norm_backend = NVTE_Norm_Backend::Te; +#ifndef __HIP_PLATFORM_AMD__ // TE backend does not currently support zero_centered_gamma_in_weight_dtype NVTE_CHECK(!use_zero_centered_gamma_in_weight_dtype(), "zero_centered_gamma_in_weight_dtype is currently not supported for rmsnorm_bwd_add"); +#endif bool is_aligned = is_ptr_aligned(x.data.dptr, gamma.data.dptr, rsigma.data.dptr, dx->data.dptr, dz.data.dptr, dgamma->data.dptr, add.data.dptr); diff --git a/transformer_engine/common/recipe/current_scaling.cu b/transformer_engine/common/recipe/current_scaling.cu index d3c92732f..f3c6b7952 100644 --- a/transformer_engine/common/recipe/current_scaling.cu +++ b/transformer_engine/common/recipe/current_scaling.cu @@ -53,17 +53,17 @@ __global__ void amax_final_reduce(const float* __restrict__ block_amax, template __launch_bounds__(amax_kernel_threads) __global__ + void amax_kernel(const InputType *input, float *amax, #ifdef __HIP_PLATFORM_AMD__ - void amax_kernel(const InputType *input, float *amax, float* __restrict__ block_amax, const size_t N, - const size_t num_aligned_elements) { + float* __restrict__ block_amax, #else - void amax_kernel(const InputType *input, float *amax, const size_t N, - const size_t num_aligned_elements, const float *noop_ptr) { + [[maybe_unused]] void* __restrict__ block_amax, +#endif + const size_t N, const size_t num_aligned_elements, const float *noop_ptr) { if (noop_ptr != nullptr && noop_ptr[0] == 1.0f) { return; } -#endif //PIV TODO: noop_ptr for ROCm kernel VectorizedLoader loader(input, N); InputType max{0.f}; const int warp_id = threadIdx.x / THREADS_PER_WARP; @@ -112,9 +112,11 @@ __launch_bounds__(amax_kernel_threads) __global__ } template -void launch_amax_kernel(const InputType *input, float *amax, const size_t N, float *block_amax, - size_t block_capacity, const float *noop_ptr, - cudaStream_t stream) {//PIV TODO: CUDA vs ROCm differences +void launch_amax_kernel(const InputType *input, float *amax, const size_t N, +#ifdef __HIP_PLATFORM_AMD__ + float *block_amax, size_t block_capacity, +#endif + const float *noop_ptr, cudaStream_t stream) { // Zero out amax so we can update with atomic max NVTE_CHECK_CUDA(cudaMemsetAsync(amax, 0, sizeof(float), stream)); @@ -133,7 +135,7 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo size_t num_blocks = DIVUP(num_aligned_elements, threads); constexpr size_t max_blocks = 65535; num_blocks = std::min(num_blocks, max_blocks); - + constexpr void* block_amax = nullptr; #else constexpr size_t threads = amax_kernel_threads; size_t num_blocks = nvte_amax_workspace_num_blocks(num_aligned_elements); @@ -144,32 +146,18 @@ void launch_amax_kernel(const InputType *input, float *amax, const size_t N, flo // Launch kernel switch (align) { case Alignment::SAME_ALIGNED: -#ifdef __HIP_PLATFORM_AMD__ amax_kernel - <<>>(input, amax, block_amax, N, num_aligned_elements); -#else - amax_kernel - <<>>(input, amax, N, num_aligned_elements, noop_ptr); -#endif + <<>>(input, amax, block_amax, N, num_aligned_elements, noop_ptr); break; case Alignment::SAME_UNALIGNED: -#ifdef __HIP_PLATFORM_AMD__ amax_kernel - <<>>(input, amax, block_amax, N, num_aligned_elements); -#else - amax_kernel - <<>>(input, amax, N, num_aligned_elements, noop_ptr); -#endif + <<>>(input, amax, block_amax, N, num_aligned_elements, noop_ptr); break; case Alignment::DIFFERENT: { // This case is a logic error, since there is only one pointer (input) // in the alignment check. Still safe to process without vectorization. -#ifdef __HIP_PLATFORM_AMD__ - amax_kernel<1, true, InputType><<>>(input, amax, block_amax, N, N); -#else amax_kernel<1, true, InputType> - <<>>(input, amax, N, N, noop_ptr); -#endif + <<>>(input, amax, block_amax, N, N, noop_ptr); break; } } @@ -208,6 +196,11 @@ size_t nvte_amax_workspace_num_blocks(size_t N) { namespace { void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream, +#ifdef __HIP_PLATFORM_AMD__ + const NVTETensor workspace_, +#else + [[maybe_unused]] const NVTETensor workspace_, +#endif const NVTEQuantizationConfig config_) { using namespace transformer_engine; @@ -285,18 +278,21 @@ void compute_amax_impl(const NVTETensor input_, const NVTETensor output_, cudaSt void nvte_compute_amax(const NVTETensor input_, const NVTETensor output_, cudaStream_t stream) { NVTE_API_CALL(nvte_compute_amax); - compute_amax_impl(input_, output_, stream, nullptr); + compute_amax_impl(input_, output_, stream, nullptr, nullptr); } void nvte_compute_amax_with_config(const NVTETensor input_, const NVTETensor output_, const NVTEQuantizationConfig config_, cudaStream_t stream) { NVTE_API_CALL(nvte_compute_amax_with_config); - compute_amax_impl(input_, output_, stream, config_); + compute_amax_impl(input_, output_, stream, nullptr, config_); } #ifdef __HIP_PLATFORM_AMD__ -void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, const NVTETensor workspace_, cudaStream_t stream) { - compute_amax_impl(input_, output_, /*workspace=*/nullptr, stream); //PIV TODO: proper parameters +void nvte_compute_amax_with_workspace(const NVTETensor input_, const NVTETensor output_, + NVTETensor workspace_, const NVTEQuantizationConfig config_, + cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_amax_with_workspace); + compute_amax_impl(input_, output_, stream, workspace_, config_); } #endif diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 141a2e991..12cb4ea31 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -24,6 +24,7 @@ namespace { #define __ldg(x) (*(x)) #endif +#ifndef __HIP_PLATFORM_AMD__ constexpr __device__ __host__ int MXFP8_BLOCK_SIZE = 32; constexpr __device__ __host__ int TB_DIM = 32; constexpr __device__ __host__ int NEW_SF_TILE_DIM_K = 16; @@ -32,6 +33,18 @@ constexpr __device__ __host__ int N_SF_PER_TD_PER_TILE = 4; // output is in ~K-major interleaved blocks constexpr __device__ __host__ int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; constexpr __device__ __host__ int NEW_SF_TILE_DIM_M_I32 = 32; +#else +// HIPCC does not support __host__ qualifier for variables +// and constexpr values do not need __device__ qualifier because they are compile-time constants +constexpr int MXFP8_BLOCK_SIZE = 32; +constexpr int TB_DIM = 32; +constexpr int NEW_SF_TILE_DIM_K = 16; +constexpr int N_SF_PER_TD_PER_TILE = 4; + +// output is in ~K-major interleaved blocks +constexpr int NEW_SF_TILE_DIM_K_I32 = NEW_SF_TILE_DIM_K / 4; +constexpr int NEW_SF_TILE_DIM_M_I32 = 32; +#endif template __device__ inline void regs_shuffle_with_bit_shifts(LType* regs_vec) { diff --git a/transformer_engine/common/util/cast_gated_kernels.cuh b/transformer_engine/common/util/cast_gated_kernels.cuh index 6e02091a5..dcb3aa42d 100644 --- a/transformer_engine/common/util/cast_gated_kernels.cuh +++ b/transformer_engine/common/util/cast_gated_kernels.cuh @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -988,6 +988,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out NVTE_CHECK(output->columnwise_scale_inv.dptr != nullptr, "Scaling tensor must be allocated."); } +#ifndef __HIP_PLATFORM_AMD__ ScalingType scaling_type; if (USE_ROWWISE_SCALING && (!USE_COLWISE_SCALING)) { scaling_type = ScalingType::ROWWISE; @@ -996,11 +997,23 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out } else if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { scaling_type = ScalingType::BIDIMENSIONAL; } +#endif const size_t rows = gated_input.flat_first_dim(); const size_t cols = gated_input.flat_last_dim() / 2; const size_t output_cols = (IS_DGATED ? 2 : 1) * cols; +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t TMA_SHMEM_ALIGNMENT = ALIGNMENT_SIZE; + + constexpr size_t BUFF_DIM_Y = BUFFER_DIM_Y; + constexpr size_t BUFF_DIM_X = BUFFER_DIM_X; + constexpr size_t BUFFS_NUM = BUFFERS_NUM; + + const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); +#else + constexpr size_t BUFF_DIM_Y = mxfp8_kernel::BUFF_DIM_Y; constexpr size_t BUFF_DIM_X = mxfp8_kernel::BUFF_DIM_X; constexpr size_t BUFFS_NUM = mxfp8_kernel::BUFFS_NUM; @@ -1013,6 +1026,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t THREADS_PER_CHUNK = (scaling_type == ScalingType::COLWISE) ? THREADS_PER_CHUNK_COLWISE : THREADS_PER_CHUNK_NON_COLWISE; +#endif const dim3 grid(blocks_X, blocks_Y); const dim3 block_size(THREADS_PER_CHUNK); @@ -1029,16 +1043,18 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out gated_input.dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, + #ifdef __HIP_PLATFORM_AMD__ - TRANSFORMER_ENGINE_SWITCH_CONDITION(//PIV TODO - !(cols % (32 * sizeof(IType))), IS_ALIGNED, - const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; - const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr); - const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols; - OType *tensor_map_output_act_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; - OType *tensor_map_output_gate_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) + cols : nullptr; - OType *tensor_map_output_act_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; - OType *tensor_map_output_gate_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) + cols : nullptr; + const IType *tensor_map_grad = IS_DGATED ? reinterpret_cast(grad.data.dptr) : nullptr; + const IType *tensor_map_input_act = reinterpret_cast(gated_input.data.dptr); + const IType *tensor_map_input_gate = reinterpret_cast(gated_input.data.dptr) + cols; + OType *tensor_map_output_act_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) : nullptr; + OType *tensor_map_output_gate_rowwise = USE_ROWWISE_SCALING ? reinterpret_cast(output->data.dptr) + cols : nullptr; + OType *tensor_map_output_act_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) : nullptr; + OType *tensor_map_output_gate_colwise = USE_COLWISE_SCALING ? reinterpret_cast(output->columnwise_data.dptr) + cols : nullptr; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; #else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_grad{}; alignas(64) CUtensorMap tensor_map_input_act{}; @@ -1095,12 +1111,38 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out const size_t in_mem = grad_mem + in_act_mem + in_gate_mem; const size_t out_act_mem = buff_size_aligned_out; +#ifdef __HIP_PLATFORM_AMD__ + const size_t out_gate_mem = buff_size_aligned_out; +#else const size_t out_gate_mem = (IS_DGATED ? buff_size_aligned_out : 0); +#endif size_t out_mem = out_act_mem + out_gate_mem; if (USE_ROWWISE_SCALING && USE_COLWISE_SCALING) { out_mem *= 2; } const size_t shmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; +#ifdef __HIP_PLATFORM_AMD__ + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_COLWISE_SCALING ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (USE_ROWWISE_SCALING ? 32 : 1), SCALE_DIM_X, + TRANSFORMER_ENGINE_SWITCH_CONDITION(!(cols % (32 * sizeof(IType))), IS_ALIGNED, { + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + cast_mxfp8_gated_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); + + cast_mxfp8_gated_kernel + <<>>( + tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, + tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, + tensor_map_output_act_colwise, tensor_map_output_gate_colwise, + scales_rowwise_ptr, scales_colwise_ptr, rows, cols, scale_stride_rowwise, + scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + }))); // NOLINT(*) +#else switch (scaling_type) { case ScalingType::ROWWISE: NVTE_CHECK_CUDA(cudaFuncSetAttribute( @@ -1110,7 +1152,7 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); mxfp8_kernel::cast_mxfp8_gated_kernel//PIV TODO: is_aligned + true, false, THREADS_PER_CHUNK_NON_COLWISE> <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, @@ -1120,22 +1162,14 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out NVTE_CHECK_CUDA(cudaGetLastError()); break; case ScalingType::COLWISE: - NVTE_CHECK_CUDA(NVTE_CHECK_CUDA(cudaFuncSetAttribute( - mxfp8_kernel::cast_mxfp8_gated_kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size))); + THREADS_PER_CHUNK_COLWISE>, + cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size)); mxfp8_kernel::cast_mxfp8_gated_kernel + false, true, THREADS_PER_CHUNK_COLWISE> <<>>( tensor_map_grad, tensor_map_input_act, tensor_map_input_gate, tensor_map_output_act_rowwise, tensor_map_output_gate_rowwise, @@ -1161,11 +1195,10 @@ void cast_mxfp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *out scale_stride_colwise); NVTE_CHECK_CUDA(cudaGetLastError()); break; - }); // NOLINT(*) - ); // NOLINT(*) -#ifdef __HIP_PLATFORM_AMD__ - ); // NOLINT(*) + } #endif + ); // NOLINT(*) + ); // NOLINT(*) } template diff --git a/transformer_engine/common/util/cast_kernels.cuh b/transformer_engine/common/util/cast_kernels.cuh index 479f082f1..b7c4cf837 100644 --- a/transformer_engine/common/util/cast_kernels.cuh +++ b/transformer_engine/common/util/cast_kernels.cuh @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -1014,8 +1014,10 @@ template has_data(); bool use_colwise_scaling = output->has_columnwise_data(); @@ -1034,6 +1036,11 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t rows = input.flat_first_dim(); const size_t cols = input.flat_last_dim(); +#ifdef __HIP_PLATFORM_AMD__ + constexpr size_t CHUNK_DIM_Y = MXFP8_CHUNK_DIM_Y; + constexpr size_t CHUNK_DIM_X = MXFP8_CHUNK_DIM_X; + constexpr size_t THREADS_PER_CHUNK = MXFP8_THREADS_PER_CHUNK; +#else constexpr bool CAST_DBIAS_ONLY = IS_DBIAS && (!IS_DACT) && (!IS_ACT); constexpr size_t CHUNK_DIM_Y = CAST_DBIAS_ONLY ? 128 : 64; @@ -1044,6 +1051,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, constexpr size_t THREADS_Y = THREADS_PER_CHUNK / THREADS_X; constexpr size_t BUFF_DIM_Y = THREADS_Y; constexpr size_t BUFF_DIM_X = CHUNK_DIM_X; +#endif const size_t blocks_Y = DIVUP(rows, CHUNK_DIM_Y); const size_t blocks_X = DIVUP(cols, CHUNK_DIM_X); @@ -1061,6 +1069,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const size_t dbias_rows = blocks_Y; const size_t dbias_cols = cols; +#ifndef __HIP_PLATFORM_AMD__ ScalingType scaling_type; if (use_rowwise_scaling && (!use_colwise_scaling)) { scaling_type = ScalingType::ROWWISE; @@ -1069,6 +1078,7 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, } else if (use_rowwise_scaling && use_colwise_scaling) { scaling_type = ScalingType::BIDIMENSIONAL; } +#endif if constexpr (IS_DBIAS) { NVTE_CHECK(dbias->data.dtype == input.dtype(), "DBias must have the same type as input."); @@ -1091,10 +1101,15 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, #ifdef __HIP_PLATFORM_AMD__ + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_colwise_scaling ? 32 : 1), SCALE_DIM_Y, + TRANSFORMER_ENGINE_MX_SCALE_DIM_SWITCH( + (use_rowwise_scaling ? 32 : 1), SCALE_DIM_X, TRANSFORMER_ENGINE_SWITCH_CONDITION( !(cols % (32 * sizeof(IType))), IS_ALIGNED, - cast_mxfp8_2D_kernel<<>>( + cast_mxfp8_2D_kernel + <<>>( reinterpret_cast(input.data.dptr), (IS_DACT) ? reinterpret_cast(act_input->data.dptr) : nullptr, reinterpret_cast(output->data.dptr), @@ -1102,6 +1117,8 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, scales_rowwise_ptr, scales_colwise_ptr, reinterpret_cast(noop->data.dptr), workspace_ptr, amax_ptr, rows, cols, scale_stride_rowwise, scale_stride_colwise); + NVTE_CHECK_CUDA(cudaGetLastError()); + ))); // NOLINT(*) #else // #ifdef __HIP_PLATFORM_AMD__ alignas(64) CUtensorMap tensor_map_input{}; @@ -1202,9 +1219,6 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input, reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); }); // NOLINT(*) ); // NOLINT(*) -#ifdef __HIP_PLATFORM_AMD__ - ); // NOLINT(*) -#endif } namespace detail { diff --git a/transformer_engine/common/util/cuda_runtime.cpp b/transformer_engine/common/util/cuda_runtime.cpp index 896f09e50..6f3f117d4 100644 --- a/transformer_engine/common/util/cuda_runtime.cpp +++ b/transformer_engine/common/util/cuda_runtime.cpp @@ -157,6 +157,8 @@ bool supports_multicast(int device_id) { return false; #endif } +#endif // __HIP_PLATFORM_AMD__ + const std::string &include_directory(bool required) { static std::string path; @@ -169,16 +171,28 @@ const std::string &include_directory(bool required) { if (need_to_check_env) { // Search for CUDA headers in common paths using Path = std::filesystem::path; +#ifdef __HIP_PLATFORM_AMD__ + std::vector> search_paths = {{"ROCM_PATH", ""}, + {"HIP_PATH", ""}, + {"", "/opt/rocm"}}; +#else std::vector> search_paths = {{"NVTE_CUDA_INCLUDE_DIR", ""}, {"CUDA_HOME", ""}, {"CUDA_DIR", ""}, {"", string_path_cuda_include}, {"", "/usr/local/cuda"}}; +#endif for (auto &[env, p] : search_paths) { if (p.empty()) { p = getenv(env.c_str()); } if (!p.empty()) { +#ifdef __HIP_PLATFORM_AMD__ + if (file_exists(p / "include" / "hip" / "hip_runtime.h")) { + path = p / "include"; + break; + } +#else if (file_exists(p / "cuda_runtime.h")) { path = p; break; @@ -187,6 +201,7 @@ const std::string &include_directory(bool required) { path = p / "include"; break; } +#endif } } @@ -194,7 +209,11 @@ const std::string &include_directory(bool required) { if (path.empty() && required) { std::string message; message.reserve(2048); +#ifdef __HIP_PLATFORM_AMD__ + message += "Could not find hip/hip_runtime.h in"; +#else message += "Could not find cuda_runtime.h in"; +#endif bool is_first = true; for (const auto &[env, p] : search_paths) { message += is_first ? " " : ", "; @@ -209,11 +228,18 @@ const std::string &include_directory(bool required) { message += p; } } +#ifdef __HIP_PLATFORM_AMD__ + message += + (". " + "Specify path to ROCM headers with ROCM_PATH " + "or disable NVRTC support with NVTE_DISABLE_NVRTC=1."); +#else message += (". " "Specify path to CUDA Toolkit headers " "with NVTE_CUDA_INCLUDE_DIR " "or disable NVRTC support with NVTE_DISABLE_NVRTC=1."); +#endif NVTE_ERROR(message); } need_to_check_env = false; @@ -223,6 +249,7 @@ const std::string &include_directory(bool required) { return path; } +#ifndef __HIP_PLATFORM_AMD__ int cudart_version() { auto get_version = []() -> int { int version; diff --git a/transformer_engine/common/util/cuda_runtime.h b/transformer_engine/common/util/cuda_runtime.h index 58712c9d9..069981347 100644 --- a/transformer_engine/common/util/cuda_runtime.h +++ b/transformer_engine/common/util/cuda_runtime.h @@ -68,10 +68,11 @@ void stream_priority_range(int *low_priority, int *high_priority, int device_id * \return CUDA multicast support flag */ bool supports_multicast(int device_id = -1); +#endif -/* \brief Path to CUDA Toolkit headers +/* \brief Path to CUDA/ROCm Toolkit headers * - * The path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the + * On CUDA platform the path can be configured by setting NVTE_CUDA_INCLUDE_DIR in the * environment. Otherwise searches in common install paths. * * \param[in] required Whether to throw exception if not found @@ -80,6 +81,7 @@ bool supports_multicast(int device_id = -1); */ const std::string &include_directory(bool required = false); +#ifndef __HIP_PLATFORM_AMD__ /* \brief CUDA Runtime version number at run-time * * Versions may differ between compile-time and run-time. diff --git a/transformer_engine/common/util/logging.h b/transformer_engine/common/util/logging.h index 297972475..6ab5eb958 100644 --- a/transformer_engine/common/util/logging.h +++ b/transformer_engine/common/util/logging.h @@ -6,7 +6,8 @@ * See LICENSE for license information. ************************************************************************/ -#pragma once +#ifndef TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ +#define TRANSFORMER_ENGINE_COMMON_UTIL_LOGGING_H_ #include #ifdef __HIP_PLATFORM_AMD__ diff --git a/transformer_engine/common/util/ptx.cuh b/transformer_engine/common/util/ptx.cuh index 581de9f9f..7c38a337b 100644 --- a/transformer_engine/common/util/ptx.cuh +++ b/transformer_engine/common/util/ptx.cuh @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2022-2025, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -118,6 +120,9 @@ __device__ __forceinline__ float exp2f(e8m0_t biased_exp) { } __device__ __forceinline__ e8m0_t float_to_e8m0(float val) { +#ifdef __HIP_PLATFORM_AMD__ +#define __CUDA_ARCH_HAS_FEATURE__(x) 0 +#endif #if ((__CUDA_ARCH_HAS_FEATURE__(SM100_ALL)) || (__CUDA_ARCH_HAS_FEATURE__(SM101_ALL)) || \ (__CUDA_ARCH_HAS_FEATURE__(SM120_ALL))) uint16_t out; diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index b8fee6862..94e246e3f 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -1,23 +1,23 @@ /************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once +#include #include #include -#include -#include - -#include -#include "../common.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" +#include "common.h" #include "math.h" -#include "../util/rocm_vectorized_2d.cuh" +#include "ptx.cuh" +#include "rocm_vectorized_2d.cuh" +#include "transformer_engine/activation.h" +#include "transformer_engine/cast.h" +#include "vectorized_pointwise.h" +#include "utils.cuh" namespace transformer_engine { namespace gated_kernels { @@ -170,6 +170,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh[shmem_idx]); float gate_elt = static_cast(in_gate_sh[shmem_idx]); + float after_act_elt; + float after_gate_elt; if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_idx]); @@ -185,20 +187,31 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) act_x = ActOP(x, {}); dact_x = DActOP(x, {}); } - after_dact_reg[stage] = dact_x * grad_elt * gate_elt; - after_dgate_reg[stage] = act_x * grad_elt; + after_act_elt = dact_x * grad_elt * gate_elt; + after_gate_elt = act_x * grad_elt; + after_dact_reg[stage] = after_act_elt; + after_dgate_reg[stage] = after_gate_elt; } else { - after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; + after_act_elt = ActOP(act_elt, {}) * gate_elt; + after_dact_reg[stage] = after_act_elt; + } + + // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 + if constexpr (!std::is_same_v) { + after_act_elt = static_cast(static_cast(after_act_elt)); + if constexpr (IS_DGATED) { + after_gate_elt = static_cast(static_cast(after_gate_elt)); + } } if constexpr (USE_ROWWISE_SCALING) { if constexpr (IS_DGATED) { // dgate - float amax = fabsf(after_dgate_reg[stage]); + float amax = fabsf(after_gate_elt); const float mx_block_X_amax = warp_reduce_max_broadcast(amax); const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); out_gate_rowwise_sh[shmem_idx] = static_cast(scale_reciprocal_X * after_dgate_reg[stage]); @@ -214,11 +227,11 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scales_rowwise[scale_idx] = biased_exponent_X; } } - float amax = fabsf(after_dact_reg[stage]); + float amax = fabsf(after_act_elt); const float mx_block_X_amax = warp_reduce_max_broadcast(amax); const e8m0_t biased_exponent_X = - float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal_X = exp2f_rcp(biased_exponent_X); + ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal_X = ptx::exp2f_rcp(biased_exponent_X); out_act_rowwise_sh[shmem_idx] = static_cast(scale_reciprocal_X * after_dact_reg[stage]); @@ -237,10 +250,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (USE_COLWISE_SCALING) { __builtin_assume(thread_Y_mx_block_amax >= 0); __builtin_assume(thread_Y_mx_block_amax_gate >= 0); - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_act_elt)); if constexpr (IS_DGATED) { thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); + fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_gate_elt)); } } } @@ -273,8 +286,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); + ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); // Only single thread writes the computed scaling factor // Also assuming one iteration covers exactly 32 rows @@ -319,8 +332,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) } const e8m0_t biased_exponent = - float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); - const float scale_reciprocal = exp2f_rcp(biased_exponent); + ptx::float_to_e8m0(mx_block_Y_amax * Quantized_Limits::max_norm_rcp); + const float scale_reciprocal = ptx::exp2f_rcp(biased_exponent); // Only single thread writes the computed scaling factor // Also assuming one iteration covers exactly 32 rows diff --git a/transformer_engine/common/util/rocm_cast_kernels.cuh b/transformer_engine/common/util/rocm_cast_kernels.cuh index d62350e0a..52a77733a 100644 --- a/transformer_engine/common/util/rocm_cast_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_kernels.cuh @@ -1,26 +1,32 @@ /************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once +#include #include #include -#include - -#include -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" +#include "common.h" #include "math.h" -#include "transformer_engine/transformer_engine.h" -#include "../util/rocm_vectorized_2d.cuh" +#include "ptx.cuh" +#include "rocm_vectorized_2d.cuh" +#include "transformer_engine/cast.h" +#include "transpose/cast_transpose.h" +#include "vectorized_pointwise.h" +#include "utils.cuh" namespace transformer_engine { +// Forward declaration, definition is in cast_kernels.cuh +template +void mxfp8_quantize(const Tensor &input, const Tensor *act_input, const Tensor *noop, + Tensor *output, Tensor *dbias, Tensor *workspace, cudaStream_t stream); + + constexpr size_t MXFP8_CHUNK_DIM_Y = 64; constexpr size_t MXFP8_CHUNK_DIM_X = 64; constexpr size_t MXFP8_CHUNKS_PER_BLOCK_Y = 1; @@ -209,6 +215,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) partial_dbias_rowwise[chunk_X].data.elt[j] += elt; } } + // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } in_compute[j] = elt; if (!out_of_bounds) { thread_amax = fmaxf(thread_amax, fabsf(elt)); @@ -221,7 +231,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) const float subwarp_amax = subwarp_reduce_max_broadcast(thread_amax); const e8m0_t biased_exponent = - float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + ptx::float_to_e8m0(subwarp_amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation // Only single thread writes the computed scaling factor if (tid_rowwise_X % THREADS_PER_SCALE_X_ROWWISE == 0) { @@ -234,7 +244,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) scales_rowwise[scale_idx] = biased_exponent; } - const float block_scale_inverse = exp2f_rcp(biased_exponent); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll for (int j = 0; j < ELEMS_PER_THREAD; j++) { @@ -268,6 +278,10 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) partial_dbias_colwise[chunk_X] += elt; } } + // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 + if constexpr (!std::is_same_v) { + elt = static_cast(static_cast(elt)); + } in_compute[i] = elt; if (!out_of_bounds) { amax = fmaxf(amax, fabsf(elt)); @@ -278,7 +292,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) __builtin_assume(amax >= 0); block_amax = fmaxf(block_amax, amax); - const e8m0_t biased_exponent = float_to_e8m0(amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation + const e8m0_t biased_exponent = ptx::float_to_e8m0(amax * Quantized_Limits::max_norm_rcp) + (IS_NORM ? 1 : 0); // Normalization requires a +1 scale to avoid saturation const int global_scales_offset_Y = scales_colwise_chunk_offset_Y + iter; const int global_scales_offset_X = scales_colwise_chunk_offset_X + tid_colwise_X; @@ -286,7 +300,7 @@ __global__ void __launch_bounds__(MXFP8_THREADS_PER_CHUNK) global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; scales_colwise[scale_idx] = biased_exponent; - const float block_scale_inverse = exp2f_rcp(biased_exponent); + const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); #pragma unroll for (int i = 0; i < SCALE_DIM_Y; i++) { out_colwise_sh[i][tid_colwise_X] = @@ -540,4 +554,5 @@ void fp8_quantize_rocm(const Tensor &input, const Tensor *act_input, const Tenso } } + } // namespace transformer_engine diff --git a/transformer_engine/common/util/rocm_dequantize_kernels.cuh b/transformer_engine/common/util/rocm_dequantize_kernels.cuh index ae5cb4bbd..398e4c0ad 100644 --- a/transformer_engine/common/util/rocm_dequantize_kernels.cuh +++ b/transformer_engine/common/util/rocm_dequantize_kernels.cuh @@ -1,26 +1,26 @@ /************************************************************************* - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * * License for AMD contributions = MIT. See LICENSE for more information ************************************************************************/ #pragma once +#include #include #include -#include - -#include #include -#include "../common.h" -#include "../transpose/cast_transpose.h" -#include "../util/vectorized_pointwise.h" -#include "../utils.cuh" +#include "common.h" #include "math.h" +#include "ptx.cuh" +#include "rocm_vectorized_2d.cuh" #include "transformer_engine/activation.h" +#include "transformer_engine/cast.h" +#include "transpose/cast_transpose.h" #include "transformer_engine/transpose.h" -#include "../util/rocm_vectorized_2d.cuh" +#include "utils.cuh" +#include "vectorized_pointwise.h" namespace transformer_engine { @@ -102,7 +102,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) const int scale_idx = scale_offset_Y * scales_stride + scale_offset_X; const e8m0_t biased_exponent = scales_ptr[scale_idx]; - const float block_scale = exp2f(static_cast(biased_exponent) - FP32_EXPONENT_BIAS); + const float block_scale = ptx::exp2f(biased_exponent); if constexpr (USE_ROWWISE_SCALING) { Vec in; diff --git a/transformer_engine/common/util/rtc.cpp b/transformer_engine/common/util/rtc.cpp index 054531169..82d0d9048 100644 --- a/transformer_engine/common/util/rtc.cpp +++ b/transformer_engine/common/util/rtc.cpp @@ -175,8 +175,8 @@ void KernelManager::compile(const std::string& kernel_label, const std::string& } else { opts.push_back(concat_strings("--gpu-architecture=sm_", compile_sm_arch)); } - opts.push_back(concat_strings("-I", cuda::include_directory(true))); #endif //__HIP_PLATFORM_AMD__ + opts.push_back(concat_strings("-I", cuda::include_directory(true))); std::vector opts_ptrs; for (const auto& opt : opts) { diff --git a/transformer_engine/common/utils.cuh b/transformer_engine/common/utils.cuh index 758db6171..4208f3511 100644 --- a/transformer_engine/common/utils.cuh +++ b/transformer_engine/common/utils.cuh @@ -1041,7 +1041,6 @@ struct Quantized_Limits { static constexpr float max_norm_rcp = 1.0 / max_norm; #endif // TE_DYNAMIC_HIP_FP8_TYPE }; -//PIV TODO: code moved to ptx.cuh } // namespace transformer_engine diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 9701955cd..4ba581c66 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. diff --git a/transformer_engine/pytorch/csrc/extensions/activation.cpp b/transformer_engine/pytorch/csrc/extensions/activation.cpp index 2a41d639f..8b0607c9e 100644 --- a/transformer_engine/pytorch/csrc/extensions/activation.cpp +++ b/transformer_engine/pytorch/csrc/extensions/activation.cpp @@ -12,7 +12,7 @@ namespace transformer_engine::pytorch { -template //PIV TODO: amax moved somewhere else +template py::object activation_helper(const at::Tensor& input, py::handle quantizer, int shape_divisor = 1) { init_extension(); diff --git a/transformer_engine/pytorch/csrc/extensions/bias.cpp b/transformer_engine/pytorch/csrc/extensions/bias.cpp index 5e82b591c..f65614d07 100644 --- a/transformer_engine/pytorch/csrc/extensions/bias.cpp +++ b/transformer_engine/pytorch/csrc/extensions/bias.cpp @@ -92,7 +92,7 @@ std::vector bgrad_quantize(const at::Tensor &grad_output, py::handle namespace { -std::vector dact_dbias(//PIV TODO: amax +std::vector dact_dbias( void (*dact_dbias_func)(const NVTETensor, const NVTETensor, NVTETensor, NVTETensor, NVTETensor, cudaStream_t), void (*dact_func)(const NVTETensor, const NVTETensor, NVTETensor, cudaStream_t), diff --git a/transformer_engine/pytorch/csrc/extensions/dropout.cpp b/transformer_engine/pytorch/csrc/extensions/dropout.cpp index e6f29d0da..d009cf2f3 100644 --- a/transformer_engine/pytorch/csrc/extensions/dropout.cpp +++ b/transformer_engine/pytorch/csrc/extensions/dropout.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -29,7 +31,8 @@ std::vector dropout_fwd(const py::handle &input, float dropout_proba // Allocate output tensor if needed if (!out) { at::ScalarType dtype = GetATenDType(input_nvte.dtype()); - if (dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2) { + if (dtype == at::kFloat8_e4m3fn || dtype == at::kFloat8_e5m2 || + dtype == at::kFloat8_e4m3fnuz || dtype == at::kFloat8_e5m2fnuz) { dtype = input.attr("dtype").cast(); } const auto shape_uint64 = convertShape(input_nvte.shape()); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 05893bf0b..55b1d179e 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2023-2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2023-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -428,12 +428,16 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::multi_tensor_compute_scale_and_scale_inv_cuda, "Fused compute scale and scale_inv from amax", py::call_guard()); +#ifndef USE_ROCM // Comm+GEMM Overlap m.def("bulk_overlap_ag_with_external_gemm", &transformer_engine::pytorch::bulk_overlap_ag_with_external_gemm, "Bulk overlap All-Gather with a GEMM operation launched by another communicator", py::call_guard(), py::arg("allgather_communicator"), py::arg("send_stream"), py::arg("recv_stream")); +#else + m.def("bulk_overlap_ag_with_external_gemm", &transformer_engine::pytorch::placeholder, "Dummy function for python side annotations"); +#endif // Data structures py::class_(m, "FP8TensorMeta") diff --git a/transformer_engine/pytorch/csrc/extensions/recipe.cpp b/transformer_engine/pytorch/csrc/extensions/recipe.cpp index 0ede63d77..10a889b56 100644 --- a/transformer_engine/pytorch/csrc/extensions/recipe.cpp +++ b/transformer_engine/pytorch/csrc/extensions/recipe.cpp @@ -31,7 +31,7 @@ void compute_amax(const at::Tensor& tensor, at::Tensor& amax) { at::Tensor ws = allocate_amax_workspace(te_input); TensorWrapper tw = makeTransformerEngineTensor(ws); nvte_compute_amax_with_workspace(te_input.data(), fake_te_output.data(), - tw.data(), + tw.data(), nullptr, at::cuda::getCurrentCUDAStream()); #else nvte_compute_amax(te_input.data(), fake_te_output.data(), at::cuda::getCurrentCUDAStream()); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index cd7e70fec..37c13362c 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -497,8 +499,16 @@ void Float8CurrentScalingQuantizer::quantize_impl(const TensorWrapper& input, Te // Compute amax if (compute_amax) { +#ifdef __HIP_PLATFORM_AMD__ + at::Tensor ws = allocate_amax_workspace(input); + TensorWrapper tw = makeTransformerEngineTensor(ws); + NVTE_SCOPED_GIL_RELEASE({ + nvte_compute_amax_with_workspace(input.data(), out.data(), tw.data(), quant_config, stream); + }); +#else NVTE_SCOPED_GIL_RELEASE( { nvte_compute_amax_with_config(input.data(), out.data(), quant_config, stream); }); +#endif } // Perform amax reduction if needed diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index ea0f8d25b..d1aeebc9f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -67,6 +67,7 @@ ) from ...debug.pytorch.debug_state import TEDebugState from ..tensor.float8_blockwise_tensor import Float8BlockQuantizer +from ..tensor.float8_tensor import Float8CurrentScalingQuantizer from ..tensor.mxfp8_tensor import MXFP8Quantizer from ..tensor._internal.mxfp8_tensor_base import MXFP8TensorBase from ..tensor._internal.float8_blockwise_tensor_base import Float8BlockwiseQTensorBase diff --git a/transformer_engine/pytorch/triton_kernels/cast.py b/transformer_engine/pytorch/triton_kernels/cast.py index 3ba81118f..f0f8563f1 100644 --- a/transformer_engine/pytorch/triton_kernels/cast.py +++ b/transformer_engine/pytorch/triton_kernels/cast.py @@ -1,4 +1,4 @@ -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # License for AMD contributions = MIT. See LICENSE for more information """Python interface for cast extensions""" @@ -118,6 +118,7 @@ def te_quantize_triton( ) else: + out.remove_caches() #Make sure to remove transpose if it is marked as invalid out = tex.quantize(input_tensor, quantizer, out, noop_flag) elif isinstance(out, MXFP8TensorBase): te_cast_transpose_mxfp8_triton(input_tensor, out) From 92ce3757dc75e6261643d9bf339f8ac9a5f96dea Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Tue, 13 Jan 2026 00:22:51 -0500 Subject: [PATCH 149/153] Fix JAX and Pytorch UT; code cleanup; ROCm 7.2 w/a (#404) --- ci/jax.sh | 5 ++++- tests/pytorch/attention/test_attention.py | 6 +----- tests/pytorch/attention/test_attention_with_cp.py | 5 ++++- tests/pytorch/attention/test_kv_cache.py | 6 ++++++ tests/pytorch/test_numerics.py | 15 +++++++++++++++ transformer_engine/pytorch/tensor/mxfp8_tensor.py | 8 ++++++-- 6 files changed, 36 insertions(+), 9 deletions(-) diff --git a/ci/jax.sh b/ci/jax.sh index 0f2aef2e8..b72a72e52 100755 --- a/ci/jax.sh +++ b/ci/jax.sh @@ -1,5 +1,5 @@ #!/bin/sh -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # # See LICENSE for license information. @@ -54,6 +54,7 @@ run_default_fa_lbl() { run_test_config() { echo ==== Run with Fused attention backend: $_fus_attn ==== + export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests run_default_fa 1 test_custom_call_compute.py run_default_fa 1 test_functions.py run 1 test_fused_attn.py @@ -75,8 +76,10 @@ run_test_config_mgpu() { if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then _dfa_level=2 + export NVTE_JAX_UNITTEST_LEVEL=L1 else _dfa_level=3 + export NVTE_JAX_UNITTEST_LEVEL=L2 fi run $_dfa_level test_distributed_fused_attn.py $_timeout_args run_default_fa 3 test_distributed_layernorm.py diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index d2fe65604..07ada9c3c 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -193,7 +193,7 @@ def test_dot_product_attention( config.window_size = [2, 2] config.window_size = check_set_window_size(config.attn_mask_type, config.window_size) - is_training = True #PIV TODO: config.head_dim_qk <= 192 and config.head_dim_v <= 128 + is_training = True available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, @@ -375,10 +375,6 @@ def test_dpa_checkpoint(dtype, model_configs, model): "mla_3_2": ModelConfig(8, 1, 16, 192, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_3": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=128), # inference "mla_3_4": ModelConfig(8, 1, 16, 160, max_seqlen_kv=2048, head_dim_v=160), # inference - #"mla_4_0": ModelConfig(#PIV TODO: do cross 0 and cross 1 cover it - # 10, 4096, 16, 192, max_seqlen_kv=4096, attn_mask_type="causal", head_dim_v=128 - #), - #"mla_4_1": ModelConfig(10, 4096, 16, 192, max_seqlen_kv=4096, head_dim_v=128), } diff --git a/tests/pytorch/attention/test_attention_with_cp.py b/tests/pytorch/attention/test_attention_with_cp.py index 093ccbcac..ece5a37de 100644 --- a/tests/pytorch/attention/test_attention_with_cp.py +++ b/tests/pytorch/attention/test_attention_with_cp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -92,6 +92,9 @@ def test_cp_with_flash_attention(dtype, model, qkv_format, cp_comm_type): ) if "p2p" not in cp_comm_type and config.head_dim_qk != config.head_dim_v: pytest.skip("MLA CP currently only support KV P2P!") + if IS_HIP_EXTENSION: + if config.head_dim_qk != config.head_dim_v and not FlashAttentionUtils.v3_is_installed: + pytest.skip("MLA FlashAttention requires v3+!") subprocess.run( get_bash_arguments( diff --git a/tests/pytorch/attention/test_kv_cache.py b/tests/pytorch/attention/test_kv_cache.py index bded2a82a..af71866f3 100644 --- a/tests/pytorch/attention/test_kv_cache.py +++ b/tests/pytorch/attention/test_kv_cache.py @@ -386,6 +386,12 @@ def get_tols(config, module, backend, dtype): torch.half: (1e-2, 1e-2), torch.bfloat16: (8e-2, 7e-2), } + # With FA on ROCm it may not fit default tolerance + if IS_HIP_EXTENSION and backend == "FlashAttention": + tols = { + torch.half: (1e-2, 1e-2), + torch.bfloat16: (1e-1, 1e-1), + } if module == "DotProductAttention": tols = { torch.half: (1e-3, 1e-3), diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 9df465511..8e3731198 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -656,6 +656,9 @@ def _test_e2e_selective_recompute( def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and + dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]): + pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.") config = model_configs[model] @@ -775,6 +778,8 @@ def test_gpt_full_activation_recompute( and recipe.float8_per_tensor_scaling() ): pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") + if (dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]): + pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.") config = model_configs[model] torch.compiler.reset() # avoid cache size limit overflow @@ -926,6 +931,10 @@ def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] if not is_fused_attn_available(config, dtype, deterministic=True): pytest.skip("No attention backend available.") + if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and + dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]): + pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.") + outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -2685,6 +2694,9 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): def test_gpt_fp8_parameters(dtype, bs, model, recipe): if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") + if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and + dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]): + pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.") config = model_configs[model] @@ -2972,6 +2984,9 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua pytest.skip(reason_for_no_fp8) if is_mxfp8_needed and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) + if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): + if isinstance(out_quantizer, Float8Quantizer): + pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) outp_type = torch.float32 diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 8d40f4403..16b1568cb 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -1,3 +1,5 @@ +# This file was modified for portability to AMDGPU +# Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -110,7 +112,8 @@ def make_empty( # Allocate FP8 data data = torch.empty(shape, dtype=torch.uint8, device=device) - scale_inv = torch.empty( + # ROCm TE does not implement fuse padding zeros so use zero tensor here + scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), dtype=torch.uint8, @@ -122,7 +125,8 @@ def make_empty( columnwise_scale_inv = None if self.columnwise_usage: columnwise_data = torch.empty_like(data) - columnwise_scale_inv = torch.empty( + # ROCm TE does not implement fuse padding zeros so use zero tensor here + columnwise_scale_inv = torch.zeros( round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), round_up_to_nearest_multiple(shape[-1], 128), dtype=torch.uint8, From a406914a9d17889fcc64d38e2cfed563a2b1a7b5 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Thu, 15 Jan 2026 17:42:37 -0500 Subject: [PATCH 150/153] Review comments --- tests/cpp/operator/test_cast_mxfp8.cu | 10 +++++++ tests/jax/test_distributed_layernorm_mlp.py | 11 ++++++-- tests/jax/test_distributed_softmax.py | 2 -- tests/pytorch/attention/test_attention.py | 29 ++++++-------------- tests/pytorch/test_numerics.py | 27 +++--------------- tests/pytorch/utils.py | 26 ++++++++++++++++-- transformer_engine/common/swizzle/swizzle.cu | 2 +- 7 files changed, 54 insertions(+), 53 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8.cu b/tests/cpp/operator/test_cast_mxfp8.cu index 9e4e12bf8..b635dc00b 100644 --- a/tests/cpp/operator/test_cast_mxfp8.cu +++ b/tests/cpp/operator/test_cast_mxfp8.cu @@ -311,8 +311,13 @@ void performTest_x1(const ProcessingMethod processing_method, : output_c.columnwise_cpu_scale_inv_ptr(); const size_t scale_diff_abs_tolerance = 0; +#ifdef __HIP_PLATFORM_AMD__ const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; +#else + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; +#endif std::vector mismatches_scales_indices; size_t mismatches_scales = 0; @@ -491,8 +496,13 @@ void performTest_x2(const ProcessingMethod processing_method, scales_stride_colwise); const size_t scale_diff_abs_tolerance = 0; +#ifdef __HIP_PLATFORM_AMD__ const double abs_tolerable_mismatches_limit = 1.0; const double rel_tolerable_mismatches_limit = 1.0e-4; +#else + const double abs_tolerable_mismatches_limit = 0.0; + const double rel_tolerable_mismatches_limit = 0.0; +#endif std::vector mismatches_scales_indices_rowwise; size_t mismatches_scales_rowwise = 0; diff --git a/tests/jax/test_distributed_layernorm_mlp.py b/tests/jax/test_distributed_layernorm_mlp.py index 8d11cbaed..0af10d050 100644 --- a/tests/jax/test_distributed_layernorm_mlp.py +++ b/tests/jax/test_distributed_layernorm_mlp.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -38,7 +38,11 @@ from transformer_engine.jax.quantize import QuantizerFactory from transformer_engine.jax.cpp_extensions.misc import get_min_device_compute_capability -from transformer_engine.jax.util import get_jnp_float8_e4m3_type, get_jnp_float8_e5m2_type +from transformer_engine.jax.util import ( + is_hip_extension, + get_jnp_float8_e4m3_type, + get_jnp_float8_e5m2_type, +) jnp_float8_e4m3_type = get_jnp_float8_e4m3_type() jnp_float8_e5m2_type = get_jnp_float8_e5m2_type() @@ -232,7 +236,8 @@ def _test_layernorm_mlp_grad( multi_fwd, multi_grads = multi_jitter(*multi_inputs, *static_inputs, True) # TODO: skip cases with single fwd as nan/inf - if jnp.any(jnp.isnan(single_fwd)) or jnp.any(jnp.isinf(single_fwd)): + if is_hip_extension() and (jnp.any(jnp.isnan(single_fwd)) or + jnp.any(jnp.isinf(single_fwd))): pytest.skip("skip tests with nan/inf single fwd.") fwd_test_type = dtype if fp8_recipe is None else jnp_float8_e4m3_type diff --git a/tests/jax/test_distributed_softmax.py b/tests/jax/test_distributed_softmax.py index d3a82b2d3..d9eaf314a 100644 --- a/tests/jax/test_distributed_softmax.py +++ b/tests/jax/test_distributed_softmax.py @@ -5,7 +5,6 @@ import warnings from functools import partial import pytest -from packaging import version import jax import jax.numpy as jnp @@ -135,7 +134,6 @@ def impl_test_softmax( f"{str(w)}" ) - @pytest.mark.skipif(version.parse(jax.__version__) < version.parse("0.5.0"), reason="shardy sharding requires JAX 0.5.0") @pytest.mark.parametrize("device_count,mesh_shape,mesh_axes,mesh_resource", generate_configs()) @pytest.mark.parametrize("data_shape", [[32, 12, 128, 128], [8, 8, 1024, 1024]]) @pytest.mark.parametrize( diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 07ada9c3c..3e4867508 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -73,27 +73,14 @@ def reset_global_fp8_state(): fp8.FP8GlobalStateManager.reset() -class EnvVarCleaner: - def __init__(self, envs_): - self.envs = envs_ - self.flags = {} - for env in self.envs: - if env in os.environ: - self.flags[env] = os.environ[env] - def __del__(self): - for env in self.envs: - if env in self.flags: - os.environ[env] = self.flags[env] - else: - os.environ.pop(env, None) - - -@pytest.fixture(autouse=True) -def reset_attn_backend(): - env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", - "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", - "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3"]) - yield +if IS_HIP_EXTENSION: + from utils import EnvVarCleaner + @pytest.fixture(autouse=True) + def reset_attn_backend(): + env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", + "NVTE_FUSED_ATTN_CK", "NVTE_FUSED_ATTN_AOTRITON", + "NVTE_CK_USES_FWD_V3", "NVTE_CK_USES_BWD_V3", "NVTE_FP8_DPA_BWD"]) + yield model_configs_base = { diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 8e3731198..77b82ec0b 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -56,6 +56,8 @@ from transformer_engine.common import recipe import transformer_engine_torch as tex from utils import ModelConfig, reset_rng_states, get_available_attention_backends +if IS_HIP_EXTENSION: + from utils import EnvVarCleaner # Only run FP8 tests on supported devices. @@ -236,28 +238,6 @@ def reset_global_fp8_state(): FP8GlobalStateManager.reset() -class EnvVarCleaner: - def __init__(self, envs_): - self.envs = envs_ - self.flags = {} - for env in self.envs: - if env in os.environ: - self.flags[env] = os.environ[env] - def __del__(self): - for env in self.envs: - if env in self.flags: - os.environ[env] = self.flags[env] - else: - os.environ.pop(env, None) - - -@pytest.fixture -def reset_test_envs(): - env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", - "NVTE_BIAS_GELU_NVFUSION"]) - yield - - class TorchScaledMaskedSoftmax(nn.Module): def __init__(self) -> None: super().__init__() @@ -765,7 +745,6 @@ def _test_e2e_full_recompute( @pytest.mark.parametrize("recipe", fp8_recipes) @pytest.mark.parametrize("fp8_model_params", all_boolean) @pytest.mark.parametrize("use_reentrant", all_boolean) -@pytest.mark.usefixtures("reset_test_envs") def test_gpt_full_activation_recompute( dtype, bs, model, fp8, recipe, fp8_model_params, use_reentrant ): @@ -785,6 +764,8 @@ def test_gpt_full_activation_recompute( torch.compiler.reset() # avoid cache size limit overflow if not use_reentrant: + if IS_HIP_EXTENSION: + env = EnvVarCleaner(["NVTE_BIAS_GELU_NVFUSION"]) # Non-reentrant checkpoint becomes non-deterministic with bias+GELU fusion os.environ["NVTE_BIAS_GELU_NVFUSION"] = "0" diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index f1f443bdc..5674535fa 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -200,6 +200,25 @@ def logging_context(highest_level=logging.WARNING): logging.disable(previous_level) +if IS_HIP_EXTENSION: + class EnvVarCleaner: + def __init__(self, envs_): + print("PIV create envs:", envs_) + self.envs = envs_ + self.flags = {} + for env in self.envs: + if env in os.environ: + self.flags[env] = os.environ[env] + + def __del__(self): + print("PIV destroty envs:", self.envs) + for env in self.envs: + if env in self.flags: + os.environ[env] = self.flags[env] + else: + os.environ.pop(env, None) + + def get_available_attention_backends( config: ModelConfig, qkv_dtype: torch.dtype, @@ -214,6 +233,9 @@ def get_available_attention_backends( inference_params: Optional[InferenceParams] = None, ) -> Tuple[List, List]: """Check for all available attention backends that support a model configuration""" + if IS_HIP_EXTENSION: + env = EnvVarCleaner(["NVTE_FLASH_ATTN", "NVTE_FUSED_ATTN", "NVTE_UNFUSED_ATTN", + "NVTE_FUSED_ATTN_AOTRITON", "NVTE_FUSED_ATTN_CK"]) os.environ["NVTE_FLASH_ATTN"] = "1" os.environ["NVTE_FUSED_ATTN"] = "1" @@ -301,8 +323,6 @@ def test(): available_backends, flash_attention_backend, fused_attention_backend = test() if fused_attention_backend == FusedAttnBackend[i]: fused_attn_backends.append(fused_attention_backend) - for i in backends.keys(): - del os.environ["NVTE_FUSED_ATTN_"+backends[i]] available_backends[1] = len(fused_attn_backends) > 0 return available_backends, flash_attention_backend, fused_attn_backends diff --git a/transformer_engine/common/swizzle/swizzle.cu b/transformer_engine/common/swizzle/swizzle.cu index 12cb4ea31..499f7bcff 100644 --- a/transformer_engine/common/swizzle/swizzle.cu +++ b/transformer_engine/common/swizzle/swizzle.cu @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. From cfe3fc3400f8cbcfe1943b54467523b699b834f9 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Fri, 16 Jan 2026 21:37:58 -0500 Subject: [PATCH 151/153] Remove not needed intermetdiate var in cast kernel. Update tests. Disable Pytorch MXFP8 scale swizzling --- tests/pytorch/attention/test_attention.py | 5 ++-- tests/pytorch/test_numerics.py | 15 +++--------- .../common/util/rocm_cast_gated_kernels.cuh | 23 ++++++++----------- .../pytorch/csrc/extensions/gemm.cpp | 8 ++++++- transformer_engine/pytorch/csrc/util.cpp | 6 +++++ transformer_engine/pytorch/csrc/util.h | 6 +++++ 6 files changed, 34 insertions(+), 29 deletions(-) diff --git a/tests/pytorch/attention/test_attention.py b/tests/pytorch/attention/test_attention.py index 9596a7000..a5128653e 100644 --- a/tests/pytorch/attention/test_attention.py +++ b/tests/pytorch/attention/test_attention.py @@ -113,10 +113,11 @@ def test_gqa_mla_thd(): Explicitly test dk_or_dv_reduce_thd as part of TE's CK integration post-processing for BWD FA with native padding support. """ - config = ModelConfig(8, 16, 4, 128, 128, 128, 0.0, "padding", "no_bias", head_dim_v=64) + # b, sq, h, dqk + config = ModelConfig(8, 128, 16, 128, num_gqa_groups= 4, head_dim_v=64, attn_mask_type="padding") qkv_layout = "thd_thd_thd" dtype = torch.float16 - _, _, fused_attn_backends = _get_attention_backends( + _, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/tests/pytorch/test_numerics.py b/tests/pytorch/test_numerics.py index 77b82ec0b..a4dfd64ba 100644 --- a/tests/pytorch/test_numerics.py +++ b/tests/pytorch/test_numerics.py @@ -636,9 +636,6 @@ def _test_e2e_selective_recompute( def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params): if fp8_model_params and NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and - dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]): - pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.") config = model_configs[model] @@ -757,8 +754,6 @@ def test_gpt_full_activation_recompute( and recipe.float8_per_tensor_scaling() ): pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") - if (dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]): - pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.") config = model_configs[model] torch.compiler.reset() # avoid cache size limit overflow @@ -912,9 +907,6 @@ def test_gpt_checkpointing(dtype, bs, model): config = model_configs[model] if not is_fused_attn_available(config, dtype, deterministic=True): pytest.skip("No attention backend available.") - if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and - dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]): - pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.") outputs = _test_e2e_checkpointing(bs, dtype, config, checkpoint=False) outputs_checkpoint = _test_e2e_checkpointing(bs, dtype, config, checkpoint=True) @@ -2675,9 +2667,6 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe): def test_gpt_fp8_parameters(dtype, bs, model, recipe): if NVTE_TEST_NVINSPECT_ENABLED: pytest.skip("FP8 parameters are not supported in debug mode.") - if (IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5) and - dtype in (torch.float16, torch.bfloat16) and rocm_attn_backend()[2]): - pytest.skip("Test is not supported on GFX950 with current parameters and CK fused attention backend and non-zero dropout.") config = model_configs[model] @@ -2966,7 +2955,9 @@ def test_fp8gemm_with_unfused_quantization(N, datatype, input_quantizer, out_qua if is_mxfp8_needed and not mxfp8_available: pytest.skip(reason_for_no_mxfp8) if IS_HIP_EXTENSION and get_device_compute_capability() == (9, 5): - if isinstance(out_quantizer, Float8Quantizer): + if isinstance(input_quantizer, MXFP8Quantizer): + N = math.ceil(N / 128) * 128 #hipBlasLt supports K which is multiple of 128 for MXFP8 + if not is_mxfp8_needed and isinstance(out_quantizer, Float8Quantizer): pytest.skip("hipBLASLt does not provide suitable algorithms on GFX950 for this config.") inp_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) weight_fp8 = input_quantizer(torch.randn(N, N, device="cuda", dtype=datatype)) diff --git a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh index 94e246e3f..a53fd51c5 100644 --- a/transformer_engine/common/util/rocm_cast_gated_kernels.cuh +++ b/transformer_engine/common/util/rocm_cast_gated_kernels.cuh @@ -170,8 +170,6 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) float act_elt = static_cast(in_act_sh[shmem_idx]); float gate_elt = static_cast(in_gate_sh[shmem_idx]); - float after_act_elt; - float after_gate_elt; if constexpr (IS_DGATED) { float grad_elt = static_cast(in_grad_sh[shmem_idx]); @@ -187,27 +185,24 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) act_x = ActOP(x, {}); dact_x = DActOP(x, {}); } - after_act_elt = dact_x * grad_elt * gate_elt; - after_gate_elt = act_x * grad_elt; - after_dact_reg[stage] = after_act_elt; - after_dgate_reg[stage] = after_gate_elt; + after_dact_reg[stage] = dact_x * grad_elt * gate_elt; + after_dgate_reg[stage] = act_x * grad_elt; } else { - after_act_elt = ActOP(act_elt, {}) * gate_elt; - after_dact_reg[stage] = after_act_elt; + after_dact_reg[stage] = ActOP(act_elt, {}) * gate_elt; } // Numerical truncation: downcast to IType (BF16/FP16) and upcast back to FP32 if constexpr (!std::is_same_v) { - after_act_elt = static_cast(static_cast(after_act_elt)); + after_dact_reg[stage] = static_cast(static_cast(after_dact_reg[stage])); if constexpr (IS_DGATED) { - after_gate_elt = static_cast(static_cast(after_gate_elt)); + after_dgate_reg[stage] = static_cast(static_cast(after_dgate_reg[stage])); } } if constexpr (USE_ROWWISE_SCALING) { if constexpr (IS_DGATED) { // dgate - float amax = fabsf(after_gate_elt); + float amax = fabsf(after_dgate_reg[stage]); const float mx_block_X_amax = warp_reduce_max_broadcast(amax); const e8m0_t biased_exponent_X = ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); @@ -227,7 +222,7 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) scales_rowwise[scale_idx] = biased_exponent_X; } } - float amax = fabsf(after_act_elt); + float amax = fabsf(after_dact_reg[stage]); const float mx_block_X_amax = warp_reduce_max_broadcast(amax); const e8m0_t biased_exponent_X = ptx::float_to_e8m0(mx_block_X_amax * Quantized_Limits::max_norm_rcp); @@ -250,10 +245,10 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) if constexpr (USE_COLWISE_SCALING) { __builtin_assume(thread_Y_mx_block_amax >= 0); __builtin_assume(thread_Y_mx_block_amax_gate >= 0); - thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_act_elt)); + thread_Y_mx_block_amax = fmaxf(thread_Y_mx_block_amax, fabsf(after_dact_reg[stage])); if constexpr (IS_DGATED) { thread_Y_mx_block_amax_gate = - fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_gate_elt)); + fmaxf(thread_Y_mx_block_amax_gate, fabsf(after_dgate_reg[stage])); } } } diff --git a/transformer_engine/pytorch/csrc/extensions/gemm.cpp b/transformer_engine/pytorch/csrc/extensions/gemm.cpp index c4af3e9db..b637d49c7 100644 --- a/transformer_engine/pytorch/csrc/extensions/gemm.cpp +++ b/transformer_engine/pytorch/csrc/extensions/gemm.cpp @@ -1,6 +1,6 @@ /************************************************************************* * This file was modified for portability to AMDGPU - * Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. + * Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -215,14 +215,18 @@ std::vector gemm(py::handle A, bool transa, py::handle B, bool trans const int sm_count = transformer_engine::cuda::sm_count(device_id); int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); +#ifndef USE_ROCM // Keep the swizzled scaling factor tensors alive during the GEMM. std::vector> swizzled_scale_inverses_list; +#endif auto main_stream = at::cuda::getCurrentCUDAStream(); if (A_tensor.numel() != 0 && B_tensor.numel() != 0) { +#ifndef USE_ROCM // Optionally swizzle the scaling factors swizzled_scale_inverses_list.emplace_back(std::move(swizzle_scaling_factors(A_tensor, transa))); swizzled_scale_inverses_list.emplace_back( std::move(swizzle_scaling_factors(B_tensor, !transb))); +#endif if (comm_overlap) { #ifndef USE_ROCM @@ -469,10 +473,12 @@ std::optional> te_general_grouped_gemm( wrappers.emplace_back(std::move(te_pre_gelu_out)); } +#ifndef USE_ROCM // Optionally swizzle the scaling factors // Keep the swizzled scaling factor tensors alive during the GEMMs. auto swizzled_scale_inv_A = multi_tensor_swizzle_scaling_factors(te_A_wrappers, transa); auto swizzled_scale_inv_B = multi_tensor_swizzle_scaling_factors(te_B_wrappers, !transb); +#endif for (size_t i = 0; i < workspace.size(); i++) { auto wsp = makeTransformerEngineTensor(workspace[i].data_ptr(), diff --git a/transformer_engine/pytorch/csrc/util.cpp b/transformer_engine/pytorch/csrc/util.cpp index 92f2d3a50..44b636930 100644 --- a/transformer_engine/pytorch/csrc/util.cpp +++ b/transformer_engine/pytorch/csrc/util.cpp @@ -1,9 +1,13 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. ************************************************************************/ +#ifndef USE_ROCM + #include "util.h" #include "common.h" @@ -170,3 +174,5 @@ std::optional multi_tensor_swizzle_scaling_factors( return buffer; } + +#endif //!USE_ROCM diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 4b2686096..621cc1db8 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -1,4 +1,6 @@ /************************************************************************* + * This file was modified for portability to AMDGPU + * Copyright (c) 2026, Advanced Micro Devices, Inc. All rights reserved. * Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. * * See LICENSE for license information. @@ -7,6 +9,8 @@ #ifndef TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ #define TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ +#ifndef USE_ROCM + #include #include @@ -27,4 +31,6 @@ std::optional swizzle_scaling_factors(transformer_engine::TensorWrap std::optional multi_tensor_swizzle_scaling_factors( std::vector &inputs, bool rowwise); +#endif //!USE_ROCM + #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ From 08db27ebf56c75fd8797f1f2cf1dfc1339f93989 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Sat, 17 Jan 2026 00:00:41 -0500 Subject: [PATCH 152/153] Resolve automerge error --- transformer_engine/pytorch/csrc/util.h | 2 -- 1 file changed, 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/util.h b/transformer_engine/pytorch/csrc/util.h index 6c0f76e83..621cc1db8 100644 --- a/transformer_engine/pytorch/csrc/util.h +++ b/transformer_engine/pytorch/csrc/util.h @@ -33,6 +33,4 @@ std::optional multi_tensor_swizzle_scaling_factors( #endif //!USE_ROCM -#endif //!USE_ROCM - #endif // TRANSFORMER_ENGINE_PYTORCH_CSRC_UTIL_H_ From 842736238bfcf581b230e5c844a437a7fcc6ddd4 Mon Sep 17 00:00:00 2001 From: Ilya Panfilov Date: Sat, 17 Jan 2026 21:46:43 -0500 Subject: [PATCH 153/153] Fix benchmark script. Remove not needed debug messages --- .../attention/benchmark_attention_rocm.py | 28 +++++++++---------- tests/pytorch/utils.py | 2 -- 2 files changed, 14 insertions(+), 16 deletions(-) diff --git a/benchmarks/attention/benchmark_attention_rocm.py b/benchmarks/attention/benchmark_attention_rocm.py index b126fb022..0c37696ac 100644 --- a/benchmarks/attention/benchmark_attention_rocm.py +++ b/benchmarks/attention/benchmark_attention_rocm.py @@ -1,5 +1,5 @@ # This file was modified for portability to AMDGPU -# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved. +# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved. # Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. # # See LICENSE for license information. @@ -13,17 +13,17 @@ import transformer_engine from transformer_engine_torch import NVTE_Fused_Attn_Backend -# Add test_fused_attn to the sys path +# Add TE repo root to the sys path tests_path = os.path.abspath( - os.path.join(os.path.dirname(__file__), "../../tests/pytorch/fused_attn") + os.path.join(os.path.dirname(__file__), "../../") ) sys.path.append(tests_path) -from test_fused_attn import ( +from tests.pytorch.utils import ( ModelConfig, - _get_attention_backends, - _run_dot_product_attention, + get_available_attention_backends, ) +from tests.pytorch.attention.test_attention import _run_dot_product_attention pd.set_option("display.precision", 4) @@ -46,12 +46,12 @@ is_training = True model_configs = { - # test: b, h, hg, d, sq, skv, p, mask, bias - "test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq - "test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask - "test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias - "test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA - "test_4": ModelConfig(2, 128, 8, 128, 8192, 8192, 0.0, "causal_bottom_right", "no_bias") + # b, sq, h, dqk + "test_0": ModelConfig(2, 512, 16, 64), # short seq + "test_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), # longer seq, mask + "test_2": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"), # bias + "test_3": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), # GQA + "test_4": ModelConfig(2, 8192, 128, 128, num_gqa_groups=8, attn_mask_type="causal_bottom_right") } # DataFrame indices and columns for results @@ -303,7 +303,7 @@ def sanity_checks( } for model, cfg in model_configs.items(): - avail, _, fused_bes = _get_attention_backends( + avail, _, fused_bes = get_available_attention_backends( cfg, qkv_dtype=dtype, qkv_layout=qkv_layout, @@ -364,7 +364,7 @@ def main(args): # Benchmarking starts.. for model in model_configs.keys(): config = model_configs[model] - available_backends, _, fused_attn_backends = _get_attention_backends( + available_backends, _, fused_attn_backends = get_available_attention_backends( config, qkv_dtype=dtype, qkv_layout=qkv_layout, diff --git a/tests/pytorch/utils.py b/tests/pytorch/utils.py index 5674535fa..684c15737 100644 --- a/tests/pytorch/utils.py +++ b/tests/pytorch/utils.py @@ -203,7 +203,6 @@ def logging_context(highest_level=logging.WARNING): if IS_HIP_EXTENSION: class EnvVarCleaner: def __init__(self, envs_): - print("PIV create envs:", envs_) self.envs = envs_ self.flags = {} for env in self.envs: @@ -211,7 +210,6 @@ def __init__(self, envs_): self.flags[env] = os.environ[env] def __del__(self): - print("PIV destroty envs:", self.envs) for env in self.envs: if env in self.flags: os.environ[env] = self.flags[env]