diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 5beeeb8879f..7ebcbf740c6 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,7 +17,7 @@ jobs: uses: actions/checkout@v3 - name: 'Install dependencies' run: | - pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 + pip install sphinx==8.1.3 sphinx_rtd_theme==3.0.1 nbsphinx==0.9.5 IPython ipython_genutils==0.2.0 ipywidgets==8.0.2 astroid==3.3.2 sphinx-tabs==3.4.7 pip install breathe==4.35.0 sphinx-autoapi==3.3.2 sudo apt-get install -y pandoc graphviz doxygen export GIT_SHA=$(git show-ref --hash HEAD) diff --git a/docs/_static/css/output-style.css b/docs/_static/css/output-style.css new file mode 100644 index 00000000000..864d8587a32 --- /dev/null +++ b/docs/_static/css/output-style.css @@ -0,0 +1,60 @@ +/* Custom styling for program output blocks */ + +.program-output { + background-color: #f8f9fa; + padding: 0; /* No padding at all */ + margin: 0; /* No margins at all */ + border-radius: 0; /* No rounded corners */ + font-family: 'Courier New', monospace; + font-size: 14px; + line-height: 1.5; + width: 100%; + max-width: 100%; +} + +.program-output pre { + margin: 0; + padding: 0; + background: transparent !important; + border: none !important; + color: #2c3e50; + width: 100%; +} + +.program-output .highlight { + background: transparent !important; + margin: 0; + width: 100%; +} + +/* Alternative lighter style */ +.output-block { + background-color: #fafbfc; + border: 1px solid #e1e4e8; + padding: 10px 14px; + margin: 10px 0; + border-radius: 3px; + font-family: 'SF Mono', 'Consolas', monospace; + font-size: 13px; + color: #24292e; +} + +/* Console-like output style */ +.console-output { + background-color: #1e1e1e; + border-left: 3px solid #76b900; + padding: 14px 18px; + margin: 12px 0; + border-radius: 5px; + font-family: 'Fira Code', 'Consolas', monospace; + font-size: 13px; + color: #d4d4d4; + box-shadow: 0 2px 4px rgba(0,0,0,0.1); +} + +.console-output pre { + margin: 0; + color: #d4d4d4; + background: transparent !important; +} + diff --git a/docs/_static/css/rtabs.css b/docs/_static/css/rtabs.css new file mode 100644 index 00000000000..7f4213ef99b --- /dev/null +++ b/docs/_static/css/rtabs.css @@ -0,0 +1,43 @@ +/* Custom styling for sphinx-tabs */ + +.sphinx-tabs { + margin-bottom: 1rem; +} + +.sphinx-tabs-tab { + background-color: #f4f4f4; + border: 1px solid #ccc; + border-bottom: none; + padding: 0.5rem 1rem; + margin-right: 0.5rem; + cursor: pointer; + font-weight: 500; + transition: background-color 0.2s; +} + +.sphinx-tabs-tab:hover { + background-color: #e0e0e0; +} + +.sphinx-tabs-tab[aria-selected="true"] { + background-color: #76b900; /* NVIDIA green */ + color: white; + border-color: #76b900; + margin-right: 0.5rem; +} + +.sphinx-tabs-panel { + border: 1px solid #ccc; + padding: 1rem; + background-color: #f9f9f9; +} + +/* Dark mode support for RTD theme */ +.rst-content .sphinx-tabs-tab { + color: #333; +} + +.rst-content .sphinx-tabs-tab[aria-selected="true"] { + color: white; +} + diff --git a/docs/conf.py b/docs/conf.py index 479c1f8948f..920b487fd32 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -58,6 +58,7 @@ "nbsphinx", "breathe", "autoapi.extension", + "sphinx_tabs.tabs", ] templates_path = ["_templates"] @@ -83,6 +84,8 @@ html_css_files = [ "css/nvidia_font.css", "css/nvidia_footer.css", + "css/rtabs.css", + "css/output-style.css", ] html_theme_options = { diff --git a/docs/examples/advanced_optimizations.ipynb b/docs/examples/advanced_optimizations.ipynb index 7c08bb65866..1b0694a05f8 100644 --- a/docs/examples/advanced_optimizations.ipynb +++ b/docs/examples/advanced_optimizations.ipynb @@ -13,7 +13,7 @@ "id": "6dcbf25a", "metadata": {}, "source": [ - "This guide is a follow-up to the discussion in the [quickstart guide](quickstart.ipynb). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). " + "This guide is a follow-up to the discussion in the [Getting Started guide](../getting_started/index.rst). We will focus on techniques to achieve maximum performance when training a basic GPT encoder layer. For convenience, we use some helper functions defined in [quickstart_utils.py](quickstart_utils.py). " ] }, { diff --git a/docs/examples/quickstart.ipynb b/docs/examples/quickstart.ipynb deleted file mode 100644 index 0ad2f4fee8c..00000000000 --- a/docs/examples/quickstart.ipynb +++ /dev/null @@ -1,606 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "da9fd6a8", - "metadata": {}, - "source": [ - "# Getting Started\n", - "\n", - "## Overview\n", - "\n", - "Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your PyTorch code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n", - "\n", - "## Let's build a Transformer layer!\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We build a basic Transformer layer using regular PyTorch modules. This will be our baseline for later comparisons with Transformer Engine.\n", - "\n", - "
\n", - "\n", - "Let's start with creating a GPT encoder layer using plain PyTorch. Figure 1 shows the overall structure.\n", - "\n", - "
\n", - "\n", - "
Figure 1: Structure of a GPT encoder layer.
\n", - "
\n", - "\n", - "We construct the components as follows:\n", - "\n", - "- `LayerNorm`: `torch.nn.LayerNorm`\n", - "- `QKV Projection`: `torch.nn.Linear` (conceptually three `Linear` layers for Q, K, and V separately, but we fuse into a single `Linear` layer that is three times larger)\n", - "- `DotProductAttention`: `DotProductAttention` from [quickstart_utils.py](quickstart_utils.py)\n", - "- `Projection`: `torch.nn.Linear`\n", - "- `Dropout`: `torch.nn.Dropout`\n", - "- `MLP`: `BasicMLP` from [quickstart_utils.py](quickstart_utils.py)\n", - "\n", - "Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_utils.py](quickstart_utils.py). Putting it all together:" - ] - }, - { - "cell_type": "code", - "execution_count": 1, - "id": "2be43d64", - "metadata": {}, - "outputs": [], - "source": [ - "import torch\n", - "import quickstart_utils as utils\n", - "\n", - "class BasicTransformerLayer(torch.nn.Module):\n", - " def __init__(\n", - " self,\n", - " hidden_size: int,\n", - " ffn_hidden_size: int,\n", - " num_attention_heads: int,\n", - " layernorm_eps: int = 1e-5,\n", - " attention_dropout: float = 0.1,\n", - " hidden_dropout: float = 0.1,\n", - " ):\n", - " super().__init__()\n", - " self.num_attention_heads = num_attention_heads\n", - " self.kv_channels = hidden_size // num_attention_heads\n", - " self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n", - " self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True)\n", - " self.attention = utils.DotProductAttention(\n", - " num_attention_heads=num_attention_heads,\n", - " kv_channels=self.kv_channels,\n", - " attention_dropout=attention_dropout,\n", - " )\n", - " self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True)\n", - " self.dropout = torch.nn.Dropout(hidden_dropout)\n", - " self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps)\n", - " self.mlp = utils.BasicMLP(\n", - " hidden_size=hidden_size,\n", - " ffn_hidden_size=ffn_hidden_size,\n", - " ) \n", - " \n", - " def forward(\n", - " self, \n", - " x: torch.Tensor, \n", - " attention_mask: torch.Tensor\n", - " ) -> torch.Tensor:\n", - " res = x\n", - " x = self.ln1(x)\n", - " \n", - " # Fused QKV projection\n", - " qkv = self.qkv_projection(x)\n", - " qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n", - " \n", - " x = self.attention(q, k, v, attention_mask)\n", - " x = self.projection(x)\n", - " x = self.dropout(x)\n", - " x = res + x\n", - " res = x\n", - " x = self.ln2(x)\n", - " x = self.mlp(x)\n", - " \n", - " return x + res" - ] - }, - { - "cell_type": "markdown", - "id": "40724d1d", - "metadata": {}, - "source": [ - "That's it! We now have a simple Transformer layer. We can test it:" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "a786f0ea", - "metadata": {}, - "outputs": [], - "source": [ - "# Layer configuration\n", - "hidden_size = 4096\n", - "sequence_length = 2048\n", - "batch_size = 4\n", - "ffn_hidden_size = 16384\n", - "num_attention_heads = 32\n", - "dtype = torch.float16\n", - "\n", - "# Synthetic data\n", - "x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)\n", - "dy = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype)" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "ffdbfb7a", - "metadata": {}, - "outputs": [ - { - "data": { - "text/plain": [ - "BasicTransformerLayer(\n", - " (ln1): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n", - " (qkv_projection): Linear(in_features=4096, out_features=12288, bias=True)\n", - " (attention): DotProductAttention(\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " )\n", - " (projection): Linear(in_features=4096, out_features=4096, bias=True)\n", - " (dropout): Dropout(p=0.1, inplace=False)\n", - " (ln2): LayerNorm((4096,), eps=1e-05, elementwise_affine=True)\n", - " (mlp): BasicMLP(\n", - " (linear1): Linear(in_features=4096, out_features=16384, bias=True)\n", - " (linear2): Linear(in_features=16384, out_features=4096, bias=True)\n", - " )\n", - ")" - ] - }, - "execution_count": 3, - "metadata": {}, - "output_type": "execute_result" - } - ], - "source": [ - "basic_transformer = BasicTransformerLayer(\n", - " hidden_size,\n", - " ffn_hidden_size,\n", - " num_attention_heads,\n", - ")\n", - "basic_transformer.to(dtype=dtype).cuda()" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "0162ad40", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(1234)\n", - "y = basic_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "65ae6dd6", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 43.0663916015625 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " basic_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "43717e36", - "metadata": {}, - "source": [ - "## Meet Transformer Engine\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We modify the example Transformer layer to include the simplest TE modules: `Linear` and `LayerNorm`.\n", - "\n", - "
\n", - "\n", - "Now that we have a basic Transformer layer, let's use Transformer Engine to speed up the training. " - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "004d3c92", - "metadata": {}, - "outputs": [], - "source": [ - "import transformer_engine.pytorch as te" - ] - }, - { - "cell_type": "markdown", - "id": "1931f911", - "metadata": {}, - "source": [ - "TE provides a set of PyTorch modules that can be used to build Transformer layers. The simplest of the provided modules are the `Linear` and `LayerNorm` layers, which we can use instead of `torch.nn.Linear` and `torch.nn.LayerNorm`. Let's modify `BasicTransformerLayer`:" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "1f44db50", - "metadata": {}, - "outputs": [], - "source": [ - "class BasicTEMLP(torch.nn.Module):\n", - " def __init__(self,\n", - " hidden_size: int,\n", - " ffn_hidden_size: int) -> None:\n", - " super().__init__()\n", - " self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True)\n", - " self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True)\n", - "\n", - " def forward(self, x):\n", - " x = self.linear1(x)\n", - " x = torch.nn.functional.gelu(x, approximate='tanh')\n", - " x = self.linear2(x)\n", - " return x \n", - " \n", - "class BasicTETransformerLayer(torch.nn.Module):\n", - " def __init__(self,\n", - " hidden_size: int,\n", - " ffn_hidden_size: int,\n", - " num_attention_heads: int,\n", - " layernorm_eps: int = 1e-5,\n", - " attention_dropout: float = 0.1,\n", - " hidden_dropout: float = 0.1):\n", - " super().__init__()\n", - " self.num_attention_heads = num_attention_heads\n", - " self.kv_channels = hidden_size // num_attention_heads\n", - " self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n", - " self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True)\n", - " self.attention = utils.DotProductAttention(\n", - " num_attention_heads=num_attention_heads,\n", - " kv_channels=self.kv_channels,\n", - " attention_dropout=attention_dropout,\n", - " )\n", - " self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n", - " self.dropout = torch.nn.Dropout(hidden_dropout)\n", - " self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps)\n", - " self.mlp = BasicTEMLP(\n", - " hidden_size=hidden_size,\n", - " ffn_hidden_size=ffn_hidden_size,\n", - " )\n", - " \n", - " def forward(self, \n", - " x: torch.Tensor, \n", - " attention_mask: torch.Tensor):\n", - " res = x\n", - " x = self.ln1(x)\n", - " \n", - " # Fused QKV projection\n", - " qkv = self.qkv_projection(x)\n", - " qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n", - " \n", - " x = self.attention(q, k, v, attention_mask)\n", - " x = self.projection(x)\n", - " x = self.dropout(x)\n", - " x = res + x\n", - " res = x\n", - " x = self.ln2(x)\n", - " x = self.mlp(x)\n", - " \n", - " return x + res" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "916531e8", - "metadata": {}, - "outputs": [], - "source": [ - "basic_te_transformer = BasicTETransformerLayer(\n", - " hidden_size, \n", - " ffn_hidden_size, \n", - " num_attention_heads,\n", - ")\n", - "basic_te_transformer.to(dtype=dtype).cuda()\n", - "utils.share_parameters_with_basic_te_model(basic_te_transformer, basic_transformer)" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "3643fa54", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(1234)\n", - "y = basic_te_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "10b92894", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 43.1413232421875 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " basic_te_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "3f990226", - "metadata": {}, - "source": [ - "## Fused TE Modules\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We optimize the example Transformer layer with TE modules for fused operations.\n", - "\n", - "
\n", - "\n", - "The `Linear` layer is enough to build any Transformer model and it enables usage of Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations like kernel fusion, increasing the achievable speedup.\n", - "\n", - "Transformer Engine therefore provides coarser modules that span multiple layers:\n", - "\n", - "* `LayerNormLinear`\n", - "* `LayerNormMLP`\n", - "* `TransformerLayer`\n", - "\n", - "Building a third iteration of our Transformer layer with `LayerNormLinear` and `LayerNormMLP`:" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "c55eae1f", - "metadata": {}, - "outputs": [], - "source": [ - "class FusedTETransformerLayer(torch.nn.Module):\n", - " def __init__(self,\n", - " hidden_size: int,\n", - " ffn_hidden_size: int,\n", - " num_attention_heads: int,\n", - " layernorm_eps: int = 1e-5,\n", - " attention_dropout: float = 0.1,\n", - " hidden_dropout: float = 0.1):\n", - " super().__init__()\n", - " self.num_attention_heads = num_attention_heads\n", - " self.kv_channels = hidden_size // num_attention_heads\n", - " self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True)\n", - " self.attention = utils.DotProductAttention(\n", - " num_attention_heads=num_attention_heads,\n", - " kv_channels=self.kv_channels,\n", - " attention_dropout=attention_dropout,\n", - " )\n", - " self.projection = te.Linear(hidden_size, hidden_size, bias=True)\n", - " self.dropout = torch.nn.Dropout(hidden_dropout)\n", - " self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True)\n", - " \n", - " \n", - " def forward(self, \n", - " x: torch.Tensor, \n", - " attention_mask: torch.Tensor):\n", - " res = x\n", - " qkv = self.ln_qkv(x)\n", - " \n", - " # Split qkv into query, key and value\n", - " qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3)\n", - " \n", - " x = self.attention(q, k, v, attention_mask)\n", - " x = self.projection(x)\n", - " x = self.dropout(x)\n", - " x = res + x\n", - " res = x\n", - " x = self.ln_mlp(x)\n", - " \n", - " return x + res" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "85949421", - "metadata": {}, - "outputs": [], - "source": [ - "fused_te_transformer = FusedTETransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n", - "fused_te_transformer.to(dtype=dtype).cuda()\n", - "utils.share_parameters_with_fused_te_model(fused_te_transformer, basic_transformer)" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "2c263e71", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(1234)\n", - "y = fused_te_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "24e101bc", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 43.1981201171875 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " fused_te_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "33f13c26", - "metadata": {}, - "source": [ - "Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures and it provides the highest degree of performance optimization:" - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "ec8c3685", - "metadata": {}, - "outputs": [], - "source": [ - "te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n", - "te_transformer.to(dtype=dtype).cuda()\n", - "utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)" - ] - }, - { - "cell_type": "code", - "execution_count": 16, - "id": "e48cd590", - "metadata": {}, - "outputs": [], - "source": [ - "torch.manual_seed(1234)\n", - "y = te_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 17, - "id": "3ec3707d-e63f-4899-8308-b11c55b5caa4", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 39.99169921875 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " te_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "4034c3eb-8958-49f2-85f6-30c94977d884", - "metadata": {}, - "source": [ - "## Enabling FP8\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We configure a TE module to perform compute in FP8.\n", - "\n", - "
\n", - "\n", - "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/pytorch.rst#transformer_engine.pytorch.autocast) context manager. Note that autocast should only be used to wrap the forward pass and must exit before starting a backward pass. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options." - ] - }, - { - "cell_type": "code", - "execution_count": 18, - "id": "31256aa7-3d5e-425c-91ab-502b1326a748", - "metadata": {}, - "outputs": [], - "source": [ - "from transformer_engine.common.recipe import Format, DelayedScaling\n", - "\n", - "te_transformer = te.TransformerLayer(hidden_size, ffn_hidden_size, num_attention_heads)\n", - "te_transformer.to(dtype=dtype).cuda()\n", - "utils.share_parameters_with_transformerlayer_te_model(te_transformer, basic_transformer)\n", - "\n", - "fp8_format = Format.HYBRID\n", - "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")\n", - "torch.manual_seed(1234)\n", - "with te.autocast(enabled=True, fp8_recipe=fp8_recipe):\n", - " y = te_transformer(x, attention_mask=None)" - ] - }, - { - "cell_type": "code", - "execution_count": 19, - "id": "793ebd2d-b84b-47bc-811a-7991df8500aa", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 28.61394775390625 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " te_transformer,\n", - " x,\n", - " dy,\n", - " forward_kwargs = { \"attention_mask\": None },\n", - " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", - ")" - ] - } - ], - "metadata": { - "language_info": { - "name": "python" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/examples/quickstart_jax.ipynb b/docs/examples/quickstart_jax.ipynb deleted file mode 100644 index 5b4e439c004..00000000000 --- a/docs/examples/quickstart_jax.ipynb +++ /dev/null @@ -1,833 +0,0 @@ -{ - "cells": [ - { - "cell_type": "markdown", - "id": "962d87bb", - "metadata": {}, - "source": [ - "\n", - "\n", - "# Getting Started\n", - "\n", - "## Overview\n", - "\n", - "Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, providing better performance with lower memory utilization in both training and inference. It provides support for 8-bit floating point (FP8) precision on Hopper, Ada, as well as 8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs, implements a collection of highly optimized building blocks for popular Transformer architectures, and exposes an automatic-mixed-precision-like API that can be used seamlessly with your JAX code. It also includes a framework-agnostic C++ API that can be integrated with other deep learning libraries to enable FP8 support for Transformers.\n", - "\n", - "This guide shows how to start using Transformer Engine with JAX. Similar tutorial for pyTorch is available [here](quickstart.ipynb).\n", - "We recommend you to try understanding the basics of JAX first, using these resources:\n", - "\n", - "- Thinking in JAX: https://docs.jax.dev/en/latest/notebooks/thinking_in_jax.html\n", - "- JAX 101: https://docs.jax.dev/en/latest/jax-101.html\n", - "- Key concepts in JAX: https://docs.jax.dev/en/latest/key-concepts.html#jax-arrays-jax-array\n", - "- Flax 101: https://flax-linen.readthedocs.io/en/latest/guides/flax_fundamentals/index.html\n", - "\n", - "## Let's build a Transformer decoder layer!\n", - "_This is based upon the GPT decoder layer with causal masking, which prevents each position from attending to future positions._\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We build a basic Transformer layer using regular Flax modules. This will be our baseline for later comparisons with Transformer Engine.\n", - "\n", - "
\n", - "\n", - "Let's start with creating the transformer layer using plain [FLAX Linen](https://flax.readthedocs.io/en/stable/) . Figure 1 shows the overall structure.\n", - "\n", - "
\n", - "\n", - "
Figure 1: Structure of a GPT decoder layer.
\n", - "
\n", - "\n", - "We construct the components as follows:\n", - "\n", - "- `LayerNorm`: `nn.LayerNorm` (Flax)\n", - "- `QKV Projection`: `nn.Dense` (conceptually there are three seperate `Dense` layers for Q, K, and V separately, but we fuse them together into a single `Dense` layer that is three times larger)\n", - "- `DotProductAttention`: `nn.MuliheadDotProductAttention` (Flax)\n", - "- `Projection`: `nn.Dense` (Flax)\n", - "- `Dropout`: `nn.Dropout` (Flax)\n", - "- `MLP`: `FlaxMLP` implemented using `nn.Dense` and `nn.gelu`\n", - "\n", - "Over the course of this tutorial we will use a few modules and helper functions defined in [quickstart_jax_utils.py](quickstart_jax_utils.py). Putting it all together: \n" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "d5284a38", - "metadata": {}, - "outputs": [], - "source": [ - "import jax\n", - "import jax.numpy as jnp\n", - "from flax import linen as nn\n", - "import quickstart_jax_utils as utils\n", - "from typing import Optional" - ] - }, - { - "cell_type": "code", - "execution_count": 2, - "id": "a4d1cfdc", - "metadata": {}, - "outputs": [], - "source": [ - "class FlaxMLP(nn.Module):\n", - " \"\"\"Feed-forward network in Transformer layer\n", - " Built with plain Flax modules.\n", - " \"\"\"\n", - " hidden_size: int\n", - " ffn_hidden_size: int\n", - "\n", - " @nn.compact\n", - " def __call__(self, x: jnp.ndarray) -> jnp.ndarray:\n", - " x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x)\n", - " x = nn.gelu(x, approximate=True) # equivalent to tanh approximation\n", - " x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n", - " return x\n", - "\n", - "class FlaxTransformerLayer(nn.Module):\n", - " \"\"\"Basic Transformer layer using plain Flax modules\"\"\"\n", - " hidden_size: int\n", - " ffn_hidden_size: int\n", - " num_attention_heads: int\n", - " layernorm_eps: float = 1e-5\n", - " attention_dropout: float = 0.1\n", - " \n", - " def setup(self):\n", - " self.kv_channels = self.hidden_size // self.num_attention_heads\n", - "\n", - " @nn.compact\n", - " def __call__(\n", - " self, \n", - " x: jnp.ndarray, \n", - " attention_mask: Optional[jnp.ndarray] = None,\n", - " deterministic: bool = False\n", - " ) -> jnp.ndarray:\n", - " # Create causal mask if not provided\n", - " if attention_mask is None:\n", - " attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n", - " \n", - " res = x\n", - " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - " \n", - " # Fused QKV projection\n", - " qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x)\n", - " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = jnp.split(qkv, 3, axis=3)\n", - " \n", - " # q, k, v now have shape [batch, seq_len, num_heads, kv_channels]\n", - " # which is the correct format for dot_product_attention\n", - " \n", - " # Apply dot product attention\n", - " # Note: dot_product_attention expects mask to be broadcastable to \n", - " # [batch, num_heads, q_length, kv_length], but attention_mask from \n", - " # nn.make_causal_mask has shape [batch, 1, seq_len, seq_len]\n", - " \n", - " # Generate dropout RNG key when needed (not deterministic and dropout_rate > 0)\n", - " dropout_rng = None\n", - " if not deterministic and self.attention_dropout > 0:\n", - " dropout_rng = self.make_rng('dropout')\n", - " \n", - " x = nn.dot_product_attention(\n", - " query=q,\n", - " key=k,\n", - " value=v,\n", - " mask=attention_mask,\n", - " dropout_rng=dropout_rng,\n", - " dropout_rate=self.attention_dropout,\n", - " deterministic=deterministic,\n", - " broadcast_dropout=True,\n", - " )\n", - " \n", - " # Reshape output from [batch, seq_len, num_heads, kv_channels] to [batch, seq_len, hidden_size]\n", - " x = x.reshape(x.shape[0], x.shape[1], self.hidden_size)\n", - "\n", - " # Output projection\n", - " x = nn.Dense(features=self.hidden_size, use_bias=True)(x)\n", - " \n", - " x = res + x\n", - " \n", - " # Second residual connection\n", - " res = x\n", - " x = nn.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - " \n", - " # MLP\n", - " mlp = FlaxMLP(\n", - " hidden_size=self.hidden_size,\n", - " ffn_hidden_size=self.ffn_hidden_size,\n", - " )\n", - " x = mlp(x)\n", - " \n", - " return x + res\n" - ] - }, - { - "cell_type": "markdown", - "id": "fbc3510b", - "metadata": {}, - "source": [ - "## Testing Performance\n", - "\n", - "Now let's test the performance of our FlaxTransformerLayer:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 3, - "id": "8b44649d", - "metadata": {}, - "outputs": [], - "source": [ - "# Layer configuration\n", - "hidden_size = 4096\n", - "sequence_length = 2048\n", - "batch_size = 4\n", - "ffn_hidden_size = 16384\n", - "num_attention_heads = 32\n", - "dtype = jnp.bfloat16\n", - "\n", - "# Synthetic data\n", - "key, dropout_key = jax.random.split(jax.random.PRNGKey(42))\n", - "x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n", - "dy = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype)\n" - ] - }, - { - "cell_type": "code", - "execution_count": 4, - "id": "e44ed26d", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Pure Flax FlaxTransformerLayer initialized successfully!\n", - "Parameter shapes: {'params': {'Dense_0': {'bias': (12288,), 'kernel': (4096, 12288)}, 'Dense_1': {'bias': (4096,), 'kernel': (4096, 4096)}, 'FlaxMLP_0': {'Dense_0': {'bias': (16384,), 'kernel': (4096, 16384)}, 'Dense_1': {'bias': (4096,), 'kernel': (16384, 4096)}}, 'LayerNorm_0': {'bias': (4096,), 'scale': (4096,)}, 'LayerNorm_1': {'bias': (4096,), 'scale': (4096,)}}}\n" - ] - } - ], - "source": [ - "# Initialize the FlaxTransformerLayer\n", - "flax_transformer = FlaxTransformerLayer(\n", - " hidden_size=hidden_size,\n", - " ffn_hidden_size=ffn_hidden_size,\n", - " num_attention_heads=num_attention_heads,\n", - ")\n", - "\n", - "# Initialize parameters\n", - "params = flax_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - "print(\"Pure Flax FlaxTransformerLayer initialized successfully!\")\n", - "print(f\"Parameter shapes: {jax.tree_util.tree_map(lambda x: x.shape, params)}\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 5, - "id": "de91af7a", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Input shape: (4, 2048, 4096)\n", - "Output shape: (4, 2048, 4096)\n", - "Output dtype: float32\n", - "Forward pass completed successfully!\n" - ] - } - ], - "source": [ - "# Example usage of forward pass\n", - "y = flax_transformer.apply(params, x, attention_mask=None, deterministic=True)\n", - "print(f\"Input shape: {x.shape}\")\n", - "print(f\"Output shape: {y.shape}\")\n", - "print(f\"Output dtype: {y.dtype}\")\n", - "print(\"Forward pass completed successfully!\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 6, - "id": "037bc8d9", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 19.258604049682617 ms\n" - ] - } - ], - "source": [ - "import importlib\n", - "import quickstart_jax_utils\n", - "importlib.reload(quickstart_jax_utils)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=flax_transformer.apply,\n", - " variables=params,\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "ccb16f31", - "metadata": {}, - "source": [ - "## Meet Transformer Engine\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "Now that we have a basic Transformer layer in Flax, let's use Transformer Engine to speed up the training. The following examples show how to use TE modules.\n", - "\n", - "
\n", - "\n", - "As a reminder, the FlaxTransformerLayer above used:\n", - "\n", - "- `nn.LayerNorm`: Flax LayerNorm\n", - "- `nn.Dense`: Flax Dense layer for QKV projection \n", - "- `nn.MultiheadDotProductAttention`: Flax MultiheadDotProductAttention\n", - "- `nn.Dense`: Flax Dense layer for projection\n", - "- `nn.Dropout`: Flax Dropout\n", - "- `FlaxMLP`: Custom MLP implemented from `nn.Dense`\n", - "\n", - "Below we show how to use Transformer Engine Flax modules for better performance:\n" - ] - }, - { - "cell_type": "code", - "execution_count": 7, - "id": "bed20d6b", - "metadata": {}, - "outputs": [], - "source": [ - "import transformer_engine.jax as te\n", - "import transformer_engine.jax.flax as te_flax" - ] - }, - { - "cell_type": "markdown", - "id": "f28cb444", - "metadata": {}, - "source": [ - "TE provides a set of Flax Linen modules that can be used to build Transformer layers. The simplest of the provided modules are the `DenseGeneral ` and `LayerNorm` layers, which we can use instead of `flax.linen.Dense` and ` flax.linen.LayerNorm`. Let's modify our `FlaxTransformerLayer`:" - ] - }, - { - "cell_type": "code", - "execution_count": 8, - "id": "56105579", - "metadata": {}, - "outputs": [], - "source": [ - "class TEUnfusedMLP(nn.Module):\n", - " hidden_size : int\n", - " ffn_hidden_size: int\n", - "\n", - " @nn.compact\n", - " def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray:\n", - " x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True) (x)\n", - " x = x.reshape(*x.shape[:-1], 1, x.shape[-1])\n", - " x = te.activation.activation(x, activation_type=('gelu',))\n", - " x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True) (x)\n", - " return x\n", - "\n", - "class TEUnfusedTransformerLayer(nn.Module):\n", - " hidden_size: int\n", - " ffn_hidden_size: int \n", - " num_attention_heads: int \n", - " layernorm_eps: float = 1e-5\n", - " attention_dropout: float = 0.1 \n", - " use_te_attention: bool = True # True for TE attention, False for Flax attention\n", - "\n", - " def setup(self):\n", - " self.kv_channels = self.hidden_size // self.num_attention_heads\n", - "\n", - " @nn.compact\n", - " def __call__(\n", - " self, \n", - " x: jnp.ndarray,\n", - " attention_mask: Optional[jnp.ndarray] = None,\n", - " deterministic: bool = False\n", - " ) -> jnp.ndarray: \n", - " res = x\n", - " x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - "\n", - " # Fused QKV projection\n", - " qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x)\n", - " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = jnp.split(qkv, 3, axis=3)\n", - "\n", - " # Attention - either TE or Flax implementation\n", - " if self.use_te_attention:\n", - " # Use TE's DotProductAttention\n", - " attention = te_flax.DotProductAttention(\n", - " head_dim=self.kv_channels,\n", - " num_attention_heads=self.num_attention_heads,\n", - " num_gqa_groups=self.num_attention_heads, # No GQA\n", - " attention_dropout=self.attention_dropout,\n", - " attn_mask_type='causal',\n", - " )\n", - " x = attention(\n", - " q, k, v,\n", - " # Causal mask does not need an explicit instatiated mask as specialized kernels exist to handle it\n", - " sequence_descriptor=None, \n", - " deterministic=deterministic\n", - " )\n", - " # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n", - " x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n", - " x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n", - " x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n", - " else:\n", - " # Use Flax's MultiHeadDotProductAttention\n", - " q_reshaped = q.reshape(q.shape[0], q.shape[1], self.hidden_size)\n", - " k_reshaped = k.reshape(k.shape[0], k.shape[1], self.hidden_size)\n", - " v_reshaped = v.reshape(v.shape[0], v.shape[1], self.hidden_size)\n", - "\n", - " # Create causal mask if not provided\n", - " if attention_mask is None:\n", - " attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_)\n", - " \n", - " attention = nn.MultiHeadDotProductAttention(\n", - " num_heads=self.num_attention_heads,\n", - " qkv_features=self.kv_channels,\n", - " dropout_rate=self.attention_dropout,\n", - " )\n", - " x = attention(q_reshaped, k_reshaped, v_reshaped, mask=attention_mask, deterministic=deterministic)\n", - "\n", - " x = res + x\n", - "\n", - " # Second residual connection\n", - " res = x\n", - " x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x)\n", - "\n", - " # MLP\n", - " mlp = TEUnfusedMLP(\n", - " hidden_size=self.hidden_size,\n", - " ffn_hidden_size=self.ffn_hidden_size\n", - " )\n", - "\n", - " x = mlp(x, deterministic=deterministic)\n", - "\n", - " return x + res" - ] - }, - { - "cell_type": "markdown", - "id": "a76911ac", - "metadata": {}, - "source": [ - "Testing performance of the model, using `DenseGeneral`, `LayerNorm` and activation from TE, while keeping Flax's `MultiHeadDotProductAttention` the same as the first simple Transformer in JAX implementation. To read more about this implementation from Flax, you can refer to this documentation: https://flax.readthedocs.io/en/latest/api_reference/flax.nnx/nn/attention.html" - ] - }, - { - "cell_type": "code", - "execution_count": 9, - "id": "4b67511f", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 16.003193855285645 ms\n" - ] - } - ], - "source": [ - "te_unfused_transformer_with_flax_MHA = TEUnfusedTransformerLayer(\n", - " hidden_size, \n", - " ffn_hidden_size, \n", - " num_attention_heads,\n", - " use_te_attention=False\n", - ")\n", - "\n", - "te_params = te_unfused_transformer_with_flax_MHA.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=te_unfused_transformer_with_flax_MHA.apply,\n", - " variables=te_params, # Ensure the correct `params` is passed\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " rngs={\"dropout\": dropout_key},\n", - ")\n" - ] - }, - { - "cell_type": "markdown", - "id": "0b230058", - "metadata": {}, - "source": [ - "Now, we move on to also replace the attention sub-layer with TE's `DotProductAttention` implementation" - ] - }, - { - "cell_type": "code", - "execution_count": 10, - "id": "5146cd99", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 8.897695541381836 ms\n" - ] - } - ], - "source": [ - "te_unfused_transformer = TEUnfusedTransformerLayer(\n", - " hidden_size, \n", - " ffn_hidden_size, \n", - " num_attention_heads,\n", - ")\n", - "\n", - "te_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=te_unfused_transformer.apply,\n", - " variables=te_params, # Ensure the correct `params` is passed\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "c9a101d3", - "metadata": {}, - "source": [ - "## Enabling Quantization (FP8 or FP4)\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We configure a TE module to perform compute in FP8.\n", - "\n", - "
\n", - "\n", - "Enabling FP8 support is very simple in Transformer Engine. We just need to wrap the modules within an [autocast](../api/jax.rst#transformer_engine.jax.fp8_autocast) context manager. See the [FP8 tutorial](fp8_primer.ipynb) for a detailed explanation of FP8 recipes and the supported options.\n", - "\n", - "
\n", - "\n", - "Important: FP8 Metadata Initialization\n", - "\n", - "When using FP8, the model **must be initialized within the `autocast` context**. This creates a special collection called `fp8_metas` that contains scaling factors and other metadata required for FP8 computation. If you initialize a model outside of `autocast` and then try to use it with FP8, you will get a `ScopeCollectionNotFound` error because the `fp8_metas` collection was never created.\n", - "\n", - "
" - ] - }, - { - "cell_type": "code", - "execution_count": 11, - "id": "c2eee376", - "metadata": {}, - "outputs": [], - "source": [ - "from transformer_engine.common.recipe import Format, DelayedScaling\n", - "fp8_format = Format.HYBRID\n", - "fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo=\"max\")" - ] - }, - { - "cell_type": "code", - "execution_count": 12, - "id": "de96827c", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 5.651178359985352 ms\n" - ] - } - ], - "source": [ - "with te.autocast(enabled=True, recipe=fp8_recipe):\n", - " te_unfused_params = te_unfused_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - "\n", - " # Example usage of forward \n", - " y = te_unfused_transformer.apply(te_unfused_params, x, attention_mask=None, deterministic=True)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=te_unfused_transformer.apply,\n", - " variables=te_unfused_params, # Ensure the correct `params` is passed\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "3801b201", - "metadata": {}, - "source": [ - "\n", - "## Fused TE Modules\n", - "\n", - "
\n", - "\n", - "Summary\n", - " \n", - "We optimize the example Transformer layer with TE modules for fused operations.\n", - "\n", - "
\n", - "\n", - "The `DenseGeneral` layer is enough to build any Transformer model and it enables usage of the Transformer Engine even for very custom Transformers. However, having more knowledge about the model allows for additional optimizations such as kernel fusions in mixed-precision recipes, increasing the achievable speedup.\n", - "\n", - "Transformer Engine therefore provides coarser modules that span multiple layers:\n", - "\n", - "* `LayerNormDenseGeneral`\n", - "* `LayerNormMLP`\n", - "* `TransformerLayer`\n", - "\n", - "To see a complete list of all the functions TE Flax support, you can view it here: https://docs.nvidia.com/deeplearning/transformer-engine/user-guide/api/jax.html#modules\n", - "\n", - "Building a third iteration of our Transformer layer with `LayerNormDenseGeneral` and `LayerNormMLP`:" - ] - }, - { - "cell_type": "code", - "execution_count": 13, - "id": "11203785", - "metadata": {}, - "outputs": [], - "source": [ - "class TEFusedTransformerLayer(nn.Module):\n", - " hidden_size: int\n", - " ffn_hidden_size: int \n", - " num_attention_heads: int \n", - " layernorm_eps: float = 1e-5\n", - " attention_dropout: float = 0.1\n", - "\n", - " def setup(self):\n", - " self.kv_channels = self.hidden_size // self.num_attention_heads\n", - "\n", - " @nn.compact\n", - " def __call__(\n", - " self, \n", - " x: jnp.ndarray,\n", - " attention_mask: Optional[jnp.ndarray] = None,\n", - " deterministic: bool = False\n", - " ) -> jnp.ndarray:\n", - " res = x\n", - "\n", - " # Fused QKV projection\n", - " qkv,_ = te_flax.LayerNormDenseGeneral(features=3 * self.hidden_size, \n", - " epsilon=self.layernorm_eps, \n", - " use_bias=True, \n", - " return_layernorm_output=False)(x)\n", - " qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels)\n", - " q, k, v = jnp.split(qkv, 3, axis=3)\n", - "\n", - " # Attention using TE's DotProductAttention\n", - " attention = te_flax.DotProductAttention(\n", - " head_dim=self.kv_channels,\n", - " num_attention_heads=self.num_attention_heads,\n", - " num_gqa_groups=self.num_attention_heads, \n", - " attention_dropout=self.attention_dropout,\n", - " attn_mask_type='causal',\n", - " )\n", - " x = attention(q, k, v, sequence_descriptor=None, deterministic=deterministic)\n", - " # Reshape from [batch, seq_len, num_heads, head_dim] to [batch, seq_len, hidden_size]\n", - " x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3]))\n", - " x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x)\n", - " x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic)\n", - "\n", - " x = res + x\n", - "\n", - " # Second residual connection\n", - " res = x\n", - " x,_ = te_flax.LayerNormMLP(intermediate_dim=self.ffn_hidden_size, \n", - " epsilon=self.layernorm_eps,\n", - " use_bias=True,\n", - " activations=('gelu',),\n", - " intermediate_dropout_rate=0.0,\n", - " return_layernorm_output=False\n", - " )(x, deterministic=deterministic)\n", - "\n", - " return x + res" - ] - }, - { - "cell_type": "markdown", - "id": "334cff59", - "metadata": {}, - "source": [ - "Similar to the unnfused model, we also compare the performance of fused model when using Flax's MultiheadDotProductAttention implementation and TE's." - ] - }, - { - "cell_type": "code", - "execution_count": 14, - "id": "6b0c705e", - "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n", - "/mnt/jberchtold/ptyche-lustre-home/transformerengine/transformer_engine/jax/flax/transformer.py:626: UserWarning: transpose_batch_sequence defaults to False in DotProductAttention starting TransformerEngine v2.10\n", - " warnings.warn(\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 5.493879318237305 ms\n" - ] - } - ], - "source": [ - "te_fused_transformer = TEFusedTransformerLayer(\n", - " hidden_size, \n", - " ffn_hidden_size, \n", - " num_attention_heads\n", - ")\n", - "\n", - "with te.autocast(enabled=True, recipe=fp8_recipe):\n", - " te_fused_params = te_fused_transformer.init(key, x, attention_mask=None, deterministic=False)\n", - " # Example usage of forward \n", - " y = te_fused_transformer.apply(te_fused_params, x, attention_mask=None, deterministic=True)\n", - "\n", - "utils.speedometer(\n", - " model_apply_fn=te_fused_transformer.apply,\n", - " variables=te_fused_params,\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe},\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - }, - { - "cell_type": "markdown", - "id": "a45c12c8", - "metadata": {}, - "source": [ - "Finally, the `TransformerLayer` module is convenient for creating standard Transformer architectures." - ] - }, - { - "cell_type": "code", - "execution_count": 15, - "id": "b2aaa8ef", - "metadata": {}, - "outputs": [], - "source": [ - "te_transformer = te_flax.TransformerLayer(\n", - " hidden_size=hidden_size,\n", - " mlp_hidden_size=ffn_hidden_size, \n", - " num_attention_heads=num_attention_heads,\n", - " mlp_activations=(\"gelu\",),\n", - " self_attn_mask_type='causal',\n", - " layernorm_epsilon=1e-5,\n", - " use_bias=True,\n", - " intermediate_dropout=0.0,\n", - " enable_relative_embedding=False,\n", - " self_attn_bias_type='no_bias',\n", - " hidden_dropout=0.0,\n", - ")\n", - "\n", - "with te.autocast(enabled=True, recipe=fp8_recipe):\n", - " te_transformer_params = te_transformer.init(key, x, deterministic=False)\n", - " y = te_transformer.apply(te_transformer_params, x, attention_mask=None, deterministic=True)" - ] - }, - { - "cell_type": "code", - "execution_count": 20, - "id": "b9cdbf22", - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Mean time: 5.334172248840332 ms\n" - ] - } - ], - "source": [ - "utils.speedometer(\n", - " model_apply_fn=te_transformer.apply,\n", - " model_init_fn=te_transformer.init,\n", - " variables=te_transformer_params,\n", - " input=x,\n", - " output_grad=dy,\n", - " forward_kwargs={\"attention_mask\": None, \"deterministic\": False},\n", - " autocast_kwargs = { \"enabled\": True, \"recipe\": fp8_recipe },\n", - " rngs={\"dropout\": dropout_key},\n", - ")" - ] - } - ], - "metadata": { - "kernelspec": { - "display_name": "Python 3 (ipykernel)", - "language": "python", - "name": "python3" - }, - "language_info": { - "codemirror_mode": { - "name": "ipython", - "version": 3 - }, - "file_extension": ".py", - "mimetype": "text/x-python", - "name": "python", - "nbconvert_exporter": "python", - "pygments_lexer": "ipython3", - "version": "3.12.3" - } - }, - "nbformat": 4, - "nbformat_minor": 5 -} diff --git a/docs/examples/quickstart_jax_utils.py b/docs/examples/quickstart_jax_utils.py index f1ff9b7d995..9a145a511b1 100644 --- a/docs/examples/quickstart_jax_utils.py +++ b/docs/examples/quickstart_jax_utils.py @@ -5,13 +5,9 @@ import jax import jax.numpy as jnp import time -import math from typing import Callable, Any, Dict, Optional, Tuple -from flax import linen as nn import transformer_engine.jax as te -import transformer_engine.jax.flax as te_flax -from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention def speedometer( diff --git a/docs/examples/te_jax_integration.ipynb b/docs/examples/te_jax_integration.ipynb index 70647e421a0..66d16ed52f4 100644 --- a/docs/examples/te_jax_integration.ipynb +++ b/docs/examples/te_jax_integration.ipynb @@ -264,7 +264,7 @@ "id": "5e9310c9", "metadata": {}, "source": [ - "# Transformer Engine" + "## Transformer Engine" ] }, { diff --git a/docs/getting_started.rst b/docs/getting_started.rst deleted file mode 100644 index 2e8047763a8..00000000000 --- a/docs/getting_started.rst +++ /dev/null @@ -1,16 +0,0 @@ -.. - Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. - - See LICENSE for license information. - -Getting Started -=============== - -Choose your framework to get started with Transformer Engine: - -.. toctree:: - :maxdepth: 1 - - PyTorch - JAX - diff --git a/docs/getting_started/getting_started_jax.out b/docs/getting_started/getting_started_jax.out new file mode 100644 index 00000000000..c11f3b1965f --- /dev/null +++ b/docs/getting_started/getting_started_jax.out @@ -0,0 +1,34 @@ +pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel +pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-jax-py3-devel +# BENCHMARK_BASELINE_OUTPUT_START +Baseline Flax: +Mean time: 86.580 ms +# BENCHMARK_BASELINE_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_OUTPUT_START +TE Unfused: +Mean time: 42.252 ms +# BENCHMARK_TE_UNFUSED_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START +TE Unfused + TE Attention: +Mean time: 35.054 ms +# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START +TE Unfused + TE Attention + FP8: +Mean time: 22.638 ms +# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END + +# BENCHMARK_TE_FUSED_FP8_OUTPUT_START +TE Fused + TE Attention + FP8: +Mean time: 23.703 ms +# BENCHMARK_TE_FUSED_FP8_OUTPUT_END + +# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START +TE TransformerLayer + FP8: +Mean time: 22.812 ms +# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END + + +Summary written to getting_started_jax_summary.csv diff --git a/docs/getting_started/getting_started_jax.py b/docs/getting_started/getting_started_jax.py new file mode 100644 index 00000000000..a2ce3c0ec7b --- /dev/null +++ b/docs/getting_started/getting_started_jax.py @@ -0,0 +1,523 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Getting Started with Transformer Engine - JAX Example +====================================================== + +This example shows how to build a Transformer decoder layer using JAX/Flax +and how to optimize it with Transformer Engine. +""" + +import jax +import jax.numpy as jnp +from flax import linen as nn +from typing import Optional + +import transformer_engine.jax as te +import transformer_engine.jax.flax as te_flax +from transformer_engine.jax.sharding import MeshResource +from transformer_engine.common.recipe import Format, DelayedScaling + +from getting_started_utils_jax import speedometer + + +# Configuration +hidden_size = 4096 +sequence_length = 2048 +batch_size = 8 +ffn_hidden_size = 16384 +num_attention_heads = 32 +dtype = jnp.bfloat16 + +# Create synthetic data +key = jax.random.PRNGKey(42) +x = jax.random.normal(key, (batch_size, sequence_length, hidden_size)).astype(dtype) +mesh_resource = MeshResource() + + +# ============================================================================= +# Baseline: Pure Flax Implementation +# ============================================================================= + + +# BASELINE_MLP_START +class FlaxMLP(nn.Module): + """Feed-forward network in Transformer layer. + Built with plain Flax modules. + """ + + hidden_size: int + ffn_hidden_size: int + + @nn.compact + def __call__(self, x: jnp.ndarray) -> jnp.ndarray: + x = nn.Dense(features=self.ffn_hidden_size, use_bias=True)(x) + x = nn.gelu(x, approximate=True) + x = nn.Dense(features=self.hidden_size, use_bias=True)(x) + return x + + +# BASELINE_MLP_END + + +# BASELINE_LAYER_START +class FlaxTransformerLayer(nn.Module): + """Basic Transformer layer using plain Flax modules.""" + + hidden_size: int + ffn_hidden_size: int + num_attention_heads: int + layernorm_eps: float = 1e-5 + attention_dropout: float = 0.1 + + def setup(self): + self.kv_channels = self.hidden_size // self.num_attention_heads + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = False, + ) -> jnp.ndarray: + if attention_mask is None: + attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_) + + res = x + x = nn.LayerNorm(epsilon=self.layernorm_eps)(x) + + # Fused QKV projection + qkv = nn.Dense(features=3 * self.hidden_size, use_bias=True)(x) + qkv = qkv.reshape( + qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels + ) + q, k, v = jnp.split(qkv, 3, axis=3) + + dropout_rng = None + if not deterministic and self.attention_dropout > 0: + dropout_rng = self.make_rng("dropout") + + x = nn.dot_product_attention( + query=q, + key=k, + value=v, + mask=attention_mask, + dropout_rng=dropout_rng, + dropout_rate=self.attention_dropout, + deterministic=deterministic, + broadcast_dropout=True, + ) + + x = x.reshape(x.shape[0], x.shape[1], self.hidden_size) + x = nn.Dense(features=self.hidden_size, use_bias=True)(x) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + x = res + x + + res = x + x = nn.LayerNorm(epsilon=self.layernorm_eps)(x) + mlp = FlaxMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size) + x = mlp(x) + + return x + res + + +# BASELINE_LAYER_END + + +print("# BENCHMARK_BASELINE_OUTPUT_START") +# BENCHMARK_BASELINE_START +baseline = FlaxTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) +params = baseline.init(key, x, deterministic=False) + +print("Baseline Flax:") +time_baseline = speedometer( + baseline.apply, params, x, forward_kwargs={"deterministic": True}, label="baseline" +) +# BENCHMARK_BASELINE_END +print("# BENCHMARK_BASELINE_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused: Basic TE Modules +# ============================================================================= + + +# TE_UNFUSED_MLP_START +class TEUnfusedMLP(nn.Module): + """MLP using TE modules.""" + + hidden_size: int + ffn_hidden_size: int + + @nn.compact + def __call__(self, x: jnp.ndarray, deterministic: bool) -> jnp.ndarray: + x = te_flax.DenseGeneral(features=self.ffn_hidden_size, use_bias=True)(x) + x = x.reshape(*x.shape[:-1], 1, x.shape[-1]) + x = te.activation.activation(x, activation_type=("gelu",)) + x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x) + return x + + +# TE_UNFUSED_MLP_END + + +# TE_UNFUSED_LAYER_START +class TEUnfusedTransformerLayer(nn.Module): + """Transformer layer using basic TE modules (without TE attention).""" + + hidden_size: int + ffn_hidden_size: int + num_attention_heads: int + layernorm_eps: float = 1e-5 + attention_dropout: float = 0.1 + + def setup(self): + self.kv_channels = self.hidden_size // self.num_attention_heads + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = False, + ) -> jnp.ndarray: + if attention_mask is None: + attention_mask = nn.make_causal_mask(x[..., 0], dtype=jnp.bool_) + + res = x + x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x) + + qkv = te_flax.DenseGeneral(features=3 * self.hidden_size, use_bias=True)(x) + qkv = qkv.reshape( + qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels + ) + q, k, v = jnp.split(qkv, 3, axis=3) + + dropout_rng = None + if not deterministic and self.attention_dropout > 0: + dropout_rng = self.make_rng("dropout") + + x = nn.dot_product_attention( + query=q, + key=k, + value=v, + mask=attention_mask, + dropout_rng=dropout_rng, + dropout_rate=self.attention_dropout, + deterministic=deterministic, + broadcast_dropout=True, + ) + + x = x.reshape(x.shape[0], x.shape[1], self.hidden_size) + x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + x = res + x + + res = x + x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x) + mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size) + x = mlp(x, deterministic=deterministic) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + return x + res + + +# TE_UNFUSED_LAYER_END + + +print("# BENCHMARK_TE_UNFUSED_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_START +te_unfused = TEUnfusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) +params = te_unfused.init(key, x, deterministic=False) + +print("TE Unfused:") +time_te_unfused = speedometer( + te_unfused.apply, params, x, forward_kwargs={"deterministic": True}, label="te_unfused" +) +# BENCHMARK_TE_UNFUSED_END +print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused + TE Attention +# ============================================================================= + + +# TE_UNFUSED_ATTN_LAYER_START +class TEUnfusedAttnTransformerLayer(nn.Module): + """Transformer layer using TE modules including TE DotProductAttention.""" + + hidden_size: int + ffn_hidden_size: int + num_attention_heads: int + layernorm_eps: float = 1e-5 + attention_dropout: float = 0.1 + + def setup(self): + self.kv_channels = self.hidden_size // self.num_attention_heads + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = False, + ) -> jnp.ndarray: + res = x + x = te_flax.LayerNorm(epsilon=self.layernorm_eps, dtype=jnp.bfloat16)(x) + + qkv = te_flax.DenseGeneral( + features=3 * self.hidden_size, use_bias=True, dtype=jnp.bfloat16 + )(x) + qkv = qkv.reshape( + qkv.shape[0], qkv.shape[1], self.num_attention_heads, 3 * self.kv_channels + ) + q, k, v = jnp.split(qkv, 3, axis=3) + + attention = te_flax.DotProductAttention( + head_dim=self.kv_channels, + num_attention_heads=self.num_attention_heads, + num_gqa_groups=self.num_attention_heads, + attention_dropout=self.attention_dropout, + attn_mask_type="causal", + transpose_batch_sequence=False, + ) + x = attention(q, k, v, deterministic=deterministic) + x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) + x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True, dtype=jnp.bfloat16)(x) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + x = res + x + + res = x + x = te_flax.LayerNorm(epsilon=self.layernorm_eps)(x) + mlp = TEUnfusedMLP(hidden_size=self.hidden_size, ffn_hidden_size=self.ffn_hidden_size) + x = mlp(x, deterministic=deterministic) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + return x + res + + +# TE_UNFUSED_ATTN_LAYER_END + + +print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_ATTN_START +te_unfused_attn = TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) + +with te.autocast(enabled=False, mesh_resource=mesh_resource): + params = te_unfused_attn.init(key, x, deterministic=False) + +print("TE Unfused + TE Attention:") +time_te_unfused_attn = speedometer( + te_unfused_attn.apply, + params, + x, + forward_kwargs={"deterministic": True}, + autocast_kwargs={"enabled": False, "mesh_resource": mesh_resource}, + label="te_unfused_attn", +) +# BENCHMARK_TE_UNFUSED_ATTN_END +print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused + FP8 +# ============================================================================= + +print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_FP8_START +recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max") + +te_unfused_fp8 = TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) + +with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource): + params = te_unfused_fp8.init(key, x, deterministic=False) + +print("TE Unfused + TE Attention + FP8:") +time_te_unfused_fp8 = speedometer( + te_unfused_fp8.apply, + params, + x, + forward_kwargs={"deterministic": True}, + autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource}, + label="te_unfused_fp8", +) +# BENCHMARK_TE_UNFUSED_FP8_END +print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n") + + +# ============================================================================= +# TE Fused + FP8: Optimized Modules with FP8 +# ============================================================================= + + +# TE_FUSED_LAYER_START +class TEFusedTransformerLayer(nn.Module): + """Transformer layer using fused TE modules for better performance.""" + + hidden_size: int + ffn_hidden_size: int + num_attention_heads: int + layernorm_eps: float = 1e-5 + attention_dropout: float = 0.1 + + def setup(self): + self.kv_channels = self.hidden_size // self.num_attention_heads + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + attention_mask: Optional[jnp.ndarray] = None, + deterministic: bool = False, + ) -> jnp.ndarray: + res = x + + # Fused LayerNorm + QKV projection + qkv, _ = te_flax.LayerNormDenseGeneral( + features=3 * self.hidden_size, + epsilon=self.layernorm_eps, + use_bias=True, + return_layernorm_output=False, + )(x) + qkv = qkv.reshape(qkv.shape[0], qkv.shape[1], 3, self.num_attention_heads, self.kv_channels) + q, k, v = qkv[:, :, 0, :, :], qkv[:, :, 1, :, :], qkv[:, :, 2, :, :] + + attention = te_flax.DotProductAttention( + head_dim=self.kv_channels, + num_attention_heads=self.num_attention_heads, + num_gqa_groups=self.num_attention_heads, + attention_dropout=self.attention_dropout, + attn_mask_type="causal", + qkv_layout="bshd_bshd_bshd", + transpose_batch_sequence=False, + ) + x = attention(q, k, v, deterministic=deterministic) + x = x.reshape((x.shape[0], x.shape[1], x.shape[2] * x.shape[3])) + x = te_flax.DenseGeneral(features=self.hidden_size, use_bias=True)(x) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + x = res + x + + res = x + # Fused LayerNorm + MLP + x, _ = te_flax.LayerNormMLP( + intermediate_dim=self.ffn_hidden_size, + epsilon=self.layernorm_eps, + use_bias=True, + activations=("gelu",), + intermediate_dropout_rate=0.0, + return_layernorm_output=False, + )(x, deterministic=deterministic) + x = nn.Dropout(rate=self.attention_dropout)(x, deterministic=deterministic) + + return x + res + + +# TE_FUSED_LAYER_END + + +print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START") +# BENCHMARK_TE_FUSED_FP8_START +te_fused_fp8 = TEFusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, +) + +with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource): + params = te_fused_fp8.init(key, x, deterministic=False) + +print("TE Fused + TE Attention + FP8:") +time_te_fused_fp8 = speedometer( + te_fused_fp8.apply, + params, + x, + forward_kwargs={"deterministic": True}, + autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource}, + label="te_fused_fp8", +) +# BENCHMARK_TE_FUSED_FP8_END +print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n") + + +# ============================================================================= +# TE TransformerLayer + FP8: Ready-to-use Module +# ============================================================================= + +print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START") +# BENCHMARK_TE_TRANSFORMER_LAYER_START +te_transformer_layer = te_flax.TransformerLayer( + hidden_size=hidden_size, + mlp_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + mlp_activations=("gelu",), + self_attn_mask_type="causal", + layernorm_epsilon=1e-5, + use_bias=True, + attention_dropout=0.0, + intermediate_dropout=0.0, + hidden_dropout=0.0, + enable_relative_embedding=False, + self_attn_bias_type="no_bias", + dtype=jnp.bfloat16, + transpose_batch_sequence=False, +) + +with te.autocast(enabled=True, recipe=recipe, mesh_resource=mesh_resource): + params = te_transformer_layer.init(key, x, deterministic=False) + +print("TE TransformerLayer + FP8:") +time_te_transformer_layer = speedometer( + te_transformer_layer.apply, + params, + x, + forward_kwargs={"deterministic": True}, + autocast_kwargs={"enabled": True, "recipe": recipe, "mesh_resource": mesh_resource}, + label="te_transformer_layer", +) +# BENCHMARK_TE_TRANSFORMER_LAYER_END +print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n") + +# Write summary CSV for RST documentation +with open("getting_started_jax_summary.csv", "w") as f: + f.write("Implementation,Time (ms),Speedup\n") + f.write(f"Baseline Flax,{time_baseline:.2f},1.00x\n") + f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n") + f.write( + "TE Unfused + TE" + f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n" + ) + f.write( + "TE Unfused + TE Attention +" + f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n" + ) + f.write( + "TE Fused + TE Attention +" + f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n" + ) + f.write( + "TE TransformerLayer +" + f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n" + ) +print("\nSummary written to getting_started_jax_summary.csv") diff --git a/docs/getting_started/getting_started_jax_summary.csv b/docs/getting_started/getting_started_jax_summary.csv new file mode 100644 index 00000000000..5b6a4249b3d --- /dev/null +++ b/docs/getting_started/getting_started_jax_summary.csv @@ -0,0 +1,7 @@ +Implementation,Time (ms),Speedup +Baseline Flax,86.58,1.00x +TE Unfused,42.25,2.05x +TE Unfused + TE Attention,35.05,2.47x +TE Unfused + TE Attention + FP8,22.64,3.82x +TE Fused + TE Attention + FP8,23.70,3.65x +TE TransformerLayer + FP8,22.81,3.80x diff --git a/docs/getting_started/getting_started_pytorch.out b/docs/getting_started/getting_started_pytorch.out new file mode 100644 index 00000000000..9b9387a8b22 --- /dev/null +++ b/docs/getting_started/getting_started_pytorch.out @@ -0,0 +1,42 @@ +pyxis: importing docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64 +pyxis: imported docker image: gitlab-master.nvidia.com/dl/transformerengine/transformerengine:main-pytorch-py3-devel-amd64 +/usr/local/lib/python3.12/dist-packages/torch/library.py:357: UserWarning: Warning only once for all operators, other operators may also be overridden. + Overriding a previously registered kernel for the same operator and the same dispatch key + operator: flash_attn::_flash_attn_backward(Tensor dout, Tensor q, Tensor k, Tensor v, Tensor out, Tensor softmax_lse, Tensor(a6!)? dq, Tensor(a7!)? dk, Tensor(a8!)? dv, float dropout_p, float softmax_scale, bool causal, SymInt window_size_left, SymInt window_size_right, float softcap, Tensor? alibi_slopes, bool deterministic, Tensor? rng_state=None) -> Tensor + registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926 + dispatch key: ADInplaceOrView + previous kernel: no debug info + new kernel: registered at /usr/local/lib/python3.12/dist-packages/torch/_library/custom_ops.py:926 (Triggered internally at /opt/pytorch/pytorch/aten/src/ATen/core/dispatch/OperatorEntry.cpp:208.) + self.m.impl( +# BENCHMARK_BASELINE_OUTPUT_START +Baseline PyTorch: +Mean time: 48.280 ms +# BENCHMARK_BASELINE_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_OUTPUT_START +TE Unfused: +Mean time: 49.342 ms +# BENCHMARK_TE_UNFUSED_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START +TE Unfused + TE Attention: +Mean time: 35.709 ms +# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END + +# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START +TE Unfused + TE Attention + FP8: +Mean time: 23.406 ms +# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END + +# BENCHMARK_TE_FUSED_FP8_OUTPUT_START +TE Fused + TE Attention + FP8: +Mean time: 22.964 ms +# BENCHMARK_TE_FUSED_FP8_OUTPUT_END + +# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START +TE TransformerLayer + FP8: +Mean time: 21.670 ms +# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END + + +Summary written to getting_started_pytorch_summary.csv diff --git a/docs/getting_started/getting_started_pytorch.py b/docs/getting_started/getting_started_pytorch.py new file mode 100644 index 00000000000..eed21ee0aad --- /dev/null +++ b/docs/getting_started/getting_started_pytorch.py @@ -0,0 +1,497 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Getting Started with Transformer Engine - PyTorch Example +========================================================== + +This example shows how to build a Transformer layer using PyTorch +and how to optimize it with Transformer Engine. +""" + +from typing import Optional +import torch +import transformer_engine.pytorch as te +from transformer_engine.common.recipe import Format, DelayedScaling + +from getting_started_utils_pytorch import DotProductAttention, speedometer + + +# Configuration +hidden_size = 4096 +sequence_length = 2048 +batch_size = 8 +ffn_hidden_size = 16384 +num_attention_heads = 32 +dtype = torch.bfloat16 + +# Create synthetic data +x = torch.rand(sequence_length, batch_size, hidden_size).cuda().to(dtype=dtype) + + +# ============================================================================= +# Baseline: Pure PyTorch Implementation +# ============================================================================= + + +# BASELINE_MLP_START +class PyTorchMLP(torch.nn.Module): + """Feed-forward network in Transformer layer. + Built with plain PyTorch modules. + """ + + hidden_size: int + ffn_hidden_size: int + + def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None: + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.linear1 = torch.nn.Linear(hidden_size, ffn_hidden_size, bias=True) + self.linear2 = torch.nn.Linear(ffn_hidden_size, hidden_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = torch.nn.functional.gelu(x, approximate="tanh") + x = self.linear2(x) + return x + + +# BASELINE_MLP_END + + +# BASELINE_LAYER_START +class PyTorchTransformerLayer(torch.nn.Module): + """Basic Transformer layer using plain PyTorch modules.""" + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + layernorm_eps: float = 1e-5, + attention_dropout: float = 0.1, + hidden_dropout: float = 0.1, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.kv_channels = hidden_size // num_attention_heads + self.ln1 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps) + self.qkv_projection = torch.nn.Linear(hidden_size, 3 * hidden_size, bias=True) + self.attention = DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=self.kv_channels, + attention_dropout=attention_dropout, + ) + self.projection = torch.nn.Linear(hidden_size, hidden_size, bias=True) + self.dropout = torch.nn.Dropout(hidden_dropout) + self.ln2 = torch.nn.LayerNorm(hidden_size, eps=layernorm_eps) + self.mlp = PyTorchMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size) + + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + res = x + x = self.ln1(x) + + # Fused QKV projection + qkv = self.qkv_projection(x) + qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels) + q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3) + + x = self.attention(q, k, v, attention_mask) + x = self.projection(x) + x = self.dropout(x) + x = res + x + + # Second residual connection + res = x + x = self.ln2(x) + x = self.mlp(x) + + return x + res + + +# BASELINE_LAYER_END + + +print("# BENCHMARK_BASELINE_OUTPUT_START") +# BENCHMARK_BASELINE_START +baseline = ( + PyTorchTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() +) + +print("Baseline PyTorch:") +time_baseline = speedometer(baseline, x, forward_kwargs={"attention_mask": None}, label="baseline") +# BENCHMARK_BASELINE_END +print("# BENCHMARK_BASELINE_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused: Basic TE Modules +# ============================================================================= + + +# TE_UNFUSED_MLP_START +class TEUnfusedMLP(torch.nn.Module): + """MLP using TE modules.""" + + hidden_size: int + ffn_hidden_size: int + + def __init__(self, hidden_size: int, ffn_hidden_size: int) -> None: + super().__init__() + self.hidden_size = hidden_size + self.ffn_hidden_size = ffn_hidden_size + self.linear1 = te.Linear(hidden_size, ffn_hidden_size, bias=True) + self.linear2 = te.Linear(ffn_hidden_size, hidden_size, bias=True) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.linear1(x) + x = torch.nn.functional.gelu(x, approximate="tanh") + x = self.linear2(x) + return x + + +# TE_UNFUSED_MLP_END + + +# TE_UNFUSED_LAYER_START +class TEUnfusedTransformerLayer(torch.nn.Module): + """Transformer layer using basic TE modules.""" + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + layernorm_eps: float = 1e-5, + attention_dropout: float = 0.1, + hidden_dropout: float = 0.1, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.kv_channels = hidden_size // num_attention_heads + self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps) + self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True) + self.attention = DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=self.kv_channels, + attention_dropout=attention_dropout, + ) + self.projection = te.Linear(hidden_size, hidden_size, bias=True) + self.dropout1 = torch.nn.Dropout(hidden_dropout) + self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps) + self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size) + self.dropout2 = torch.nn.Dropout(hidden_dropout) + + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + res = x + x = self.ln1(x) + + # Fused QKV projection + qkv = self.qkv_projection(x) + qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels) + q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3) + + x = self.attention(q, k, v, attention_mask) + x = self.projection(x) + x = self.dropout1(x) + x = res + x + + # Second residual connection + res = x + x = self.ln2(x) + x = self.mlp(x) + x = self.dropout2(x) + + return x + res + + +# TE_UNFUSED_LAYER_END + + +print("# BENCHMARK_TE_UNFUSED_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_START +te_unfused = ( + TEUnfusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() +) + +print("TE Unfused:") +time_te_unfused = speedometer( + te_unfused, x, forward_kwargs={"attention_mask": None}, label="te_unfused" +) +# BENCHMARK_TE_UNFUSED_END +print("# BENCHMARK_TE_UNFUSED_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused + TE Attention +# ============================================================================= + + +# TE_UNFUSED_ATTN_LAYER_START +class TEUnfusedAttnTransformerLayer(torch.nn.Module): + """Transformer layer using TE modules including TE DotProductAttention.""" + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + layernorm_eps: float = 1e-5, + attention_dropout: float = 0.1, + hidden_dropout: float = 0.1, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.kv_channels = hidden_size // num_attention_heads + self.ln1 = te.LayerNorm(hidden_size, eps=layernorm_eps) + self.qkv_projection = te.Linear(hidden_size, 3 * hidden_size, bias=True) + self.attention = te.DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=self.kv_channels, + attention_dropout=attention_dropout, + attn_mask_type="causal", + ) + self.projection = te.Linear(hidden_size, hidden_size, bias=True) + self.dropout1 = torch.nn.Dropout(hidden_dropout) + self.ln2 = te.LayerNorm(hidden_size, eps=layernorm_eps) + self.mlp = TEUnfusedMLP(hidden_size=hidden_size, ffn_hidden_size=ffn_hidden_size) + self.dropout2 = torch.nn.Dropout(hidden_dropout) + + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + res = x + x = self.ln1(x) + + # Fused QKV projection + qkv = self.qkv_projection(x) + qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels) + q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3) + + x = self.attention(q, k, v, attention_mask) + x = self.projection(x) + x = self.dropout1(x) + x = res + x + + # Second residual connection + res = x + x = self.ln2(x) + x = self.mlp(x) + x = self.dropout2(x) + + return x + res + + +# TE_UNFUSED_ATTN_LAYER_END + + +print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_ATTN_START +te_unfused_attn = ( + TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() +) + +print("TE Unfused + TE Attention:") +time_te_unfused_attn = speedometer( + te_unfused_attn, x, forward_kwargs={"attention_mask": None}, label="te_unfused_attn" +) +# BENCHMARK_TE_UNFUSED_ATTN_END +print("# BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END\n") + + +# ============================================================================= +# TE Unfused + FP8 +# ============================================================================= + +print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START") +# BENCHMARK_TE_UNFUSED_FP8_START +recipe = DelayedScaling(fp8_format=Format.HYBRID, amax_history_len=16, amax_compute_algo="max") + +te_unfused_fp8 = ( + TEUnfusedAttnTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() +) + +print("TE Unfused + TE Attention + FP8:") +time_te_unfused_fp8 = speedometer( + te_unfused_fp8, + x, + forward_kwargs={"attention_mask": None}, + autocast_kwargs={"enabled": True, "recipe": recipe}, + label="te_unfused_fp8", +) +# BENCHMARK_TE_UNFUSED_FP8_END +print("# BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END\n") + + +# ============================================================================= +# TE Fused + FP8: Optimized Modules with FP8 +# ============================================================================= + + +# TE_FUSED_LAYER_START +class TEFusedTransformerLayer(torch.nn.Module): + """Transformer layer using fused TE modules for better performance.""" + + def __init__( + self, + hidden_size: int, + ffn_hidden_size: int, + num_attention_heads: int, + layernorm_eps: float = 1e-5, + attention_dropout: float = 0.1, + hidden_dropout: float = 0.1, + ): + super().__init__() + self.num_attention_heads = num_attention_heads + self.kv_channels = hidden_size // num_attention_heads + + # Fused LayerNorm + QKV projection + self.ln_qkv = te.LayerNormLinear(hidden_size, 3 * hidden_size, eps=layernorm_eps, bias=True) + self.attention = te.DotProductAttention( + num_attention_heads=num_attention_heads, + kv_channels=self.kv_channels, + attention_dropout=attention_dropout, + attn_mask_type="causal", + ) + self.projection = te.Linear(hidden_size, hidden_size, bias=True) + self.dropout1 = torch.nn.Dropout(hidden_dropout) + + # Fused LayerNorm + MLP + self.ln_mlp = te.LayerNormMLP(hidden_size, ffn_hidden_size, eps=layernorm_eps, bias=True) + self.dropout2 = torch.nn.Dropout(hidden_dropout) + + def forward( + self, x: torch.Tensor, attention_mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + res = x + + # Fused LayerNorm + QKV projection + qkv = self.ln_qkv(x) + qkv = qkv.view(qkv.size(0), qkv.size(1), self.num_attention_heads, 3 * self.kv_channels) + q, k, v = torch.split(qkv, qkv.size(3) // 3, dim=3) + + x = self.attention(q, k, v, attention_mask) + x = self.projection(x) + x = self.dropout1(x) + x = res + x + + # Fused LayerNorm + MLP + res = x + x = self.ln_mlp(x) + x = self.dropout2(x) + + return x + res + + +# TE_FUSED_LAYER_END + + +print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_START") +# BENCHMARK_TE_FUSED_FP8_START +te_fused_fp8 = ( + TEFusedTransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + ) + .to(dtype=dtype) + .cuda() +) + +print("TE Fused + TE Attention + FP8:") +time_te_fused_fp8 = speedometer( + te_fused_fp8, + x, + forward_kwargs={"attention_mask": None}, + autocast_kwargs={"enabled": True, "recipe": recipe}, + label="te_fused_fp8", +) +# BENCHMARK_TE_FUSED_FP8_END +print("# BENCHMARK_TE_FUSED_FP8_OUTPUT_END\n") + + +# ============================================================================= +# TE TransformerLayer + FP8: Ready-to-use Module +# ============================================================================= + +print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START") +# BENCHMARK_TE_TRANSFORMER_LAYER_START +te_transformer_layer = ( + te.TransformerLayer( + hidden_size=hidden_size, + ffn_hidden_size=ffn_hidden_size, + num_attention_heads=num_attention_heads, + self_attn_mask_type="causal", + layernorm_epsilon=1e-5, + bias=True, + hidden_dropout=0.0, + attention_dropout=0.0, + ) + .to(dtype=dtype) + .cuda() +) + +print("TE TransformerLayer + FP8:") +time_te_transformer_layer = speedometer( + te_transformer_layer, + x, + forward_kwargs={"attention_mask": None}, + autocast_kwargs={"enabled": True, "recipe": recipe}, + label="te_transformer_layer", +) +# BENCHMARK_TE_TRANSFORMER_LAYER_END +print("# BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END\n") + + +# Write summary CSV for RST documentation +with open("getting_started_pytorch_summary.csv", "w") as f: + f.write("Implementation,Time (ms),Speedup\n") + f.write(f"Baseline PyTorch,{time_baseline:.2f},1.00x\n") + f.write(f"TE Unfused,{time_te_unfused:.2f},{time_baseline/time_te_unfused:.2f}x\n") + f.write( + "TE Unfused + TE" + f" Attention,{time_te_unfused_attn:.2f},{time_baseline/time_te_unfused_attn:.2f}x\n" + ) + f.write( + "TE Unfused + TE Attention +" + f" FP8,{time_te_unfused_fp8:.2f},{time_baseline/time_te_unfused_fp8:.2f}x\n" + ) + f.write( + "TE Fused + TE Attention +" + f" FP8,{time_te_fused_fp8:.2f},{time_baseline/time_te_fused_fp8:.2f}x\n" + ) + f.write( + "TE TransformerLayer +" + f" FP8,{time_te_transformer_layer:.2f},{time_baseline/time_te_transformer_layer:.2f}x\n" + ) +print("\nSummary written to getting_started_pytorch_summary.csv") diff --git a/docs/getting_started/getting_started_pytorch_summary.csv b/docs/getting_started/getting_started_pytorch_summary.csv new file mode 100644 index 00000000000..b3a5d7330ed --- /dev/null +++ b/docs/getting_started/getting_started_pytorch_summary.csv @@ -0,0 +1,7 @@ +Implementation,Time (ms),Speedup +Baseline PyTorch,48.28,1.00x +TE Unfused,49.34,0.98x +TE Unfused + TE Attention,35.71,1.35x +TE Unfused + TE Attention + FP8,23.41,2.06x +TE Fused + TE Attention + FP8,22.96,2.10x +TE TransformerLayer + FP8,21.67,2.23x diff --git a/docs/getting_started/getting_started_utils_jax.py b/docs/getting_started/getting_started_utils_jax.py new file mode 100644 index 00000000000..6184b6565b8 --- /dev/null +++ b/docs/getting_started/getting_started_utils_jax.py @@ -0,0 +1,76 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Utility functions for Getting Started with Transformer Engine - JAX +==================================================================== + +Helper classes and functions for the getting started examples. +""" + +import time +from typing import Callable, Any, Optional + +import jax +import jax.numpy as jnp +from flax import linen as nn +import transformer_engine.jax as te +from transformer_engine.jax.sharding import MeshResource + + +def speedometer( + apply_fn: Callable, + params: Any, + x: jnp.ndarray, + forward_kwargs: dict = {}, + autocast_kwargs: Optional[dict] = None, + timing_iters: int = 100, + warmup_iters: int = 10, + label: str = "benchmark", +) -> float: + """Measure average forward + backward pass time for a JAX module. + + Args: + apply_fn: JIT-compiled apply function + params: Model parameters + x: Input tensor + forward_kwargs: Additional kwargs for forward pass + autocast_kwargs: Kwargs for te.autocast context + timing_iters: Number of timing iterations + warmup_iters: Number of warmup iterations + label: Optional label for logging + + Returns: + Average time per iteration in milliseconds + """ + if autocast_kwargs is None: + autocast_kwargs = {"enabled": False} + else: + autocast_kwargs = dict(autocast_kwargs) + autocast_kwargs.setdefault("mesh_resource", MeshResource()) + + def loss_fn(params, x): + y = apply_fn(params, x, **forward_kwargs) + return jnp.sum(y) + + # JIT compile within autocast context + with te.autocast(**autocast_kwargs): + grad_fn = jax.jit(jax.value_and_grad(loss_fn)) + + # Warmup runs + for _ in range(warmup_iters): + loss, grads = grad_fn(params, x) + jax.block_until_ready((loss, grads)) + + # Timing runs + times = [] + for _ in range(timing_iters): + start = time.perf_counter() + loss, grads = grad_fn(params, x) + jax.block_until_ready((loss, grads)) + times.append(time.perf_counter() - start) + + avg_time = sum(times) / len(times) * 1000 + print(f"Mean time: {avg_time:.3f} ms") + return avg_time diff --git a/docs/getting_started/getting_started_utils_pytorch.py b/docs/getting_started/getting_started_utils_pytorch.py new file mode 100644 index 00000000000..307d3d13b46 --- /dev/null +++ b/docs/getting_started/getting_started_utils_pytorch.py @@ -0,0 +1,124 @@ +# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# See LICENSE for license information. + +""" +Utility functions for Getting Started with Transformer Engine - PyTorch +======================================================================== + +Helper classes and functions for the getting started examples. +""" + +import math +from typing import Optional +import torch +import transformer_engine.pytorch as te + + +def speedometer( + module: torch.nn.Module, + x: torch.Tensor, + forward_kwargs: dict = {}, + autocast_kwargs: Optional[dict] = None, + timing_iters: int = 100, + warmup_iters: int = 10, + label: str = "benchmark", +) -> float: + """Measure average forward + backward pass time for a PyTorch module. + + Args: + module: PyTorch module to benchmark + x: Input tensor + forward_kwargs: Additional kwargs for forward pass + autocast_kwargs: Kwargs for te.autocast context + timing_iters: Number of timing iterations + warmup_iters: Number of warmup iterations + label: Optional label for logging + + Returns: + Average time per iteration in milliseconds + """ + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + if autocast_kwargs is None: + autocast_kwargs = {"enabled": False} + + # Warmup runs + torch.cuda.synchronize() + for _ in range(warmup_iters): + with te.autocast(**autocast_kwargs): + y = module(x, **forward_kwargs) + loss = y.sum() + loss.backward() + torch.cuda.synchronize() + + # Timing runs + start.record() + for _ in range(timing_iters): + with te.autocast(**autocast_kwargs): + y = module(x, **forward_kwargs) + loss = y.sum() + loss.backward() + end.record() + torch.cuda.synchronize() + + avg_time = start.elapsed_time(end) / timing_iters + print(f"Mean time: {avg_time:.3f} ms") + return avg_time + + +class DotProductAttention(torch.nn.Module): + """Attention operation in Transformer layer. + + Built with plain PyTorch modules. + """ + + def __init__( + self, + num_attention_heads: int, + kv_channels: int, + attention_dropout: float, + ) -> None: + super().__init__() + self.projection_size = kv_channels * num_attention_heads + self.hidden_size_per_attention_head = kv_channels + self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) + self.dropout = torch.nn.Dropout(attention_dropout) + + def masked_softmax(self, inp: torch.Tensor, mask: Optional[torch.Tensor]) -> torch.Tensor: + if mask is not None: + inp.masked_fill_(mask, -10000.0) + return torch.nn.Softmax(dim=-1)(inp) + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + b = query.size(1) + np = query.size(2) + sq = query.size(0) + sk = key.size(0) + hn = value.size(3) + + query = query.view(sq, b * np, -1) + key = key.view(sk, b * np, -1) + + bmm1 = ( + torch.bmm(query.transpose(0, 1), key.transpose(0, 1).transpose(1, 2)) / self.norm_factor + ) + + attention_scores = bmm1.view(b, np, sq, sk) + attention_probs = self.masked_softmax(attention_scores, attention_mask) + attention_probs = self.dropout(attention_probs) + + value = value.view(sk, b * np, -1) + attention_probs = attention_probs.view(b * np, sq, -1) + context = torch.bmm(attention_probs, value.transpose(0, 1)) + context = context.view(b, np, sq, hn) + context = context.permute(2, 0, 1, 3).contiguous() + context = context.view(sq, b, self.projection_size) + + return context diff --git a/docs/getting_started/index.rst b/docs/getting_started/index.rst new file mode 100644 index 00000000000..30dba93c6e1 --- /dev/null +++ b/docs/getting_started/index.rst @@ -0,0 +1,566 @@ +.. + Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + + See LICENSE for license information. + +Getting Started +=============== + +Overview +-------- + +Transformer Engine (TE) is a library for accelerating Transformer models on NVIDIA GPUs, +providing better performance with lower memory utilization in both training and inference. +It provides support for 8-bit floating point (FP8) precision on Hopper and Ada GPUs, as well as +8-bit and 4-bit floating point (NVFP4) precision on Blackwell GPUs. + +TE implements a collection of highly optimized building blocks for popular Transformer +architectures and exposes an automatic-mixed-precision-like API that can be used seamlessly +with your deep learning code. + + +Currently two frameworks are supported: PyTorch and JAX. + +.. tabs:: + + .. tab:: PyTorch + + Basic knowledge of PyTorch is recommended: + + - `PyTorch Tutorials `_ + - `PyTorch Documentation `_ + + .. tab:: JAX + + We recommend understanding the basics of JAX first: + + - `Thinking in JAX `_ + - `JAX 101 `_ + - `Key concepts in JAX `_ + - `Flax 101 `_ + + +Baseline: Pure Framework Implementation +--------------------------------------- + +Let's build a Transformer decoder layer! + +We'll create a basic GPT-style layer with causal masking, +which prevents each position from attending to future positions. This will be our baseline +for later comparisons with Transformer Engine. + +.. raw:: html + :file: transformer_layer.svg + +.. raw:: html + +

Structure of a GPT decoder layer

+ +We construct the components as follows: + +.. tabs:: + + .. tab:: PyTorch + + * **LayerNorm**: ``torch.nn.LayerNorm`` + * **QKV Projection**: ``torch.nn.Linear`` (fused Q, K, V into single layer 3x larger) + * **DotProductAttention**: Custom implementation using ``torch.bmm`` + * **Projection**: ``torch.nn.Linear`` + * **Dropout**: ``torch.nn.Dropout`` + * **MLP**: Two ``torch.nn.Linear`` layers with ``torch.nn.functional.gelu`` activation + + .. tab:: JAX + + * **LayerNorm**: ``nn.LayerNorm`` + * **QKV Projection**: ``nn.Dense`` (fused Q, K, V into single layer 3x larger) + * **DotProductAttention**: ``nn.dot_product_attention`` + * **Projection**: ``nn.Dense`` + * **Dropout**: ``nn.Dropout`` + * **MLP**: Two ``nn.Dense`` layers with ``nn.gelu`` activation + +Putting it all together: + +.. tabs:: + + .. tab:: PyTorch + + First, define the MLP block: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BASELINE_MLP_START + :end-before: # BASELINE_MLP_END + + Now, putting it all together into a GPT decoder layer: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BASELINE_LAYER_START + :end-before: # BASELINE_LAYER_END + + Benchmark the baseline implementation: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_BASELINE_START + :end-before: # BENCHMARK_BASELINE_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_BASELINE_OUTPUT_START + :end-before: # BENCHMARK_BASELINE_OUTPUT_END + + .. tab:: JAX + + First, define the MLP block: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BASELINE_MLP_START + :end-before: # BASELINE_MLP_END + + Now, putting it all together into a GPT decoder layer: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BASELINE_LAYER_START + :end-before: # BASELINE_LAYER_END + + Benchmark the baseline implementation: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_BASELINE_START + :end-before: # BENCHMARK_BASELINE_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_BASELINE_OUTPUT_START + :end-before: # BENCHMARK_BASELINE_OUTPUT_END + + +TE Unfused: Basic TE Modules +---------------------------- + +Now let's replace the standard framework modules with TE equivalents. +This is the simplest way to start using Transformer Engine. + +.. tabs:: + + .. tab:: PyTorch + + Replace PyTorch modules with TE equivalents: + + .. code-block:: python + + import transformer_engine.pytorch as te + + Mapping: + + * ``torch.nn.Linear`` → ``te.Linear`` + * ``torch.nn.LayerNorm`` → ``te.LayerNorm`` + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # TE_UNFUSED_MLP_START + :end-before: # TE_UNFUSED_MLP_END + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # TE_UNFUSED_LAYER_START + :end-before: # TE_UNFUSED_LAYER_END + + Benchmark the TE unfused implementation: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_START + :end-before: # BENCHMARK_TE_UNFUSED_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_OUTPUT_END + + .. tab:: JAX + + Replace Flax modules with TE equivalents: + + .. code-block:: python + + import transformer_engine.jax as te + import transformer_engine.jax.flax as te_flax + + Mapping: + + * ``nn.Dense`` → ``te_flax.DenseGeneral`` + * ``nn.LayerNorm`` → ``te_flax.LayerNorm`` + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # TE_UNFUSED_MLP_START + :end-before: # TE_UNFUSED_MLP_END + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # TE_UNFUSED_LAYER_START + :end-before: # TE_UNFUSED_LAYER_END + + Benchmark the TE unfused implementation: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_START + :end-before: # BENCHMARK_TE_UNFUSED_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_OUTPUT_END + + +TE Unfused + TE Attention +------------------------- + +Now let's also replace the attention mechanism with TE's optimized ``DotProductAttention``. +TE's attention automatically selects the best available backend — for example, FlashAttention or cuDNN fused attention — based on your hardware and input configuration, +delivering optimal performance without manual tuning. + +.. tabs:: + + .. tab:: PyTorch + + Replace the custom attention with TE's optimized implementation: + + * Custom ``DotProductAttention`` → ``te.DotProductAttention`` + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # TE_UNFUSED_ATTN_LAYER_START + :end-before: # TE_UNFUSED_ATTN_LAYER_END + + Benchmark TE Unfused with TE Attention: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_ATTN_START + :end-before: # BENCHMARK_TE_UNFUSED_ATTN_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END + + .. tab:: JAX + + Replace Flax's attention with TE's optimized implementation: + + * ``nn.dot_product_attention`` → ``te_flax.DotProductAttention`` + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # TE_UNFUSED_ATTN_LAYER_START + :end-before: # TE_UNFUSED_ATTN_LAYER_END + + Benchmark TE Unfused with TE Attention: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_ATTN_START + :end-before: # BENCHMARK_TE_UNFUSED_ATTN_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_ATTN_OUTPUT_END + + +TE Unfused + TE Attention + FP8 +------------------------------- + +Now let's combine TE modules with TE Attention and enable FP8 precision. +Wrap your code within an ``autocast`` context manager to enable FP8. +This provides significant speedups on supported hardware (Hopper, Ada, Blackwell GPUs). + +.. tabs:: + + .. tab:: PyTorch + + .. code-block:: python + + from transformer_engine.common.recipe import Format, DelayedScaling + + recipe = DelayedScaling( + fp8_format=Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max" + ) + + with te.autocast(enabled=True, recipe=recipe): + y = te_unfused(x, attention_mask=None) + + .. note:: + + The ``autocast`` should only wrap the forward pass and must exit before + starting a backward pass. + + Benchmark TE Unfused with FP8: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_FP8_START + :end-before: # BENCHMARK_TE_UNFUSED_FP8_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END + + .. tab:: JAX + + .. code-block:: python + + from transformer_engine.common.recipe import Format, DelayedScaling + + recipe = DelayedScaling( + fp8_format=Format.HYBRID, + amax_history_len=16, + amax_compute_algo="max" + ) + + with te.autocast(enabled=True, recipe=recipe): + params = te_unfused.init(key, x, deterministic=False) + y = te_unfused.apply(params, x, deterministic=True) + + .. important:: + + When using FP8 in JAX, the model **must be initialized within the autocast context** + to create the ``fp8_metas`` collection. + + Benchmark TE Unfused with FP8: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_UNFUSED_FP8_START + :end-before: # BENCHMARK_TE_UNFUSED_FP8_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_START + :end-before: # BENCHMARK_TE_UNFUSED_FP8_OUTPUT_END + + +TE Fused + TE Attention + FP8: Optimized Modules +------------------------------------------------ + +Fused modules use kernel fusion to combine multiple operations. +While speedups are modest on a single GPU, they scale better in multi-GPU setups. +Combined with TE Attention and FP8, this delivers peak performance. + +.. tabs:: + + .. tab:: PyTorch + + Fused modules available: + + * ``te.LayerNormLinear`` - fuses LayerNorm + Linear + * ``te.LayerNormMLP`` - fuses LayerNorm + MLP + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # TE_FUSED_LAYER_START + :end-before: # TE_FUSED_LAYER_END + + Benchmark TE Fused with FP8: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_FUSED_FP8_START + :end-before: # BENCHMARK_TE_FUSED_FP8_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_FUSED_FP8_OUTPUT_START + :end-before: # BENCHMARK_TE_FUSED_FP8_OUTPUT_END + + .. tab:: JAX + + Fused modules available: + + * ``te_flax.LayerNormDenseGeneral`` - fuses LayerNorm + Dense + * ``te_flax.LayerNormMLP`` - fuses LayerNorm + MLP + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # TE_FUSED_LAYER_START + :end-before: # TE_FUSED_LAYER_END + + Benchmark TE Fused with FP8: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_FUSED_FP8_START + :end-before: # BENCHMARK_TE_FUSED_FP8_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_FUSED_FP8_OUTPUT_START + :end-before: # BENCHMARK_TE_FUSED_FP8_OUTPUT_END + + +TE TransformerLayer + FP8: Ready-to-use Module +---------------------------------------------- + +For the simplest integration, Transformer Engine provides a ready-to-use ``TransformerLayer`` +module that includes all optimizations out of the box. + +.. tabs:: + + .. tab:: PyTorch + + Just use ``te.TransformerLayer`` - it handles everything for you: + + .. literalinclude:: getting_started_pytorch.py + :language: python + :start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_START + :end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_pytorch.out + :language: text + :start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START + :end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END + + .. tab:: JAX + + Just use ``te_flax.TransformerLayer`` - it handles everything for you: + + .. literalinclude:: getting_started_jax.py + :language: python + :start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_START + :end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_END + + .. raw:: html + +
+ Output: +
+ + .. container:: program-output + + .. literalinclude:: getting_started_jax.out + :language: text + :start-after: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_START + :end-before: # BENCHMARK_TE_TRANSFORMER_LAYER_OUTPUT_END + + +Benchmark Summary +----------------- + +The table below summarizes the performance improvements achieved with Transformer Engine +on an NVIDIA H100 GPU. Results may vary depending on hardware and configuration. While this +tutorial focuses on a simple single-GPU scenario, features like fused layers can provide +additional benefits in more complex setups such as multi-GPU training. + +.. tabs:: + + .. tab:: PyTorch + + .. csv-table:: + :header-rows: 1 + :widths: 40, 20, 20 + :file: getting_started_pytorch_summary.csv + + .. tab:: JAX + + .. csv-table:: + :header-rows: 1 + :widths: 40, 20, 20 + :file: getting_started_jax_summary.csv diff --git a/docs/getting_started/transformer_layer.svg b/docs/getting_started/transformer_layer.svg new file mode 100644 index 00000000000..28ba3dd3869 --- /dev/null +++ b/docs/getting_started/transformer_layer.svg @@ -0,0 +1,82 @@ + + + + + + + + + + + + + + + + + + LayerNorm + + + + + + QKV Projection + + + + + + Dot Product + Attention + + + + + + Projection + + + + + + Dropout + + + + + + + + + + + + + + + + + LayerNorm + + + + + + MLP + + + + + + + + + + + + diff --git a/docs/index.rst b/docs/index.rst index 7a3ab9f6fd5..d9f7b03859e 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -29,7 +29,7 @@ Transformer Engine documentation :caption: Getting Started installation - getting_started + getting_started/index faq .. toctree::