-
Notifications
You must be signed in to change notification settings - Fork 22
Open
Description
Problem Description
Even on Real World Llama 2 70B Training Shapes, TE Linear FP8 is 1.5 to 2x slower than AMP BF16 Linear. Do you have any suggestions or magic env flags on how to improve performance? On H100, TE Linear FP8 is way faster than BF16 AMP Linear.
I have attached an reprod & all the relevant versions & installation scripts below.
cc: @hliuca
python3 ./gemm.py
Benchmark results for Realistic GEMM shapes with warmup=30 and repeats=200
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| Shape (M, N, K) | bf16 torch.matmul | bf16 F.linear (with bias) | bf16 F.linear (with bias & amp) | TE Linear (FP8 autocast) |
+=====================+=====================+=============================+===================================+============================+
| (16384, 8192, 1280) | 493.0 TFLOPS | 491.7 TFLOPS | 420.2 TFLOPS | 206.8 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 1024, 8192) | 546.4 TFLOPS | 470.0 TFLOPS | 288.6 TFLOPS | 137.1 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 8192, 7168) | 567.0 TFLOPS | 566.3 TFLOPS | 504.0 TFLOPS | 465.3 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (16384, 3584, 8192) | 610.0 TFLOPS | 545.0 TFLOPS | 430.1 TFLOPS | 325.8 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+
| (8192, 8192, 8192) | 588.3 TFLOPS | 504.3 TFLOPS | 443.2 TFLOPS | 372.9 TFLOPS |
+---------------------+---------------------+-----------------------------+-----------------------------------+----------------------------+Steps to Reproduce
Versions
root@NODENAME:/workspace/llm-train-bench# pip list | grep torch
^[[Apytorch-triton-rocm 3.1.0+cf34004b8a
torch 2.6.0.dev20241012+rocm6.2
torchvision 0.18.0a0+68ba7ec
root@NODENAME:/workspace/llm-train-bench# pip list | grep transformer
transformer_engine 1.8.0.dev0+691dc23Install Instruction
FROM rocm/pytorch:rocm6.2_ubuntu22.04_py3.10_pytorch_release_2.3.0
RUN apt install nano
RUN pip install uv
RUN uv pip install --system ipython pytest fire pydantic pybind11
RUN pip3 uninstall -y torch
RUN pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/rocm6.2
WORKDIR /workspace/
RUN git clone --recursive https://github.com/ROCm/TransformerEngine.git
ENV NVTE_FRAMEWORK=pytorch
ENV PYTORCH_ROCM_ARCH=gfx942
RUN cd TransformerEngine && pip install .
WORKDIR /workspace/llm-train-bench/
CMD ["/usr/bin/bash"]Reprod Script
import time
import torch
import tabulate
from triton.testing import do_bench
import torch.nn.functional as F
import transformer_engine.pytorch as te
from transformer_engine.common.recipe import Format, DelayedScaling
torch.manual_seed(0)
repeats = 200
warmup = 30
dtype = torch.bfloat16
device = 'cuda'
verbose = False
shapes = [
(16384, 8192, 1280), # LLama 70B TP8 Shape
(16384, 1024, 8192), # LLama 70B TP8 Shape
(16384, 8192, 7168), # LLama 70B TP8 Shape
(16384, 3584, 8192), # LLama 70B TP8 Shape
(8192, 8192, 8192) # Square shape
]
results = []
for (m, n, k) in shapes:
# Matmul benchmark
a = torch.randn(m, k, device=device, dtype=dtype)
b = torch.randn(n, k, device=device, dtype=dtype).transpose(-1, -2)
nFLOPS = 2 * m * n * k
ms = do_bench(lambda: torch.matmul(a, b), warmup=warmup, rep=repeats)
tflops_matmul = nFLOPS / ms * 1e-9
time.sleep(3)
nFLOPS_with_bias = 2 * m * n * k + m * n # FLOPs for matmul and addition
# # Linear (with bias) benchmark using F.linear
weight_with_bias = torch.randn(n, k, device=device, dtype=dtype)
bias = torch.randn(n, device=device, dtype=dtype)
input_tensor = torch.randn(m, k, device=device, dtype=dtype)
ms_linear_with_bias = do_bench(lambda: F.linear(input_tensor, weight_with_bias, bias=bias), warmup=warmup, rep=repeats)
tflops_linear_with_bias = nFLOPS_with_bias / ms_linear_with_bias * 1e-9
time.sleep(0.25)
# # F.linear with autocast bf16 with a, b, and c being fp32
a = torch.randn(m, k, device=device, dtype=torch.float32)
b = torch.randn(n, k, device=device, dtype=torch.float32)
c = torch.randn(n, device=device, dtype=torch.float32)
with torch.autocast(dtype=dtype, device_type=device):
ms_autocast = do_bench(lambda: F.linear(a, b, bias=c), warmup=warmup, rep=repeats)
tflops_autocast = nFLOPS_with_bias / ms_autocast * 1e-9
time.sleep(0.25)
# TE Linear (with FP8 autocast) benchmark
fp8_format = Format.HYBRID
fp8_recipe = DelayedScaling(fp8_format=fp8_format, amax_history_len=16, amax_compute_algo="max")
input_tensor = torch.randn(m, k, device=device)
linear_layer = te.Linear(k, n, bias=True).to(device)
with te.fp8_autocast(enabled=True, fp8_recipe=fp8_recipe):
ms_te_linear = do_bench(lambda: linear_layer(input_tensor), warmup=warmup, rep=repeats)
tflops_te_linear = nFLOPS_with_bias / ms_te_linear * 1e-9
time.sleep(0.25)
# Append the results to the list
results.append([
f"({m}, {n}, {k})",
f"{tflops_matmul:.1f} TFLOPS",
f"{tflops_linear_with_bias:.1f} TFLOPS",
f"{tflops_autocast:.1f} TFLOPS",
f"{tflops_te_linear:.1f} TFLOPS"
])
# Print results using tabulate
headers = [
"Shape (M, N, K)",
"bf16 torch.matmul",
"bf16 F.linear (with bias)",
"bf16 F.linear (with bias & amp)",
"TE Linear (FP8 autocast)"
]
print(f"Benchmark results for Realistic GEMM shapes with {warmup=} and {repeats=}")
print(tabulate.tabulate(results, headers=headers, tablefmt="grid"))Operating System
Ubuntu
CPU
AMD CPU
GPU
AMD Instinct MI300X
ROCm Version
ROCm 6.2.0
Metadata
Metadata
Assignees
Labels
No labels
