From 98be798a796cb67fbeb6f7e81b5f61ac6860acc5 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Mon, 4 Nov 2024 15:01:50 +0000 Subject: [PATCH 01/10] implement GPUSpecs --- bitsandbytes/autograd/_functions.py | 6 +- bitsandbytes/backends/cuda.py | 2 +- bitsandbytes/cextension.py | 51 ++++++++--------- bitsandbytes/cuda_specs.py | 25 +-------- bitsandbytes/functional.py | 2 +- bitsandbytes/gpu_specs.py | 86 +++++++++++++++++++++++++++++ tests/test_autograd.py | 4 +- tests/test_cuda_setup_evaluator.py | 14 ++--- tests/test_functional.py | 7 ++- 9 files changed, 128 insertions(+), 69 deletions(-) create mode 100644 bitsandbytes/gpu_specs.py diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 59e26ad09..02a7f45af 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -7,7 +7,7 @@ import torch -from bitsandbytes.cextension import BNB_HIP_VERSION +from bitsandbytes.gpu_specs import get_compute_capabilities import bitsandbytes.functional as F @@ -224,8 +224,8 @@ def supports_igemmlt(device: torch.device) -> bool: if device == torch.device("cpu"): return True if torch.version.hip: - return False if BNB_HIP_VERSION < 601 else True - if torch.cuda.get_device_capability(device=device) < (7, 5): + return False if get_compute_capabilities() < 601 else True + if get_compute_capabilities() < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) nvidia16_models = ("GTX 1630", "GTX 1650", "GTX 1660") # https://en.wikipedia.org/wiki/GeForce_16_series diff --git a/bitsandbytes/backends/cuda.py b/bitsandbytes/backends/cuda.py index ad478431c..53edc94ad 100644 --- a/bitsandbytes/backends/cuda.py +++ b/bitsandbytes/backends/cuda.py @@ -24,7 +24,7 @@ from .base import Backend -if lib and lib.compiled_with_cuda: +if lib and lib.compiled_with_gpu: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { "adam": ( diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index cc5d8deff..532f6970b 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -24,29 +24,28 @@ import torch from bitsandbytes.consts import DYNAMIC_LIBRARY_SUFFIX, PACKAGE_DIR -from bitsandbytes.cuda_specs import CUDASpecs, get_cuda_specs, get_rocm_gpu_arch +from bitsandbytes.gpu_specs import GPUSpecs, get_gpu_specs, get_rocm_gpu_arch logger = logging.getLogger(__name__) -def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: +def get_gpu_bnb_library_path(gpu_specs: GPUSpecs) -> Path: """ - Get the disk path to the CUDA BNB native library specified by the - given CUDA specs, taking into account the `BNB_CUDA_VERSION` override environment variable. + Get the disk path to the GPU BNB native library specified by the + given GPU specs, taking into account the `BNB_GPU_VERSION` override environment variable. The library is not guaranteed to exist at the returned path. """ - if torch.version.hip: - if BNB_HIP_VERSION < 601: - return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}_nohipblaslt{DYNAMIC_LIBRARY_SUFFIX}" - else: - return PACKAGE_DIR / f"libbitsandbytes_rocm{BNB_HIP_VERSION_SHORT}{DYNAMIC_LIBRARY_SUFFIX}" - library_name = f"libbitsandbytes_cuda{cuda_specs.cuda_version_string}" - if not cuda_specs.has_cublaslt: + library_name = f"libbitsandbytes_{gpu_specs.gpu_backend}{gpu_specs.backend_version_string}" + if not gpu_specs.has_blaslt: # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt - library_name += "_nocublaslt" + if gpu_specs.gpu_backend == "rocm": + library_name += "_nohipblaslt" + else: + library_name += "_nocublaslt" library_name = f"{library_name}{DYNAMIC_LIBRARY_SUFFIX}" + # Do I need to change it to BNB_GPU_VERSION here? IGNORE FOR NOW! override_value = os.environ.get("BNB_CUDA_VERSION") if override_value: library_name_stem, _, library_name_ext = library_name.rpartition(".") @@ -69,7 +68,7 @@ def get_cuda_bnb_library_path(cuda_specs: CUDASpecs) -> Path: class BNBNativeLibrary: _lib: ct.CDLL - compiled_with_cuda = False + compiled_with_gpu = False def __init__(self, lib: ct.CDLL): self._lib = lib @@ -78,8 +77,8 @@ def __getattr__(self, item): return getattr(self._lib, item) -class CudaBNBNativeLibrary(BNBNativeLibrary): - compiled_with_cuda = True +class GpuBNBNativeLibrary(BNBNativeLibrary): + compiled_with_gpu = True def __init__(self, lib: ct.CDLL): super().__init__(lib) @@ -93,18 +92,18 @@ def __init__(self, lib: ct.CDLL): def get_native_library() -> BNBNativeLibrary: binary_path = PACKAGE_DIR / f"libbitsandbytes_cpu{DYNAMIC_LIBRARY_SUFFIX}" - cuda_specs = get_cuda_specs() - if cuda_specs: - cuda_binary_path = get_cuda_bnb_library_path(cuda_specs) - if cuda_binary_path.exists(): - binary_path = cuda_binary_path + gpu_specs = get_gpu_specs() + if gpu_specs: + gpu_binary_path = get_gpu_bnb_library_path(gpu_specs) + if gpu_binary_path.exists(): + binary_path = gpu_binary_path else: - logger.warning("Could not find the bitsandbytes %s binary at %r", BNB_BACKEND, cuda_binary_path) + logger.warning("Could not find the bitsandbytes %s binary at %r", gpu_specs.gpu_backend, gpu_binary_path) logger.debug(f"Loading bitsandbytes native library from: {binary_path}") dll = ct.cdll.LoadLibrary(str(binary_path)) if hasattr(dll, "get_context"): # only a CUDA-built library exposes this - return CudaBNBNativeLibrary(dll) + return GpuBNBNativeLibrary(dll) return BNBNativeLibrary(dll) @@ -113,15 +112,11 @@ def get_native_library() -> BNBNativeLibrary: try: if torch.version.hip: - hip_major, hip_minor = map(int, torch.version.hip.split(".")[0:2]) - HIP_ENVIRONMENT, BNB_HIP_VERSION = True, hip_major * 100 + hip_minor - BNB_HIP_VERSION_SHORT = f"{hip_major}{hip_minor}" BNB_BACKEND = "ROCm" + HIP_ENVIRONMENT = True else: - HIP_ENVIRONMENT, BNB_HIP_VERSION = False, 0 - BNB_HIP_VERSION_SHORT = "" BNB_BACKEND = "CUDA" - + HIP_ENVIRONMENT = False lib = get_native_library() except Exception as e: lib = None diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py index 0afecd3ea..d77c5a3fd 100644 --- a/bitsandbytes/cuda_specs.py +++ b/bitsandbytes/cuda_specs.py @@ -44,27 +44,4 @@ def get_cuda_specs() -> Optional[CUDASpecs]: highest_compute_capability=(get_compute_capabilities()[-1]), cuda_version_string=(get_cuda_version_string()), cuda_version_tuple=get_cuda_version_tuple(), - ) - - -def get_rocm_gpu_arch() -> str: - logger = logging.getLogger(__name__) - try: - if torch.version.hip: - result = subprocess.run(["rocminfo"], capture_output=True, text=True) - match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) - if match: - return "gfx" + match.group(1) - else: - return "unknown" - else: - return "unknown" - except Exception as e: - logger.error(f"Could not detect ROCm GPU architecture: {e}") - if torch.cuda.is_available(): - logger.warning( - """ -ROCm GPU architecture detection failed despite ROCm being available. - """, - ) - return "unknown" + ) \ No newline at end of file diff --git a/bitsandbytes/functional.py b/bitsandbytes/functional.py index 6cf64df28..6fa74d5aa 100644 --- a/bitsandbytes/functional.py +++ b/bitsandbytes/functional.py @@ -25,7 +25,7 @@ def prod(iterable): name2qmap = {} -if lib and lib.compiled_with_cuda: +if lib and lib.compiled_with_gpu: """C FUNCTIONS FOR OPTIMIZERS""" str2optimizer32bit = { "adam": ( diff --git a/bitsandbytes/gpu_specs.py b/bitsandbytes/gpu_specs.py new file mode 100644 index 000000000..49517f28c --- /dev/null +++ b/bitsandbytes/gpu_specs.py @@ -0,0 +1,86 @@ +import dataclasses +import logging +import re +import subprocess +from typing import List, Optional, Tuple, Union + +import torch + + +@dataclasses.dataclass(frozen=True) +class GPUSpecs: + gpu_backend: str + highest_compute_capability: Union[int, Tuple[int, int]] + backend_version_string: str + backend_version_tuple: Tuple[int, int] + + @property + def has_blaslt(self) -> bool: + if torch.version.hip: + return self.highest_compute_capability >= 601 + else: + return self.highest_compute_capability >= (7, 5) + + +def get_gpu_backend() -> str: + if torch.version.hip: + return "rocm" + else: + return "cuda" + + +def get_compute_capabilities() -> Union[int, Tuple[int, int]]: + if torch.version.hip: + hip_major, hip_minor = get_backend_version_tuple() + return hip_major * 100 + hip_minor + else: + return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()))[-1] + + +def get_backend_version_tuple() -> Tuple[int, int]: + # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION + if torch.version.cuda: + major, minor = map(int, torch.version.cuda.split(".")) + elif torch.version.hip: + major, minor = map(int, torch.version.hip.split(".")[0:2]) + return major, minor + + +def get_backend_version_string() -> str: + major, minor = get_backend_version_tuple() + return f"{major}{minor}" + + +def get_gpu_specs() -> Optional[GPUSpecs]: + if not torch.cuda.is_available(): + return None + + return GPUSpecs( + gpu_backend=get_gpu_backend(), + highest_compute_capability=(get_compute_capabilities()), + backend_version_string=(get_backend_version_string()), + backend_version_tuple=get_backend_version_tuple(), + ) + + +def get_rocm_gpu_arch() -> str: + logger = logging.getLogger(__name__) + try: + if torch.version.hip: + result = subprocess.run(["rocminfo"], capture_output=True, text=True) + match = re.search(r"Name:\s+gfx([a-zA-Z\d]+)", result.stdout) + if match: + return "gfx" + match.group(1) + else: + return "unknown" + else: + return "unknown" + except Exception as e: + logger.error(f"Could not detect ROCm GPU architecture: {e}") + if torch.cuda.is_available(): + logger.warning( + """ +ROCm GPU architecture detection failed despite ROCm being available. + """, + ) + return "unknown" diff --git a/tests/test_autograd.py b/tests/test_autograd.py index eafa01f0e..ac89c9195 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -4,7 +4,7 @@ import torch import bitsandbytes as bnb -from bitsandbytes.cextension import BNB_HIP_VERSION +from bitsandbytes.gpu_specs import get_compute_capabilities from tests.helpers import ( BOOLEAN_TRIPLES, BOOLEAN_TUPLES, @@ -199,7 +199,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool assert (idx == 0).sum().item() < n * 0.02 -@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 53dd25044..42749ef00 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,6 +1,6 @@ import pytest -from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path from bitsandbytes.cuda_specs import CUDASpecs @@ -23,19 +23,19 @@ def cuda111_noblas_spec() -> CUDASpecs: @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") -def test_get_cuda_bnb_library_path(monkeypatch, cuda120_spec): +def test_get_gpu_bnb_library_path(monkeypatch, cuda120_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" + assert get_gpu_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda120" @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") -def test_get_cuda_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): +def test_get_gpu_bnb_library_path_override(monkeypatch, cuda120_spec, caplog): monkeypatch.setenv("BNB_CUDA_VERSION", "110") - assert get_cuda_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" + assert get_gpu_bnb_library_path(cuda120_spec).stem == "libbitsandbytes_cuda110" assert "BNB_CUDA_VERSION" in caplog.text # did we get the warning? @pytest.mark.skipif(HIP_ENVIRONMENT, reason="this test is not supported on ROCm") -def test_get_cuda_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): +def test_get_gpu_bnb_library_path_nocublaslt(monkeypatch, cuda111_noblas_spec): monkeypatch.delenv("BNB_CUDA_VERSION", raising=False) - assert get_cuda_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" + assert get_gpu_bnb_library_path(cuda111_noblas_spec).stem == "libbitsandbytes_cuda111_nocublaslt" diff --git a/tests/test_functional.py b/tests/test_functional.py index 35187db78..01d27cf80 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -11,7 +11,8 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.cextension import BNB_HIP_VERSION, HIP_ENVIRONMENT, ROCM_GPU_ARCH +from bitsandbytes.gpu_specs import get_compute_capabilities +from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, @@ -512,7 +513,7 @@ def test_vector_quant(dim1, dim2, dim3): assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) -@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) @@ -1817,7 +1818,7 @@ def quant_zp(x): print(err1, err2, err3, err4, err5, err6) -@pytest.mark.skipif(0 < BNB_HIP_VERSION < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_extract_outliers(device): for i in range(k): From 250690b493f5dac9137941c8c5d0c2271efc27ee Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Mon, 4 Nov 2024 15:04:41 +0000 Subject: [PATCH 02/10] replace cuda_specs --- bitsandbytes/cuda_specs.py | 47 -------------------------------------- 1 file changed, 47 deletions(-) delete mode 100644 bitsandbytes/cuda_specs.py diff --git a/bitsandbytes/cuda_specs.py b/bitsandbytes/cuda_specs.py deleted file mode 100644 index d77c5a3fd..000000000 --- a/bitsandbytes/cuda_specs.py +++ /dev/null @@ -1,47 +0,0 @@ -import dataclasses -import logging -import re -import subprocess -from typing import List, Optional, Tuple - -import torch - - -@dataclasses.dataclass(frozen=True) -class CUDASpecs: - highest_compute_capability: Tuple[int, int] - cuda_version_string: str - cuda_version_tuple: Tuple[int, int] - - @property - def has_cublaslt(self) -> bool: - return self.highest_compute_capability >= (7, 5) - - -def get_compute_capabilities() -> List[Tuple[int, int]]: - return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count())) - - -def get_cuda_version_tuple() -> Tuple[int, int]: - # https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART____VERSION.html#group__CUDART____VERSION - if torch.version.cuda: - major, minor = map(int, torch.version.cuda.split(".")) - elif torch.version.hip: - major, minor = map(int, torch.version.hip.split(".")[0:2]) - return major, minor - - -def get_cuda_version_string() -> str: - major, minor = get_cuda_version_tuple() - return f"{major}{minor}" - - -def get_cuda_specs() -> Optional[CUDASpecs]: - if not torch.cuda.is_available(): - return None - - return CUDASpecs( - highest_compute_capability=(get_compute_capabilities()[-1]), - cuda_version_string=(get_cuda_version_string()), - cuda_version_tuple=get_cuda_version_tuple(), - ) \ No newline at end of file From a86edc951462a9ec4033e6dd3f0894b1a7ff01f6 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Mon, 4 Nov 2024 15:42:14 +0000 Subject: [PATCH 03/10] fixing lint --- bitsandbytes/gpu_specs.py | 6 ++++-- tests/test_functional.py | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/bitsandbytes/gpu_specs.py b/bitsandbytes/gpu_specs.py index 49517f28c..b01a38390 100644 --- a/bitsandbytes/gpu_specs.py +++ b/bitsandbytes/gpu_specs.py @@ -2,7 +2,7 @@ import logging import re import subprocess -from typing import List, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch @@ -34,7 +34,9 @@ def get_compute_capabilities() -> Union[int, Tuple[int, int]]: hip_major, hip_minor = get_backend_version_tuple() return hip_major * 100 + hip_minor else: - return sorted(torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()))[-1] + return sorted( + torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()) + )[-1] def get_backend_version_tuple() -> Tuple[int, int]: diff --git a/tests/test_functional.py b/tests/test_functional.py index 01d27cf80..f03a0203c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -11,8 +11,8 @@ import bitsandbytes as bnb from bitsandbytes import functional as F -from bitsandbytes.gpu_specs import get_compute_capabilities from bitsandbytes.cextension import HIP_ENVIRONMENT, ROCM_GPU_ARCH +from bitsandbytes.gpu_specs import get_compute_capabilities from tests.helpers import ( BOOLEAN_TUPLES, TRUE_FALSE, From 69803b618d13f96b96474eefc4f83df80e6e6688 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Mon, 4 Nov 2024 15:44:24 +0000 Subject: [PATCH 04/10] fixing lint --- bitsandbytes/autograd/_functions.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index 02a7f45af..dce45fa5e 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -7,8 +7,8 @@ import torch -from bitsandbytes.gpu_specs import get_compute_capabilities import bitsandbytes.functional as F +from bitsandbytes.gpu_specs import get_compute_capabilities # math.prod not compatible with python < 3.8 From 716e01040320afbb0fcf841d1a47088d81973172 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Fri, 8 Nov 2024 19:23:59 +0000 Subject: [PATCH 05/10] fix diagnostics --- bitsandbytes/cextension.py | 2 +- bitsandbytes/diagnostics/cuda.py | 30 ++-- bitsandbytes/diagnostics/gpu.py | 241 +++++++++++++++++++++++++++++ bitsandbytes/diagnostics/main.py | 14 +- bitsandbytes/gpu_specs.py | 2 +- tests/test_cuda_setup_evaluator.py | 10 +- 6 files changed, 270 insertions(+), 29 deletions(-) create mode 100644 bitsandbytes/diagnostics/gpu.py diff --git a/bitsandbytes/cextension.py b/bitsandbytes/cextension.py index 532f6970b..d863ad41e 100644 --- a/bitsandbytes/cextension.py +++ b/bitsandbytes/cextension.py @@ -37,7 +37,7 @@ def get_gpu_bnb_library_path(gpu_specs: GPUSpecs) -> Path: The library is not guaranteed to exist at the returned path. """ library_name = f"libbitsandbytes_{gpu_specs.gpu_backend}{gpu_specs.backend_version_string}" - if not gpu_specs.has_blaslt: + if not gpu_specs.enable_blaslt: # if not has_cublaslt (CC < 7.5), then we have to choose _nocublaslt if gpu_specs.gpu_backend == "rocm": library_name += "_nohipblaslt" diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py index 014b753a9..0e8593fdd 100644 --- a/bitsandbytes/diagnostics/cuda.py +++ b/bitsandbytes/diagnostics/cuda.py @@ -5,9 +5,9 @@ import torch -from bitsandbytes.cextension import HIP_ENVIRONMENT, get_cuda_bnb_library_path +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path from bitsandbytes.consts import NONPYTORCH_DOC_URL -from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.gpu_specs import GPUSpecs from bitsandbytes.diagnostics.utils import print_dedented CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") @@ -109,13 +109,13 @@ def find_cudart_libraries() -> Iterator[Path]: yield from find_cuda_libraries_in_path_list(value) -def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: +def _print_cuda_diagnostics(gpu_specs: GPUSpecs) -> None: print( - f"PyTorch settings found: CUDA_VERSION={cuda_specs.cuda_version_string}, " - f"Highest Compute Capability: {cuda_specs.highest_compute_capability}.", + f"PyTorch settings found: CUDA_VERSION={gpu_specs.cuda_version_string}, " + f"Highest Compute Capability: {gpu_specs.highest_compute_capability}.", ) - binary_path = get_cuda_bnb_library_path(cuda_specs) + binary_path = get_gpu_bnb_library_path(gpu_specs) if not binary_path.exists(): print_dedented( f""" @@ -128,7 +128,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: """, ) - cuda_major, cuda_minor = cuda_specs.cuda_version_tuple + cuda_major, cuda_minor = gpu_specs.cuda_version_tuple if cuda_major < 11: print_dedented( """ @@ -140,7 +140,7 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") # 7.5 is the minimum CC for cublaslt - if not cuda_specs.has_cublaslt: + if not gpu_specs.has_cublaslt: print_dedented( """ WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! @@ -154,10 +154,10 @@ def _print_cuda_diagnostics(cuda_specs: CUDASpecs) -> None: # (2) Multiple CUDA versions installed -def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: - print(f"PyTorch settings found: ROCM_VERSION={cuda_specs.cuda_version_string}") +def _print_hip_diagnostics(gpu_specs: GPUSpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={gpu_specs.cuda_version_string}") - binary_path = get_cuda_bnb_library_path(cuda_specs) + binary_path = get_gpu_bnb_library_path(gpu_specs) if not binary_path.exists(): print_dedented( f""" @@ -168,7 +168,7 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: """, ) - hip_major, hip_minor = cuda_specs.cuda_version_tuple + hip_major, hip_minor = gpu_specs.cuda_version_tuple if (hip_major, hip_minor) < (6, 1): print_dedented( """ @@ -177,11 +177,11 @@ def _print_hip_diagnostics(cuda_specs: CUDASpecs) -> None: ) -def print_diagnostics(cuda_specs: CUDASpecs) -> None: +def print_diagnostics(gpu_specs: GPUSpecs) -> None: if HIP_ENVIRONMENT: - _print_hip_diagnostics(cuda_specs) + _print_hip_diagnostics(gpu_specs) else: - _print_cuda_diagnostics(cuda_specs) + _print_cuda_diagnostics(gpu_specs) def _print_cuda_runtime_diagnostics() -> None: diff --git a/bitsandbytes/diagnostics/gpu.py b/bitsandbytes/diagnostics/gpu.py new file mode 100644 index 000000000..18db9592c --- /dev/null +++ b/bitsandbytes/diagnostics/gpu.py @@ -0,0 +1,241 @@ +import logging +import os +from pathlib import Path +from typing import Dict, Iterable, Iterator + +import torch + +from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path +from bitsandbytes.consts import NONPYTORCH_DOC_URL +from bitsandbytes.gpu_specs import GPUSpecs +from bitsandbytes.diagnostics.utils import print_dedented + +GPU_RT_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") + +GPU_RT_PATH_IGNORED_ENVVARS = { + "DBUS_SESSION_BUS_ADDRESS", # hardware related + "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks + "HOME", # Linux shell default + "LESSCLOSE", + "LESSOPEN", # related to the `less` command + "MAIL", # something related to emails + "OLDPWD", + "PATH", # this is for finding binaries, not libraries + "PWD", # PWD: this is how the shell keeps track of the current working dir + "SHELL", # binary for currently invoked shell + "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated + "SSH_TTY", + "TMUX", # Terminal Multiplexer + "XDG_DATA_DIRS", # XDG: Desktop environment stuff + "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff + "XDG_RUNTIME_DIR", + "_", # current Python interpreter +} + +logger = logging.getLogger(__name__) + + +def get_runtime_lib_patterns() -> tuple: + if HIP_ENVIRONMENT: + return ("libamdhip64.so*",) + else: + return ( + "cudart64*.dll", # Windows + "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. + "nvcuda*.dll", # Windows + ) + + +def find_gpu_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: + for dir_string in paths_list_candidate.split(os.pathsep): + if not dir_string: + continue + if os.sep not in dir_string: + continue + try: + dir = Path(dir_string) + try: + if not dir.exists(): + logger.warning(f"The directory listed in your path is found to be non-existent: {dir}") + continue + except OSError: # Assume an esoteric error trying to poke at the directory + pass + for lib_pattern in get_runtime_lib_patterns(): + for pth in dir.glob(lib_pattern): + if pth.is_file() and not pth.is_symlink(): + yield pth + except (OSError, PermissionError): + pass + + +def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: + return ( + env_var in GPU_RT_PATH_PREFERRED_ENVVARS # is a preferred location + or ( + os.sep in value # might contain a path + and env_var not in GPU_RT_PATH_IGNORED_ENVVARS # not ignored + and "CONDA" not in env_var # not another conda envvar + and "BASH_FUNC" not in env_var # not a bash function defined via envvar + and "\n" not in value # likely e.g. a script or something? + ) + ) + + +def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: + return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)} + + +def find_gpu_rt_libraries() -> Iterator[Path]: + """ + Searches for a cuda installations, in the following order of priority: + 1. active conda env + 2. LD_LIBRARY_PATH + 3. any other env vars, while ignoring those that + - are known to be unrelated + - don't contain the path separator `/` + + If multiple libraries are found in part 3, we optimistically try one, + while giving a warning message. + """ + candidate_env_vars = get_potentially_lib_path_containing_env_vars() + + for envvar in GPU_RT_PATH_PREFERRED_ENVVARS: + if envvar in candidate_env_vars: + directory = candidate_env_vars[envvar] + yield from find_gpu_libraries_in_path_list(directory) + candidate_env_vars.pop(envvar) + + for env_var, value in candidate_env_vars.items(): + yield from find_gpu_libraries_in_path_list(value) + + +def _print_cuda_diagnostics(gpu_specs: GPUSpecs) -> None: + print( + f"PyTorch settings found: CUDA_VERSION={gpu_specs.backend_version_string}, " + f"Highest Compute Capability: {gpu_specs.highest_compute_capability}.", + ) + + binary_path = get_gpu_bnb_library_path(gpu_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. Maybe you need to compile it from source? + If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`, + for example, `make CUDA_VERSION=113`. + + The CUDA version for the compile might depend on your conda install, if using conda. + Inspect CUDA version via `conda list | grep cuda`. + """, + ) + + cuda_major, cuda_minor = gpu_specs.backend_version_tuple + if cuda_major < 11: + print_dedented( + """ + WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). + You will be only to use 8-bit optimizers and quantization routines! + """, + ) + + print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") + + # 7.5 is the minimum CC for cublaslt + if not gpu_specs.enable_blaslt: + print_dedented( + """ + WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! + If you run into issues with 8-bit matmul, you can try 4-bit quantization: + https://huggingface.co/blog/4bit-transformers-bitsandbytes + """, + ) + + # TODO: + # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) + # (2) Multiple CUDA versions installed + + +def _print_hip_diagnostics(gpu_specs: GPUSpecs) -> None: + print(f"PyTorch settings found: ROCM_VERSION={gpu_specs.backend_version_string}") + + binary_path = get_gpu_bnb_library_path(gpu_specs) + if not binary_path.exists(): + print_dedented( + f""" + Library not found: {binary_path}. + Maybe you need to compile it from source? If you compiled from source, check that ROCM_VERSION + in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version + and rebuild bitsandbytes. + """, + ) + + hip_major, hip_minor = gpu_specs.backend_version_tuple + if (hip_major, hip_minor) < (6, 1): + print_dedented( + """ + WARNING: bitsandbytes is fully supported only from ROCm 6.1. + """, + ) + + +def print_diagnostics(gpu_specs: GPUSpecs) -> None: + if HIP_ENVIRONMENT: + _print_hip_diagnostics(gpu_specs) + else: + _print_cuda_diagnostics(gpu_specs) + + +def _print_cuda_runtime_diagnostics() -> None: + gpu_rt_paths = list(find_gpu_rt_libraries()) + if not gpu_rt_paths: + print("WARNING! CUDA runtime files not found in any environmental path.") + elif len(gpu_rt_paths) > 1: + print_dedented( + f""" + Found duplicate CUDA runtime files (see below). + + We select the PyTorch default CUDA runtime, which is {torch.version.cuda}, + but this might mismatch with the CUDA version that is needed for bitsandbytes. + To override this behavior set the `BNB_CUDA_VERSION=` environmental variable. + + For example, if you want to use the CUDA version 122, + BNB_CUDA_VERSION=122 python ... + + OR set the environmental variable in your .bashrc: + export BNB_CUDA_VERSION=122 + + In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, + """, + ) + for pth in gpu_rt_paths: + print(f"* Found CUDA runtime at: {pth}") + + +def _print_hip_runtime_diagnostics() -> None: + gpu_rt_paths = list(find_gpu_rt_libraries()) + if not gpu_rt_paths: + print("WARNING! ROCm runtime files not found in any environmental path.") + elif len(gpu_rt_paths) > 1: + print_dedented( + f""" + Found duplicate ROCm runtime files (see below). + + We select the PyTorch default ROCm runtime, which is {torch.version.hip}, + but this might mismatch with the ROCm version that is needed for bitsandbytes. + + To resolve it, install PyTorch built for the ROCm version you want to use + + and set LD_LIBRARY_PATH to your ROCm install path, e.g. + export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib, + """, + ) + + for pth in gpu_rt_paths: + print(f"* Found ROCm runtime at: {pth}") + + +def print_runtime_diagnostics() -> None: + if HIP_ENVIRONMENT: + _print_hip_runtime_diagnostics() + else: + _print_cuda_runtime_diagnostics() diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 8dc43ed2a..65e0fe924 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -5,7 +5,7 @@ from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT from bitsandbytes.consts import PACKAGE_GITHUB_URL -from bitsandbytes.cuda_specs import get_cuda_specs +from bitsandbytes.gpu_specs import get_gpu_specs from bitsandbytes.diagnostics.cuda import ( print_diagnostics, print_runtime_diagnostics, @@ -50,20 +50,20 @@ def main(): print_header("") print_header("OTHER") - cuda_specs = get_cuda_specs() + gpu_specs = get_gpu_specs() if HIP_ENVIRONMENT: - rocm_specs = f" rocm_version_string='{cuda_specs.cuda_version_string}'," - rocm_specs += f" rocm_version_tuple={cuda_specs.cuda_version_tuple}" + rocm_specs = f" rocm_version_string='{gpu_specs.cuda_version_string}'," + rocm_specs += f" rocm_version_tuple={gpu_specs.cuda_version_tuple}" print(f"{BNB_BACKEND} specs:{rocm_specs}") else: - print(f"{BNB_BACKEND} specs:{cuda_specs}") + print(f"{BNB_BACKEND} specs:{gpu_specs}") if not torch.cuda.is_available(): print(f"Torch says {BNB_BACKEND} is not available. Possible reasons:") print(f"1. {BNB_BACKEND} driver not installed") print(f"2. {BNB_BACKEND} not installed") print(f"3. You have multiple conflicting {BNB_BACKEND} libraries") - if cuda_specs: - print_diagnostics(cuda_specs) + if gpu_specs: + print_diagnostics(gpu_specs) print_runtime_diagnostics() print_header("") print_header("DEBUG INFO END") diff --git a/bitsandbytes/gpu_specs.py b/bitsandbytes/gpu_specs.py index b01a38390..822ad3fb2 100644 --- a/bitsandbytes/gpu_specs.py +++ b/bitsandbytes/gpu_specs.py @@ -15,7 +15,7 @@ class GPUSpecs: backend_version_tuple: Tuple[int, int] @property - def has_blaslt(self) -> bool: + def enable_blaslt(self) -> bool: if torch.version.hip: return self.highest_compute_capability >= 601 else: diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index 42749ef00..a8597acae 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -1,12 +1,12 @@ import pytest from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path -from bitsandbytes.cuda_specs import CUDASpecs +from bitsandbytes.gpu_specs import GPUSpecs @pytest.fixture -def cuda120_spec() -> CUDASpecs: - return CUDASpecs( +def cuda120_spec() -> GPUSpecs: + return GPUSpecs( cuda_version_string="120", highest_compute_capability=(8, 6), cuda_version_tuple=(12, 0), @@ -14,8 +14,8 @@ def cuda120_spec() -> CUDASpecs: @pytest.fixture -def cuda111_noblas_spec() -> CUDASpecs: - return CUDASpecs( +def cuda111_noblas_spec() -> GPUSpecs: + return GPUSpecs( cuda_version_string="111", highest_compute_capability=(7, 2), cuda_version_tuple=(11, 1), From dc4057ed7386f6c718bb4f5786dea89c44b38659 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Fri, 8 Nov 2024 19:34:50 +0000 Subject: [PATCH 06/10] fix lint --- bitsandbytes/diagnostics/cuda.py | 241 ------------------------------- bitsandbytes/diagnostics/gpu.py | 2 +- bitsandbytes/diagnostics/main.py | 4 +- 3 files changed, 3 insertions(+), 244 deletions(-) delete mode 100644 bitsandbytes/diagnostics/cuda.py diff --git a/bitsandbytes/diagnostics/cuda.py b/bitsandbytes/diagnostics/cuda.py deleted file mode 100644 index 0e8593fdd..000000000 --- a/bitsandbytes/diagnostics/cuda.py +++ /dev/null @@ -1,241 +0,0 @@ -import logging -import os -from pathlib import Path -from typing import Dict, Iterable, Iterator - -import torch - -from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path -from bitsandbytes.consts import NONPYTORCH_DOC_URL -from bitsandbytes.gpu_specs import GPUSpecs -from bitsandbytes.diagnostics.utils import print_dedented - -CUDART_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") - -CUDART_PATH_IGNORED_ENVVARS = { - "DBUS_SESSION_BUS_ADDRESS", # hardware related - "GOOGLE_VM_CONFIG_LOCK_FILE", # GCP: requires elevated permissions, causing problems in VMs and Jupyter notebooks - "HOME", # Linux shell default - "LESSCLOSE", - "LESSOPEN", # related to the `less` command - "MAIL", # something related to emails - "OLDPWD", - "PATH", # this is for finding binaries, not libraries - "PWD", # PWD: this is how the shell keeps track of the current working dir - "SHELL", # binary for currently invoked shell - "SSH_AUTH_SOCK", # SSH stuff, therefore unrelated - "SSH_TTY", - "TMUX", # Terminal Multiplexer - "XDG_DATA_DIRS", # XDG: Desktop environment stuff - "XDG_GREETER_DATA_DIR", # XDG: Desktop environment stuff - "XDG_RUNTIME_DIR", - "_", # current Python interpreter -} - -logger = logging.getLogger(__name__) - - -def get_runtime_lib_patterns() -> tuple: - if HIP_ENVIRONMENT: - return ("libamdhip64.so*",) - else: - return ( - "cudart64*.dll", # Windows - "libcudart*.so*", # libcudart.so, libcudart.so.11.0, libcudart.so.12.0, libcudart.so.12.1, libcudart.so.12.2 etc. - "nvcuda*.dll", # Windows - ) - - -def find_cuda_libraries_in_path_list(paths_list_candidate: str) -> Iterable[Path]: - for dir_string in paths_list_candidate.split(os.pathsep): - if not dir_string: - continue - if os.sep not in dir_string: - continue - try: - dir = Path(dir_string) - try: - if not dir.exists(): - logger.warning(f"The directory listed in your path is found to be non-existent: {dir}") - continue - except OSError: # Assume an esoteric error trying to poke at the directory - pass - for lib_pattern in get_runtime_lib_patterns(): - for pth in dir.glob(lib_pattern): - if pth.is_file() and not pth.is_symlink(): - yield pth - except (OSError, PermissionError): - pass - - -def is_relevant_candidate_env_var(env_var: str, value: str) -> bool: - return ( - env_var in CUDART_PATH_PREFERRED_ENVVARS # is a preferred location - or ( - os.sep in value # might contain a path - and env_var not in CUDART_PATH_IGNORED_ENVVARS # not ignored - and "CONDA" not in env_var # not another conda envvar - and "BASH_FUNC" not in env_var # not a bash function defined via envvar - and "\n" not in value # likely e.g. a script or something? - ) - ) - - -def get_potentially_lib_path_containing_env_vars() -> Dict[str, str]: - return {env_var: value for env_var, value in os.environ.items() if is_relevant_candidate_env_var(env_var, value)} - - -def find_cudart_libraries() -> Iterator[Path]: - """ - Searches for a cuda installations, in the following order of priority: - 1. active conda env - 2. LD_LIBRARY_PATH - 3. any other env vars, while ignoring those that - - are known to be unrelated - - don't contain the path separator `/` - - If multiple libraries are found in part 3, we optimistically try one, - while giving a warning message. - """ - candidate_env_vars = get_potentially_lib_path_containing_env_vars() - - for envvar in CUDART_PATH_PREFERRED_ENVVARS: - if envvar in candidate_env_vars: - directory = candidate_env_vars[envvar] - yield from find_cuda_libraries_in_path_list(directory) - candidate_env_vars.pop(envvar) - - for env_var, value in candidate_env_vars.items(): - yield from find_cuda_libraries_in_path_list(value) - - -def _print_cuda_diagnostics(gpu_specs: GPUSpecs) -> None: - print( - f"PyTorch settings found: CUDA_VERSION={gpu_specs.cuda_version_string}, " - f"Highest Compute Capability: {gpu_specs.highest_compute_capability}.", - ) - - binary_path = get_gpu_bnb_library_path(gpu_specs) - if not binary_path.exists(): - print_dedented( - f""" - Library not found: {binary_path}. Maybe you need to compile it from source? - If you compiled from source, try again with `make CUDA_VERSION=DETECTED_CUDA_VERSION`, - for example, `make CUDA_VERSION=113`. - - The CUDA version for the compile might depend on your conda install, if using conda. - Inspect CUDA version via `conda list | grep cuda`. - """, - ) - - cuda_major, cuda_minor = gpu_specs.cuda_version_tuple - if cuda_major < 11: - print_dedented( - """ - WARNING: CUDA versions lower than 11 are currently not supported for LLM.int8(). - You will be only to use 8-bit optimizers and quantization routines! - """, - ) - - print(f"To manually override the PyTorch CUDA version please see: {NONPYTORCH_DOC_URL}") - - # 7.5 is the minimum CC for cublaslt - if not gpu_specs.has_cublaslt: - print_dedented( - """ - WARNING: Compute capability < 7.5 detected! Only slow 8-bit matmul is supported for your GPU! - If you run into issues with 8-bit matmul, you can try 4-bit quantization: - https://huggingface.co/blog/4bit-transformers-bitsandbytes - """, - ) - - # TODO: - # (1) CUDA missing cases (no CUDA installed by CUDA driver (nvidia-smi accessible) - # (2) Multiple CUDA versions installed - - -def _print_hip_diagnostics(gpu_specs: GPUSpecs) -> None: - print(f"PyTorch settings found: ROCM_VERSION={gpu_specs.cuda_version_string}") - - binary_path = get_gpu_bnb_library_path(gpu_specs) - if not binary_path.exists(): - print_dedented( - f""" - Library not found: {binary_path}. - Maybe you need to compile it from source? If you compiled from source, check that ROCM_VERSION - in PyTorch Settings matches your ROCm install. If not, reinstall PyTorch for your ROCm version - and rebuild bitsandbytes. - """, - ) - - hip_major, hip_minor = gpu_specs.cuda_version_tuple - if (hip_major, hip_minor) < (6, 1): - print_dedented( - """ - WARNING: bitsandbytes is fully supported only from ROCm 6.1. - """, - ) - - -def print_diagnostics(gpu_specs: GPUSpecs) -> None: - if HIP_ENVIRONMENT: - _print_hip_diagnostics(gpu_specs) - else: - _print_cuda_diagnostics(gpu_specs) - - -def _print_cuda_runtime_diagnostics() -> None: - cudart_paths = list(find_cudart_libraries()) - if not cudart_paths: - print("WARNING! CUDA runtime files not found in any environmental path.") - elif len(cudart_paths) > 1: - print_dedented( - f""" - Found duplicate CUDA runtime files (see below). - - We select the PyTorch default CUDA runtime, which is {torch.version.cuda}, - but this might mismatch with the CUDA version that is needed for bitsandbytes. - To override this behavior set the `BNB_CUDA_VERSION=` environmental variable. - - For example, if you want to use the CUDA version 122, - BNB_CUDA_VERSION=122 python ... - - OR set the environmental variable in your .bashrc: - export BNB_CUDA_VERSION=122 - - In the case of a manual override, make sure you set LD_LIBRARY_PATH, e.g. - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/usr/local/cuda-11.2, - """, - ) - for pth in cudart_paths: - print(f"* Found CUDA runtime at: {pth}") - - -def _print_hip_runtime_diagnostics() -> None: - cudart_paths = list(find_cudart_libraries()) - if not cudart_paths: - print("WARNING! ROCm runtime files not found in any environmental path.") - elif len(cudart_paths) > 1: - print_dedented( - f""" - Found duplicate ROCm runtime files (see below). - - We select the PyTorch default ROCm runtime, which is {torch.version.hip}, - but this might mismatch with the ROCm version that is needed for bitsandbytes. - - To resolve it, install PyTorch built for the ROCm version you want to use - - and set LD_LIBRARY_PATH to your ROCm install path, e.g. - export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/opt/rocm-6.1.2/lib, - """, - ) - - for pth in cudart_paths: - print(f"* Found ROCm runtime at: {pth}") - - -def print_runtime_diagnostics() -> None: - if HIP_ENVIRONMENT: - _print_hip_runtime_diagnostics() - else: - _print_cuda_runtime_diagnostics() diff --git a/bitsandbytes/diagnostics/gpu.py b/bitsandbytes/diagnostics/gpu.py index 18db9592c..e9ae0c71e 100644 --- a/bitsandbytes/diagnostics/gpu.py +++ b/bitsandbytes/diagnostics/gpu.py @@ -7,8 +7,8 @@ from bitsandbytes.cextension import HIP_ENVIRONMENT, get_gpu_bnb_library_path from bitsandbytes.consts import NONPYTORCH_DOC_URL -from bitsandbytes.gpu_specs import GPUSpecs from bitsandbytes.diagnostics.utils import print_dedented +from bitsandbytes.gpu_specs import GPUSpecs GPU_RT_PATH_PREFERRED_ENVVARS = ("CONDA_PREFIX", "LD_LIBRARY_PATH") diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 65e0fe924..1bf1e5cd7 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -5,12 +5,12 @@ from bitsandbytes.cextension import BNB_BACKEND, HIP_ENVIRONMENT from bitsandbytes.consts import PACKAGE_GITHUB_URL -from bitsandbytes.gpu_specs import get_gpu_specs -from bitsandbytes.diagnostics.cuda import ( +from bitsandbytes.diagnostics.gpu import ( print_diagnostics, print_runtime_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header +from bitsandbytes.gpu_specs import GPUSpecs def sanity_check(): From d673950060f23e6b62ebe866f9346caa9a7b9000 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Fri, 8 Nov 2024 20:01:14 +0000 Subject: [PATCH 07/10] fix lint --- bitsandbytes/diagnostics/main.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 1bf1e5cd7..560f651ec 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -10,7 +10,7 @@ print_runtime_diagnostics, ) from bitsandbytes.diagnostics.utils import print_dedented, print_header -from bitsandbytes.gpu_specs import GPUSpecs +from bitsandbytes.gpu_specs import get_gpu_specs def sanity_check(): From 8fa795acf80f4cac2035093eab360281c8d15963 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Mon, 13 Jan 2025 13:17:45 +0000 Subject: [PATCH 08/10] change version to tuple format --- bitsandbytes/autograd/_functions.py | 2 +- bitsandbytes/gpu_specs.py | 6 +++--- tests/test_autograd.py | 2 +- tests/test_cuda_setup_evaluator.py | 8 ++++---- tests/test_functional.py | 4 ++-- 5 files changed, 11 insertions(+), 11 deletions(-) diff --git a/bitsandbytes/autograd/_functions.py b/bitsandbytes/autograd/_functions.py index dce45fa5e..ca568f72d 100644 --- a/bitsandbytes/autograd/_functions.py +++ b/bitsandbytes/autograd/_functions.py @@ -224,7 +224,7 @@ def supports_igemmlt(device: torch.device) -> bool: if device == torch.device("cpu"): return True if torch.version.hip: - return False if get_compute_capabilities() < 601 else True + return False if get_compute_capabilities() < (6, 1) else True if get_compute_capabilities() < (7, 5): return False device_name = torch.cuda.get_device_name(device=device) diff --git a/bitsandbytes/gpu_specs.py b/bitsandbytes/gpu_specs.py index 822ad3fb2..3a1d19a2e 100644 --- a/bitsandbytes/gpu_specs.py +++ b/bitsandbytes/gpu_specs.py @@ -10,14 +10,14 @@ @dataclasses.dataclass(frozen=True) class GPUSpecs: gpu_backend: str - highest_compute_capability: Union[int, Tuple[int, int]] + highest_compute_capability: Tuple[int, int] backend_version_string: str backend_version_tuple: Tuple[int, int] @property def enable_blaslt(self) -> bool: if torch.version.hip: - return self.highest_compute_capability >= 601 + return self.highest_compute_capability >= (6, 1) else: return self.highest_compute_capability >= (7, 5) @@ -32,7 +32,7 @@ def get_gpu_backend() -> str: def get_compute_capabilities() -> Union[int, Tuple[int, int]]: if torch.version.hip: hip_major, hip_minor = get_backend_version_tuple() - return hip_major * 100 + hip_minor + return (hip_major, hip_minor) else: return sorted( torch.cuda.get_device_capability(torch.cuda.device(i)) for i in range(torch.cuda.device_count()) diff --git a/tests/test_autograd.py b/tests/test_autograd.py index ac89c9195..49d5368a6 100644 --- a/tests/test_autograd.py +++ b/tests/test_autograd.py @@ -199,7 +199,7 @@ def test_matmul(dim1, dim2, dim3, dim4, funcs, dtype, req_grad: Tuple[bool, bool assert (idx == 0).sum().item() < n * 0.02 -@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.skipif((0, 0) < get_compute_capabilities() < (6, 1), reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("dim1", get_test_dims(16, 64, n=1), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", [*get_test_dims(32, 96, n=1), 0], ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(32, 96, n=1), ids=id_formatter("dim3")) diff --git a/tests/test_cuda_setup_evaluator.py b/tests/test_cuda_setup_evaluator.py index a8597acae..7612d68af 100644 --- a/tests/test_cuda_setup_evaluator.py +++ b/tests/test_cuda_setup_evaluator.py @@ -7,18 +7,18 @@ @pytest.fixture def cuda120_spec() -> GPUSpecs: return GPUSpecs( - cuda_version_string="120", + backend_version_string="120", highest_compute_capability=(8, 6), - cuda_version_tuple=(12, 0), + backend_version_tuple=(12, 0), ) @pytest.fixture def cuda111_noblas_spec() -> GPUSpecs: return GPUSpecs( - cuda_version_string="111", + backend_version_string="111", highest_compute_capability=(7, 2), - cuda_version_tuple=(11, 1), + backend_version_tuple=(11, 1), ) diff --git a/tests/test_functional.py b/tests/test_functional.py index f03a0203c..2e3e8c90c 100644 --- a/tests/test_functional.py +++ b/tests/test_functional.py @@ -513,7 +513,7 @@ def test_vector_quant(dim1, dim2, dim3): assert_all_approx_close(A1, A, atol=0.01, rtol=0.1, count=int(n * 0.002)) -@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.skipif((0, 0) < get_compute_capabilities() < (6, 1), reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("dim1", get_test_dims(2, 256, n=2), ids=id_formatter("dim1")) @pytest.mark.parametrize("dim2", get_test_dims(2, 256, n=2), ids=id_formatter("dim2")) @pytest.mark.parametrize("dim3", get_test_dims(2, 256, n=2), ids=id_formatter("dim3")) @@ -1818,7 +1818,7 @@ def quant_zp(x): print(err1, err2, err3, err4, err5, err6) -@pytest.mark.skipif(0 < get_compute_capabilities() < 601, reason="this test is supported on ROCm from 6.1") +@pytest.mark.skipif((0, 0) < get_compute_capabilities() < (6, 1), reason="this test is supported on ROCm from 6.1") @pytest.mark.parametrize("device", ["cuda", "cpu"]) def test_extract_outliers(device): for i in range(k): From fa5004fc07ba642c4c8a36a17860b172e9429cc8 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Mon, 13 Jan 2025 13:29:57 +0000 Subject: [PATCH 09/10] change version to tuple format --- bitsandbytes/gpu_specs.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bitsandbytes/gpu_specs.py b/bitsandbytes/gpu_specs.py index 3a1d19a2e..2df48441b 100644 --- a/bitsandbytes/gpu_specs.py +++ b/bitsandbytes/gpu_specs.py @@ -29,7 +29,7 @@ def get_gpu_backend() -> str: return "cuda" -def get_compute_capabilities() -> Union[int, Tuple[int, int]]: +def get_compute_capabilities() -> Tuple[int, int]: if torch.version.hip: hip_major, hip_minor = get_backend_version_tuple() return (hip_major, hip_minor) From 3ed572376ef178f88474fb49ca9e4abb783ab9b0 Mon Sep 17 00:00:00 2001 From: Lzy17 Date: Mon, 13 Jan 2025 13:38:18 +0000 Subject: [PATCH 10/10] debug --- bitsandbytes/diagnostics/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bitsandbytes/diagnostics/main.py b/bitsandbytes/diagnostics/main.py index 560f651ec..a47ce9f7e 100644 --- a/bitsandbytes/diagnostics/main.py +++ b/bitsandbytes/diagnostics/main.py @@ -52,8 +52,8 @@ def main(): print_header("OTHER") gpu_specs = get_gpu_specs() if HIP_ENVIRONMENT: - rocm_specs = f" rocm_version_string='{gpu_specs.cuda_version_string}'," - rocm_specs += f" rocm_version_tuple={gpu_specs.cuda_version_tuple}" + rocm_specs = f" rocm_version_string='{gpu_specs.backend_version_string}'," + rocm_specs += f" rocm_version_tuple={gpu_specs.backend_version_tuple}" print(f"{BNB_BACKEND} specs:{rocm_specs}") else: print(f"{BNB_BACKEND} specs:{gpu_specs}")