Skip to content

MI300X FP8 TE.Linear 2x Slower than AMP BF16 F.Linear #73

@functionstackx

Description

@functionstackx

Problem Description

Even on Real World Llama 2 70B Training Shapes, TE Linear FP8 is 1.5 to 2x slower than AMP BF16 Linear. Do you have any suggestions or magic env flags on how to improve performance? On H100, TE Linear FP8 is way faster than BF16 AMP Linear.

I have attached an reprod & all the relevant versions & installation scripts below.

cc: @hliuca

image

python3 ./gemm.py 
Benchmark results for Realistic GEMM shapes with warmup=30 and repeats=200
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| Shape (M, N, K)     | bf16 torch.matmul   | bf16 F.linear (with bias)   | bf16 F.linear (with bias & amp)   | TE Linear (FP8 autocast)   |
+=====================+=====================+=============================+===================================+============================+
| (16384, 8192, 1280) | 493.0 TFLOPS        | 491.7 TFLOPS                | 420.2 TFLOPS                      | 206.8 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 1024, 8192) | 546.4 TFLOPS        | 470.0 TFLOPS                | 288.6 TFLOPS                      | 137.1 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 8192, 7168) | 567.0 TFLOPS        | 566.3 TFLOPS                | 504.0 TFLOPS                      | 465.3 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 3584, 8192) | 610.0 TFLOPS        | 545.0 TFLOPS                | 430.1 TFLOPS                      | 325.8 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (8192, 8192, 8192)  | 588.3 TFLOPS        | 504.3 TFLOPS                | 443.2 TFLOPS                      | 372.9 TFLOPS               |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+

Steps to Reproduce

Versions

root@NODENAME:/workspace/llm-train-bench# pip list | grep torch
^[[Apytorch-triton-rocm     3.1.0+cf34004b8a
torch                   2.6.0.dev20241012+rocm6.2
torchvision             0.18.0a0+68ba7ec
root@NODENAME:/workspace/llm-train-bench# pip list | grep transformer
transformer_engine      1.8.0.dev0+691dc23

Install Instruction

FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0

RUN apt install nano

RUN pip install uv

RUN uv pip install --system ipython pytest fire pydantic pybind11

RUN pip3 uninstall -y torch

RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2


WORKDIR /workspace/

RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git
ENV NVTE_FRAMEWORK=pytorch
ENV PYTORCH_ROCM_ARCH=gfx942

RUN cd TransformerEngine && pip install .

WORKDIR /workspace/llm-train-bench/

CMD ["/usr/bin/bash"]

Reprod Script

import time
import torch
import tabulate
from triton.testing import do_bench
import torch.nn.functional as F
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling

torch.manual_seed(0)
repeats = 200
warmup = 30
dtype = torch.bfloat16
device = 'cuda'
verbose = False

shapes = [
    (16384, 8192, 1280), # LLama 70B TP8 Shape
    (16384, 1024, 8192), # LLama 70B TP8 Shape
    (16384, 8192, 7168), # LLama 70B TP8 Shape
    (16384, 3584, 8192), # LLama 70B TP8 Shape
    (8192, 8192, 8192) # Square shape
]

results = []

for (m, n, k) in shapes:
    # Matmul benchmark
    a = torch.randn(m, k, device=device, dtype=dtype)
    b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
    nFLOPS = 2 * m * n * k
    ms = do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=repeats)
    tflops_matmul = nFLOPS / ms * 1e-9
    time.sleep(3)

    nFLOPS_with_bias = 2 * m * n * k + m * n  # FLOPs for matmul and addition


    # # Linear (with bias) benchmark using F.linear
    weight_with_bias = torch.randn(n, k, device=device, dtype=dtype)
    bias = torch.randn(n, device=device, dtype=dtype)
    input_tensor = torch.randn(m, k, device=device, dtype=dtype)
    ms_linear_with_bias = do_bench(lambda: F.linear(input_tensor, weight_with_bias, bias=bias), warmup=warmup, rep=repeats)
    tflops_linear_with_bias = nFLOPS_with_bias / ms_linear_with_bias * 1e-9
    time.sleep(0.25)

    # # F.linear with autocast bf16 with a, b, and c being fp32
    a = torch.randn(m, k, device=device, dtype=torch.float32)
    b = torch.randn(n, k, device=device, dtype=torch.float32)
    c = torch.randn(n, device=device, dtype=torch.float32)
    with torch.autocast(dtype=dtype, device_type=device):
        ms_autocast = do_bench(lambda: F.linear(a, b, bias=c), warmup=warmup, rep=repeats)
    tflops_autocast = nFLOPS_with_bias / ms_autocast * 1e-9
    time.sleep(0.25)

    # TE Linear (with FP8 autocast) benchmark
    fp8_format = Format.HYBRID
    fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
    input_tensor = torch.randn(m, k, device=device)
    linear_layer = te.Linear(k, n, bias=True).to(device)
    with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
        ms_te_linear = do_bench(lambda: linear_layer(input_tensor), warmup=warmup, rep=repeats)
    tflops_te_linear = nFLOPS_with_bias / ms_te_linear * 1e-9
    time.sleep(0.25)

    # Append the results to the list
    results.append([
        f"({m}, {n}, {k})",
        f"{tflops_matmul:.1f} TFLOPS",
        f"{tflops_linear_with_bias:.1f} TFLOPS",
        f"{tflops_autocast:.1f} TFLOPS",
        f"{tflops_te_linear:.1f} TFLOPS"
    ])

# Print results using tabulate
headers = [
    "Shape (M, N, K)",
    "bf16 torch.matmul",
    "bf16 F.linear (with bias)",
    "bf16 F.linear (with bias & amp)",
    "TE Linear (FP8 autocast)"
]
print(f"Benchmark results for Realistic GEMM shapes with {warmup=} and {repeats=}")
print(tabulate.tabulate(results, headers=headers, tablefmt="grid"))

Operating System

Ubuntu

CPU

AMD CPU

GPU

AMD Instinct MI300X

ROCm Version

ROCm 6.2.0

Metadata

Metadata

Labels

No labels
No labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions