Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
0abfffe
PyTorch version upgrade: tested on single-operator tests
wok1909 Sep 24, 2025
b7a275e
[Test] Add torch.no_grad(), change to use torch.nn.ReLU, fuion off
wok1909 Sep 24, 2025
5c5e61c
[Implement] Hook and GuardImpl for extension device
wok1909 Nov 6, 2025
74704b8
[CI] Change the trigger condition
YWHyuk Jan 6, 2026
d3f3298
[CI] Use CMake 3 to build pytorchsim
YWHyuk Jan 6, 2026
0763363
[CI] Seperate base image
YWHyuk Jan 6, 2026
4591403
[Fix] PyTorch2.8 support (WIP)
YWHyuk Jan 7, 2026
b9d4144
[Fix] Use official prologue fusion path
YWHyuk Jan 7, 2026
9abc060
[Fix] Don't split a reduce kernel
YWHyuk Jan 7, 2026
2c7264b
[Fix] Add a missing reduction fusion condition
YWHyuk Jan 7, 2026
b951b95
[Fix] update indirect_index interface for v2.8
YWHyuk Jan 7, 2026
c6ba98c
[Fix] Allow cpp kernel code in the wrapper function
YWHyuk Jan 7, 2026
fd07eda
[Ops] Use V.kernel instead of argument passing
YWHyuk Jan 8, 2026
4bed31b
[Fix] Set epilogue fusoin condition
YWHyuk Jan 8, 2026
758b5b3
[Fix] Support Identity indexing + Fix wrapper codegen
YWHyuk Jan 8, 2026
a7ab604
[Fix] Keep contextvar after reset()
YWHyuk Jan 8, 2026
cd52f57
[Frontend] Add decompsition of default attetnion
YWHyuk Jan 8, 2026
08e0c8b
[Fix] Add missing case
YWHyuk Jan 8, 2026
1d1508a
[Test] Add GQA test file
YWHyuk Jan 8, 2026
862ba44
[Fix+Log] Change logging system + Fix meta_code interface
YWHyuk Jan 9, 2026
75207a4
[Test] Wrap softmax module
YWHyuk Jan 9, 2026
8df5fef
[Log] Add progress bar for auto-tuning
YWHyuk Jan 9, 2026
d7c16b1
[Test/MoE] Disable compiling sparse dispatcher
YWHyuk Jan 9, 2026
c88cabc
[Fix] Support identity in the dram_stride extraction
YWHyuk Jan 12, 2026
67612bb
[Fix] index to float casting
YWHyuk Jan 12, 2026
50ceb58
[Fix] Change vlane_split_axis in case of group-dim
YWHyuk Jan 12, 2026
319fd6c
[Frontend] Fix any operation codegen
YWHyuk Jan 13, 2026
c223258
[Decompose] Use F.softmax for decomposed SDPA
YWHyuk Jan 13, 2026
07be94b
[Frontend] Add recompiliation for ModularIndexing
YWHyuk Jan 13, 2026
e999bfc
[Test] Fix minor bugs in the test folder
YWHyuk Jan 13, 2026
d747e7e
[Log] Add progress bar in spike simulation
YWHyuk Jan 13, 2026
b49b679
[Fix] Use extraction for vlane_offset + Register extract op
YWHyuk Jan 15, 2026
729b999
[Tests/Diffusion] Add embedding test case
YWHyuk Jan 15, 2026
7fa8d54
[Tests/MoE] Add patch to avoid dynamo bug
YWHyuk Jan 15, 2026
7919094
[Fix] Change wrong TORCHSIM_DUMP_PATH usage
YWHyuk Jan 15, 2026
1ca3348
[Scheduler] Validate pytorchsim_timing_mode != 0 in Scheduler constru…
YWHyuk Jan 15, 2026
8df3bee
[Fix] Move rename_indexing before load cacheing
YWHyuk Jan 15, 2026
ea79ad0
[Fusion] Fix template codegen + Add custom fusion hook
YWHyuk Jan 16, 2026
0c6175f
[Template] Fix template fusion codegen
YWHyuk Jan 19, 2026
a90f114
[Fix] Fusion axis mechanism change
YWHyuk Jan 20, 2026
78613ad
[Test] Fix syntax error in experiment scripts
YWHyuk Jan 22, 2026
21d08f2
[CI] Change base image for OpenReg build
YWHyuk Jan 22, 2026
24e67ed
[OpenReg] Use OpenReg style Custom device
YWHyuk Jan 22, 2026
468f414
[Device] Use torch.device(npu)
YWHyuk Jan 22, 2026
a625409
[SDPA] Use math as a default
YWHyuk Jan 23, 2026
a053314
[AMP] Add amp interface for OpenReg style device
YWHyuk Jan 23, 2026
eda34ff
[Tests] Cleanup unnecessary code in tests
YWHyuk Jan 23, 2026
3f8b866
[Cleanup] Remove built libraries
YWHyuk Jan 23, 2026
174e10f
[Device] Rename deivce PyTorchSimDevice2 to PyTorchSimDevice
YWHyuk Jan 23, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions .github/workflows/docker-base-image-2-8.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ name: Docker Base Image CI (PyTorch 2.8)

on:
push:
branches: [ "base" ]
branches: [ "base_v2.8" ]
workflow_dispatch:
repository_dispatch:
types: [ build_base ]
Expand Down Expand Up @@ -63,7 +63,7 @@ jobs:
file: ./Dockerfile.base
push: true
build-args: |
PYTORCH_IMAGE=pytorch/pytorch:2.8.0-cuda12.6-cudnn9-runtime
PYTORCH_IMAGE=pytorch/pytorch:2.8.0-cuda12.6-cudnn9-devel
GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }}
LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }}
SPIKE_ASSET_ID=${{ env.SPIKE_ASSET_ID }}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/docker-image-2-8.yml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name: Docker image CI (PyTorch 2.8)

on:
pull_request:
push:
branches: [ "torch_v2.8" ]
workflow_dispatch:

Expand Down
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
__pycache__/
TOGSim/build/
.vscode
*.txt
*.ipynb_checkpoints
output
togsim_results/*
Expand Down
5 changes: 4 additions & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,7 @@ RUN cd PyTorchSim/TOGSim && \
cd build && \
conan install .. --build=missing && \
cmake .. && \
make -j$(nproc)
make -j$(nproc)

RUN cd PyTorchSim/PyTorchSimDevice && \
python -m pip install --no-build-isolation -e .
2 changes: 1 addition & 1 deletion Dockerfile.base
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ RUN apt -y update && \
python3-dev python-is-python3 libboost-all-dev \
libhdf5-serial-dev python3-pydot libpng-dev libelf-dev pkg-config pip \
python3-venv black libssl-dev libasan5 libubsan1 curl device-tree-compiler wget ninja-build && \
pip install onnx matplotlib scikit-learn pydot tabulate && pip install --user conan==1.56.0 && rm -rf /var/lib/apt/lists/*
pip install onnx matplotlib scikit-learn pydot tabulate && pip install --user conan==1.56.0 cmake==3.26.4 && rm -rf /var/lib/apt/lists/*

# Download RISC-V tool chain
RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2023.12.14/riscv64-glibc-ubuntu-22.04-llvm-nightly-2023.12.14-nightly.tar.gz && \
Expand Down
44 changes: 44 additions & 0 deletions PyTorchSimDevice/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
cmake_minimum_required(VERSION 3.18 FATAL_ERROR)

project(TORCH_OPENREG CXX C)

include(GNUInstallDirs)
include(CheckCXXCompilerFlag)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_C_STANDARD 11)
set(CMAKE_CXX_EXTENSIONS OFF)

set(CMAKE_EXPORT_COMPILE_COMMANDS ON)
set(CMAKE_SKIP_BUILD_RPATH FALSE)
set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE)
set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE)
set(CMAKE_CXX_VISIBILITY_PRESET hidden)

set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE)

if(APPLE)
set(CMAKE_INSTALL_RPATH "@loader_path/lib;@loader_path")
elseif(UNIX)
set(CMAKE_INSTALL_RPATH "$ORIGIN/lib:$ORIGIN")
elseif(WIN32)
set(CMAKE_INSTALL_RPATH "")
endif()
set(CMAKE_INSTALL_LIBDIR lib)
set(CMAKE_INSTALL_MESSAGE NEVER)

set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch)
find_package(Torch REQUIRED)

if(DEFINED PYTHON_INCLUDE_DIR)
include_directories(${PYTHON_INCLUDE_DIR})
else()
message(FATAL_ERROR "Cannot find Python directory")
endif()

include_directories(${CMAKE_CURRENT_SOURCE_DIR})
include(${PROJECT_SOURCE_DIR}/cmake/TorchPythonTargets.cmake)

add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg)
add_subdirectory(${PROJECT_SOURCE_DIR}/csrc)
add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc)
175 changes: 175 additions & 0 deletions PyTorchSimDevice/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# PyTorch OpenReg

## Background

The third-party device integration mechanism based on PrivateUse1 has become the official mainstream method for new backends to integrate with PyTorch. Ensuring the availability of this mechanism is crucial for enriching PyTorch's hardware ecosystem.

**Note:**

The goal of `torch_openreg` is **not to implement a fully functional, high-performance PyTorch backend**, but to serve as a **minimalist reference implementation for mechanism verification**.

### Purpose

- **Test Backend**: To serve as an in-tree test backend for PrivateUse1, ensuring quality stability through CI/CD.
- **Integration Example**: To serve as a reference example for new backend integration.
- **Integration Documentation**: To provide module-level integration documentation that corresponds with the code.

### Design Principles

- **Minimality Principle**: The fundamental goal is to enable/verify all integration paths/mechanisms for a new backend to integrate to PyTorch. All functions follow a "just right" strategy to ensure the correctness of relevant integration capabilities.
- **Authenticity Principle**: To complete the OpenReg integration in the same way a real accelerator backend would integrate with PyTorch.

## Directory Structure

```shell
torch_openreg/
├── CMakeLists.txt
├── csrc
│ ├── aten
│ │ ├── native
│ │ │ ├── Extra.cpp
│ │ │ ├── Minimal.cpp
│ │ │ └── ...
│ │ ├── OpenRegExtra.cpp
│ │ └── OpenRegMinimal.cpp
│ ├── CMakeLists.txt
│ └── runtime
│ ├── OpenRegDeviceAllocator.cpp
│ ├── OpenRegDeviceAllocator.h
│ ├── OpenRegFunctions.cpp
│ ├── OpenRegFunctions.h
│ ├── OpenRegGenerator.cpp
│ ├── OpenRegGenerator.h
│ ├── OpenRegGuard.cpp
│ ├── OpenRegGuard.h
│ ├── OpenRegHooks.cpp
│ ├── OpenRegHooks.h
│ ├── OpenRegHostAllocator.cpp
│ ├── OpenRegHostAllocator.h
│ └── ...
├── pyproject.toml
├── README.md
├── setup.py
├── third_party
│ └── openreg
└── torch_openreg
├── csrc
│ ├── CMakeLists.txt
│ ├── Module.cpp
│ └── stub.c
├── __init__.py
└── openreg
├── __init__.py
├── meta.py
└── random.py
```

**Dependencies**:

```mermaid
graph LR
A[Python]
B[_C.so]
C[libtorch_bindings.so]
D[libtorch_openreg.so]
E[libopenreg.so]

A --> B --> C --> D --> E
```

There are 4 DSOs in torch_openreg, and the dependencies between them are as follows:

- `_C.so`:
- **sources**: torch_openreg/csrc/stub.c
- **description**: Python C module entry point.
- `libtorch_bindings.so`: The bridging code between Python and C++ should go here.
- **sources**: torch_openreg/csrc
- **description**: A thin glue layer between Python and C++.
- `libtorch_openreg.so`: All core implementations should go here.
- **sources**: csrc
- **description**: All core functionality, such as device runtime, operators, etc.
- `libopenreg.so`: A DSO that uses the CPU to emulate a CUDA-like device, you can ignore it.
- **sources**: third_party/openreg
- **description**: Provides low-level device functionality similar to libcudart.so.

**Key Directories**:

- `csrc/`: Core device implementation, including operator registration, runtime, etc.
- `csrc/aten/`: Operator registration
- `csrc/aten/native/`: Specific operator implementations for the OpenReg device.
- `csrc/aten/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion).
- `csrc/aten/OpenRegExtra.cpp`: Implementations for other types of operators.
- `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc.
- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU.
- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings).
- `torch_openreg/csrc/`: Python C++ binding code.
- `torch_openreg/openreg/`: Python API.

## Currently Implemented Features

### Operator Registration

- Operator Implementation

- Register for builtin PyTorch Operators
- `TORCH_LIBRARY_IMPL` form: See `empty.memory_format
- `STUB` form: See `abs_stub`
- Register for custom operators
- Schema Registration: See `custom_abs`
- Kernel Registration: See `custom_abs`
- Fallback Registration for `AutogradPriavateUse1`: See `custom_abs`
- Meta Registration: See `custom_abs`
- `torch.autograd.Function`: See `custom_autograd_fn_aliasing`
- Register for fallback
- Per-operator Fallback: See `sub.Tensor`
- Global Fallback: See `wrapper_cpu_fallback`

## Installation and Usage

### Installation

```python
pip3 install --no-build-isolation -e . # for develop
pip3 install --no-build-isolation . # for install
```

### Usage Example

After installation, you can use the `openreg` device in Python just like any other regular device.

```python
import torch
import torch_openreg

if not torch.openreg.is_available():
print("OpenReg backend is not available in this build.")
exit()

print("OpenReg backend is available!")

device = torch.device("openreg")

x = torch.tensor([[1., 2.], [3., 4.]], device=device)
y = x + 2
print("Result y:\n", y)
print(f"Device of y: {y.device}")

z = y.cpu()
print("Result z:\n", z)
print(f"Device of z: {z.device}")
```

## Future Plans

- **Enhance Features**:
- Autoload
- AMP
- Device-agnostic APIs
- Memory Management
- Generator
- Distrubuted
- Custom Tensor&Storage
- ...
- **Improve Tests**: Add more test cases related to the integration mechanism.
- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation.
- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync.
22 changes: 22 additions & 0 deletions PyTorchSimDevice/cmake/TorchPythonTargets.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
if(WIN32)
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib")
elseif(APPLE)
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.dylib")
else()
set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so")
endif()

add_library(torch_python SHARED IMPORTED)

set_target_properties(torch_python PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "${PYTORCH_INSTALL_DIR}/include"
INTERFACE_LINK_LIBRARIES "c10;torch_cpu"
IMPORTED_LOCATION "${TORCH_PYTHON_IMPORTED_LOCATION}"
)

add_library(torch_python_library INTERFACE IMPORTED)

set_target_properties(torch_python_library PROPERTIES
INTERFACE_INCLUDE_DIRECTORIES "\$<TARGET_PROPERTY:torch_python,INTERFACE_INCLUDE_DIRECTORIES>"
INTERFACE_LINK_LIBRARIES "\$<TARGET_FILE:torch_python>;\$<TARGET_PROPERTY:torch_python,INTERFACE_LINK_LIBRARIES>"
)
16 changes: 16 additions & 0 deletions PyTorchSimDevice/csrc/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
set(LIBRARY_NAME torch_openreg)

file(GLOB_RECURSE SOURCE_FILES
"${CMAKE_CURRENT_SOURCE_DIR}/*.cpp"
)

add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES})

target_link_libraries(${LIBRARY_NAME} PRIVATE torch_cpu_library openreg)
target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR})

install(TARGETS ${LIBRARY_NAME}
ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR}
LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR}
RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR}
)
15 changes: 15 additions & 0 deletions PyTorchSimDevice/csrc/amp/OpenRegAmp.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
#pragma once

#include <c10/core/ScalarType.h>
#include <c10/macros/Macros.h>

#include <include/Macros.h>

namespace c10::openreg {

OPENREG_EXPORT bool is_amp_enabled();
OPENREG_EXPORT void set_amp_enabled(bool flag);
OPENREG_EXPORT at::ScalarType get_amp_dtype();
OPENREG_EXPORT void set_amp_dtype(at::ScalarType dtype);

} // namespace c10::openreg
28 changes: 28 additions & 0 deletions PyTorchSimDevice/csrc/amp/auto_cast_mode.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
#include <ATen/autocast_mode.h>
#include <iostream>
#include "OpenRegAmp.h"

namespace {
bool g_amp_enabled = false;
at::ScalarType g_amp_dtype = at::kFloat;
}

namespace c10::openreg {

OPENREG_EXPORT bool is_amp_enabled() {
return g_amp_enabled;
}

OPENREG_EXPORT void set_amp_enabled(bool flag) {
g_amp_enabled = flag;
}

OPENREG_EXPORT at::ScalarType get_amp_dtype() {
return g_amp_dtype;
}

OPENREG_EXPORT void set_amp_dtype(at::ScalarType dtype) {
g_amp_dtype = dtype;
}

} // namespace c10::openreg
Loading
Loading