Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
157 commits
Select commit Hold shift + click to select a range
7a9a082
Changed VERSION to 2.7.0.dev0 (#1973)
KshitijLakhani Jul 21, 2025
5ba7953
[PyTorch] Remove GH pinned deps (#1961)
ksivaman Jul 21, 2025
78a3821
[PyTorch] Reset FP8 weight workspace if usages are invalid (#1972)
timmoon10 Jul 21, 2025
ab5cc40
Fix the condition error when checking fp8 attn in `get_attention_back…
yuzhongw-nvidia Jul 21, 2025
0d80228
[Common] Skip cuDNN 9.10.0/9.10.1 due to bugs (#1937)
cyanguwa Jul 21, 2025
315b47d
[PyTorch] Debug linear layer when saving original input and using deb…
timmoon10 Jul 22, 2025
cb504cd
[Common] Improved performance of mxfp8 cast kernels (#1628)
Oleg-Goncharov Jul 22, 2025
e0204fb
Refactor `te.ops` (#1951)
janekb04 Jul 22, 2025
d1967d5
fix: Add stream synchronization before destroying MPI communicator (#…
djns99 Jul 22, 2025
fdb87af
[PyTorch] Reset recipe state in fusible operations when FP8 amax hist…
timmoon10 Jul 23, 2025
4296b7d
Fix the device for cuDNN/cuBLAS handles (#1974)
cyanguwa Jul 23, 2025
992ba01
[JAX] Fix current scaling test_helper.py and enable test_helper.py in…
jberchtold-nvidia Jul 23, 2025
2a29345
[JAX] Helper to disable TE custom calls + disable GemmPrimitive for n…
phu0ngng Jul 24, 2025
dab931a
[PyTorch] Improve L2Normalization basic op (#1964)
negvet Jul 24, 2025
fe27bf1
Fix runtime lib loading for cuDNN (#1989)
ksivaman Jul 24, 2025
ee84108
Add `in_place` kwarg to extra tensor ops (#1983)
janekb04 Jul 24, 2025
71b2dd4
Fix cudnn versioning support in PyTorch DPA and Fused attn (#1991)
KshitijLakhani Jul 24, 2025
a99c056
[Common] Fixed integer overflow issue in cast kernels (#1988)
Oleg-Goncharov Jul 24, 2025
25a8219
[JAX] Fixing GemmPrimitive partitioning rules to handle tensor-parall…
denera Jul 24, 2025
e950ceb
[PyTorch] Optimize cudagraph static_grad_outputs reuse (#1992)
buptzyb Jul 25, 2025
374849e
[PyTorch] Enable generic QK norm support (+ RMSNorm/LayerNorm) (#1966)
negvet Jul 25, 2025
1470116
[C][PyTorch] Remove deprecated `device_id` arg for multi tensor API (…
ksivaman Jul 25, 2025
38c26dd
Fixed double buffering issue for assymetric layers (#1984)
sanandaraj5597 Jul 25, 2025
c6c1f50
[PyTorch] Add ops for dropout and constant scale (#1995)
timmoon10 Jul 25, 2025
aac7442
[PyTorch] Prune L0 unit test (#1999)
ksivaman Jul 29, 2025
5a495a3
Fix the use-after-free bug in unfused normalization (#2002)
ptrendx Jul 29, 2025
cb5013b
[PyTorch] Refactor C++ quantizer infrastructure (#1952)
timmoon10 Jul 29, 2025
f858dc3
Rename `do_not_clear` to `_do_not_clear` (#1977)
janekb04 Jul 29, 2025
feda5b5
Fuse amax computation into activation kernel (#2004)
janekb04 Jul 29, 2025
020428f
[PyTorch] Fix bug with clearing op outputs during backward (#2008)
timmoon10 Jul 30, 2025
11ac24c
Refactor normalization.cpp to use quantizer logic introduced in #1952…
janekb04 Jul 30, 2025
858755c
[JAX] TE GEMM checkpointing policies (#2003)
jberchtold-nvidia Jul 30, 2025
44a581c
[PyTorch Debug] Minor fix in docs. (#1947)
dupeljan Jul 31, 2025
51eb636
Fuse amax computation into normalization kernel for current scaling (…
janekb04 Jul 31, 2025
8dfdb91
[PyTorch] Tutorial for the ONNX export (#1586)
pggPL Jul 31, 2025
8e2d37e
[PyTorch] Fix corner case in router fuson (#2009)
Autumn1998 Aug 1, 2025
1258bbe
Manually launch wgrad accumulation and reduce in backward_dw() instea…
lhb8125 Aug 1, 2025
c444bf5
[PyTorch Debug] Fix debug tests (#2021)
pggPL Aug 1, 2025
13cae89
Tensor numel() return dtype to be size_t (#2022)
shangz-ai Aug 1, 2025
1f2df73
Fix JAX and PyTorch wheel builds for v2.6 (#2005)
jberchtold-nvidia Aug 1, 2025
c3f8a9f
[Core] Kernel that swaps first two tensor dimensions (#1998)
timmoon10 Aug 4, 2025
06947e8
[PyTorch] Fix cudagraph static_input and static_grad_input reuse (#2018)
buptzyb Aug 4, 2025
3e6859e
[JAX] Sharding specs for TE GEMM custom call operands (#2023)
phu0ngng Aug 5, 2025
6c97061
[JAX] Disable TE Norm Custom Calls (#1993)
phu0ngng Aug 5, 2025
7101f4b
[PyTorch] Fix zero initialization in permute kernel for padded slots …
xiaoxi-wangfj Aug 6, 2025
ed42b5a
[JAX] Remove `dot_1_output_axes` usage in LayerNormMLP (#2029)
phu0ngng Aug 6, 2025
6d178b4
[JAX] Reduce L1 tests/jax/test_distributed_softmax.py test runtime (#…
jberchtold-nvidia Aug 6, 2025
c0d2f1a
[PyTorch] Multi-tensor swizzle scaling factors for MXFP8 and fuse pad…
yaox12 Aug 6, 2025
de69ca0
[PyTorch] fix input_quantizer usage for save_original_input; fix bloc…
hxbai Aug 6, 2025
c5ee5fd
Revert "[JAX] Disable TE Norm Custom Calls" (#2035)
phu0ngng Aug 6, 2025
bfab8c6
[Common] PDL for Quantization Kernels (#2001)
yaox12 Aug 7, 2025
dd083bd
[PyTorch] Fix numeric overflow caused by int-type parameters and retu…
lvdunlin Aug 7, 2025
cae1c43
[JAX] TE Gemm custom call clean up (#2030)
phu0ngng Aug 7, 2025
9f9b481
[JAX] Remove cudaGraph compatible trait from GroupedGemmFFI and Group…
phu0ngng Aug 8, 2025
b6b3abc
[PyTorch debug] Improve precision debug tools performance (#1909)
pggPL Aug 8, 2025
235c8d0
[JAX] Enable TE GEMM custom call for all recipes (#2047)
phu0ngng Aug 8, 2025
077e26c
Use userbuffers for MXFP8 wgrad all-gather overlap (#1982)
djns99 Aug 9, 2025
de6afe2
[PyTorch] Fix high-precision dtype for MXFP8 AG (#2058)
ksivaman Aug 11, 2025
bfca2e3
[PyTorch] Update amax pointers when reallocating amax history in fusi…
timmoon10 Aug 12, 2025
f947e70
[PyTorch] Fix bug when deducing dtype in linear functional API (#2017)
timmoon10 Aug 12, 2025
6a4e871
[JAX] Support custom recipe and custom collection name when creating …
jberchtold-nvidia Aug 12, 2025
05d3b7b
[PyTorch] Fix normalization+amax forward CS fusion to work for untune…
janekb04 Aug 12, 2025
ec65ba3
[JAX] Add L2_jax_distributed_unittest (#2060)
jberchtold-nvidia Aug 13, 2025
ebca615
[Common] PDL for Blockwise Quantization (#2066)
yaox12 Aug 13, 2025
6afca29
[PyTorch Debug] More advanced stats for Quantized Tensors (#1897)
pggPL Aug 13, 2025
aa0659e
Remove if-else and torch.tensor to meet cudagraph requirement (#1997)
katec846 Aug 13, 2025
8dc2756
[JAX] Manual axis filter in `with_sharding_constraint` (#2069)
phu0ngng Aug 13, 2025
bbddcb9
[JAX] Cleanup the MLP warning for TE GEMM + TP (#2054)
phu0ngng Aug 13, 2025
44fbe9e
fix: update grad_output quant to avoid redundant work (#1736)
kshitij12345 Aug 14, 2025
c582f6b
[Common] Reduce CUDA driver calls (#2067)
yaox12 Aug 14, 2025
ccbc8cf
[PyTorch] Register weight and bias params in linear op (#2027)
timmoon10 Aug 14, 2025
26b4b71
[PyTorch] Avoid registering FP8 scale update in ops without backward …
timmoon10 Aug 14, 2025
a169e9e
[PyTorch] Disable fused dbias-quantize kernel for unsupported recipes…
timmoon10 Aug 14, 2025
12065ac
[Core] Add launch bounds to swizzle kernels (#2076)
ksivaman Aug 14, 2025
92f431b
[JAX] Trim dist fused attn tests in L1 (#2050)
KshitijLakhani Aug 15, 2025
c654e4f
Fuse linear+scale+add (#2042)
janekb04 Aug 15, 2025
6ba98d4
fix: fixes multi head attention for context parallel: rotary embeddin…
jomitchellnv Aug 16, 2025
757fd1c
[JAX] Fix Flax variable creation when quantizers are created directly…
jberchtold-nvidia Aug 18, 2025
988af0f
Update list of authorized CI users (#2078)
timmoon10 Aug 18, 2025
0e3e270
[PyTorch] Check if the given recipe is supported in `fp8_autocast` (#…
yaox12 Aug 18, 2025
3fc1e4b
[JAX] Fix for TE GEMM - Always AllGather RHS non-contracting dims wit…
phu0ngng Aug 18, 2025
734bced
Changed VERSION to 2.8.0.dev0
ptrendx Aug 18, 2025
1d075c0
Add user to TE CI (#2089)
ksivaman Aug 19, 2025
5b4d89c
Add backward RMSNorm+Add fusion (#2028)
janekb04 Aug 20, 2025
51f19fd
[PyTorch] Add test for TRT integration + fix for mxfp8 export (#2083)
pggPL Aug 20, 2025
bc99a88
[JAX] Error checking for mesh resource and update GemmPrimitive to us…
jberchtold-nvidia Aug 20, 2025
96944a8
[PyTorch] Avoid garbage collection when capturing a CUDA Graph (#2092)
timmoon10 Aug 20, 2025
406e2c9
Fix incorrect version checks for atomic GEMM (#2095)
timmoon10 Aug 20, 2025
f1b18ed
Update list of authorized CI users (#2081)
timmoon10 Aug 20, 2025
20be25a
[ TE-JAX ] Expose cp_strategy argument to DPA api (#2090)
kocchop Aug 21, 2025
40dde4d
Update NGC version to 25.08 (#2085)
phu0ngng Aug 22, 2025
d88137c
[PyTorch] Debug Mcore wgrad fusion with te.ops (#2097)
timmoon10 Aug 23, 2025
78e097f
[Jax] Fix narrowing conversions (#2094)
alexeldeib Aug 25, 2025
2e23ad7
[JAX] Add Shardy warning in GEMM custom call (#2101)
phu0ngng Aug 25, 2025
47ab4a7
[JAX] Add Transformer Layer tests for pre_scale_bias and post_scale_b…
KshitijLakhani Aug 25, 2025
ccc1abf
[Pytorch] Fix `UnboundLocalError` during build (#2116)
ksivaman Aug 26, 2025
07db17b
[PyTorch] Expose more activation functions (#2106)
yaox12 Aug 26, 2025
3d0ea80
[JAX] `ScaledTensor1x` to store `amax` (#2117)
phu0ngng Aug 26, 2025
d972e76
Revert "[Common] PDL for Quantization Kernels" (#2114)
jberchtold-nvidia Aug 26, 2025
d770886
[JAX] Add `tpsp_resource` in the `MeshResource` map (#2113)
phu0ngng Aug 26, 2025
54c0c85
Bump cuDNN FE to 1.14.0 (#2072)
vcherepanov-nv Aug 26, 2025
d370608
Revert "[Common] PDL for Blockwise Quantization" (#2115)
jberchtold-nvidia Aug 26, 2025
1398fa5
[PyTorch Debug] Skip log test on device if it does not support fp8. (…
pggPL Aug 26, 2025
8dba296
Add cuBLASMp-backed GEMM-like API to TE common (#1824)
mk-61 Aug 26, 2025
62a57dd
FP8 AllGather in FP8 GroupedGEMM + Fix Stream Usage Issue. (#2086)
mingxu1067 Aug 27, 2025
04add79
[JAX] Delay MeshResource validation until first usage (#2124)
jberchtold-nvidia Aug 27, 2025
c950800
[JAX] Decouple Recipe and ScalingMode (#1728)
jberchtold-nvidia Aug 27, 2025
a282136
[JAX] `dot_1_output` sharding constraint + use AXIS_IS_UNSHARDED (#2128)
phu0ngng Aug 27, 2025
1e2c68d
[JAX] Add amax input to DBiasQuantizePrimitive and FFI (#2118)
phu0ngng Aug 27, 2025
de81b7d
Further relax constraints to cuDNN 9.13 for disabling fused attn for …
KshitijLakhani Aug 27, 2025
c776141
Temporarily remove comm_gemm tests (#2133)
vcherepanov-nv Aug 28, 2025
a5c7987
[PyTorch] Disable determinism for sm100 (#2130)
cyanguwa Aug 28, 2025
06a38cc
[PyTorch] ONNX export of FP8 Current Scaling (#2068)
pggPL Aug 28, 2025
c449c6c
[PyTorch][MOE] Tentative Fix For Replacing from_blob with empty for e…
zhongbozhu Aug 28, 2025
f98e305
build: pull cached wheels (#2127)
ko3n1g Aug 29, 2025
715c3bb
feat: Add support for multiple quantization modes in the UB communica…
djns99 Aug 29, 2025
4285874
[Common] Add checks to CUDA kernel launch and CUDA API calls (#2074)
yaox12 Aug 29, 2025
607fcc4
[PyTorch] Support bf16+fp8 cudagraph (#2098)
buptzyb Aug 29, 2025
e0e3d12
Dropout with 8-bit RNG (#2014)
vasunvidia Aug 31, 2025
67fcc15
Create GPU reload buffers on main stream (#2131)
sanandaraj5597 Sep 2, 2025
3b4366b
Fix CI failures for UB overlap changes (#2149)
djns99 Sep 3, 2025
f378eaf
[JAX] Fix failing fused attn tests for dropout=0.1 and bias for sm100…
KshitijLakhani Sep 3, 2025
0f68f7b
[PyTorch][CUDA Graph] Fix FP8 Weight Quantization Cache under CUDA Gr…
zhongbozhu Sep 4, 2025
e9a5fa4
[PyTorch] fix cross entropy vanishing gradients (#2139)
casper-hansen Sep 4, 2025
11e9d66
Fix bug when enabling --overlap-grad-reduce in mcore (#2142)
lhb8125 Sep 5, 2025
b10f436
Fix CUDA version in setup.py (#2132)
vcherepanov-nv Sep 5, 2025
c47f329
[JAX] NoScaleTensor wrapper for non-quantized data (#2136)
jberchtold-nvidia Sep 5, 2025
5b3d65c
[JAX] Fix GroupedScaledTensor creation with keyword arg (#2154)
phu0ngng Sep 8, 2025
aa06107
Fixing few issues with multi-process launching. (#2155)
mingxu1067 Sep 8, 2025
603dbf7
Update list of authorized CI users (#2152)
timmoon10 Sep 8, 2025
84fa28d
Fused RoPE with combined QKV input. (#2122)
vasunvidia Sep 8, 2025
a26a7f1
Add bf16/fp32 token-per-expert to the MoE aux loss kernel (#2162)
Autumn1998 Sep 9, 2025
5f2b831
[JAX] Scale swizzling via JAX transpose op (#2163)
phu0ngng Sep 9, 2025
4903f94
Extract cpp distributed tests into a separate project (#2165)
vcherepanov-nv Sep 10, 2025
483d959
Adds context parallelism utilities: moving cp shards to diff ranks an…
jomitchellnv Sep 10, 2025
405d474
[PyTorch Debug] Fix issue with negative underflow% stat. (#2107)
pggPL Sep 15, 2025
cd2034f
Lower precision gated-act to accelerate FP8 current-scaling. (#2153)
mingxu1067 Sep 15, 2025
59130cc
[PyTorch] Support activation CPU offloading in fusible ops (#2158)
timmoon10 Sep 15, 2025
258d084
Do not use normalization forward + amax fusion if cuDNN backend is re…
janekb04 Sep 16, 2025
c221909
Fix unjoined comm stream in UB communicator (#2160)
djns99 Sep 16, 2025
ba37529
FP8 Output Quantization for GEMM (#2123)
vthumbe1503 Sep 17, 2025
7042d7a
TE Gemma tutorial attempt#2 (#1839)
sudhakarsingh27 Sep 17, 2025
93a67af
Fix memory overhead of linear layer when all gather from sequence par…
yuzhongw-nvidia Sep 17, 2025
eb69fad
Fix incorrect TP rank calculation when using data parallel (#2179)
djns99 Sep 17, 2025
8aee1bb
[Pytorch] Add Cutlass Grouped GEMM Support for fine-grained MoE Model…
cassiewilliam Sep 18, 2025
c334fc4
[PyTorch] Support FA3 for MLA and with CP (#1907)
zhujian19891203 Sep 18, 2025
7f77127
Fix cuDNN version checks when getting backend and for sm89 kv cache (…
KshitijLakhani Sep 18, 2025
7f4b020
Merge upstream 7f77127 v2.8.0.dev0 with TODOs
ipanfilo Dec 9, 2025
b95717e
Merged conflits resolution and restore ROCm functionality
ipanfilo Jan 13, 2026
7d2ed36
Merge dev f141f34
ipanfilo Jan 13, 2026
92ce375
Fix JAX and Pytorch UT; code cleanup; ROCm 7.2 w/a (#404)
ipanfilo Jan 13, 2026
a406914
Review comments
ipanfilo Jan 15, 2026
403527b
Merge branch 'dev' into IFU-dev-250918-v2.8
ipanfilo Jan 16, 2026
cfe3fc3
Remove not needed intermetdiate var in cast kernel. Update tests. Dis…
ipanfilo Jan 17, 2026
c29c8fb
Merge branch 'dev' into IFU-dev-250918-v2.8
ipanfilo Jan 17, 2026
08db27e
Resolve automerge error
ipanfilo Jan 17, 2026
8427362
Fix benchmark script. Remove not needed debug messages
ipanfilo Jan 18, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
6 changes: 5 additions & 1 deletion .github/workflows/trigger-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,11 @@ jobs:
|| github.actor == 'lhb8125'
|| github.actor == 'kunlunl'
|| github.actor == 'pstjohn'
|| github.actor == 'mk-61'
|| github.actor == 'vcherepanov-nv'
|| github.actor == 'tdophung'
|| github.actor == 'vthumbe1503'
|| github.actor == 'janekb04'
|| github.actor == 'shengfangd'
)
steps:
- name: Check if comment is issued by authorized person
Expand Down
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,6 @@
[submodule "examples/pytorch/nanogpt"]
path = examples/pytorch/nanogpt
url = https://github.com/floraamd/nanoGPTwTE.git
[submodule "3rdparty/cutlass"]
path = 3rdparty/cutlass
url = https://github.com/NVIDIA/cutlass.git
2 changes: 1 addition & 1 deletion 3rdparty/cudnn-frontend
Submodule cudnn-frontend updated 190 files
1 change: 1 addition & 0 deletions 3rdparty/cutlass
Submodule cutlass added at 57e3cf
6 changes: 3 additions & 3 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -526,15 +526,15 @@ For example to use the NGC PyTorch container interactively,

.. code-block:: bash

docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.04-py3
docker run --gpus all -it --rm nvcr.io/nvidia/pytorch:25.08-py3

For example to use the NGC JAX container interactively,

.. code-block:: bash

docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.04-py3
docker run --gpus all -it --rm nvcr.io/nvidia/jax:25.08-py3

Where 25.04 (corresponding to April 2025 release) is the container version.
Where 25.08 (corresponding to August 2025 release) is the container version.

**Benefits of using NGC containers:**

Expand Down
8 changes: 4 additions & 4 deletions benchmarks/attention/benchmark_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
import torch
import nvtx
import transformer_engine
from tests.pytorch.fused_attn.test_fused_attn import (
from tests.pytorch.utils import (
ModelConfig,
_get_attention_backends,
_run_dot_product_attention,
get_available_attention_backends,
)
from tests.pytorch.attention.test_attention import _run_dot_product_attention

pd.set_option("display.precision", 4)

Expand Down Expand Up @@ -197,7 +197,7 @@ def main():
)
for model in model_configs.keys():
config = model_configs[model]
available_backends, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down
28 changes: 14 additions & 14 deletions benchmarks/attention/benchmark_attention_rocm.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# This file was modified for portability to AMDGPU
# Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2025-2026, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
Expand All @@ -13,17 +13,17 @@
import transformer_engine
from transformer_engine_torch import NVTE_Fused_Attn_Backend

# Add test_fused_attn to the sys path
# Add TE repo root to the sys path
tests_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "../../tests/pytorch/fused_attn")
os.path.join(os.path.dirname(__file__), "../../")
)
sys.path.append(tests_path)

from test_fused_attn import (
from tests.pytorch.utils import (
ModelConfig,
_get_attention_backends,
_run_dot_product_attention,
get_available_attention_backends,
)
from tests.pytorch.attention.test_attention import _run_dot_product_attention

pd.set_option("display.precision", 4)

Expand All @@ -46,12 +46,12 @@
is_training = True

model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"test_0": ModelConfig(2, 16, 16, 64, 512, 512, 0.0, "no_mask", "no_bias"), # short seq
"test_1": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "no_bias"), # longer seq, mask
"test_2": ModelConfig(2, 16, 16, 128, 2048, 2048, 0.0, "causal", "post_scale_bias"), # bias
"test_3": ModelConfig(2, 32, 4, 128, 8192, 8192, 0.0, "causal", "no_bias"), # GQA
"test_4": ModelConfig(2, 128, 8, 128, 8192, 8192, 0.0, "causal_bottom_right", "no_bias")
# b, sq, h, dqk
"test_0": ModelConfig(2, 512, 16, 64), # short seq
"test_1": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal"), # longer seq, mask
"test_2": ModelConfig(2, 2048, 16, 128, attn_mask_type="causal", attn_bias_type="post_scale_bias"), # bias
"test_3": ModelConfig(2, 8192, 32, 128, num_gqa_groups=4, attn_mask_type="causal"), # GQA
"test_4": ModelConfig(2, 8192, 128, 128, num_gqa_groups=8, attn_mask_type="causal_bottom_right")
}

# DataFrame indices and columns for results
Expand Down Expand Up @@ -303,7 +303,7 @@ def sanity_checks(
}

for model, cfg in model_configs.items():
avail, _, fused_bes = _get_attention_backends(
avail, _, fused_bes = get_available_attention_backends(
cfg,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down Expand Up @@ -364,7 +364,7 @@ def main(args):
# Benchmarking starts..
for model in model_configs.keys():
config = model_configs[model]
available_backends, _, fused_attn_backends = _get_attention_backends(
available_backends, _, fused_attn_backends = get_available_attention_backends(
config,
qkv_dtype=dtype,
qkv_layout=qkv_layout,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/linear/benchmark_grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,7 +247,7 @@ def run_benchmark_linear(mkns, recipe_name, use_bias, num_gemms=4):
num_gemms_list = [8]

if args.profile:
mkns = [(4096, 4096, 4096)]
mkns = [(4096 * 8, 4096, 4096)]
# in profile mode, only run one recipe specified in args.recipe
assert args.recipe != "all", (
"In profile mode, only one recipe can be specified, please specify the recipe as"
Expand Down
2 changes: 1 addition & 1 deletion build_tools/VERSION.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
2.6.0.dev0
2.8.0.dev0
15 changes: 1 addition & 14 deletions build_tools/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,7 @@

def install_requirements() -> List[str]:
"""Install dependencies for TE/PyTorch extensions."""
reqs = ["einops"]
if not rocm_build():
reqs.append(
"nvdlfw-inspect @"
" git+https://github.com/NVIDIA/nvidia-dlfw-inspect.git@v0.1#egg=nvdlfw-inspect"
)
reqs.extend(
[
"torch>=2.1",
"onnx",
"onnxscript@git+https://github.com/microsoft/onnxscript.git@51ecf47523ef079c53b0e620c62d56d70cfd3871",
]
)
return reqs
return ["torch>=2.1", "einops", "onnxscript==0.3.1", "onnx"]


def test_requirements() -> List[str]:
Expand Down
4 changes: 2 additions & 2 deletions build_tools/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import subprocess
import sys
from pathlib import Path
from importlib.metadata import version
from importlib.metadata import version as get_version
from subprocess import CalledProcessError
from typing import List, Optional, Tuple, Union

Expand Down Expand Up @@ -340,7 +340,7 @@ def cuda_version() -> Tuple[int, ...]:
return tuple(int(v) for v in version)

try:
version_str = version("nvidia-cuda-runtime-cu12")
version_str = get_version("nvidia-cuda-runtime-cu12")
version_tuple = tuple(int(part) for part in version_str.split(".") if part.isdigit())
return version_tuple
except importlib.metadata.PackageNotFoundError:
Expand Down
6 changes: 4 additions & 2 deletions ci/jax.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/sh
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

Expand Down Expand Up @@ -54,14 +54,14 @@ run_default_fa_lbl() {

run_test_config() {
echo ==== Run with Fused attention backend: $_fus_attn ====
export NVTE_JAX_UNITTEST_LEVEL=L0 # this env variable controls parameters set for some tests
run_default_fa 1 test_custom_call_compute.py
run_default_fa 1 test_functions.py
run 1 test_fused_attn.py
NVTE_CK_USES_FWD_V3=0 NVTE_CK_USES_BWD_V3=0 run_default_fa_lbl "v2" 3 test_fused_attn.py # Using FAv2 for forward and backward pass
run_default_fa 1 test_helper.py
run_default_fa 1 test_layer.py #it effectevly always uses unfused attention
run_default_fa 1 test_sanity_import.py
run_default_fa 1 test_sharding.py
run_default_fa 1 test_softmax.py
}

Expand All @@ -76,8 +76,10 @@ run_test_config_mgpu() {

if [ $_fus_attn = $_DEFAULT_FUSED_ATTN ]; then
_dfa_level=2
export NVTE_JAX_UNITTEST_LEVEL=L1
else
_dfa_level=3
export NVTE_JAX_UNITTEST_LEVEL=L2
fi
run $_dfa_level test_distributed_fused_attn.py $_timeout_args
run_default_fa 3 test_distributed_layernorm.py
Expand Down
24 changes: 12 additions & 12 deletions ci/pytorch.sh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#!/bin/sh
# Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
# Copyright (c) 2024-2026, Advanced Micro Devices, Inc. All rights reserved.
#
# See LICENSE for license information.

Expand Down Expand Up @@ -65,23 +65,23 @@ run_test_config(){
run_default_fa 1 test_recipe.py
run 1 test_sanity.py
run_default_fa 1 test_sanity_import.py
run_default_fa 1 fused_attn/test_fused_attn.py # Backend selection is controlled by the test
run_default_fa 1 attention/test_attention.py # Backend selection is controlled by the test
run_default_fa 1 attention/test_cp_utils.py
run_default_fa 1 attention/test_kv_cache.py
run_default_fa 1 triton_kernels/test_cast.py
run_default_fa 1 triton_kernels/test_cast_mxfp8.py
run_default_fa 1 triton_kernels/test_norm_common.py
run_default_fa 1 triton_kernels/test_norms.py
NVTE_TEST_TRITON_AUTOTUNE=1 run_default_fa_lbl "autotune" 3 triton_kernels/test_norms.py
run_default_fa 1 test_parallel_cross_entropy.py
NVTE_USE_DEQUANTIZE_TRITON=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 NVTE_USE_LAYERNORM_TRITON=1 run_default_fa_lbl "triton" 3 test_numerics.py
NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py
NVTE_USE_CAST_TRANSPOSE_TRITON=1 NVTE_USE_RMSNORM_TRITON=1 run_default_fa_lbl "triton" 1 test_fusible_ops.py
NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "triton" 1 test_float8_current_scaling_exact.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_numerics.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py
NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=0 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_numerics.py
NVTE_USE_ATOMIC_AMAX=0 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa 3 triton_kernels/test_cast.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 test_numerics.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_numerics.py
NVTE_USE_ATOMIC_AMAX=1 NVTE_USE_CAST_TRANSPOSE_TRITON=1 run_default_fa_lbl "amax+triton" 3 test_fusible_ops.py
NVTE_USE_ATOMIC_AMAX=1 run_default_fa_lbl "amax" 3 triton_kernels/test_cast.py
}

run_test_config_mgpu(){
Expand All @@ -93,8 +93,8 @@ run_test_config_mgpu(){
run_default_fa 2 distributed/test_numerics.py
run_default_fa 1 distributed/test_torch_fsdp2.py
run_default_fa 2 distributed/test_torch_fsdp2_fp8.py
run_default_fa_lbl "flash" 3 fused_attn/test_fused_attn_with_cp.py -k "with_flash"
run_default_fa_lbl "fused" 2 fused_attn/test_fused_attn_with_cp.py -k "with_fused"
run_default_fa_lbl "flash" 3 attention/test_attention_with_cp.py -k "with_flash"
run_default_fa_lbl "fused" 2 attention/test_attention_with_cp.py -k "with_fused"
}

run_benchmark() {
Expand Down
4 changes: 4 additions & 0 deletions docs/api/jax.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ Variables are available in `transformer_engine.jax.sharding`.
* JOINED_AXES: The logical axis of non-defined dimension. It is usually not sharded.


Checkpointing
------------------------------------
When using checkpointing with Transformer Engine JAX, please be aware of the checkpointing policy being applied to your model. Any JAX checkpointing policy using `dot`, such as `jax.checkpoint_policies.dots_with_no_batch_dims`, may not work with GEMMs provided by Transformer Engine as they do not always use the `jax.lax.dot_general` primitive. Instead, you can use `transformer_engine.jax.checkpoint_policies.dots_and_te_gemms_with_no_batch_dims` or similar policies that are designed to work with Transformer Engine's GEMMs and `jax.lax.dot_general` GEMMs. You may also use any JAX policies that do not filter by primitive, such as `jax.checkpoint_policies.save_only_these_names` or `jax.checkpoint_policies.everything_saveable`.

Modules
------------------------------------
.. autoapiclass:: transformer_engine.jax.flax.TransformerLayerType
Expand Down
5 changes: 4 additions & 1 deletion docs/api/pytorch.rst
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.moe_permute

.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs
.. autoapifunction:: transformer_engine.pytorch.moe_permute_with_probs

.. autoapifunction:: transformer_engine.pytorch.moe_unpermute

Expand All @@ -63,3 +63,6 @@ pyTorch

.. autoapifunction:: transformer_engine.pytorch.destroy_ub
.. autoapifunction:: transformer_engine.pytorch.moe_sort_chunks_by_index

.. autoapiclass:: transformer_engine.pytorch.UserBufferQuantizationMode
:members: FP8, NONE
6 changes: 3 additions & 3 deletions docs/debug/1_getting_started.rst
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ Transformer Engine provides a set of precision debug tools which allow you to ea
There are 4 things one needs to do to use Transformer Engine debug features:

1. Create a configuration YAML file to configure the desired features.
2. Import, and initialize the `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ tool, which is installed as the dependency of the Transformer Engine.
2. Import, initialize, and install the `Nvidia-DL-Framework-Inspect <https://github.com/NVIDIA/nvidia-dlfw-inspect>`_ tool.
3. One can pass ``name="..."`` when creating TE layers to easier identify layer names. If this is not provided, names will be inferred automatically.
4. Invoke ``debug_api.step()`` at the end of one forward-backward pass.

Expand Down Expand Up @@ -141,7 +141,7 @@ Adjusting Python file
In the modified code above, the following changes were made:

1. Added an import for ``nvdlfw_inspect.api``.
2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory.
2. Initialized the Nvidia-DL-Framework-Inspect by calling ``debug_api.initialize()`` with appropriate configuration, specifying the path to the config file, feature directories, and log directory. The directory with Transformer Engine features is located `here <https://github.com/NVIDIA/TransformerEngine/tree/main/transformer_engine/debug/features>`_. The full parameters description could be found :doc:`here <3_api_debug_setup>`.
3. Added ``debug_api.step()`` after each of the forward-backward pass.

Inspecting the logs
Expand Down Expand Up @@ -238,4 +238,4 @@ Let's run training and open TensorBoard by ``tensorboard --logdir=./tensorboard_
.. figure:: ./img/tensorboard.png
:align: center

Fig 2: TensorBoard with plotted stats.
Fig 2: TensorBoard with plotted stats.
16 changes: 5 additions & 11 deletions docs/debug/3_api_te_calls.rst
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,7 @@ Let's look deeper into how Nvidia-DL-Framework-Inspect with Transformer Engine w

Fig 1: Example of Nvidia-DL-Framework-Inspect affecting training script with 1 Linear Layer. For tensors mentioned in ``config.yaml``, behavior of ``modify_tensor_enabled()`` and ``modify_tensor()`` calls are substituted with definitions from the feature class. Other calls return default values - in fact they do nothing.

In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed. The order of these calls is illustrated in the image below.

.. figure:: ./img/api_calls2.svg
:align: center

Fig 2: The calls to Nvidia-DL-Framework-Inspect done for Transformer Engine. There are 2 types of calls: GEMM calls and routing calls.


In this page, all calls from TransformerEngine to the Nvidia-DL-Framework-Inspect for each GEMM are listed.
There are 2 categories of API calls, each is used for different purposes:

- GEMM calls - invoked during every GEMM, used to process or quantize tensors and collect information about them,
Expand All @@ -32,14 +25,15 @@ if fusions happen. An important remark is that if no feature is used for the lay

.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor

.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor

.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize

.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.modify_tensor_enabled

.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.fp8_gemm_enabled

.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor

.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize

.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_enabled

.. autoapifunction:: transformer_engine.debug.features.api.TEDefaultFeatures.inspect_tensor_postquantize_enabled
Loading