Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions transformer_engine/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def load_framework_extension(framework: str) -> None:
# PyPI. For this case we need to make sure that the metapackage, the core lib, and framework
# extension are all installed via PyPI and have matching versions.
if te_framework_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
# assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: Commenting out this validation weakens installation safety checks. Users with mismatched installations (e.g., PyPI framework package + source-built core) may encounter runtime issues. Consider a more targeted fix that specifically allows the SGLang 0.5.5 scenario while keeping validation for other cases.

assert te_core_installed, "Could not find TE core package `transformer-engine-cu*`."

assert version(module_name) == version("transformer-engine") == te_core_version, (
Expand Down Expand Up @@ -203,7 +203,7 @@ def sanity_checks_for_pypi_installation() -> None:

# If the core package is installed via PyPI.
if te_core_installed:
assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
# assert te_installed_via_pypi, "Could not find `transformer-engine` PyPI package."
assert version("transformer-engine") == te_core_version, (
"Transformer Engine package version mismatch. Found "
f"transformer-engine v{version('transformer-engine')} "
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
#endif

#include <cuda_runtime_api.h>
#include <transformer_engine/transformer_engine.h>

#include <cstdint>
#include <functional>
Expand All @@ -32,6 +31,7 @@
#include "./nvtx.h"
#include "./util/cuda_driver.h"
#include "./util/logging.h"
#include "transformer_engine/transformer_engine.h"

namespace transformer_engine {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <sstream>
#include <string>

#include "../util/cuda_driver.h"
#include "../util/logging.h"
#include "nvshmem_waitkernel.h"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,15 +125,32 @@
flash_attn_with_kvcache_v3 = None
# pass # only print warning if use_flash_attention_3 = True in get_attention_backend
else:
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_3.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_3.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
try:
from flash_attn_3.flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_3.flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_3.flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_3.flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_3.flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
except ModuleNotFoundError as e:
print(e)
print(
"Please install fa3 : pip install"
" git+https://github.com/Dao-AILab/flash-attention.git#subdirectory=hopper, usage :"
' "from flash_attn_3.flash_attn_interface import flash_attn_func"'
)
from flash_attn_interface import flash_attn_func as flash_attn_func_v3
from flash_attn_interface import (
flash_attn_varlen_func as flash_attn_varlen_func_v3,
)
from flash_attn_interface import (
flash_attn_with_kvcache as flash_attn_with_kvcache_v3,
)
from flash_attn_interface import _flash_attn_forward as _flash_attn_fwd_v3
from flash_attn_interface import _flash_attn_backward as _flash_attn_bwd_v3
Comment on lines +145 to +153
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

logic: This fallback will fail if flash_attn_interface module doesn't exist. The package flash-attn-3 was detected but imports from flash_attn_3.flash_attn_interface failed. If the goal is to support an alternative location like flash_attn.flash_attn_interface (from flash-attn v2), this should use flash_attn.flash_attn_interface instead of bare flash_attn_interface. Without the proper module path, these imports will raise ModuleNotFoundError and the variables will remain undefined, causing issues later.


fa_utils.set_flash_attention_3_params()

Expand Down