diff --git a/.github/workflows/docker-base-image-2-8.yml b/.github/workflows/docker-base-image-2-8.yml index f8649303..74e81e07 100644 --- a/.github/workflows/docker-base-image-2-8.yml +++ b/.github/workflows/docker-base-image-2-8.yml @@ -2,7 +2,7 @@ name: Docker Base Image CI (PyTorch 2.8) on: push: - branches: [ "base" ] + branches: [ "base_v2.8" ] workflow_dispatch: repository_dispatch: types: [ build_base ] @@ -63,7 +63,7 @@ jobs: file: ./Dockerfile.base push: true build-args: | - PYTORCH_IMAGE=pytorch/pytorch:2.8.0-cuda12.6-cudnn9-runtime + PYTORCH_IMAGE=pytorch/pytorch:2.8.0-cuda12.6-cudnn9-devel GEM5_ASSET_ID=${{ env.GEM5_ASSET_ID }} LLVM_ASSET_ID=${{ env.LLVM_ASSET_ID }} SPIKE_ASSET_ID=${{ env.SPIKE_ASSET_ID }} diff --git a/.github/workflows/docker-image-2-8.yml b/.github/workflows/docker-image-2-8.yml index cb5f73d1..4d511a1a 100644 --- a/.github/workflows/docker-image-2-8.yml +++ b/.github/workflows/docker-image-2-8.yml @@ -1,7 +1,7 @@ name: Docker image CI (PyTorch 2.8) on: - pull_request: + push: branches: [ "torch_v2.8" ] workflow_dispatch: diff --git a/.gitignore b/.gitignore index b42d5f6b..3ca1e54b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,6 @@ __pycache__/ TOGSim/build/ .vscode -*.txt *.ipynb_checkpoints output togsim_results/* diff --git a/Dockerfile b/Dockerfile index 088daa43..1c52d32f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -10,4 +10,7 @@ RUN cd PyTorchSim/TOGSim && \ cd build && \ conan install .. --build=missing && \ cmake .. && \ - make -j$(nproc) \ No newline at end of file + make -j$(nproc) + +RUN cd PyTorchSim/PyTorchSimDevice && \ + python -m pip install --no-build-isolation -e . \ No newline at end of file diff --git a/Dockerfile.base b/Dockerfile.base index 897b8195..c5f200bc 100644 --- a/Dockerfile.base +++ b/Dockerfile.base @@ -34,7 +34,7 @@ RUN apt -y update && \ python3-dev python-is-python3 libboost-all-dev \ libhdf5-serial-dev python3-pydot libpng-dev libelf-dev pkg-config pip \ python3-venv black libssl-dev libasan5 libubsan1 curl device-tree-compiler wget ninja-build && \ - pip install onnx matplotlib scikit-learn pydot tabulate && pip install --user conan==1.56.0 && rm -rf /var/lib/apt/lists/* + pip install onnx matplotlib scikit-learn pydot tabulate && pip install --user conan==1.56.0 cmake==3.26.4 && rm -rf /var/lib/apt/lists/* # Download RISC-V tool chain RUN wget https://github.com/riscv-collab/riscv-gnu-toolchain/releases/download/2023.12.14/riscv64-glibc-ubuntu-22.04-llvm-nightly-2023.12.14-nightly.tar.gz && \ diff --git a/PyTorchSimDevice/CMakeLists.txt b/PyTorchSimDevice/CMakeLists.txt new file mode 100644 index 00000000..2c207ca6 --- /dev/null +++ b/PyTorchSimDevice/CMakeLists.txt @@ -0,0 +1,44 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(TORCH_OPENREG CXX C) + +include(GNUInstallDirs) +include(CheckCXXCompilerFlag) + +set(CMAKE_CXX_STANDARD 17) +set(CMAKE_C_STANDARD 11) +set(CMAKE_CXX_EXTENSIONS OFF) + +set(CMAKE_EXPORT_COMPILE_COMMANDS ON) +set(CMAKE_SKIP_BUILD_RPATH FALSE) +set(CMAKE_BUILD_WITH_INSTALL_RPATH TRUE) +set(CMAKE_INSTALL_RPATH_USE_LINK_PATH FALSE) +set(CMAKE_CXX_VISIBILITY_PRESET hidden) + +set(CMAKE_BUILD_TYPE Release CACHE STRING "Build type" FORCE) + +if(APPLE) + set(CMAKE_INSTALL_RPATH "@loader_path/lib;@loader_path") +elseif(UNIX) + set(CMAKE_INSTALL_RPATH "$ORIGIN/lib:$ORIGIN") +elseif(WIN32) + set(CMAKE_INSTALL_RPATH "") +endif() +set(CMAKE_INSTALL_LIBDIR lib) +set(CMAKE_INSTALL_MESSAGE NEVER) + +set(Torch_DIR ${PYTORCH_INSTALL_DIR}/share/cmake/Torch) +find_package(Torch REQUIRED) + +if(DEFINED PYTHON_INCLUDE_DIR) + include_directories(${PYTHON_INCLUDE_DIR}) +else() + message(FATAL_ERROR "Cannot find Python directory") +endif() + +include_directories(${CMAKE_CURRENT_SOURCE_DIR}) +include(${PROJECT_SOURCE_DIR}/cmake/TorchPythonTargets.cmake) + +add_subdirectory(${PROJECT_SOURCE_DIR}/third_party/openreg) +add_subdirectory(${PROJECT_SOURCE_DIR}/csrc) +add_subdirectory(${PROJECT_SOURCE_DIR}/torch_openreg/csrc) diff --git a/PyTorchSimDevice/README.md b/PyTorchSimDevice/README.md new file mode 100644 index 00000000..83ec85b1 --- /dev/null +++ b/PyTorchSimDevice/README.md @@ -0,0 +1,175 @@ +# PyTorch OpenReg + +## Background + +The third-party device integration mechanism based on PrivateUse1 has become the official mainstream method for new backends to integrate with PyTorch. Ensuring the availability of this mechanism is crucial for enriching PyTorch's hardware ecosystem. + +**Note:** + +The goal of `torch_openreg` is **not to implement a fully functional, high-performance PyTorch backend**, but to serve as a **minimalist reference implementation for mechanism verification**. + +### Purpose + +- **Test Backend**: To serve as an in-tree test backend for PrivateUse1, ensuring quality stability through CI/CD. +- **Integration Example**: To serve as a reference example for new backend integration. +- **Integration Documentation**: To provide module-level integration documentation that corresponds with the code. + +### Design Principles + +- **Minimality Principle**: The fundamental goal is to enable/verify all integration paths/mechanisms for a new backend to integrate to PyTorch. All functions follow a "just right" strategy to ensure the correctness of relevant integration capabilities. +- **Authenticity Principle**: To complete the OpenReg integration in the same way a real accelerator backend would integrate with PyTorch. + +## Directory Structure + +```shell +torch_openreg/ +├── CMakeLists.txt +├── csrc +│ ├── aten +│ │ ├── native +│ │ │ ├── Extra.cpp +│ │ │ ├── Minimal.cpp +│ │ │ └── ... +│ │ ├── OpenRegExtra.cpp +│ │ └── OpenRegMinimal.cpp +│ ├── CMakeLists.txt +│ └── runtime +│ ├── OpenRegDeviceAllocator.cpp +│ ├── OpenRegDeviceAllocator.h +│ ├── OpenRegFunctions.cpp +│ ├── OpenRegFunctions.h +│ ├── OpenRegGenerator.cpp +│ ├── OpenRegGenerator.h +│ ├── OpenRegGuard.cpp +│ ├── OpenRegGuard.h +│ ├── OpenRegHooks.cpp +│ ├── OpenRegHooks.h +│ ├── OpenRegHostAllocator.cpp +│ ├── OpenRegHostAllocator.h +│ └── ... +├── pyproject.toml +├── README.md +├── setup.py +├── third_party +│ └── openreg +└── torch_openreg + ├── csrc + │ ├── CMakeLists.txt + │ ├── Module.cpp + │ └── stub.c + ├── __init__.py + └── openreg + ├── __init__.py + ├── meta.py + └── random.py +``` + +**Dependencies**: + +```mermaid +graph LR + A[Python] + B[_C.so] + C[libtorch_bindings.so] + D[libtorch_openreg.so] + E[libopenreg.so] + + A --> B --> C --> D --> E +``` + +There are 4 DSOs in torch_openreg, and the dependencies between them are as follows: + +- `_C.so`: + - **sources**: torch_openreg/csrc/stub.c + - **description**: Python C module entry point. +- `libtorch_bindings.so`: The bridging code between Python and C++ should go here. + - **sources**: torch_openreg/csrc + - **description**: A thin glue layer between Python and C++. +- `libtorch_openreg.so`: All core implementations should go here. + - **sources**: csrc + - **description**: All core functionality, such as device runtime, operators, etc. +- `libopenreg.so`: A DSO that uses the CPU to emulate a CUDA-like device, you can ignore it. + - **sources**: third_party/openreg + - **description**: Provides low-level device functionality similar to libcudart.so. + +**Key Directories**: + +- `csrc/`: Core device implementation, including operator registration, runtime, etc. + - `csrc/aten/`: Operator registration + - `csrc/aten/native/`: Specific operator implementations for the OpenReg device. + - `csrc/aten/OpenRegMinimal.cpp`: The most minimal set of operator implementations (allowing for the creation of Tensors and related operations upon completion). + - `csrc/aten/OpenRegExtra.cpp`: Implementations for other types of operators. + - `csrc/runtime/`: Implementations for Host memory, device memory, Guard, Hooks, etc. +- `third_party/`: A C++ library that simulates a CUDA-like device using the CPU. +- `torch_openreg/`: Python interface implementation (Python code and C++ Bindings). + - `torch_openreg/csrc/`: Python C++ binding code. + - `torch_openreg/openreg/`: Python API. + +## Currently Implemented Features + +### Operator Registration + +- Operator Implementation + + - Register for builtin PyTorch Operators + - `TORCH_LIBRARY_IMPL` form: See `empty.memory_format + - `STUB` form: See `abs_stub` + - Register for custom operators + - Schema Registration: See `custom_abs` + - Kernel Registration: See `custom_abs` + - Fallback Registration for `AutogradPriavateUse1`: See `custom_abs` + - Meta Registration: See `custom_abs` + - `torch.autograd.Function`: See `custom_autograd_fn_aliasing` + - Register for fallback + - Per-operator Fallback: See `sub.Tensor` + - Global Fallback: See `wrapper_cpu_fallback` + +## Installation and Usage + +### Installation + +```python +pip3 install --no-build-isolation -e . # for develop +pip3 install --no-build-isolation . # for install +``` + +### Usage Example + +After installation, you can use the `openreg` device in Python just like any other regular device. + +```python +import torch +import torch_openreg + +if not torch.openreg.is_available(): + print("OpenReg backend is not available in this build.") + exit() + +print("OpenReg backend is available!") + +device = torch.device("openreg") + +x = torch.tensor([[1., 2.], [3., 4.]], device=device) +y = x + 2 +print("Result y:\n", y) +print(f"Device of y: {y.device}") + +z = y.cpu() +print("Result z:\n", z) +print(f"Device of z: {z.device}") +``` + +## Future Plans + +- **Enhance Features**: + - Autoload + - AMP + - Device-agnostic APIs + - Memory Management + - Generator + - Distrubuted + - Custom Tensor&Storage + - ... +- **Improve Tests**: Add more test cases related to the integration mechanism. +- **Improve Documentation**: Add a new chapter on third-party device integration in the `Developer Notes` section of the PyTorch documentation. +- **Real-time Synchronization**: Keep the code and documentation updated iteratively and in sync. diff --git a/PyTorchSimDevice/cmake/TorchPythonTargets.cmake b/PyTorchSimDevice/cmake/TorchPythonTargets.cmake new file mode 100644 index 00000000..b7a807d2 --- /dev/null +++ b/PyTorchSimDevice/cmake/TorchPythonTargets.cmake @@ -0,0 +1,22 @@ +if(WIN32) + set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/torch_python.lib") +elseif(APPLE) + set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.dylib") +else() + set(TORCH_PYTHON_IMPORTED_LOCATION "${PYTORCH_INSTALL_DIR}/lib/libtorch_python.so") +endif() + +add_library(torch_python SHARED IMPORTED) + +set_target_properties(torch_python PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "${PYTORCH_INSTALL_DIR}/include" + INTERFACE_LINK_LIBRARIES "c10;torch_cpu" + IMPORTED_LOCATION "${TORCH_PYTHON_IMPORTED_LOCATION}" +) + +add_library(torch_python_library INTERFACE IMPORTED) + +set_target_properties(torch_python_library PROPERTIES + INTERFACE_INCLUDE_DIRECTORIES "\$" + INTERFACE_LINK_LIBRARIES "\$;\$" +) diff --git a/PyTorchSimDevice/csrc/CMakeLists.txt b/PyTorchSimDevice/csrc/CMakeLists.txt new file mode 100644 index 00000000..e2ae2b3f --- /dev/null +++ b/PyTorchSimDevice/csrc/CMakeLists.txt @@ -0,0 +1,16 @@ +set(LIBRARY_NAME torch_openreg) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_link_libraries(${LIBRARY_NAME} PRIVATE torch_cpu_library openreg) +target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +install(TARGETS ${LIBRARY_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/PyTorchSimDevice/csrc/amp/OpenRegAmp.h b/PyTorchSimDevice/csrc/amp/OpenRegAmp.h new file mode 100644 index 00000000..2f81e9d2 --- /dev/null +++ b/PyTorchSimDevice/csrc/amp/OpenRegAmp.h @@ -0,0 +1,15 @@ +#pragma once + +#include +#include + +#include + +namespace c10::openreg { + +OPENREG_EXPORT bool is_amp_enabled(); +OPENREG_EXPORT void set_amp_enabled(bool flag); +OPENREG_EXPORT at::ScalarType get_amp_dtype(); +OPENREG_EXPORT void set_amp_dtype(at::ScalarType dtype); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/amp/auto_cast_mode.cpp b/PyTorchSimDevice/csrc/amp/auto_cast_mode.cpp new file mode 100644 index 00000000..fd650026 --- /dev/null +++ b/PyTorchSimDevice/csrc/amp/auto_cast_mode.cpp @@ -0,0 +1,28 @@ +#include +#include +#include "OpenRegAmp.h" + +namespace { + bool g_amp_enabled = false; + at::ScalarType g_amp_dtype = at::kFloat; +} + +namespace c10::openreg { + +OPENREG_EXPORT bool is_amp_enabled() { + return g_amp_enabled; +} + +OPENREG_EXPORT void set_amp_enabled(bool flag) { + g_amp_enabled = flag; +} + +OPENREG_EXPORT at::ScalarType get_amp_dtype() { + return g_amp_dtype; +} + +OPENREG_EXPORT void set_amp_dtype(at::ScalarType dtype) { + g_amp_dtype = dtype; +} + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp new file mode 100644 index 00000000..04ba6d48 --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/OpenRegExtra.cpp @@ -0,0 +1,195 @@ +#include "native/Extra.h" + +#include +#include + +#include +#include + +namespace at::openreg { + +namespace { +at::Tensor wrapper_quantize_per_tensor( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + return at::native::openreg::quantize_per_tensor( + self, scale, zero_point, dtype); +} + +int64_t wrapper__fused_sdp_choice( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + return at::native::openreg::_fused_sdp_choice( + query, key, value, attn_mask, dropout_p, is_causal, scale, enable_gqa); +} + +void wrapper_quantize_tensor_per_tensor_affine_stub( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point) { + at::native::openreg::quantize_tensor_per_tensor_affine_stub( + rtensor, qtensor, scale, zero_point); +} + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +wrapper__scaled_dot_product_fused_attention_overrideable( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale) { + return at::native::openreg::_scaled_dot_product_fused_attention_overrideable( + query, + key, + value, + attn_bias, + dropout_p, + is_causal, + return_debug_mask, + scale); +} + +std::tuple +wrapper_scaled_dot_product_fused_attention_overrideable_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + return at::native::openreg:: + _scaled_dot_product_fused_attention_overrideable_backward( + grad_out, + query, + key, + value, + attn_bias, + grad_input_mask, + out, + logsumexp, + cum_seq_q, + cum_seq_k, + max_q, + max_k, + dropout_p, + is_causal, + philox_seed, + philox_offset, + scale); +} + +at::Tensor wrapper_custom_autograd_fn_returns_self(at::Tensor x) { + return at::native::openreg::custom_autograd_fn_returns_self(x); +} + +at::Tensor wrapper_custom_autograd_fn_aliasing(at::Tensor x) { + return at::native::openreg::custom_autograd_fn_aliasing(x); +} + +at::Tensor& wrapper_abs_out(const at::Tensor& self, at::Tensor& out) { + return at::native::openreg::abs_out(self, out); +} + +void wrapper_abs_stub(at::TensorIteratorBase& iter) { + at::native::openreg::abs_kernel(iter); +} + +at::Tensor wrapper_custom_abs(at::Tensor x) { + return at::native::openreg::custom_abs(x); +} +} // namespace + +using namespace at::native; +// Registration via STUB +// LITERALINCLUDE START: STUB DEFAULT +REGISTER_PRIVATEUSE1_DISPATCH(abs_stub, &wrapper_abs_stub); +REGISTER_PRIVATEUSE1_DISPATCH( + quantize_tensor_per_tensor_affine_stub, + &wrapper_quantize_tensor_per_tensor_affine_stub); +REGISTER_PRIVATEUSE1_DISPATCH( + _fused_sdp_choice_stub, + &wrapper__fused_sdp_choice); +// LITERALINCLUDE END: STUB DEFAULT + +// Registration of custom operators +// LITERALINCLUDE START: CUSTOM OPERATOR SCHEMA +TORCH_LIBRARY(openreg, m) { + m.def("custom_abs(Tensor input)-> Tensor"); +} +// LITERALINCLUDE END: CUSTOM OPERATOR SCHEMA + +// LITERALINCLUDE START: CUSTOM OPERATOR DEFAULT +TORCH_LIBRARY_IMPL(openreg, PrivateUse1, m) { + m.impl("custom_abs", &wrapper_custom_abs); +} +// LITERALINCLUDE END: CUSTOM OPERATOR DEFAULT + +// LITERALINCLUDE START: CUSTOM OPERATOR FALLBACK +TORCH_LIBRARY_IMPL(_, AutogradPrivateUse1, m) { + m.fallback(torch::autograd::autogradNotImplementedFallback()); +} +// LITERALINCLUDE END: CUSTOM OPERATOR FALLBACK + +// The rest is for testing purposes +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + /* + abs_stub only works if abs.out is also registered with PrivateUse1, because + abs.default is designed to redirect directly to abs.out, which calls + abs_stub. + */ + m.impl("abs.out", &wrapper_abs_out); + m.impl("quantize_per_tensor", &wrapper_quantize_per_tensor); + m.impl("_fused_sdp_choice", &wrapper__fused_sdp_choice); + m.impl( + "_scaled_dot_product_fused_attention_overrideable", + &wrapper__scaled_dot_product_fused_attention_overrideable); + m.impl( + "_scaled_dot_product_fused_attention_overrideable_backward", + &wrapper_scaled_dot_product_fused_attention_overrideable_backward); +} + +TORCH_LIBRARY_FRAGMENT(openreg, m) { + m.def("custom_autograd_fn_returns_self(Tensor input)-> Tensor"); + m.def("custom_autograd_fn_aliasing(Tensor(a) input)-> Tensor(a)"); +} + +TORCH_LIBRARY_IMPL(openreg, AutogradPrivateUse1, m) { + m.impl( + "custom_autograd_fn_returns_self", + &wrapper_custom_autograd_fn_returns_self); + m.impl("custom_autograd_fn_aliasing", &wrapper_custom_autograd_fn_aliasing); +} + +} // namespace at::openreg diff --git a/PyTorchSimDevice/csrc/aten/OpenRegMinimal.cpp b/PyTorchSimDevice/csrc/aten/OpenRegMinimal.cpp new file mode 100644 index 00000000..39f019c5 --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/OpenRegMinimal.cpp @@ -0,0 +1,169 @@ +#include "native/Minimal.h" + +#include +#include + +#include +#include +#include +#include +#include + +namespace at::openreg { + +namespace { + +// LITERALINCLUDE START: EMPTY.MEMORY_FORMAT WRAPPER +at::Tensor wrapper_empty_memory_format( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + return at::native::openreg::empty_memory_format( + size, + dtype_opt, + layout_opt, + device_opt, + pin_memory_opt, + memory_format_opt); +} +// LITERALINCLUDE END: EMPTY.MEMORY_FORMAT WRAPPER + +at::Tensor wrapper_empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + return at::native::openreg::empty_strided( + size, stride, dtype_opt, layout_opt, device_opt, pin_memory_opt); +} + +at::Tensor wrapper_as_strided( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset) { + return at::native::openreg::as_strided(self, size, stride, storage_offset); +} + +const at::Tensor& wrapper_resize_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::openreg::resize_(self, size, memory_format); +} + +at::Tensor wrapper__reshape_alias( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) { + return at::native::openreg::_reshape_alias(self, size, stride); +} + +at::Tensor wrapper__copy_from( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking) { + return at::native::openreg::_copy_from(self, dst, non_blocking); +} + +at::Tensor wrapper__copy_from_and_resize( + const at::Tensor& self, + const at::Tensor& dst) { + return at::native::openreg::_copy_from_and_resize(self, dst); +} + +at::Scalar wrapper__local_scalar_densor(const at::Tensor& self) { + return at::native::openreg::_local_scalar_dense(self); +} + +at::Tensor& wrapper_set_source_Tensor_( + at::Tensor& self, + const at::Tensor& source) { + return at::native::openreg::set_source_Tensor_(self, source); +} + +at::Tensor& wrapper_set_source_Storage_(at::Tensor& self, at::Storage source) { + return at::native::openreg::set_source_Storage_(self, source); +} + +at::Tensor& wrapper_set_source_Storage_storage_offsetset_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + return at::native::openreg::set_source_Storage_storage_offset_( + result, storage, storage_offset, size, stride); +} + +at::Tensor wrapper_view(const at::Tensor& self, c10::SymIntArrayRef size) { + return at::native::openreg::view(self, size); +} + +// LITERALINCLUDE START: FALLBACK WRAPPER +void wrapper_cpu_fallback( + const c10::OperatorHandle& op, + torch::jit::Stack* stack) { + const auto& op_name = op.schema().operator_name(); + + // Generate timestamp in format [YYYY-MM-DD HH:MM:SS.mmm] + auto now = std::chrono::system_clock::now(); + auto time_t = std::chrono::system_clock::to_time_t(now); + auto ms = std::chrono::duration_cast( + now.time_since_epoch()) % 1000; + + std::tm tm_buf; + localtime_r(&time_t, &tm_buf); + + std::ostringstream oss; + oss << std::put_time(&tm_buf, "%Y-%m-%d %H:%M:%S"); + oss << '.' << std::setfill('0') << std::setw(3) << ms.count(); + + std::cerr << "[" << oss.str() << "] [INFO] [PyTorchSimDevice] [Eager Mode] Operator: " << op_name << std::endl; + + at::native::openreg::cpu_fallback(op, stack); +} +// LITERALINCLUDE END: FALLBACK WRAPPER + +} // namespace + +// LITERALINCLUDE START: TORCH_LIBRARY_IMPL DEFAULT +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl("empty.memory_format", wrapper_empty_memory_format); + m.impl("empty_strided", wrapper_empty_strided); + m.impl("as_strided", wrapper_as_strided); + m.impl("resize_", wrapper_resize_); + m.impl("_reshape_alias", wrapper__reshape_alias); + m.impl("_copy_from", wrapper__copy_from); + m.impl("_copy_from_and_resize", wrapper__copy_from_and_resize); + m.impl("_local_scalar_dense", wrapper__local_scalar_densor); + m.impl("set_.source_Tensor", wrapper_set_source_Tensor_); + m.impl("set_.source_Storage", wrapper_set_source_Storage_); + m.impl( + "set_.source_Storage_storage_offset", + wrapper_set_source_Storage_storage_offsetset_); + m.impl("view", wrapper_view); +} +// LITERALINCLUDE END: TORCH_LIBRARY_IMPL DEFAULT + +// LITERALINCLUDE START: FALLBACK GLOBAL +TORCH_LIBRARY_IMPL(_, PrivateUse1, m) { + m.fallback( + torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>()); +} +// LITERALINCLUDE END: FALLBACK GLOBAL + +// LITERALINCLUDE START: FALLBACK SINGLE +TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { + m.impl( + "sub.Tensor", + torch::CppFunction::makeFromBoxedFunction<&wrapper_cpu_fallback>()); +} +// LITERALINCLUDE END: FALLBACK SINGLE + +} // namespace at::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Common.h b/PyTorchSimDevice/csrc/aten/native/Common.h new file mode 100644 index 00000000..c17196d0 --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Common.h @@ -0,0 +1,97 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#include + +#include + +namespace at::native::openreg { + +class MemoryGuard { + public: + template + explicit MemoryGuard(const Args&... args) { + (find_and_unprotect_tensors(args), ...); + } + + ~MemoryGuard() noexcept { + for (void* ptr : unprotected_pointers_) { + orMemoryProtect(ptr); + } + } + + MemoryGuard(const MemoryGuard&) = delete; + MemoryGuard& operator=(const MemoryGuard&) = delete; + MemoryGuard(MemoryGuard&&) = delete; + MemoryGuard& operator=(MemoryGuard&&) = delete; + + private: + template + void find_and_unprotect_tensors(const T& item) { + if constexpr (std::is_base_of_v) { + unprotect_if_needed(item); + } else if constexpr (std::is_same_v) { + if (item.isTensor()) { + unprotect_if_needed(item.toTensor()); + } else if (item.isTensorList()) { + for (const at::Tensor& tensor : item.toTensorListRef()) { + unprotect_if_needed(tensor); + } + } else if (item.isList()) { + for (const c10::IValue& element : item.toListRef()) { + find_and_unprotect_tensors(element); + } + } else if (item.isGenericDict()) { + for (const auto& [key, value] : item.toGenericDict()) { + find_and_unprotect_tensors(key); + find_and_unprotect_tensors(value); + } + } + } + } + + void unprotect_if_needed(const at::TensorBase& tensor) { + if (!tensor.defined() || !tensor.has_storage()) { + return; + } + + void* ptr = tensor.data_ptr(); + orPointerAttributes attr; + + if (orPointerGetAttributes(&attr, ptr) != orSuccess || + attr.type != orMemoryTypeDevice) { + return; + } + + auto [it, inserted] = unprotected_pointers_.insert(attr.pointer); + if (inserted) { + orMemoryUnprotect(attr.pointer); + } + } + + std::unordered_set unprotected_pointers_; +}; + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Extra.cpp b/PyTorchSimDevice/csrc/aten/native/Extra.cpp new file mode 100644 index 00000000..711d114c --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Extra.cpp @@ -0,0 +1,210 @@ +#include "Extra.h" + +namespace at::native::openreg { + +at::Tensor quantize_per_tensor( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype) { + return at::native::quantize_per_tensor(self, scale, zero_point, dtype); +} + +int64_t _fused_sdp_choice( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa) { + auto backend = sdp::SDPBackend::math; + return static_cast(backend); +} + +void quantize_tensor_per_tensor_affine_stub( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point) {} + +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +_scaled_dot_product_fused_attention_overrideable( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale) { + const int64_t batch_size = query.size(0); + const int64_t num_heads = query.size(1); + const int64_t head_dim_v = value.size(3); + const int64_t max_seqlen_q = query.size(2); + const int64_t max_seqlen_kv = key.size(2); + + auto opts = query.options(); + auto output = + at::empty({batch_size, num_heads, max_seqlen_q, head_dim_v}, opts); + auto logsumexp = + at::empty({batch_size, num_heads, max_seqlen_q}, opts.dtype(at::kFloat)); + auto debug_attn_mask = at::empty( + {batch_size, num_heads, max_seqlen_q, max_seqlen_kv}, + opts.dtype(at::kFloat)); + auto philox_seed = at::empty({}, at::dtype(at::kLong)); + auto philox_offset = at::empty({}, at::dtype(at::kLong)); + + return std::make_tuple( + output, + logsumexp, + at::Tensor(), + at::Tensor(), + max_seqlen_q, + max_seqlen_kv, + philox_seed, + philox_offset, + debug_attn_mask); +} + +std::tuple +_scaled_dot_product_fused_attention_overrideable_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale) { + return std::tuple( + at::empty_like(query), + at::empty_like(key), + at::empty_like(value), + at::empty_like(attn_bias)); +} + +namespace { +struct CustomAutogradFnReturnsSelf + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self; + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; + +struct CustomAutogradFnAliasing + : public torch::autograd::Function { + static at::Tensor forward( + torch::autograd::AutogradContext* ctx, + at::Tensor self) { + return self.view_symint(self.sym_sizes()); + } + + static torch::autograd::variable_list backward( + torch::autograd::AutogradContext* ctx, + torch::autograd::variable_list grad_output) { + return {grad_output[0] * 0.5}; + } +}; +} // namespace + +at::Tensor custom_autograd_fn_returns_self(at::Tensor x) { + return CustomAutogradFnReturnsSelf::apply(x); +} + +at::Tensor custom_autograd_fn_aliasing(at::Tensor x) { + return CustomAutogradFnAliasing::apply(x); +} + +/* + This implementation is only used to test stub registration, so not all + capabilities are fully supported. + + Current Limitations: + - dtype: Float only + - input tensor: must be contiguous layout +*/ +// LITERALINCLUDE START: STUB ABS +void abs_kernel(at::TensorIteratorBase& iter) { + TORCH_CHECK(iter.ntensors() == 2, "Abs kernel expects 2 tensors"); + TORCH_CHECK( + iter.common_dtype() == at::ScalarType::Float, + "Abs kernel only supports float type"); + + auto& output_tensor = iter.tensor(0); + auto& input_tensor = iter.tensor(1); + + TORCH_CHECK( + input_tensor.sizes() == output_tensor.sizes(), + "Input and output tensor sizes must match."); + + auto abs_loop = [](float* out_ptr, const float* in_ptr, int64_t n) { + for (int64_t i = 0; i < n; ++i) { + out_ptr[i] = std::abs(in_ptr[i]); + } + }; + + MemoryGuard guard(input_tensor, output_tensor); + + if (iter.is_contiguous()) { + abs_loop( + static_cast(iter.data_ptr(0)), + static_cast(iter.data_ptr(1)), + iter.numel()); + } else { + TORCH_CHECK( + input_tensor.is_contiguous(), "Input tensor must be contiguous.") + + auto output = at::empty( + input_tensor.sizes(), + input_tensor.options().memory_format( + input_tensor.suggest_memory_format())); + + MemoryGuard guard(output); + + abs_loop( + static_cast(output.data_ptr()), + static_cast(iter.data_ptr(1)), + iter.numel()); + + output_tensor.copy_(output); + } +} +// LITERALINCLUDE END: STUB ABS + +at::Tensor& abs_out(const at::Tensor& self, at::Tensor& out) { + return at::native::abs_out(self, out); +} + +at::Tensor custom_abs(at::Tensor x) { + return at::abs(x); +} + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Extra.h b/PyTorchSimDevice/csrc/aten/native/Extra.h new file mode 100644 index 00000000..f002949a --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Extra.h @@ -0,0 +1,69 @@ +#include "Common.h" + +namespace at::native::openreg { + +at::Tensor quantize_per_tensor( + const at::Tensor& self, + double scale, + int64_t zero_point, + at::ScalarType dtype); +int64_t _fused_sdp_choice( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_mask, + double dropout_p, + bool is_causal, + std::optional scale, + bool enable_gqa); +void quantize_tensor_per_tensor_affine_stub( + const at::Tensor& rtensor, + at::Tensor& qtensor, + double scale, + int64_t zero_point); +std::tuple< + at::Tensor, + at::Tensor, + at::Tensor, + at::Tensor, + c10::SymInt, + c10::SymInt, + at::Tensor, + at::Tensor, + at::Tensor> +_scaled_dot_product_fused_attention_overrideable( + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const std::optional& attn_bias, + double dropout_p, + bool is_causal, + bool return_debug_mask, + std::optional scale); +std::tuple +_scaled_dot_product_fused_attention_overrideable_backward( + const at::Tensor& grad_out, + const at::Tensor& query, + const at::Tensor& key, + const at::Tensor& value, + const at::Tensor& attn_bias, + std::array grad_input_mask, + const at::Tensor& out, + const at::Tensor& logsumexp, + const at::Tensor& cum_seq_q, + const at::Tensor& cum_seq_k, + int64_t max_q, + int64_t max_k, + double dropout_p, + bool is_causal, + const at::Tensor& philox_seed, + const at::Tensor& philox_offset, + std::optional scale); + +at::Tensor custom_autograd_fn_returns_self(at::Tensor x); +at::Tensor custom_autograd_fn_aliasing(at::Tensor x); +at::Tensor& abs_out(const at::Tensor& self, at::Tensor& out); +void abs_kernel(at::TensorIteratorBase& iter); +at::Tensor custom_abs(at::Tensor x); + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Minimal.cpp b/PyTorchSimDevice/csrc/aten/native/Minimal.cpp new file mode 100644 index 00000000..8a3263bb --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Minimal.cpp @@ -0,0 +1,185 @@ +#include "Minimal.h" + +#include + +namespace at::native::openreg { + +// LITERALINCLUDE START: EMPTY.MEMORY_FORMAT IMPL +at::Tensor empty_memory_format( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + auto allocator = at::GetAllocator(at::kPrivateUse1); + return at::detail::empty_generic( + size, allocator, pu1_dks, dtype, memory_format_opt); +} +// LITERALINCLUDE END: EMPTY.MEMORY_FORMAT IMPL + +at::Tensor empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt) { + const auto device = c10::device_or_default(device_opt); + const auto dtype = c10::dtype_or_default(dtype_opt); + TORCH_CHECK(device.is_privateuseone()); + TORCH_CHECK( + c10::layout_or_default(layout_opt) == c10::Layout::Strided, + "Non strided layout not supported"); + TORCH_CHECK( + !c10::pinned_memory_or_default(pin_memory_opt), + "Pin memory can only be on CPU"); + const c10::DeviceGuard device_guard(device); + constexpr c10::DispatchKeySet pu1_dks(c10::DispatchKey::PrivateUse1); + auto allocator = at::GetAllocator(at::kPrivateUse1); + return at::detail::empty_strided_generic( + size, stride, allocator, pu1_dks, dtype); +} + +at::Tensor as_strided( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset) { + MemoryGuard guard(self); + + return at::cpu::as_strided_symint(self, size, stride, storage_offset); +} + +const at::Tensor& resize_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format) { + return at::native::resize_( + self, C10_AS_INTARRAYREF_SLOW(size), memory_format); +} + +at::Tensor _reshape_alias( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride) { + return at::native::_reshape_alias( + self, C10_AS_INTARRAYREF_SLOW(size), C10_AS_INTARRAYREF_SLOW(stride)); +} + +at::Tensor _copy_from( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking) { + TORCH_CHECK(self.defined(), "Source tensor (self) is not defined."); + TORCH_CHECK(dst.defined(), "Destination tensor (dst) is not defined."); + + MemoryGuard guard(self, dst); + + if (self.device() == dst.device()) { + at::Tensor dst_as_cpu = at::from_blob( + dst.data_ptr(), + dst.sizes(), + dst.strides(), + dst.options().device(at::kCPU)); + const at::Tensor self_as_cpu = at::from_blob( + self.data_ptr(), + self.sizes(), + self.strides(), + self.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst_as_cpu), self_as_cpu, non_blocking); + + } else { + if (self.is_cpu()) { + at::Tensor dst_as_cpu = at::from_blob( + dst.data_ptr(), + dst.sizes(), + dst.strides(), + dst.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst_as_cpu), self, non_blocking); + + } else { + at::Tensor self_as_cpu = at::from_blob( + self.data_ptr(), + self.sizes(), + self.strides(), + self.options().device(at::kCPU)); + + at::native::copy_( + const_cast(dst), self_as_cpu, non_blocking); + } + } + + return dst; +} + +at::Tensor _copy_from_and_resize( + const at::Tensor& self, + const at::Tensor& dst) { + at::native::resize_(dst, self.sizes(), std::nullopt); + return at::native::copy_(const_cast(dst), self, false); +} + +at::Scalar _local_scalar_dense(const at::Tensor& self) { + MemoryGuard guard(self); + return at::native::_local_scalar_dense_cpu(self); +} + +at::Tensor& set_source_Tensor_(at::Tensor& self, const at::Tensor& source) { + return at::native::set_tensor_(self, source); +} + +at::Tensor& set_source_Storage_(at::Tensor& self, at::Storage source) { + return at::native::set_(self, source); +} + +at::Tensor& set_source_Storage_storage_offset_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride) { + return at::cpu::set_(result, storage, storage_offset, size, stride); +} + +at::Tensor view(const at::Tensor& self, c10::SymIntArrayRef size) { + MemoryGuard guard(self); + return at::native::view(self, C10_AS_INTARRAYREF_SLOW(size)); +} + +// LITERALINCLUDE START: FALLBACK IMPL +void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { + static const std::unordered_set cpu_fallback_blocklist = { + c10::OperatorName("aten::abs", ""), + c10::OperatorName("aten::abs", "out"), + }; + + const auto& op_name = op.schema().operator_name(); + if (cpu_fallback_blocklist.count(op_name)) { + TORCH_CHECK( + false, + "Operator '", + op_name, + "' is not implemented for device openreg."); + } else { + at::native::cpu_fallback(op, stack); + } +} +// LITERALINCLUDE END: FALLBACK IMPL + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/aten/native/Minimal.h b/PyTorchSimDevice/csrc/aten/native/Minimal.h new file mode 100644 index 00000000..a2e5cf02 --- /dev/null +++ b/PyTorchSimDevice/csrc/aten/native/Minimal.h @@ -0,0 +1,61 @@ +#include "Common.h" + +namespace at::native::openreg { + +at::Tensor empty_memory_format( + c10::IntArrayRef size, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt, + std::optional memory_format_opt); + +at::Tensor empty_strided( + c10::IntArrayRef size, + c10::IntArrayRef stride, + std::optional dtype_opt, + std::optional layout_opt, + std::optional device_opt, + std::optional pin_memory_opt); + +at::Tensor as_strided( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride, + std::optional storage_offset); + +const at::Tensor& resize_( + const at::Tensor& self, + c10::SymIntArrayRef size, + ::std::optional memory_format); + +at::Tensor _reshape_alias( + const at::Tensor& self, + c10::SymIntArrayRef size, + c10::SymIntArrayRef stride); + +at::Tensor _copy_from( + const at::Tensor& self, + const at::Tensor& dst, + bool non_blocking); + +at::Tensor _copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst); + +at::Scalar _local_scalar_dense(const at::Tensor& self); + +at::Tensor& set_source_Tensor_(at::Tensor& self, const at::Tensor& source); + +at::Tensor& set_source_Storage_(at::Tensor& self, at::Storage source); + +at::Tensor& set_source_Storage_storage_offset_( + at::Tensor& result, + at::Storage storage, + int64_t storage_offset, + c10::IntArrayRef size, + c10::IntArrayRef stride); + +at::Tensor view(const at::Tensor& self, c10::SymIntArrayRef size); + +void cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack); + +} // namespace at::native::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.cpp new file mode 100644 index 00000000..3d35b677 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.cpp @@ -0,0 +1,8 @@ +#include "OpenRegDeviceAllocator.h" + +namespace c10::openreg { + +static OpenRegDeviceAllocator global_openreg_alloc; +REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_openreg_alloc); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.h b/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.h new file mode 100644 index 00000000..c9aea4a9 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegDeviceAllocator.h @@ -0,0 +1,43 @@ +#include + +#include +#include + +#include + +namespace c10::openreg { +struct OpenRegDeviceAllocator final : at::Allocator { + OpenRegDeviceAllocator() = default; + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + orFreeHost(ptr); + } + + at::DataPtr allocate(size_t nbytes) override { + int current_device_index = -1; + orGetDevice(¤t_device_index); + + auto curr_device = + c10::Device(c10::DeviceType::PrivateUse1, current_device_index); + void* data = nullptr; + if (nbytes > 0) { + orMalloc(&data, nbytes); + TORCH_CHECK( + data, "Failed to allocator ", nbytes, " bytes on openreg device."); + } + return {data, data, &ReportAndDelete, curr_device}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + orMemcpy(dest, src, count, orMemcpyDeviceToDevice); + } +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegEvent.h b/PyTorchSimDevice/csrc/runtime/OpenRegEvent.h new file mode 100644 index 00000000..e869cf0d --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegEvent.h @@ -0,0 +1,146 @@ +#pragma once + +#include + +#include "OpenRegException.h" +#include "OpenRegStream.h" + +namespace c10::openreg { + +struct OpenRegEvent { + OpenRegEvent(bool enable_timing) noexcept : enable_timing_{enable_timing} {} + + ~OpenRegEvent() { + if (is_created_) { + OPENREG_CHECK(orEventDestroy(event_)); + } + } + + OpenRegEvent(const OpenRegEvent&) = delete; + OpenRegEvent& operator=(const OpenRegEvent&) = delete; + + OpenRegEvent(OpenRegEvent&& other) noexcept { + moveHelper(std::move(other)); + } + OpenRegEvent& operator=(OpenRegEvent&& other) noexcept { + if (this != &other) { + moveHelper(std::move(other)); + } + return *this; + } + + operator orEvent_t() const { + return event(); + } + + std::optional device() const { + if (is_created_) { + return at::Device(at::kPrivateUse1, device_index_); + } else { + return std::nullopt; + } + } + + bool isCreated() const { + return is_created_; + } + + DeviceIndex device_index() const { + return device_index_; + } + + orEvent_t event() const { + return event_; + } + + bool query() const { + if (!is_created_) { + return true; + } + + orError_t err = orEventQuery(event_); + if (err == orSuccess) { + return true; + } + + return false; + } + + void record() { + record(getCurrentOpenRegStream()); + } + + void recordOnce(const OpenRegStream& stream) { + if (!was_recorded_) + record(stream); + } + + void record(const OpenRegStream& stream) { + if (!is_created_) { + createEvent(stream.device_index()); + } + + TORCH_CHECK( + device_index_ == stream.device_index(), + "Event device ", + device_index_, + " does not match recording stream's device ", + stream.device_index(), + "."); + + OPENREG_CHECK(orEventRecord(event_, stream)); + was_recorded_ = true; + } + + void block(const OpenRegStream& stream) { + if (is_created_) { + OPENREG_CHECK(orStreamWaitEvent(stream, event_, 0)); + } + } + + float elapsed_time(const OpenRegEvent& other) const { + TORCH_CHECK_VALUE( + !(enable_timing_ & orEventDisableTiming) && + !(other.enable_timing_ & orEventDisableTiming), + "Both events must be created with argument 'enable_timing=True'."); + TORCH_CHECK_VALUE( + is_created_ && other.isCreated(), + "Both events must be recorded before calculating elapsed time."); + TORCH_CHECK( + query() && other.query(), + "Both events must be completed before calculating elapsed time."); + + float time_ms = 0; + OPENREG_CHECK(orEventElapsedTime(&time_ms, event_, other.event_)); + return time_ms; + } + + void synchronize() const { + if (is_created_) { + OPENREG_CHECK(orEventSynchronize(event_)); + } + } + + private: + unsigned int enable_timing_{orEventDisableTiming}; + bool is_created_{false}; + bool was_recorded_{false}; + DeviceIndex device_index_{-1}; + orEvent_t event_{}; + + void createEvent(DeviceIndex device_index) { + device_index_ = device_index; + OPENREG_CHECK(orEventCreateWithFlags(&event_, enable_timing_)); + is_created_ = true; + } + + void moveHelper(OpenRegEvent&& other) { + std::swap(enable_timing_, other.enable_timing_); + std::swap(is_created_, other.is_created_); + std::swap(was_recorded_, other.was_recorded_); + std::swap(device_index_, other.device_index_); + std::swap(event_, other.event_); + } +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegException.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegException.cpp new file mode 100644 index 00000000..09eb09b6 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegException.cpp @@ -0,0 +1,9 @@ +#include "OpenRegException.h" + +void orCheckFail( + const char* func, + const char* file, + uint32_t line, + const char* msg) { + throw ::c10::Error({func, file, line}, msg); +} diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegException.h b/PyTorchSimDevice/csrc/runtime/OpenRegException.h new file mode 100644 index 00000000..16c1ee1c --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegException.h @@ -0,0 +1,20 @@ +#pragma once + +#include + +#include + +void orCheckFail( + const char* func, + const char* file, + uint32_t line, + const char* msg = ""); + +#define OPENREG_CHECK(EXPR, ...) \ + do { \ + const orError_t __err = EXPR; \ + if (__err != orSuccess) { \ + orCheckFail( \ + __func__, __FILE__, static_cast(__LINE__), ##__VA_ARGS__); \ + } \ + } while (0) diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.cpp new file mode 100644 index 00000000..566bacd0 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.cpp @@ -0,0 +1,74 @@ +#include + +#include "OpenRegException.h" +#include "OpenRegFunctions.h" + +namespace c10::openreg { + +orError_t GetDeviceCount(int* dev_count) { + return orGetDeviceCount(dev_count); +} + +orError_t GetDevice(c10::DeviceIndex* device) { + int tmp_device = -1; + auto err = orGetDevice(&tmp_device); + *device = static_cast(tmp_device); + return err; +} + +orError_t SetDevice(c10::DeviceIndex device) { + int cur_device = -1; + orGetDevice(&cur_device); + if (device == cur_device) { + return orSuccess; + } + return orSetDevice(device); +} + +int device_count_impl() { + int count = 0; + GetDeviceCount(&count); + return count; +} + +OPENREG_EXPORT c10::DeviceIndex device_count() noexcept { + // initialize number of devices only once + static int count = []() { + try { + auto result = device_count_impl(); + TORCH_INTERNAL_ASSERT( + result <= std::numeric_limits::max(), + "Too many devices, DeviceIndex overflowed"); + return result; + } catch (const c10::Error& ex) { + // We don't want to fail, but still log the warning + // msg() returns the message without the stack trace + TORCH_WARN("Device initialization: ", ex.msg()); + return 0; + } + }(); + return static_cast(count); +} + +OPENREG_EXPORT c10::DeviceIndex current_device() { + c10::DeviceIndex cur_device = -1; + GetDevice(&cur_device); + return cur_device; +} + +OPENREG_EXPORT void set_device(c10::DeviceIndex device) { + SetDevice(device); +} + +OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device) { + int current_device = -1; + orGetDevice(¤t_device); + + if (device != current_device) { + orSetDevice(device); + } + + return current_device; +} + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.h b/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.h new file mode 100644 index 00000000..c2eb1e80 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegFunctions.h @@ -0,0 +1,18 @@ +#pragma once + +#include +#include + +#include + +#include + +namespace c10::openreg { + +OPENREG_EXPORT c10::DeviceIndex device_count() noexcept; +OPENREG_EXPORT c10::DeviceIndex current_device(); +OPENREG_EXPORT void set_device(c10::DeviceIndex device); + +OPENREG_EXPORT DeviceIndex ExchangeDevice(DeviceIndex device); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.cpp new file mode 100644 index 00000000..c2e03f66 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.cpp @@ -0,0 +1,28 @@ +#include "OpenRegGenerator.h" + +// Default, global generators, one per device. +static std::vector default_generators; + +namespace c10::openreg { + +const at::Generator& getDefaultOpenRegGenerator(c10::DeviceIndex device_index) { + static bool flag [[maybe_unused]] = []() { + auto deivce_nums = device_count(); + default_generators.resize(deivce_nums); + for (auto i = 0; i < deivce_nums; i++) { + default_generators[i] = at::make_generator(i); + default_generators[i].seed(); + } + return true; + }(); + + c10::DeviceIndex idx = device_index; + if (idx == -1) { + idx = current_device(); + } else { + TORCH_CHECK(idx >= 0 && idx < device_count()); + } + return default_generators[idx]; +} + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.h b/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.h new file mode 100644 index 00000000..877a9707 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegGenerator.h @@ -0,0 +1,21 @@ +#include +#include + +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { +class OpenRegGeneratorImpl : public at::CPUGeneratorImpl { + public: + OpenRegGeneratorImpl(c10::DeviceIndex device_index) { + device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); + key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); + } + ~OpenRegGeneratorImpl() override = default; +}; + +const at::Generator& getDefaultOpenRegGenerator( + c10::DeviceIndex device_index = -1); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegGuard.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegGuard.cpp new file mode 100644 index 00000000..d50e56e4 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegGuard.cpp @@ -0,0 +1,7 @@ +#include "OpenRegGuard.h" + +namespace c10::openreg { + +C10_REGISTER_GUARD_IMPL(PrivateUse1, OpenRegGuardImpl); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegGuard.h b/PyTorchSimDevice/csrc/runtime/OpenRegGuard.h new file mode 100644 index 00000000..f0150fe6 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegGuard.h @@ -0,0 +1,197 @@ +#include +#include + +#include + +#include "OpenRegFunctions.h" + +namespace c10::openreg { + +// Device guard registration +struct OpenRegGuardImpl final : public c10::impl::DeviceGuardImplInterface { + static constexpr c10::DeviceType static_type = c10::DeviceType::PrivateUse1; + + OpenRegGuardImpl() = default; + explicit OpenRegGuardImpl(c10::DeviceType t) { + TORCH_INTERNAL_ASSERT(t == static_type); + } + + /** + * Return the type of device managed by this guard implementation. + */ + c10::DeviceType type() const override { + return static_type; + } + + /** + * Set the current device to Device, and return the previous c10::Device. + */ + c10::Device exchangeDevice(c10::Device d) const override { + TORCH_CHECK(d.is_privateuseone()); + + auto old_device_index = ExchangeDevice(d.index()); + return c10::Device(static_type, old_device_index); + } + + /** + * Get the current device. + */ + c10::Device getDevice() const override { + int device_index = current_device(); + return c10::Device(static_type, device_index); + } + + /** + * Set the current device to c10::Device. + */ + void setDevice(c10::Device d) const override { + TORCH_CHECK(d.is_privateuseone()); + + set_device(d.index()); + } + + /** + * Set the current device to c10::Device, without checking for errors + * (so, e.g., this can be called from a destructor). + */ + void uncheckedSetDevice(c10::Device d) const noexcept override { + TORCH_CHECK(d.is_privateuseone()); + + set_device(d.index()); + } + + /** + * Get the current stream for a given device. + */ + c10::Stream getStream(c10::Device d) const noexcept override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Get the default stream for a given device. + */ + c10::Stream getDefaultStream(c10::Device d) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Get a stream from the global pool for a given device. + */ + c10::Stream getStreamFromGlobalPool( + c10::Device d, + bool isHighPriority = false) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Return a new stream for a given device and priority. The stream will be + * copied and shared around, device backend should be able to correctly handle + * the lifetime of the stream. + */ + c10::Stream getNewStream(c10::Device d, int priority = 0) const override { + return c10::Stream(c10::Stream::DEFAULT, d); + } + + /** + * Set a stream to be the thread local current stream for its device. + * Return the previous stream for that device. You are NOT required + * to set the current device to match the device of this stream. + */ + c10::Stream exchangeStream(c10::Stream s) const noexcept override { + return s; + } + + /** + * Destroys the given event. + */ + void destroyEvent(void* event, const c10::DeviceIndex device_index) + const noexcept override {} + + /** + * Increments the event's version and enqueues a job with this version + * in the stream's work queue. When the stream process that job + * it notifies all streams waiting on / blocked by that version of the + * event to continue and marks that version as recorded. + * */ + void record( + void** event, + const c10::Stream& stream, + const c10::DeviceIndex device_index, + const c10::EventFlag flag) const override { + static int event_id = 1; + + if (!*event) + *event = reinterpret_cast(event_id++); + } + + /** + * Does nothing if the event has not been scheduled to be recorded. + * If the event was previously enqueued to be recorded, a command + * to wait for the version of the event that exists at the time of this call + * is inserted in the stream's work queue. + * When the stream reaches this command it will stop processing + * additional commands until that version of the event is marked as recorded. + */ + void block(void* event, const c10::Stream& stream) const override {} + + /** + * Returns true if (and only if) + * (1) the event has never been scheduled to be recorded + * (2) the current version is marked as recorded. + * Returns false otherwise. + */ + bool queryEvent(void* event) const override { + return true; + } + + /** + * Get the number of devices. WARNING: This is REQUIRED to not raise + * an exception. If there is some sort of problem, e.g., driver error, + * you should report that there are zero available devices. + */ + c10::DeviceIndex deviceCount() const noexcept override { + int device_index = -1; + orGetDeviceCount(&device_index); + return device_index; + } + /** + * Return true if all the work previously enqueued on the stream for + * asynchronous execution has completed running on the device. + */ + bool queryStream(const c10::Stream& stream) const override { + return true; + } + + /** + * Wait (by blocking the calling thread) until all the work previously + * enqueued on the stream has completed running on the device. + */ + void synchronizeStream(const c10::Stream& stream) const override {} + + /** + * Wait (by blocking the calling thread) until all the work previously + * recorded on the event has completed running on the device. + */ + void synchronizeEvent(void* event) const override {} + + /** + * Ensure the caching allocator (if any) is aware that the given DataPtr is + * being used on the given stream, and that it should thus avoid recycling the + * DataPtr until all work on that stream is done. + */ + void recordDataPtrOnStream( + const c10::DataPtr& data_ptr, + const c10::Stream& stream) const override {} + + /** + * Fetch the elapsed time between two recorded events. + */ + double elapsedTime( + void* event1, + void* event2, + const c10::DeviceIndex device_index) const override { + return 1; + } +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegHooks.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegHooks.cpp new file mode 100644 index 00000000..57bc2d9f --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegHooks.cpp @@ -0,0 +1,11 @@ +#include "OpenRegHooks.h" + +namespace c10::openreg { + +static bool register_hook_flag [[maybe_unused]] = []() { + at::RegisterPrivateUse1HooksInterface(new OpenRegHooksInterface()); + + return true; +}(); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegHooks.h b/PyTorchSimDevice/csrc/runtime/OpenRegHooks.h new file mode 100644 index 00000000..656fba8e --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegHooks.h @@ -0,0 +1,41 @@ +#include +#include + +#include +#include + +#include + +#include "OpenRegGenerator.h" + +namespace c10::openreg { +struct OpenRegHooksInterface : public at::PrivateUse1HooksInterface { + OpenRegHooksInterface() {}; + ~OpenRegHooksInterface() override = default; + + bool hasPrimaryContext(c10::DeviceIndex device_index) const override { + return true; + } + + at::Allocator* getPinnedMemoryAllocator() const override { + return at::getHostAllocator(at::kPrivateUse1); + } + + bool isPinnedPtr(const void* data) const override { + orPointerAttributes attr{}; + orPointerGetAttributes(&attr, data); + + return attr.type == orMemoryTypeHost; + } + + const at::Generator& getDefaultGenerator( + c10::DeviceIndex device_index) const override { + return getDefaultOpenRegGenerator(device_index); + } + + at::Generator getNewGenerator(c10::DeviceIndex device_index) const override { + return at::make_generator(device_index); + } +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.cpp new file mode 100644 index 00000000..55263803 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.cpp @@ -0,0 +1,8 @@ +#include "OpenRegHostAllocator.h" + +namespace c10::openreg { + +OpenRegHostAllocator caching_host_allocator; +REGISTER_HOST_ALLOCATOR(at::kPrivateUse1, &caching_host_allocator); + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.h b/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.h new file mode 100644 index 00000000..edef545a --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegHostAllocator.h @@ -0,0 +1,48 @@ +#include + +#include +#include + +#include + +namespace c10::openreg { +struct OpenRegHostAllocator final : at::HostAllocator { + OpenRegHostAllocator() = default; + + static void ReportAndDelete(void* ptr) { + if (!ptr) { + return; + } + orFreeHost(ptr); + } + + at::DataPtr allocate(size_t nbytes) override { + void* data = nullptr; + if (nbytes > 0) { + orMallocHost(&data, nbytes); + TORCH_CHECK(data, "Failed to allocator ", nbytes, " bytes on host."); + } + return {data, data, &ReportAndDelete, at::Device(at::kCPU)}; + } + + at::DeleterFnPtr raw_deleter() const override { + return &ReportAndDelete; + } + + void copy_data(void* dest, const void* src, std::size_t count) const final { + orMemcpy(dest, src, count, orMemcpyHostToHost); + } + + // ignore + bool record_event(void* ptr, void* ctx, c10::Stream stream) override { + return true; + } + void empty_cache() override {} + at::HostStats get_stats() override { + return at::HostStats(); + } + void reset_accumulated_stats() override {} + void reset_peak_stats() override {} +}; + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.cpp new file mode 100644 index 00000000..43809d60 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.cpp @@ -0,0 +1,48 @@ +#include "OpenRegSerialization.h" + +namespace c10::openreg { +struct OpenRegBackendMeta : public c10::BackendMeta { + OpenRegBackendMeta(int version_number, int format_number) + : version_number_(version_number), format_number_(format_number) {} + + int version_number_{-1}; + int format_number_{-1}; +}; + +void for_serialization( + const at::Tensor& t, + std::unordered_map& m) { + auto meta_ptr = t.unsafeGetTensorImpl()->get_backend_meta(); + + if (meta_ptr != nullptr) { + auto o_meta_ptr = dynamic_cast(meta_ptr); + if (o_meta_ptr->version_number_ == 1) { + m["version_number"] = true; + } + if (o_meta_ptr->format_number_ == 29) { + m["format_number"] = true; + } + } +} + +void for_deserialization( + const at::Tensor& t, + std::unordered_map& m) { + int version_number{-1}; + int format_number{-1}; + + if (m.find("version_number") != m.end()) { + version_number = 1; + } + if (m.find("format_number") != m.end()) { + format_number = 29; + } + + c10::intrusive_ptr meta{std::unique_ptr( + new OpenRegBackendMeta(version_number, format_number))}; + t.unsafeGetTensorImpl()->set_backend_meta(meta); +} + +REGISTER_PRIVATEUSE1_SERIALIZATION(&for_serialization, &for_deserialization) + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.h b/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.h new file mode 100644 index 00000000..559e92ea --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegSerialization.h @@ -0,0 +1,10 @@ +#include + +#define REGISTER_PRIVATEUSE1_SERIALIZATION( \ + FOR_SERIALIZATION, FOR_DESERIALIZATION) \ + static int register_serialization() { \ + torch::jit::TensorBackendMetaRegistry( \ + c10::DeviceType::PrivateUse1, FOR_SERIALIZATION, FOR_DESERIALIZATION); \ + return 0; \ + } \ + static const int _temp = register_serialization(); diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegStream.cpp b/PyTorchSimDevice/csrc/runtime/OpenRegStream.cpp new file mode 100644 index 00000000..aa6c325d --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegStream.cpp @@ -0,0 +1,253 @@ +#include "OpenRegStream.h" + +#include +#include +#include + +#include +#include +#include +#include + +namespace c10::openreg { + +namespace { + +// Global stream state and constants +static c10::once_flag init_flag; + +static DeviceIndex num_devices = -1; +static constexpr int kStreamsPerPoolBits = 5; +static constexpr int kStreamsPerPool = 1 << kStreamsPerPoolBits; +static constexpr int kStreamTypeBits = 2; + +/* + * The stream pools are lazily initialized when the first queue is requested + * for a device. The device flags track the initialization of each device. When + * a queue is requested, the next queue in the pool to be returned in a + * round-robin fashion, see Note [Stream Management]. + */ +static std::deque device_flags; +static std::vector, + c10::openreg::max_compile_time_stream_priorities>> + streams; +static std::deque< + std::array, max_compile_time_stream_priorities>> + priority_counters; + +static thread_local std::unique_ptr current_streams = nullptr; + +/* + * Note [StreamId assignment] + * ~~~~~~~~~~~~~~~~~~~~~~~~~~ + * How do we assign stream IDs? + * + * -- 56 bits -- -- 5 bits -- -- 2 bits -- -- 1 bit -- + * zeros StreamIdIndex StreamIdType Ext/native stream + * ignored for ext ignored for ext + * + * Where StreamIdType: + * 00 = default stream + * 01 = normal stream + * 11 = external stream + * + * For external stream, StreamID is a orStream_t pointer. This means that last + * bit will always be 0. So when constructing StreamId for a native stream we + * set last bit to 1 to distinguish between native and external streams. + * + * StreamId is 64-bit, so we can just rely on regular promotion rules. + * We rely on StreamIdIndex and StreamIdType being non-negative; + */ +using StreamIdIndex = uint8_t; +enum class StreamIdType : uint8_t { + DEFAULT = 0x0, + NORMAL = 0x1, + EXT = 0x3, +}; + +inline std::ostream& operator<<(std::ostream& stream, StreamIdType s) { + switch (s) { + case StreamIdType::DEFAULT: + return stream << "DEFAULT"; + case StreamIdType::NORMAL: + return stream << "NORMAL"; + case StreamIdType::EXT: + return stream << "EXT"; + default: + break; + } + + return stream << static_cast(s); +} + +static inline StreamIdType streamIdType(StreamId s) { + // Externally allocated streams have their id being the orStream_ptr + // so the last bit will be 0 + if (!(s & 1)) { + return StreamIdType(StreamIdType::EXT); + } + + int mask_for_type = (1 << kStreamTypeBits) - 1; + auto st = static_cast((s >> 1) & mask_for_type); + TORCH_CHECK( + st == StreamIdType::DEFAULT || st == StreamIdType::NORMAL, + "invalid StreamId: ", + s); + return st; +} + +static inline size_t streamIdIndex(StreamId s) { + return static_cast( + (s >> (kStreamTypeBits + 1)) & ((1 << kStreamsPerPoolBits) - 1)); +} + +StreamId makeStreamId(StreamIdType st, size_t si) { + if (st == StreamIdType::EXT) { + return static_cast(0); + } + + return (static_cast(si) << (kStreamTypeBits + 1)) | + (static_cast(st) << 1) | 1; +} + +static void initGlobalStreamState() { + num_devices = device_count(); + device_flags.resize(num_devices); + streams.resize(num_devices); + priority_counters.resize(num_devices); +} + +static void initSingleDeviceStream( + int priority, + DeviceIndex device_index, + int i) { + auto& stream = streams[device_index][priority][i]; + + OPENREG_CHECK(orStreamCreateWithPriority(&stream, 0, priority)); + priority_counters[device_index][priority] = 0; +} + +// Creates stream pools for the specified device. It should be call only once. +static void initDeviceStreamState(DeviceIndex device_index) { + for (const auto i : c10::irange(kStreamsPerPool)) { + for (const auto p : c10::irange(max_compile_time_stream_priorities)) { + initSingleDeviceStream(p, device_index, i); + } + } +} + +static void initOpenRegStreamsOnce() { + c10::call_once(init_flag, initGlobalStreamState); + + if (current_streams) { + return; + } + + // Inits current streams (thread local) to the last queue in the "normal + // priority" queue pool. Note: the queue pool have not been initialized yet. + // It will be initialized in initDeviceStreamState for the specified device. + current_streams = std::make_unique(num_devices); + for (const auto i : c10::irange(num_devices)) { + current_streams[i] = makeStreamId(StreamIdType::DEFAULT, 0); + } +} + +static uint32_t get_idx(std::atomic& counter) { + auto raw_idx = counter++; + return raw_idx % kStreamsPerPool; +} + +OpenRegStream OpenRegStreamForId(DeviceIndex device_index, StreamId stream_id) { + return OpenRegStream( + OpenRegStream::UNCHECKED, + Stream( + Stream::UNSAFE, + c10::Device(DeviceType::PrivateUse1, device_index), + stream_id)); +} + +} // anonymous namespace + +// See Note [StreamId assignment] +orStream_t OpenRegStream::stream() const { + c10::DeviceIndex device_index = stream_.device_index(); + StreamId stream_id = stream_.id(); + StreamIdType st = streamIdType(stream_id); + size_t si = streamIdIndex(stream_id); + switch (st) { + // The index 0 stream is default as well. + case StreamIdType::DEFAULT: + case StreamIdType::NORMAL: + return streams[device_index][static_cast(st)][si]; + case StreamIdType::EXT: + return reinterpret_cast(stream_id); + default: + TORCH_CHECK( + false, + "Unrecognized stream ", + stream_, + " (I didn't recognize the stream type, ", + st, + ").", + " Did you manufacture the StreamId yourself? Don't do that;"); + } +} + +// Returns a stream from the requested pool +// Note: when called the first time on a device, this will create the +// stream pools for that device. +OpenRegStream getStreamFromPool(const int priority, DeviceIndex device_index) { + initOpenRegStreamsOnce(); + if (device_index == -1) { + device_index = current_device(); + } + c10::call_once( + device_flags[device_index], initDeviceStreamState, device_index); + auto pri_idx = + std::clamp(priority, 0, max_compile_time_stream_priorities - 1); + const auto idx = get_idx(priority_counters[device_index][pri_idx]); + auto id_type = static_cast(pri_idx); + return OpenRegStreamForId(device_index, makeStreamId(id_type, idx)); +} + +OpenRegStream getStreamFromPool(const bool isHighPriority, DeviceIndex device) { + initOpenRegStreamsOnce(); + int priority = 0; + return getStreamFromPool(priority, device); +} + +OpenRegStream getStreamFromExternal( + orStream_t ext_stream, + DeviceIndex device_index) { + return OpenRegStreamForId( + device_index, reinterpret_cast(ext_stream)); +} + +OpenRegStream getDefaultOpenRegStream(DeviceIndex device_index) { + initOpenRegStreamsOnce(); + if (device_index == -1) { + device_index = current_device(); + } + return OpenRegStreamForId( + device_index, makeStreamId(StreamIdType::DEFAULT, 0)); +} + +OpenRegStream getCurrentOpenRegStream(DeviceIndex device_index) { + initOpenRegStreamsOnce(); + if (device_index == -1) { + device_index = current_device(); + } + return OpenRegStreamForId(device_index, current_streams[device_index]); +} + +void setCurrentOpenRegStream(OpenRegStream stream) { + initOpenRegStreamsOnce(); + current_streams[stream.device_index()] = stream.id(); +} + +std::ostream& operator<<(std::ostream& stream, const OpenRegStream& s) { + return stream << s.unwrap(); +} + +} // namespace c10::openreg diff --git a/PyTorchSimDevice/csrc/runtime/OpenRegStream.h b/PyTorchSimDevice/csrc/runtime/OpenRegStream.h new file mode 100644 index 00000000..e1fd0c71 --- /dev/null +++ b/PyTorchSimDevice/csrc/runtime/OpenRegStream.h @@ -0,0 +1,162 @@ +#pragma once + +#include + +#include "OpenRegException.h" +#include "OpenRegFunctions.h" + +#include +#include +#include + +namespace c10::openreg { + +static constexpr int max_compile_time_stream_priorities = 1; + +class OpenRegStream { + public: + enum Unchecked { UNCHECKED }; + + explicit OpenRegStream(Stream stream) : stream_(stream) { + TORCH_CHECK(stream_.device_type() == DeviceType::PrivateUse1); + } + + explicit OpenRegStream(Unchecked, Stream stream) : stream_(stream) {} + + bool operator==(const OpenRegStream& other) const noexcept { + return unwrap() == other.unwrap(); + } + + bool operator!=(const OpenRegStream& other) const noexcept { + return unwrap() != other.unwrap(); + } + + operator orStream_t() const { + return stream(); + } + + operator Stream() const { + return unwrap(); + } + + DeviceType device_type() const { + return DeviceType::PrivateUse1; + } + + DeviceIndex device_index() const { + return stream_.device_index(); + } + + Device device() const { + return Device(DeviceType::PrivateUse1, device_index()); + } + + StreamId id() const { + return stream_.id(); + } + + bool query() const { + DeviceGuard guard{stream_.device()}; + + if (orStreamQuery(stream()) == orSuccess) { + return true; + } + + return false; + } + + void synchronize() const { + DeviceGuard guard{stream_.device()}; + OPENREG_CHECK(orStreamSynchronize(stream())); + } + + int priority() const { + DeviceGuard guard{stream_.device()}; + int priority = 0; + OPENREG_CHECK(orStreamGetPriority(stream(), &priority)); + return priority; + } + + orStream_t stream() const; + + Stream unwrap() const { + return stream_; + } + + struct c10::StreamData3 pack3() const { + return stream_.pack3(); + } + + static OpenRegStream unpack3( + StreamId stream_id, + DeviceIndex device_index, + DeviceType device_type) { + return OpenRegStream(Stream::unpack3(stream_id, device_index, device_type)); + } + + private: + Stream stream_; +}; + +/* + * Get a stream from the pool in a round-robin fashion. + * + * You can request a stream from the highest priority pool by setting + * isHighPriority to true for a specific device. + */ +OPENREG_EXPORT OpenRegStream +getStreamFromPool(const bool isHighPriority = false, DeviceIndex device = -1); + +/* + * Get a stream from the pool in a round-robin fashion. + * + * You can request a stream by setting a priority value for a specific device. + * The priority number lower, the priority higher. + */ +OPENREG_EXPORT OpenRegStream +getStreamFromPool(const int priority, DeviceIndex device = -1); + +/* + * Get a OpenRegStream from a externally allocated one. + * + * This is mainly for interoperability with different libraries where we + * want to operate on a non-torch allocated stream for data exchange or similar + * purposes + */ +OPENREG_EXPORT OpenRegStream +getStreamFromExternal(orStream_t ext_stream, DeviceIndex device_index); + +/* + * Get the default OpenReg stream, for the passed OpenReg device, or for the + * current device if no device index is passed. + */ +OPENREG_EXPORT OpenRegStream +getDefaultOpenRegStream(DeviceIndex device_index = -1); + +/* + * Get the current OpenReg stream, for the passed OpenReg device, or for the + * current device if no device index is passed. + */ +OPENREG_EXPORT OpenRegStream +getCurrentOpenRegStream(DeviceIndex device_index = -1); + +/* + * Set the current stream on the device of the passed in stream to be the passed + * in stream. + */ +OPENREG_EXPORT void setCurrentOpenRegStream(OpenRegStream stream); + +OPENREG_EXPORT std::ostream& operator<<( + std::ostream& stream, + const OpenRegStream& s); + +} // namespace c10::openreg + +namespace std { +template <> +struct hash { + size_t operator()(c10::openreg::OpenRegStream s) const noexcept { + return std::hash{}(s.unwrap()); + } +}; +} // namespace std diff --git a/PyTorchSimDevice/include/Macros.h b/PyTorchSimDevice/include/Macros.h new file mode 100644 index 00000000..c75523c2 --- /dev/null +++ b/PyTorchSimDevice/include/Macros.h @@ -0,0 +1,7 @@ +#pragma once + +#ifdef _WIN32 +#define OPENREG_EXPORT __declspec(dllexport) +#else +#define OPENREG_EXPORT __attribute__((visibility("default"))) +#endif diff --git a/PyTorchSimDevice/pyproject.toml b/PyTorchSimDevice/pyproject.toml new file mode 100644 index 00000000..774fe5cd --- /dev/null +++ b/PyTorchSimDevice/pyproject.toml @@ -0,0 +1,35 @@ +[build-system] +requires = [ + "setuptools", + "wheel", + "torch", # Needed by setup.py for getting include of PyTorch +] + +build-backend = "setuptools.build_meta" + +[project] +name = "torch_openreg" +version = "0.0.1" +description = "A minimal reference implementation of an out-of-tree backend" +readme = "README.md" +requires-python = ">=3.9" +license = { text = "BSD-3-Clause" } +authors = [{ name = "PyTorch Team", email = "packages@pytorch.org" }] +dependencies = [ + "torch", +] +# Add classifiers info for making lint happy +classifiers = [ + "Development Status :: 4 - Beta", + "Topic :: Software Development", + "Topic :: Software Development :: Libraries", + "Topic :: Software Development :: Libraries :: Python Modules", + "Programming Language :: C++", + "Programming Language :: Python :: 3 :: Only", +] + +[project.urls] +Homepage = "https://pytorch.org" +Repository = "https://github.com/pytorch/pytorch" +Documentation = "https://pytorch.org/docs" +Forum = "https://discuss.pytorch.org" diff --git a/PyTorchSimDevice/setup.py b/PyTorchSimDevice/setup.py new file mode 100644 index 00000000..01e2f065 --- /dev/null +++ b/PyTorchSimDevice/setup.py @@ -0,0 +1,148 @@ +import multiprocessing +import os +import platform +import shutil +import subprocess +import sys +import sysconfig +from distutils.command.clean import clean + +from setuptools import Extension, find_packages, setup + + +# Env Variables +IS_DARWIN = platform.system() == "Darwin" +IS_WINDOWS = platform.system() == "Windows" + +BASE_DIR = os.path.dirname(os.path.realpath(__file__)) +RUN_BUILD_DEPS = any(arg in {"clean", "dist_info"} for arg in sys.argv) + + +def make_relative_rpath_args(path): + if IS_DARWIN: + return ["-Wl,-rpath,@loader_path/" + path] + elif IS_WINDOWS: + return [] + else: + return ["-Wl,-rpath,$ORIGIN/" + path] + + +def get_pytorch_dir(): + os.environ["TORCH_DEVICE_BACKEND_AUTOLOAD"] = "0" + import torch + + return os.path.dirname(os.path.realpath(torch.__file__)) + + +def build_deps(): + build_dir = os.path.join(BASE_DIR, "build") + os.makedirs(build_dir, exist_ok=True) + + cmake_args = [ + "-DCMAKE_INSTALL_PREFIX=" + + os.path.realpath(os.path.join(BASE_DIR, "torch_openreg")), + "-DPYTHON_INCLUDE_DIR=" + sysconfig.get_paths().get("include"), + "-DPYTORCH_INSTALL_DIR=" + get_pytorch_dir(), + ] + + subprocess.check_call( + ["cmake", BASE_DIR] + cmake_args, cwd=build_dir, env=os.environ + ) + + build_args = [ + "--build", + ".", + "--target", + "install", + "--config", # For multi-config generators + "Release", + "--", + ] + + if IS_WINDOWS: + build_args += ["/m:" + str(multiprocessing.cpu_count())] + else: + build_args += ["-j", str(multiprocessing.cpu_count())] + + command = ["cmake"] + build_args + subprocess.check_call(command, cwd=build_dir, env=os.environ) + + +class BuildClean(clean): + def run(self): + for i in ["build", "install", "torch_openreg/lib"]: + dirs = os.path.join(BASE_DIR, i) + if os.path.exists(dirs) and os.path.isdir(dirs): + shutil.rmtree(dirs) + + for dirpath, _, filenames in os.walk(os.path.join(BASE_DIR, "torch_openreg")): + for filename in filenames: + if filename.endswith(".so"): + os.remove(os.path.join(dirpath, filename)) + + +def main(): + if not RUN_BUILD_DEPS: + build_deps() + + if IS_WINDOWS: + # /NODEFAULTLIB makes sure we only link to DLL runtime + # and matches the flags set for protobuf and ONNX + extra_link_args: list[str] = ["/NODEFAULTLIB:LIBCMT.LIB"] + [ + *make_relative_rpath_args("lib") + ] + # /MD links against DLL runtime + # and matches the flags set for protobuf and ONNX + # /EHsc is about standard C++ exception handling + extra_compile_args: list[str] = ["/MD", "/FS", "/EHsc"] + else: + extra_link_args = [*make_relative_rpath_args("lib")] + extra_compile_args = [ + "-Wall", + "-Wextra", + "-Wno-strict-overflow", + "-Wno-unused-parameter", + "-Wno-missing-field-initializers", + "-Wno-unknown-pragmas", + "-fno-strict-aliasing", + ] + + ext_modules = [ + Extension( + name="torch_openreg._C", + sources=["torch_openreg/csrc/stub.c"], + language="c", + extra_compile_args=extra_compile_args, + libraries=["torch_bindings"], + library_dirs=[os.path.join(BASE_DIR, "torch_openreg/lib")], + extra_link_args=extra_link_args, + ) + ] + + package_data = { + "torch_openreg": [ + "lib/*.so*", + "lib/*.dylib*", + "lib/*.dll", + "lib/*.lib", + ] + } + + setup( + packages=find_packages(), + package_data=package_data, + ext_modules=ext_modules, + cmdclass={ + "clean": BuildClean, # type: ignore[misc] + }, + include_package_data=False, + entry_points={ + "torch.backends": [ + "torch_openreg = torch_openreg:_autoload", + ], + }, + ) + + +if __name__ == "__main__": + main() diff --git a/PyTorchSimDevice/third_party/openreg/CMakeLists.txt b/PyTorchSimDevice/third_party/openreg/CMakeLists.txt new file mode 100644 index 00000000..1bde7e00 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/CMakeLists.txt @@ -0,0 +1,21 @@ +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(TORCH_OPENREG CXX C) + + +set(LIBRARY_NAME openreg) +set(LIBRARY_TEST ortests) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/csrc/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_include_directories(${LIBRARY_NAME} PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) + +install(TARGETS ${LIBRARY_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/PyTorchSimDevice/third_party/openreg/README.md b/PyTorchSimDevice/third_party/openreg/README.md new file mode 100644 index 00000000..0cee2c87 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/README.md @@ -0,0 +1,151 @@ +# OpenReg: An Accelerator Backend that Simulates CUDA Behavior on a CPU + +## Introduction + +OpenReg is a C++ backend library that simulates the behavior of a CUDA-like device on a CPU. Its core objective is **not to accelerate computation or improve performance**, but rather to **simulate modern CUDA programming, enabling developers to prototype and test in an environment without actual GPU hardware**. The current design principles are as follows: + +* **API Consistency**: Provide an interface consistent with the CUDA Runtime API, allowing upper-level applications (like PyTorch's `PrivateUse1` backend) to switch and test seamlessly. +* **Functional Consistency**: Provide behavior consistent with the CUDA Runtime, such as memory isolation, device context management, etc. +* **Completeness**: Aim to support `PrivateUse1` device integration and safeguard the third-party device integration mechanism, without striving to cover all capabilities of the CUDA Runtime. + +## Directory Structure + +The project's code is organized with a clear structure and separation of responsibilities: + +```text +openreg/ +├── README.md # Comprehensive introduction of OpenReg. +├── CMakeLists.txt # Top-level CMake build script, used to compile and generate libopenreg.so +├── cmake/ +│ └── GTestTargets.cmake # Utils of fetching GoogleTest. +├── include/ +│ ├── openreg.h # Public API header file, external users only need to include this file +│ └── openreg.inl # Public API header file, as an extension of openreg.h, cannot be included separately. +├── example/ +│ └── example.cpp # Example for OpenReg. +├── tests/ +│ ├── event_tests.cpp # Testcases about OpenReg Event. +│ ├── stream_tests.cpp # Testcases about OpenReg Stream. +│ ├── device_tests.cpp # Testcases about OpenReg Device. +│ └── memory_tests.cpp # Testcases about OpenReg Memory. +└── csrc/ + ├── device.cpp # Implementation of device management APIs + ├── memory.cpp # Implementation of memory management APIs + └── stream.cpp # Implementation of stream and event APIs. +``` + +* `CMakeLists.txt`: Responsible for compiling and linking all source files under the `csrc/` directory to generate the final `libopenreg.so` shared library. +* `include`: Defines all externally exposed APIs, data structures, and enums. + * `openreg.h`: Defines all externally exposed C-style APIs. + * `openreg.inl`: Defines all externally exposed C++ APIs. +* `csrc/`: Contains the C++ implementation source code for all core functionalities. + * `device.cpp`: Implements the core functions of device management: device discovery and context management. + * `memory.cpp`: Implements the core functions of memory management: allocation, free, copy and memory protection. + * `stream.cpp`: Implements the core functions of stream and event: creation, destroy, record, synchronization and so on. + +## Implemented APIs + +OpenReg currently provides a set of APIs covering basic memory and device management. + +### Device Management APIs + +| OpenReg | CUDA | Feature Description | +| :------------------------------- | :--------------------------------- | :--------------------------------- | +| `orGetDeviceCount` | `cudaGetDeviceCount` | Get the number of available GPUs | +| `orSetDevice` | `cudaSetDevice` | Set the active GPU | +| `orGetDevice` | `cudaGetDevice` | Get the current GPU | +| `orDeviceSynchronize` | `cudaDeviceSynchronize` | Wait for all GPU tasks to finish | +| `orDeviceGetStreamPriorityRange` | `cudaDeviceGetStreamPriorityRange` | Get the range of stream priorities | + +### Memory Management APIs + +| OpenReg | CUDA | Feature Description | +| :----------------------- | :------------------------- | :---------------------------------------- | +| `orMalloc` | `cudaMalloc` | Allocate device memory | +| `orFree` | `cudaFree` | Free device memory | +| `orMallocHost` | `cudaMallocHost` | Allocate page-locked (Pinned) host memory | +| `orFreeHost` | `cudaFreeHost` | Free page-locked host memory | +| `orMemcpy` | `cudaMemcpy` | Synchronous memory copy | +| `orMemcpyAsyn` | `cudaMemcpyAsyn` | Asynchronous memory copy | +| `orPointerGetAttributes` | `cudaPointerGetAttributes` | Get pointer attributes | + +### Stream APIs + +| OpenReg | CUDA | Feature Description | +| :--------------------------- | :----------------------------- | :------------------------------------- | +| `orStreamCreate` | `cudaStreamCreate` | Create a default-priority stream | +| `orStreamCreateWithPriority` | `cudaStreamCreateWithPriority` | Create a stream with a given priority | +| `orStreamDestroy` | `cudaStreamDestroy` | Destroy a stream | +| `orStreamQuery` | `cudaStreamQuery` | Check if a stream has completed | +| `orStreamSynchronize` | `cudaStreamSynchronize` | Wait for a stream to complete | +| `orStreamWaitEvent` | `cudaStreamWaitEvent` | Make a stream wait for an event | +| `orStreamGetPriority` | `cudaStreamGetPriority` | Get a stream’s priority | + +### Event APIs + +| OpenReg | CUDA | Feature Description | +| :----------------------- | :------------------------- | :---------------------------------- | +| `orEventCreate` | `cudaEventCreate` | Create an event with default flag | +| `orEventCreateWithFlags` | `cudaEventCreateWithFlags` | Create an event with specific flag | +| `orEventDestroy` | `cudaEventDestroy` | Destroy an event | +| `orEventRecord` | `cudaEventRecord` | Record an event in a stream | +| `orEventSynchronize` | `cudaEventSynchronize` | Wait for an event to complete | +| `orEventQuery` | `cudaEventQuery` | Check if an event has completed | +| `orEventElapsedTime` | `cudaEventElapsedTime` | Get time elapsed between two events | + +## Implementation Principles + +### Device Management Principles + +Simulating multiple devices and thread-safe device context switching: + +1. **Device Count**: The total number of simulated devices is defined by the compile-time constant `constexpr int kDeviceCount`. +2. **Device Switching**: Device switching in multi-threaded scenarios is simulated using a **TLS (Thread-Local Storage) global variable**. + +### Memory Management Principles + +Simulating device memory, host memory, and memory copies: + +1. **Allocation**: A page-aligned memory block is allocated using `mmap` + `mprotect` with the permission flag `PROT_NONE`. Read, write, and execute operations on this memory region are all prohibited. +2. **Deallocation**: Memory is freed using `munmap`. +3. **Authorization**: When a legitimate memory access is required, an RAII guard restores the memory permissions to `PROT_READ | PROT_WRITE`. The permissions are automatically reverted to `PROT_NONE` when the scope is exited. + +### Stream&Event Principles + +Simulating creation, release and synchronization for event and steam: + +1. **Event**: Each event is encapsulated as a task function and placed into a stream, which acts as a thread. Upon completion of the task, a flag within the event is modified to simulate the event's status. +2. **Stream**: When each stream is requested, a new thread is created, which sequentially processes each task in the task queue within the stream structure. Tasks can be wrappers around kernel functions or events. +3. **Synchronization**: Synchronization between streams and events is achieved using multithreading, condition variables, and mutexes. + +## Usage Example + +Please refer to [example](example/example.cpp) for example. + +The command to compile example.cpp is as follow: + +```Shell +mkdir build + +pushd build +cmake .. +make -j 32 +popd + +g++ -o out example/example.cpp -L ./build -lopenreg +LD_LIBRARY_PATH=./build ./out +``` + +The output is as follow: + +```Shell +Current environment have 2 devices +Current is 0 device +All tasks have been submitted. +Kernel execution time: 0.238168 ms +Verification PASSED! +``` + +## Next Steps + +The most basic functions of the OpenReg backend are currently supported, and will be dynamically optimized and expanded based on the needs of PyTorch integration. diff --git a/PyTorchSimDevice/third_party/openreg/cmake/GTestTargets.cmake b/PyTorchSimDevice/third_party/openreg/cmake/GTestTargets.cmake new file mode 100644 index 00000000..777fc489 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/cmake/GTestTargets.cmake @@ -0,0 +1,12 @@ +set(GTest_REL_PATH "../../../../../../../third_party/googletest") +get_filename_component(GTest_DIR "${CMAKE_CURRENT_LIST_DIR}/${GTest_REL_PATH}" ABSOLUTE) + +if(EXISTS "${GTest_DIR}/CMakeLists.txt") + message(STATUS "Found GTest: ${GTest_DIR}") + + set(BUILD_GMOCK OFF CACHE BOOL "Disable GMock build") + set(INSTALL_GTEST OFF CACHE BOOL "Disable GTest install") + add_subdirectory(${GTest_DIR} "${CMAKE_BINARY_DIR}/gtest") +else() + message(FATAL_ERROR "GTest Not Found") +endif() diff --git a/PyTorchSimDevice/third_party/openreg/csrc/device.cpp b/PyTorchSimDevice/third_party/openreg/csrc/device.cpp new file mode 100644 index 00000000..9643bc59 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/csrc/device.cpp @@ -0,0 +1,37 @@ +#include + +namespace { + +// Total device numbers +constexpr int DEVICE_COUNT = 2; +// Current device index +thread_local int gCurrentDevice = 0; + +} // namespace + +orError_t orGetDeviceCount(int* count) { + if (!count) { + return orErrorUnknown; + } + + *count = DEVICE_COUNT; + return orSuccess; +} + +orError_t orGetDevice(int* device) { + if (!device) { + return orErrorUnknown; + } + + *device = gCurrentDevice; + return orSuccess; +} + +orError_t orSetDevice(int device) { + if (device < 0 || device >= DEVICE_COUNT) { + return orErrorUnknown; + } + + gCurrentDevice = device; + return orSuccess; +} diff --git a/PyTorchSimDevice/third_party/openreg/csrc/memory.cpp b/PyTorchSimDevice/third_party/openreg/csrc/memory.cpp new file mode 100644 index 00000000..6f02eeb0 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/csrc/memory.cpp @@ -0,0 +1,259 @@ +#include "memory.h" + +#include + +#include +#include + +namespace { + +struct Block { + orMemoryType type = orMemoryType::orMemoryTypeUnmanaged; + int device = -1; + void* pointer = nullptr; + size_t size = 0; + int refcount{0}; +}; + +class MemoryManager { + public: + static MemoryManager& getInstance() { + static MemoryManager instance; + return instance; + } + + orError_t allocate(void** ptr, size_t size, orMemoryType type) { + if (!ptr || size == 0) + return orErrorUnknown; + + std::lock_guard lock(m_mutex); + long page_size = openreg::get_pagesize(); + size_t aligned_size = ((size - 1) / page_size + 1) * page_size; + void* mem = nullptr; + int current_device = -1; + + if (type == orMemoryType::orMemoryTypeDevice) { + orGetDevice(¤t_device); + + mem = openreg::mmap(aligned_size); + if (mem == nullptr) + return orErrorUnknown; + if (openreg::mprotect(mem, aligned_size, F_PROT_NONE) != 0) { + openreg::munmap(mem, aligned_size); + return orErrorUnknown; + } + } else { + if (openreg::alloc(&mem, page_size, aligned_size) != 0) { + return orErrorUnknown; + } + } + + m_registry[mem] = {type, current_device, mem, aligned_size, 0}; + *ptr = mem; + return orSuccess; + } + + orError_t free(void* ptr) { + if (!ptr) + return orSuccess; + + std::lock_guard lock(m_mutex); + auto it = m_registry.find(ptr); + if (it == m_registry.end()) + return orErrorUnknown; + + const auto& info = it->second; + if (info.type == orMemoryType::orMemoryTypeDevice) { + openreg::mprotect(info.pointer, info.size, F_PROT_READ | F_PROT_WRITE); + openreg::munmap(info.pointer, info.size); + } else { + openreg::free(info.pointer); + } + + m_registry.erase(it); + return orSuccess; + } + + orError_t memcpy( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind) { + if (!dst || !src || count == 0) + return orErrorUnknown; + + std::lock_guard lock(m_mutex); + Block* dst_info = getBlockInfoNoLock(dst); + Block* src_info = getBlockInfoNoLock(src); + + switch (kind) { + case orMemcpyHostToDevice: + if ((!dst_info || dst_info->type != orMemoryType::orMemoryTypeDevice) || + (src_info && src_info->type == orMemoryType::orMemoryTypeDevice)) + return orErrorUnknown; + break; + case orMemcpyDeviceToHost: + if ((dst_info && dst_info->type == orMemoryType::orMemoryTypeDevice) || + (!src_info || src_info->type != orMemoryType::orMemoryTypeDevice)) + return orErrorUnknown; + break; + case orMemcpyDeviceToDevice: + if ((!dst_info || dst_info->type != orMemoryType::orMemoryTypeDevice) || + (!src_info || src_info->type != orMemoryType::orMemoryTypeDevice)) + return orErrorUnknown; + break; + case orMemcpyHostToHost: + if ((dst_info && dst_info->type == orMemoryType::orMemoryTypeDevice) || + (src_info && src_info->type == orMemoryType::orMemoryTypeDevice)) + return orErrorUnknown; + break; + } + + unprotectNoLock(dst_info); + unprotectNoLock(src_info); + ::memcpy(dst, src, count); + protectNoLock(dst_info); + protectNoLock(src_info); + + return orSuccess; + } + + orError_t getPointerAttributes( + orPointerAttributes* attributes, + const void* ptr) { + if (!attributes || !ptr) + return orErrorUnknown; + + std ::lock_guard lock(m_mutex); + Block* info = getBlockInfoNoLock(ptr); + + if (!info) { + attributes->type = orMemoryType::orMemoryTypeUnmanaged; + attributes->device = -1; + attributes->pointer = const_cast(ptr); + } else { + attributes->type = info->type; + attributes->device = info->device; + attributes->pointer = info->pointer; + } + + return orSuccess; + } + + orError_t unprotect(void* ptr) { + std::lock_guard lock(m_mutex); + return unprotectNoLock(getBlockInfoNoLock(ptr)); + } + + orError_t protect(void* ptr) { + std::lock_guard lock(m_mutex); + return protectNoLock(getBlockInfoNoLock(ptr)); + } + + private: + MemoryManager() = default; + + orError_t unprotectNoLock(Block* info) { + if (info && info->type == orMemoryType::orMemoryTypeDevice) { + if (info->refcount == 0) { + if (openreg::mprotect( + info->pointer, info->size, F_PROT_READ | F_PROT_WRITE) != 0) { + return orErrorUnknown; + } + } + + info->refcount++; + } + + return orSuccess; + } + + orError_t protectNoLock(Block* info) { + if (info && info->type == orMemoryType::orMemoryTypeDevice) { + if (info->refcount == 1) { + if (openreg::mprotect(info->pointer, info->size, F_PROT_NONE) != 0) { + return orErrorUnknown; + } + } + + info->refcount--; + } + + return orSuccess; + } + + Block* getBlockInfoNoLock(const void* ptr) { + auto it = m_registry.upper_bound(const_cast(ptr)); + if (it != m_registry.begin()) { + --it; + const char* p_char = static_cast(ptr); + const char* base_char = static_cast(it->first); + if (p_char >= base_char && p_char < (base_char + it->second.size)) { + return &it->second; + } + } + + return nullptr; + } + + std::map m_registry; + std::mutex m_mutex; +}; + +} // namespace + +orError_t orMalloc(void** devPtr, size_t size) { + return MemoryManager::getInstance().allocate( + devPtr, size, orMemoryType::orMemoryTypeDevice); +} + +orError_t orFree(void* devPtr) { + return MemoryManager::getInstance().free(devPtr); +} + +orError_t orMallocHost(void** hostPtr, size_t size) { + return MemoryManager::getInstance().allocate( + hostPtr, size, orMemoryType::orMemoryTypeHost); +} + +orError_t orFreeHost(void* hostPtr) { + return MemoryManager::getInstance().free(hostPtr); +} + +orError_t orMemcpy( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind) { + return MemoryManager::getInstance().memcpy(dst, src, count, kind); +} + +orError_t orMemcpyAsync( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind, + orStream_t stream) { + if (!stream) { + return orErrorUnknown; + } + + auto& mm = MemoryManager::getInstance(); + + return orLaunchKernel( + stream, &MemoryManager::memcpy, &mm, dst, src, count, kind); +} + +orError_t orPointerGetAttributes( + orPointerAttributes* attributes, + const void* ptr) { + return MemoryManager::getInstance().getPointerAttributes(attributes, ptr); +} + +orError_t orMemoryUnprotect(void* devPtr) { + return MemoryManager::getInstance().unprotect(devPtr); +} + +orError_t orMemoryProtect(void* devPtr) { + return MemoryManager::getInstance().protect(devPtr); +} diff --git a/PyTorchSimDevice/third_party/openreg/csrc/memory.h b/PyTorchSimDevice/third_party/openreg/csrc/memory.h new file mode 100644 index 00000000..35851ac9 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/csrc/memory.h @@ -0,0 +1,96 @@ +#pragma once + +#include +#include +#include + +#if defined(_WIN32) +#include +#else +#include +#include +#endif + +#define F_PROT_NONE 0x0 +#define F_PROT_READ 0x1 +#define F_PROT_WRITE 0x2 + +namespace openreg { + +void* mmap(size_t size) { +#if defined(_WIN32) + return VirtualAlloc(nullptr, size, MEM_RESERVE | MEM_COMMIT, PAGE_READWRITE); +#else + void* addr = ::mmap( + nullptr, + size, + PROT_READ | PROT_WRITE, + MAP_PRIVATE | MAP_ANONYMOUS, + -1, + 0); + return (addr == MAP_FAILED) ? nullptr : addr; +#endif +} + +void munmap(void* addr, size_t size) { +#if defined(_WIN32) + VirtualFree(addr, 0, MEM_RELEASE); +#else + ::munmap(addr, size); +#endif +} + +int mprotect(void* addr, size_t size, int prot) { +#if defined(_WIN32) + DWORD win_prot = 0; + DWORD old; + if (prot == F_PROT_NONE) { + win_prot = PAGE_NOACCESS; + } else { + win_prot = PAGE_READWRITE; + } + + return VirtualProtect(addr, size, win_prot, &old) ? 0 : -1; +#else + int native_prot = 0; + if (prot == F_PROT_NONE) + native_prot = PROT_NONE; + else { + if (prot & F_PROT_READ) + native_prot |= PROT_READ; + if (prot & F_PROT_WRITE) + native_prot |= PROT_WRITE; + } + + return ::mprotect(addr, size, native_prot); +#endif +} + +int alloc(void** mem, size_t alignment, size_t size) { +#ifdef _WIN32 + *mem = _aligned_malloc(size, alignment); + return *mem ? 0 : -1; +#else + return posix_memalign(mem, alignment, size); +#endif +} + +void free(void* mem) { +#ifdef _WIN32 + _aligned_free(mem); +#else + ::free(mem); +#endif +} + +long get_pagesize() { +#ifdef _WIN32 + SYSTEM_INFO si; + GetSystemInfo(&si); + return static_cast(si.dwPageSize); +#else + return sysconf(_SC_PAGESIZE); +#endif +} + +} // namespace openreg diff --git a/PyTorchSimDevice/third_party/openreg/csrc/stream.cpp b/PyTorchSimDevice/third_party/openreg/csrc/stream.cpp new file mode 100644 index 00000000..30f50b1a --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/csrc/stream.cpp @@ -0,0 +1,313 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +static std::mutex g_mutex; +static std::once_flag g_flag; +static std::vector> g_streams_per_device; + +static void initialize_registries() { + int device_count = 0; + orGetDeviceCount(&device_count); + g_streams_per_device.resize(device_count); +} + +struct orEventImpl { + std::mutex mtx; + std::condition_variable cv; + std::atomic completed{true}; + int device_index = -1; + bool timing_enabled{false}; + std::chrono::high_resolution_clock::time_point completion_time; +}; + +struct orEvent { + std::shared_ptr impl; +}; + +struct orStream { + std::queue> tasks; + std::mutex mtx; + std::condition_variable cv; + std::thread worker; + std::atomic stop_flag{false}; + int device_index = -1; + + orStream() { + worker = std::thread([this] { + while (true) { + std::function task; + { + std::unique_lock lock(this->mtx); + this->cv.wait(lock, [this] { + return this->stop_flag.load() || !this->tasks.empty(); + }); + if (this->stop_flag.load() && this->tasks.empty()) { + return; + } + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + task(); + } + }); + } + + ~orStream() { + stop_flag.store(true); + cv.notify_one(); + worker.join(); + } +}; + +orError_t openreg::addTaskToStream( + orStream_t stream, + std::function task) { + if (!stream) + return orErrorUnknown; + + { + std::lock_guard lock(stream->mtx); + stream->tasks.push(std::move(task)); + } + + stream->cv.notify_one(); + return orSuccess; +} + +orError_t orEventCreateWithFlags(orEvent_t* event, unsigned int flags) { + if (!event) + return orErrorUnknown; + + auto impl = std::make_shared(); + orGetDevice(&(impl->device_index)); + if (flags & orEventEnableTiming) { + impl->timing_enabled = true; + } + + *event = new orEvent{std::move(impl)}; + return orSuccess; +} + +orError_t orEventCreate(orEvent_t* event) { + return orEventCreateWithFlags(event, orEventDisableTiming); +} + +orError_t orEventDestroy(orEvent_t event) { + if (!event) + return orErrorUnknown; + + delete event; + return orSuccess; +} + +orError_t orEventRecord(orEvent_t event, orStream_t stream) { + if (!event || !stream) + return orErrorUnknown; + + auto event_impl = event->impl; + event_impl->completed.store(false); + auto record_task = [event_impl]() { + if (event_impl->timing_enabled) { + event_impl->completion_time = std::chrono::high_resolution_clock::now(); + } + + { + std::lock_guard lock(event_impl->mtx); + event_impl->completed.store(true); + } + + event_impl->cv.notify_all(); + }; + + return openreg::addTaskToStream(stream, record_task); +} + +orError_t orEventSynchronize(orEvent_t event) { + if (!event) + return orErrorUnknown; + + auto event_impl = event->impl; + std::unique_lock lock(event_impl->mtx); + event_impl->cv.wait(lock, [&] { return event_impl->completed.load(); }); + + return orSuccess; +} + +orError_t orEventQuery(orEvent_t event) { + if (!event) + return orErrorUnknown; + + return event->impl->completed.load() ? orSuccess : orErrorNotReady; +} + +orError_t orEventElapsedTime(float* ms, orEvent_t start, orEvent_t end) { + if (!ms || !start || !end) + return orErrorUnknown; + + auto start_impl = start->impl; + auto end_impl = end->impl; + + if (start_impl->device_index != end_impl->device_index) { + return orErrorUnknown; + } + + if (!start_impl->timing_enabled || !end_impl->timing_enabled) { + return orErrorUnknown; + } + + if (!start_impl->completed.load() || !end_impl->completed.load()) { + return orErrorUnknown; + } + + auto duration = end_impl->completion_time - start_impl->completion_time; + *ms = std::chrono::duration_cast>( + duration) + .count(); + + return orSuccess; +} + +orError_t orStreamCreateWithPriority( + orStream_t* stream, + [[maybe_unused]] unsigned int flag, + int priority) { + if (!stream) { + return orErrorUnknown; + } + + int min_p, max_p; + orDeviceGetStreamPriorityRange(&min_p, &max_p); + if (priority < min_p || priority > max_p) { + return orErrorUnknown; + } + + int current_device = 0; + orGetDevice(¤t_device); + + orStream_t new_stream = nullptr; + new_stream = new orStream(); + new_stream->device_index = current_device; + + { + std::lock_guard lock(g_mutex); + std::call_once(g_flag, initialize_registries); + g_streams_per_device[current_device].insert(new_stream); + } + + *stream = new_stream; + + return orSuccess; +} + +orError_t orStreamCreate(orStream_t* stream) { + int min_p, max_p; + orDeviceGetStreamPriorityRange(&min_p, &max_p); + + return orStreamCreateWithPriority(stream, 0, max_p); +} + +orError_t orStreamGetPriority( + [[maybe_unused]] orStream_t stream, + int* priority) { + // Since OpenReg has only one priority level, the following code + // returns 0 directly for convenience. + *priority = 0; + + return orSuccess; +} + +orError_t orStreamDestroy(orStream_t stream) { + if (!stream) + return orErrorUnknown; + + { + std::lock_guard lock(g_mutex); + + int device_idx = stream->device_index; + if (device_idx >= 0 && device_idx < g_streams_per_device.size()) { + g_streams_per_device[device_idx].erase(stream); + } + } + + delete stream; + return orSuccess; +} + +orError_t orStreamQuery(orStream_t stream) { + if (!stream) { + return orErrorUnknown; + } + + std::lock_guard lock(stream->mtx); + return stream->tasks.empty() ? orSuccess : orErrorNotReady; +} + +orError_t orStreamSynchronize(orStream_t stream) { + if (!stream) + return orErrorUnknown; + + orEvent_t event; + orEventCreate(&event); + orEventRecord(event, stream); + + orError_t status = orEventSynchronize(event); + orEventDestroy(event); + + return status; +} + +orError_t orStreamWaitEvent(orStream_t stream, orEvent_t event, unsigned int) { + if (!stream || !event) + return orErrorUnknown; + + auto event_impl = event->impl; + auto wait_task = [event_impl]() { + std::unique_lock lock(event_impl->mtx); + event_impl->cv.wait(lock, [&] { return event_impl->completed.load(); }); + }; + + return openreg::addTaskToStream(stream, wait_task); +} + +orError_t orDeviceGetStreamPriorityRange( + int* leastPriority, + int* greatestPriority) { + if (!leastPriority || !greatestPriority) { + return orErrorUnknown; + } + + // OpenReg have only one priority now. + *leastPriority = 0; + *greatestPriority = 0; + return orSuccess; +} + +orError_t orDeviceSynchronize(void) { + int current_device = 0; + orGetDevice(¤t_device); + + std::vector streams; + { + std::lock_guard lock(g_mutex); + std::call_once(g_flag, initialize_registries); + + auto& streams_on_device = g_streams_per_device[current_device]; + streams.assign(streams_on_device.begin(), streams_on_device.end()); + } + + for (orStream_t stream : streams) { + orError_t status = orStreamSynchronize(stream); + if (status != orSuccess) { + return status; + } + } + + return orSuccess; +} diff --git a/PyTorchSimDevice/third_party/openreg/example/example.cpp b/PyTorchSimDevice/third_party/openreg/example/example.cpp new file mode 100644 index 00000000..f00f1909 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/example/example.cpp @@ -0,0 +1,112 @@ +#include "include/openreg.h" + +#include +#include +#include +#include + +struct MemoryGuard { + MemoryGuard(void* ptr) : ptr_(ptr) { + orMemoryUnprotect(ptr_); + } + ~MemoryGuard() { + orMemoryProtect(ptr_); + } + + private: + void* ptr_{}; +}; + +void add_kernel(float* out, float* a, float* b, int num) { + for (int i = 0; i < num; ++i) { + out[i] = a[i] + b[i]; + } +} + +int main() { + int device_count = 0; + orGetDeviceCount(&device_count); + + std::cout << "Current environment have " << device_count << " devices" + << std::endl; + + orSetDevice(0); + int current_device = -1; + orGetDevice(¤t_device); + + std::cout << "Current is " << current_device << " device" << std::endl; + + constexpr int num = 50000; + constexpr size_t size = num * sizeof(float); + + std::vector host_a(num), host_b(num), host_out(num, 0.0f); + std::iota(host_a.begin(), host_a.end(), 0.0f); + for (int i = 0; i < num; ++i) { + host_b[i] = 2.0f; + } + + float *dev_a, *dev_b, *dev_out; + orMalloc((void**)&dev_a, size); + orMalloc((void**)&dev_b, size); + orMalloc((void**)&dev_out, size); + + // There will be subsequent memory access operations, so memory protection + // needs to be released + MemoryGuard a{dev_a}; + MemoryGuard b{dev_b}; + MemoryGuard c{dev_out}; + + orStream_t stream1, stream2; + orEvent_t start_event, stop_event; + + orStreamCreate(&stream1); + orStreamCreate(&stream2); + orEventCreateWithFlags(&start_event, orEventEnableTiming); + orEventCreateWithFlags(&stop_event, orEventEnableTiming); + + // Copy input from host to device + orMemcpyAsync(dev_a, host_a.data(), size, orMemcpyHostToDevice, stream1); + orMemcpyAsync(dev_b, host_b.data(), size, orMemcpyHostToDevice, stream1); + + // Submit compute kernel and two events those are used for calculating time. + orEventRecord(start_event, stream1); + orLaunchKernel(stream1, add_kernel, dev_out, dev_a, dev_b, num); + orEventRecord(stop_event, stream1); + + // Synchronization between streams. + orStreamWaitEvent(stream2, stop_event, 0); + orMemcpyAsync(host_out.data(), dev_out, size, orMemcpyDeviceToHost, stream2); + orStreamSynchronize(stream2); + + std::cout << "All tasks have been submitted." << std::endl; + + float elapsed_ms = 0.0f; + orEventElapsedTime(&elapsed_ms, start_event, stop_event); + std::cout << "Kernel execution time: " << elapsed_ms << " ms" << std::endl; + + bool success = true; + for (int i = 0; i < num; ++i) { + if (std::abs(host_out[i] - (host_a[i] + host_b[i])) > 1e-5) { + std::cout << "Verification FAILED at index " << i << "! Expected " + << (host_a[i] + host_b[i]) << ", got " << host_out[i] + << std::endl; + success = false; + break; + } + } + if (success) { + std::cout << "Verification PASSED!" << std::endl; + } + + orFree(dev_a); + orFree(dev_b); + orFree(dev_out); + + orStreamDestroy(stream1); + orStreamDestroy(stream2); + + orEventDestroy(start_event); + orEventDestroy(stop_event); + + return 0; +} diff --git a/PyTorchSimDevice/third_party/openreg/include/openreg.h b/PyTorchSimDevice/third_party/openreg/include/openreg.h new file mode 100644 index 00000000..a5e4af55 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/include/openreg.h @@ -0,0 +1,109 @@ +#pragma once + +#include + +#ifdef _WIN32 +#define OPENREG_EXPORT __declspec(dllexport) +#else +#define OPENREG_EXPORT __attribute__((visibility("default"))) +#endif + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum orError_t { + orSuccess = 0, + orErrorUnknown = 1, + orErrorNotReady = 2 +} orError_t; + +typedef enum orMemcpyKind { + orMemcpyHostToHost = 0, + orMemcpyHostToDevice = 1, + orMemcpyDeviceToHost = 2, + orMemcpyDeviceToDevice = 3 +} orMemcpyKind; + +typedef enum orMemoryType { + orMemoryTypeUnmanaged = 0, + orMemoryTypeHost = 1, + orMemoryTypeDevice = 2 +} orMemoryType; + +struct orPointerAttributes { + orMemoryType type = orMemoryType::orMemoryTypeUnmanaged; + int device; + void* pointer; +}; + +typedef enum orEventFlags { + orEventDisableTiming = 0x0, + orEventEnableTiming = 0x1, +} orEventFlags; + +struct orStream; +struct orEvent; +typedef struct orStream* orStream_t; +typedef struct orEvent* orEvent_t; + +// Memory +OPENREG_EXPORT orError_t orMalloc(void** devPtr, size_t size); +OPENREG_EXPORT orError_t orFree(void* devPtr); +OPENREG_EXPORT orError_t orMallocHost(void** hostPtr, size_t size); +OPENREG_EXPORT orError_t orFreeHost(void* hostPtr); +OPENREG_EXPORT orError_t +orMemcpy(void* dst, const void* src, size_t count, orMemcpyKind kind); +OPENREG_EXPORT orError_t orMemcpyAsync( + void* dst, + const void* src, + size_t count, + orMemcpyKind kind, + orStream_t stream); +OPENREG_EXPORT orError_t +orPointerGetAttributes(orPointerAttributes* attributes, const void* ptr); +OPENREG_EXPORT orError_t orMemoryUnprotect(void* devPtr); +OPENREG_EXPORT orError_t orMemoryProtect(void* devPtr); + +// Device +OPENREG_EXPORT orError_t orGetDeviceCount(int* count); +OPENREG_EXPORT orError_t orSetDevice(int device); +OPENREG_EXPORT orError_t orGetDevice(int* device); +OPENREG_EXPORT orError_t +orDeviceGetStreamPriorityRange(int* leastPriority, int* greatestPriority); +OPENREG_EXPORT orError_t orDeviceSynchronize(void); + +// Stream +OPENREG_EXPORT orError_t orStreamCreateWithPriority( + orStream_t* stream, + unsigned int flags, + int priority); +OPENREG_EXPORT orError_t orStreamCreate(orStream_t* stream); +OPENREG_EXPORT orError_t orStreamGetPriority(orStream_t stream, int* priority); +OPENREG_EXPORT orError_t orStreamDestroy(orStream_t stream); +OPENREG_EXPORT orError_t orStreamQuery(orStream_t stream); +OPENREG_EXPORT orError_t orStreamSynchronize(orStream_t stream); +OPENREG_EXPORT orError_t +orStreamWaitEvent(orStream_t stream, orEvent_t event, unsigned int flags); + +// Event +OPENREG_EXPORT orError_t +orEventCreateWithFlags(orEvent_t* event, unsigned int flags); +OPENREG_EXPORT orError_t orEventCreate(orEvent_t* event); +OPENREG_EXPORT orError_t orEventDestroy(orEvent_t event); +OPENREG_EXPORT orError_t orEventRecord(orEvent_t event, orStream_t stream); +OPENREG_EXPORT orError_t orEventSynchronize(orEvent_t event); +OPENREG_EXPORT orError_t orEventQuery(orEvent_t event); +OPENREG_EXPORT orError_t +orEventElapsedTime(float* ms, orEvent_t start, orEvent_t end); + +#ifdef __cplusplus +} // extern "C" +#endif + +#ifdef __cplusplus + +#define OPENREG_H +#include "openreg.inl" + +#endif diff --git a/PyTorchSimDevice/third_party/openreg/include/openreg.inl b/PyTorchSimDevice/third_party/openreg/include/openreg.inl new file mode 100644 index 00000000..851be132 --- /dev/null +++ b/PyTorchSimDevice/third_party/openreg/include/openreg.inl @@ -0,0 +1,42 @@ +#ifndef OPENREG_H +#error "Don`t include openreg.inl directly, include openreg.h instead." +#endif + +#include +#include +#include + +namespace openreg { +OPENREG_EXPORT orError_t +addTaskToStream(orStream* stream, std::function task); +} + +template +OPENREG_EXPORT inline orError_t orLaunchKernel( + orStream* stream, + Func&& kernel_func, + Args&&... args) { + if (!stream) { + return orErrorUnknown; + } + +/* + * Some tests in PyTorch still use C++11, so we use conditional macro to + * select different approaches for different C++ version. + * + * Std::apply is only supported in C++17, so for C++11/14, std::bind is + * a more appropriate approach, but the former has better performance. + */ +#if __cplusplus >= 201703L + auto task = [func = std::forward(kernel_func), + args_tuple = + std::make_tuple(std::forward(args)...)]() mutable { + std::apply(func, std::move(args_tuple)); + }; +#else + auto task = + std::bind(std::forward(kernel_func), std::forward(args)...); +#endif + + return openreg::addTaskToStream(stream, std::move(task)); +} diff --git a/PyTorchSimDevice/torch_openreg/_C.cpython-311-x86_64-linux-gnu.so b/PyTorchSimDevice/torch_openreg/_C.cpython-311-x86_64-linux-gnu.so new file mode 100755 index 00000000..04b3b4e1 Binary files /dev/null and b/PyTorchSimDevice/torch_openreg/_C.cpython-311-x86_64-linux-gnu.so differ diff --git a/PyTorchSimDevice/torch_openreg/__init__.py b/PyTorchSimDevice/torch_openreg/__init__.py new file mode 100644 index 00000000..5e404f7d --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/__init__.py @@ -0,0 +1,33 @@ +import sys +import os +import torch + + +if sys.platform == "win32": + from ._utils import _load_dll_libraries + + _load_dll_libraries() + del _load_dll_libraries + +import torch_openreg._C # type: ignore[misc] +import torch_openreg.openreg + +torch.utils.rename_privateuse1_backend("npu") +torch._register_device_module("npu", torch_openreg.openreg) +torch.utils.generate_methods_for_privateuse1_backend(for_storage=True) + +sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +from PyTorchSimFrontend.mlir.mlir_codegen_backend import ExtensionWrapperCodegen +from PyTorchSimFrontend.mlir.mlir_scheduling import MLIRScheduling +torch._inductor.codegen.common.register_backend_for_device( + "npu", + lambda scheduling: MLIRScheduling(scheduling), + ExtensionWrapperCodegen +) + +torch_openreg.openreg.init() +sys.modules['torch.npu'] = torch_openreg.openreg + +def _autoload(): + # It is a placeholder function here to be registered as an entry point. + pass \ No newline at end of file diff --git a/PyTorchSimDevice/torch_openreg/_utils.py b/PyTorchSimDevice/torch_openreg/_utils.py new file mode 100644 index 00000000..1c26f475 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/_utils.py @@ -0,0 +1,42 @@ +import ctypes +import glob +import os + + +def _load_dll_libraries() -> None: + openreg_dll_path = os.path.join(os.path.dirname(__file__), "lib") + + kernel32 = ctypes.WinDLL("kernel32.dll", use_last_error=True) + with_load_library_flags = hasattr(kernel32, "AddDllDirectory") + prev_error_mode = kernel32.SetErrorMode(0x0001) + + kernel32.LoadLibraryW.restype = ctypes.c_void_p + if with_load_library_flags: + kernel32.LoadLibraryExW.restype = ctypes.c_void_p + + os.add_dll_directory(openreg_dll_path) + + dlls = glob.glob(os.path.join(openreg_dll_path, "*.dll")) + path_patched = False + for dll in dlls: + is_loaded = False + if with_load_library_flags: + res = kernel32.LoadLibraryExW(dll, None, 0x00001100) + last_error = ctypes.get_last_error() + if res is None and last_error != 126: + err = ctypes.WinError(last_error) + err.strerror += f' Error loading "{dll}" or one of its dependencies.' + raise err + elif res is not None: + is_loaded = True + if not is_loaded: + if not path_patched: + os.environ["PATH"] = ";".join([openreg_dll_path] + [os.environ["PATH"]]) + path_patched = True + res = kernel32.LoadLibraryW(dll) + if res is None: + err = ctypes.WinError(ctypes.get_last_error()) + err.strerror += f' Error loading "{dll}" or one of its dependencies.' + raise err + + kernel32.SetErrorMode(prev_error_mode) diff --git a/PyTorchSimDevice/torch_openreg/csrc/CMakeLists.txt b/PyTorchSimDevice/torch_openreg/csrc/CMakeLists.txt new file mode 100644 index 00000000..4ff321c4 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/csrc/CMakeLists.txt @@ -0,0 +1,24 @@ +set(LIBRARY_NAME torch_bindings) + +file(GLOB_RECURSE SOURCE_FILES + "${CMAKE_CURRENT_SOURCE_DIR}/*.cpp" +) + +add_library(${LIBRARY_NAME} SHARED ${SOURCE_FILES}) + +target_link_libraries(${LIBRARY_NAME} PRIVATE torch_python_library torch_openreg) + +if(WIN32) + find_package(Python3 COMPONENTS Interpreter Development REQUIRED) + target_link_libraries(${LIBRARY_NAME} PRIVATE ${Python3_LIBRARIES}) +elseif(APPLE) + set_target_properties(${LIBRARY_NAME} PROPERTIES LINK_FLAGS "-undefined dynamic_lookup") +endif() + +target_link_directories(${LIBRARY_NAME} PRIVATE ${PYTORCH_INSTALL_DIR}/lib) + +install(TARGETS ${LIBRARY_NAME} + ARCHIVE DESTINATION ${CMAKE_INSTALL_LIBDIR} + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + RUNTIME DESTINATION ${CMAKE_INSTALL_LIBDIR} +) diff --git a/PyTorchSimDevice/torch_openreg/csrc/Module.cpp b/PyTorchSimDevice/torch_openreg/csrc/Module.cpp new file mode 100644 index 00000000..052a9ed4 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/csrc/Module.cpp @@ -0,0 +1,166 @@ +#include + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +static PyObject* _initExtension(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + + at::globalContext().lazyInitDevice(c10::DeviceType::PrivateUse1); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +static PyObject* _getDefaultGenerator(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPUtils_checkLong(arg), + "_get_default_generator expects an int, but got ", + THPUtils_typename(arg)); + auto idx = static_cast(THPUtils_unpackLong(arg)); + + return THPGenerator_initDefaultGenerator( + at::globalContext().defaultGenerator( + c10::Device(c10::DeviceType::PrivateUse1, idx))); + + END_HANDLE_TH_ERRORS +} + +PyObject* _setDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to setDevice"); + auto device = THPUtils_unpackLong(arg); + + torch::utils::device_lazy_init(at::kPrivateUse1); + c10::openreg::set_device(static_cast(device)); + + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _exchangeDevice(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK(THPUtils_checkLong(arg), "invalid argument to exchangeDevice"); + auto device_index = THPUtils_unpackDeviceIndex(arg); + if (device_index < 0) { + return THPUtils_packInt32(-1); + } + + torch::utils::device_lazy_init(at::kPrivateUse1); + auto current_device = c10::openreg::ExchangeDevice(device_index); + + return THPUtils_packDeviceIndex(current_device); + END_HANDLE_TH_ERRORS +} + +PyObject* _getDevice(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + torch::utils::device_lazy_init(at::kPrivateUse1); + auto device = static_cast(c10::openreg::current_device()); + return THPUtils_packInt32(device); + END_HANDLE_TH_ERRORS +} + +PyObject* _getDeviceCount(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + return THPUtils_packUInt64(c10::openreg::device_count()); + END_HANDLE_TH_ERRORS +} + +PyObject* _isAutocastEnabled(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + if (c10::openreg::is_amp_enabled()) { + Py_RETURN_TRUE; + } else { + Py_RETURN_FALSE; + } + END_HANDLE_TH_ERRORS +} + +PyObject* _setAutocastEnabled(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + PyBool_Check(arg), + "set_autocast_enabled expects a bool, but got ", + THPUtils_typename(arg)); + c10::openreg::set_amp_enabled(arg == Py_True); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _getAutocastDtype(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + THPDtype* dtype_obj = torch::getTHPDtype(c10::openreg::get_amp_dtype()); + Py_INCREF(dtype_obj); + return reinterpret_cast(dtype_obj); + END_HANDLE_TH_ERRORS +} + +PyObject* _setAutocastDtype(PyObject* self, PyObject* arg) { + HANDLE_TH_ERRORS + TORCH_CHECK( + THPDtype_Check(arg), + "set_autocast_dtype expects a dtype, but got ", + THPUtils_typename(arg)); + THPDtype* dtype_obj = reinterpret_cast(arg); + at::ScalarType dtype = dtype_obj->scalar_type; + c10::openreg::set_amp_dtype(dtype); + Py_RETURN_NONE; + END_HANDLE_TH_ERRORS +} + +PyObject* _getAmpSupportedDtype(PyObject* self, PyObject* noargs) { + HANDLE_TH_ERRORS + PyObject* torch_mod = PyImport_ImportModule("torch"); + TORCH_CHECK(torch_mod != nullptr, "Failed to import torch module"); + + PyObject* float16 = PyObject_GetAttrString(torch_mod, "float16"); + PyObject* float32 = PyObject_GetAttrString(torch_mod, "float32"); + + PyObject* lst = PyList_New(1); + PyList_SetItem(lst, 0, float32); + //PyList_SetItem(lst, 1, float32); + + Py_DECREF(torch_mod); + return lst; + END_HANDLE_TH_ERRORS +} + +static PyMethodDef methods[] = { + {"_init", _initExtension, METH_NOARGS, nullptr}, + {"_get_default_generator", _getDefaultGenerator, METH_O, nullptr}, + {"_get_device", _getDevice, METH_NOARGS, nullptr}, + {"_set_device", _setDevice, METH_O, nullptr}, + {"_exchangeDevice", _exchangeDevice, METH_O, nullptr}, + {"_get_device_count", _getDeviceCount, METH_NOARGS, nullptr}, + {"is_autocast_enabled", _isAutocastEnabled, METH_NOARGS, nullptr}, + {"set_autocast_enabled", _setAutocastEnabled, METH_O, nullptr}, + {"get_autocast_dtype", _getAutocastDtype, METH_NOARGS, nullptr}, + {"set_autocast_dtype", _setAutocastDtype, METH_O, nullptr}, + {"get_amp_supported_dtype", _getAmpSupportedDtype, METH_NOARGS, nullptr}, + {nullptr, nullptr, 0, nullptr}}; + +/* + * When ASAN is enabled, PyTorch modifies the dlopen flag during import, + * causing all global and weak symbols in _C.so and its dependent libraries + * to be exposed to the global symbol scope, which in turn causes + * subsequent symbols with the same name in other libraries to be intercepted. + * Therefore, it cannot be named initModule here, otherwise initModule + * in torch/csrc/Module.cpp will be called, resulting in failure. + */ +extern "C" OPENREG_EXPORT PyObject* initOpenRegModule(void) { + static struct PyModuleDef openreg_C_module = { + PyModuleDef_HEAD_INIT, "torch_openreg._C", nullptr, -1, methods}; + PyObject* mod = PyModule_Create(&openreg_C_module); + + return mod; +} diff --git a/PyTorchSimDevice/torch_openreg/csrc/stub.c b/PyTorchSimDevice/torch_openreg/csrc/stub.c new file mode 100644 index 00000000..4e02f9fd --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/csrc/stub.c @@ -0,0 +1,20 @@ +#include + +#ifdef _WIN32 +#define OPENREG_EXPORT __declspec(dllexport) +#else +#define OPENREG_EXPORT __attribute__((visibility("default"))) +#endif + +extern OPENREG_EXPORT PyObject* initOpenRegModule(void); + +#ifdef __cplusplus +extern "C" +#endif + + OPENREG_EXPORT PyObject* + PyInit__C(void); + +PyMODINIT_FUNC PyInit__C(void) { + return initOpenRegModule(); +} diff --git a/PyTorchSimDevice/torch_openreg/openreg/__init__.py b/PyTorchSimDevice/torch_openreg/openreg/__init__.py new file mode 100644 index 00000000..81c2fc60 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/__init__.py @@ -0,0 +1,91 @@ +import torch +from torch._dynamo.device_interface import register_interface_for_device + +import torch_openreg._C # type: ignore[misc] + +from . import meta # noqa: F401 +from . import extension_device_op_overrides +from .extension_device_interface import ExtensionDeviceInterface + +_initialized = False + + +class device: + r"""Context-manager that changes the selected device. + + Args: + device (torch.device or int): device index to select. It's a no-op if + this argument is a negative integer or ``None``. + """ + + def __init__(self, device): + self.idx = torch.accelerator._get_device_index(device, optional=True) + self.prev_idx = -1 + + def __enter__(self): + self.prev_idx = torch_openreg._C._exchangeDevice(self.idx) + + def __exit__(self, type, value, traceback): + self.idx = torch_openreg._C._set_device(self.prev_idx) + return False + + +def is_available(): + return True + + +def device_count() -> int: + return torch_openreg._C._get_device_count() + + +def current_device(): + return torch_openreg._C._get_device() + + +def set_device(device) -> None: + return torch_openreg._C._set_device(device) + +def custom_device(): + return torch.device("npu:0") + +def init(): + _lazy_init() + + +def is_initialized(): + return _initialized + + +def _lazy_init(): + global _initialized + if is_initialized(): + return + torch_openreg._C._init() + register_interface_for_device(custom_device(), ExtensionDeviceInterface) + _initialized = True + + +from .random import * # noqa: F403 +from .amp import * + +__all__ = [ + "device", + "device_count", + "current_device", + "set_device", + "custom_device", + "initial_seed", + "is_available", + "init", + "is_initialized", + "random", + "manual_seed", + "manual_seed_all", + "get_rng_state", + "set_rng_state", + "is_autocast_enabled", + "set_autocast_enabled", + "get_autocast_dtype", + "set_autocast_dtype", + "get_amp_supported_dtype", +] diff --git a/PyTorchSimDevice/torch_openreg/openreg/amp.py b/PyTorchSimDevice/torch_openreg/openreg/amp.py new file mode 100644 index 00000000..0a9dfdf0 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/amp.py @@ -0,0 +1,33 @@ +import torch + +import torch_openreg._C # type: ignore[misc] + +from . import _lazy_init + + +__all__ = [ + "is_autocast_enabled", + "set_autocast_enabled", + "get_autocast_dtype", + "set_autocast_dtype", + "get_amp_supported_dtype", +] + +def is_autocast_enabled(): + return torch_openreg._C.is_autocast_enabled() + + +def set_autocast_enabled(enabled: bool) -> None: + torch_openreg._C.set_autocast_enabled(enabled) + + +def get_autocast_dtype(): + return torch_openreg._C.get_autocast_dtype() + + +def set_autocast_dtype(dtype) -> None: + torch_openreg._C.set_autocast_dtype(dtype) + + +def get_amp_supported_dtype(): + return torch_openreg._C.get_amp_supported_dtype() \ No newline at end of file diff --git a/PyTorchSimDevice/torch_openreg/openreg/extension_device_interface.py b/PyTorchSimDevice/torch_openreg/openreg/extension_device_interface.py new file mode 100644 index 00000000..e5875ab7 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/extension_device_interface.py @@ -0,0 +1,63 @@ +import torch +from torch._dynamo.device_interface import DeviceInterface, caching_worker_current_devices, caching_worker_device_properties + +class _ExtensionDeviceProperties: # FIXME: Dummy property values + name: str = "Extension_device" + platform_name: str + vendor: str + driver_version: str + version: str + max_compute_units: int + gpu_eu_count: int + max_work_group_size: int + max_num_sub_groups: int + sub_group_sizes: list[int] + has_fp16: bool + has_fp64: bool + has_atomic64: bool + has_bfloat16_conversions: bool + has_subgroup_matrix_multiply_accumulate: bool + has_subgroup_matrix_multiply_accumulate_tensor_float32: bool + has_subgroup_2d_block_io: bool + total_memory: int + multi_processor_count: int = 128 # gpu_subslice_count, num_sm + architecture: int + type: str + +_ExtensionDeviceProperties = _ExtensionDeviceProperties + +class ExtensionDeviceInterface(DeviceInterface): + class Worker: + @staticmethod + def set_device(device: int): + caching_worker_current_devices["extension_device"] = device + + @staticmethod + def current_device() -> int: + if "extension_device" in caching_worker_current_devices: + return caching_worker_current_devices["extension_device"] + return torch.xpu.current_device() + + @staticmethod + def get_device_properties(device: torch.types.Device = None) -> _ExtensionDeviceProperties: + if device is not None: + if isinstance(device, str): + device = torch.device(device) + assert device.type == "extension_device" + if isinstance(device, torch.device): + device = device.index + if device is None: + device = ExtensionDeviceInterface.Worker.current_device() + + if "extension_device" not in caching_worker_device_properties: + device_prop = [ + torch.cuda.get_device_properties(i) + for i in range(torch.cuda.device_count()) + ] + caching_worker_device_properties["extension_device"] = device_prop + + return _ExtensionDeviceProperties + + @staticmethod + def get_compute_capability(device: torch.types.Device = None): + return 36 \ No newline at end of file diff --git a/PyTorchSimDevice/torch_openreg/openreg/extension_device_op_overrides.py b/PyTorchSimDevice/torch_openreg/openreg/extension_device_op_overrides.py new file mode 100644 index 00000000..27a47357 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/extension_device_op_overrides.py @@ -0,0 +1,27 @@ +from __future__ import annotations + +from textwrap import dedent + +from torch._inductor.codegen.common import DeviceOpOverrides, register_device_op_overrides +from torch._inductor.codegen.cpu_device_op_overrides import CpuDeviceOpOverrides + +class ExtensionDeviceOpOverrides(DeviceOpOverrides): + def import_get_raw_stream_as(self, name: str) -> str: + return dedent( + """ + def get_raw_stream(_): + return 0 + """ + ) + + def set_device(self, device_idx: int) -> str: + return "pass" + + def synchronize(self) -> str: + return "pass" + + def device_guard(self, device_idx: int) -> str: + return "pass" + +register_device_op_overrides("npu", ExtensionDeviceOpOverrides()) +register_device_op_overrides("cpu", CpuDeviceOpOverrides()) \ No newline at end of file diff --git a/PyTorchSimDevice/torch_openreg/openreg/meta.py b/PyTorchSimDevice/torch_openreg/openreg/meta.py new file mode 100644 index 00000000..c475e8e0 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/meta.py @@ -0,0 +1,13 @@ +import torch + + +# LITERALINCLUDE START: CUSTOM OPERATOR META +lib = torch.library.Library("openreg", "IMPL", "Meta") # noqa: TOR901 + + +@torch.library.impl(lib, "custom_abs") +def custom_abs(self): + return torch.empty_like(self) + + +# LITERALINCLUDE END: CUSTOM OPERATOR META diff --git a/PyTorchSimDevice/torch_openreg/openreg/random.py b/PyTorchSimDevice/torch_openreg/openreg/random.py new file mode 100644 index 00000000..6817bd79 --- /dev/null +++ b/PyTorchSimDevice/torch_openreg/openreg/random.py @@ -0,0 +1,61 @@ +import torch + +import torch_openreg._C # type: ignore[misc] + +from . import _lazy_init, current_device, device_count + + +__all__ = [ + "get_rng_state", + "set_rng_state", + "manual_seed", + "manual_seed_all", + "initial_seed", +] + + +def get_rng_state(device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + return default_generator.get_state() + + +def set_rng_state(new_state, device="openreg"): + if isinstance(device, str): + device = torch.device(device) + elif isinstance(device, int): + device = torch.device("openreg", device) + idx = device.index + if idx is None: + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.set_state(new_state) + + +def initial_seed() -> int: + _lazy_init() + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + return default_generator.initial_seed() + + +def manual_seed(seed: int) -> None: + seed = int(seed) + + idx = current_device() + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) + + +def manual_seed_all(seed: int) -> None: + seed = int(seed) + + for idx in range(device_count()): + default_generator = torch_openreg._C._get_default_generator(idx) + default_generator.manual_seed(seed) diff --git a/PyTorchSimFrontend/extension_codecache.py b/PyTorchSimFrontend/extension_codecache.py index 2e35220c..5066d214 100644 --- a/PyTorchSimFrontend/extension_codecache.py +++ b/PyTorchSimFrontend/extension_codecache.py @@ -3,12 +3,16 @@ import shlex import subprocess -from torch._inductor.codecache import AsyncCompile, get_lock_dir, get_hash, write +from torch._inductor.codecache import get_lock_dir, get_hash, write +from torch._inductor.async_compile import AsyncCompile from AsmParser.tog_generator import tog_generator from PyTorchSimFrontend.mlir.mlir_caller_codegen import MLIRKernelCallerCodeGen from PyTorchSimFrontend import extension_config from Simulator.simulator import FunctionalSimulator, CycleSimulator, TOGSimulator +# Configure logger for extension_codecache module (WARNING level by default) +logger = extension_config.setup_logger() + LOCK_TIMEOUT = 600 def hash_prefix(hash_value): @@ -165,8 +169,8 @@ def load(cls, source_code, subprocess.check_call(translate_cmd) subprocess.check_call(llc_cmd) except subprocess.CalledProcessError as e: - print("Command failed with exit code", e.returncode) - print("Error output:", e.output) + logger.error(f"Command failed with exit code {e.returncode}") + logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") assert(0) val_llvm_caller = MLIRKernelCallerCodeGen(extension_config.pytorchsim_functional_mode, arg_attributes) @@ -178,8 +182,10 @@ def load(cls, source_code, spad_size = val_llvm_caller.get_spad_size(target) spad_usage = stack_size + spad_size # Spad usage per lane if extension_config.CONFIG_SPAD_INFO["spad_size"] < spad_usage: - print(f"[Warning] Scratchpad size exceeded: required {spad_usage} bytes, " - f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available.") + logger.debug( + f"Scratchpad size exceeded: required {spad_usage} bytes, " + f"but only {extension_config.CONFIG_SPAD_INFO['spad_size']} bytes available." + ) raise SpadOverflowError() # Launch tile graph generator @@ -196,8 +202,8 @@ def load(cls, source_code, subprocess.check_call(gem5_translate_cmd) subprocess.check_call(gem5_llc_cmd) except subprocess.CalledProcessError as e: - print("Command failed with exit code", e.returncode) - print("Error output:", e.output) + logger.error(f"Command failed with exit code {e.returncode}") + logger.error(f"Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") assert(0) if not extension_config.pytorchsim_timing_mode: diff --git a/PyTorchSimFrontend/extension_config.py b/PyTorchSimFrontend/extension_config.py index 2b1b3102..b0bcac7f 100644 --- a/PyTorchSimFrontend/extension_config.py +++ b/PyTorchSimFrontend/extension_config.py @@ -2,6 +2,7 @@ import sys import importlib import yaml +import logging CONFIG_TORCHSIM_DIR = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') CONFIG_GEM5_PATH = os.environ.get('GEM5_PATH', default="/workspace/gem5/build/RISCV/gem5.opt") @@ -134,4 +135,43 @@ def load_plan_from_module(module_path): CONFIG_USE_TIMING_POOLING = int(os.environ.get('TORCHSIM_USE_TIMING_POOLING', default=0)) -CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0)) \ No newline at end of file +CONFIG_DEBUG_MODE = int(os.environ.get('TORCHSIM_DEBUG_MODE', default=0)) + + +def setup_logger(name=None, level=None): + """ + Setup a logger with consistent formatting across all modules. + + Args: + name: Logger name (default: __name__ of calling module) + level: Logging level (default: DEBUG if CONFIG_DEBUG_MODE else INFO) + + Returns: + Logger instance + """ + if name is None: + import inspect + # Get the calling module's name + frame = inspect.currentframe().f_back + name = frame.f_globals.get('__name__', 'PyTorchSim') + + # Convert logger name to lowercase + name = name.lower() + logger = logging.getLogger(name) + + # Only configure if not already configured (avoid duplicate handlers) + if not logger.handlers: + handler = logging.StreamHandler() + formatter = logging.Formatter( + fmt='[%(asctime)s.%(msecs)03d] [%(levelname)s] [%(name)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S' + ) + handler.setFormatter(formatter) + logger.addHandler(handler) + + # Set log level + if level is None: + level = logging.DEBUG if CONFIG_DEBUG_MODE else logging.INFO + logger.setLevel(level) + + return logger \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_device.cpp b/PyTorchSimFrontend/extension_device.cpp deleted file mode 100644 index cfaecf2b..00000000 --- a/PyTorchSimFrontend/extension_device.cpp +++ /dev/null @@ -1,711 +0,0 @@ -#include -#include - -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include -#include -#include -#include -#include -namespace py = pybind11; - -namespace { - bool g_amp_enabled = false; - at::ScalarType g_amp_dtype = at::kFloat; -} - -static at::ScalarType to_scalar_type(const py::object& dtype_obj) { - py::module torch_mod = py::module::import("torch"); - if (dtype_obj.is(torch_mod.attr("bfloat16"))) return at::kBFloat16; - if (dtype_obj.is(torch_mod.attr("float16"))) return at::kHalf; - if (dtype_obj.is(torch_mod.attr("float32"))) return at::kFloat; - if (dtype_obj.is(torch_mod.attr("float64"))) return at::kDouble; - throw std::runtime_error("Unsupported dtype for extension_device AMP"); -} - -static py::object to_torch_dtype(at::ScalarType st) { - py::module torch_mod = py::module::import("torch"); - switch (st) { - case at::kBFloat16: return torch_mod.attr("bfloat16"); - case at::kHalf: return torch_mod.attr("float16"); - case at::kFloat: return torch_mod.attr("float32"); - case at::kDouble: return torch_mod.attr("float64"); - default: - throw std::runtime_error("Unsupported scalar type in get_autocast_dtype"); - } -} - -static inline at::MemoryFormat fix_memory_format(c10::optional mf_opt) { - if (!mf_opt.has_value()) return at::MemoryFormat::Contiguous; - - auto mf = mf_opt.value(); - if (mf == at::MemoryFormat::Preserve) { - return at::MemoryFormat::Contiguous; - } - return mf; -} - -static uint64_t op_counter = 0; -static uint64_t last_saved_value = 0; - -// register guard -namespace at { -namespace detail { - -C10_REGISTER_GUARD_IMPL(PrivateUse1, c10::impl::NoOpDeviceGuardImpl); - -}} // namespace at::detail - -// basic dummy add function -at::Tensor custom_add_Tensor(const at::Tensor & self, const at::Tensor & other, const at::Scalar & alpha) { - op_counter += 1; - // Since this custom device is just for testing, not bothering to implement kernels. - return at::empty(self.sizes(), self.options()); -} - -// basic dummy mul function -at::Tensor custom_mul_Tensor(const at::Tensor & self, const at::Tensor & other) { - op_counter += 1; - // Since this custom device is just for testing, not bothering to implement kernels. - return at::empty(self.sizes(), self.options()); -} - -at::Tensor _reinterpret_tensor( - const at::Tensor& self, - c10::IntArrayRef size, - c10::IntArrayRef stride, - int64_t offset_increment) { - at::Tensor self_ = at::detail::make_tensor( - c10::Storage(self.storage()), self.key_set(), self.dtype()); - auto* self_tmp_ = self_.unsafeGetTensorImpl(); - self_tmp_->set_storage_offset(self.storage_offset() + offset_increment); - self_tmp_->set_sizes_and_strides(size, stride); - return self_; -} - -at::Tensor& zero_inplace_batching_rule(at::Tensor &self) { - op_counter += 1; - // Since this custom device is just for testing, not bothering to implement kernels. - return self; -} - -const at::Tensor& custom_resize_(const at::Tensor& self, at::IntArrayRef size, - std::optional optional_memory_format) { - at::TensorImpl* tensor_impl = self.unsafeGetTensorImpl(); - tensor_impl->set_sizes_contiguous(size); - const auto itemsize = tensor_impl->dtype().itemsize(); - const auto offset = tensor_impl->storage_offset(); - const auto storage_size = at::detail::computeStorageNbytesContiguous(size, itemsize, offset); - // Dummy device is using cpu allocator, so here just call cpu - // function maybe_resize_storage_cpu in aten/src/ATen/native/Resize.h - // to get a sufficient memory space. - at::native::maybe_resize_storage_cpu(tensor_impl, storage_size); - if (optional_memory_format.has_value()) { - auto memory_format = - optional_memory_format.value(); - TORCH_CHECK( - memory_format != at::MemoryFormat::Preserve, - "Unsupported memory format", - memory_format); - tensor_impl->empty_tensor_restride(memory_format); - } - return self; -} - -// basic dummy eq function: Only support CPU -at::Tensor custom_to_device( - const at::Tensor & self, - at::Device device, - at::ScalarType dtype, - bool non_blocking, - bool copy, - c10::optional memory_format) { - TORCH_CHECK(self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); - TORCH_CHECK(device.is_cpu() || device.type() == c10::DeviceType::PrivateUse1, "Dummy test only allows copy from cpu -> dummy device."); - // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous. - TORCH_CHECK(self.scalar_type() == dtype); - TORCH_CHECK(self.is_contiguous()); - - op_counter += 1; - if (device.type() == at::DeviceType::CPU) { - auto out = at::empty(self.sizes(), dtype, self.options().layout(), - device, false, memory_format); - std::memcpy(out.mutable_data_ptr(), self.data_ptr(), self.nbytes()); - return out; - } else { - auto opts = self.options().device(device).dtype(dtype); - auto out = at::empty(self.sizes(), opts); - std::memcpy(out.mutable_data_ptr(), self.data_ptr(), self.nbytes()); - return out; - } - - auto out = at::empty(self.sizes(), dtype, self.options().layout(), device, false, memory_format); - memcpy(out.mutable_data_ptr(), self.mutable_data_ptr(), self.nbytes()); - // Since this custom device is just for testing, not bothering to implement kernels. - return out; -} - - -// A dummy allocator for our custom device, that secretly uses the CPU -struct DummyCustomAllocator final : at::Allocator { - DummyCustomAllocator() = default; - at::DataPtr allocate(size_t nbytes) const override { - void* data = c10::alloc_cpu(nbytes); - return {data, data, &ReportAndDelete, at::Device(at::DeviceType::PrivateUse1, 0)}; - } - - static void ReportAndDelete(void* ptr) { - if (!ptr) { - return; - } - c10::free_cpu(ptr); - } - - at::DeleterFnPtr raw_deleter() const override { - return &ReportAndDelete; - } -}; - -// Register our dummy allocator -static DummyCustomAllocator global_custom_alloc; -REGISTER_ALLOCATOR(c10::DeviceType::PrivateUse1, &global_custom_alloc); - -at::Tensor & custom_fill__scalar(at::Tensor & self, const at::Scalar & value) { - TORCH_CHECK(self.device().type() == c10::DeviceType::PrivateUse1, - "Dummy test only allows dummy device."); - TORCH_CHECK(self.is_contiguous()); - - op_counter += 1; - - switch (self.scalar_type()) { - case c10::ScalarType::Float: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = value.toFloat(); - } - break; - } - case c10::ScalarType::Double: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = value.toDouble(); - } - break; - } - case c10::ScalarType::Half: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = at::Half(value.toHalf()); - } - break; - } - case c10::ScalarType::BFloat16: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = at::BFloat16(value.toBFloat16()); - } - break; - } - case c10::ScalarType::Int: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = value.toInt(); - } - break; - } - case c10::ScalarType::Long: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = value.toLong(); - } - break; - } - case c10::ScalarType::Short: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = static_cast(value.toShort()); - } - break; - } - case c10::ScalarType::Char: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = static_cast(value.toChar()); - } - break; - } - case c10::ScalarType::Byte: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = static_cast(value.toByte()); - } - break; - } - case c10::ScalarType::Bool: { - auto* data = self.mutable_data_ptr(); - for (int64_t i = 0; i < self.numel(); i++) { - data[i] = value.toBool(); - } - break; - } - default: - TORCH_CHECK(false, "Unsupported scalar type: ", self.scalar_type()); - } - return self; -} - -at::Tensor unsafe_create_cpu_tensor_from_dummy_tensor(const at::Tensor& src) { - // TORCH_CHECK(src.device().type() == c10::DeviceType::PrivateUse1, - // "Only support dummy device."); - const auto& sizes_ = src.sizes(); - const auto& strides_ = src.strides(); - auto storage_offset_ = src.storage_offset(); - at::detail::check_size_nonnegative(sizes_); - - size_t size_bytes = at::detail::computeStorageNbytes(sizes_, strides_, - src.element_size(), - storage_offset_); - - at::DataPtr data_ptr = - c10::InefficientStdFunctionContext::makeDataPtr(src.storage().mutable_data_ptr().get(), - [](void*){}, at::kCPU); - - c10::Storage storage{c10::Storage::use_byte_size_t{}, size_bytes, std::move(data_ptr), - /*allocator=*/&global_custom_alloc, /*resizeable=*/false}; - - constexpr c10::DispatchKeySet cpu_ks(c10::DispatchKey::CPU); - at::Tensor tensor = at::detail::make_tensor( - std::move(storage), cpu_ks, src.dtype()); - - c10::TensorImpl* tensor_impl = tensor.unsafeGetTensorImpl(); - tensor_impl->set_sizes_and_strides(sizes_, strides_); - tensor_impl->set_storage_offset(storage_offset_); - return tensor; -} - -// basic dummy copy_() function, so we can copy from the custom device to/from CPU -at::Tensor custom__copy_from(const at::Tensor& self, const at::Tensor& dst, bool non_blocking) { - TORCH_CHECK( - self.is_cpu() || self.device().type() == c10::DeviceType::PrivateUse1, - "Dummy test only allows copy from cpu -> dummy device."); - TORCH_CHECK( - dst.is_cpu() || dst.device().type() == c10::DeviceType::PrivateUse1, - "Dummy test only allows copy from cpu -> dummy device."); - - // Some dummy asserts for the basic use case: inputs are the same size / dtype, all contiguous. - if (self.numel() != dst.numel()) { - custom_resize_(dst, self.sizes(), c10::nullopt); - } - TORCH_CHECK(self.sizes() == dst.sizes()); - - const bool same_dtype = (self.scalar_type() == dst.scalar_type()); - const bool both_contig = self.is_contiguous() && dst.is_contiguous(); - - // 1) fast path - if (same_dtype && both_contig) { - std::memcpy(dst.mutable_data_ptr(), - self.data_ptr(), - dst.storage().nbytes()); - return dst; - } - - // 2) slow path - at::Tensor cpu_self = unsafe_create_cpu_tensor_from_dummy_tensor(self); - at::Tensor cpu_dst = unsafe_create_cpu_tensor_from_dummy_tensor(dst); - if (!same_dtype) { - cpu_self = cpu_self.to(cpu_dst.scalar_type(), /*non_blocking=*/false, /*copy=*/true); - } - cpu_dst.copy_(cpu_self); - return dst; -} - -at::Tensor custom__copy_from_and_resize(const at::Tensor& self, const at::Tensor& dst) { - return custom__copy_from(self, dst, false); -} - -at::Tensor& custom_abs_out(const at::Tensor& self, at::Tensor& out) { - return at::native::abs_out(self, out); -} - -at::Tensor custom_empty_strided(c10::IntArrayRef size, c10::IntArrayRef stride, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt) { - op_counter += 1; - constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); - auto dtype = c10::dtype_or_default(dtype_opt); - return at::detail::empty_strided_generic(size, stride, &global_custom_alloc, private_use_ks, dtype); -} - -at::Tensor custom_empty(c10::IntArrayRef size, c10::optional dtype_opt, c10::optional layout_opt, c10::optional device_opt, c10::optional pin_memory_opt, c10::optional optional_memory_format) { - op_counter += 1; - - constexpr c10::DispatchKeySet private_use_ks(c10::DispatchKey::PrivateUse1); - auto dtype = c10::dtype_or_default(dtype_opt); - return at::detail::empty_generic(size, &global_custom_alloc, private_use_ks, dtype, fix_memory_format(optional_memory_format)); -} - -at::Tensor& custom_arange_start_out_impl( - const c10::Scalar& start, - const c10::Scalar& end, - const c10::Scalar& step, - at::Tensor& out) { - double s = start.toDouble(); - double e = end.toDouble(); - double st = step.toDouble(); - TORCH_CHECK(st != 0.0, "step must be nonzero"); - - int64_t length = 0; - if (st > 0) { - if (e > s) length = static_cast(std::ceil((e - s) / st)); - } else { - if (e < s) length = static_cast(std::ceil((e - s) / st)); - } - - // Resize out tensor - custom_resize_(out, {length}, c10::nullopt); - - if (out.scalar_type() == at::kFloat || out.scalar_type() == at::kDouble) { - double* data = out.mutable_data_ptr(); - for (int64_t i = 0; i < length; i++) { - data[i] = s + i * st; - } - } else if (out.scalar_type() == at::kLong) { - int64_t* data = out.mutable_data_ptr(); - for (int64_t i = 0; i < length; i++) { - data[i] = static_cast(s + i * st); - } - } else { - TORCH_CHECK(false, "Unsupported dtype for arange on dummy device"); - } - - return out; -} - -static at::Tensor custom_to_dtype_impl(const at::Tensor& self, - c10::ScalarType dtype, - bool non_blocking, bool copy, - c10::optional memory_format) { - return at::native::to(self, dtype, non_blocking, copy, memory_format); -} - -at::Tensor custom_zeros_like( - const at::Tensor& input, - c10::optional dtype_opt, - c10::optional layout_opt, - c10::optional device_opt, - c10::optional pin_memory_opt, - c10::optional memory_format_opt) -{ - // dtype / layout / device fallback - auto dtype = dtype_opt.value_or(input.scalar_type()); - auto layout = layout_opt.value_or(input.layout()); - auto device = device_opt.value_or(input.device()); - auto memfmt = memory_format_opt.value_or(c10::MemoryFormat::Contiguous); - - TORCH_CHECK( - device.type() == c10::DeviceType::PrivateUse1, - "custom_zeros_like: device must be PrivateUse1"); - - at::Tensor out = custom_empty( - input.sizes(), - dtype, - layout, - device, - pin_memory_opt, - memfmt - ); - size_t nbytes = out.numel() * out.element_size(); - void* ptr = out.mutable_data_ptr(); - - TORCH_CHECK(ptr != nullptr, - "custom_zeros_like: out.mutable_data_ptr() returned NULL"); - std::memset(ptr, 0, nbytes); - return out; -} - -at::Tensor& custom_zero_impl(at::Tensor& self) -{ - TORCH_CHECK( - self.device().type() == c10::DeviceType::PrivateUse1, - "custom_zero_: expected a PrivateUse1 device tensor"); - - if (self.numel() == 0) { - return self; - } - - void* data = self.mutable_data_ptr(); - TORCH_CHECK(data != nullptr, - "custom_zero_: self.mutable_data_ptr() returned NULL " - "(storage was not allocated)"); - - size_t nbytes = self.numel() * self.element_size(); - std::memset(data, 0, nbytes); - - return self; -} - -// With TORCH_LIBRARY_IMPL, you can register custom kernels for your backend. -// For open registration, we're registering all of our kernels to the PrivateUse1 dispatch key. -// Later in this file, we map a custom device to the PrivateUse1 device type, -// which allows user code that puts a tensor on your custom_device to eventually get plumbed -// into the kernels registered here. -// -// This macro registers your kernels to the PyTorch Dispatcher. -// More details on the dispatcher can be found at http://blog.ezyang.com/2020/09/lets-talk-about-the-pytorch-dispatcher/. -TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { - m.impl("to.Device", &custom_to_device); - m.impl("to.dtype", &custom_to_dtype_impl); - m.impl("fill_.Scalar", &custom_fill__scalar); - m.impl("_copy_from", &custom__copy_from); - m.impl("_copy_from_and_resize", &custom__copy_from_and_resize); - m.impl("empty_strided", &custom_empty_strided); - m.impl("empty.memory_format", &custom_empty); - m.impl("as_strided", at::native::as_strided_tensorimpl); - m.impl("view", at::native::view); - m.impl("arange.start_out", &custom_arange_start_out_impl); - m.impl("zeros_like", &custom_zeros_like); - m.impl("zero_", &custom_zero_impl); -} - -TORCH_LIBRARY_IMPL(aten, AutogradPrivateUse1, m) { - m.impl("to.dtype", &custom_to_dtype_impl); -} - -TORCH_LIBRARY_FRAGMENT(aten, m) { - m.def( - "_reinterpret_tensor(Tensor self, int[] size, int[] stride, int offset_increment=0) -> Tensor", - torch::dispatch(c10::DispatchKey::AutogradPrivateUse1, _reinterpret_tensor), - {at::Tag::pt2_compliant_tag} - ); -} - -void custom_cpu_fallback(const c10::OperatorHandle& op, torch::jit::Stack* stack) { - at::native::cpu_fallback(op, stack); -} - -TORCH_LIBRARY_IMPL(aten, PrivateUse1, m) { - m.impl("abs", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("abs.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("abs_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("absolute", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("absolute.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("absolute_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("add.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("add.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("add.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("add_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("add_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("cat", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("cat.names", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("cat.names_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("cat.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("div.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("div.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("div.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("div_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("div_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("eq.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("eq.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("eq.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("eq.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("equal", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("erf", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("erf.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("erf_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("erfc", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("erfc.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("erfc_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("exp", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("exp.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("ge.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("ge.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("ge.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("ge.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("gt.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("gt.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("gt.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("gt.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("le.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("le.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("le.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("le.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("lt.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("lt.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("lt.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("lt.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("ne.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("ne.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("ne.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("ne.Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("logical_and", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_and.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_and_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_not", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_not.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_not_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_or", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_or.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_or_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_xor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_xor.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("logical_xor_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("neg", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("neg.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("neg_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("mul.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("mul.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("mul_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("pow.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("pow.Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("pow.Tensor_Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("pow.Tensor_Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("pow.Tensor_Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("pow.Tensor_Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("pow_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("pow_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("sub.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sub.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sub.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sub_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sub_.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("sum", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sum.DimnameList_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sum.IntList_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sum.dim_DimnameList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sum.dim_IntList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("resize_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("resize_as_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - // Foreach ops - m.impl("_foreach_add.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_add_.Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_add_.ScalarList", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_add.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_foreach_add_.List", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - // Indexed - m.impl("index_add.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index_add_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index_copy.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index_copy_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index_fill.int_Scalar", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index_fill.int_Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index_fill.int_Scalar_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index_fill.int_Tensor_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index_fill_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("tril", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("tril_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("triu", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("triu_", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("triu_indices", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("nll_loss2d_forward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("nll_loss2d_backward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("nll_loss_backward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("nll_loss_forward", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("scatter.src_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("scatter.value_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("index_put.Default", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index.Tensor", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("mm.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("sigmoid.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("gather.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("silu.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - - m.impl("all.all_out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_local_scalar_dense", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_log_softmax", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_log_softmax_backward_data", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("mse_loss.out", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("_native_multi_head_attention", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("where.self", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("min", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("max", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("index_select", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); - m.impl("nonzero", torch::CppFunction::makeFromBoxedFunction<&custom_cpu_fallback>()); -} - -// This basic implementation doesn't bother dealing with different device indices -// (e.g. custom_device:0 vs. custom_device:1). -// We could do that by letting the user pass in a device index in our exposed device function. -// Note that if you do that, you'll also need to register a device guard to core. -// See `c10/core/impl/DeviceGuardImplInterface.h:C10_REGISTER_GUARD_IMPL`. -c10::Device get_custom_device() { - return c10::Device(c10::DeviceType::PrivateUse1, 0); -} - -bool custom_op_called() { - bool called = false; - if (op_counter > last_saved_value) { - called = true; - last_saved_value = op_counter; - } - return called; -} - -class PrivateGeneratorImpl : public at::CPUGeneratorImpl { -public: - PrivateGeneratorImpl(c10::DeviceIndex device_index) { - device_ = c10::Device(c10::DeviceType::PrivateUse1, device_index); - key_set_ = c10::DispatchKeySet(c10::DispatchKey::PrivateUse1); - } - ~PrivateGeneratorImpl() override = default; -}; - -// this is used to register generator -at::Generator make_generator_privateuse1(c10::DeviceIndex device_index) { - return at::make_generator(device_index); -} - -void register_generator() { - REGISTER_GENERATOR_PRIVATEUSE1(make_generator_privateuse1) -} - -// Here, we're exposing a custom device object that corresponds to our custom backend. -// We do this using pybind: exposing an "extension_name.custom_device()" function in python, -// that's implemented in C++. -// The implementation in this file maps directly to the `PrivateUse1` device type. -PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("custom_device", &get_custom_device, "get custom device object"); - m.def("custom_op_called", &custom_op_called, "check if our custom function was called"); - m.def("register_generator", ®ister_generator, "register generator for custom device"); - m.def("is_autocast_enabled", []() -> bool { return g_amp_enabled;}); - m.def("set_autocast_enabled", [](bool flag) -> void {g_amp_enabled = flag;}); - m.def("get_autocast_dtype", []() -> py::object { return to_torch_dtype(g_amp_dtype); }); - m.def("set_autocast_dtype", [](py::object dtype_obj) -> void { - auto st = to_scalar_type(dtype_obj); - g_amp_dtype = st; - }); - m.def("get_amp_supported_dtype", []() -> py::list { - py::module torch_mod = py::module::import("torch"); - py::list lst; - lst.append(torch_mod.attr("float16")); - lst.append(torch_mod.attr("float32")); - return lst; - }); -} \ No newline at end of file diff --git a/PyTorchSimFrontend/extension_utils.py b/PyTorchSimFrontend/extension_utils.py new file mode 100644 index 00000000..0418cacd --- /dev/null +++ b/PyTorchSimFrontend/extension_utils.py @@ -0,0 +1,26 @@ +import sympy +import torch + +""" +NOTE: Temporary File + +This file contains functions that were removed or changed in newer versions +of PyTorch. It is kept here only to temporarily enable compatibility while +upgrading to PyTorch 2.8 from PyTorch 2.2. + +These functions will eventually be integrated into the appropriate source files +or removed once no longer needed. + +This file is not intended to be permanent and should be deleted in the future. +""" + +def free_symbol_startswith(index: sympy.Expr, prefix: str): + return any(v.name.startswith(prefix) for v in index.free_symbols) + +def sympy_symbol(name: str) -> sympy.Symbol: + # This should never be used for creating shape/stride symbols, as those + # should all be allocated before Inductor. + assert name[0] != "s" + # NOTE: shape symbols are positive (> 0), but index variables are only + # non-negative (>= 0). + return sympy.Symbol(name, integer=True, nonnegative=True) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_autotune.py b/PyTorchSimFrontend/mlir/mlir_autotune.py index 988408ea..138bec50 100644 --- a/PyTorchSimFrontend/mlir/mlir_autotune.py +++ b/PyTorchSimFrontend/mlir/mlir_autotune.py @@ -49,6 +49,9 @@ def __init__( self.extra_args = extra_args #self.hash_key, self.source_file = CUDACodeCache.write(self.source_code, "so") + def __str__(self) -> str: + return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" + def make_run_fn( self, input_tensors: torch.Tensor, output_tensors: torch.Tensor ) -> Callable[[], None]: @@ -84,5 +87,6 @@ def cached_run_fn(*args, **kwargs): *args, ) - def __str__(self) -> str: - return f"{self.kernel_name=}, {self.source_file=}, {self.hash_key=}" \ No newline at end of file + def update_workspace_size(self) -> None: + # FIXME: Not implemented yet. Checkout torch/_inductor/codegen/rocm/rocm_benchmark_request.py + return \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py index 297ea162..1565a26b 100644 --- a/PyTorchSimFrontend/mlir/mlir_codegen_backend.py +++ b/PyTorchSimFrontend/mlir/mlir_codegen_backend.py @@ -2,16 +2,17 @@ import sympy import re import os -import math from functools import reduce from operator import mul import torch +from typing import Optional from collections import defaultdict from concurrent.futures import ThreadPoolExecutor from torch._dynamo.testing import rand_strided from torch._inductor.autotune_process import TensorMeta from torch._dynamo.utils import dynamo_timed from torch._inductor.codegen import cpp, wrapper, common, memory_planning +from torch._inductor.ir import GraphPartitionSignature from torch._inductor.virtualized import V, _ops as ops from torch._inductor.codecache import write_atomic from torch._inductor.utils import ( @@ -27,6 +28,11 @@ from .mlir_ops import ExtensionOverrides from PyTorchSimFrontend.mlir.mlir_autotune import MLIRBenchmarkRequest +# Configure logger for mlir_codegen_backend module +logger = extension_config.setup_logger() + +from Simulator.simulator import ProgressBar + def reduction_init(reduction_type, dtype): if dtype in cpp.DTYPE_LOWP_FP: # Since load promotes all half-precision inputs to float, the initial @@ -54,13 +60,28 @@ def reduction_partial_combine_vec(reduction_type, vector_value, init_value): if reduction_type == "min": return ops.minimum(vector_value, init_value) if reduction_type == "any": - return ops.logical_and(vector_value, init_value) + return ops.logical_or(vector_value, init_value) raise AssertionError(reduction_type) -class ExtensionWrapperCodegen(wrapper.WrapperCodeGen): +class ExtensionWrapperCodegen(wrapper.PythonWrapperCodegen): def __init__(self): super().__init__() + @classmethod + def create( + cls, + is_subgraph: bool, + subgraph_name: Optional[str], + parent_wrapper: Optional[wrapper.PythonWrapperCodegen], + partition_signatures: Optional[GraphPartitionSignature] = None, + ): + if is_subgraph: + assert subgraph_name is not None and parent_wrapper is not None + return wrapper.SubgraphPythonWrapperCodegen( + subgraph_name, parent_wrapper, partition_signatures + ) + return cls() + def write_header(self): self.header.splice( f""" @@ -74,21 +95,27 @@ def write_header(self): from torch._inductor.hooks import run_intermediate_hooks from torch._inductor.utils import maybe_profile from torch._inductor.codegen.memory_planning import _align as align + from torch._inductor.async_compile import AsyncCompile from torch import device, empty, empty_strided from {extension_codecache.__name__} import CustomAsyncCompile - from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_TOGSIM_EAGER_MODE + from PyTorchSimFrontend.extension_config import CONFIG_SRAM_BUFFER_PLAN, CONFIG_TOGSIM_EAGER_MODE, setup_logger from Simulator.simulator import TOGSimulator from PyTorchSimFrontend.extension_op import sparse_mm_dummy_stonne_outer from torch._inductor.select_algorithm import extern_kernels + # Configure logger for generated wrapper code + _logger = setup_logger("PyTorchSimFrontend.mlir.generated_wrapper") + aten = torch.ops.aten inductor_ops = torch.ops.inductor assert_size_stride = torch._C._dynamo.guards.assert_size_stride alloc_from_pool = torch.ops.inductor._alloc_from_pool - reinterpret_tensor = torch.ops.aten._reinterpret_tensor + reinterpret_tensor = torch.ops.inductor._reinterpret_tensor custom_async_compile = CustomAsyncCompile() + async_compile = AsyncCompile() os.environ["TORCHSIM_LAST_COMPILED_MODULE"] = __file__ + _logger.info(f'Wrapper Codegen Path = {{__file__}}') """ ) self.header.splice( @@ -120,6 +147,7 @@ def device2host_memcpy(buffer): ) def write_prefix(self): + self.write_async_compile_wait() self.prefix.splice( """ def call(args): @@ -132,7 +160,7 @@ def call(args): self.prefix.writeline(f"{lhs} = args") self.prefix.writeline("args.clear()") - self.codegen_inputs(self.prefix, V.graph.graph_inputs) + self.codegen_inputs() self.codegen_input_size_asserts() self.codegen_sram_plan_prefix() @@ -152,35 +180,60 @@ def codegen_sram_plan_postfix(self, outputs): continue self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") - @dynamo_timed + def _generate_kernel_call_helper( + self, + kernel_name: str, + call_args, + *, + device=None, + triton=True, + arg_types=None, + raw_keys=None, + raw_args=None, + triton_meta=None, + graph_name="", + original_fxnode_name=None, + ): + device = device or V.graph.get_current_device_or_throw() + self.writeline(self.wrap_kernel_call(kernel_name, call_args)) + return + def generate(self, is_inference): result = IndentedBuffer() - result.splice(self.header) + # result.splice(self.header) with contextlib.ExitStack() as stack: stack.enter_context(self.wrapper_call.indent()) self.memory_plan_reuse() - for line in self.lines: - # Add buffer plan hook for dealloc - if isinstance(line, memory_planning.DeallocFromPoolLine): - self.wrapper_call.writeline(f"sram_plan_postfix('{line.node.get_name()}', {line.node.get_name()})") - elif isinstance(line, str) and "del" in line: - name = line.split(" ")[1] - self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") - - if isinstance(line, wrapper.MemoryPlanningLine): - line.codegen(self.wrapper_call) - else: - self.wrapper_call.writeline(line) - # Add buffer plan hook for alloc - if isinstance(line, memory_planning.AllocFromPoolLine) or isinstance(line, wrapper.AllocateLine): - self.wrapper_call.writeline(f"sram_plan_prefix('{line.node.get_name()}', {line.node.get_name()})") + with self.set_writeline(self.wrapper_call.writeline): + for line in self.lines: + # Add buffer plan hook for dealloc + if isinstance(line, memory_planning.DeallocFromPoolLine): + self.wrapper_call.writeline(f"sram_plan_postfix('{line.node.get_name()}', {line.node.get_name()})") + elif isinstance(line, str) and "del" in line: + name = line.split(" ")[1] + self.wrapper_call.writeline(f"sram_plan_postfix('{name}', {name})") + + if isinstance(line, wrapper.MemoryPlanningLine): + line.codegen(self.wrapper_call) + elif isinstance(line, wrapper.KernelCallLine): + self.wrapper_call.writeline(self.wrap_kernel_call(line.kernel_name, line.call_args)) + else: + if isinstance(line, wrapper.WrapperLine): + line.codegen(self.wrapper_call) + else: + self.wrapper_call.writeline(line) + # Add buffer plan hook for alloc + if isinstance(line, memory_planning.AllocFromPoolLine) or isinstance(line, wrapper.AllocateLine): + self.wrapper_call.writeline(f"sram_plan_prefix('{line.node.get_name()}', {line.node.get_name()})") output_refs = self.get_output_refs() self.codegen_sram_plan_postfix(output_refs) self.mark_output_type() self.generate_return(output_refs) - self.append_precomputed_sizes_to_prefix() + # self.append_precomputed_sizes_to_prefix() # FIXME: Need to replace append_precomputed_sizes_to_prefix() + result.splice(self.header) + self.finalize_prefix() result.splice(self.prefix) @@ -189,7 +242,10 @@ def generate(self, is_inference): self.generate_end(result) self.add_benchmark_harness(result) - return result.getvaluewithlinemap() + return ( + result.getvaluewithlinemap(), + self.kernel_declarations.getvaluewithlinemap(), + ) def memory_plan(self): self.lines = memory_planning.MemoryPlanner(self).plan(self.lines) @@ -257,7 +313,9 @@ def __init__(self, kernel_group, reason=None): self.base_vector_initialized = False def reset(self, reason): + save = self.exit_stack, self._nested_context_depth self.__init__(self.kernel_group, reason=reason) + self.exit_stack, self._nested_context_depth = save # padding type 0: zero-padding 1: negative-padding(-inf) ... def get_padding_type(self): @@ -271,7 +329,7 @@ def get_padding_type(self): # return 1 return 0 - def convert_index(self, expr, buffer): + def convert_index(self, expr): if len(expr.free_symbols) != 1: raise NotImplementedError("Not supporting this view operation...!") @@ -286,20 +344,41 @@ def convert_index(self, expr, buffer): expr_str = expr_str.replace("//", " floordiv ") else: raise NotImplementedError("What is this case?") + first_arg = expr.args[0] if len(first_arg.free_symbols) != 1: raise NotImplementedError("What is this case?") + + # Create affine.apply operation indices = [list(first_arg.free_symbols)[0]] - args = ", ".join(map(str, indices)) - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args}) -> ({expr_str})>") - args = ", ".join([f"%{i}" for i in indices]) - index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})") + with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse): + map_var = ops.affine_map(indices, expr_str) + index = ops.affine_apply(map_var, indices) return index - def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> common.CSEVariable: - if buffer is None: - buffer = self.applys + def _convert_sympy_to_mlir_expr(self, expr, sorted_args): + """ + Convert sympy expression to MLIR affine map expression by replacing index variables. + """ + indices = [] + for arg in sorted_args: + if arg.is_Mul and arg.args[0].is_number: + target_arg = arg.args[1] + elif not arg.is_number: + target_arg = arg + else: + continue + new_arg = sympy.Symbol(str(self.convert_index(target_arg))) + expr = expr.replace(target_arg, new_arg) + indices.append(str(new_arg)) + + expr_str = str(expr) + if "//" in expr_str: + expr_str = expr_str.replace("//", " floordiv ") + return expr_str, indices + + def parse_indices(self, expr, comments="", indices=None, indirect_dims=[]) -> common.CSEVariable: # Constant case if expr.is_number and len(indirect_dims) == 0: return self.get_const_cse(int(expr)) @@ -315,33 +394,19 @@ def parse_indices(self, expr, buffer=None, comments="", indirect_dims=[]) -> com # Sort index variable.. ex) (%index1, %index0) args_dict = {term: list(term.free_symbols)[0] for term in args if term.free_symbols} sorted_args = sorted(args_dict.keys(), key=lambda term: str(args_dict[term])) - indices = [] - for arg in sorted_args: - if arg.is_Mul and arg.args[0].is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer))) - expr = expr.replace(arg.args[1], new_arg) - indices.append(str(new_arg)) - elif not arg.is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg, buffer))) - expr = expr.replace(arg, new_arg) - indices.append(str(new_arg)) - # Extract index var + # Convert sympy expression to affine map expression + expr_str, indices = self._convert_sympy_to_mlir_expr(expr, sorted_args) indirect_args = [f"%{i}" for i in indirect_dims] - if len(indirect_args): - comments = "{indirect_access} " + comments # Add indirect access attribute - expr_str = str(expr) - if "//" in expr_str: - expr_str = expr_str.replace("//", " floordiv ") - args = ", ".join(map(str, indices)) - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[{','.join(indirect_dims)}] -> ({expr_str})>") - args = ", ".join([f"%{i}" for i in indices]) - index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[{','.join(indirect_args)}] {comments}") + # Create affine.apply operation + with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse): + map_var = ops.affine_map(indices, expr_str, symbol_names=indirect_dims) + + index = ops.affine_apply(map_var, indices, indirect_dims=indirect_args, comment=comments) return index - def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0)) -> common.CSEVariable: - if buffer is None: - buffer = self.applys + def parse_index_list(self, expr_list:list, offset=sympy.Number(0)) -> common.CSEVariable: + """ Need to override buffer and cse to use this function. """ expr_list = [arg for arg in expr_list] dim_list = [f"d{i}" for i in range(len(expr_list))] @@ -356,11 +421,11 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0)) new_expr_list = [0] * len(expr_list) for idx, arg in enumerate(expr_list): if arg.is_Mul and arg.args[0].is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg.args[1], buffer))) + new_arg = sympy.Symbol(str(self.convert_index(arg.args[1]))) new_expr_list[idx] = arg.subs(arg.args[1], dim_list[idx]) indices.append(str(new_arg)) elif not arg.is_number: - new_arg = sympy.Symbol(str(self.convert_index(arg, buffer))) + new_arg = sympy.Symbol(str(self.convert_index(arg))) new_expr_list[idx] = new_arg.subs(new_arg, dim_list[idx]) indices.append(str(new_arg)) else: @@ -370,15 +435,14 @@ def parse_index_list(self, expr_list:list, buffer=None, offset=sympy.Number(0)) indices.append(str(new_arg)) # Extract index var + # Create affine.apply operation expr_str = str(sum(new_expr_list) + offset) - args = ", ".join(map(str, dim_list)) - map_var = self.map_cse.generate(self.global_vars, f"affine_map<({args})[] -> ({expr_str})>") - args = ", ".join([f"%{i}" for i in indices]) - index = self.apply_cse.generate(buffer, f"affine.apply #{map_var}({args})[]") + with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse): + map_var = ops.affine_map(dim_list, expr_str) + index = ops.affine_apply(map_var, indices) return index def load(self, name: str, index: sympy.Expr): - index = self.rename_indexing(index) index, comptute_depedency = self.convert_indirect_indexing(index) padding = self.get_padding_type() @@ -432,7 +496,6 @@ def load(self, name: str, index: sympy.Expr): return out def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs): - index = self.rename_indexing(index) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] @@ -477,8 +540,8 @@ def store(self, name: str, index: sympy.Expr, value, mode=None, *args, **kwargs) value = ops.to_dtype(value, mlir_dtype) if compute_vec_size < self.var_info[value][0]: - value = self.cse.generate(self.stores, f"vector.extract_strided_slice %{value} {{offsets = [0], sizes = [{compute_vec_size}], strides = [1]}}: vector<{self.var_info[value][0]}x{self.var_info[value][1]}> to {vshape}") - self.register_var_info(value, [compute_vec_size, mlir_dtype]) + with self.override_buffer_cse(buffer=self.stores): + value = ops.extract_strided_slice(value, compute_vec_size) with self.override_buffer_cse(buffer=self.stores): ops._store(value, sram_var, compute_index_var, tile_shape, buffer_name=name) @@ -585,7 +648,6 @@ def store_reduction(self, name, index, value): dram_var = self.kernel_group.args.output(name) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - index = self.rename_indexing(index) with self.override_buffer_cse(cse=self.reduction_cse): # Tile is always reuduced in inner loop @@ -622,16 +684,17 @@ def store_reduction(self, name, index, value): dram_shape, tile_shape, attribute) self.reductions_suffix.writeline(common.DeferredLine(name, code)) - def indirect_indexing(self, index_var, size, check=True): + def indirect_indexing(self, index_var, size, check=True, wrap_neg=True): return str(index_var) def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): # In case of index expr, dimension size should be divisible by tile size if not self.kernel_group.tile_desc.is_dim_dividable(self.ranges): new_tile_size = self.kernel_group.tile_desc.adjust_tile_to_divisible(self.ranges) + prior_tile_size, prior_ranges = self.kernel_group.tile_desc.get_tile_size(), self.ranges self.kernel_group.tile_desc.set_tile_size(new_tile_size) self.reset("recompile") - raise mlir_common.RecompileSignal(f"Index access (tile size {self.kernel_group.tile_desc.get_tile_size()} is not divisible by {self.ranges})") + raise mlir_common.RecompileSignal(f"Index access (tile size {prior_tile_size} is not divisible by {prior_ranges})") tile_size = tile_desc.get_tile_size_per_lane() compute_vec_size = tile_desc.get_compute_vec_size() @@ -671,9 +734,11 @@ def _index_expr(self, tile_desc, renamed_expression, index, base_vector_index): outer_dim = ops.remainder(ops.truncdiv(dim, vlane_stride_vec), vlane_outer_vec) dim = ops.add(stride_dim, ops.mul(outer_dim, nr_vector_lane_vec)) - vlane_offset = self.const_cse.generate(self.const_buffer, f"arith.addi %{vlane_vec}, %{vlane_vec} {{ vlane_offset={offset} }} : vector<{vlane_vec_size}xi64> // vlane offset") - self.register_var_info(vlane_offset, [vlane_vec_size, "i64"]) - vlane_offset = ops.index_cast(vlane_offset, "index") + with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): + vlane_offset = ops.vlane_offset(vlane_vec, vlane_vec, attributes={"vlane_offset": offset}, comment="vlane offset") + if compute_vec_size < self.var_info[vlane_offset][0]: + vlane_offset = ops.extract_strided_slice(vlane_offset, compute_vec_size) + vlane_offset = ops.index_cast(vlane_offset, "index") dim = ops.add(dim, vlane_offset) dim_list.append(dim) @@ -736,7 +801,6 @@ def index_expr(self, index, dtype): tile_desc = base_tile_desc compute_vec_size = tile_desc.get_compute_vec_size() - tile_shape = f"memref<{compute_vec_size*self.vector_lane}xindex, 1>" vshape = f"vector<{compute_vec_size}xindex>" @@ -851,15 +915,14 @@ def make_choices(self, nodes, kernel_name): # Try initial tile size self.reset(None) - src_code = super().codegen_nodes(nodes, kernel_name) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) search_space.add(current_tile_sz) - if extension_config.CONFIG_DEBUG_MODE: - print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + logger.debug(f"Auto-tune: Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride)) + choices.append((bench_runner, src_code, meta_code, current_tile_sz, self.kernel_group.tile_desc.vmap.vlane_stride)) while prevent_infinite_loop < 10 and candidate_axes: for axis in list(candidate_axes): @@ -881,7 +944,7 @@ def make_choices(self, nodes, kernel_name): continue self.reset(None) - src_code = super().codegen_nodes(nodes, kernel_name) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) current_tile_sz = tuple(self.kernel_group.tile_desc.get_tile_size()) # FIXME. How to intergrate this constraint to tile system? @@ -898,11 +961,10 @@ def make_choices(self, nodes, kernel_name): # Add this choice search_space.add(current_tile_sz) - if extension_config.CONFIG_DEBUG_MODE: - print(f"[Auto-tune] Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") + logger.debug(f"Auto-tune: Trying tile size: {list(current_tile_sz)}, vlane_stride: {self.kernel_group.tile_desc.vmap.vlane_stride}, split_axis: {self.kernel_group.tile_desc.vmap.vlane_split_axis}") self._prepare_simulator_headers(src_code) bench_runner = self.run_bench(nodes, kernel_name, src_code) - choices.append((bench_runner, src_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride)) + choices.append((bench_runner, src_code, meta_code, self.kernel_group.tile_desc.get_tile_size(), self.kernel_group.tile_desc.vmap.vlane_stride)) prevent_infinite_loop += 1 self.kernel_group.tile_desc.prev_tail_threshold = prev_tail_threshold return choices @@ -918,18 +980,24 @@ def get_cycle(choice): return float("inf") return float("inf") # Exceeded maximum number of autotuning attempts choices = self.make_choices(*args) - if len(choices) == 0: # Can't autotune - return [None, None] - with ThreadPoolExecutor(max_workers=8) as executor: - results = list(executor.map(get_cycle, choices)) - max_idx = results.index(min(results)) + return [None, None, None] + + # Get cycle time for each choice + # Show progress bar only when CONFIG_DEBUG_MODE is off + show_progress = not extension_config.CONFIG_DEBUG_MODE + with ProgressBar("[Auto-tune] Running benchmarks", silent_mode=not show_progress) if show_progress else contextlib.nullcontext(): + with ThreadPoolExecutor(max_workers=8) as executor: + results = list(executor.map(get_cycle, choices)) + + min_idx = results.index(min(results)) if min(results) == float("inf"): raise RuntimeError("Failed to find optimal tile size...") - if extension_config.CONFIG_DEBUG_MODE: - self._log_autotune_result(choices[max_idx], results[max_idx]) - optimal_src_code, loop_size = choices[max_idx][1], choices[max_idx][-1] - return optimal_src_code, loop_size + + self._log_autotune_result(choices[min_idx], results[min_idx]) + + optimal_src_code, meta_code, loop_size = choices[min_idx][1], choices[min_idx][2], choices[min_idx][-1] + return optimal_src_code, meta_code, loop_size def run_bench(self, nodes, kernel_name, src_code): _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() @@ -957,20 +1025,20 @@ def run_bench(self, nodes, kernel_name, src_code): return bmreq.make_run_fn(dummy_inputs, dummy_outputs) def _log_autotune_result(self, best_choice, best_cycle): - print( - f"[Auto-tune] Optimal tile size: {list(best_choice[2])}, " - f"vlane_stride: {best_choice[3]}, " + logger.debug( + f"Auto-tune: Optimal tile size: {list(best_choice[3])}, " + f"vlane_stride: {best_choice[4]}, " f"cycles: {best_cycle}" ) def codegen_nodes(self, nodes, kernel_name): - src_code = super().codegen_nodes(nodes, kernel_name) + src_code, meta_code = super().codegen_nodes(nodes, kernel_name) self._prepare_simulator_headers(src_code) if "autotune" in extension_config.codegen_mapping_strategy and extension_config.pytorchsim_timing_mode: - optimal_src_code = self.autotune(nodes, kernel_name)[0] + optimal_src_code, meta_code = self.autotune(nodes, kernel_name)[:2] if optimal_src_code is not None: - return optimal_src_code - return src_code + return optimal_src_code, meta_code + return src_code, meta_code def _prepare_simulator_headers(self, src_code): write_path = extension_codecache.get_write_path(src_code) @@ -1020,7 +1088,8 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe if broadcast and (total_dims != local_dims or (self.reduction_depth!=len(total_dims) and total_dims[:self.reduction_depth] == local_dims)): local_dims = total_dims # Brodatcast tile shape - index_var = self.parse_indices(index, buffer=buffer, indirect_dims=indirect_dims, comments=f"// store_reduction={store_reduction}") + with self.override_buffer_cse(buffer=buffer, cse=self.apply_cse): + index_var = self.parse_indices(index, indirect_dims=indirect_dims, comments=f"// store_reduction={store_reduction}") if kg_tile_desc.vmap.vlane_split_axis in local_dims: local_vlane_split_axis = local_dims.index(kg_tile_desc.vmap.vlane_split_axis) @@ -1110,7 +1179,7 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe max_dim = len(self.ranges) if not store_reduction else len(self.ranges) - 1 for i in range(max_dim): target_dim = f"index{i}" - if target_dim not in str(index): + if sympy.Symbol(target_dim) not in index.free_symbols: dram_dict[target_dim] = [0] sorted_keys = sorted(dram_dict.keys()) dram_stride = sum((dram_dict[key] for key in sorted_keys), []) @@ -1127,14 +1196,19 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe dim_idx = int((str(sub.args[0])[5:])) if int(self.kernel_group.tile_desc.get_tile_size()[dim_idx] % sub.args[1]) != 0: # In this case, need to recompile - original_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] - divisor = sub.args[1] + original_tile = self.kernel_group.tile_desc.get_tile_size() + original_size = original_tile[dim_idx] + divisor = sub.args[1] * self.kernel_group.tile_desc.vmap.vlane_stride new_size = ((original_size + divisor - 1) // divisor) * divisor new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) new_tile_sizes[dim_idx] = new_size self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True + # Can't use dim_idx as vlane_split_axis + if dim_idx == self.kernel_group.tile_desc.vmap.vlane_split_axis: + self.kernel_group.tile_desc.vmap.vlane_split_axis = (dim_idx + 1) % len(original_tile) + # Send recompile signal self.reset("recompile") raise mlir_common.RecompileSignal(f"Tile size {self.kernel_group.tile_desc.get_tile_size()[dim_idx]} is not divisible by {sub.args[1]}") @@ -1150,6 +1224,57 @@ def get_dma_info(self, name, index, broadcast=True, store_reduction=False, buffe local_tile_desc.apply_divisor(dim_idx+offset, divisor, "split") offset = offset+1 + # Support ModularIndexing pattern + # This pattern can be used to broadcast ex) torch.cat([a,a]) + # ModularIndexing(x, y, z) means (x // y) % z + # tile_size must be: multiple of y (floorDiv divisor) and divisor of z (modular divisor) + if index.has(ModularIndexing): + for sub in sympy.preorder_traversal(index): + if isinstance(sub, ModularIndexing): + if not str(sub.args[0]).startswith("index"): + continue + dim_idx = int((str(sub.args[0])[5:])) + floor_divisor = sub.args[1] # y: floorDiv divisor + mod_divisor = sub.args[2] # z: modular divisor + current_tile_size = self.kernel_group.tile_desc.get_tile_size()[dim_idx] + + # Check if tile_size is multiple of floorDiv divisor + if int(current_tile_size % floor_divisor) != 0: + original_tile = self.kernel_group.tile_desc.get_tile_size() + original_size = original_tile[dim_idx] + divisor = floor_divisor * self.kernel_group.tile_desc.vmap.vlane_stride + new_size = ((original_size + divisor - 1) // divisor) * divisor + new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) + new_tile_sizes[dim_idx] = new_size + self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) + self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True + + self.reset("recompile") + raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a multiple of floorDiv divisor {floor_divisor} in ModularIndexing") + + # Check if tile_size is a divisor of modular divisor + if int((mod_divisor * floor_divisor) % current_tile_size) != 0: + original_tile = self.kernel_group.tile_desc.get_tile_size() + original_size = original_tile[dim_idx] + # Find the largest divisor of mod_divisor that is <= original_size + # and is a multiple of floor_divisor + new_size = original_size + while new_size > 0: + if mod_divisor % new_size == 0 and new_size % floor_divisor == 0: + break + new_size -= floor_divisor + + if new_size <= 0: + new_size = mod_divisor * floor_divisor + + new_tile_sizes = list(self.kernel_group.tile_desc.get_tile_size()) + new_tile_sizes[dim_idx] = new_size + self.kernel_group.tile_desc.set_tile_size(new_tile_sizes) + self.kernel_group.tile_desc.tile_constraint[dim_idx].fixed = True + + self.reset("recompile") + raise mlir_common.RecompileSignal(f"Tile size {current_tile_size} is not a divisor of modular divisor {mod_divisor} in ModularIndexing") + # FIXME. It will be nice to modify node instead of this exception handling... if len(self.itervars) == 1 and self.reduction_depth == 0: # In case of reduction loop only case, we will add dummy loop so shift it once diff --git a/PyTorchSimFrontend/mlir/mlir_common.py b/PyTorchSimFrontend/mlir/mlir_common.py index b86607ea..be491925 100644 --- a/PyTorchSimFrontend/mlir/mlir_common.py +++ b/PyTorchSimFrontend/mlir/mlir_common.py @@ -14,7 +14,8 @@ from torch._inductor.virtualized import V from torch._inductor.ir import MultiOutputLayout from torch._inductor.dependencies import MemoryDep, StarDep, WeakDep -from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod +from torch._inductor.codegen.wrapper import KernelDefinitionLine +from torch.utils._sympy.functions import ModularIndexing, FloorDiv, Mod, Identity import sympy import contextlib @@ -22,18 +23,21 @@ import sympy -import torch.fx from torch.utils._sympy.value_ranges import ValueRanges from torch._inductor.utils import ( - free_symbol_startswith, get_sympy_Expr_dtype, IndentedBuffer, sympy_subs, - sympy_symbol, unique, ) from PyTorchSimFrontend import extension_config from PyTorchSimFrontend import extension_codecache + +from PyTorchSimFrontend.extension_utils import ( + free_symbol_startswith, + sympy_symbol +) + schedule_log = torch._logging.getArtifactLogger(__name__, "schedule") DTYPE_TO_MLIR = { @@ -605,11 +609,12 @@ def __init__(self, kernel_group, reason=None): self.recodegen = reason # spad overflow, tile size, vlane stride self.stop_autotune = False - # Context var for codegen - self.target_buffer_override = contextvars.ContextVar("Handler_compute_override", default=self.compute) - self.target_cse_override = contextvars.ContextVar("Handler_cse_override", default=self.cse) + instance_id = id(self) + self.target_buffer_override = contextvars.ContextVar(f"Handler_compute_override_{instance_id}", default=self.compute) + self.target_cse_override = contextvars.ContextVar(f"Handler_cse_override_{instance_id}", default=self.cse) + self._nested_context_depth = 0 - def set_ranges(self, lengths, reduction_lengths): + def set_ranges(self, lengths, reduction_lengths, index_names=None): if self.call_ranges: assert self.call_ranges == tuple(lengths) + tuple( reduction_lengths @@ -618,7 +623,12 @@ def set_ranges(self, lengths, reduction_lengths): else: self.call_ranges = tuple(lengths) + tuple(reduction_lengths) self.ranges = [self.rename_indexing(x) for x in self.call_ranges] - self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))] + if index_names is None: + self.itervars = [sympy.Symbol(f"index{n}") for n in range(len(self.ranges))] + else: + assert len(index_names) == len(self.ranges), f"Index names length mismatch: {len(index_names)} != {len(self.ranges)}" + self.itervars = [sympy.Symbol(str(n)) for n in index_names] + self.itervar_cses = {str(index) : self.register_var_cse(str(index), 1, "index") for index in self.itervars} self.reduction_depth = len(lengths) return ( @@ -641,7 +651,7 @@ def store(self, name, index, value, mode=None): def reduction(self, dtype, src_dtype, reduction_type, value): raise NotImplementedError() - def indirect_indexing(self, index_var, size, check): + def indirect_indexing(self, index_var, size, check, wrap_neg): raise NotImplementedError() def codegen_global_init(self): @@ -654,7 +664,7 @@ def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this - wrapper.generate_kernel_call(kernel_name, call_args, cuda=False) + wrapper.generate_kernel_call(kernel_name, call_args, triton=False) def is_modular_indexing(self, expr): return "ModularIndexing" in str(expr) @@ -688,7 +698,9 @@ def extract_dividers(self, implicit_ops): } new_index = operand.index.subs(subs_map) for arg in new_index.args: - if len(arg.free_symbols) != 1: + if arg.is_number: + continue + if len(arg.free_symbols) > 1: raise NotImplementedError("Not supporting this view operation...!") if arg.is_Mul and arg.args[0].is_number: arg = arg.args[1] @@ -778,8 +790,8 @@ def codegen_nodes(self, nodes, kernel_name): V.graph.removed_buffers |= self.removed_buffers # V.graph.inplaced_to_remove |= self.inplaced_to_remove src_code = self.codegen_kernel(kernel_name=kernel_name) - self.meta_kernel() - return src_code + meta_code = self.meta_kernel() + return src_code, meta_code def codegen_kernel(self, kernel_name): arg_defs, _, _, _ = self.kernel_group.args.mlir_argdefs() @@ -797,12 +809,9 @@ def codegen_kernel(self, kernel_name): return code.getvalue() def meta_kernel(self): - wrapper = V.graph.wrapper_code _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - # Dump loop and load/store information - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") - return arg_attributes + meta_code = arg_attributes + return meta_code def get_constant_vector(self, expr): constant_vector = [[int(expr.coeff(var)),None] for var in self.itervars] @@ -835,6 +844,21 @@ def rename_indexing(self, index) -> sympy.Expr: # and renames variables in index expressions to kernel arg names if isinstance(index, (list, tuple)): return [self.rename_indexing(x) for x in index] + + # FIXME. This is a temporary solution to remove Identity wrappers from index expression. + # Remove Identity wrappers from index expression + # Check if index itself is Identity + if isinstance(index, Identity): + index = index.args[0] if index.args else index + + # Replace Identity arguments with Identity.args[0] + if hasattr(index, 'args') and len(index.args) > 0: + for arg in index.args: + if arg.is_Mul and arg.args[0].is_number and isinstance(arg.args[1], Identity): + index = index.replace(arg.args[1], arg.args[1].args[0] if arg.args[1].args else arg.args[1]) + if isinstance(arg, Identity): + index = index.replace(arg, arg.args[0] if arg.args else arg) + index = V.graph.sizevars.simplify(index) sorted_symbols = sorted(index.free_symbols, key=lambda s: s.name) replacements = { @@ -846,18 +870,24 @@ def rename_indexing(self, index) -> sympy.Expr: @contextmanager def override_buffer_cse(self, *, buffer=None, cse=None): - target_buffer = target_cse = None + buffer_override = self.target_buffer_override + cse_override = self.target_cse_override + buffer_token = cse_token = None try: + # Store tokens for proper restoration in nested contexts + # contextvars.set() returns the previous value (token) which can be used for reset() if buffer is not None: - target_buffer = self.target_buffer_override.set(buffer) + buffer_token = buffer_override.set(buffer) if cse is not None: - target_cse = self.target_cse_override.set(cse) + cse_token = cse_override.set(cse) yield self finally: - if target_cse is not None: - self.target_cse_override.reset(target_cse) - if target_buffer is not None: - self.target_buffer_override.reset(target_buffer) + # Restore using tokens - contextvars automatically handles nested contexts + # Each level restores to its own previous value + if cse_token is not None: + cse_override.reset(cse_token) + if buffer_token is not None: + buffer_override.reset(buffer_token) def __enter__(self): class CSEProxy: @@ -866,7 +896,7 @@ class CSEProxy: @staticmethod def __getattr__(name: str) -> Callable[..., common.CSEVariable]: # type: ignore[misc] def inner(*args, **kwargs): - code, ret_info = getattr(parent_handler, name)(*args, var_info=self.var_info, **kwargs) + code, ret_info = getattr(parent_handler, name)(*args, **kwargs) target_buffer = self.target_buffer_override.get() target_cse = self.target_cse_override.get() if isinstance(code, common.DeferredLine): @@ -887,12 +917,13 @@ def inner(*args, **kwargs): return inner @staticmethod - def indirect_indexing(index_var, size, check=True): + def indirect_indexing(index_var, size, check=True, wrap_neg=True): # Skip CSE since this doesn't return an expression - return self.indirect_indexing(index_var, size, check) + return self.indirect_indexing(index_var, size, check, wrap_neg) @staticmethod def load(name: str, index: sympy.Expr): + index = self.rename_indexing(index) if name in self.cse.invalidated_stores: # A load from an invalidated store requires us to # keep the actual buffer around @@ -903,10 +934,10 @@ def load(name: str, index: sympy.Expr): if name in store_cache: return store_cache[name] key = name+str(index) - if key not in self.cse.cache: + if key not in self.cse._cache: result = self.load(name, index) - self.cse.cache[key] = result - return self.cse.cache[key] + self.cse._cache[key] = result + return self.cse._cache[key] @staticmethod def store(name, index, value, mode=None): @@ -914,9 +945,10 @@ def store(name, index, value, mode=None): if mode is None: self.cse.store_cache[name] = value if self.current_node: - for other_name in self.current_node.get_mutations(): + for other_name in self.current_node.get_output(name).get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: + index = self.rename_indexing(index) return self.store(name, index, value, mode=mode) @staticmethod @@ -924,10 +956,11 @@ def store_reduction(name, index, value): self.store_buffer_names.add(name) self.cse.store_cache[name] = value if self.current_node: - for other_name in self.current_node.get_mutations(): + for other_name in self.current_node.get_output(name).get_mutations(): self.cse.store_cache[other_name] = value if name not in V.graph.removed_buffers: + index = self.rename_indexing(index) return self.store_reduction(name, index, value) @staticmethod @@ -940,6 +973,7 @@ def _index_expr(tile_size, buffer, renamed_expression, index): @staticmethod def index_expr(index, dtype): + index = self.rename_indexing(index) return self.index_expr(index, dtype) @staticmethod @@ -968,13 +1002,20 @@ def bucketize( values, offsets_name, offsets_size, indexing_dtype, right ) - super().__enter__() - assert self.overrides - parent_handler = self.overrides(V.get_ops_handler()) - self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) - self.exit_stack.enter_context(V.set_kernel_handler(self)) + if self._nested_context_depth == 0: + self.exit_stack.__enter__() + assert self.overrides + parent_handler = self.overrides() + + self.exit_stack.enter_context(V.set_ops_handler(CSEProxy())) + self.exit_stack.enter_context(V.set_kernel_handler(self)) + self._nested_context_depth += 1 return self + def __exit__(self, exc_type, exc_val, exc_tb): + self._nested_context_depth -= 1 + if self._nested_context_depth == 0: + super().__exit__(exc_type, exc_val, exc_tb) @dataclasses.dataclass class LoopLevel: diff --git a/PyTorchSimFrontend/mlir/mlir_decomposition.py b/PyTorchSimFrontend/mlir/mlir_decomposition.py new file mode 100644 index 00000000..284d25d7 --- /dev/null +++ b/PyTorchSimFrontend/mlir/mlir_decomposition.py @@ -0,0 +1,167 @@ +import math +import torch +import torch.nn.functional as F +from torch._inductor.decomposition import register_decomposition + +aten = torch.ops.aten + +@register_decomposition(aten._native_multi_head_attention.default) +def decompose_native_multi_head_attention( + query, + key, + value, + embed_dim: int, + num_heads: int, + qkv_weight, + qkv_bias, + proj_weight, + proj_bias, + mask=None, + need_weights: bool = False, +): + """ + Decompose _native_multi_head_attention into scaled_dot_product_attention operations. + + Based on F.scaled_dot_product_attention and nn.MultiheadAttention implementation: + 1. QKV projection (if needed - but query/key/value may already be projected) + 2. Reshape to multi-head format + 3. Scaled dot product: Q @ K^T / sqrt(head_dim) + 4. Softmax + 5. Attention @ V + 6. Reshape back and output projection + """ + head_dim = embed_dim // num_heads + scale_factor = 1.0 / math.sqrt(head_dim) + + # Get input shapes - assuming [batch, seq_len, embed_dim] format + query_shape = query.shape + if len(query_shape) == 3: + # [batch, seq_len, embed_dim] format + batch_size = query_shape[0] + seq_len = query_shape[1] + elif len(query_shape) == 2: + # [seq_len, embed_dim] -> add batch dimension + batch_size = 1 + seq_len = query_shape[0] + query = query.unsqueeze(0) # [1, seq_len, embed_dim] + key = key.unsqueeze(0) + value = value.unsqueeze(0) + else: + # Fallback: assume first dim is batch, second is seq_len + batch_size = query_shape[0] if len(query_shape) > 0 else 1 + seq_len = query_shape[1] if len(query_shape) > 1 else query_shape[0] + + # Step 1: QKV projection (if query/key/value are not already projected) + # In many cases, query/key/value are already projected, so we check if qkv_weight is used + # For now, assume they might need projection + # Note: In practice, _native_multi_head_attention often receives already projected inputs + + # Reshape for projection: [batch, seq_len, embed_dim] -> [batch*seq_len, embed_dim] + if len(query.shape) == 3: + query_flat = query.view(-1, embed_dim) + key_flat = key.view(-1, embed_dim) + value_flat = value.view(-1, embed_dim) + else: + query_flat = query + key_flat = key + value_flat = value + + # QKV projection using qkv_weight and qkv_bias + # Check if GQA (Grouped Query Attention) is used + # Standard MHA: qkv_weight shape = [3*embed_dim, embed_dim] + # GQA: qkv_weight shape = [embed_dim + 2*kv_embed_dim, embed_dim] where kv_embed_dim < embed_dim + qkv_weight_total = qkv_weight.shape[0] + + # Determine if GQA: if qkv_weight is not exactly 3*embed_dim, it might be GQA + if qkv_weight_total == 3 * embed_dim: + # Standard MHA: split equally + qkv_weight_q, qkv_weight_k, qkv_weight_v = torch.split(qkv_weight, embed_dim, dim=0) + if qkv_bias is not None: + qkv_bias_q, qkv_bias_k, qkv_bias_v = torch.split(qkv_bias, embed_dim, dim=0) + else: + qkv_bias_q = qkv_bias_k = qkv_bias_v = None + kv_embed_dim = embed_dim + kv_heads = num_heads + else: + # GQA: Q has embed_dim, K and V share the rest + # Assume Q = embed_dim, K = V = (qkv_weight_total - embed_dim) / 2 + q_dim = embed_dim + kv_dim = (qkv_weight_total - embed_dim) // 2 + qkv_weight_q = qkv_weight[:q_dim] + qkv_weight_k = qkv_weight[q_dim:q_dim + kv_dim] + qkv_weight_v = qkv_weight[q_dim + kv_dim:] + if qkv_bias is not None: + qkv_bias_q = qkv_bias[:q_dim] + qkv_bias_k = qkv_bias[q_dim:q_dim + kv_dim] + qkv_bias_v = qkv_bias[q_dim + kv_dim:] + else: + qkv_bias_q = qkv_bias_k = qkv_bias_v = None + kv_embed_dim = kv_dim + kv_heads = kv_embed_dim // head_dim # Number of KV heads + + # Project Q, K, V + q = torch.nn.functional.linear(query_flat, qkv_weight_q, qkv_bias_q) + k = torch.nn.functional.linear(key_flat, qkv_weight_k, qkv_bias_k) + v = torch.nn.functional.linear(value_flat, qkv_weight_v, qkv_bias_v) + + # Reshape back: [batch*seq_len, embed_dim] -> [batch, seq_len, embed_dim] + q = q.view(batch_size, seq_len, embed_dim) + k = k.view(batch_size, seq_len, kv_embed_dim) + v = v.view(batch_size, seq_len, kv_embed_dim) + + # Step 2: Reshape to multi-head format + # [batch, seq_len, embed_dim] -> [batch, seq_len, num_heads, head_dim] + q = q.view(batch_size, seq_len, num_heads, head_dim) + k = k.view(batch_size, seq_len, kv_heads, head_dim) + v = v.view(batch_size, seq_len, kv_heads, head_dim) + + # Transpose to [batch, num_heads, seq_len, head_dim] for bmm + q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] + k = k.transpose(1, 2) # [batch, kv_heads, seq_len, head_dim] + v = v.transpose(1, 2) # [batch, kv_heads, seq_len, head_dim] + + # GQA: If key/value have fewer heads, repeat them to match query heads + if kv_heads < num_heads: + repeat_factor = num_heads // kv_heads + k = k.repeat_interleave(repeat_factor, dim=1) # [batch, num_heads, seq_len, head_dim] + v = v.repeat_interleave(repeat_factor, dim=1) # [batch, num_heads, seq_len, head_dim] + + # Step 3: Scaled dot product attention + # Scale Q + q_scaled = q * scale_factor + + # Q @ K^T: [batch, num_heads, seq_len, head_dim] @ [batch, num_heads, head_dim, seq_len] + # -> [batch, num_heads, seq_len, seq_len] + k_transposed = k.transpose(-2, -1) # [batch, num_heads, head_dim, seq_len] + scores = torch.matmul(q_scaled, k_transposed) # [batch, num_heads, seq_len, seq_len] + + # Step 4: Apply mask if provided + if mask is not None: + if mask.dtype == torch.bool: + attn_bias.masked_fill_(mask.logical_not(), float("-inf")) + else: + attn_bias = mask + attn_bias + + # Step 5: Softmax along the last dimension (seq_len dimension) + attn_weights = F.softmax(scores, dim=-1) # [batch, num_heads, seq_len, seq_len] + + # Step 6: Attention @ V + # [batch, num_heads, seq_len, seq_len] @ [batch, num_heads, seq_len, head_dim] + # -> [batch, num_heads, seq_len, head_dim] + attn_output = torch.matmul(attn_weights, v) + + # Step 7: Reshape back to [batch, seq_len, embed_dim] + attn_output = attn_output.transpose(1, 2) # [batch, seq_len, num_heads, head_dim] + attn_output = attn_output.contiguous().view(batch_size, seq_len, embed_dim) + + # Step 8: Output projection + attn_output_flat = attn_output.view(-1, embed_dim) + output = torch.nn.functional.linear(attn_output_flat, proj_weight, proj_bias) + output = output.view(batch_size, seq_len, embed_dim) + + if need_weights: + # Return attention weights: [batch, num_heads, seq_len, seq_len] -> [batch, seq_len, seq_len] + attn_weights_mean = attn_weights.mean(dim=1) # Average over heads + return output, attn_weights_mean + else: + return (output, None) \ No newline at end of file diff --git a/PyTorchSimFrontend/mlir/mlir_gemm_template.py b/PyTorchSimFrontend/mlir/mlir_gemm_template.py index bbc63b45..0158caa6 100644 --- a/PyTorchSimFrontend/mlir/mlir_gemm_template.py +++ b/PyTorchSimFrontend/mlir/mlir_gemm_template.py @@ -154,7 +154,7 @@ def render(self, W_tile_desc.set_tile_size_stride(W_tile_size, W_tile_stride) W_tile_desc.set_name("W_buffer") W_tile_desc.offset = W.get_layout().offset - W_stride = W.get_layout().stride + W_stride = W.get_layout().stride if N>1 else [Y.get_layout().stride[0], 0] W_idx = [sympy.Symbol("index2") * W_stride[0], sympy.Symbol("index1") * W_stride[1]] vlane_split_axis = vlane_split_axis if nr_rdim==0 else 0 @@ -163,7 +163,7 @@ def render(self, Y_tile_desc = mlir_common.MLIRMultiDimTile(Y_tile_size, kernel.vector_lane, vlane_split_axis, vlane_stride) Y_tile_desc.set_tile_size_stride(Y_tile_size, Y_tile_stride) Y_tile_desc.set_name("Y_buffer") - Y_stride = Y.get_layout().stride + Y_stride = Y.get_layout().stride if N>1 else [Y.get_layout().stride[0], 0] if nr_rdim == 0: Y_idx = [sympy.Symbol("index0") * Y_stride[0], sympy.Symbol("index1") * Y_stride[1]] else: diff --git a/PyTorchSimFrontend/mlir/mlir_ops.py b/PyTorchSimFrontend/mlir/mlir_ops.py index 21995512..9edd2e44 100644 --- a/PyTorchSimFrontend/mlir/mlir_ops.py +++ b/PyTorchSimFrontend/mlir/mlir_ops.py @@ -1,10 +1,13 @@ import math import torch +import warnings from torch._inductor.codegen import common from torch._inductor.virtualized import V, _ops as ops from . import mlir_common +warnings.filterwarnings('ignore', message='undefined OpHandler\\..*, please add missing op schema') + def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, reduced_shape): if reduction_type == "sum": return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" @@ -15,12 +18,41 @@ def reduction_combine_vec(reduction_type, vector_value, init_value, axis, shape, if reduction_type == "min": return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" if reduction_type == "any": - return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" + return f"vector.multi_reduction , %{vector_value}, %{init_value} [{axis}] : {shape} to {reduced_shape}" raise AssertionError(reduction_type) +def format_mlir_op(op_str, shape, **kwargs): + """ + Format MLIR operation string with optional attributes and comment. + + Args: + op_str: Base operation string (e.g., "arith.addi %0, %1") + shape: Type shape string (e.g., "vector<4xi64>" or "i64") + **kwargs: May contain 'attributes' (dict or str) and 'comment' (str) + + Returns: + Formatted MLIR operation string + """ + result = op_str + attributes = kwargs.get('attributes', None) + comment = kwargs.get('comment', None) + + if attributes: + if isinstance(attributes, dict): + # Format: { key1=value1, key2=value2 } + attrs_str = ", ".join(f"{k}={v}" for k, v in attributes.items()) + result += f" {{ {attrs_str} }}" + elif isinstance(attributes, str): + # Direct string format + result += f" {{ {attributes} }}" + result += f" : {shape}" + if comment: + result += f" // {comment}" + return result + class ExtensionOverrides(common.OpOverrides): @staticmethod - def constant(value, src_type, *args, var_info=None, **kwargs): + def constant(value, src_type, *args, **kwargs): if isinstance(src_type, torch.dtype): src_type = mlir_common.DTYPE_TO_MLIR[src_type] @@ -33,12 +65,12 @@ def constant(value, src_type, *args, var_info=None, **kwargs): elif src_type[0] == "f": value = format(float(value), ".20f") elif src_type[0] == "i": - value = int(float(value)) - return f'arith.constant {value} : {src_type}', [1, src_type] + value = int(float(value)) + return format_mlir_op(f'arith.constant {value}', src_type, **kwargs), [1, src_type] @staticmethod - def broadcast(operand, target_size, *args, var_info=None, **kwargs): - src_size, dtype = var_info[operand] + def broadcast(operand, target_size, *args, **kwargs): + src_size, dtype = V.kernel.var_info[operand] src_shape = f"vector<{src_size}x{dtype}>" if src_size > 1 else dtype dst_shape = f"vector<{target_size}x{dtype}>" @@ -51,27 +83,30 @@ def broadcast(operand, target_size, *args, var_info=None, **kwargs): outer_dim = target_size // src_size unflat_shape = f"vector<{outer_dim}x{src_size}x{dtype}>" # Flatten back to 1D - op_str = f"vector.shape_cast %{unflat_operand} : {unflat_shape} to {dst_shape}" + op_str = f"vector.shape_cast %{unflat_operand}" + shape = f"{unflat_shape} to {dst_shape}" else: raise NotImplementedError( f"Vector broadcast size mismatch: src={src_size} cannot broadcast to target={target_size}" ) elif src_size == 1: - op_str = f"vector.broadcast %{operand} : {src_shape} to {dst_shape}" + op_str = f"vector.broadcast %{operand}" + shape = f"{src_shape} to {dst_shape}" else: raise ValueError(f"Invalid source size: {src_size}") - return op_str, [target_size, dtype] + return format_mlir_op(op_str, shape, **kwargs), [target_size, dtype] @staticmethod - def broadcast_unflat(operand, target_size, *args, var_info=None, **kwargs): - src_size, dtype = var_info[operand] + def broadcast_unflat(operand, target_size, *args, **kwargs): + src_size, dtype = V.kernel.var_info[operand] outer_dim = target_size // src_size src_shape = f"vector<{src_size}x{dtype}>" dst_shape = f"vector<{outer_dim}x{src_size}x{dtype}>" - op_str = f"vector.broadcast %{operand} : {src_shape} to {dst_shape}" - return op_str, [target_size, dtype] + op_str = f"vector.broadcast %{operand}" + shape = f"{src_shape} to {dst_shape}" + return format_mlir_op(op_str, shape, **kwargs), [target_size, dtype] def load_seed(self, *args, **kwargs): raise NotImplementedError @@ -87,33 +122,36 @@ def randint64(self, *args, **kwargs): # Special operaitons @staticmethod - def masked(mask, body, other, *args, var_info=None, tile_size=16, dtype="f32", ninf_declared=False, **kwargs): + def masked(mask, body, other, *args, tile_size=16, dtype="f32", ninf_declared=False, **kwargs): result = body() val = ops.constant(other, dtype, *args, **kwargs) result = ops.where(mask, result, val) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def where(condition, operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - cond_type = var_info[condition] - operand_type = var_info[operand1] + def where(condition, operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + cond_type = V.kernel.var_info[condition] + operand_type = V.kernel.var_info[operand1] condition = ops.to_bool(condition) if cond_type[0] < tile_size: condition = ops.broadcast(condition, tile_size) elif cond_type[0] > tile_size: operand1 = ops.broadcast(operand1, cond_type[0]) operand2 = ops.broadcast(operand2, cond_type[0]) - tile_size, ret_type = var_info[operand1] + tile_size, ret_type = V.kernel.var_info[operand1] shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type cond_shape = f"vector<{tile_size}xi1>" if tile_size > 1 else "" - return f"arith.select %{condition}, %{operand1}, %{operand2} : {cond_shape}, {shape}", [tile_size, ret_type] + + op_str = f"arith.select %{condition}, %{operand1}, %{operand2}" + shape = f"{cond_shape}, {shape}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): + def to_dtype(operand, dst_mlir_dtype, *args, **kwargs): # Extract source information - src_mlir_dtype = var_info[operand][1] - tile_size = var_info[operand][0] + src_mlir_dtype = V.kernel.var_info[operand][1] + tile_size = V.kernel.var_info[operand][0] # Normalize destination type (Torch dtype -> MLIR string) if isinstance(dst_mlir_dtype, torch.dtype): @@ -154,7 +192,7 @@ def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): op_str = f"arith.extsi %{operand} : {src_shape} to {shape}" elif dst_bits < src_bits: # Use arith.trunci for integer truncation - op_str = f"arith.trunci %{operand} : {src_shape} to {shape}" + op_str = f"arith.trunci %{operand} : {src_shape} to {shape}" else: return operand, [tile_size, dst_mlir_dtype] # Case D: Float -> Float (Extension / Truncation) @@ -163,7 +201,7 @@ def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): op_str = f"arith.extf %{operand} : {src_shape} to {shape}" elif dst_bits < src_bits: # Corrected 'trunf' to 'truncf' - op_str = f"arith.truncf %{operand} : {src_shape} to {shape}" + op_str = f"arith.truncf %{operand} : {src_shape} to {shape}" else: return operand, [tile_size, dst_mlir_dtype] else: @@ -172,13 +210,13 @@ def to_dtype(operand, dst_mlir_dtype, *args, var_info=None, **kwargs): return op_str, [tile_size, dst_mlir_dtype] @staticmethod - def identity(operand, *args, var_info=None, **kwargs): - operand_info = var_info[operand] + def identity(operand, *args, **kwargs): + operand_info = V.kernel.var_info[operand] return operand, operand_info @staticmethod - def to_dtype_bitcast(operand, dtype, *args, var_info=None, **kwargs): - tile_size, current_src_type = var_info[operand] + def to_dtype_bitcast(operand, dtype, *args, **kwargs): + tile_size, current_src_type = V.kernel.var_info[operand] if isinstance(dtype, torch.dtype): dst_mlir_type = mlir_common.DTYPE_TO_MLIR[dtype] @@ -197,15 +235,18 @@ def to_dtype_bitcast(operand, dtype, *args, var_info=None, **kwargs): src_shape = f"vector<{tile_size}x{current_src_type}>" if tile_size > 1 else current_src_type dst_shape = f"vector<{tile_size}x{dst_mlir_type}>" if tile_size > 1 else dst_mlir_type - return f"arith.bitcast %{operand} : {src_shape} to {dst_shape}", [tile_size, dst_mlir_type] + op_str = f"arith.bitcast %{operand}" + shape = f"{src_shape} to {dst_shape}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dst_mlir_type] # Binary element wise operations @staticmethod - def binary_elementwise_common(operand1, operand2, var_info): + def binary_elementwise_common(operand1, operand2): + V.kernel.var_info = V.kernel.var_info operand1.bounds = operand1.bounds.unknown() operand2.bounds = operand2.bounds.unknown() - op_type1 = var_info[operand1] - op_type2 = var_info[operand2] + op_type1 = V.kernel.var_info[operand1] + op_type2 = V.kernel.var_info[operand2] # Tile size check if op_type1[0] != op_type2[0]: # Try to broad cast @@ -213,33 +254,47 @@ def binary_elementwise_common(operand1, operand2, var_info): rhs_tile_size, rhs_dtype = op_type2 if lhs_tile_size > rhs_tile_size: operand2 = ops.broadcast(operand2, lhs_tile_size) - op_type2 = var_info[operand2] + op_type2 = V.kernel.var_info[operand2] elif lhs_tile_size < rhs_tile_size: operand1 = ops.broadcast(operand1, rhs_tile_size) - op_type1 = var_info[operand1] + op_type1 = V.kernel.var_info[operand1] # Data type check if op_type1[1] != op_type2[1]: if op_type1[1] == "index" or op_type1 == "index": if op_type1[1] == "index": - operand1 = ops.index_cast(operand1, op_type2[1]) - op_type1 = var_info[operand1] + # index -> target type: 2-step casting if target is float + if op_type2[1][0] == "f": + operand1 = ops.index_cast(operand1, "i64") + operand1 = ops.to_dtype(operand1, op_type2[1]) + op_type1 = V.kernel.var_info[operand1] + else: + # index -> integer: direct casting + operand1 = ops.index_cast(operand1, op_type2[1]) + op_type1 = V.kernel.var_info[operand1] if op_type2[1] == "index": - operand2 = ops.index_cast(operand2, op_type1[1]) - op_type2 = var_info[operand2] + # index -> target type: 2-step casting if target is float + if op_type1[1][0] == "f": + operand2 = ops.index_cast(operand2, "i64") + operand2 = ops.to_dtype(operand2, op_type1[1]) + op_type2 = V.kernel.var_info[operand2] + else: + # index -> integer: direct casting + operand2 = ops.index_cast(operand2, op_type1[1]) + op_type2 = V.kernel.var_info[operand2] elif op_type1[1][0] == "i" and op_type2[1][0] == "f": operand1 = ops.to_dtype(operand1, op_type2[1]) - op_type1 = var_info[operand1] + op_type1 = V.kernel.var_info[operand1] elif op_type1[1][0] == "f" and op_type2[1][0] == "i": operand2 = ops.to_dtype(operand2, op_type1[1]) - op_type2 = var_info[operand2] + op_type2 = V.kernel.var_info[operand2] elif op_type1[1][0] == op_type2[1][0]: if mlir_common.MLIR_TO_BIT[op_type1[1]] > mlir_common.MLIR_TO_BIT[op_type2[1]]: operand2 = ops.ext(operand2, op_type1[1]) - op_type2 = var_info[operand2] + op_type2 = V.kernel.var_info[operand2] elif mlir_common.MLIR_TO_BIT[op_type1[1]] < mlir_common.MLIR_TO_BIT[op_type2[1]]: operand1 = ops.ext(operand1, op_type2[1]) - op_type1 = var_info[operand1] + op_type1 = V.kernel.var_info[operand1] else: raise NotImplementedError("Unsupported type converting") @@ -249,45 +304,45 @@ def binary_elementwise_common(operand1, operand2, var_info): return tile_size, ret_type, operand1, operand2 @staticmethod - def abs(operand, *args, var_info=None, **kwargs): + def abs(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def exp(operand, *args, var_info=None, **kwargs): + def exp(operand, *args, **kwargs): # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.exp(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.exp %{operand} : {shape}', [tile_size, dtype] + return format_mlir_op(f'math.exp %{operand}', shape, **kwargs), [tile_size, dtype] @staticmethod - def exp2(operand, *args, var_info=None, **kwargs): + def exp2(operand, *args, **kwargs): # Hands-on part: implement exp2 using math.exp2 - # var_info = {operand: [tile_size, dtype]} - # Ex) var_info[operand] = [8, "f32"] + # V.kernel.var_info = {operand: [tile_size, dtype]} + # Ex) V.kernel.var_info[operand] = [8, "f32"] ln2 = math.log(2) coeff = ops.constant(ln2, "f32") operand = ops.mul(operand, coeff) - return ops.exp(operand), var_info[operand] + return ops.exp(operand), V.kernel.var_info[operand] @staticmethod - def expm1(operand, *args, var_info=None, **kwargs): + def expm1(operand, *args, **kwargs): coeff = ops.constant(1.0, "f32") operand = ops.exp(operand) operand = ops.sub(operand, coeff) - return operand, var_info[operand] + return operand, V.kernel.var_info[operand] @staticmethod - def sqrt(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def sqrt(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -297,46 +352,48 @@ def sqrt(operand, *args, var_info=None, **kwargs): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.sqrt %{operand} : {shape}', [tile_size, dtype] + return format_mlir_op(f'math.sqrt %{operand}', shape, **kwargs), [tile_size, dtype] @staticmethod - def relu(operand, *args, var_info=None, **kwargs): - src_mlir_dtype = var_info[operand][1] - tile_size = var_info[operand][0] + def relu(operand, *args, **kwargs): + src_mlir_dtype = V.kernel.var_info[operand][1] + tile_size = V.kernel.var_info[operand][0] return ops.maximum(operand, ops.constant(0, src_mlir_dtype)), [tile_size, src_mlir_dtype] @staticmethod - def minimum(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def minimum(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": opcode = f'arith.minimumf' else: opcode = f'arith.minsi' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def maximum(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def maximum(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": opcode = f'arith.maximumf' else: opcode = f'arith.maxsi' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def cos(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def cos(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.cos(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -344,20 +401,20 @@ def cos(operand, *args, var_info=None, **kwargs): if dtype.startswith("f"): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.cos %{operand} : {shape}', [tile_size, dtype] + return format_mlir_op(f'math.cos %{operand}', shape, **kwargs), [tile_size, dtype] @staticmethod - def sin(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def sin(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.sin(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -365,54 +422,54 @@ def sin(operand, *args, var_info=None, **kwargs): if dtype.startswith("f"): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.sin %{operand} : {shape}', [tile_size, dtype] + return format_mlir_op(f'math.sin %{operand}', shape, **kwargs), [tile_size, dtype] @staticmethod - def tan(operand, *args, var_info=None, **kwargs): + def tan(operand, *args, **kwargs): sin_res = ops.sin(operand) cos_res = ops.cos(operand) operand = ops.truediv(sin_res, cos_res) - return operand, var_info[operand] + return operand, V.kernel.var_info[operand] @staticmethod - def lgamma(operand, *args, var_info=None, **kwargs): + def lgamma(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def erf(operand, *args, var_info=None, **kwargs): + def erf(operand, *args, **kwargs): # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.erf(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.erf %{operand} : {shape}', [tile_size, dtype] + return format_mlir_op(f'math.erf %{operand}', shape, **kwargs), [tile_size, dtype] @staticmethod - def cosh(operand, *args, var_info=None, **kwargs): + def cosh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def sinh(operand, *args, var_info=None, **kwargs): + def sinh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def tanh(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def tanh(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] # Check scalar - op_type = var_info[operand] + op_type = V.kernel.var_info[operand] if op_type[0] == 1: operand = ops.broadcast(operand, 4) val = ops.tanh(operand) result = ops.extractelement(val, 0) - return result, var_info[result] - op_type = var_info[operand] + return result, V.kernel.var_info[result] + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -420,83 +477,82 @@ def tanh(operand, *args, var_info=None, **kwargs): if dtype.startswith("f"): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.tanh %{operand} : {shape}', [tile_size, dtype] + return format_mlir_op(f'math.tanh %{operand}', shape, **kwargs), [tile_size, dtype] @staticmethod - def acos(operand, *args, var_info=None, **kwargs): + def acos(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def acosh(operand, *args, var_info=None, **kwargs): + def acosh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def asin(operand, *args, var_info=None, **kwargs): + def asin(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def asinh(operand, *args, var_info=None, **kwargs): + def asinh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def atan2(operand1, operand2, *args, var_info=None, **kwargs): + def atan2(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def atan(operand, *args, var_info=None, **kwargs): + def atan(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def atanh(operand, *args, var_info=None, **kwargs): + def atanh(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def copysign(operand1, operand2, *args, var_info=None, **kwargs): + def copysign(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def erfc(operand, *args, var_info=None, **kwargs): + def erfc(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def erfinv(operand, *args, var_info=None, **kwargs): + def erfinv(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def frexp(operand, *args, var_info=None, **kwargs): + def frexp(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def hypot(operand1, operand2, *args, var_info=None, **kwargs): + def hypot(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def log10(operand, *args, var_info=None, **kwargs): + def log10(operand, *args, **kwargs): val_ln = ops.log(operand) - - tile_size, dtype = var_info[val_ln] + + tile_size, dtype = V.kernel.var_info[val_ln] inv_ln10 = 1/math.log(10) const_op = ops.constant(inv_ln10, dtype) - + # Multiply: ln(x) * (1/ln(10)) result = ops.mul(val_ln, const_op) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def log2(operand, *args, var_info=None, **kwargs): + def log2(operand, *args, **kwargs): val_ln = ops.log(operand) - - tile_size, dtype = var_info[val_ln] + tile_size, dtype = V.kernel.var_info[val_ln] inv_ln10 = 1/math.log(2) const_op = ops.constant(inv_ln10, dtype) - + # Multiply: ln(x) * (1/ln(10)) result = ops.mul(val_ln, const_op) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def log(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def log(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -505,112 +561,103 @@ def log(operand, *args, var_info=None, **kwargs): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.log %{operand} : {shape}', [tile_size, dtype] + return format_mlir_op(f'math.log %{operand}', shape, **kwargs), [tile_size, dtype] @staticmethod - def log1p(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def log1p(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] const_one = ops.constant(1, dtype) - # 3. 덧셈 연산: (x + 1) - # ops.add가 (result_ssa, result_info)를 반환한다고 가정 val_add = ops.add(operand, const_one) result = ops.log(val_add) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def nextafter(operand1, operand2, *args, var_info=None, **kwargs): + def nextafter(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def logical_and(operand1, operand2, *args, var_info=None, **kwargs): - if var_info[operand1][1] != "i1": + def logical_and(operand1, operand2, *args, **kwargs): + if V.kernel.var_info[operand1][1] != "i1": operand1 = ops.to_bool(operand1) - - if var_info[operand2][1] != "i1": + if V.kernel.var_info[operand2][1] != "i1": operand2 = ops.to_bool(operand2) result = ops.and_(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def logical_or(operand1, operand2, *args, var_info=None, **kwargs): - if var_info[operand1][1] != "i1": + def logical_or(operand1, operand2, *args, **kwargs): + if V.kernel.var_info[operand1][1] != "i1": operand1 = ops.to_bool(operand1) - - if var_info[operand2][1] != "i1": + if V.kernel.var_info[operand2][1] != "i1": operand2 = ops.to_bool(operand2) result = ops.or_(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def logical_xor(operand1, operand2, *args, var_info=None, **kwargs): - if var_info[operand1][1] != "i1": + def logical_xor(operand1, operand2, *args, **kwargs): + if V.kernel.var_info[operand1][1] != "i1": operand1 = ops.to_bool(operand1) - - if var_info[operand2][1] != "i1": + if V.kernel.var_info[operand2][1] != "i1": operand2 = ops.to_bool(operand2) result = ops.xor(operand1, operand2) - return result, var_info[result] - + return result, V.kernel.var_info[result] + @staticmethod - def logical_not(operand, *args, var_info=None, **kwargs): - op_info = var_info[operand] + def logical_not(operand, *args, **kwargs): + op_info = V.kernel.var_info[operand] tile_size = op_info[0] dtype = op_info[1] - zero_const = ops.constant(0, dtype) result = ops.eq(operand, zero_const) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def bitwise_and(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_and(operand1, operand2, *args, **kwargs): # Float check - if var_info[operand1][1].startswith("f") or var_info[operand2][1].startswith("f"): + if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): raise ValueError("Bitwise AND not supported for floats") - result = ops.and_(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def bitwise_not(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def bitwise_not(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] # Float check - if var_info[operand][1].startswith("f"): + if V.kernel.var_info[operand][1].startswith("f"): raise ValueError("Bitwise NOT not supported for floats") - neg_one = ops.constant(-1, dtype) - result = ops.xor(operand, neg_one) - return result, var_info[result] + result = ops.xor(operand, neg_one) + return result, V.kernel.var_info[result] @staticmethod - def bitwise_or(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_or(operand1, operand2, *args, **kwargs): # Float check - if var_info[operand1][1].startswith("f") or var_info[operand2][1].startswith("f"): + if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): raise ValueError("Bitwise AND not supported for floats") - + result = ops.or_(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def bitwise_xor(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_xor(operand1, operand2, *args, **kwargs): # Float check - if var_info[operand1][1].startswith("f") or var_info[operand2][1].startswith("f"): + if V.kernel.var_info[operand1][1].startswith("f") or V.kernel.var_info[operand2][1].startswith("f"): raise ValueError("Bitwise AND not supported for floats") - result = ops.xor(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def bitwise_left_shift(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_left_shift(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def bitwise_right_shift(operand1, operand2, *args, var_info=None, **kwargs): + def bitwise_right_shift(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def rsqrt(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def rsqrt(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] @@ -619,101 +666,104 @@ def rsqrt(operand, *args, var_info=None, **kwargs): operand = ops.to_dtype(operand, "f32") shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'math.rsqrt %{operand} : {shape}', [tile_size, dtype] + return format_mlir_op(f'math.rsqrt %{operand}', shape, **kwargs), [tile_size, dtype] @staticmethod - def sigmoid(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def sigmoid(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] one = ops.constant(1, dtype) return ops.truediv(one, ops.add(one, ops.exp(ops.neg(operand)))), [tile_size, dtype] @staticmethod - def fmod(operand1, operand2, *args, var_info=None, **kwargs): + def fmod(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def isinf(operand, *args, var_info=None, **kwargs): + def isinf(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def isnan(operand, *args, var_info=None, **kwargs): + def isnan(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def round(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def round(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): - return f"math.roundeven %{operand} : {shape}", [tile_size, dtype] + op_str = f"math.roundeven %{operand}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] else: return operand, [tile_size, dtype] @staticmethod - def floor(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def floor(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): - return f"math.floor %{operand} : {shape}", [tile_size, dtype] + op_str = f"math.floor %{operand}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] else: return operand, [tile_size, dtype] @staticmethod - def sign(operand, *args, var_info=None, **kwargs): + def sign(operand, *args, **kwargs): raise NotImplementedError @staticmethod - def trunc(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def trunc(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): - return f"math.trunc %{operand} : {shape}", [tile_size, dtype] + op_str = f"math.trunc %{operand}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] else: return operand, [tile_size, dtype] @staticmethod - def ceil(operand, *args, var_info=None, **kwargs): - tile_size, dtype = var_info[operand] + def ceil(operand, *args, **kwargs): + tile_size, dtype = V.kernel.var_info[operand] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype if dtype.startswith("f"): - return f"math.ceil %{operand} : {shape}", [tile_size, dtype] + op_str = f"math.ceil %{operand}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] else: return operand, [tile_size, dtype] # Logical operations @staticmethod - def neg(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def neg(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] # Type check & auto cast if dtype.startswith("f"): operand = ops.to_dtype(operand, "f32") - + op_str = f"arith.negf %{operand}" shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f'arith.negf %{operand} : {shape}', [tile_size, dtype] + return format_mlir_op(op_str, shape, **kwargs), [tile_size, dtype] @staticmethod - def reciprocal(operand, *args, var_info=None, **kwargs): - op_type = var_info[operand] - tile_size = op_type[0] - dtype = op_type[1] - - # Type check & auto cast - if dtype.startswith("f"): - operand = ops.to_dtype(operand, "f32") + def reciprocal(operand, *args, **kwargs): + op_type = V.kernel.var_info[operand] + tile_size, dtype = op_type[0], op_type[1] + if dtype.startswith("i"): + openand = ops.to_dtype(operand, "f32") + op_type = V.kernel.var_info[operand] + tile_size, dtype = op_type[0], op_type[1] return ops.truediv(ops.constant(1.0, dtype), operand), [tile_size, dtype] @staticmethod - def eq(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def eq(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "oeq" @@ -723,12 +773,13 @@ def eq(operand1, operand2, *args, var_info=None, **kwargs): else: raise ValueError(f"Unsupported data type for 'eq' operation: {ret_type}") + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] @staticmethod - def ne(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def ne(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "one" @@ -738,12 +789,13 @@ def ne(operand1, operand2, *args, var_info=None, **kwargs): else: raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] @staticmethod - def lt(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def lt(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "olt" @@ -753,12 +805,13 @@ def lt(operand1, operand2, *args, var_info=None, **kwargs): else: raise ValueError(f"Unsupported data type for 'lt' operation: {ret_type}") + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] @staticmethod - def gt(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def gt(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "ogt" @@ -768,12 +821,13 @@ def gt(operand1, operand2, *args, var_info=None, **kwargs): else: raise ValueError(f"Unsupported data type for 'gt' operation: {ret_type}") + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] @staticmethod - def le(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def le(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "ole" @@ -783,12 +837,13 @@ def le(operand1, operand2, *args, var_info=None, **kwargs): else: raise ValueError(f"Unsupported data type for 'le' operation: {ret_type}") + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] @staticmethod - def ge(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def ge(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if ret_type[0] == "f": op_type = "arith.cmpf" attribute = "oge" @@ -798,33 +853,37 @@ def ge(operand1, operand2, *args, var_info=None, **kwargs): else: raise ValueError(f"Unsupported data type for 'ne' operation: {ret_type}") + op_str = f'{op_type} {attribute}, %{operand1}, %{operand2}' shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'{op_type} {attribute}, %{operand1}, %{operand2} : {shape}', [tile_size, "i1"] + return format_mlir_op(op_str, shape, **kwargs), [tile_size, "i1"] @staticmethod - def add(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def add(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.add{ret_type[0]}' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def sub(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def sub(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.sub{ret_type[0]}' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def mul(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def mul(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type opcode = f'arith.mul{ret_type[0]}' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def pow(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def pow(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) # Type check & auto cast if ret_type.startswith("f"): operand1 = ops.to_dtype(operand1, "f32") @@ -834,51 +893,56 @@ def pow(operand1, operand2, *args, var_info=None, **kwargs): operand2 = ops.to_dtype(operand2, "f32") shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f"math.pow{ret_type[0]} %{operand1}, %{operand2} : {shape}", [tile_size, ret_type] + op_str = f"math.pow{ret_type[0]} %{operand1}, %{operand2}" + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def and_(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - + def and_(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'arith.andi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'arith.andi %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def or_(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - + def or_(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'arith.ori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'arith.ori %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def xor(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) - + def xor(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type - return f'arith.xori %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'arith.xori %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def lshift(operand1, operand2, *args, var_info=None, **kwargs): + def lshift(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def rshift(operand1, operand2, *args, var_info=None, **kwargs): + def rshift(operand1, operand2, *args, **kwargs): raise NotImplementedError @staticmethod - def truncdiv(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def truncdiv(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type.startswith("f"): raise ValueError("truncdiv is strictly for integers. Use truediv for floats.") - + # arith.divsi: Signed Integer Division (Result is truncated) - return f'arith.divsi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'arith.divsi %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def floordiv(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def floordiv(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type.startswith("f"): @@ -886,25 +950,27 @@ def floordiv(operand1, operand2, *args, var_info=None, **kwargs): raise ValueError("floordiv implementation expects integers based on definition.") # arith.floordivsi: Floor Division for Signed Integers - return f'arith.floordivsi %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'arith.floordivsi %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def truediv(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def truediv(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if not ret_type.startswith("f"): raise ValueError(f"truediv expects float inputs, but got {ret_type}. Use int_truediv for integers.") - return f'arith.divf %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'arith.divf %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def int_truediv(operand1, operand2, *args, var_info=None, **kwargs): + def int_truediv(operand1, operand2, *args, **kwargs): """ True division for Integers (Int -> Float). Promotes integers to floats, then performs floating-point division. """ - tile_size, src_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + tile_size, src_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) if not src_type.startswith("f"): target_float_type = "f32" operand1 = ops.to_dtype(operand1, target_float_type) @@ -912,21 +978,22 @@ def int_truediv(operand1, operand2, *args, var_info=None, **kwargs): src_type = target_float_type result = ops.truediv(operand1, operand2) - return result, var_info[result] + return result, V.kernel.var_info[result] @staticmethod - def mod(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def mod(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type[0] == "f": raise NotImplementedError("Not support remainder operation for floating point") else: opcode = f'arith.remsi' - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def remainder(operand1, operand2, *args, var_info=None, **kwargs): - tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2, var_info) + def remainder(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type if ret_type.startswith("f"): @@ -934,66 +1001,134 @@ def remainder(operand1, operand2, *args, var_info=None, **kwargs): else: opcode = 'arith.remsi' # Signed Integer Remainder (LHS sign) - return f'{opcode} %{operand1}, %{operand2} : {shape}', [tile_size, ret_type] + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod - def square(operand, *args, var_info=None, **kwargs): + def square(operand, *args, **kwargs): result = ops.mul(operand, operand) - return result, var_info[result] + return result, V.kernel.var_info[result] + + @staticmethod + def fma(operand1, operand2, operand3, *args, **kwargs): + result = ops.mul(operand1, operand2) + result = ops.add(result, operand3) + return result, V.kernel.var_info[result] # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - # PyTorchSim specific operations + # PyTorchSim specific operations @staticmethod - def alloc(size, src_type, *args, var_info=None, **kwargs): + def alloc(size, src_type, *args, **kwargs): return f"memref.alloc() : memref<{size}x{src_type}>", [size, src_type] @staticmethod - def extractelement(operand, idx, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def extractelement(operand, idx, *args, **kwargs): + op_type = V.kernel.var_info[operand] tile_size = op_type[0] dtype = op_type[1] shape = f"vector<{tile_size}x{dtype}>" if tile_size > 1 else dtype - return f"vector.extract %{operand}[{idx}]: {dtype} from {shape}", [1, dtype] + op_str = f"vector.extract %{operand}[{idx}]" + shape = f"{dtype} from {shape}" + return format_mlir_op(op_str, shape, **kwargs), [1, dtype] @staticmethod - def ext(operand, dtype, *args, var_info=None, **kwargs): - op_type = var_info[operand] + def ext(operand, dtype, *args, **kwargs): + op_type = V.kernel.var_info[operand] shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else f"{op_type[1]}" target_type = f"vector<{op_type[0]}x{dtype}>" if op_type[0] > 1 else f"{dtype}" if op_type[0] == "f": opcode = f'arith.extf' else: opcode = f'arith.extui' - return f'{opcode} %{operand} : {shape} to {target_type}', [op_type[0], dtype] + op_str = f'{opcode} %{operand}' + shape = f"{shape} to {target_type}" + return format_mlir_op(op_str, shape, **kwargs), [op_type[0], dtype] @staticmethod - def to_bool(operand, *args, var_info=None, **kwargs): - tile_size, ret_type = var_info[operand] + def to_bool(operand, *args, **kwargs): + tile_size, ret_type = V.kernel.var_info[operand] if ret_type == "i1": return operand, [tile_size, ret_type] - const_one = ops.constant(0, ret_type) + const_zero = ops.constant(0, ret_type) if tile_size > 1: - const_one = ops.broadcast(const_one, tile_size) - ret = ops.ne(operand, const_one) + const_zero = ops.broadcast(const_zero, tile_size) + ret = ops.ne(operand, const_zero) return ret, [tile_size, "i1"] @staticmethod def step(size, dtype, *args, **kwargs): index_shape = f"vector<{size}x{dtype}>" - return f"vector.step : {index_shape}", [size, dtype] + op_str = f"vector.step" + return format_mlir_op(op_str, index_shape, **kwargs), [size, dtype] @staticmethod - def index_cast(operand, target_type, *args, var_info=None, **kwrags): - op_type = var_info[operand] + def index_cast(operand, target_type, *args, **kwargs): + op_type = V.kernel.var_info[operand] src_shape = f"vector<{op_type[0]}x{op_type[1]}>" if op_type[0] > 1 else op_type[1] des_shape = f"vector<{op_type[0]}x{target_type}>" if op_type[0] > 1 else target_type - return f"arith.index_cast %{operand} : {src_shape} to {des_shape}", [op_type[0], target_type] + op_str = f"arith.index_cast %{operand}" + shape = f"{src_shape} to {des_shape}" + return format_mlir_op(op_str, shape, **kwargs), [op_type[0], target_type] @staticmethod - def shape_cast(operand, src_shape, dst_shape, *args, var_info=None, **kwargs): - operand_type = var_info[operand] - return f"vector.shape_cast %{operand} : {src_shape} to {dst_shape}", operand_type + def shape_cast(operand, src_shape, dst_shape, *args, **kwargs): + operand_type = V.kernel.var_info[operand] + op_str = f"vector.shape_cast %{operand}" + shape = f"{src_shape} to {dst_shape}" + return format_mlir_op(op_str, shape, **kwargs), operand_type + + @staticmethod + def extract_strided_slice(operand, target_size, offsets=None, sizes=None, strides=None, *args, **kwargs): + op_type = V.kernel.var_info[operand] + src_size = op_type[0] + dtype = op_type[1] + + if offsets is None: + offsets = [0] + if sizes is None: + sizes = [target_size] + if strides is None: + strides = [1] + + src_shape = f"vector<{src_size}x{dtype}>" + dst_shape = f"vector<{target_size}x{dtype}>" + + offsets_str = ", ".join(str(o) for o in offsets) + sizes_str = ", ".join(str(s) for s in sizes) + strides_str = ", ".join(str(s) for s in strides) + + # Build attributes dict for offsets, sizes, strides + built_attributes = { + "offsets": f"[{offsets_str}]", + "sizes": f"[{sizes_str}]", + "strides": f"[{strides_str}]" + } + + # Merge with any existing attributes from kwargs + existing_attributes = kwargs.get('attributes', {}) + if isinstance(existing_attributes, dict): + merged_attributes = {**built_attributes, **existing_attributes} + elif isinstance(existing_attributes, str): + built_attrs_str = ", ".join(f"{k}={v}" for k, v in built_attributes.items()) + merged_attributes = f"{built_attrs_str}, {existing_attributes}" + else: + merged_attributes = built_attributes + + op_str = f"vector.extract_strided_slice %{operand}" + shape = f"{src_shape} to {dst_shape}" + + # Pass merged attributes to format_mlir_op + updated_kwargs = {**kwargs, 'attributes': merged_attributes} + return format_mlir_op(op_str, shape, **updated_kwargs), [target_size, dtype] + + @staticmethod + def vlane_offset(operand1, operand2, *args, **kwargs): + tile_size, ret_type, operand1, operand2 = ExtensionOverrides.binary_elementwise_common(operand1, operand2) + shape = f"vector<{tile_size}x{ret_type}>" if tile_size > 1 else ret_type + opcode = f'arith.add{ret_type[0]}' + op_str = f'{opcode} %{operand1}, %{operand2}' + return format_mlir_op(op_str, shape, **kwargs), [tile_size, ret_type] @staticmethod def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_name, *args, **kwargs): @@ -1008,31 +1143,73 @@ def multi_reduction(acc, init, vec_size, red_size, red_shape, red_type, type_nam return line, [red_size, type_name] @staticmethod - def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, var_info=None, **kwargs): + def _load(compute_vec_size, mlir_dtype, buffer, indices, buffer_shape, *args, **kwargs): if compute_vec_size == 1: vshape = f"{mlir_dtype}" operation = "affine.load" - line = f"{operation} %{buffer}[{indices}] : {buffer_shape}" + line = f"{operation} %{buffer}[{indices}]" + shape = buffer_shape else: vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" operation = "affine.vector_load" - line = f"{operation} %{buffer}[{indices}] : {buffer_shape}, {vshape}" - return line, [compute_vec_size, mlir_dtype] + line = f"{operation} %{buffer}[{indices}]" + shape = f"{buffer_shape}, {vshape}" + return format_mlir_op(line, shape, **kwargs), [compute_vec_size, mlir_dtype] @staticmethod - def _store(operand, buffer, indices, buffer_shape, *args, buffer_name=None, var_info=None, **kwargs): - compute_vec_size, mlir_dtype = var_info[operand][0], var_info[operand][1] + def _store(operand, buffer, indices, buffer_shape, *args, buffer_name=None, **kwargs): + compute_vec_size, mlir_dtype = V.kernel.var_info[operand][0], V.kernel.var_info[operand][1] if compute_vec_size == 1: vshape = f"{mlir_dtype}" operation = "affine.store" - line = f"{operation} %{operand}, %{buffer}[{indices}] : {buffer_shape}" + line = f"{operation} %{operand}, %{buffer}[{indices}]" + shape = buffer_shape else: vshape = f"vector<{compute_vec_size}x{mlir_dtype}>" operation = "affine.vector_store" - line = f"{operation} %{operand}, %{buffer}[{indices}] : {buffer_shape}, {vshape}" + line = f"{operation} %{operand}, %{buffer}[{indices}]" + shape = f"{buffer_shape}, {vshape}" + line = format_mlir_op(line, shape, **kwargs) if buffer_name is not None: return common.DeferredLine(buffer_name, line), [None, None] else: - return line, [None, None] \ No newline at end of file + return line, [None, None] + + @staticmethod + def affine_apply(map_var, indices, indirect_dims=None, comment=None, *args, **kwargs): + # Format indices arguments + indices_str = ", ".join([f"%{i}" for i in indices]) + op_str = f"affine.apply #{map_var}({indices_str})" + + # Add indirect dimensions if provided + if indirect_dims: + indirect_str = ", ".join(indirect_dims) + op_str += f"[{indirect_str}] {{indirect_access}}" + if comment: + op_str += f" // {comment}" + return op_str, [1, "index"] + + @staticmethod + def affine_map(dim_names, expr_str, symbol_names=None, comment=None, *args, **kwargs): + # Handle dim_names as list or string + if isinstance(dim_names, list): + dims_str = ", ".join([str(dim) for dim in dim_names]) + else: + dims_str = dim_names + + # Build the map string + if symbol_names: + if isinstance(symbol_names, list): + symbols_str = ", ".join(symbol_names) + else: + symbols_str = symbol_names + map_str = f"affine_map<({dims_str})[{symbols_str}] -> ({expr_str})>" + else: + map_str = f"affine_map<({dims_str}) -> ({expr_str})>" + + if comment: + map_str += f" // {comment}" + + return map_str, [1, "map"] diff --git a/PyTorchSimFrontend/mlir/mlir_scheduling.py b/PyTorchSimFrontend/mlir/mlir_scheduling.py index 23be941c..faf5e69c 100644 --- a/PyTorchSimFrontend/mlir/mlir_scheduling.py +++ b/PyTorchSimFrontend/mlir/mlir_scheduling.py @@ -7,24 +7,27 @@ from PyTorchSimFrontend import extension_config from PyTorchSimFrontend.mlir.mlir_codegen_backend import MLIRKernel +from torch.utils._ordered_set import OrderedSet from torch._inductor import config from torch._inductor.scheduler import BaseScheduling, FusedSchedulerNode, SchedulerNode, BaseSchedulerNode from torch._inductor.utils import IndentedBuffer from torch._inductor.virtualized import V from torch._inductor.ir import LoopBody from torch._inductor import dependencies +from torch._inductor.codegen.common import BackendFeature from . import mlir_common from . import mlir_lowering # DO NOT REMOVE THIS LINE, it is used for lowering +from . import mlir_decomposition # DO NOT REMOVE THIS LINE, it is used for decomposition class MLIRScheduling(BaseScheduling): count = 0 target_kernel = MLIRKernel def __init__(self, scheduler): self.scheduler = scheduler - self.scheduler.can_fuse_origin = self.scheduler.can_fuse - self.scheduler.can_fuse = self.can_fuse_with_exceptions - #self.scheduler.enter_context = self.enter_context_fixed # FIXME. Monkey patch: For fixing the inductor bug + if scheduler is not None: + self.scheduler.can_fuse_origin = self.scheduler.can_fuse + self.scheduler.can_fuse = self.can_fuse_with_exceptions # FIXME. Monkey patch: For prolouge fusion self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() self._ready_to_flush = False self.outer_function = set() @@ -32,51 +35,28 @@ def __init__(self, scheduler): self.max_fusion_size = 5 def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedulerNode) -> bool: + if not extension_config.CONFIG_FUSION_PROLOGUE: + return self.scheduler.can_fuse_origin(node1, node2) + # Extract base template node base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] - if node1.get_device() != node2.get_device(): - return False - if not (isinstance(node1, (SchedulerNode, FusedSchedulerNode)) and isinstance(node2, (SchedulerNode, FusedSchedulerNode))): - return False - if len(base_template_node1) == 1 and len(base_template_node2) == 0 and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: + # Case 3: Prologue(Pointwise) + Tempalte + if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE: from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate - if (isinstance(base_template_node1[0].node.template, MLIRGemmTemplate) or isinstance(base_template_node1[0].node.template, MLIRBMMTemplate)) and node2.is_reduction(): - # For matmul/bmm+reduction case - size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) - target_symbol = symbols("r0") - try: - stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] - stride = int(sympify(stride).coeff(target_symbol)) - except: - return False - - # We can't fuse dim=-1 - layout_possible = stride != 1 - # Directed linked? - dependency_check = node2.get_nodes()[0] in [node.node for node in base_template_node1[0].users]# and len(node2.read_writes.reads)==1 - dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) - return size_match and layout_possible and dependency_check and dependency_size - # For prologue fusion case - if extension_config.CONFIG_FUSION_PROLOGUE and len(base_template_node1) == 0 and len(node1.get_nodes())==1 and len(base_template_node2) == 1: - from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate target_node = base_template_node2[0].node - if target_node.origin_node is not None and hasattr(target_node.origin_node.target, "_name") and target_node.origin_node.target._name == 'aten::convolution': - return False - if node1.is_reduction(): + # Currently only BMM, MM support prologue fusion + if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): return False + if len(node1.read_writes.writes) != 1: return False if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME return False - # Currently only BMM, MM support prologue fusion - if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): - return False # We don't fuse this edge case... if base_template_node2[0].group[1][0][0] == 1: return False @@ -84,28 +64,39 @@ def can_fuse_with_exceptions(self, node1: BaseSchedulerNode, node2: BaseSchedule if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: node1 = self.revert_group(node1) return True - return self.scheduler.can_fuse_origin(node1, node2) + def _set_flush_status(self, status: bool): self._ready_to_flush = status + def reset_kernel_group(self): + self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + + def get_backend_features(self, device): + """Return a set of .codegen.common.BackendFeature()""" + return OrderedSet([BackendFeature.REDUCE_TO_SINGLE_ELEMENT]) + def can_fuse_vertical(self, node1, node2): return self.can_fuse_horizontal(node1, node2) + def can_fuse_multi_outputs_template(self, node1, node2): + return self.can_fuse_horizontal(node1, node2) + def can_fuse_horizontal(self, node1, node2): if not extension_config.CONFIG_FUSION: return False + if (len(node1.get_nodes())+ len(node2.get_nodes())) > self.max_fusion_size: return False + _, (vars1, reduce1) = node1.group _, (vars2, reduce2) = node2.group - - # Reduction is currently not supported - if node1.is_reduction() and node2.is_reduction() and not node1.is_template() and not node2.is_template() and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION: - return vars1 == vars2 and reduce1 == reduce2 and node1.inverse_users == node2.inverse_users - if node1.is_reduction() or node2.is_reduction(): - return False + # For input/dependency checks + reads1 = {dep.name for dep in node1.read_writes.reads} + reads2 = {dep.name for dep in node2.read_writes.reads} + writes1 = {dep.name for dep in node1.read_writes.writes} + writes2 = {dep.name for dep in node2.read_writes.writes} # Can't fuse two template node if node1.is_template() and node2.is_template(): @@ -114,17 +105,37 @@ def can_fuse_horizontal(self, node1, node2): if '_unsafe_index' in node1.get_nodes()[0].node.origins or "_unsafe_index" in node2.get_nodes()[0].node.origins: return False - # Check template node fusion - if node1.is_template() or node2.is_template(): + # Extract base template node + base_template_node1 = [node for node in node1.get_nodes() if node.is_template()] + base_template_node2 = [node for node in node2.get_nodes() if node.is_template()] + + # Case 0: Reduction fusion + if ( + node1.is_reduction() + and node2.is_reduction() + and not node1.is_template() + and not node2.is_template() + and extension_config.CONFIG_FUSION_REDUCTION_REDUCTION + ): + # 1) Same loop/iteration domain + same_iter = vars1 == vars2 and reduce1 == reduce2 + # 2) No data dependency between the two reductions + no_dependency = not ( + writes1 & (reads2 | writes2) or writes2 & (reads1 | writes1) + ) + return same_iter and no_dependency + + # Case 1: Template + Pointwise fusion + if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and not node2.is_reduction(): # Don't fuse maxpool template code from PyTorchSimFrontend.mlir.mlir_maxpool_template import MLIRMaxPoolTemplate from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate - template_node1 = next((n for n in node1.get_nodes() if n.is_template()), None) - template_node2 = next((n for n in node2.get_nodes() if n.is_template()), None) - if template_node1 and len(node1.get_nodes()) == 1 and isinstance(template_node1.node.template, MLIRMaxPoolTemplate) or \ - template_node2 and len(node2.get_nodes()) == 1 and isinstance(template_node2.node.template, MLIRMaxPoolTemplate): + template_node = base_template_node1[0] + epilogue_node = node2 + + if isinstance(template_node.node.template, MLIRMaxPoolTemplate): return False # Pointwise check @@ -133,26 +144,76 @@ def can_fuse_horizontal(self, node1, node2): if v1_total != v2_total: return False - # Pattern check - template_node, act_node = (template_node1, node2) if template_node1 else (template_node2, node1) - has_depedency = set(act_node.inverse_users) <= set(template_node.get_nodes()) + # Pattern check: check data dependency between act_node and template_node + template_sched_nodes = list(template_node.get_nodes()) + # Buffers produced by the template (its outputs) + template_writes = { + dep + for n in template_sched_nodes + for dep in n.read_writes.writes + } + # Buffers still required by the activation node (unmet) or read by it + epilogue_unmet = { dep for dep in epilogue_node.unmet_dependencies } + has_depedency = bool(template_writes) and epilogue_unmet.issubset(template_writes) if not has_depedency: return False # Revert act_node.group : simplify_and_reorder() modified _body, _size, group - if template_node.group != act_node.group: + if template_node.group != epilogue_node.group: # We don't fuse this case... if (isinstance(template_node.node.template, MLIRBMMTemplate) or isinstance(template_node.node.template, MLIRGemmTemplate)) and template_node.group[1][0][0] == 1: return False - if list(template_node.group[1][0]) != list(act_node.get_nodes()[0].node.data.get_size()): + if list(template_node.group[1][0]) != list(epilogue_node.get_nodes()[0].node.data.get_size()): return False - self.revert_group(act_node) + self.revert_group(epilogue_node) return True - # Check elementwise fusion - if vars1 == vars2 and reduce1 == reduce2: - return True + # Case 2: Tempalte + Reduction fusion + if len(base_template_node1) == 1 and len(node1.get_nodes())==1 and len(node2.get_nodes())==1 and len(base_template_node2) == 0 and node2.is_reduction() and extension_config.CONFIG_FUSION_REDUCTION_EPILOGUE: + from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + target_node = base_template_node1[0].node + if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + return False + + size_match = node1.get_nodes()[0].node.get_numel() == reduce(operator.mul, node2.get_nodes()[0].node.get_size(), 1) * reduce(operator.mul, node2.get_nodes()[0].node.get_reduction_size(), 1) + target_symbol = symbols("r0_0") + try: + stride = [i.strip()[:-1].split(",")[-1].strip() for i in str(node2.get_nodes()[0].node).split("\n") if "r0" in i][1] + stride = int(sympify(stride).coeff(target_symbol)) + except: + return False + + # We can't fuse dim=-1 & N == 1 + layout_possible = stride != 1 and (1 not in node1.node.get_size()) + # Directed linked? + dependency_check = writes1 & reads2 + dependency_size = all([i.get_numel() == node1.get_nodes()[0].node.get_numel() for i in node2.read_writes.reads]) + return size_match and layout_possible and dependency_check and dependency_size + + # Case 3: Prologue(Pointwise) + Tempalte + # if len(base_template_node1) == 0 and len(node1.get_nodes())==1 and not node1.is_reduction() and len(base_template_node2) == 1 and extension_config.CONFIG_FUSION_PROLOGUE: + # from PyTorchSimFrontend.mlir.mlir_gemm_template import MLIRGemmTemplate + # from PyTorchSimFrontend.mlir.mlir_bmm_template import MLIRBMMTemplate + + # target_node = base_template_node2[0].node + # # Currently only BMM, MM support prologue fusion + # if not isinstance(target_node.template, (MLIRBMMTemplate, MLIRGemmTemplate)): + # return False + + # if len(node1.read_writes.writes) != 1: + # return False + # if node1.node not in target_node.inputs or any(["view" in str(ori) for ori in node1.node.origins]): #FIXME + # return False + + # # We don't fuse this edge case... + # if base_template_node2[0].group[1][0][0] == 1: + # return False + + # if list(node1.read_writes.writes)[0].name in [dep.name for dep in node2.read_writes.reads]: + # node1 = self.revert_group(node1) + # return True return False def revert_group(self, act_nodes, args=None, var_ranges=None): @@ -165,6 +226,8 @@ def revert_group(self, act_nodes, args=None, var_ranges=None): act_node.node.get_store_function(), (args if act_node.node.get_reduction_type() else args[:1]), var_ranges, + args[0], + args[1] ) index_size = [] reduce_size = [] @@ -180,12 +243,13 @@ def revert_group(self, act_nodes, args=None, var_ranges=None): def group_fn(self, sizes): return tuple(tuple(map(V.graph.sizevars.simplify, s)) for s in sizes) - def codegen_nodes(self, nodes): + def codegen_node(self, _node): + nodes = _node.get_nodes() _, (group, reduction_group) = max( nodes, key=lambda x: int(x.is_reduction()) ).group - # Note: We assume that ther is at least one loop in the nodes + # Note: We assume that there is at least one loop in the nodes # But, inductor simplifies the group, there could be no loop # In that case, we add dummy loop(size=1) to the group if len(group) == 0: @@ -210,8 +274,8 @@ def codegen_nodes(self, nodes): kernel_name_candidate = f"extension_kernel_{MLIRScheduling.count}" MLIRScheduling.count += 1 - src_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) - kernel_name = self.define_kernel(src_code, kernel_name_candidate, ex_kernel.vector_lane, + src_code, meta_code = ex_kernel.codegen_nodes(nodes, kernel_name_candidate) + kernel_name = self.define_kernel(src_code, meta_code, kernel_name_candidate, ex_kernel.vector_lane, ex_kernel.spad_info, origins= {str(i) for i in nodes[0].node.origins}) ex_kernel.call_kernel(kernel_name) _, args, _, _ = ex_kernel.args.mlir_argdefs() @@ -230,57 +294,50 @@ def codegen_sync(self): pass def flush(self): - self.kernel_group.codegen_define_and_call(V.graph.wrapper_code) - self.kernel_group = mlir_common.MLIRWrapperKenrelGroup() + src_code = self.kernel_group.codegen_group() + if src_code: + kernel_name = self.define_kernel( + src_code, self.kernel_group.scheduled_nodes + ) + self.kernel_group.call_kernel(V.graph.wrapper_code, kernel_name) + self.reset_kernel_group() self._set_flush_status(False) def define_function(self, kernel): partial_code, function_name = kernel.def_function() if partial_code is not None and function_name not in self.outer_function: with V.set_kernel_handler(kernel): - code = partial_code.finalize() + code = partial_code.finalize_all() wrapper = V.graph.wrapper_code wrapper.header.writeline(code) self.outer_function.add(function_name) - def define_kernel(self, src_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): + def define_kernel(self, src_code, meta_code, kernel_name, vector_lane, spad_info, loop_size=None, origins={}): wrapper = V.graph.wrapper_code if src_code in wrapper.src_to_kernel: kernel_name = wrapper.src_to_kernel[src_code] else: wrapper.src_to_kernel[src_code] = kernel_name - codecache_def = IndentedBuffer() codecache_def.writeline(f"custom_async_compile.mlir('''{src_code}''', ") codecache_def.writeline(f"vectorlane_size={vector_lane},") codecache_def.writeline(f"loop_size={loop_size},") codecache_def.writeline(f"spad_info={spad_info},") codecache_def.writeline(f"origins={origins},") - codecache_def.writeline("arg_attributes=arg_attributes,") + codecache_def.writeline(f"arg_attributes={meta_code},") codecache_def.writeline(f"vlen={extension_config.vpu_vector_length_bits})") - wrapper.define_kernel(kernel_name, codecache_def.getvalue(), cuda=False) + wrapper.define_kernel(kernel_name, codecache_def.getvalue(), gpu=False) return kernel_name - def codegen_template(self, template_node, epilogue_nodes): - # Handle prologue pattern - prologue_nodes = [] - if not template_node.is_template(): - epilogue_nodes = [template_node] + epilogue_nodes - for i, node in enumerate(epilogue_nodes): - if node.is_template(): - template_node = node - prologue_nodes = epilogue_nodes[:i] - epilogue_nodes = epilogue_nodes[i+1:] - break - + def codegen_template(self, template_node, epilogue_nodes, prologue_nodes): # Generate template code template_buffer = template_node.node kernel, tile_candidates, render = template_buffer.make_kernel_render(template_buffer, prologue_nodes=prologue_nodes, epilogue_nodes=epilogue_nodes, kernel_group=self.kernel_group) _, _, _, kernel.buffer_types = self.kernel_group.args.mlir_argdefs() - src_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) + src_code, meta_code = kernel.codegen_nodes(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) - with V.set_kernel_handler(kernel): - kernel_name = self.define_kernel(src_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, + with kernel: + kernel_name = self.define_kernel(src_code, meta_code, kernel.kernel_name, kernel.vector_lane, kernel.spad_info, kernel.loop_size, origins={str(i) for i in template_node.node.origins}) self.define_function(kernel) diff --git a/PyTorchSimFrontend/mlir/mlir_template.py b/PyTorchSimFrontend/mlir/mlir_template.py index a36bc907..b864e5f2 100644 --- a/PyTorchSimFrontend/mlir/mlir_template.py +++ b/PyTorchSimFrontend/mlir/mlir_template.py @@ -13,8 +13,8 @@ from typing import List, Optional from unittest.mock import patch -from torch._inductor.codegen.common import KernelTemplate, ChoiceCaller, CSE, DeferredLine -from torch._inductor.ir import Buffer, IRNode, TemplateBuffer +from torch._inductor.codegen.common import KernelTemplate, CSE, DeferredLine +from torch._inductor.ir import Buffer, IRNode, TemplateBuffer, ChoiceCaller from torch._inductor.select_algorithm import PartialRender from torch._inductor.codegen.cuda.cuda_kernel import CUDATemplateCaller from torch._inductor.autotune_process import TensorMeta @@ -32,6 +32,9 @@ from PyTorchSimFrontend import extension_config from . import mlir_common +# Configure logger for mlir_template module +logger = extension_config.setup_logger() + class IndentedBufferGroup: def __init__(self, kernel: 'MLIRTemplateKernel', prefix=""): self.kernel = kernel @@ -386,7 +389,6 @@ def conv_single_batch_mapping(self, M, N, K, K_H, K_W, O_H, O_W, stride, dilatio return tile_candidates def meta_kernel(self): - wrapper = V.graph.wrapper_code kernel_arg_attributes = self.kernel_arg_attributes _, _, arg_attributes, _ = self.kernel_group.args.mlir_argdefs() if kernel_arg_attributes is not None: @@ -394,18 +396,14 @@ def meta_kernel(self): for idx in range(len(arg_attributes)): if arg_attributes[idx][0] == name: arg_attributes[idx][1] = attr - wrapper.add_import_once('\nprint(f\'Wrapper Codegen Path = {__file__}\')') - # Dump loop and load/store information - wrapper.add_import_once(f"loop_info = {self.loop_info}") - wrapper.add_import_once(f"arg_attributes = {arg_attributes}") + return arg_attributes def call_kernel(self, kernel_name): wrapper = V.graph.wrapper_code _, call_args, _, _ = self.kernel_group.args.mlir_argdefs() # generate the code to call this wrapper.generate_kernel_call( - kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", - call_args, cuda=False) + kernel_name if self.outer_func_name is None else self.outer_func_name + f"_{len(call_args)}", call_args) def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_nodes, tile_info): with self as kernel: @@ -431,7 +429,7 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ ).group prologue_tile_desc = kernel.set_tile_size(kernel.prologue_info, prologue=True) kernel.kernel_group.set_tile_info(prologue_tile_desc) - vars, reduction_vars = kernel.set_ranges(group, reduction_group) + vars, reduction_vars = kernel.set_ranges(group, reduction_group, list(self.dim_aliasing.values())) for node in prologue_nodes: # Reuse created spad read_list = sorted([i.name for i in node.read_writes.reads]) @@ -471,15 +469,15 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ _, (group, reduction_group) = max( epilogue_nodes, key=lambda x: int(x.is_reduction()) ).group - vars, reduction_vars = kernel.set_ranges(group, reduction_group) + vars, reduction_vars = kernel.set_ranges(group, reduction_group, list(self.dim_aliasing.values())) for node in epilogue_nodes: node.codegen((vars, reduction_vars)) - with V.set_kernel_handler(kernel): + with self as kernel: src_code = ( partial_code if isinstance(partial_code, str) - else partial_code.finalize() + else partial_code.finalize_all() ) # For consistency, white space could make wrong write_path @@ -487,38 +485,36 @@ def codegen_template_code(self, render, template_node, prologue_nodes, epilogue_ buffer.splice(src_code) src_code = buffer.getvalue() self._prepare_simulator_headers(src_code) - return src_code + meta_code = self.meta_kernel() + return src_code, meta_code def make_choices(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): choices = [] for tile_info in tile_candidates: - if extension_config.CONFIG_DEBUG_MODE: - # Compute Tile M, N, K DMA Tile M, N, K - print(f"[Auto-tune] Trying tile size: {list(tile_info)}") - src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) + # Compute Tile M, N, K DMA Tile M, N, K + logger.debug(f"Auto-tune: Trying tile size: {list(tile_info)}") + src_code, meta_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) bench_runner = self.run_bench([template_node], self.kernel_name, src_code) - choices.append((bench_runner, src_code, tile_info, self.loop_size)) + choices.append((bench_runner, src_code, meta_code, tile_info, self.loop_size)) self.reset(reason=None) return choices def _log_autotune_result(self, best_choice, best_cycle): - tile_size = best_choice[2] - print( - f"[Auto-tune] Optimal tile size: {list(tile_size)}, " + tile_size = best_choice[3] + logger.debug( + f"Auto-tune: Optimal tile size: {list(tile_size)}, " f"cycles: {best_cycle}" ) def codegen_nodes(self, tile_candidates, render, template_node, prologue_nodes, epilogue_nodes): if "autotune" in extension_config.codegen_mapping_strategy and len(tile_candidates): - src_code, loop_size = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) + src_code, meta_code, loop_size = self.autotune(tile_candidates, render, template_node, prologue_nodes, epilogue_nodes) self.loop_size = loop_size else: tile_info = tile_candidates[0] if tile_candidates else None - src_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) + src_code, meta_code = self.codegen_template_code(render, template_node, prologue_nodes, epilogue_nodes, tile_info) - with V.set_kernel_handler(self): - self.meta_kernel() - return src_code + return src_code, meta_code def _prepare_simulator_headers(self, src_code): spad_end_symbol = f"int spad_end[0] __attribute__ ((section(\".spad\")));\n" @@ -577,8 +573,8 @@ def template_store(): with contextlib.ExitStack() as stack: stack.enter_context(compute_body.indent(attribute="{inner_loop=false}",suffix=self.compute_body_loop.epilogue_line())) if self.reduction_fusion: - compute_body.writelines(self.reduction_body_loop.lines()) compute_body.splice(self.masks) + compute_body.writelines(self.reduction_body_loop.lines()) stack.enter_context(compute_body.indent(attribute="{inner_loop=false}")) compute_body.splice(self.loads) compute_body.splice(self.compute) @@ -753,7 +749,7 @@ def hook(): return "" def def_function(self): - _, call_args, _ = self.kernel_group.args.python_argdefs() + _, call_args, _, _ = self.kernel_group.args.python_argdefs() if self.outer_func_render is not None: partial_code, function_name = self.outer_func_render(input_args=call_args) return PartialRender( @@ -789,8 +785,8 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com subtile_size:list=[], async_type=None, indent_size=0): # Prepare code block local_code = IndentedBuffer() - with V.set_kernel_handler(self): - index_var = self.parse_index_list(index_list, local_code, offset=tile_desc.offset) + with self, self.override_buffer_cse(buffer=local_code, cse=self.apply_cse): + index_var = self.parse_index_list(index_list, offset=tile_desc.offset) node_layout = self.named_nodes[dram_var].get_layout() if dram_var in self.exception_nodes: numel = self.exception_nodes[dram_var]["numel"] @@ -830,7 +826,7 @@ def def_dma_op(self, dma_type, dram_var:str, index_list:list, tile_desc:mlir_com def def_sram_buffer(self, dram_name, tile_desc, id=0, indent_size=0): # Prepare code block - with V.set_kernel_handler(self): + with self: dtype = self.named_nodes[dram_name].get_layout().dtype tile_shape = tile_desc.get_mlir_shape(mlir_common.DTYPE_TO_MLIR[dtype]) buffer_name = self.allocate_sram_buffer(dtype, dram_name, tile_desc, id, forced_name=dram_name) @@ -852,14 +848,14 @@ def get_spad_size_per_lane(self, tile_m, tile_n): return max(size, 2) # vector load/store def load_epilogue(self, name: str, index: sympy.Expr): - index = self.rename_indexing(index) dram_var = self.kernel_group.args.input(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] # Want to use tile_desc from epilogue_info - index_var = self.parse_indices(index) + with self.override_buffer_cse(buffer=self.applys, cse=self.apply_cse): + index_var = self.parse_indices(index) dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride @@ -893,7 +889,10 @@ def load_epilogue(self, name: str, index: sympy.Expr): vsize = compute_vec_size//reduce_size if compute_vec_size > 1: - offset = self.cse.generate(self.loads, f"affine.apply affine_map<(d0, d1) -> (d0 + d1*{(self.r_tile_size)})>(%{self.compute_idx}, %{self.reduction_loop_idx})") + with self.override_buffer_cse(buffer=self.global_vars, cse=self.map_cse): + map_var = ops.affine_map(["d0", "d1"], f"d0 + d1*{(self.r_tile_size)}") + with self.override_buffer_cse(buffer=self.loads): + offset = ops.affine_apply(map_var, [self.compute_idx, self.reduction_loop_idx]) compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{offset}"]) with self.override_buffer_cse(buffer=self.loads): @@ -902,13 +901,13 @@ def load_epilogue(self, name: str, index: sympy.Expr): return out def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): - index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - index_var = self.parse_indices(index) + with self.override_buffer_cse(buffer=self.applys, cse=self.apply_cse): + index_var = self.parse_indices(index) dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()] vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride @@ -929,7 +928,7 @@ def store_epilogue(self, name: str, index: sympy.Expr, value, *args, **kwargs): _, operand_type = self.var_info[value] if mlir_dtype != operand_type: - value = ops.to_dtype(value, mlir_dtype, var_info=self.var_info) + value = ops.to_dtype(value, mlir_dtype) compute_index_var = ",".join([f"%{zero_var}"] * (self.kernel_group.tile_desc.get_nr_dim()-1) + [f"%{self.compute_idx}"]) # Generate vector load instruction buffer_name = name if not store_force else None @@ -987,15 +986,17 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): compute_index_var = ", ".join(zero_var_list) with self.override_buffer_cse(buffer=self.loads): out = ops._load(vec_size, type_name, sram_var, compute_index_var, tile_shape) - # Reduction body codegen with self.override_buffer_cse(buffer=self.const_buffer, cse=self.const_cse): init = ops.constant(reduction_init(reduction_type, dtype), type_name) init_vec = ops.broadcast(init, compute_vec_size) + init_vec2 = ops.broadcast(init, local_tile_desc.get_numel_per_lane()) + ops._store(init_vec2, sram_var, ", ".join([f"%{self.get_const_cse(0)}"] * local_tile_desc.get_nr_dim()), tile_shape) mask_shape, mask_var = self.get_mask() if mask_var is not None: value = ops.where(mask_var, value, init_vec) + result = reduction_partial_combine_vec(reduction_type, value, out) # Store partial result @@ -1004,13 +1005,13 @@ def reduction_epilogue(self, dtype, src_dtype, reduction_type, value): return sram_var def store_reduction_epilogue(self, name, index, value): - index = self.rename_indexing(index) dram_var = self.kernel_group.args.output(name) dram_shape = mlir_common.MLIRKernelArgs.get_mlir_shape(self.buffer_types[name]) dtype = V.graph.get_dtype(name) mlir_dtype = mlir_common.DTYPE_TO_MLIR[dtype] - index_var = self.parse_indices(index, self.reductions_suffix, comments="// Store reduction") + with self.override_buffer_cse(buffer=self.reductions_suffix, cse=self.apply_cse): + index_var = self.parse_indices(index, comments="// Store reduction") dram_stride = [index.coeff(sympy.Symbol(val)) for val in self.dim_aliasing.values()][:-1] # Assume that there is only one reduction axis vlane_split_axis = self.kernel_group.tile_desc.vmap.vlane_split_axis vlane_stride = self.kernel_group.tile_desc.vmap.vlane_stride @@ -1107,7 +1108,7 @@ def set_tile_size(self, template_fusion_info, prologue=False): self.r_tile_size = tile_desc.get_tile_size()[-1] self.r_dim_size = template_fusion_info['r_dim_size'] self.reduction_nr_outer_loop = nr_outer_loop - self.reduction_loop_idx = "reduce_loop_idx" + self.reduction_loop_idx = self.register_var_cse("reduce_loop_idx", 1, "index") self.compute_body_loop.size = r_tile_size self.compute_body_loop.step = tile_desc.get_compute_vec_size() // nr_outer_loop self.reduction_body_loop = mlir_common.LoopLevel(self.reduction_loop_idx, nr_outer_loop) @@ -1122,14 +1123,6 @@ def set_tile_size(self, template_fusion_info, prologue=False): self.compute_body_loop.step = tile_desc.get_compute_vec_size() return tile_desc - def rename_indexing(self, index) -> sympy.Expr: - for dim_name, dim_aliased_name in self.dim_aliasing.items(): - index = index.subs(sympy.Symbol(dim_name), sympy.Symbol("tmp_"+dim_aliased_name)) - # To avoid this case ({"index0":"index1", "index1":"index0"}) - for dim_aliased_name in self.dim_aliasing.values(): - index = index.subs(sympy.Symbol("tmp_"+dim_aliased_name), sympy.Symbol(dim_aliased_name)) - return index - class MLIRTemplateCaller(CUDATemplateCaller): def __str__(self): return f"MLIRTemplateCaller(source_file={self.bmreq.source_file})" @@ -1153,7 +1146,7 @@ def __init__(self, name, input_nodes, layout, input_reorder = None): """ super().__init__(name) self.input_nodes = [node for node in input_nodes if node is not None] - self.output_node: Buffer = Buffer("buf_out", layout) + self.output_node: Buffer = Buffer(name="buf_out", layout=layout) self.input_reorder = input_reorder self.layout = layout @@ -1218,7 +1211,10 @@ def make_kernel_render( self.output_node.get_layout(), make_kernel_render, bmreq, + False, # supports_epilogue_fusion self, + kwargs, + "" # Currently Empty description ) def get_tile_candidates(self, **kwargs): diff --git a/README.md b/README.md index 4d98baa4..4a3ef145 100644 --- a/README.md +++ b/README.md @@ -99,7 +99,7 @@ The `tests` directory contains several AI workloads examples. ```bash python tests/test_matmul.py ``` -The result is stored to `TORCHSIM_DUMP_PATH/hash/togsim_result/`. The log file contains detailed core, memory, and interconnect stats. +The result is stored to `TORCHSIM_LOG_PATH/hash/togsim_result/`. The log file contains detailed core, memory, and interconnect stats. ### Run Your Own Model on PyTorchSim You can run your own PyTorch model on PyTorchSim by setting up a custom NPU device. @@ -197,9 +197,9 @@ Log contains memory & core stats. [2025-12-05 08:05:52.538] [info] Total execution cycles: 2065 [2025-12-05 08:05:52.538] [info] Wall-clock time for simulation: 0.147463 seconds ``` -The log is dumped in `TORCHSIM_DUMP_PATH` and you can set the path as below. +The log is dumped in `TORCHSIM_LOG_PATH` and you can set the path as below. ```bash -export TORCHSIM_DUMP_PATH=/tmp/torchinductor # output file dump path +export TORCHSIM_LOG_PATH=/tmp/torchinductor # output file dump path ``` ## Training diff --git a/Scheduler/scheduler.py b/Scheduler/scheduler.py index 8aa849b1..cdcdd2a7 100644 --- a/Scheduler/scheduler.py +++ b/Scheduler/scheduler.py @@ -1,5 +1,6 @@ from typing import List import os +import sys import numpy as np import torch from pathlib import Path @@ -8,6 +9,10 @@ from Simulator.simulator import TOGSimulator from PyTorchSimFrontend import extension_config +# Configure logger for Scheduler module +logger = extension_config.setup_logger() + + def import_module_from_path(module_name, path): module_path = Path(path) # Convert to Path object for safety if not module_path.exists() or not module_path.is_file(): @@ -166,46 +171,24 @@ def __init__(self, tog_simulator : TOGSimulator, num_partion=1) -> None: def setup_device(cls): if cls.NPU_MODULE is not None: return cls.NPU_MODULE - source_file_path = os.path.dirname(os.path.abspath(__file__)) - source_file = os.path.join( - source_file_path, f"{extension_config.CONFIG_TORCHSIM_DIR}/PyTorchSimFrontend/extension_device.cpp" - ) - import torch.utils.cpp_extension - module = torch.utils.cpp_extension.load( - name="npu", - sources=[ - str(source_file), - ], - extra_cflags=["-g"], - verbose=True, - ) + try: + from torch._inductor.codegen.common import register_backend_for_device + from PyTorchSimFrontend.mlir.mlir_codegen_backend import ExtensionWrapperCodegen + from PyTorchSimFrontend.mlir.mlir_scheduling import MLIRScheduling + except ImportError as e: + logger.error(f"Failed to import torch_openreg: {e}") + logger.error("Please ensure PyTorchSimDevice2 is installed: pip install -e PyTorchSimDevice2") + raise - torch.utils.rename_privateuse1_backend("npu") - torch._register_device_module("npu", module) - from torch._inductor.codegen.common import ( - get_scheduling_for_device, - get_wrapper_codegen_for_device, - register_backend_for_device, - ) - from PyTorchSimFrontend.mlir.mlir_codegen_backend import ( - ExtensionWrapperCodegen, - ) - from PyTorchSimFrontend.mlir.mlir_scheduling import ( - MLIRScheduling - ) register_backend_for_device( - "npu", MLIRScheduling, ExtensionWrapperCodegen - ) - assert( - get_scheduling_for_device("npu") == MLIRScheduling - ) - assert( - get_wrapper_codegen_for_device("npu") - == ExtensionWrapperCodegen + "npu", + lambda scheduling: MLIRScheduling(scheduling), + ExtensionWrapperCodegen ) - cls.NPU_MODULE = module - return module + + cls.NPU_MODULE = torch.npu + return cls.NPU_MODULE def submit(self, batched_req, partition_idx) -> List[RequestReturn]: # FIXME. Construct SchedulerDNNModel @@ -362,6 +345,12 @@ def __init__(self, num_request_queue=1, max_batch=1, engine_select=FIFO_ENGINE, togsim_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "TOGSim") self.tog_simulator = TOGSimulator(togsim_path, togsim_config) + if self.tog_simulator.config_yaml['pytorchsim_timing_mode'] == 0: + # Scheduler requires timing mode to be enabled (pytorchsim_timing_mode != 0). + logger.error(f"pytorchsim_timing_mode is set to 0 in config file '{togsim_config}'. ") + logger.error(f"Scheduler requires timing mode to be enabled (pytorchsim_timing_mode != 0).") + exit(0) + os.environ['TOGSIM_CONFIG'] = togsim_config self.tog_simulator.interactive_simulation() if engine_select == Scheduler.FIFO_ENGINE: @@ -369,7 +358,7 @@ def __init__(self, num_request_queue=1, max_batch=1, engine_select=FIFO_ENGINE, elif engine_select == Scheduler.RR_ENGINE: self.execution_engine = RoundRobinRunner(self.tog_simulator, self.num_request_queue) else: - print(f"Not supporetd engine type {engine_select}") + logger.error(f"Not supported engine type {engine_select}") exit(1) def add_request(self, request: Request, request_time=-1): @@ -430,9 +419,11 @@ def finish_request(self, req : Request): self.finish_queue.append(req) self.request_queue[req.request_queue_idx].remove(req) turnaround_time, response_time, tbt_time = req.get_latency() - print(f"[Request-{req.id} finished] partition: {req.request_queue_idx} arrival_time: " - f"{req.arrival_time} start_time: {req.start_time[0]} turnaround latency: {turnaround_time}, " - f"response time: {response_time} tbt_time: {tbt_time}") + logger.info( + f"[Request-{req.id} finished] partition: {req.request_queue_idx} arrival_time: " + f"{req.arrival_time} start_time: {req.start_time[0]} turnaround latency: {turnaround_time}, " + f"response time: {response_time} tbt_time: {tbt_time}" + ) def per_schedule(self, request_queue_idx): # Wait partition is idle @@ -443,11 +434,13 @@ def per_schedule(self, request_queue_idx): if not request_list: return False - print(f"[Request issue] partition: {request_queue_idx} batch size: {len(request_list)}", flush=True) + logger.info(f"[Request issue] partition: {request_queue_idx} batch size: {len(request_list)}") for req in request_list: req.set_start(self.current_time()) - print(f"[Request-{req.id} issue] partition: {req.request_queue_idx} " - f"arrival_time: {req.arrival_time} start_time: {req.start_time[0]}", flush=True) + logger.info( + f"[Request-{req.id} issue] partition: {req.request_queue_idx} " + f"arrival_time: {req.arrival_time} start_time: {req.start_time[0]}" + ) # Submit batched request self.execution_engine.submit(request_list, request_queue_idx) diff --git a/Simulator/simulator.py b/Simulator/simulator.py index 672ae6ec..96a1fc86 100644 --- a/Simulator/simulator.py +++ b/Simulator/simulator.py @@ -17,7 +17,46 @@ from PyTorchSimFrontend.mlir.mlir_common import MLIRKernelArgs from PyTorchSimFrontend import extension_config -print_lock = threading.Lock() +# Configure logger for Simulator module +logger = extension_config.setup_logger() +from tqdm import tqdm + + +class ProgressBar: + def __init__(self, desc, silent_mode=False, update_interval=0.5): + self.desc = desc + self.silent_mode = silent_mode + self.update_interval = update_interval + self.pbar = None + self.finished = False + self.progress_thread = None + + def __enter__(self): + if not self.silent_mode: + self.pbar = tqdm( + desc=self.desc, + bar_format='{desc}: {elapsed}', + leave=False, # Don't leave the bar when done (it will disappear) + ncols=80, + disable=False, + total=100, # Use a total for smooth animation + ) + # Update progress bar in a separate thread + def update_progress(): + while not self.finished: + self.pbar.update(1) + time.sleep(self.update_interval) + + self.progress_thread = threading.Thread(target=update_progress, daemon=True) + self.progress_thread.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.finished = True + if not self.silent_mode and self.pbar is not None: + self.pbar.close() + return False + TORCH_TO_NUMPY = { torch.float32: np.float32, @@ -105,17 +144,18 @@ def run_spike(self, args, arg_attributes, runtime_path, binary, vectorlane_size= os.makedirs(os.path.join(runtime_path, "indirect_access"), exist_ok=True) os.makedirs(os.path.join(runtime_path, "dma_access"), exist_ok=True) run = f'spike --isa rv64gcv --varch=vlen:256,elen:64 {vectorlane_option} {spad_option} {kernel_address} {base_path} /workspace/riscv-pk/build/pk {target_binary} {file_path_str}' - if not silent_mode and extension_config.CONFIG_DEBUG_MODE: - print("[Spike] cmd> ", run) - print("[Spike] Running Spike simulator") + if not silent_mode: + logger.debug(f"[Spike] cmd> {run}") + logger.info("[Spike] Running Spike simulator") run_cmd = shlex.split(run) try: stdout_setting = subprocess.DEVNULL if silent_mode else None stderr_setting = subprocess.DEVNULL if silent_mode else None - subprocess.check_call(run_cmd, stdout=stdout_setting, stderr=stderr_setting) + with ProgressBar("[Spike] Running simulation", silent_mode=silent_mode): + subprocess.check_call(run_cmd, stdout=stdout_setting, stderr=stderr_setting) except subprocess.CalledProcessError as e: if not silent_mode: - print("[Spike] Command failed with exit code", e.returncode) + logger.error(f"[Spike] Command failed with exit code {e.returncode}") error_msg = "" if e.returncode == 200: error_msg = "INVALID_SPAD_ACCESS" @@ -155,41 +195,23 @@ def __init__(self) -> None: pass def compile_and_simulate(self, target_binary, array_size, vectorlane_size, silent_mode=False): - def show_progress(): - i = 0 - while not finished: - i = (i + 1) % 3 - tail = "." * i + " " * (3-i) - with print_lock: - sys.stdout.write("\r[Gem5] Gem5 is running." + tail) - sys.stdout.flush() - time.sleep(1) - with print_lock: - print("") - dir_path = os.path.join(os.path.dirname(target_binary), "m5out") gem5_script_path = os.path.join(extension_config.CONFIG_TORCHSIM_DIR, "gem5_script/script_systolic.py") gem5_cmd = [extension_config.CONFIG_GEM5_PATH, "-r", "--stdout-file=sto.log", "-d", dir_path, gem5_script_path, "-c", target_binary, "--vlane", str(vectorlane_size)] + + is_dryrun = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) or silent_mode + + if not is_dryrun: + logger.debug(f"[Gem5] cmd> {' '.join(gem5_cmd)}") + logger.info("[Gem5] Gem5 simulation started") + try: - # Create progress thread - is_dryrun = int(os.environ.get('TOGSIM_EAGER_MODE', default=False)) or silent_mode - if not is_dryrun: - if extension_config.CONFIG_DEBUG_MODE: - print("[Gem5] cmd> ", " ".join(gem5_cmd)) - finished = False - progress_thread = threading.Thread(target=show_progress) - progress_thread.start() - output = subprocess.check_output(gem5_cmd, stderr=subprocess.DEVNULL) - finished = True - progress_thread.join() - else: - output = subprocess.check_output(gem5_cmd, stderr=subprocess.DEVNULL) + #with ProgressBar("[Gem5] Running simulation", silent_mode=is_dryrun): + output = subprocess.check_output(gem5_cmd, stderr=subprocess.DEVNULL) except subprocess.CalledProcessError as e: - print(f"[Gem5] Gem5 simulation failed with error: \"{e.output.decode()}\"") - if not is_dryrun: - finished = True - progress_thread.join() - raise RuntimeError(f"Gem5 Simulation Failed: \"{e.output.decode()}\"") + output_error = e.output.decode() if isinstance(e.output, bytes) else str(e.output) + logger.debug(f"[Gem5] Gem5 simulation failed with error: \"{output_error}\"") + raise RuntimeError(f"Gem5 Simulation Failed: \"{output_error}\"") with open(f"{dir_path}/stats.txt", "r") as stat_file: raw_list = stat_file.readlines() @@ -216,39 +238,21 @@ def get_togsim_command(self): return cmd def simulation(self, model_path, attribute_path="", silent_mode=False, autotune_mode=False): - def show_progress(): - i = 0 - while not finished: - i = (i + 1) % 3 - tail = "." * i + " " * (3-i) - sys.stdout.write("\r[TOGSim] TOGSim is running." + tail) - time.sleep(1) - print("") cmd = f"{self.get_togsim_command()} --models_list {model_path}" if extension_config.CONFIG_TOGSIM_DEBUG_LEVEL: cmd += f" --log_level {extension_config.CONFIG_TOGSIM_DEBUG_LEVEL}" if attribute_path: cmd = f"{cmd} --attributes_list {attribute_path}" - if not silent_mode and extension_config.CONFIG_DEBUG_MODE: - print("[TOGSim] cmd> ", cmd) - - # Create progress thread if not silent_mode: - finished = False - progress_thread = threading.Thread(target=show_progress) - progress_thread.start() + logger.debug(f"[TOGSim] cmd> {cmd}") + logger.info("[TOGSim] TOGSim simulation started") + try: - result = subprocess.check_output(shlex.split(cmd)) - if not silent_mode: - finished = True - progress_thread.join() + with ProgressBar("[TOGSim] Running simulation", silent_mode=silent_mode): + result = subprocess.check_output(shlex.split(cmd)) except subprocess.CalledProcessError as e: - if not silent_mode: - finished = True - progress_thread.join() - with print_lock: - print("[TOGSim] Command failed with exit code", e.returncode) - print("[TOGSim] Error output:", e.output) + logger.error(f"[TOGSim] Command failed with exit code {e.returncode}") + logger.error(f"[TOGSim] Error output: {e.output.decode() if isinstance(e.output, bytes) else e.output}") assert 0 # Separate Autotune logs @@ -271,10 +275,10 @@ def show_progress(): f.flush() os.fsync(f.fileno()) - if not silent_mode or extension_config.CONFIG_DEBUG_MODE: - model_path_log = f' of "{model_path}" ' if extension_config.CONFIG_DEBUG_MODE else " " - with print_lock: - print(f'[TOGSim] Simulation log{model_path_log}is stored to "{result_path}"') + if not silent_mode: + import logging as _logging + model_path_log = f' of "{model_path}" ' if logger.isEnabledFor(_logging.DEBUG) else " " + logger.info(f'[TOGSim] Simulation log{model_path_log}is stored to "{result_path}"') return result_path def interactive_simulation(self): @@ -282,8 +286,7 @@ def interactive_simulation(self): if extension_config.CONFIG_TOGSIM_DEBUG_LEVEL: cmd += f" --log_level {extension_config.CONFIG_TOGSIM_DEBUG_LEVEL}" - if extension_config.CONFIG_DEBUG_MODE: - print("[TOGSim] cmd> ", cmd) + logger.debug(f"[TOGSim] cmd> {cmd}") if self.process is None: self.process = subprocess.Popen( shlex.split(cmd), @@ -292,28 +295,27 @@ def interactive_simulation(self): universal_newlines=True ) else: - print("[TOGSim] Simulator is already running.") + logger.warning("[TOGSim] Simulator is already running.") def stop(self): if self.process: self.process.terminate() self.process.wait() self.process = None - print("[TOGSim] Simulator stopped.") + logger.info("[TOGSim] Simulator stopped.") def wait(self): if self.process: - print("[TOGSim] Waiting for simulation to complete...") + logger.info("[TOGSim] Waiting for simulation to complete...") self.quit() self.process.wait() self.process = None - print("[TOGSim] Simulation completed.") + logger.info("[TOGSim] Simulation completed.") def send_command(self, command): if self.process: try: - if extension_config.CONFIG_TORCHSIM_DEBUG_MODE: - print(command, flush=True) + logger.debug(command) self.process.stdin.write(command + '\n') self.process.stdin.flush() ret = self.process.stderr.readline().strip() @@ -321,11 +323,11 @@ def send_command(self, command): except BrokenPipeError: err = self.process.stderr.readlines() for line in err: - print(line) + logger.error(line.strip()) self.process = None exit(1) else: - print("Simulator is not running.") + logger.warning("Simulator is not running.") return None def launch(self, onnx_path, attribute_path, arrival_time=0, partion_id=0): @@ -440,7 +442,7 @@ def get_result_from_file(result_path): break if simulation_finished_idx == -1: - print(f"[TOGSim] Warning: Unable to parse the output file ({result_path}). The file may be improperly formatted.") + logger.warning(f"[TOGSim] Warning: Unable to parse the output file ({result_path}). The file may be improperly formatted.") return core_metrics, dram_channel_bw, avg_dram_bw, simulation_time total_stat_lines = lines[simulation_finished_idx:] diff --git a/experiments/BERT.py b/experiments/BERT.py index 5ccd3084..fd671833 100644 --- a/experiments/BERT.py +++ b/experiments/BERT.py @@ -48,7 +48,7 @@ def run_BERT(size, input_seq, config): input_seq = args.input_size result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"BERT_{size}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_LOG_PATH'] = result_path # only timing simulation os.environ['TORCHSIM_VALIDATION_MODE'] = "0" if 'pytorchsim_functional_mode' in os.environ: diff --git a/experiments/attention.py b/experiments/attention.py index 842f105a..211433f1 100644 --- a/experiments/attention.py +++ b/experiments/attention.py @@ -47,7 +47,7 @@ def attention(query, key, value): size_str = "x".join([str(i) for i in size]) result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"attention_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_LOG_PATH'] = result_path # only timing simulation os.environ['TORCHSIM_VALIDATION_MODE'] = "0" if 'pytorchsim_functional_mode' in os.environ: diff --git a/experiments/conv.py b/experiments/conv.py index 25952fb0..61f7ad80 100644 --- a/experiments/conv.py +++ b/experiments/conv.py @@ -48,7 +48,7 @@ def custom_conv2d(a, b, bias): size_str = "_".join([str(i) for i in size]) result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"CONV_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_LOG_PATH'] = result_path # only timing simulation os.environ['TORCHSIM_VALIDATION_MODE'] = "0" if 'pytorchsim_functional_mode' in os.environ: diff --git a/experiments/gemm.py b/experiments/gemm.py index 3090e331..0e1a15e4 100644 --- a/experiments/gemm.py +++ b/experiments/gemm.py @@ -31,7 +31,7 @@ def custom_matmul(a, b): import os import sys base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim') - config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml) + config = os.environ.get('TORCHSIM_CONFIG', default=f'{base_dir}/configs/systolic_ws_128x128_c2_simple_noc_tpuv4.yml') config_prefix = config.split('/')[-1].split('.')[0][9:] # extract config name from config path sys.path.append(base_dir) args = argparse.ArgumentParser() @@ -42,13 +42,10 @@ def custom_matmul(a, b): size_str = "x".join([str(i) for i in size]) result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"GEMM_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_LOG_PATH'] = result_path # only timing simulation os.environ['TORCHSIM_VALIDATION_MODE'] = "0" if 'pytorchsim_functional_mode' in os.environ: del os.environ['pytorchsim_functional_mode'] - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() run_matmul(size[0], size[1], size[2], config) diff --git a/experiments/layernorm.py b/experiments/layernorm.py index 9c9934a1..a6b16986 100644 --- a/experiments/layernorm.py +++ b/experiments/layernorm.py @@ -38,7 +38,7 @@ def run_layernorm(size, config): size_str = "x".join([str(i) for i in size]) result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"LayerNorm_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_LOG_PATH'] = result_path os.environ['TORCHSIM_FUSION_REDUCTION_REDUCTION'] = "0" # only timing simulation os.environ['TORCHSIM_VALIDATION_MODE'] = "0" diff --git a/experiments/resnet18.py b/experiments/resnet18.py index 5451e0f5..c7763d86 100644 --- a/experiments/resnet18.py +++ b/experiments/resnet18.py @@ -39,7 +39,7 @@ def run_resnet(batch, config): batch = args.batch result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet18_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_LOG_PATH'] = result_path os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" # only timing simulation os.environ['TORCHSIM_VALIDATION_MODE'] = "0" diff --git a/experiments/resnet50.py b/experiments/resnet50.py index 83d82db4..4e611541 100644 --- a/experiments/resnet50.py +++ b/experiments/resnet50.py @@ -39,7 +39,7 @@ def run_resnet(batch, config): batch = args.batch result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"resnet50_{batch}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_LOG_PATH'] = result_path os.environ['TORCHSIM_USE_TIMING_POOLING'] = "1" # only timing simulation os.environ['TORCHSIM_VALIDATION_MODE'] = "0" diff --git a/experiments/softmax.py b/experiments/softmax.py index 580d56ca..d30559f7 100644 --- a/experiments/softmax.py +++ b/experiments/softmax.py @@ -38,7 +38,7 @@ def run_softmax(size, config, dim=1): size_str = "x".join([str(i) for i in size]) result_path = os.path.join(base_dir, args.dump_path, config_prefix, f"Softmax_{size_str}_{datetime.datetime.now().strftime('%Y%m%d_%H%M%S')}") # setting environment variables - os.environ['TORCHSIM_DUMP_PATH'] = result_path + os.environ['TORCHSIM_LOG_PATH'] = result_path # only timing simulation os.environ['TORCHSIM_VALIDATION_MODE'] = "0" if 'pytorchsim_functional_mode' in os.environ: diff --git a/scripts/ILS_experiment/test_matmul.py b/scripts/ILS_experiment/test_matmul.py index 667dfc66..b0bc474c 100644 --- a/scripts/ILS_experiment/test_matmul.py +++ b/scripts/ILS_experiment/test_matmul.py @@ -52,15 +52,9 @@ def custom_matmul(bias, a, b): test_result("Addmm Forward", res, y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run matmul with given shape") parser.add_argument('--shape', type=str, default="(512,512,512)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() - test_matmul(device, *shape) + device = torch.device("npu:0") + test_matmul(device, *shape) \ No newline at end of file diff --git a/scripts/chiplet_prep.py b/scripts/chiplet_prep.py index 4f8b7f7c..2266d74c 100644 --- a/scripts/chiplet_prep.py +++ b/scripts/chiplet_prep.py @@ -1,10 +1,7 @@ import os import yaml -import shutil import argparse import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -60,20 +57,14 @@ def modify_file(dump_path, name, address_numa_stride=None, subgraph_map=None): print(f"Modified file saved to {output_file}") if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") parser = argparse.ArgumentParser(description='Process folder argument.') parser.add_argument('size', type=int, help='Folder value', default=256) args = parser.parse_args() folder = int(args.size) print("Taget size: ", folder) - folder_path = os.environ.get("TORCHSIM_DUMP_PATH") + folder_path = os.environ.get("TORCHSIM_LOG_PATH") print(folder_path) os.makedirs(folder_path, exist_ok=True) test_matmul(device, folder, folder, folder) diff --git a/scripts/chiplet_prep.sh b/scripts/chiplet_prep.sh index cddf1a58..f3bd1a1c 100755 --- a/scripts/chiplet_prep.sh +++ b/scripts/chiplet_prep.sh @@ -8,7 +8,7 @@ for size in "${sizes[@]}"; do export TORCHSIM_TILE_M=$((size / 2)) export TORCHSIM_TILE_K=$((size / 2)) export TORCHSIM_TILE_N=$((size / 2)) - export TORCHSIM_DUMP_PATH=$(pwd)/chiplet_result/$size + export TORCHSIM_LOG_PATH=$(pwd)/chiplet_result/$size python3 chiplet_prep.py $size #python3 chiplet_run.py $(pwd)/chiplet_result done \ No newline at end of file diff --git a/scripts/sparsity_experiment/run.sh b/scripts/sparsity_experiment/run.sh index 84c818ac..da9b73cc 100755 --- a/scripts/sparsity_experiment/run.sh +++ b/scripts/sparsity_experiment/run.sh @@ -1,4 +1,4 @@ -export TORCHSIM_DUMP_PATH=$(pwd)/result +export TORCHSIM_LOG_PATH=$(pwd)/result export SPIKE_DUMP_SPARSE_TILE=1 export TORCHSIM_FORCE_TIME_K=8 export TORCHSIM_FORCE_TIME_M=8 diff --git a/tests/Diffusion/test_diffusion.py b/tests/Diffusion/test_diffusion.py index c5170209..85eaba9f 100644 --- a/tests/Diffusion/test_diffusion.py +++ b/tests/Diffusion/test_diffusion.py @@ -8,6 +8,7 @@ from diffusers.models.unets.unet_2d_condition import UNet2DConditionModel from diffusers.models.upsampling import Upsample2D from diffusers.models.resnet import ResnetBlock2D +from diffusers.models.embeddings import Timesteps def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -313,7 +314,7 @@ def test_cross_attn_down_block2d( dual_cross_attention=False ): print(f"Testing CrossAttnDownBlock2D on device: {device}") - + # 1. Initialize the module on CPU cpu_block = CrossAttnDownBlock2D( in_channels=in_channels, @@ -338,7 +339,7 @@ def test_cross_attn_down_block2d( temb=temb_cpu, encoder_hidden_states=encoder_hidden_states_cpu, ) - + # 4. Initialize the module on the custom device device_block = cpu_block.to(device).eval() device_block = torch.compile(device_block, dynamic=False) @@ -347,7 +348,7 @@ def test_cross_attn_down_block2d( hidden_states_dev = hidden_states_cpu.to(device) temb_dev = temb_cpu.to(device) encoder_hidden_states_dev = encoder_hidden_states_cpu.to(device) - + # 6. Get the output from the custom device module with torch.no_grad(): dev_out, _ = device_block( @@ -442,9 +443,9 @@ def test_groupnorm( # 1. Initialize the module on CPU cpu_norm = torch.nn.GroupNorm( - num_groups=num_groups, - num_channels=channels, - eps=eps, + num_groups=num_groups, + num_channels=channels, + eps=eps, affine=True ).to("cpu").eval() @@ -462,13 +463,13 @@ def test_groupnorm( # 4. Initialize the module on the custom device device_norm = torch.nn.GroupNorm( - num_groups=num_groups, - num_channels=channels, - eps=eps, + num_groups=num_groups, + num_channels=channels, + eps=eps, affine=True ).to(device).eval() device_norm = torch.compile(device_norm, dynamic=False) - + # Copy the weights from the CPU module to ensure they are identical device_norm.weight.data.copy_(cpu_norm.weight.data) device_norm.bias.data.copy_(cpu_norm.bias.data) @@ -541,6 +542,89 @@ def test_upsample2d( print("Max diff >", torch.max(torch.abs(y_dev.cpu() - y_cpu)).item()) print("Upsample2D simulation done.") + +def test_flip_sin_to_cos_embedding( + device, + batch=1, + embedding_dim=256, + rtol=1e-4, + atol=1e-4, +): + def create_embeddings(timesteps, embedding_dim, scale=1.0, flip_sin_to_cos=False): + """ + Replicate the embedding creation logic from Timesteps class. + """ + half_dim = embedding_dim // 2 + exponent = -math.log(10000) * torch.arange(start=0, end=half_dim, dtype=torch.float32, device=timesteps.device) + exponent = exponent / half_dim + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + emb = scale * emb + + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + # flip sine and cosine embeddings + if flip_sin_to_cos: + new_emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + return emb, new_emb + return emb, emb + + g = torch.Generator().manual_seed(0) + timesteps_cpu = torch.randint(low=0, high=1000, size=(batch,), generator=g, dtype=torch.long) + + # Test with flip_sin_to_cos=True + with torch.no_grad(): + emb_flip_cpu = create_embeddings(timesteps_cpu, embedding_dim, flip_sin_to_cos=True) + + # Move to device and test + timesteps_dev = timesteps_cpu.to(device) + @torch.compile(dynamic=False) + def create_embeddings_compiled(timesteps, embedding_dim, scale=1.0, flip_sin_to_cos=False): + return create_embeddings(timesteps, embedding_dim, scale, flip_sin_to_cos) + + with torch.no_grad(): + emb_flip_dev = create_embeddings_compiled(timesteps_dev, embedding_dim, flip_sin_to_cos=True) + + # Verify flip case + test_result("Embedding (flip_sin_to_cos=True)", emb_flip_dev[0], emb_flip_cpu[0], rtol=rtol, atol=atol) + print("Max diff (flip) >", torch.max(torch.abs(emb_flip_dev[0].cpu() - emb_flip_cpu[0])).item()) + test_result("Embedding (flip_sin_to_cos=True)", emb_flip_dev[1], emb_flip_cpu[1], rtol=rtol, atol=atol) + print("Max diff (flip) >", torch.max(torch.abs(emb_flip_dev[1].cpu() - emb_flip_cpu[1])).item()) + + +def test_timesteps( + device, + batch=1, + num_channels=64, + flip_sin_to_cos=True, + downscale_freq_shift=1.0, + rtol=1e-4, + atol=1e-4, +): + print(f"Testing Timesteps on device: {device}") + + cpu_timesteps = Timesteps( + num_channels=num_channels, + flip_sin_to_cos=flip_sin_to_cos, + downscale_freq_shift=downscale_freq_shift, + ).to("cpu").eval() + + g = torch.Generator().manual_seed(0) + timesteps_cpu = torch.randint(low=0, high=1000, size=(batch,), generator=g, dtype=torch.long) + + with torch.no_grad(): + cpu_out = cpu_timesteps(timesteps_cpu) + + dev_timesteps = cpu_timesteps.to(device).eval() + dev_timesteps = torch.compile(dev_timesteps, dynamic=False) + + timesteps_dev = timesteps_cpu.to(device) + with torch.no_grad(): + dev_out = dev_timesteps(timesteps_dev) + + test_result("Timesteps", dev_out, cpu_out, rtol=rtol, atol=atol) + print("Max diff >", torch.max(torch.abs(dev_out.cpu() - cpu_out)).item()) + print("Timesteps simulation done.") + if __name__ == "__main__": parser = argparse.ArgumentParser(description="Run UNet (diffusers) test with comparison") parser.add_argument("--model", type=str, default="runwayml/stable-diffusion-v1-5", @@ -553,18 +637,18 @@ def test_upsample2d( args = parser.parse_args() sys.path.append(os.environ.get("TORCHSIM_DIR", "/workspace/PyTorchSim")) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_upsample2d(device) #test_groupnorm(device) #test_groupnorm(device, stride=[1, 1, 320*32, 320]) - #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=320) + #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=256, resnet_act_fn='silu') #test_resnetblock2d(device, in_channels=640, out_channels=320, temb_channels=1280) #test_cross_attn_down_block2d(device) #test_unet_mid_block2d_cross_attn(device) #test_cross_attn_up_block2d(device) + #test_flip_sin_to_cos_embedding(device) + #test_timesteps(device) test_unet2d_condition_model(device) #test_unet_conditional( # device=device, diff --git a/tests/Fusion/test_addmm_residual.py b/tests/Fusion/test_addmm_residual.py index ef753a67..917628e3 100644 --- a/tests/Fusion/test_addmm_residual.py +++ b/tests/Fusion/test_addmm_residual.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -38,14 +36,7 @@ def addmm_residual(a, b, c, d): y = addmm_residual(b2, x2, w2, r2) test_result("Addmm + Residual Fusion Forward", res, y) -if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() +if __name__ == "__main__": device = torch.device("npu:0") test_addmm_residual(device, 32, 32, 32) test_addmm_residual(device, 128, 128, 128) test_addmm_residual(device, 512, 512, 512) diff --git a/tests/Fusion/test_attention_fusion.py b/tests/Fusion/test_attention_fusion.py index 123376d1..ebbd3037 100644 --- a/tests/Fusion/test_attention_fusion.py +++ b/tests/Fusion/test_attention_fusion.py @@ -1,8 +1,5 @@ -import math import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -70,14 +67,7 @@ def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): test_result("MHA Forward", res, cpu_res) -if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() +if __name__ == "__main__": device = torch.device("npu:0") test_MHA(device) # test_Attention(device, head=16, seq=512, d_k=64) # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/tests/Fusion/test_bmm_reduction.py b/tests/Fusion/test_bmm_reduction.py index 4f4d3ad6..45e31dab 100644 --- a/tests/Fusion/test_bmm_reduction.py +++ b/tests/Fusion/test_bmm_reduction.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -38,13 +36,7 @@ def bmm(a, b): test_result("BMM Reduction Fusion reduction", res[1], y[1]) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_bmm_reduce(device) test_bmm_reduce(device, 12, 512) test_bmm_reduce(device, 4, 256) diff --git a/tests/Fusion/test_conv_fusion.py b/tests/Fusion/test_conv_fusion.py index 694f3bb9..bc200ff2 100644 --- a/tests/Fusion/test_conv_fusion.py +++ b/tests/Fusion/test_conv_fusion.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): message = f"|{name} Test Passed|" @@ -97,13 +95,7 @@ def custom_conv_bn_relu(a, b, bias, c, d, e, f): print("Max diff > ", torch.max(torch.abs(res.cpu() - out))) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") # Vanila test test_conv_residual(device, batch_size=3, in_channels=64, out_channels=64, input_size=28, kernel_size=3, stride=1, padding=1) diff --git a/tests/Fusion/test_matmul_activation.py b/tests/Fusion/test_matmul_activation.py index 2f1d014f..232ec98d 100644 --- a/tests/Fusion/test_matmul_activation.py +++ b/tests/Fusion/test_matmul_activation.py @@ -1,7 +1,5 @@ import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -69,13 +67,7 @@ def test_matmul_activation(device, batch_size=16, input_size=32, output_size=8, print("CPU output > ", cpu_y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_matmul_activation(device) test_matmul_activation(device, batch_size=32, input_size=32, output_size=32, activation_fn="sigmoid") test_matmul_activation(device, batch_size=42, input_size=42, output_size=42, activation_fn="sigmoid") diff --git a/tests/Fusion/test_matmul_reduction.py b/tests/Fusion/test_matmul_reduction.py index df8cf969..9b09214a 100644 --- a/tests/Fusion/test_matmul_reduction.py +++ b/tests/Fusion/test_matmul_reduction.py @@ -85,13 +85,7 @@ def matmul_fused(a, b, c, d): test_result("Matmul+residual+var_mean Fusion reduction", res[2], y[2]) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_matmul_reduce(device, 3072, 512, 768) test_matmul_var_mean(device) test_matmul_add_var_mean(device) diff --git a/tests/Fusion/test_matmul_scalar.py b/tests/Fusion/test_matmul_scalar.py index 0815bb90..d5a159ed 100644 --- a/tests/Fusion/test_matmul_scalar.py +++ b/tests/Fusion/test_matmul_scalar.py @@ -35,11 +35,5 @@ def matmul_fused(a, b, c): test_result("Matmul Scalar Fusion Forward", res, y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_matmul_scalar(device) diff --git a/tests/Fusion/test_matmul_vector.py b/tests/Fusion/test_matmul_vector.py index bf1bd513..f87f9432 100644 --- a/tests/Fusion/test_matmul_vector.py +++ b/tests/Fusion/test_matmul_vector.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -41,12 +39,6 @@ def matmul_fused(a, b, c, d): test_result("Matmul Vector Fusion Forward", res, y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import ExecutionEngine - module = ExecutionEngine.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_matmul_vector(device, size=[253, 123, 47], dim=0) test_matmul_vector(device, size=[253, 123, 47], dim=1) \ No newline at end of file diff --git a/tests/Fusion/test_prologue_fusion.py b/tests/Fusion/test_prologue_fusion.py index b27312a9..ecfd5fbf 100644 --- a/tests/Fusion/test_prologue_fusion.py +++ b/tests/Fusion/test_prologue_fusion.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -84,13 +82,7 @@ def bmm(a, b, c, d): test_result("BMM Element-wise Fusion Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_elem_broadcast_fusion(device) test_elem_fusion(device) test_elem_bmm_input_fusion(device, batch_size=4, m=512, n=512, k=64) diff --git a/tests/Fusion/test_transformer_fusion.py b/tests/Fusion/test_transformer_fusion.py index b1cceb2c..1581cd97 100644 --- a/tests/Fusion/test_transformer_fusion.py +++ b/tests/Fusion/test_transformer_fusion.py @@ -1,8 +1,6 @@ import math import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -199,13 +197,7 @@ def test_EncoderBlock_validation(head=12, embed_dim=768, input_seq=512): test_result("Encoder Block Validation", res, origin_res) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_MHA(device) test_EncoderBlock(device) # test_EncoderBlock_validation() diff --git a/tests/Llama/test_llama.py b/tests/Llama/test_llama.py index 443f3fc2..5e87b8e7 100644 --- a/tests/Llama/test_llama.py +++ b/tests/Llama/test_llama.py @@ -101,7 +101,8 @@ def run_rotary_embedding_test( vocab_size=8192, _attn_implementation = "sdpa" ) - base_rope = LlamaRotaryEmbedding(cfg) + # Pass dim explicitly to avoid config parsing issues + base_rope = LlamaRotaryEmbedding(dim=head_dim, max_position_embeddings=cfg.max_position_embeddings, base=cfg.rope_theta, config=cfg) cpu_rope = copy.deepcopy(base_rope) @@ -368,21 +369,19 @@ def run_llama_model_test( args = parser.parse_args() sys.path.append(os.environ.get("PYTORCHSIM_ROOT_PATH", "/workspace/PyTorchSim")) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_triu(device, size=(32, 128), diagonal=1) torch.compiler.is_compiling = lambda: True # FIXME. How to fix this? #run_rmsnorm_test(device) #run_rotary_embedding_test(device) - #run_decoder_layer_test( - # device=device, - # batch=args.batch, - # seq_len=args.seq_len, - # dtype=args.dtype, - # rtol=args.rtol, - # atol=args.atol, - #) + run_decoder_layer_test( + device=device, + batch=args.batch, + seq_len=args.seq_len, + dtype=args.dtype, + rtol=args.rtol, + atol=args.atol, + ) run_llama_model_test(device) #run_custom_llama_test( # device=device, diff --git a/tests/Mixtral_8x7B/test_attention.py b/tests/Mixtral_8x7B/test_attention.py index 58955928..57760370 100644 --- a/tests/Mixtral_8x7B/test_attention.py +++ b/tests/Mixtral_8x7B/test_attention.py @@ -1,7 +1,5 @@ import copy import torch -import torch._dynamo -import torch.utils.cpp_extension from model import Transformer, TransformerBlock, ModelArgs, Attention, FeedForward, KVCache, RMSNorm, precompute_freqs_cis, sample def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): @@ -159,13 +157,7 @@ def test_rmsnorm(device, seq=32): test_result("RMSNorm", res, cpu_res) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_rmsnorm(device, seq=1) #test_concat(device, size1=(1, 8, 64, 64), size2=(1,8,1,64), dim=2) test_decode(device, 32, 3) diff --git a/tests/MoE/test_moe.py b/tests/MoE/test_moe.py index ae16f0b0..f9c96aff 100644 --- a/tests/MoE/test_moe.py +++ b/tests/MoE/test_moe.py @@ -4,7 +4,6 @@ import copy import matplotlib.pyplot as plt - import torch import torch.nn as nn from torch.distributions.normal import Normal @@ -17,6 +16,20 @@ sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) +# FIXME. This is a Dynamo bug. Solution to avoid is_forward conflict during backward +def patch_metrics_context_update(): + """Patch MetricsContext.update to set overwrite=True by default.""" + from torch._dynamo.utils import get_metrics_context + ctx = get_metrics_context() + original_update = ctx.update + + def patched_update(values, overwrite=True): + """Patched version that sets overwrite=True by default.""" + return original_update(values, overwrite=True) + + # Patch the method + get_metrics_context().update = patched_update + def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): pass_message = f"|{name} Test Passed|" fail_message = f"|{name} Test Failed|" @@ -64,6 +77,7 @@ class SparseDispatcher(object): `Tensor`s for expert i only the batch elements for which `gates[b, i] > 0`. """ + @torch.compiler.disable(recursive=True) def __init__(self, num_experts, gates): """Create a SparseDispatcher.""" gates = gates.cpu() @@ -443,6 +457,7 @@ def test_moe(device): total_cpu_loss = cpu_loss + cpu_aux_loss total_loss.to(device) + patch_metrics_context_update() print("Backward Started!") total_loss.backward() total_cpu_loss.backward() @@ -469,6 +484,9 @@ def test_moe(device): print("\n") def train_moe(device): + # Patch CompileEventLogger to avoid metric conflicts + patch_metrics_context_update() + def perceptron(a, b, c): return a * b + c @@ -589,6 +607,9 @@ def weight_update(a, b, lr): plt.savefig('result.png') def train_moe_mnist(device): + # Patch CompileEventLogger to avoid metric conflicts + patch_metrics_context_update() + torch.manual_seed(0) batch_size = 32 input_size = 28*28 @@ -670,6 +691,9 @@ def train(model, device, train_loader, optimizer, epochs): plt.savefig(f'{name}_result.png') def train_moe_single_iteration(device, iter_idx, is_evaluation=0): + # Patch CompileEventLogger to avoid metric conflicts + patch_metrics_context_update() + # Training moe with mnist dataset for sinlge iteration torch.manual_seed(0) batch_size = 128 diff --git a/tests/test_activation.py b/tests/test_activation.py index 575fc7e8..dacc102e 100644 --- a/tests/test_activation.py +++ b/tests/test_activation.py @@ -23,9 +23,10 @@ def test_ReLU(device, size=(128, 128)): input = torch.randn(size) x1 = input.to(device=device) x2 = input.to("cpu") - opt_fn = torch.compile(dynamic=False)(torch.nn.functional.relu) + ReLU = torch.nn.ReLU() + opt_fn = torch.compile(dynamic=False)(ReLU) y = opt_fn(x1) - cpu_y = torch.nn.functional.relu(x2) + cpu_y = ReLU(x2) test_result("ReLU", y, cpu_y) def test_GeLU(device, size=(128, 128), approximate='none'): @@ -78,19 +79,14 @@ def test_SwiGLU(device, size=(128, 128)): test_result("SwiGLU", y, cpu_y) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_ReLU(device, (47, 10)) test_ReLU(device, (128, 128)) test_ReLU(device, (4071, 429)) diff --git a/tests/test_add.py b/tests/test_add.py index 118632d5..7a0d23d9 100644 --- a/tests/test_add.py +++ b/tests/test_add.py @@ -48,19 +48,14 @@ def vectoradd(a, b): test_result("VectorTensorAdd", res, out) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_vectoradd(device, (1, 1)) test_vectoradd(device, (47, 10)) test_vectoradd(device, (128, 128)) diff --git a/tests/test_batchnorm.py b/tests/test_batchnorm.py index 251805f5..065c0870 100644 --- a/tests/test_batchnorm.py +++ b/tests/test_batchnorm.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -33,13 +31,7 @@ def test_BatchNorm(device, size=(1, 16, 64, 64)): test_result("BatchNorm Forward", y, cpu_y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_BatchNorm(device) test_BatchNorm(device, size=(1,64, 32, 32)) test_BatchNorm(device, size=(1, 8, 4, 4)) diff --git a/tests/test_bmm.py b/tests/test_bmm.py index d90410db..02a6460e 100644 --- a/tests/test_bmm.py +++ b/tests/test_bmm.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -42,13 +40,7 @@ def bmm(a, b, bias): test_result("BMM Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_BMM(device) test_BMM(device, 2, 256, 128, 256) test_BMM(device, 2, 128, 256, 256) diff --git a/tests/test_cnn.py b/tests/test_cnn.py index 54225747..e6b01bbd 100644 --- a/tests/test_cnn.py +++ b/tests/test_cnn.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -49,11 +47,5 @@ def test_CNN(device): print("Max diff > ", torch.max(torch.abs(y.cpu() - cpu_y))) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_CNN(device) diff --git a/tests/test_conv2d.py b/tests/test_conv2d.py index e964319d..533a04db 100644 --- a/tests/test_conv2d.py +++ b/tests/test_conv2d.py @@ -1,6 +1,5 @@ import torch import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -36,23 +35,18 @@ def custom_conv2d(a, b, bias): print("Max diff > ", torch.max(torch.abs(res.cpu() - out))) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") torch._dynamo.config.cache_size_limit = 64 - test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) - test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) - test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) - test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) - test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0) + with torch.no_grad(): + test_conv2d(device, batch_size=8, in_channels=3, out_channels=32, input_size=32, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=2, padding=3) + test_conv2d(device, batch_size=2, in_channels=3, out_channels=64, input_size=32//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=4, in_channels=3, out_channels=64, input_size=64//2, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=256, input_size=13, kernel_size=5, stride=1, padding=2) + test_conv2d(device, batch_size=2, in_channels=128, out_channels=512, input_size=14, kernel_size=7, stride=1, padding=3) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=7, kernel_size=3, stride=2, padding=1) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=2, kernel_size=1, stride=1, padding=0) + test_conv2d(device, batch_size=1, in_channels=128, out_channels=256, input_size=14, kernel_size=1, stride=2, padding=0) + test_conv2d(device, batch_size=1, in_channels=3, out_channels=768, input_size=224, kernel_size=16,stride=16, padding=0) diff --git a/tests/test_eager.py b/tests/test_eager.py new file mode 100644 index 00000000..7a2df6e2 --- /dev/null +++ b/tests/test_eager.py @@ -0,0 +1,8 @@ +import torch + +if __name__ == "__main__": + device = torch.device("npu:0") + x = torch.zeros(10, 10).to(device) + y = torch.zeros(10, 10).to(device) + z = x + y + print(z.cpu()) \ No newline at end of file diff --git a/tests/test_exponent.py b/tests/test_exponent.py index e60f8407..20f0a143 100644 --- a/tests/test_exponent.py +++ b/tests/test_exponent.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -27,11 +25,5 @@ def exponent(a): test_result("exponent", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_exponent(device, size=(32, 32)) diff --git a/tests/test_gqa.py b/tests/test_gqa.py new file mode 100644 index 00000000..ba262fa6 --- /dev/null +++ b/tests/test_gqa.py @@ -0,0 +1,333 @@ +import sys +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch._dynamo +import argparse + + +def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): + if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): + message = f"|{name} Test Passed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + else: + message = f"|{name} Test Failed|" + print("-" * len(message)) + print(message) + print("-" * len(message)) + print("custom out: ", out.cpu()) + print("cpu out: ", cpu_out) + exit(1) + + +class GQAMultiheadAttention(nn.Module): + """ + Grouped Query Attention (GQA) implementation. + Query has num_heads, but key/value have num_kv_heads (num_kv_heads < num_heads). + """ + def __init__(self, embed_dim, num_heads, num_kv_heads=None, head_dim=None, bias=True, dropout=0.0): + super().__init__() + assert embed_dim % num_heads == 0 + if head_dim is None: + head_dim = embed_dim // num_heads + assert embed_dim == num_heads * head_dim + + # If num_kv_heads is not specified, use num_heads (standard MHA) + if num_kv_heads is None: + num_kv_heads = num_heads + + assert num_kv_heads <= num_heads + assert embed_dim % num_kv_heads == 0 + + self.embed_dim = embed_dim + self.num_heads = num_heads + self.num_kv_heads = num_kv_heads + self.head_dim = head_dim + self.dropout = dropout + + # QKV projection: Q has embed_dim, K and V have kv_embed_dim each + kv_embed_dim = num_kv_heads * head_dim + total_qkv_dim = embed_dim + 2 * kv_embed_dim + + self.qkv_proj = nn.Linear(embed_dim, total_qkv_dim, bias=bias) + self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias) + + def forward(self, query, key=None, value=None, attn_mask=None, need_weights=False): + """ + Args: + query: [batch, seq_len, embed_dim] or [seq_len, batch, embed_dim] + key: optional, same shape as query + value: optional, same shape as query + attn_mask: optional attention mask + need_weights: whether to return attention weights + """ + # For compatibility with nn.MultiheadAttention API + if key is None: + key = query + if value is None: + value = query + + # Handle batch_first vs batch_second + if query.dim() == 3: + batch_first = True + batch_size, seq_len, _ = query.shape + else: + batch_first = False + seq_len, batch_size, _ = query.shape + query = query.transpose(0, 1) + key = key.transpose(0, 1) + value = value.transpose(0, 1) + + # Project QKV + # Use query for QKV projection (standard MHA/GQA pattern) + qkv = self.qkv_proj(query) # [batch, seq_len, total_qkv_dim] + + # Split into Q, K, V + kv_embed_dim = self.num_kv_heads * self.head_dim + q = qkv[:, :, :self.embed_dim] # [batch, seq_len, embed_dim] + k = qkv[:, :, self.embed_dim:self.embed_dim + kv_embed_dim] # [batch, seq_len, kv_embed_dim] + v = qkv[:, :, self.embed_dim + kv_embed_dim:] # [batch, seq_len, kv_embed_dim] + + # Reshape to multi-head format + q = q.view(batch_size, seq_len, self.num_heads, self.head_dim) # [batch, seq_len, num_heads, head_dim] + k = k.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) # [batch, seq_len, num_kv_heads, head_dim] + v = v.view(batch_size, seq_len, self.num_kv_heads, self.head_dim) # [batch, seq_len, num_kv_heads, head_dim] + + # Transpose for attention: [batch, num_heads, seq_len, head_dim] + q = q.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] + k = k.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] + v = v.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] + + # Scaled dot product attention with GQA support + # enable_gqa=True allows different number of heads for Q vs K/V + attn_output = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.dropout if self.training else 0.0, + is_causal=False, + enable_gqa=(self.num_kv_heads < self.num_heads) + ) # [batch, num_heads, seq_len, head_dim] + + # Reshape back: [batch, num_heads, seq_len, head_dim] -> [batch, seq_len, embed_dim] + attn_output = attn_output.transpose(1, 2) # [batch, seq_len, num_heads, head_dim] + attn_output = attn_output.contiguous().view(batch_size, seq_len, self.embed_dim) + + # Output projection + output = self.out_proj(attn_output) # [batch, seq_len, embed_dim] + + if not batch_first: + output = output.transpose(0, 1) # [seq_len, batch, embed_dim] + + if need_weights: + # Compute attention weights for return + # This is simplified - in practice you'd want the actual attention weights + attn_weights = None + return output, attn_weights + else: + return output + + +def test_gqa_attention(device, batch=1, seq_len=32, embed_dim=768, num_heads=12, num_kv_heads=4): + """ + Test Grouped Query Attention (GQA) where num_kv_heads < num_heads. + + Args: + device: target device + batch: batch size + seq_len: sequence length + embed_dim: embedding dimension + num_heads: number of query heads + num_kv_heads: number of key/value heads (should be <= num_heads) + """ + print(f"Testing GQA Attention (batch={batch}, seq_len={seq_len}, embed_dim={embed_dim}, " + f"num_heads={num_heads}, num_kv_heads={num_kv_heads})") + + # Create GQA model + gqa = GQAMultiheadAttention( + embed_dim=embed_dim, + num_heads=num_heads, + num_kv_heads=num_kv_heads, + bias=True, + dropout=0.0 + ).eval() + + # Initialize weights + torch.nn.init.normal_(gqa.qkv_proj.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(gqa.qkv_proj.bias, mean=0.0, std=0.02) + torch.nn.init.normal_(gqa.out_proj.weight, mean=0.0, std=0.02) + torch.nn.init.normal_(gqa.out_proj.bias, mean=0.0, std=0.02) + + # Create input + x = torch.randn(batch, seq_len, embed_dim) + query = x.clone() + key = x.clone() + value = x.clone() + + # Run on custom device + gqa_device = gqa.to(device) + q1, k1, v1 = query.to(device), key.to(device), value.to(device) + + compiled_gqa = torch.compile(gqa_device, dynamic=False) + with torch.no_grad(): + out_device = compiled_gqa(q1, k1, v1) + + # Run on CPU + gqa_cpu = gqa.cpu() + q2, k2, v2 = query.cpu(), key.cpu(), value.cpu() + with torch.no_grad(): + out_cpu = gqa_cpu(q2, k2, v2) + + test_result("GQA Attention", out_device, out_cpu) + print("Max diff > ", torch.max(torch.abs(out_device.cpu() - out_cpu))) + print("GQA Attention Simulation Done") + + +def test_standard_mha_via_gqa(device, batch=1, seq_len=32, embed_dim=768, num_heads=12): + """ + Test standard Multi-Head Attention using GQA with num_kv_heads == num_heads. + This should behave the same as standard MHA. + """ + print(f"Testing Standard MHA via GQA (batch={batch}, seq_len={seq_len}, " + f"embed_dim={embed_dim}, num_heads={num_heads})") + + test_gqa_attention(device, batch, seq_len, embed_dim, num_heads, num_kv_heads=num_heads) + + +def test_repeat_interleave_compilation(device, batch=1, seq_len=32, embed_dim=768, num_heads=12, num_kv_heads=4): + """ + Test that repeat_interleave operation compiles and works correctly using scaled_dot_product_attention implementation. + + This test uses the exact implementation from F.scaled_dot_product_attention to verify + that repeat_interleave works correctly when enable_gqa=True. + + Args: + device: target device + batch: batch size + seq_len: sequence length + embed_dim: embedding dimension + num_heads: number of query heads + num_kv_heads: number of key/value heads (should be < num_heads) + """ + import math + + print(f"Testing repeat_interleave compilation using scaled_dot_product_attention implementation " + f"(batch={batch}, seq_len={seq_len}, embed_dim={embed_dim}, " + f"num_heads={num_heads}, num_kv_heads={num_kv_heads})") + + head_dim = embed_dim // num_heads + assert num_kv_heads < num_heads, "num_kv_heads must be less than num_heads for GQA" + + # Create Q, K, V tensors + # Q: [batch, num_heads, seq_len, head_dim] + # K, V: [batch, num_kv_heads, seq_len, head_dim] + q = torch.randn(batch, num_heads, seq_len, head_dim) + k = torch.randn(batch, num_kv_heads, seq_len, head_dim) + v = torch.randn(batch, num_kv_heads, seq_len, head_dim) + + # Move to device + q_device = q.to(device) + k_device = k.to(device) + v_device = v.to(device) + + # Implementation from F.scaled_dot_product_attention + def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor: + L, S = query.size(-2), key.size(-2) + scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale + attn_bias = torch.zeros(L, S, dtype=query.dtype, device=query.device) + if is_causal: + assert attn_mask is None + temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0) + attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf")) + attn_bias.to(query.dtype) + + if attn_mask is not None: + if attn_mask.dtype == torch.bool: + attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf")) + else: + attn_bias = attn_mask + attn_bias + + if enable_gqa: + key = key.repeat_interleave(query.size(-3)//key.size(-3), -3) + value = value.repeat_interleave(query.size(-3)//value.size(-3), -3) + + attn_weight = query @ key.transpose(-2, -1) * scale_factor + attn_weight += attn_bias + attn_weight = torch.softmax(attn_weight, dim=-1) + return attn_weight, value, attn_weight @ value + + # Compile the function + compiled_attn = torch.compile(scaled_dot_product_attention, dynamic=False) + + # Run on custom device with enable_gqa=True + with torch.no_grad(): + output_device = compiled_attn(q_device, k_device, v_device, + attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=True) + + # Run on CPU for comparison + q_cpu = q.cpu() + k_cpu = k.cpu() + v_cpu = v.cpu() + with torch.no_grad(): + output_cpu = scaled_dot_product_attention(q_cpu, k_cpu, v_cpu, + attn_mask=None, dropout_p=0.0, + is_causal=False, scale=None, enable_gqa=True) + + # Compare results + test_result("repeat_interleave in scaled_dot_product_attention", output_device[0], output_cpu[0]) + print("Max diff > ", torch.max(torch.abs(output_device[0].cpu() - output_cpu[0]))) + test_result("repeat_interleave in scaled_dot_product_attention", output_device[1], output_cpu[1]) + print("Max diff > ", torch.max(torch.abs(output_device[1].cpu() - output_cpu[1]))) + test_result("repeat_interleave in scaled_dot_product_attention", output_device[2], output_cpu[2]) + print("Max diff > ", torch.max(torch.abs(output_device[2].cpu() - output_cpu[2]))) + print("repeat_interleave compilation test Done") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--device", type=str, default="npu", help="Device to use") + parser.add_argument("--batch", type=int, default=1, help="Batch size") + parser.add_argument("--seq_len", type=int, default=32, help="Sequence length") + parser.add_argument("--embed_dim", type=int, default=768, help="Embedding dimension") + parser.add_argument("--num_heads", type=int, default=8, help="Number of query heads") + parser.add_argument("--num_kv_heads", type=int, default=4, help="Number of key/value heads") + parser.add_argument("--test_standard", action="store_true", help="Also test standard MHA via GQA") + parser.add_argument("--test_repeat_interleave", action="store_true", help="Test repeat_interleave compilation") + + args = parser.parse_args() + + sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) + device = torch.device("npu:0") + + test_repeat_interleave_compilation( + device=device, + batch=args.batch, + seq_len=args.seq_len, + embed_dim=args.embed_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads + ) + + # Test GQA + test_gqa_attention( + device=device, + batch=args.batch, + seq_len=args.seq_len, + embed_dim=args.embed_dim, + num_heads=args.num_heads, + num_kv_heads=args.num_kv_heads + ) + + # Optionally test standard MHA via GQA + # if args.test_standard: + # test_standard_mha_via_gqa( + # device=args.device, + # batch=args.batch, + # seq_len=args.seq_len, + # embed_dim=args.embed_dim, + # num_heads=args.num_heads + # ) diff --git a/tests/test_indirect_access.py b/tests/test_indirect_access.py index d103ee1b..95167d1e 100644 --- a/tests/test_indirect_access.py +++ b/tests/test_indirect_access.py @@ -1,7 +1,5 @@ import torch import copy -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -79,13 +77,7 @@ def vectoradd(a, idx, b): test_result("Indirect VectorAdd", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_scatter_full(device) test_scatter_full(device, size=(2048, 2048)) test_scatter_add(device) diff --git a/tests/test_layernorm.py b/tests/test_layernorm.py index 28e38d37..3db27dc5 100644 --- a/tests/test_layernorm.py +++ b/tests/test_layernorm.py @@ -31,18 +31,14 @@ def test_LayerNorm(device, size=(64, 64)): test_result("LayerNorm Forward", y, cpu_y) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, help="Shape of the tensor in the format (batch_size, features)", default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() - #test_LayerNorm(device) - test_LayerNorm(device, shape) + device = torch.device("npu:0") + with torch.no_grad(): + #test_LayerNorm(device) + test_LayerNorm(device, shape) diff --git a/tests/test_matmul.py b/tests/test_matmul.py index cd30bd30..a5bdf422 100644 --- a/tests/test_matmul.py +++ b/tests/test_matmul.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -90,13 +88,7 @@ def custom_linear(a, b, bias): test_result("Linear Forward", res, y) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_matmul(device, 32, 32, 32) test_matmul(device, 128, 128, 128) test_matmul(device, 256, 256, 256) diff --git a/tests/test_mlp.py b/tests/test_mlp.py index 423d6e8e..e3f79561 100644 --- a/tests/test_mlp.py +++ b/tests/test_mlp.py @@ -1,7 +1,5 @@ import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -105,13 +103,7 @@ def test_optimizer(device): test_result("Optimizer", model.linear1.weight, cpu_model.linear1.weight) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_mlp(device) test_mlp_inf(device, batch_size=1, input_size=256, hidden_size=512, output_size=256) test_mlp_inf(device, batch_size=8, input_size=256, hidden_size=512, output_size=256) diff --git a/tests/test_pool.py b/tests/test_pool.py index f5505dba..2848e04b 100644 --- a/tests/test_pool.py +++ b/tests/test_pool.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -43,13 +41,7 @@ def avgpool(a): test_result("Avgpool Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_maxpool(device, b=1, c=8, h=16, w=16) #test_maxpool(device, b=1, c=8, h=112, w=112) test_avgpool(device, b=1, c=512, h=7, w=7) diff --git a/tests/test_reduce.py b/tests/test_reduce.py index 4781112d..07f8fef2 100644 --- a/tests/test_reduce.py +++ b/tests/test_reduce.py @@ -37,19 +37,14 @@ def reduce_sum(a, dim, keepdim): test_result("ReduceMax", res, out) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, default="(128,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_reduce_sum(device, (29, 47), 1, keepdim=True) test_reduce_sum(device, (17, 68), 0, keepdim=True) test_reduce_sum(device, (327, 447), 1, keepdim=True) diff --git a/tests/test_resnet.py b/tests/test_resnet.py index c83f13ba..2459cd58 100644 --- a/tests/test_resnet.py +++ b/tests/test_resnet.py @@ -49,7 +49,5 @@ def test_resnet(device, batch=1, model_type='resnet18'): args = args.parse_args() sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_resnet(device, model_type=args.model_type) diff --git a/tests/test_single_perceptron.py b/tests/test_single_perceptron.py index beab1c54..7d3401a3 100644 --- a/tests/test_single_perceptron.py +++ b/tests/test_single_perceptron.py @@ -1,7 +1,5 @@ import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -78,11 +76,5 @@ def weight_update(a, b, lr): # plt.savefig('result.png') if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_single_perceptron(device) diff --git a/tests/test_softmax.py b/tests/test_softmax.py index e6e8cc1e..2dca97b7 100644 --- a/tests/test_softmax.py +++ b/tests/test_softmax.py @@ -42,25 +42,29 @@ def test_softmax(device, size=(128, 128), dim=1): #cpu_y = softmax3(x2, cpu_max, cpu_sum) #test_result("Softmax", y, cpu_y) - opt_fn = torch.compile(dynamic=False)(torch.nn.functional.softmax) - y = opt_fn(x1, dim=dim) + class SoftmaxModule(torch.nn.Module): + def __init__(self, dim): + super().__init__() + self.dim = dim + + def forward(self, x): + return torch.nn.functional.softmax(x, dim=self.dim) + + softmax_module = SoftmaxModule(dim=dim).to(device) + opt_fn = torch.compile(dynamic=False)(softmax_module) + y = opt_fn(x1) cpu_y = torch.nn.functional.softmax(x2, dim=dim) test_result("Softmax", y, cpu_y) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, help="Shape of the tensor in the format (batch_size, features)", default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_softmax(device, size=(64, 128)) test_softmax(device, size=(64, 128), dim=0) test_softmax(device, size=(256, 128)) diff --git a/tests/test_sparsity.py b/tests/test_sparsity.py index a2493673..eaa7c63c 100644 --- a/tests/test_sparsity.py +++ b/tests/test_sparsity.py @@ -96,9 +96,7 @@ def test_mlp_inf(device, batch_size=64, input_size=64, hidden_size=32, output_si ) args = parser.parse_args() - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_dec_inf(device, sparsity=args.sparsity, block=args.block) test_mlp_inf(device, batch_size=32, input_size=784, hidden_size=512, output_size=256, sparsity=args.sparsity, block=args.block) diff --git a/tests/test_stonne.py b/tests/test_stonne.py index 04ad05a8..ac26c273 100644 --- a/tests/test_stonne.py +++ b/tests/test_stonne.py @@ -54,7 +54,5 @@ def test_sparse_mm(device, input_size=128, hidden_size=128, output_size=128, spa args = parser.parse_args() sys.path.append(os.environ.get('TORCHSIM_DIR', default='/root/workspace/PyTorchSim')) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_sparse_mm(device, args.sz, args.sz, args.sz, args.sparsity) \ No newline at end of file diff --git a/tests/test_topk.py b/tests/test_topk.py index 0d5c08ec..c8565310 100644 --- a/tests/test_topk.py +++ b/tests/test_topk.py @@ -38,10 +38,7 @@ def topk_fn(a): test_result("TopK/indices", res_indices, ref_indices) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, default="(512,768)") diff --git a/tests/test_transcendental.py b/tests/test_transcendental.py index 38c2f4f6..34546539 100644 --- a/tests/test_transcendental.py +++ b/tests/test_transcendental.py @@ -63,19 +63,14 @@ def cos(a): test_result("Cos", res, out) if __name__ == "__main__": - import os - import sys import argparse - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) parser = argparse.ArgumentParser(description="Run LayerNorm test with dynamic shape") parser.add_argument('--shape', type=str, default="(512,768)") args = parser.parse_args() shape = tuple(map(int, args.shape.strip('()').split(','))) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_tanh(device) test_exp(device) test_erf(device) diff --git a/tests/test_transformer.py b/tests/test_transformer.py index a3ac55d7..2b7f308c 100644 --- a/tests/test_transformer.py +++ b/tests/test_transformer.py @@ -1,8 +1,6 @@ import math import copy import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -115,13 +113,7 @@ def test_MHA(device, num_heads=12, embed_dim=768, input_seq=512): test_result("MHA Forward", res, cpu_res) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_EncoderBlock(device) # test_Attention(device, head=16, seq=512, d_k=64) # test_MHA(device, num_heads=12, embed_dim=768) diff --git a/tests/test_transpose2D.py b/tests/test_transpose2D.py index af5aacf7..4e9807ce 100644 --- a/tests/test_transpose2D.py +++ b/tests/test_transpose2D.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -42,13 +40,7 @@ def transpose(a, b): test_result("Transpose2 Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_Transpose2D(device, [64, 156]) test_Transpose2D_2(device, [16, 64]) test_Transpose2D(device, [640, 256]) diff --git a/tests/test_transpose3D.py b/tests/test_transpose3D.py index d6c1092d..e4d4e952 100644 --- a/tests/test_transpose3D.py +++ b/tests/test_transpose3D.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -57,13 +55,7 @@ def transpose(a, b): test_result("Transpose 3D Forward", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_Transpose3D_1(device, [62, 34, 44]) test_Transpose3D_1(device, [62, 134, 144]) test_Transpose3D_2(device, [62, 34, 44]) diff --git a/tests/test_vectorops.py b/tests/test_vectorops.py index ed895171..90e9c0f5 100644 --- a/tests/test_vectorops.py +++ b/tests/test_vectorops.py @@ -1,14 +1,7 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") # Target shape seq_list = [1,128,512,2048,8192] diff --git a/tests/test_view3D_2D.py b/tests/test_view3D_2D.py index 148fe8fa..cc7b5e41 100644 --- a/tests/test_view3D_2D.py +++ b/tests/test_view3D_2D.py @@ -1,6 +1,4 @@ import torch -import torch._dynamo -import torch.utils.cpp_extension def test_result(name, out, cpu_out, rtol=1e-4, atol=1e-4): if torch.allclose(out.cpu(), cpu_out, rtol=rtol, atol=atol): @@ -40,13 +38,7 @@ def view2D_3D(a): test_result("view 2D->3D", res, out) if __name__ == "__main__": - import os - import sys - sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") test_view3D_2D(device) test_view3D_2D(device, [12, 512, 64]) test_view2D_3D(device, size=(512, 1024), h=16, d_k=64) diff --git a/tests/test_vit.py b/tests/test_vit.py index aeb4f148..6149166d 100644 --- a/tests/test_vit.py +++ b/tests/test_vit.py @@ -202,9 +202,7 @@ def test_encoder_block_with_class_token( shape = tuple(map(int, args.shape.strip('()').split(','))) sys.path.append(os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')) - from Scheduler.scheduler import PyTorchSimRunner - module = PyTorchSimRunner.setup_device() - device = module.custom_device() + device = torch.device("npu:0") #test_multihead_attention(device) #test_encoder_block(device, seq_len=197) #test_encoder_block_with_class_token(device, seq_len=196) diff --git a/tutorial/session2/Hands_on.ipynb b/tutorial/session2/Hands_on.ipynb index 2d5a5cdc..2964f293 100644 --- a/tutorial/session2/Hands_on.ipynb +++ b/tutorial/session2/Hands_on.ipynb @@ -32,7 +32,7 @@ "import torch._dynamo\n", "import torch.utils.cpp_extension\n", "base_dir = os.environ.get('TORCHSIM_DIR', default='/workspace/PyTorchSim')\n", - "os.environ['TORCHSIM_DUMP_PATH']=os.path.join(os.getcwd(), \"togsim_results\")\n", + "os.environ['TORCHSIM_LOG_PATH']=os.path.join(os.getcwd(), \"togsim_results\")\n", "sys.path.append(base_dir)\n", "\n", "from Scheduler.scheduler import PyTorchSimRunner\n",