From ff4de4eafdfe5c2346b252cf09c6c412e13f3135 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Tue, 13 Jan 2026 13:08:25 +0000 Subject: [PATCH 1/2] Add support for kernel versions This change adds support for kernel versions branches as implemented by `kernel-builder`: * `kernels upload` will upload the branch corresponding to the kernel's version. * `get_kernel`, `LayerRepository`, etc. accept version branches through the `version` argument. * Kernel locking (e.g. `kernels lock`) support reading version branches from `pyproject.yaml` and correctly locking them. All the kernel functions and layers already supported a `version` argument for version specifiers (based on tags). These are now overloaded -- if `version` is a `str`, it is parsed as a version specifier. If it is an `int`, it is used as a version branch. This overloading is used to gradually move to version branches, deprecating version specifiers. For now, using version specifiers will emit a warning. --- docs/source/basic-usage.md | 27 +++--- docs/source/kernel-requirements.md | 21 ++-- docs/source/layers.md | 16 ++- docs/source/locking.md | 2 +- src/kernels/_versions.py | 69 +++++++++---- src/kernels/cli.py | 55 ++--------- src/kernels/layer/func.py | 20 ++-- src/kernels/layer/layer.py | 18 ++-- src/kernels/lockfile.py | 2 +- src/kernels/metadata.py | 23 +++++ src/kernels/upload.py | 69 +++++++++++++ src/kernels/utils.py | 35 ++++--- tests/kernel_locking/kernels.lock | 124 +++++++++++++++++------- tests/kernel_locking/pyproject.toml | 4 +- tests/kernel_locking_old/kernels.lock | 82 ++++++++++++++++ tests/kernel_locking_old/pyproject.toml | 3 + tests/layer_locking/kernels.lock | 20 +++- tests/layer_locking/pyproject.toml | 2 +- tests/layer_locking_old/kernels.lock | 12 +++ tests/layer_locking_old/pyproject.toml | 2 + tests/test_basic.py | 17 +++- tests/test_kernel_locking.py | 80 ++++++++++++++- tests/test_layer.py | 82 +++++++++++++++- 23 files changed, 612 insertions(+), 173 deletions(-) create mode 100644 src/kernels/metadata.py create mode 100644 src/kernels/upload.py create mode 100644 tests/kernel_locking_old/kernels.lock create mode 100644 tests/kernel_locking_old/pyproject.toml create mode 100644 tests/layer_locking_old/kernels.lock create mode 100644 tests/layer_locking_old/pyproject.toml diff --git a/docs/source/basic-usage.md b/docs/source/basic-usage.md index 1d9fd93..bd520b6 100644 --- a/docs/source/basic-usage.md +++ b/docs/source/basic-usage.md @@ -9,7 +9,7 @@ import torch from kernels import get_kernel # Download optimized kernels from the Hugging Face hub -activation = get_kernel("kernels-community/activation") +activation = get_kernel("kernels-community/activation", version=1) # Create a random tensor x = torch.randn((10, 10), dtype=torch.float16, device="cuda") @@ -21,30 +21,25 @@ activation.gelu_fast(y, x) print(y) ``` -### Using version bounds +This fetches version `1` of the kernel `kernels-community/activation`. +Kernels are versioned using a major version number. Using `version=1` will +get the latest kernel build from the `v1` branch. The kernel version is +bumped is bumped in the following circumstances: -Kernels are versioned using tags of the form `v..`. -You can specify which version to download using Python version specifiers: +* The kernel API changes in an incompatible way. +* Support for an older PyTorch version is dropped. -```python -import torch -from kernels import get_kernel - -activation = get_kernel("kernels-community/activation", version=">=0.0.4,<0.1.0") -``` - -This will get the latest kernel tagged `v0.0.z` where `z` is at least 4. It -is strongly recommended to specify a version bound, since a kernel author -might push incompatible changes to the `main` branch. +In this way, you can ensure that your code will continue to work. ## Checking Kernel Availability -You can check if a specific kernel is available for your environment: +You can check if a particular version of a kernel supports the environment +that the program is running on: ```python from kernels import has_kernel # Check if kernel is available for current environment -is_available = has_kernel("kernels-community/activation") +is_available = has_kernel("kernels-community/activation", version=1) print(f"Kernel available: {is_available}") ``` diff --git a/docs/source/kernel-requirements.md b/docs/source/kernel-requirements.md index 0da3c4e..2e09bc2 100644 --- a/docs/source/kernel-requirements.md +++ b/docs/source/kernel-requirements.md @@ -40,21 +40,30 @@ must be available for that combination. ## Kernel metadata The build variant directory can optionally contain a `metadata.json` file. -Currently the only purpose of the metadata is to specify the kernel python dependencies, for example: +Currently the metadata specifies the kernel's versin and Python dependencies, +for example: ```json -{ "python-depends": ["nvidia-cutlass-dsl"] } +{ + "python-depends": ["nvidia-cutlass-dsl"], + "version": 1 +} ``` The following dependencies are the only ones allowed at this stage: `einops` and `nvidia-cutlass-dsl` ## Versioning -Kernels are versioned on the Hub using Git tags. Version tags must be of -the form `v..`. Versions are used by [locking](./locking.md) -to resolve the version constraints. +Kernels are versioned using a major version. The kernel revisions of a +version are stored in a branch of the form `v`. Each build +variant will also have the kernel version in `metadata.json`. -We recommend using [semver](https://semver.org/) to version kernels. +The version **must** be bumped in the following cases: + +- The kernel API is changed in an incompatible way. +- Support for one or more PyTorch version is dropped. This is usually done + when the API is extended and the builds for older PyTorch versions are not + updated anymore to have the new extensions. ## Native Python module diff --git a/docs/source/layers.md b/docs/source/layers.md index 99629e8..6650fc3 100644 --- a/docs/source/layers.md +++ b/docs/source/layers.md @@ -214,20 +214,18 @@ kernel_layer_mapping = { "cuda": LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", - version=">=0.0.4,<0.1.0", + version=1, ), "rocm": LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", - version=">=0.0.4,<0.1.0", + version=1, ) } } ``` -This will get the layer from latest kernel tagged `v0.0.z` where `z` is at -least 4. It is strongly recommended to specify a version bound, since a -kernel author might push incompatible changes to the `main` branch. +This will get the layer from lates version on the version 1 branch. ### Registering kernels for specific modes @@ -242,10 +240,12 @@ kernel_layer_mapping = { Mode.INFERENCE: LayerRepository( repo_id="kernels-community/activation-inference-optimized", layer_name="SiluAndMul", + version=1, ), Mode.TRAINING | Mode.TORCH_COMPILE: LayerRepository( repo_id="kernels-community/activation-training-optimized", layer_name="SiluAndMul", + version=1, ), } } @@ -273,14 +273,17 @@ kernel_layer_mapping = { Mode.FALLBACK: LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", + version=1, ), Mode.INFERENCE: LayerRepository( repo_id="kernels-community/activation-inference-optimized", layer_name="SiluAndMul", + version=1, ), Mode.TRAINING: LayerRepository( repo_id="kernels-community/activation-training-optimized", layer_name="SiluAndMul", + version=1, ), } } @@ -310,6 +313,7 @@ kernel_layer_mapping = { ): LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", + version=1, ), Device( type="cuda", @@ -319,6 +323,7 @@ kernel_layer_mapping = { ): LayerRepository( repo_id="kernels-community/activation-hopper", layer_name="SiluAndMul", + version=1, ), } } @@ -359,6 +364,7 @@ with use_kernel_mapping( repo_path="/home/daniel/kernels/activation", package_name="activation", layer_name="SiluAndMul", + version=1, ) } }, diff --git a/docs/source/locking.md b/docs/source/locking.md index cb4b125..7667052 100644 --- a/docs/source/locking.md +++ b/docs/source/locking.md @@ -10,7 +10,7 @@ requires = ["kernels", "setuptools"] build-backend = "setuptools.build_meta" [tool.kernels.dependencies] -"kernels-community/activation" = ">=0.0.1" +"kernels-community/activation" = 1 ``` Then run `kernels lock .` in the project directory. This generates a `kernels.lock` file with diff --git a/src/kernels/_versions.py b/src/kernels/_versions.py index cf8a616..48beea5 100644 --- a/src/kernels/_versions.py +++ b/src/kernels/_versions.py @@ -1,3 +1,4 @@ +import warnings from typing import Dict, Optional from huggingface_hub import HfApi @@ -6,9 +7,27 @@ from packaging.version import InvalidVersion, Version -def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]: +def _get_available_versions(repo_id: str) -> Dict[int, GitRefInfo]: """Get kernel versions that are available in the repository.""" versions = {} + for branch in HfApi().list_repo_refs(repo_id).branches: + if not branch.name.startswith("v"): + continue + try: + versions[int(branch.name[1:])] = branch + except ValueError: + continue + + return versions + + +def _get_available_versions_old(repo_id: str) -> Dict[Version, GitRefInfo]: + """ + Get kernel versions that are available in the repository. + + This is for the old tag-based versioning scheme. + """ + versions = {} for tag in HfApi().list_repo_refs(repo_id).tags: if not tag.name.startswith("v"): continue @@ -20,33 +39,49 @@ def _get_available_versions(repo_id: str) -> Dict[Version, GitRefInfo]: return versions -def resolve_version_spec_as_ref(repo_id: str, version_spec: str) -> GitRefInfo: +def resolve_version_spec_as_ref(repo_id: str, version_spec: int | str) -> GitRefInfo: """ Get the locks for a kernel with the given version spec. The version specifier can be any valid Python version specifier: https://packaging.python.org/en/latest/specifications/version-specifiers/#version-specifiers """ - versions = _get_available_versions(repo_id) - requirement = SpecifierSet(version_spec) - accepted_versions = sorted(requirement.filter(versions.keys())) - - if len(accepted_versions) == 0: - raise ValueError( - f"No version of `{repo_id}` satisfies requirement: {version_spec}" + if isinstance(version_spec, int): + versions = _get_available_versions(repo_id) + ref = versions.get(version_spec, None) + if ref is None: + raise ValueError( + f"Version {version_spec} not found, available versions: {', '.join(sorted(str(v) for v in versions.keys()))}" + ) + return ref + else: + warnings.warn( + "Version specifiers are deprecated, support will be removed in a future `kernels` version,\na concrete version instead." ) + versions_old = _get_available_versions_old(repo_id) + requirement = SpecifierSet(version_spec) + accepted_versions = sorted(requirement.filter(versions_old.keys())) - return versions[accepted_versions[-1]] + if len(accepted_versions) == 0: + raise ValueError( + f"No version of `{repo_id}` satisfies requirement: {version_spec}" + ) + + return versions_old[accepted_versions[-1]] def select_revision_or_version( - repo_id: str, revision: Optional[str], version: Optional[str] + repo_id: str, + *, + revision: Optional[str], + version: Optional[int | str], ) -> str: if revision is not None and version is not None: - raise ValueError("Either a revision or a version must be specified, not both.") - elif revision is None and version is None: - revision = "main" + raise ValueError("Only one of `revision` or `version` must be specified.") + + if revision is not None: + return revision elif version is not None: - revision = resolve_version_spec_as_ref(repo_id, version).target_commit - assert revision is not None - return revision + return resolve_version_spec_as_ref(repo_id, version).target_commit + + return "main" diff --git a/src/kernels/cli.py b/src/kernels/cli.py index 6cb84a9..6a27d43 100644 --- a/src/kernels/cli.py +++ b/src/kernels/cli.py @@ -1,20 +1,17 @@ import argparse import dataclasses import json -import re import sys from pathlib import Path -from huggingface_hub import create_repo, upload_folder, create_branch from kernels.compat import tomllib from kernels.lockfile import KernelLock, get_kernel_locks +from kernels.upload import upload_kernels_dir from kernels.utils import install_kernel, install_kernel_all_variants from .doc import generate_readme_for_kernel -BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-(cpu|cuda|metal|rocm|xpu))") - def main(): parser = argparse.ArgumentParser( @@ -69,11 +66,13 @@ def main(): upload_parser.add_argument( "--repo-id", type=str, + required=True, help="Repository ID to use to upload to the Hugging Face Hub", ) upload_parser.add_argument( "--branch", - type=None, + type=str, + default=None, help="If set, the upload will be made to a particular branch of the provided `repo-id`.", ) upload_parser.add_argument( @@ -169,46 +168,12 @@ def lock_kernels(args): def upload_kernels(args): - # Resolve `kernel_dir` to be uploaded. - kernel_dir = Path(args.kernel_dir).resolve() - - build_dir = None - for candidate in [kernel_dir / "build", kernel_dir]: - variants = [ - variant_path - for variant_path in candidate.glob("torch*") - if BUILD_VARIANT_REGEX.match(variant_path.name) is not None - ] - if variants: - build_dir = candidate - break - if build_dir is None: - raise ValueError( - f"Couldn't find any build variants in: {kernel_dir.absolute()} or {(kernel_dir / 'build').absolute()}" - ) - - repo_id = create_repo( - repo_id=args.repo_id, private=args.private, exist_ok=True - ).repo_id - - if args.branch is not None: - create_branch(repo_id=repo_id, branch=args.branch, exist_ok=True) - - delete_patterns: set[str] = set() - for build_variant in build_dir.iterdir(): - if build_variant.is_dir(): - delete_patterns.add(f"{build_variant.name}/**") - - upload_folder( - repo_id=repo_id, - folder_path=build_dir, - revision=args.branch, - path_in_repo="build", - delete_patterns=list(delete_patterns), - commit_message="Build uploaded using `kernels`.", - allow_patterns=["torch*"], - ) - print(f"✅ Kernel upload successful. Find the kernel in https://hf.co/{repo_id}.") + upload_kernels_dir( + Path(args.kernel_dir).resolve(), + repo_id=args.repo_id, + branch=args.branch, + private=args.private, + ) class _JSONEncoder(json.JSONEncoder): diff --git a/src/kernels/layer/func.py b/src/kernels/layer/func.py index e2f1127..7fc7d1b 100644 --- a/src/kernels/layer/func.py +++ b/src/kernels/layer/func.py @@ -35,9 +35,9 @@ class FuncRepository: The name of the function within the kernel repository. revision (`str`, *optional*, defaults to `"main"`): The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. - version (`str`, *optional*): - The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. - Cannot be used together with `revision`. + version (`int|str`, *optional*): + The kernel version to download as an integer. The `str` variant is deprecated and will be + removed in a future release. Cannot be used together with `revision`. Example: ```python @@ -49,11 +49,11 @@ class FuncRepository: func_name="silu_and_mul", ) - # Reference a layer by version constraint + # Reference a layer by version layer_repo_versioned = FuncRepository( - repo_id="kernels-community/activation", - func_name="silu_and_mul", - version=">=0.0.3,<0.1" + repo_id="kernels-community/relu", + func_name="relu", + version=1 ) ``` """ @@ -64,7 +64,7 @@ def __init__( *, func_name: str, revision: Optional[str] = None, - version: Optional[str] = None, + version: Optional[int | str] = None, ): if revision is not None and version is not None: raise ValueError( @@ -82,7 +82,9 @@ def __init__( @functools.lru_cache() def _resolve_revision(self) -> str: return select_revision_or_version( - repo_id=self._repo_id, revision=self._revision, version=self._version + repo_id=self._repo_id, + revision=self._revision, + version=self._version, ) def load(self) -> Type["nn.Module"]: diff --git a/src/kernels/layer/layer.py b/src/kernels/layer/layer.py index 9965c64..94773ae 100644 --- a/src/kernels/layer/layer.py +++ b/src/kernels/layer/layer.py @@ -14,8 +14,6 @@ Type, ) -from .device import Device -from .globals import _DISABLE_KERNEL_MAPPING, _KERNEL_MAPPING from .._versions import select_revision_or_version from ..utils import ( _get_caller_locked_kernel, @@ -23,8 +21,10 @@ get_kernel, get_local_kernel, ) +from .device import Device +from .globals import _DISABLE_KERNEL_MAPPING, _KERNEL_MAPPING from .mode import Mode -from .repos import _select_repository, RepositoryProtocol +from .repos import RepositoryProtocol, _select_repository if TYPE_CHECKING: from torch import nn @@ -46,9 +46,9 @@ class LayerRepository: The name of the layer within the kernel repository. revision (`str`, *optional*, defaults to `"main"`): The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. - version (`str`, *optional*): - The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. - Cannot be used together with `revision`. + version (`int|str`, *optional*): + The kernel version to download as an integer. The `str` variant is deprecated and will be + removed in a future release. Cannot be used together with `revision`. Example: ```python @@ -75,7 +75,7 @@ def __init__( *, layer_name: str, revision: Optional[str] = None, - version: Optional[str] = None, + version: Optional[int | str] = None, ): if revision is not None and version is not None: raise ValueError( @@ -93,7 +93,9 @@ def __init__( @functools.lru_cache() def _resolve_revision(self) -> str: return select_revision_or_version( - repo_id=self._repo_id, revision=self._revision, version=self._version + repo_id=self._repo_id, + revision=self._revision, + version=self._version, ) def load(self) -> Type["nn.Module"]: diff --git a/src/kernels/lockfile.py b/src/kernels/lockfile.py index 5722bed..fce601b 100644 --- a/src/kernels/lockfile.py +++ b/src/kernels/lockfile.py @@ -29,7 +29,7 @@ def from_json(cls, o: Dict): return cls(repo_id=o["repo_id"], sha=o["sha"], variants=variants) -def get_kernel_locks(repo_id: str, version_spec: str) -> KernelLock: +def get_kernel_locks(repo_id: str, version_spec: int | str) -> KernelLock: """ Get the locks for a kernel with the given version spec. diff --git a/src/kernels/metadata.py b/src/kernels/metadata.py new file mode 100644 index 0000000..e8cd1ec --- /dev/null +++ b/src/kernels/metadata.py @@ -0,0 +1,23 @@ +import json +from dataclasses import dataclass +from pathlib import Path +from typing import List, Optional + + +@dataclass +class Metadata: + python_depends: List[str] + version: Optional[int] + + @staticmethod + def load_from_variant(variant_path: Path) -> "Metadata": + metadata_path = variant_path / "metadata.json" + if metadata_path.exists(): + with open(metadata_path, "r") as f: + metadata_dict = json.load(f) + return Metadata( + python_depends=metadata_dict.get("python-depends", []), + version=metadata_dict.get("version", None), + ) + + return Metadata(version=None, python_depends=[]) diff --git a/src/kernels/upload.py b/src/kernels/upload.py new file mode 100644 index 0000000..920a814 --- /dev/null +++ b/src/kernels/upload.py @@ -0,0 +1,69 @@ +import re +from pathlib import Path +from typing import Optional + +from huggingface_hub import create_branch, create_repo, upload_folder + +from kernels.metadata import Metadata + +BUILD_VARIANT_REGEX = re.compile(r"^(torch\d+\d+|torch-(cpu|cuda|metal|rocm|xpu))") + + +def upload_kernels_dir( + kernel_dir: Path, *, repo_id: str, branch: Optional[str], private: bool +): + kernel_dir = Path(kernel_dir).resolve() + + build_dir = None + variants = None + for candidate in [kernel_dir / "build", kernel_dir]: + variants = [ + variant_path + for variant_path in candidate.glob("torch*") + if BUILD_VARIANT_REGEX.match(variant_path.name) is not None + and (variant_path / "metadata.json").is_file() + ] + if variants: + build_dir = candidate + break + if build_dir is None: + raise ValueError( + f"Couldn't find any build variants in: {kernel_dir.absolute()} or {(kernel_dir / 'build').absolute()}" + ) + + if branch is None: + assert variants is not None + versions = set() + for variant in variants: + metadata = Metadata.load_from_variant(variant) + versions.add(metadata.version) + + if len(versions) > 1: + raise ValueError( + f"Found multiple versions in build variants: {', '.join(str(version) for version in versions)}" + ) + + version = versions.pop() + if version is not None: + branch = f"v{version}" + + repo_id = create_repo(repo_id=repo_id, private=private, exist_ok=True).repo_id + + if branch is not None: + create_branch(repo_id=repo_id, branch=branch, exist_ok=True) + + delete_patterns: set[str] = set() + for build_variant in build_dir.iterdir(): + if build_variant.is_dir(): + delete_patterns.add(f"{build_variant.name}/**") + + upload_folder( + repo_id=repo_id, + folder_path=build_dir, + revision=branch, + path_in_repo="build", + delete_patterns=list(delete_patterns), + commit_message="Build uploaded using `kernels`.", + allow_patterns=["torch*"], + ) + print(f"✅ Kernel upload successful. Find the kernel in: https://hf.co/{repo_id}") diff --git a/src/kernels/utils.py b/src/kernels/utils.py index 2cb3fdb..b6eb328 100644 --- a/src/kernels/utils.py +++ b/src/kernels/utils.py @@ -20,6 +20,7 @@ from kernels._versions import select_revision_or_version from kernels.deps import validate_dependencies from kernels.lockfile import KernelLock, VariantLock +from kernels.metadata import Metadata ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} @@ -129,12 +130,8 @@ def build_variants() -> List[str]: def _import_from_path(module_name: str, variant_path: Path) -> ModuleType: - metadata_path = variant_path / "metadata.json" - if metadata_path.exists(): - with open(metadata_path, "r") as f: - metadata = json.load(f) - deps = metadata.get("python-depends", []) - validate_dependencies(deps, backend()) + metadata = Metadata.load_from_variant(variant_path) + validate_dependencies(metadata.python_depends, backend()) file_path = variant_path / "__init__.py" if not file_path.exists(): @@ -278,7 +275,7 @@ def install_kernel_all_variants( def get_kernel( repo_id: str, revision: Optional[str] = None, - version: Optional[str] = None, + version: Optional[int | str] = None, user_agent: Optional[Union[str, dict]] = None, ) -> ModuleType: """ @@ -292,9 +289,9 @@ def get_kernel( The Hub repository containing the kernel. revision (`str`, *optional*, defaults to `"main"`): The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. - version (`str`, *optional*): - The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. - Cannot be used together with `revision`. + version (`int|str`, *optional*): + The kernel version to download as an integer. The `str` variant is deprecated and will be + removed in a future release. Cannot be used together with `revision`. user_agent (`Union[str, dict]`, *optional*): The `user_agent` info to pass to `snapshot_download()` for internal telemetry. @@ -306,13 +303,13 @@ def get_kernel( import torch from kernels import get_kernel - activation = get_kernel("kernels-community/activation") + activation = get_kernel("kernels-community/relu", version=1) x = torch.randn(10, 20, device="cuda") out = torch.empty_like(x) - result = activation.silu_and_mul(out, x) + result = activation.relu(out, x) ``` """ - revision = select_revision_or_version(repo_id, revision, version) + revision = select_revision_or_version(repo_id, revision=revision, version=version) package_name, variant_path = install_kernel( repo_id, revision=revision, user_agent=user_agent ) @@ -349,7 +346,9 @@ def get_local_kernel(repo_path: Path, package_name: str) -> ModuleType: def has_kernel( - repo_id: str, revision: Optional[str] = None, version: Optional[str] = None + repo_id: str, + revision: Optional[str] = None, + version: Optional[int | str] = None, ) -> bool: """ Check whether a kernel build exists for the current environment (Torch version and compute framework). @@ -359,14 +358,14 @@ def has_kernel( The Hub repository containing the kernel. revision (`str`, *optional*, defaults to `"main"`): The specific revision (branch, tag, or commit) to download. Cannot be used together with `version`. - version (`str`, *optional*): - The kernel version to download. This can be a Python version specifier, such as `">=1.0.0,<2.0.0"`. - Cannot be used together with `revision`. + version (`int|str`, *optional*): + The kernel version to download as an integer. The `str` variant is deprecated and will be + removed in a future release. Cannot be used together with `revision`. Returns: `bool`: `True` if a kernel is available for the current environment. """ - revision = select_revision_or_version(repo_id, revision, version) + revision = select_revision_or_version(repo_id, revision=revision, version=version) package_name = package_name_from_repo_id(repo_id) variant = build_variant() diff --git a/tests/kernel_locking/kernels.lock b/tests/kernel_locking/kernels.lock index 987d160..e115a97 100644 --- a/tests/kernel_locking/kernels.lock +++ b/tests/kernel_locking/kernels.lock @@ -1,80 +1,136 @@ [ { - "repo_id": "kernels-community/activation", - "sha": "83046852be158d525114f68513cd79fd88911b37", + "repo_id": "kernels-community/relu", + "sha": "8c883dab2ee1e812d6bf996ab30cb43e0dcf63ed", "variants": { - "torch27-cxx11-cu118-x86_64-linux": { - "hash": "sha256-e34965c814c4c092fcb634ebadefe82ea9a05b98343f8ebdefa7305dcc05359e", + "torch-ext": { + "hash": "sha256-fa95388531c6280130219f0e73e5daf521116da6c841fa5ab6a190c7994767d8", "hash_type": "git_lfs_concat" }, - "torch27-cxx11-cu126-x86_64-linux": { - "hash": "sha256-5f92b35922b37224a416398a39a29b7e5f1aca1df17d5c69f1b9e9cdb7033561", + "torch210-cxx11-cu126-x86_64-linux": { + "hash": "sha256-e7dc6b003dc4360f7da1bfff2f7d7de7ecf8ceb0015f9356b70bdbefa91c0ebb", + "hash_type": "git_lfs_concat" + }, + "torch210-cxx11-cu128-x86_64-linux": { + "hash": "sha256-a8617f0e9b60e81f95e03201bf4a9c51b2292ef891eaf533e26fa262a7c51713", + "hash_type": "git_lfs_concat" + }, + "torch210-cxx11-cu130-x86_64-linux": { + "hash": "sha256-cfc79ff94189e89d9209861aff2ae6a56829440f1c91e2cb54bdf029faded69f", + "hash_type": "git_lfs_concat" + }, + "torch210-cxx11-rocm70-x86_64-linux": { + "hash": "sha256-c79f25c7a6621e53edf3f2b548e09be2eac1f89d63686c974ba9fd93dee283ca", + "hash_type": "git_lfs_concat" + }, + "torch210-cxx11-rocm71-x86_64-linux": { + "hash": "sha256-c69930b897a3bdef5e1ec553d5206428ce5e25f7e82d69b9cf182a64e381ae4a", "hash_type": "git_lfs_concat" }, - "torch27-cxx11-cu128-aarch64-linux": { - "hash": "sha256-125967cb23bacd2cec443799f184ac08247dfff33f5027e54ee16d3779ca5986", + "torch210-cxx11-xpu20253-x86_64-linux": { + "hash": "sha256-6785b1154659ce5ecf73c86a0e4153daab2c1fa7a5c002c6c8e63a4e1b14e6d4", + "hash_type": "git_lfs_concat" + }, + "torch210-metal-aarch64-darwin": { + "hash": "sha256-99648d1f188fc15e1dcef35932115d36b7044e504e56330810dad77f7970e0ce", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-7429779f50c9fc301b0ea24dba9414d5f2ad1c282bcdf7764562a23c297983d7", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-d49807befdf0c158fa2db633128b1b9e9968de28441190e440ebeb2c78dba68e", "hash_type": "git_lfs_concat" }, "torch27-cxx11-cu128-x86_64-linux": { - "hash": "sha256-496a84c99d7035a1b6f0ea1c026b751c3a2677956f4c1be546d3cc1505a5fdbb", + "hash": "sha256-06ff9a19519870a58d81839b225ff78247c1dc069337f72896e865a840ad9e70", "hash_type": "git_lfs_concat" }, - "torch28-cxx11-cu126-aarch64-linux": { - "hash": "sha256-f0775a30ffa290c90aba3a41037e3ca91edb15b4a9367561fafd5f25455e117a", + "torch27-cxx11-rocm63-x86_64-linux": { + "hash": "sha256-7269c26076bd220067c790afce9d23c793a1fd02ed62270418fa2e4fbca49634", "hash_type": "git_lfs_concat" }, "torch28-cxx11-cu126-x86_64-linux": { - "hash": "sha256-081995e6230f306bdf6111186618794f2411cf0ffd9b4800330df60b4ebe1927", + "hash": "sha256-b60fb4b4bd0ddb35a661ee33177443112a0d2abde762dc03891c23616939cea0", "hash_type": "git_lfs_concat" }, - "torch28-cxx11-cu128-aarch64-linux": { - "hash": "sha256-b937fef62a0c1cd71ab98490b651c473577af209b9a3e2a6b452350283d8812c", + "torch28-cxx11-cu128-x86_64-linux": { + "hash": "sha256-35123770a7175e85e75dc35984308345a826270e2f4c4b6d22a83c32ebfa4388", "hash_type": "git_lfs_concat" }, - "torch28-cxx11-cu128-x86_64-linux": { - "hash": "sha256-a3915686cc58641a3361ece63ab77b33e9d30315dea12547e4bda008d8810a01", + "torch28-cxx11-cu129-x86_64-linux": { + "hash": "sha256-6faf41b72b64b94853b39935e4cdfe62a1242ffce3b0561ebcbd050c39a28b15", "hash_type": "git_lfs_concat" }, - "torch28-cxx11-cu129-aarch64-linux": { - "hash": "sha256-a24dca8e998f88be42491921c9df89d88a6112ca630acd2efc2dd34a64b91fcb", + "torch28-cxx11-rocm63-x86_64-linux": { + "hash": "sha256-520815c9095eba3ec12f689c987043f56626799e4474729467e9250a49fb47b7", "hash_type": "git_lfs_concat" }, - "torch28-cxx11-cu129-x86_64-linux": { - "hash": "sha256-df6c70a70f425db2f68b86561c6f93c5675c1d5e5d058766d88ab17472229907", + "torch28-cxx11-rocm64-x86_64-linux": { + "hash": "sha256-b62f9dcd72bd9b9fba8387a0731f05839f7a0e917eebcc22a4cf07be69130045", "hash_type": "git_lfs_concat" }, - "torch29-cxx11-cu126-aarch64-linux": { - "hash": "sha256-c120011c201072b4cfd70c2ba2d45c2f05337feaf604ddec3c6c4987def33ab3", + "torch28-cxx11-xpu20251-x86_64-linux": { + "hash": "sha256-e848d98d98d545c826326b1341969b41783c3a87ebf49a097598bbf9b9b0b700", "hash_type": "git_lfs_concat" }, - "torch29-cxx11-cu126-x86_64-linux": { - "hash": "sha256-765a7f3279009979be4001a23c5c70e5e6ab9553098d67886731a5275a6d4b32", + "torch28-metal-aarch64-darwin": { + "hash": "sha256-1b464c4ff101158571ad66b1b2060c0397607446cf6fae81d0d4255d0a7a677f", "hash_type": "git_lfs_concat" }, - "torch29-cxx11-cu128-aarch64-linux": { - "hash": "sha256-266d057a9cd82b872a0e02f09ac5e2660fcffcf9a7b7fa1fa8ff33dc19c0f5c2", + "torch29-cxx11-cu126-x86_64-linux": { + "hash": "sha256-c02518dda82caa737eaf4cb9463f7d561e183eb5bd7f0c5b34f2342d43f1647d", "hash_type": "git_lfs_concat" }, "torch29-cxx11-cu128-x86_64-linux": { - "hash": "sha256-6850e594ba4588f289b5904eb88eda5a41870ee20a3bf1586f3268307caf4b53", + "hash": "sha256-390c5c7e068003662ed53cf2dfe0928c24c8b14dec3909e2f17029b1982735b2", + "hash_type": "git_lfs_concat" + }, + "torch29-cxx11-cu130-x86_64-linux": { + "hash": "sha256-b9af45dacb1ed41d5af9616850d665e57c0134e854dda9c9fc6475485cab2c42", "hash_type": "git_lfs_concat" }, - "torch29-cxx11-cu130-aarch64-linux": { - "hash": "sha256-23741b935462b53bdf868f8d1c9c8cff5f02f71ea3b0550df41dc8b030b0b474", + "torch29-cxx11-rocm63-x86_64-linux": { + "hash": "sha256-642f7e0d64f46a4927045720e7785c03d81b5cb53fa13d3dc49adc03792b2d42", "hash_type": "git_lfs_concat" }, - "torch29-cxx11-cu130-x86_64-linux": { - "hash": "sha256-b884ae792dc1eada071f31645add0c2c76d479864f25aebcdd8318b675aaaf29", + "torch29-cxx11-rocm64-x86_64-linux": { + "hash": "sha256-0b926ca76faf56d491c85400756bd95621c3d7ddb349700b000bfdcf7165e461", + "hash_type": "git_lfs_concat" + }, + "torch29-cxx11-xpu20252-x86_64-linux": { + "hash": "sha256-1dfd8c3f3d7919c5afba5ea300e2d0d64ebca1392d0aded5c46d5d63c4053b11", + "hash_type": "git_lfs_concat" + }, + "torch29-metal-aarch64-darwin": { + "hash": "sha256-007a149ec48851db679736715c6296851256899647177d2e6430cca5ecc9b744", "hash_type": "git_lfs_concat" } } }, { - "repo_id": "kernels-community/triton-scaled-mm", - "sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f", + "repo_id": "kernels-test/versions", + "sha": "d50441c818a0fce4879f73ae2f3188274f261927", "variants": { + "torch-cpu": { + "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", + "hash_type": "git_lfs_concat" + }, + "torch-cuda": { + "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", + "hash_type": "git_lfs_concat" + }, + "torch-rocm": { + "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", + "hash_type": "git_lfs_concat" + }, "torch-universal": { - "hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52", + "hash": "sha256-57de77a9bde54f52d0a67eb9e5d259d223cae66963f644ed2b7386f59f7e2d23", + "hash_type": "git_lfs_concat" + }, + "torch-xpu": { + "hash": "sha256-d70e804797597372e50001a1e631e96bb38ccec669f5f0a47d7e9863af293447", "hash_type": "git_lfs_concat" } } diff --git a/tests/kernel_locking/pyproject.toml b/tests/kernel_locking/pyproject.toml index eeffd5a..e50f22e 100644 --- a/tests/kernel_locking/pyproject.toml +++ b/tests/kernel_locking/pyproject.toml @@ -1,3 +1,3 @@ [tool.kernels.dependencies] -"kernels-community/activation" = ">=0.0.2" -"kernels-community/triton-scaled-mm" = ">=0.0.2" +"kernels-community/relu" = 1 +"kernels-test/versions" = 2 diff --git a/tests/kernel_locking_old/kernels.lock b/tests/kernel_locking_old/kernels.lock new file mode 100644 index 0000000..987d160 --- /dev/null +++ b/tests/kernel_locking_old/kernels.lock @@ -0,0 +1,82 @@ +[ + { + "repo_id": "kernels-community/activation", + "sha": "83046852be158d525114f68513cd79fd88911b37", + "variants": { + "torch27-cxx11-cu118-x86_64-linux": { + "hash": "sha256-e34965c814c4c092fcb634ebadefe82ea9a05b98343f8ebdefa7305dcc05359e", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu126-x86_64-linux": { + "hash": "sha256-5f92b35922b37224a416398a39a29b7e5f1aca1df17d5c69f1b9e9cdb7033561", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-aarch64-linux": { + "hash": "sha256-125967cb23bacd2cec443799f184ac08247dfff33f5027e54ee16d3779ca5986", + "hash_type": "git_lfs_concat" + }, + "torch27-cxx11-cu128-x86_64-linux": { + "hash": "sha256-496a84c99d7035a1b6f0ea1c026b751c3a2677956f4c1be546d3cc1505a5fdbb", + "hash_type": "git_lfs_concat" + }, + "torch28-cxx11-cu126-aarch64-linux": { + "hash": "sha256-f0775a30ffa290c90aba3a41037e3ca91edb15b4a9367561fafd5f25455e117a", + "hash_type": "git_lfs_concat" + }, + "torch28-cxx11-cu126-x86_64-linux": { + "hash": "sha256-081995e6230f306bdf6111186618794f2411cf0ffd9b4800330df60b4ebe1927", + "hash_type": "git_lfs_concat" + }, + "torch28-cxx11-cu128-aarch64-linux": { + "hash": "sha256-b937fef62a0c1cd71ab98490b651c473577af209b9a3e2a6b452350283d8812c", + "hash_type": "git_lfs_concat" + }, + "torch28-cxx11-cu128-x86_64-linux": { + "hash": "sha256-a3915686cc58641a3361ece63ab77b33e9d30315dea12547e4bda008d8810a01", + "hash_type": "git_lfs_concat" + }, + "torch28-cxx11-cu129-aarch64-linux": { + "hash": "sha256-a24dca8e998f88be42491921c9df89d88a6112ca630acd2efc2dd34a64b91fcb", + "hash_type": "git_lfs_concat" + }, + "torch28-cxx11-cu129-x86_64-linux": { + "hash": "sha256-df6c70a70f425db2f68b86561c6f93c5675c1d5e5d058766d88ab17472229907", + "hash_type": "git_lfs_concat" + }, + "torch29-cxx11-cu126-aarch64-linux": { + "hash": "sha256-c120011c201072b4cfd70c2ba2d45c2f05337feaf604ddec3c6c4987def33ab3", + "hash_type": "git_lfs_concat" + }, + "torch29-cxx11-cu126-x86_64-linux": { + "hash": "sha256-765a7f3279009979be4001a23c5c70e5e6ab9553098d67886731a5275a6d4b32", + "hash_type": "git_lfs_concat" + }, + "torch29-cxx11-cu128-aarch64-linux": { + "hash": "sha256-266d057a9cd82b872a0e02f09ac5e2660fcffcf9a7b7fa1fa8ff33dc19c0f5c2", + "hash_type": "git_lfs_concat" + }, + "torch29-cxx11-cu128-x86_64-linux": { + "hash": "sha256-6850e594ba4588f289b5904eb88eda5a41870ee20a3bf1586f3268307caf4b53", + "hash_type": "git_lfs_concat" + }, + "torch29-cxx11-cu130-aarch64-linux": { + "hash": "sha256-23741b935462b53bdf868f8d1c9c8cff5f02f71ea3b0550df41dc8b030b0b474", + "hash_type": "git_lfs_concat" + }, + "torch29-cxx11-cu130-x86_64-linux": { + "hash": "sha256-b884ae792dc1eada071f31645add0c2c76d479864f25aebcdd8318b675aaaf29", + "hash_type": "git_lfs_concat" + } + } + }, + { + "repo_id": "kernels-community/triton-scaled-mm", + "sha": "af10d8c1affe8efce93d228c3e6e64ff673d493f", + "variants": { + "torch-universal": { + "hash": "sha256-b843c5f30b52b6c1c56fca28cb0cf453be71d6ce7d308f383dce71a8050f7b52", + "hash_type": "git_lfs_concat" + } + } + } +] \ No newline at end of file diff --git a/tests/kernel_locking_old/pyproject.toml b/tests/kernel_locking_old/pyproject.toml new file mode 100644 index 0000000..eeffd5a --- /dev/null +++ b/tests/kernel_locking_old/pyproject.toml @@ -0,0 +1,3 @@ +[tool.kernels.dependencies] +"kernels-community/activation" = ">=0.0.2" +"kernels-community/triton-scaled-mm" = ">=0.0.2" diff --git a/tests/layer_locking/kernels.lock b/tests/layer_locking/kernels.lock index 806eab6..cd51755 100644 --- a/tests/layer_locking/kernels.lock +++ b/tests/layer_locking/kernels.lock @@ -1,10 +1,26 @@ [ { "repo_id": "kernels-test/versions", - "sha": "dc142fd6c9920c993d32be6358b78957c58681c3", + "sha": "7dc0217a6c2e9eb7d7470aa32554dadbbf619b02", "variants": { + "torch-cpu": { + "hash": "sha256-e1e60c9e2c69f1b60614f5867e22e9147a2419984b4ef77041ae920c7abc4a73", + "hash_type": "git_lfs_concat" + }, + "torch-cuda": { + "hash": "sha256-e1e60c9e2c69f1b60614f5867e22e9147a2419984b4ef77041ae920c7abc4a73", + "hash_type": "git_lfs_concat" + }, + "torch-rocm": { + "hash": "sha256-e1e60c9e2c69f1b60614f5867e22e9147a2419984b4ef77041ae920c7abc4a73", + "hash_type": "git_lfs_concat" + }, "torch-universal": { - "hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c", + "hash": "sha256-57de77a9bde54f52d0a67eb9e5d259d223cae66963f644ed2b7386f59f7e2d23", + "hash_type": "git_lfs_concat" + }, + "torch-xpu": { + "hash": "sha256-e1e60c9e2c69f1b60614f5867e22e9147a2419984b4ef77041ae920c7abc4a73", "hash_type": "git_lfs_concat" } } diff --git a/tests/layer_locking/pyproject.toml b/tests/layer_locking/pyproject.toml index 92a0bc1..1b5a77f 100644 --- a/tests/layer_locking/pyproject.toml +++ b/tests/layer_locking/pyproject.toml @@ -1,2 +1,2 @@ [tool.kernels.dependencies] -"kernels-test/versions" = ">=0.1.0,<0.2.0" +"kernels-test/versions" = 1 diff --git a/tests/layer_locking_old/kernels.lock b/tests/layer_locking_old/kernels.lock new file mode 100644 index 0000000..806eab6 --- /dev/null +++ b/tests/layer_locking_old/kernels.lock @@ -0,0 +1,12 @@ +[ + { + "repo_id": "kernels-test/versions", + "sha": "dc142fd6c9920c993d32be6358b78957c58681c3", + "variants": { + "torch-universal": { + "hash": "sha256-35ce0ccfe68e392cbc06feef72268f4c41a74b9920496a2c6ee8978db7f7c17c", + "hash_type": "git_lfs_concat" + } + } + } +] \ No newline at end of file diff --git a/tests/layer_locking_old/pyproject.toml b/tests/layer_locking_old/pyproject.toml new file mode 100644 index 0000000..92a0bc1 --- /dev/null +++ b/tests/layer_locking_old/pyproject.toml @@ -0,0 +1,2 @@ +[tool.kernels.dependencies] +"kernels-test/versions" = ">=0.1.0,<0.2.0" diff --git a/tests/test_basic.py b/tests/test_basic.py index b8cb2db..5c90ad5 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -129,7 +129,8 @@ def test_has_kernel(kernel_exists): assert has_kernel(repo_id, revision=revision) == kernel -def test_version(): +def test_version_old(): + # Remove once we drop support for version specs. kernel = get_kernel("kernels-test/versions") assert kernel.version() == "0.2.0" kernel = get_kernel("kernels-test/versions", version="<1.0.0") @@ -142,12 +143,24 @@ def test_version(): with pytest.raises(ValueError, match=r"No version.*satisfies requirement"): get_kernel("kernels-test/versions", version=">0.2.0") - with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"): + with pytest.raises(ValueError, match=r"Only one of"): kernel = get_kernel( "kernels-test/versions", revision="v0.1.0", version="<1.0.0" ) +def test_version(): + kernel = get_kernel("kernels-test/versions", version=1) + assert kernel.version() == "1" + kernel = get_kernel("kernels-test/versions", version=2) + assert kernel.version() == "2" + + with pytest.raises( + ValueError, match="Version 0 not found, available versions: 1, 2.*" + ): + kernel = get_kernel("kernels-test/versions", version=0) + + @pytest.mark.cuda_only def test_universal_kernel(universal_kernel): torch.manual_seed(0) diff --git a/tests/test_kernel_locking.py b/tests/test_kernel_locking.py index 2cbe905..24de3e5 100644 --- a/tests/test_kernel_locking.py +++ b/tests/test_kernel_locking.py @@ -29,11 +29,24 @@ def test_download_all_hash_validation(): download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) +def test_download_all_hash_validation_old(): + project_dir = Path(__file__).parent / "kernel_locking_old" + download_kernels(DownloadArgs(all_variants=True, project_dir=project_dir)) + + @pytest.mark.cuda_only def test_load_locked(): project_dir = Path(__file__).parent / "kernel_locking" # Also validates that hashing works correctly. download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir)) + load_kernel("kernels-community/relu", lockfile=project_dir / "kernels.lock") + + +@pytest.mark.cuda_only +def test_load_locked_old(): + project_dir = Path(__file__).parent / "kernel_locking_old" + # Also validates that hashing works correctly. + download_kernels(DownloadArgs(all_variants=False, project_dir=project_dir)) load_kernel("kernels-community/activation", lockfile=project_dir / "kernels.lock") @@ -47,6 +60,31 @@ def forward(self) -> str: version = Version() + with use_kernel_mapping( + { + "Version": { + device: LockedLayerRepository( + repo_id="kernels-test/versions", + layer_name="Version", + lockfile=project_dir / "kernels.lock", + ) + }, + } + ): + version = kernelize(version, device=device, mode=Mode.INFERENCE) + assert version() == "1" + + +def test_layer_locked_old(device): + project_dir = Path(__file__).parent / "layer_locking_old" + + @use_kernel_forward_from_hub("Version") + class Version(nn.Module): + def forward(self) -> str: + return "0.0.0" + + version = Version() + with use_kernel_mapping( { "Version": { @@ -79,7 +117,43 @@ def forward(self) -> str: model = Version() - print(model.version.forward) + with use_kernel_mapping( + { + "version": { + device: LockedFuncRepository( + repo_id="kernels-test/versions", + func_name="version", + lockfile=project_dir / "kernels.lock", + ) + }, + } + ): + model = kernelize(model, device=device, mode=Mode.INFERENCE) + + assert version() == "1" + + with use_kernel_mapping({"version": {}}): + model = kernelize(model, mode=Mode.INFERENCE, device=device) + + assert version() == "0.0.0" + + +def test_func_locked_old(device): + project_dir = Path(__file__).parent / "layer_locking_old" + + @use_kernel_func_from_hub("version") + def version(): + return "0.0.0" + + class Version(nn.Module): + def __init__(self): + super().__init__() + self.version = version + + def forward(self) -> str: + return self.version() + + model = Version() with use_kernel_mapping( { @@ -96,11 +170,7 @@ def forward(self) -> str: assert version() == "0.1.1" - print(model.version.forward) - with use_kernel_mapping({"version": {}}): model = kernelize(model, mode=Mode.INFERENCE, device=device) assert version() == "0.0.0" - - print(model.version.forward) diff --git a/tests/test_layer.py b/tests/test_layer.py index c08e470..6cf0179 100644 --- a/tests/test_layer.py +++ b/tests/test_layer.py @@ -1053,7 +1053,7 @@ def test_kernel_modes_cross_fallback(): assert linear.n_calls == 2 -def test_layer_versions(device): +def test_layer_versions_old(device): @use_kernel_forward_from_hub("Version") class Version(nn.Module): def forward(self) -> str: @@ -1143,3 +1143,83 @@ def forward(self) -> str: } } ) + + +def test_layer_versions(device): + @use_kernel_forward_from_hub("Version") + class Version(nn.Module): + def forward(self) -> str: + return "0.0.0" + + version = Version() + + with use_kernel_mapping( + { + "Version": { + Device(type=device): LayerRepository( + repo_id="kernels-test/versions", + layer_name="Version", + ) + } + } + ): + version = kernelize(version, device=device, mode=Mode.INFERENCE) + assert version() == "0.2.0" + + with use_kernel_mapping( + { + "Version": { + Device(type=device): LayerRepository( + repo_id="kernels-test/versions", + layer_name="Version", + version=1, + ) + } + } + ): + version = kernelize(version, device=device, mode=Mode.INFERENCE) + assert version() == "1" + + with use_kernel_mapping( + { + "Version": { + Device(type=device): LayerRepository( + repo_id="kernels-test/versions", + layer_name="Version", + version=2, + ) + } + } + ): + version = kernelize(version, device=device, mode=Mode.INFERENCE) + assert version() == "2" + + with use_kernel_mapping( + { + "Version": { + Device(type=device): LayerRepository( + repo_id="kernels-test/versions", + layer_name="Version", + version=0, + ) + } + } + ): + with pytest.raises( + ValueError, match=r"Version 0 not found, available versions: 1, 2.*" + ): + kernelize(version, device=device, mode=Mode.INFERENCE) + + with pytest.raises(ValueError, match=r"Either a revision or a version.*not both"): + use_kernel_mapping( + { + "Version": { + Device(type=device): LayerRepository( + repo_id="kernels-test/versions", + layer_name="Version", + revision="v0.1.0", + version=1, + ) + } + } + ) From 537e8e6acf12b3269fce1a2fb037291aab9c15fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Dani=C3=ABl=20de=20Kok?= Date: Thu, 15 Jan 2026 12:21:37 +0000 Subject: [PATCH 2/2] Documtenation updates --- README.md | 2 +- docs/source/basic-usage.md | 9 +++------ docs/source/layers.md | 3 ++- examples/basic.py | 2 +- src/kernels/layer/kernelize.py | 9 ++++++--- src/kernels/layer/layer.py | 8 +------- 6 files changed, 14 insertions(+), 19 deletions(-) diff --git a/README.md b/README.md index 8eaed0f..3fdc373 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ import torch from kernels import get_kernel # Download optimized kernels from the Hugging Face hub -activation = get_kernel("kernels-community/activation") +activation = get_kernel("kernels-community/activation", version=1) # Random tensor x = torch.randn((10, 10), dtype=torch.float16, device="cuda") diff --git a/docs/source/basic-usage.md b/docs/source/basic-usage.md index bd520b6..1053dec 100644 --- a/docs/source/basic-usage.md +++ b/docs/source/basic-usage.md @@ -23,13 +23,10 @@ print(y) This fetches version `1` of the kernel `kernels-community/activation`. Kernels are versioned using a major version number. Using `version=1` will -get the latest kernel build from the `v1` branch. The kernel version is -bumped is bumped in the following circumstances: +get the latest kernel build from the `v1` branch. -* The kernel API changes in an incompatible way. -* Support for an older PyTorch version is dropped. - -In this way, you can ensure that your code will continue to work. +Kernels within a version branch must never break the API or remove builds +for older PyTorch versions. This ensures that your code will continue to work. ## Checking Kernel Availability diff --git a/docs/source/layers.md b/docs/source/layers.md index 6650fc3..254133d 100644 --- a/docs/source/layers.md +++ b/docs/source/layers.md @@ -159,10 +159,12 @@ kernel_layer_mapping = { "cuda": LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", + version=1, ), "rocm": LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", + version=1, ) } } @@ -364,7 +366,6 @@ with use_kernel_mapping( repo_path="/home/daniel/kernels/activation", package_name="activation", layer_name="SiluAndMul", - version=1, ) } }, diff --git a/examples/basic.py b/examples/basic.py index b814bc7..9106e6e 100644 --- a/examples/basic.py +++ b/examples/basic.py @@ -5,7 +5,7 @@ print("Starting examples/basic.py demo") # Download optimized kernels from the Hugging Face hub -activation = get_kernel("kernels-community/activation") +activation = get_kernel("kernels-community/activation", version=1) print("Activation kernel fetched") diff --git a/src/kernels/layer/kernelize.py b/src/kernels/layer/kernelize.py index 545595a..9ad23fe 100644 --- a/src/kernels/layer/kernelize.py +++ b/src/kernels/layer/kernelize.py @@ -63,6 +63,7 @@ def use_kernel_mapping( "cuda": LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", + version=1 ) } } @@ -131,9 +132,9 @@ def register_kernel_mapping( kernel_layer_mapping = { "LlamaRMSNorm": { "cuda": LayerRepository( - repo_id="kernels-community/activation", - layer_name="RmsNorm", - revision="layers", + repo_id="kernels-community/layer_norm", + layer_name="LlamaRMSNorm", + version=1, ), }, } @@ -146,10 +147,12 @@ def register_kernel_mapping( Mode.TRAINING: LayerRepository( repo_id="username/training-kernels", layer_name="TrainingAttention" + version=1, ), Mode.INFERENCE: LayerRepository( repo_id="username/inference-kernels", layer_name="FastAttention" + version=1, ), } } diff --git a/src/kernels/layer/layer.py b/src/kernels/layer/layer.py index 94773ae..25115a1 100644 --- a/src/kernels/layer/layer.py +++ b/src/kernels/layer/layer.py @@ -58,13 +58,7 @@ class LayerRepository: layer_repo = LayerRepository( repo_id="kernels-community/activation", layer_name="SiluAndMul", - ) - - # Reference a layer by version constraint - layer_repo_versioned = LayerRepository( - repo_id="kernels-community/activation", - layer_name="SiluAndMul", - version=">=0.0.3,<0.1" + version=1, ) ``` """