diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 000000000..8e5903655
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1 @@
+open_collective: bitsandbytes
diff --git a/.github/scripts/build-cuda.sh b/.github/scripts/build-cuda.sh
index 8985327f2..b13d9c92b 100644
--- a/.github/scripts/build-cuda.sh
+++ b/.github/scripts/build-cuda.sh
@@ -11,14 +11,14 @@ if [[ -v cuda_targets ]]; then
elif [ "${build_arch}" = "aarch64" ]; then
build_capability="75;80;90"
- # CUDA 12.8: Add sm100
- [[ "${cuda_version}" == 12.8.* ]] && build_capability="75;80;90;100"
+ # CUDA 12.8+: Add sm100/sm120
+ [[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="75;80;90;100;120"
else
- # By default, target Maxwell through Hopper.
- build_capability="50;52;60;61;70;75;80;86;89;90"
+ # By default, target Pascal through Hopper.
+ build_capability="60;70;75;80;86;89;90"
- # CUDA 12.8: Add sm100 and sm120; remove < sm75 to align with PyTorch 2.7+cu128 minimum
- [[ "${cuda_version}" == 12.8.* ]] && build_capability="75;80;86;89;90;100;120"
+ # CUDA 12.8+: Add sm100 and sm120; remove < sm70 to align with PyTorch 2.8+cu128 minimum
+ [[ "${cuda_version}" == 12.8.* || "${cuda_version}" == 12.9.* ]] && build_capability="70;75;80;86;89;90;100;120"
fi
[[ "${build_os}" = windows-* ]] && python3 -m pip install ninja
diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml
index 827c2ffbf..a11b13f33 100644
--- a/.github/workflows/python-package.yml
+++ b/.github/workflows/python-package.yml
@@ -72,16 +72,17 @@ jobs:
- os: windows-latest
arch: x86_64
cuda_version:
- ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1"]
+ ["11.8.0", "12.0.1", "12.1.1", "12.2.2", "12.3.2", "12.4.1", "12.5.1", "12.6.3", "12.8.1", "12.9.1"]
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
# Windows: We install Cuda on the agent (slow)
- - uses: Jimver/cuda-toolkit@v0.2.22
+ - uses: Jimver/cuda-toolkit@c35baa1a18fd1fc9dcf47c5bd839bf30559c0bc3 # v0.2.24
if: startsWith(matrix.os, 'windows')
id: cuda-toolkit
with:
- cuda: ${{ matrix.cuda_version }}
+ # Temporary: Use CUDA 12.9.0 for Windows until 12.9.1 is supported with this action.
+ cuda: ${{ matrix.cuda_version == '12.9.1' && '12.9.0' || matrix.cuda_version }}
method: "network"
sub-packages: '["nvcc","cudart","cusparse","cublas","thrust","nvrtc_dev","cublas_dev","cusparse_dev"]'
linux-local-args: '["--toolkit"]'
diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml
index 0d3884593..997da52bd 100644
--- a/.github/workflows/tests.yml
+++ b/.github/workflows/tests.yml
@@ -49,8 +49,8 @@ jobs:
build-cuda:
strategy:
matrix:
- cuda_version: ["11.8.0", "12.6.3", "12.8.1"]
- os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025]
+ cuda_version: ["11.8.0", "12.6.3", "12.8.1", "12.9.1"]
+ os: [ubuntu-22.04, ubuntu-22.04-arm]
include:
- os: ubuntu-22.04
arch: x86_64
@@ -58,13 +58,14 @@ jobs:
arch: aarch64
- os: windows-2025
arch: x86_64
+ cuda_version: "11.8.0"
runs-on: ${{ matrix.os }}
steps:
- uses: actions/checkout@v4
- name: Install CUDA Toolkit
- uses: Jimver/cuda-toolkit@v0.2.23
+ uses: Jimver/cuda-toolkit@c35baa1a18fd1fc9dcf47c5bd839bf30559c0bc3 # v0.2.24
if: startsWith(matrix.os, 'windows')
id: cuda-toolkit
with:
@@ -100,8 +101,8 @@ jobs:
fail-fast: false
matrix:
os: [ubuntu-22.04, ubuntu-22.04-arm, windows-2025, macos-15]
- # Test with the oldest supported torch version and the two newest.
- torch_version: ["2.2.2", "2.6.0", "2.7.1"]
+ # Test with the oldest supported torch version, the newest two stable/RC.
+ torch_version: ["2.3.1", "2.7.1", "2.8.0"]
include:
- os: ubuntu-22.04
arch: x86_64
@@ -117,7 +118,7 @@ jobs:
arch: arm64
exclude:
- os: ubuntu-22.04-arm
- torch_version: "2.2.2"
+ torch_version: "2.3.1"
runs-on: ${{ matrix.runner || matrix.os }}
env:
@@ -147,9 +148,10 @@ jobs:
pip install -e ".[test]"
pip install pytest-cov
- # We need to downgrade to numpy<2 for torch<2.3 compatibility.
+ # We need to downgrade to numpy<2 for torch<2.4.1 compatibility on Windows
+ # See: https://github.com/pytorch/pytorch/issues/131668
- name: Downgrade NumPy
- if: startsWith(matrix.torch_version, '2.2.')
+ if: startsWith(matrix.os, 'windows') && startsWith(matrix.torch_version, '2.3.')
run: pip install "numpy<2"
- name: Show installed packages
@@ -161,7 +163,7 @@ jobs:
- name: Run tests
run: pytest --durations=100
- test-cpu-ipex:
+ test-cpu-intel:
if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cpu
runs-on: banb-aws-general-8-plus-use1-public-80
@@ -185,7 +187,6 @@ jobs:
- name: Install dependencies
run: |
pip install torch==2.7.1 --index-url https://download.pytorch.org/whl/cpu
- pip install intel_extension_for_pytorch==2.7.0 --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/cpu/us/
pip install -e ".[test]"
pip install pytest-cov
@@ -195,9 +196,6 @@ jobs:
- name: Show environment information
run: python -m torch.utils.collect_env
- - name: IPEX smoke test
- run: python -c "import torch; import intel_extension_for_pytorch as ipex; print(torch.__version__); print(ipex.__version__);"
-
- name: Run tests
run: pytest --durations=100
@@ -223,7 +221,7 @@ jobs:
# run: pip list
test-hpu:
- if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
+ if: false # github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cpu
strategy:
fail-fast: false
@@ -279,21 +277,12 @@ jobs:
run: pytest --durations=100
test-xpu:
- if: github.repository == 'bitsandbytes-foundation/bitsandbytes'
+ if: false # github.repository == 'bitsandbytes-foundation/bitsandbytes'
needs: build-cpu
strategy:
fail-fast: false
matrix:
torch_version: ["2.7.1"] #["2.6.0", "2.7.1"]
- ipex: [false]
- # ipex: [true, false]
- # include:
- # - torch_version: "2.6.0"
- # ipex: true
- # ipex_version: "2.6.10+xpu"
- # - torch_version: "2.7.1"
- # ipex: true
- # ipex_version: "2.7.10+xpu"
runs-on:
group: bandb-itac-bmsprpvc1550-8-1gpu
env:
@@ -329,10 +318,6 @@ jobs:
- name: Install PyTorch
run: pip install torch==${{ matrix.torch_version }} --index-url https://download.pytorch.org/whl/xpu
- - name: Install IPEX
- if: matrix.ipex == true
- run: pip install intel_extension_for_pytorch==${{ matrix.ipex_version }} --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/
-
- name: Install dependencies
run: |
pip install -e ".[test]"
@@ -358,10 +343,10 @@ jobs:
os: [ubuntu-22.04, windows-2025]
arch: [x86_64]
gpu: [T4, L40S]
- cuda_version: ["11.8.0", "12.6.3", "12.8.1"]
+ cuda_version: ["11.8.0", "12.6.3", "12.8.1", "12.9.1"]
include:
- cuda_version: "11.8.0"
- torch_version: "2.2.2"
+ torch_version: "2.3.1"
pypi_index: "https://download.pytorch.org/whl/cu118"
- cuda_version: "12.6.3"
torch_version: "2.6.0"
@@ -369,6 +354,9 @@ jobs:
- cuda_version: "12.8.1"
torch_version: "2.7.1"
pypi_index: "https://download.pytorch.org/whl/cu128"
+ - cuda_version: "12.9.1"
+ torch_version: "2.8.0"
+ pypi_index: "https://download.pytorch.org/whl/cu129"
# Linux L40S runners
@@ -387,7 +375,7 @@ jobs:
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
- torch_version: "2.2.0"
+ torch_version: "2.3.1"
pypi_index: "https://download.pytorch.org/whl/cu118"
- os: windows-2025
arch: x86_64
@@ -401,12 +389,14 @@ jobs:
gpu: T4
runner: CUDA-Windows-x64
cuda_version: "11.8.0"
- torch_version: "2.7.1"
+ torch_version: "2.7.1" # Note: this is the last PyTorch release supporting CUDA 11.8.
pypi_index: "https://download.pytorch.org/whl/cu118"
exclude:
# Our current T4 Windows runner has a driver too old (471.11)
# and cannot support CUDA 12+. Skip for now.
+ - os: windows-2025
+ cuda_version: "12.9.1"
- os: windows-2025
cuda_version: "12.8.1"
- os: windows-2025
@@ -438,15 +428,9 @@ jobs:
- name: Install dependencies
run: |
- pip install torch==${{ matrix.torch_version }} --index-url ${{ matrix.pypi_index }}
+ pip install --pre torch~=${{ matrix.torch_version }}.dev0 --index-url ${{ matrix.pypi_index }}
pip install -e ".[test]"
pip install pytest-cov
-
- # We need to downgrade to numpy<2 for torch<2.3 compatibility.
- - name: Downgrade NumPy
- if: startsWith(matrix.torch_version, '2.2.')
- run: pip install "numpy<2"
-
- name: Show installed packages
run: pip list
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 770b4ba30..d9529b0d7 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -28,11 +28,12 @@ set(CUDA_FILES csrc/ops.cu csrc/kernels.cu)
set(HIP_FILES csrc/ops.hip csrc/kernels.hip)
set(MPS_FILES csrc/mps_ops.mm)
set(METAL_FILES csrc/mps_kernels.metal)
+set(XPU_FILES csrc/xpu_ops.cpp csrc/xpu_kernels.cpp)
# C++ sources are always included
list(APPEND SRC_FILES ${CPP_FILES})
-set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps)")
-set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps)
+set(COMPUTE_BACKEND "cpu" CACHE STRING "The compute backend to use (cpu, cuda, hip, mps, xpu)")
+set_property(CACHE COMPUTE_BACKEND PROPERTY STRINGS cpu cuda hip mps xpu)
option(PTXAS_VERBOSE "Pass through -v flag to PTX Assembler" OFF)
if(APPLE)
@@ -64,10 +65,19 @@ elseif(${COMPUTE_BACKEND} STREQUAL "mps")
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS ON)
+elseif(${COMPUTE_BACKEND} STREQUAL "xpu")
+ if(APPLE)
+ message(FATAL_ERROR "XPU is not supported on macOS" )
+ endif()
+ set(BUILD_CUDA OFF)
+ set(BUILD_HIP OFF)
+ set(BUILD_MPS OFF)
+ set(BUILD_XPU ON)
else()
set(BUILD_CUDA OFF)
set(BUILD_HIP OFF)
set(BUILD_MPS OFF)
+ set(BUILD_XPU OFF)
endif()
@@ -217,6 +227,15 @@ elseif(BUILD_MPS)
COMMENT "Compiling Metal kernels"
VERBATIM)
add_custom_target(metallib DEPENDS "bitsandbytes/bitsandbytes.metallib")
+elseif(BUILD_XPU)
+ list(APPEND SRC_FILES ${XPU_FILES})
+ string(APPEND BNB_OUTPUT_NAME "_xpu")
+ add_compile_definitions(BUILD_XPU)
+ set(CMAKE_C_COMPILER icx)
+ set(CMAKE_CXX_COMPILER icpx)
+ if(WIN32)
+ set(CMAKE_CXX_COMPILER icx)
+ endif()
else()
string(APPEND BNB_OUTPUT_NAME "_cpu")
set(GPU_SOURCES)
@@ -285,6 +304,15 @@ if(BUILD_MPS)
add_dependencies(bitsandbytes metallib)
target_link_libraries(bitsandbytes objc "-framework Foundation" "-framework Metal" "-framework MetalPerformanceShaders" "-framework MetalPerformanceShadersGraph")
endif()
+if(BUILD_XPU)
+ set(SYCL_LINK_FLAGS "-fsycl;--offload-compress;-fsycl-targets=spir64_gen,spir64;-Xs;-device pvc,xe-lpg,ats-m150 -options ' -cl-intel-enable-auto-large-GRF-mode -cl-poison-unsupported-fp64-kernels -cl-intel-greater-than-4GB-buffer-required'")
+ set(SYCL_COMPILE_FLAGS "-fsycl;-fhonor-nans;-fhonor-infinities;-fno-associative-math;-fno-approx-func;-fno-sycl-instrument-device-code;--offload-compress;-fsycl-targets=spir64_gen,spir64;")
+
+ set_property(TARGET bitsandbytes PROPERTY CXX_STANDARD 20)
+ target_compile_options(bitsandbytes PRIVATE ${SYCL_COMPILE_FLAGS})
+ target_link_options(bitsandbytes PRIVATE ${SYCL_LINK_FLAGS})
+
+endif()
if(WIN32)
set_target_properties(bitsandbytes PROPERTIES PREFIX "lib")
diff --git a/MANIFEST.in b/MANIFEST.in
new file mode 100644
index 000000000..00bdaa214
--- /dev/null
+++ b/MANIFEST.in
@@ -0,0 +1,3 @@
+include CMakeLists.txt
+graft csrc
+graft include
diff --git a/README.md b/README.md
index c6c5ff25b..732baea69 100644
--- a/README.md
+++ b/README.md
@@ -20,13 +20,13 @@ The library includes quantization primitives for 8-bit & 4-bit operations, throu
bitsandbytes has the following minimum requirements for all platforms:
* Python 3.9+
-* [PyTorch](https://pytorch.org/get-started/locally/) 2.2+
+* [PyTorch](https://pytorch.org/get-started/locally/) 2.3+
* _Note: While we aim to provide wide backwards compatibility, we recommend using the latest version of PyTorch for the best experience._
#### Accelerator support:
Note: this table reflects the status of the current development branch. For the latest stable release, see the
-[document in the v0.46.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.46.0/README.md#accelerator-support).
+[document in the 0.47.0 tag](https://github.com/bitsandbytes-foundation/bitsandbytes/blob/0.47.0/README.md#accelerator-support).
##### Legend:
@@ -61,7 +61,7 @@ bitsandbytes has the following minimum requirements for all platforms:
|
🟩 NVIDIA GPU
cuda |
- SM50+ minimum SM75+ recommended |
+ SM60+ minimum SM75+ recommended |
✅ |
✅ |
✅ |
@@ -71,11 +71,11 @@ bitsandbytes has the following minimum requirements for all platforms:
🟥 AMD GPU
cuda |
CDNA: gfx90a, gfx942
- RDNA: gfx1100, gfx1200
+ RDNA: gfx1100
|
- 🚧 |
- 🚧 |
- 🚧 |
+ ✅ |
+ 〰️ |
+ ✅ |
|
@@ -85,14 +85,14 @@ bitsandbytes has the following minimum requirements for all platforms:
Arc A-Series (Alchemist)
Arc B-Series (Battlemage)
- 🚧 |
- 🚧 |
- 🚧 |
+ ✅ |
+ ✅ |
+ 〰️ |
|
🟪 Intel Gaudi
hpu |
- Gaudi1, Gaudi2, Gaudi3 |
+ Gaudi2, Gaudi3 |
✅ |
〰️ |
❌ |
@@ -108,7 +108,7 @@ bitsandbytes has the following minimum requirements for all platforms:
|
🟩 NVIDIA GPU
cuda |
- SM75, SM80, SM90, SM100 |
+ SM75+ |
✅ |
✅ |
✅ |
@@ -127,7 +127,7 @@ bitsandbytes has the following minimum requirements for all platforms:
|
🟩 NVIDIA GPU
cuda |
- SM50+ minimum SM75+ recommended |
+ SM60+ minimum SM75+ recommended |
✅ |
✅ |
✅ |
@@ -139,9 +139,9 @@ bitsandbytes has the following minimum requirements for all platforms:
Arc A-Series (Alchemist)
Arc B-Series (Battlemage)
- 🚧 |
- 🚧 |
- 🚧 |
+ ✅ |
+ ✅ |
+ 〰️ |
| 🍎 macOS 14+ |
@@ -173,7 +173,9 @@ bitsandbytes has the following minimum requirements for all platforms:
## :heart: Sponsors
The continued maintenance and development of `bitsandbytes` is made possible thanks to the generous support of our sponsors. Their contributions help ensure that we can keep improving the project and delivering valuable updates to the community.
-
+
+
+
## License
`bitsandbytes` is MIT licensed.
diff --git a/_typos.toml b/_typos.toml
index 955c6cb79..fce018f81 100644
--- a/_typos.toml
+++ b/_typos.toml
@@ -1,4 +1,11 @@
[files]
+# Skip these files in typo checks
+extend-exclude = [
+ "csrc/xpu_ops.h",
+ "csrc/xpu_ops.cpp",
+ "csrc/xpu_kernels.h",
+ "csrc/xpu_kernels.cpp"
+]
[default]
extend-ignore-re = [
diff --git a/benchmarking/inference_benchmark.py b/benchmarking/inference_benchmark.py
index 61ac570f2..72ee8cfae 100644
--- a/benchmarking/inference_benchmark.py
+++ b/benchmarking/inference_benchmark.py
@@ -21,6 +21,9 @@
--batches BATCHES [BATCHES ...]
--input-length INPUT_LENGTH
--out-dir OUT_DIR
+ --iterations ITERATIONS
+ --warmup-runs WARMUP_RUNS
+ --output-length OUTPUT_LENGTH
"""
import argparse
@@ -30,6 +33,9 @@
from optimum_benchmark.logging_utils import setup_logging
import torch
+torch.backends.cudnn.benchmark = False
+torch.backends.cudnn.deterministic = True
+
BFLOAT16_SUPPORT = torch.cuda.get_device_capability()[0] >= 8
WEIGHTS_CONFIGS = {
@@ -73,9 +79,8 @@
},
}
-if __name__ == "__main__":
- setup_logging(level="INFO")
+def parse_args():
parser = argparse.ArgumentParser(description="bitsandbytes inference benchmark tool")
parser.add_argument("model_id", type=str, help="The model checkpoint to use.")
@@ -98,37 +103,73 @@
parser.add_argument("--out-dir", type=str, default="reports")
- args = parser.parse_args()
+ parser.add_argument("--iterations", type=int, default=10, help="Number of iterations for each benchmark run")
+ parser.add_argument(
+ "--warmup-runs", type=int, default=10, help="Number of warmup runs to discard before measurement"
+ )
+ parser.add_argument(
+ "--output-length",
+ type=int,
+ default=64,
+ help="If set, `max_new_tokens` and `min_new_tokens` will be set to this value.",
+ )
+
+ return parser.parse_args()
+
+
+def run_benchmark(args, config, batch_size):
+ launcher_config = ProcessConfig(device_isolation=True, device_isolation_action="warn", start_method="spawn")
+ scenario_config = InferenceConfig(
+ latency=True,
+ memory=True,
+ input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
+ iterations=args.iterations,
+ warmup_runs=args.warmup_runs,
+ # set duration to 0 to disable the duration-based stopping criterion
+ # this is IMPORTANT to ensure that all benchmarks run the same number of operations, regardless of hardware speed/bottlenecks
+ duration=0,
+ # for consistent results, set a fixed min and max for output tokens
+ generate_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length},
+ forward_kwargs={"min_new_tokens": args.output_length, "max_new_tokens": args.output_length},
+ )
+
+ backend_config = PyTorchConfig(
+ device="cuda",
+ device_ids="0",
+ device_map="auto",
+ no_weights=False,
+ model=args.model_id,
+ **WEIGHTS_CONFIGS[config],
+ )
+
+ test_name = (
+ f"benchmark-{config}"
+ f"-bsz-{batch_size}"
+ f"-isz-{args.input_length}"
+ f"-osz-{args.output_length}"
+ f"-iter-{args.iterations}"
+ f"-wrmup-{args.warmup_runs}"
+ )
+ benchmark_config = BenchmarkConfig(
+ name=test_name,
+ scenario=scenario_config,
+ launcher=launcher_config,
+ backend=backend_config,
+ )
+
+ out_path = out_dir / (test_name + ".json")
+ print(f"[{test_name}] Starting:")
+ benchmark_report = Benchmark.launch(benchmark_config)
+ benchmark_report.save_json(out_path)
+
+
+if __name__ == "__main__":
+ setup_logging(level="INFO")
+ args = parse_args()
out_dir = Path(args.out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
for batch_size in args.batches:
- print(f"Benchmarking batch size: {batch_size}")
for config in args.configs:
- launcher_config = ProcessConfig(device_isolation=True, start_method="spawn")
- scenario_config = InferenceConfig(
- latency=True,
- memory=True,
- input_shapes={"batch_size": batch_size, "sequence_length": args.input_length},
- )
- backend_config = PyTorchConfig(
- device="cuda",
- device_ids="0",
- device_map="auto",
- no_weights=False,
- model=args.model_id,
- **WEIGHTS_CONFIGS[config],
- )
- benchmark_config = BenchmarkConfig(
- name=f"benchmark-{config}-bsz{batch_size}",
- scenario=scenario_config,
- launcher=launcher_config,
- backend=backend_config,
- )
-
- out_path = out_dir / f"benchmark_{config}_bsz{batch_size}.json"
-
- benchmark_report = Benchmark.launch(benchmark_config)
- benchmark_report.log()
- benchmark_report.save_json(out_path)
+ run_benchmark(args, config, batch_size)
diff --git a/benchmarking/xpu/inference_benchmark.py b/benchmarking/xpu/inference_benchmark.py
new file mode 100644
index 000000000..055abed2e
--- /dev/null
+++ b/benchmarking/xpu/inference_benchmark.py
@@ -0,0 +1,147 @@
+import argparse
+import time
+
+# import intel_extension_for_pytorch as ipex
+import numpy as np
+import torch
+from transformers import AutoModelForCausalLM, AutoTokenizer, GenerationConfig
+
+MAX_NEW_TOKENS = 256
+
+get_time = time.time
+
+system_prompt = "You are a helpful assistant"
+user_prompt = """Summarize this text please:
+
+```Tell me, O muse, of that ingenious hero who travelled far and wide after he had sacked the famous town of Troy. Many cities did he visit, and many were the nations with whose manners and customs he was acquainted; moreover he suffered much by sea while trying to save his own life and bring his men safely home; but do what he might he could not save his men, for they perished through their own sheer folly in eating the cattle of the Sun-god Hyperion; so the god prevented them from ever reaching home. Tell me, too, about all these things, O daughter of Jove, from whatsoever source you may know them.
+
+So now all who escaped death in battle or by shipwreck had got safely home except Ulysses, and he, though he was longing to return to his wife and country, was detained by the goddess Calypso, who had got him into a large cave and wanted to marry him. But as years went by, there came a time when the gods settled that he should go back to Ithaca; even then, however, when he was among his own people, his troubles were not yet over; nevertheless all the gods had now begun to pity him except Neptune, who still persecuted him without ceasing and would not let him get home.
+
+Now Neptune had gone off to the Ethiopians, who are at the world's end, and lie in two halves, the one looking West and the other East. He had gone there to accept a hecatomb of sheep and oxen, and was enjoying himself at his festival; but the other gods met in the house of Olympian Jove, and the sire of gods and men spoke first. At that moment he was thinking of Aegisthus, who had been killed by Agamemnon's son Orestes; so he said to the other gods:
+
+"See now, how men lay blame upon us gods for what is after all nothing but their own folly. Look at Aegisthus; he must needs make love to Agamemnon's wife unrighteously and then kill Agamemnon, though he knew it would be the death of him; for I sent Mercury to warn him not to do either of these things, inasmuch as Orestes would be sure to take his revenge when he grew up and wanted to return home. Mercury told him this in all good will but he would not listen, and now he has paid for everything in full."
+
+Then Minerva said, "Father, son of Saturn, King of kings, it served Aegisthus right, and so it would any one else who does as he did; but Aegisthus is neither here nor there; it is for Ulysses that my heart bleeds, when I think of his sufferings in that lonely sea-girt island, far away, poor man, from all his friends. It is an island covered with forest, in the very middle of the sea, and a goddess lives there, daughter of the magician Atlas, who looks after the bottom of the ocean, and carries the great columns that keep heaven and earth asunder. This daughter of Atlas has got hold of poor unhappy Ulysses, and keeps trying by every kind of blandishment to make him forget his home, so that he is tired of life, and thinks of nothing but how he may once more see the smoke of his own chimneys. You, sir, take no heed of this, and yet when Ulysses was before Troy did he not propitiate you with many a burnt sacrifice? Why then should you keep on being so angry with him?"
+
+And Jove said, "My child, what are you talking about? How can I forget Ulysses than whom there is no more capable man on earth, nor more liberal in his offerings to the immortal gods that live in heaven? Bear in mind, however, that Neptune is still furious with Ulysses for having blinded an eye of Polyphemus king of the Cyclopes. Polyphemus is son to Neptune by the nymph Thoosa, daughter to the sea-king Phorcys; therefore though he will not kill Ulysses outright, he torments him by preventing him from getting home. Still, let us lay our heads together and see how we can help him to return; Neptune will then be pacified, for if we are all of a mind he can hardly stand out against us."```"""
+
+prompt = [
+ {"role": "system", "content": system_prompt},
+ {"role": "user", "content": user_prompt},
+]
+
+
+def get_inputs(tokenizer):
+ inputs = tokenizer.apply_chat_template(
+ prompt,
+ tokenize=True,
+ add_generation_prompt=True,
+ return_tensors="pt",
+ return_dict=True,
+ )
+ return inputs
+
+
+def get_streamer(tokenizer):
+ streamer = Streamer(tokenizer)
+ # streamer = TextStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
+ return streamer
+
+
+class Streamer:
+ def __init__(self, tokenizer, print_median=False):
+ self.times = []
+ self.print_median = print_median
+ self.tokenizer = tokenizer
+
+ def put(self, t):
+ self.times.append(get_time())
+ if len(self.times) > 1:
+ print(f"Token latency: {1000 * (self.times[-1] - self.times[-2]):.1f} ms")
+
+ if len(self.times) % 10 == 3 and self.print_median:
+ ts = np.array(self.times)
+ diff = ts[1:] - ts[:-1]
+ # print("Token latency:", 1000 * diff, "ms")
+ print("Token latency median:", np.median(1000 * diff), "ms")
+
+ def print_report(self):
+ times = np.array(self.times)
+ diff = times[1:] - times[:-1]
+ print(f"Median latency: {round(np.median(diff) * 1000, 2)}ms")
+ percentiles = [10, 25, 50, 75, 90]
+ print(
+ "Latency percentiles",
+ {p: round(1000 * float(np.percentile(diff, p)), 1) for p in percentiles},
+ )
+
+ def end(self, *args):
+ pass
+
+
+def parse_arguments():
+ parser = argparse.ArgumentParser(description="Run inference benchmark for LLM models")
+ parser.add_argument(
+ "--device",
+ type=str,
+ default="xpu",
+ help="Device to run inference on (e.g., xpu, cuda, cpu)",
+ )
+ parser.add_argument(
+ "--model-id",
+ type=str,
+ default="unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit",
+ help="Model ID from Hugging Face or local path",
+ )
+ parser.add_argument(
+ "--attn",
+ type=str,
+ default="eager",
+ choices=["eager", "flash_attention", "sdpa"],
+ help="Attention implementation to use",
+ )
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ args = parse_arguments()
+
+ device = args.device
+ model_id = args.model_id
+
+ print(f"Running inference on {device} with model {model_id}")
+ print(f"Using attention implementation: {args.attn}")
+
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
+ model = AutoModelForCausalLM.from_pretrained(model_id, attn_implementation=args.attn)
+
+ inputs = get_inputs(tokenizer)
+ streamer = get_streamer(tokenizer)
+
+ inputs = inputs.to(device)
+ model = model.to(device)
+
+ generation_config = GenerationConfig(
+ use_cache=True,
+ forced_eos_token_id=1,
+ eos_token_id=1,
+ max_new_tokens=MAX_NEW_TOKENS,
+ do_sample=False,
+ )
+
+ outputs = model.generate(
+ **inputs,
+ streamer=streamer,
+ generation_config=generation_config,
+ )
+
+ # Print the final outputs (including the input prompt)
+ output_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
+
+ print(r"\Output (including prompt):")
+ print("-" * 40)
+ print(output_text)
+ print("-" * 40)
+ print(f"Peak memory usage: {torch.xpu.max_memory_allocated() / 1024**2:.0f}MB")
+
+ streamer.print_report()
diff --git a/bitsandbytes/__init__.py b/bitsandbytes/__init__.py
index 516afa51f..d58b7b441 100644
--- a/bitsandbytes/__init__.py
+++ b/bitsandbytes/__init__.py
@@ -38,7 +38,6 @@
if hasattr(torch, "xpu") and torch.xpu.is_available():
from .backends.xpu import ops as xpu_ops
-
if importlib.util.find_spec("habana_frameworks") and importlib.util.find_spec("habana_frameworks.torch"):
# In case not automatically imported
import habana_frameworks.torch
@@ -76,4 +75,4 @@ def _import_backends():
"optim.optimizer.MockArgs": False,
}
-__version__ = "0.47.0.dev0"
+__version__ = "0.48.0.dev0"
diff --git a/bitsandbytes/_ops.py b/bitsandbytes/_ops.py
index a260852f5..532fe7afa 100644
--- a/bitsandbytes/_ops.py
+++ b/bitsandbytes/_ops.py
@@ -4,8 +4,6 @@
import torch
-from .cextension import ipex_cpu, ipex_xpu
-
_IS_TORCH_GTE_24 = False
if hasattr(torch.library, "register_fake"):
@@ -331,20 +329,105 @@ def _(
torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
-if ipex_cpu or ipex_xpu:
- # Register the dequantize_nf4_ipex implementation
- torch.library.define(
- "bitsandbytes::dequantize_nf4_ipex",
- "(Tensor A, Tensor absmax, int blocksize, int[] shape, ScalarType dtype) -> Tensor",
+torch.library.define(
+ "bitsandbytes::optimizer_update_32bit",
+ "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, Tensor(a4!)? unorm_vec, float max_unorm, float param_norm, float beta1, float beta2, float beta3, float alpha, float eps, float weight_decay, int step, float lr, float gnorm_scale, bool skip_zeros=False) -> ()",
+)
+
+
+@register_fake("bitsandbytes::optimizer_update_32bit")
+def _(
+ optimizer_name: str,
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ unorm_vec: Optional[torch.Tensor],
+ max_unorm: float,
+ param_norm: float,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ weight_decay: float,
+ step: int,
+ lr: float,
+ gnorm_scale: float,
+ skip_zeros=False,
+) -> None:
+ torch._check(
+ g.numel() == p.numel(),
+ lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
+ )
+ compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ torch._check(
+ g.dtype in compute_dtypes,
+ lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
+ )
+ torch._check(
+ g.dtype == p.dtype,
+ lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
)
- @register_fake("bitsandbytes::dequantize_nf4_ipex")
- def _(
- A: torch.Tensor,
- absmax: torch.Tensor,
- blocksize: int,
- shape: Sequence[int],
- dtype: torch.dtype,
- ) -> torch.Tensor:
- torch._check_is_size(blocksize)
- return torch.empty(shape, dtype=dtype, device=A.device)
+
+torch.library.define(
+ "bitsandbytes::optimizer_update_8bit_blockwise",
+ "(str optimizer_name, Tensor(a0!) g, Tensor(a1!) p, Tensor(a2!) state1, Tensor(a3!)? state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, Tensor(a4!) qmap1, Tensor(a5!)? qmap2, Tensor(a6!) absmax1, Tensor(a7!)? absmax2, float weight_decay, float gnorm_scale, bool skip_zeros=False) -> ()",
+)
+
+
+@register_fake("bitsandbytes::optimizer_update_8bit_blockwise")
+def _(
+ optimizer_name: str,
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: torch.Tensor,
+ qmap2: Optional[torch.Tensor],
+ absmax1: torch.Tensor,
+ absmax2: Optional[torch.Tensor],
+ weight_decay: float,
+ gnorm_scale: float,
+ skip_zeros=False,
+) -> None:
+ torch._check(
+ g.numel() == p.numel(),
+ lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
+ )
+ compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ torch._check(
+ g.dtype in compute_dtypes,
+ lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
+ )
+ torch._check(
+ g.dtype == p.dtype,
+ lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
+ )
+ torch._check(
+ state1.dtype == torch.uint8,
+ lambda: f"state1 must be uint8, got {state1.dtype}",
+ )
+ torch._check(
+ qmap1.dtype == absmax1.dtype == torch.float32,
+ lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
+ )
+ if state2 is not None:
+ torch._check(
+ state2.dtype == torch.uint8,
+ lambda: f"state2 must be uint8, got {state2.dtype}",
+ )
+ torch._check(
+ qmap2.dtype == absmax2.dtype == torch.float32,
+ lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
+ )
diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py
index 80fc86861..ece18caa3 100644
--- a/bitsandbytes/autograd/_functions.py
+++ b/bitsandbytes/autograd/_functions.py
@@ -8,7 +8,6 @@
from typing_extensions import deprecated
import bitsandbytes.functional as F
-from bitsandbytes.functional import ipex_cpu, ipex_xpu
# The inverse transformation for the colTuring and colAmpere format were contributed by Alex Borzunov:
# https://github.com/bigscience-workshop/petals/blob/main/src/petals/utils/linear8bitlt_patch.py
@@ -85,11 +84,7 @@ def get_inverse_transform_indices(
return permuted_tile_indices
-# torch.compiler.is_compiling() is available only in torch >= 2.3
-if hasattr(torch.compiler, "is_compiling"):
- _is_compiling = torch.compiler.is_compiling
-else:
- _is_compiling = torch._dynamo.is_compiling
+_is_compiling = torch.compiler.is_compiling
@deprecated(
@@ -320,8 +315,6 @@ def forward(ctx, A, B, out=None, bias=None, state=MatmulLtState):
CB = state.CB.data.to(A.dtype).mul_(state.SCB.unsqueeze(1).mul(1.0 / 127.0))
output = torch.nn.functional.linear(A, CB, bias)
- # to pass the test: tests/test_modules.py::test_linear8bitlt_no_fp16_weights[2.0-xpu]
- state.idx = False
ctx.state = state
ctx.dtype_A = A.dtype
ctx.grad_shape = A.shape
@@ -426,7 +419,7 @@ def matmul(
state.threshold = threshold
# MatMul8bitLt is slower because no fast kernel for quant/dequant 8bit in CPU/XPU
if state.is_training:
- if (A.device.type == "cpu" and ipex_cpu) or (A.device.type == "xpu" and ipex_xpu):
+ if A.device.type in ("cpu", "xpu"):
return MatMul8bitFp.apply(A, B, out, bias, state)
return MatMul8bitLt.apply(A, B, out, bias, state)
@@ -440,17 +433,6 @@ def matmul_4bit(
):
assert quant_state is not None
- if A.device.type in ("cpu", "xpu") and A.requires_grad == False:
- if getattr(quant_state, "ipex", False):
- # IPEX CPU will change weight to 4D so don't need transpose
- B = B.t() if B.dim() == 2 else B
- out = F.gemv_4bit(A, B, out, state=quant_state)
- if bias is not None:
- out += bias
- return out
- else:
- return MatMul4Bit.apply(A, B, out, bias, quant_state)
-
if A.numel() == A.shape[-1] and A.requires_grad == False and A.device.type != "hpu":
if A.shape[-1] % quant_state.blocksize != 0:
warn(
diff --git a/bitsandbytes/backends/cpu/ops.py b/bitsandbytes/backends/cpu/ops.py
index 5f009ea40..e295cc2a3 100644
--- a/bitsandbytes/backends/cpu/ops.py
+++ b/bitsandbytes/backends/cpu/ops.py
@@ -1,13 +1,14 @@
-from collections.abc import Sequence
import ctypes as ct
+import logging
import torch
from bitsandbytes.functional import get_ptr
from ..._ops import register_kernel
-from ...cextension import lib
-from ..utils import ipex_cpu
+from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
+
+logger = logging.getLogger(__name__)
# torch._int_mm for s8@s8->s32 is supported on CPU from torch 2.4+.
# However, we can overflow if we use this without AVX512_VNNI support.
@@ -24,97 +25,77 @@ def _(A: torch.Tensor, B: torch.Tensor):
).reshape(*A.shape[:-1], B.shape[0])
-@register_kernel("bitsandbytes::quantize_blockwise", "cpu")
-def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
- torch._check_is_size(blocksize)
-
- n = A.numel()
-
- # Only FP32 has c++ kernrl
- if A.dtype == torch.float32:
- blocks = -(n // -blocksize)
-
- absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
- out = torch.empty_like(A, dtype=torch.uint8)
-
- lib.cquantize_blockwise_cpu_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_longlong(blocksize),
- ct.c_longlong(n),
- )
- else:
- rem = n % blocksize
- has_rem = rem > 0
- blocks = n // blocksize + has_rem
- absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
- A_reshaped = A.reshape(n)
- A_com = A_reshaped[: n - rem]
- A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
- absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
- scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
- scaled_A = scaled_A.reshape(-1)
- if has_rem:
- absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
- scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
- scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
-
- diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
- out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
-
- return out, absmax
-
-
-@register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
-def _(A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype) -> torch.Tensor:
- torch._check_is_size(blocksize)
- torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
-
- # Only FP32 has c++ kernrl
- if dtype == torch.float32:
- out = torch.empty_like(A, dtype=dtype)
-
- lib.cdequantize_blockwise_cpu_fp32(
- get_ptr(code),
- get_ptr(A),
- get_ptr(absmax),
- get_ptr(out),
- ct.c_longlong(blocksize),
- ct.c_longlong(A.numel()),
- )
- else:
- out = code[A.reshape(-1).int()]
- blocks = out.shape[-1] // blocksize
- res = out.shape[-1] % blocksize
- if res != 0:
- out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
- out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
- out = out[: blocks * blocksize + res]
- out = out.reshape(A.shape)
-
- return out
-
-
-if ipex_cpu:
- from bitsandbytes.utils import _reverse_4bit_compress_format
-
- @register_kernel("bitsandbytes::dequantize_nf4_ipex", "cpu")
+if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
+
+ @register_kernel("bitsandbytes::quantize_blockwise", "cpu")
+ def _(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
+ torch._check_is_size(blocksize)
+
+ n = A.numel()
+
+ # Only FP32 has c++ kernrl
+ if A.dtype == torch.float32:
+ blocks = -(n // -blocksize)
+
+ absmax = torch.empty((blocks,), device=A.device, dtype=torch.float32)
+ out = torch.empty_like(A, dtype=torch.uint8)
+
+ lib.cquantize_blockwise_cpu_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_longlong(blocksize),
+ ct.c_longlong(n),
+ )
+ else:
+ rem = n % blocksize
+ has_rem = rem > 0
+ blocks = n // blocksize + has_rem
+ absmax = torch.zeros((blocks,), device=A.device, dtype=torch.float32)
+ A_reshaped = A.reshape(n)
+ A_com = A_reshaped[: n - rem]
+ A_com_reshaped = A_com.reshape(n // blocksize, blocksize)
+ absmax[: blocks - has_rem] = torch.abs(A_com_reshaped).max(dim=-1)[0]
+ scaled_A = torch.clamp(A_com_reshaped * (1 / absmax[: blocks - has_rem].view(-1, 1)), -1, 1)
+ scaled_A = scaled_A.reshape(-1)
+ if has_rem:
+ absmax[-1] = torch.abs(A_reshaped[n - rem :]).max()
+ scaled_A_rem = torch.clamp(A_reshaped[n - rem :] * (1 / absmax[-1]), -1, 1)
+ scaled_A = torch.cat([scaled_A, scaled_A_rem], dim=0)
+
+ diff = torch.abs(scaled_A.unsqueeze(-1) - code.to(scaled_A.device))
+ out = torch.argmin(diff, dim=-1).to(torch.uint8).to(scaled_A.device).reshape(A.shape)
+
+ return out, absmax
+
+ @register_kernel("bitsandbytes::dequantize_blockwise", "cpu")
def _(
- A: torch.Tensor,
- absmax: torch.Tensor,
- blocksize: int,
- shape: Sequence[int],
- dtype: torch.dtype,
+ A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
) -> torch.Tensor:
- ipex_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(A, "nf4", shape, 2)
- A = _reverse_4bit_compress_format(ipex_weight.reshape(-1)).reshape(1, -1)
- return torch.ops.bitsandbytes.dequantize_4bit.default(
- A,
- absmax,
- blocksize,
- "nf4",
- shape,
- dtype,
- )
+ torch._check_is_size(blocksize)
+ torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
+
+ # Only FP32 has c++ kernrl
+ if dtype == torch.float32:
+ out = torch.empty_like(A, dtype=dtype)
+
+ lib.cdequantize_blockwise_cpu_fp32(
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_longlong(blocksize),
+ ct.c_longlong(A.numel()),
+ )
+ else:
+ out = code[A.reshape(-1).int()]
+ blocks = out.shape[-1] // blocksize
+ res = out.shape[-1] % blocksize
+ if res != 0:
+ out = torch.nn.functional.pad(out, (0, blocksize - res), mode="constant", value=0)
+ out = (out.view(-1, blocksize) * absmax.view(-1, 1)).to(dtype).reshape(-1)
+ out = out[: blocks * blocksize + res]
+ out = out.reshape(A.shape)
+
+ return out
diff --git a/bitsandbytes/backends/cuda/ops.py b/bitsandbytes/backends/cuda/ops.py
index 13359bbd8..30cad3e34 100644
--- a/bitsandbytes/backends/cuda/ops.py
+++ b/bitsandbytes/backends/cuda/ops.py
@@ -538,3 +538,229 @@ def _gemv_4bit_impl(
ct.c_int32(blocksize),
stream,
)
+
+
+"""C FUNCTIONS FOR OPTIMIZERS"""
+str2optimizer32bit = {
+ "adam": (
+ lib.cadam32bit_grad_fp32,
+ lib.cadam32bit_grad_fp16,
+ lib.cadam32bit_grad_bf16,
+ ),
+ "momentum": (
+ lib.cmomentum32bit_grad_32,
+ lib.cmomentum32bit_grad_16,
+ ),
+ "rmsprop": (
+ lib.crmsprop32bit_grad_32,
+ lib.crmsprop32bit_grad_16,
+ ),
+ "lion": (
+ lib.clion32bit_grad_fp32,
+ lib.clion32bit_grad_fp16,
+ lib.clion32bit_grad_bf16,
+ ),
+ "adagrad": (
+ lib.cadagrad32bit_grad_32,
+ lib.cadagrad32bit_grad_16,
+ ),
+ "lamb": (
+ lib.cadam32bit_grad_fp32,
+ lib.cadam32bit_grad_fp16,
+ lib.cadam32bit_grad_bf16,
+ ),
+ "ademamix": (
+ lib.cademamix32bit_grad_fp32,
+ lib.cademamix32bit_grad_fp16,
+ lib.cademamix32bit_grad_bf16,
+ ),
+}
+
+str2optimizer8bit_blockwise = {
+ "adam": (
+ lib.cadam_8bit_blockwise_grad_fp32,
+ lib.cadam_8bit_blockwise_grad_fp16,
+ lib.cadam_8bit_blockwise_grad_bf16,
+ ),
+ "momentum": (
+ lib.cmomentum_8bit_blockwise_grad_fp32,
+ lib.cmomentum_8bit_blockwise_grad_fp16,
+ lib.cmomentum_8bit_blockwise_grad_bf16,
+ ),
+ "rmsprop": (
+ lib.crmsprop_8bit_blockwise_grad_fp32,
+ lib.crmsprop_8bit_blockwise_grad_fp16,
+ lib.crmsprop_8bit_blockwise_grad_bf16,
+ ),
+ "lion": (
+ lib.clion_8bit_blockwise_grad_fp32,
+ lib.clion_8bit_blockwise_grad_fp16,
+ lib.clion_8bit_blockwise_grad_bf16,
+ ),
+ "adagrad": (
+ lib.cadagrad_8bit_blockwise_grad_fp32,
+ lib.cadagrad_8bit_blockwise_grad_fp16,
+ lib.cadagrad_8bit_blockwise_grad_bf16,
+ ),
+ "ademamix": (
+ lib.cademamix_8bit_blockwise_grad_fp32,
+ lib.cademamix_8bit_blockwise_grad_fp16,
+ lib.cademamix_8bit_blockwise_grad_bf16,
+ ),
+}
+
+
+def _optimizer_update_32bit_impl(
+ optimizer_name: str,
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ unorm_vec: Optional[torch.Tensor],
+ max_unorm: float,
+ param_norm: float,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ weight_decay: float,
+ step: int,
+ lr: float,
+ gnorm_scale: float,
+ skip_zeros=False,
+) -> None:
+ optim_fns = str2optimizer32bit.get(optimizer_name, None)
+ if optim_fns is None:
+ raise ValueError(
+ f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
+ )
+ if g.dtype == torch.float32:
+ optim_func = optim_fns[0]
+ elif g.dtype == torch.float16:
+ optim_func = optim_fns[1]
+ elif g.dtype == torch.bfloat16 and len(optim_fns) == 3:
+ optim_func = optim_fns[2]
+ else:
+ raise ValueError(
+ f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
+ )
+
+ with _cuda_device_of(g):
+ optim_func(
+ get_ptr(g),
+ get_ptr(p),
+ get_ptr(state1),
+ get_ptr(state2),
+ get_ptr(unorm_vec),
+ ct.c_float(max_unorm),
+ ct.c_float(param_norm),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(beta3),
+ ct.c_float(alpha),
+ ct.c_float(eps),
+ ct.c_float(weight_decay),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
+
+
+def _optimizer_update_8bit_blockwise_impl(
+ optimizer_name: str,
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: torch.Tensor,
+ qmap2: Optional[torch.Tensor],
+ absmax1: torch.Tensor,
+ absmax2: Optional[torch.Tensor],
+ weight_decay: float,
+ gnorm_scale: float,
+ skip_zeros=False,
+) -> None:
+ # torch._check(
+ # g.numel() == p.numel(),
+ # lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
+ # )
+ # compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
+
+ # torch._check(
+ # g.dtype in compute_dtypes,
+ # lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
+ # )
+ # torch._check(
+ # g.dtype == p.dtype,
+ # lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
+ # )
+ # torch._check(
+ # state1.dtype == torch.uint8,
+ # lambda: f"state1 must be uint8, got {state1.dtype}",
+ # )
+ # torch._check(
+ # qmap1.dtype == absmax1.dtype == torch.float32,
+ # lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
+ # )
+ # if state2 is not None:
+ # torch._check(
+ # state2.dtype == torch.uint8,
+ # lambda: f"state2 must be uint8, got {state2.dtype}",
+ # )
+ # torch._check(
+ # qmap2.dtype == absmax2.dtype == torch.float32,
+ # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
+ # )
+ optimizer_fns = str2optimizer8bit_blockwise.get(optimizer_name)
+ if optimizer_fns is None:
+ raise ValueError(
+ f"Unsupported optimizer name: {optimizer_name}. Supported optimizers: {list(str2optimizer8bit_blockwise.keys())}"
+ )
+
+ if g.dtype == torch.float32:
+ optimizer_fn = optimizer_fns[0]
+ elif g.dtype == torch.float16:
+ optimizer_fn = optimizer_fns[1]
+ elif g.dtype == torch.bfloat16:
+ optimizer_fn = optimizer_fns[2]
+ else:
+ raise ValueError(
+ f"Unsupported gradient dtype: {g.dtype}. Supported dtypes: torch.float32, torch.float16, torch.bfloat16"
+ )
+
+ with _cuda_device_of(g):
+ optimizer_fn(
+ get_ptr(p),
+ get_ptr(g),
+ get_ptr(state1),
+ get_ptr(state2),
+ ct.c_float(beta1),
+ ct.c_float(beta2),
+ ct.c_float(beta3),
+ ct.c_float(alpha),
+ ct.c_float(eps),
+ ct.c_int32(step),
+ ct.c_float(lr),
+ get_ptr(qmap1),
+ get_ptr(qmap2),
+ get_ptr(absmax1),
+ get_ptr(absmax2),
+ ct.c_float(weight_decay),
+ ct.c_float(gnorm_scale),
+ ct.c_bool(skip_zeros),
+ ct.c_int32(g.numel()),
+ )
+
+
+register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "cuda")(_optimizer_update_8bit_blockwise_impl)
+register_kernel("bitsandbytes::optimizer_update_32bit", "cuda")(_optimizer_update_32bit_impl)
diff --git a/bitsandbytes/backends/default/ops.py b/bitsandbytes/backends/default/ops.py
index ce5926979..067347d47 100644
--- a/bitsandbytes/backends/default/ops.py
+++ b/bitsandbytes/backends/default/ops.py
@@ -1,5 +1,5 @@
from collections.abc import Sequence
-from math import prod
+from math import prod, sqrt
from typing import Optional
import torch
@@ -301,3 +301,278 @@ def _(
B_dq,
bias=None,
)
+
+
+MOMENTUM = 0
+RMSPROP = 1
+ADAGRAD = 2
+ADAM = 3
+# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels
+LION = 4
+ADEMAMIX = 5
+
+name2optimizer_id = {
+ "momentum": MOMENTUM,
+ "rmsprop": RMSPROP,
+ "adagrad": ADAGRAD,
+ "adam": ADAM,
+ "lion": LION,
+ "ademamix": ADEMAMIX,
+}
+
+
+@torch.compile
+def _optimizer_precondition_32bit(
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ unorm_vec: torch.Tensor,
+ beta1: float,
+ beta2: float,
+ eps: float,
+ weight_decay: float,
+ step: int,
+ lr: float,
+ gnorm_scale: float,
+ optimizer_id: int,
+):
+ """Preprocessing optimizer, computing update norm"""
+
+ g_vals = gnorm_scale * g
+
+ if optimizer_id == 3: # ADAM
+ correction1 = 1.0 / (1.0 - beta1**step)
+ correction2 = 1.0 / (1.0 - beta2**step)
+
+ s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
+ s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals
+
+ s1_vals = s1_vals * correction1
+ s2_vals = s2_vals * correction2
+
+ update_vals = s1_vals / (torch.sqrt(s2_vals) + eps)
+ update_norm = update_vals * update_vals
+
+ elif optimizer_id == 5: # ADEMAMIX
+ update_norm = state1
+
+ elif optimizer_id == 0: # MOMENTUM
+ if step == 1:
+ s1_vals = g_vals
+ else:
+ s1_vals = state1 * beta1 + g_vals
+ update_norm = s1_vals * s1_vals
+
+ elif optimizer_id == 4: # LION
+ s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
+ update_norm = s1_vals
+
+ elif optimizer_id == 1: # RMSPROP
+ s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
+ update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
+ update_norm = update_vals * update_vals
+
+ elif optimizer_id == 2: # ADAGRAD
+ s1_vals = state1 + g_vals * g_vals
+ update_vals = g_vals / (torch.sqrt(s1_vals) + eps)
+ update_norm = update_vals * update_vals
+
+ total_norm = torch.sum(update_norm)
+ unorm_vec.add_(total_norm)
+
+
+@torch.compile
+def _optimizer_update_32bit(
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ unorm_vec: Optional[torch.Tensor],
+ max_unorm: float,
+ param_norm: float,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ weight_decay: float,
+ step: int,
+ lr: float,
+ gnorm_scale: float,
+ optimizer_id: int,
+):
+ """Unified optimizer update kernel"""
+
+ p_vals = p.float()
+ g_vals = (gnorm_scale * g).float()
+ if optimizer_id in [0, 1, 2, 4] and weight_decay > 0.0:
+ g_vals = g_vals + p_vals * weight_decay
+
+ update_scale = 1.0
+ if max_unorm > 0.0:
+ current_unorm = torch.sqrt(unorm_vec)
+ if optimizer_id in [0, 1, 2, 4]: # 1-state optimizers
+ if current_unorm > max_unorm * param_norm + eps:
+ update_scale = (max_unorm * param_norm + eps) / current_unorm
+ else: # 2-state optimizers
+ if current_unorm > max_unorm * param_norm:
+ update_scale = (max_unorm * param_norm) / current_unorm
+
+ if optimizer_id == 3: # ADAM
+ s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals
+ s2_vals = state2 * beta2 + (1.0 - beta2) * g_vals * g_vals
+
+ correction1 = 1.0 - beta1**step
+ correction2 = sqrt(1.0 - beta2**step)
+ step_size = -lr * correction2 / correction1
+
+ if weight_decay > 0.0:
+ p_vals = p_vals * (1.0 - lr * weight_decay)
+
+ update_val = update_scale * step_size * (s1_vals / (torch.sqrt(s2_vals) + eps * correction2))
+ p_vals = p_vals + update_val
+
+ state1.copy_(s1_vals)
+ state2.copy_(s2_vals)
+
+ elif optimizer_id == 5: # ADEMAMIX
+ s1_vals = state1[0]
+ s3_vals = state1[1]
+ s2_vals = state2
+
+ m1 = s1_vals * beta1 + (1.0 - beta1) * g_vals
+ m2 = s3_vals * beta3 + (1.0 - beta3) * g_vals
+ nu = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
+
+ correction1 = 1.0 - beta1**step
+ correction2 = sqrt(1.0 - beta2**step)
+
+ if weight_decay > 0.0:
+ p_vals = p_vals * (1.0 - lr * weight_decay)
+
+ mixed_momentum = (m1 / correction1) + (alpha * m2)
+ adaptive_term = (torch.sqrt(nu) / correction2) + eps
+ p_vals = p_vals - lr * (mixed_momentum / adaptive_term)
+
+ state1[0].copy_(m1)
+ state1[1].copy_(m2)
+ state2.copy_(nu)
+
+ elif optimizer_id == 0: # MOMENTUM
+ if step == 1:
+ s1_vals = g_vals
+ else:
+ s1_vals = state1 * beta1 + g_vals
+
+ update_val = update_scale * (-lr * s1_vals)
+ p_vals = p_vals + update_val
+
+ state1.copy_(s1_vals)
+
+ elif optimizer_id == 4: # LION
+ momentum_update = state1 * beta1 + (1.0 - beta1) * g_vals
+ update_val = update_scale * lr * torch.sign(momentum_update)
+ p_vals = p_vals - update_val
+
+ s1_vals = state1 * beta2 + (1.0 - beta2) * g_vals
+ state1.copy_(s1_vals)
+
+ elif optimizer_id == 1: # RMSPROP
+ s1_vals = state1 * beta1 + (1.0 - beta1) * g_vals * g_vals
+ update_val = update_scale * lr * g_vals / (torch.sqrt(s1_vals) + eps)
+ p_vals = p_vals - update_val
+
+ state1.copy_(s1_vals)
+
+ elif optimizer_id == 2: # ADAGRAD
+ s1_vals = state1 + g_vals * g_vals
+ update_val = lr * g_vals / (torch.sqrt(s1_vals) + eps)
+ p_vals = p_vals - update_val
+
+ state1.copy_(s1_vals)
+
+ p.copy_(p_vals)
+
+
+@register_kernel("bitsandbytes::optimizer_update_32bit", "default")
+def _(
+ optimizer_name: str,
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ unorm_vec: Optional[torch.Tensor],
+ max_unorm: float,
+ param_norm: float,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ weight_decay: float,
+ step: int,
+ lr: float,
+ gnorm_scale: float = 1.0,
+ skip_zeros=False,
+) -> None:
+ """
+ 32-bit optimizer implemented by PyTorch with @torch.compile
+ """
+ if skip_zeros:
+ raise NotImplementedError("skip_zeros is not supported yet")
+
+ optimizer_id = name2optimizer_id[optimizer_name]
+
+ if optimizer_name == "lion":
+ _optimizer_update_32bit(
+ g,
+ p,
+ state1,
+ state2,
+ unorm_vec,
+ max_unorm,
+ param_norm,
+ beta1,
+ beta2,
+ beta3,
+ alpha,
+ eps,
+ weight_decay,
+ step,
+ lr,
+ gnorm_scale,
+ optimizer_id,
+ )
+
+ if max_unorm > 0.0:
+ unorm_vec.zero_()
+ _optimizer_precondition_32bit(
+ g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
+ )
+ else:
+ if max_unorm > 0.0:
+ unorm_vec.zero_()
+ _optimizer_precondition_32bit(
+ g, p, state1, state2, unorm_vec, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, optimizer_id
+ )
+
+ _optimizer_update_32bit(
+ g,
+ p,
+ state1,
+ state2,
+ unorm_vec,
+ max_unorm,
+ param_norm,
+ beta1,
+ beta2,
+ beta3,
+ alpha,
+ eps,
+ weight_decay,
+ step,
+ lr,
+ gnorm_scale,
+ optimizer_id,
+ )
diff --git a/bitsandbytes/backends/hpu/ops.py b/bitsandbytes/backends/hpu/ops.py
index 4c43a3cb7..9ecd63e0b 100644
--- a/bitsandbytes/backends/hpu/ops.py
+++ b/bitsandbytes/backends/hpu/ops.py
@@ -3,12 +3,19 @@
import torch
-from bitsandbytes.utils import _reverse_4bit_compress_format
-
from ..._ops import register_kernel
from ..utils import GAUDI_SW_VER
+# convert btw standard 4-bit compression format and ipex compression format
+# needed for backward compatibility with older versions of gaudi sw
+def _reverse_4bit_compress_format(weight: torch.Tensor):
+ out_1 = (weight & 0xF0) >> 4
+ out_2 = (weight & 0xF) << 4
+ out = out_1 | out_2
+ return out
+
+
@register_kernel("bitsandbytes::dequantize_4bit", "hpu")
def _(
A: torch.Tensor,
diff --git a/bitsandbytes/backends/triton/triton_kernels.py b/bitsandbytes/backends/triton/kernels_4bit.py
similarity index 78%
rename from bitsandbytes/backends/triton/triton_kernels.py
rename to bitsandbytes/backends/triton/kernels_4bit.py
index 03ffa187d..0e94f49e8 100644
--- a/bitsandbytes/backends/triton/triton_kernels.py
+++ b/bitsandbytes/backends/triton/kernels_4bit.py
@@ -4,167 +4,6 @@
import triton.language as tl
-# @triton.autotune(
-# configs=[
-# # triton.Config({'SPLIT_SIZE': 64}),
-# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
-# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
-# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
-# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
-# # triton.Config({'SPLIT_SIZE': 128}),
-# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
-# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
-# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
-# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
-# triton.Config({"SPLIT_SIZE": 256}),
-# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
-# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
-# triton.Config({"SPLIT_SIZE": 512}),
-# # triton.Config({'SPLIT_SIZE': 1024}),
-# ],
-# key=["num_paired_elements", "QUANT_BLOCK"],
-# )
-@triton.jit
-def dequant_8bit_kernel(
- a_ptr,
- c_ptr,
- quant_ptr,
- absmax_ptr,
- num_paired_elements,
- QUANT_BLOCK: tl.constexpr,
- SPLIT_SIZE: tl.constexpr,
-):
- pid = tl.program_id(axis=0)
- block_start = pid * SPLIT_SIZE
- offsets = block_start + tl.arange(0, SPLIT_SIZE)
- mask = offsets < num_paired_elements
-
- a = tl.load(a_ptr + offsets, mask)
- a = a.to(tl.uint8)
-
- # apply conversion
- scaled_int8 = tl.load(quant_ptr + a, mask)
-
- abs_blocks_lim = (num_paired_elements // QUANT_BLOCK) * QUANT_BLOCK + num_paired_elements % QUANT_BLOCK
- abs_offsets = offsets // QUANT_BLOCK
- mask_blocked = offsets < abs_blocks_lim
-
- absmax = tl.load(absmax_ptr + abs_offsets, mask_blocked)
- # apply scales
- out_dq = scaled_int8 * absmax
-
- offs = block_start + tl.arange(0, SPLIT_SIZE)
- mask = offs < num_paired_elements
- tl.store(c_ptr + offs, out_dq, mask)
-
-
-def dequant_int8_blockwise(
- A_nf4: torch.Tensor,
- quant_state_code: torch.Tensor,
- absmax: torch.Tensor,
- out: torch.Tensor,
- quant_blocksize: int = 64,
-):
- number_of_paired_elements = A_nf4.numel()
-
- SPLIT_SIZE = 256
- # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
- grid = (triton.cdiv(number_of_paired_elements, SPLIT_SIZE),)
- dequant_8bit_kernel[grid](
- A_nf4,
- out,
- quant_state_code,
- absmax,
- number_of_paired_elements,
- quant_blocksize,
- SPLIT_SIZE,
- )
- return out
-
-
-# @triton.autotune(
-# configs=[
-# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
-# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
-# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
-# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
-# ],
-# key=["n_elements"],
-# )
-@triton.jit
-def quantize_blockwise_kernel(
- A_ptr,
- code_ptr,
- absmax_ptr,
- out_ptr,
- n_elements,
- BLOCK_SIZE: tl.constexpr,
- CODE_SIZE: tl.constexpr,
- SPLIT_NUM_BLOCKS: tl.constexpr,
-):
- block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
- thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
-
- offsets = block_start_idx * BLOCK_SIZE + thread_idx
- mask = offsets < n_elements
-
- A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
-
- # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
- A_reshaped = tl.reshape(A, (SPLIT_NUM_BLOCKS, BLOCK_SIZE))
-
- # Calculating absamax for each block
- absmax = tl.max(tl.abs(A_reshaped), axis=1)
- tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
-
- A_normalized = A_reshaped / absmax[:, None]
- A_normalized = tl.clamp(A_normalized, -1.0, 1.0)
-
- lower_pivot = tl.zeros((SPLIT_NUM_BLOCKS, BLOCK_SIZE), dtype=tl.int32)
- upper_pivot = tl.full((SPLIT_NUM_BLOCKS, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
-
- for _ in range(8): # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
- pivot = (lower_pivot + upper_pivot) // 2
- val = tl.load(code_ptr + pivot)
- is_higher = A_normalized > val # code[pivot]
- lower_pivot = tl.where(is_higher, pivot, lower_pivot)
- upper_pivot = tl.where(is_higher, upper_pivot, pivot)
-
- # Choose closest level
- lower_val = tl.load(code_ptr + lower_pivot)
- upper_val = tl.load(code_ptr + upper_pivot)
- lower_dist = tl.abs(A_normalized - lower_val)
- upper_dist = tl.abs(A_normalized - upper_val)
- quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
-
- # too slow approach
- # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
- # quantized = tl.argmin(diff, axis=2).to(tl.uint8)
-
- quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * SPLIT_NUM_BLOCKS,))
- tl.store(out_ptr + offsets, quantized_flat, mask=mask)
-
-
-def quantize_blockwise_triton(A, blocksize, code, blocks, absmax, quantized_out):
- n = A.numel()
-
- split_num_blocks = 1
- grid = (triton.cdiv(blocks, split_num_blocks),)
- # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
- quantize_blockwise_kernel[grid](
- A_ptr=A,
- code_ptr=code,
- absmax_ptr=absmax,
- out_ptr=quantized_out,
- n_elements=n,
- BLOCK_SIZE=blocksize,
- CODE_SIZE=code.numel(),
- SPLIT_NUM_BLOCKS=split_num_blocks,
- )
-
- return quantized_out, absmax
-
-
# Triton implementation of similar CUDA kernel to avoid loading code from csrc/kernels.cu::dQuantizeFP4
# @triton.autotune(
# configs=[
@@ -587,7 +426,7 @@ def dequant_nf4_kernel(
tl.store(c_ptr + offs, out_dq, mask)
-def _dequantize_4bit_impl(
+def dequantize_4bit_impl(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
@@ -611,7 +450,7 @@ def _dequantize_4bit_impl(
dequant_nf4_kernel[grid](A, out, absmax, number_of_paired_elements, blocksize, SPLIT_SIZE)
-def _dequantize_4bit_impl_passing_code(
+def dequantize_4bit_impl_passing_code(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
diff --git a/bitsandbytes/backends/triton/kernels_8bit_quant.py b/bitsandbytes/backends/triton/kernels_8bit_quant.py
new file mode 100644
index 000000000..c0a5a21ef
--- /dev/null
+++ b/bitsandbytes/backends/triton/kernels_8bit_quant.py
@@ -0,0 +1,195 @@
+import torch
+
+import triton
+import triton.language as tl
+
+
+# @triton.autotune(
+# configs=[
+# # triton.Config({'SPLIT_SIZE': 64}),
+# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
+# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
+# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
+# # triton.Config({'SPLIT_SIZE': 64, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
+# # triton.Config({'SPLIT_SIZE': 128}),
+# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
+# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
+# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'large'}, num_stages=4, num_warps=32),
+# # triton.Config({'SPLIT_SIZE': 128, 'grf_mode': 'auto'}, num_stages=4, num_warps=32),
+# triton.Config({"SPLIT_SIZE": 256}),
+# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'large'}, num_stages=2, num_warps=32),
+# # triton.Config({'SPLIT_SIZE': 256, 'grf_mode': 'auto'}, num_stages=2, num_warps=32),
+# triton.Config({"SPLIT_SIZE": 512}),
+# # triton.Config({'SPLIT_SIZE': 1024}),
+# ],
+# key=["num_paired_elements", "QUANT_BLOCK"],
+# )
+@triton.jit
+def dequant_8bit_kernel(
+ a_ptr,
+ out_ptr,
+ code_ptr,
+ absmax_ptr,
+ n,
+ QUANT_BLOCK: tl.constexpr,
+ SPLIT_SIZE: tl.constexpr,
+):
+ pid = tl.program_id(axis=0)
+ block_start = pid * SPLIT_SIZE
+ offsets = block_start + tl.arange(0, SPLIT_SIZE)
+ mask = offsets < n
+ out_dq = dequant_8bit_blockwise_kernel_util(a_ptr, offsets, code_ptr, absmax_ptr, mask, QUANT_BLOCK)
+ tl.store(out_ptr + offsets, out_dq, mask)
+
+
+def dequant_8bit_blockwise(
+ a: torch.Tensor,
+ absmax: torch.Tensor,
+ quant_state_code: torch.Tensor,
+ quant_blocksize: int = 64,
+ dtype: torch.dtype = None,
+ out: torch.Tensor = None,
+):
+ n = a.numel()
+ if out is None:
+ if dtype is None:
+ raise ValueError("If out is None, dtype must be specified")
+ out = torch.empty_like(a, dtype=dtype, device=a.device)
+
+ SPLIT_SIZE = 256
+ # grid = lambda META: (triton.cdiv(number_of_paired_elements, META["SPLIT_SIZE"]),)
+ grid = (triton.cdiv(n, SPLIT_SIZE),)
+ dequant_8bit_kernel[grid](
+ a,
+ out,
+ quant_state_code,
+ absmax,
+ n,
+ quant_blocksize,
+ SPLIT_SIZE,
+ )
+ return out
+
+
+# @triton.autotune(
+# configs=[
+# triton.Config({"SPLIT_NUM_BLOCKS": 1, "grf_mode": "auto"}, num_stages=4, num_warps=32),
+# triton.Config({"SPLIT_NUM_BLOCKS": 2, "grf_mode": "auto"}, num_stages=4, num_warps=32),
+# triton.Config({"SPLIT_NUM_BLOCKS": 1}),
+# triton.Config({"SPLIT_NUM_BLOCKS": 2}),
+# ],
+# key=["n_elements"],
+# )
+@triton.jit
+def quantize_8bit_blockwise_kernel(
+ A_ptr,
+ code_ptr,
+ absmax_ptr,
+ out_ptr,
+ n_elements,
+ BLOCK_SIZE: tl.constexpr,
+ CODE_SIZE: tl.constexpr,
+ SPLIT_NUM_BLOCKS: tl.constexpr,
+):
+ block_start_idx = tl.program_id(0) * SPLIT_NUM_BLOCKS
+ thread_idx = tl.arange(0, SPLIT_NUM_BLOCKS * BLOCK_SIZE)
+
+ offsets = block_start_idx * BLOCK_SIZE + thread_idx
+ mask = offsets < n_elements
+
+ A = tl.load(A_ptr + offsets, mask=mask, other=0.0)
+
+ quantized, absmax = quantize_8bit_blockwise_kernel_util(A, code_ptr, CODE_SIZE, BLOCK_SIZE, SPLIT_NUM_BLOCKS)
+ tl.store(absmax_ptr + block_start_idx + tl.arange(0, SPLIT_NUM_BLOCKS), absmax)
+ tl.store(out_ptr + offsets, quantized, mask=mask)
+
+
+def quantize_blockwise_triton(A, code, blocksize, absmax=None, out=None):
+ n = A.numel()
+ blocks = -(n // -blocksize)
+
+ if absmax is None:
+ absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
+ if out is None:
+ out = torch.empty_like(A.flatten(), dtype=torch.uint8)
+
+ split_num_blocks = 1
+ grid = (triton.cdiv(blocks, split_num_blocks),)
+ # grid = lambda META: (triton.cdiv(blocks, META["SPLIT_NUM_BLOCKS"]),)
+ quantize_8bit_blockwise_kernel[grid](
+ A_ptr=A,
+ code_ptr=code,
+ absmax_ptr=absmax,
+ out_ptr=out,
+ n_elements=n,
+ BLOCK_SIZE=blocksize,
+ CODE_SIZE=code.numel(),
+ SPLIT_NUM_BLOCKS=split_num_blocks,
+ # num_warps=1,
+ # num_stages=2,
+ )
+ out = out.reshape(A.shape)
+
+ return out, absmax
+
+
+@triton.jit
+def quantize_8bit_blockwise_kernel_util(
+ a,
+ code_ptr,
+ CODE_SIZE: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ N_PER_TH: tl.constexpr,
+):
+ # To be able process several blocks -> (BLOCK_SIZE, SPLIT_NUM_BLOCKS)
+ a_reshaped = tl.reshape(a, (N_PER_TH, BLOCK_SIZE))
+
+ # Calculating absmax for each block
+ absmax = tl.max(tl.abs(a_reshaped), axis=1)
+
+ a_normalized = a_reshaped / absmax[:, None]
+ a_normalized = tl.clamp(a_normalized, -1.0, 1.0)
+
+ lower_pivot = tl.zeros((N_PER_TH, BLOCK_SIZE), dtype=tl.int32)
+ upper_pivot = tl.full((N_PER_TH, BLOCK_SIZE), CODE_SIZE - 1, dtype=tl.int32)
+
+ # ceil(log2(code_size)) = 8, actually, in general case should be input parameter
+ for _ in range(8):
+ pivot = (lower_pivot + upper_pivot) // 2
+ val = tl.load(code_ptr + pivot)
+ is_higher = a_normalized > val # code[pivot]
+ lower_pivot = tl.where(is_higher, pivot, lower_pivot)
+ upper_pivot = tl.where(is_higher, upper_pivot, pivot)
+
+ # Choose closest level
+ lower_val = tl.load(code_ptr + lower_pivot)
+ upper_val = tl.load(code_ptr + upper_pivot)
+ lower_dist = tl.abs(a_normalized - lower_val)
+ upper_dist = tl.abs(a_normalized - upper_val)
+ quantized = tl.where(lower_dist <= upper_dist, lower_pivot, upper_pivot).to(tl.uint8)
+
+ # too slow approach
+ # diff = tl.abs(A_normalized[:, :, None] - code[None, None, :])
+ # quantized = tl.argmin(diff, axis=2).to(tl.uint8)
+
+ quantized_flat = tl.reshape(quantized, (BLOCK_SIZE * N_PER_TH,))
+ return quantized_flat, absmax
+
+
+@triton.jit
+def dequant_8bit_blockwise_kernel_util(
+ a_ptr,
+ offsets,
+ code_ptr,
+ absmax_ptr,
+ mask,
+ BLOCK_SIZE: tl.constexpr,
+):
+ a = tl.load(a_ptr + offsets, mask, other=0).to(tl.uint8)
+ scaled_int8 = tl.load(code_ptr + a, mask)
+ # Load scales
+ absmax_offsets = offsets // BLOCK_SIZE
+ absmax = tl.load(absmax_ptr + absmax_offsets, mask=mask, other=0.0, eviction_policy="evict_last")
+ # Apply scales
+ out_dq = scaled_int8 * absmax
+ return out_dq
diff --git a/bitsandbytes/backends/triton/kernels_optim.py b/bitsandbytes/backends/triton/kernels_optim.py
new file mode 100644
index 000000000..2cd6d8c93
--- /dev/null
+++ b/bitsandbytes/backends/triton/kernels_optim.py
@@ -0,0 +1,1154 @@
+import math
+from typing import Optional
+
+import torch
+
+import triton
+import triton.language as tl
+
+# from triton.language.extra import libdevice
+from .kernels_8bit_quant import (
+ dequant_8bit_blockwise,
+ dequant_8bit_blockwise_kernel_util,
+ quantize_8bit_blockwise_kernel_util,
+ quantize_blockwise_triton,
+)
+
+MOMENTUM = 0
+RMSPROP = 1
+ADAGRAD = 2
+ADAM = 3
+# LION should be larger than MOMENTUM, RMSPROP, ADAGRAD due to comparison in kernels
+LION = 4
+ADEMAMIX = 5
+
+name2optimizer_id = {
+ "momentum": MOMENTUM,
+ "rmsprop": RMSPROP,
+ "adagrad": ADAGRAD,
+ "adam": ADAM,
+ "lion": LION,
+ "ademamix": ADEMAMIX,
+}
+
+
+@triton.jit
+def _optimizer_precondition_2state_32bit(
+ g_ptr,
+ p_ptr,
+ state1_ptr,
+ state2_ptr,
+ unorm_ptr,
+ beta1: tl.constexpr,
+ beta2: tl.constexpr,
+ eps: tl.constexpr,
+ weight_decay: tl.constexpr,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ gnorm_scale: tl.constexpr,
+ n_elements,
+ OPTIMIZER_ID: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ N_PER_TH: tl.constexpr,
+):
+ """Preprocessing optimizer, computing update norm (2-state optimizer)"""
+ pid = tl.program_id(axis=0)
+ block_start_idx = pid * N_PER_TH
+ offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
+ mask = offsets < n_elements
+
+ g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0)
+ s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
+ s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0)
+
+ g_vals = gnorm_scale * g_vals
+
+ correction1 = 1.0 / (1.0 - beta1_step)
+ correction2 = 1.0 / (1.0 - beta2_step)
+
+ if OPTIMIZER_ID == 3: # ADAM
+ s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals
+ s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
+
+ s1_vals = s1_vals * correction1
+ s2_vals = s2_vals * correction2
+
+ update_vals = s1_vals / (tl.sqrt(s2_vals) + eps)
+
+ update_norm = update_vals * update_vals
+
+ elif OPTIMIZER_ID == 5: # ADEMAMIX
+ update_norm = s1_vals
+
+ total_norm = tl.sum(tl.where(mask, update_norm, 0.0))
+
+ tl.atomic_add(unorm_ptr, total_norm)
+
+
+@triton.jit
+def _optimizer_precondition_1state_32bit(
+ g_ptr,
+ p_ptr,
+ state1_ptr,
+ state2_ptr,
+ unorm_ptr,
+ beta1: tl.constexpr,
+ beta2: tl.constexpr,
+ eps: tl.constexpr,
+ weight_decay,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ gnorm_scale: tl.constexpr,
+ n_elements,
+ OPTIMIZER_ID: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ N_PER_TH: tl.constexpr,
+):
+ """Preprocessing optimizer, computing update norm (1-state optimizer)"""
+ pid = tl.program_id(axis=0)
+ block_start_idx = pid * N_PER_TH
+ offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
+ mask = offsets < n_elements
+
+ g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0)
+ s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
+
+ g_vals = gnorm_scale * g_vals
+
+ if OPTIMIZER_ID == 0: # MOMENTUM
+ if step == 1:
+ s1_vals = g_vals
+ else:
+ s1_vals = s1_vals * beta1 + g_vals
+ update_norm = s1_vals * s1_vals
+
+ elif OPTIMIZER_ID == 4: # LION
+ s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals
+ update_norm = s1_vals
+
+ elif OPTIMIZER_ID == 1: # RMSPROP
+ s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals
+ update_vals = g_vals / (tl.sqrt(s1_vals) + eps)
+ update_norm = update_vals * update_vals
+
+ elif OPTIMIZER_ID == 2: # ADAGRAD
+ s1_vals = s1_vals + g_vals * g_vals
+ update_vals = g_vals / (tl.sqrt(s1_vals) + eps)
+ update_norm = update_vals * update_vals
+
+ total_norm = tl.sum(tl.where(mask, update_norm, 0.0))
+
+ tl.atomic_add(unorm_ptr, total_norm)
+
+
+@triton.jit
+def _optimizer_update_2state_32bit_triton_kernel(
+ g_ptr,
+ p_ptr,
+ state1_ptr,
+ state2_ptr,
+ unorm_ptr,
+ max_unorm: tl.constexpr,
+ param_norm,
+ beta1: tl.constexpr,
+ beta2: tl.constexpr,
+ beta3,
+ alpha,
+ eps: tl.constexpr,
+ weight_decay: tl.constexpr,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ gnorm_scale: tl.constexpr,
+ skip_zeros,
+ n_elements,
+ OPTIMIZER_ID: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ N_PER_TH: tl.constexpr,
+):
+ """2-state optimizer kernel"""
+ pid = tl.program_id(axis=0)
+ block_start_idx = pid * N_PER_TH
+ offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
+ mask = offsets < n_elements
+
+ g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
+ s2_vals = tl.load(state2_ptr + offsets, mask=mask, other=0.0)
+
+ if OPTIMIZER_ID == 5: # ADEMAMIX
+ s3_vals = tl.load(state1_ptr + n_elements + offsets, mask=mask, other=0.0)
+
+ g_vals = gnorm_scale * g_vals
+
+ update_scale = 1.0
+ if max_unorm > 0.0:
+ current_unorm = tl.sqrt(tl.load(unorm_ptr))
+ if current_unorm > max_unorm * param_norm:
+ update_scale = (max_unorm * param_norm) / current_unorm
+
+ if OPTIMIZER_ID == 3: # ADAM
+ s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals
+ s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals
+
+ correction1 = 1.0 - beta1_step
+ correction2 = tl.sqrt(1.0 - beta2_step)
+ step_size = -lr * correction2 / correction1
+
+ if weight_decay > 0.0:
+ p_vals = p_vals * (1.0 - lr * weight_decay)
+
+ update_val = update_scale * step_size * (s1_vals / (tl.sqrt(s2_vals) + eps * correction2))
+ p_vals = p_vals + update_val
+
+ elif OPTIMIZER_ID == 5: # ADEMAMIX
+ s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals # m1
+ s3_vals = s3_vals * beta3 + (1.0 - beta3) * g_vals # m2
+ s2_vals = s2_vals * beta2 + (1.0 - beta2) * g_vals * g_vals # nu
+
+ correction1 = 1.0 - beta1_step
+ correction2 = tl.sqrt(1.0 - beta2_step)
+
+ if weight_decay > 0.0:
+ p_vals = p_vals * (1.0 - lr * weight_decay)
+
+ mixed_momentum = (s1_vals / correction1) + (alpha * s3_vals)
+ adaptive_term = (tl.sqrt(s2_vals) / correction2) + eps
+ p_vals = p_vals - lr * (mixed_momentum / adaptive_term)
+
+ tl.store(p_ptr + offsets, p_vals, mask=mask)
+ tl.store(state1_ptr + offsets, s1_vals, mask=mask)
+ tl.store(state2_ptr + offsets, s2_vals, mask=mask)
+
+ if OPTIMIZER_ID == 5: # ADEMAMIX
+ tl.store(state1_ptr + n_elements + offsets, s3_vals, mask=mask)
+
+
+@triton.jit
+def _optimizer_update_1state_32bit_triton_kernel(
+ g_ptr,
+ p_ptr,
+ state1_ptr,
+ state2_ptr,
+ unorm_ptr,
+ max_unorm: tl.constexpr,
+ param_norm,
+ beta1: tl.constexpr,
+ beta2: tl.constexpr,
+ beta3,
+ alpha,
+ eps: tl.constexpr,
+ weight_decay: tl.constexpr,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ gnorm_scale: tl.constexpr,
+ skip_zeros,
+ n_elements,
+ OPTIMIZER_ID: tl.constexpr,
+ BLOCK_SIZE: tl.constexpr,
+ N_PER_TH: tl.constexpr,
+):
+ """1-state optimizer kernel"""
+ pid = tl.program_id(axis=0)
+ block_start_idx = pid * N_PER_TH
+ offsets = block_start_idx * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE * N_PER_TH)
+ mask = offsets < n_elements
+
+ g_vals = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ p_vals = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ s1_vals = tl.load(state1_ptr + offsets, mask=mask, other=0.0)
+
+ g_vals = gnorm_scale * g_vals
+ if weight_decay > 0.0:
+ g_vals = g_vals + p_vals * weight_decay
+
+ update_scale = 1.0
+ if max_unorm > 0.0:
+ current_unorm = tl.sqrt(tl.load(unorm_ptr))
+ if current_unorm > max_unorm * param_norm + eps:
+ update_scale = (max_unorm * param_norm + eps) / current_unorm
+
+ if OPTIMIZER_ID == 0: # MOMENTUM
+ if step == 1:
+ s1_vals = g_vals
+ else:
+ s1_vals = s1_vals * beta1 + g_vals
+
+ update_val = update_scale * (-lr * s1_vals)
+ p_vals = p_vals + update_val
+
+ elif OPTIMIZER_ID == 4: # LION
+ momentum_update = s1_vals * beta1 + (1.0 - beta1) * g_vals
+ update_val = update_scale * lr * tl.where(momentum_update > 0, 1.0, tl.where(momentum_update < 0, -1.0, 0.0))
+ p_vals = p_vals - update_val
+
+ s1_vals = s1_vals * beta2 + (1.0 - beta2) * g_vals
+
+ elif OPTIMIZER_ID == 1: # RMSPROP
+ s1_vals = s1_vals * beta1 + (1.0 - beta1) * g_vals * g_vals
+
+ update_val = update_scale * lr * g_vals / (tl.sqrt(s1_vals) + eps)
+ p_vals = p_vals - update_val
+
+ elif OPTIMIZER_ID == 2: # ADAGRAD
+ s1_vals = s1_vals + g_vals * g_vals
+
+ update_val = lr * g_vals / (tl.sqrt(s1_vals) + eps)
+ p_vals = p_vals - update_val
+
+ tl.store(p_ptr + offsets, p_vals, mask=mask)
+ tl.store(state1_ptr + offsets, s1_vals, mask=mask)
+
+
+name2optimizer_32bit_fn = {
+ "adam": {
+ "preprocess": _optimizer_precondition_2state_32bit,
+ "update": _optimizer_update_2state_32bit_triton_kernel,
+ },
+ "ademamix": {
+ "preprocess": _optimizer_precondition_2state_32bit,
+ "update": _optimizer_update_2state_32bit_triton_kernel,
+ },
+ "momentum": {
+ "preprocess": _optimizer_precondition_1state_32bit,
+ "update": _optimizer_update_1state_32bit_triton_kernel,
+ },
+ "rmsprop": {
+ "preprocess": _optimizer_precondition_1state_32bit,
+ "update": _optimizer_update_1state_32bit_triton_kernel,
+ },
+ "adagrad": {
+ "preprocess": _optimizer_precondition_1state_32bit,
+ "update": _optimizer_update_1state_32bit_triton_kernel,
+ },
+ "lion": {
+ "preprocess": _optimizer_precondition_1state_32bit,
+ "update": _optimizer_update_1state_32bit_triton_kernel,
+ },
+}
+
+
+def optimizer_update_32bit_impl(
+ optimizer_name: str,
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ unorm_vec: Optional[torch.Tensor],
+ max_unorm: float,
+ param_norm: float,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ weight_decay: float,
+ step: int,
+ lr: float,
+ gnorm_scale: float = 1.0,
+ skip_zeros=False,
+) -> None:
+ """
+ 32-bit optimizer implemented by Triton
+ """
+ if skip_zeros:
+ raise NotImplementedError("skip_zeros is not supported on XPU yet")
+
+ BLOCK_SIZE = 256
+ N_PER_TH = 1 # Number of blocks processed per thread.
+ grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),)
+ optimizer_id = name2optimizer_id[optimizer_name]
+ fn_preprocess = name2optimizer_32bit_fn[optimizer_name]["preprocess"]
+ fn_update = name2optimizer_32bit_fn[optimizer_name]["update"]
+
+ # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
+ # For backwards compatibility we precompute the bias correction factors.
+ beta1_step = beta1**step
+ beta2_step = beta2**step
+
+ if optimizer_name == "lion":
+ fn_update[grid](
+ g,
+ p,
+ state1,
+ state2,
+ unorm_vec,
+ max_unorm,
+ param_norm,
+ beta1,
+ beta2,
+ beta3,
+ alpha,
+ eps,
+ weight_decay,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ gnorm_scale,
+ skip_zeros,
+ p.numel(),
+ optimizer_id,
+ BLOCK_SIZE,
+ N_PER_TH,
+ num_warps=2,
+ )
+
+ if max_unorm > 0.0:
+ unorm_vec.zero_()
+ fn_preprocess[grid](
+ g,
+ p,
+ state1,
+ state2,
+ unorm_vec,
+ beta1,
+ beta2,
+ eps,
+ weight_decay,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ gnorm_scale,
+ p.numel(),
+ optimizer_id,
+ BLOCK_SIZE,
+ N_PER_TH,
+ num_warps=2,
+ )
+
+ else:
+ if max_unorm > 0.0:
+ unorm_vec.zero_()
+ fn_preprocess[grid](
+ g,
+ p,
+ state1,
+ state2,
+ unorm_vec,
+ beta1,
+ beta2,
+ eps,
+ weight_decay,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ gnorm_scale,
+ p.numel(),
+ optimizer_id,
+ BLOCK_SIZE,
+ N_PER_TH,
+ num_warps=2,
+ )
+
+ fn_update[grid](
+ g,
+ p,
+ state1,
+ state2,
+ unorm_vec,
+ max_unorm,
+ param_norm,
+ beta1,
+ beta2,
+ beta3,
+ alpha,
+ eps,
+ weight_decay,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ gnorm_scale,
+ skip_zeros,
+ p.numel(),
+ optimizer_id,
+ BLOCK_SIZE,
+ N_PER_TH,
+ num_warps=2,
+ )
+
+
+###########################################
+# Pure torch implementation for reference #
+###########################################
+
+
+@torch.compile
+def _dequantize_blockwise_pytorch(
+ A: torch.Tensor,
+ absmax: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
+ dtype: torch.dtype,
+) -> torch.Tensor:
+ """
+ Pure PyTorch reference implementation for block-wise dequantization.
+ """
+ if A.numel() == 0:
+ return torch.empty_like(A, dtype=dtype)
+
+ A_flat = A.flatten()
+ num_elements = A_flat.numel()
+
+ dequantized_flat = code.to(A.device)[A_flat.long()].to(dtype)
+
+ num_blocks = math.ceil(num_elements / blocksize)
+ pad_len = num_blocks * blocksize - num_elements
+ if pad_len > 0:
+ dequantized_flat = torch.nn.functional.pad(dequantized_flat, (0, pad_len))
+
+ dequantized_blocks = dequantized_flat.reshape(num_blocks, blocksize)
+
+ rescaled_blocks = dequantized_blocks * absmax.unsqueeze(1).to(dtype)
+
+ rescaled_flat = rescaled_blocks.flatten()
+ if pad_len > 0:
+ rescaled_flat = rescaled_flat[:-pad_len]
+
+ return rescaled_flat.reshape(A.shape)
+
+
+@torch.compile
+def _quantize_blockwise_pytorch(
+ A: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
+) -> tuple[torch.Tensor, torch.Tensor]:
+ """
+ Pure PyTorch reference implementation for block-wise quantization.
+ """
+ if A.numel() == 0:
+ return torch.empty_like(A, dtype=torch.uint8), torch.empty(0, dtype=torch.float32, device=A.device)
+
+ A_flat = A.flatten()
+ num_elements = A_flat.numel()
+
+ num_blocks = math.ceil(num_elements / blocksize)
+
+ pad_len = num_blocks * blocksize - num_elements
+ if pad_len > 0:
+ A_flat = torch.nn.functional.pad(A_flat, (0, pad_len))
+
+ A_blocks = A_flat.reshape(num_blocks, blocksize)
+
+ absmax = torch.max(torch.abs(A_blocks), dim=1, keepdim=True)[0]
+ absmax[absmax == 0] = 1.0
+
+ scaled_blocks = A_blocks / absmax
+
+ # Inefficient but straightforward quantization, takes a lot of memory
+ diff = torch.abs(scaled_blocks.unsqueeze(2) - code.to(A.device))
+ quantized_indices = torch.argmin(diff, dim=2).to(torch.uint8)
+
+ quantized_flat = quantized_indices.flatten()
+ if pad_len > 0:
+ quantized_flat = quantized_flat[:-pad_len]
+
+ return quantized_flat.reshape(A.shape), absmax.flatten()
+
+
+# Main updated function
+def optimizer_update_8bit_blockwise_pytorch(
+ p: torch.Tensor,
+ g: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ beta1: float,
+ beta2: float,
+ beta3: float, # ADEMIX
+ alpha: float, # ADEMIX
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: torch.Tensor,
+ qmap2: Optional[torch.Tensor],
+ absmax1: torch.Tensor,
+ absmax2: Optional[torch.Tensor],
+ weight_decay: float,
+ gnorm_scale: float,
+ skip_zeros: bool,
+ # ADEMIX
+ *,
+ optimizer_name: str,
+) -> None:
+ """
+ Pure PyTorch implementation of the 8-bit block-wise optimizer update step.
+ This version ensures high-precision updates for float16 parameters.
+ """
+ if skip_zeros:
+ raise ValueError("skip_zeros is not supported on XPU yet.")
+
+ blocksize = 256
+
+ with torch.no_grad():
+ # Dequantize states to perform updates in 32-bit precision
+ if optimizer_name == "ademamix" and absmax1.ndim == 2:
+ # For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked.
+ s1_1_fp32 = _dequantize_blockwise_pytorch(state1[0], absmax1[0], qmap1, blocksize, torch.float32)
+ s1_2_fp32 = _dequantize_blockwise_pytorch(state1[1], absmax1[1], qmap1, blocksize, torch.float32)
+ state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32])
+ else:
+ state1_fp32 = _dequantize_blockwise_pytorch(state1, absmax1, qmap1, blocksize, torch.float32)
+
+ state2_fp32 = None
+ if state2 is not None:
+ state2_fp32 = _dequantize_blockwise_pytorch(state2, absmax2, qmap2, blocksize, torch.float32)
+
+ grad = g.float() * gnorm_scale
+
+ # Create a 32-bit copy of the parameter for high-precision updates
+ p_fp32 = p.data.float()
+
+ if optimizer_name == "adam":
+ state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
+ state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
+
+ bias_correction1 = 1.0 - beta1**step
+ bias_correction2 = 1.0 - beta2**step
+
+ denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps)
+
+ if weight_decay > 0.0:
+ p_fp32.mul_(1.0 - lr * weight_decay)
+ p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1)
+
+ elif optimizer_name == "ademamix":
+ m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1]
+ nu_fp32 = state2_fp32
+
+ m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
+ m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3)
+ nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
+
+ bias_correction1 = 1.0 - beta1**step
+ bias_correction2 = math.sqrt(1.0 - beta2**step)
+
+ update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps)
+
+ if weight_decay > 0.0:
+ p_fp32.mul_(1.0 - lr * weight_decay)
+
+ p_fp32.add_(update, alpha=-lr)
+ state1_fp32 = torch.stack([m1_fp32, m2_fp32])
+
+ elif optimizer_name == "momentum":
+ grad.add_(p_fp32, alpha=weight_decay)
+ if step == 1:
+ state1_fp32.copy_(grad)
+ else:
+ state1_fp32.mul_(beta1).add_(grad)
+ p_fp32.add_(state1_fp32, alpha=-lr)
+
+ elif optimizer_name == "rmsprop":
+ grad.add_(p_fp32, alpha=weight_decay)
+ state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1)
+ p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
+
+ elif optimizer_name == "lion":
+ if weight_decay > 0.0:
+ p_fp32.mul_(1.0 - lr * weight_decay)
+
+ update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1))
+ p_fp32.add_(update_dir, alpha=-lr)
+
+ state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2)
+
+ elif optimizer_name == "adagrad":
+ grad.add_(p_fp32, alpha=weight_decay)
+ state1_fp32.addcmul_(grad, grad, value=1.0)
+ p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
+
+ else:
+ raise NotImplementedError(
+ f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available."
+ )
+
+ # Copy the updated 32-bit parameter back to the original tensor
+ p.data.copy_(p_fp32)
+
+ # Re-quantize states and update state tensors in-place
+ if optimizer_name == "ademamix":
+ new_m1_8bit, new_absmax_m1 = _quantize_blockwise_pytorch(state1_fp32[0], qmap1, blocksize)
+ new_m2_8bit, new_absmax_m2 = _quantize_blockwise_pytorch(state1_fp32[1], qmap1, blocksize)
+ state1[0].copy_(new_m1_8bit)
+ state1[1].copy_(new_m2_8bit)
+ absmax1[0].copy_(new_absmax_m1)
+ absmax1[1].copy_(new_absmax_m2)
+
+ new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize)
+ state2.copy_(new_state2_8bit)
+ absmax2.copy_(new_absmax2)
+ else:
+ new_state1_8bit, new_absmax1 = _quantize_blockwise_pytorch(state1_fp32, qmap1, blocksize)
+ state1.copy_(new_state1_8bit)
+ absmax1.copy_(new_absmax1)
+
+ if state2_fp32 is not None:
+ new_state2_8bit, new_absmax2 = _quantize_blockwise_pytorch(state2_fp32, qmap2, blocksize)
+ state2.copy_(new_state2_8bit)
+ absmax2.copy_(new_absmax2)
+
+
+#######################################
+# Mixed torch + triton implementation #
+#######################################
+
+
+# Much more memory efficient due to using triton for quantization/dequantization
+def optimizer_update_8bit_blockwise_triton_quant(
+ p: torch.Tensor,
+ g: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ beta1: float,
+ beta2: float,
+ beta3: float, # ADEMIX
+ alpha: float, # ADEMIX
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: torch.Tensor,
+ qmap2: Optional[torch.Tensor],
+ absmax1: torch.Tensor,
+ absmax2: Optional[torch.Tensor],
+ weight_decay: float,
+ gnorm_scale: float,
+ skip_zeros: bool,
+ # ADEMIX
+ *,
+ optimizer_name: str,
+) -> None:
+ """
+ Pure PyTorch implementation of the 8-bit block-wise optimizer update step.
+ This version ensures high-precision updates for float16 parameters.
+ """
+ if skip_zeros and not torch.any(g):
+ return
+
+ blocksize = 256
+ grad = g.float() * gnorm_scale
+
+ with torch.no_grad():
+ # Create a 32-bit copy of the parameter for high-precision updates
+ p_fp32 = p.data.float()
+
+ # Dequantize states to perform updates in 32-bit precision
+ if optimizer_name == "ademamix" and absmax1.ndim == 2:
+ # For AdEMAMix, state1 holds two EMAs, so absmax1 is stacked.
+ s1_1_fp32 = dequant_8bit_blockwise(state1[0], absmax1[0], qmap1, blocksize, dtype=torch.float32)
+ s1_2_fp32 = dequant_8bit_blockwise(state1[1], absmax1[1], qmap1, blocksize, dtype=torch.float32)
+ state1_fp32 = torch.stack([s1_1_fp32, s1_2_fp32])
+ else:
+ state1_fp32 = dequant_8bit_blockwise(state1, absmax1, qmap1, blocksize, dtype=torch.float32)
+
+ state2_fp32 = None
+ if state2 is not None:
+ state2_fp32 = dequant_8bit_blockwise(state2, absmax2, qmap2, blocksize, dtype=torch.float32)
+
+ # Apply optimizer-specific update logic
+ if optimizer_name == "adam":
+ if weight_decay > 0.0:
+ p_fp32.mul_(1.0 - lr * weight_decay)
+
+ state1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
+ state2_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
+
+ bias_correction1 = 1.0 - beta1**step
+ bias_correction2 = 1.0 - beta2**step
+
+ denom = (state2_fp32.sqrt() / math.sqrt(bias_correction2)).add_(eps)
+ p_fp32.addcdiv_(state1_fp32, denom, value=-lr / bias_correction1)
+
+ elif optimizer_name == "ademamix":
+ m1_fp32, m2_fp32 = state1_fp32[0], state1_fp32[1]
+ nu_fp32 = state2_fp32
+
+ m1_fp32.mul_(beta1).add_(grad, alpha=1.0 - beta1)
+ m2_fp32.mul_(beta3).add_(grad, alpha=1.0 - beta3)
+ nu_fp32.mul_(beta2).addcmul_(grad, grad, value=1.0 - beta2)
+
+ bias_correction1 = 1.0 - beta1**step
+ bias_correction2 = math.sqrt(1.0 - beta2**step)
+
+ update = (m1_fp32 / bias_correction1 + alpha * m2_fp32) / (nu_fp32.sqrt() / bias_correction2 + eps)
+
+ if weight_decay > 0.0:
+ p_fp32.mul_(1.0 - lr * weight_decay)
+
+ p_fp32.add_(update, alpha=-lr)
+ state1_fp32 = torch.stack([m1_fp32, m2_fp32])
+
+ elif optimizer_name == "momentum":
+ grad.add_(p_fp32, alpha=weight_decay)
+ if step == 1:
+ state1_fp32.copy_(grad)
+ else:
+ state1_fp32.mul_(beta1).add_(grad)
+ p_fp32.add_(state1_fp32, alpha=-lr)
+
+ elif optimizer_name == "rmsprop":
+ grad.add_(p_fp32, alpha=weight_decay)
+ state1_fp32.mul_(beta1).addcmul_(grad, grad, value=1.0 - beta1)
+ p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
+
+ elif optimizer_name == "lion":
+ if weight_decay > 0.0:
+ p_fp32.mul_(1.0 - lr * weight_decay)
+
+ update_dir = torch.sign(state1_fp32.mul(beta1) + grad.mul(1.0 - beta1))
+ p_fp32.add_(update_dir, alpha=-lr)
+
+ state1_fp32.mul_(beta2).add_(grad, alpha=1.0 - beta2)
+
+ elif optimizer_name == "adagrad":
+ grad.add_(p_fp32, alpha=weight_decay)
+ state1_fp32.addcmul_(grad, grad, value=1.0)
+ p_fp32.addcdiv_(grad, state1_fp32.sqrt().add_(eps), value=-lr)
+
+ else:
+ raise NotImplementedError(
+ f"Pure PyTorch implementation for optimizer '{optimizer_name}' is not available."
+ )
+
+ # Copy the updated 32-bit parameter back to the original tensor
+ p.data.copy_(p_fp32)
+
+ # Re-quantize states and update state tensors in-place
+ if optimizer_name == "ademamix":
+ new_m1_8bit, new_absmax_m1 = quantize_blockwise_triton(state1_fp32[0], qmap1, blocksize)
+ new_m2_8bit, new_absmax_m2 = quantize_blockwise_triton(state1_fp32[1], qmap1, blocksize)
+ state1[0].copy_(new_m1_8bit)
+ state1[1].copy_(new_m2_8bit)
+ absmax1[0].copy_(new_absmax_m1)
+ absmax1[1].copy_(new_absmax_m2)
+
+ new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize)
+ state2.copy_(new_state2_8bit)
+ absmax2.copy_(new_absmax2)
+ else:
+ new_state1_8bit, new_absmax1 = quantize_blockwise_triton(state1_fp32, qmap1, blocksize)
+ state1.copy_(new_state1_8bit)
+ absmax1.copy_(new_absmax1)
+
+ if state2_fp32 is not None:
+ new_state2_8bit, new_absmax2 = quantize_blockwise_triton(state2_fp32, qmap2, blocksize)
+ state2.copy_(new_state2_8bit)
+ absmax2.copy_(new_absmax2)
+
+
+#########################
+# Triton implementation #
+#########################
+
+
+@triton.jit
+def _optimizer_update_1state_8bit_blockwise_triton_kernel(
+ # Tensors
+ p_ptr,
+ g_ptr,
+ state1_ptr,
+ state2_ptr,
+ beta1: tl.constexpr,
+ beta2: tl.constexpr,
+ beta3,
+ alpha,
+ eps: tl.constexpr,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ qmap1_ptr,
+ qmap2_ptr,
+ absmax1_ptr,
+ absmax2_ptr,
+ weight_decay,
+ gnorm_scale,
+ # Meta-parameters
+ n_elements,
+ BLOCK_SIZE_N: tl.constexpr,
+ N_PER_TH: tl.constexpr,
+ OPTIMIZER_ID: tl.constexpr,
+):
+ """
+ Triton kernel for 8-bit optimizers that use one momentum state.
+ Supports: Momentum, RMSprop, Adagrad, Lion.
+ """
+ # 1. Boilerplate: pid, offsets, mask
+ pid = tl.program_id(axis=0)
+ block_start_idx = pid * N_PER_TH
+ offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH)
+ mask = offsets < n_elements
+
+ # 2. Load and dequantize tensors
+ g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+ s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
+
+ # 3. Optimizer-specific updates
+ # LION
+ if weight_decay > 0.0 and OPTIMIZER_ID == 2:
+ p *= 1.0 - lr * weight_decay
+ # Apply weight decay for momentum, rmsprop, adagrad
+ elif weight_decay > 0.0:
+ g += p * weight_decay
+
+ # Momentum update
+ if OPTIMIZER_ID == 0: # MOMENTUM
+ if step == 1:
+ s1 = g
+ else:
+ s1 = s1 * beta1 + g
+ p -= lr * s1
+
+ # RMSprop update
+ elif OPTIMIZER_ID == 1: # RMSPROP
+ s1 = s1 * beta1 + (1.0 - beta1) * g * g
+ p -= lr * (g / (tl.sqrt(s1) + eps))
+
+ # Adagrad update
+ elif OPTIMIZER_ID == 2: # ADAGRAD
+ s1 += g * g
+ p -= lr * (g / (tl.sqrt(s1) + eps))
+
+ # Lion update
+ elif OPTIMIZER_ID == 4: # LION
+ val = s1 * beta1 + (1.0 - beta1) * g
+ update = tl.where(val > 0.0, 1.0, tl.where(val < 0.0, -1.0, 0.0))
+ p -= lr * update
+ s1 = s1 * beta2 + (1.0 - beta2) * g
+
+ # 4. Store updated parameter and requantized state
+ tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
+ s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
+ tl.store(state1_ptr + offsets, s1_codes, mask=mask)
+ tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1)
+
+
+@triton.jit
+def _optimizer_update_2state_8bit_blockwise_triton_kernel(
+ # Tensors
+ p_ptr,
+ g_ptr,
+ state1_ptr,
+ state2_ptr,
+ beta1: tl.constexpr,
+ beta2: tl.constexpr,
+ # ademamix changes alpha and beta3
+ beta3,
+ # ademamix changes alpha and beta3
+ alpha,
+ eps: tl.constexpr,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ qmap1_ptr,
+ qmap2_ptr,
+ absmax1_ptr,
+ absmax2_ptr,
+ weight_decay: tl.constexpr,
+ gnorm_scale: tl.constexpr,
+ # Meta-parameters
+ n_elements,
+ BLOCK_SIZE_N: tl.constexpr,
+ N_PER_TH: tl.constexpr,
+ OPTIMIZER_ID: tl.constexpr,
+):
+ """
+ Triton kernel for 8-bit optimizers that use two momentum states.
+ Supports: Adam, AdEMAMix.
+ """
+ # 1. Boilerplate: pid, offsets, mask
+ pid = tl.program_id(axis=0)
+ block_start_idx = pid * N_PER_TH
+ offsets = block_start_idx * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N * N_PER_TH)
+ mask = offsets < n_elements
+
+ # 2. Load and dequantize tensors
+ g = tl.load(g_ptr + offsets, mask=mask, other=0.0).to(tl.float32) * gnorm_scale
+ p = tl.load(p_ptr + offsets, mask=mask, other=0.0).to(tl.float32)
+
+ # 3. Optimizer-specific updates
+ if OPTIMIZER_ID == 3: # ADAM
+ s1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
+ s2 = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
+
+ s1 = s1 * beta1 + (1.0 - beta1) * g
+ s2 = s2 * beta2 + (1.0 - beta2) * g * g
+
+ # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
+ # For backwards compatibility we precompute the bias correction factors.
+ # bias_correction1 = 1.0 - libdevice.pow(beta1, step)
+ # bias_correction2 = 1.0 - libdevice.pow(beta2, step)
+ bias_correction1 = 1.0 - beta1_step
+ bias_correction2 = 1.0 - beta2_step
+
+ if weight_decay > 0.0:
+ p *= 1.0 - lr * weight_decay
+
+ denom = tl.sqrt(s2) / tl.sqrt(bias_correction2) + eps
+ p -= (lr / bias_correction1) * (s1 / denom)
+
+ # Store updated parameter
+ tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
+
+ # Requantize and store states
+ s1_codes, new_absmax1 = quantize_8bit_blockwise_kernel_util(s1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
+ tl.store(state1_ptr + offsets, s1_codes, mask=mask)
+ tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax1)
+
+ s2_codes, new_absmax2 = quantize_8bit_blockwise_kernel_util(s2, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
+ tl.store(state2_ptr + offsets, s2_codes, mask=mask)
+ tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax2)
+
+ elif OPTIMIZER_ID == 5: # ADEMAMIX
+ # AdEMAMix has a stacked state1 (m1, m2) and state2 (nu)
+ m1 = dequant_8bit_blockwise_kernel_util(state1_ptr, offsets, qmap1_ptr, absmax1_ptr, mask, BLOCK_SIZE_N)
+ m2 = dequant_8bit_blockwise_kernel_util(
+ state1_ptr + n_elements,
+ offsets,
+ qmap1_ptr,
+ absmax1_ptr + n_elements // BLOCK_SIZE_N,
+ mask,
+ BLOCK_SIZE_N,
+ )
+ nu = dequant_8bit_blockwise_kernel_util(state2_ptr, offsets, qmap2_ptr, absmax2_ptr, mask, BLOCK_SIZE_N)
+
+ m1 = m1 * beta1 + (1.0 - beta1) * g
+ m2 = m2 * beta3 + (1.0 - beta3) * g
+ nu = nu * beta2 + (1.0 - beta2) * g * g
+
+ # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
+ # For backwards compatibility we precompute the bias correction factors.
+ # bias_correction1 = 1.0 - libdevice.pow(beta1, step)
+ # bias_correction2 = tl.sqrt(1.0 - libdevice.pow(beta2, step))
+ bias_correction1 = 1.0 - beta1_step
+ bias_correction2 = tl.sqrt(1.0 - beta2_step)
+
+ update = (m1 / bias_correction1 + alpha * m2) / (tl.sqrt(nu) / bias_correction2 + eps)
+
+ if weight_decay > 0.0:
+ p *= 1.0 - lr * weight_decay
+
+ p -= lr * update
+
+ # Store updated parameter
+ tl.store(p_ptr + offsets, p.to(p_ptr.dtype.element_ty), mask=mask)
+
+ # Requantize and store all three states
+ m1_codes, new_absmax_m1 = quantize_8bit_blockwise_kernel_util(m1, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
+ tl.store(state1_ptr + offsets, m1_codes, mask=mask)
+ tl.store(absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_m1)
+
+ m2_codes, new_absmax_m2 = quantize_8bit_blockwise_kernel_util(m2, qmap1_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
+ tl.store(state1_ptr + n_elements + offsets, m2_codes, mask=mask)
+ tl.store(
+ absmax1_ptr + block_start_idx + tl.arange(0, N_PER_TH) + n_elements // BLOCK_SIZE_N,
+ new_absmax_m2,
+ )
+
+ nu_codes, new_absmax_nu = quantize_8bit_blockwise_kernel_util(nu, qmap2_ptr, 256, BLOCK_SIZE_N, N_PER_TH)
+ tl.store(state2_ptr + offsets, nu_codes, mask=mask)
+ tl.store(absmax2_ptr + block_start_idx + tl.arange(0, N_PER_TH), new_absmax_nu)
+
+
+name2optimizer_fn = {
+ "momentum": _optimizer_update_1state_8bit_blockwise_triton_kernel,
+ "rmsprop": _optimizer_update_1state_8bit_blockwise_triton_kernel,
+ "adagrad": _optimizer_update_1state_8bit_blockwise_triton_kernel,
+ "adam": _optimizer_update_2state_8bit_blockwise_triton_kernel,
+ "lion": _optimizer_update_1state_8bit_blockwise_triton_kernel,
+ "ademamix": _optimizer_update_2state_8bit_blockwise_triton_kernel,
+}
+
+
+def optimizer_update_8bit_blockwise_impl(
+ optimizer_name: str,
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: torch.Tensor,
+ qmap2: Optional[torch.Tensor],
+ absmax1: torch.Tensor,
+ absmax2: Optional[torch.Tensor],
+ weight_decay: float = 0.0,
+ gnorm_scale: float = 1.0,
+ skip_zeros=False,
+) -> None:
+ if skip_zeros:
+ raise NotImplementedError("skip_zeros is not supported on XPU yet")
+
+ if optimizer_name == "ademamix":
+ # Handle AdEMAMIX's stacked state tensors
+ if state1.dim() < 2 or state1.shape[0] != 2:
+ raise ValueError(
+ f"For ademamix, state1 must be a stacked tensor of shape (2, ...), but got {state1.shape}"
+ )
+ if absmax1.dim() < 2 or absmax1.shape[0] != 2:
+ raise ValueError(
+ f"For ademamix, absmax1 must be a stacked tensor of shape (2, ...), but got {absmax1.shape}"
+ )
+
+ BLOCK_SIZE = 256
+ N_PER_TH = 1 # Number of blocks processed per thread.
+ grid = (triton.cdiv(p.numel(), BLOCK_SIZE * N_PER_TH),)
+ fn = name2optimizer_fn[optimizer_name]
+ optimizer_id = name2optimizer_id[optimizer_name]
+
+ # In torch=2.7 on XPU there is an issue with libdevice.pow, leading to an error.
+ # For backwards compatibility we precompute the bias correction factors.
+ beta1_step = beta1**step
+ beta2_step = beta2**step
+
+ fn[grid](
+ p,
+ g,
+ state1,
+ state2,
+ beta1,
+ beta2,
+ beta3,
+ alpha,
+ eps,
+ step,
+ beta1_step,
+ beta2_step,
+ lr,
+ qmap1,
+ qmap2,
+ absmax1,
+ absmax2,
+ weight_decay,
+ gnorm_scale,
+ p.numel(),
+ BLOCK_SIZE_N=BLOCK_SIZE,
+ N_PER_TH=N_PER_TH,
+ OPTIMIZER_ID=optimizer_id,
+ num_warps=2,
+ )
+
+
+# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_pytorch
+# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_pytorch_impl)
+# optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_triton_quant
+# optimizer_update_8bit_blockwise_impl = torch.compile(optimizer_update_8bit_blockwise_triton_quant)
+optimizer_update_8bit_blockwise_impl = optimizer_update_8bit_blockwise_impl
diff --git a/bitsandbytes/backends/triton/ops.py b/bitsandbytes/backends/triton/ops.py
index 1e2802ab5..66bff3c94 100644
--- a/bitsandbytes/backends/triton/ops.py
+++ b/bitsandbytes/backends/triton/ops.py
@@ -1,30 +1,25 @@
from collections.abc import Sequence
+from typing import Optional
import torch
-from . import triton_kernels
+from . import kernels_4bit, kernels_8bit_quant, kernels_optim
# currently codes unused, kept for reference
# Should be the same for quant/dequant
# from bitsandbytes.functional import get_4bit_type
# _FP4_QUANT_TABLE = get_4bit_type("fp4", device="xpu")
# _NF4_QUANT_TABLE = get_4bit_type("nf4", device="xpu")
+device_type = torch.accelerator.current_accelerator().type if hasattr(torch, "accelerator") else "cuda"
+torch_accelerator_module = getattr(torch, device_type, torch.cuda)
def quantize_blockwise(A: torch.Tensor, code: torch.Tensor, blocksize: int) -> tuple[torch.Tensor, torch.Tensor]:
torch._check_is_size(blocksize)
# torch._check(A.dtype == torch.float32, lambda: f"A must be float32 on xpu, got {A.dtype}")
-
- n = A.numel()
- blocks = -(n // -blocksize)
-
- absmax = torch.empty((blocks,), device=A.device, dtype=A.dtype)
- out = torch.empty_like(A.flatten(), dtype=torch.uint8)
-
- triton_kernels.quantize_blockwise_triton(A, blocksize, code, blocks, absmax, out)
- out = out.reshape(A.shape)
-
- return out, absmax.float()
+ with torch_accelerator_module.device(A.device):
+ out, absmax = kernels_8bit_quant.quantize_blockwise_triton(A, code, blocksize)
+ return out, absmax.float()
def dequantize_blockwise(
@@ -33,21 +28,24 @@ def dequantize_blockwise(
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
# torch._check(dtype == torch.float32, lambda: f"dtype must be float32 on xpu, got {dtype}")
-
- out = torch.empty_like(A, dtype=dtype, device=A.device)
- triton_kernels.dequant_int8_blockwise(
- A,
- code,
- absmax,
- out,
- blocksize,
- )
-
+ with torch_accelerator_module.device(A.device):
+ out = kernels_8bit_quant.dequant_8bit_blockwise(
+ A,
+ absmax,
+ code,
+ blocksize,
+ dtype=dtype,
+ )
return out
def dequantize_blockwise_inplace(
- A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
+ A: torch.Tensor,
+ absmax: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
+ dtype: torch.dtype,
+ out: torch.Tensor,
) -> None:
torch._check_is_size(blocksize)
torch._check(A.dtype == torch.uint8, lambda: f"A must be uint8, got {A.dtype}")
@@ -55,13 +53,15 @@ def dequantize_blockwise_inplace(
torch._check(out.device == A.device, lambda: f"Expected out.device == {A.device}, got {out.device}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
- triton_kernels.dequant_int8_blockwise(
- A,
- code,
- absmax,
- out,
- blocksize,
- )
+ with torch_accelerator_module.device(A.device):
+ kernels_8bit_quant.dequant_8bit_blockwise(
+ A,
+ absmax,
+ code,
+ blocksize,
+ dtype=dtype,
+ out=out,
+ )
def quantize_4bit(
@@ -84,9 +84,10 @@ def quantize_4bit(
absmax = torch.empty((blocks * 2,), device=A.device, dtype=A.dtype)
out = torch.empty((n // 2, 1), device=A.device, dtype=torch.uint8)
- triton_kernels.quantize_4bit_blockwise_triton(
- A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
- )
+ with torch_accelerator_module.device(A.device):
+ kernels_4bit.quantize_4bit_blockwise_triton(
+ A, blocksize, quant_type, blocks, absmax, num_elements=n, quantized_out=out
+ )
packed = out
if quant_storage != torch.uint8:
@@ -118,8 +119,9 @@ def dequantize_4bit(
A = A.squeeze().view(torch.uint8).unsqueeze(1)
out = torch.empty(shape, dtype=dtype, device=A.device)
+ with torch_accelerator_module.device(A.device):
+ kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
- triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
return out
@@ -134,7 +136,8 @@ def dequantize_4bit_inplace(
) -> None:
torch._check(out.shape == shape, lambda: f"Expected out.shape == {shape}, got {out.shape}")
torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
- triton_kernels._dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
+ with torch_accelerator_module.device(A.device):
+ kernels_4bit.dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
def gemv_4bit(
@@ -150,17 +153,145 @@ def gemv_4bit(
B_dq_triton = torch.empty(shapeB, dtype=A.dtype, device=A.device)
- triton_kernels._dequantize_4bit_impl_passing_code(
- B,
- absmax,
- blocksize,
- code,
- dtype=A.dtype,
- out=B_dq_triton,
- )
+ with torch_accelerator_module.device(A.device):
+ kernels_4bit.dequantize_4bit_impl_passing_code(
+ B,
+ absmax,
+ blocksize,
+ code,
+ dtype=A.dtype,
+ out=B_dq_triton,
+ )
+
+ return torch.nn.functional.linear(
+ A,
+ B_dq_triton,
+ bias=None,
+ )
+
+
+# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_pytorch
+# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_pytorch) # 60ms
+# optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_triton_quant #2.8ms
+# optimizer_update_8bit_blockwise_impl = torch.compile(kernels_optim.optimizer_update_8bit_blockwise_triton_quant) # 2.3ms
+optimizer_update_8bit_blockwise_impl = kernels_optim.optimizer_update_8bit_blockwise_impl # ~0.95ms for adam
+
+
+def optimizer_update_8bit_blockwise(
+ optimizer_name: str,
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ step: int,
+ lr: float,
+ qmap1: torch.Tensor,
+ qmap2: Optional[torch.Tensor],
+ absmax1: torch.Tensor,
+ absmax2: Optional[torch.Tensor],
+ weight_decay: float = 0.0,
+ gnorm_scale: float = 1.0,
+ skip_zeros=False,
+) -> None:
+ # torch._check(
+ # g.numel() == p.numel(),
+ # lambda: f"g and p must have the same number of elements, got {g.numel()} and {p.numel()}",
+ # )
+ # compute_dtypes = [torch.float16, torch.bfloat16, torch.float32]
- return torch.nn.functional.linear(
- A,
- B_dq_triton,
- bias=None,
- )
+ # torch._check(
+ # g.dtype in compute_dtypes,
+ # lambda: f"g must be bfloat16, float16, or float32, got {g.dtype}",
+ # )
+ # torch._check(
+ # g.dtype == p.dtype,
+ # lambda: f"Expected all tensors to have the same dtype, got g.dtype={g.dtype}, p.dtype={p.dtype}",
+ # )
+ # torch._check(
+ # state1.dtype == torch.uint8,
+ # lambda: f"state1 must be uint8, got {state1.dtype}",
+ # )
+ # torch._check(
+ # qmap1.dtype == absmax1.dtype == torch.float32,
+ # lambda: f"Expected qmap1 and absmax1 to be float32, got qmap1.dtype={qmap1.dtype}, absmax1.dtype={absmax1.dtype}",
+ # )
+ # if state2 is not None:
+ # torch._check(
+ # state2.dtype == torch.uint8,
+ # lambda: f"state2 must be uint8, got {state2.dtype}",
+ # )
+ # torch._check(
+ # qmap2.dtype == absmax2.dtype == torch.float32,
+ # lambda: f"Expected qmap2 and absmax2 to be float32, got qmap2.dtype={qmap2.dtype}, absmax2.dtype={absmax2.dtype}",
+ # )
+
+ with torch_accelerator_module.device(state1.device):
+ optimizer_update_8bit_blockwise_impl(
+ optimizer_name=optimizer_name,
+ g=g,
+ p=p,
+ state1=state1,
+ state2=state2,
+ beta1=beta1,
+ beta2=beta2,
+ beta3=beta3,
+ alpha=alpha,
+ eps=eps,
+ step=step,
+ lr=lr,
+ qmap1=qmap1,
+ qmap2=qmap2,
+ absmax1=absmax1,
+ absmax2=absmax2,
+ weight_decay=weight_decay,
+ gnorm_scale=gnorm_scale,
+ skip_zeros=skip_zeros,
+ )
+
+
+def optimizer_update_32bit(
+ optimizer_name: str,
+ g: torch.Tensor,
+ p: torch.Tensor,
+ state1: torch.Tensor,
+ state2: Optional[torch.Tensor],
+ unorm_vec: Optional[torch.Tensor],
+ max_unorm: float,
+ param_norm: float,
+ beta1: float,
+ beta2: float,
+ beta3: float,
+ alpha: float,
+ eps: float,
+ weight_decay: float,
+ step: int,
+ lr: float,
+ gnorm_scale: float,
+ skip_zeros=False,
+) -> None:
+ with torch_accelerator_module.device(state1.device):
+ kernels_optim.optimizer_update_32bit_impl(
+ optimizer_name=optimizer_name,
+ g=g,
+ p=p,
+ state1=state1,
+ state2=state2,
+ unorm_vec=unorm_vec,
+ max_unorm=max_unorm,
+ param_norm=param_norm,
+ beta1=beta1,
+ beta2=beta2,
+ beta3=beta3,
+ alpha=alpha,
+ eps=eps,
+ weight_decay=weight_decay,
+ step=step,
+ lr=lr,
+ gnorm_scale=gnorm_scale,
+ skip_zeros=skip_zeros,
+ )
diff --git a/bitsandbytes/backends/utils.py b/bitsandbytes/backends/utils.py
old mode 100755
new mode 100644
index 1543f3474..34e3d5faa
--- a/bitsandbytes/backends/utils.py
+++ b/bitsandbytes/backends/utils.py
@@ -3,22 +3,12 @@
from packaging import version
import torch
-try:
- # to support Intel CPU/XPU (IPEX) backend
- import intel_extension_for_pytorch as ipex
-
- ipex_cpu = ipex if ipex._C._has_cpu() else None
- ipex_xpu = ipex if ipex._C._has_xpu() else None
-except BaseException:
- ipex_cpu = None
- ipex_xpu = None
-
try:
import triton # noqa: F401
import triton.language as tl # noqa: F401
triton_available = True
-except ImportError as e:
+except ImportError:
triton_available = False
diff --git a/bitsandbytes/backends/xpu/__init__.py b/bitsandbytes/backends/xpu/__init__.py
old mode 100755
new mode 100644
diff --git a/bitsandbytes/backends/xpu/ops.py b/bitsandbytes/backends/xpu/ops.py
old mode 100755
new mode 100644
index 999116c97..a0620dc4b
--- a/bitsandbytes/backends/xpu/ops.py
+++ b/bitsandbytes/backends/xpu/ops.py
@@ -1,14 +1,20 @@
from collections.abc import Sequence
-import warnings
+import ctypes as ct
+import logging
+from packaging import version
import torch
+from bitsandbytes.functional import _get_tensor_stream, get_ptr
+
from ..._ops import register_kernel
-from ..utils import ipex_xpu, triton_available
+from ...cextension import ErrorHandlerMockBNBNativeLibrary, lib
+from ..utils import triton_available
+
+logger = logging.getLogger(__name__)
-# _int_mm is available in torch starting from 2.7 version,
-# but currently it's don't have xpu implementation.
-if ipex_xpu and torch.__version__ >= (2, 7):
+# _int_mm is available in torch starting from 2.9 version
+if version.parse(torch.__version__).release >= version.parse("2.9").release:
@register_kernel("bitsandbytes::int8_linear_matmul", "xpu")
def _(A: torch.Tensor, B: torch.Tensor):
@@ -18,42 +24,209 @@ def _(A: torch.Tensor, B: torch.Tensor):
).reshape(*A.shape[:-1], B.shape[0])
-# IPEX should be faster for xpu, so at first checking if it is available.
-if ipex_xpu:
+def _dequantize_4bit_impl(
+ A: torch.Tensor,
+ absmax: torch.Tensor,
+ blocksize: int,
+ quant_type: str,
+ dtype: torch.dtype,
+ out: torch.Tensor,
+) -> None:
+ args = (
+ None,
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(blocksize),
+ ct.c_int(out.numel()),
+ _get_tensor_stream(A),
+ )
+ if dtype == torch.bfloat16:
+ if quant_type == "fp4":
+ lib.cdequantize_blockwise_bf16_fp4(*args)
+ else:
+ lib.cdequantize_blockwise_bf16_nf4(*args)
+ elif dtype == torch.float16:
+ if quant_type == "fp4":
+ lib.cdequantize_blockwise_fp16_fp4(*args)
+ else:
+ lib.cdequantize_blockwise_fp16_nf4(*args)
+ elif dtype == torch.float32:
+ if quant_type == "fp4":
+ lib.cdequantize_blockwise_fp32_fp4(*args)
+ else:
+ lib.cdequantize_blockwise_fp32_nf4(*args)
+
+
+def _dequantize_blockwise_impl(
+ A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype, out: torch.Tensor
+) -> None:
+ args = (
+ get_ptr(code),
+ get_ptr(A),
+ get_ptr(absmax),
+ get_ptr(out),
+ ct.c_int(blocksize),
+ ct.c_int(A.numel()),
+ _get_tensor_stream(A),
+ )
+ if dtype == torch.float16:
+ lib.cdequantize_blockwise_fp16(*args)
+ elif dtype == torch.bfloat16:
+ lib.cdequantize_blockwise_bf16(*args)
+ elif dtype == torch.float32:
+ lib.cdequantize_blockwise_fp32(*args)
+
+
+def _gemv_4bit_impl(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ shapeB: Sequence[int],
+ absmax: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
+ out: torch.Tensor,
+) -> None:
+ m = ct.c_int32(1)
+ n = ct.c_int32(shapeB[0])
+ k = ct.c_int32(shapeB[1])
+
+ lda = m
+ ldb = ct.c_int32((A.shape[-1] + 1) // 2)
+ ldc = m
+
+ stream = _get_tensor_stream(A)
+ if A.dtype == torch.float16:
+ lib.cgemv_4bit_inference_fp16(
+ m,
+ n,
+ k,
+ get_ptr(A),
+ get_ptr(B),
+ get_ptr(absmax),
+ get_ptr(code),
+ get_ptr(out),
+ lda,
+ ldb,
+ ldc,
+ ct.c_int32(blocksize),
+ stream,
+ )
+ elif A.dtype == torch.bfloat16:
+ lib.cgemv_4bit_inference_bf16(
+ m,
+ n,
+ k,
+ get_ptr(A),
+ get_ptr(B),
+ get_ptr(absmax),
+ get_ptr(code),
+ get_ptr(out),
+ lda,
+ ldb,
+ ldc,
+ ct.c_int32(blocksize),
+ stream,
+ )
+ elif A.dtype == torch.float32:
+ lib.cgemv_4bit_inference_fp32(
+ m,
+ n,
+ k,
+ get_ptr(A),
+ get_ptr(B),
+ get_ptr(absmax),
+ get_ptr(code),
+ get_ptr(out),
+ lda,
+ ldb,
+ ldc,
+ ct.c_int32(blocksize),
+ stream,
+ )
+
+
+# SYCL should be faster for xpu, so at first checking if it is available.
+if not isinstance(lib, ErrorHandlerMockBNBNativeLibrary):
+ logger.info("Register sycl bitsandbytes kernels for XPU")
+
+ # TODO: Remove the triton register when quantization sycl kernel is ready.
+ if triton_available:
+ from ..triton import ops as triton_ops
+
+ register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
+ register_kernel("bitsandbytes::quantize_4bit", "xpu")(triton_ops.quantize_4bit)
+ register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(
+ triton_ops.optimizer_update_8bit_blockwise
+ )
+ register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
- @register_kernel("bitsandbytes::dequantize_nf4_ipex", "xpu")
+ @register_kernel("bitsandbytes::dequantize_4bit", "xpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
blocksize: int,
+ quant_type: str,
shape: Sequence[int],
dtype: torch.dtype,
) -> torch.Tensor:
- return torch.ops.torch_ipex.dequantize_4bit(A, "nf4", shape, absmax, None, blocksize).t().to(dtype)
+ out = torch.empty(shape, dtype=dtype, device=A.device)
+ _dequantize_4bit_impl(A, absmax, blocksize, quant_type, dtype, out=out)
+ return out
@register_kernel("bitsandbytes::dequantize_blockwise", "xpu")
+ def _(
+ A: torch.Tensor, absmax: torch.Tensor, code: torch.Tensor, blocksize: int, dtype: torch.dtype
+ ) -> torch.Tensor:
+ out = torch.empty_like(A, dtype=dtype)
+ _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
+ return out
+
+ @register_kernel("bitsandbytes::dequantize_blockwise.out", "xpu")
def _(
A: torch.Tensor,
absmax: torch.Tensor,
code: torch.Tensor,
blocksize: int,
dtype: torch.dtype,
+ out: torch.Tensor,
+ ) -> None:
+ torch._check(out.dtype == dtype, lambda: f"Expected out.dtype == {dtype}, got {out.dtype}")
+ torch._check(out.shape == A.shape, lambda: f"Expected out.shape == {A.shape}, got {out.shape}")
+ _dequantize_blockwise_impl(A, absmax, code, blocksize, dtype, out=out)
+
+ @register_kernel("bitsandbytes::gemv_4bit", "xpu")
+ def _(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ shapeB: Sequence[int],
+ absmax: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
) -> torch.Tensor:
- shape = A.shape
- out = torch.empty(A.reshape(-1).shape, dtype=dtype, device=A.device)
- # void cdequantize_blockwise_fp32(
- # float *code, unsigned char *A, float *absmax, float *out, int blocksize, const int n, cudaStream_t stream)
- if dtype == torch.float16:
- ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp16(code, A, absmax, out, blocksize, A.numel())
- elif dtype == torch.bfloat16:
- ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_bf16(code, A, absmax, out, blocksize, A.numel())
- elif dtype == torch.float32:
- ipex_xpu.xpu.bitsandbytes.cdequantize_blockwise_fp32(code, A, absmax, out, blocksize, A.numel())
- else:
- raise ValueError(f"Blockwise quantization only supports 16/32-bit floats, but got {out.dtype}")
+ shape = (*A.shape[:-1], shapeB[0])
+ out = torch.empty(shape, device=A.device, dtype=A.dtype)
+ _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
+ return out
- return out.reshape(shape)
+ @register_kernel("bitsandbytes::gemv_4bit.out", "xpu")
+ def _(
+ A: torch.Tensor,
+ B: torch.Tensor,
+ shapeB: Sequence[int],
+ absmax: torch.Tensor,
+ code: torch.Tensor,
+ blocksize: int,
+ out: torch.Tensor,
+ ) -> None:
+ torch._check(
+ out.shape == (*A.shape[:-1], shapeB[0]),
+ lambda: f"Expected out.shape == {(*A.shape[:-1], shapeB[0])}, got {out.shape}",
+ )
+ torch._check(out.dtype == A.dtype, lambda: f"Expected out.dtype == {A.dtype}, got {out.dtype}")
+ _gemv_4bit_impl(A, B, shapeB, absmax, code, blocksize, out=out)
elif triton_available:
+ logger.info("Register triton bitsandbytes kernels for XPU")
from ..triton import ops as triton_ops
register_kernel("bitsandbytes::quantize_blockwise", "xpu")(triton_ops.quantize_blockwise)
@@ -63,5 +236,7 @@ def _(
register_kernel("bitsandbytes::dequantize_4bit.out", "xpu")(triton_ops.dequantize_4bit_inplace)
register_kernel("bitsandbytes::dequantize_4bit", "xpu")(triton_ops.dequantize_4bit)
register_kernel("bitsandbytes::gemv_4bit", "xpu")(triton_ops.gemv_4bit)
+ register_kernel("bitsandbytes::optimizer_update_8bit_blockwise", "xpu")(triton_ops.optimizer_update_8bit_blockwise)
+ register_kernel("bitsandbytes::optimizer_update_32bit", "xpu")(triton_ops.optimizer_update_32bit)
else:
- warnings.warn("XPU available but no ipex or triton packages found.")
+ logger.warning("Register pytorch bitsandbytes kernels for XPU because no native library or triton packages found.")
diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py
index bb301e712..2eb584a66 100644
--- a/bitsandbytes/cextension.py
+++ b/bitsandbytes/cextension.py
@@ -283,6 +283,9 @@ def get_native_library() -> BNBNativeLibrary:
binary_path = cuda_binary_path
+ if torch._C._has_xpu:
+ binary_path = PACKAGE_DIR / f"libbitsandbytes_xpu{DYNAMIC_LIBRARY_SUFFIX}"
+
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
# Try to load the library - any errors will propagate up
@@ -291,39 +294,32 @@ def get_native_library() -> BNBNativeLibrary:
if hasattr(dll, "get_context"): # only a CUDA-built library exposes this
return CudaBNBNativeLibrary(dll)
- logger.warning(
- "The installed version of bitsandbytes was compiled without GPU support. "
- "8-bit optimizers and GPU quantization are unavailable."
- )
return BNBNativeLibrary(dll)
ROCM_GPU_ARCH = get_rocm_gpu_arch()
-try:
- # to support Intel CPU/GPU (XPU) backend
- import intel_extension_for_pytorch as ipex
-
- ipex_cpu = ipex if ipex._C._has_cpu() else None
- ipex_xpu = ipex if ipex._C._has_xpu() else None
-except BaseException:
- ipex_cpu = None
- ipex_xpu = None
+HIP_ENVIRONMENT = False
+BNB_BACKEND = "CPU"
+if torch.version.hip:
+ HIP_ENVIRONMENT = True
+ BNB_BACKEND = "ROCm"
+elif torch.cuda.is_available():
+ BNB_BACKEND = "CUDA"
+elif torch._C._has_xpu:
+ BNB_BACKEND = "XPU"
try:
- if torch.version.hip:
- HIP_ENVIRONMENT, BNB_BACKEND = True, "ROCm"
- else:
- HIP_ENVIRONMENT, BNB_BACKEND = False, "CUDA"
-
lib = get_native_library()
except Exception as e:
- error_msg = str(e)
- if not (ipex_cpu or ipex_xpu):
+ if BNB_BACKEND in ("CPU", "XPU"):
+ lib = ErrorHandlerMockBNBNativeLibrary("XPU/CPU can run without native library.")
+ else:
+ error_msg = str(e)
logger.error(
- f"bitsandbytes library load error: {error_msg}\n If you are using Intel CPU/XPU, please install intel_extension_for_pytorch to enable required ops",
+ f"bitsandbytes library load error: {error_msg}",
exc_info=True,
)
- # create a mock with error messaging as fallback
- lib = ErrorHandlerMockBNBNativeLibrary(error_msg)
+ # create a mock with error messaging as fallback
+ lib = ErrorHandlerMockBNBNativeLibrary(error_msg)
diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py
index 9b446a2de..7cca33dcf 100644
--- a/bitsandbytes/functional.py
+++ b/bitsandbytes/functional.py
@@ -13,48 +13,13 @@
from torch import Tensor
from typing_extensions import deprecated
-from bitsandbytes.utils import _reverse_4bit_compress_format, pack_dict_to_tensor, unpack_tensor_to_dict
+from bitsandbytes.utils import pack_dict_to_tensor, unpack_tensor_to_dict
-from .cextension import HIP_ENVIRONMENT, ipex_cpu, ipex_xpu, lib
+from .cextension import HIP_ENVIRONMENT, lib
name2qmap = {}
"""C FUNCTIONS FOR OPTIMIZERS"""
-str2optimizer32bit = {
- "adam": (
- lib.cadam32bit_grad_fp32,
- lib.cadam32bit_grad_fp16,
- lib.cadam32bit_grad_bf16,
- ),
- "momentum": (
- lib.cmomentum32bit_grad_32,
- lib.cmomentum32bit_grad_16,
- ),
- "rmsprop": (
- lib.crmsprop32bit_grad_32,
- lib.crmsprop32bit_grad_16,
- ),
- "lion": (
- lib.clion32bit_grad_fp32,
- lib.clion32bit_grad_fp16,
- lib.clion32bit_grad_bf16,
- ),
- "adagrad": (
- lib.cadagrad32bit_grad_32,
- lib.cadagrad32bit_grad_16,
- ),
- "lamb": (
- lib.cadam32bit_grad_fp32,
- lib.cadam32bit_grad_fp16,
- lib.cadam32bit_grad_bf16,
- ),
- "ademamix": (
- lib.cademamix32bit_grad_fp32,
- lib.cademamix32bit_grad_fp16,
- lib.cademamix32bit_grad_bf16,
- ),
-}
-
str2optimizer8bit = {
"adam": (
lib.cadam_static_8bit_grad_32,
@@ -82,39 +47,6 @@
),
}
-str2optimizer8bit_blockwise = {
- "adam": (
- lib.cadam_8bit_blockwise_grad_fp32,
- lib.cadam_8bit_blockwise_grad_fp16,
- lib.cadam_8bit_blockwise_grad_bf16,
- ),
- "momentum": (
- lib.cmomentum_8bit_blockwise_grad_fp32,
- lib.cmomentum_8bit_blockwise_grad_fp16,
- lib.cmomentum_8bit_blockwise_grad_bf16,
- ),
- "rmsprop": (
- lib.crmsprop_8bit_blockwise_grad_fp32,
- lib.crmsprop_8bit_blockwise_grad_fp16,
- lib.crmsprop_8bit_blockwise_grad_bf16,
- ),
- "lion": (
- lib.clion_8bit_blockwise_grad_fp32,
- lib.clion_8bit_blockwise_grad_fp16,
- lib.clion_8bit_blockwise_grad_bf16,
- ),
- "adagrad": (
- lib.cadagrad_8bit_blockwise_grad_fp32,
- lib.cadagrad_8bit_blockwise_grad_fp16,
- lib.cadagrad_8bit_blockwise_grad_bf16,
- ),
- "ademamix": (
- lib.cademamix_8bit_blockwise_grad_fp32,
- lib.cademamix_8bit_blockwise_grad_fp16,
- lib.cademamix_8bit_blockwise_grad_bf16,
- ),
-}
-
class GlobalPageManager:
_instance = None
@@ -310,7 +242,6 @@ def create_fp8_map(signed=True, exponent_bits=5, precision_bits=2, total_bits=8)
assert e + p == total_bits - has_sign
# the exponent is biased to 2^(e-1) -1 == 0
evalues = []
- pvalues = []
for i, val in enumerate(range(-(2 ** (exponent_bits - has_sign)), 2 ** (exponent_bits - has_sign), 1)):
evalues.append(2**val)
@@ -422,8 +353,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
for t in tensors:
# NULL pointers and paged tensors are OK.
if t is not None and not getattr(t, "is_paged", False):
- on_gpu &= t.is_cuda
- gpu_ids.add(t.device.index)
+ on_gpu &= t.device.type != "cpu"
+ gpu_ids.add((t.device.type, t.device.index))
if not on_gpu:
raise RuntimeError(
@@ -439,6 +370,8 @@ def is_on_gpu(tensors: Iterable[Optional[torch.Tensor]]):
def _get_tensor_stream(tensor: Tensor) -> ct.c_void_p:
# We use the raw stream for performance reasons.
+ if tensor.device.type == "xpu":
+ return ct.c_void_p(torch._C._xpu_getCurrentRawStream(tensor.device.index))
return ct.c_void_p(torch._C._cuda_getCurrentRawStream(tensor.device.index))
@@ -1053,16 +986,6 @@ def dequantize_4bit(
if absmax.dtype != torch.float32:
absmax = absmax.float()
- # IPEX format is different, we need extra process.
- if getattr(quant_state, "ipex", False) and quant_state.quant_type == "nf4":
- return torch.ops.bitsandbytes.dequantize_nf4_ipex(
- A,
- absmax,
- quant_state.blocksize,
- quant_state.shape,
- quant_state.dtype,
- )
-
if out is not None:
torch.ops.bitsandbytes.dequantize_4bit.out(
A, absmax, quant_state.blocksize, quant_state.quant_type, quant_state.shape, quant_state.dtype, out=out
@@ -1252,41 +1175,27 @@ def optimizer_update_32bit(
if max_unorm > 0.0:
param_norm = torch.norm(p.data.float())
- optim_func = None
- if g.dtype == torch.float32:
- optim_func = str2optimizer32bit[optimizer_name][0]
- elif g.dtype == torch.float16:
- optim_func = str2optimizer32bit[optimizer_name][1]
- elif g.dtype == torch.bfloat16 and len(str2optimizer32bit[optimizer_name]) == 3:
- optim_func = str2optimizer32bit[optimizer_name][2]
- else:
- raise ValueError(
- f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
- )
-
is_on_gpu([g, p, state1, state2, unorm_vec])
-
- with _cuda_device_of(g):
- optim_func(
- get_ptr(g),
- get_ptr(p),
- get_ptr(state1),
- get_ptr(state2),
- get_ptr(unorm_vec),
- ct.c_float(max_unorm),
- ct.c_float(param_norm),
- ct.c_float(beta1),
- ct.c_float(beta2),
- ct.c_float(beta3),
- ct.c_float(alpha),
- ct.c_float(eps),
- ct.c_float(weight_decay),
- ct.c_int32(step),
- ct.c_float(lr),
- ct.c_float(gnorm_scale),
- ct.c_bool(skip_zeros),
- ct.c_int32(g.numel()),
- )
+ torch.ops.bitsandbytes.optimizer_update_32bit(
+ optimizer_name,
+ g,
+ p,
+ state1,
+ state2,
+ unorm_vec,
+ max_unorm,
+ param_norm,
+ beta1,
+ beta2,
+ beta3,
+ alpha,
+ eps,
+ weight_decay,
+ step,
+ lr,
+ gnorm_scale,
+ skip_zeros,
+ )
@deprecated(
@@ -1447,47 +1356,29 @@ def optimizer_update_8bit_blockwise(
gnorm_scale: float = 1.0,
skip_zeros=False,
) -> None:
- optim_func = None
-
- if g.dtype == torch.float32 and state1.dtype == torch.uint8:
- optim_func = str2optimizer8bit_blockwise[optimizer_name][0]
- elif g.dtype == torch.float16 and state1.dtype == torch.uint8:
- optim_func = str2optimizer8bit_blockwise[optimizer_name][1]
- elif (
- g.dtype == torch.bfloat16
- and state1.dtype == torch.uint8
- and len(str2optimizer8bit_blockwise[optimizer_name]) == 3
- ):
- optim_func = str2optimizer8bit_blockwise[optimizer_name][2]
- else:
- raise ValueError(
- f"Gradient+optimizer bit data type combination not supported: grad {g.dtype}, optimizer {state1.dtype}",
- )
-
is_on_gpu([p, g, state1, state2, qmap1, qmap2, absmax1, absmax2])
- with _cuda_device_of(g):
- optim_func(
- get_ptr(p),
- get_ptr(g),
- get_ptr(state1),
- get_ptr(state2),
- ct.c_float(beta1),
- ct.c_float(beta2),
- ct.c_float(beta3),
- ct.c_float(alpha),
- ct.c_float(eps),
- ct.c_int32(step),
- ct.c_float(lr),
- get_ptr(qmap1),
- get_ptr(qmap2),
- get_ptr(absmax1),
- get_ptr(absmax2),
- ct.c_float(weight_decay),
- ct.c_float(gnorm_scale),
- ct.c_bool(skip_zeros),
- ct.c_int32(g.numel()),
- )
+ torch.ops.bitsandbytes.optimizer_update_8bit_blockwise(
+ optimizer_name,
+ g,
+ p,
+ state1,
+ state2,
+ beta1,
+ beta2,
+ beta3,
+ alpha,
+ eps,
+ step,
+ lr,
+ qmap1,
+ qmap2,
+ absmax1,
+ absmax2,
+ weight_decay,
+ gnorm_scale,
+ skip_zeros,
+ )
@deprecated("This function is deprecated and will be removed in a future release.", category=FutureWarning)
@@ -1631,25 +1522,6 @@ def gemv_4bit(
if state.nested:
absmax = dequantize_blockwise(absmax, state.state2) + state.offset
- if getattr(state, "ipex", False) and state.quant_type == "nf4":
- # compute_dtype: 1 indicates fp16, 2 indicates bf16
- compute_dtype = 2 if A.dtype == torch.bfloat16 else 1
- out = torch.ops.torch_ipex.woq_linear(
- A,
- B,
- "nf4",
- state.shape,
- state.new_scales,
- state.new_zeros,
- None,
- None,
- state.blocksize,
- compute_dtype,
- 1,
- state.compensation,
- )
- return out
-
if out is not None:
torch.ops.bitsandbytes.gemv_4bit.out(
A,
@@ -2214,7 +2086,7 @@ def spmm_coo(
assert cooA.values.numel() == nnz
assert cooA.cols == B.shape[0]
- transposed_B = False if B.is_contiguous() else True
+ transposed_B = not B.is_contiguous()
ldb = B.stride()[(1 if transposed_B else 0)]
ldc = B.shape[1]
@@ -2263,12 +2135,7 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
assert cooA.values.numel() == nnz
assert cooA.cols == B.shape[0], f"{cooA.cols} vs {B.shape}"
- transposed_B = False if B.is_contiguous() else True
-
- ldb = B.stride()[(1 if transposed_B else 0)]
- ldc = B.shape[1]
-
- values, counts = torch.unique(cooA.rowidx, return_counts=True)
+ _, counts = torch.unique(cooA.rowidx, return_counts=True)
offset = counts.cumsum(0).int()
max_count, max_idx = torch.sort(counts, descending=True)
max_idx = max_idx.int()
@@ -2288,11 +2155,8 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
cnnz_rows = ct.c_int32(counts.numel())
cnnz = ct.c_int32(cooA.nnz)
crowsA = ct.c_int32(cooA.rows)
- ccolsA = ct.c_int32(cooA.cols)
crowsB = ct.c_int32(B.shape[1])
ccolsB = ct.c_int32(B.shape[1])
- cldb = ct.c_int32(ldb)
- cldc = ct.c_int32(ldc)
with _cuda_device_of(B):
is_on_gpu([cooA.rowidx, cooA.colidx, cooA.values, B, out, dequant_stats])
@@ -2336,49 +2200,3 @@ def spmm_coo_very_sparse(cooA, B, dequant_stats=None, out=None):
C = 127.0
-
-
-def _enable_ipex_fusion(linear: torch.nn.Module, x: torch.Tensor):
- quant_state = linear.weight.quant_state
-
- if quant_state.nested:
- absmax = dequantize_blockwise(quant_state.absmax, quant_state.state2)
- absmax += quant_state.offset
- if absmax.dtype != torch.float32:
- absmax = absmax.float()
-
- quant_state.absmax = absmax
- quant_state.nested = False
- delattr(quant_state, "state2")
-
- if x.device.type == "cpu" and ipex_cpu:
- converted_weight = _reverse_4bit_compress_format(linear.weight.data)
- new_weight, new_scales, new_zeros, _, compensation = torch.ops.ipex_prepack.woq_linear_pack_weight(
- converted_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2]),
- "nf4",
- quant_state.shape, # weight shape
- quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize), # scales
- None, # zero_points
- None, # bias
- None, # batch_size
- quant_state.blocksize,
- 2,
- )
- elif x.device.type == "xpu" and ipex_xpu:
- new_weight = _reverse_4bit_compress_format(linear.weight.data)
- new_scales = quant_state.absmax.view(quant_state.shape[0], quant_state.shape[1] // quant_state.blocksize)
- new_zeros = None
- compensation = None
- new_scales = list(new_scales)
- if not linear.training and not x.requires_grad:
- new_weight = new_weight.reshape([quant_state.shape[0], quant_state.shape[1] // 2])
- else:
- raise ValueError(
- "Please check the device and ipex version. The device should be cpu or xpu while ipex version should >= 2.7"
- )
-
- linear.weight.data = new_weight.data
- linear.weight.quant_state.ipex = True
- linear.weight.quant_state.new_scales = new_scales
- linear.weight.quant_state.new_zeros = new_zeros
- linear.weight.quant_state.compensation = compensation
diff --git a/bitsandbytes/nn/modules.py b/bitsandbytes/nn/modules.py
index ba134f52a..1adf75e79 100644
--- a/bitsandbytes/nn/modules.py
+++ b/bitsandbytes/nn/modules.py
@@ -12,13 +12,9 @@
import bitsandbytes as bnb
from bitsandbytes.cextension import HIP_ENVIRONMENT
-from bitsandbytes.functional import QuantState, _enable_ipex_fusion, ipex_cpu, ipex_xpu
+from bitsandbytes.functional import QuantState
from bitsandbytes.optim import GlobalOptimManager
-from bitsandbytes.utils import (
- INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING,
- OutlierTracer,
- _reverse_4bit_compress_format,
-)
+from bitsandbytes.utils import INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING, OutlierTracer
T = TypeVar("T", bound="torch.nn.Module")
@@ -356,6 +352,46 @@ def to(self, *args, **kwargs):
return new_param
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+
+ if func in [torch.chunk, torch.split]:
+ tensor = args[0]
+
+ result = super().__torch_function__(func, types, args, kwargs)
+
+ if isinstance(result, tuple):
+ return tuple(
+ cls(
+ data=chunk,
+ requires_grad=tensor.requires_grad,
+ quant_state=tensor.quant_state,
+ blocksize=tensor.blocksize,
+ compress_statistics=tensor.compress_statistics,
+ quant_type=tensor.quant_type,
+ quant_storage=tensor.quant_storage,
+ module=tensor.module,
+ bnb_quantized=tensor.bnb_quantized,
+ )
+ for chunk in result
+ )
+ else:
+ return cls(
+ data=result,
+ requires_grad=tensor.requires_grad,
+ quant_state=tensor.quant_state,
+ blocksize=tensor.blocksize,
+ compress_statistics=tensor.compress_statistics,
+ quant_type=tensor.quant_type,
+ quant_storage=tensor.quant_storage,
+ module=tensor.module,
+ bnb_quantized=tensor.bnb_quantized,
+ )
+
+ return super().__torch_function__(func, types, args, kwargs)
+
def fix_4bit_weight_quant_state_from_module(module: Union["Embedding4bit", "Linear4bit"]):
if getattr(module.weight, "quant_state", None) is not None:
@@ -440,10 +476,9 @@ def __init__(
)
# self.persistent_buffers = [] # TODO consider as way to save quant state
self.compute_dtype = compute_dtype
- self.compute_type_is_set = False if compute_dtype is None else True
+ self.compute_type_is_set = compute_dtype is not None
self.quant_state = None
self.quant_storage = quant_storage
- self.ipex_linear_is_set = False
def set_compute_type(self, x):
if x.dtype in [torch.float32, torch.bfloat16]:
@@ -470,40 +505,13 @@ def _save_to_state_dict(self, destination, prefix, keep_vars):
save weight and bias,
then fill state_dict with components of quant_state
"""
- if getattr(self.weight, "quant_state", None) is not None and getattr(self.weight.quant_state, "ipex", False):
- if self.weight.device.type == "cpu":
- original_weight = torch.ops.ipex_prepack.woq_linear_unpack_weight(
- self.weight, "nf4", self.weight.quant_state.shape, 2
- )
- self.weight.data = _reverse_4bit_compress_format(original_weight.data)
- elif self.weight.device.type == "xpu":
- self.weight.data = _reverse_4bit_compress_format(self.weight.data.reshape(1, -1))
-
- self.weight.quant_state.ipex = False
- self.ipex_linear_is_set = False
-
super()._save_to_state_dict(destination, prefix, keep_vars) # saving weight and bias
if getattr(self.weight, "quant_state", None) is not None:
for k, v in self.weight.quant_state.as_dict(packed=True).items():
destination[prefix + "weight." + k] = v if keep_vars else v.detach()
- def set_ipex_linear(self, x: torch.Tensor):
- if (
- not getattr(self.weight.quant_state, "ipex", False)
- and self.weight.data.dtype == torch.uint8
- and self.weight.quant_state.shape[1] % self.weight.quant_state.blocksize == 0
- and self.weight.quant_state.quant_type == "nf4"
- ):
- if x.device.type == "xpu" or (x.device.type == "cpu" and not self.training and x.requires_grad == False):
- _enable_ipex_fusion(self, x)
-
def forward(self, x: torch.Tensor):
- # Check if ipex fusion can be used
- if not self.ipex_linear_is_set and (ipex_cpu or ipex_xpu):
- self.set_ipex_linear(x)
- self.ipex_linear_is_set = True
-
fix_4bit_weight_quant_state_from_module(self)
# weights are cast automatically as Int8Params, but the bias has to be cast manually
@@ -519,8 +527,7 @@ def forward(self, x: torch.Tensor):
x = x.to(self.compute_dtype)
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
- # IPEX CPU will change weight to 4D so don't need transpose
- weight = self.weight.t() if self.weight.dim() == 2 else self.weight
+ weight = self.weight.t()
return bnb.matmul_4bit(x, weight, bias=bias, quant_state=self.weight.quant_state).to(inp_dtype)
@@ -675,7 +682,7 @@ def to(self, *args, **kwargs):
if device is not None and device.type != "meta" and self.data.device.type == "cpu":
if device.type != "cpu" or self.data.dtype != torch.int8:
return self._quantize(device)
- elif self.data.dtype == torch.int8 and device.type in ("cpu", "xpu") and (ipex_cpu or ipex_xpu):
+ elif self.data.dtype == torch.int8 and device.type == "cpu":
self.CB = self.data
new_param = Int8Params(
@@ -1110,4 +1117,4 @@ def forward(self, x):
if self.weight.CB is not None:
self.init_8bit_state()
- out = bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
+ return bnb.matmul_mixed(x.half(), self.weight.half(), bias=None, state=self.state) + self.bias
diff --git a/bitsandbytes/nn/parametrize.py b/bitsandbytes/nn/parametrize.py
new file mode 100644
index 000000000..4a956c7fa
--- /dev/null
+++ b/bitsandbytes/nn/parametrize.py
@@ -0,0 +1,192 @@
+from functools import partial
+from typing import Any, Literal, Optional
+
+import torch
+import torch.nn as nn
+import torch.nn.utils.parametrize as P
+
+from .. import functional as F
+
+
+class Bnb4bitParametrization(nn.Module):
+ """
+ A parametrization module that handles dequantization of a 4-bit quantized parameter.
+
+ The parameter data is expected to be already quantized when this parametrization is applied.
+ This module will dequantize the parameter data to its original floating-point representation
+ when the forward method is called (i.e. when the parameter is accessed).
+
+ Args:
+ quant_state (`F.QuantState`):
+ The quantization state containing the necessary information for dequantization.
+ """
+
+ def __init__(self, quant_state: F.QuantState):
+ super().__init__()
+ self.quant_state = quant_state
+
+ @torch.no_grad()
+ def forward(self, quantized_param: torch.Tensor) -> torch.Tensor:
+ """
+ Forward pass to dequantize the parameter.
+
+ Args:
+ quantized_param (`torch.Tensor`): The quantized parameter tensor (from .original)
+
+ Returns:
+ `torch.Tensor`: The dequantized parameter tensor in the original shape and dtype.
+ """
+ return F.dequantize_4bit(quantized_param, self.quant_state)
+
+
+def replace_parameter_4bit_prequantized(
+ module: nn.Module, param_name: str, qs_dict: dict[str, Any], device: torch.device
+):
+ if not hasattr(module, param_name):
+ raise AttributeError(f"Module does not have parameter '{param_name}'")
+
+ original_param = getattr(module, param_name)
+
+ if not isinstance(original_param, nn.Parameter):
+ raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
+
+ quant_state = F.QuantState.from_dict(qs_dict, device=device)
+
+ # Apply a parametrization to the module to handle dequantization.
+ P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
+
+ # Next, register hooks.
+ _register_parametrization_hooks(module, param_name)
+
+
+def replace_parameter_4bit(
+ module: nn.Module,
+ param_name: str,
+ compress_statistics: bool = False,
+ quant_type: Literal["nf4", "fp4"] = "nf4",
+ blocksize: Optional[int] = None,
+):
+ """
+ Replace a module parameter with a 4-bit quantized version using parametrization.
+
+ This function quantizes an existing parameter in a PyTorch module to 4-bit precision
+ and sets up parametrization to handle automatic dequantization during forward passes.
+ The original parameter is replaced with quantized data, and a parametrization layer
+ is registered to manage the quantization state and dequantization process.
+
+ Additional, it registers a state dict post-hook to ensure that the quantization state
+ is saved correctly when the model's state dict is saved.
+
+ It is useful for MoE models or other scenarios where you want to quantize parameters
+ outside of nn.Linear layers without changing the model's architecture.
+
+ This feature is experimental and may change in future releases.
+
+ Args:
+ module (`nn.Module`):
+ The PyTorch module containing the parameter to be quantized.
+ param_name (`str`):
+ The name of the parameter within the module to quantize.
+ compress_statistics (`bool`, *optional*, defaults to `False`):
+ Whether to compress quantization statistics to reduce memory usage.
+ quant_type (`Literal["nf4", "fp4"]`, *optional*, defaults to `"nf4"`):
+ The quantization format to use.
+ blocksize (`int`, *optional*, defaults to `None`):
+ The block size for quantization. If None, uses the default block size.
+
+ Raises:
+ AttributeError: If the module does not have the specified parameter.
+ TypeError: If the specified attribute is not an instance of nn.Parameter.
+ """
+
+ if not hasattr(module, param_name):
+ raise AttributeError(f"Module does not have parameter '{param_name}'")
+
+ original_param = getattr(module, param_name)
+
+ if not isinstance(original_param, nn.Parameter):
+ raise TypeError(f"Parameter '{param_name}' is not an instance of nn.Parameter")
+
+ # Quantize the original parameter.
+ quantized_data, quant_state = F.quantize_4bit(
+ original_param.data,
+ blocksize=blocksize,
+ compress_statistics=compress_statistics,
+ quant_type=quant_type,
+ )
+
+ # Replace the parameter with the quantized data.
+ setattr(module, param_name, nn.Parameter(quantized_data, requires_grad=False))
+ del original_param
+
+ # Apply a parametrization to the module to handle dequantization.
+ P.register_parametrization(module, param_name, Bnb4bitParametrization(quant_state), unsafe=True)
+
+ # Next, register hooks.
+ _register_parametrization_hooks(module, param_name)
+
+
+def _disable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...], output: Any):
+ P._cache_enabled -= 1
+ if not P._cache_enabled:
+ P._cache = {}
+
+
+def _enable_parametrization_cache(module: nn.Module, inputs: tuple[Any, ...]):
+ P._cache_enabled += 1
+
+
+def _register_parametrization_hooks(module: nn.Module, param_name: str):
+ # Register a state dict hook for saving. Note that this requires torch >= 2.5.0.
+ if torch.__version__ >= (2, 5):
+ module.register_state_dict_post_hook(
+ partial(
+ _parametrized_state_dict_post_hook,
+ param_name=param_name,
+ )
+ )
+
+ # Register hooks to enable caching for the dequantization parametrization.
+ # This helps preserve time and memory when the same quantized parameter
+ # is accessed multiple times in the forward computation.
+ module.register_forward_pre_hook(_enable_parametrization_cache)
+ module.register_forward_hook(_disable_parametrization_cache)
+
+
+def _parametrized_state_dict_post_hook(
+ module: nn.Module,
+ state_dict: dict[str, Any],
+ prefix: str,
+ local_metadata: Any,
+ *,
+ param_name: str = "weight",
+ **kwargs: dict[str, Any],
+) -> None:
+ """
+ Hook to modify the state dict to include the quantization state.
+ """
+
+ original_key = f"{prefix}parametrizations.{param_name}.original"
+
+ if original_key in state_dict:
+ # Create a clean entry.
+ # The `parametrizations.{param_name}.original` key will have the quantized data,
+ # but we would like it to keep it in the state_dict as `{param_name}`.
+ clean_key = f"{prefix}{param_name}"
+ state_dict[clean_key] = state_dict.pop(original_key)
+
+ assert P.is_parametrized(module, param_name)
+
+ # Find the parametrization, which should have the quantization state.
+ parametrization: Bnb4bitParametrization = next(
+ filter(lambda x: isinstance(x, Bnb4bitParametrization), module.parametrizations[param_name]), None
+ )
+
+ assert parametrization is not None, "Parametrization not found for the parameter."
+
+ quant_state = parametrization.quant_state
+
+ # Next, we need to store the quantization state.
+ if quant_state is not None:
+ for k, v in quant_state.as_dict(packed=True).items():
+ state_dict[f"{prefix}{param_name}.{k}"] = v
diff --git a/bitsandbytes/optim/adamw.py b/bitsandbytes/optim/adamw.py
index a32394bd5..5f225c9ad 100644
--- a/bitsandbytes/optim/adamw.py
+++ b/bitsandbytes/optim/adamw.py
@@ -26,7 +26,7 @@ def __init__(
Base AdamW optimizer.
Arguments:
- params (`torch.tensor`):
+ params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
@@ -87,7 +87,7 @@ def __init__(
8-bit AdamW optimizer.
Arguments:
- params (`torch.tensor`):
+ params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
@@ -159,7 +159,7 @@ def __init__(
32-bit AdamW optimizer.
Arguments:
- params (`torch.tensor`):
+ params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
@@ -219,7 +219,7 @@ def __init__(
Paged AdamW optimizer.
Arguments:
- params (`torch.tensor`):
+ params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
@@ -241,8 +241,6 @@ def __init__(
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
- is_paged (`bool`, defaults to `False`):
- Whether the optimizer is a paged optimizer or not.
"""
super().__init__(
"adam",
@@ -279,7 +277,7 @@ def __init__(
Paged 8-bit AdamW optimizer.
Arguments:
- params (`torch.tensor`):
+ params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
@@ -303,8 +301,6 @@ def __init__(
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
- is_paged (`bool`, defaults to `False`):
- Whether the optimizer is a paged optimizer or not.
"""
# Validate unsupported parameters
if amsgrad:
@@ -350,7 +346,7 @@ def __init__(
Paged 32-bit AdamW optimizer.
Arguments:
- params (`torch.tensor`):
+ params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
@@ -372,8 +368,6 @@ def __init__(
Adapts clipping threshold automatically by tracking the last 100 gradient norms and clipping the gradient at a certain percentile to improve stability.
block_wise (`bool`, defaults to `True`):
Whether to independently quantize each block of tensors to reduce outlier effects and improve stability.
- is_paged (`bool`, defaults to `False`):
- Whether the optimizer is a paged optimizer or not.
"""
super().__init__(
"adam",
diff --git a/bitsandbytes/optim/lars.py b/bitsandbytes/optim/lars.py
index 90c3686fe..fa2af57bc 100644
--- a/bitsandbytes/optim/lars.py
+++ b/bitsandbytes/optim/lars.py
@@ -231,9 +231,6 @@ def step(self, closure=None):
loss = closure()
for group in self.param_groups:
- params_with_grad = []
- d_p_list = []
- momentum_buffer_list = []
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py
index 9c20f9376..ea3ff32c9 100644
--- a/bitsandbytes/optim/optimizer.py
+++ b/bitsandbytes/optim/optimizer.py
@@ -10,6 +10,7 @@
import torch
import bitsandbytes.functional as F
+from bitsandbytes.utils import sync_gpu
class MockArgs:
@@ -64,9 +65,9 @@ def override_config(self, parameters, key=None, value=None, key_value_dict=None)
parameters (`torch.Tensor` or `list(torch.Tensors)`):
The input parameters.
key (`str`):
- The hyperparamter to override.
+ The hyperparameter to override.
value:
- The hyperparameter values.
+ The hyperparameter value.
key_value_dict (`dict`):
A dictionary with multiple key-values to override.
@@ -115,7 +116,7 @@ def __init__(self, params, defaults, optim_bits=32, is_paged=False):
Base 8-bit optimizer class.
Arguments:
- params (`torch.tensor`):
+ params (`torch.Tensor`):
The input parameters to optimize.
optim_bits (`int`, defaults to 32):
The number of bits of the optimizer state.
@@ -271,14 +272,13 @@ def step(self, closure=None):
with torch.enable_grad():
loss = closure()
- overflows = []
-
if not self.initialized:
self.check_overrides()
self.to_gpu() # needed for fairseq pure fp16 training
self.initialized = True
# if self.is_paged: self.page_mng.prefetch_all()
+ p = None
for gindex, group in enumerate(self.param_groups):
for pindex, p in enumerate(group["params"]):
if p.grad is None:
@@ -289,11 +289,11 @@ def step(self, closure=None):
self.prefetch_state(p)
self.update_step(group, p, gindex, pindex)
- torch.cuda.synchronize()
- if self.is_paged:
- # all paged operation are asynchronous, we need
+ sync_gpu(p)
+ if self.is_paged and p is not None:
+ # all paged operations are asynchronous, we need
# to sync to make sure all tensors are in the right state
- torch.cuda.synchronize()
+ sync_gpu(p)
return loss
@@ -371,7 +371,7 @@ def __init__(
Arguments:
optimizer_name (`str`):
The name of the optimizer.
- params (`torch.tensor`):
+ params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
@@ -428,7 +428,6 @@ def __init__(
if args is None:
args = {}
args["optim_bits"] = optim_bits
- args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
@@ -613,7 +612,7 @@ def __init__(
Arguments:
optimizer_name (`str`):
The name of the optimizer.
- params (`torch.tensor`):
+ params (`torch.Tensor`):
The input parameters to optimize.
lr (`float`, defaults to 1e-3):
The learning rate.
@@ -655,7 +654,6 @@ def __init__(
if args is None:
args = {}
args["optim_bits"] = optim_bits
- args["percentile_clipping"] = 100
args["min_8bit_size"] = min_8bit_size
args["percentile_clipping"] = percentile_clipping
args["block_wise"] = block_wise
diff --git a/bitsandbytes/py.typed b/bitsandbytes/py.typed
new file mode 100644
index 000000000..e69de29bb
diff --git a/bitsandbytes/research/autograd/_functions.py b/bitsandbytes/research/autograd/_functions.py
index d9718382b..9c7afc354 100644
--- a/bitsandbytes/research/autograd/_functions.py
+++ b/bitsandbytes/research/autograd/_functions.py
@@ -235,7 +235,7 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non
# 2. Quantize B
if state.has_fp16_weights:
# print('B shape', B.shape)
- has_grad = True if (getattr(B, "grad", None) is not None) else False
+ has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
B = B.contiguous()
diff --git a/bitsandbytes/triton/triton_utils.py b/bitsandbytes/triton/triton_utils.py
index b706ff1ba..f6bedd8cd 100644
--- a/bitsandbytes/triton/triton_utils.py
+++ b/bitsandbytes/triton/triton_utils.py
@@ -4,11 +4,8 @@
@functools.lru_cache(None)
def is_triton_available():
try:
- # torch>=2.2.0
from torch.utils._triton import has_triton, has_triton_package
return has_triton_package() and has_triton()
- except ImportError:
- from torch._inductor.utils import has_triton
-
- return has_triton()
+ except Exception:
+ return False
diff --git a/bitsandbytes/utils.py b/bitsandbytes/utils.py
index 7920e2188..1af07710c 100644
--- a/bitsandbytes/utils.py
+++ b/bitsandbytes/utils.py
@@ -38,14 +38,6 @@ def outlier_hook(module, input):
hook.remove()
-# convert btw standard 4-bit compression format and ipex compression format
-def _reverse_4bit_compress_format(weight: torch.Tensor):
- out_1 = (weight & 0xF0) >> 4
- out_2 = (weight & 0xF) << 4
- out = out_1 | out_2
- return out
-
-
class OutlierTracer:
_instance = None
@@ -92,11 +84,6 @@ def find_outlier_dims(weight, reduction_dim=0, zscore=4.0, topk=None, rdm=False)
if rdm:
return torch.randint(0, weight.shape[1], size=(topk,), device=weight.device).long()
- m = weight.mean(reduction_dim)
- mm = m.mean()
- mstd = m.std()
- zm = (m - mm) / mstd
-
std = weight.std(reduction_dim)
stdm = std.mean()
stdstd = std.std()
@@ -209,3 +196,10 @@ def unpack_tensor_to_dict(tensor_data):
LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {"row": 0, "col32": 1, "col_turing": 2, "col_ampere": 3}
INVERSE_LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING = {val: name for (name, val) in LINEAR_8BIT_WEIGHTS_FORMAT_MAPPING.items()}
+
+
+def sync_gpu(t: torch.Tensor):
+ if t.device.type == "cuda":
+ torch.cuda.synchronize()
+ elif t.device.type == "xpu":
+ torch.xpu.synchronize()
diff --git a/csrc/kernels.cu b/csrc/kernels.cu
index 649f2ee1f..738ae0cd1 100644
--- a/csrc/kernels.cu
+++ b/csrc/kernels.cu
@@ -21,23 +21,34 @@
#define NUM 4
#define NUM_BLOCK 4096
-__device__ static float nf4_data[16] = {
- -1.0,
- -0.6961928009986877,
- -0.5250730514526367,
- -0.39491748809814453,
- -0.28444138169288635,
- -0.18477343022823334,
- -0.09105003625154495,
- 0.0,
- 0.07958029955625534,
- 0.16093020141124725,
- 0.24611230194568634,
- 0.33791524171829224,
- 0.44070982933044434,
- 0.5626170039176941,
- 0.7229568362236023,
- 1.0
+__device__ static float fp4_dequantization_lut[8] = {
+ 0.0f, // 0b000
+ 0.005208333333f, // 0b001
+ 0.66666667f, // 0b010
+ 1.0f, // 0b011
+ 0.33333333f, // 0b100
+ 0.5f, // 0b101
+ 0.16666667f, // 0b110
+ 0.25f // 0b111
+};
+
+__device__ static float nf4_dequantization_lut[16] = {
+ -1.0f, // 0b0000
+ -0.6961928009986877f, // 0b0001
+ -0.5250730514526367f, // 0b0010
+ -0.39491748809814453f, // 0b0011
+ -0.28444138169288635f, // 0b0100
+ -0.18477343022823334f, // 0b0101
+ -0.09105003625154495f, // 0b0110
+ 0.0f, // 0b0111
+ 0.07958029955625534f, // 0b1000
+ 0.16093020141124725f, // 0b1001
+ 0.24611230194568634f, // 0b1010
+ 0.33791524171829224f, // 0b1011
+ 0.44070982933044434f, // 0b1100
+ 0.5626170039176941f, // 0b1101
+ 0.7229568362236023f, // 0b1110
+ 1.0f // 0b1111
};
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
@@ -51,27 +62,9 @@ __device__ float atomicMax(float* address, float val) {
return __int_as_float(old);
}
-__device__ float dDequantizeFP4Tree(unsigned char val, float absmax) {
- float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
- if ((val & 0b0100) == 4) // 0
- if ((val & 0b0010) == 2) // 01
- if ((val & 0b0001) == 1) // 111
- return 0.25000000f * absmax * sign; // 1111
- else
- return 0.16666667f * absmax * sign; // 1110
- else if ((val & 0b0001) == 1) // 110
- return 0.50000000f * absmax * sign; // 1101
- else
- return 0.33333333f * absmax * sign; // 1100
- else if ((val & 0b0010) == 2) // 10
- if ((val & 0b0001) == 1) // 101
- return 1.00000000f * absmax * sign; // 1011
- else
- return 0.66666667f * absmax * sign; // 1010
- else if ((val & 0b0001) == 1) // 100
- return 5.208333333e-03f * absmax * sign; // 1001
- else
- return 0.00000000f * absmax * sign; // 1000
+__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) {
+ float sign = 1.0f - 2 * ((val & 0b1000) >> 3);
+ return fp4_dequantization_lut[val & 0b111] * sign;
}
__device__ unsigned char dQuantizeFP4(float x) {
@@ -118,51 +111,7 @@ __device__ unsigned char dQuantizeFP4(float x) {
return 0b0000 + sign;
}
-__device__ __forceinline__ float dDequantizeNF4(unsigned char val) {
-
- // the values for this tree was generated by test_normal_map_tree
- // in the file tests/test_functional.py
- if ((val & 0b1000) == 8)
- if ((val & 0b0100) == 4) // 1
- if ((val & 0b0010) == 2) // 11
- if ((val & 0b0001) == 1) // 111
- return 1.0f;
- else
- return 0.7229568362236023f;
- else if ((val & 0b0001) == 1) // 110
- return 0.5626170039176941f;
- else
- return 0.44070982933044434f;
- else if ((val & 0b0010) == 2) // 10
- if ((val & 0b0001) == 1) // 101
- return 0.33791524171829224f;
- else
- return 0.24611230194568634f;
- else if ((val & 0b0001) == 1) // 100
- return 0.16093020141124725f;
- else
- return 0.07958029955625534f;
-
- else if ((val & 0b0100) == 4) // 0
- if ((val & 0b0010) == 2) // 01
- if ((val & 0b0001) == 1) // 011
- return 0.0f;
- else
- return -0.09105003625154495f;
- else if ((val & 0b0001) == 1) // 010
- return -0.18477343022823334f;
- else
- return -0.28444138169288635f;
- else if ((val & 0b0010) == 2) // 00
- if ((val & 0b0001) == 1) // 001
- return -0.39491748809814453f;
- else
- return -0.5250730514526367f;
- else if ((val & 0b0001) == 1) // 000
- return -0.6961928009986877f;
- else
- return -1.0f;
-}
+__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; }
__device__ unsigned char dQuantizeNF4(float x) {
@@ -431,7 +380,6 @@ __global__ void kQuantizeBlockwise(
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}
- unsigned char packed_4bit = 0;
switch (DATA_TYPE) {
case General8bit:
#pragma unroll NUM_PER_TH
@@ -445,17 +393,15 @@ __global__ void kQuantizeBlockwise(
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
- packed_4bit |= dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
- packed_4bit |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
- qvals[j] = packed_4bit;
+ qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
+ qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH / 2; j++) {
- packed_4bit |= dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
- packed_4bit |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
- qvals[j] = packed_4bit;
+ qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
+ qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
}
@@ -513,8 +459,8 @@ __global__ void
case FP4:
#pragma unroll NUM_PER_TH
for (int j = 0; j < NUM_PER_TH; j++) {
- vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
- vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
+ vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max;
+ vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max;
}
break;
case NF4:
@@ -2355,7 +2301,7 @@ __global__ void kgemm_4bit_inference(
#pragma unroll 16
for (int i = 0; i < 16; i++)
- quant_map[i] = nf4_data[i];
+ quant_map[i] = nf4_dequantization_lut[i];
//__shared__ T quant_map[16*160];
T local_A[2];
diff --git a/csrc/kernels.hip b/csrc/kernels.hip
index 58f6ed065..bef6cffa6 100644
--- a/csrc/kernels.hip
+++ b/csrc/kernels.hip
@@ -19,37 +19,42 @@
#define NUM 4
#define NUM_BLOCK 4096
-__device__ static float nf4_data[16] = {-1.0, -0.6961928009986877, -0.5250730514526367, -0.39491748809814453, -0.28444138169288635, -0.18477343022823334, -0.09105003625154495, 0.0, 0.07958029955625534, 0.16093020141124725, 0.24611230194568634, 0.33791524171829224, 0.44070982933044434, 0.5626170039176941, 0.7229568362236023, 1.0};
+__device__ static float fp4_dequantization_lut[8] = {
+ 0.0f, // 0b000
+ 0.005208333333f, // 0b001
+ 0.66666667f, // 0b010
+ 1.0f, // 0b011
+ 0.33333333f, // 0b100
+ 0.5f, // 0b101
+ 0.16666667f, // 0b110
+ 0.25f // 0b111
+};
+
+__device__ static float nf4_dequantization_lut[16] = {
+ -1.0f, // 0b0000
+ -0.6961928009986877f, // 0b0001
+ -0.5250730514526367f, // 0b0010
+ -0.39491748809814453f, // 0b0011
+ -0.28444138169288635f, // 0b0100
+ -0.18477343022823334f, // 0b0101
+ -0.09105003625154495f, // 0b0110
+ 0.0f, // 0b0111
+ 0.07958029955625534f, // 0b1000
+ 0.16093020141124725f, // 0b1001
+ 0.24611230194568634f, // 0b1010
+ 0.33791524171829224f, // 0b1011
+ 0.44070982933044434f, // 0b1100
+ 0.5626170039176941f, // 0b1101
+ 0.7229568362236023f, // 0b1110
+ 1.0f // 0b1111
+};
// source: https://stackoverflow.com/questions/17399119/how-do-i-use-atomicmax-on-floating-point-values-in-cuda
// Luckily we have atomicmax and atomicmin in ROCm
-
-__device__ float dDequantizeFP4Tree(unsigned char val, float absmax)
-{
- float sign = (val & 0b1000) == 8 ? -1.0f : 1.0f;
- if((val & 0b0100) == 4) // 0
- if((val & 0b0010) == 2) //01
- if((val & 0b0001) == 1) // 111
- return 0.25000000f*absmax*sign; // 1111
- else
- return 0.16666667f*absmax*sign; // 1110
- else
- if((val & 0b0001) == 1) // 110
- return 0.50000000f*absmax*sign; // 1101
- else
- return 0.33333333f*absmax*sign; // 1100
- else
- if((val & 0b0010) == 2) //10
- if((val & 0b0001) == 1) // 101
- return 1.00000000f*absmax*sign; // 1011
- else
- return 0.66666667f*absmax*sign; // 1010
- else
- if((val & 0b0001) == 1) // 100
- return 5.208333333e-03f*absmax*sign; // 1001
- else
- return 0.00000000f*absmax*sign; // 1000
+__device__ __forceinline__ float dDequantizeFP4Tree(unsigned char val) {
+ float sign = 1.0f - 2 * ((val & 0b1000) >> 3);
+ return fp4_dequantization_lut[val & 0b111] * sign;
}
__device__ unsigned char dQuantizeFP4(float x)
@@ -101,61 +106,7 @@ __device__ unsigned char dQuantizeFP4(float x)
return 0b0000+sign;
}
-
-__device__ __forceinline__ float dDequantizeNF4(unsigned char val)
-{
-
- // the values for this tree was generated by test_normal_map_tree
- // in the file tests/test_functional.py
- if((val & 0b1000) == 8)
- if((val & 0b0100) == 4) // 1
- if((val & 0b0010) == 2) // 11
- if((val & 0b0001) == 1) // 111
- return 1.0f;
- else
- return 0.7229568362236023f;
- else
- if((val & 0b0001) == 1) // 110
- return 0.5626170039176941f;
- else
- return 0.44070982933044434f;
- else
- if((val & 0b0010) == 2) //10
- if((val & 0b0001) == 1) // 101
- return 0.33791524171829224f;
- else
- return 0.24611230194568634f;
- else
- if((val & 0b0001) == 1) // 100
- return 0.16093020141124725f;
- else
- return 0.07958029955625534f;
-
- else
- if((val & 0b0100) == 4) // 0
- if((val & 0b0010) == 2) //01
- if((val & 0b0001) == 1) // 011
- return 0.0f;
- else
- return -0.09105003625154495f;
- else
- if((val & 0b0001) == 1) // 010
- return -0.18477343022823334f;
- else
- return -0.28444138169288635f;
- else
- if((val & 0b0010) == 2) //00
- if((val & 0b0001) == 1) // 001
- return -0.39491748809814453f;
- else
- return -0.5250730514526367f;
- else
- if((val & 0b0001) == 1) // 000
- return -0.6961928009986877f;
- else
- return -1.0f;
-
-}
+__device__ __forceinline__ float dDequantizeNF4(unsigned char val) { return nf4_dequantization_lut[val & 0x0F]; }
__device__ unsigned char dQuantizeNF4(float x)
{
@@ -456,7 +407,6 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
LoadFloat(loadf).Load(&rand[local_rand_idx], rand_vals, BLOCK_SIZE, 0);
}
- unsigned char packed_4bit = 0;
switch(DATA_TYPE)
{
case General8bit:
@@ -473,18 +423,16 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
- packed_4bit |= dQuantizeFP4(((float)vals[2*j])*local_abs_max) << 4;
- packed_4bit |= dQuantizeFP4(((float)vals[2*j+1])*local_abs_max);
- qvals[j] = packed_4bit;
+ qvals[j] = dQuantizeFP4(((float)vals[2 * j]) * local_abs_max) << 4;
+ qvals[j] |= dQuantizeFP4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
case NF4:
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH/2; j++)
{
- packed_4bit |= dQuantizeNF4(((float)vals[2*j])*local_abs_max) << 4;
- packed_4bit |= dQuantizeNF4(((float)vals[2*j+1])*local_abs_max);
- qvals[j] = packed_4bit;
+ qvals[j] = dQuantizeNF4(((float)vals[2 * j]) * local_abs_max) << 4;
+ qvals[j] |= dQuantizeNF4(((float)vals[2 * j + 1]) * local_abs_max);
}
break;
}
@@ -546,8 +494,8 @@ __global__ void kDequantizeBlockwise(float *code, unsigned char * A, float * abs
#pragma unroll NUM_PER_TH
for(int j = 0; j < NUM_PER_TH; j++)
{
- vals[j*2] = dDequantizeFP4Tree(qvals[j] >> 4, local_abs_max);
- vals[j*2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F, local_abs_max);
+ vals[j * 2] = dDequantizeFP4Tree(qvals[j] >> 4) * local_abs_max;
+ vals[j * 2 + 1] = dDequantizeFP4Tree(qvals[j] & 0x0F) * local_abs_max;
}
break;
case NF4:
@@ -2507,7 +2455,7 @@ template __global__ void kgemm_4bit_inference(int M, i
#pragma unroll 16
for(int i = 0; i < 16; i++)
- quant_map[i] = nf4_data[i];
+ quant_map[i] = nf4_dequantization_lut[i];
//__shared__ T quant_map[16*160];
T local_A[2];
diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp
index 9c4cab9cc..b5d9afc6b 100644
--- a/csrc/pythonInterface.cpp
+++ b/csrc/pythonInterface.cpp
@@ -12,6 +12,9 @@
#if BUILD_MPS
// #include
#endif
+#if BUILD_XPU
+#include
+#endif
#include
// Compatibility between HIP/CUDA APIs
@@ -308,6 +311,90 @@ void spmm_coo_very_sparse_naive_int8(
}
#endif
+#if BUILD_XPU
+
+void dequantizeBlockwise_fp16(
+ float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream);
+}
+
+void dequantizeBlockwise_fp16_fp4(
+ float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream);
+}
+
+void dequantizeBlockwise_fp16_nf4(
+ float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream);
+}
+
+void dequantizeBlockwise_fp32(
+ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream);
+}
+
+void dequantizeBlockwise_fp32_fp4(
+ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream);
+}
+
+void dequantizeBlockwise_fp32_nf4(
+ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream);
+}
+
+void dequantizeBlockwise_bf16(
+ float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
+ sycl::queue* stream
+) {
+ dequantizeBlockwise(code, A, absmax, out, blocksize, n, stream);
+}
+
+void dequantizeBlockwise_bf16_fp4(
+ float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
+ sycl::queue* stream
+) {
+ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream);
+}
+
+void dequantizeBlockwise_bf16_nf4(
+ float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
+ sycl::queue* stream
+) {
+ dequantizeBlockwise(NULL, A, absmax, out, blocksize, n, stream);
+}
+
+void gemv_4bit_inference_fp16(
+ int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda,
+ int ldb, int ldc, int blocksize, sycl::queue* stream
+) {
+ gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
+}
+
+void gemv_4bit_inference_bf16(
+ int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype,
+ sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
+) {
+ gemv_4bit_inference(
+ m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream
+ );
+}
+
+void gemv_4bit_inference_fp32(
+ int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
+ int ldc, int blocksize, sycl::queue* stream
+) {
+ gemv_4bit_inference(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
+}
+
+#endif
+
extern "C" {
#if BUILD_CUDA || BUILD_HIP
void cquantize(float* code, float* A, unsigned char* out, int n) { quantize(code, A, out, n); }
@@ -658,6 +745,88 @@ void cgemm_4bit_inference_naive_fp32(
#endif
+#if BUILD_XPU
+
+void cdequantize_blockwise_fp16_fp4(
+ float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise_fp16_fp4(code, A, absmax, out, blocksize, n, stream);
+}
+
+void cdequantize_blockwise_fp16(
+ float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise_fp16(code, A, absmax, out, blocksize, n, stream);
+}
+
+void cdequantize_blockwise_fp16_nf4(
+ float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise_fp16_nf4(code, A, absmax, out, blocksize, n, stream);
+}
+
+void cdequantize_blockwise_fp32(
+ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise_fp32(code, A, absmax, out, blocksize, n, stream);
+}
+
+void cdequantize_blockwise_fp32_fp4(
+ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise_fp32_fp4(code, A, absmax, out, blocksize, n, stream);
+}
+
+void cdequantize_blockwise_fp32_nf4(
+ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
+) {
+ dequantizeBlockwise_fp32_nf4(code, A, absmax, out, blocksize, n, stream);
+}
+
+void cdequantize_blockwise_bf16(
+ float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
+ sycl::queue* stream
+) {
+ dequantizeBlockwise_bf16(code, A, absmax, out, blocksize, n, stream);
+}
+
+void cdequantize_blockwise_bf16_fp4(
+ float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
+ sycl::queue* stream
+) {
+ dequantizeBlockwise_bf16_fp4(code, A, absmax, out, blocksize, n, stream);
+}
+
+void cdequantize_blockwise_bf16_nf4(
+ float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
+ sycl::queue* stream
+) {
+ dequantizeBlockwise_bf16_nf4(code, A, absmax, out, blocksize, n, stream);
+}
+
+void cgemv_4bit_inference_fp16(
+ int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda,
+ int ldb, int ldc, int blocksize, sycl::queue* stream
+) {
+ gemv_4bit_inference_fp16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
+}
+
+void cgemv_4bit_inference_bf16(
+ int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype,
+ sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
+) {
+ gemv_4bit_inference_bf16(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
+}
+
+void cgemv_4bit_inference_fp32(
+ int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
+ int ldc, int blocksize, sycl::queue* stream
+) {
+ gemv_4bit_inference_fp32(m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize, stream);
+}
+
+#endif
+
void cquantize_blockwise_cpu_fp32(
float* code, float* A, float* absmax, unsigned char* out, long long blocksize, long long n
) {
diff --git a/csrc/xpu_kernels.cpp b/csrc/xpu_kernels.cpp
new file mode 100644
index 000000000..8ee8add98
--- /dev/null
+++ b/csrc/xpu_kernels.cpp
@@ -0,0 +1,281 @@
+#include "xpu_kernels.h"
+#include
+#include
+#include
+
+#include
+
+inline float dDequantizeFP4(unsigned char val) {
+ if ((val & 0b1000) == 8)
+ if ((val & 0b0100) == 4)
+ if ((val & 0b0010) == 2)
+ if ((val & 0b0001) == 1)
+ return -0.25000000f;
+ else
+ return -0.16666667f;
+ else if ((val & 0b0001) == 1)
+ return -0.50000000f;
+ else
+ return -0.33333333f;
+ else if ((val & 0b0010) == 2)
+ if ((val & 0b0001) == 1)
+ return -1.00000000f;
+ else
+ return -0.66666667f;
+ else if ((val & 0b0001) == 1)
+ return -5.208333333e-03f;
+ else
+ return 0.00000000f;
+ else if ((val & 0b0100) == 4)
+ if ((val & 0b0010) == 2)
+ if ((val & 0b0001) == 1)
+ return 0.25000000f;
+ else
+ return 0.16666667f;
+ else if ((val & 0b0001) == 1)
+ return 0.50000000f;
+ else
+ return 0.33333333f;
+ else if ((val & 0b0010) == 2)
+ if ((val & 0b0001) == 1)
+ return 1.00000000f;
+ else
+ return 0.66666667f;
+ else if ((val & 0b0001) == 1)
+ return 5.208333333e-03f;
+ else
+ return 0.00000000f;
+}
+
+inline float dDequantizeNF4(unsigned char val) {
+
+ // the values for this tree was generated by test_normal_map_tree
+ // in the file tests/test_functional.py
+ if ((val & 0b1000) == 8)
+ if ((val & 0b0100) == 4) // 1
+ if ((val & 0b0010) == 2) // 11
+ if ((val & 0b0001) == 1) // 111
+ return 1.0f; //*1111
+ else
+ return 0.7229568362236023f; //*1110
+ else if ((val & 0b0001) == 1) // 110
+ return 0.5626170039176941f; //*1101
+ else
+ return 0.44070982933044434f; //*1100
+ else if ((val & 0b0010) == 2) // 10
+ if ((val & 0b0001) == 1) // 101
+ return 0.33791524171829224f; //*1011
+ else
+ return 0.24611230194568634f; //*1010
+ else if ((val & 0b0001) == 1) // 100
+ return 0.16093020141124725f; //*1001
+ else
+ return 0.07958029955625534f; //*1000
+
+ else if ((val & 0b0100) == 4) // 0
+ if ((val & 0b0010) == 2) // 01
+ if ((val & 0b0001) == 1) // 011
+ return 0.0f; //*0111
+ else
+ return -0.09105003625154495f; //*0110
+ else if ((val & 0b0001) == 1) // 010
+ return -0.18477343022823334f; //*0101
+ else
+ return -0.28444138169288635f; //*0100
+ else if ((val & 0b0010) == 2) // 00
+ if ((val & 0b0001) == 1) // 001
+ return -0.39491748809814453f; //*0011
+ else
+ return -0.5250730514526367f; //*0010
+ else if ((val & 0b0001) == 1) // 000
+ return -0.6961928009986877f; //*0001
+ else
+ return -1.0f; //*0000
+}
+
+template
+SYCL_EXTERNAL void kDequantizeBlockwise::operator()(sycl::nd_item<1> item) const {
+ const int base_idx = item.get_group(0) * TILE_SIZE;
+ size_t local_idx = item.get_local_id(0) * NUM_PER_TH;
+ float local_abs_max = -FLT_MAX;
+ int local_load_idx = 0;
+ int local_store_idx = 0;
+
+ uint8_t qvals[NUM_PER_TH];
+ T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)];
+
+ if (DATA_TYPE > 0) {
+ local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx);
+ local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2);
+ } else {
+ local_load_idx = sycl::min(TILE_SIZE, n - base_idx);
+ local_store_idx = local_load_idx;
+ }
+
+ // Avoid expensive division by the blocksize (as blocksize will always be a
+ // power-of-2)
+ local_abs_max = absmax[(base_idx + local_idx) >> (31 - std::countl_zero(blocksize))];
+
+ if (local_idx + NUM_PER_TH < local_load_idx) {
+ reinterpret_cast(&)[NUM_PER_TH]>(qvals)[0] =
+ reinterpret_cast*>(A)[(base_idx + local_idx) / NUM_PER_TH];
+ } else {
+#pragma unroll NUM_PER_TH
+ for (int i = 0; i < NUM_PER_TH; i++) {
+ if (local_idx + i < local_load_idx) {
+ qvals[i] = A[base_idx + local_idx + i];
+ } else {
+ qvals[i] = (uint8_t)0;
+ }
+ }
+ }
+
+ switch (DATA_TYPE) {
+ case General8bit:
+#pragma unroll NUM_PER_TH
+ for (int j = 0; j < NUM_PER_TH; j++)
+ vals[j] = code[qvals[j]] * local_abs_max;
+ break;
+ case FP4:
+#pragma unroll NUM_PER_TH
+ for (int j = 0; j < NUM_PER_TH; j++) {
+ vals[j * 2] = dDequantizeFP4(qvals[j] >> 4) * local_abs_max;
+ vals[j * 2 + 1] = dDequantizeFP4(qvals[j] & 0x0F) * local_abs_max;
+ }
+ break;
+ case NF4:
+#pragma unroll NUM_PER_TH
+ for (int j = 0; j < NUM_PER_TH; j++) {
+ vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max;
+ vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max;
+ }
+ break;
+ }
+
+ const int local_dst_size = (DATA_TYPE > 0) ? NUM_PER_TH * 2 : NUM_PER_TH;
+ int local_dst_idx = (DATA_TYPE > 0) ? local_idx * 2 : local_idx;
+
+ if (local_dst_idx + local_dst_size < local_store_idx) {
+ reinterpret_cast*>(
+ out
+ )[(((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx) / local_dst_size] =
+ reinterpret_cast(&)[local_dst_size]>(vals)[0];
+ } else {
+#pragma unroll NUM_PER_TH
+ for (int i = 0; i < local_dst_size; i++) {
+ if (local_dst_idx + i < local_store_idx) {
+ out[((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx + i] = vals[i];
+ }
+ }
+ }
+}
+
+template
+SYCL_EXTERNAL void
+ kgemv_4bit_inference::operator()(sycl::nd_item<1> item) const {
+ size_t idx = item.get_local_id();
+ const int sg_idx = idx / SUBG_SIZE;
+ const int sg_lane = idx % SUBG_SIZE;
+ const int num_values_4bit = SUBG_SIZE;
+ const int row_B = NUM_PER_THREAD * item.get_group().get_group_id() + sg_idx;
+ const int offset_B = ldb * row_B;
+ const int num_values_8bit = num_values_4bit / 2;
+ float local_C = 0.0f;
+
+ unsigned char local_B_4bit[num_values_8bit];
+ T local_B[num_values_4bit / 4];
+ T local_A[num_values_4bit / 4];
+ T local_absmax = T(0.0f);
+
+ if (idx < 16) {
+ quant_map[idx] = T(datatype[idx]);
+ }
+
+ item.barrier(sycl::access::fence_space::local_space);
+
+ for (int inner_idx = sg_lane * num_values_4bit; inner_idx < K; inner_idx += SUBG_SIZE * num_values_4bit) {
+ const int inner_idx_halved = inner_idx / 2;
+
+ // Avoid expensive division by the blocksize (as blocksize will always be a
+ // power-of-2)
+ const int absidx = ((2 * offset_B) + inner_idx) >> (31 - std::countl_zero((unsigned int)blocksize));
+ local_absmax = absmax[absidx];
+
+ if (row_B < N) {
+ if ((inner_idx_halved + num_values_8bit) < (K / 2)) {
+ reinterpret_cast(&)[num_values_8bit]>(local_B_4bit)[0] =
+ reinterpret_cast*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)];
+ } else {
+#pragma unroll
+ for (int j = 0; j < (num_values_8bit); j++)
+ if ((inner_idx_halved) + j < (K / 2))
+ local_B_4bit[j] = B[offset_B + inner_idx_halved + j];
+ else
+ local_B_4bit[j] = 0b01110111;
+ }
+ } else {
+#pragma unroll
+ for (int j = 0; j < (num_values_8bit); j++)
+ local_B_4bit[j] = 0b01110111;
+ }
+
+ for (int i = 0; i < 4; i++) {
+#pragma unroll
+ for (int k = 0; k < num_values_8bit / 4; k++) {
+ local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax;
+ local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax;
+ }
+
+ if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) {
+ if (BITS == 16) {
+ reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] =
+ reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 4) + i];
+ } else {
+ reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] =
+ reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0];
+ reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[1] =
+ reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1];
+ }
+
+ } else {
+#pragma unroll
+ for (int k = 0; k < num_values_4bit / 4; k++)
+ if (inner_idx + (i * num_values_4bit / 4) + k < K)
+ local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)];
+ else
+ local_A[k] = T(0.0f);
+ }
+
+// accumulate in float for accuracy;
+#pragma unroll
+ for (int k = 0; k < num_values_4bit / 4; k++) {
+ local_C += (float)(local_A[k] * local_B[k]);
+ }
+ }
+ }
+
+ local_C = sycl::reduce_over_group(item.get_sub_group(), local_C, sycl::plus<>());
+
+ if (row_B < N && sg_lane == 0)
+ out[row_B] = T(local_C);
+}
+
+//==============================================================
+// TEMPLATE DEFINITIONS
+//==============================================================
+
+template class kDequantizeBlockwise;
+template class kDequantizeBlockwise;
+template class kDequantizeBlockwise;
+
+template class kDequantizeBlockwise;
+template class kDequantizeBlockwise;
+template class kDequantizeBlockwise;
+
+template class kDequantizeBlockwise;
+template class kDequantizeBlockwise;
+template class kDequantizeBlockwise;
+
+template class kgemv_4bit_inference;
+template class kgemv_4bit_inference;
+template class kgemv_4bit_inference;
diff --git a/csrc/xpu_kernels.h b/csrc/xpu_kernels.h
new file mode 100644
index 000000000..caa7e6716
--- /dev/null
+++ b/csrc/xpu_kernels.h
@@ -0,0 +1,52 @@
+#include
+#include
+
+#ifndef xpu_kernels
+#define xpu_kernels
+
+template class kDequantizeBlockwise {
+ public:
+ SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const;
+
+ kDequantizeBlockwise(float* code_, uint8_t* A_, float* absmax_, T* out_, const int blocksize_, const int n_)
+ : code(code_), A(A_), absmax(absmax_), out(out_), blocksize(blocksize_), n(n_) {}
+
+ private:
+ float* code;
+ uint8_t* A;
+ float* absmax;
+ T* out;
+ const int blocksize;
+ const int n;
+};
+
+template class kgemv_4bit_inference {
+ public:
+ SYCL_EXTERNAL void operator()(sycl::nd_item<1> item) const;
+
+ kgemv_4bit_inference(
+ int M_, int N_, int K_, T* A_, unsigned char* B_, float* absmax_, const float* datatype_, T* out_, int lda_,
+ int ldb_, int ldc_, int blocksize_
+ )
+ : M(M_), N(N_), K(K_), A(A_), B(B_), absmax(absmax_), datatype(datatype_), out(out_), lda(lda_), ldb(ldb_),
+ ldc(ldc_), blocksize(blocksize_), quant_map() {}
+
+ void sycl_ker_local_memory_creation(sycl::handler& cgh) { quant_map = sycl::local_accessor(16, cgh); }
+
+ private:
+ int M;
+ int N;
+ int K;
+ T* A;
+ unsigned char* B;
+ float* absmax;
+ const float* datatype;
+ T* out;
+ int lda;
+ int ldb;
+ int ldc;
+ int blocksize;
+ sycl::local_accessor quant_map;
+};
+
+#endif
diff --git a/csrc/xpu_ops.cpp b/csrc/xpu_ops.cpp
new file mode 100644
index 000000000..aa6ac808f
--- /dev/null
+++ b/csrc/xpu_ops.cpp
@@ -0,0 +1,102 @@
+#include
+#include
+#include
+
+template
+void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, T* out, int blocksize, const int n, sycl::queue* stream
+) {
+ auto& queue = *stream;
+ const int workgroup_size = 128;
+ const int num_per_th = 4;
+ const int tile_size = workgroup_size * num_per_th;
+ if (DATA_TYPE > 0) {
+ const int workgroup_num = (n + tile_size * 2 - 1) / (tile_size * 2);
+ sycl::range<1> local_range{(size_t)workgroup_size};
+ sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
+ kDequantizeBlockwise kfn(code, A, absmax, out, blocksize / 2, n);
+ sycl_kernel_submit(
+ sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
+ );
+ } else {
+ const int workgroup_num = (n + tile_size - 1) / tile_size;
+ sycl::range<1> local_range{(size_t)workgroup_size};
+ sycl::range<1> global_range{(size_t)workgroup_num * (size_t)workgroup_size};
+ kDequantizeBlockwise kfn(code, A, absmax, out, blocksize, n);
+ sycl_kernel_submit(
+ sycl::nd_range<1>(sycl::range<1>(global_range), sycl::range<1>(local_range)), queue, kfn
+ );
+ }
+}
+
+template
+void gemv_4bit_inference(
+ int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
+ int blocksize, sycl::queue* stream
+) {
+
+ auto& queue = *stream;
+
+ const size_t GROUP_SIZE = 128; // workgroup_size
+ const size_t SUBG_SIZE = 32; // subgroup_size
+ const size_t NUM_PER_THREAD = GROUP_SIZE / SUBG_SIZE;
+ size_t workgroup_num = (n + NUM_PER_THREAD - 1) / NUM_PER_THREAD;
+
+ kgemv_4bit_inference kfn(
+ m, n, k, A, B, absmax, datatype, out, lda, ldb, ldc, blocksize
+ );
+
+ sycl_comp_kernel_submit(
+ sycl::nd_range<1>(sycl::range<1>(GROUP_SIZE * workgroup_num), sycl::range<1>(GROUP_SIZE)), queue, kfn
+ );
+}
+
+//==============================================================
+// TEMPLATE DEFINITIONS
+//==============================================================
+
+template void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
+);
+template void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
+);
+template void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, float* out, int blocksize, const int n, sycl::queue* stream
+);
+
+template void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
+);
+template void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
+);
+template void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, sycl::half* out, int blocksize, const int n, sycl::queue* stream
+);
+
+template void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
+ sycl::queue* stream
+);
+template void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
+ sycl::queue* stream
+);
+template void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, sycl::ext::oneapi::bfloat16* out, int blocksize, const int n,
+ sycl::queue* stream
+);
+
+template void gemv_4bit_inference(
+ int m, int n, int k, sycl::half* A, unsigned char* B, float* absmax, float* datatype, sycl::half* out, int lda,
+ int ldb, int ldc, int blocksize, sycl::queue* stream
+);
+template void gemv_4bit_inference(
+ int m, int n, int k, sycl::ext::oneapi::bfloat16* A, unsigned char* B, float* absmax, float* datatype,
+ sycl::ext::oneapi::bfloat16* out, int lda, int ldb, int ldc, int blocksize, sycl::queue* stream
+);
+template void gemv_4bit_inference(
+ int m, int n, int k, float* A, unsigned char* B, float* absmax, float* datatype, float* out, int lda, int ldb,
+ int ldc, int blocksize, sycl::queue* stream
+);
diff --git a/csrc/xpu_ops.h b/csrc/xpu_ops.h
new file mode 100644
index 000000000..142d6c161
--- /dev/null
+++ b/csrc/xpu_ops.h
@@ -0,0 +1,46 @@
+#ifndef xpu_ops_H
+#define xpu_ops_H
+
+#include
+#include
+#include
+#include
+
+#include
+#include
+
+#include
+
+template
+static inline void sycl_kernel_submit(sycl::nd_range range, sycl::queue q, ker_t ker) {
+ auto cgf = [&](::sycl::handler& cgh)
+ [[sycl::reqd_sub_group_size(subgroup_size)]] { cgh.parallel_for(range, ker); };
+ q.submit(cgf);
+}
+
+template
+static inline void sycl_comp_kernel_submit(sycl::nd_range range, sycl::queue q, ker_t ker) {
+ auto cgf = [&](::sycl::handler& cgh) [[sycl::reqd_sub_group_size(subgroup_size)]] {
+ ker.sycl_ker_local_memory_creation(cgh);
+ cgh.parallel_for(range, ker);
+ };
+ q.submit(cgf);
+}
+
+typedef enum DataType_t {
+ General8bit = 0,
+ FP4 = 1,
+ NF4 = 2,
+} DataType_t;
+
+template
+void dequantizeBlockwise(
+ float* code, unsigned char* A, float* absmax, T* out, int workgroup_size, const int n, sycl::queue* stream
+);
+template
+void gemv_4bit_inference(
+ int m, int n, int k, T* A, unsigned char* B, float* absmax, float* datatype, T* out, int lda, int ldb, int ldc,
+ int blocksize, sycl::queue* stream
+);
+
+#endif
diff --git a/docs/source/installation.mdx b/docs/source/installation.mdx
index e61ce4655..daa06a3c6 100644
--- a/docs/source/installation.mdx
+++ b/docs/source/installation.mdx
@@ -16,17 +16,19 @@ Welcome to the installation guide for the `bitsandbytes` library! This document
## CUDA[[cuda]]
-`bitsandbytes` is currently supported on NVIDIA GPUs with [Compute Capability](https://developer.nvidia.com/cuda-gpus) 5.0+.
-The library can be built using CUDA Toolkit versions as old as **11.6** on Windows and **11.4** on Linux.
+`bitsandbytes` is currently supported on NVIDIA GPUs with [Compute Capability](https://developer.nvidia.com/cuda-gpus) 6.0+.
+The library can be built using CUDA Toolkit versions as old as **11.8**.
| **Feature** | **CC Required** | **Example Hardware Requirement** |
|---------------------------------|-----------------|---------------------------------------------|
-| LLM.int8() | 7.5+ | Turing (RTX 20 series, T4) or newer GPUs |
-| 8-bit optimizers/quantization | 5.0+ | Maxwell (GTX 900 series, TITAN X, M40) or newer GPUs |
-| NF4/FP4 quantization | 5.0+ | Maxwell (GTX 900 series, TITAN X, M40) or newer GPUs |
+| LLM.int8() | 7.5+ | Turing (RTX 20 series, T4) or newer GPUs |
+| 8-bit optimizers/quantization | 6.0+ | Pascal (GTX 10X0 series, P100) or newer GPUs|
+| NF4/FP4 quantization | 6.0+ | Pascal (GTX 10X0 series, P100) or newer GPUs|
> [!WARNING]
-> Support for Maxwell GPUs is deprecated and will be removed in a future release. For the best results, a Turing generation device or newer is recommended.
+> Support for Maxwell GPUs is deprecated and will be removed in a future release.
+> Maxwell support is not included in PyPI distributions from `v0.48.0` on and must be built from source.
+> For the best results, a Turing generation device or newer is recommended.
### Installation via PyPI[[cuda-pip]]
@@ -36,12 +38,12 @@ The currently distributed `bitsandbytes` packages are built with the following c
| **OS** | **CUDA Toolkit** | **Host Compiler** | **Targets**
|--------------------|------------------|----------------------|--------------
-| **Linux x86-64** | 11.8 - 12.6 | GCC 11.2 | sm50, sm60, sm75, sm80, sm86, sm89, sm90
-| **Linux x86-64** | 12.8 | GCC 11.2 | sm75, sm80, sm86, sm89, sm90, sm100, sm120
+| **Linux x86-64** | 11.8 - 12.6 | GCC 11.2 | sm60, sm70, sm75, sm80, sm86, sm89, sm90
+| **Linux x86-64** | 12.8 - 12.9 | GCC 11.2 | sm70, sm75, sm80, sm86, sm89, sm90, sm100, sm120
| **Linux aarch64** | 11.8 - 12.6 | GCC 11.2 | sm75, sm80, sm90
-| **Linux aarch64** | 12.8 | GCC 11.2 | sm75, sm80, sm90, sm100
+| **Linux aarch64** | 12.8 - 12.9 | GCC 11.2 | sm75, sm80, sm90, sm100, sm120
| **Windows x86-64** | 11.8 - 12.6 | MSVC 19.43+ (VS2022) | sm50, sm60, sm75, sm80, sm86, sm89, sm90
-| **Windows x86-64** | 12.8 | MSVC 19.43+ (VS2022) | sm75, sm80, sm86, sm89, sm90, sm100, sm120
+| **Windows x86-64** | 12.8 - 12.9 | MSVC 19.43+ (VS2022) | sm70, sm75, sm80, sm86, sm89, sm90, sm100, sm120
Use `pip` or `uv` to install:
@@ -67,7 +69,7 @@ For example, to install a compiler and CMake on Ubuntu:
apt-get install -y build-essential cmake
```
-You should also install CUDA Toolkit by following the [NVIDIA CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) guide. The current minimum supported CUDA Toolkit version that we test with is **11.8**.
+You should also install CUDA Toolkit by following the [NVIDIA CUDA Installation Guide for Linux](https://docs.nvidia.com/cuda/cuda-installation-guide-linux/index.html) guide. The current minimum supported CUDA Toolkit version that we support is **11.8**.
```bash
git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/
@@ -84,7 +86,7 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise
Compilation from source on Windows systems require Visual Studio with C++ support as well as an installation of the CUDA Toolkit.
-To compile from source, you need CMake >= **3.22.1** and Python >= **3.9** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA. The current minimum supported CUDA Toolkit version that we test with is **11.8**.
+To compile from source, you need CMake >= **3.22.1** and Python >= **3.9** installed. You should also install CUDA Toolkit by following the [CUDA Installation Guide for Windows](https://docs.nvidia.com/cuda/cuda-installation-guide-microsoft-windows/index.html) guide from NVIDIA. The current minimum supported CUDA Toolkit version that we support is **11.8**.
```bash
git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/
@@ -138,8 +140,8 @@ We provide an early preview of support for AMD and Intel hardware as part of a d
| **Backend** | **Supported Versions** | **Python versions** | **Architecture Support** | **Status** |
|-------------|------------------------|---------------------------|-------------------------|------------|
| **AMD ROCm** | 6.1+ | 3.10+ | minimum CDNA - `gfx90a`, RDNA - `gfx1100` | Alpha |
-| **Intel CPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel CPU | Alpha |
-| **Intel GPU** | v2.4.0+ (`ipex`) | 3.10+ | Intel GPU | Experimental |
+| **Intel CPU** | v2.4.0+ | 3.10+ | Intel CPU | Alpha |
+| **Intel GPU** | v2.7.0+ | 3.10+ | Intel GPU | Experimental |
| **Ascend NPU** | 2.1.0+ (`torch_npu`) | 3.10+ | Ascend NPU | Experimental |
For each supported backend, follow the respective instructions below:
@@ -179,7 +181,6 @@ pip install torch --index-url https://download.pytorch.org/whl/rocm6.3/
* A compatible PyTorch version with Intel XPU support is required. It is recommended to use the latest stable release. See [Getting Started on Intel GPU](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html) for guidance.
-* The [Intel Extension for PyTorch](https://intel.github.io/intel-extension-for-pytorch/xpu/latest/) is recommended for performance improvements.
@@ -235,27 +236,18 @@ pip install -e . # `-e` for "editable" install, when developing BNB (otherwise
-#### Intel CPU + XPU
+#### Intel CPU + GPU(XPU)
-
-If you are using Intel CPU on Linux or Intel XPU on Linux/Windows, please follow the [instruction](https://pytorch-extension.intel.com/) or the following command to install intel_extension_for_pytorch so you can get better performance.
-
-CPU: `pip install intel_extension_for_pytorch`
-XPU: `pip install intel_extension_for_pytorch --extra-index-url https://pytorch-extension.intel.com/release-whl/stable/xpu/us/`
-
-Install bitsandbytes:
-CPU: Need to build CPU C++ codes
+CPU needs to build CPU C++ codes, while XPU needs to build sycl codes.
+Run `export bnb_device=xpu` if you are using xpu, run `export bnb_device=cpu` if you are using cpu.
```
git clone https://github.com/bitsandbytes-foundation/bitsandbytes.git && cd bitsandbytes/
-cmake -DCOMPUTE_BACKEND=cpu -S .
+cmake -DCOMPUTE_BACKEND=$bnb_device -S .
make
-pip install .
-```
-XPU:
-```
-pip install git+https://github.com/bitsandbytes-foundation/bitsandbytes.git
+pip install -e .
```
+
diff --git a/install_cuda.py b/install_cuda.py
index c87deaedf..0122be04b 100644
--- a/install_cuda.py
+++ b/install_cuda.py
@@ -87,7 +87,7 @@ def main():
# Install CUDA version(s)
if version == "all":
- for ver in cuda_versions.keys():
+ for ver in cuda_versions:
install_cuda(ver, base_path, download_path)
elif version in cuda_versions:
install_cuda(version, base_path, download_path)
diff --git a/pyproject.toml b/pyproject.toml
index af4c8c240..61b35c648 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,6 +1,6 @@
[build-system]
-requires = ["setuptools >= 63.0.0"]
-build-backend = "setuptools.build_meta"
+requires = ["scikit-build-core", "setuptools >= 63.0.0"]
+build-backend = "scikit_build_core.setuptools.build_meta"
[project]
name = "bitsandbytes"
@@ -42,8 +42,9 @@ classifiers = [
"Topic :: Scientific/Engineering :: Artificial Intelligence"
]
dependencies = [
- "torch>=2.2,<3",
- "numpy>=1.17"
+ "torch>=2.3,<3",
+ "numpy>=1.17",
+ "packaging>=20.9"
]
[project.urls]
@@ -71,7 +72,7 @@ test = [
]
[tool.setuptools]
-package-data = { "*" = ["libbitsandbytes*.*"] }
+package-data = { "*" = ["libbitsandbytes*.*", "py.typed"] }
[tool.setuptools.packages.find]
include = ["bitsandbytes*"]
@@ -123,11 +124,10 @@ select = [
ignore = [
"B007", # Loop control variable not used within the loop body (TODO: enable)
"B028", # Warning without stacklevel (TODO: enable)
- "E501", # Supress line-too-long warnings: trust yapf's judgement on this one.
+ "E501", # Suppress line-too-long warnings: trust yapf's judgement on this one.
"E701", # Multiple statements on one line (TODO: enable)
"E712", # Allow using if x == False, as it's not always equivalent to if x.
"E731", # Do not use lambda
- "F841", # Local assigned but not used (TODO: enable, these are likely bugs)
"RUF012", # Mutable class attribute annotations
"RUF034", # Useless if-else (TODO: enable)
"ISC001", # single-line-implicit-string-concatenation incompatible with formatter
diff --git a/setup.py b/setup.py
index 8c84b2c73..a04630b8a 100644
--- a/setup.py
+++ b/setup.py
@@ -2,7 +2,11 @@
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
+from distutils.errors import DistutilsModuleError
+from warnings import warn
+
from setuptools import find_packages, setup
+from setuptools.command.build_py import build_py
from setuptools.dist import Distribution
@@ -12,4 +16,26 @@ def has_ext_modules(self):
return True
-setup(version="0.47.0.dev0", packages=find_packages(), distclass=BinaryDistribution)
+class ExtBuildPy(build_py):
+ def run(self):
+ # build_cmake needs to be called prior to build_py, as the latter
+ # collects the files output into the package directory.
+ try:
+ self.run_command("build_cmake")
+ except DistutilsModuleError:
+ warn(
+ "scikit-build-core not installed, CMake will not be invoked automatically. "
+ "Please install scikit-build-core or run CMake manually to build extensions."
+ )
+ super().run()
+
+
+setup(
+ version="0.48.0.dev0",
+ packages=find_packages(),
+ distclass=BinaryDistribution,
+ cmake_source_dir=".",
+ cmdclass={
+ "build_py": ExtBuildPy,
+ },
+)
diff --git a/tests/conftest.py b/tests/conftest.py
index a514e1284..f69b9ff2b 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -34,6 +34,8 @@ def pytest_runtest_teardown(item, nextitem):
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
+ elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
+ torch.mps.empty_cache()
@pytest.fixture(scope="session")
diff --git a/tests/helpers.py b/tests/helpers.py
index a87bc5d08..f1fa7eb62 100644
--- a/tests/helpers.py
+++ b/tests/helpers.py
@@ -18,12 +18,13 @@
@functools.cache
-def get_available_devices():
+def get_available_devices(no_cpu=False):
if "BNB_TEST_DEVICE" in os.environ:
# If the environment variable is set, use it directly.
- return [os.environ["BNB_TEST_DEVICE"]]
+ device = os.environ["BNB_TEST_DEVICE"]
+ return [] if no_cpu and device == "cpu" else [device]
- devices = [] if HIP_ENVIRONMENT else ["cpu"]
+ devices = [] if HIP_ENVIRONMENT else ["cpu"] if not no_cpu else []
if hasattr(torch, "accelerator"):
# PyTorch 2.6+ - determine accelerator using agnostic API.
diff --git a/tests/test_functional.py b/tests/test_functional.py
index b84db6502..6a4f72190 100644
--- a/tests/test_functional.py
+++ b/tests/test_functional.py
@@ -1,15 +1,16 @@
import math
+import platform
import random
import time
import einops
-import numpy as np
+from packaging import version
import pytest
import torch
import bitsandbytes as bnb
from bitsandbytes import functional as F
-from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH
+from bitsandbytes.cextension import HIP_ENVIRONMENT
from tests.helpers import (
BOOLEAN_TUPLES,
TRUE_FALSE,
@@ -101,16 +102,16 @@ class Test8BitBlockwiseQuantizeFunctional:
def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize, signed):
iters = 100
- if device == "cpu":
+ if device != "cuda":
iters = 10
- # This test is slow on CPU, so avoid atypical use cases.
+ # This test is slow in our non-CUDA implementations, so avoid atypical use cases.
if nested:
pytest.skip("Not a typical use case.")
if blocksize != 256:
- pytest.skip("Only blocksize 256 is used in CPU/XPU")
+ pytest.skip("Only blocksize 256 is used in CPU/MPS/XPU")
if dtype != torch.float32:
- pytest.skip("Only float32 is used in CPU/XPU")
+ pytest.skip("Only float32 is used in CPU/MPS/XPU")
diffs = []
reldiffs = []
@@ -142,11 +143,11 @@ def test_dynamic_blockwise_quantization(self, device, dtype, nested, blocksize,
abserr = sum(diffs) / len(diffs)
relerr = sum(reldiffs) / len(reldiffs)
if signed:
- threshold_abserr = 0.0036 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0035
+ threshold_abserr = 0.0035
assert abserr < 0.0036
assert relerr < 0.015
else:
- assert abserr < 0.00175 if device in ("cpu", "xpu") and (F.ipex_cpu or F.ipex_xpu) else 0.0023
+ assert abserr < 0.0023
assert relerr < 0.012
assert A2.dtype == dtype
@@ -177,8 +178,8 @@ def test_blockwise_cpu_large(self, hidden, blocksize):
@pytest.mark.parametrize("bits", range(2, 9), ids=id_formatter("bits"))
@pytest.mark.parametrize("method", ["linear", "fp8", "dynamic"])
def test_few_bit_quant(self, device, bits, method):
- if bits != 8 and (device == "cpu" or (device == "xpu" and F.ipex_xpu)):
- pytest.skip("CPU/XPU implementation only supports 8 bits")
+ if bits != 8 and device == "cpu":
+ pytest.skip("CPU implementation only supports 8 bits")
abserrs = []
relerrs = []
@@ -239,7 +240,7 @@ def test_fp8_quant(self, device):
abserr = []
relerr = []
- for i in range(100):
+ for i in range(10):
A1 = torch.randn(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
@@ -253,7 +254,7 @@ def test_fp8_quant(self, device):
abserr = []
relerr = []
- for i in range(100):
+ for i in range(10):
A1 = torch.rand(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1, code=code)
A2 = F.dequantize_blockwise(C, SC)
@@ -267,7 +268,7 @@ def test_fp8_quant(self, device):
abserr = []
relerr = []
- for i in range(100):
+ for i in range(10):
A1 = torch.randn(1024, 1024, device=device)
C, SC = F.quantize_blockwise(A1)
A2 = F.dequantize_blockwise(C, SC)
@@ -462,6 +463,7 @@ def test_dim3_igemm(self, seq_dim, hidden_dim, batch_dim):
@pytest.mark.parametrize("hidden_dim", [32, 1024 * 4], ids=id_formatter("hidden_dim"))
@pytest.mark.parametrize("batch_dim", [2, 16], ids=id_formatter("batch_dim"))
@pytest.mark.parametrize("transpose", TRUE_FALSE, ids=id_formatter("transpose"))
+ @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_minmax_igemm(self, seq_dim, hidden_dim, batch_dim, transpose):
def min_max(x):
maxA = torch.amax(x, dim=2, keepdim=True)
@@ -1109,6 +1111,7 @@ class TestQuantize4BitFunctional:
"blocksize",
[64, 128, 256, 512, 1024, 2048, 4096] if not HIP_ENVIRONMENT else [128, 256, 512, 1024, 2048, 4096],
)
+ @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_4bit_quant(self, device, dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
pytest.skip("This configuration is not supported on HPU.")
@@ -1125,21 +1128,56 @@ def test_4bit_quant(self, device, dtype, quant_type, blocksize):
# With larger block sizes, we can expect this to blow up.
# At blocksize>=1024, don't even bother looking at relerr.
- if blocksize <= 64:
- assert err.item() < 0.1
- assert relerr.item() < 0.28
- elif blocksize <= 256:
- assert err.item() < 0.11
- assert relerr.item() < 0.30
- elif blocksize <= 512:
- assert err.item() < 0.12
- assert relerr.item() < 0.31
- elif quant_type == "fp4":
- # 1024 => 0.48, 2048 => 0.52, 4096 => 0.56
- assert err.item() < 0.08 + math.log2(blocksize) * 4e-2
- else:
- # 1024 => 0.8, 2048 => 0.88, 4096 => 0.96
- assert err.item() < math.log2(blocksize) * 8e-2
+ #
+ # Actually, the above is not true anymore after fixing the integer packing bug.
+ # The following values were taken from averaging 1k samples per test configuration after fixing the bug.
+ error_dict = dict()
+ error_dict["fp4"] = dict()
+ error_dict["nf4"] = dict()
+ error_dict["fp4"]["err"] = {
+ 64: 0.096545,
+ 128: 0.102947,
+ 256: 0.108685,
+ 512: 0.114087,
+ 1024: 0.119312,
+ 2048: 0.124460,
+ 4096: 0.129573,
+ }
+ error_dict["fp4"]["rel_err"] = {
+ 64: 0.260130,
+ 128: 0.275734,
+ 256: 0.289842,
+ 512: 0.302852,
+ 1024: 0.314982,
+ 2048: 0.326402,
+ 4096: 0.337228,
+ }
+
+ error_dict["nf4"]["err"] = {
+ 64: 0.072792,
+ 128: 0.076835,
+ 256: 0.080326,
+ 512: 0.083535,
+ 1024: 0.086603,
+ 2048: 0.089592,
+ 4096: 0.092537,
+ }
+ error_dict["nf4"]["rel_err"] = {
+ 64: 0.203299,
+ 128: 0.215252,
+ 256: 0.226044,
+ 512: 0.236021,
+ 1024: 0.245365,
+ 2048: 0.254146,
+ 4096: 0.262457,
+ }
+
+ # Allow higher tolerance for fp32 on CPU with larger block sizes
+ reltol = 2.8e-3 if dtype == torch.float32 and blocksize >= 128 and device == "cpu" else 1e-3
+ errtol = 1.2e-3 if dtype == torch.float32 and blocksize >= 1024 and device == "cpu" else 1e-3
+
+ assert err < error_dict[quant_type]["err"][blocksize] + errtol
+ assert relerr < error_dict[quant_type]["rel_err"][blocksize] + reltol
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@@ -1238,8 +1276,8 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
max_errs3 = []
# Large number of iterations is excessive and slow on CPU.
- # Keep for CUDA for now.
- iters = 100 if device == "cuda" else 10
+ # Keep for CUDA/XPU for now.
+ iters = 10 if device == "cpu" else 100
for i in range(iters):
if kind == "fc1":
@@ -1341,13 +1379,13 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
assert err1 < 6e-5
assert relerr1 < 2e-4
assert absratio < 1.005 and absratio > 0.995
- assert relratio < 1.005 and relratio > 0.995
- assert maxratio < 1.005 and maxratio > 0.995
+ assert relratio < 1.005 and relratio > 0.992
+ assert maxratio < 1.005 and maxratio > 0.992
elif dtype == torch.float32:
if dim <= 512:
assert err1 < 5e-8
assert relerr1 < 1e-6
- assert maxerr1 < 1e-7
+ assert maxerr1 < 1.05e-7
else:
assert err1 < 5e-8
assert relerr1 < 8e-6
@@ -1357,34 +1395,34 @@ def test_gemv_4bit(self, device, dim, dtype, storage_type, quant_storage, double
assert maxratio < 1.005 and maxratio > 0.995
elif dtype == torch.bfloat16:
if dim <= 512:
+ relerr_thres = 0.013 if hasattr(torch, "xpu") and torch.xpu.is_available() else 0.007
assert err1 < 6e-4
- assert relerr1 < 0.007
+ assert relerr1 < relerr_thres
assert maxerr1 < 0.015
else:
assert err1 < 2e-4
assert relerr1 < 0.002
assert maxerr1 < 0.0012
assert absratio < 1.005 and absratio > 0.995
- assert relratio < 1.04 and relratio > 0.96
- assert maxratio < 1.02 and maxratio > 0.98
+ assert relratio < 1.05 and relratio > 0.96
+ assert maxratio < 1.05 and maxratio > 0.97
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("storage_type", ["nf4", "fp4"], ids=["nf4", "fp4"])
@pytest.mark.parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32], ids=describe_dtype)
- @pytest.mark.parametrize("double_quant", [False], ids=["DQ_True"])
- @pytest.mark.skipif(
- HIP_ENVIRONMENT and ROCM_GPU_ARCH == "gfx90a",
- reason="this test is not supported on ROCm with gfx90a architecture yet",
- )
- def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
- if device == "cpu" and dtype == torch.bfloat16 and torch.__version__ < (2, 3):
- pytest.skip("eye doe not support bfloat16 on CPU in torch < 2.3")
-
+ @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
+ def test_gemv_eye_4bit(self, device, storage_type, dtype):
if device == "hpu" and not is_supported_on_hpu(storage_type, dtype):
pytest.skip("This configuration is not supported on HPU.")
- dims = 10
- torch.random.manual_seed(np.random.randint(0, 412424242))
+ if (
+ device == "cpu"
+ and platform.system() == "Windows"
+ and version.parse(torch.__version__).release == (2, 8, 0)
+ ):
+ pytest.skip("Regression: CPU crash on Windows with torch 2.8.0")
+
+ dims = 4
dims = get_test_dims(0, 8192, n=dims)
dims = [dim + (64 - (dim % 64)) for dim in dims]
# for dim in [576, 5120, 3520, 5184, 1280, 4992, 5312, 2048]:
@@ -1392,7 +1430,7 @@ def test_gemv_eye_4bit(self, device, storage_type, dtype, double_quant):
A = torch.normal(0, 0.1, size=(1, 1, dim), dtype=dtype, device=device)
B = torch.eye(dim, dtype=dtype, device=device)
- qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=double_quant)
+ qB, state = F.quantize_4bit(B, quant_type=storage_type, compress_statistics=False)
C3 = torch.matmul(A, B.t())
C2 = bnb.matmul_4bit(A, qB.t(), state)
A.requires_grad = True
diff --git a/tests/test_generation.py b/tests/test_generation.py
index 38b5ce9bd..3ab1cc5bd 100644
--- a/tests/test_generation.py
+++ b/tests/test_generation.py
@@ -112,7 +112,7 @@ def test_pi(requires_cuda, model_and_tokenizer, inference_kernel, DQ, dtype):
assert len(outputs) == n_cases
failure_count = 0
for i in range(n_cases):
- if not outputs[i][: len(str(math.pi))] == str(math.pi):
+ if outputs[i][: len(str(math.pi))] != str(math.pi):
failure_count += 1
failure_max = 2 if fixture_config[0] == "huggyllama/llama-7b" else 4
if failure_count > failure_max:
diff --git a/tests/test_linear4bit.py b/tests/test_linear4bit.py
index e07b54d2d..1c5e77a32 100644
--- a/tests/test_linear4bit.py
+++ b/tests/test_linear4bit.py
@@ -212,6 +212,41 @@ def test_copy_param(device, quant_type, blocksize, compress_statistics):
assert param.data.data_ptr() == shallow_copy_param.data.data_ptr()
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
+def test_params4bit_torch_chunk_split(device, quant_type):
+ """Test that torch.chunk and torch.split preserve Params4bit subclass for FSDP2 compatibility."""
+ if device == "hpu" and not is_supported_on_hpu(quant_type, torch.float16, torch.uint8):
+ pytest.skip("This configuration is not supported on HPU.")
+
+ if device == "cpu":
+ pytest.skip("CPU quantization causes segfault, skipping CPU test")
+
+ original_tensor = torch.randn(8, 4, dtype=torch.float16, device="cpu")
+
+ params4bit = bnb.nn.Params4bit(data=original_tensor, quant_type=quant_type, requires_grad=False)
+
+ if device != "cpu":
+ params4bit = params4bit.to(device)
+
+ chunks = torch.chunk(params4bit, 2, dim=0)
+
+ assert isinstance(chunks, tuple), "torch.chunk should return tuple"
+ for chunk in chunks:
+ assert isinstance(chunk, bnb.nn.Params4bit), "Chunk should preserve Params4bit subclass"
+ assert hasattr(chunk, "quant_type"), "Should preserve metadata"
+ assert chunk.quant_type == params4bit.quant_type, "Should preserve quant_type value"
+
+ splits = torch.split(params4bit, 2, dim=0)
+
+ assert isinstance(splits, tuple), "torch.split should return tuple"
+ assert len(splits) > 0, "Should have at least one split"
+ for split in splits:
+ assert isinstance(split, bnb.nn.Params4bit), "Split should preserve Params4bit subclass"
+ assert hasattr(split, "quant_type"), "Should preserve metadata"
+ assert split.quant_type == params4bit.quant_type, "Should preserve quant_type value"
+
+
@pytest.mark.parametrize("device", get_available_devices())
@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
@pytest.mark.parametrize("blocksize", [64, 128] if not HIP_ENVIRONMENT else [128])
diff --git a/tests/test_linear8bitlt.py b/tests/test_linear8bitlt.py
index 86726bd44..51b4cf9cd 100644
--- a/tests/test_linear8bitlt.py
+++ b/tests/test_linear8bitlt.py
@@ -9,6 +9,7 @@
import torch
import bitsandbytes as bnb
+from bitsandbytes.cextension import HIP_ENVIRONMENT
from bitsandbytes.nn.modules import Linear8bitLt
from tests.helpers import (
TRUE_FALSE,
@@ -233,6 +234,7 @@ def test_linear8bit_serialization(linear8bit):
@pytest.mark.parametrize("fullgraph", TRUE_FALSE, ids=id_formatter("fullgraph"))
@pytest.mark.parametrize("mode", ["default", "reduce-overhead"], ids=id_formatter("mode"))
@pytest.mark.skipif(torch.__version__ < (2, 4), reason="Not supported in torch < 2.4")
+@pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
if device == "cuda" and platform.system() == "Windows":
pytest.skip("Triton is not officially supported on Windows")
@@ -272,14 +274,11 @@ def test_linear8bitlt_torch_compile(device, threshold, bias, fullgraph, mode):
# Test with gradients. Currently only works with threshold=0.
# Has a strange regression on Linux aarch64 CPU in torch==2.6.0.
- # There is also an issue with torch==2.7.0 on x86-64 with IPEX.
is_broken_platform = (
device == "cpu"
and platform.system() == "Linux"
- and (
- (platform.machine() == "aarch64" and (2, 6) <= torch.__version__ < (2, 7))
- or (platform.machine() == "x86_64" and bnb.functional.ipex_cpu)
- )
+ and platform.machine() == "aarch64"
+ and (2, 6) <= torch.__version__ < (2, 7)
)
if threshold == 0 and not is_broken_platform:
diff --git a/tests/test_modules.py b/tests/test_modules.py
index 8946522d3..e5682e5c8 100644
--- a/tests/test_modules.py
+++ b/tests/test_modules.py
@@ -143,9 +143,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0:
+ if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None
- if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).to(device).half()
@@ -156,9 +155,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0:
+ if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None
- if threshold > 0:
assert mlp.fc2.state.idx is not None
mlp = MLP8bit(32, 64, threshold=threshold, has_fp16_weights=False).half().to(device)
@@ -167,9 +165,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0:
+ if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None
- if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
@@ -189,9 +186,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0:
+ if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None
- if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
assert mlp.fc2.weight.dtype == torch.int8
@@ -211,9 +207,8 @@ def test_linear8bitlt_no_fp16_weights(device, threshold):
b1 = torch.randn(16, 8, 32, device=device, dtype=torch.float16)
o1 = mlp(b1)
assert o1.dtype == torch.float16
- if threshold > 0:
+ if threshold > 0 and device not in ("cpu", "xpu"):
assert mlp.fc1.state.idx is not None
- if threshold > 0:
assert mlp.fc2.state.idx is not None
assert mlp.fc1.weight.dtype == torch.int8
diff --git a/tests/test_ops.py b/tests/test_ops.py
index 8aa0560fd..02472630e 100644
--- a/tests/test_ops.py
+++ b/tests/test_ops.py
@@ -5,7 +5,6 @@
import bitsandbytes
from bitsandbytes.cextension import HIP_ENVIRONMENT
-from bitsandbytes.functional import ipex_xpu
from tests.helpers import TRUE_FALSE, get_available_devices, id_formatter, is_supported_on_hpu
# torch.library.opcheck is only available in torch 2.4 and later.
@@ -145,10 +144,6 @@ def test_dequantize_blockwise(self, device, dtype, blocksize):
assert out.dtype == dtype
assert out.device == A.device
- # TODO: Enable it
- if device == "xpu" and ipex_xpu:
- pytest.skip("XPU implementation have torch.op inside torch.op, it will fail on op check")
-
opcheck(torch.ops.bitsandbytes.dequantize_blockwise.default, (A, absmax, code, blocksize, dtype))
@@ -216,6 +211,7 @@ def test_dequantize_4bit(self, device, dtype, storage_dtype, quant_type, blocksi
@pytest.mark.parametrize("storage_dtype", [torch.uint8, torch.bfloat16], ids=id_formatter("storage_dtype"))
@pytest.mark.parametrize("quant_type", ["fp4", "nf4"])
@pytest.mark.parametrize("blocksize", [64, 128, 256, 512] if not HIP_ENVIRONMENT else [128, 256, 512])
+ @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm yet")
def test_gemv_4bit(self, device, dtype, storage_dtype, quant_type, blocksize):
if device == "hpu" and not is_supported_on_hpu(quant_type, dtype, storage_dtype):
pytest.skip("This configuration is not supported on HPU.")
diff --git a/tests/test_optim.py b/tests/test_optim.py
index 75e5a1714..3d4157152 100644
--- a/tests/test_optim.py
+++ b/tests/test_optim.py
@@ -11,7 +11,8 @@
import bitsandbytes as bnb
import bitsandbytes.functional as F
-from tests.helpers import describe_dtype, id_formatter
+from bitsandbytes.utils import sync_gpu
+from tests.helpers import describe_dtype, get_available_devices, id_formatter
# import apex
@@ -168,15 +169,23 @@ def rm_path(path):
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097, 1], ids=id_formatter("dim2"))
-def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
+@pytest.mark.parametrize("device", get_available_devices(no_cpu=True), ids=id_formatter("device"))
+@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
+def test_optimizer32bit(dim1, dim2, gtype, optim_name, device):
+ if device not in ["cuda", "xpu"]:
+ pytest.skip("Optimizers are only supported on CUDA and XPU")
+
if optim_name.startswith("paged_") and sys.platform == "win32":
pytest.skip("Paged optimizers can have issues on Windows.")
+ if optim_name.startswith("paged_") and device == "xpu":
+ pytest.skip("Paged optimizers are not supported on XPU currently.")
+
if gtype == torch.bfloat16 and optim_name in ["momentum", "rmsprop"]:
pytest.skip()
if dim1 == 1 and dim2 == 1:
return
- p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
+ p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
@@ -191,7 +200,7 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
atol, rtol = 1e-4, 1e-3
for i in range(k):
- g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
+ g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
@@ -201,14 +210,14 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
for name1, name2 in str2statenames[optim_name]:
torch.testing.assert_close(
torch_optimizer.state[p1][name1],
- bnb_optimizer.state[p2][name2].cuda(),
+ bnb_optimizer.state[p2][name2].to(device),
atol=atol,
rtol=rtol,
)
# since Lion can have pretty noisy updates where things lie at the boundary
- # allow up to 10 errors for Lion
- assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=10)
+ # allow up to 15 errors for Lion
+ assert_most_approx_close(p1, p2.float(), atol=atol, rtol=rtol, max_error_count=15)
if i % (k // 5) == 0 and i > 0:
path = get_temp_dir()
@@ -247,7 +256,12 @@ def test_optimizer32bit(requires_cuda, dim1, dim2, gtype, optim_name):
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16], ids=describe_dtype)
-def test_global_config(requires_cuda, dim1, dim2, gtype):
+@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
+@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
+def test_global_config(dim1, dim2, gtype, device):
+ if device not in ["cuda", "xpu"]:
+ pytest.skip("Optimizers are only supported on CUDA and XPU")
+
if dim1 == 1 and dim2 == 1:
return
p1 = torch.randn(dim1, dim2, device="cpu", dtype=gtype) * 0.1
@@ -263,9 +277,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
bnb.optim.GlobalOptimManager.get_instance().override_config(p3, "optim_bits", 8)
bnb.optim.GlobalOptimManager.get_instance().register_parameters([p1, p2, p3])
- p1 = p1.cuda()
- p2 = p2.cuda()
- p3 = p3.cuda()
+ p1 = p1.to(device)
+ p2 = p2.to(device)
+ p3 = p3.to(device)
adam2 = bnb.optim.Adam([p1, p2, p3], lr, (beta1, beta2), eps)
@@ -275,9 +289,9 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
atol, rtol = 1e-4, 1e-3
for i in range(50):
- g1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
- g2 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
- g3 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1 + 0.001
+ g1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001
+ g2 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001
+ g3 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1 + 0.001
p1.grad = g1
p2.grad = g2
p3.grad = g3
@@ -302,13 +316,18 @@ def test_global_config(requires_cuda, dim1, dim2, gtype):
@pytest.mark.parametrize("gtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
@pytest.mark.parametrize("dim2", [32, 1024, 4097], ids=id_formatter("dim2"))
@pytest.mark.parametrize("dim1", [1024], ids=id_formatter("dim1"))
-def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
+@pytest.mark.parametrize("device", get_available_devices(no_cpu=True))
+@pytest.mark.skipif(not get_available_devices(no_cpu=True), reason="No device")
+def test_optimizer8bit(dim1, dim2, gtype, optim_name, device):
+ if device not in ["cuda", "xpu"]:
+ pytest.skip("8-bit optimizers are only supported on CUDA and XPU")
+
torch.set_printoptions(precision=6)
if dim1 == 1 and dim2 == 1:
return
- p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
+ p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
p2 = p1.clone()
p1 = p1.float()
blocksize = 256
@@ -330,15 +349,15 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
relerrors = []
for i in range(50):
- g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
+ g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
p1.grad = g.clone().float()
p2.grad = g.clone()
- bnb_optimizer.step()
torch_optimizer.step()
+ bnb_optimizer.step()
# since Lion can have pretty noisy updates where things lie at the boundary
- assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
+ # assert_most_approx_close(p1, p2.float(), patol, prtol, max_error_count=0)
dequant_states = []
for name1, name2, qmap, max_val in str2statenames[optim_name]:
@@ -368,7 +387,7 @@ def test_optimizer8bit(requires_cuda, dim1, dim2, gtype, optim_name):
)
num_not_close = torch.isclose(torch_optimizer.state[p1][name1], s1, atol=atol, rtol=rtol) == 0
- # assert num_not_close.sum().item() < 20
+ assert num_not_close.sum().item() < 20
dequant_states.append(s1.clone())
err = torch.abs(p1 - p2)
@@ -549,25 +568,25 @@ def test_adam_percentile_clipping(requires_cuda, dim1, dim2, gtype, optim_bits):
@pytest.mark.parametrize("gtype", [torch.float32, torch.bfloat16, torch.float16], ids=describe_dtype)
@pytest.mark.parametrize("optim_name", optimizer_names_benchmark, ids=id_formatter("opt"))
@pytest.mark.benchmark
-def test_benchmark_blockwise(dim1, dim2, gtype, optim_name):
+def test_benchmark_blockwise(dim1, dim2, gtype, optim_name, device):
if dim1 == 1 and dim2 == 1:
return
- p1 = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.1
+ p1 = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.1
bnb_optimizer = str2optimizers[optim_name][1]([p1])
- g = torch.randn(dim1, dim2, device="cuda", dtype=gtype) * 0.01
+ g = torch.randn(dim1, dim2, device=device, dtype=gtype) * 0.01
p1.grad = g
total_steps = 500
for i in range(total_steps):
if i == total_steps // 5:
# 100 iterations for burn-in
- torch.cuda.synchronize()
+ sync_gpu(p1)
t0 = time.time()
bnb_optimizer.step()
- torch.cuda.synchronize()
+ sync_gpu(p1)
s = time.time() - t0
print("")
params = (total_steps - total_steps // 5) * dim1 * dim2
diff --git a/tests/test_parametrize.py b/tests/test_parametrize.py
new file mode 100644
index 000000000..9e661ee2f
--- /dev/null
+++ b/tests/test_parametrize.py
@@ -0,0 +1,411 @@
+import pytest
+import torch
+import torch.nn as nn
+
+from bitsandbytes import functional as F
+from bitsandbytes.cextension import HIP_ENVIRONMENT
+from bitsandbytes.nn.parametrize import (
+ Bnb4bitParametrization,
+ replace_parameter_4bit,
+ replace_parameter_4bit_prequantized,
+)
+from tests.helpers import (
+ TRUE_FALSE,
+ describe_dtype,
+ get_available_devices,
+ id_formatter,
+ is_supported_on_hpu,
+)
+
+
+class ParametrizeTestModule(nn.Module):
+ """Test module with different parameter shapes for testing parametrization."""
+
+ def __init__(self, device="cpu", dtype=torch.float32):
+ super().__init__()
+ # 2D parameter (typical weight matrix)
+ self.weight_2d = nn.Parameter(torch.randn(1024, 1024, device=device, dtype=dtype))
+ # 3D parameter (MoE expert weights - the main use case for this feature)
+ self.expert_weights = nn.Parameter(torch.randn(8, 512, 256, device=device, dtype=dtype))
+ # 1D parameter (bias-like)
+ self.bias_1d = nn.Parameter(torch.randn(1024, device=device, dtype=dtype))
+ # Non-parameter attribute (should not be quantizable)
+ self.not_param = torch.randn(32, device=device, dtype=dtype)
+
+
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
+@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
+@pytest.mark.parametrize(
+ "blocksize",
+ [64, 128, 256] if not HIP_ENVIRONMENT else [128, 256],
+)
+def test_replace_parameter_4bit(device, dtype, quant_type, compress_statistics, blocksize):
+ """Test basic parameter replacement with 4-bit quantization on different dtypes."""
+ if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
+ pytest.skip("This configuration is not supported on HPU.")
+
+ # Create module directly on target device to avoid unnecessary transfers
+ module = ParametrizeTestModule(device=device, dtype=dtype)
+ original_param = module.weight_2d.clone()
+
+ # Apply 4-bit quantization parametrization to the weight parameter
+ replace_parameter_4bit(
+ module, "weight_2d", compress_statistics=compress_statistics, quant_type=quant_type, blocksize=blocksize
+ )
+
+ # Verify that parametrization was applied correctly
+ assert hasattr(module, "parametrizations"), "Module should have parametrizations attribute"
+ assert "weight_2d" in module.parametrizations, "weight_2d should be parametrized"
+
+ # Test that accessing the parameter returns dequantized version with correct properties
+ reconstructed = module.weight_2d
+ assert reconstructed.shape == original_param.shape, "Shape should be preserved"
+ assert reconstructed.dtype == dtype, "dtype should match original"
+ assert reconstructed.device.type == device, "Device should match target"
+
+ # Verify quantization quality using same approach as functional tests
+ err = (original_param - reconstructed.detach()).abs().float()
+ relerr = (err / (original_param.abs().float() + 1e-8)).mean()
+ err_mean = err.mean()
+
+ # Expected error bounds from test_functional.py
+ expected_errors = {
+ "nf4": {
+ 64: {"abs": 0.072792, "rel": 0.203299},
+ 128: {"abs": 0.076835, "rel": 0.215252},
+ 256: {"abs": 0.080326, "rel": 0.226044},
+ },
+ "fp4": {
+ 64: {"abs": 0.096545, "rel": 0.260130},
+ 128: {"abs": 0.102947, "rel": 0.275734},
+ 256: {"abs": 0.108685, "rel": 0.289842},
+ },
+ }
+
+ assert err_mean < expected_errors[quant_type][blocksize]["abs"] + 1e-3, f"Mean abs error {err_mean:.6f} too high"
+ assert relerr < expected_errors[quant_type][blocksize]["rel"] + 1e-3, f"Mean rel error {relerr:.6f} too high"
+
+
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+def test_moe_parameter_shape(device, dtype):
+ """Test parametrization with MoE-style parameter shape"""
+ if device == "hpu" and not is_supported_on_hpu("nf4", dtype):
+ pytest.skip("This configuration is not supported on HPU.")
+
+ param_shape = (8, 64, 32)
+
+ # Create module with custom parameter shape directly on target device
+ class MoEModule(nn.Module):
+ def __init__(self, device, dtype):
+ super().__init__()
+ self.param = nn.Parameter(torch.randn(*param_shape, dtype=dtype, device=device))
+
+ module = MoEModule(device=device, dtype=dtype)
+ original_param = module.param.clone()
+
+ # Apply quantization parametrization
+ replace_parameter_4bit(module, "param", quant_type="nf4")
+
+ # Verify reconstruction maintains all properties
+ reconstructed = module.param
+ assert reconstructed.shape == param_shape, f"Shape should be preserved: {reconstructed.shape} vs {param_shape}"
+ assert reconstructed.dtype == dtype, "dtype should match original"
+ assert reconstructed.device.type == device, "Device should match target"
+
+ # Verify quantization quality using error calculation approach from functional tests
+ err = (original_param - reconstructed.detach()).abs().float()
+ relerr = (err / (original_param.abs().float() + 1e-8)).mean()
+ err_mean = err.mean()
+
+ # Use slightly looser bounds for higher dimensional tensors
+ abs_bound = 0.085 # NF4 baseline + margin
+ rel_bound = 0.25 # NF4 baseline + margin
+
+ assert err_mean < abs_bound, f"Mean abs error {err_mean:.6f} too high for shape {param_shape}"
+ assert relerr < rel_bound, f"Mean rel error {relerr:.6f} too high for shape {param_shape}"
+
+
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
+def test_prequantized_replacement(device, dtype, quant_type):
+ """Test applying parametrization to already quantized parameters."""
+ if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
+ pytest.skip("Configuration not supported on HPU.")
+
+ module = ParametrizeTestModule(device=device, dtype=dtype)
+ original_param = module.weight_2d.clone()
+
+ # Manually quantize the parameter data first (simulates loading pre-quantized weights)
+ quantized_data, quant_state = F.quantize_4bit(original_param.data, quant_type=quant_type)
+
+ # Replace parameter with quantized data (what would happen during model loading)
+ module.weight_2d = nn.Parameter(quantized_data, requires_grad=False)
+
+ # Apply parametrization to handle dequantization on access
+ replace_parameter_4bit_prequantized(
+ module, "weight_2d", quant_state.as_dict(packed=True), device=torch.device(device)
+ )
+
+ # Test that parameter access properly dequantizes
+ reconstructed = module.weight_2d
+ assert reconstructed.shape == original_param.shape, "Shape should be preserved"
+ assert reconstructed.dtype == dtype, "dtype should match original"
+ assert reconstructed.device.type == device, "Device should match target"
+
+
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+@pytest.mark.parametrize("quant_type", ["nf4", "fp4"])
+@pytest.mark.parametrize("compress_statistics", TRUE_FALSE, ids=id_formatter("compress_statistics"))
+@pytest.mark.skipif(torch.__version__ < (2, 5), reason="state dict hook requires torch >= 2.5.0")
+def test_state_dict_functionality(device, dtype, quant_type, compress_statistics):
+ """Test that state dict saving works with quantized parameters."""
+ if device == "hpu" and not is_supported_on_hpu(quant_type, dtype):
+ pytest.skip("Configuration not supported on HPU.")
+
+ module = ParametrizeTestModule(device=device, dtype=dtype)
+
+ # Apply parametrization to expert weights (main MoE use case)
+ replace_parameter_4bit(module, "expert_weights", quant_type=quant_type, compress_statistics=compress_statistics)
+
+ # Save state dict - should include quantization state, not parametrization internals
+ state_dict = module.state_dict()
+
+ # Verify state dict structure: quantized param + quantization metadata
+ assert "expert_weights" in state_dict, "Quantized parameter should be in state dict"
+ assert "expert_weights.absmax" in state_dict, "Quantization absmax should be saved"
+ assert "expert_weights.quant_map" in state_dict, "Quantization map should be saved"
+ assert f"expert_weights.quant_state.bitsandbytes__{quant_type}" in state_dict, "Quant state should be saved"
+
+ # Verify parametrization internals are NOT saved (clean state dict)
+ assert "parametrizations.expert_weights.original" not in state_dict, (
+ "Internal parametrization keys should not be saved"
+ )
+
+ # Test that the parameter can be accessed after state dict creation
+ reconstructed = module.expert_weights
+ assert reconstructed.shape == (8, 512, 256), "Shape should be preserved"
+ assert reconstructed.dtype == dtype, "dtype should match"
+
+
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+def test_moe_realistic_forward(device, dtype):
+ """Test realistic MoE forward computation with quantized expert weights."""
+ if device == "hpu" and not is_supported_on_hpu("nf4", dtype):
+ pytest.skip("Configuration not supported on HPU.")
+
+ class SimpleMoE(nn.Module):
+ def __init__(self, device, dtype):
+ super().__init__()
+ # Expert weights: [num_experts, input_dim, output_dim]
+ self.expert_weights = nn.Parameter(torch.randn(4, 32, 64, dtype=dtype, device=device))
+
+ def forward(self, x, expert_idx=0):
+ # Select and use specific expert weight matrix
+ expert_weight = self.expert_weights[expert_idx] # Shape: [input_dim, output_dim]
+ return torch.matmul(x, expert_weight)
+
+ module = SimpleMoE(device=device, dtype=dtype)
+ x = torch.randn(8, 32, dtype=dtype, device=device)
+
+ # Get reference output before quantization
+ with torch.no_grad():
+ reference_output = module(x, expert_idx=1)
+
+ # Apply 4-bit quantization to expert weights
+ replace_parameter_4bit(module, "expert_weights", quant_type="nf4")
+
+ # Get output after quantization - should be very close to original
+ with torch.no_grad():
+ quantized_output = module(x, expert_idx=1)
+
+ # Verify outputs match within quantization tolerance
+ assert quantized_output.shape == reference_output.shape, "Output shape should be preserved"
+
+ # Calculate error like functional tests (matrix ops may amplify quantization errors)
+ err = (reference_output - quantized_output).abs().float()
+ relerr = (err / (reference_output.abs().float() + 1e-8)).mean()
+ err_mean = err.mean()
+
+ # Allow for error amplification through matrix multiplication
+ assert err_mean < 0.5, f"Forward pass mean abs error {err_mean:.6f} too high"
+ assert relerr < 2.0, f"Forward pass mean rel error {relerr:.6f} too high"
+
+
+def test_error_conditions():
+ """Test that proper errors are raised for invalid inputs."""
+ module = ParametrizeTestModule()
+
+ # Test AttributeError for non-existent parameter
+ with pytest.raises(AttributeError, match="Module does not have parameter 'nonexistent'"):
+ replace_parameter_4bit(module, "nonexistent")
+
+ # Test TypeError for non-Parameter attribute
+ with pytest.raises(TypeError, match="Parameter 'not_param' is not an instance of nn.Parameter"):
+ replace_parameter_4bit(module, "not_param")
+
+ # Test same errors for prequantized version
+ with pytest.raises(AttributeError, match="Module does not have parameter 'nonexistent'"):
+ replace_parameter_4bit_prequantized(module, "nonexistent", {}, torch.device("cpu"))
+
+ with pytest.raises(TypeError, match="Parameter 'not_param' is not an instance of nn.Parameter"):
+ replace_parameter_4bit_prequantized(module, "not_param", {}, torch.device("cpu"))
+
+
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+@pytest.mark.skipif(torch.__version__ < (2, 5), reason="state dict hook requires torch >= 2.5.0")
+def test_quant_state_preservation(device, dtype):
+ """Test that quantization state is properly preserved and accessible."""
+ if device == "hpu" and not is_supported_on_hpu("nf4", dtype):
+ pytest.skip("Configuration not supported on HPU.")
+
+ module = ParametrizeTestModule(device=device, dtype=dtype)
+
+ blocksize = 128 if HIP_ENVIRONMENT else 64
+
+ # Apply parametrization with specific settings
+ replace_parameter_4bit(module, "weight_2d", quant_type="nf4", compress_statistics=True, blocksize=blocksize)
+
+ # Verify that quantization state is accessible through parametrization
+ parametrization = module.parametrizations.weight_2d[0]
+ assert isinstance(parametrization, Bnb4bitParametrization), "Should be Bnb4bitParametrization instance"
+
+ # Check quantization state properties
+ quant_state = parametrization.quant_state
+ assert isinstance(quant_state, F.QuantState), "Should have QuantState"
+ assert quant_state.quant_type == "nf4", "Quant type should be preserved"
+ assert quant_state.blocksize == blocksize, "Block size should be preserved"
+
+ # Verify that state dict includes all necessary quantization metadata
+ state_dict = module.state_dict()
+ quant_state_dict = quant_state.as_dict(packed=True)
+
+ for key in quant_state_dict.keys():
+ full_key = f"weight_2d.{key}"
+ assert full_key in state_dict, f"Quantization metadata '{full_key}' should be in state dict"
+
+
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+@pytest.mark.skipif(torch.__version__ < (2, 5), reason="state dict hook requires torch >= 2.5.0")
+def test_multiple_parameters(device, dtype):
+ """Test applying parametrization to multiple parameters in the same module."""
+ if device == "hpu" and not is_supported_on_hpu("nf4", dtype):
+ pytest.skip("Configuration not supported on HPU.")
+
+ module = ParametrizeTestModule(device=device, dtype=dtype)
+ original_2d = module.weight_2d.clone()
+ original_3d = module.expert_weights.clone()
+
+ # Apply parametrization to multiple parameters, with varying configurations
+ replace_parameter_4bit(module, "weight_2d", quant_type="nf4", blocksize=128)
+ replace_parameter_4bit(module, "expert_weights", quant_type="fp4", blocksize=256)
+
+ # Verify both parameters are parametrized and work correctly
+ reconstructed_2d = module.weight_2d
+ reconstructed_3d = module.expert_weights
+
+ assert reconstructed_2d.shape == original_2d.shape, "2D parameter shape should be preserved"
+ assert reconstructed_3d.shape == original_3d.shape, "3D parameter shape should be preserved"
+
+ # Check that state dict includes quantization info for both parameters
+ state_dict = module.state_dict()
+ assert "weight_2d" in state_dict, "2D parameter should be in state dict"
+ assert "expert_weights" in state_dict, "3D parameter should be in state dict"
+ assert "weight_2d.absmax" in state_dict, "2D parameter quantization metadata should be saved"
+ assert "expert_weights.absmax" in state_dict, "3D parameter quantization metadata should be saved"
+
+
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+@pytest.mark.parametrize(
+ "blocksize",
+ [64, 128, 256] if not HIP_ENVIRONMENT else [128, 256],
+)
+def test_different_blocksizes(device, dtype, blocksize):
+ """Test parametrization with different block sizes to verify flexibility."""
+ if device == "hpu" and not is_supported_on_hpu("nf4", dtype):
+ pytest.skip("Configuration not supported on HPU.")
+
+ module = ParametrizeTestModule(device=device, dtype=dtype)
+ original_param = module.expert_weights.clone()
+
+ # Apply parametrization with specified block size
+ replace_parameter_4bit(module, "expert_weights", quant_type="nf4", blocksize=blocksize)
+
+ # Verify reconstruction works with different block sizes
+ reconstructed = module.expert_weights
+ assert reconstructed.shape == original_param.shape, "Shape should be preserved"
+ assert reconstructed.device.type == device, "Device should match"
+
+ # Verify quantization quality using error calculation approach from functional tests
+ err = (original_param - reconstructed.detach()).abs().float()
+ relerr = (err / (original_param.abs().float() + 1e-8)).mean()
+ err_mean = err.mean()
+
+ # Expected error bounds from functional tests (using NF4 bounds since that's what we're testing)
+ expected_abs = {64: 0.072792, 128: 0.076835, 256: 0.080326}
+ expected_rel = {64: 0.203299, 128: 0.215252, 256: 0.226044}
+
+ assert err_mean < expected_abs[blocksize] + 0.01, (
+ f"Mean abs error {err_mean:.6f} too high for blocksize {blocksize}"
+ )
+ assert relerr < expected_rel[blocksize] + 0.02, f"Mean rel error {relerr:.6f} too high for blocksize {blocksize}"
+
+
+def test_parametrization_forward_method():
+ """Test the Bnb4bitParametrization forward method directly."""
+ device = "cpu"
+
+ # Create test tensor and manually quantize it
+ original_tensor = torch.randn(64, 32, dtype=torch.float32, device=device)
+ quantized_data, quant_state = F.quantize_4bit(original_tensor, quant_type="nf4")
+
+ # Create parametrization instance
+ parametrization = Bnb4bitParametrization(quant_state)
+
+ # Test forward pass (dequantization)
+ dequantized = parametrization.forward(quantized_data)
+
+ # Verify dequantization produces correct output
+ assert dequantized.shape == original_tensor.shape, "Shape should be preserved during dequantization"
+ assert dequantized.dtype == torch.float32, "dtype should be preserved"
+ assert dequantized.device == original_tensor.device, "Device should be preserved"
+
+ # Check that dequantization approximates original using mean error calculation
+ err = (original_tensor - dequantized.detach()).abs().float()
+ relerr = (err / (original_tensor.abs().float() + 1e-8)).mean()
+ err_mean = err.mean()
+
+ # Use NF4 bounds from functional tests with small margin
+ assert err_mean < 0.08, f"Mean abs error {err_mean:.6f} too high"
+ assert relerr < 0.25, f"Mean rel error {relerr:.6f} too high"
+
+
+@pytest.mark.parametrize("device", get_available_devices())
+@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.bfloat16], ids=describe_dtype)
+def test_gradient_behavior(device, dtype):
+ """Test that quantized parameters have proper gradient behavior."""
+ if device == "hpu" and not is_supported_on_hpu("nf4", dtype):
+ pytest.skip("Configuration not supported on HPU.")
+
+ module = ParametrizeTestModule(device=device, dtype=dtype)
+
+ # Ensure original parameter requires gradients
+ module.weight_2d.requires_grad_(True)
+ assert module.weight_2d.requires_grad, "Original parameter should require gradients"
+
+ # Apply quantization parametrization
+ replace_parameter_4bit(module, "weight_2d", quant_type="nf4")
+
+ # Verify that quantized parameters don't require gradients (expected behavior)
+ # The underlying quantized parameter should have requires_grad=False
+ # The dequantized output should also not require gradients
+ reconstructed = module.weight_2d
+ assert not reconstructed.requires_grad, "Dequantized parameter should not require gradients"