Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions API.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,15 @@ Or you can use the library:
torchruntime.install(["torch", "torchvision<0.20"])
```

On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`).

## Test torch
Run:
`python -m torchruntime test`

To specifically verify `torch.compile` / Triton:
`python -m torchruntime test compile`

## Get device info
You can use the device database built into `torchruntime` for your projects:
```py
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,8 @@ Supports Windows, Linux, and Mac.

This will install `torch`, `torchvision`, and `torchaudio`, and will decide the variant based on the user's OS, GPU manufacturer and GPU model number. See [customizing packages](#customizing-packages) for more options.

On Windows CUDA, Linux ROCm (6.x+), and Linux XPU, this also installs the appropriate Triton package to enable `torch.compile` (`triton-windows`, `pytorch-triton-rocm`, or `pytorch-triton-xpu`).

**Tip:** You can also add the `--uv` flag to install packages using [uv](https://docs.astral.sh/uv/) (instead of `pip`). For e.g. `python -m torchruntime install --uv`

### Step 2. Configure torch
Expand All @@ -42,7 +44,7 @@ torchruntime.configure()
```

### (Optional) Step 3. Test torch
Run `python -m torchruntime test` to run a set of tests to check whether the installed version of torch is working correctly.
Run `python -m torchruntime test` to run a set of tests to check whether the installed version of torch is working correctly (including a `torch.compile` / Triton check on CUDA/XPU systems). You can also run `python -m torchruntime test compile` to run only the compile check.

## Customizing packages
By default, `python -m torchruntime install` will install the latest available `torch`, `torchvision` and `torchaudio` suitable on the user's platform.
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import setuptools

setuptools.setup(
install_requires=[],
install_requires=["packaging"],
)
40 changes: 37 additions & 3 deletions tests/test_installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,57 @@ def test_cpu_platform():
assert result == [packages]


def test_cuda_platform():
def test_cuda_platform(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
result = get_install_commands("cu112", packages)
expected_url = "https://download.pytorch.org/whl/cu112"
assert result == [packages + ["--index-url", expected_url]]


def test_cuda_nightly_platform():
def test_cuda_platform_windows_installs_triton(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch", "torchvision"]
result = get_install_commands("cu112", packages)
expected_url = "https://download.pytorch.org/whl/cu112"
assert result == [packages + ["--index-url", expected_url], ["triton-windows"]]


def test_cuda_nightly_platform(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
result = get_install_commands("nightly/cu112", packages)
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
assert result == [packages + ["--index-url", expected_url]]


def test_cuda_nightly_platform_windows_installs_triton(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch", "torchvision"]
result = get_install_commands("nightly/cu112", packages)
expected_url = "https://download.pytorch.org/whl/nightly/cu112"
assert result == [packages + ["--index-url", expected_url], ["triton-windows"]]


def test_rocm_platform():
packages = ["torch", "torchvision"]
result = get_install_commands("rocm4.2", packages)
expected_url = "https://download.pytorch.org/whl/rocm4.2"
assert result == [packages + ["--index-url", expected_url]]


def test_rocm_platform_linux_installs_triton(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Linux")
packages = ["torch", "torchvision"]
result = get_install_commands("rocm6.2", packages)
expected_url = "https://download.pytorch.org/whl/rocm6.2"
triton_index_url = "https://download.pytorch.org/whl"
assert result == [
packages + ["--index-url", expected_url],
["pytorch-triton-rocm", "--index-url", triton_index_url],
]


def test_xpu_platform_windows_with_torch_only(monkeypatch):
monkeypatch.setattr("torchruntime.installer.os_name", "Windows")
packages = ["torch"]
Expand All @@ -60,7 +90,11 @@ def test_xpu_platform_linux(monkeypatch):
packages = ["torch", "torchvision"]
result = get_install_commands("xpu", packages)
expected_url = "https://download.pytorch.org/whl/test/xpu"
assert result == [packages + ["--index-url", expected_url]]
triton_index_url = "https://download.pytorch.org/whl"
assert result == [
packages + ["--index-url", expected_url],
["pytorch-triton-xpu", "--index-url", triton_index_url],
]


def test_directml_platform():
Expand Down
5 changes: 3 additions & 2 deletions torchruntime/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def print_usage(entry_command: str):

Commands:
install Install PyTorch packages
test [subcommand] Run tests (subcommands: all, devices, math, functions)
test [subcommand] Run tests (subcommands: all, import, devices, compile, math, functions)
--help Show this help message

Examples:
Expand All @@ -20,10 +20,11 @@ def print_usage(entry_command: str):
{entry_command} install --uv torch>=2.0.0 torchaudio
{entry_command} install torch==2.1.* torchvision>=0.16.0 torchaudio==2.1.0

{entry_command} test # Runs all tests (import, devices, math, functions)
{entry_command} test # Runs all tests (import, devices, compile, math, functions)
{entry_command} test all # Same as above
{entry_command} test import # Test only import
{entry_command} test devices # Test only devices
{entry_command} test compile # Test torch.compile (Triton)
{entry_command} test math # Test only math
{entry_command} test functions # Test only functions

Expand Down
21 changes: 19 additions & 2 deletions torchruntime/installer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
PIP_PREFIX = [sys.executable, "-m", "pip", "install"]
CUDA_REGEX = re.compile(r"^(nightly/)?cu\d+$")
ROCM_REGEX = re.compile(r"^(nightly/)?rocm\d+\.\d+$")
ROCM_VERSION_REGEX = re.compile(r"^(?:nightly/)?rocm(?P<major>\d+)\.(?P<minor>\d+)$")


def get_install_commands(torch_platform, packages):
Expand Down Expand Up @@ -43,6 +44,9 @@ def get_install_commands(torch_platform, packages):
- For "xpu" on Windows, if torchvision or torchaudio are included, the function switches to nightly builds.
- For "directml", the "torch-directml" package is returned as part of the installation commands.
- For "ipex", the "intel-extension-for-pytorch" package is returned as part of the installation commands.
- For Windows CUDA, the function also installs "triton-windows" (for torch.compile and Triton kernels).
- For Linux ROCm 6.x, the function also installs "pytorch-triton-rocm".
- For Linux XPU, the function also installs "pytorch-triton-xpu".
"""
if not packages:
packages = ["torch", "torchaudio", "torchvision"]
Expand All @@ -52,7 +56,17 @@ def get_install_commands(torch_platform, packages):

if CUDA_REGEX.match(torch_platform) or ROCM_REGEX.match(torch_platform):
index_url = f"https://download.pytorch.org/whl/{torch_platform}"
return [packages + ["--index-url", index_url]]
cmds = [packages + ["--index-url", index_url]]

if os_name == "Windows" and CUDA_REGEX.match(torch_platform):
cmds.append(["triton-windows"])

if os_name == "Linux" and ROCM_REGEX.match(torch_platform):
match = ROCM_VERSION_REGEX.match(torch_platform)
if match and int(match.group("major")) >= 6:
cmds.append(["pytorch-triton-rocm", "--index-url", "https://download.pytorch.org/whl"])

return cmds

if torch_platform == "xpu":
if os_name == "Windows" and ("torchvision" in packages or "torchaudio" in packages):
Expand All @@ -65,7 +79,10 @@ def get_install_commands(torch_platform, packages):
else:
index_url = f"https://download.pytorch.org/whl/test/{torch_platform}"

return [packages + ["--index-url", index_url]]
cmds = [packages + ["--index-url", index_url]]
if os_name == "Linux":
cmds.append(["pytorch-triton-xpu", "--index-url", "https://download.pytorch.org/whl"])
return cmds

if torch_platform == "directml":
return [["torch-directml"], packages]
Expand Down
77 changes: 75 additions & 2 deletions torchruntime/utils/torch_test/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import importlib.util
import platform
import time

from ..torch_device_utils import get_installed_torch_platform, get_device_count, get_device_name, get_device
from ..torch_device_utils import get_device, get_device_count, get_device_name, get_installed_torch_platform


def test(subcommand):
Expand All @@ -16,7 +18,7 @@ def test(subcommand):


def test_all():
for fn in (test_import, test_devices, test_math, test_functions):
for fn in (test_import, test_devices, test_compile, test_math, test_functions):
fn()
print("")

Expand Down Expand Up @@ -101,3 +103,74 @@ def test_functions():
t.run_all_tests()

print("--- / FUNCTIONAL TEST ---")


def test_compile():
print("--- COMPILE TEST ---")

try:
import torch
except ImportError:
print("torch.compile: SKIPPED (torch not installed)")
print("--- / COMPILE TEST ---")
return

if not hasattr(torch, "compile"):
print("torch.compile: SKIPPED (requires torch>=2.0)")
print("--- / COMPILE TEST ---")
return

torch_platform_name, _ = get_installed_torch_platform()
if torch_platform_name not in ("cuda", "xpu"):
print(f"torch.compile: SKIPPED (unsupported backend: {torch_platform_name})")
print("--- / COMPILE TEST ---")
return

if importlib.util.find_spec("triton") is None:
print("triton: NOT INSTALLED")
else:
print("triton: installed")

device = get_device(0)
print("On torch device:", device)

def f(x):
return x * 2 + 1

try:
compiled_f = torch.compile(f)
x = torch.randn((1024,), device=device)
y = compiled_f(x)
expected = f(x)
if not torch.allclose(y, expected):
print("torch.compile: FAILED (output mismatch)")
else:
if torch_platform_name == "cuda":
torch.cuda.synchronize()
if torch_platform_name == "xpu" and hasattr(torch, "xpu") and hasattr(torch.xpu, "synchronize"):
torch.xpu.synchronize()
print("torch.compile: PASSED")
except Exception as e:
print(f"torch.compile: FAILED ({type(e).__name__}: {e})")

hint = None
os_name = platform.system()
if torch_platform_name == "cuda" and os_name == "Windows":
hint = "pip install triton-windows (or: python -m torchruntime install)"
elif torch_platform_name == "cuda" and os_name == "Linux":
if getattr(torch.version, "hip", None):
hint = (
"pip install pytorch-triton-rocm --index-url https://download.pytorch.org/whl "
"(or: python -m torchruntime install)"
)
elif torch_platform_name == "xpu" and os_name == "Linux":
hint = (
"pip install pytorch-triton-xpu --index-url https://download.pytorch.org/whl "
"(or: python -m torchruntime install)"
)

if hint:
print("If this failed due to Triton, try:")
print(" ", hint)

print("--- / COMPILE TEST ---")