diff --git a/.DS_Store b/.DS_Store new file mode 100644 index 000000000..c673260c3 Binary files /dev/null and b/.DS_Store differ diff --git a/CMakeLists.txt b/CMakeLists.txt index 3f28f8bd0..fe3aebe9b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -239,12 +239,32 @@ elseif(BUILD_MPS) add_compile_definitions(BUILD_MPS) file(MAKE_DIRECTORY "build") add_custom_command(OUTPUT "bitsandbytes/bitsandbytes.metallib" - COMMAND xcrun metal -c -o "build/bitsandbytes.air" ${METAL_FILES} + COMMAND xcrun metal -c -g -frecord-sources -gline-tables-only -o "build/bitsandbytes.air" ${METAL_FILES} COMMAND xcrun metallib "build/bitsandbytes.air" -o "bitsandbytes/bitsandbytes.metallib" DEPENDS "${METAL_FILES}" COMMENT "Compiling Metal kernels" VERBATIM) add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib") + if(NOT Torch_DIR) + find_package(Python3 COMPONENTS Interpreter) + if(Python3_EXECUTABLE) + execute_process( + COMMAND "${Python3_EXECUTABLE}" -c "import torch; import sys; sys.stdout.write(torch.utils.cmake_prefix_path)" + OUTPUT_VARIABLE TORCH_CMAKE_PREFIX_PATH + ERROR_VARIABLE TORCH_DETECT_ERROR + RESULT_VARIABLE TORCH_DETECT_RESULT + ) + if(TORCH_DETECT_RESULT EQUAL 0 AND TORCH_CMAKE_PREFIX_PATH) + list(APPEND CMAKE_PREFIX_PATH "${TORCH_CMAKE_PREFIX_PATH}") + endif() + endif() + endif() + find_package(Torch REQUIRED) + if(TORCH_CXX_FLAGS) + string(APPEND CMAKE_CXX_FLAGS " ${TORCH_CXX_FLAGS}") + endif() + set(BNB_TORCH_INCLUDE_DIRS ${TORCH_INCLUDE_DIRS}) + set(BNB_TORCH_LIBRARIES ${TORCH_LIBRARIES}) elseif(BUILD_XPU) list(APPEND SRC_FILES ${XPU_FILES}) string(APPEND BNB_OUTPUT_NAME "_xpu") @@ -351,7 +371,13 @@ if(BUILD_HIP) endif() if(BUILD_MPS) add_dependencies(bitsandbytes metallib) - target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") + target_link_libraries(bitsandbytes PRIVATE objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph") + if(BNB_TORCH_INCLUDE_DIRS) + target_include_directories(bitsandbytes PRIVATE ${BNB_TORCH_INCLUDE_DIRS}) + endif() + if(BNB_TORCH_LIBRARIES) + target_link_libraries(bitsandbytes PRIVATE ${BNB_TORCH_LIBRARIES}) + endif() endif() if(BUILD_XPU) set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'") diff --git a/__init__.py b/__init__.py new file mode 100644 index 000000000..c41923d0b --- /dev/null +++ b/__init__.py @@ -0,0 +1,30 @@ +"""Dispatcher shimming the editable layout. + +When this repository is used via ``pip install -e .`` the real Python +package lives under ``bitsandbytes/bitsandbytes``. Importing from the +workspace root (e.g. running scripts from ``.../ai/kernels``) would +otherwise resolve to this outer directory, yielding a namespace module +with no attributes. Import the inner package eagerly and mirror its +symbols so ``import bitsandbytes`` always behaves the same as the +installed wheel. +""" + +from __future__ import annotations + +import importlib +from types import ModuleType + +_inner: ModuleType = importlib.import_module(".bitsandbytes", __name__) + +# Copy dunder metadata expected by consumers. +for _name in ("__all__", "__doc__", "__file__", "__loader__", "__path__", "__spec__", "__version__"): + if hasattr(_inner, _name): + globals()[_name] = getattr(_inner, _name) + +# Re-export public symbols while leaving dunders alone. +for _name, _value in vars(_inner).items(): + if not _name.startswith("__"): + globals()[_name] = _value + +del _inner, _name, _value, ModuleType, importlib + diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py index 8bea82fb3..8d4fa4c7e 100644 --- a/bitsandbytes/__init__.py +++ b/bitsandbytes/__init__.py @@ -9,7 +9,7 @@ import torch -from . import _ops, research, utils +from . import _ops, nn, research, utils from .autograd._functions import ( MatmulLtState, matmul, @@ -38,6 +38,9 @@ if hasattr(torch, "xpu") and torch.xpu.is_available(): from .backends.xpu import ops as xpu_ops +if hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + from .backends.mps import ops as mps_ops + if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"): # In case not automatically imported import habana_frameworks.torch diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py index a0f0d2a34..9ab44a7e4 100644 --- a/bitsandbytes/backends/default/ops.py +++ b/bitsandbytes/backends/default/ops.py @@ -189,8 +189,7 @@ def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, return out -@register_kernel("bitsandbytes::quantize_4bit", "default") -def _( +def _quantize_4bit_impl( A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype ) -> tuple[torch.Tensor, torch.Tensor]: torch._check_is_size(blocksize) @@ -232,6 +231,13 @@ def _( return packed, absmax.float() +@register_kernel("bitsandbytes::quantize_4bit", "default") +def _( + A: torch.Tensor, blocksize: int, quant_type: str, quant_storage: torch.dtype +) -> tuple[torch.Tensor, torch.Tensor]: + return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage) + + def _dequantize_4bit_impl( A: torch.Tensor, absmax: torch.Tensor, @@ -243,7 +249,6 @@ def _dequantize_4bit_impl( # Enable non uint8 dtype if A.dtype != torch.uint8: A = A.view(torch.uint8) - A = A.reshape(-1) # Map nf4 to [-1, 1] out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) @@ -290,7 +295,6 @@ def _( dtype in [torch.bfloat16, torch.float16, torch.float32], lambda: f"Blockwise 4bit dequantization only supports 16/32-bit floats, but got {dtype}", ) - return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) diff --git a/bitsandbytes/backends/mps/__init__.py b/bitsandbytes/backends/mps/__init__.py new file mode 100644 index 000000000..662a206e0 --- /dev/null +++ b/bitsandbytes/backends/mps/__init__.py @@ -0,0 +1,2 @@ +# MPS backend registrations are defined in ops.py + diff --git a/bitsandbytes/backends/mps/ops.py b/bitsandbytes/backends/mps/ops.py new file mode 100644 index 000000000..b74106ad2 --- /dev/null +++ b/bitsandbytes/backends/mps/ops.py @@ -0,0 +1,202 @@ +from collections.abc import Sequence +from typing import Optional + +import ctypes as ct +from ctypes import _CFuncPtr +import torch + +from ..._ops import register_kernel +from ...cextension import lib +from ..default.ops import _dequantize_4bit_impl, _quantize_4bit_impl +from ..utils import CODE +from .shim import MPSTensorShim#, configure_mps_blockwise_kernel + + +def _check_mps_device(tensor: torch.Tensor, name: str) -> None: + torch._check( + tensor.device.type == "mps", + lambda: f"{name} must live on an MPS device for the MPS backend, got {tensor.device.type}", + ) + + +def _supports_dtype(dtype: torch.dtype) -> bool: + return dtype in (torch.float16, torch.float32, torch.bfloat16) + + +def _kernel_dtype(dtype: torch.dtype) -> torch.dtype: + if dtype == torch.bfloat16: + return torch.float32 + return dtype + + +def _resolve_quant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]: + try: + if dtype == torch.float16: + fn = getattr( + lib, + "cquantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp16_nf4", + ) + # configure_mps_blockwise_kernel(fn) + return fn + if dtype == torch.float32: + fn = getattr( + lib, + "cquantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cquantize_blockwise_fp32_nf4", + ) + # configure_mps_blockwise_kernel(fn) + return fn + except AttributeError: + return None + return None + + +def _resolve_dequant_fn(dtype: torch.dtype, quant_type: str) -> Optional[_CFuncPtr]: + try: + if dtype == torch.float16: + fn = getattr( + lib, + "cdequantize_blockwise_fp16_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp16_nf4", + ) + # configure_mps_blockwise_kernel(fn) + return fn + if dtype == torch.float32: + fn = getattr( + lib, + "cdequantize_blockwise_fp32_fp4" if quant_type == "fp4" else "cdequantize_blockwise_fp32_nf4", + ) + # configure_mps_blockwise_kernel(fn) + return fn + except AttributeError: + return None + return None + + +def _quantize_4bit_native( + A: torch.Tensor, + blocksize: int, + quant_type: str, + quant_storage: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor] | None: + if quant_storage != torch.uint8 or not _supports_dtype(A.dtype): + return None + + kernel_dtype = _kernel_dtype(A.dtype) + fn = _resolve_quant_fn(kernel_dtype, quant_type) + if fn is None: + return None + + if kernel_dtype != A.dtype: + A_kernel = A.to(kernel_dtype) + else: + A_kernel = A + + n = A.numel() + blocks = -(n // -blocksize) + absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32) + out = torch.empty(((n + 1) // (quant_storage.itemsize * 2), 1), device=A.device, dtype=quant_storage) + + input_shim = MPSTensorShim.from_tensor(A_kernel) + absmax_shim = MPSTensorShim.from_tensor(absmax) + out_shim = MPSTensorShim.from_tensor(out) + + fn( + input_shim.struct, + absmax_shim.struct, + out_shim.struct, + ct.c_int32(blocksize), + ct.c_int32(n), + ) + return out, absmax + + +def _dequantize_4bit_native( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + dtype: torch.dtype, + out: torch.Tensor, +) -> bool: + if A.dtype != torch.uint8 or not _supports_dtype(dtype): + return False + + _check_mps_device(absmax, "absmax") + kernel_dtype = _kernel_dtype(dtype) + fn = _resolve_dequant_fn(kernel_dtype, quant_type) + if fn is None: + return False + + packed_shim = MPSTensorShim.from_tensor(A) + absmax_shim = MPSTensorShim.from_tensor(absmax) + if kernel_dtype != dtype: + work_out = torch.empty_like(out, dtype=kernel_dtype) + else: + work_out = out + out_shim = MPSTensorShim.from_tensor(work_out) + + fn( + packed_shim.struct, + absmax_shim.struct, + out_shim.struct, + ct.c_int32(blocksize), + ct.c_int32(out.numel()), + ) + + if work_out is not out: + out.copy_(work_out.to(dtype)) + + return True + + +@register_kernel("bitsandbytes::quantize_4bit", "mps") +def _( + A: torch.Tensor, + blocksize: int, + quant_type: str, + quant_storage: torch.dtype, +) -> tuple[torch.Tensor, torch.Tensor]: + _check_mps_device(A, "A") + # result = _quantize_4bit_native(A, blocksize, quant_type, quant_storage) + # if result is not None: + # return result + return _quantize_4bit_impl(A, blocksize, quant_type, quant_storage) + + +@register_kernel("bitsandbytes::dequantize_4bit", "mps") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, +) -> torch.Tensor: + _check_mps_device(A, "A") + _check_mps_device(absmax, "absmax") + out = torch.empty(shape, dtype=dtype, device=A.device) + if _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out): + return out + else: + raise RuntimeError("Failed to dequantize 4bit on MPS") + return _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) + + +@register_kernel("bitsandbytes::dequantize_4bit.out", "mps") +def _( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: Sequence[int], + dtype: torch.dtype, + out: torch.Tensor, +) -> None: + _check_mps_device(A, "A") + _check_mps_device(out, "out") + _check_mps_device(absmax, "absmax") + torch._check(out.shape == tuple(shape), lambda: f"Expected out.shape == {tuple(shape)}, got {out.shape}") + torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}") + + if not _dequantize_4bit_native(A, absmax, blocksize, quant_type, dtype, out): + result = _dequantize_4bit_impl(A, absmax, blocksize, quant_type, shape, dtype) + out.copy_(result) \ No newline at end of file diff --git a/bitsandbytes/backends/mps/shim.py b/bitsandbytes/backends/mps/shim.py new file mode 100644 index 000000000..53c18fd92 --- /dev/null +++ b/bitsandbytes/backends/mps/shim.py @@ -0,0 +1,63 @@ +import ctypes as ct +from dataclasses import dataclass +from typing import Callable + +import torch + + +class _BNBMPSTensor(ct.Structure): + _fields_ = [ + ("storage", ct.c_void_p), + ("byte_offset", ct.c_size_t), + ("nbytes", ct.c_size_t), + ] + + +@dataclass(slots=True) +class MPSTensorShim: + """ + Lightweight wrapper that keeps a Tensor alive while exposing its Metal storage. + + PyTorch stores an ``id`` inside the tensor's untyped storage data + pointer on MPS. We capture that pointer once and forward the storage offset + so native kernels can bind the correct buffer without any host copies. + """ + + tensor: torch.Tensor + struct: _BNBMPSTensor + + @classmethod + def from_tensor(cls, tensor: torch.Tensor) -> "MPSTensorShim": + if hasattr(tensor, "untyped_storage"): + storage = tensor.untyped_storage() + else: + storage = tensor.storage() + + storage_ptr = storage.data_ptr() + byte_offset = tensor.storage_offset() * tensor.element_size() + nbytes = tensor.nbytes + + struct = _BNBMPSTensor( + ct.c_void_p(storage_ptr), + ct.c_size_t(byte_offset), + ct.c_size_t(nbytes), + ) + return cls(tensor=tensor, struct=struct) + + +# def configure_mps_blockwise_kernel(fn: Callable[[object], None]) -> None: +# """ +# Ensure ctypes knows the function expects our tensor shim structs by value. +# """ + +# try: +# argtypes = getattr(fn, "argtypes") +# except AttributeError: +# argtypes = None + +# desired = [_BNBMPSTensor, _BNBMPSTensor, _BNBMPSTensor, ct.c_int32, ct.c_int32] +# if argtypes != desired: +# fn.argtypes = desired +# if getattr(fn, "restype", None) is not None: +# fn.restype = None + diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 188576225..48d933384 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -278,19 +278,35 @@ def get_native_library() -> BNBNativeLibrary: """ Load CUDA library XOR CPU, as the latter contains a subset of symbols of the former. """ - cuda_specs = get_cuda_specs() - binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" - - if cuda_specs: - cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) - - if not cuda_binary_path.exists(): - raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {cuda_binary_path}") - - binary_path = cuda_binary_path - - if torch._C._has_xpu: + cpu_binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" + binary_path = cpu_binary_path + + if BNB_BACKEND in {"CUDA", "ROCm"}: + cuda_specs = get_cuda_specs() + if cuda_specs: + candidate = get_cuda_bnb_library_path(cuda_specs) + if not candidate.exists(): + raise RuntimeError(f"Configured {BNB_BACKEND} binary not found at {candidate}") + binary_path = candidate + else: + logger.warning( + "bitsandbytes: CUDA/ROCm backend requested but PyTorch did not expose runtime specs; " + "falling back to CPU implementation." + ) + elif BNB_BACKEND == "XPU": binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}" + elif BNB_BACKEND == "MPS": + binary_path = PACKAGE_DIR / f"libbitsandbytes_mps{DYNAMIC_LIBRARY_SUFFIX}" + + if not binary_path.exists(): + if BNB_BACKEND == "MPS": + logger.warning( + "bitsandbytes: libbitsandbytes_mps was not found. Falling back to CPU kernels; " + "MPS-specific optimizations will be unavailable." + ) + binary_path = cpu_binary_path + else: + raise RuntimeError(f"bitsandbytes: native library not found at {binary_path}") logger.debug(f"Loading bitsandbytes native library from: {binary_path}") @@ -313,6 +329,8 @@ def get_native_library() -> BNBNativeLibrary: BNB_BACKEND = "ROCm" elif torch.cuda.is_available(): BNB_BACKEND = "CUDA" +elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): + BNB_BACKEND = "MPS" elif torch._C._has_xpu: BNB_BACKEND = "XPU" diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py index d3332acfe..e8561a893 100644 --- a/bitsandbytes/nn/modules.py +++ b/bitsandbytes/nn/modules.py @@ -527,7 +527,6 @@ def _save_to_state_dict(self, destination, prefix, keep_vars): def forward(self, x: torch.Tensor): fix_4bit_weight_quant_state_from_module(self) quant_state = self.weight.quant_state - if ( not getattr(quant_state, "packing_format_for_cpu", False) and x.device.type == "cpu" diff --git a/csrc/mps_kernels.metal b/csrc/mps_kernels.metal index 63b3bf78c..45bcc2d0e 100644 --- a/csrc/mps_kernels.metal +++ b/csrc/mps_kernels.metal @@ -1,117 +1,498 @@ #include +#include using namespace metal; -#define HLF_MAX 65504 -#define TH 1024 -#define NUM 4 -#define NUM_BLOCK 4096 - -template -static unsigned char quantize_scalar( - float rand, - device float* code, - float x) -{ - int pivot = 127; - int upper_pivot = 255; - int lower_pivot = 0; - - float lower = -1.0f; - float upper = 1.0f; - - float val = code[pivot]; - // i>>=1 = {32, 16, 8, 4, 2, 1} - for(int i = 64; i > 0; i>>=1) - { - if(x > val) - { - lower_pivot = pivot; - lower = val; - pivot+=i; +namespace { + +constant uint kQuantThreadsCapacity = 512; + +constant float NF4_CODE[16] = { + -1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, + -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, + 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, + 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0 +}; + +constant float FP4_CODE[16] = { + 0.0, 0.0052, 0.6667, 1.0, 0.3333, 0.5, 0.1667, 0.25, + 0.0, -0.0052, -0.6667, -1.0, -0.3333, -0.5, -0.1667, -0.25 +}; + +template +inline uchar encode_value(float value, constant float* code_table) { + float best = fabs(value - code_table[0]); + uchar index = 0; + for (uchar i = 1; i < 16; ++i) { + float diff = fabs(value - code_table[i]); + if (diff < best) { + best = diff; + index = i; } - else - { - upper_pivot = pivot; - upper = val; - pivot-=i; + } + return index; +} + +template +inline void quantize_block( + device const scalar_t* input, + device float* absmax, + device uchar* packed, + uint n, + uint blocksize, + uint block_index, + uint thread_idx, + uint threadgroup_size, + constant float* code_table, + threadgroup float* shared_thread_max, + threadgroup float& shared_scale, + uint simd_lane_id, + uint simd_group_id +) { + uint start = block_index * blocksize; + if (start >= n) { + return; + } + + uint end = min(start + blocksize, n); + float local_max = 0.0f; + for (uint i = start + thread_idx; i < end; i += threadgroup_size) { + float current = fabs((float)input[i]); + local_max = max(local_max, current); + } + + // SIMD reduction + local_max = simd_max(local_max); + + // Store SIMD group max to shared memory + if (simd_lane_id == 0) { + shared_thread_max[simd_group_id] = local_max; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + if (thread_idx == 0) { + float max_val = 0.0f; + uint num_simd_groups = (threadgroup_size + 31) / 32; + for (uint i = 0; i < num_simd_groups; ++i) { + max_val = max(max_val, shared_thread_max[i]); } - val = code[pivot]; - } - - if(upper_pivot == 255) - upper = code[upper_pivot]; - if(lower_pivot == 0) - lower = code[lower_pivot]; - - if(!STOCHASTIC) - { - if(x > val) - { - float midpoint = (upper+val)*0.5f; - if(x > midpoint) - { - return upper_pivot; + shared_scale = max_val; + } + threadgroup_barrier(mem_flags::mem_threadgroup); + + float max_val = shared_scale; + absmax[block_index] = max_val; + float inv = max_val > 0.0f ? 1.0f / max_val : 0.0f; + + uint pairs_in_block = (end - start + 1) >> 1; + uint out_byte = block_index * ((blocksize + 1) >> 1); + + for (uint pair = thread_idx; pair < pairs_in_block; pair += threadgroup_size) { + uint value_index0 = start + pair * 2; + float normalized0 = (max_val > 0.0f) ? clamp((float)input[value_index0] * inv, -1.0f, 1.0f) : 0.0f; + uchar nibble0 = encode_value(normalized0, code_table) & 0xF; + + uint value_index1 = value_index0 + 1; + uchar nibble1 = 0; + if (value_index1 < end) { + float normalized1 = (max_val > 0.0f) ? clamp((float)input[value_index1] * inv, -1.0f, 1.0f) : 0.0f; + nibble1 = encode_value(normalized1, code_table) & 0xF; } - else - return pivot; - } - else - { - float midpoint = (lower+val)*0.5f; - if(x < midpoint) - return lower_pivot; - else - return pivot; - } - } - else - { - if(x > val) - { - float dist_to_upper = fabs(upper-x); - float dist_full = upper-val; - if(rand >= dist_to_upper/dist_full) return upper_pivot; - else return pivot; - } - else - { - float dist_to_lower = fabs(lower-x); - float dist_full = val-lower; - if(rand >= dist_to_lower/dist_full) return lower_pivot; - else return pivot; - } + packed[out_byte + pair] = (nibble0 << 4) | nibble1; } + } -kernel void quantize(device float* code [[buffer(0)]], - device float* A [[buffer(1)]], - device uchar* out [[buffer(2)]], - constant uint& n [[buffer(3)]], - uint id [[thread_position_in_grid]]) { - const uint n_full = (NUM_BLOCK * (n / NUM_BLOCK)) + (n % NUM_BLOCK == 0 ? 0 : NUM_BLOCK); - uint valid_items = (id / NUM_BLOCK + 1 == (n + NUM_BLOCK - 1) / NUM_BLOCK) ? n - (id / NUM_BLOCK * NUM_BLOCK) : NUM_BLOCK; - const uint base_idx = (id / NUM_BLOCK * NUM_BLOCK); +template +inline void dequantize_block( + device const uchar* packed, + device const float* absmax, + device scalar_t* output, + uint n, + uint blocksize, + uint block_index, + uint thread_idx, + uint threadgroup_size, + constant float* code_table +) { + uint block_start = block_index * blocksize; + if (block_start >= n) { + return; + } + uint block_end; + if (block_start + blocksize < n) { + block_end = block_start + blocksize; + } else { + block_end = n; + } + uint pairs_in_block = (block_end - block_start + 1) >> 1; - float vals[NUM]; - uchar qvals[NUM]; + float scale = absmax[block_index]; - for (uint i = base_idx; i < n_full; i += ((n + NUM_BLOCK - 1) / NUM_BLOCK) * NUM_BLOCK) { - valid_items = n - i > NUM_BLOCK ? NUM_BLOCK : n - i; + // Precompute scaled table in registers - avoids threadgroup bank conflicts + // and constant memory is broadcast-optimized so initial loads are fast + float scaled_table[16]; + for (uint i = 0; i < 16; i++) { + scaled_table[i] = code_table[i] * scale; + } - threadgroup_barrier(mem_flags::mem_threadgroup); + for (uint pair = thread_idx; pair < pairs_in_block; pair += threadgroup_size) { + uint value_index0 = block_start + pair * 2; + if (value_index0 >= block_end) { + break; + } + + uint byte_index0 = value_index0 >> 1; + uchar byte_val0 = packed[byte_index0]; + // High nibble -> even index, low nibble -> odd index (matches Python ref) + uchar nibble0 = (byte_val0 >> 4) & 0xF; + uchar nibble1 = byte_val0 & 0xF; + float decoded0 = scaled_table[nibble0]; + float decoded1 = scaled_table[nibble1]; + // value_index0 is already the output index (block_start + pair*2) + output[value_index0] = scalar_t(decoded0); + + // Bounds check for odd-length blocks + if (value_index0 + 1 < block_end) { + output[value_index0 + 1] = scalar_t(decoded1); + } + } +} + +// template +// inline void dequantize_block( +// device const uchar* packed, +// device const float* absmax, +// device scalar_t* output, +// uint n, +// uint blocksize, +// uint block_index, +// uint thread_idx, +// uint threadgroup_size, +// constant float* code_table +// ) { +// const uint block_start = block_index * blocksize; +// if (block_start >= n) return; + +// const uint block_end = min(block_start + blocksize, n); +// const uint num_values = block_end - block_start; + +// const float scale = absmax[block_index]; + +// // Precompute scaled code table +// float scaled_table[16]; +// for (uint i = 0; i < 16; ++i) +// scaled_table[i] = code_table[i] * scale; + +// device const uchar* packed_ptr = packed + (block_start >> 1); +// device scalar_t* output_ptr = output + block_start; + +// // Each thread processes multiple *bytes* at a stride +// const uint bytes_in_block = (num_values + 1) >> 1; + +// for (uint byte_idx = thread_idx; byte_idx < bytes_in_block; byte_idx += threadgroup_size) { +// uchar byte_val = packed_ptr[byte_idx]; + +// // Decode upper and lower nibbles +// uchar upper_nib = (byte_val >> 4) & 0xF; +// uchar lower_nib = byte_val & 0xF; + +// // Compute global value index +// uint val_idx = byte_idx << 1; // byte_idx * 2 - for (uint j = 0; j < valid_items; j++) { - vals[j] = A[i + j]; +// // Write both values if in bounds +// if (val_idx < num_values) output_ptr[val_idx] = scalar_t(scaled_table[upper_nib]); +// if (val_idx + 1 < num_values) output_ptr[val_idx + 1] = scalar_t(scaled_table[lower_nib]); +// } +// } + +// template +// inline void dequantize_block( +// device const uchar* packed, +// device const float* absmax, +// device scalar_t* output, +// uint n, +// uint blocksize, +// uint block_index, +// uint thread_idx, +// uint threadgroup_size, +// constant float* code_table +// ) { +// const uint block_start = block_index * blocksize; +// if (block_start >= n) return; + +// const uint block_end = min(block_start + blocksize, n); +// const uint num_values = block_end - block_start; + +// const float scale = absmax[block_index]; + +// // Precompute scaled code table +// float scaled_table[16]; +// for (uint i = 0; i < 16; ++i) +// scaled_table[i] = code_table[i] * scale; + +// device const uchar* packed_ptr = packed + (block_start >> 1); +// device scalar_t* output_ptr = output + block_start; + +// // Each thread processes multiple uchar4 (4 bytes = 8 values) +// const uint num_bytes = (num_values + 1) >> 1; // total bytes in block +// const uint num_blocks = (num_bytes + 3) >> 2; // number of uchar4 blocks + +// for (uint block_idx = thread_idx; block_idx < num_blocks; block_idx += threadgroup_size) { +// uint byte_offset = block_idx * 4; // starting byte in packed array +// uchar4 b = uchar4(0); // default zero + +// // Load safely (handle tail) +// if (byte_offset + 3 < num_bytes) { +// b = *((device uchar4*)(packed_ptr + byte_offset)); +// } else { +// // Tail case: read remaining bytes safely +// uchar temp[4] = {0, 0, 0, 0}; +// for (uint i = 0; i < num_bytes - byte_offset; ++i) { +// temp[i] = packed_ptr[byte_offset + i]; +// } +// b = uchar4(temp[0], temp[1], temp[2], temp[3]); +// } + +// // Decode 8 nibbles into 8 values +// uchar nibbles[8] = { +// uchar((b.x >> 4) & 0xF), uchar(b.x & 0xF), +// uchar((b.y >> 4) & 0xF), uchar(b.y & 0xF), +// uchar((b.z >> 4) & 0xF), uchar(b.z & 0xF), +// uchar((b.w >> 4) & 0xF), uchar(b.w & 0xF) +// }; + +// // Compute global value indices and write outputs +// uint val_idx = byte_offset << 1; // byte_offset * 2 +// for (uint i = 0; i < 8; ++i) { +// if (val_idx + i < num_values) +// output_ptr[val_idx + i] = scalar_t(scaled_table[nibbles[i]]); +// } +// } +// } + +// template +// inline void dequantize_block( +// device const uchar* packed, +// device const float* absmax, +// device scalar_t* output, +// uint n, +// uint blocksize, +// uint block_index, +// uint thread_idx, +// uint threadgroup_size, +// constant float* code_table +// ) { +// const uint block_start = block_index * blocksize; +// if (block_start >= n) return; + +// const uint block_end = min(block_start + blocksize, n); +// const uint num_values = block_end - block_start; + +// const float scale = absmax[block_index]; + +// // Precompute scaled code table +// float scaled_table[16]; +// for (uint i = 0; i < 16; ++i) +// scaled_table[i] = code_table[i] * scale; + +// device const uchar* packed_ptr = packed + (block_start >> 1); +// device scalar_t* output_ptr = output + block_start; + +// const uint num_bytes = (num_values + 1) >> 1; // total bytes in block +// const uint num_uchar4 = (num_bytes + 3) >> 2; // total uchar4 blocks + +// // Each thread handles one or two uchar4 blocks +// uint block_pos = thread_idx; +// if (block_pos >= num_uchar4) return; + +// // Compute byte offset +// uint byte_offset = block_pos * 4; +// uchar4 b = uchar4(0, 0, 0, 0); + +// // Safe load +// if (byte_offset + 3 < num_bytes) { +// b = *((device uchar4*)(packed_ptr + byte_offset)); +// } else { +// uchar temp[4] = {0, 0, 0, 0}; +// for (uint i = 0; i < num_bytes - byte_offset; ++i) +// temp[i] = packed_ptr[byte_offset + i]; +// b = uchar4(temp[0], temp[1], temp[2], temp[3]); +// } + +// // Decode 8 nibbles +// uchar nibbles[8] = { +// uchar((b.x >> 4) & 0xF), uchar(b.x & 0xF), +// uchar((b.y >> 4) & 0xF), uchar(b.y & 0xF), +// uchar((b.z >> 4) & 0xF), uchar(b.z & 0xF), +// uchar((b.w >> 4) & 0xF), uchar(b.w & 0xF) +// }; + +// // Compute global value index +// uint val_idx = byte_offset << 1; // byte_offset * 2 + +// // Fully unrolled writes (branch-free for main values) +// if (val_idx + 0 < num_values) output_ptr[val_idx + 0] = scalar_t(scaled_table[nibbles[0]]); +// if (val_idx + 1 < num_values) output_ptr[val_idx + 1] = scalar_t(scaled_table[nibbles[1]]); +// if (val_idx + 2 < num_values) output_ptr[val_idx + 2] = scalar_t(scaled_table[nibbles[2]]); +// if (val_idx + 3 < num_values) output_ptr[val_idx + 3] = scalar_t(scaled_table[nibbles[3]]); +// if (val_idx + 4 < num_values) output_ptr[val_idx + 4] = scalar_t(scaled_table[nibbles[4]]); +// if (val_idx + 5 < num_values) output_ptr[val_idx + 5] = scalar_t(scaled_table[nibbles[5]]); +// if (val_idx + 6 < num_values) output_ptr[val_idx + 6] = scalar_t(scaled_table[nibbles[6]]); +// if (val_idx + 7 < num_values) output_ptr[val_idx + 7] = scalar_t(scaled_table[nibbles[7]]); +// } + +} // namespace + +// Quantization kernels +kernel void quantize_4bit_fp16_fp4( + device const half* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* packed [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]] +) { + if (tgid >= blocks || threadgroup_size > kQuantThreadsCapacity) { + return; } + threadgroup float shared_thread_max[kQuantThreadsCapacity]; + threadgroup float shared_scale; + quantize_block(input, absmax, packed, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_thread_max, shared_scale, simd_lane_id, simd_group_id); +} - for (uint j = 0; j < valid_items; j++) { - qvals[j] = quantize_scalar(0.0f, code, vals[j]); +kernel void quantize_4bit_fp16_nf4( + device const half* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* packed [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]] +) { + if (tgid >= blocks || threadgroup_size > kQuantThreadsCapacity) { + return; } + threadgroup float shared_thread_max[kQuantThreadsCapacity]; + threadgroup float shared_scale; + quantize_block(input, absmax, packed, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_thread_max, shared_scale, simd_lane_id, simd_group_id); +} - threadgroup_barrier(mem_flags::mem_threadgroup); +kernel void quantize_4bit_fp32_fp4( + device const float* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* packed [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]] +) { + if (tgid >= blocks || threadgroup_size > kQuantThreadsCapacity) { + return; + } + threadgroup float shared_thread_max[kQuantThreadsCapacity]; + threadgroup float shared_scale; + quantize_block(input, absmax, packed, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE, shared_thread_max, shared_scale, simd_lane_id, simd_group_id); +} - for (uint j = 0; j < valid_items; j++) { - out[i + j] = qvals[j]; +kernel void quantize_4bit_fp32_nf4( + device const float* input [[buffer(0)]], + device float* absmax [[buffer(1)]], + device uchar* packed [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]], + uint simd_lane_id [[thread_index_in_simdgroup]], + uint simd_group_id [[simdgroup_index_in_threadgroup]] +) { + if (tgid >= blocks || threadgroup_size > kQuantThreadsCapacity) { + return; } - } + threadgroup float shared_thread_max[kQuantThreadsCapacity]; + threadgroup float shared_scale; + quantize_block(input, absmax, packed, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE, shared_thread_max, shared_scale, simd_lane_id, simd_group_id); } + +// Dequantization kernels +kernel void dequantize_4bit_fp16_fp4( + device const uchar* packed [[buffer(0)]], + device const float* absmax [[buffer(1)]], + device half* output [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]] +) { + if (tgid >= blocks) { + return; + } + dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE); +} + +kernel void dequantize_4bit_fp16_nf4( + device const uchar* packed [[buffer(0)]], + device const float* absmax [[buffer(1)]], + device half* output [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]] +) { + if (tgid >= blocks) { + return; + } + dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE); +} + +kernel void dequantize_4bit_fp32_fp4( + device const uchar* packed [[buffer(0)]], + device const float* absmax [[buffer(1)]], + device float* output [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]] +) { + if (tgid >= blocks) { + return; + } + dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, FP4_CODE); +} + +kernel void dequantize_4bit_fp32_nf4( + device const uchar* packed [[buffer(0)]], + device const float* absmax [[buffer(1)]], + device float* output [[buffer(2)]], + constant uint& n [[buffer(3)]], + constant uint& blocksize [[buffer(4)]], + constant uint& blocks [[buffer(5)]], + uint tgid [[threadgroup_position_in_grid]], + uint tid [[thread_index_in_threadgroup]], + uint threadgroup_size [[threads_per_threadgroup]] +) { + // if (tgid >= blocks) { + // return; + // } + dequantize_block(packed, absmax, output, n, blocksize, tgid, tid, threadgroup_size, NF4_CODE); +} \ No newline at end of file diff --git a/csrc/mps_ops.mm b/csrc/mps_ops.mm index 85ed1b1e4..cfd9b5376 100644 --- a/csrc/mps_ops.mm +++ b/csrc/mps_ops.mm @@ -1,62 +1,244 @@ -#import +#import +#import -#define HLF_MAX 65504 -#define TH 1024 -#define NUM 4 -#define NUM_BLOCK 4096 +#include +#include +#include +#include +#include -static inline MPSGraph* get_graph() { - static MPSGraph* cur = nil; - if (!cur) { - cur = [[MPSGraph alloc] init]; +#include +#include + +namespace { + +typedef struct { + void* storage; + size_t byte_offset; + size_t nbytes; +} BNBMPSTensor; + +static constexpr NSUInteger kMaxThreadsPerThreadgroup = 512; + +static inline at::mps::MPSStream* get_default_stream() { + at::mps::MPSStream* stream = at::mps::getCurrentMPSStream(); + if (!stream) { + NSLog(@"bitsandbytes: PyTorch MPS stream is unavailable"); + abort(); } - return cur; + return stream; } static inline id get_device() { - NSError* error = nil; - static id device = nil; - if (!device) { - device = MTLCreateSystemDefaultDevice(); - } - if (!device) { - NSLog(@"Failed to get MPS device"); + return get_default_stream()->device(); +} + +static inline NSURL* metallib_url() { + Dl_info info; + if (dladdr(reinterpret_cast(&metallib_url), &info) == 0) { + NSLog(@"bitsandbytes: dladdr failed to resolve metallib path"); abort(); } - return device; + NSString* dylibPath = [NSString stringWithUTF8String:info.dli_fname]; + NSString* directory = [dylibPath stringByDeletingLastPathComponent]; + NSString* metallibPath = [directory stringByAppendingPathComponent:@"bitsandbytes.metallib"]; + return [NSURL fileURLWithPath:metallibPath]; } static inline id get_library() { - NSError* error = nil; static id library = nil; + static dispatch_once_t onceToken; + dispatch_once(&onceToken, ^{ + NSError* error = nil; + library = [get_device() newLibraryWithURL:metallib_url() error:&error]; if (!library) { - library = [get_device() newLibraryWithURL:[NSURL fileURLWithPath:@"bitsandbytes.metallib"] error:&error]; + NSLog(@"bitsandbytes: failed to load bitsandbytes.metallib (%@)", error); + abort(); + } + }); + return library; +} + +static inline id get_pipeline(NSString* functionName) { + static NSMutableDictionary>* cache = nil; + static dispatch_once_t onceToken; + dispatch_once(&onceToken, ^{ + cache = [[NSMutableDictionary alloc] init]; + }); + + @synchronized(cache) { + id pipeline = cache[functionName]; + if (pipeline) { + return pipeline; + } } - if (!library) { - NSLog(@"Failed to load bitsandbytes.metallib"); + + NSError* error = nil; + id function = [get_library() newFunctionWithName:functionName]; + if (!function) { + NSLog(@"bitsandbytes: missing Metal kernel %@", functionName); abort(); } - return library; + + id pipeline = [get_device() newComputePipelineStateWithFunction:function error:&error]; + [function release]; + + if (!pipeline) { + NSLog(@"bitsandbytes: failed to create pipeline for %@ (%@)", functionName, error); + abort(); + } + + @synchronized(cache) { + cache[functionName] = pipeline; + } + return pipeline; } -/*MPSGraphTensor* dequantize_mps(MPSGraphTensor* code, MPSGraphTensor* A, int n) -{ - id out = [get_graph() dequantizeTensor:(MPSGraphTensor*)A scaleTensor:(MPSGraphTensor*)code zeroPoint:0.0 -dataType:MPSDataTypeInt8 axis:0 name:@"out"]; return out; -}*/ - -// MPSGraph function for quantize -extern "C" MPSGraphTensor* quantize_mps(MPSGraph* graph, MPSGraphTensor* code, MPSGraphTensor* A, int n) { - id device = get_device(); - id library = get_library(); - static id kernel = nil; - if (!kernel) { - kernel = [library newFunctionWithName:@"quantize"]; - if (!kernel) { - NSLog(@"Failed to load bitsandbytes.metallib"); - abort(); - } +struct TensorView { + id buffer; + NSUInteger offset; +}; + +static inline TensorView make_tensor_view(const BNBMPSTensor& tensor, const char* label) { + TensorView view; + view.buffer = __builtin_bit_cast(id, tensor.storage); + view.offset = static_cast(tensor.byte_offset); + if (!view.buffer && tensor.nbytes > 0) { + NSLog(@"bitsandbytes: missing MTLBuffer for %s tensor (storage=%p, bytes=%zu)", label, tensor.storage, tensor.nbytes); + abort(); + } + return view; +} + +static inline void dispatch_quant_kernel( + NSString* name, + const BNBMPSTensor& input, + const BNBMPSTensor& absmax, + const BNBMPSTensor& out, + uint32_t blocksize, + uint32_t n +) { + if (n == 0) { + return; } - NSLog(@"Not implemented"); - return nil; + uint32_t blocks = (n + blocksize - 1) / blocksize; + TensorView inputView = make_tensor_view(input, "input"); + TensorView absmaxView = make_tensor_view(absmax, "absmax"); + TensorView outView = make_tensor_view(out, "out"); + + at::mps::MPSStream* stream = get_default_stream(); + // stream->endKernelCoalescing(); + id command_buffer_obj = stream->commandBuffer(); + + id pipeline_state_obj = (id) get_pipeline(name); + + id command_encoder_obj = stream->commandEncoder(); + + // Set kernel arguments + [command_encoder_obj setComputePipelineState:pipeline_state_obj]; + [command_encoder_obj setBuffer:inputView.buffer offset:inputView.offset atIndex:0]; + [command_encoder_obj setBuffer:absmaxView.buffer offset:absmaxView.offset atIndex:1]; + [command_encoder_obj setBuffer:outView.buffer offset:outView.offset atIndex:2]; + [command_encoder_obj setBytes:&n length:sizeof(uint32_t) atIndex:3]; + [command_encoder_obj setBytes:&blocksize length:sizeof(uint32_t) atIndex:4]; + [command_encoder_obj setBytes:&blocks length:sizeof(uint32_t) atIndex:5]; + NSUInteger threadsPerThreadgroup = pipeline_state_obj.threadExecutionWidth; + if (threadsPerThreadgroup == 0) { + threadsPerThreadgroup = 1; + } + MTLSize threads = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize grid = MTLSizeMake(blocks, 1, 1); + [command_encoder_obj dispatchThreads:grid threadsPerThreadgroup:threads]; + // [command_encoder_obj endEncoding]; + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); +} + +static inline void dispatch_dequant_kernel( + NSString* name, + const BNBMPSTensor& packed, + const BNBMPSTensor& absmax, + const BNBMPSTensor& output, + uint32_t blocksize, + uint32_t n +) { + // NSLog(@"bitsandbytes: dispatching dequant kernel %@ with blocksize=%d, n=%d", name, blocksize, n); + if (n == 0) { + return; + } + uint32_t blocks = (n + blocksize - 1) / blocksize; + TensorView inputView = make_tensor_view(packed, "packed"); + TensorView absmaxView = make_tensor_view(absmax, "absmax"); + TensorView outView = make_tensor_view(output, "output"); + + at::mps::MPSStream* stream = get_default_stream(); + // stream->endKernelCoalescing(); + id command_buffer_obj = stream->commandBuffer(); + + id pipeline_state_obj = (id) get_pipeline(name); + + id command_encoder_obj = stream->commandEncoder(); + + // Set kernel arguments + [command_encoder_obj setComputePipelineState:pipeline_state_obj]; + [command_encoder_obj setBuffer:inputView.buffer offset:inputView.offset atIndex:0]; + [command_encoder_obj setBuffer:absmaxView.buffer offset:absmaxView.offset atIndex:1]; + [command_encoder_obj setBuffer:outView.buffer offset:outView.offset atIndex:2]; + [command_encoder_obj setBytes:&n length:sizeof(uint32_t) atIndex:3]; + [command_encoder_obj setBytes:&blocksize length:sizeof(uint32_t) atIndex:4]; + [command_encoder_obj setBytes:&blocks length:sizeof(uint32_t) atIndex:5]; + NSUInteger maxThreadsPerTG = pipeline_state_obj.maxTotalThreadsPerThreadgroup; + NSUInteger desiredThreads = (blocksize + 1) / 2; + if (desiredThreads == 0) { + desiredThreads = 1; + } + NSUInteger threadsPerThreadgroup = + std::min(maxThreadsPerTG, std::max(1, desiredThreads)); + if (threadsPerThreadgroup < pipeline_state_obj.threadExecutionWidth) { + threadsPerThreadgroup = std::min(pipeline_state_obj.threadExecutionWidth, maxThreadsPerTG); + } + + NSUInteger totalThreads = threadsPerThreadgroup * blocks; + MTLSize threads = MTLSizeMake(threadsPerThreadgroup, 1, 1); + MTLSize grid = MTLSizeMake(totalThreads, 1, 1); + [command_encoder_obj dispatchThreads:grid threadsPerThreadgroup:threads]; + // [command_encoder_obj endEncoding]; + stream->synchronize(at::mps::SyncType::COMMIT_AND_CONTINUE); +} + +} // namespace + +extern "C" { + +void cquantize_blockwise_fp16_fp4(BNBMPSTensor input, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_quant_kernel(@"quantize_4bit_fp16_fp4", input, absmax, out, blocksize, n); } + +void cquantize_blockwise_fp16_nf4(BNBMPSTensor input, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_quant_kernel(@"quantize_4bit_fp16_nf4", input, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32_fp4(BNBMPSTensor input, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_quant_kernel(@"quantize_4bit_fp32_fp4", input, absmax, out, blocksize, n); +} + +void cquantize_blockwise_fp32_nf4(BNBMPSTensor input, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_quant_kernel(@"quantize_4bit_fp32_nf4", input, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp16_fp4(BNBMPSTensor packed, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_dequant_kernel(@"dequantize_4bit_fp16_fp4", packed, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp16_nf4(BNBMPSTensor packed, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_dequant_kernel(@"dequantize_4bit_fp16_nf4", packed, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp32_fp4(BNBMPSTensor packed, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_dequant_kernel(@"dequantize_4bit_fp32_fp4", packed, absmax, out, blocksize, n); +} + +void cdequantize_blockwise_fp32_nf4(BNBMPSTensor packed, BNBMPSTensor absmax, BNBMPSTensor out, int blocksize, const int n) { + dispatch_dequant_kernel(@"dequantize_4bit_fp32_nf4", packed, absmax, out, blocksize, n); +} + +} // extern "C" \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 65f9314c5..8d35ecda3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,6 +74,7 @@ test = [ [tool.setuptools] package-data = { "*" = ["libbitsandbytes*.*", "py.typed"] } +# package-data = { "*" = ["libbitsandbytes*.*", "bitsandbytes.metallib", "py.typed"] } [tool.setuptools.packages.find] include = ["bitsandbytes*"] diff --git a/script.sh b/script.sh new file mode 100755 index 000000000..dec3e8f6f --- /dev/null +++ b/script.sh @@ -0,0 +1,4 @@ +#!/bin/bash + +PYTHON_PATH=/Users/medmekk/miniforge3/envs/gpt/bin/python +$PYTHON_PATH ./test_bnb_mac.py \ No newline at end of file diff --git a/test_bnb_mac.py b/test_bnb_mac.py new file mode 100644 index 000000000..038e3aea7 --- /dev/null +++ b/test_bnb_mac.py @@ -0,0 +1,188 @@ +# from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig +# import torch +# tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-3.2-1b-Instruct") +# quantization_config = BitsAndBytesConfig(load_in_4bit=True) +# model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-3.2-1b-Instruct", device_map="mps", quantization_config=quantization_config, dtype=torch.float16) +# print("model.device:", model.device) +# prompt = "Hello, how are you?" +# inputs = tokenizer(prompt, return_tensors="pt").to(model.device) +# outputs = model.generate(**inputs, max_new_tokens=20) +# print(tokenizer.decode(outputs[0], skip_special_tokens=True)) # or whatever entry function you have + +import torch +import bitsandbytes as bnb +from torch.profiler import profile, ProfilerActivity +from torch.mps.profiler import metal_capture + +_NF4_QUANT_TABLE = torch.tensor( + [ + -1.0, + -0.6961928009986877, + -0.5250730514526367, + -0.39491748809814453, + -0.28444138169288635, + -0.18477343022823334, + -0.09105003625154495, + 0.0, + 0.07958029955625534, + 0.16093020141124725, + 0.24611230194568634, + 0.33791524171829224, + 0.44070982933044434, + 0.5626170039176941, + 0.7229568362236023, + 1.0, + ], + dtype=torch.float32, + device="xpu" + if hasattr(torch, "xpu") and torch.xpu.is_available() + else "cpu", # Only cpu/xpu use this table for now. +) +_FP4_QUANT_TABLE = torch.tensor( + [ + 0.0000, + 0.0052, + 0.6667, + 1.0000, + 0.3333, + 0.5000, + 0.1667, + 0.2500, + 0.0000, + -0.0052, + -0.6667, + -1.0000, + -0.3333, + -0.5000, + -0.1667, + -0.2500, + ], + dtype=torch.float32, + device="xpu" + if hasattr(torch, "xpu") and torch.xpu.is_available() + else "cpu", # Only cpu/xpu use this table for now. +) +CODE = {"nf4": _NF4_QUANT_TABLE, "fp4": _FP4_QUANT_TABLE} + +def _dequantize_4bit_impl( + A: torch.Tensor, + absmax: torch.Tensor, + blocksize: int, + quant_type: str, + shape: "Sequence[int]", + dtype: torch.dtype, +) -> torch.Tensor: + # Enable non uint8 dtype + if A.dtype != torch.uint8: + A = A.view(torch.uint8) + A = A.reshape(-1) + # Map nf4 to [-1, 1] + out_dq = torch.empty(A.size(0) * 2, dtype=torch.int32, device=A.device) + n = out_dq.numel() + out_dq[1::2] = A & 0xF + out_dq[::2] = A >> 4 + # code is fp32, cast to dtype to avoid the mismatch issue + code = CODE[quant_type].to(dtype).to(A.device) + out_dq = code[out_dq] + + # Apply scales + if out_dq.numel() != n: + assert out_dq.numel() == n + 1 + out_dq = torch.narrow(out_dq, 0, 0, n) + blocks = n // blocksize + blocks += 1 if n % blocksize > 0 else 0 + rem = n % blocksize + has_rem = rem > 0 + + out = torch.empty(shape, dtype=dtype, device=A.device).reshape(-1) + if has_rem: + out[: n - rem] = (out_dq[: n - rem].view(-1, blocksize) * absmax[: blocks - has_rem].view(-1, 1)).reshape(-1) + out[n - rem :] = out_dq[n - rem :] * absmax[-1] + else: + out = out_dq.view(-1, blocksize) * absmax.view(-1, 1) + + out = out.reshape(-1, *shape[1:]).to(dtype) + + return out +def run_once(): + # A = torch.randn(2048, device="mps", dtype=torch.float16) + # q, absmax = torch.ops.bitsandbytes.quantize_4bit(A, 64, "nf4", torch.uint8) + out = torch.empty(2048*2, device="mps", dtype=torch.float32) + q = torch.randint(0, 255, (2048,), device="mps", dtype=torch.uint8) + absmax = torch.randn(64, device="mps", dtype=torch.float32) + print("q.shape:", q.shape, q.dtype) + print("absmax.shape:", absmax.shape, absmax.dtype) + B = torch.ops.bitsandbytes.dequantize_4bit(q, absmax, 64, "nf4", out.shape, out.dtype) + # B_ref = _dequantize_4bit_impl(q, absmax, 64, "nf4", out.shape, out.dtype) + # print("ok", float((B - B_ref).abs().max())) + # torch.mps.synchronize() + # print("B.shape:", B.shape, B.dtype) + # print("ok", float((A - B).abs().max())) + +run_once() +trace_path = "bnb_mps_capture_11.gputrace" + +with metal_capture(trace_path): + with profile( + activities=[], + record_shapes=True, + with_stack=True, + ) as prof: + for i in range(10): + run_once() + torch.mps.synchronize() + print(f"iteration {i} done") + +print(prof.key_averages().table(sort_by="self_cpu_time_total")) +print(f"Metal capture saved to: {trace_path}") + +# import torch, bitsandbytes as bnb + +# torch.manual_seed(0) +# A = torch.randn(256, device="mps", dtype=torch.float16) + +# q, absmax = torch.ops.bitsandbytes.quantize_4bit(A, 64, "nf4", torch.uint8) +# B_native = torch.ops.bitsandbytes.dequantize_4bit(q, absmax, 64, "nf4", A.shape, A.dtype) + +# # CPU reference (uses the default implementation, then move back to MPS) +# B_ref = torch.ops.bitsandbytes.dequantize_4bit.default( +# q.cpu(), absmax.cpu(), 64, "nf4", A.shape, A.dtype +# ).to("mps") + +# print("A[:8] ", A[:8].cpu()) +# print("B_native[:8]", B_native[:8].cpu()) +# print("B_ref[:8] ", B_ref[:8].cpu()) +# print("max |A-B_native|:", float((A - B_native).abs().max())) +# print("max |A-B_ref| :", float((A - B_ref).abs().max())) + +# diff = (B_native - B_ref).cpu() +# print("B_native shape:", B_native.shape) +# print("B_ref shape:", B_ref.shape) +# print("max |B_native - B_ref|:", float(diff.abs().max())) +# print("first 16 diffs:", diff[:16]) + +# q_cpu, absmax_cpu = torch.ops.bitsandbytes.quantize_4bit.default( +# A.cpu(), 64, "nf4", torch.uint8 +# ) + +# print("q identical? ", torch.equal(q.cpu(), q_cpu)) +# print("absmax max diff:", float((absmax.cpu() - absmax_cpu).abs().max())) +# print("q_mps[:8]:", q.view(-1)[:8].cpu()) +# print("q_cpu[:8]:", q_cpu.view(-1)[:8]) +# print("absmax_mps[:4]:", absmax[:4].cpu()) +# print("absmax_cpu[:4]:", absmax_cpu[:4]) + +# import torch, bitsandbytes as bnb, time + +# torch.manual_seed(0) +# A = torch.randn(4096 * 4096, device="mps", dtype=torch.float16) +# blocksize = 64 + +# q, absmax = torch.ops.bitsandbytes.quantize_4bit(A, blocksize, "nf4", torch.uint8) + +# torch.mps.synchronize() +# t0 = time.perf_counter() +# torch.ops.bitsandbytes.dequantize_4bit(q, absmax, blocksize, "nf4", A.shape, A.dtype) +# torch.mps.synchronize() +# dt = time.perf_counter() - t0 +# print(f"Dequant time: {dt*1000:.2f} ms for {A.numel()/1e6:.1f}M elements") \ No newline at end of file