Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions bitsandbytes/autograd/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@

import torch

from bitsandbytes.cextension import BNB_HIP_VERSION
import bitsandbytes.functional as F
from bitsandbytes.gpu_specs import get_compute_capabilities


# math.prod not compatible with python < 3.8
Expand Down Expand Up @@ -224,8 +224,8 @@ def supports_igemmlt(device: torch.device) -> bool:
if device == torch.device("cpu"):
return True
if torch.version.hip:
return False if BNB_HIP_VERSION < 601 else True
if torch.cuda.get_device_capability(device=device) < (7, 5):
return False if get_compute_capabilities() < (6, 1) else True
if get_compute_capabilities() < (7, 5):
return False
device_name = torch.cuda.get_device_name(device=device)
nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/backends/cuda.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from .base import Backend

if lib and lib.compiled_with_cuda:
if lib and lib.compiled_with_gpu:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are only for cuda and rocm as of now, may be use BNB_BACKEND from cextension instead?

"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
"adam": (
Expand Down
51 changes: 23 additions & 28 deletions bitsandbytes/cextension.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,29 +24,28 @@
import torch

from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR
from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch
from bitsandbytes.gpu_specs import GPUSpecs, get_gpu_specs, get_rocm_gpu_arch

logger = logging.getLogger(__name__)


def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:
def get_gpu_bnb_library_path(gpu_specs: GPUSpecs) -> Path:
"""
Get the disk path to the CUDA BNB native library specified by the
given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable.
Get the disk path to the GPU BNB native library specified by the
given GPU specs, taking into account the `BNB_GPU_VERSION` override environment variable.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave env var as BNB_CUDA_VERSION and mention 'for cuda'


The library is not guaranteed to exist at the returned path.
"""
if torch.version.hip:
if BNB_HIP_VERSION < 601:
return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}_nohipblaslt{DYNAMIC_LIBRARY_SUFFIX}"
else:
return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}{DYNAMIC_LIBRARY_SUFFIX}"
library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}"
if not cuda_specs.has_cublaslt:
library_name = f"libbitsandbytes_{gpu_specs.gpu_backend}{gpu_specs.backend_version_string}"
if not gpu_specs.enable_blaslt:
# if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt
library_name += "_nocublaslt"
if gpu_specs.gpu_backend == "rocm":
library_name += "_nohipblaslt"
else:
library_name += "_nocublaslt"
library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}"

# Do I need to change it to BNB_GPU_VERSION here? IGNORE FOR NOW!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove this comment

override_value = os.environ.get("BNB_CUDA_VERSION")
if override_value:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add condition for cuda backend

library_name_stem, _, library_name_ext = library_name.rpartition(".")
Expand All @@ -69,7 +68,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path:

class BNBNativeLibrary:
_lib: ct.CDLL
compiled_with_cuda = False
compiled_with_gpu = False

def __init__(self, lib: ct.CDLL):
self._lib = lib
Expand All @@ -78,8 +77,8 @@ def __getattr__(self, item):
return getattr(self._lib, item)


class CudaBNBNativeLibrary(BNBNativeLibrary):
compiled_with_cuda = True
class GpuBNBNativeLibrary(BNBNativeLibrary):
compiled_with_gpu = True

def __init__(self, lib: ct.CDLL):
super().__init__(lib)
Expand All @@ -93,18 +92,18 @@ def __init__(self, lib: ct.CDLL):

def get_native_library() -> BNBNativeLibrary:
binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}"
cuda_specs = get_cuda_specs()
if cuda_specs:
cuda_binary_path = get_cuda_bnb_library_path(cuda_specs)
if cuda_binary_path.exists():
binary_path = cuda_binary_path
gpu_specs = get_gpu_specs()
if gpu_specs:
gpu_binary_path = get_gpu_bnb_library_path(gpu_specs)
if gpu_binary_path.exists():
binary_path = gpu_binary_path
else:
logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path)
logger.warning("Could not find the bitsandbytes %s binary at %r", gpu_specs.gpu_backend, gpu_binary_path)
logger.debug(f"Loading bitsandbytes native library from: {binary_path}")
dll = ct.cdll.LoadLibrary(str(binary_path))

if hasattr(dll, "get_context"): # only a CUDA-built library exposes this
return CudaBNBNativeLibrary(dll)
return GpuBNBNativeLibrary(dll)

return BNBNativeLibrary(dll)

Expand All @@ -113,15 +112,11 @@ def get_native_library() -> BNBNativeLibrary:

try:
if torch.version.hip:
hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2])
HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor
BNB_HIP_VERSION_SHORT = f"{hip_major}{hip_minor}"
BNB_BACKEND = "ROCm"
HIP_ENVIRONMENT = True
else:
HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0
BNB_HIP_VERSION_SHORT = ""
BNB_BACKEND = "CUDA"

HIP_ENVIRONMENT = False
lib = get_native_library()
except Exception as e:
lib = None
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@

import torch

from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path
from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path
from bitsandbytes.consts import NONPYTORCH_DOC_URL
from bitsandbytes.cuda_specs import CUDASpecs
from bitsandbytes.diagnostics.utils import print_dedented
from bitsandbytes.gpu_specs import GPUSpecs

CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH")
GPU_RT_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH")

CUDART_PATH_IGNORED_ENVVARS = {
GPU_RT_PATH_IGNORED_ENVVARS = {
"DBUS_SESSION_BUS_ADDRESS", # hardware related
"GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks
"HOME", # Linux shell default
Expand Down Expand Up @@ -46,7 +46,7 @@ def get_runtime_lib_patterns() -> tuple:
)


def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]:
def find_gpu_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]:
for dir_string in paths_list_candidate.split(os.pathsep):
if not dir_string:
continue
Expand All @@ -70,10 +70,10 @@ def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path

def is_relevant_candidate_env_var(env_var: str, value: str) -> bool:
return (
env_var in CUDART_PATH_PREFERRED_ENVVARS # is a preferred location
env_var in GPU_RT_PATH_PREFERRED_ENVVARS # is a preferred location
or (
os.sep in value # might contain a path
and env_var not in CUDART_PATH_IGNORED_ENVVARS # not ignored
and env_var not in GPU_RT_PATH_IGNORED_ENVVARS # not ignored
and "CONDA" not in env_var # not another conda envvar
and "BASH_FUNC" not in env_var # not a bash function defined via envvar
and "\n" not in value # likely e.g. a script or something?
Expand All @@ -85,7 +85,7 @@ def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]:
return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)}


def find_cudart_libraries() -> Iterator[Path]:
def find_gpu_rt_libraries() -> Iterator[Path]:
"""
Searches for a cuda installations, in the following order of priority:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please make this comment generic for gpu backend

1. active conda env
Expand All @@ -99,23 +99,23 @@ def find_cudart_libraries() -> Iterator[Path]:
"""
candidate_env_vars = get_potentially_lib_path_containing_env_vars()

for envvar in CUDART_PATH_PREFERRED_ENVVARS:
for envvar in GPU_RT_PATH_PREFERRED_ENVVARS:
if envvar in candidate_env_vars:
directory = candidate_env_vars[envvar]
yield from find_cuda_libraries_in_path_list(directory)
yield from find_gpu_libraries_in_path_list(directory)
candidate_env_vars.pop(envvar)

for env_var, value in candidate_env_vars.items():
yield from find_cuda_libraries_in_path_list(value)
yield from find_gpu_libraries_in_path_list(value)


def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
def _print_cuda_diagnostics(gpu_specs: GPUSpecs) -> None:
print(
f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, "
f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.",
f"PyTorch settings found: CUDA_VERSION={gpu_specs.backend_version_string}, "
f"Highest Compute Capability: {gpu_specs.highest_compute_capability}.",
)

binary_path = get_cuda_bnb_library_path(cuda_specs)
binary_path = get_gpu_bnb_library_path(gpu_specs)
if not binary_path.exists():
print_dedented(
f"""
Expand All @@ -128,7 +128,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
""",
)

cuda_major, cuda_minor = cuda_specs.cuda_version_tuple
cuda_major, cuda_minor = gpu_specs.backend_version_tuple
if cuda_major < 11:
print_dedented(
"""
Expand All @@ -140,7 +140,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}")

# 7.5 is the minimum CC for cublaslt
if not cuda_specs.has_cublaslt:
if not gpu_specs.enable_blaslt:
print_dedented(
"""
WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU!
Expand All @@ -154,10 +154,10 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None:
# (2) Multiple CUDA versions installed


def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}")
def _print_hip_diagnostics(gpu_specs: GPUSpecs) -> None:
print(f"PyTorch settings found: ROCM_VERSION={gpu_specs.backend_version_string}")

binary_path = get_cuda_bnb_library_path(cuda_specs)
binary_path = get_gpu_bnb_library_path(gpu_specs)
if not binary_path.exists():
print_dedented(
f"""
Expand All @@ -168,7 +168,7 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
""",
)

hip_major, hip_minor = cuda_specs.cuda_version_tuple
hip_major, hip_minor = gpu_specs.backend_version_tuple
if (hip_major, hip_minor) < (6, 1):
print_dedented(
"""
Expand All @@ -177,18 +177,18 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None:
)


def print_diagnostics(cuda_specs: CUDASpecs) -> None:
def print_diagnostics(gpu_specs: GPUSpecs) -> None:
if HIP_ENVIRONMENT:
_print_hip_diagnostics(cuda_specs)
_print_hip_diagnostics(gpu_specs)
else:
_print_cuda_diagnostics(cuda_specs)
_print_cuda_diagnostics(gpu_specs)


def _print_cuda_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries())
if not cudart_paths:
gpu_rt_paths = list(find_gpu_rt_libraries())
if not gpu_rt_paths:
print("WARNING! CUDA runtime files not found in any environmental path.")
elif len(cudart_paths) > 1:
elif len(gpu_rt_paths) > 1:
print_dedented(
f"""
Found duplicate CUDA runtime files (see below).
Expand All @@ -207,15 +207,15 @@ def _print_cuda_runtime_diagnostics() -> None:
export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2,
""",
)
for pth in cudart_paths:
for pth in gpu_rt_paths:
print(f"* Found CUDA runtime at: {pth}")


def _print_hip_runtime_diagnostics() -> None:
cudart_paths = list(find_cudart_libraries())
if not cudart_paths:
gpu_rt_paths = list(find_gpu_rt_libraries())
if not gpu_rt_paths:
print("WARNING! ROCm runtime files not found in any environmental path.")
elif len(cudart_paths) > 1:
elif len(gpu_rt_paths) > 1:
print_dedented(
f"""
Found duplicate ROCm runtime files (see below).
Expand All @@ -230,7 +230,7 @@ def _print_hip_runtime_diagnostics() -> None:
""",
)

for pth in cudart_paths:
for pth in gpu_rt_paths:
print(f"* Found ROCm runtime at: {pth}")


Expand Down
16 changes: 8 additions & 8 deletions bitsandbytes/diagnostics/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,12 @@

from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT
from bitsandbytes.consts import PACKAGE_GITHUB_URL
from bitsandbytes.cuda_specs import get_cuda_specs
from bitsandbytes.diagnostics.cuda import (
from bitsandbytes.diagnostics.gpu import (
print_diagnostics,
print_runtime_diagnostics,
)
from bitsandbytes.diagnostics.utils import print_dedented, print_header
from bitsandbytes.gpu_specs import get_gpu_specs


def sanity_check():
Expand Down Expand Up @@ -50,20 +50,20 @@ def main():
print_header("")

print_header("OTHER")
cuda_specs = get_cuda_specs()
gpu_specs = get_gpu_specs()
if HIP_ENVIRONMENT:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Refactor the code to print gpu specs directly

rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}',"
rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}"
rocm_specs = f" rocm_version_string='{gpu_specs.backend_version_string}',"
rocm_specs += f" rocm_version_tuple={gpu_specs.backend_version_tuple}"
print(f"{BNB_BACKEND} specs:{rocm_specs}")
else:
print(f"{BNB_BACKEND} specs:{cuda_specs}")
print(f"{BNB_BACKEND} specs:{gpu_specs}")
if not torch.cuda.is_available():
print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:")
print(f"1. {BNB_BACKEND} driver not installed")
print(f"2. {BNB_BACKEND} not installed")
print(f"3. You have multiple conflicting {BNB_BACKEND} libraries")
if cuda_specs:
print_diagnostics(cuda_specs)
if gpu_specs:
print_diagnostics(gpu_specs)
print_runtime_diagnostics()
print_header("")
print_header("DEBUG INFO END")
Expand Down
2 changes: 1 addition & 1 deletion bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def prod(iterable):

name2qmap = {}

if lib and lib.compiled_with_cuda:
if lib and lib.compiled_with_gpu:

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Change this to BNB_BACKEND cuda and rocm

"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
"adam": (
Expand Down
Loading
Loading