From ce78cc682eac70d60da57c909bdafe7765a6acf6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 19 Nov 2025 16:03:02 +0000 Subject: [PATCH 1/5] fixed conflicts --- pyproject.toml | 2 +- src/pruna/algorithms/ipex_llm.py | 154 +++++++++ src/pruna/algorithms/x_fast.py | 462 +++++++++++++++++++++++++++ tests/algorithms/testers/ipex_llm.py | 17 + tests/algorithms/testers/x_fast.py | 13 + tests/common.py | 2 +- 6 files changed, 648 insertions(+), 2 deletions(-) create mode 100644 src/pruna/algorithms/ipex_llm.py create mode 100644 src/pruna/algorithms/x_fast.py create mode 100644 tests/algorithms/testers/ipex_llm.py create mode 100644 tests/algorithms/testers/x_fast.py diff --git a/pyproject.toml b/pyproject.toml index bc2ca641..99fc6323 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -137,7 +137,7 @@ dependencies = [ "aenum", "vbench-pruna; sys_platform != 'darwin'", "imageio-ffmpeg", - "jaxtyping" + "jaxtyping", ] [project.optional-dependencies] diff --git a/src/pruna/algorithms/ipex_llm.py b/src/pruna/algorithms/ipex_llm.py new file mode 100644 index 00000000..b5f79508 --- /dev/null +++ b/src/pruna/algorithms/ipex_llm.py @@ -0,0 +1,154 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +import re +from typing import Any, Dict + +import torch +from ConfigSpace import OrdinalHyperparameter +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.logging.logger import pruna_logger + +class IPEXLLM(PrunaAlgorithmBase): + """ + Implement IPEX LLM compilation using the intel library. + + This compiler leverages advanced graph optimizations, quantization, and kernel fusion techniques to accelerate + PyTorch-based LLM inference on Intel CPUs. + + Note: After compilation, the model supports sequence lengths that are either ≤ 32, or even numbers. + """ + + algorithm_name: str = "ipex_llm" + group_tags: list[AlgorithmTag] = [AlgorithmTag.COMPILER] + references: dict[str, str] = {"Github": "https://github.com/intel/intel-extension-for-pytorch"} + tokenizer_required: bool = False + processor_required: bool = False + dataset_required: bool = False + save_fn = SAVE_FUNCTIONS.save_before_apply + runs_on: list[str] = ["cpu"] + compatible_before: list[str] = ["half"] + required_install = ( + "``pip install pruna[intel]`` " + "``--extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/``" + ) + + def get_hyperparameters(self) -> list: + """ + Get the hyperparameters for IPEX LLM compilation. + + Returns + ------- + list + The hyperparameters. + """ + return [ + OrdinalHyperparameter( + "weight_bits", + sequence=[8, 4], + default_value=8, + meta=dict(desc="The number of bits to use for weight quantization."), + ), + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is compatible with IPEX LLM compilation. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + Whether the model is compatible with IPEX LLM compilation. + """ + imported_modules = self.import_algorithm_packages() + # Find the installation path of ipex + ipex_path = os.path.dirname(imported_modules["ipex"].__file__) + # Try to find the models.py file + transformers_path = os.path.join(ipex_path, "transformers") + # Find the full path of models.py if it exists + models_path = os.path.join(transformers_path, "models", "reference", "models.py") + if os.path.exists(models_path): + # Read the function names from the file + with open(models_path, "r") as f: + content = f.read() + # Simple regex to find function definitions + funcs = [f for f in re.findall(r"def\s+([A-Z][a-zA-Z0-9_]*)\s*\(", content) if f.endswith("_forward")] + compatible_list = [name.replace("_forward", "") for name in funcs] + return model.__class__.__name__ in compatible_list + else: + pruna_logger.warning("IPEX models.py file not found. Please check if IPEX is installed correctly.") + return False + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Compile the model with IPEX LLM. + + Parameters + ---------- + model : Any + The model to compile. + smash_config : SmashConfigPrefixWrapper + The configuration to use for compilation. + + Returns + ------- + Any + The compiled model. + """ + imported_modules = self.import_algorithm_packages() + ipex = imported_modules["ipex"] + woq_weight_dtype = imported_modules["WoqWeightDtype"] + + weight_dtype = woq_weight_dtype.INT8 if smash_config["weight_bits"] == 8 else woq_weight_dtype.INT4 + + lowp_mode = ipex.quantization.WoqLowpMode.INT8 + + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(weight_dtype=weight_dtype, lowp_mode=lowp_mode) + + model = ipex.llm.optimize( + model.eval(), + dtype=getattr(torch, "float32"), + quantization_config=qconfig, + low_precision_checkpoint=None, + deployment_mode=True, + inplace=True, + ) + + return model + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Import the algorithm packages. + + Returns + ------- + Dict[str, Any] + The algorithm packages. + """ + # Import necessary modules here to avoid unnecessary imports and ensure they're available when needed + import intel_extension_for_pytorch as ipex + from intel_extension_for_pytorch.quantization import WoqWeightDtype + + return dict( + ipex=ipex, + WoqWeightDtype=WoqWeightDtype, + ) diff --git a/src/pruna/algorithms/x_fast.py b/src/pruna/algorithms/x_fast.py new file mode 100644 index 00000000..9de0d7b5 --- /dev/null +++ b/src/pruna/algorithms/x_fast.py @@ -0,0 +1,462 @@ +# Copyright 2025 - Pruna AI GmbH. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +import dataclasses +import inspect +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import torch +from ConfigSpace import Constant +from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase +from pruna.algorithms.base.tags import AlgorithmTag +from pruna.config.hyperparameters import Boolean +from pruna.config.smash_config import SmashConfigPrefixWrapper +from pruna.engine.save import SAVE_FUNCTIONS +from pruna.logging.logger import pruna_logger + +class XFast(PrunaAlgorithmBase): + """ + Implement X-Fast compilation using the sfast library. + + Based on stable_fast, this compiler speeds up inference latency for any model using a combination of xformers, + triton, cudnn, and torch tracing. + """ + + algorithm_name: str = "x_fast" + group_tags: list[AlgorithmTag] = [AlgorithmTag.COMPILER] + save_fn = SAVE_FUNCTIONS.save_before_apply + references: dict[str, str] = {} + tokenizer_required: bool = False + processor_required: bool = False + runs_on: list[str] = ["cuda"] + dataset_required: bool = False + compatible_before: list[str | AlgorithmTag] = [ + "quanto", + "half", + "text_to_text_lora", + "text_to_image_lora", + "text_to_image_perp", + "text_to_image_inplace_perp", + "text_to_text_perp", + "text_to_text_inplace_perp", + "text_to_image_distillation_inplace_perp", + "text_to_image_distillation_lora", + "text_to_image_distillation_perp", + ] + required_install: str = "``pip install pruna[stable-fast]``" + + def get_hyperparameters(self) -> list: + """ + Get the hyperparameters for the X-Fast compiler. + + Returns + ------- + list + The hyperparameters. + """ + return [ + Constant("fn_to_compile", value="forward"), + Boolean("xformers", default=True, meta=dict(desc="Whether to use xformers for faster inference.")), + ] + + def model_check_fn(self, model: Any) -> bool: + """ + Check if the model is a valid model for the X-Fast compiler. + + Parameters + ---------- + model : Any + The model to check. + + Returns + ------- + bool + True if the model is valid, False otherwise. + """ + if isinstance(model, torch.nn.Module): + return True + return any(isinstance(attr_value, torch.nn.Module) for _, attr_value in inspect.getmembers(model)) + + def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Compile the model using the X-Fast compiler. + + Parameters + ---------- + model : Any + The model to compile. + smash_config : SmashConfigPrefixWrapper + The configuration for the compilation process. + + Returns + ------- + Any + The compiled model. + """ + imported_modules = self.import_algorithm_packages() + + model.x_fast_compiler = XFastHelper(imported_modules) + + if smash_config["torch_dynamic"]: + model = torch.quantization.quantize_dynamic( + model, # Input model + {torch.nn.Linear}, + dtype=torch.qint8, + inplace=True, + ) + + # Second we need to compile the model and return it + smashed_model = model.x_fast_compiler.compile(model, smash_config) + return smashed_model + + def import_algorithm_packages(self) -> Dict[str, Any]: + """ + Import the necessary packages for the X-Fast compiler. + + Returns + ------- + Dict[str, Any] + The imported modules. + """ + import sfast + from sfast.compilers.diffusion_pipeline_compiler import ( + CompilationConfig, + _build_lazy_trace, + _enable_xformers, + ) + from sfast.cuda.graphs import ( + make_dynamic_graphed_callable, # apply_auto_graph_compiler, + ) + from sfast.utils.memory_format import apply_memory_format + + sfast.cuda.graphs.get_cuda_device_from_tensors = get_cuda_device_from_tensors + + return dict( + CompilationConfig=CompilationConfig, + _build_lazy_trace=_build_lazy_trace, + _enable_xformers=_enable_xformers, + make_dynamic_graphed_callable=make_dynamic_graphed_callable, + apply_memory_format=apply_memory_format, + ) + + +def get_cuda_device_from_tensors( + x: Union[torch.Tensor, List[Any], Tuple[Any], Dict[Any, Any], Any], +) -> Optional[torch.device | int]: + """ + Recursively searches for a CUDA device index in a tensor or a nested structure of tensors. + + Parameters + ---------- + x : Union[torch.Tensor, list, tuple, dict, Any] + A tensor or a nested structure (list, tuple, dictionary, or dataclass) containing tensors. + + Returns + ------- + Optional[torch.device | int] + The index of the CUDA device if a tensor is found on a CUDA device, otherwise None. + """ + device: Optional[torch.device | int] = None + if isinstance(x, torch.Tensor): + device = x.device + if device.type == "cuda": + return device.index + return None + elif isinstance(x, (list, tuple)): + for y in x: + device = get_cuda_device_from_tensors(y) + if device is not None: + return device + return None + elif dataclasses.is_dataclass(x): + for k in dataclasses.fields(x): + device = get_cuda_device_from_tensors(getattr(x, k.name)) + if device is not None: + return device + return None + elif isinstance(x, dict): + for v in x.values(): + device = get_cuda_device_from_tensors(v) + if device is not None: + return device + return None + else: + return None + + +def get_and_compile_nested_attribute(obj: Any, attr_path: str) -> Any: + """ + Get and compile a nested attribute of an object. + + Parameters + ---------- + obj : Any + The object to retrieve the nested attribute from. + attr_path : str + The path to the nested attribute, using dot notation. + + Returns + ------- + Any + The compiled nested attribute. + """ + current_attr = obj + attr_chain = attr_path.split(".") + + for attr in attr_chain[:-1]: + current_attr = getattr(current_attr, attr) + + # Get the final attribute (method) in the chain + final_attr = getattr(current_attr, attr_chain[-1]) + return final_attr + + +def apply_lazy_tracing_and_dynamic_graphing( + model: Any, + config: Any, + smash_config: SmashConfigPrefixWrapper, + enable_cuda_graph: bool, + imported_modules: Dict[str, Any], +) -> None: + """ + Apply lazy tracing and dynamic graphing to the given model. + + Parameters + ---------- + model : Any + The model to apply lazy tracing and dynamic graphing to. + config : Any + The configuration for lazy tracing. + smash_config : SmashConfigPrefixWrapper + The configuration for smashing (e.g., which functions to compile). + enable_cuda_graph : bool + Flag indicating whether to enable CUDA graph. + imported_modules : Dict[str, Any] + Dictionary containing the imported modules. + """ + config.enable_cnn_optimization = False + lazy_trace_ = imported_modules["_build_lazy_trace"]( + config, + enable_triton_reshape=enable_cuda_graph, + enable_triton_layer_norm=enable_cuda_graph, + ) + + current_attribute = get_and_compile_nested_attribute(model, smash_config["fn_to_compile"]) + modified_attribute = lazy_trace_(current_attribute) + if enable_cuda_graph: + modified_attribute = imported_modules["make_dynamic_graphed_callable"](modified_attribute) + + attr_chain = smash_config["fn_to_compile"].split(".") + parent_attr = model + for attr in attr_chain[:-1]: + parent_attr = getattr(parent_attr, attr) + setattr(parent_attr, attr_chain[-1], modified_attribute) + + +def process_model( + model: Any, + config: Any, + smash_config: SmashConfigPrefixWrapper, + enable_cuda_graph: bool, + imported_modules: Dict[str, Any], +) -> None: + """ + Update the given model by applying various optimizations and transformations. + + This function applies lazy tracing, dynamic graphing, xformers, and memory format optimizations + to the model based on the provided configuration. It can handle both callable functions and + torch.nn.Module instances, and recursively processes nested model attributes. + + Parameters + ---------- + model : Any + The model or function to be processed. Can be a callable or a torch.nn.Module instance. + config : Any + Configuration object containing settings for various optimizations. + smash_config : SmashConfigPrefixWrapper + Configuration dictionary for model smashing, including the function to compile. + enable_cuda_graph : bool + Flag indicating whether to enable CUDA graph. + imported_modules : Dict[str, Any] + Dictionary containing the imported modules. + + Returns + ------- + None + The function modifies the model in-place and doesn't return a value. + + Notes + ----- + - For callable models that are not torch.nn.Module instances, it applies lazy tracing + and potentially dynamic graphing. + - For torch.nn.Module instances with the specified compile function, it applies + various optimizations including xformers and memory format changes. + - Recursively processes nested model attributes if the model doesn't match the above criteria. + - Silently returns if an exception occurs during recursive processing. + """ + if hasattr(model, smash_config["fn_to_compile"]) and isinstance(model, torch.nn.Module): + if hasattr(model, "eval"): + model.eval() + + if config.enable_xformers: + imported_modules["_enable_xformers"](model) + if config.memory_format is not None: + imported_modules["apply_memory_format"](model, memory_format=config.memory_format) + apply_lazy_tracing_and_dynamic_graphing(model, config, smash_config, enable_cuda_graph, imported_modules) + else: + # Recursively process model attributes + try: + for attribute_name, attribute_value in vars(model).items(): + process_model( + attribute_value, + config, + smash_config, + enable_cuda_graph, + imported_modules, + ) + except Exception: + # Model is not an object, cannot recurse + return + + +class XFastHelper: + """ + A compiler class to process models using various optimizations such as xformers, Triton, and CUDA graphing. + + Parameters + ---------- + imported_modules : Dict[str, Any] + Dictionary containing the imported modules. + """ + + def __init__(self, imported_modules: Dict[str, Any]) -> None: + self.imported_modules = imported_modules + + def compile(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: + """ + Compile the model by applying optimizations based on the smash configuration. + + Parameters + ---------- + model : Any + The model or function to compile. + smash_config : SmashConfigPrefixWrapper + The configuration for the compilation process. + + Returns + ------- + Any + The compiled model. + """ + config = self.imported_modules["CompilationConfig"].Default() + + try: + import xformers # noqa: F401 + + if smash_config["xformers"]: + config.enable_xformers = True + except ImportError: + pruna_logger.info("xformers not installed, skip") + + try: + import triton # noqa: F401 + + if "CausalLM" in type(model).__name__: + config.enable_triton = False + else: + config.enable_triton = True + except ImportError: + pruna_logger.info("Triton not installed, skip") + + config.enable_cuda_graph = True + + device = ( + model.device if hasattr(model, "device") else torch.device("cuda" if torch.cuda.is_available() else "cpu") + ) + + if hasattr(model, "to"): + model = model.to(device) + + enable_cuda_graph = config.enable_cuda_graph + + process_model(model, config, smash_config, enable_cuda_graph, self.imported_modules) + + return model + + @staticmethod + def process_function(function: Callable[..., Any]) -> Callable[..., Any] | None: + """ + For internal use only. Process the given function by applying various optimizations and transformations. + + Parameters + ---------- + function : Callable + The function to be processed. + + Returns + ------- + Callable + The processed function. + """ + try: + import sfast + from sfast.compilers.diffusion_pipeline_compiler import ( + CompilationConfig, + _build_lazy_trace, + ) + from sfast.cuda.graphs import ( + make_dynamic_graphed_callable, # apply_auto_graph_compiler, + ) + + sfast.cuda.graphs.get_cuda_device_from_tensors = get_cuda_device_from_tensors + except ImportError: + pruna_logger.error( + "You are trying to use XFast compiler, but sfast is not installed. " + "This is likely because you did not install the GPU version of Pruna." + ) + raise + + config = CompilationConfig.Default() + + try: + import xformers # noqa: F401 + + config.enable_xformers = True + except ImportError: + pruna_logger.info("xformers not installed, skip") + + try: + import triton # noqa: F401 + + if "CausalLM" in type(function).__name__: + config.enable_triton = False + else: + config.enable_triton = True + except ImportError: + pruna_logger.info("Triton not installed, skip") + + config.enable_cuda_graph = True + + if callable(function) and not isinstance(function, torch.nn.Module): + lazy_trace_ = _build_lazy_trace( + config, + enable_triton_reshape=config.enable_cuda_graph, + enable_triton_layer_norm=config.enable_cuda_graph, + ) + function = lazy_trace_(function) + if config.enable_cuda_graph: + function = make_dynamic_graphed_callable(function) + return function + return None diff --git a/tests/algorithms/testers/ipex_llm.py b/tests/algorithms/testers/ipex_llm.py new file mode 100644 index 00000000..e3a40005 --- /dev/null +++ b/tests/algorithms/testers/ipex_llm.py @@ -0,0 +1,17 @@ +import pytest + +from pruna.algorithms.ipex_llm import IPEXLLM + +from .base_tester import AlgorithmTesterBase + + +# this prevents the test from running on GitHub Actions, which does not reliably provide Intel CPUs +@pytest.mark.high_cpu +class TestIPEXLLM(AlgorithmTesterBase): + """Test the IPEX LLM algorithm.""" + + models = ["opt_tiny_random"] + reject_models = ["sd_tiny_random"] + allow_pickle_files = False + algorithm_class = IPEXLLM + metrics = ["latency"] diff --git a/tests/algorithms/testers/x_fast.py b/tests/algorithms/testers/x_fast.py new file mode 100644 index 00000000..36670614 --- /dev/null +++ b/tests/algorithms/testers/x_fast.py @@ -0,0 +1,13 @@ +from pruna.algorithms.x_fast import XFast + +from .base_tester import AlgorithmTesterBase + + +class TestXFast(AlgorithmTesterBase): + """Test the X-Fast algorithm.""" + + models = ["opt_tiny_random"] + reject_models = ["dummy_lambda"] + allow_pickle_files = False + algorithm_class = XFast + metrics = ["latency"] diff --git a/tests/common.py b/tests/common.py index 3aa509f5..9da2d88b 100644 --- a/tests/common.py +++ b/tests/common.py @@ -344,4 +344,4 @@ def extract_code_blocks_from_node(node: Any, section_name: str) -> None: section_title = section_title_node.astext().replace(" ", "_").lower() extract_code_blocks_from_node(sec, section_title) - print(f"Code blocks extracted and written to {output_dir}") + print(f"Code blocks extracted and written to {output_dir}") \ No newline at end of file From b65f4aab4c135d8f38473af37b4ed1e7382ea00d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 19 Nov 2025 17:11:07 +0000 Subject: [PATCH 2/5] fixed linting errors --- src/pruna/algorithms/ipex_llm.py | 10 ++++++---- src/pruna/algorithms/x_fast.py | 1 + 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/src/pruna/algorithms/ipex_llm.py b/src/pruna/algorithms/ipex_llm.py index b5f79508..e68436f6 100644 --- a/src/pruna/algorithms/ipex_llm.py +++ b/src/pruna/algorithms/ipex_llm.py @@ -13,6 +13,7 @@ # limitations under the License. import os import re +from pathlib import Path from typing import Any, Dict import torch @@ -23,6 +24,7 @@ from pruna.engine.save import SAVE_FUNCTIONS from pruna.logging.logger import pruna_logger + class IPEXLLM(PrunaAlgorithmBase): """ Implement IPEX LLM compilation using the intel library. @@ -81,12 +83,12 @@ def model_check_fn(self, model: Any) -> bool: """ imported_modules = self.import_algorithm_packages() # Find the installation path of ipex - ipex_path = os.path.dirname(imported_modules["ipex"].__file__) + ipex_path = Path(imported_modules["ipex"].__file__).parent # Try to find the models.py file - transformers_path = os.path.join(ipex_path, "transformers") + transformers_path = ipex_path / "transformers" # Find the full path of models.py if it exists - models_path = os.path.join(transformers_path, "models", "reference", "models.py") - if os.path.exists(models_path): + models_path = transformers_path / "models" / "reference" / "models.py" + if models_path.exists(): # Read the function names from the file with open(models_path, "r") as f: content = f.read() diff --git a/src/pruna/algorithms/x_fast.py b/src/pruna/algorithms/x_fast.py index 9de0d7b5..0f912fa0 100644 --- a/src/pruna/algorithms/x_fast.py +++ b/src/pruna/algorithms/x_fast.py @@ -26,6 +26,7 @@ from pruna.engine.save import SAVE_FUNCTIONS from pruna.logging.logger import pruna_logger + class XFast(PrunaAlgorithmBase): """ Implement X-Fast compilation using the sfast library. From 2ea8107e9e348d5515fd7d4e20ee950ad475714f Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 19 Nov 2025 17:27:54 +0000 Subject: [PATCH 3/5] fixing more linting errors --- src/pruna/algorithms/ipex_llm.py | 2 +- src/pruna/algorithms/x_fast.py | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pruna/algorithms/ipex_llm.py b/src/pruna/algorithms/ipex_llm.py index e68436f6..51732d33 100644 --- a/src/pruna/algorithms/ipex_llm.py +++ b/src/pruna/algorithms/ipex_llm.py @@ -11,13 +11,13 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os import re from pathlib import Path from typing import Any, Dict import torch from ConfigSpace import OrdinalHyperparameter + from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag from pruna.config.smash_config import SmashConfigPrefixWrapper diff --git a/src/pruna/algorithms/x_fast.py b/src/pruna/algorithms/x_fast.py index 0f912fa0..b583f898 100644 --- a/src/pruna/algorithms/x_fast.py +++ b/src/pruna/algorithms/x_fast.py @@ -19,6 +19,7 @@ import torch from ConfigSpace import Constant + from pruna.algorithms.base.pruna_base import PrunaAlgorithmBase from pruna.algorithms.base.tags import AlgorithmTag from pruna.config.hyperparameters import Boolean From 9c1572e14a4057cfc33787653cd8c6c25e7322c3 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Thu, 11 Dec 2025 14:54:44 +0000 Subject: [PATCH 4/5] removing uv.lock --- pyproject.toml | 12 +++++++++++- src/pruna/algorithms/x_fast.py | 11 +---------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 99fc6323..5c9d5d20 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -70,6 +70,13 @@ name = "pruna_internal" url = "https://prunaai.pythonanywhere.com/simple/" # Pruna Pythonanywhere default = true # default = True makes this index the lowest prio +[[tool.uv.index]] +name = "intel-pytorch-extension" +url = "https://pytorch-extension.intel.com/release-whl/stable/cpu/cn/" + +[tool.uv] +index-strategy = "unsafe-best-match" + [tool.uv.sources] gptqmodel = [ { index = "pruna_internal", marker = "sys_platform != 'darwin' or platform_machine != 'arm64'"}, @@ -180,12 +187,15 @@ dev = [ "pytest-rerunfailures", "coverage", "docutils", - "ty==0.0.1a21", + "ty", "types-PyYAML", "logbar", "pytest-xdist>=3.8.0", ] cpu = [] +intel = [ + "intel-extension-for-pytorch==2.7.0", +] [build-system] requires = ["hatchling"] diff --git a/src/pruna/algorithms/x_fast.py b/src/pruna/algorithms/x_fast.py index b583f898..e1b8ac8f 100644 --- a/src/pruna/algorithms/x_fast.py +++ b/src/pruna/algorithms/x_fast.py @@ -46,16 +46,7 @@ class XFast(PrunaAlgorithmBase): dataset_required: bool = False compatible_before: list[str | AlgorithmTag] = [ "quanto", - "half", - "text_to_text_lora", - "text_to_image_lora", - "text_to_image_perp", - "text_to_image_inplace_perp", - "text_to_text_perp", - "text_to_text_inplace_perp", - "text_to_image_distillation_inplace_perp", - "text_to_image_distillation_lora", - "text_to_image_distillation_perp", + "half" ] required_install: str = "``pip install pruna[stable-fast]``" From e8fe1e327def10496ccc4feb5e9bbff97315c211 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Wed, 17 Dec 2025 14:24:57 +0000 Subject: [PATCH 5/5] fixing linting error --- src/pruna/algorithms/ipex_llm.py | 12 ++++++------ src/pruna/algorithms/x_fast.py | 20 ++++++++++---------- 2 files changed, 16 insertions(+), 16 deletions(-) diff --git a/src/pruna/algorithms/ipex_llm.py b/src/pruna/algorithms/ipex_llm.py index 51732d33..ae716e42 100644 --- a/src/pruna/algorithms/ipex_llm.py +++ b/src/pruna/algorithms/ipex_llm.py @@ -83,7 +83,7 @@ def model_check_fn(self, model: Any) -> bool: """ imported_modules = self.import_algorithm_packages() # Find the installation path of ipex - ipex_path = Path(imported_modules["ipex"].__file__).parent + ipex_path = Path(imported_modules["ipex"].__file__).parent # type: ignore[attr-defined] # Try to find the models.py file transformers_path = ipex_path / "transformers" # Find the full path of models.py if it exists @@ -122,11 +122,11 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: weight_dtype = woq_weight_dtype.INT8 if smash_config["weight_bits"] == 8 else woq_weight_dtype.INT4 - lowp_mode = ipex.quantization.WoqLowpMode.INT8 + lowp_mode = ipex.quantization.WoqLowpMode.INT8 # type: ignore[attr-defined] - qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(weight_dtype=weight_dtype, lowp_mode=lowp_mode) + qconfig = ipex.quantization.get_weight_only_quant_qconfig_mapping(weight_dtype=weight_dtype, lowp_mode=lowp_mode) # type: ignore[attr-defined] - model = ipex.llm.optimize( + model = ipex.llm.optimize( # type: ignore[attr-defined] model.eval(), dtype=getattr(torch, "float32"), quantization_config=qconfig, @@ -147,8 +147,8 @@ def import_algorithm_packages(self) -> Dict[str, Any]: The algorithm packages. """ # Import necessary modules here to avoid unnecessary imports and ensure they're available when needed - import intel_extension_for_pytorch as ipex - from intel_extension_for_pytorch.quantization import WoqWeightDtype + import intel_extension_for_pytorch as ipex # type: ignore[import-untyped] + from intel_extension_for_pytorch.quantization import WoqWeightDtype # type: ignore[import-untyped] return dict( ipex=ipex, diff --git a/src/pruna/algorithms/x_fast.py b/src/pruna/algorithms/x_fast.py index e1b8ac8f..d1b056a3 100644 --- a/src/pruna/algorithms/x_fast.py +++ b/src/pruna/algorithms/x_fast.py @@ -103,7 +103,7 @@ def _apply(self, model: Any, smash_config: SmashConfigPrefixWrapper) -> Any: model.x_fast_compiler = XFastHelper(imported_modules) if smash_config["torch_dynamic"]: - model = torch.quantization.quantize_dynamic( + model = torch.quantization.quantize_dynamic( # type: ignore[attr-defined] model, # Input model {torch.nn.Linear}, dtype=torch.qint8, @@ -123,18 +123,18 @@ def import_algorithm_packages(self) -> Dict[str, Any]: Dict[str, Any] The imported modules. """ - import sfast - from sfast.compilers.diffusion_pipeline_compiler import ( + import sfast # type: ignore[import-untyped] + from sfast.compilers.diffusion_pipeline_compiler import ( # type: ignore[import-untyped] CompilationConfig, _build_lazy_trace, _enable_xformers, ) - from sfast.cuda.graphs import ( + from sfast.cuda.graphs import ( # type: ignore[import-untyped] make_dynamic_graphed_callable, # apply_auto_graph_compiler, ) - from sfast.utils.memory_format import apply_memory_format + from sfast.utils.memory_format import apply_memory_format # type: ignore[import-untyped] - sfast.cuda.graphs.get_cuda_device_from_tensors = get_cuda_device_from_tensors + sfast.cuda.graphs.get_cuda_device_from_tensors = get_cuda_device_from_tensors # type: ignore[attr-defined] return dict( CompilationConfig=CompilationConfig, @@ -404,16 +404,16 @@ def process_function(function: Callable[..., Any]) -> Callable[..., Any] | None: The processed function. """ try: - import sfast - from sfast.compilers.diffusion_pipeline_compiler import ( + import sfast # type: ignore[import-untyped] + from sfast.compilers.diffusion_pipeline_compiler import ( # type: ignore[import-untyped] CompilationConfig, _build_lazy_trace, ) - from sfast.cuda.graphs import ( + from sfast.cuda.graphs import ( # type: ignore[import-untyped] make_dynamic_graphed_callable, # apply_auto_graph_compiler, ) - sfast.cuda.graphs.get_cuda_device_from_tensors = get_cuda_device_from_tensors + sfast.cuda.graphs.get_cuda_device_from_tensors = get_cuda_device_from_tensors # type: ignore[attr-defined] except ImportError: pruna_logger.error( "You are trying to use XFast compiler, but sfast is not installed. "