From 4d88df2aa0cb003c992eb8520ffd3513ba6bacdd Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 24 Apr 2025 10:38:02 +0000 Subject: [PATCH 01/22] install the current repo in dockerfile Signed-off-by: jiqing-feng --- docker/Dockerfile | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/docker/Dockerfile b/docker/Dockerfile index 9803ff8..4e3703e 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -35,12 +35,11 @@ RUN pip install --no-cache-dir \ ruff # Then install PyTorch-dependent packages with constraint to use existing torch -RUN pip install --no-cache-dir \ - --extra-index-url https://download.pytorch.org/whl/xpu \ - -C torch==2.6.0+xpu \ - transformers \ - accelerate \ - bitsandbytes +RUN pip install transformers accelerate bitsandbytes + +# Copy the bitsandbytes-intel repository into /workspace/src/bnb and install it. +COPY .. ${WORKSPACE}/src/bnb +RUN cd ${WORKSPACE}/src/bnb && pip install . COPY --chmod=755 docker/entrypoint.sh /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] From 59f8f970ce11da30347d804713a12a1b49c012e3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Tue, 29 Apr 2025 15:39:10 +0000 Subject: [PATCH 02/22] register cpu/xpu ops Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/cpu_xpu_common.py | 190 ++++++++++- src/bitsandbytes_intel/ops.py | 392 +++++++++++++++++++++-- 2 files changed, 548 insertions(+), 34 deletions(-) diff --git a/src/bitsandbytes_intel/cpu_xpu_common.py b/src/bitsandbytes_intel/cpu_xpu_common.py index 13d20ee..d9af7fa 100644 --- a/src/bitsandbytes_intel/cpu_xpu_common.py +++ b/src/bitsandbytes_intel/cpu_xpu_common.py @@ -1,10 +1,11 @@ import subprocess -from typing import Optional +from typing import Optional, Tuple import warnings import torch import torch.nn.functional as F +from bitsandbytes.utils import QuantState from bitsandbytes.functional import ( QuantState, create_dynamic_map, @@ -57,6 +58,17 @@ def _ipex_xpu_version_prereq(major, minor): return False +str2optimizer8bit_blockwise = {} +if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7): + str2optimizer8bit_blockwise = { + "adam": ( + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16, + ), + } + + def _maybe_torch_compile(func): # torch.compile requires g++ and pytorch >= 2.0 if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu: @@ -77,8 +89,32 @@ def reverse_4bit_compress_format(weight): return out +def transform( + A: torch.Tensor, + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[Tuple[torch.Size, str]] = None, + ): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For CPU/XPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + @_maybe_torch_compile -def double_quant_impl(A, col_stats=None, row_stats=None, out_col=None, out_row=None, threshold=0.0): +def int8_double_quant_impl(A, threshold=0.0, col_stats=None, row_stats=None, out_col=None, out_row=None): """ Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in @@ -157,6 +193,26 @@ def quant_to_int8(A, stats): return out_row, out_col, row_stats.float(), col_stats.float(), outlier_cols +def int8_vectorwise_quant_impl(A: torch.Tensor, threshold=0.0): + # TODO: We can optimize this as we don't actually need column-wise quant. + out, _, stats, _, outlier_cols = int8_double_quant_impl(A, threshold=threshold) + return out, stats, outlier_cols + + +def int8_vectorwise_dequant_impl(A: torch.Tensor, stats: torch.Tensor): + """Dequantizes a tensor with dtype `torch.int8` to `torch.float32`. + + Args: + A (`torch.Tensor` with dtype `torch.int8`): The quantized int8 tensor. + stats (`torch.Tensor` with dtype `torch.float32`): The row-wise quantization statistics. + + Returns: + `torch.Tensor` with dtype `torch.float32`: The dequantized tensor. + """ + # To dequantize we divide by 127, or multiply by the reciprocal. + return A * stats.view(-1, 1) * 7.874015718698502e-3 + + def int8_linear_matmul_impl( A: torch.Tensor, B: torch.Tensor, @@ -227,10 +283,10 @@ def int8_mm_dequant_impl( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, - out: Optional[torch.Tensor] = None, bias: Optional[torch.Tensor] = None, compute_dtype=torch.float32, output_dtype=torch.float32, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Dequant and add bias @@ -303,11 +359,11 @@ def int8_mm_dequant_impl( def quantize_4bit_impl( A: Tensor, absmax: Tensor = None, - out: Tensor = None, blocksize=64, compress_statistics=False, quant_type="nf4", quant_storage=torch.uint8, + out: Tensor = None, ) -> Tensor: """ Quantize tensor A in blocks of 4-bit values. @@ -443,9 +499,9 @@ def dequantize_4bit_impl( A: Tensor, quant_state=None, absmax: Tensor = None, - out: Tensor = None, blocksize: int = 64, quant_type="nf4", + out: Tensor = None, ) -> Tensor: """ Dequantizes 4-bit blockwise quantized values. @@ -471,6 +527,11 @@ def dequantize_4bit_impl( torch.Tensor: Dequantized tensor. """ + # For NF4, ipex have dequant kernel. + if quant_type == "nf4" and getattr(quant_state, "ipex", False): + out = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t() + return out + transpose = True if A.shape[0] == 1 else False A = A.reshape(-1) device = A.device @@ -545,10 +606,8 @@ def dequantize_4bit_impl( def gemm_4bit_impl( A: torch.Tensor, B: torch.Tensor, - out: Optional[torch.Tensor] = None, - transposed_A=False, - transposed_B=False, state: QuantState = None, + out: Optional[torch.Tensor] = None, ) -> torch.Tensor: """ Matrix-matrix multiplication with 4-bit quantization. @@ -598,3 +657,118 @@ def gemm_4bit_impl( else: out = output return out + + +def dequantize_blockwise( + A: torch.Tensor, + absmax: Optional[torch.Tensor] = None, + code: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize: int = 4096, +) -> torch.Tensor: + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + # void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) + if out.dtype == torch.float16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) + elif out.dtype == torch.bfloat16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) + elif out.dtype == torch.float32: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + + +def quantize_blockwise( + A: torch.Tensor, + code: Optional[torch.Tensor] = None, + absmax: Optional[torch.Tensor] = None, + out: Optional[torch.Tensor] = None, + blocksize=4096, + nested=False, +) -> Tuple[torch.Tensor, QuantState]: + raise NotImplementedError + +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + optim_func = None + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + optim_func( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + g.numel() + ) + + +def optimizer_update_32bit( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + beta3: float = 0.0, + alpha: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, +) -> None: + raise NotImplementedError diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index 8ffe13b..cdba9f8 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -1,51 +1,377 @@ from collections.abc import Sequence +from typing import Optional import math import torch -from .cpu_xpu_common import int8_linear_matmul_impl +from bitsandbytes.utils import QuantState +from .cpu_xpu_common import ( + int8_linear_matmul_impl, + int8_double_quant_impl, + int8_vectorwise_quant_impl, + int8_mm_dequant_impl, + quantize_4bit_impl, + dequantize_4bit_impl, + gemm_4bit_impl, + dequantize_blockwise, + optimizer_update_8bit_blockwise, + ipex_xpu, + ipex_cpu_only, +) print("Loading ops module") -def register_ops(): +def register_xpu_ops(): print("Registering XPU implementations") - # Check if the operator exists - if not hasattr(torch.ops.bitsandbytes, "int8_linear_matmul"): - raise RuntimeError("bitsandbytes::int8_linear_matmul not found! Make sure bitsandbytes is installed") - + # Register the int8_linear_matmul implementation @torch.library.impl("bitsandbytes::int8_linear_matmul", "XPU") def int8_linear_matmul_xpu(A: torch.Tensor, B: torch.Tensor): - print("int8_linear_matmul_xpu called with tensors of shape:", A.shape, B.shape) - return int8_linear_matmul_impl(A, B) - + return int8_linear_matmul_impl(A, B) @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "XPU") def int8_linear_matmul_xpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - print("int8_linear_matmul_xpu_out called with tensors of shape:", A.shape, B.shape) return int8_linear_matmul_impl(A, B, out) - - @torch.library.impl("bitsandbytes::dequantize_4bit.out", "XPU") + + # Register the int8_double_quant implementation + @torch.library.impl("bitsandbytes::int8_double_quant", "XPU") + def int8_double_quant_xpu( + A: torch.Tensor, + threshold: float = 0.0, + col_stats: torch.Tensor = None, + row_stats: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_double_quant_impl(A, threshold, col_stats, row_stats) + @torch.library.impl("bitsandbytes::int8_double_quant.out", "XPU") + def int8_double_quant_xpu_out( + A: torch.Tensor, + threshold: float = 0.0, + col_stats: torch.Tensor = None, + row_stats: torch.Tensor = None, + out_col: torch.Tensor = None, + out_row: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_double_quant_impl(A, threshold, col_stats, row_stats, out_col, out_row) + + # Register the int8_vectorwise_quant implementation + @torch.library.impl("bitsandbytes::int8_vectorwise_quant", "XPU") + def int8_vectorwise_quant_xpu( + A: torch.Tensor, + threshold: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_vectorwise_quant_impl(A, threshold) + + # Register the int8_mm_dequant implementation + @torch.library.impl("bitsandbytes::int8_mm_dequant", "XPU") + def int8_mm_dequant_xpu( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: torch.Tensor = None, + compute_dtype=torch.float32, + output_dtype=torch.float32, + ) -> torch.Tensor: + return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype) + @torch.library.impl("bitsandbytes::int8_mm_dequant.out", "XPU") + def int8_mm_dequant_xpu_out( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: torch.Tensor = None, + compute_dtype = torch.float32, + output_dtype = torch.float32, + out: torch.Tensor = None, + ) -> torch.Tensor: + return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype, out) + + # Register the quantize_4bit implementation + @torch.library.impl("bitsandbytes::quantize_4bit", "XPU") + def quantize_4bit_xpu( + A: torch.Tensor, + absmax: torch.Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + quant_storage=torch.uint8, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_4bit_impl( + A, + absmax, + blocksize, + compress_statistics, + quant_type, + quant_storage, + ) + @torch.library.impl("bitsandbytes::quantize_4bit.out", "XPU") + def quantize_4bit_xpu_out( + A: torch.Tensor, + absmax: torch.Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + quant_storage=torch.uint8, + out: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_4bit_impl( + A, + absmax, + blocksize, + compress_statistics, + quant_type, + quant_storage, + out, + ) + + # Register the dequantize_4bit implementation + @torch.library.impl("bitsandbytes::dequantize_4bit", "XPU") def dequantize_4bit_xpu( A: torch.Tensor, - absmax: torch.Tensor, - blocksize: int, - quant_type: str, - shape: Sequence[int], - dtype: torch.dtype, - out: torch.Tensor, + quant_state = None, + absmax: torch.Tensor = None, + blocksize: int = 64, + quant_type = "nf4", ) -> torch.Tensor: - # TODO - # if quant_type == "nf4" and getattr(quant_state, "ipex", False): - # output = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t() - # else: - # output = dequantize_4bit_impl(A, quant_state, absmax, out, blocksize, quant_type) - - # return output - raise NotImplementedError + return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type) + @torch.library.impl("bitsandbytes::dequantize_4bit.out", "XPU") + def dequantize_4bit_xpu_out( + A: torch.Tensor, + quant_state = None, + absmax: torch.Tensor = None, + blocksize: int = 64, + quant_type = "nf4", + out: torch.Tensor = None, + ) -> torch.Tensor: + return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type, out) + + # Register the gemv_4bit implementation + @torch.library.impl("bitsandbytes::gemv_4bit", "XPU") + def gemv_4bit_xpu( + A: torch.Tensor, + B: torch.Tensor, + state: QuantState = None, + ) -> torch.Tensor: + return gemm_4bit_impl(A, B, state=state) + @torch.library.impl("bitsandbytes::gemv_4bit.out", "XPU") + def gemv_4bit_xpu_out( + A: torch.Tensor, + B: torch.Tensor, + state: QuantState = None, + out: torch.Tensor = None, + ) -> torch.Tensor: + return gemm_4bit_impl(A, B, state=state, out=out) + + # Register the dequantize_blockwise implementation + @torch.library.impl("bitsandbytes::dequantize_blockwise", "XPU") + def dequantize_blockwise_xpu( + A: torch.Tensor, + absmax: torch.Tensor = None, + code: torch.Tensor = None, + out: torch.Tensor = None, + blocksize: int = 4096, + ) -> torch.Tensor: + return dequantize_blockwise(A, absmax, code, out, blocksize) + + # Register the optimizer_update_8bit_blockwise implementation + @torch.library.impl("bitsandbytes::optimizer_update_8bit_blockwise", "XPU") + def optimizer_update_8bit_blockwise_xpu( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, + ) -> None: + optimizer_update_8bit_blockwise( + optimizer_name, + g, + p, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + ) print("Successfully registered XPU implementation") + +def register_cpu_ops(): + print("Registering CPU implementations") + + # Register the int8_linear_matmul implementation + @torch.library.impl("bitsandbytes::int8_linear_matmul", "CPU") + def int8_linear_matmul_cpu(A: torch.Tensor, B: torch.Tensor): + return int8_linear_matmul_impl(A, B) + @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "CPU") + def int8_linear_matmul_cpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + return int8_linear_matmul_impl(A, B, out) + + # Register the int8_double_quant implementation + @torch.library.impl("bitsandbytes::int8_double_quant", "CPU") + def int8_double_quant_cpu( + A: torch.Tensor, + threshold: float = 0.0, + col_stats: torch.Tensor = None, + row_stats: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_double_quant_impl(A, threshold, col_stats, row_stats) + @torch.library.impl("bitsandbytes::int8_double_quant.out", "CPU") + def int8_double_quant_cpu_out( + A: torch.Tensor, + threshold: float = 0.0, + col_stats: torch.Tensor = None, + row_stats: torch.Tensor = None, + out_col: torch.Tensor = None, + out_row: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_double_quant_impl(A, threshold, col_stats, row_stats, out_col, out_row) + + # Register the int8_vectorwise_quant implementation + @torch.library.impl("bitsandbytes::int8_vectorwise_quant", "CPU") + def int8_vectorwise_quant_cpu( + A: torch.Tensor, + threshold: float = 0.0, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + return int8_vectorwise_quant_impl(A, threshold) + + # Register the int8_mm_dequant implementation + @torch.library.impl("bitsandbytes::int8_mm_dequant", "CPU") + def int8_mm_dequant_cpu( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: torch.Tensor = None, + compute_dtype=torch.float32, + output_dtype=torch.float32, + ) -> torch.Tensor: + return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype) + @torch.library.impl("bitsandbytes::int8_mm_dequant.out", "CPU") + def int8_mm_dequant_cpu_out( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + bias: torch.Tensor = None, + compute_dtype = torch.float32, + output_dtype = torch.float32, + out: torch.Tensor = None, + ) -> torch.Tensor: + return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype, out) + + # Register the quantize_4bit implementation + @torch.library.impl("bitsandbytes::quantize_4bit", "CPU") + def quantize_4bit_cpu( + A: torch.Tensor, + absmax: torch.Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + quant_storage=torch.uint8, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_4bit_impl( + A, + absmax, + blocksize, + compress_statistics, + quant_type, + quant_storage, + ) + @torch.library.impl("bitsandbytes::quantize_4bit.out", "CPU") + def quantize_4bit_cpu_out( + A: torch.Tensor, + absmax: torch.Tensor = None, + blocksize=64, + compress_statistics=False, + quant_type="nf4", + quant_storage=torch.uint8, + out: torch.Tensor = None, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_4bit_impl( + A, + absmax, + blocksize, + compress_statistics, + quant_type, + quant_storage, + out, + ) + + # Register the dequantize_4bit implementation + @torch.library.impl("bitsandbytes::dequantize_4bit", "CPU") + def dequantize_4bit_cpu( + A: torch.Tensor, + quant_state = None, + absmax: torch.Tensor = None, + blocksize: int = 64, + quant_type = "nf4", + ) -> torch.Tensor: + return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type) + @torch.library.impl("bitsandbytes::dequantize_4bit.out", "CPU") + def dequantize_4bit_cpu_out( + A: torch.Tensor, + quant_state = None, + absmax: torch.Tensor = None, + blocksize: int = 64, + quant_type = "nf4", + out: torch.Tensor = None, + ) -> torch.Tensor: + return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type, out) + + # Register the gemv_4bit implementation + @torch.library.impl("bitsandbytes::gemv_4bit", "CPU") + def gemv_4bit_cpu( + A: torch.Tensor, + B: torch.Tensor, + state: QuantState = None, + ) -> torch.Tensor: + return gemm_4bit_impl(A, B, state=state) + @torch.library.impl("bitsandbytes::gemv_4bit.out", "CPU") + def gemv_4bit_cpu_out( + A: torch.Tensor, + B: torch.Tensor, + state: QuantState = None, + out: torch.Tensor = None, + ) -> torch.Tensor: + return gemm_4bit_impl(A, B, state=state, out=out) + + # Register the dequantize_blockwise implementation + @torch.library.impl("bitsandbytes::dequantize_blockwise", "CPU") + def dequantize_blockwise_cpu( + A: torch.Tensor, + absmax: torch.Tensor = None, + code: torch.Tensor = None, + out: torch.Tensor = None, + blocksize: int = 4096, + ) -> torch.Tensor: + return dequantize_blockwise(A, absmax, code, out, blocksize) + + print("Successfully registered CPU implementation") + + +def register_hpu_ops(): print("Registering HPU implementations") @torch.library.impl("bitsandbytes::dequantize_4bit", "HPU") @@ -77,4 +403,18 @@ def quantize_4bit_hpu( print("Successfully registered HPU implementations") +def register_ops(): + # Check if the operator exists + if not hasattr(torch.ops.bitsandbytes, "int8_linear_matmul"): + raise RuntimeError("bitsandbytes::int8_linear_matmul not found! Make sure bitsandbytes is installed") + + if ipex_xpu: + register_xpu_ops() + elif ipex_cpu_only: + register_cpu_ops() + # TODO: Need to check HPU + else: + register_hpu_ops() + + print("ops module loaded") From ae7670334290c75029c6e76c353f2a449796fbeb Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 7 May 2025 09:20:21 +0000 Subject: [PATCH 03/22] add de/quantize blockwise op Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/cpu_xpu_common.py | 182 +++++++++++++---------- src/bitsandbytes_intel/ops.py | 130 +++++++++------- 2 files changed, 177 insertions(+), 135 deletions(-) diff --git a/src/bitsandbytes_intel/cpu_xpu_common.py b/src/bitsandbytes_intel/cpu_xpu_common.py index d9af7fa..7ca9c02 100644 --- a/src/bitsandbytes_intel/cpu_xpu_common.py +++ b/src/bitsandbytes_intel/cpu_xpu_common.py @@ -5,7 +5,6 @@ import torch import torch.nn.functional as F -from bitsandbytes.utils import QuantState from bitsandbytes.functional import ( QuantState, create_dynamic_map, @@ -353,18 +352,105 @@ def int8_mm_dequant_impl( 0.8333333: 3, # 0b0011 } -INT8_QUANT_TABLE = create_dynamic_map().tolist() +# INT8_QUANT_TABLE = create_dynamic_map().tolist() + + +def quantize_blockwise_impl( + A: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize tensor A in blocks of 8-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to int8. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + code : torch.Tensor + The quantization code. + blocksize : int + The blocksize used in quantization. + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + torch.Tensor: + The absmax. + """ + n = A.numel() + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + + if out is None: + out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) + + rem = n % blocksize + has_rem = rem > 0 + + # Scale tensor to [-1, 1] + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + map = torch.tensor(code, device=scaled_A.device) + diff = torch.abs(scaled_A.unsqueeze(-1) - map) + out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) + + return out_uint8, absmax + + +def dequantize_blockwise_impl( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor = None, +) -> torch.Tensor: + assert A.dtype == torch.uint8 + out = code[A.reshape(-1).int()] + blocks = out.shape[-1] // blocksize + res = out.shape[-1] % blocksize + if res != 0: + out = F.pad(out, (0, blocksize - res), mode="constant", value=0) + out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) + out = out[: blocks * blocksize + res] + out = out.reshape(A.shape) + return out + + +# def dequant_8bit(A, offset, quant_state): +# assert A.dtype == torch.uint8 +# absmax = quant_state.code[A.reshape(-1).int()] +# blocks = absmax.shape[-1] // 256 +# res = absmax.shape[-1] % 256 +# if res != 0: +# absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0) +# absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1) +# absmax = absmax[: blocks * 256 + res] +# absmax = absmax.reshape(A.shape) +# absmax += offset +# return absmax def quantize_4bit_impl( A: Tensor, - absmax: Tensor = None, blocksize=64, - compress_statistics=False, quant_type="nf4", quant_storage=torch.uint8, - out: Tensor = None, -) -> Tensor: +) -> tuple[torch.Tensor, torch.Tensor]: """ Quantize tensor A in blocks of 4-bit values. @@ -374,10 +460,6 @@ def quantize_4bit_impl( ---------- A : torch.Tensor The input tensor. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - The output tensor (8-bit). blocksize : int The blocksize used in quantization. quant_type : str @@ -389,8 +471,8 @@ def quantize_4bit_impl( ------- torch.Tensor: The 8-bit tensor with packed 4-bit values. - tuple(torch.Tensor, torch.Size, torch.dtype, int): - The quantization state to undo the quantization. + torch.Tensor: + The absmax. """ if quant_type not in ["nf4", "fp4", "int8"]: raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.") @@ -398,12 +480,9 @@ def quantize_4bit_impl( warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.") assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] n = A.numel() - input_shape = A.shape blocks = n // blocksize blocks += 1 if n % blocksize > 0 else 0 - - if absmax is None: - absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) if out is None: out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) @@ -433,64 +512,16 @@ def quantize_4bit_impl( for key, val in FP4_QUANT_TABLE.items(): out_uint8[abs_scaled_A > key] = val out_uint8 += sign.to(torch.uint8) * 8 - elif quant_type == "int8": - map = torch.tensor(INT8_QUANT_TABLE, device=scaled_A.device) - diff = torch.abs(scaled_A.unsqueeze(-1) - map) - out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) - - if quant_type == "int8": - out = out_uint8 - code = torch.Tensor(INT8_QUANT_TABLE).to(A.device) - else: - if out_uint8.size(-1) % 2: - out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) - out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2]) - code = get_4bit_type(quant_type, device=A.device) - - if compress_statistics: - offset = absmax.mean() - absmax -= offset - qabsmax, state2 = quantize_4bit_impl(absmax, blocksize=256, quant_type="int8") - del absmax - state = QuantState( - absmax=qabsmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - offset=offset, - state2=state2, - ) - else: - state = QuantState( - absmax=absmax, - shape=input_shape, - dtype=A.dtype, - blocksize=blocksize, - code=code, - quant_type=quant_type, - ) + + if out_uint8.size(-1) % 2: + out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) + out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2]) if quant_storage != torch.uint8: bytes_value = out.cpu().numpy().tobytes() out = torch.frombuffer(bytes_value, dtype=quant_storage).to(A.device) - return out.reshape(-1, 1), state - - -def dequant_8bit(A, offset, quant_state): - assert A.dtype == torch.uint8 - absmax = quant_state.code[A.reshape(-1).int()] - blocks = absmax.shape[-1] // 256 - res = absmax.shape[-1] % 256 - if res != 0: - absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0) - absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1) - absmax = absmax[: blocks * 256 + res] - absmax = absmax.reshape(A.shape) - absmax += offset - return absmax + return out.reshape(-1, 1), absmax # Compile will fail in torch.frombuffer @@ -558,9 +589,6 @@ def dequantize_4bit_impl( f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." ) - if quant_state.nested: - absmax = dequant_8bit(absmax, quant_state.offset, quant_state.state2) - if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False): ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2) A = reverse_4bit_compress_format(ipex_weight) @@ -678,17 +706,7 @@ def dequantize_blockwise( ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") - -def quantize_blockwise( - A: torch.Tensor, - code: Optional[torch.Tensor] = None, - absmax: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize=4096, - nested=False, -) -> Tuple[torch.Tensor, QuantState]: - raise NotImplementedError def optimizer_update_8bit_blockwise( optimizer_name: str, diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index cdba9f8..c64252d 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -4,14 +4,16 @@ import torch -from bitsandbytes.utils import QuantState from .cpu_xpu_common import ( + QuantState, int8_linear_matmul_impl, int8_double_quant_impl, int8_vectorwise_quant_impl, int8_mm_dequant_impl, quantize_4bit_impl, dequantize_4bit_impl, + quantize_blockwise_impl, + dequantize_blockwise_impl, gemm_4bit_impl, dequantize_blockwise, optimizer_update_8bit_blockwise, @@ -32,7 +34,7 @@ def int8_linear_matmul_xpu(A: torch.Tensor, B: torch.Tensor): @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "XPU") def int8_linear_matmul_xpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): return int8_linear_matmul_impl(A, B, out) - + # Register the int8_double_quant implementation @torch.library.impl("bitsandbytes::int8_double_quant", "XPU") def int8_double_quant_xpu( @@ -52,7 +54,7 @@ def int8_double_quant_xpu_out( out_row: torch.Tensor = None, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: return int8_double_quant_impl(A, threshold, col_stats, row_stats, out_col, out_row) - + # Register the int8_vectorwise_quant implementation @torch.library.impl("bitsandbytes::int8_vectorwise_quant", "XPU") def int8_vectorwise_quant_xpu( @@ -60,7 +62,7 @@ def int8_vectorwise_quant_xpu( threshold: float = 0.0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return int8_vectorwise_quant_impl(A, threshold) - + # Register the int8_mm_dequant implementation @torch.library.impl("bitsandbytes::int8_mm_dequant", "XPU") def int8_mm_dequant_xpu( @@ -83,7 +85,7 @@ def int8_mm_dequant_xpu_out( out: torch.Tensor = None, ) -> torch.Tensor: return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype, out) - + # Register the quantize_4bit implementation @torch.library.impl("bitsandbytes::quantize_4bit", "XPU") def quantize_4bit_xpu( @@ -102,26 +104,7 @@ def quantize_4bit_xpu( quant_type, quant_storage, ) - @torch.library.impl("bitsandbytes::quantize_4bit.out", "XPU") - def quantize_4bit_xpu_out( - A: torch.Tensor, - absmax: torch.Tensor = None, - blocksize=64, - compress_statistics=False, - quant_type="nf4", - quant_storage=torch.uint8, - out: torch.Tensor = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - return quantize_4bit_impl( - A, - absmax, - blocksize, - compress_statistics, - quant_type, - quant_storage, - out, - ) - + # Register the dequantize_4bit implementation @torch.library.impl("bitsandbytes::dequantize_4bit", "XPU") def dequantize_4bit_xpu( @@ -142,7 +125,37 @@ def dequantize_4bit_xpu_out( out: torch.Tensor = None, ) -> torch.Tensor: return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type, out) - + + # Register the quantize_blockwise implementation + @torch.library.impl("bitsandbytes::quantize_blockwise", "XPU") + def quantize_blockwise_xpu( + A: torch.Tensor, + code: torch.Tensor, + blocksize: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_blockwise_impl(A, code, blocksize) + + # Register the dequantize_blockwise implementation + @torch.library.impl("bitsandbytes::dequantize_blockwise", "XPU") + def dequantize_blockwise_xpu( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + ) -> torch.Tensor: + return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype) + @torch.library.impl("bitsandbytes::dequantize_blockwise.out", "XPU") + def dequantize_blockwise_xpu_out( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, + ) -> torch.Tensor: + return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out) + # Register the gemv_4bit implementation @torch.library.impl("bitsandbytes::gemv_4bit", "XPU") def gemv_4bit_xpu( @@ -159,7 +172,7 @@ def gemv_4bit_xpu_out( out: torch.Tensor = None, ) -> torch.Tensor: return gemm_4bit_impl(A, B, state=state, out=out) - + # Register the dequantize_blockwise implementation @torch.library.impl("bitsandbytes::dequantize_blockwise", "XPU") def dequantize_blockwise_xpu( @@ -170,7 +183,7 @@ def dequantize_blockwise_xpu( blocksize: int = 4096, ) -> torch.Tensor: return dequantize_blockwise(A, absmax, code, out, blocksize) - + # Register the optimizer_update_8bit_blockwise implementation @torch.library.impl("bitsandbytes::optimizer_update_8bit_blockwise", "XPU") def optimizer_update_8bit_blockwise_xpu( @@ -229,7 +242,7 @@ def int8_linear_matmul_cpu(A: torch.Tensor, B: torch.Tensor): @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "CPU") def int8_linear_matmul_cpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): return int8_linear_matmul_impl(A, B, out) - + # Register the int8_double_quant implementation @torch.library.impl("bitsandbytes::int8_double_quant", "CPU") def int8_double_quant_cpu( @@ -257,7 +270,7 @@ def int8_vectorwise_quant_cpu( threshold: float = 0.0, ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: return int8_vectorwise_quant_impl(A, threshold) - + # Register the int8_mm_dequant implementation @torch.library.impl("bitsandbytes::int8_mm_dequant", "CPU") def int8_mm_dequant_cpu( @@ -280,7 +293,7 @@ def int8_mm_dequant_cpu_out( out: torch.Tensor = None, ) -> torch.Tensor: return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype, out) - + # Register the quantize_4bit implementation @torch.library.impl("bitsandbytes::quantize_4bit", "CPU") def quantize_4bit_cpu( @@ -299,26 +312,7 @@ def quantize_4bit_cpu( quant_type, quant_storage, ) - @torch.library.impl("bitsandbytes::quantize_4bit.out", "CPU") - def quantize_4bit_cpu_out( - A: torch.Tensor, - absmax: torch.Tensor = None, - blocksize=64, - compress_statistics=False, - quant_type="nf4", - quant_storage=torch.uint8, - out: torch.Tensor = None, - ) -> tuple[torch.Tensor, torch.Tensor]: - return quantize_4bit_impl( - A, - absmax, - blocksize, - compress_statistics, - quant_type, - quant_storage, - out, - ) - + # Register the dequantize_4bit implementation @torch.library.impl("bitsandbytes::dequantize_4bit", "CPU") def dequantize_4bit_cpu( @@ -339,7 +333,37 @@ def dequantize_4bit_cpu_out( out: torch.Tensor = None, ) -> torch.Tensor: return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type, out) - + + # Register the quantize_blockwise implementation + @torch.library.impl("bitsandbytes::quantize_blockwise", "XPU") + def quantize_blockwise_xpu( + A: torch.Tensor, + code: torch.Tensor, + blocksize: int, + ) -> tuple[torch.Tensor, torch.Tensor]: + return quantize_blockwise_impl(A, code, blocksize) + + # Register the dequantize_blockwise implementation + @torch.library.impl("bitsandbytes::dequantize_blockwise", "CPU") + def dequantize_blockwise_cpu( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + ) -> torch.Tensor: + return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype) + @torch.library.impl("bitsandbytes::dequantize_blockwise.out", "CPU") + def dequantize_blockwise_cpu_out( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor, + ) -> torch.Tensor: + return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out) + # Register the gemv_4bit implementation @torch.library.impl("bitsandbytes::gemv_4bit", "CPU") def gemv_4bit_cpu( @@ -356,7 +380,7 @@ def gemv_4bit_cpu_out( out: torch.Tensor = None, ) -> torch.Tensor: return gemm_4bit_impl(A, B, state=state, out=out) - + # Register the dequantize_blockwise implementation @torch.library.impl("bitsandbytes::dequantize_blockwise", "CPU") def dequantize_blockwise_cpu( From 803b429ef2ea3bf67f4f6e862461a42c8163b46e Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 7 May 2025 09:44:28 +0000 Subject: [PATCH 04/22] fix dequantize blockwise register Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/cpu_xpu_common.py | 18 ++++++----- src/bitsandbytes_intel/ops.py | 38 +++++++----------------- 2 files changed, 21 insertions(+), 35 deletions(-) diff --git a/src/bitsandbytes_intel/cpu_xpu_common.py b/src/bitsandbytes_intel/cpu_xpu_common.py index 7ca9c02..4c04e1c 100644 --- a/src/bitsandbytes_intel/cpu_xpu_common.py +++ b/src/bitsandbytes_intel/cpu_xpu_common.py @@ -687,22 +687,24 @@ def gemm_4bit_impl( return out -def dequantize_blockwise( +# Currently only works for XPU +def dequantize_blockwise_ipex_impl( A: torch.Tensor, - absmax: Optional[torch.Tensor] = None, - code: Optional[torch.Tensor] = None, - out: Optional[torch.Tensor] = None, - blocksize: int = 4096, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, + out: torch.Tensor = None, ) -> torch.Tensor: if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") # void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) - if out.dtype == torch.float16: + if dtype == torch.float16: ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) - elif out.dtype == torch.bfloat16: + elif dtype == torch.bfloat16: ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) - elif out.dtype == torch.float32: + elif dtype == torch.float32: ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index c64252d..f40794d 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -15,10 +15,11 @@ quantize_blockwise_impl, dequantize_blockwise_impl, gemm_4bit_impl, - dequantize_blockwise, + dequantize_blockwise_ipex_impl, optimizer_update_8bit_blockwise, ipex_xpu, ipex_cpu_only, + _ipex_xpu_version_prereq, ) print("Loading ops module") @@ -136,6 +137,11 @@ def quantize_blockwise_xpu( return quantize_blockwise_impl(A, code, blocksize) # Register the dequantize_blockwise implementation + if _ipex_xpu_version_prereq(2, 7): + dequantize_blockwise = dequantize_blockwise_ipex_impl + else: + dequantize_blockwise = dequantize_blockwise_impl + @torch.library.impl("bitsandbytes::dequantize_blockwise", "XPU") def dequantize_blockwise_xpu( A: torch.Tensor, @@ -144,7 +150,7 @@ def dequantize_blockwise_xpu( blocksize: int, dtype: torch.dtype, ) -> torch.Tensor: - return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype) + return dequantize_blockwise(A, absmax, code, blocksize, dtype) @torch.library.impl("bitsandbytes::dequantize_blockwise.out", "XPU") def dequantize_blockwise_xpu_out( A: torch.Tensor, @@ -154,7 +160,7 @@ def dequantize_blockwise_xpu_out( dtype: torch.dtype, out: torch.Tensor, ) -> torch.Tensor: - return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out) + return dequantize_blockwise(A, absmax, code, blocksize, dtype, out) # Register the gemv_4bit implementation @torch.library.impl("bitsandbytes::gemv_4bit", "XPU") @@ -173,17 +179,6 @@ def gemv_4bit_xpu_out( ) -> torch.Tensor: return gemm_4bit_impl(A, B, state=state, out=out) - # Register the dequantize_blockwise implementation - @torch.library.impl("bitsandbytes::dequantize_blockwise", "XPU") - def dequantize_blockwise_xpu( - A: torch.Tensor, - absmax: torch.Tensor = None, - code: torch.Tensor = None, - out: torch.Tensor = None, - blocksize: int = 4096, - ) -> torch.Tensor: - return dequantize_blockwise(A, absmax, code, out, blocksize) - # Register the optimizer_update_8bit_blockwise implementation @torch.library.impl("bitsandbytes::optimizer_update_8bit_blockwise", "XPU") def optimizer_update_8bit_blockwise_xpu( @@ -335,8 +330,8 @@ def dequantize_4bit_cpu_out( return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type, out) # Register the quantize_blockwise implementation - @torch.library.impl("bitsandbytes::quantize_blockwise", "XPU") - def quantize_blockwise_xpu( + @torch.library.impl("bitsandbytes::quantize_blockwise", "CPU") + def quantize_blockwise_cpu( A: torch.Tensor, code: torch.Tensor, blocksize: int, @@ -381,17 +376,6 @@ def gemv_4bit_cpu_out( ) -> torch.Tensor: return gemm_4bit_impl(A, B, state=state, out=out) - # Register the dequantize_blockwise implementation - @torch.library.impl("bitsandbytes::dequantize_blockwise", "CPU") - def dequantize_blockwise_cpu( - A: torch.Tensor, - absmax: torch.Tensor = None, - code: torch.Tensor = None, - out: torch.Tensor = None, - blocksize: int = 4096, - ) -> torch.Tensor: - return dequantize_blockwise(A, absmax, code, out, blocksize) - print("Successfully registered CPU implementation") From 0bc2a21b3fba7cc60eaad991c7c9cb6e86320ebf Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 7 May 2025 09:50:18 +0000 Subject: [PATCH 05/22] fix quantize 4bit Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/ops.py | 8 -------- 1 file changed, 8 deletions(-) diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index f40794d..5d78881 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -91,17 +91,13 @@ def int8_mm_dequant_xpu_out( @torch.library.impl("bitsandbytes::quantize_4bit", "XPU") def quantize_4bit_xpu( A: torch.Tensor, - absmax: torch.Tensor = None, blocksize=64, - compress_statistics=False, quant_type="nf4", quant_storage=torch.uint8, ) -> tuple[torch.Tensor, torch.Tensor]: return quantize_4bit_impl( A, - absmax, blocksize, - compress_statistics, quant_type, quant_storage, ) @@ -293,17 +289,13 @@ def int8_mm_dequant_cpu_out( @torch.library.impl("bitsandbytes::quantize_4bit", "CPU") def quantize_4bit_cpu( A: torch.Tensor, - absmax: torch.Tensor = None, blocksize=64, - compress_statistics=False, quant_type="nf4", quant_storage=torch.uint8, ) -> tuple[torch.Tensor, torch.Tensor]: return quantize_4bit_impl( A, - absmax, blocksize, - compress_statistics, quant_type, quant_storage, ) From bb22037c5e022836bc0e7d99c20317281ab008ca Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Wed, 7 May 2025 09:52:37 +0000 Subject: [PATCH 06/22] fix quantize 4bit out Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/cpu_xpu_common.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/bitsandbytes_intel/cpu_xpu_common.py b/src/bitsandbytes_intel/cpu_xpu_common.py index 4c04e1c..a91b17e 100644 --- a/src/bitsandbytes_intel/cpu_xpu_common.py +++ b/src/bitsandbytes_intel/cpu_xpu_common.py @@ -483,9 +483,7 @@ def quantize_4bit_impl( blocks = n // blocksize blocks += 1 if n % blocksize > 0 else 0 absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) - - if out is None: - out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) + out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) rem = n % blocksize has_rem = rem > 0 From 8b02dcf7e6ecf92c7a78ff501f1f160d6fbd7021 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 8 May 2025 12:39:25 +0000 Subject: [PATCH 07/22] fix xpu ops and precommit Signed-off-by: jiqing-feng --- .github/workflows/package.yml | 16 +- docker/Dockerfile | 2 +- docker/entrypoint.sh | 2 +- src/bitsandbytes_intel/cpu_xpu_common.py | 792 ----------------------- src/bitsandbytes_intel/ops.py | 290 ++------- 5 files changed, 46 insertions(+), 1056 deletions(-) delete mode 100644 src/bitsandbytes_intel/cpu_xpu_common.py diff --git a/.github/workflows/package.yml b/.github/workflows/package.yml index e5b235a..389192c 100644 --- a/.github/workflows/package.yml +++ b/.github/workflows/package.yml @@ -53,42 +53,42 @@ jobs: # - "3.13" steps: - uses: actions/checkout@v4 - + - name: Checkout bitsandbytes uses: actions/checkout@v4 with: repository: bitsandbytes-foundation/bitsandbytes path: bitsandbytes - + - name: Set up Python ${{ matrix.python-version }} uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} cache: pip - + - name: Install build tools run: | sudo apt-get update sudo apt-get install -y build-essential cmake - + - name: Compile bitsandbytes with CPU backend run: | cd bitsandbytes cmake -DCOMPUTE_BACKEND=cpu -S . && make cd .. - + - name: Download build artifacts uses: actions/download-artifact@v4 with: name: dist path: dist/ - + - name: Install dependencies and built package run: | python -m pip install --upgrade pip pip install ./bitsandbytes pip install dist/*.whl - + - name: Test import works run: | python -c " @@ -98,7 +98,7 @@ jobs: print('✅ bitsandbytes_intel import successful') print('✅ All imports successful - no XPU operations tested, as for that we would need to configure the XPU runner..') " - + # - name: Test with pytest # run: pytest diff --git a/docker/Dockerfile b/docker/Dockerfile index 4e3703e..ec8e4b0 100644 --- a/docker/Dockerfile +++ b/docker/Dockerfile @@ -44,4 +44,4 @@ RUN cd ${WORKSPACE}/src/bnb && pip install . COPY --chmod=755 docker/entrypoint.sh /entrypoint.sh ENTRYPOINT ["/entrypoint.sh"] -CMD ["sleep", "infinity"] \ No newline at end of file +CMD ["sleep", "infinity"] diff --git a/docker/entrypoint.sh b/docker/entrypoint.sh index 8979d69..5c33986 100644 --- a/docker/entrypoint.sh +++ b/docker/entrypoint.sh @@ -4,4 +4,4 @@ set -euo pipefail pip install --no-deps -e /workspace/src/bnb pip install --no-deps -e /workspace/src/bnb_intel -exec "$@" \ No newline at end of file +exec "$@" diff --git a/src/bitsandbytes_intel/cpu_xpu_common.py b/src/bitsandbytes_intel/cpu_xpu_common.py deleted file mode 100644 index a91b17e..0000000 --- a/src/bitsandbytes_intel/cpu_xpu_common.py +++ /dev/null @@ -1,792 +0,0 @@ -import subprocess -from typing import Optional, Tuple -import warnings - -import torch -import torch.nn.functional as F - -from bitsandbytes.functional import ( - QuantState, - create_dynamic_map, - get_4bit_type, -) - -try: - # to support Intel CPU/GPU (XPU) backend - import intel_extension_for_pytorch as ipex - - ipex_cpu = ipex if ipex._C._has_cpu() else None - ipex_xpu = ipex if ipex._C._has_xpu() else None - ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu()) -except BaseException: - ipex_cpu = None - ipex_xpu = None - ipex_cpu_only = None - - -gxx_available = False -try: - subprocess.run(["g++", "--version"], capture_output=True) # hide terminal output - gxx_available = True -except BaseException: - warnings.warn("g++ not found, torch.compile disabled for CPU/XPU.") - - -Tensor = torch.Tensor - - -def _torch_version_prereq(major, minor): - ver_major = int(torch.__version__.split(".")[0]) - ver_minor = int(torch.__version__.split(".")[1]) - return ver_major * 32 + ver_minor >= major * 32 + minor - - -def _ipex_cpu_version_prereq(major, minor): - if ipex_cpu is not None: - ver_major = ipex_cpu.__version__.split(".")[0] - ver_minor = ipex_cpu.__version__.split(".")[1] - return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor - return False - - -def _ipex_xpu_version_prereq(major, minor): - if ipex_xpu is not None: - ver_major = ipex_xpu.__version__.split(".")[0] - ver_minor = ipex_xpu.__version__.split(".")[1] - return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor - return False - - -str2optimizer8bit_blockwise = {} -if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7): - str2optimizer8bit_blockwise = { - "adam": ( - ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32, - ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16, - ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16, - ), - } - - -def _maybe_torch_compile(func): - # torch.compile requires g++ and pytorch >= 2.0 - if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu: - options = {} - # fx_graph_cache requires pytorch >= 2.2 - if _torch_version_prereq(2, 2): - options.update({"fx_graph_cache": True}) - return torch.compile(func, dynamic=True, options=options) - return func - - -def reverse_4bit_compress_format(weight): - out_1 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) - out_2 = torch.empty(weight.size(0), dtype=torch.int32, device=weight.device) - out_1 = (weight & 0xF0) >> 4 - out_2 = (weight & 0xF) << 4 - out = out_1 | out_2 - return out - - -def transform( - A: torch.Tensor, - out: Optional[torch.Tensor] = None, - transpose=False, - state: Optional[Tuple[torch.Size, str]] = None, - ): - """ - Transform tensor A to to_order. It is originally designed for CUDA. - For CPU/XPU, it returns the original tensor if transpose=False. - Otherwise, it returns the transpose of A - """ - if transpose: - if out is not None: - out.copy_(A.T) - else: - out = A.T - else: - if out is not None: - out.copy_(A) - else: - out = A - return out, state - - -@_maybe_torch_compile -def int8_double_quant_impl(A, threshold=0.0, col_stats=None, row_stats=None, out_col=None, out_row=None): - """ - Find absolute max values of each row/column of a tensor, and symmetrically quantize it to int8. - If threshold > 0.0, only values <= threshold are counted. All outliers are zeroed out in - the original tensor and they are kept in COO format: (rows, cols, values) - If threshold == 0.0, there are no outliers. - Args: - A The tensor to be analyzed and quantized. - col_stats Absolute max values of each column of A. If it is not None, use the values directly. - Otherwise, find the values. - row_stats Absolute max values of each row of A. If it is not None, use the values directly. - Otherwise, find the values. - out_col Output buffer for the result quantized per column if it is not None - out_row Output buffer for the result quantized per row if it is not None - threshold The threshold for finding outliers if it is > 0.0. Otherwise it has no effect. - Return: - A tuple of output quantized per row, output quantized per column, absolute max values of - each row of A, absolute max values of each column of A, outliers in COO format - """ - - cols = A.shape[-1] - if len(A.shape) == 3: - rows = A.shape[0] * A.shape[1] - else: - assert A.dim() == 2, f"double_quant: Input tensor should be 2d or 3d but got {A.dim()}d" - rows = A.shape[0] - A = A.reshape(rows, cols) - - def get_row_col_stats(A): - row_stats = torch.max(torch.abs(A), 1).values # absolute max of each row - col_stats = torch.max(torch.abs(A), 0).values # absolute max of each col - return row_stats, col_stats - - def quant_to_int8(A, stats): - return torch.clamp(torch.round(A * (127.0 / stats)), -128, 127).to(torch.int8) - - if threshold == 0.0: - if row_stats is None or col_stats is None: - row_stats, col_stats = get_row_col_stats(A) - outlier_cols = None - else: - outlier_indices = torch.abs(A) >= threshold # find outliers - outlier_cols = torch.argwhere(outlier_indices.any(dim=0)).view(-1) - outlier_values = A[outlier_indices].clone() - - # outlier_indices = torch.abs(A) >= threshold # find outliers - # outlier_coord = outlier_indices.nonzero() # get outlier coordinates - # outlier_rows = outlier_coord[:, 0] # outlier row for COO sparse tensor - # outlier_cols = outlier_coord[:, 1] # outlier column for COO sparse tensor - # outlier_values = A[outlier_indices] # outlier values for COO sparse tensor - # coo_tensor = COOSparseTensor( - # A.shape[0], A.shape[1], outlier_values.numel(), outlier_rows.int(), outlier_cols.int(), outlier_values - # ) - if row_stats is None or col_stats is None: - A[outlier_indices] = 0 # zero out outliers - row_stats, col_stats = get_row_col_stats(A) - - quant_by_row = quant_to_int8(A, row_stats.unsqueeze(-1)) - quant_by_col = quant_to_int8(A, col_stats.unsqueeze(0)) - - if outlier_cols is not None: - A[outlier_indices] = outlier_values # restore outliers for later use - - if rows > 1: - # zero out outlier columns for all rows - quant_by_row[:, outlier_cols] = 0 - - if out_row is not None: - out_row.copy_(quant_by_row) - else: - out_row = quant_by_row - if out_col is not None: - out_col.copy_(quant_by_col) - else: - out_col = quant_by_col - # Return float stats to align with CUDA impl - return out_row, out_col, row_stats.float(), col_stats.float(), outlier_cols - - -def int8_vectorwise_quant_impl(A: torch.Tensor, threshold=0.0): - # TODO: We can optimize this as we don't actually need column-wise quant. - out, _, stats, _, outlier_cols = int8_double_quant_impl(A, threshold=threshold) - return out, stats, outlier_cols - - -def int8_vectorwise_dequant_impl(A: torch.Tensor, stats: torch.Tensor): - """Dequantizes a tensor with dtype `torch.int8` to `torch.float32`. - - Args: - A (`torch.Tensor` with dtype `torch.int8`): The quantized int8 tensor. - stats (`torch.Tensor` with dtype `torch.float32`): The row-wise quantization statistics. - - Returns: - `torch.Tensor` with dtype `torch.float32`: The dequantized tensor. - """ - # To dequantize we divide by 127, or multiply by the reciprocal. - return A * stats.view(-1, 1) * 7.874015718698502e-3 - - -def int8_linear_matmul_impl( - A: torch.Tensor, - B: torch.Tensor, - out: Optional[torch.Tensor] = None, - dtype=torch.int32, -) -> torch.Tensor: - """ - Do GEMMM computation. Data type: int8 * int8 -> int32. - Args: - A Activation of linear, data type is int8 - B Weight of linear, data type is int8 - out Specified output tensor if it is not None - dtype Data type of output - Return: - A tuple of GEMM result in dtype and Sout - """ - - assert A.dtype == torch.int8 - assert B.dtype == torch.int8 - if out is not None: - assert out.dtype == dtype - - dimsA = A.ndim - dimsB = B.ndim - shapeA = A.shape - shapeB = B.shape - assert dimsA in [2, 3], "Only two or three dimensional matrices are supported for argument A" - assert dimsB == 2, "Only two dimensional matrices are supported for argument B" - - if dimsA == 2: - m = shapeA[0] - elif dimsA == 3: - m = shapeA[0] * shapeA[1] - n = shapeB[0] - k = shapeA[-1] - assert shapeA[-1] == shapeB[-1], f"Shapes of A and B do not match, got {shapeA} and {shapeB}" - - # if the tensor is empty, return a transformed empty tensor with the right dimensions - if shapeA[0] == 0 and dimsA == 2: - return torch.empty((0, n), device=A.device, dtype=A.dtype) - elif shapeA[1] == 0 and dimsA == 3: - return torch.empty(tuple(shapeA[:2] + [n]), device=A.device, dtype=A.dtype) - - A_reshaped = A.reshape(m, k) - - # torch._int_mm is available on CPU since torch 2.4, XPU since torch 2.6 - if ( - A.device.type == "cpu" and _torch_version_prereq(2, 4) - # or (A.device.type == "xpu" and _torch_version_prereq(2, 6) - ): - C = torch._int_mm(A_reshaped, B.T).to(dtype) - else: - C = torch.matmul(A_reshaped.float(), B.t().float()).to(dtype) - if C.ndim != dimsA: - assert dimsA == 3 - shapeOut = (shapeA[0], m // shapeA[0], C.shape[-1]) - C = C.reshape(shapeOut) - if out is not None: - out.copy_(C) - else: - out = C - - return out - - -@_maybe_torch_compile -def int8_mm_dequant_impl( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - bias: Optional[torch.Tensor] = None, - compute_dtype=torch.float32, - output_dtype=torch.float32, - out: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - Dequant and add bias - out = A_int32 * (abs_max_A * abs_max_B) / 127 * 127 + bias - Args: - A The output of int8 gemm, whose dtype is int32 - row_stats Absolute max value of each row of input (A) of gemm - col_stats Absolute max value of each row of weight (B) of gemm - out Output buffer - bias Bias of linear - compute_dtype Data type for computation - output_dtype Data type for output - Return: - The result - """ - assert A.dtype == torch.int32 - out_shape = A.shape - if len(out_shape) == 3: - out_shape = (out_shape[0] * out_shape[1], out_shape[2]) - - if compute_dtype not in [torch.float32, torch.bfloat16]: - warnings.warn( - f"mm_dequant_{A.device}: compute_dtype {compute_dtype} is not supported, will use bfloat16 instead" - ) - compute_dtype = torch.bfloat16 - A_reshaped = A.reshape(out_shape).to(compute_dtype) - row_stats = row_stats.reshape(-1).unsqueeze(-1).to(compute_dtype) - col_stats = col_stats.reshape(-1).unsqueeze(0).to(compute_dtype) - out = A_reshaped * row_stats * col_stats / (127 * 127) - if bias is not None: - out = out + bias.to(compute_dtype) - out = out.to(output_dtype) - return out - - -NF4_QUANT_TABLE = [ - -1.0 - 1e-2, # 0b0000 - -0.8480964004993439, # 0b0001 - -0.6106329262256622, # 0b0010 - -0.4599952697753906, # 0b0011 - -0.33967943489551544, # 0b0100 - -0.23460740596055984, # 0b0101 - -0.13791173323988914, # 0b0110 - -0.045525018125772476, # 0b0111 - 0.03979014977812767, # 0b1000 - 0.1202552504837513, # 0b1001 - 0.2035212516784668, # 0b1010 - 0.2920137718319893, # 0b1011 - 0.3893125355243683, # 0b1100 - 0.5016634166240692, # 0b1101 - 0.6427869200706482, # 0b1110 - 0.8614784181118011, # 0b1111 -] - - -FP4_QUANT_TABLE = { - 0 - 1e-2: 0, # 0b0000 - 0.00260417: 1, # 0b0001 - 0.0859375: 6, # 0b0110 - 0.20833333: 7, # 0b0111 - 0.29166667: 4, # 0b0100 - 0.4166667: 5, # 0b0101 - 0.583333: 2, # 0b0010 - 0.8333333: 3, # 0b0011 -} - -# INT8_QUANT_TABLE = create_dynamic_map().tolist() - - -def quantize_blockwise_impl( - A: torch.Tensor, - code: torch.Tensor, - blocksize: int, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantize tensor A in blocks of 8-bit values. - - Quantizes tensor A by dividing it into blocks which are independently quantized to int8. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - code : torch.Tensor - The quantization code. - blocksize : int - The blocksize used in quantization. - - Returns - ------- - torch.Tensor: - The 8-bit tensor with packed 4-bit values. - torch.Tensor: - The absmax. - """ - n = A.numel() - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) - - if out is None: - out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) - - rem = n % blocksize - has_rem = rem > 0 - - # Scale tensor to [-1, 1] - A_reshaped = A.reshape(n) - A_com = A_reshaped[: n - rem] - A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) - scaled_A = scaled_A.reshape(-1) - if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() - scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) - scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - - map = torch.tensor(code, device=scaled_A.device) - diff = torch.abs(scaled_A.unsqueeze(-1) - map) - out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) - - return out_uint8, absmax - - -def dequantize_blockwise_impl( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor = None, -) -> torch.Tensor: - assert A.dtype == torch.uint8 - out = code[A.reshape(-1).int()] - blocks = out.shape[-1] // blocksize - res = out.shape[-1] % blocksize - if res != 0: - out = F.pad(out, (0, blocksize - res), mode="constant", value=0) - out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) - out = out[: blocks * blocksize + res] - out = out.reshape(A.shape) - return out - - -# def dequant_8bit(A, offset, quant_state): -# assert A.dtype == torch.uint8 -# absmax = quant_state.code[A.reshape(-1).int()] -# blocks = absmax.shape[-1] // 256 -# res = absmax.shape[-1] % 256 -# if res != 0: -# absmax = F.pad(absmax, (0, 256 - res), mode="constant", value=0) -# absmax = (absmax.view(-1, 256) * quant_state.absmax.view(-1, 1)).to(quant_state.dtype).reshape(-1) -# absmax = absmax[: blocks * 256 + res] -# absmax = absmax.reshape(A.shape) -# absmax += offset -# return absmax - - -def quantize_4bit_impl( - A: Tensor, - blocksize=64, - quant_type="nf4", - quant_storage=torch.uint8, -) -> tuple[torch.Tensor, torch.Tensor]: - """ - Quantize tensor A in blocks of 4-bit values. - - Quantizes tensor A by dividing it into blocks which are independently quantized to FP4. - - Parameters - ---------- - A : torch.Tensor - The input tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4}, only nf4 is supported now - quant_storage: torch.dtype - We can use bytes to convert storage type. - - Returns - ------- - torch.Tensor: - The 8-bit tensor with packed 4-bit values. - torch.Tensor: - The absmax. - """ - if quant_type not in ["nf4", "fp4", "int8"]: - raise NotImplementedError(f"4-bit quantization data type {quant_type} is not implemented for CPU/XPU.") - if quant_type == "fp4": - warnings.warn("fp4 quantization is currently slow on CPU/XPU. Please Use nf4 instead for better performance.") - assert blocksize in [4096, 2048, 1024, 512, 256, 128, 64] - n = A.numel() - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) - out = torch.zeros(((n + 1) // 2), dtype=torch.uint8, device=A.device) - - rem = n % blocksize - has_rem = rem > 0 - - # Scale tensor to [-1, 1] - A_reshaped = A.reshape(n) - A_com = A_reshaped[: n - rem] - A_com_reshaped = A_com.reshape(n // blocksize, blocksize) - absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) - scaled_A = scaled_A.reshape(-1) - if has_rem: - absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() - scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) - scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) - # map [-1, 1] to nf4/fp4 - out_uint8 = torch.empty(scaled_A.shape, dtype=torch.uint8, device=A.device) - if quant_type == "nf4": - for i in range(len(NF4_QUANT_TABLE)): - out_uint8[scaled_A > NF4_QUANT_TABLE[i]] = i - elif quant_type == "fp4": - sign = scaled_A < 0 - abs_scaled_A = torch.abs(scaled_A) - for key, val in FP4_QUANT_TABLE.items(): - out_uint8[abs_scaled_A > key] = val - out_uint8 += sign.to(torch.uint8) * 8 - - if out_uint8.size(-1) % 2: - out_uint8 = torch.nn.functional.pad(out_uint8, (0, 1), value=0) - out[:] = out_uint8[::2].bitwise_left_shift(4).bitwise_or_(out_uint8[1::2]) - - if quant_storage != torch.uint8: - bytes_value = out.cpu().numpy().tobytes() - out = torch.frombuffer(bytes_value, dtype=quant_storage).to(A.device) - - return out.reshape(-1, 1), absmax - - -# Compile will fail in torch.frombuffer -# @_maybe_torch_compile -def dequantize_4bit_impl( - A: Tensor, - quant_state=None, - absmax: Tensor = None, - blocksize: int = 64, - quant_type="nf4", - out: Tensor = None, -) -> Tensor: - """ - Dequantizes 4-bit blockwise quantized values. - Dequantizes the tensor A with maximum absolute values absmax in blocks of size blocksize. - - Parameters - ---------- - A : torch.Tensor - The input 8-bit tensor (packed 4-bit values). - quant_state : QuantState - object with quantisation stats, incl. absmax values, original tensor shape and original dtype. - absmax : torch.Tensor - The absmax values. - out : torch.Tensor - Dequantized output tensor. - blocksize : int - The blocksize used in quantization. - quant_type : str - The 4-bit quantization data type {fp4, nf4} - - Returns - ------- - torch.Tensor: - Dequantized tensor. - """ - # For NF4, ipex have dequant kernel. - if quant_type == "nf4" and getattr(quant_state, "ipex", False): - out = torch.ops.torch_ipex.dequantize_4bit(A, "nf4", quant_state.shape, absmax, None, blocksize).t() - return out - - transpose = True if A.shape[0] == 1 else False - A = A.reshape(-1) - device = A.device - if A.dtype != torch.uint8: - bytes_value = A.cpu().numpy().tobytes() - A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device) - - if quant_state is None: - assert absmax is not None and out is not None - - quant_state = QuantState( - absmax=absmax, - shape=out.shape, - dtype=out.dtype, - blocksize=blocksize, - quant_type=quant_type, - ) - - else: - absmax = quant_state.absmax - - if quant_type not in ["nf4", "fp4"]: - raise NotImplementedError( - f"4-bit quantization data type {quant_state.quant_type} is not implemented for CPU/XPU." - ) - - if ipex_cpu_only and _ipex_cpu_version_prereq(2, 5) and getattr(quant_state, "ipex", False): - ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", quant_state.shape, 2) - A = reverse_4bit_compress_format(ipex_weight) - quant_state.ipex = False - - # Map nf4 to [-1, 1] - out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) - n = out_dq.numel() - out_dq[1::2] = A & 0xF - out_dq[::2] = A >> 4 - # quant_state.code is fp32, cast to quant_state dtype to avoid the mismatch issue - quant_state.code = quant_state.code.to(quant_state.dtype) - out_dq = quant_state.code[out_dq] - - # Apply scales - if out_dq.numel() != n: - assert out_dq.numel() == n + 1 - out_dq = torch.narrow(out_dq, 0, 0, n) - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - rem = n % blocksize - has_rem = rem > 0 - - if has_rem: - if out is None: - out = torch.empty(quant_state.shape, dtype=quant_state.dtype, device=A.device) - out_reshaped = out.reshape(-1) - out_reshaped[: n - rem] = ( - out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1) - ).reshape(-1) - out_reshaped[n - rem :] = out_dq[n - rem :] * absmax[-1] - else: - out = (out_dq.view(-1, blocksize) * absmax.view(-1, 1)).reshape(quant_state.shape).to(quant_state.dtype) - - # take transpose here because weight is transposed (again) for computation - if transpose: - out = out.t() - - return out - - -# Do not need torch.compile here as we are calling torch/ipex kernel -def gemm_4bit_impl( - A: torch.Tensor, - B: torch.Tensor, - state: QuantState = None, - out: Optional[torch.Tensor] = None, -) -> torch.Tensor: - """ - Matrix-matrix multiplication with 4-bit quantization. - - Parameters - ---------- - A : torch.Tensor - The first input tensor. Usually the activation tensor. - B : torch.Tensor - The second input tensor. Usually the weight tensor. - out : torch.Tensor - The output tensor. - transposed_A : bool - Whether A is transposed - transposed_B : bool - Whether B is transposed - state : QuantState - Contains quantization info, such as blocksize and dtype - - Returns - ------- - torch.Tensor: - GEMM output tensor. - """ - if getattr(state, "ipex", False): - # compute_dtype: 1 indicates fp16, 2 indicates bf16 - compute_dtype = 2 if A.dtype == torch.bfloat16 else 1 - output = torch.ops.torch_ipex.woq_linear( - A, - B, - "nf4", - state.shape, - state.new_scales, - state.new_zeros, - None, - None, - state.blocksize, - compute_dtype, - 1, - state.compensation, - ) - else: - dqB = dequantize_4bit_impl(B, state, blocksize=state.blocksize) - output = torch.matmul(A, dqB.to(A.dtype)) - if out is not None: - out.copy_(output) - else: - out = output - return out - - -# Currently only works for XPU -def dequantize_blockwise_ipex_impl( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor = None, -) -> torch.Tensor: - if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): - raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") - - # void cdequantize_blockwise_fp32(float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) - if dtype == torch.float16: - ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) - elif dtype == torch.bfloat16: - ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) - elif dtype == torch.float32: - ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) - else: - raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") - - -def optimizer_update_8bit_blockwise( - optimizer_name: str, - g: torch.Tensor, - p: torch.Tensor, - state1: torch.Tensor, - state2: Optional[torch.Tensor], - beta1: float, - beta2: float, - beta3: float, - alpha: float, - eps: float, - step: int, - lr: float, - qmap1: torch.Tensor, - qmap2: Optional[torch.Tensor], - absmax1: torch.Tensor, - absmax2: Optional[torch.Tensor], - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, - skip_zeros=False, -) -> None: - optim_func = None - if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): - raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") - - if g.dtype == torch.float32 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][0] - elif g.dtype == torch.float16 and state1.dtype == torch.uint8: - optim_func = str2optimizer8bit_blockwise[optimizer_name][1] - elif ( - g.dtype == torch.bfloat16 - and state1.dtype == torch.uint8 - and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 - ): - optim_func = str2optimizer8bit_blockwise[optimizer_name][2] - else: - raise ValueError( - f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", - ) - optim_func( - p, - g, - state1, - state2, - beta1, - beta2, - beta3, - alpha, - eps, - step, - lr, - qmap1, - qmap2, - absmax1, - absmax2, - weight_decay, - gnorm_scale, - skip_zeros, - g.numel() - ) - - -def optimizer_update_32bit( - optimizer_name: str, - g: torch.Tensor, - p: torch.Tensor, - state1: torch.Tensor, - beta1: float, - eps: float, - step: int, - lr: float, - state2: Optional[torch.Tensor] = None, - beta2: float = 0.0, - beta3: float = 0.0, - alpha: float = 0.0, - weight_decay: float = 0.0, - gnorm_scale: float = 1.0, - unorm_vec: Optional[torch.Tensor] = None, - max_unorm: float = 0.0, - skip_zeros=False, -) -> None: - raise NotImplementedError diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index 5d78881..9037911 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -1,25 +1,22 @@ from collections.abc import Sequence -from typing import Optional import math +from typing import Optional import torch -from .cpu_xpu_common import ( +from .xpu import ( QuantState, - int8_linear_matmul_impl, - int8_double_quant_impl, - int8_vectorwise_quant_impl, - int8_mm_dequant_impl, - quantize_4bit_impl, + _ipex_xpu_version_prereq, dequantize_4bit_impl, - quantize_blockwise_impl, dequantize_blockwise_impl, - gemm_4bit_impl, dequantize_blockwise_ipex_impl, - optimizer_update_8bit_blockwise, + gemv_4bit_impl, + int8_linear_matmul_impl, + int8_mm_dequant_impl, ipex_xpu, - ipex_cpu_only, - _ipex_xpu_version_prereq, + optimizer_update_8bit_blockwise, + quantize_4bit_impl, + quantize_blockwise_impl, ) print("Loading ops module") @@ -29,71 +26,32 @@ def register_xpu_ops(): print("Registering XPU implementations") # Register the int8_linear_matmul implementation - @torch.library.impl("bitsandbytes::int8_linear_matmul", "XPU") + @torch.library.impl("bitsandbytes::int8_linear_matmul", "xpu") def int8_linear_matmul_xpu(A: torch.Tensor, B: torch.Tensor): - return int8_linear_matmul_impl(A, B) - @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "XPU") - def int8_linear_matmul_xpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - return int8_linear_matmul_impl(A, B, out) - - # Register the int8_double_quant implementation - @torch.library.impl("bitsandbytes::int8_double_quant", "XPU") - def int8_double_quant_xpu( - A: torch.Tensor, - threshold: float = 0.0, - col_stats: torch.Tensor = None, - row_stats: torch.Tensor = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return int8_double_quant_impl(A, threshold, col_stats, row_stats) - @torch.library.impl("bitsandbytes::int8_double_quant.out", "XPU") - def int8_double_quant_xpu_out( - A: torch.Tensor, - threshold: float = 0.0, - col_stats: torch.Tensor = None, - row_stats: torch.Tensor = None, - out_col: torch.Tensor = None, - out_row: torch.Tensor = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return int8_double_quant_impl(A, threshold, col_stats, row_stats, out_col, out_row) + return int8_linear_matmul_impl(A, B) - # Register the int8_vectorwise_quant implementation - @torch.library.impl("bitsandbytes::int8_vectorwise_quant", "XPU") - def int8_vectorwise_quant_xpu( - A: torch.Tensor, - threshold: float = 0.0, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return int8_vectorwise_quant_impl(A, threshold) + @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "xpu") + def int8_linear_matmul_xpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): + return int8_linear_matmul_impl(A, B) # Register the int8_mm_dequant implementation - @torch.library.impl("bitsandbytes::int8_mm_dequant", "XPU") + @torch.library.impl("bitsandbytes::int8_mm_dequant", "xpu") def int8_mm_dequant_xpu( A: torch.Tensor, row_stats: torch.Tensor, col_stats: torch.Tensor, - bias: torch.Tensor = None, - compute_dtype=torch.float32, - output_dtype=torch.float32, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: - return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype) - @torch.library.impl("bitsandbytes::int8_mm_dequant.out", "XPU") - def int8_mm_dequant_xpu_out( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - bias: torch.Tensor = None, - compute_dtype = torch.float32, - output_dtype = torch.float32, - out: torch.Tensor = None, - ) -> torch.Tensor: - return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype, out) + return int8_mm_dequant_impl(A, row_stats, col_stats, dtype, bias) # Register the quantize_4bit implementation - @torch.library.impl("bitsandbytes::quantize_4bit", "XPU") + @torch.library.impl("bitsandbytes::quantize_4bit", "xpu") def quantize_4bit_xpu( A: torch.Tensor, - blocksize=64, - quant_type="nf4", - quant_storage=torch.uint8, + blocksize: int, + quant_type: str, + quant_storage: torch.dtype, ) -> tuple[torch.Tensor, torch.Tensor]: return quantize_4bit_impl( A, @@ -103,28 +61,20 @@ def quantize_4bit_xpu( ) # Register the dequantize_4bit implementation - @torch.library.impl("bitsandbytes::dequantize_4bit", "XPU") + @torch.library.impl("bitsandbytes::dequantize_4bit", "xpu") def dequantize_4bit_xpu( A: torch.Tensor, - quant_state = None, - absmax: torch.Tensor = None, - blocksize: int = 64, - quant_type = "nf4", - ) -> torch.Tensor: - return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type) - @torch.library.impl("bitsandbytes::dequantize_4bit.out", "XPU") - def dequantize_4bit_xpu_out( - A: torch.Tensor, - quant_state = None, - absmax: torch.Tensor = None, - blocksize: int = 64, - quant_type = "nf4", - out: torch.Tensor = None, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, ) -> torch.Tensor: - return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type, out) + out = torch.empty(shape, dtype=dtype, device=A.device) + return dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype, out) # Register the quantize_blockwise implementation - @torch.library.impl("bitsandbytes::quantize_blockwise", "XPU") + @torch.library.impl("bitsandbytes::quantize_blockwise", "xpu") def quantize_blockwise_xpu( A: torch.Tensor, code: torch.Tensor, @@ -138,7 +88,7 @@ def quantize_blockwise_xpu( else: dequantize_blockwise = dequantize_blockwise_impl - @torch.library.impl("bitsandbytes::dequantize_blockwise", "XPU") + @torch.library.impl("bitsandbytes::dequantize_blockwise", "xpu") def dequantize_blockwise_xpu( A: torch.Tensor, absmax: torch.Tensor, @@ -147,36 +97,18 @@ def dequantize_blockwise_xpu( dtype: torch.dtype, ) -> torch.Tensor: return dequantize_blockwise(A, absmax, code, blocksize, dtype) - @torch.library.impl("bitsandbytes::dequantize_blockwise.out", "XPU") - def dequantize_blockwise_xpu_out( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, - ) -> torch.Tensor: - return dequantize_blockwise(A, absmax, code, blocksize, dtype, out) # Register the gemv_4bit implementation - @torch.library.impl("bitsandbytes::gemv_4bit", "XPU") + @torch.library.impl("bitsandbytes::gemv_4bit", "xpu") def gemv_4bit_xpu( A: torch.Tensor, B: torch.Tensor, state: QuantState = None, ) -> torch.Tensor: - return gemm_4bit_impl(A, B, state=state) - @torch.library.impl("bitsandbytes::gemv_4bit.out", "XPU") - def gemv_4bit_xpu_out( - A: torch.Tensor, - B: torch.Tensor, - state: QuantState = None, - out: torch.Tensor = None, - ) -> torch.Tensor: - return gemm_4bit_impl(A, B, state=state, out=out) + return gemv_4bit_impl(A, B, state=state) # Register the optimizer_update_8bit_blockwise implementation - @torch.library.impl("bitsandbytes::optimizer_update_8bit_blockwise", "XPU") + @torch.library.impl("bitsandbytes::optimizer_update_8bit_blockwise", "xpu") def optimizer_update_8bit_blockwise_xpu( optimizer_name: str, g: torch.Tensor, @@ -223,154 +155,6 @@ def optimizer_update_8bit_blockwise_xpu( print("Successfully registered XPU implementation") -def register_cpu_ops(): - print("Registering CPU implementations") - - # Register the int8_linear_matmul implementation - @torch.library.impl("bitsandbytes::int8_linear_matmul", "CPU") - def int8_linear_matmul_cpu(A: torch.Tensor, B: torch.Tensor): - return int8_linear_matmul_impl(A, B) - @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "CPU") - def int8_linear_matmul_cpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - return int8_linear_matmul_impl(A, B, out) - - # Register the int8_double_quant implementation - @torch.library.impl("bitsandbytes::int8_double_quant", "CPU") - def int8_double_quant_cpu( - A: torch.Tensor, - threshold: float = 0.0, - col_stats: torch.Tensor = None, - row_stats: torch.Tensor = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return int8_double_quant_impl(A, threshold, col_stats, row_stats) - @torch.library.impl("bitsandbytes::int8_double_quant.out", "CPU") - def int8_double_quant_cpu_out( - A: torch.Tensor, - threshold: float = 0.0, - col_stats: torch.Tensor = None, - row_stats: torch.Tensor = None, - out_col: torch.Tensor = None, - out_row: torch.Tensor = None, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - return int8_double_quant_impl(A, threshold, col_stats, row_stats, out_col, out_row) - - # Register the int8_vectorwise_quant implementation - @torch.library.impl("bitsandbytes::int8_vectorwise_quant", "CPU") - def int8_vectorwise_quant_cpu( - A: torch.Tensor, - threshold: float = 0.0, - ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return int8_vectorwise_quant_impl(A, threshold) - - # Register the int8_mm_dequant implementation - @torch.library.impl("bitsandbytes::int8_mm_dequant", "CPU") - def int8_mm_dequant_cpu( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - bias: torch.Tensor = None, - compute_dtype=torch.float32, - output_dtype=torch.float32, - ) -> torch.Tensor: - return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype) - @torch.library.impl("bitsandbytes::int8_mm_dequant.out", "CPU") - def int8_mm_dequant_cpu_out( - A: torch.Tensor, - row_stats: torch.Tensor, - col_stats: torch.Tensor, - bias: torch.Tensor = None, - compute_dtype = torch.float32, - output_dtype = torch.float32, - out: torch.Tensor = None, - ) -> torch.Tensor: - return int8_mm_dequant_impl(A, row_stats, col_stats, bias, compute_dtype, output_dtype, out) - - # Register the quantize_4bit implementation - @torch.library.impl("bitsandbytes::quantize_4bit", "CPU") - def quantize_4bit_cpu( - A: torch.Tensor, - blocksize=64, - quant_type="nf4", - quant_storage=torch.uint8, - ) -> tuple[torch.Tensor, torch.Tensor]: - return quantize_4bit_impl( - A, - blocksize, - quant_type, - quant_storage, - ) - - # Register the dequantize_4bit implementation - @torch.library.impl("bitsandbytes::dequantize_4bit", "CPU") - def dequantize_4bit_cpu( - A: torch.Tensor, - quant_state = None, - absmax: torch.Tensor = None, - blocksize: int = 64, - quant_type = "nf4", - ) -> torch.Tensor: - return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type) - @torch.library.impl("bitsandbytes::dequantize_4bit.out", "CPU") - def dequantize_4bit_cpu_out( - A: torch.Tensor, - quant_state = None, - absmax: torch.Tensor = None, - blocksize: int = 64, - quant_type = "nf4", - out: torch.Tensor = None, - ) -> torch.Tensor: - return dequantize_4bit_impl(A, quant_state, absmax, blocksize, quant_type, out) - - # Register the quantize_blockwise implementation - @torch.library.impl("bitsandbytes::quantize_blockwise", "CPU") - def quantize_blockwise_cpu( - A: torch.Tensor, - code: torch.Tensor, - blocksize: int, - ) -> tuple[torch.Tensor, torch.Tensor]: - return quantize_blockwise_impl(A, code, blocksize) - - # Register the dequantize_blockwise implementation - @torch.library.impl("bitsandbytes::dequantize_blockwise", "CPU") - def dequantize_blockwise_cpu( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - ) -> torch.Tensor: - return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype) - @torch.library.impl("bitsandbytes::dequantize_blockwise.out", "CPU") - def dequantize_blockwise_cpu_out( - A: torch.Tensor, - absmax: torch.Tensor, - code: torch.Tensor, - blocksize: int, - dtype: torch.dtype, - out: torch.Tensor, - ) -> torch.Tensor: - return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out) - - # Register the gemv_4bit implementation - @torch.library.impl("bitsandbytes::gemv_4bit", "CPU") - def gemv_4bit_cpu( - A: torch.Tensor, - B: torch.Tensor, - state: QuantState = None, - ) -> torch.Tensor: - return gemm_4bit_impl(A, B, state=state) - @torch.library.impl("bitsandbytes::gemv_4bit.out", "CPU") - def gemv_4bit_cpu_out( - A: torch.Tensor, - B: torch.Tensor, - state: QuantState = None, - out: torch.Tensor = None, - ) -> torch.Tensor: - return gemm_4bit_impl(A, B, state=state, out=out) - - print("Successfully registered CPU implementation") - - def register_hpu_ops(): print("Registering HPU implementations") @@ -410,10 +194,8 @@ def register_ops(): if ipex_xpu: register_xpu_ops() - elif ipex_cpu_only: - register_cpu_ops() # TODO: Need to check HPU - else: + elif hasattr(torch.backends, "hpu") and torch.backends.hpu.is_available(): register_hpu_ops() From 51f260fd9d29eadfeb85dad4f3e04b3051f811e2 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 8 May 2025 12:42:50 +0000 Subject: [PATCH 08/22] fix op name Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/ops.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index 9037911..c9604c9 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -8,8 +8,8 @@ QuantState, _ipex_xpu_version_prereq, dequantize_4bit_impl, - dequantize_blockwise_impl, dequantize_blockwise_ipex_impl, + dequantize_blockwise_torch_impl, gemv_4bit_impl, int8_linear_matmul_impl, int8_mm_dequant_impl, @@ -83,10 +83,9 @@ def quantize_blockwise_xpu( return quantize_blockwise_impl(A, code, blocksize) # Register the dequantize_blockwise implementation - if _ipex_xpu_version_prereq(2, 7): - dequantize_blockwise = dequantize_blockwise_ipex_impl - else: - dequantize_blockwise = dequantize_blockwise_impl + dequantize_blockwise_impl = ( + dequantize_blockwise_ipex_impl if _ipex_xpu_version_prereq(2, 7) else dequantize_blockwise_torch_impl + ) @torch.library.impl("bitsandbytes::dequantize_blockwise", "xpu") def dequantize_blockwise_xpu( @@ -96,7 +95,7 @@ def dequantize_blockwise_xpu( blocksize: int, dtype: torch.dtype, ) -> torch.Tensor: - return dequantize_blockwise(A, absmax, code, blocksize, dtype) + return dequantize_blockwise_impl(A, absmax, code, blocksize, dtype) # Register the gemv_4bit implementation @torch.library.impl("bitsandbytes::gemv_4bit", "xpu") From dfa8235180695984eeae76343436b0ccdf8f7468 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 8 May 2025 13:08:54 +0000 Subject: [PATCH 09/22] add xpu ops Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/xpu.py | 470 ++++++++++++++++++++++++++++++++++ 1 file changed, 470 insertions(+) create mode 100644 src/bitsandbytes_intel/xpu.py diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py new file mode 100644 index 0000000..bf6f33a --- /dev/null +++ b/src/bitsandbytes_intel/xpu.py @@ -0,0 +1,470 @@ +from collections.abc import Sequence +import subprocess +from typing import Optional, tuple +import warnings + +import torch +import torch.nn.functional as F + +try: + # to support Intel CPU/GPU (XPU) backend + import intel_extension_for_pytorch as ipex + + ipex_cpu = ipex if ipex._C._has_cpu() else None + ipex_xpu = ipex if ipex._C._has_xpu() else None + ipex_cpu_only = ipex._C._has_cpu() and (not ipex._C._has_xpu()) +except BaseException: + ipex_cpu = None + ipex_xpu = None + ipex_cpu_only = None + + +gxx_available = False +try: + subprocess.run(["g++", "--version"], capture_output=True) # hide terminal output + gxx_available = True +except BaseException: + warnings.warn("g++ not found, torch.compile disabled for CPU/XPU.") + + +Tensor = torch.Tensor + + +def _torch_version_prereq(major, minor): + ver_major = int(torch.__version__.split(".")[0]) + ver_minor = int(torch.__version__.split(".")[1]) + return ver_major * 32 + ver_minor >= major * 32 + minor + + +def _ipex_xpu_version_prereq(major, minor): + if ipex_xpu is not None: + ver_major = ipex_xpu.__version__.split(".")[0] + ver_minor = ipex_xpu.__version__.split(".")[1] + return int(ver_major) * 32 + int(ver_minor) >= major * 32 + minor + return False + + +str2optimizer8bit_blockwise = {} +if ipex_xpu is not None and _ipex_xpu_version_prereq(2, 7): + str2optimizer8bit_blockwise = { + "adam": ( + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp32, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_fp16, + ipex.xpu.bitsandbytes.cadam_8bit_blockwise_grad_bf16, + ), + } + + +def _maybe_torch_compile(func): + # torch.compile requires g++ and pytorch >= 2.0 + if gxx_available and _torch_version_prereq(2, 0) and not ipex_xpu: + options = {} + # fx_graph_cache requires pytorch >= 2.2 + if _torch_version_prereq(2, 2): + options.update({"fx_graph_cache": True}) + return torch.compile(func, dynamic=True, options=options) + return func + + +def transform( + A: torch.Tensor, + out: Optional[torch.Tensor] = None, + transpose=False, + state: Optional[tuple[torch.Size, str]] = None, +): + """ + Transform tensor A to to_order. It is originally designed for CUDA. + For CPU/XPU, it returns the original tensor if transpose=False. + Otherwise, it returns the transpose of A + """ + if transpose: + if out is not None: + out.copy_(A.T) + else: + out = A.T + else: + if out is not None: + out.copy_(A) + else: + out = A + return out, state + + +# Applied from cpu int8_linear_matmul op +def int8_linear_matmul_impl(A: torch.Tensor, B: torch.Tensor): + return torch._int_mm( + A.reshape(-1, A.shape[-1]), + B.t(), + ).reshape(*A.shape[:-1], B.shape[0]) + + +@_maybe_torch_compile +def int8_mm_dequant_impl( + A: torch.Tensor, + row_stats: torch.Tensor, + col_stats: torch.Tensor, + dtype: Optional[torch.dtype] = None, + bias: Optional[torch.Tensor] = None, +) -> torch.Tensor: + torch._check(A.dtype == torch.int32, lambda: f"A must be int32, got {A.dtype}") + torch._check(row_stats.dtype == torch.float32, lambda: f"row_stats must be float32, got {row_stats.dtype}") + torch._check(col_stats.dtype == torch.float32, lambda: f"col_stats must be float32, got {col_stats.dtype}") + + A_calc = A.view(-1, A.shape[-1]) + row_stats = row_stats.reshape(-1).unsqueeze(-1) + col_stats = col_stats.reshape(-1).unsqueeze(0) + + out = A_calc * (row_stats * col_stats) * 6.200124e-05 + if bias is not None: + out += bias + + return out.to(dtype or torch.float16) + + +_NF4_QUANT_TABLE = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=torch.float32, + device="cpu", +) +_FP4_QUANT_TABLE = torch.tensor( + [ + 0.0000, + 0.0052, + 0.6667, + 1.0000, + 0.3333, + 0.5000, + 0.1667, + 0.2500, + 0.0000, + -0.0052, + -0.6667, + -1.0000, + -0.3333, + -0.5000, + -0.1667, + -0.2500, + ], + dtype=torch.float32, + device="cpu", +) +CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} + + +def quantize_blockwise_impl( + A: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Quantize tensor A in blocks of 8-bit values. + + Quantizes tensor A by dividing it into blocks which are independently quantized to int8. + + Parameters + ---------- + A : torch.Tensor + The input tensor. + code : torch.Tensor + The quantization code. + blocksize : int + The blocksize used in quantization. + + Returns + ------- + torch.Tensor: + The 8-bit tensor with packed 4-bit values. + torch.Tensor: + The absmax. + """ + n = A.numel() + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + rem = n % blocksize + has_rem = rem > 0 + # Scale tensor to [-1, 1] + A_reshaped = A.reshape(n) + A_com = A_reshaped[: n - rem] + A_com_reshaped = A_com.reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled_A = scaled_A.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) + + diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) + out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) + + return out_uint8, absmax + + +def dequantize_blockwise_torch_impl( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, +) -> torch.Tensor: + assert A.dtype == torch.uint8 + out = code[A.reshape(-1).int()] + blocks = out.shape[-1] // blocksize + res = out.shape[-1] % blocksize + if res != 0: + out = F.pad(out, (0, blocksize - res), mode="constant", value=0) + out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1) + out = out[: blocks * blocksize + res] + out = out.reshape(A.shape) + + return out + + +# Currently only works for XPU +def dequantize_blockwise_ipex_impl( + A: torch.Tensor, + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, + dtype: torch.dtype, +) -> torch.Tensor: + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device) + # void cdequantize_blockwise_fp32( + # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) + if dtype == torch.float16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel()) + elif dtype == torch.bfloat16: + ipex.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel()) + elif dtype == torch.float32: + ipex.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel()) + else: + raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + + +# Copied from cpu quantize_4bit op +def quantize_4bit_impl( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}") + torch._check( + A.dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit quantization only supports 16/32-bit floats, but got {A.dtype}", + ) + + n = A.numel() + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + # Scale tensor to [-1, 1] + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + A_reshaped = A.reshape(n) + A_com_reshaped = A_reshaped[: n - rem].reshape(n // blocksize, blocksize) + absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] + scaled = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) + scaled = scaled.reshape(-1) + if has_rem: + absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() + scaled_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) + scaled = torch.cat([scaled, scaled_rem], dim=0) + # Quantize with the lookup table + quant_table = CODE[quant_type] + quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - quant_table), dim=-1, keepdim=True).to(torch.uint8) + + # Pack two quantized values per byte + packed = quantized[::2] << 4 | quantized[1::2] + + if quant_storage != torch.uint8: + packed = packed.squeeze().view(quant_storage).unsqueeze(1) + + return packed, absmax.float() + + +# Copied from cpu dequantize_4bit op +# Compile will fail in torch.frombuffer +# @_maybe_torch_compile +def dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> Tensor: + torch._check_is_size(blocksize) + torch._check(quant_type in ("nf4", "fp4"), lambda: f"quant_type must be nf4 or fp4 on CPU, got {quant_type}") + torch._check( + dtype in [torch.bfloat16, torch.float16, torch.float32], + lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", + ) + + # Enable non uint8 dtype + device = A.device + if A.dtype != torch.uint8: + bytes_value = A.cpu().numpy().tobytes() + A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device) + + A = A.reshape(-1) + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(dtype) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + + out = out.reshape(-1, *shape[1:]).to(dtype) + + return out + + +# Copied from cpu gemv_4bit op +def gemv_4bit_impl( + A: torch.Tensor, + B: torch.Tensor, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, +) -> torch.Tensor: + # Applied from dequantize_4bit + B = B.view(-1, 1) + upper = (B >> 4).to(torch.int64) + lower = (B & 0x0F).to(torch.int64) + blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) + B_dq = code[blocks] * absmax[:, None] + B_dq = B_dq.reshape(-1, *shapeB[1:]).to(A.dtype) + + # User called gemv with B.t(), so we need to transpose it back. + # if B.shape[0] == 1: + # B_dq = B_dq.t() + + return torch.nn.functional.linear( + A, + B_dq, + bias=None, + ) + + +def optimizer_update_8bit_blockwise( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + state2: Optional[torch.Tensor], + beta1: float, + beta2: float, + beta3: float, + alpha: float, + eps: float, + step: int, + lr: float, + qmap1: torch.Tensor, + qmap2: Optional[torch.Tensor], + absmax1: torch.Tensor, + absmax2: Optional[torch.Tensor], + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + skip_zeros=False, +) -> None: + optim_func = None + if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): + raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + + if g.dtype == torch.float32 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][0] + elif g.dtype == torch.float16 and state1.dtype == torch.uint8: + optim_func = str2optimizer8bit_blockwise[optimizer_name][1] + elif ( + g.dtype == torch.bfloat16 + and state1.dtype == torch.uint8 + and len(str2optimizer8bit_blockwise[optimizer_name]) == 3 + ): + optim_func = str2optimizer8bit_blockwise[optimizer_name][2] + else: + raise ValueError( + f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}", + ) + optim_func( + p, + g, + state1, + state2, + beta1, + beta2, + beta3, + alpha, + eps, + step, + lr, + qmap1, + qmap2, + absmax1, + absmax2, + weight_decay, + gnorm_scale, + skip_zeros, + g.numel(), + ) + + +def optimizer_update_32bit( + optimizer_name: str, + g: torch.Tensor, + p: torch.Tensor, + state1: torch.Tensor, + beta1: float, + eps: float, + step: int, + lr: float, + state2: Optional[torch.Tensor] = None, + beta2: float = 0.0, + beta3: float = 0.0, + alpha: float = 0.0, + weight_decay: float = 0.0, + gnorm_scale: float = 1.0, + unorm_vec: Optional[torch.Tensor] = None, + max_unorm: float = 0.0, + skip_zeros=False, +) -> None: + raise NotImplementedError From ae25c788f12fc211402d09f17499fcdd9b2e69ed Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 8 May 2025 13:10:07 +0000 Subject: [PATCH 10/22] fix tuple Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/xpu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py index bf6f33a..a27dd85 100644 --- a/src/bitsandbytes_intel/xpu.py +++ b/src/bitsandbytes_intel/xpu.py @@ -1,6 +1,6 @@ from collections.abc import Sequence import subprocess -from typing import Optional, tuple +from typing import Optional import warnings import torch From 84bace74703b481f63adc1ab774c05b169d66e95 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 8 May 2025 13:11:42 +0000 Subject: [PATCH 11/22] fix gemv 4bit Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/ops.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index c9604c9..295b795 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -5,7 +5,6 @@ import torch from .xpu import ( - QuantState, _ipex_xpu_version_prereq, dequantize_4bit_impl, dequantize_blockwise_ipex_impl, @@ -102,9 +101,12 @@ def dequantize_blockwise_xpu( def gemv_4bit_xpu( A: torch.Tensor, B: torch.Tensor, - state: QuantState = None, + shapeB: Sequence[int], + absmax: torch.Tensor, + code: torch.Tensor, + blocksize: int, ) -> torch.Tensor: - return gemv_4bit_impl(A, B, state=state) + return gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize) # Register the optimizer_update_8bit_blockwise implementation @torch.library.impl("bitsandbytes::optimizer_update_8bit_blockwise", "xpu") From e3fd20f55e7a449e7791de4ce229d5139fc4fd4b Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 8 May 2025 13:13:44 +0000 Subject: [PATCH 12/22] fix quant table device Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/xpu.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py index a27dd85..a67697e 100644 --- a/src/bitsandbytes_intel/xpu.py +++ b/src/bitsandbytes_intel/xpu.py @@ -141,7 +141,7 @@ def int8_mm_dequant_impl( 1.0, ], dtype=torch.float32, - device="cpu", + device="xpu", ) _FP4_QUANT_TABLE = torch.tensor( [ @@ -163,7 +163,7 @@ def int8_mm_dequant_impl( -0.2500, ], dtype=torch.float32, - device="cpu", + device="xpu", ) CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} @@ -291,7 +291,7 @@ def quantize_4bit_impl( scaled_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) scaled = torch.cat([scaled, scaled_rem], dim=0) # Quantize with the lookup table - quant_table = CODE[quant_type] + quant_table = CODE[quant_type].to(scaled.device) quantized = torch.argmin(torch.abs(scaled.view(-1, 1) - quant_table), dim=-1, keepdim=True).to(torch.uint8) # Pack two quantized values per byte @@ -334,7 +334,7 @@ def dequantize_4bit_impl( out_dq[1::2] = A & 0xF out_dq[::2] = A >> 4 # code is fp32, cast to dtype to avoid the mismatch issue - code = CODE[quant_type].to(dtype) + code = CODE[quant_type].to(out_dq.device).to(dtype) out_dq = code[out_dq] # Apply scales From a151b865a9c8109a3eb8bc19c38c92b263e439f3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 8 May 2025 13:22:12 +0000 Subject: [PATCH 13/22] fix dequantize blockwise Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/xpu.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py index a67697e..9820fb2 100644 --- a/src/bitsandbytes_intel/xpu.py +++ b/src/bitsandbytes_intel/xpu.py @@ -261,6 +261,8 @@ def dequantize_blockwise_ipex_impl( else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") + return out + # Copied from cpu quantize_4bit op def quantize_4bit_impl( From 7a2175fa1cea6228dc21195fd14435a842240cbe Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Thu, 8 May 2025 13:28:07 +0000 Subject: [PATCH 14/22] fix dequantize 4bit Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/ops.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index 295b795..c7a86eb 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -69,8 +69,7 @@ def dequantize_4bit_xpu( shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: - out = torch.empty(shape, dtype=dtype, device=A.device) - return dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype, out) + return dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) # Register the quantize_blockwise implementation @torch.library.impl("bitsandbytes::quantize_blockwise", "xpu") From 41359eab361ade3091db83fd3602b5732f72ad91 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 9 May 2025 10:33:36 +0000 Subject: [PATCH 15/22] register ipex op Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/ops.py | 48 ++++++++++++++++++++++++++++++++++- 1 file changed, 47 insertions(+), 1 deletion(-) diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index c7a86eb..8e353fb 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -12,6 +12,7 @@ gemv_4bit_impl, int8_linear_matmul_impl, int8_mm_dequant_impl, + ipex_cpu, ipex_xpu, optimizer_update_8bit_blockwise, quantize_4bit_impl, @@ -187,16 +188,61 @@ def quantize_4bit_hpu( print("Successfully registered HPU implementations") +def register_ipex_ops(): + print("Registering IPEX implementations") + + # Register the dequantize_nf4_ipex implementation + if ipex_cpu: + from bitsandbytes.utils import _reverse_4bit_compress_format + + @torch.library.impl("bitsandbytes::dequantize_nf4_ipex", "cpu") + def dequantize_nf4_ipex_cpu( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2) + A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1) + return torch.ops.bitsandbytes.dequantize_4bit.default( + A, + absmax, + blocksize, + "nf4", + shape, + dtype, + ) + + if ipex_xpu: + + @torch.library.impl("bitsandbytes::dequantize_nf4_ipex", "xpu") + def dequantize_nf4_ipex_xpu( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype) + + print("Successfully registered IPEX implementation") + + def register_ops(): # Check if the operator exists if not hasattr(torch.ops.bitsandbytes, "int8_linear_matmul"): raise RuntimeError("bitsandbytes::int8_linear_matmul not found! Make sure bitsandbytes is installed") - if ipex_xpu: + if hasattr(torch, "xpu") and torch.xpu.is_available(): register_xpu_ops() # TODO: Need to check HPU elif hasattr(torch.backends, "hpu") and torch.backends.hpu.is_available(): register_hpu_ops() + if ipex_cpu or ipex_xpu: + register_ipex_ops() print("ops module loaded") From 20d4bdf9ebeee2eef4f002360baeabca0febd00a Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 9 May 2025 10:56:26 +0000 Subject: [PATCH 16/22] fix dequantize_nf4_ipex Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/ops.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index 8e353fb..4c12026 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -192,6 +192,24 @@ def register_ipex_ops(): print("Registering IPEX implementations") # Register the dequantize_nf4_ipex implementation + torch.library.define( + "bitsandbytes::dequantize_nf4_ipex", + "(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor", + ) + + @torch.library.register_fake("bitsandbytes::dequantize_nf4_ipex") + def dequantize_nf4_ipex( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + shape: Sequence[int], + dtype: torch.dtype, + ) -> torch.Tensor: + raise NotImplementedError( + "bitsandbytes::dequantize_nf4_ipex is not implemented for default backend. " + "Please make sure you installed ipex to support Intel CPU or XPU." + ) + if ipex_cpu: from bitsandbytes.utils import _reverse_4bit_compress_format @@ -200,7 +218,6 @@ def dequantize_nf4_ipex_cpu( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, - quant_type: str, shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: @@ -222,7 +239,6 @@ def dequantize_nf4_ipex_xpu( A: torch.Tensor, absmax: torch.Tensor, blocksize: int, - quant_type: str, shape: Sequence[int], dtype: torch.dtype, ) -> torch.Tensor: From 30fa5c0c058c5bb2dbac027a22368b95b6434287 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 9 May 2025 12:42:05 +0000 Subject: [PATCH 17/22] simplify gemv Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/xpu.py | 18 ++++++------------ 1 file changed, 6 insertions(+), 12 deletions(-) diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py index 9820fb2..a6c7167 100644 --- a/src/bitsandbytes_intel/xpu.py +++ b/src/bitsandbytes_intel/xpu.py @@ -276,18 +276,16 @@ def quantize_4bit_impl( ) n = A.numel() - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 rem = n % blocksize has_rem = rem > 0 + blocks = n // blocksize + has_rem # Scale tensor to [-1, 1] absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) A_reshaped = A.reshape(n) A_com_reshaped = A_reshaped[: n - rem].reshape(n // blocksize, blocksize) absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] - scaled = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1) - scaled = scaled.reshape(-1) + scaled = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1).reshape(-1) if has_rem: absmax[-1] = torch.abs(A_reshaped[n - rem :]).max() scaled_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1) @@ -343,10 +341,10 @@ def dequantize_4bit_impl( if out_dq.numel() != n: assert out_dq.numel() == n + 1 out_dq = torch.narrow(out_dq, 0, 0, n) - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize has_rem = rem > 0 + blocks = n // blocksize + has_rem out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) if has_rem: @@ -370,12 +368,8 @@ def gemv_4bit_impl( blocksize: int, ) -> torch.Tensor: # Applied from dequantize_4bit - B = B.view(-1, 1) - upper = (B >> 4).to(torch.int64) - lower = (B & 0x0F).to(torch.int64) - blocks = torch.cat((upper, lower), dim=1).reshape(-1, blocksize) - B_dq = code[blocks] * absmax[:, None] - B_dq = B_dq.reshape(-1, *shapeB[1:]).to(A.dtype) + quant_type = "nf4" if code[1] > 0 else "fp4" + B_dq = dequantize_4bit_impl(B, absmax, blocksize, quant_type, shapeB, A.dtype) # User called gemv with B.t(), so we need to transpose it back. # if B.shape[0] == 1: From 9141674683a87990c40354ea0cfe9fdb3aa940c3 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 9 May 2025 13:04:19 +0000 Subject: [PATCH 18/22] check xpu Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/xpu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py index a6c7167..64a6b52 100644 --- a/src/bitsandbytes_intel/xpu.py +++ b/src/bitsandbytes_intel/xpu.py @@ -141,7 +141,7 @@ def int8_mm_dequant_impl( 1.0, ], dtype=torch.float32, - device="xpu", + device="xpu" if torch.xpu.is_available() else "cpu", ) _FP4_QUANT_TABLE = torch.tensor( [ @@ -163,7 +163,7 @@ def int8_mm_dequant_impl( -0.2500, ], dtype=torch.float32, - device="xpu", + device="xpu" if torch.xpu.is_available() else "cpu", ) CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} From 7eb98c4380f14d87a9c476671c10798172da04cd Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 9 May 2025 14:40:34 +0000 Subject: [PATCH 19/22] fix quantize blockwise output shape Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/xpu.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py index 64a6b52..5d84e95 100644 --- a/src/bitsandbytes_intel/xpu.py +++ b/src/bitsandbytes_intel/xpu.py @@ -195,12 +195,10 @@ def quantize_blockwise_impl( The absmax. """ n = A.numel() - blocks = n // blocksize - blocks += 1 if n % blocksize > 0 else 0 - absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) rem = n % blocksize has_rem = rem > 0 - # Scale tensor to [-1, 1] + blocks = n // blocksize + has_rem + absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) A_reshaped = A.reshape(n) A_com = A_reshaped[: n - rem] A_com_reshaped = A_com.reshape(n // blocksize, blocksize) @@ -213,7 +211,7 @@ def quantize_blockwise_impl( scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0) diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device)) - out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device) + out_uint8 = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape) return out_uint8, absmax From a7d314e4f596be21b7ceaa0656f9810839631456 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Fri, 9 May 2025 16:01:14 +0000 Subject: [PATCH 20/22] fix quant_storage bf16 Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/xpu.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py index 5d84e95..ff757b7 100644 --- a/src/bitsandbytes_intel/xpu.py +++ b/src/bitsandbytes_intel/xpu.py @@ -322,6 +322,9 @@ def dequantize_4bit_impl( # Enable non uint8 dtype device = A.device if A.dtype != torch.uint8: + if A.dtype == torch.bfloat16: + # Numpy does not support bfloat16 + A = A.view(torch.float16) bytes_value = A.cpu().numpy().tobytes() A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device) From 6b31eb3d886fc57f421b964651da727fc3bad7dc Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 12 May 2025 12:56:11 +0000 Subject: [PATCH 21/22] fix xpu ops Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/xpu.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/bitsandbytes_intel/xpu.py b/src/bitsandbytes_intel/xpu.py index ff757b7..2b35b59 100644 --- a/src/bitsandbytes_intel/xpu.py +++ b/src/bitsandbytes_intel/xpu.py @@ -198,7 +198,7 @@ def quantize_blockwise_impl( rem = n % blocksize has_rem = rem > 0 blocks = n // blocksize + has_rem - absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) A_reshaped = A.reshape(n) A_com = A_reshaped[: n - rem] A_com_reshaped = A_com.reshape(n // blocksize, blocksize) @@ -247,6 +247,7 @@ def dequantize_blockwise_ipex_impl( if ipex_xpu is None or not _ipex_xpu_version_prereq(2, 7): raise RuntimeError("Please install intel_extension_for_ipex >= 2.7 for 8bit optimizer backend on XPU device.") + shape = A.shape out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device) # void cdequantize_blockwise_fp32( # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream) @@ -259,7 +260,7 @@ def dequantize_blockwise_ipex_impl( else: raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}") - return out + return out.reshape(shape) # Copied from cpu quantize_4bit op @@ -279,7 +280,7 @@ def quantize_4bit_impl( blocks = n // blocksize + has_rem # Scale tensor to [-1, 1] - absmax = torch.zeros((blocks,), device=A.device, dtype=A.dtype) + absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32) A_reshaped = A.reshape(n) A_com_reshaped = A_reshaped[: n - rem].reshape(n // blocksize, blocksize) absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0] @@ -320,13 +321,8 @@ def dequantize_4bit_impl( ) # Enable non uint8 dtype - device = A.device if A.dtype != torch.uint8: - if A.dtype == torch.bfloat16: - # Numpy does not support bfloat16 - A = A.view(torch.float16) - bytes_value = A.cpu().numpy().tobytes() - A = torch.frombuffer(bytes_value, dtype=torch.uint8).to(device) + A = A.view(torch.uint8) A = A.reshape(-1) # Map nf4 to [-1, 1] @@ -369,7 +365,7 @@ def gemv_4bit_impl( blocksize: int, ) -> torch.Tensor: # Applied from dequantize_4bit - quant_type = "nf4" if code[1] > 0 else "fp4" + quant_type = "fp4" if code[1] > 0 else "nf4" B_dq = dequantize_4bit_impl(B, absmax, blocksize, quant_type, shapeB, A.dtype) # User called gemv with B.t(), so we need to transpose it back. From c23930c65745116ddb7474c415b1a01cb18caf46 Mon Sep 17 00:00:00 2001 From: jiqing-feng Date: Mon, 12 May 2025 13:28:13 +0000 Subject: [PATCH 22/22] fix xpu ops Signed-off-by: jiqing-feng --- src/bitsandbytes_intel/ops.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/bitsandbytes_intel/ops.py b/src/bitsandbytes_intel/ops.py index 4c12026..a631711 100644 --- a/src/bitsandbytes_intel/ops.py +++ b/src/bitsandbytes_intel/ops.py @@ -30,10 +30,6 @@ def register_xpu_ops(): def int8_linear_matmul_xpu(A: torch.Tensor, B: torch.Tensor): return int8_linear_matmul_impl(A, B) - @torch.library.impl("bitsandbytes::int8_linear_matmul.out", "xpu") - def int8_linear_matmul_xpu_out(A: torch.Tensor, B: torch.Tensor, out: torch.Tensor): - return int8_linear_matmul_impl(A, B) - # Register the int8_mm_dequant implementation @torch.library.impl("bitsandbytes::int8_mm_dequant", "xpu") def int8_mm_dequant_xpu(