Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions bitsandbytes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
# LICENSE file in the root directory of this source tree.


import logging
import importlib
import sys

Expand All @@ -20,6 +21,10 @@
from .nn import modules
from .optim import adam

# Library logging should be opt-in for downstream users.
# (No handlers are configured by default; CLI entrypoints may configure logging.)
logging.getLogger(__name__).addHandler(logging.NullHandler())

# This is a signal for integrations with transformers/diffusers.
# Eventually we may remove this but it is currently required for compatibility.
features = {"multi_backend"}
Expand Down
17 changes: 9 additions & 8 deletions bitsandbytes/diagnostics/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ def find_cudart_libraries() -> Iterator[Path]:


def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print(
f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, "
f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.",
logger.info(
"PyTorch settings found: CUDA_VERSION=%s, Highest Compute Capability: %s.",
cuda_specs.cuda_version_string,
cuda_specs.highest_compute_capability,
)

binary_path = get_cuda_bnb_library_path(cuda_specs)
Expand All @@ -133,7 +134,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:


def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}")
logger.info("PyTorch settings found: ROCM_VERSION=%s", cuda_specs.cuda_version_string)

binary_path = get_cuda_bnb_library_path(cuda_specs)
if not binary_path.exists():
Expand Down Expand Up @@ -165,7 +166,7 @@ def print_diagnostics(cuda_specs: CUDASpecs) -> None:
def _print_cuda_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries())
if not cudart_paths:
print("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.")
logger.warning("CUDA SETUP: WARNING! CUDA runtime files not found in any environmental path.")
elif len(cudart_paths) > 1:
print_dedented(
f"""
Expand All @@ -186,13 +187,13 @@ def _print_cuda_runtime_diagnostics() -> None:
""",
)
for pth in cudart_paths:
print(f"* Found CUDA runtime at: {pth}")
logger.info("* Found CUDA runtime at: %s", pth)


def _print_hip_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries())
if not cudart_paths:
print("WARNING! ROCm runtime files not found in any environmental path.")
logger.warning("WARNING! ROCm runtime files not found in any environmental path.")
elif len(cudart_paths) > 1:
print_dedented(
f"""
Expand All @@ -209,7 +210,7 @@ def _print_hip_runtime_diagnostics() -> None:
)

for pth in cudart_paths:
print(f"* Found ROCm runtime at: {pth}")
logger.info("* Found ROCm runtime at: %s", pth)


def print_runtime_diagnostics() -> None:
Expand Down
49 changes: 30 additions & 19 deletions bitsandbytes/diagnostics/main.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import importlib
import logging
import os
import platform
import sys
import traceback
Expand Down Expand Up @@ -26,6 +28,8 @@
"trl",
]

logger = logging.getLogger(__name__)


def sanity_check():
from bitsandbytes.optim import Adam
Expand Down Expand Up @@ -53,24 +57,30 @@ def get_package_version(name: str) -> str:
def show_environment():
"""Simple utility to print out environment information."""

print(f"Platform: {platform.platform()}")
logger.info("Platform: %s", platform.platform())
if platform.system() == "Linux":
print(f" libc: {'-'.join(platform.libc_ver())}")
logger.info(" libc: %s", "-".join(platform.libc_ver()))

print(f"Python: {platform.python_version()}")
logger.info("Python: %s", platform.python_version())

print(f"PyTorch: {torch.__version__}")
print(f" CUDA: {torch.version.cuda or 'N/A'}")
print(f" HIP: {torch.version.hip or 'N/A'}")
print(f" XPU: {getattr(torch.version, 'xpu', 'N/A') or 'N/A'}")
logger.info("PyTorch: %s", torch.__version__)
logger.info(" CUDA: %s", torch.version.cuda or "N/A")
logger.info(" HIP: %s", torch.version.hip or "N/A")
logger.info(" XPU: %s", getattr(torch.version, "xpu", "N/A") or "N/A")

print("Related packages:")
logger.info("Related packages:")
for pkg in _RELATED_PACKAGES:
version = get_package_version(pkg)
print(f" {pkg}: {version}")
logger.info(" %s: %s", pkg, version)


def main():
# bitsandbytes' CLI entrypoint: configure logging for human-readable output.
# Library imports do not configure logging; downstream apps should decide.
level_name = os.environ.get("BNB_LOG_LEVEL", "INFO").upper()
level = getattr(logging, level_name, logging.INFO)
logging.basicConfig(level=level, format="%(message)s")

print_header(f"bitsandbytes v{bnb_version}")
show_environment()
print_header("")
Expand All @@ -84,29 +94,30 @@ def main():
# print_cuda_runtime_diagnostics()

if not torch.cuda.is_available():
print(f"PyTorch says {BNB_BACKEND} is not available. Possible reasons:")
print(f"1. {BNB_BACKEND} driver not installed")
print("2. Using a CPU-only PyTorch build")
print("3. No GPU detected")
logger.warning("PyTorch says %s is not available. Possible reasons:", BNB_BACKEND)
logger.warning("1. %s driver not installed", BNB_BACKEND)
logger.warning("2. Using a CPU-only PyTorch build")
logger.warning("3. No GPU detected")

else:
print(f"Checking that the library is importable and {BNB_BACKEND} is callable...")
logger.info("Checking that the library is importable and %s is callable...", BNB_BACKEND)

try:
sanity_check()
print("SUCCESS!")
logger.info("SUCCESS!")
return
except RuntimeError as e:
if "not available in CPU-only" in str(e):
print(
f"WARNING: {__package__} is currently running as CPU-only!\n"
logger.warning(
"WARNING: %s is currently running as CPU-only!\n"
"Therefore, 8-bit optimizers and GPU quantization are unavailable.\n\n"
f"If you think that this is so erroneously,\nplease report an issue!",
"If you think that this is so erroneously,\nplease report an issue!",
__package__,
)
else:
raise e
except Exception:
traceback.print_exc()
logger.exception("Diagnostics sanity check failed:")

print_dedented(
f"""
Expand Down
7 changes: 5 additions & 2 deletions bitsandbytes/diagnostics/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import logging
import textwrap

HEADER_WIDTH = 60

logger = logging.getLogger(__name__)


def print_header(txt: str, width: int = HEADER_WIDTH, filler: str = "=") -> None:
txt = f" {txt} " if txt else ""
print(txt.center(width, filler))
logger.info(txt.center(width, filler))


def print_dedented(text):
print("\n".join(textwrap.dedent(text).strip().split("\n")))
logger.info("\n".join(textwrap.dedent(text).strip().split("\n")))
8 changes: 6 additions & 2 deletions bitsandbytes/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import copy
import logging
from typing import Any, Optional, TypeVar, Union, overload
import warnings

Expand All @@ -23,6 +24,8 @@

T = TypeVar("T", bound="torch.nn.Module")

logger = logging.getLogger(__name__)


class StableEmbedding(torch.nn.Embedding):
"""
Expand Down Expand Up @@ -1115,9 +1118,10 @@ def forward(self, x):
if self.outlier_dim is None:
tracer = OutlierTracer.get_instance()
if not tracer.is_initialized():
print("Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer")
logger.warning(
"Please use OutlierTracer.initialize(model) before using the OutlierAwareLinear layer",
)
outlier_idx = tracer.get_outliers(self.weight)
# print(outlier_idx, tracer.get_hvalue(self.weight))
self.outlier_dim = outlier_idx

if not self.is_quantized:
Expand Down
9 changes: 6 additions & 3 deletions bitsandbytes/nn/triton_based_modules.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from functools import partial
import logging

import torch
import torch.nn as nn
Expand All @@ -20,6 +21,8 @@
from bitsandbytes.triton.quantize_rowwise import quantize_rowwise
from bitsandbytes.triton.triton_utils import is_triton_available

logger = logging.getLogger(__name__)


class _switchback_global(torch.autograd.Function):
@staticmethod
Expand Down Expand Up @@ -173,8 +176,8 @@ def __init__(
if self.vector_wise_quantization:
self._fn = _switchback_vectorrize
if mem_efficient:
print("mem efficient is not supported for vector-wise quantization.")
exit(1)
logger.error("mem efficient is not supported for vector-wise quantization.")
raise ValueError("mem_efficient is not supported for vector-wise quantization.")
else:
if mem_efficient:
self._fn = _switchback_global_mem_efficient
Expand All @@ -189,7 +192,7 @@ def prepare_for_eval(self):
# if hasattr(m, "prepare_for_eval"):
# m.prepare_for_eval()
# model.apply(cond_prepare)
print("=> preparing for eval.")
logger.info("Preparing SwitchBackLinear for eval.")
if self.vector_wise_quantization:
W_int8, state_W = quantize_rowwise(self.weight)
else:
Expand Down
3 changes: 0 additions & 3 deletions bitsandbytes/research/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@ def forward(ctx, A, B, out=None, bias=None, state: Optional[MatmulLtState] = Non

# 2. Quantize B
if state.has_fp16_weights:
# print('B shape', B.shape)
has_grad = getattr(B, "grad", None) is not None
is_transposed = not B.is_contiguous() and B.shape[0] == B.stride(1)
if is_transposed:
Expand Down Expand Up @@ -323,8 +322,6 @@ def backward(ctx, grad_output):
_Cgrad, _Cgradt, _SCgrad, _SCgradt, _outlier_cols = F.int8_double_quant(grad_output.to(torch.float16))

if req_gradB:
# print('back A shape', A.shape)
# print('grad output t shape', grad_output.t().shape)
grad_B = torch.matmul(grad_output.t(), A)

if req_gradA:
Expand Down
14 changes: 10 additions & 4 deletions bitsandbytes/triton/matmul_perf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import functools
import heapq
import logging

import torch

Expand All @@ -15,6 +16,8 @@
nvsmi,
)

logger = logging.getLogger(__name__)


@functools.lru_cache
def get_clock_rate_in_khz():
Expand Down Expand Up @@ -125,10 +128,13 @@ def estimate_matmul_time(

total_time_ms = max(compute_ms, load_ms) + store_ms
if debug:
print(
f"Total time: {total_time_ms}ms, compute time: {compute_ms}ms, "
f"loading time: {load_ms}ms, store time: {store_ms}ms, "
f"Activate CTAs: {active_cta_ratio * 100}%"
logger.debug(
"Total time: %sms, compute time: %sms, loading time: %sms, store time: %sms, Activate CTAs: %s%%",
total_time_ms,
compute_ms,
load_ms,
store_ms,
active_cta_ratio * 100,
)
return total_time_ms

Expand Down
5 changes: 4 additions & 1 deletion bitsandbytes/utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
import json
import logging
import shlex
import subprocess

import torch

logger = logging.getLogger(__name__)


def outlier_hook(module, input):
assert isinstance(module, torch.nn.Linear)
Expand Down Expand Up @@ -65,7 +68,7 @@ def get_hvalue(self, weight):

def get_outliers(self, weight):
if not self.is_initialized():
print("Outlier tracer is not initialized...")
logger.warning("Outlier tracer is not initialized...")
return None
hvalue = self.get_hvalue(weight)
if hvalue in self.hvalue2outlier_idx:
Expand Down