Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/pruna/algorithms/deepcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,14 @@ class DeepCache(PrunaAlgorithmBase):
processor_required: bool = False
dataset_required: bool = False
runs_on: list[str] = ["cpu", "cuda", "accelerate"]
compatible_before: Iterable[str] = ["qkv_diffusers", "half", "hqq_diffusers", "diffusers_int8", "quanto"]
compatible_before: Iterable[str] = [
"qkv_diffusers",
"half",
"hqq_diffusers",
"diffusers_int8",
"quanto",
"sage_attn",
]
compatible_after: Iterable[str] = ["stable_fast", "torch_compile"]

def get_hyperparameters(self) -> list:
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/fastercache.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ class FasterCache(PrunaAlgorithmBase):
processor_required: bool = False
dataset_required: bool = False
runs_on: list[str] = ["cpu", "cuda", "accelerate"]
compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8"]
compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn"]

def get_hyperparameters(self) -> list:
"""
Expand Down
9 changes: 8 additions & 1 deletion src/pruna/algorithms/fora.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,14 @@ class FORA(PrunaAlgorithmBase):
processor_required: bool = False
runs_on: list[str] = ["cpu", "cuda", "accelerate"]
dataset_required: bool = False
compatible_before: Iterable[str] = ["qkv_diffusers", "diffusers_int8", "hqq_diffusers", "torchao", "flash_attn3"]
compatible_before: Iterable[str] = [
"qkv_diffusers",
"diffusers_int8",
"hqq_diffusers",
"torchao",
"flash_attn3",
"sage_attn"
]
compatible_after: Iterable[str] = ["stable_fast", "torch_compile"]

def get_hyperparameters(self) -> list:
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/gptq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ class GPTQ(PrunaAlgorithmBase):
processor_required: bool = False
runs_on: list[str] = ["cuda"]
dataset_required: bool = True
compatible_after: Iterable[str] = ["torch_compile"]
compatible_after: Iterable[str] = ["torch_compile", "sage_attn"]
required_install: str = (
"You must first install the base package with ``pip install pruna`` "
"before installing the GPTQ extension with ``pip install pruna[gptq] --extra-index-url https://prunaai.pythonanywhere.com/``"
Expand Down
1 change: 1 addition & 0 deletions src/pruna/algorithms/half.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ class Half(PrunaAlgorithmBase):
"torch_compile",
"ifw",
"whisper_s2t",
"sage_attn",
]

def model_check_fn(self, model: Any) -> bool:
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/hqq.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class HQQ(PrunaAlgorithmBase):
runs_on: list[str] = ["cuda"]
dataset_required: bool = False
compatible_before: Iterable[str] = ["torch_structured"]
compatible_after: Iterable[str] = ["torch_compile"]
compatible_after: Iterable[str] = ["torch_compile", "sage_attn"]

def get_hyperparameters(self) -> list:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/hqq_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class HQQDiffusers(PrunaAlgorithmBase):
runs_on: list[str] = ["cuda"]
dataset_required: bool = False
compatible_before: Iterable[str] = ["qkv_diffusers"]
compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile"]
compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile", "sage_attn"]

def get_hyperparameters(self) -> list:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/huggingface_diffusers_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class DiffusersInt8(PrunaAlgorithmBase):
runs_on: list[str] = ["cuda", "accelerate"]
save_fn: None = None
compatible_before: Iterable[str] = ["qkv_diffusers"]
compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile"]
compatible_after: Iterable[str] = ["deepcache", "fastercache", "fora", "pab", "torch_compile", "sage_attn"]

def get_hyperparameters(self) -> list:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/huggingface_llm_int8.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class LLMInt8(PrunaAlgorithmBase):
dataset_required: bool = False
runs_on: list[str] = ["cuda", "accelerate"]
save_fn: None = None
compatible_after: Iterable[str] = ["torch_compile"]
compatible_after: Iterable[str] = ["torch_compile", "sage_attn"]

def get_hyperparameters(self) -> list:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/pab.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class PAB(PrunaAlgorithmBase):
processor_required: bool = False
dataset_required: bool = False
runs_on: list[str] = ["cpu", "cuda", "accelerate"]
compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8"]
compatible_before: Iterable[str] = ["hqq_diffusers", "diffusers_int8", "sage_attn"]
compatible_after: Iterable[str] = []

def get_hyperparameters(self) -> list:
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/quanto.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ class Quanto(PrunaAlgorithmBase):
dataset_required: bool = False
runs_on: list[str] = ["cuda"]
compatible_before: Iterable[str] = ["qkv_diffusers"]
compatible_after: Iterable[str] = ["deepcache"]
compatible_after: Iterable[str] = ["deepcache", "sage_attn"]

def get_hyperparameters(self) -> list:
"""
Expand Down
184 changes: 184 additions & 0 deletions src/pruna/algorithms/sage_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,184 @@
# Copyright 2025 - Pruna AI GmbH. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

from collections.abc import Iterable
from typing import Any, List

import torch
from diffusers import DiffusionPipeline

from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase
from pruna.algorithms.base.tags import AlgorithmTag as tags
from pruna.config.smash_config import SmashConfigPrefixWrapper
from pruna.config.target_modules import TARGET_MODULES_TYPE, TargetModules, map_targeted_nn_roots
from pruna.engine.save import SAVE_FUNCTIONS
from pruna.logging.logger import pruna_logger


class SageAttn(PrunaAlgorithmBase):
"""
Replace torch.nn.functional.scaled_dot_product_attention with sage_attn.

SageAttention is a fast and memory-efficient attention mechanism. It applies the flash attention mechanism
in combination with quantization and smoothing to speed up attention computations.
"""

algorithm_name: str = "sage_attn"
group_tags: list[str] = [tags.KERNEL]
save_fn = SAVE_FUNCTIONS.reapply
references: dict[str, str] = {
"Paper (SA2++)": "https://arxiv.org/pdf/2505.21136v3",
"GitHub": "https://github.com/thu-ml/SageAttention",
"Kernel Hub": "https://huggingface.co/kernels-community/sage_attention",
}
tokenizer_required: bool = False
processor_required: bool = False
runs_on: list[str] = ["cuda", "accelerate"]
dataset_required: bool = False
compatible_before: Iterable[str] = [tags.QUANTIZER]
compatible_after: Iterable[str] = ["torch_compile", tags.CACHER]

def model_check_fn(self, model: Any) -> bool:
"""
Check if the model has an attention mechanism that can be replaced with sage_attn.

Parameters
----------
model : Any
The model to check.

Returns
-------
bool
True if the model is a valid model for the algorithm, False otherwise.
"""
if not isinstance(model, DiffusionPipeline) or not hasattr(model, "components"):
return False

return any(
hasattr(component, "set_attention_backend") and component.dtype in (torch.bfloat16, torch.float16)
for component in model.components.values()
)

def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any:
"""
Wrap the model to use SageAttention where possible.

Parameters
----------
model : Any
The model to wrap.
smash_config : SmashConfigPrefixWrapper
The configuration for the application of the algorithm.

Returns
-------
Any
The wrapped model.
"""
target_modules = smash_config["target_modules"]

if target_modules is None:
target_modules = self.get_model_dependent_hyperparameter_defaults(
model,
smash_config
) # for consistency, not used yet

def apply_sage_attn(
root_name: str | None,
root_nn_module: torch.nn.Module,
relative_target_paths: List[str],
) -> torch.nn.Module:
"""
Apply the SageAttention backend to targeted submodules of a root module.

For each relative submodule path, this function retrieves the corresponding
submodule from ``root_nn_module`` and applies
``set_attention_backend("sage_hub")`` if the method is available.

Parameters
----------
root_name : str or None
The attribute name of the root module within the model (used for identification).
May be ``None`` if the model itself is a ``torch.nn.Module``.
root_nn_module : torch.nn.Module
The root torch.nn.module containing the targeted submodules.
relative_target_paths : List[str]
Relative paths of submodules (with respect to ``root_nn_module``) to consider.

Returns
-------
torch.nn.Module
The root ntorch.nn.module with the SageAttention backend applied where supported.
"""
for rel_path in relative_target_paths:
try:
sub_module = root_nn_module.get_submodule(rel_path)
except AttributeError:
# safety net: should not happen,
# since the paths come from named_modules()
continue
if hasattr(sub_module, "set_attention_backend"):
sub_module.set_attention_backend("sage_hub")
else:
pruna_logger.warning(f"Module {root_name}.{rel_path} does not have a set_attention_backend method"
"and will not be replaced with SageAttention")
return root_nn_module

return map_targeted_nn_roots(apply_sage_attn, model, target_modules)

def get_hyperparameters(self) -> list:
"""
Get the list of configurable hyperparameters for this algorithm.

Returns
-------
list
A list of hyperparameter objects (e.g., Boolean, TargetModules) used by the
configuration system.
"""
return [
TargetModules(name="target_modules", default_value=None),
]

def get_model_dependent_hyperparameter_defaults(
self,
model: Any,
smash_config: SmashConfigPrefixWrapper,
) -> TARGET_MODULES_TYPE:
"""
Get model-dependent default hyperparameters for this algorithm.

Parameters
----------
model : Any
The model/pipeline instance for which defaults should be computed.
smash_config : SmashConfigPrefixWrapper
The configuration wrapper passed to the algorithm. It can be used to read other
algorithm settings when selecting defaults.

Returns
-------
TARGET_MODULES_TYPE
A dictionary with keys "include" and "exclude" defining which modules should be
targeted by default.
"""
# So far, everything is included and nothing is excluded
# Filtering is done in the _apply method by the set_attention_backend method
include = ["*"]
exclude = []

return {"include": include, "exclude": exclude}
1 change: 1 addition & 0 deletions src/pruna/algorithms/torch_compile/torch_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class TorchCompile(PrunaAlgorithmBase):
"flash_attn3",
"deepcache",
"fora",
"sage_attn",
]

def get_hyperparameters(self) -> list:
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/torch_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ class TorchDynamic(PrunaAlgorithmBase):
runs_on: list[str] = ["cpu", "cuda"]
dataset_required: bool = False
compatible_before: Iterable[str] = []
compatible_after: Iterable[str] = []
compatible_after: Iterable[str] = ["sage_attn"]

def get_hyperparameters(self) -> list:
"""
Expand Down
2 changes: 1 addition & 1 deletion src/pruna/algorithms/torchao.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@ class Torchao(PrunaAlgorithmBase):
runs_on: list[str] = ["cpu", "cuda", "accelerate"]
dataset_required: bool = False
compatible_before: Iterable[str] = ["qkv_diffusers", "torch_structured"]
compatible_after: Iterable[str] = ["flash_attn3", "fora", "torch_compile"]
compatible_after: Iterable[str] = ["flash_attn3", "fora", "torch_compile", "sage_attn"]

def get_hyperparameters(self) -> list:
"""
Expand Down
16 changes: 16 additions & 0 deletions tests/algorithms/testers/sage_attn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import pytest

from pruna.algorithms.sage_attn import SageAttn

from .base_tester import AlgorithmTesterBase


@pytest.mark.high
class TestSageAttn(AlgorithmTesterBase):
"""Test the sage attention kernel."""

models = ["flux_tiny", "wan_tiny_random"]
reject_models = ["opt_tiny_random"]
allow_pickle_files = False
algorithm_class = SageAttn
metrics = ["latency"]