From 802c747612e831d1345862e97c9bea0d01642182 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 01:12:35 +0100 Subject: [PATCH 1/5] docs: Add comprehensive Getting Started guide with benchmarks - Add new Getting Started documentation with PyTorch and JAX tutorials - Include benchmark scripts demonstrating TE performance benefits - Add CSS styling for code output and tabs - Replace old quickstart notebooks with improved documentation - Add transformer layer diagram (SVG) - Update docs configuration and workflow Signed-off-by: Pawel Gadzinski --- .github/workflows/docs.yml | 2 +- docs/_static/css/output-style.css | 60 ++ docs/_static/css/rtabs.css | 43 + docs/conf.py | 3 + docs/examples/advanced_optimizations.ipynb | 2 +- docs/examples/quickstart.ipynb | 606 ------------- docs/examples/quickstart_jax.ipynb | 833 ------------------ docs/getting_started.rst | 16 - docs/getting_started/getting_started_jax.out | 46 + docs/getting_started/getting_started_jax.py | 481 ++++++++++ .../getting_started_jax_summary.csv | 7 + .../getting_started_pytorch.out | 42 + .../getting_started_pytorch.py | 435 +++++++++ .../getting_started_pytorch_summary.csv | 7 + .../getting_started_utils_jax.py | 77 ++ .../getting_started_utils_pytorch.py | 125 +++ docs/getting_started/index.rst | 564 ++++++++++++ docs/getting_started/transformer_layer.svg | 82 ++ docs/index.rst | 2 +- 19 files changed, 1975 insertions(+), 1458 deletions(-) create mode 100644 docs/_static/css/output-style.css create mode 100644 docs/_static/css/rtabs.css delete mode 100644 docs/examples/quickstart.ipynb delete mode 100644 docs/examples/quickstart_jax.ipynb delete mode 100644 docs/getting_started.rst create mode 100644 docs/getting_started/getting_started_jax.out create mode 100644 docs/getting_started/getting_started_jax.py create mode 100644 docs/getting_started/getting_started_jax_summary.csv create mode 100644 docs/getting_started/getting_started_pytorch.out create mode 100644 docs/getting_started/getting_started_pytorch.py create mode 100644 docs/getting_started/getting_started_pytorch_summary.csv create mode 100644 docs/getting_started/getting_started_utils_jax.py create mode 100644 docs/getting_started/getting_started_utils_pytorch.py create mode 100644 docs/getting_started/index.rst create mode 100644 docs/getting_started/transformer_layer.svg diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5beeeb8879f..7ebcbf740c6 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,7 +17,7 @@ jobs: uses: actions/checkout@v3 - name: 'Install dependencies' run: | - pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 + pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 sphinx-tabs==3.4.7 pip install breathe==4.35.0 sphinx-autoapi==3.3.2 sudo apt-get install -y pandoc graphviz doxygen export GIT_SHA=$(git show-ref --hash HEAD) diff --git a/docs/_static/css/output-style.css b/docs/_static/css/output-style.css new file mode 100644 index 00000000000..864d8587a32 --- /dev/null +++ b/docs/_static/css/output-style.css @@ -0,0 +1,60 @@ +/* Custom styling for program output blocks */ + +.program-output { + background-color: #f8f9fa; + padding: 0; /* No padding at all */ + margin: 0; /* No margins at all */ + border-radius: 0; /* No rounded corners */ + font-family: 'Courier New', monospace; + font-size: 14px; + line-height: 1.5; + width: 100%; + max-width: 100%; +} + +.program-output pre { + margin: 0; + padding: 0; + background: transparent !important; + border: none !important; + color: #2c3e50; + width: 100%; +} + +.program-output .highlight { + background: transparent !important; + margin: 0; + width: 100%; +} + +/* Alternative lighter style */ +.output-block { + background-color: #fafbfc; + border: 1px solid #e1e4e8; + padding: 10px 14px; + margin: 10px 0; + border-radius: 3px; + font-family: 'SF Mono', 'Consolas', monospace; + font-size: 13px; + color: #24292e; +} + +/* Console-like output style */ +.console-output { + background-color: #1e1e1e; + border-left: 3px solid #76b900; + padding: 14px 18px; + margin: 12px 0; + border-radius: 5px; + font-family: 'Fira Code', 'Consolas', monospace; + font-size: 13px; + color: #d4d4d4; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); +} + +.console-output pre { + margin: 0; + color: #d4d4d4; + background: transparent !important; +} + diff --git a/docs/_static/css/rtabs.css b/docs/_static/css/rtabs.css new file mode 100644 index 00000000000..7f4213ef99b --- /dev/null +++ b/docs/_static/css/rtabs.css @@ -0,0 +1,43 @@ +/* Custom styling for sphinx-tabs */ + +.sphinx-tabs { + margin-bottom: 1rem; +} + +.sphinx-tabs-tab { + background-color: #f4f4f4; + border: 1px solid #ccc; + border-bottom: none; + padding: 0.5rem 1rem; + margin-right: 0.5rem; + cursor: pointer; + font-weight: 500; + transition: background-color 0.2s; +} + +.sphinx-tabs-tab:hover { + background-color: #e0e0e0; +} + +.sphinx-tabs-tab[aria-selected="true"] { + background-color: #76b900; /* NVIDIA green */ + color: white; + border-color: #76b900; + margin-right: 0.5rem; +} + +.sphinx-tabs-panel { + border: 1px solid #ccc; + padding: 1rem; + background-color: #f9f9f9; +} + +/* Dark mode support for RTD theme */ +.rst-content .sphinx-tabs-tab { + color: #333; +} + +.rst-content .sphinx-tabs-tab[aria-selected="true"] { + color: white; +} + diff --git a/docs/conf.py b/docs/conf.py index 479c1f8948f..920b487fd32 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -58,6 +58,7 @@ "nbsphinx", "breathe", "autoapi.extension", + "sphinx_tabs.tabs", ] templates_path = ["_templates"] @@ -83,6 +84,8 @@ html_css_files = [ "css/nvidia_font.css", "css/nvidia_footer.css", + "css/rtabs.css", + "css/output-style.css", ] html_theme_options = { diff --git a/docs/examples/advanced_optimizations.ipynb b/docs/examples/advanced_optimizations.ipynb index 7c08bb65866..1b0694a05f8 100644 --- a/docs/examples/advanced_optimizations.ipynb +++ b/docs/examples/advanced_optimizations.ipynb @@ -13,7 +13,7 @@ "id": "6dcbf25a", "metadata": {}, "source": [ - "This guide is a follow-up to the discussion in the [quickstart guide](quickstart.ipynb). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). " + "This guide is a follow-up to the discussion in the [Getting Started guide](../getting_started/index.rst). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). " ] }, { diff --git a/docs/examples/quickstart.ipynb b/docs/examples/quickstart.ipynb deleted file mode 100644 index 0ad2f4fee8c..00000000000 --- a/docs/examples/quickstart.ipynb +++ /dev/null @@ -1,606 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "da9fd6a8", - "metadata": {}, - "source": [ - "# Getting Started\n", - "\n", - "## Overview\n", - "\n", - "Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your PyTorch code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n", - "\n", - "## Let's build a Transformer layer!\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We build a basic Transformer layer using regular PyTorch modules. This will be our baseline for later comparisons with Transformer Engine.\n", - "\n", - "
\n", - "\n", - "Let's start with creating a GPT encoder layer using plain PyTorch. Figure 1 shows the overall structure.\n", - "\n", - "
\n", - "\n", - "
Figure 1: Structure of a GPT encoder layer.
\n", - "
\n", - "\n", - "We construct the components as follows:\n", - "\n", - "- `LayerNorm`: `torch.nn.LayerNorm`\n", - "- `QKV Projection`: `torch.nn.Linear` (conceptually three `Linear` layers for Q, K, and V separately, but we fuse into a single `Linear` layer that is three times larger)\n", - "- `DotProductAttention`: `DotProductAttention` from [quickstart_utils.py](quickstart_utils.py)\n", - "- `Projection`: `torch.nn.Linear`\n", - "- `Dropout`: `torch.nn.Dropout`\n", - "- `MLP`: `BasicMLP` from [quickstart_utils.py](quickstart_utils.py)\n", - "\n", - "Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_utils.py](quickstart_utils.py). Putting it all together:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "2be43d64", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import quickstart_utils as utils\n", - "\n", - "class BasicTransformerLayer(torch.nn.Module):\n", - " def __init__(\n", - " self,\n", - " hidden_size: int,\n", - " ffn_hidden_size: int,\n", - " num_attention_heads: int,\n", - " layernorm_eps: int = 1e-5,\n", - " attention_dropout: float = 0.1,\n", - " hidden_dropout: float = 0.1,\n", - " ):\n", - " super().__init__()\n", - " self.num_attention_heads = num_attention_heads\n", - " self.kv_channels = hidden_size // num_attention_heads\n", - " self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n", - " self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)\n", - " self.attention = utils.DotProductAttention(\n", - " num_attention_heads=num_attention_heads,\n", - " kv_channels=self.kv_channels,\n", - " attention_dropout=attention_dropout,\n", - " )\n", - " self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)\n", - " self.dropout = torch.nn.Dropout(hidden_dropout)\n", - " self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n", - " self.mlp = utils.BasicMLP(\n", - " hidden_size=hidden_size,\n", - " ffn_hidden_size=ffn_hidden_size,\n", - " ) \n", - " \n", - " def forward(\n", - " self, \n", - " x: torch.Tensor, \n", - " attention_mask: torch.Tensor\n", - " ) -> torch.Tensor:\n", - " res = x\n", - " x = self.ln1(x)\n", - " \n", - " # Fused QKV projection\n", - " qkv = self.qkv_projection(x)\n", - " qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n", - " \n", - " x = self.attention(q, k, v, attention_mask)\n", - " x = self.projection(x)\n", - " x = self.dropout(x)\n", - " x = res + x\n", - " res = x\n", - " x = self.ln2(x)\n", - " x = self.mlp(x)\n", - " \n", - " return x + res" - ] - }, - { - "cell_type": "markdown", - "id": "40724d1d", - "metadata": {}, - "source": [ - "That's it! We now have a simple Transformer layer. We can test it:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "a786f0ea", - "metadata": {}, - "outputs": [], - "source": [ - "# Layer configuration\n", - "hidden_size = 4096\n", - "sequence_length = 2048\n", - "batch_size = 4\n", - "ffn_hidden_size = 16384\n", - "num_attention_heads = 32\n", - "dtype = torch.float16\n", - "\n", - "# Synthetic data\n", - "x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)\n", - "dy = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ffdbfb7a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "BasicTransformerLayer(\n", - " (ln1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n", - " (qkv_projection): Linear(in_features=4096, out_features=12288, bias=True)\n", - " (attention): DotProductAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (projection): Linear(in_features=4096, out_features=4096, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n", - " (mlp): BasicMLP(\n", - " (linear1): Linear(in_features=4096, out_features=16384, bias=True)\n", - " (linear2): Linear(in_features=16384, out_features=4096, bias=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "basic_transformer = BasicTransformerLayer(\n", - " hidden_size,\n", - " ffn_hidden_size,\n", - " num_attention_heads,\n", - ")\n", - "basic_transformer.to(dtype=dtype).cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "0162ad40", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(1234)\n", - "y = basic_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "65ae6dd6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 43.0663916015625 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " basic_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "43717e36", - "metadata": {}, - "source": [ - "## Meet Transformer Engine\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We modify the example Transformer layer to include the simplest TE modules: `Linear` and `LayerNorm`.\n", - "\n", - "
\n", - "\n", - "Now that we have a basic Transformer layer, let's use Transformer Engine to speed up the training. " - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "004d3c92", - "metadata": {}, - "outputs": [], - "source": [ - "import transformer_engine.pytorch as te" - ] - }, - { - "cell_type": "markdown", - "id": "1931f911", - "metadata": {}, - "source": [ - "TE provides a set of PyTorch modules that can be used to build Transformer layers. The simplest of the provided modules are the `Linear` and `LayerNorm` layers, which we can use instead of `torch.nn.Linear` and `torch.nn.LayerNorm`. Let's modify `BasicTransformerLayer`:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "1f44db50", - "metadata": {}, - "outputs": [], - "source": [ - "class BasicTEMLP(torch.nn.Module):\n", - " def __init__(self,\n", - " hidden_size: int,\n", - " ffn_hidden_size: int) -> None:\n", - " super().__init__()\n", - " self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)\n", - " self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)\n", - "\n", - " def forward(self, x):\n", - " x = self.linear1(x)\n", - " x = torch.nn.functional.gelu(x, approximate='tanh')\n", - " x = self.linear2(x)\n", - " return x \n", - " \n", - "class BasicTETransformerLayer(torch.nn.Module):\n", - " def __init__(self,\n", - " hidden_size: int,\n", - " ffn_hidden_size: int,\n", - " num_attention_heads: int,\n", - " layernorm_eps: int = 1e-5,\n", - " attention_dropout: float = 0.1,\n", - " hidden_dropout: float = 0.1):\n", - " super().__init__()\n", - " self.num_attention_heads = num_attention_heads\n", - " self.kv_channels = hidden_size // num_attention_heads\n", - " self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n", - " self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)\n", - " self.attention = utils.DotProductAttention(\n", - " num_attention_heads=num_attention_heads,\n", - " kv_channels=self.kv_channels,\n", - " attention_dropout=attention_dropout,\n", - " )\n", - " self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n", - " self.dropout = torch.nn.Dropout(hidden_dropout)\n", - " self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n", - " self.mlp = BasicTEMLP(\n", - " hidden_size=hidden_size,\n", - " ffn_hidden_size=ffn_hidden_size,\n", - " )\n", - " \n", - " def forward(self, \n", - " x: torch.Tensor, \n", - " attention_mask: torch.Tensor):\n", - " res = x\n", - " x = self.ln1(x)\n", - " \n", - " # Fused QKV projection\n", - " qkv = self.qkv_projection(x)\n", - " qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n", - " \n", - " x = self.attention(q, k, v, attention_mask)\n", - " x = self.projection(x)\n", - " x = self.dropout(x)\n", - " x = res + x\n", - " res = x\n", - " x = self.ln2(x)\n", - " x = self.mlp(x)\n", - " \n", - " return x + res" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "916531e8", - "metadata": {}, - "outputs": [], - "source": [ - "basic_te_transformer = BasicTETransformerLayer(\n", - " hidden_size, \n", - " ffn_hidden_size, \n", - " num_attention_heads,\n", - ")\n", - "basic_te_transformer.to(dtype=dtype).cuda()\n", - "utils.share_parameters_with_basic_te_model(basic_te_transformer, basic_transformer)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "3643fa54", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(1234)\n", - "y = basic_te_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "10b92894", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 43.1413232421875 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " basic_te_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "3f990226", - "metadata": {}, - "source": [ - "## Fused TE Modules\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We optimize the example Transformer layer with TE modules for fused operations.\n", - "\n", - "
\n", - "\n", - "The `Linear` layer is enough to build any Transformer model and it enables usage of Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations like kernel fusion, increasing the achievable speedup.\n", - "\n", - "Transformer Engine therefore provides coarser modules that span multiple layers:\n", - "\n", - "* `LayerNormLinear`\n", - "* `LayerNormMLP`\n", - "* `TransformerLayer`\n", - "\n", - "Building a third iteration of our Transformer layer with `LayerNormLinear` and `LayerNormMLP`:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "c55eae1f", - "metadata": {}, - "outputs": [], - "source": [ - "class FusedTETransformerLayer(torch.nn.Module):\n", - " def __init__(self,\n", - " hidden_size: int,\n", - " ffn_hidden_size: int,\n", - " num_attention_heads: int,\n", - " layernorm_eps: int = 1e-5,\n", - " attention_dropout: float = 0.1,\n", - " hidden_dropout: float = 0.1):\n", - " super().__init__()\n", - " self.num_attention_heads = num_attention_heads\n", - " self.kv_channels = hidden_size // num_attention_heads\n", - " self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)\n", - " self.attention = utils.DotProductAttention(\n", - " num_attention_heads=num_attention_heads,\n", - " kv_channels=self.kv_channels,\n", - " attention_dropout=attention_dropout,\n", - " )\n", - " self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n", - " self.dropout = torch.nn.Dropout(hidden_dropout)\n", - " self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)\n", - " \n", - " \n", - " def forward(self, \n", - " x: torch.Tensor, \n", - " attention_mask: torch.Tensor):\n", - " res = x\n", - " qkv = self.ln_qkv(x)\n", - " \n", - " # Split qkv into query, key and value\n", - " qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n", - " \n", - " x = self.attention(q, k, v, attention_mask)\n", - " x = self.projection(x)\n", - " x = self.dropout(x)\n", - " x = res + x\n", - " res = x\n", - " x = self.ln_mlp(x)\n", - " \n", - " return x + res" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "85949421", - "metadata": {}, - "outputs": [], - "source": [ - "fused_te_transformer = FusedTETransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n", - "fused_te_transformer.to(dtype=dtype).cuda()\n", - "utils.share_parameters_with_fused_te_model(fused_te_transformer, basic_transformer)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "2c263e71", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(1234)\n", - "y = fused_te_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "24e101bc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 43.1981201171875 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " fused_te_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "33f13c26", - "metadata": {}, - "source": [ - "Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "ec8c3685", - "metadata": {}, - "outputs": [], - "source": [ - "te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n", - "te_transformer.to(dtype=dtype).cuda()\n", - "utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "e48cd590", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(1234)\n", - "y = te_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "3ec3707d-e63f-4899-8308-b11c55b5caa4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 39.99169921875 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " te_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "4034c3eb-8958-49f2-85f6-30c94977d884", - "metadata": {}, - "source": [ - "## Enabling FP8\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We configure a TE module to perform compute in FP8.\n", - "\n", - "
\n", - "\n", - "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager. Note that autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "31256aa7-3d5e-425c-91ab-502b1326a748", - "metadata": {}, - "outputs": [], - "source": [ - "from transformer_engine.common.recipe import Format, DelayedScaling\n", - "\n", - "te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n", - "te_transformer.to(dtype=dtype).cuda()\n", - "utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)\n", - "\n", - "fp8_format = Format.HYBRID\n", - "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", - "torch.manual_seed(1234)\n", - "with te.autocast(enabled=True, fp8_recipe=fp8_recipe):\n", - " y = te_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "793ebd2d-b84b-47bc-811a-7991df8500aa", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 28.61394775390625 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " te_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", - ")" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/examples/quickstart_jax.ipynb b/docs/examples/quickstart_jax.ipynb deleted file mode 100644 index 5b4e439c004..00000000000 --- a/docs/examples/quickstart_jax.ipynb +++ /dev/null @@ -1,833 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "962d87bb", - "metadata": {}, - "source": [ - "\n", - "\n", - "# Getting Started\n", - "\n", - "## Overview\n", - "\n", - "Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n", - "\n", - "This guide shows how to start using Transformer Engine with JAX. Similar tutorial for pyTorch is available [here](quickstart.ipynb).\n", - "We recommend you to try understanding the basics of JAX first, using these resources:\n", - "\n", - "- Thinking in JAX: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html\n", - "- JAX 101: https://docs.jax.dev/en/latest/jax-101.html\n", - "- Key concepts in JAX: https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array\n", - "- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html\n", - "\n", - "## Let's build a Transformer decoder layer!\n", - "_This is based upon the GPT decoder layer with causal masking, which prevents each position from attending to future positions._\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We build a basic Transformer layer using regular Flax modules. This will be our baseline for later comparisons with Transformer Engine.\n", - "\n", - "
\n", - "\n", - "Let's start with creating the transformer layer using plain [FLAX Linen](https://flax.readthedocs.io/en/stable/) . Figure 1 shows the overall structure.\n", - "\n", - "
\n", - "\n", - "
Figure 1: Structure of a GPT decoder layer.
\n", - "
\n", - "\n", - "We construct the components as follows:\n", - "\n", - "- `LayerNorm`: `nn.LayerNorm` (Flax)\n", - "- `QKV Projection`: `nn.Dense` (conceptually there are three seperate `Dense` layers for Q, K, and V separately, but we fuse them together into a single `Dense` layer that is three times larger)\n", - "- `DotProductAttention`: `nn.MuliheadDotProductAttention` (Flax)\n", - "- `Projection`: `nn.Dense` (Flax)\n", - "- `Dropout`: `nn.Dropout` (Flax)\n", - "- `MLP`: `FlaxMLP` implemented using `nn.Dense` and `nn.gelu`\n", - "\n", - "Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_jax_utils.py](quickstart_jax_utils.py). Putting it all together: \n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d5284a38", - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from flax import linen as nn\n", - "import quickstart_jax_utils as utils\n", - "from typing import Optional" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "a4d1cfdc", - "metadata": {}, - "outputs": [], - "source": [ - "class FlaxMLP(nn.Module):\n", - " \"\"\"Feed-forward network in Transformer layer\n", - " Built with plain Flax modules.\n", - " \"\"\"\n", - " hidden_size: int\n", - " ffn_hidden_size: int\n", - "\n", - " @nn.compact\n", - " def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n", - " x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)\n", - " x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n", - " x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n", - " return x\n", - "\n", - "class FlaxTransformerLayer(nn.Module):\n", - " \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n", - " hidden_size: int\n", - " ffn_hidden_size: int\n", - " num_attention_heads: int\n", - " layernorm_eps: float = 1e-5\n", - " attention_dropout: float = 0.1\n", - " \n", - " def setup(self):\n", - " self.kv_channels = self.hidden_size // self.num_attention_heads\n", - "\n", - " @nn.compact\n", - " def __call__(\n", - " self, \n", - " x: jnp.ndarray, \n", - " attention_mask: Optional[jnp.ndarray] = None,\n", - " deterministic: bool = False\n", - " ) -> jnp.ndarray:\n", - " # Create causal mask if not provided\n", - " if attention_mask is None:\n", - " attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n", - " \n", - " res = x\n", - " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - " \n", - " # Fused QKV projection\n", - " qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)\n", - " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = jnp.split(qkv, 3, axis=3)\n", - " \n", - " # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n", - " # which is the correct format for dot_product_attention\n", - " \n", - " # Apply dot product attention\n", - " # Note: dot_product_attention expects mask to be broadcastable to \n", - " # [batch, num_heads, q_length, kv_length], but attention_mask from \n", - " # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n", - " \n", - " # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n", - " dropout_rng = None\n", - " if not deterministic and self.attention_dropout > 0:\n", - " dropout_rng = self.make_rng('dropout')\n", - " \n", - " x = nn.dot_product_attention(\n", - " query=q,\n", - " key=k,\n", - " value=v,\n", - " mask=attention_mask,\n", - " dropout_rng=dropout_rng,\n", - " dropout_rate=self.attention_dropout,\n", - " deterministic=deterministic,\n", - " broadcast_dropout=True,\n", - " )\n", - " \n", - " # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n", - " x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n", - "\n", - " # Output projection\n", - " x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n", - " \n", - " x = res + x\n", - " \n", - " # Second residual connection\n", - " res = x\n", - " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - " \n", - " # MLP\n", - " mlp = FlaxMLP(\n", - " hidden_size=self.hidden_size,\n", - " ffn_hidden_size=self.ffn_hidden_size,\n", - " )\n", - " x = mlp(x)\n", - " \n", - " return x + res\n" - ] - }, - { - "cell_type": "markdown", - "id": "fbc3510b", - "metadata": {}, - "source": [ - "## Testing Performance\n", - "\n", - "Now let's test the performance of our FlaxTransformerLayer:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8b44649d", - "metadata": {}, - "outputs": [], - "source": [ - "# Layer configuration\n", - "hidden_size = 4096\n", - "sequence_length = 2048\n", - "batch_size = 4\n", - "ffn_hidden_size = 16384\n", - "num_attention_heads = 32\n", - "dtype = jnp.bfloat16\n", - "\n", - "# Synthetic data\n", - "key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n", - "x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n", - "dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "e44ed26d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Pure Flax FlaxTransformerLayer initialized successfully!\n", - "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n" - ] - } - ], - "source": [ - "# Initialize the FlaxTransformerLayer\n", - "flax_transformer = FlaxTransformerLayer(\n", - " hidden_size=hidden_size,\n", - " ffn_hidden_size=ffn_hidden_size,\n", - " num_attention_heads=num_attention_heads,\n", - ")\n", - "\n", - "# Initialize parameters\n", - "params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n", - "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "de91af7a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input shape: (4, 2048, 4096)\n", - "Output shape: (4, 2048, 4096)\n", - "Output dtype: float32\n", - "Forward pass completed successfully!\n" - ] - } - ], - "source": [ - "# Example usage of forward pass\n", - "y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n", - "print(f\"Input shape: {x.shape}\")\n", - "print(f\"Output shape: {y.shape}\")\n", - "print(f\"Output dtype: {y.dtype}\")\n", - "print(\"Forward pass completed successfully!\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "037bc8d9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 19.258604049682617 ms\n" - ] - } - ], - "source": [ - "import importlib\n", - "import quickstart_jax_utils\n", - "importlib.reload(quickstart_jax_utils)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=flax_transformer.apply,\n", - " variables=params,\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "ccb16f31", - "metadata": {}, - "source": [ - "## Meet Transformer Engine\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "Now that we have a basic Transformer layer in Flax, let's use Transformer Engine to speed up the training. The following examples show how to use TE modules.\n", - "\n", - "
\n", - "\n", - "As a reminder, the FlaxTransformerLayer above used:\n", - "\n", - "- `nn.LayerNorm`: Flax LayerNorm\n", - "- `nn.Dense`: Flax Dense layer for QKV projection \n", - "- `nn.MultiheadDotProductAttention`: Flax MultiheadDotProductAttention\n", - "- `nn.Dense`: Flax Dense layer for projection\n", - "- `nn.Dropout`: Flax Dropout\n", - "- `FlaxMLP`: Custom MLP implemented from `nn.Dense`\n", - "\n", - "Below we show how to use Transformer Engine Flax modules for better performance:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "bed20d6b", - "metadata": {}, - "outputs": [], - "source": [ - "import transformer_engine.jax as te\n", - "import transformer_engine.jax.flax as te_flax" - ] - }, - { - "cell_type": "markdown", - "id": "f28cb444", - "metadata": {}, - "source": [ - "TE provides a set of Flax Linen modules that can be used to build Transformer layers. The simplest of the provided modules are the `DenseGeneral ` and `LayerNorm` layers, which we can use instead of `flax.linen.Dense` and ` flax.linen.LayerNorm`. Let's modify our `FlaxTransformerLayer`:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "56105579", - "metadata": {}, - "outputs": [], - "source": [ - "class TEUnfusedMLP(nn.Module):\n", - " hidden_size : int\n", - " ffn_hidden_size: int\n", - "\n", - " @nn.compact\n", - " def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:\n", - " x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True) (x)\n", - " x = x.reshape(*x.shape[:-1], 1, x.shape[-1])\n", - " x = te.activation.activation(x, activation_type=('gelu',))\n", - " x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True) (x)\n", - " return x\n", - "\n", - "class TEUnfusedTransformerLayer(nn.Module):\n", - " hidden_size: int\n", - " ffn_hidden_size: int \n", - " num_attention_heads: int \n", - " layernorm_eps: float = 1e-5\n", - " attention_dropout: float = 0.1 \n", - " use_te_attention: bool = True # True for TE attention, False for Flax attention\n", - "\n", - " def setup(self):\n", - " self.kv_channels = self.hidden_size // self.num_attention_heads\n", - "\n", - " @nn.compact\n", - " def __call__(\n", - " self, \n", - " x: jnp.ndarray,\n", - " attention_mask: Optional[jnp.ndarray] = None,\n", - " deterministic: bool = False\n", - " ) -> jnp.ndarray: \n", - " res = x\n", - " x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - "\n", - " # Fused QKV projection\n", - " qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)\n", - " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = jnp.split(qkv, 3, axis=3)\n", - "\n", - " # Attention - either TE or Flax implementation\n", - " if self.use_te_attention:\n", - " # Use TE's DotProductAttention\n", - " attention = te_flax.DotProductAttention(\n", - " head_dim=self.kv_channels,\n", - " num_attention_heads=self.num_attention_heads,\n", - " num_gqa_groups=self.num_attention_heads, # No GQA\n", - " attention_dropout=self.attention_dropout,\n", - " attn_mask_type='causal',\n", - " )\n", - " x = attention(\n", - " q, k, v,\n", - " # Causal mask does not need an explicit instatiated mask as specialized kernels exist to handle it\n", - " sequence_descriptor=None, \n", - " deterministic=deterministic\n", - " )\n", - " # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n", - " x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n", - " x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n", - " x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n", - " else:\n", - " # Use Flax's MultiHeadDotProductAttention\n", - " q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n", - " k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n", - " v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n", - "\n", - " # Create causal mask if not provided\n", - " if attention_mask is None:\n", - " attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n", - " \n", - " attention = nn.MultiHeadDotProductAttention(\n", - " num_heads=self.num_attention_heads,\n", - " qkv_features=self.kv_channels,\n", - " dropout_rate=self.attention_dropout,\n", - " )\n", - " x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n", - "\n", - " x = res + x\n", - "\n", - " # Second residual connection\n", - " res = x\n", - " x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - "\n", - " # MLP\n", - " mlp = TEUnfusedMLP(\n", - " hidden_size=self.hidden_size,\n", - " ffn_hidden_size=self.ffn_hidden_size\n", - " )\n", - "\n", - " x = mlp(x, deterministic=deterministic)\n", - "\n", - " return x + res" - ] - }, - { - "cell_type": "markdown", - "id": "a76911ac", - "metadata": {}, - "source": [ - "Testing performance of the model, using `DenseGeneral`, `LayerNorm` and activation from TE, while keeping Flax's `MultiHeadDotProductAttention` the same as the first simple Transformer in JAX implementation. To read more about this implementation from Flax, you can refer to this documentation: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "4b67511f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 16.003193855285645 ms\n" - ] - } - ], - "source": [ - "te_unfused_transformer_with_flax_MHA = TEUnfusedTransformerLayer(\n", - " hidden_size, \n", - " ffn_hidden_size, \n", - " num_attention_heads,\n", - " use_te_attention=False\n", - ")\n", - "\n", - "te_params = te_unfused_transformer_with_flax_MHA.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=te_unfused_transformer_with_flax_MHA.apply,\n", - " variables=te_params, # Ensure the correct `params` is passed\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " rngs={\"dropout\": dropout_key},\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "id": "0b230058", - "metadata": {}, - "source": [ - "Now, we move on to also replace the attention sub-layer with TE's `DotProductAttention` implementation" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "5146cd99", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 8.897695541381836 ms\n" - ] - } - ], - "source": [ - "te_unfused_transformer = TEUnfusedTransformerLayer(\n", - " hidden_size, \n", - " ffn_hidden_size, \n", - " num_attention_heads,\n", - ")\n", - "\n", - "te_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=te_unfused_transformer.apply,\n", - " variables=te_params, # Ensure the correct `params` is passed\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "c9a101d3", - "metadata": {}, - "source": [ - "## Enabling Quantization (FP8 or FP4)\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We configure a TE module to perform compute in FP8.\n", - "\n", - "
\n", - "\n", - "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.\n", - "\n", - "
\n", - "\n", - "Important: FP8 Metadata Initialization\n", - "\n", - "When using FP8, the model **must be initialized within the `autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n", - "\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "c2eee376", - "metadata": {}, - "outputs": [], - "source": [ - "from transformer_engine.common.recipe import Format, DelayedScaling\n", - "fp8_format = Format.HYBRID\n", - "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "de96827c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 5.651178359985352 ms\n" - ] - } - ], - "source": [ - "with te.autocast(enabled=True, recipe=fp8_recipe):\n", - " te_unfused_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - " # Example usage of forward \n", - " y = te_unfused_transformer.apply(te_unfused_params, x, attention_mask=None, deterministic=True)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=te_unfused_transformer.apply,\n", - " variables=te_unfused_params, # Ensure the correct `params` is passed\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "3801b201", - "metadata": {}, - "source": [ - "\n", - "## Fused TE Modules\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We optimize the example Transformer layer with TE modules for fused operations.\n", - "\n", - "
\n", - "\n", - "The `DenseGeneral` layer is enough to build any Transformer model and it enables usage of the Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations such as kernel fusions in mixed-precision recipes, increasing the achievable speedup.\n", - "\n", - "Transformer Engine therefore provides coarser modules that span multiple layers:\n", - "\n", - "* `LayerNormDenseGeneral`\n", - "* `LayerNormMLP`\n", - "* `TransformerLayer`\n", - "\n", - "To see a complete list of all the functions TE Flax support, you can view it here: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#modules\n", - "\n", - "Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "11203785", - "metadata": {}, - "outputs": [], - "source": [ - "class TEFusedTransformerLayer(nn.Module):\n", - " hidden_size: int\n", - " ffn_hidden_size: int \n", - " num_attention_heads: int \n", - " layernorm_eps: float = 1e-5\n", - " attention_dropout: float = 0.1\n", - "\n", - " def setup(self):\n", - " self.kv_channels = self.hidden_size // self.num_attention_heads\n", - "\n", - " @nn.compact\n", - " def __call__(\n", - " self, \n", - " x: jnp.ndarray,\n", - " attention_mask: Optional[jnp.ndarray] = None,\n", - " deterministic: bool = False\n", - " ) -> jnp.ndarray:\n", - " res = x\n", - "\n", - " # Fused QKV projection\n", - " qkv,_ = te_flax.LayerNormDenseGeneral(features=3 * self.hidden_size, \n", - " epsilon=self.layernorm_eps, \n", - " use_bias=True, \n", - " return_layernorm_output=False)(x)\n", - " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = jnp.split(qkv, 3, axis=3)\n", - "\n", - " # Attention using TE's DotProductAttention\n", - " attention = te_flax.DotProductAttention(\n", - " head_dim=self.kv_channels,\n", - " num_attention_heads=self.num_attention_heads,\n", - " num_gqa_groups=self.num_attention_heads, \n", - " attention_dropout=self.attention_dropout,\n", - " attn_mask_type='causal',\n", - " )\n", - " x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n", - " # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n", - " x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n", - " x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n", - " x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n", - "\n", - " x = res + x\n", - "\n", - " # Second residual connection\n", - " res = x\n", - " x,_ = te_flax.LayerNormMLP(intermediate_dim=self.ffn_hidden_size, \n", - " epsilon=self.layernorm_eps,\n", - " use_bias=True,\n", - " activations=('gelu',),\n", - " intermediate_dropout_rate=0.0,\n", - " return_layernorm_output=False\n", - " )(x, deterministic=deterministic)\n", - "\n", - " return x + res" - ] - }, - { - "cell_type": "markdown", - "id": "334cff59", - "metadata": {}, - "source": [ - "Similar to the unnfused model, we also compare the performance of fused model when using Flax's MultiheadDotProductAttention implementation and TE's." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "6b0c705e", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 5.493879318237305 ms\n" - ] - } - ], - "source": [ - "te_fused_transformer = TEFusedTransformerLayer(\n", - " hidden_size, \n", - " ffn_hidden_size, \n", - " num_attention_heads\n", - ")\n", - "\n", - "with te.autocast(enabled=True, recipe=fp8_recipe):\n", - " te_fused_params = te_fused_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - " # Example usage of forward \n", - " y = te_fused_transformer.apply(te_fused_params, x, attention_mask=None, deterministic=True)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=te_fused_transformer.apply,\n", - " variables=te_fused_params,\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "a45c12c8", - "metadata": {}, - "source": [ - "Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "b2aaa8ef", - "metadata": {}, - "outputs": [], - "source": [ - "te_transformer = te_flax.TransformerLayer(\n", - " hidden_size=hidden_size,\n", - " mlp_hidden_size=ffn_hidden_size, \n", - " num_attention_heads=num_attention_heads,\n", - " mlp_activations=(\"gelu\",),\n", - " self_attn_mask_type='causal',\n", - " layernorm_epsilon=1e-5,\n", - " use_bias=True,\n", - " intermediate_dropout=0.0,\n", - " enable_relative_embedding=False,\n", - " self_attn_bias_type='no_bias',\n", - " hidden_dropout=0.0,\n", - ")\n", - "\n", - "with te.autocast(enabled=True, recipe=fp8_recipe):\n", - " te_transformer_params = te_transformer.init(key, x, deterministic=False)\n", - " y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "b9cdbf22", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 5.334172248840332 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " model_apply_fn=te_transformer.apply,\n", - " model_init_fn=te_transformer.init,\n", - " variables=te_transformer_params,\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/getting_started.rst b/docs/getting_started.rst deleted file mode 100644 index 2e8047763a8..00000000000 --- a/docs/getting_started.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - - See LICENSE for license information. - -Getting Started -=============== - -Choose your framework to get started with Transformer Engine: - -.. toctree:: - :maxdepth: 1 - - PyTorch - JAX - diff --git a/docs/getting_started/getting_started_jax.out b/docs/getting_started/getting_started_jax.out new file mode 100644 index 00000000000..3455cbfdaff --- /dev/null +++ b/docs/getting_started/getting_started_jax.out @@ -0,0 +1,46 @@ +pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel +pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel +/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 + warnings.warn( +/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 + warnings.warn( +/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 + warnings.warn( +/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 + warnings.warn( +/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 + warnings.warn( +/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 + warnings.warn( +# BENCHMARK_BASELINE_OUTPUT_START +Baseline Flax: +Mean time: 81.035 ms +# BENCHMARK_BASELINE_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_OUTPUT_START +TE Unfused: +Mean time: 42.570 ms +# BENCHMARK_TE_UNFUSED_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START +TE Unfused + TE Attention: +Mean time: 35.017 ms +# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START +TE Unfused + TE Attention + FP8: +Mean time: 22.778 ms +# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END + +# BENCHMARK_TE_FUSED_FP8_OUTPUT_START +TE Fused + TE Attention + FP8: +Mean time: 24.007 ms +# BENCHMARK_TE_FUSED_FP8_OUTPUT_END + +# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START +TE TransformerLayer + FP8: +Mean time: 23.004 ms +# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END + + +Summary written to getting_started_jax_summary.csv diff --git a/docs/getting_started/getting_started_jax.py b/docs/getting_started/getting_started_jax.py new file mode 100644 index 00000000000..dc905737d4f --- /dev/null +++ b/docs/getting_started/getting_started_jax.py @@ -0,0 +1,481 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Getting Started with Transformer Engine - JAX Example +====================================================== + +This example shows how to build a Transformer decoder layer using JAX/Flax +and how to optimize it with Transformer Engine. +""" + +import jax +import jax.numpy as jnp +from flax import linen as nn +from typing import Optional + +import transformer_engine.jax as te +import transformer_engine.jax.flax as te_flax +from transformer_engine.jax.sharding import MeshResource +from transformer_engine.common.recipe import Format, DelayedScaling + +from getting_started_utils_jax import speedometer + + +# Configuration +hidden_size = 4096 +sequence_length = 2048 +batch_size = 8 +ffn_hidden_size = 16384 +num_attention_heads = 32 +dtype = jnp.bfloat16 + +# Create synthetic data +key = jax.random.PRNGKey(42) +x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype) +mesh_resource = MeshResource() + + +# ============================================================================= +# Baseline: Pure Flax Implementation +# ============================================================================= + +# BASELINE_MLP_START +class FlaxMLP(nn.Module): + """Feed-forward network in Transformer layer. + Built with plain Flax modules. + """ + hidden_size: int + ffn_hidden_size: int + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x) + x = nn.gelu(x, approximate=True) + x = nn.Dense(features=self.hidden_size, use_bias=True)(x) + return x +# BASELINE_MLP_END + + +# BASELINE_LAYER_START +class FlaxTransformerLayer(nn.Module): + """Basic Transformer layer using plain Flax modules.""" + hidden_size: int + ffn_hidden_size: int + num_attention_heads: int + layernorm_eps: float = 1e-5 + attention_dropout: float = 0.1 + + def setup(self): + self.kv_channels = self.hidden_size // self.num_attention_heads + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = False + ) -> jnp.ndarray: + if attention_mask is None: + attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_) + + res = x + x = nn.LayerNorm(epsilon=self.layernorm_eps)(x) + + # Fused QKV projection + qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x) + qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels) + q, k, v = jnp.split(qkv, 3, axis=3) + + dropout_rng = None + if not deterministic and self.attention_dropout > 0: + dropout_rng = self.make_rng('dropout') + + x = nn.dot_product_attention( + query=q, + key=k, + value=v, + mask=attention_mask, + dropout_rng=dropout_rng, + dropout_rate=self.attention_dropout, + deterministic=deterministic, + broadcast_dropout=True, + ) + + x = x.reshape(x.shape[0], x.shape[1], self.hidden_size) + x = nn.Dense(features=self.hidden_size, use_bias=True)(x) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + x = res + x + + res = x + x = nn.LayerNorm(epsilon=self.layernorm_eps)(x) + mlp = FlaxMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size) + x = mlp(x) + + return x + res +# BASELINE_LAYER_END + + +print("# BENCHMARK_BASELINE_OUTPUT_START") +# BENCHMARK_BASELINE_START +baseline = FlaxTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) +params = baseline.init(key, x, deterministic=False) + +print("Baseline Flax:") +time_baseline = speedometer(baseline.apply, params, x, forward_kwargs={"deterministic": True}, label="baseline") +# BENCHMARK_BASELINE_END +print("# BENCHMARK_BASELINE_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused: Basic TE Modules +# ============================================================================= + +# TE_UNFUSED_MLP_START +class TEUnfusedMLP(nn.Module): + """MLP using TE modules.""" + hidden_size: int + ffn_hidden_size: int + + @nn.compact + def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray: + x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True)(x) + x = x.reshape(*x.shape[:-1], 1, x.shape[-1]) + x = te.activation.activation(x, activation_type=('gelu',)) + x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x) + return x +# TE_UNFUSED_MLP_END + + +# TE_UNFUSED_LAYER_START +class TEUnfusedTransformerLayer(nn.Module): + """Transformer layer using basic TE modules (without TE attention).""" + hidden_size: int + ffn_hidden_size: int + num_attention_heads: int + layernorm_eps: float = 1e-5 + attention_dropout: float = 0.1 + + def setup(self): + self.kv_channels = self.hidden_size // self.num_attention_heads + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = False + ) -> jnp.ndarray: + if attention_mask is None: + attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_) + + res = x + x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x) + + qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x) + qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels) + q, k, v = jnp.split(qkv, 3, axis=3) + + dropout_rng = None + if not deterministic and self.attention_dropout > 0: + dropout_rng = self.make_rng('dropout') + + x = nn.dot_product_attention( + query=q, + key=k, + value=v, + mask=attention_mask, + dropout_rng=dropout_rng, + dropout_rate=self.attention_dropout, + deterministic=deterministic, + broadcast_dropout=True, + ) + + x = x.reshape(x.shape[0], x.shape[1], self.hidden_size) + x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + x = res + x + + res = x + x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x) + mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size) + x = mlp(x, deterministic=deterministic) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + return x + res +# TE_UNFUSED_LAYER_END + + +print("# BENCHMARK_TE_UNFUSED_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_START +te_unfused = TEUnfusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) +params = te_unfused.init(key, x, deterministic=False) + +print("TE Unfused:") +time_te_unfused = speedometer(te_unfused.apply, params, x, forward_kwargs={"deterministic": True}, label="te_unfused") +# BENCHMARK_TE_UNFUSED_END +print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused + TE Attention +# ============================================================================= + +# TE_UNFUSED_ATTN_LAYER_START +class TEUnfusedAttnTransformerLayer(nn.Module): + """Transformer layer using TE modules including TE DotProductAttention.""" + hidden_size: int + ffn_hidden_size: int + num_attention_heads: int + layernorm_eps: float = 1e-5 + attention_dropout: float = 0.1 + + def setup(self): + self.kv_channels = self.hidden_size // self.num_attention_heads + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = False + ) -> jnp.ndarray: + res = x + x = te_flax.LayerNorm(epsilon=self.layernorm_eps, dtype=jnp.bfloat16)(x) + + qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True, dtype=jnp.bfloat16)(x) + qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels) + q, k, v = jnp.split(qkv, 3, axis=3) + + attention = te_flax.DotProductAttention( + head_dim=self.kv_channels, + num_attention_heads=self.num_attention_heads, + num_gqa_groups=self.num_attention_heads, + attention_dropout=self.attention_dropout, + attn_mask_type='causal', + transpose_batch_sequence=False, + ) + x = attention(q, k, v, deterministic=deterministic) + x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) + x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True, dtype=jnp.bfloat16)(x) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + x = res + x + + res = x + x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x) + mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size) + x = mlp(x, deterministic=deterministic) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + return x + res +# TE_UNFUSED_ATTN_LAYER_END + + +print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_ATTN_START +te_unfused_attn = TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) + +with te.autocast(enabled=False, mesh_resource=mesh_resource): + params = te_unfused_attn.init(key, x, deterministic=False) + +print("TE Unfused + TE Attention:") +time_te_unfused_attn = speedometer( + te_unfused_attn.apply, + params, + x, + forward_kwargs={"deterministic": True}, + autocast_kwargs={"enabled": False, "mesh_resource": mesh_resource}, + label="te_unfused_attn" +) +# BENCHMARK_TE_UNFUSED_ATTN_END +print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused + FP8 +# ============================================================================= + +print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_FP8_START +recipe = DelayedScaling( + fp8_format=Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max" +) + +te_unfused_fp8 = TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) + +with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource): + params = te_unfused_fp8.init(key, x, deterministic=False) + +print("TE Unfused + TE Attention + FP8:") +time_te_unfused_fp8 = speedometer( + te_unfused_fp8.apply, + params, + x, + forward_kwargs={"deterministic": True}, + autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource}, + label="te_unfused_fp8" +) +# BENCHMARK_TE_UNFUSED_FP8_END +print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n") + + +# ============================================================================= +# TE Fused + FP8: Optimized Modules with FP8 +# ============================================================================= + +# TE_FUSED_LAYER_START +class TEFusedTransformerLayer(nn.Module): + """Transformer layer using fused TE modules for better performance.""" + hidden_size: int + ffn_hidden_size: int + num_attention_heads: int + layernorm_eps: float = 1e-5 + attention_dropout: float = 0.1 + + def setup(self): + self.kv_channels = self.hidden_size // self.num_attention_heads + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = False + ) -> jnp.ndarray: + res = x + + # Fused LayerNorm + QKV projection + qkv, _ = te_flax.LayerNormDenseGeneral( + features=3 * self.hidden_size, + epsilon=self.layernorm_eps, + use_bias=True, + return_layernorm_output=False + )(x) + qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.num_attention_heads, self.kv_channels) + q, k, v = qkv[:, :, 0, :, :], qkv[:, :, 1, :, :], qkv[:, :, 2, :, :] + + attention = te_flax.DotProductAttention( + head_dim=self.kv_channels, + num_attention_heads=self.num_attention_heads, + num_gqa_groups=self.num_attention_heads, + attention_dropout=self.attention_dropout, + attn_mask_type='causal', + qkv_layout='bshd_bshd_bshd', + transpose_batch_sequence=False, + ) + x = attention(q, k, v, deterministic=deterministic) + x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) + x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + x = res + x + + res = x + # Fused LayerNorm + MLP + x, _ = te_flax.LayerNormMLP( + intermediate_dim=self.ffn_hidden_size, + epsilon=self.layernorm_eps, + use_bias=True, + activations=('gelu',), + intermediate_dropout_rate=0.0, + return_layernorm_output=False + )(x, deterministic=deterministic) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + return x + res +# TE_FUSED_LAYER_END + + +print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START") +# BENCHMARK_TE_FUSED_FP8_START +te_fused_fp8 = TEFusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) + +with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource): + params = te_fused_fp8.init(key, x, deterministic=False) + +print("TE Fused + TE Attention + FP8:") +time_te_fused_fp8 = speedometer( + te_fused_fp8.apply, + params, + x, + forward_kwargs={"deterministic": True}, + autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource}, + label="te_fused_fp8" +) +# BENCHMARK_TE_FUSED_FP8_END +print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n") + + +# ============================================================================= +# TE TransformerLayer + FP8: Ready-to-use Module +# ============================================================================= + +print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START") +# BENCHMARK_TE_TRANSFORMER_LAYER_START +te_transformer_layer = te_flax.TransformerLayer( + hidden_size=hidden_size, + mlp_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + mlp_activations=("gelu",), + self_attn_mask_type='causal', + layernorm_epsilon=1e-5, + use_bias=True, + attention_dropout=0.0, + intermediate_dropout=0.0, + hidden_dropout=0.0, + enable_relative_embedding=False, + self_attn_bias_type='no_bias', + dtype=jnp.bfloat16, + transpose_batch_sequence=False, +) + +with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource): + params = te_transformer_layer.init(key, x, deterministic=False) + +print("TE TransformerLayer + FP8:") +time_te_transformer_layer = speedometer( + te_transformer_layer.apply, + params, + x, + forward_kwargs={"deterministic": True}, + autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource}, + label="te_transformer_layer" +) +# BENCHMARK_TE_TRANSFORMER_LAYER_END +print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n") + +# Write summary CSV for RST documentation +with open("getting_started_jax_summary.csv", "w") as f: + f.write("Implementation,Time (ms),Speedup\n") + f.write(f"Baseline Flax,{time_baseline:.2f},1.00x\n") + f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n") + f.write(f"TE Unfused + TE Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n") + f.write(f"TE Unfused + TE Attention + FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n") + f.write(f"TE Fused + TE Attention + FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n") + f.write(f"TE TransformerLayer + FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n") +print("\nSummary written to getting_started_jax_summary.csv") diff --git a/docs/getting_started/getting_started_jax_summary.csv b/docs/getting_started/getting_started_jax_summary.csv new file mode 100644 index 00000000000..5c8c7f84fdc --- /dev/null +++ b/docs/getting_started/getting_started_jax_summary.csv @@ -0,0 +1,7 @@ +Implementation,Time (ms),Speedup +Baseline Flax,81.04,1.00x +TE Unfused,42.57,1.90x +TE Unfused + TE Attention,35.02,2.31x +TE Unfused + TE Attention + FP8,22.78,3.56x +TE Fused + TE Attention + FP8,24.01,3.38x +TE TransformerLayer + FP8,23.00,3.52x diff --git a/docs/getting_started/getting_started_pytorch.out b/docs/getting_started/getting_started_pytorch.out new file mode 100644 index 00000000000..41675f03ceb --- /dev/null +++ b/docs/getting_started/getting_started_pytorch.out @@ -0,0 +1,42 @@ +pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64 +pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64 +/usr/local/lib/python3.12/dist-packages/torch/library.py:357: UserWarning: Warning only once for all operators, other operators may also be overridden. + Overriding a previously registered kernel for the same operator and the same dispatch key + operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor + registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926 + dispatch key: ADInplaceOrView + previous kernel: no debug info + new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) + self.m.impl( +# BENCHMARK_BASELINE_OUTPUT_START +Baseline PyTorch: +Mean time: 48.507 ms +# BENCHMARK_BASELINE_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_OUTPUT_START +TE Unfused: +Mean time: 49.451 ms +# BENCHMARK_TE_UNFUSED_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START +TE Unfused + TE Attention: +Mean time: 35.776 ms +# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START +TE Unfused + TE Attention + FP8: +Mean time: 23.460 ms +# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END + +# BENCHMARK_TE_FUSED_FP8_OUTPUT_START +TE Fused + TE Attention + FP8: +Mean time: 23.037 ms +# BENCHMARK_TE_FUSED_FP8_OUTPUT_END + +# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START +TE TransformerLayer + FP8: +Mean time: 21.844 ms +# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END + + +Summary written to getting_started_pytorch_summary.csv diff --git a/docs/getting_started/getting_started_pytorch.py b/docs/getting_started/getting_started_pytorch.py new file mode 100644 index 00000000000..23264cc0d78 --- /dev/null +++ b/docs/getting_started/getting_started_pytorch.py @@ -0,0 +1,435 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Getting Started with Transformer Engine - PyTorch Example +========================================================== + +This example shows how to build a Transformer layer using PyTorch +and how to optimize it with Transformer Engine. +""" + +from typing import Optional +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +from getting_started_utils_pytorch import DotProductAttention, speedometer + + +# Configuration +hidden_size = 4096 +sequence_length = 2048 +batch_size = 8 +ffn_hidden_size = 16384 +num_attention_heads = 32 +dtype = torch.bfloat16 + +# Create synthetic data +x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype) + + +# ============================================================================= +# Baseline: Pure PyTorch Implementation +# ============================================================================= + +# BASELINE_MLP_START +class PyTorchMLP(torch.nn.Module): + """Feed-forward network in Transformer layer. + Built with plain PyTorch modules. + """ + hidden_size: int + ffn_hidden_size: int + + def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None: + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True) + self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = torch.nn.functional.gelu(x, approximate="tanh") + x = self.linear2(x) + return x +# BASELINE_MLP_END + + +# BASELINE_LAYER_START +class PyTorchTransformerLayer(torch.nn.Module): + """Basic Transformer layer using plain PyTorch modules.""" + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + layernorm_eps: float = 1e-5, + attention_dropout: float = 0.1, + hidden_dropout: float = 0.1, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.kv_channels = hidden_size // num_attention_heads + self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps) + self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True) + self.attention = DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=self.kv_channels, + attention_dropout=attention_dropout, + ) + self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True) + self.dropout = torch.nn.Dropout(hidden_dropout) + self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps) + self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size) + + def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + res = x + x = self.ln1(x) + + # Fused QKV projection + qkv = self.qkv_projection(x) + qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels) + q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3) + + x = self.attention(q, k, v, attention_mask) + x = self.projection(x) + x = self.dropout(x) + x = res + x + + # Second residual connection + res = x + x = self.ln2(x) + x = self.mlp(x) + + return x + res +# BASELINE_LAYER_END + + +print("# BENCHMARK_BASELINE_OUTPUT_START") +# BENCHMARK_BASELINE_START +baseline = PyTorchTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +).to(dtype=dtype).cuda() + +print("Baseline PyTorch:") +time_baseline = speedometer(baseline, x, forward_kwargs={"attention_mask": None}, label="baseline") +# BENCHMARK_BASELINE_END +print("# BENCHMARK_BASELINE_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused: Basic TE Modules +# ============================================================================= + +# TE_UNFUSED_MLP_START +class TEUnfusedMLP(torch.nn.Module): + """MLP using TE modules.""" + hidden_size: int + ffn_hidden_size: int + + def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None: + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True) + self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = torch.nn.functional.gelu(x, approximate='tanh') + x = self.linear2(x) + return x +# TE_UNFUSED_MLP_END + + +# TE_UNFUSED_LAYER_START +class TEUnfusedTransformerLayer(torch.nn.Module): + """Transformer layer using basic TE modules.""" + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + layernorm_eps: float = 1e-5, + attention_dropout: float = 0.1, + hidden_dropout: float = 0.1 + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.kv_channels = hidden_size // num_attention_heads + self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps) + self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True) + self.attention = DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=self.kv_channels, + attention_dropout=attention_dropout, + ) + self.projection = te.Linear(hidden_size, hidden_size, bias=True) + self.dropout1 = torch.nn.Dropout(hidden_dropout) + self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps) + self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size) + self.dropout2 = torch.nn.Dropout(hidden_dropout) + + def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + res = x + x = self.ln1(x) + + # Fused QKV projection + qkv = self.qkv_projection(x) + qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels) + q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3) + + x = self.attention(q, k, v, attention_mask) + x = self.projection(x) + x = self.dropout1(x) + x = res + x + + # Second residual connection + res = x + x = self.ln2(x) + x = self.mlp(x) + x = self.dropout2(x) + + return x + res +# TE_UNFUSED_LAYER_END + + +print("# BENCHMARK_TE_UNFUSED_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_START +te_unfused = TEUnfusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +).to(dtype=dtype).cuda() + +print("TE Unfused:") +time_te_unfused = speedometer(te_unfused, x, forward_kwargs={"attention_mask": None}, label="te_unfused") +# BENCHMARK_TE_UNFUSED_END +print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused + TE Attention +# ============================================================================= + +# TE_UNFUSED_ATTN_LAYER_START +class TEUnfusedAttnTransformerLayer(torch.nn.Module): + """Transformer layer using TE modules including TE DotProductAttention.""" + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + layernorm_eps: float = 1e-5, + attention_dropout: float = 0.1, + hidden_dropout: float = 0.1 + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.kv_channels = hidden_size // num_attention_heads + self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps) + self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True) + self.attention = te.DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=self.kv_channels, + attention_dropout=attention_dropout, + attn_mask_type='causal', + ) + self.projection = te.Linear(hidden_size, hidden_size, bias=True) + self.dropout1 = torch.nn.Dropout(hidden_dropout) + self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps) + self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size) + self.dropout2 = torch.nn.Dropout(hidden_dropout) + + def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + res = x + x = self.ln1(x) + + # Fused QKV projection + qkv = self.qkv_projection(x) + qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels) + q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3) + + x = self.attention(q, k, v, attention_mask) + x = self.projection(x) + x = self.dropout1(x) + x = res + x + + # Second residual connection + res = x + x = self.ln2(x) + x = self.mlp(x) + x = self.dropout2(x) + + return x + res +# TE_UNFUSED_ATTN_LAYER_END + + +print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_ATTN_START +te_unfused_attn = TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +).to(dtype=dtype).cuda() + +print("TE Unfused + TE Attention:") +time_te_unfused_attn = speedometer(te_unfused_attn, x, forward_kwargs={"attention_mask": None}, label="te_unfused_attn") +# BENCHMARK_TE_UNFUSED_ATTN_END +print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused + FP8 +# ============================================================================= + +print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_FP8_START +recipe = DelayedScaling( + fp8_format=Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max" +) + +te_unfused_fp8 = TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +).to(dtype=dtype).cuda() + +print("TE Unfused + TE Attention + FP8:") +time_te_unfused_fp8 = speedometer( + te_unfused_fp8, + x, + forward_kwargs={"attention_mask": None}, + autocast_kwargs={"enabled": True, "recipe": recipe}, + label="te_unfused_fp8" +) +# BENCHMARK_TE_UNFUSED_FP8_END +print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n") + + +# ============================================================================= +# TE Fused + FP8: Optimized Modules with FP8 +# ============================================================================= + +# TE_FUSED_LAYER_START +class TEFusedTransformerLayer(torch.nn.Module): + """Transformer layer using fused TE modules for better performance.""" + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + layernorm_eps: float = 1e-5, + attention_dropout: float = 0.1, + hidden_dropout: float = 0.1 + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.kv_channels = hidden_size // num_attention_heads + + # Fused LayerNorm + QKV projection + self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True) + self.attention = te.DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=self.kv_channels, + attention_dropout=attention_dropout, + attn_mask_type='causal', + ) + self.projection = te.Linear(hidden_size, hidden_size, bias=True) + self.dropout1 = torch.nn.Dropout(hidden_dropout) + + # Fused LayerNorm + MLP + self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True) + self.dropout2 = torch.nn.Dropout(hidden_dropout) + + def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + res = x + + # Fused LayerNorm + QKV projection + qkv = self.ln_qkv(x) + qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels) + q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3) + + x = self.attention(q, k, v, attention_mask) + x = self.projection(x) + x = self.dropout1(x) + x = res + x + + # Fused LayerNorm + MLP + res = x + x = self.ln_mlp(x) + x = self.dropout2(x) + + return x + res +# TE_FUSED_LAYER_END + + +print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START") +# BENCHMARK_TE_FUSED_FP8_START +te_fused_fp8 = TEFusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +).to(dtype=dtype).cuda() + +print("TE Fused + TE Attention + FP8:") +time_te_fused_fp8 = speedometer( + te_fused_fp8, + x, + forward_kwargs={"attention_mask": None}, + autocast_kwargs={"enabled": True, "recipe": recipe}, + label="te_fused_fp8" +) +# BENCHMARK_TE_FUSED_FP8_END +print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n") + + +# ============================================================================= +# TE TransformerLayer + FP8: Ready-to-use Module +# ============================================================================= + +print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START") +# BENCHMARK_TE_TRANSFORMER_LAYER_START +te_transformer_layer = te.TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + self_attn_mask_type='causal', + layernorm_epsilon=1e-5, + bias=True, + hidden_dropout=0.0, + attention_dropout=0.0, +).to(dtype=dtype).cuda() + +print("TE TransformerLayer + FP8:") +time_te_transformer_layer = speedometer( + te_transformer_layer, + x, + forward_kwargs={"attention_mask": None}, + autocast_kwargs={"enabled": True, "recipe": recipe}, + label="te_transformer_layer" +) +# BENCHMARK_TE_TRANSFORMER_LAYER_END +print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n") + + +# Write summary CSV for RST documentation +with open("getting_started_pytorch_summary.csv", "w") as f: + f.write("Implementation,Time (ms),Speedup\n") + f.write(f"Baseline PyTorch,{time_baseline:.2f},1.00x\n") + f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n") + f.write(f"TE Unfused + TE Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n") + f.write(f"TE Unfused + TE Attention + FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n") + f.write(f"TE Fused + TE Attention + FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n") + f.write(f"TE TransformerLayer + FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n") +print("\nSummary written to getting_started_pytorch_summary.csv") diff --git a/docs/getting_started/getting_started_pytorch_summary.csv b/docs/getting_started/getting_started_pytorch_summary.csv new file mode 100644 index 00000000000..89d4a076ecf --- /dev/null +++ b/docs/getting_started/getting_started_pytorch_summary.csv @@ -0,0 +1,7 @@ +Implementation,Time (ms),Speedup +Baseline PyTorch,48.51,1.00x +TE Unfused,49.45,0.98x +TE Unfused + TE Attention,35.78,1.36x +TE Unfused + TE Attention + FP8,23.46,2.07x +TE Fused + TE Attention + FP8,23.04,2.11x +TE TransformerLayer + FP8,21.84,2.22x diff --git a/docs/getting_started/getting_started_utils_jax.py b/docs/getting_started/getting_started_utils_jax.py new file mode 100644 index 00000000000..edd9771b898 --- /dev/null +++ b/docs/getting_started/getting_started_utils_jax.py @@ -0,0 +1,77 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Utility functions for Getting Started with Transformer Engine - JAX +==================================================================== + +Helper classes and functions for the getting started examples. +""" + +import time +from typing import Callable, Any, Optional + +import jax +import jax.numpy as jnp +from flax import linen as nn +import transformer_engine.jax as te +from transformer_engine.jax.sharding import MeshResource + + +def speedometer( + apply_fn: Callable, + params: Any, + x: jnp.ndarray, + forward_kwargs: dict = {}, + autocast_kwargs: Optional[dict] = None, + timing_iters: int = 100, + warmup_iters: int = 10, + label: str = "benchmark", +) -> float: + """Measure average forward + backward pass time for a JAX module. + + Args: + apply_fn: JIT-compiled apply function + params: Model parameters + x: Input tensor + forward_kwargs: Additional kwargs for forward pass + autocast_kwargs: Kwargs for te.autocast context + timing_iters: Number of timing iterations + warmup_iters: Number of warmup iterations + label: Optional label for logging + + Returns: + Average time per iteration in milliseconds + """ + if autocast_kwargs is None: + autocast_kwargs = {"enabled": False} + else: + autocast_kwargs = dict(autocast_kwargs) + autocast_kwargs.setdefault("mesh_resource", MeshResource()) + + def loss_fn(params, x): + y = apply_fn(params, x, **forward_kwargs) + return jnp.sum(y) + + # JIT compile within autocast context + with te.autocast(**autocast_kwargs): + grad_fn = jax.jit(jax.value_and_grad(loss_fn)) + + # Warmup runs + for _ in range(warmup_iters): + loss, grads = grad_fn(params, x) + jax.block_until_ready((loss, grads)) + + # Timing runs + times = [] + for _ in range(timing_iters): + start = time.perf_counter() + loss, grads = grad_fn(params, x) + jax.block_until_ready((loss, grads)) + times.append(time.perf_counter() - start) + + avg_time = sum(times) / len(times) * 1000 + print(f"Mean time: {avg_time:.3f} ms") + return avg_time + diff --git a/docs/getting_started/getting_started_utils_pytorch.py b/docs/getting_started/getting_started_utils_pytorch.py new file mode 100644 index 00000000000..0b2a9cd3e5b --- /dev/null +++ b/docs/getting_started/getting_started_utils_pytorch.py @@ -0,0 +1,125 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Utility functions for Getting Started with Transformer Engine - PyTorch +======================================================================== + +Helper classes and functions for the getting started examples. +""" + +import math +from typing import Optional +import torch +import transformer_engine.pytorch as te + + +def speedometer( + module: torch.nn.Module, + x: torch.Tensor, + forward_kwargs: dict = {}, + autocast_kwargs: Optional[dict] = None, + timing_iters: int = 100, + warmup_iters: int = 10, + label: str = "benchmark", +) -> float: + """Measure average forward + backward pass time for a PyTorch module. + + Args: + module: PyTorch module to benchmark + x: Input tensor + forward_kwargs: Additional kwargs for forward pass + autocast_kwargs: Kwargs for te.autocast context + timing_iters: Number of timing iterations + warmup_iters: Number of warmup iterations + label: Optional label for logging + + Returns: + Average time per iteration in milliseconds + """ + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + if autocast_kwargs is None: + autocast_kwargs = {"enabled": False} + + # Warmup runs + torch.cuda.synchronize() + for _ in range(warmup_iters): + with te.autocast(**autocast_kwargs): + y = module(x, **forward_kwargs) + loss = y.sum() + loss.backward() + torch.cuda.synchronize() + + # Timing runs + start.record() + for _ in range(timing_iters): + with te.autocast(**autocast_kwargs): + y = module(x, **forward_kwargs) + loss = y.sum() + loss.backward() + end.record() + torch.cuda.synchronize() + + avg_time = start.elapsed_time(end) / timing_iters + print(f"Mean time: {avg_time:.3f} ms") + return avg_time + + +class DotProductAttention(torch.nn.Module): + """Attention operation in Transformer layer. + + Built with plain PyTorch modules. + """ + + def __init__( + self, + num_attention_heads: int, + kv_channels: int, + attention_dropout: float, + ) -> None: + super().__init__() + self.projection_size = kv_channels * num_attention_heads + self.hidden_size_per_attention_head = kv_channels + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.dropout = torch.nn.Dropout(attention_dropout) + + def masked_softmax(self, inp: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: + if mask is not None: + inp.masked_fill_(mask, -10000.0) + return torch.nn.Softmax(dim=-1)(inp) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + b = query.size(1) + np = query.size(2) + sq = query.size(0) + sk = key.size(0) + hn = value.size(3) + + query = query.view(sq, b * np, -1) + key = key.view(sk, b * np, -1) + + bmm1 = ( + torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor + ) + + attention_scores = bmm1.view(b, np, sq, sk) + attention_probs = self.masked_softmax(attention_scores, attention_mask) + attention_probs = self.dropout(attention_probs) + + value = value.view(sk, b * np, -1) + attention_probs = attention_probs.view(b * np, sq, -1) + context = torch.bmm(attention_probs, value.transpose(0, 1)) + context = context.view(b, np, sq, hn) + context = context.permute(2, 0, 1, 3).contiguous() + context = context.view(sq, b, self.projection_size) + + return context + diff --git a/docs/getting_started/index.rst b/docs/getting_started/index.rst new file mode 100644 index 00000000000..e704497c29e --- /dev/null +++ b/docs/getting_started/index.rst @@ -0,0 +1,564 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Getting Started +=============== + +Overview +-------- + +Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, +providing better performance with lower memory utilization in both training and inference. +It provides support for 8-bit floating point (FP8) precision on Hopper and Ada GPUs, as well as +8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs. + +TE implements a collection of highly optimized building blocks for popular Transformer +architectures and exposes an automatic-mixed-precision-like API that can be used seamlessly +with your deep learning code. + + +Currently two frameworks are supported: PyTorch and JAX. + +.. tabs:: + + .. tab:: PyTorch + + Basic knowledge of PyTorch is recommended: + + - `PyTorch Tutorials `_ + - `PyTorch Documentation `_ + + .. tab:: JAX + + We recommend understanding the basics of JAX first: + + - `Thinking in JAX `_ + - `JAX 101 `_ + - `Key concepts in JAX `_ + - `Flax 101 `_ + + +Baseline: Pure Framework Implementation +--------------------------------------- + +Let's build a Transformer decoder layer! + +We'll create a basic GPT-style layer with causal masking, +which prevents each position from attending to future positions. This will be our baseline +for later comparisons with Transformer Engine. + +.. raw:: html + :file: transformer_layer.svg + +.. raw:: html + +

Structure of a GPT decoder layer

+ +We construct the components as follows: + +.. tabs:: + + .. tab:: PyTorch + + * **LayerNorm**: ``torch.nn.LayerNorm`` + * **QKV Projection**: ``torch.nn.Linear`` (fused Q, K, V into single layer 3x larger) + * **DotProductAttention**: Custom implementation using ``torch.bmm`` + * **Projection**: ``torch.nn.Linear`` + * **Dropout**: ``torch.nn.Dropout`` + * **MLP**: Two ``torch.nn.Linear`` layers with ``torch.nn.functional.gelu`` activation + + .. tab:: JAX + + * **LayerNorm**: ``nn.LayerNorm`` + * **QKV Projection**: ``nn.Dense`` (fused Q, K, V into single layer 3x larger) + * **DotProductAttention**: ``nn.dot_product_attention`` + * **Projection**: ``nn.Dense`` + * **Dropout**: ``nn.Dropout`` + * **MLP**: Two ``nn.Dense`` layers with ``nn.gelu`` activation + +Putting it all together: + +.. tabs:: + + .. tab:: PyTorch + First, define the MLP block: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BASELINE_MLP_START + :end-before: # BASELINE_MLP_END + + Now, putting it all together into a GPT decoder layer: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BASELINE_LAYER_START + :end-before: # BASELINE_LAYER_END + + Benchmark the baseline implementation: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_BASELINE_START + :end-before: # BENCHMARK_BASELINE_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_BASELINE_OUTPUT_START + :end-before: # BENCHMARK_BASELINE_OUTPUT_END + + .. tab:: JAX + + First, define the MLP block: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BASELINE_MLP_START + :end-before: # BASELINE_MLP_END + + Now, putting it all together into a GPT decoder layer: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BASELINE_LAYER_START + :end-before: # BASELINE_LAYER_END + + Benchmark the baseline implementation: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_BASELINE_START + :end-before: # BENCHMARK_BASELINE_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_BASELINE_OUTPUT_START + :end-before: # BENCHMARK_BASELINE_OUTPUT_END + + +TE Unfused: Basic TE Modules +---------------------------- + +Now let's replace the standard framework modules with TE equivalents. +This is the simplest way to start using Transformer Engine. + +.. tabs:: + + .. tab:: PyTorch + + Replace PyTorch modules with TE equivalents: + + .. code-block:: python + + import transformer_engine.pytorch as te + + Mapping: + + * ``torch.nn.Linear`` → ``te.Linear`` + * ``torch.nn.LayerNorm`` → ``te.LayerNorm`` + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # TE_UNFUSED_MLP_START + :end-before: # TE_UNFUSED_MLP_END + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # TE_UNFUSED_LAYER_START + :end-before: # TE_UNFUSED_LAYER_END + + Benchmark the TE unfused implementation: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_START + :end-before: # BENCHMARK_TE_UNFUSED_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_OUTPUT_END + + .. tab:: JAX + + Replace Flax modules with TE equivalents: + + .. code-block:: python + + import transformer_engine.jax as te + import transformer_engine.jax.flax as te_flax + + Mapping: + + * ``nn.Dense`` → ``te_flax.DenseGeneral`` + * ``nn.LayerNorm`` → ``te_flax.LayerNorm`` + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # TE_UNFUSED_MLP_START + :end-before: # TE_UNFUSED_MLP_END + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # TE_UNFUSED_LAYER_START + :end-before: # TE_UNFUSED_LAYER_END + + Benchmark the TE unfused implementation: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_START + :end-before: # BENCHMARK_TE_UNFUSED_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_OUTPUT_END + + +TE Unfused + TE Attention +------------------------- + +Now let's also replace the attention mechanism with TE's optimized ``DotProductAttention``. +This provides Flash Attention and other optimizations. + +.. tabs:: + + .. tab:: PyTorch + + Replace the custom attention with TE's optimized implementation: + + * Custom ``DotProductAttention`` → ``te.DotProductAttention`` + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # TE_UNFUSED_ATTN_LAYER_START + :end-before: # TE_UNFUSED_ATTN_LAYER_END + + Benchmark TE Unfused with TE Attention: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_ATTN_START + :end-before: # BENCHMARK_TE_UNFUSED_ATTN_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END + + .. tab:: JAX + + Replace Flax's attention with TE's optimized implementation: + + * ``nn.dot_product_attention`` → ``TEDotProductAttention`` + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # TE_UNFUSED_ATTN_LAYER_START + :end-before: # TE_UNFUSED_ATTN_LAYER_END + + Benchmark TE Unfused with TE Attention: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_ATTN_START + :end-before: # BENCHMARK_TE_UNFUSED_ATTN_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END + + +TE Unfused + TE Attention + FP8 +------------------------------- + +Now let's combine TE modules with TE Attention and enable FP8 precision. +Wrap your code within an ``autocast`` context manager to enable FP8. +This provides significant speedups on supported hardware (Hopper, Ada, Blackwell GPUs). + +.. tabs:: + + .. tab:: PyTorch + + .. code-block:: python + + from transformer_engine.common.recipe import Format, DelayedScaling + + recipe = DelayedScaling( + fp8_format=Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max" + ) + + with te.autocast(enabled=True, recipe=recipe): + y = te_unfused(x, attention_mask=None) + + .. note:: + + The ``autocast`` should only wrap the forward pass and must exit before + starting a backward pass. + + Benchmark TE Unfused with FP8: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_FP8_START + :end-before: # BENCHMARK_TE_UNFUSED_FP8_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END + + .. tab:: JAX + + .. code-block:: python + + from transformer_engine.common.recipe import Format, DelayedScaling + + recipe = DelayedScaling( + fp8_format=Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max" + ) + + with te.autocast(enabled=True, recipe=recipe): + params = te_unfused.init(key, x, deterministic=False) + y = te_unfused.apply(params, x, deterministic=True) + + .. important:: + + When using FP8 in JAX, the model **must be initialized within the autocast context** + to create the ``fp8_metas`` collection. + + Benchmark TE Unfused with FP8: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_FP8_START + :end-before: # BENCHMARK_TE_UNFUSED_FP8_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END + + +TE Fused + TE Attention + FP8: Optimized Modules +------------------------------------------------ + +Fused modules use kernel fusion to combine multiple operations. +While speedups are modest on a single GPU, they scale better in multi-GPU setups. +Combined with TE Attention and FP8, this delivers peak performance. + +.. tabs:: + + .. tab:: PyTorch + + Fused modules available: + + * ``te.LayerNormLinear`` - fuses LayerNorm + Linear + * ``te.LayerNormMLP`` - fuses LayerNorm + MLP + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # TE_FUSED_LAYER_START + :end-before: # TE_FUSED_LAYER_END + + Benchmark TE Fused with FP8: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_FUSED_FP8_START + :end-before: # BENCHMARK_TE_FUSED_FP8_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_FUSED_FP8_OUTPUT_START + :end-before: # BENCHMARK_TE_FUSED_FP8_OUTPUT_END + + .. tab:: JAX + + Fused modules available: + + * ``te_flax.LayerNormDenseGeneral`` - fuses LayerNorm + Dense + * ``te_flax.LayerNormMLP`` - fuses LayerNorm + MLP + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # TE_FUSED_LAYER_START + :end-before: # TE_FUSED_LAYER_END + + Benchmark TE Fused with FP8: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_FUSED_FP8_START + :end-before: # BENCHMARK_TE_FUSED_FP8_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_FUSED_FP8_OUTPUT_START + :end-before: # BENCHMARK_TE_FUSED_FP8_OUTPUT_END + + +TE TransformerLayer + FP8: Ready-to-use Module +---------------------------------------------- + +For the simplest integration, Transformer Engine provides a ready-to-use ``TransformerLayer`` +module that includes all optimizations out of the box. + +.. tabs:: + + .. tab:: PyTorch + + Just use ``te.TransformerLayer`` - it handles everything for you: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_START + :end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START + :end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END + + .. tab:: JAX + + Just use ``te_flax.TransformerLayer`` - it handles everything for you: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_START + :end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START + :end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END + + +Benchmark Summary +----------------- + +The table below summarizes the performance improvements achieved with Transformer Engine +on an NVIDIA H100 GPU. Results may vary depending on hardware and configuration. While this +tutorial focuses on a simple single-GPU scenario, features like fused layers can provide +additional benefits in more complex setups such as multi-GPU training. + +.. tabs:: + + .. tab:: PyTorch + + .. csv-table:: + :header-rows: 1 + :widths: 40, 20, 20 + :file: getting_started_pytorch_summary.csv + + .. tab:: JAX + + .. csv-table:: + :header-rows: 1 + :widths: 40, 20, 20 + :file: getting_started_jax_summary.csv diff --git a/docs/getting_started/transformer_layer.svg b/docs/getting_started/transformer_layer.svg new file mode 100644 index 00000000000..28ba3dd3869 --- /dev/null +++ b/docs/getting_started/transformer_layer.svg @@ -0,0 +1,82 @@ + + + + + + + + + + + + + + + + + + LayerNorm + + + + + + QKV Projection + + + + + + Dot Product + Attention + + + + + + Projection + + + + + + Dropout + + + + + + + + + + + + + + + + + LayerNorm + + + + + + MLP + + + + + + + + + + + + diff --git a/docs/index.rst b/docs/index.rst index 7a3ab9f6fd5..d9f7b03859e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,7 +29,7 @@ Transformer Engine documentation :caption: Getting Started installation - getting_started + getting_started/index faq .. toctree:: From 10fee292daa476da446a128f0d4285377375e64d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 00:15:39 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/getting_started/getting_started_jax.py | 110 +++++++---- .../getting_started_pytorch.py | 180 ++++++++++++------ .../getting_started_utils_jax.py | 1 - .../getting_started_utils_pytorch.py | 1 - 4 files changed, 197 insertions(+), 95 deletions(-) diff --git a/docs/getting_started/getting_started_jax.py b/docs/getting_started/getting_started_jax.py index dc905737d4f..a2ce3c0ec7b 100644 --- a/docs/getting_started/getting_started_jax.py +++ b/docs/getting_started/getting_started_jax.py @@ -41,11 +41,13 @@ # Baseline: Pure Flax Implementation # ============================================================================= + # BASELINE_MLP_START class FlaxMLP(nn.Module): """Feed-forward network in Transformer layer. Built with plain Flax modules. """ + hidden_size: int ffn_hidden_size: int @@ -55,12 +57,15 @@ def __call__(self, x: jnp.ndarray) -> jnp.ndarray: x = nn.gelu(x, approximate=True) x = nn.Dense(features=self.hidden_size, use_bias=True)(x) return x + + # BASELINE_MLP_END # BASELINE_LAYER_START class FlaxTransformerLayer(nn.Module): """Basic Transformer layer using plain Flax modules.""" + hidden_size: int ffn_hidden_size: int num_attention_heads: int @@ -75,7 +80,7 @@ def __call__( self, x: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = False + deterministic: bool = False, ) -> jnp.ndarray: if attention_mask is None: attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_) @@ -85,12 +90,14 @@ def __call__( # Fused QKV projection qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x) - qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels) + qkv = qkv.reshape( + qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels + ) q, k, v = jnp.split(qkv, 3, axis=3) dropout_rng = None if not deterministic and self.attention_dropout > 0: - dropout_rng = self.make_rng('dropout') + dropout_rng = self.make_rng("dropout") x = nn.dot_product_attention( query=q, @@ -114,6 +121,8 @@ def __call__( x = mlp(x) return x + res + + # BASELINE_LAYER_END @@ -127,7 +136,9 @@ def __call__( params = baseline.init(key, x, deterministic=False) print("Baseline Flax:") -time_baseline = speedometer(baseline.apply, params, x, forward_kwargs={"deterministic": True}, label="baseline") +time_baseline = speedometer( + baseline.apply, params, x, forward_kwargs={"deterministic": True}, label="baseline" +) # BENCHMARK_BASELINE_END print("# BENCHMARK_BASELINE_OUTPUT_END\n") @@ -136,9 +147,11 @@ def __call__( # TE Unfused: Basic TE Modules # ============================================================================= + # TE_UNFUSED_MLP_START class TEUnfusedMLP(nn.Module): """MLP using TE modules.""" + hidden_size: int ffn_hidden_size: int @@ -146,15 +159,18 @@ class TEUnfusedMLP(nn.Module): def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray: x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True)(x) x = x.reshape(*x.shape[:-1], 1, x.shape[-1]) - x = te.activation.activation(x, activation_type=('gelu',)) + x = te.activation.activation(x, activation_type=("gelu",)) x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x) return x + + # TE_UNFUSED_MLP_END # TE_UNFUSED_LAYER_START class TEUnfusedTransformerLayer(nn.Module): """Transformer layer using basic TE modules (without TE attention).""" + hidden_size: int ffn_hidden_size: int num_attention_heads: int @@ -169,7 +185,7 @@ def __call__( self, x: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = False + deterministic: bool = False, ) -> jnp.ndarray: if attention_mask is None: attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_) @@ -178,12 +194,14 @@ def __call__( x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x) qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x) - qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels) + qkv = qkv.reshape( + qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels + ) q, k, v = jnp.split(qkv, 3, axis=3) dropout_rng = None if not deterministic and self.attention_dropout > 0: - dropout_rng = self.make_rng('dropout') + dropout_rng = self.make_rng("dropout") x = nn.dot_product_attention( query=q, @@ -209,6 +227,8 @@ def __call__( x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) return x + res + + # TE_UNFUSED_LAYER_END @@ -222,7 +242,9 @@ def __call__( params = te_unfused.init(key, x, deterministic=False) print("TE Unfused:") -time_te_unfused = speedometer(te_unfused.apply, params, x, forward_kwargs={"deterministic": True}, label="te_unfused") +time_te_unfused = speedometer( + te_unfused.apply, params, x, forward_kwargs={"deterministic": True}, label="te_unfused" +) # BENCHMARK_TE_UNFUSED_END print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n") @@ -231,9 +253,11 @@ def __call__( # TE Unfused + TE Attention # ============================================================================= + # TE_UNFUSED_ATTN_LAYER_START class TEUnfusedAttnTransformerLayer(nn.Module): """Transformer layer using TE modules including TE DotProductAttention.""" + hidden_size: int ffn_hidden_size: int num_attention_heads: int @@ -248,13 +272,17 @@ def __call__( self, x: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = False + deterministic: bool = False, ) -> jnp.ndarray: res = x x = te_flax.LayerNorm(epsilon=self.layernorm_eps, dtype=jnp.bfloat16)(x) - qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True, dtype=jnp.bfloat16)(x) - qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels) + qkv = te_flax.DenseGeneral( + features=3 * self.hidden_size, use_bias=True, dtype=jnp.bfloat16 + )(x) + qkv = qkv.reshape( + qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels + ) q, k, v = jnp.split(qkv, 3, axis=3) attention = te_flax.DotProductAttention( @@ -262,7 +290,7 @@ def __call__( num_attention_heads=self.num_attention_heads, num_gqa_groups=self.num_attention_heads, attention_dropout=self.attention_dropout, - attn_mask_type='causal', + attn_mask_type="causal", transpose_batch_sequence=False, ) x = attention(q, k, v, deterministic=deterministic) @@ -279,6 +307,8 @@ def __call__( x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) return x + res + + # TE_UNFUSED_ATTN_LAYER_END @@ -300,7 +330,7 @@ def __call__( x, forward_kwargs={"deterministic": True}, autocast_kwargs={"enabled": False, "mesh_resource": mesh_resource}, - label="te_unfused_attn" + label="te_unfused_attn", ) # BENCHMARK_TE_UNFUSED_ATTN_END print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n") @@ -312,11 +342,7 @@ def __call__( print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START") # BENCHMARK_TE_UNFUSED_FP8_START -recipe = DelayedScaling( - fp8_format=Format.HYBRID, - amax_history_len=16, - amax_compute_algo="max" -) +recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max") te_unfused_fp8 = TEUnfusedAttnTransformerLayer( hidden_size=hidden_size, @@ -334,7 +360,7 @@ def __call__( x, forward_kwargs={"deterministic": True}, autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource}, - label="te_unfused_fp8" + label="te_unfused_fp8", ) # BENCHMARK_TE_UNFUSED_FP8_END print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n") @@ -344,9 +370,11 @@ def __call__( # TE Fused + FP8: Optimized Modules with FP8 # ============================================================================= + # TE_FUSED_LAYER_START class TEFusedTransformerLayer(nn.Module): """Transformer layer using fused TE modules for better performance.""" + hidden_size: int ffn_hidden_size: int num_attention_heads: int @@ -361,7 +389,7 @@ def __call__( self, x: jnp.ndarray, attention_mask: Optional[jnp.ndarray] = None, - deterministic: bool = False + deterministic: bool = False, ) -> jnp.ndarray: res = x @@ -370,7 +398,7 @@ def __call__( features=3 * self.hidden_size, epsilon=self.layernorm_eps, use_bias=True, - return_layernorm_output=False + return_layernorm_output=False, )(x) qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.num_attention_heads, self.kv_channels) q, k, v = qkv[:, :, 0, :, :], qkv[:, :, 1, :, :], qkv[:, :, 2, :, :] @@ -380,8 +408,8 @@ def __call__( num_attention_heads=self.num_attention_heads, num_gqa_groups=self.num_attention_heads, attention_dropout=self.attention_dropout, - attn_mask_type='causal', - qkv_layout='bshd_bshd_bshd', + attn_mask_type="causal", + qkv_layout="bshd_bshd_bshd", transpose_batch_sequence=False, ) x = attention(q, k, v, deterministic=deterministic) @@ -397,13 +425,15 @@ def __call__( intermediate_dim=self.ffn_hidden_size, epsilon=self.layernorm_eps, use_bias=True, - activations=('gelu',), + activations=("gelu",), intermediate_dropout_rate=0.0, - return_layernorm_output=False + return_layernorm_output=False, )(x, deterministic=deterministic) x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) return x + res + + # TE_FUSED_LAYER_END @@ -425,7 +455,7 @@ def __call__( x, forward_kwargs={"deterministic": True}, autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource}, - label="te_fused_fp8" + label="te_fused_fp8", ) # BENCHMARK_TE_FUSED_FP8_END print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n") @@ -442,14 +472,14 @@ def __call__( mlp_hidden_size=ffn_hidden_size, num_attention_heads=num_attention_heads, mlp_activations=("gelu",), - self_attn_mask_type='causal', + self_attn_mask_type="causal", layernorm_epsilon=1e-5, use_bias=True, attention_dropout=0.0, intermediate_dropout=0.0, hidden_dropout=0.0, enable_relative_embedding=False, - self_attn_bias_type='no_bias', + self_attn_bias_type="no_bias", dtype=jnp.bfloat16, transpose_batch_sequence=False, ) @@ -464,7 +494,7 @@ def __call__( x, forward_kwargs={"deterministic": True}, autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource}, - label="te_transformer_layer" + label="te_transformer_layer", ) # BENCHMARK_TE_TRANSFORMER_LAYER_END print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n") @@ -474,8 +504,20 @@ def __call__( f.write("Implementation,Time (ms),Speedup\n") f.write(f"Baseline Flax,{time_baseline:.2f},1.00x\n") f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n") - f.write(f"TE Unfused + TE Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n") - f.write(f"TE Unfused + TE Attention + FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n") - f.write(f"TE Fused + TE Attention + FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n") - f.write(f"TE TransformerLayer + FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n") + f.write( + "TE Unfused + TE" + f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n" + ) + f.write( + "TE Unfused + TE Attention +" + f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n" + ) + f.write( + "TE Fused + TE Attention +" + f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n" + ) + f.write( + "TE TransformerLayer +" + f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n" + ) print("\nSummary written to getting_started_jax_summary.csv") diff --git a/docs/getting_started/getting_started_pytorch.py b/docs/getting_started/getting_started_pytorch.py index 23264cc0d78..eed21ee0aad 100644 --- a/docs/getting_started/getting_started_pytorch.py +++ b/docs/getting_started/getting_started_pytorch.py @@ -34,11 +34,13 @@ # Baseline: Pure PyTorch Implementation # ============================================================================= + # BASELINE_MLP_START class PyTorchMLP(torch.nn.Module): """Feed-forward network in Transformer layer. Built with plain PyTorch modules. """ + hidden_size: int ffn_hidden_size: int @@ -54,6 +56,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: x = torch.nn.functional.gelu(x, approximate="tanh") x = self.linear2(x) return x + + # BASELINE_MLP_END @@ -85,7 +89,9 @@ def __init__( self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps) self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size) - def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: res = x x = self.ln1(x) @@ -105,16 +111,22 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None x = self.mlp(x) return x + res + + # BASELINE_LAYER_END print("# BENCHMARK_BASELINE_OUTPUT_START") # BENCHMARK_BASELINE_START -baseline = PyTorchTransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - num_attention_heads=num_attention_heads, -).to(dtype=dtype).cuda() +baseline = ( + PyTorchTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() +) print("Baseline PyTorch:") time_baseline = speedometer(baseline, x, forward_kwargs={"attention_mask": None}, label="baseline") @@ -126,9 +138,11 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None # TE Unfused: Basic TE Modules # ============================================================================= + # TE_UNFUSED_MLP_START class TEUnfusedMLP(torch.nn.Module): """MLP using TE modules.""" + hidden_size: int ffn_hidden_size: int @@ -141,9 +155,11 @@ def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None: def forward(self, x: torch.Tensor) -> torch.Tensor: x = self.linear1(x) - x = torch.nn.functional.gelu(x, approximate='tanh') + x = torch.nn.functional.gelu(x, approximate="tanh") x = self.linear2(x) return x + + # TE_UNFUSED_MLP_END @@ -158,7 +174,7 @@ def __init__( num_attention_heads: int, layernorm_eps: float = 1e-5, attention_dropout: float = 0.1, - hidden_dropout: float = 0.1 + hidden_dropout: float = 0.1, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -176,7 +192,9 @@ def __init__( self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size) self.dropout2 = torch.nn.Dropout(hidden_dropout) - def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: res = x x = self.ln1(x) @@ -197,19 +215,27 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None x = self.dropout2(x) return x + res + + # TE_UNFUSED_LAYER_END print("# BENCHMARK_TE_UNFUSED_OUTPUT_START") # BENCHMARK_TE_UNFUSED_START -te_unfused = TEUnfusedTransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - num_attention_heads=num_attention_heads, -).to(dtype=dtype).cuda() +te_unfused = ( + TEUnfusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() +) print("TE Unfused:") -time_te_unfused = speedometer(te_unfused, x, forward_kwargs={"attention_mask": None}, label="te_unfused") +time_te_unfused = speedometer( + te_unfused, x, forward_kwargs={"attention_mask": None}, label="te_unfused" +) # BENCHMARK_TE_UNFUSED_END print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n") @@ -218,6 +244,7 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None # TE Unfused + TE Attention # ============================================================================= + # TE_UNFUSED_ATTN_LAYER_START class TEUnfusedAttnTransformerLayer(torch.nn.Module): """Transformer layer using TE modules including TE DotProductAttention.""" @@ -229,7 +256,7 @@ def __init__( num_attention_heads: int, layernorm_eps: float = 1e-5, attention_dropout: float = 0.1, - hidden_dropout: float = 0.1 + hidden_dropout: float = 0.1, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -240,7 +267,7 @@ def __init__( num_attention_heads=num_attention_heads, kv_channels=self.kv_channels, attention_dropout=attention_dropout, - attn_mask_type='causal', + attn_mask_type="causal", ) self.projection = te.Linear(hidden_size, hidden_size, bias=True) self.dropout1 = torch.nn.Dropout(hidden_dropout) @@ -248,7 +275,9 @@ def __init__( self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size) self.dropout2 = torch.nn.Dropout(hidden_dropout) - def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: res = x x = self.ln1(x) @@ -269,19 +298,27 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None x = self.dropout2(x) return x + res + + # TE_UNFUSED_ATTN_LAYER_END print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START") # BENCHMARK_TE_UNFUSED_ATTN_START -te_unfused_attn = TEUnfusedAttnTransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - num_attention_heads=num_attention_heads, -).to(dtype=dtype).cuda() +te_unfused_attn = ( + TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() +) print("TE Unfused + TE Attention:") -time_te_unfused_attn = speedometer(te_unfused_attn, x, forward_kwargs={"attention_mask": None}, label="te_unfused_attn") +time_te_unfused_attn = speedometer( + te_unfused_attn, x, forward_kwargs={"attention_mask": None}, label="te_unfused_attn" +) # BENCHMARK_TE_UNFUSED_ATTN_END print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n") @@ -292,25 +329,25 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START") # BENCHMARK_TE_UNFUSED_FP8_START -recipe = DelayedScaling( - fp8_format=Format.HYBRID, - amax_history_len=16, - amax_compute_algo="max" +recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max") + +te_unfused_fp8 = ( + TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() ) -te_unfused_fp8 = TEUnfusedAttnTransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - num_attention_heads=num_attention_heads, -).to(dtype=dtype).cuda() - print("TE Unfused + TE Attention + FP8:") time_te_unfused_fp8 = speedometer( te_unfused_fp8, x, forward_kwargs={"attention_mask": None}, autocast_kwargs={"enabled": True, "recipe": recipe}, - label="te_unfused_fp8" + label="te_unfused_fp8", ) # BENCHMARK_TE_UNFUSED_FP8_END print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n") @@ -320,6 +357,7 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None # TE Fused + FP8: Optimized Modules with FP8 # ============================================================================= + # TE_FUSED_LAYER_START class TEFusedTransformerLayer(torch.nn.Module): """Transformer layer using fused TE modules for better performance.""" @@ -331,7 +369,7 @@ def __init__( num_attention_heads: int, layernorm_eps: float = 1e-5, attention_dropout: float = 0.1, - hidden_dropout: float = 0.1 + hidden_dropout: float = 0.1, ): super().__init__() self.num_attention_heads = num_attention_heads @@ -343,7 +381,7 @@ def __init__( num_attention_heads=num_attention_heads, kv_channels=self.kv_channels, attention_dropout=attention_dropout, - attn_mask_type='causal', + attn_mask_type="causal", ) self.projection = te.Linear(hidden_size, hidden_size, bias=True) self.dropout1 = torch.nn.Dropout(hidden_dropout) @@ -352,7 +390,9 @@ def __init__( self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True) self.dropout2 = torch.nn.Dropout(hidden_dropout) - def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor: + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: res = x # Fused LayerNorm + QKV projection @@ -371,16 +411,22 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None x = self.dropout2(x) return x + res + + # TE_FUSED_LAYER_END print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START") # BENCHMARK_TE_FUSED_FP8_START -te_fused_fp8 = TEFusedTransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - num_attention_heads=num_attention_heads, -).to(dtype=dtype).cuda() +te_fused_fp8 = ( + TEFusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() +) print("TE Fused + TE Attention + FP8:") time_te_fused_fp8 = speedometer( @@ -388,7 +434,7 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None x, forward_kwargs={"attention_mask": None}, autocast_kwargs={"enabled": True, "recipe": recipe}, - label="te_fused_fp8" + label="te_fused_fp8", ) # BENCHMARK_TE_FUSED_FP8_END print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n") @@ -400,16 +446,20 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START") # BENCHMARK_TE_TRANSFORMER_LAYER_START -te_transformer_layer = te.TransformerLayer( - hidden_size=hidden_size, - ffn_hidden_size=ffn_hidden_size, - num_attention_heads=num_attention_heads, - self_attn_mask_type='causal', - layernorm_epsilon=1e-5, - bias=True, - hidden_dropout=0.0, - attention_dropout=0.0, -).to(dtype=dtype).cuda() +te_transformer_layer = ( + te.TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + self_attn_mask_type="causal", + layernorm_epsilon=1e-5, + bias=True, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + .to(dtype=dtype) + .cuda() +) print("TE TransformerLayer + FP8:") time_te_transformer_layer = speedometer( @@ -417,7 +467,7 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None x, forward_kwargs={"attention_mask": None}, autocast_kwargs={"enabled": True, "recipe": recipe}, - label="te_transformer_layer" + label="te_transformer_layer", ) # BENCHMARK_TE_TRANSFORMER_LAYER_END print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n") @@ -428,8 +478,20 @@ def forward(self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None f.write("Implementation,Time (ms),Speedup\n") f.write(f"Baseline PyTorch,{time_baseline:.2f},1.00x\n") f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n") - f.write(f"TE Unfused + TE Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n") - f.write(f"TE Unfused + TE Attention + FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n") - f.write(f"TE Fused + TE Attention + FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n") - f.write(f"TE TransformerLayer + FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n") + f.write( + "TE Unfused + TE" + f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n" + ) + f.write( + "TE Unfused + TE Attention +" + f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n" + ) + f.write( + "TE Fused + TE Attention +" + f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n" + ) + f.write( + "TE TransformerLayer +" + f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n" + ) print("\nSummary written to getting_started_pytorch_summary.csv") diff --git a/docs/getting_started/getting_started_utils_jax.py b/docs/getting_started/getting_started_utils_jax.py index edd9771b898..6184b6565b8 100644 --- a/docs/getting_started/getting_started_utils_jax.py +++ b/docs/getting_started/getting_started_utils_jax.py @@ -74,4 +74,3 @@ def loss_fn(params, x): avg_time = sum(times) / len(times) * 1000 print(f"Mean time: {avg_time:.3f} ms") return avg_time - diff --git a/docs/getting_started/getting_started_utils_pytorch.py b/docs/getting_started/getting_started_utils_pytorch.py index 0b2a9cd3e5b..307d3d13b46 100644 --- a/docs/getting_started/getting_started_utils_pytorch.py +++ b/docs/getting_started/getting_started_utils_pytorch.py @@ -122,4 +122,3 @@ def forward( context = context.view(sq, b, self.projection_size) return context - From e29060d226fd8bee2c1ba552d0e31107b244cb92 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 09:23:05 +0100 Subject: [PATCH 3/5] fix Signed-off-by: Pawel Gadzinski --- docs/examples/quickstart_jax_utils.py | 4 ---- docs/examples/quickstart_utils.py | 1 + docs/examples/te_jax_integration.ipynb | 2 +- docs/getting_started/index.rst | 8 +++++--- 4 files changed, 7 insertions(+), 8 deletions(-) diff --git a/docs/examples/quickstart_jax_utils.py b/docs/examples/quickstart_jax_utils.py index f1ff9b7d995..9a145a511b1 100644 --- a/docs/examples/quickstart_jax_utils.py +++ b/docs/examples/quickstart_jax_utils.py @@ -5,13 +5,9 @@ import jax import jax.numpy as jnp import time -import math from typing import Callable, Any, Dict, Optional, Tuple -from flax import linen as nn import transformer_engine.jax as te -import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention def speedometer( diff --git a/docs/examples/quickstart_utils.py b/docs/examples/quickstart_utils.py index 473fce7fe74..eaff56ebff3 100644 --- a/docs/examples/quickstart_utils.py +++ b/docs/examples/quickstart_utils.py @@ -213,3 +213,4 @@ def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"): ret = quantizer(inp) ret = ret.dequantize() return ret + diff --git a/docs/examples/te_jax_integration.ipynb b/docs/examples/te_jax_integration.ipynb index 70647e421a0..66d16ed52f4 100644 --- a/docs/examples/te_jax_integration.ipynb +++ b/docs/examples/te_jax_integration.ipynb @@ -264,7 +264,7 @@ "id": "5e9310c9", "metadata": {}, "source": [ - "# Transformer Engine" + "## Transformer Engine" ] }, { diff --git a/docs/getting_started/index.rst b/docs/getting_started/index.rst index e704497c29e..30dba93c6e1 100644 --- a/docs/getting_started/index.rst +++ b/docs/getting_started/index.rst @@ -43,7 +43,7 @@ Currently two frameworks are supported: PyTorch and JAX. Baseline: Pure Framework Implementation --------------------------------------- -Let's build a Transformer decoder layer! +Let's build a Transformer decoder layer! We'll create a basic GPT-style layer with causal masking, which prevents each position from attending to future positions. This will be our baseline @@ -83,6 +83,7 @@ Putting it all together: .. tabs:: .. tab:: PyTorch + First, define the MLP block: .. literalinclude:: getting_started_pytorch.py @@ -254,7 +255,8 @@ TE Unfused + TE Attention ------------------------- Now let's also replace the attention mechanism with TE's optimized ``DotProductAttention``. -This provides Flash Attention and other optimizations. +TE's attention automatically selects the best available backend — for example, FlashAttention or cuDNN fused attention — based on your hardware and input configuration, +delivering optimal performance without manual tuning. .. tabs:: @@ -293,7 +295,7 @@ This provides Flash Attention and other optimizations. Replace Flax's attention with TE's optimized implementation: - * ``nn.dot_product_attention`` → ``TEDotProductAttention`` + * ``nn.dot_product_attention`` → ``te_flax.DotProductAttention`` .. literalinclude:: getting_started_jax.py :language: python From 98dfe0c2c96b7b6140686affd5ff59d8cde6ebac Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 19 Dec 2025 08:24:06 +0000 Subject: [PATCH 4/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- docs/examples/quickstart_utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/docs/examples/quickstart_utils.py b/docs/examples/quickstart_utils.py index eaff56ebff3..473fce7fe74 100644 --- a/docs/examples/quickstart_utils.py +++ b/docs/examples/quickstart_utils.py @@ -213,4 +213,3 @@ def cast_to_representable(inp, scale=1.0, fp8_format="e4m3"): ret = quantizer(inp) ret = ret.dequantize() return ret - From 04e03e9999595345a05360645af35c543294d0d7 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 19 Dec 2025 00:31:26 -0800 Subject: [PATCH 5/5] fix Signed-off-by: Pawel Gadzinski --- docs/getting_started/getting_started_jax.out | 24 +++++-------------- .../getting_started_jax_summary.csv | 12 +++++----- .../getting_started_pytorch.out | 12 +++++----- .../getting_started_pytorch_summary.csv | 12 +++++----- 4 files changed, 24 insertions(+), 36 deletions(-) diff --git a/docs/getting_started/getting_started_jax.out b/docs/getting_started/getting_started_jax.out index 3455cbfdaff..c11f3b1965f 100644 --- a/docs/getting_started/getting_started_jax.out +++ b/docs/getting_started/getting_started_jax.out @@ -1,45 +1,33 @@ pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel -/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 - warnings.warn( -/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 - warnings.warn( -/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 - warnings.warn( -/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 - warnings.warn( -/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 - warnings.warn( -/usr/local/lib/python3.12/dist-packages/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10 - warnings.warn( # BENCHMARK_BASELINE_OUTPUT_START Baseline Flax: -Mean time: 81.035 ms +Mean time: 86.580 ms # BENCHMARK_BASELINE_OUTPUT_END # BENCHMARK_TE_UNFUSED_OUTPUT_START TE Unfused: -Mean time: 42.570 ms +Mean time: 42.252 ms # BENCHMARK_TE_UNFUSED_OUTPUT_END # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START TE Unfused + TE Attention: -Mean time: 35.017 ms +Mean time: 35.054 ms # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START TE Unfused + TE Attention + FP8: -Mean time: 22.778 ms +Mean time: 22.638 ms # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END # BENCHMARK_TE_FUSED_FP8_OUTPUT_START TE Fused + TE Attention + FP8: -Mean time: 24.007 ms +Mean time: 23.703 ms # BENCHMARK_TE_FUSED_FP8_OUTPUT_END # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START TE TransformerLayer + FP8: -Mean time: 23.004 ms +Mean time: 22.812 ms # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END diff --git a/docs/getting_started/getting_started_jax_summary.csv b/docs/getting_started/getting_started_jax_summary.csv index 5c8c7f84fdc..5b6a4249b3d 100644 --- a/docs/getting_started/getting_started_jax_summary.csv +++ b/docs/getting_started/getting_started_jax_summary.csv @@ -1,7 +1,7 @@ Implementation,Time (ms),Speedup -Baseline Flax,81.04,1.00x -TE Unfused,42.57,1.90x -TE Unfused + TE Attention,35.02,2.31x -TE Unfused + TE Attention + FP8,22.78,3.56x -TE Fused + TE Attention + FP8,24.01,3.38x -TE TransformerLayer + FP8,23.00,3.52x +Baseline Flax,86.58,1.00x +TE Unfused,42.25,2.05x +TE Unfused + TE Attention,35.05,2.47x +TE Unfused + TE Attention + FP8,22.64,3.82x +TE Fused + TE Attention + FP8,23.70,3.65x +TE TransformerLayer + FP8,22.81,3.80x diff --git a/docs/getting_started/getting_started_pytorch.out b/docs/getting_started/getting_started_pytorch.out index 41675f03ceb..9b9387a8b22 100644 --- a/docs/getting_started/getting_started_pytorch.out +++ b/docs/getting_started/getting_started_pytorch.out @@ -10,32 +10,32 @@ pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/tran self.m.impl( # BENCHMARK_BASELINE_OUTPUT_START Baseline PyTorch: -Mean time: 48.507 ms +Mean time: 48.280 ms # BENCHMARK_BASELINE_OUTPUT_END # BENCHMARK_TE_UNFUSED_OUTPUT_START TE Unfused: -Mean time: 49.451 ms +Mean time: 49.342 ms # BENCHMARK_TE_UNFUSED_OUTPUT_END # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START TE Unfused + TE Attention: -Mean time: 35.776 ms +Mean time: 35.709 ms # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START TE Unfused + TE Attention + FP8: -Mean time: 23.460 ms +Mean time: 23.406 ms # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END # BENCHMARK_TE_FUSED_FP8_OUTPUT_START TE Fused + TE Attention + FP8: -Mean time: 23.037 ms +Mean time: 22.964 ms # BENCHMARK_TE_FUSED_FP8_OUTPUT_END # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START TE TransformerLayer + FP8: -Mean time: 21.844 ms +Mean time: 21.670 ms # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END diff --git a/docs/getting_started/getting_started_pytorch_summary.csv b/docs/getting_started/getting_started_pytorch_summary.csv index 89d4a076ecf..b3a5d7330ed 100644 --- a/docs/getting_started/getting_started_pytorch_summary.csv +++ b/docs/getting_started/getting_started_pytorch_summary.csv @@ -1,7 +1,7 @@ Implementation,Time (ms),Speedup -Baseline PyTorch,48.51,1.00x -TE Unfused,49.45,0.98x -TE Unfused + TE Attention,35.78,1.36x -TE Unfused + TE Attention + FP8,23.46,2.07x -TE Fused + TE Attention + FP8,23.04,2.11x -TE TransformerLayer + FP8,21.84,2.22x +Baseline PyTorch,48.28,1.00x +TE Unfused,49.34,0.98x +TE Unfused + TE Attention,35.71,1.35x +TE Unfused + TE Attention + FP8,23.41,2.06x +TE Fused + TE Attention + FP8,22.96,2.10x +TE TransformerLayer + FP8,21.67,2.23x