diff --git a/auto_round/__main__.py b/auto_round/__main__.py index d946556aa..d127cf571 100644 --- a/auto_round/__main__.py +++ b/auto_round/__main__.py @@ -138,6 +138,7 @@ def __init__(self, *args, **kwargs): basic.add_argument("--low_cpu_mem_usage", action="store_true", help="Lower CPU memory mode. Defaults to False.") basic.add_argument( "--format", + "--formats", default="auto_round", type=str, help="Output format for the quantized model." @@ -466,7 +467,7 @@ def list_item(): args = argparse.ArgumentParser() args.add_argument("item", type=str, help="item to list, e.g., format") args = args.parse_args() - if args.item == "format": + if args.item == "format" or args.item == "formats": from auto_round.formats import OutputFormat print("AutoRound supported output formats and quantization scheme:") diff --git a/auto_round/compressors/base.py b/auto_round/compressors/base.py index ccb12cf42..e94fc21d5 100644 --- a/auto_round/compressors/base.py +++ b/auto_round/compressors/base.py @@ -110,6 +110,39 @@ ) from auto_round.wrapper import WrapperLinear, WrapperMultiblock, unwrapper_block, unwrapper_layer, wrapper_block +SERIALIZATION_KEYS = ( + "bits", + "act_bits", + "data_type", + "act_data_type", + "group_size", + "act_group_size", + "sym", + "act_sym", + "act_dynamic", + "amp", + "batch_size", + "enable_minmax_tuning", + "enable_norm_bias_tuning", + "enable_quanted_input", + "gradient_accumulate_steps", + "iters", + "lr", + "low_gpu_mem_usage", + "minmax_lr", + "nsamples", + "quant_block_list", + "regex_config", + "scale_dtype", + "seqlen", + "supported_types", + "static_attention_dtype", + "static_kv_dtype", + "super_bits", + "super_group_size", + "to_quant_block_names", +) + class BaseCompressor(object): """Base compressor for LLM quantization @@ -1105,35 +1138,17 @@ def _quantize_layer_via_rtn(self, name: str, dtype: torch.dtype = None, to_cpu=T def _immediate_pack(self, name: str): if not self.immediate_packing: return - m = get_module(self.model, name) - if not check_to_quantized(m): - return - from auto_round.export import PACKING_LAYER_WITH_FORMAT - - target_backend = self.formats[0].output_format - has_gguf = any(fmt.is_gguf() for fmt in self.formats) - - if has_gguf: - from auto_round.export.export_to_gguf.export import pack_gguf_layer - - output_dir = self._get_save_folder_name(self.formats[0]) - model_type = ModelType.MMPROJ if self.mllm else ModelType.TEXT - pack_gguf_layer( - name, - self.model, - self.formats[0].get_backend_name(), - output_dir, - self.layer_config, - self.tokenizer, - processor=self.processor if hasattr(self, "processor") else None, - image_processor=self.image_processor if hasattr(self, "image_processor") else None, - model_type=model_type, - device=self.device, - ) - else: - PACKING_LAYER_WITH_FORMAT[target_backend]( - name, self.model, self.formats[0].get_backend_name(), device=self.device - ) + self.formats[0].immediate_pack( + name=name, + model=self.model, + device=self.device, + output_dir=self._get_save_folder_name(self.formats[0]), + mllm=self.mllm, + layer_config=self.layer_config, + tokenizer=self.tokenizer, + processor=self.processor if hasattr(self, "processor") else None, + image_processor=self.image_processor if hasattr(self, "image_processor") else None, + ) @torch.inference_mode() def _quantize_rtn(self) -> tuple[torch.nn.Module, dict[str, Any]]: @@ -2931,98 +2946,28 @@ def save_quantized( folders = [] for format in formats: save_folder = self._get_save_folder_name(format) - if format.is_fake(): # TODO fix act quantization later - self.model = self.model.to("cpu") - self.model.save_pretrained(output_dir) - if self.tokenizer is not None and hasattr(self.tokenizer, "save_pretrained"): - self.tokenizer.save_pretrained(output_dir) - processor = kwargs.get("processor", None) - if processor is not None: - processor.save_pretrained(output_dir) - try: - copy_python_files_from_model_cache(self.model, output_dir) - except Exception as e: - logger.warning("Skipping source model Python file copy due to error: %s", e) - compressed_model = self.model - continue if self.act_bits <= 8 and format.is_fake(): logger.warning( "Support for exporting activation quantization is limited. " "Please ensure that your configuration is supported." ) - from auto_round.export import EXPORT_FORMAT - - backend = format.get_backend_name() - output_format = format.output_format - if output_format not in EXPORT_FORMAT: - raise ValueError(f"export format only supports {EXPORT_FORMAT.keys()}, but got {output_format}") - save_quantized_as_format = EXPORT_FORMAT.get(output_format) - serialization_keys = [ - "bits", - "group_size", - "sym", - "data_type", - "enable_quanted_input", - "enable_minmax_tuning", - "seqlen", - "batch_size", - "scale_dtype", - "lr", - "minmax_lr", - "gradient_accumulate_steps", - "iters", - "amp", - "nsamples", - "low_gpu_mem_usage", - "to_quant_block_names", - "enable_norm_bias_tuning", - "act_bits", - "act_group_size", - "act_sym", - "act_dynamic", - "act_data_type", - "super_bits", - "super_group_size", - "regex_config", - "static_kv_dtype", - "static_attention_dtype", - ] - if isinstance(self.dataset, str): - serialization_keys.append("dataset") + serialization_dict = {} - for key in serialization_keys: + for key in SERIALIZATION_KEYS: serialization_dict[key] = getattr(self, key) from auto_round.version import __version__ serialization_dict["autoround_version"] = __version__ if "scale_dtype" in serialization_dict.keys(): serialization_dict["scale_dtype"] = str(serialization_dict["scale_dtype"]) - compressed_model = save_quantized_as_format( # TODO refine the code + compressed_model = format.save_quantized( save_folder, model=self.model, layer_config=self.layer_config, inplace=inplace, - bits=self.bits, - act_bits=self.act_bits, - group_size=self.group_size, - sym=self.sym, - iters=self.iters, - lr=self.lr, - minmax_lr=self.minmax_lr, - enable_minmax_tuning=self.enable_minmax_tuning, - enable_quanted_input=self.enable_quanted_input, - scale_dtype=self.scale_dtype, tokenizer=self.tokenizer, - supported_types=self.supported_types, - data_type=self.data_type, - act_data_type=self.act_data_type, - serialization_dict=serialization_dict, - backend=backend, - to_quant_block_names=self.to_quant_block_names, - quant_block_list=self.quant_block_list, device=self.device, - static_kv_dtype=self.static_kv_dtype, - static_attention_dtype=self.static_attention_dtype, + serialization_dict=serialization_dict, **kwargs, ) folders.append(save_folder) diff --git a/auto_round/compressors/mllm/README.md b/auto_round/compressors/mllm/README.md index 58d56edd2..7963ef8e0 100644 --- a/auto_round/compressors/mllm/README.md +++ b/auto_round/compressors/mllm/README.md @@ -7,28 +7,19 @@ adjustments to default parameters ### API Usage (Gaudi2/CPU/GPU) Recommended -By default, AutoRoundMLLM only quantizes the text module of VLMs and uses `NeelNanda/pile-10k` for calibration. To +By default, AutoRound only quantizes the text module of VLMs and uses `NeelNanda/pile-10k` for calibration. To quantize the entire model, you can enable `quant_nontext_module` by setting it to True, though support for this feature is limited. ```python -from auto_round import AutoRoundMLLM -from transformers import Qwen2VLForConditionalGeneration, AutoProcessor, AutoTokenizer +from auto_round import AutoRound # same as llm, AutoRound can determine mllm automatically -## load the model model_name = "Qwen/Qwen2-VL-2B-Instruct" -model = Qwen2VLForConditionalGeneration.from_pretrained(model_name, trust_remote_code=True) -tokenizer = AutoTokenizer.from_pretrained(model_name) -processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True) ## quantize the model -bits, group_size, sym = 4, 128, True -autoround = AutoRoundMLLM(model, tokenizer, processor, bits=bits, group_size=group_size, sym=sym) -autoround.quantize() - -# save the quantized model, set format='auto_gptq' to use AutoGPTQ format +autoround = AutoRound(model_name, scheme="W4A16", dataset="NeelNanda/pile-10k", quant_nontext_module=False) output_dir = "./tmp_autoround" -autoround.save_quantized(output_dir, format="auto_round", inplace=True) +autoround.quantize_and_save(output_dir, format="auto_round") ``` - `dataset`: the dataset for quantization training. Currently only support NeelNanda/pile-10k, llava_conv_58k, @@ -41,15 +32,14 @@ refer [Homepage Detailed Hyperparameters](../../README.md#api-usage-gaudi2cpugpu ### Basic Usage -A user guide detailing the full list of supported arguments is provided by calling ```auto-round-mllm -h``` on the +A user guide detailing the full list of supported arguments is provided by calling ```auto-round -h``` on the terminal. Set the format you want in `format` and multiple formats exporting has been supported. **Only five model families are supported now. ```bash -auto-round-mllm \ +auto-round \ --model Qwen/Qwen2-VL-2B-Instruct \ - --bits 4 \ - --group_size 128 \ + --scheme w4a16 \ --format "auto_round" \ --output_dir ./tmp_autoround ``` @@ -95,55 +85,9 @@ liuhaotian/llava_instruct_80k", "liuhaotian/llava_instruct_150k" or a file path - -
Nontext Module Quantization -### Support Matrix - -For most VLMs, we typically support the default quantization configuration, which involves quantizing only the language -component while excluding the visual component. Besides, we also support quantizing non-text modules of models that -follow the Hugging Face standard, i.e., those with a typical processor, though inference may have some issues due to -model architecture or kernel limitations. - - -| Model | calibration dataset | quant nontext module | Quantized Model Link | -|--------------------------------|---------------------|----------------------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| allenai/Molmo | pile | X | [Molmo-7B-D-0924-int4-sym](https://huggingface.co/OPEA/Molmo-7B-D-0924-int4-sym-inc), [Molmo-72B-0924-int4-sym-gptq](https://huggingface.co/OPEA/Molmo-72B-0924-int4-sym-gptq-inc), [Molmo-72B-0924-int4-sym](https://huggingface.co/OPEA/Molmo-72B-0924-int4-sym-inc) | -| deepseek-ai/deepseek-vl2 | pile/llava | √ | [deepseek-vl2-int4-sym-gptq](https://huggingface.co/OPEA/deepseek-vl2-int4-sym-gptq-inc) | -| google/gemma-3 | pile/llava | √ | [gemma-3-12b-it-AutoRound-gguf-q4-0](https://huggingface.co/OPEA/gemma-3-12b-it-AutoRound-gguf-q4-0), [gemma-3-27b-it-AutoRound-gguf-q4-0](https://huggingface.co/OPEA/gemma-3-27b-it-AutoRound-gguf-q4-0), [gemma-3-12b-it-int4-AutoRound](https://huggingface.co/OPEA/gemma-3-12b-it-int4-AutoRound), [gemma-3-27b-it-int4-AutoRound](https://huggingface.co/OPEA/gemma-3-27b-it-int4-AutoRound) | -| HuggingFaceTB/SmolVLM | pile/llava | √ | [SmolVLM-Instruct-int4-sym](https://huggingface.co/OPEA/SmolVLM-Instruct-int4-sym-inc) | -| ibm-granite/granite-vision-3.2 | pile/llava | - | | -| liuhaotian/Llava-v1.5 | pile/llava | X | [llava-v1.5-7b-int4-sym](https://huggingface.co/OPEA/llava-v1.5-7b-int4-sym-inc) | -| meta-llama/Llama-3.2-Vision | llava | √ | [Llama-3.2V-11B-cot-int4-sym](https://huggingface.co/OPEA/Llama-3.2V-11B-cot-int4-sym-inc), [Llama-3.2-11B-Vision-Instruct-qvision-int4-sym](https://huggingface.co/OPEA/Llama-3.2-11B-Vision-Instruct-qvision-int4-sym-inc), [Llama-3.2-90B-Vision-Instruct-int4-sym](https://huggingface.co/OPEA/Llama-3.2-90B-Vision-Instruct-int4-sym-inc), [Llama-3.2-11B-Vision-Instruct-int4-sym](https://huggingface.co/OPEA/Llama-3.2-11B-Vision-Instruct-int4-sym-inc) | -| microsoft/Phi3.5-Vision | pile/llava | √ | [Phi-3.5-vision-instruct-int4-sym](https://huggingface.co/OPEA/Phi-3.5-vision-instruct-int4-sym-inc), [Phi-3.5-vision-instruct-qvision-int4-sym](https://huggingface.co/OPEA/Phi-3.5-vision-instruct-qvision-int4-sym-inc) | -| mistralai/Mistral-Small-3.1 | pile/llava | X | [Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-gptq-sym](https://huggingface.co/OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-gptq-sym), [Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym](https://huggingface.co/OPEA/Mistral-Small-3.1-24B-Instruct-2503-int4-AutoRound-awq-sym) | -| moonshotai/Kimi-VL | pile/llava | √ | | -| Qwen/Qwen2-VL | pile/llava | - | [Qwen2-VL-7B-Instruct-int4-sym](https://huggingface.co/OPEA/Qwen2-VL-7B-Instruct-int4-sym-inc), [Qwen2-VL-72B-Instruct-int4-sym](https://huggingface.co/OPEA/Qwen2-VL-72B-Instruct-int4-sym-inc), [Qwen2-VL-72B-Instruct-int2-sym](https://huggingface.co/OPEA/Qwen2-VL-72B-Instruct-int2-sym-inc) | -| Qwen/Qwen2.5-VL | pile/llava | √ | | -| rhymes-ai/Aria | pile/llava | √ | | -| THUDM/CogVLM2 | pile/llava | √ | [cogvlm2-llama3-chat-19B-int4-sym](https://huggingface.co/OPEA/cogvlm2-llama3-chat-19B-int4-sym-inc), [cogvlm2-llama3-chat-19B-qvision-int4-sym](https://huggingface.co/OPEA/cogvlm2-llama3-chat-19B-qvision-int4-sym-inc) | -| THUDM/glm-4v | pile | X | [glm-4v-9b-int4-sym](https://huggingface.co/OPEA/glm-4v-9b-int4-sym-inc) | - -√ means support, - means support to export but cannot infer, X means not support. - -
-Calibration Dataset - -For mllm, we used **text-only** calibration dataset (NeelNanda/pile-10k) as our default. If the model type does not -support plain text calibration(e.g. Llama-3.2-vision), it will also automatically switch to llava dataset and adjust the -hyperparameters. - -Through argument --dataset(text file), user can use other datasets such as "liuhaotian/llava_conv_58k" " -liuhaotian/llava_instruct_80k", "liuhaotian/llava_instruct_150k" or a file path to use local file. - -
- - - -
-Nontext Module Quantization ### New Models Support #### Template diff --git a/auto_round/export/__init__.py b/auto_round/export/__init__.py index db20c94d8..8989ae9d7 100644 --- a/auto_round/export/__init__.py +++ b/auto_round/export/__init__.py @@ -11,82 +11,3 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from auto_round.export.register import EXPORT_FORMAT, PACKING_LAYER_WITH_FORMAT, register_format, register_layer_packing - - -@register_format("auto_gptq") -def _save_quantized_as_autogptq(*args, **kwargs): - from auto_round.export.export_to_autogptq.export import save_quantized_as_autogptq - - return save_quantized_as_autogptq(*args, **kwargs) - - -@register_format("itrex") -def _save_quantized_as_itrex(*args, **kwargs): - from auto_round.export.export_to_itrex.export import save_quantized_as_itrex - - return save_quantized_as_itrex(*args, **kwargs) - - -@register_format("itrex_xpu") -def _save_quantized_as_itrex_xpu(*args, **kwargs): - from auto_round.export.export_to_itrex.export import save_quantized_as_itrex_xpu - - return save_quantized_as_itrex_xpu(*args, **kwargs) - - -@register_format("auto_round") -def _save_quantized_as_autoround(*args, **kwargs): - from auto_round.export.export_to_autoround.export import save_quantized_as_autoround - - return save_quantized_as_autoround(*args, **kwargs) - - -@register_format("auto_awq") -def _save_quantized_as_autoawq(*args, **kwargs): - from auto_round.export.export_to_awq.export import save_quantized_as_autoawq - - return save_quantized_as_autoawq(*args, **kwargs) - - -@register_format("gguf") -def _save_quantized_as_gguf(*args, **kwargs): - from auto_round.export.export_to_gguf.export import save_quantized_as_gguf - - return save_quantized_as_gguf(*args, **kwargs) - - -@register_layer_packing("auto_round") -def _packing_layer_with_autoround(*args, **kwargs): - from auto_round.export.export_to_autoround.export import pack_layer - - return pack_layer(*args, **kwargs) - - -@register_layer_packing("auto_gptq") -def _packing_layer_with_autogptq(*args, **kwargs): - from auto_round.export.export_to_autogptq.export import pack_layer - - return pack_layer(*args, **kwargs) - - -@register_layer_packing("auto_awq") -def _packing_layer_with_autoawq(*args, **kwargs): - from auto_round.export.export_to_awq.export import pack_layer - - return pack_layer(*args, **kwargs) - - -@register_format("llm_compressor") -def _save_quantized_as_llmcompressor(*args, **kwargs): - from auto_round.export.export_to_llmcompressor.export import save_quantized_as_llmcompressor - - return save_quantized_as_llmcompressor(*args, **kwargs) - - -@register_layer_packing("llm_compressor") -def _packing_layer_with_llmcompressor(*args, **kwargs): - from auto_round.export.export_to_llmcompressor.export import pack_layer - - return pack_layer(*args, **kwargs) diff --git a/auto_round/export/export_to_autogptq/export.py b/auto_round/export/export_to_autogptq/export.py index 1f99d1ff7..d16a28f3b 100644 --- a/auto_round/export/export_to_autogptq/export.py +++ b/auto_round/export/export_to_autogptq/export.py @@ -18,7 +18,7 @@ import os from concurrent.futures import ThreadPoolExecutor from dataclasses import fields -from typing import Any, Dict +from typing import Any, Callable, Dict, Union import threadpoolctl as tctl @@ -190,18 +190,24 @@ def pack_layer(name, model, backend, device=None): release_layer_safely(layer) -def save_quantized_as_autogptq(output_dir, inplace=True, backend="auto_gptq:exllamav2", **kwargs): +def save_quantized_as_autogptq( + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + backend: str = "auto_gptq:exllamav2", + serialization_dict: dict = None, + **kwargs, +) -> torch.nn.Module: """Export the model to autogptq format to easily leverage cuda kernel.""" # --- 1️⃣ Extract inputs & configs --- - model = kwargs["model"] - quantization_config = kwargs["serialization_dict"] - layer_config = kwargs["layer_config"] - quant_block_list = kwargs.get("quant_block_list", get_block_names(model)) - tokenizer = kwargs.get("tokenizer") + quantization_config = serialization_dict + quant_block_list = serialization_dict.get("quant_block_list", get_block_names(model)) processor = kwargs.get("processor") image_processor = kwargs.get("image_processor") - device = kwargs.get("device") safe_serialization = kwargs.get("safe_serialization", True) # --- Save metadata (tokenizer, processor, etc.) --- diff --git a/auto_round/export/export_to_autoround/export.py b/auto_round/export/export_to_autoround/export.py index 2dfac16ef..a5176c1db 100644 --- a/auto_round/export/export_to_autoround/export.py +++ b/auto_round/export/export_to_autoround/export.py @@ -20,6 +20,7 @@ from concurrent.futures import ThreadPoolExecutor from dataclasses import fields from enum import Enum +from typing import Callable, Union import threadpoolctl as tctl import torch @@ -155,24 +156,6 @@ def pack_layer(layer_name, model, backend, device=None): Returns: None: The function modifies the model in place. """ - if is_nv_fp(backend) or is_mx_fp(backend): - from auto_round.export.export_to_autoround.export_to_nvfp_mxfp import pack_layer - - return pack_layer(layer_name, model, backend, device) - - if ( - backend == f"auto_round:{AutoRoundExportFormat.FP8.value}" - or backend == f"auto_round:{AutoRoundExportFormat.FP8_STATIC.value}" - ): - from auto_round.export.export_to_autoround.export_to_fp8 import pack_layer - - return pack_layer(layer_name, model, backend, device) - - if backend in ["auto_round:llm_compressor", f"auto_round:llm_compressor:{AutoRoundExportFormat.FP8_STATIC.value}"]: - from auto_round.export.export_to_llmcompressor.export_to_static_fp import pack_layer - - return pack_layer(layer_name, model, backend, device) - layer = get_module(model, layer_name) if hasattr(layer, "orig_layer"): layer = layer.orig_layer @@ -207,50 +190,45 @@ def pack_layer(layer_name, model, backend, device=None): out_features = layer.weight.shape[1] bias = layer.bias is not None - if "awq" not in backend: - new_layer = QuantLinear( ##pylint: disable=E1123 - bits, group_size, in_features, out_features, bias=bias, weight_dtype=layer.weight.dtype - ) - new_layer.device = orig_device - set_module(model, layer_name, new_layer) - qlayer = new_layer - import auto_round_extension.torch.qlinear_torch - - if ( - sym - and isinstance(zp, torch.Tensor) - and isinstance(QuantLinear, (auto_round_extension.torch.qlinear_torch.QuantLinear)) - ): - zp = int(zp.flatten()[0]) - - qlayer.to("cpu") - # Force to float32 to be compatible with torch 2.0 - sig = inspect.signature(qlayer.pack) - param_count = len(sig.parameters) - if param_count == 2: - qlayer.pack(layer, scale, device=device) - else: - qlayer.pack(layer, scale, zp, None, device=device) - qlayer.to(orig_device) + new_layer = QuantLinear( ##pylint: disable=E1123 + bits, group_size, in_features, out_features, bias=bias, weight_dtype=layer.weight.dtype + ) + new_layer.device = orig_device + set_module(model, layer_name, new_layer) + qlayer = new_layer + import auto_round_extension.torch.qlinear_torch + + if ( + sym + and isinstance(zp, torch.Tensor) + and isinstance(QuantLinear, (auto_round_extension.torch.qlinear_torch.QuantLinear)) + ): + zp = int(zp.flatten()[0]) + + qlayer.to("cpu") + # Force to float32 to be compatible with torch 2.0 + sig = inspect.signature(qlayer.pack) + param_count = len(sig.parameters) + if param_count == 2: + qlayer.pack(layer, scale, device=device) else: - scale = scale.to(torch.float32).t().contiguous() - if isinstance(zp, torch.Tensor): - zp = zp.to(torch.float32).t().contiguous() - if sym: - zp = int(zp.flatten()[0]) - - if bits != 4: - logger.error("AutoAWQ format only supports 4-bits quantization.") - qlayer = QuantLinear.from_linear( - linear=layer, w_bit=bits, group_size=group_size, init_only=False, scales=scale, zeros=zp, device=device - ) - qlayer.to(orig_device) - set_module(model, layer_name, qlayer) + qlayer.pack(layer, scale, zp, None, device=device) + qlayer.to(orig_device) # Note: release weight and bias explicitly, in case they are referenced elsewhere release_layer_safely(layer) -def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:exllamav2", **kwargs): +def save_quantized_as_autoround( + output_dir: str, + model: torch.nn.Module, + tokenizer: Callable = None, + layer_config: dict = None, + inplace=True, + backend="auto_round:exllamav2", + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, +): """ Saves a quantized model in the auto-round format. @@ -272,44 +250,26 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round:ex Raises: ValueError: If the backend is not supported. """ - data_type = kwargs.get("data_type", None) - if is_nv_fp(data_type) or is_mx_fp(data_type): ## detect nvfp & mxfp first - from auto_round.export.export_to_autoround.export_to_nvfp_mxfp import save_quantized_as_fp - - return save_quantized_as_fp(output_dir, inplace=inplace, backend="auto_round:llm_compressor", **kwargs) - - if backend in ["auto_round:llm_compressor", f"auto_round:llm_compressor:{AutoRoundExportFormat.FP8_STATIC.value}"]: - from auto_round.export.export_to_llmcompressor.export_to_static_fp import save_quantized_as_static_fp - - return save_quantized_as_static_fp(output_dir, inplace=inplace, backend="auto_round:llm_compressor", **kwargs) - - if kwargs.get("data_type", "int") == "fp" and kwargs.get("bits", 16) == 8 and kwargs.get("act_bits", 16) >= 16: - from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround - - return save_quantized_as_autoround(output_dir, inplace=inplace, backend="auto_round", **kwargs) - # IF using sym, we change to gptq sym kernel to avoid compiling from auto_round source if ( - (kwargs.get("sym") is None or kwargs.get("sym")) + (serialization_dict.get("sym") is None or serialization_dict.get("sym")) and ("gptq" not in backend and "awq" not in backend) and (AutoRoundExportFormat.FP8_STATIC.value not in backend) ): backend = backend.replace("auto_round", "auto_round:auto_gptq") - model = kwargs["model"] safe_serialization = True if "safe_serialization" not in kwargs.keys() else kwargs["safe_serialization"] if not inplace: model = copy.deepcopy(model.to("cpu")) - layer_config = kwargs["layer_config"] - quantization_config = kwargs["serialization_dict"] + quantization_config = serialization_dict quantization_config["block_name_to_quantize"] = quantization_config.pop("to_quant_block_names", None) quantization_config["quant_method"] = "auto-round" quantization_config["packing_format"] = backend - device = kwargs.get("device", None) - tokenizer = kwargs.get("tokenizer", None) + processor = kwargs.get("processor", None) image_processor = kwargs.get("image_processor", None) + extra_config = {} block_name_to_quantize = quantization_config["block_name_to_quantize"] if isinstance(block_name_to_quantize, str): diff --git a/auto_round/export/export_to_autoround/export_to_fp8.py b/auto_round/export/export_to_autoround/export_to_fp8.py index ccd823b12..09cb3d788 100644 --- a/auto_round/export/export_to_autoround/export_to_fp8.py +++ b/auto_round/export/export_to_autoround/export_to_fp8.py @@ -17,6 +17,7 @@ import os from concurrent.futures import ThreadPoolExecutor from dataclasses import fields +from typing import Callable, Union import threadpoolctl as tctl import torch @@ -145,24 +146,29 @@ def pack_layer(layer_name, model, data_type, device=None): release_layer_safely(layer) -def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round", **kwargs): - model = kwargs["model"] +def save_quantized_as_autoround( + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, +): safe_serialization = True if "safe_serialization" not in kwargs.keys() else kwargs["safe_serialization"] if not inplace: model = copy.deepcopy(model.to("cpu")) - layer_config = kwargs["layer_config"] - quantization_config = kwargs["serialization_dict"] + quantization_config = serialization_dict quantization_config["block_name_to_quantize"] = quantization_config.pop("to_quant_block_names", None) quantization_config["quant_method"] = "auto-round" - if "e5m2" in kwargs.get("data_type", "fp8"): + if "e5m2" in serialization_dict.get("data_type", "fp8"): quantization_config["fmt"] = "e5m2" else: quantization_config["fmt"] = "e4m3" quantization_config["activation_scheme"] = "dynamic" if quantization_config["act_dynamic"] else "static" - tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) - device = kwargs.get("device", None) image_processor = kwargs.get("image_processor", None) extra_config = {} block_name_to_quantize = quantization_config["block_name_to_quantize"] @@ -198,7 +204,7 @@ def save_quantized_as_autoround(output_dir, inplace=True, backend="auto_round", def wrapper(name): pbar.set_description(f"packing {name}") with tctl.threadpool_limits(limits=1): - pack_layer(name, model, kwargs.get("data_type", "fp8"), device) + pack_layer(name, model, serialization_dict.get("data_type", "fp8"), device) pbar.update(1) for _ in executor.map(wrapper, names): diff --git a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py index de3991c7f..7265941d3 100644 --- a/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py +++ b/auto_round/export/export_to_autoround/export_to_nvfp_mxfp.py @@ -18,6 +18,7 @@ import os from concurrent.futures import ThreadPoolExecutor from dataclasses import fields +from typing import Callable, Union import threadpoolctl as tctl import torch @@ -118,7 +119,17 @@ def pack_layer(name, model, backend, device=None): release_layer_safely(layer) -def save_quantized_as_fp(output_dir, inplace=True, **kwargs): +def save_quantized_as_fp( + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + backend: str = "autoround:exllamav2", + serialization_dict: dict = None, + **kwargs, +) -> torch.nn.Module: """ Saves a quantized model of mxfp/nvfp data_type in the auto-round format. @@ -140,24 +151,20 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): Raises: ValueError: If the backend is not supported. """ - model = kwargs["model"] - device = kwargs.get("device", None) - backend = kwargs.get("backend", None) - bits = kwargs.get("bits", None) - data_type = kwargs.get("data_type", None) - act_bits = kwargs.get("act_bits", None) - act_data_type = kwargs.get("act_data_type", None) + bits = serialization_dict.get("bits", None) + data_type = serialization_dict.get("data_type", None) + act_bits = serialization_dict.get("act_bits", None) + act_data_type = serialization_dict.get("act_data_type", None) safe_serialization = True if "safe_serialization" not in kwargs.keys() else kwargs["safe_serialization"] if not inplace: model = copy.deepcopy(model.to("cpu")) - layer_config = kwargs["layer_config"] - quantization_config = kwargs["serialization_dict"] + quantization_config = serialization_dict quantization_config["block_name_to_quantize"] = quantization_config.pop("to_quant_block_names", None) quantization_config["quant_method"] = "auto-round" quantization_config["packing_format"] = backend - tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) + image_processor = kwargs.get("image_processor", None) extra_config = {} if act_bits <= 8: @@ -247,6 +254,8 @@ def wrapper(name): if processor is not None: processor.save_pretrained(output_dir) + if image_processor is not None: + image_processor.save_pretrained(output_dir) dtype = None save_model(model, output_dir, safe_serialization=safe_serialization, dtype=dtype) diff --git a/auto_round/export/export_to_awq/export.py b/auto_round/export/export_to_awq/export.py index 0fa922833..d1328cd22 100644 --- a/auto_round/export/export_to_awq/export.py +++ b/auto_round/export/export_to_awq/export.py @@ -24,6 +24,7 @@ import json import os from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Union import threadpoolctl as tctl import torch @@ -79,13 +80,18 @@ def pack_layer(name, model, backend, device=None): release_layer_safely(layer) -def save_quantized_as_autoawq(output_dir, inplace=True, **kwargs): +def save_quantized_as_autoawq( + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, +) -> torch.nn.Module: """Export the model to autogptq format to easily leverage cuda kernel.""" - model = kwargs["model"] - layer_config = kwargs["layer_config"] - to_quant_block_names = kwargs.get("to_quant_block_names", None) - device = kwargs.get("device", None) - tokenizer = kwargs.get("tokenizer", None) + to_quant_block_names = serialization_dict.get("to_quant_block_names", None) processor = kwargs.get("processor", None) image_processor = kwargs.get("image_processor", None) modules_to_not_convert = [] @@ -136,13 +142,12 @@ def wrapper(name): if output_dir is None: return model - quantization_config = kwargs["serialization_dict"] + quantization_config = serialization_dict regex_config = quantization_config.pop("regex_config", {}) # awq do not support mixed bits config saving if output_dir is None: return compressed_model - layer_config = kwargs["layer_config"] for key in layer_config.keys(): if not check_to_quantized(layer_config[key]) and not any(name in key for name in modules_to_not_convert): modules_to_not_convert.append(key) diff --git a/auto_round/export/export_to_gguf/export.py b/auto_round/export/export_to_gguf/export.py index 3263ef63b..df7264dd4 100644 --- a/auto_round/export/export_to_gguf/export.py +++ b/auto_round/export/export_to_gguf/export.py @@ -209,19 +209,20 @@ def pack_gguf_layer( @torch.inference_mode() -def save_quantized_as_gguf(output_dir, backend="gguf:q4_0", layer_config=None, vlm=False, device="cpu", **kwargs): +def save_quantized_as_gguf( + output_dir, model=None, backend="gguf:q4_0", layer_config=None, mllm=False, device="cpu", **kwargs +): """Export the model to gguf format.""" st = time.time() global gguf_model_instance_global - model = kwargs["model"] if "gguf_model_instance_global" not in globals(): gguf_model_instance_global = [ create_model_class( output_dir, model, layer_config, backend, model_type=convert_hf_to_gguf.ModelType.TEXT, device=device ) ] - if vlm: + if mllm: gguf_model_instance_global.append( create_model_class( output_dir, diff --git a/auto_round/export/export_to_itrex/export.py b/auto_round/export/export_to_itrex/export.py index 3254e4630..280a83561 100644 --- a/auto_round/export/export_to_itrex/export.py +++ b/auto_round/export/export_to_itrex/export.py @@ -15,12 +15,11 @@ import copy import json import os -from typing import Dict, List, Optional, Union +from typing import Callable, Dict, List, Optional, Union import torch import transformers -from auto_round.export.register import register_format from auto_round.logger import logger from auto_round.utils import check_to_quantized, detect_device, get_module, set_module @@ -61,23 +60,32 @@ def quant_weight_w_scale(weight, scale, zp, group_size=-1, device="cpu"): return int_weight -def save_quantized_as_itrex(output_dir, inplace=True, **kwargs): +def save_quantized_as_itrex( + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, +) -> torch.nn.Module: """Save configure file and weights for CPU backend inference.""" - model = kwargs["model"] - layer_config = kwargs["layer_config"] - sym = kwargs["sym"] - bits = kwargs["bits"] - group_size = kwargs["group_size"] - iters = kwargs["iters"] - lr = kwargs["lr"] - minmax_lr = kwargs["minmax_lr"] - enable_minmax_tuning = kwargs["enable_minmax_tuning"] - enable_quanted_input = kwargs["enable_quanted_input"] - scale_dtype = kwargs["scale_dtype"] - tokenizer = kwargs["tokenizer"] + sym = serialization_dict["sym"] + bits = serialization_dict["bits"] + group_size = serialization_dict["group_size"] + iters = serialization_dict["iters"] + lr = serialization_dict["lr"] + minmax_lr = serialization_dict["minmax_lr"] + enable_minmax_tuning = serialization_dict["enable_minmax_tuning"] + enable_quanted_input = serialization_dict["enable_quanted_input"] + scale_dtype = serialization_dict["scale_dtype"] safe_serialization = True if "safe_serialization" not in kwargs.keys() else kwargs["safe_serialization"] - compressed_model = pack_model(model, layer_config, inplace=inplace) + processor = kwargs.get("processor", None) + image_processor = kwargs.get("image_processor", None) + + compressed_model = pack_model(model, layer_config, inplace=inplace, device=device) if output_dir is None: return compressed_model quantize_config = QuantConfig( @@ -101,27 +109,39 @@ def save_quantized_as_itrex(output_dir, inplace=True, **kwargs): compressed_model.save_pretrained(output_dir, safe_serialization=safe_serialization) if tokenizer is not None and hasattr(tokenizer, "save_pretrained"): tokenizer.save_pretrained(output_dir) + if processor is not None: + processor.save_pretrained(output_dir) + if image_processor is not None: + image_processor.save_pretrained(output_dir) logger.info("Saved config file and weights of quantized model to {}.".format(output_dir)) except IOError as e: # pragma: no cover logger.error("Fail to save configure file and weights due to {}.".format(e)) return compressed_model -def save_quantized_as_itrex_xpu(output_dir, inplace=True, **kwargs): +def save_quantized_as_itrex_xpu( + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, +) -> torch.nn.Module: """Save configure file and weights for XPU backend inference.""" - model = kwargs["model"] - layer_config = kwargs["layer_config"] - sym = kwargs["sym"] - bits = kwargs["bits"] - group_size = kwargs["group_size"] - iters = kwargs["iters"] - lr = kwargs["lr"] - minmax_lr = kwargs["minmax_lr"] - enable_minmax_tuning = kwargs["enable_minmax_tuning"] - enable_quanted_input = kwargs["enable_quanted_input"] - scale_dtype = kwargs["scale_dtype"] - tokenizer = kwargs.get("tokenizer", None) + sym = serialization_dict["sym"] + bits = serialization_dict["bits"] + group_size = serialization_dict["group_size"] + iters = serialization_dict["iters"] + lr = serialization_dict["lr"] + minmax_lr = serialization_dict["minmax_lr"] + enable_minmax_tuning = serialization_dict["enable_minmax_tuning"] + enable_quanted_input = serialization_dict["enable_quanted_input"] + scale_dtype = serialization_dict["scale_dtype"] + processor = kwargs.get("processor", None) + image_processor = kwargs.get("image_processor", None) compressed_model = pack_model(inplace=inplace, **kwargs) if output_dir is None: @@ -151,6 +171,8 @@ def save_quantized_as_itrex_xpu(output_dir, inplace=True, **kwargs): tokenizer.save_pretrained(output_dir) if processor is not None: processor.save_pretrained(output_dir) + if image_processor is not None: + image_processor.save_pretrained(output_dir) logger.info("Saved config file and weights of quantized model to {}.".format(output_dir)) except IOError as e: # pragma: no cover logger.error("Fail to save configure file and weights due to {}.".format(e)) diff --git a/auto_round/export/export_to_llmcompressor/export.py b/auto_round/export/export_to_llmcompressor/export.py index 41cd9b71c..ab7eae5d2 100644 --- a/auto_round/export/export_to_llmcompressor/export.py +++ b/auto_round/export/export_to_llmcompressor/export.py @@ -14,6 +14,7 @@ import copy import os +from typing import Callable, Union import torch @@ -46,40 +47,17 @@ def recover_qweight(qdq_weight, scale): return (qdq_weight / scale).to(torch.int8) -def pack_layer(layer_name, model, backend, device=None): - """ - Packs a model layer for quantization based on its type and configuration. - - This function retrieves the specified layer from the model, checks its - compatibility for quantization, and replaces it with a quantized version - if applicable. The quantization process depends on the layer's bit-width, - group size, symmetry, and activation bits. - - Args: - layer_name (str): The name of the layer to be packed. - model (torch.nn.Module): The model containing the layer. - backend (str): The backend framework to be used for quantization. - - Returns: - None: The function modifies the model in place. - """ - if is_nv_fp(backend) or is_mx_fp(backend): - from auto_round.export.export_to_llmcompressor.export_to_fp import pack_layer - - return pack_layer(layer_name, model, backend, device) - - if is_static_wfp8afp8(backend): - from auto_round.export.export_to_llmcompressor.export_to_static_fp import pack_layer - - return pack_layer(layer_name, model, backend, device) - - ## passed as no other llm_compressor format is supported yet - logger.warning("No other llm_compressor packing format(except NVFP&MXFP) is supported yet, skip packing") - return - - @torch.no_grad() -def save_quantized_as_llmcompressor(output_dir: str, inplace: bool = True, **kwargs) -> torch.nn.Module: +def save_quantized_as_llmcompressor( + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, +) -> torch.nn.Module: """ Save a quantized model in the LLM-Compressor format. @@ -104,16 +82,7 @@ def save_quantized_as_llmcompressor(output_dir: str, inplace: bool = True, **kwa torch.nn.Module: The quantized model that was saved. """ - backend = kwargs.get("backend", None) - if is_nv_fp(backend) or is_mx_fp(backend): - return save_quantized_as_fp(output_dir, inplace=inplace, **kwargs) - - if is_static_wfp8afp8(backend): - return save_quantized_as_static_fp(output_dir, **kwargs) - - model = kwargs.get("model", None) safe_serialization = kwargs.get("safe_serialization", True) - tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) if output_dir is not None and os.path.exists(output_dir): logger.warning(f"{output_dir} already exists, this may cause model conflict") @@ -127,7 +96,7 @@ def save_quantized_as_llmcompressor(output_dir: str, inplace: bool = True, **kwa processor.save_pretrained(output_dir) # generate q_weight - device = detect_device() + device = detect_device(device) for n, m in model.named_modules(): if isinstance(m, WrapperWALayer): m = m.orig_layer diff --git a/auto_round/export/export_to_llmcompressor/export_to_fp.py b/auto_round/export/export_to_llmcompressor/export_to_fp.py index 377d10f2c..e7835309f 100644 --- a/auto_round/export/export_to_llmcompressor/export_to_fp.py +++ b/auto_round/export/export_to_llmcompressor/export_to_fp.py @@ -17,6 +17,7 @@ import json import os from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Union import threadpoolctl as tctl import torch @@ -50,7 +51,7 @@ ] -def pack_layer(name, model, backend, device=None): +def pack_layer(name, model, device=None): layer = get_module(model, name) if type(layer) not in SUPPORTED_LAYER_TYPES and not isinstance(layer, WrapperWALayer): ##already packed return @@ -118,7 +119,17 @@ def pack_layer(name, model, backend, device=None): release_layer_safely(layer) -def save_quantized_as_fp(output_dir, inplace=True, **kwargs): +def save_quantized_as_fp( + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + backend: str = None, + serialization_dict: dict = None, + **kwargs, +) -> torch.nn.Module: """ Saves a quantized model of mxfp/nvfp data_type in the llm-compressor format. @@ -140,22 +151,15 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): Raises: ValueError: If the backend is not supported. """ - model = kwargs["model"] - backend = kwargs.get("backend", None) - bits = kwargs.get("bits", None) - data_type = kwargs.get("data_type", None) - act_bits = kwargs.get("act_bits", None) - act_data_type = kwargs.get("act_data_type", None) + bits = serialization_dict.get("bits", None) + data_type = serialization_dict.get("data_type", None) + act_bits = serialization_dict.get("act_bits", None) + act_data_type = serialization_dict.get("act_data_type", None) safe_serialization = True if "safe_serialization" not in kwargs.keys() else kwargs["safe_serialization"] if not inplace: model = copy.deepcopy(model.to("cpu")) - layer_config = kwargs["layer_config"] - device = kwargs.get("device", None) - tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) - ar_quantization_config = kwargs["serialization_dict"] - regex_config = ar_quantization_config.pop("regex_config") - layer_config = kwargs["layer_config"] + regex_config = serialization_dict.pop("regex_config") extra_config = {} if act_bits <= 8: @@ -195,7 +199,7 @@ def save_quantized_as_fp(output_dir, inplace=True, **kwargs): def wrapper(name): pbar.set_description(f"packing {name}") with tctl.threadpool_limits(limits=1): - pack_layer(name, model, backend, device) + pack_layer(name, model, device) pbar.update(1) for _ in executor.map(wrapper, names): diff --git a/auto_round/export/export_to_llmcompressor/export_to_static_fp.py b/auto_round/export/export_to_llmcompressor/export_to_static_fp.py index fd22a8579..c03403d3b 100644 --- a/auto_round/export/export_to_llmcompressor/export_to_static_fp.py +++ b/auto_round/export/export_to_llmcompressor/export_to_static_fp.py @@ -17,6 +17,7 @@ import os import sys from concurrent.futures import ThreadPoolExecutor +from typing import Callable, Union import threadpoolctl as tctl import torch @@ -144,7 +145,16 @@ def _use_fp8_kv(static_kv_dtype: str | None) -> bool: return False -def save_quantized_as_static_fp(output_dir: str, inplace: bool = True, **kwargs) -> torch.nn.Module: +def save_quantized_as_static_fp( + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, +) -> torch.nn.Module: """ Saves a quantized model of FP8_STATIC scheme in the llm-compressor format. @@ -166,15 +176,11 @@ def save_quantized_as_static_fp(output_dir: str, inplace: bool = True, **kwargs) Raises: ValueError: If the backend is not supported. """ - model = kwargs["model"] safe_serialization = True if "safe_serialization" not in kwargs.keys() else kwargs["safe_serialization"] if not inplace: model = copy.deepcopy(model.to("cpu")) - layer_config = kwargs["layer_config"] - tokenizer = kwargs.get("tokenizer", None) processor = kwargs.get("processor", None) - device = kwargs.get("device", None) image_processor = kwargs.get("image_processor", None) names = list(layer_config.keys()) @@ -188,7 +194,7 @@ def save_quantized_as_static_fp(output_dir: str, inplace: bool = True, **kwargs) def wrapper(name): pbar.set_description(f"packing {name}") with tctl.threadpool_limits(limits=1): - pack_layer(name, model, kwargs.get("data_type", "fp8"), device) + pack_layer(name, model, serialization_dict.get("data_type", "fp8"), device) pbar.update(1) for _ in executor.map(wrapper, names): @@ -205,17 +211,17 @@ def wrapper(name): QuantizationType, ) - group_size = kwargs["serialization_dict"]["group_size"] + group_size = serialization_dict["group_size"] if group_size == -1: strategy = QuantizationStrategy.CHANNEL elif group_size == 0: strategy = QuantizationStrategy.TENSOR else: strategy = QuantizationStrategy.GROUP - if kwargs["serialization_dict"]["act_group_size"] != 0: + if serialization_dict["act_group_size"] != 0: logger.error( f"scheme FP8_STATIC export to llm_compressor format only support for act_group_size 0," - f" but got {kwargs['serialization_dict']['act_group_size']}, please check." + f" but got {serialization_dict['act_group_size']}, please check." ) sys.exit(-1) scheme_args = dict( @@ -244,7 +250,9 @@ def wrapper(name): config_groups["group_0"] = scheme quantization_config = QuantizationConfig( config_groups=config_groups, - kv_cache_scheme=_construct_kv_scheme() if _use_fp8_kv(kwargs.get("static_kv_dtype", None)) else None, + kv_cache_scheme=( + _construct_kv_scheme() if _use_fp8_kv(serialization_dict.get("static_kv_dtype", None)) else None + ), quantization_status=QuantizationStatus.COMPRESSED, ignore=ignore, ) diff --git a/auto_round/export/register.py b/auto_round/export/register.py deleted file mode 100644 index 7e186f5b2..000000000 --- a/auto_round/export/register.py +++ /dev/null @@ -1,59 +0,0 @@ -# Copyright (c) 2023 Intel Corporation -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -EXPORT_FORMAT = {} - - -def register_format(name): - """Class decorator to register a EXPORT subclass to the registry. - - Decorator function used before a Pattern subclass. - - Args: - cls (class): The subclass of register. - name: A string. Define the export type. - - Returns: - cls: The class of register. - """ - - def register(format): - EXPORT_FORMAT[name] = format - return format - - return register - - -PACKING_LAYER_WITH_FORMAT = {} - - -def register_layer_packing(name): - """Class decorator to register a EXPORT subclass to the registry. - - Decorator function used before a Pattern subclass. - - Args: - cls (class): The subclass of register. - name: A string. Define the export type. - - Returns: - cls: The class of register. - """ - - def register(format): - PACKING_LAYER_WITH_FORMAT[name] = format - return format - - return register diff --git a/auto_round/export/utils.py b/auto_round/export/utils.py index 2a1787f7f..6f9490dad 100644 --- a/auto_round/export/utils.py +++ b/auto_round/export/utils.py @@ -193,6 +193,17 @@ def filter_quantization_config(quantization_config): quantization_config.pop("act_sym", None) quantization_config.pop("act_group_size", None) + clean_list = ("supported_types", "quant_block_list") + for key in list(quantization_config.keys()): + if callable(key): + quantization_config.pop(key) + elif isinstance(quantization_config[key], (list, tuple)): + if any([callable(item) for item in quantization_config[key]]): + quantization_config.pop(key) + if key in clean_list and key in quantization_config: + quantization_config.pop(key) + return quantization_config + def release_layer_safely(layer: nn.Module): """ diff --git a/auto_round/formats.py b/auto_round/formats.py index 77dcce458..6254f8e39 100644 --- a/auto_round/formats.py +++ b/auto_round/formats.py @@ -18,6 +18,7 @@ import os import re import sys +from abc import ABC, abstractmethod from dataclasses import asdict from enum import Enum from typing import TYPE_CHECKING, Callable, Union @@ -38,7 +39,13 @@ QuantizationScheme, get_gguf_scheme, ) -from auto_round.utils import SUPPORTED_FORMATS, logger +from auto_round.utils import ( + SUPPORTED_FORMATS, + check_to_quantized, + copy_python_files_from_model_cache, + get_module, + logger, +) class AutoRoundExportFormat(str, Enum): @@ -142,11 +149,11 @@ def _check_divisible_by_32(ar): logger.warning_once(f"{n} skipped quantization (shape not divisible by 32).") -class OutputFormat: +class OutputFormat(ABC): """ "Base class for different output formats. format: determines which method from export module to use for exporting. - For example, auto_round, gguf, llmcompressor etc. + For example, auto_round, gguf, llm_compressor etc. backend: determines the specific export process within the format. For example, auto_round:fp8_static, auto_round:auto_awq etc. """ @@ -181,12 +188,16 @@ def func(output_format: OutputFormat) -> OutputFormat: @classmethod def get_support_matrix(cls: OutputFormat) -> str: output_str = "" - for k, v in cls._format_list.items(): + for k, v in sorted(cls._format_list.items()): if k == "fake": - support_scheme = "All schemes" + support_schemes = "All schemes" else: - support_scheme = ", ".join(v.support_schemes).rstrip(",") - output_str += f"\x1b[31;1m{k}\x1b[0m support scheme:\n\t{support_scheme}\n" + if ":" in k and k.split(":")[1] in cls._format_list: + support_schemes = cls._format_list[k.split(":")[1]].support_schemes + else: + support_schemes = v.support_schemes + support_schemes = ", ".join(support_schemes).rstrip(",") + output_str += f"\x1b[31;1m{k}\x1b[0m support scheme:\n\t{support_schemes}\n" return output_str def get_backend_name(self) -> str: @@ -196,10 +207,7 @@ def get_backend_name(self) -> str: # auto_round:llm_compressor:fp8_static if self.backend.backend is not None: return f"{self.output_format}:{self.backend.get_backend_name()}" - # auto_round:auto_awq, auto_round:auto_gptq - elif self.backend.get_backend_name() in self._format_list: - return f"{self.output_format}:{self.backend.get_backend_name()}" - # auto_round:fp8_static, llm_compressor:fp8_static + # auto_round:fp8_static, llm_compressor:fp8_static, auto_round:auto_awq else: return self.backend.get_backend_name() @@ -237,6 +245,21 @@ def check_and_reset_format(self, ar: BaseCompressor) -> str: return None + @abstractmethod + def pack_layer(self, *args, **kwargs): + pass + + @abstractmethod + def save_quantized(self, *args, **kwargs): + pass + + def immediate_pack(self, name: str, model: torch.nn.Module, device: torch.device, **kwargs): + m = get_module(model, name) + if not check_to_quantized(m): + return + + self.pack_layer(name, model, device=device) + def is_gguf(self) -> bool: return "gguf" in self.output_format @@ -261,8 +284,36 @@ class FakeFormat(OutputFormat): def check_and_reset_format(self, ar: BaseCompressor) -> str: return None - -@OutputFormat.register("llm_compressor", "llmcompressor") + # fake format will not execute pack_layer. + def pack_layer(self, *args, **kwargs): + pass + + def save_quantized( + self, + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, + ): + model = model.to("cpu") + model.save_pretrained(output_dir) + if tokenizer is not None and hasattr(tokenizer, "save_pretrained"): + tokenizer.save_pretrained(output_dir) + processor = kwargs.get("processor", None) + if processor is not None: + processor.save_pretrained(output_dir) + try: + copy_python_files_from_model_cache(model, output_dir) + except Exception as e: + logger.warning("Skipping source model Python file copy due to error: %s", e) + return model + + +@OutputFormat.register("llm_compressor") class LLMCompressorFormat(OutputFormat): support_schemes = ["MXFP4", "MXFP8", "NVFP4", "FPW8A16", "FP8_STATIC"] format_name = "llm_compressor" @@ -274,7 +325,8 @@ def __init__(self, format, ar): f"but got scheme {ar.scheme}, please change to fake or auto_round etc." ) exit(-1) - if format.startswith("llm_compressor"): + # if format.startswith("llm_compressor"): + if re.search("^(auto_round:)?llm_compressor", format): self.output_format = format self.backend = None if is_nv_fp(ar.data_type) or is_mx_fp(ar.data_type): @@ -326,6 +378,58 @@ def check_and_reset_format(self, ar: BaseCompressor) -> str: return None return None + def pack_layer(self, layer_name, model, device=None, **kwargs): + if self.backend is not None: + return self.backend.pack_layer(layer_name, model, device=device, **kwargs) + if re.search(f"{AutoRoundExportFormat.MX_FP.value}|{AutoRoundExportFormat.NV_FP.value}", self.output_format): + from auto_round.export.export_to_llmcompressor.export_to_fp import pack_layer + + return pack_layer(layer_name, model, device=device) + elif re.search(f"{AutoRoundExportFormat.FP8_STATIC.value}", self.output_format): + from auto_round.export.export_to_llmcompressor.export_to_static_fp import pack_layer + + return pack_layer(layer_name, model, self.get_backend_name(), device=device) + + ## passed as no other llm_compressor format is supported yet + logger.warning("No other llm_compressor packing format(except NVFP&MXFP) is supported yet, skip packing") + return + + def save_quantized( + self, + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, + ) -> torch.nn.Module: + backend = self.get_backend_name() + if re.search(f"{AutoRoundExportFormat.MX_FP.value}|{AutoRoundExportFormat.NV_FP.value}", backend): + from auto_round.export.export_to_llmcompressor.export_to_fp import save_quantized_as_fp + + export_func = save_quantized_as_fp + elif re.search(f"{AutoRoundExportFormat.FP8_STATIC.value}", backend): + from auto_round.export.export_to_llmcompressor.export_to_static_fp import save_quantized_as_static_fp + + export_func = save_quantized_as_static_fp + else: + from auto_round.export.export_to_llmcompressor.export import save_quantized_as_llmcompressor + + export_func = save_quantized_as_llmcompressor + return export_func( + output_dir=output_dir, + model=model, + tokenizer=tokenizer, + layer_config=layer_config, + inplace=inplace, + device=device, + backend=backend, + serialization_dict=serialization_dict, + **kwargs, + ) + @OutputFormat.register("auto_gptq", "gptqmodel") class AutoGPTQFormat(OutputFormat): @@ -344,6 +448,48 @@ def check_and_reset_format(self, ar): _check_divisible_by_32(ar) return super().check_and_reset_format(ar) + def pack_layer(self, layer_name, model, device=None, **kwargs): + if self.output_format.startswith("auto_round"): + from auto_round.export.export_to_autoround.export import pack_layer + + pack_layer(layer_name, model, backend=self.output_format, device=device) + else: + from auto_round.export.export_to_autogptq.export import pack_layer + + pack_layer(layer_name, model, backend=self.output_format, device=device) + + def save_quantized( + self, + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, + ) -> torch.nn.Module: + backend = self.get_backend_name() + if backend == "auto_round:auto_gptq" or backend == "auto_round:gptqmodel": + from auto_round.export.export_to_autoround.export import save_quantized_as_autoround + + export_func = save_quantized_as_autoround + else: + from auto_round.export.export_to_autogptq.export import save_quantized_as_autogptq + + export_func = save_quantized_as_autogptq + return export_func( + output_dir=output_dir, + model=model, + tokenizer=tokenizer, + layer_config=layer_config, + inplace=inplace, + device=device, + backend=backend, + serialization_dict=serialization_dict, + **kwargs, + ) + @OutputFormat.register("auto_awq") class AutoAWQFormat(OutputFormat): @@ -406,6 +552,44 @@ def check_and_reset_format(self, ar): return super().check_and_reset_format(ar) + def pack_layer(self, layer_name, model, device=None, **kwargs): + from auto_round.export.export_to_awq.export import pack_layer + + pack_layer(layer_name, model, backend=self.output_format, device=device) + + def save_quantized( + self, + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, + ) -> torch.nn.Module: + backend = self.get_backend_name() + if backend == "auto_round:auto_awq": + from auto_round.export.export_to_autoround.export import save_quantized_as_autoround + + export_func = save_quantized_as_autoround + else: + from auto_round.export.export_to_awq.export import save_quantized_as_autoawq + + export_func = save_quantized_as_autoawq + + return export_func( + output_dir=output_dir, + model=model, + tokenizer=tokenizer, + layer_config=layer_config, + inplace=inplace, + backend=backend, + device=device, + serialization_dict=serialization_dict, + **kwargs, + ) + @OutputFormat.register("itrex") @OutputFormat.register("itrex_xpu") @@ -413,6 +597,41 @@ class ITREXFormat(OutputFormat): support_schemes = ["W4A16", "W2A16", "W3A16", "W8A16", "BF16", "W2A16G64", "W2A16G32"] format_name = "itrex" + def pack_layer(self, *args, **kwargs): + pass + + def save_quantized( + self, + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, + ) -> torch.nn.Module: + backend = self.get_backend_name() + if backend == "itrex": + from auto_round.export.export_to_itrex.export import save_quantized_as_itrex + + export_func = save_quantized_as_itrex + else: + from auto_round.export.export_to_itrex.export import save_quantized_as_itrex_xpu + + export_func = save_quantized_as_itrex_xpu + return export_func( + output_dir=output_dir, + model=model, + tokenizer=tokenizer, + layer_config=layer_config, + inplace=inplace, + device=device, + backend=backend, + serialization_dict=serialization_dict, + **kwargs, + ) + @OutputFormat.register("gguf") class GGUFFormat(OutputFormat): @@ -447,6 +666,7 @@ def __init__(self, format: str, ar: BaseCompressor): else: self.output_format = f"gguf:{format}" self.backend = None + self.mllm = ar.mllm def check_and_reset_format(self, ar): if ar.iters != 0 and ar.bits != 3 and not ar.enable_alg_ext: @@ -461,6 +681,59 @@ def check_and_reset_format(self, ar): return super().check_and_reset_format(ar) + def pack_layer( + self, + name, + model, + backend, + output_dir, + layer_config, + tokenizer, + processor=None, + image_processor=None, + model_type=ModelType.TEXT, + device="cpu", + ): + from auto_round.export.export_to_gguf.export import pack_gguf_layer + + pack_gguf_layer( + name, + model, + backend, + output_dir, + layer_config, + tokenizer, + processor, + image_processor, + model_type, + device, + ) + + def save_quantized( + self, + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, + ) -> torch.nn.Module: + from auto_round.export.export_to_gguf.export import save_quantized_as_gguf + + backend = self.get_backend_name() + return save_quantized_as_gguf( + output_dir=output_dir, + model=model, + backend=backend, + layer_config=layer_config, + mllm=self.mllm, + device=device, + serialization_dict=serialization_dict, + **kwargs, + ) + @staticmethod def gguf_args_check(args_or_ar, formats: Union[str, list[str]] = None, model_type=ModelType.TEXT): import argparse @@ -574,6 +847,36 @@ def gguf_args_check(args_or_ar, formats: Union[str, list[str]] = None, model_typ # Removed obsolete commented-out block for improved readability and maintainability. return args_or_ar + def immediate_pack( + self, + name: str, + model: torch.nn.Module, + device: torch.device, + output_dir: str = None, + mllm: bool = False, + layer_config: dict = None, + tokenizer=None, + processor=None, + image_processor=None, + **kwargs, + ): + m = get_module(model, name) + if not check_to_quantized(m): + return + model_type = ModelType.MMPROJ if mllm else ModelType.TEXT + self.pack_layer( + name, + model, + self.get_backend_name(), + output_dir, + layer_config=layer_config, + tokenizer=tokenizer, + processor=processor, + image_processor=image_processor, + model_type=model_type, + device=device, + ) + @OutputFormat.register("auto_round") @OutputFormat.register("auto_round:auto_awq") @@ -602,7 +905,7 @@ def __init__(self, format: str, ar: BaseCompressor): if format == "auto_round": if ar.sym and "int" in ar.data_type: - self.backend = AutoGPTQFormat("auto_gptq", ar) + self.backend = AutoGPTQFormat("auto_round:auto_gptq", ar) elif ar.bits == 4 and not ar.sym and "int" in ar.data_type: if ar.layer_config is None: enable_awq = True @@ -611,7 +914,7 @@ def __init__(self, format: str, ar: BaseCompressor): config["bits"] == ar.bits or config["bits"] >= 16 for config in ar.layer_config.values() ) if enable_awq: - self.backend = AutoAWQFormat("auto_awq", ar) + self.backend = AutoAWQFormat("auto_round:auto_awq", ar) elif is_nv_fp(ar.data_type) or is_mx_fp(ar.data_type): self.backend = AutoRoundFormat(ar.data_type, ar) elif is_static_wfp8afp8(ar): # static wfp8afp8 @@ -624,6 +927,7 @@ def __init__(self, format: str, ar: BaseCompressor): "for the current quantization configuration, " "please change to `fake` format for research purpose" ) + # for auto_round:fp8_static, auto_round:nv_fp etc. elif not format.startswith("auto_round"): if format.upper() not in list(AutoRoundExportFormat.__members__.keys()): raise KeyError(f"Unsupported backend format auto_round:{format}, please check") @@ -631,7 +935,7 @@ def __init__(self, format: str, ar: BaseCompressor): self.backend = None else: backend = format.split(":")[1] if ":" in format else None - self.backend = self._format_list.get(backend)(backend, ar) if backend else None + self.backend = self._format_list.get(backend)(format, ar) if backend else None if self.backend is not None: self.support_schemes = self.backend.support_schemes @@ -655,3 +959,85 @@ def check_and_reset_format(self, ar): if self.backend is None: _check_divisible_by_32(ar) return None + + def pack_layer(self, layer_name, model, device=None, **kwargs): + if self.backend is not None: + return self.backend.pack_layer(layer_name, model, device=device, **kwargs) + + backend = self.get_backend_name() + + if self.output_format in [ + f"auto_round:{AutoRoundExportFormat.NV_FP.value}", + f"auto_round:{AutoRoundExportFormat.MX_FP.value}", + f"auto_round:{AutoRoundExportFormat.MX_FP_RCEIL.value}", + f"auto_round:{AutoRoundExportFormat.NV_FP4_WITH_STATIC_GS.value}", + ]: + from auto_round.export.export_to_autoround.export_to_nvfp_mxfp import pack_layer + + pack_func = pack_layer + elif self.output_format in [ + f"auto_round:{AutoRoundExportFormat.FP8.value}", + f"auto_round:{AutoRoundExportFormat.FP8_STATIC.value}", + f"auto_round:{AutoRoundExportFormat.FP8_STATIC.value}", + ]: + from auto_round.export.export_to_autoround.export_to_fp8 import pack_layer + + pack_func = pack_layer + else: + from auto_round.export.export_to_autoround.export import pack_layer + + pack_func = pack_layer + return pack_func(layer_name, model, backend, device) + + def save_quantized( + self, + output_dir: str, + model: torch.nn.Module = None, + tokenizer: Callable = None, + layer_config: dict = None, + inplace: bool = True, + device: Union[str, torch.device] = "cpu", + serialization_dict: dict = None, + **kwargs, + ) -> torch.nn.Module: + if self.backend is not None: + return self.backend.save_quantized( + output_dir=output_dir, + model=model, + tokenizer=tokenizer, + layer_config=layer_config, + inplace=inplace, + device=device, + serialization_dict=serialization_dict, + **kwargs, + ) + backend = self.get_backend_name() + if re.search(f"{AutoRoundExportFormat.MX_FP.value}|{AutoRoundExportFormat.NV_FP.value}", backend): + from auto_round.export.export_to_autoround.export_to_nvfp_mxfp import save_quantized_as_fp + + backend = "auto_round:llm_compressor" + export_func = save_quantized_as_fp + elif ( + serialization_dict.get("data_type", "int") == "fp" + and serialization_dict.get("bits", 16) == 8 + and serialization_dict.get("act_bits", 16) >= 16 + ): + from auto_round.export.export_to_autoround.export_to_fp8 import save_quantized_as_autoround + + backend = "auto_round" + export_func = save_quantized_as_autoround + else: + from auto_round.export.export_to_autoround.export import save_quantized_as_autoround + + export_func = save_quantized_as_autoround + return export_func( + output_dir=output_dir, + model=model, + tokenizer=tokenizer, + layer_config=layer_config, + inplace=inplace, + device=device, + backend=backend, + serialization_dict=serialization_dict, + **kwargs, + ) diff --git a/docs/step_by_step.md b/docs/step_by_step.md index 17079cb80..8163ac394 100644 --- a/docs/step_by_step.md +++ b/docs/step_by_step.md @@ -130,6 +130,7 @@ AutoRound supports several Schemes: Besides, you could modify the `group_size`, `bits`, `sym` and many other configs you want, though there are maybe no real kernels. ### Supported export Formats +You can use command `auto_round list format` to show all supported formats with support scheme. **AutoRound Format**: This format is well-suited for CPU, Intel GPU, CUDA and HPU devices, 2 bits, as well as mixed-precision inference. **[2,3,4,8] bits are supported**. Please set `--format auto_round` @@ -147,6 +148,16 @@ adopted within the community, **only 4-bits quantization is supported**. Please **LLM-Compressor Format**: **NVFP4, MXFP4(kernel in WIP), MXFP8 are supported**. Please set `--format llm_compressor` +#### Format and scheme support matrix +|export format | supported scheme | +|--------------|------------------| +|**auto_round** | W4A16, W2A16, W3A16, W8A16, MXFP4, MXFP8, NVFP4, FPW8A16, W2A16G64, W2A16G32, FP8_STATIC, BF16| +|**auto_awq / auto_round:auto_awq** | W4A16, W2A16, W3A16, W8A16, BF16, W2A16G64, W2A16G32 | +|**auto_gptq / auto_round:auto_gptq / auto_round:gptqmodel**|W4A16, W2A16, W3A16, W8A16, BF16, W2A16G64, W2A16G32| +|**llm_compressor / auto_round:llm_compressor** | MXFP4, MXFP8, NVFP4, FPW8A16, FP8_STATIC | +|**gguf** | GGUF:Q4_0, GGUF:Q4_1, GGUF:Q5_0, GGUF:Q5_1, GGUF:Q2_K_S, GGUF:Q3_K_S, GGUF:Q3_K_M, GGUF:Q3_K_L, GGUF:Q4_K_S, GGUF:Q4_K_M, GGUF:Q5_K_S, GGUF:Q5_K_M, GGUF:Q6_K, GGUF:Q8_0 | +|**itrex / itrex_xpu** | W4A16, W2A16, W3A16, W8A16, BF16, W2A16G64, W2A16G32 | +|**fake** | all scheme| ### Hardware Compatibility CPU, Intel GPU, HPU and CUDA for both quantization and inference.