From 91a1001f84925460a0b091128202bc6f670655c5 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team <961186938@qq.com> Date: Wed, 3 Dec 2025 06:50:13 +0800 Subject: [PATCH 1/2] fix transformer engine build --- transformer_engine/common/__init__.py | 4 +-- transformer_engine/common/common.h | 2 +- .../common/nvshmem_api/nvshmem_waitkernel.cu | 1 + .../dot_product_attention/backends.py | 31 +++++++++++++------ 4 files changed, 26 insertions(+), 12 deletions(-) diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 5e1318cf86..bbda2cb995 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -174,7 +174,7 @@ def load_framework_extension(framework: str) -> None: # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework # extension are all installed via PyPI and have matching versions. if te_framework_installed: - assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." + # assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`." assert version(module_name) == version("transformer-engine") == te_core_version, ( @@ -203,7 +203,7 @@ def sanity_checks_for_pypi_installation() -> None: # If the core package is installed via PyPI. if te_core_installed: - assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." + # assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." assert version("transformer-engine") == te_core_version, ( "Transformer Engine package version mismatch. Found " f"transformer-engine v{version('transformer-engine')} " diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index bddd9bf194..4a01f306f6 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -18,7 +18,7 @@ #endif #include -#include +#include "transformer_engine/transformer_engine.h" #include #include diff --git a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu index d5f6aeecce..f122eea91d 100644 --- a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu +++ b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu @@ -16,6 +16,7 @@ #include #include "../util/logging.h" +#include "../util/cuda_driver.h" #include "nvshmem_waitkernel.h" __global__ void __launch_bounds__(1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 95558e30da..48e11e4c58 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -125,15 +125,28 @@ flash_attn_with_kvcache_v3 = None # pass # only print warning if use_flash_attention_3 = True in get_attention_backend else: - from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flash_attn_3.flash_attn_interface import ( - flash_attn_varlen_func as flash_attn_varlen_func_v3, - ) - from flash_attn_3.flash_attn_interface import ( - flash_attn_with_kvcache as flash_attn_with_kvcache_v3, - ) - from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 - from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 + try: + from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_3.flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_v3, + ) + from flash_attn_3.flash_attn_interface import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_v3, + ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 + except ModuleNotFoundError as e: + print(e) + print("Please install fa3 : pip install git+https://github.com/Dao-AILab/flash-attention.git#subdirectory=hopper, usage : \"from flash_attn_3.flash_attn_interface import flash_attn_func\"") + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_v3, + ) + from flash_attn_interface import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_v3, + ) + from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 fa_utils.set_flash_attention_3_params() From 1d573875aeeb8bc7528b16d3a13db9dddec83a90 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 2 Dec 2025 22:58:18 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/common/common.h | 2 +- transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu | 2 +- .../pytorch/attention/dot_product_attention/backends.py | 6 +++++- 3 files changed, 7 insertions(+), 3 deletions(-) diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 4a01f306f6..767ed5453c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -18,7 +18,6 @@ #endif #include -#include "transformer_engine/transformer_engine.h" #include #include @@ -32,6 +31,7 @@ #include "./nvtx.h" #include "./util/cuda_driver.h" #include "./util/logging.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine { diff --git a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu index f122eea91d..aeb2978c3b 100644 --- a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu +++ b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu @@ -15,8 +15,8 @@ #include #include -#include "../util/logging.h" #include "../util/cuda_driver.h" +#include "../util/logging.h" #include "nvshmem_waitkernel.h" __global__ void __launch_bounds__(1) diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 48e11e4c58..8b176a0a61 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -137,7 +137,11 @@ from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 except ModuleNotFoundError as e: print(e) - print("Please install fa3 : pip install git+https://github.com/Dao-AILab/flash-attention.git#subdirectory=hopper, usage : \"from flash_attn_3.flash_attn_interface import flash_attn_func\"") + print( + "Please install fa3 : pip install" + " git+https://github.com/Dao-AILab/flash-attention.git#subdirectory=hopper, usage :" + ' "from flash_attn_3.flash_attn_interface import flash_attn_func"' + ) from flash_attn_interface import flash_attn_func as flash_attn_func_v3 from flash_attn_interface import ( flash_attn_varlen_func as flash_attn_varlen_func_v3,