diff --git a/transformer_engine/common/__init__.py b/transformer_engine/common/__init__.py index 5e1318cf869..bbda2cb9951 100644 --- a/transformer_engine/common/__init__.py +++ b/transformer_engine/common/__init__.py @@ -174,7 +174,7 @@ def load_framework_extension(framework: str) -> None: # PyPI. For this case we need to make sure that the metapackage, the core lib, and framework # extension are all installed via PyPI and have matching versions. if te_framework_installed: - assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." + # assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`." assert version(module_name) == version("transformer-engine") == te_core_version, ( @@ -203,7 +203,7 @@ def sanity_checks_for_pypi_installation() -> None: # If the core package is installed via PyPI. if te_core_installed: - assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." + # assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package." assert version("transformer-engine") == te_core_version, ( "Transformer Engine package version mismatch. Found " f"transformer-engine v{version('transformer-engine')} " diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index bddd9bf194b..767ed5453cb 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -18,7 +18,6 @@ #endif #include -#include #include #include @@ -32,6 +31,7 @@ #include "./nvtx.h" #include "./util/cuda_driver.h" #include "./util/logging.h" +#include "transformer_engine/transformer_engine.h" namespace transformer_engine { diff --git a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu index d5f6aeecced..aeb2978c3b8 100644 --- a/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu +++ b/transformer_engine/common/nvshmem_api/nvshmem_waitkernel.cu @@ -15,6 +15,7 @@ #include #include +#include "../util/cuda_driver.h" #include "../util/logging.h" #include "nvshmem_waitkernel.h" diff --git a/transformer_engine/pytorch/attention/dot_product_attention/backends.py b/transformer_engine/pytorch/attention/dot_product_attention/backends.py index 95558e30da3..8b176a0a61e 100644 --- a/transformer_engine/pytorch/attention/dot_product_attention/backends.py +++ b/transformer_engine/pytorch/attention/dot_product_attention/backends.py @@ -125,15 +125,32 @@ flash_attn_with_kvcache_v3 = None # pass # only print warning if use_flash_attention_3 = True in get_attention_backend else: - from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 - from flash_attn_3.flash_attn_interface import ( - flash_attn_varlen_func as flash_attn_varlen_func_v3, - ) - from flash_attn_3.flash_attn_interface import ( - flash_attn_with_kvcache as flash_attn_with_kvcache_v3, - ) - from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 - from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 + try: + from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_3.flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_v3, + ) + from flash_attn_3.flash_attn_interface import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_v3, + ) + from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 + except ModuleNotFoundError as e: + print(e) + print( + "Please install fa3 : pip install" + " git+https://github.com/Dao-AILab/flash-attention.git#subdirectory=hopper, usage :" + ' "from flash_attn_3.flash_attn_interface import flash_attn_func"' + ) + from flash_attn_interface import flash_attn_func as flash_attn_func_v3 + from flash_attn_interface import ( + flash_attn_varlen_func as flash_attn_varlen_func_v3, + ) + from flash_attn_interface import ( + flash_attn_with_kvcache as flash_attn_with_kvcache_v3, + ) + from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3 + from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3 fa_utils.set_flash_attention_3_params()