From 0a54aef6bda850949df53b04f3ef046415f4bc9f Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 19 Jan 2026 14:07:28 +0530 Subject: [PATCH 1/7] [VULKAN][ADRENO] Vulkan support for Adreno GPU inline with OpenCL support This PR enabled Vulkan backend for Adreno GPU along side with OpenCL. With this Adreno GPU can be used with Vulkan backend resulting similar performance line OpenCL. This is foundation to add Vulkan specific extensions in near future like co-op matmul ...etc. Common Relax and TIR pipe line is used here with codegen and runtime being different. --- CMakeLists.txt | 4 + cmake/modules/GTestConfig.cmake | 52 ++ cmake/modules/LibInfo.cmake | 1 + cmake/modules/OpenCL.cmake | 38 +- cmake/modules/Vulkan.cmake | 10 + cmake/utils/FindVulkan.cmake | 55 +- python/tvm/relax/pipeline.py | 2 +- python/tvm/testing/utils.py | 19 +- src/runtime/file_utils.cc | 4 + src/runtime/meta_data.h | 1 + src/runtime/vulkan/vulkan_buffer.cc | 59 +- src/runtime/vulkan/vulkan_buffer.h | 23 +- src/runtime/vulkan/vulkan_device.cc | 8 + src/runtime/vulkan/vulkan_device.h | 18 +- src/runtime/vulkan/vulkan_device_api.cc | 533 ++++++++++++++++-- src/runtime/vulkan/vulkan_device_api.h | 81 ++- src/runtime/vulkan/vulkan_image.cc | 167 ++++++ src/runtime/vulkan/vulkan_image.h | 133 +++++ src/runtime/vulkan/vulkan_resource.cc | 88 +++ src/runtime/vulkan/vulkan_resource.h | 131 +++++ src/runtime/vulkan/vulkan_stream.h | 2 + src/runtime/vulkan/vulkan_timer.cc | 99 ++++ src/runtime/vulkan/vulkan_timer.h | 102 ++++ src/runtime/vulkan/vulkan_wrapped_func.cc | 143 +++-- src/support/libinfo.cc | 5 + src/target/build_common.h | 8 + src/target/spirv/build_vulkan.cc | 18 + src/target/spirv/codegen_spirv.cc | 86 ++- src/target/spirv/ir_builder.cc | 130 +++++ src/target/spirv/ir_builder.h | 48 +- src/target/target_kind.cc | 7 +- tests/cpp-runtime/vulkan/texture_copy_test.cc | 158 ++++++ tests/cpp-runtime/vulkan/vulkan_timer_test.cc | 63 +++ tests/python/relax/texture/test_ops.py | 112 ++-- tests/scripts/ci.py | 1 + tests/scripts/task_config_build_adreno.sh | 3 + tests/scripts/task_config_build_gpu.sh | 1 + tests/scripts/task_python_adreno.sh | 5 +- tests/scripts/task_vulkan_cpp_unittest.sh | 38 ++ 39 files changed, 2201 insertions(+), 255 deletions(-) create mode 100644 cmake/modules/GTestConfig.cmake create mode 100644 src/runtime/vulkan/vulkan_image.cc create mode 100644 src/runtime/vulkan/vulkan_image.h create mode 100644 src/runtime/vulkan/vulkan_resource.cc create mode 100644 src/runtime/vulkan/vulkan_resource.h create mode 100644 src/runtime/vulkan/vulkan_timer.cc create mode 100644 src/runtime/vulkan/vulkan_timer.h create mode 100644 tests/cpp-runtime/vulkan/texture_copy_test.cc create mode 100644 tests/cpp-runtime/vulkan/vulkan_timer_test.cc create mode 100644 tests/scripts/task_vulkan_cpp_unittest.sh diff --git a/CMakeLists.txt b/CMakeLists.txt index ec7bd6c51453..87b1d43748ee 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -33,6 +33,7 @@ tvm_option(USE_OPENCL "Build with OpenCL" OFF) tvm_option(USE_OPENCL_ENABLE_HOST_PTR "Enable OpenCL memory object access to host" OFF) tvm_option(USE_OPENCL_GTEST "Path to OpenCL specific gtest version for runtime cpp tests." /path/to/opencl/gtest) tvm_option(USE_VULKAN "Build with Vulkan" OFF) +tvm_option(USE_VULKAN_GTEST "Path to Vulkan specific gtest version for runtime cpp tests." /path/to/vulkan/gtest) # Whether to use spirv-tools.and SPIRV-Headers from Khronos github or gitlab. @@ -454,6 +455,9 @@ set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD_REQUIRED ON) set(CMAKE_CUDA_STANDARD 17) +#include centralized gtest setup +include(cmake/modules/GTestConfig.cmake) + # Module rules include(cmake/modules/CUDA.cmake) include(cmake/modules/Hexagon.cmake) # This must come before logging.cmake diff --git a/cmake/modules/GTestConfig.cmake b/cmake/modules/GTestConfig.cmake new file mode 100644 index 000000000000..beb0aa12a9b2 --- /dev/null +++ b/cmake/modules/GTestConfig.cmake @@ -0,0 +1,52 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set(Build_GTests OFF) +if(NOT TARGET gtest) + unset(runtime_gtests) + if(DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST}) + set(runtime_gtests ${USE_OPENCL_GTEST}) + elseif(DEFINED USE_VULKAN_GTEST AND EXISTS ${USE_VULKAN_GTEST}) + set(runtime_gtests ${USE_VULKAN_GTEST}) + elseif(ANDROID_ABI AND DEFINED ENV{ANDROID_NDK_HOME}) + set(GOOGLETEST_ROOT $ENV{ANDROID_NDK_HOME}/sources/third_party/googletest) + add_library(gtest_main STATIC + ${GOOGLETEST_ROOT}/src/gtest_main.cc + ${GOOGLETEST_ROOT}/src/gtest-all.cc) + target_include_directories(gtest_main PRIVATE ${GOOGLETEST_ROOT}) + target_include_directories(gtest_main PUBLIC ${GOOGLETEST_ROOT}/include) + set(Build_GTests ON) + message(STATUS "Using gtest from Android NDK") + return() + else() + message(STATUS "No valid GTest path found, skipping GTest configuration") + return() + endif() + + # Configure if runtime_gtests is valid + if(runtime_gtests AND EXISTS ${runtime_gtests}) + include(FetchContent) + FetchContent_Declare(googletest SOURCE_DIR "${runtime_gtests}") + set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) + FetchContent_MakeAvailable(googletest) + install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) + set(Build_GTests ON) + else() + set(Build_GTests OFF) + return() + endif() +endif() diff --git a/cmake/modules/LibInfo.cmake b/cmake/modules/LibInfo.cmake index f286d9f7d9fa..0a83a59fbf05 100644 --- a/cmake/modules/LibInfo.cmake +++ b/cmake/modules/LibInfo.cmake @@ -120,6 +120,7 @@ function(add_lib_info src_file) TVM_INFO_USE_THRUST="${USE_THRUST}" TVM_INFO_USE_CURAND="${USE_CURAND}" TVM_INFO_USE_VULKAN="${USE_VULKAN}" + TVM_INFO_USE_VULKAN_GTEST="${USE_VULKAN_GTEST}" TVM_INFO_USE_CLML="${USE_CLML}" TVM_INFO_USE_CLML_GRAPH_EXECUTOR="${USE_CLML_GRAPH_EXECUTOR}" TVM_INFO_USE_TVM_CLML_VERSION="${CLML_VERSION_MAJOR}" diff --git a/cmake/modules/OpenCL.cmake b/cmake/modules/OpenCL.cmake index c5c8eae721fa..32520f044e89 100644 --- a/cmake/modules/OpenCL.cmake +++ b/cmake/modules/OpenCL.cmake @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - if(USE_OPENCL) tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc) list(APPEND COMPILER_SRCS src/target/spirv/spirv_utils.cc) @@ -35,36 +34,15 @@ if(USE_OPENCL) list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenCL_LIBRARIES}) endif() - if(DEFINED USE_OPENCL_GTEST) - if(EXISTS ${USE_OPENCL_GTEST}) - include(FetchContent) - FetchContent_Declare(googletest SOURCE_DIR "${USE_OPENCL_GTEST}") - set(gtest_force_shared_crt ON CACHE BOOL "" FORCE) - FetchContent_MakeAvailable(googletest) - install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX}) - - message(STATUS "Found OpenCL gtest at ${USE_OPENCL_GTEST}") - set(Build_OpenCL_GTests ON) - elseif (ANDROID_ABI AND DEFINED ENV{ANDROID_NDK_HOME}) - set(GOOGLETEST_ROOT $ENV{ANDROID_NDK_HOME}/sources/third_party/googletest) - add_library(gtest_main STATIC ${GOOGLETEST_ROOT}/src/gtest_main.cc ${GOOGLETEST_ROOT}/src/gtest-all.cc) - target_include_directories(gtest_main PRIVATE ${GOOGLETEST_ROOT}) - target_include_directories(gtest_main PUBLIC ${GOOGLETEST_ROOT}/include) - message(STATUS "Using gtest from Android NDK") - set(Build_OpenCL_GTests ON) - endif() - - if(Build_OpenCL_GTests) - message(STATUS "Building OpenCL-Gtests") - tvm_file_glob(GLOB_RECURSE OPENCL_TEST_SRCS - "tests/cpp-runtime/opencl/*.cc" - ) - add_executable(opencl-cpptest ${OPENCL_TEST_SRCS}) - target_link_libraries(opencl-cpptest PRIVATE gtest_main tvm_runtime ${OpenCL_LIBRARIES}) - else() - message(STATUS "Couldn't build OpenCL-Gtests") - endif() + if(Build_GTests) + message(STATUS "Building OpenCL GTests") + tvm_file_glob(GLOB_RECURSE OPENCL_TEST_SRCS "tests/cpp-runtime/opencl/*.cc") + add_executable(opencl-cpptest ${OPENCL_TEST_SRCS}) + target_link_libraries(opencl-cpptest PRIVATE gtest_main tvm_runtime ${OpenCL_LIBRARIES}) + else() + message(STATUS "Couldn't build OpenCL-Gtests") endif() + list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS}) if(USE_OPENCL_ENABLE_HOST_PTR) add_definitions(-DOPENCL_ENABLE_HOST_PTR) diff --git a/cmake/modules/Vulkan.cmake b/cmake/modules/Vulkan.cmake index 1f303f3a032b..35994c6dc92b 100644 --- a/cmake/modules/Vulkan.cmake +++ b/cmake/modules/Vulkan.cmake @@ -30,6 +30,16 @@ if(USE_VULKAN) message(STATUS "Build with Vulkan support") tvm_file_glob(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc) tvm_file_glob(GLOB COMPILER_VULKAN_SRCS src/target/spirv/*.cc) + + if(Build_GTests) + message(STATUS "Building Vulkan GTests") + tvm_file_glob(GLOB_RECURSE VULKAN_TEST_SRCS "tests/cpp-runtime/vulkan/*.cc") + add_executable(vulkan-cpptest ${VULKAN_TEST_SRCS}) + target_link_libraries(vulkan-cpptest PRIVATE gtest_main tvm_runtime) + else() + message(STATUS "Couldn't build Vulkan-Gtests") + endif() + list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS}) list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS}) list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY}) diff --git a/cmake/utils/FindVulkan.cmake b/cmake/utils/FindVulkan.cmake index 032ff1dffa21..c01d019e14d8 100644 --- a/cmake/utils/FindVulkan.cmake +++ b/cmake/utils/FindVulkan.cmake @@ -36,7 +36,7 @@ macro(find_vulkan use_vulkan use_khronos_spirv) set(__use_vulkan ${use_vulkan}) if(IS_DIRECTORY ${__use_vulkan}) set(__vulkan_sdk ${__use_vulkan}) - message(STATUS "Custom Vulkan SDK PATH=" ${__use_vulkan}) + message(STATUS "Using custom Vulkan SDK: ${__vulkan_sdk}") elseif(IS_DIRECTORY $ENV{VULKAN_SDK}) set(__vulkan_sdk $ENV{VULKAN_SDK}) else() @@ -46,45 +46,65 @@ macro(find_vulkan use_vulkan use_khronos_spirv) if(IS_DIRECTORY ${use_khronos_spirv}) set(__use_khronos_spirv ${use_khronos_spirv}) - message(STATUS "Custom khronos spirv PATH=" ${__use_khronos_spirv}) + message(STATUS "Using custom Khronos SPIRV path: ${__use_khronos_spirv}") else() set(__use_khronos_spirv "") endif() if(CMAKE_SYSTEM_NAME STREQUAL "Android") - set(VULKAN_NDK_SRC ${CMAKE_ANDROID_NDK}/sources/third_party/vulkan/src) - set(Vulkan_INCLUDE_DIRS ${VULKAN_NDK_SRC}/include) - set(Vulkan_FOUND TRUE) - message(STATUS "Android Vulkan_INCLUDE_DIRS=" ${Vulkan_INCLUDE_DIRS}) - message(STATUS "Skip finding SPIRV in Android, make sure you only build tvm runtime.") - return() - endif() + message(STATUS "Detected Android build") + + set(Vulkan_INCLUDE_DIRS "${CMAKE_SYSROOT}/usr/include/vulkan") + + # Map Android ABI to architecture + set(ANDROID_LIB_ARCH "") + if(CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a") + set(ANDROID_LIB_ARCH "aarch64-linux-android") + elseif(CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a") + set(ANDROID_LIB_ARCH "arm-linux-androideabi") + elseif(CMAKE_ANDROID_ARCH_ABI STREQUAL "x86") + set(ANDROID_LIB_ARCH "i686-linux-android") + elseif(CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64") + set(ANDROID_LIB_ARCH "x86_64-linux-android") + else() + message(FATAL_ERROR "Unsupported Android ABI: ${CMAKE_ANDROID_ARCH_ABI}") + endif() + + # Find Vulkan library for Android + set(Vulkan_LIB_PATH "${CMAKE_SYSROOT}/usr/lib/${ANDROID_LIB_ARCH}/27") + find_library(Vulkan_LIBRARY NAMES vulkan libvulkan.so PATHS ${Vulkan_LIB_PATH} NO_DEFAULT_PATH) + + if(Vulkan_LIBRARY) + set(Vulkan_FOUND TRUE) + else() + message(FATAL_ERROR "Could not find Vulkan lib in ${Vulkan_LIB_PATH}") + endif() + + else() + message(STATUS "__vulkan_sdk:- " ${__vulkan_sdk}) if(__vulkan_sdk) set(Vulkan_INCLUDE_DIRS ${__vulkan_sdk}/include) find_library(Vulkan_LIBRARY NAMES vulkan vulkan-1 PATHS ${__vulkan_sdk}/lib) if(Vulkan_LIBRARY) set(Vulkan_FOUND TRUE) endif() - endif(__vulkan_sdk) + endif() - # resort to find vulkan of option is on - if(NOT Vulkan_FOUND) - if(${__use_vulkan} MATCHES ${IS_TRUE_PATTERN}) - find_package(Vulkan QUIET) - endif() + if(NOT Vulkan_FOUND AND ${use_vulkan} MATCHES ${IS_TRUE_PATTERN}) + find_package(Vulkan QUIET) endif() if(Vulkan_FOUND) get_filename_component(VULKAN_LIBRARY_PATH ${Vulkan_LIBRARY} DIRECTORY) if (WIN32) find_library(Vulkan_SPIRV_TOOLS_LIBRARY SPIRV-Tools - HINTS ${__use_khronos_spirv}/spirv-tools/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${VULKAN_SDK}/lib) + HINTS ${__use_khronos_spirv}/spirv-tools/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${__vulkan_sdk}/lib) find_path(_libspirv libspirv.h HINTS ${__use_khronos_spirv}/spirv-tools/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan spirv-tools) find_path(_spirv spirv.hpp HINTS ${__use_khronos_spirv}/SPIRV-Headers/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers) else() find_library(Vulkan_SPIRV_TOOLS_LIBRARY SPIRV-Tools - HINTS ${__use_khronos_spirv}/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${VULKAN_SDK}/lib) + HINTS ${__use_khronos_spirv}/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${__vulkan_sdk}/lib) find_path(_libspirv libspirv.h HINTS ${__use_khronos_spirv}/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan spirv-tools) find_path(_spirv spirv.hpp HINTS ${__use_khronos_spirv}/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers) endif() @@ -95,4 +115,5 @@ macro(find_vulkan use_vulkan use_khronos_spirv) message(STATUS "Vulkan_LIBRARY=" ${Vulkan_LIBRARY}) message(STATUS "Vulkan_SPIRV_TOOLS_LIBRARY=" ${Vulkan_SPIRV_TOOLS_LIBRARY}) endif(Vulkan_FOUND) + endif() endmacro(find_vulkan) diff --git a/python/tvm/relax/pipeline.py b/python/tvm/relax/pipeline.py index 1c25b2053bc2..79fde6c362f4 100644 --- a/python/tvm/relax/pipeline.py +++ b/python/tvm/relax/pipeline.py @@ -337,7 +337,7 @@ def get_default_pipeline(target: tvm.target.Target): return backend.gpu_generic.get_default_pipeline(target) if target.kind.name == "llvm": return backend.cpu_generic.get_default_pipeline(target) - if target.kind.name == "opencl" and "adreno" in target.keys: + if target.kind.name in ["opencl", "vulkan"] and "adreno" in target.keys: return backend.adreno.get_default_pipeline(target) if BackendDispatcher.is_gpu_target(target): return backend.gpu_generic.get_default_pipeline(target) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index 51fad1803ad9..b2d685a49e7e 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -838,6 +838,14 @@ def _multi_gpu_exists(): ) +def _check_opencl_vulkan(): + return ( + (_cmake_flag_enabled("USE_OPENCL") and tvm.opencl(0).exist) + or (_cmake_flag_enabled("USE_VULKAN") and tvm.vulkan(0).exist) + or "RPC_TARGET" in os.environ + ) + + # Mark a test as requiring llvm to run requires_llvm = Feature( "llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm" @@ -976,8 +984,8 @@ def _multi_gpu_exists(): "Vulkan", cmake_flag="USE_VULKAN", target_kind_enabled="vulkan", - target_kind_hardware="vulkan", - parent_features="gpu", + target_kind_hardware="vulkan" if "RPC_TARGET" not in os.environ else None, + parent_features="gpu" if "RPC_TARGET" not in os.environ else None, ) # Mark a test as requiring OpenCLML support in build. @@ -988,6 +996,13 @@ def _multi_gpu_exists(): target_kind_enabled="opencl", ) +requires_opencl_vulkan = Feature( + "opencl_vulkan", + "OpenCL or Vulkan", + run_time_check=_check_opencl_vulkan, + parent_features=["opencl", "vulkan"], +) + # Mark a test as requiring NNAPI support in build. requires_nnapi = Feature( "NNAPI", diff --git a/src/runtime/file_utils.cc b/src/runtime/file_utils.cc index b3733ee6fdff..d64c95df83be 100644 --- a/src/runtime/file_utils.cc +++ b/src/runtime/file_utils.cc @@ -45,6 +45,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const { writer->BeginObject(); writer->WriteObjectKeyValue("name", name); writer->WriteObjectKeyValue("arg_types", sarg_types); + writer->WriteObjectKeyValue("storage_scopes", storage_scopes); writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags); std::vector iarg_extra_tags(arg_extra_tags.size()); for (size_t i = 0; i < arg_extra_tags.size(); ++i) { @@ -59,6 +60,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { std::vector sarg_types; helper.DeclareField("name", &name); helper.DeclareField("arg_types", &sarg_types); + helper.DeclareOptionalField("storage_scopes", &storage_scopes); helper.DeclareOptionalField("launch_param_tags", &launch_param_tags); helper.DeclareOptionalField("thread_axis_tags", &launch_param_tags); // for backward compatibility @@ -78,6 +80,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) { void FunctionInfo::Save(dmlc::Stream* writer) const { writer->Write(name); writer->Write(arg_types); + writer->Write(storage_scopes); writer->Write(launch_param_tags); writer->Write(arg_extra_tags); } @@ -85,6 +88,7 @@ void FunctionInfo::Save(dmlc::Stream* writer) const { bool FunctionInfo::Load(dmlc::Stream* reader) { if (!reader->Read(&name)) return false; if (!reader->Read(&arg_types)) return false; + if (!reader->Read(&storage_scopes)) return false; if (!reader->Read(&launch_param_tags)) return false; if (!reader->Read(&arg_extra_tags)) return false; return true; diff --git a/src/runtime/meta_data.h b/src/runtime/meta_data.h index aceb97b58374..61e4fde31a6a 100644 --- a/src/runtime/meta_data.h +++ b/src/runtime/meta_data.h @@ -59,6 +59,7 @@ constexpr const char* kUseCooperativeLaunch = "tir.use_cooperative_launch"; struct FunctionInfo { std::string name; std::vector arg_types; + std::vector storage_scopes; std::vector launch_param_tags; enum class ArgExtraTags : int { kNone = 0, kTensorMap = 1 }; diff --git a/src/runtime/vulkan/vulkan_buffer.cc b/src/runtime/vulkan/vulkan_buffer.cc index f8d40b030919..646e03e441ad 100644 --- a/src/runtime/vulkan/vulkan_buffer.cc +++ b/src/runtime/vulkan/vulkan_buffer.cc @@ -22,6 +22,7 @@ #include #include "vulkan_device_api.h" +#include "vulkan_resource.h" namespace tvm { namespace runtime { @@ -29,6 +30,7 @@ namespace vulkan { VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage) { VkBufferCreateInfo info = {VK_STRUCTURE_TYPE_BUFFER_CREATE_INFO}; + info.size = nbytes; // Since sharingMode is not VK_SHARING_MODE_CONCURRENT, no need to // specify the queue families. @@ -38,46 +40,48 @@ VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage) } VulkanBuffer::VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index) - : device_(device) { - // Create a buffer + uint32_t mem_type_index, std::optional mem_scope, + std::shared_ptr back_memory) + : VulkanResource(device, mem_scope, back_memory), size(nbytes) { VkBufferCreateInfo buffer_info = MakeBufferCreateInfo(nbytes, usage); VULKAN_CALL(vkCreateBuffer(device, &buffer_info, nullptr, &buffer)); - // Allocate memory - VkMemoryAllocateInfo mem_info = {VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO}; - mem_info.allocationSize = buffer_info.size; - mem_info.memoryTypeIndex = mem_type_index; - - VkMemoryDedicatedAllocateInfoKHR dedicated_info = { - VK_STRUCTURE_TYPE_MEMORY_DEDICATED_ALLOCATE_INFO_KHR}; + VkMemoryRequirements mem_reqs; + vkGetBufferMemoryRequirements(device, buffer, &mem_reqs); - bool use_dedicated_allocation = UseDedicatedAllocation(device, buffer, &mem_info.allocationSize); - if (use_dedicated_allocation) { - dedicated_info.buffer = buffer; - mem_info.pNext = &dedicated_info; + // Allocate new memory if no memory is passed in or if the existing memory is not compatible + if (!memory) { + AllocateMemory(mem_reqs, mem_type_index); } - VULKAN_CALL(vkAllocateMemory(device, &mem_info, nullptr, &memory)); - // Bind the buffer to the allocated memory - VULKAN_CALL(vkBindBufferMemory(device, buffer, memory, 0)); + VULKAN_CALL(vkBindBufferMemory(device, buffer, memory->memory_, 0)); +} + +void VulkanBuffer::AllocateMemory(const VkMemoryRequirements& mem_reqs, uint32_t mem_type_index) { + VkMemoryAllocateInfo mem_info = {VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO}; + mem_info.allocationSize = mem_reqs.size; + mem_info.memoryTypeIndex = mem_type_index; + + // Allocate memory + VkDeviceMemory raw_memory; + VULKAN_CALL(vkAllocateMemory(device_, &mem_info, nullptr, &raw_memory)); + + // Store the allocated memory along with its requirements + memory = std::make_shared(raw_memory, mem_reqs); } VulkanBuffer::~VulkanBuffer() { if (buffer) { vkDestroyBuffer(device_, buffer, nullptr); - } - if (memory) { - vkFreeMemory(device_, memory, nullptr); + buffer = VK_NULL_HANDLE; } } VulkanBuffer::VulkanBuffer(VulkanBuffer&& other) - : device_(other.device_), buffer(other.buffer), memory(other.memory) { - other.device_ = VK_NULL_HANDLE; + : VulkanResource(std::move(other)), buffer(other.buffer) { other.buffer = VK_NULL_HANDLE; - other.memory = VK_NULL_HANDLE; + other.size = 0; } VulkanBuffer& VulkanBuffer::operator=(VulkanBuffer&& other) { @@ -115,14 +119,15 @@ bool VulkanBuffer::UseDedicatedAllocation(const VulkanDevice& device, VkBuffer b } VulkanHostVisibleBuffer::VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes, - VkBufferUsageFlags usage, uint32_t mem_type_index) - : vk_buf(device, nbytes, usage, mem_type_index), size(nbytes) { - VULKAN_CALL(vkMapMemory(device, vk_buf.memory, 0, size, 0, &host_addr)); + VkBufferUsageFlags usage, uint32_t mem_type_index, + std::optional mem_scope) + : vk_buf(device, nbytes, usage, mem_type_index, mem_scope), size(nbytes) { + VULKAN_CALL(vkMapMemory(device, vk_buf.memory->memory_, 0, size, 0, &host_addr)); } VulkanHostVisibleBuffer::~VulkanHostVisibleBuffer() { if (host_addr) { - vkUnmapMemory(vk_buf.device_, vk_buf.memory); + vkUnmapMemory(vk_buf.device_, vk_buf.memory->memory_); } } diff --git a/src/runtime/vulkan/vulkan_buffer.h b/src/runtime/vulkan/vulkan_buffer.h index a3e37431e434..d5cbd2149dc5 100644 --- a/src/runtime/vulkan/vulkan_buffer.h +++ b/src/runtime/vulkan/vulkan_buffer.h @@ -23,15 +23,17 @@ #include #include +#include +#include #include +#include "vulkan_resource.h" + namespace tvm { namespace runtime { namespace vulkan { -class VulkanDevice; - -class VulkanBuffer { +class VulkanBuffer : public VulkanResource { public: /* \brief Allocate memory on the device * @@ -47,8 +49,10 @@ class VulkanBuffer { * an index to a compatible memory located in * VkPhysicalDeviceMemoryProperties. */ + VulkanBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index); + uint32_t mem_type_index, std::optional mem_scope = std::nullopt, + std::shared_ptr back_memory = nullptr); //! \brief Destructor, deallocates the memory and buffer. ~VulkanBuffer(); @@ -61,6 +65,8 @@ class VulkanBuffer { VulkanBuffer(VulkanBuffer&&); VulkanBuffer& operator=(VulkanBuffer&&); + void AllocateMemory(const VkMemoryRequirements& mem_reqs, uint32_t mem_type_index); + private: /*! \brief Whether this buffer should be allocated using dedicated * allocation @@ -95,15 +101,11 @@ class VulkanBuffer { * VulkanDevice may be moved to a different location while the * VulkanBuffer is alive. */ - VkDevice device_{VK_NULL_HANDLE}; //! \brief Handle to the logical buffer on the device VkBuffer buffer{VK_NULL_HANDLE}; - //! \brief Handle to the physical device memory - VkDeviceMemory memory{VK_NULL_HANDLE}; - - friend class VulkanHostVisibleBuffer; + size_t size{0}; // buffer size }; /*! \brief A struct to represent Vulkan buffers backed by host visible memory */ @@ -124,7 +126,8 @@ class VulkanHostVisibleBuffer { * VkPhysicalDeviceMemoryProperties. */ VulkanHostVisibleBuffer(const VulkanDevice& device, size_t nbytes, VkBufferUsageFlags usage, - uint32_t mem_type_index); + uint32_t mem_type_index, + std::optional mem_scope = std::nullopt); //! \brief Unmap memory and deallocate. ~VulkanHostVisibleBuffer(); diff --git a/src/runtime/vulkan/vulkan_device.cc b/src/runtime/vulkan/vulkan_device.cc index cc39972432a3..e809d9951ff5 100644 --- a/src/runtime/vulkan/vulkan_device.cc +++ b/src/runtime/vulkan/vulkan_device.cc @@ -143,7 +143,9 @@ VulkanDeviceProperties::VulkanDeviceProperties(const VulkanInstance& instance, supported_subgroup_operations = (subgroup.supportedStages & VK_SHADER_STAGE_COMPUTE_BIT) ? subgroup.supportedOperations : 0; + timestamp_period = properties.properties.limits.timestampPeriod; max_num_threads = properties.properties.limits.maxComputeWorkGroupInvocations; + image_row_align = properties.properties.limits.optimalBufferCopyRowPitchAlignment; // Even if we can't query it, warp size must be at least 1. // thread_warp_size = std::max(subgroup.subgroupSize, 1U); @@ -234,6 +236,12 @@ VulkanGetBufferMemoryRequirements2Functions::VulkanGetBufferMemoryRequirements2F vkGetDeviceProcAddr(device, "vkGetBufferMemoryRequirements2KHR")); } +VulkanGetImageMemoryRequirements2Functions::VulkanGetImageMemoryRequirements2Functions( + VkDevice device) { + vkGetImageMemoryRequirements2KHR = (PFN_vkGetImageMemoryRequirements2KHR)ICHECK_NOTNULL( + vkGetDeviceProcAddr(device, "vkGetImageMemoryRequirements2KHR")); +} + VulkanQueueInsertDebugUtilsLabelFunctions::VulkanQueueInsertDebugUtilsLabelFunctions( VkInstance instance) { vkQueueInsertDebugUtilsLabelEXT = (PFN_vkQueueInsertDebugUtilsLabelEXT)ICHECK_NOTNULL( diff --git a/src/runtime/vulkan/vulkan_device.h b/src/runtime/vulkan/vulkan_device.h index 0573a00e5c9e..4497afb018a1 100644 --- a/src/runtime/vulkan/vulkan_device.h +++ b/src/runtime/vulkan/vulkan_device.h @@ -33,6 +33,7 @@ #include "../thread_map.h" #include "vulkan/vulkan_core.h" #include "vulkan_buffer.h" +#include "vulkan_image.h" #include "vulkan_stream.h" namespace tvm { @@ -57,6 +58,12 @@ struct VulkanGetBufferMemoryRequirements2Functions { PFN_vkGetBufferMemoryRequirements2KHR vkGetBufferMemoryRequirements2KHR{nullptr}; }; +struct VulkanGetImageMemoryRequirements2Functions { + explicit VulkanGetImageMemoryRequirements2Functions(VkDevice device); + + PFN_vkGetImageMemoryRequirements2KHR vkGetImageMemoryRequirements2KHR{nullptr}; +}; + struct VulkanQueueInsertDebugUtilsLabelFunctions { explicit VulkanQueueInsertDebugUtilsLabelFunctions(VkInstance instance); @@ -96,16 +103,18 @@ struct VulkanDeviceProperties { uint32_t max_block_size_y{1}; uint32_t max_block_size_z{1}; uint32_t max_push_constants_size{128}; - uint32_t max_uniform_buffer_range{16384}; + uint32_t max_uniform_buffer_range{65536}; uint32_t max_storage_buffer_range{1 << 27}; uint32_t max_per_stage_descriptor_storage_buffer{4}; - uint32_t max_shared_memory_per_block{16384}; + uint32_t max_shared_memory_per_block{32768}; std::string device_type{"unknown_device_type"}; std::string device_name{"unknown_device_name"}; std::string driver_name{"unknown_driver_name"}; uint32_t driver_version{0}; uint32_t vulkan_api_version{VK_API_VERSION_1_0}; uint32_t max_spirv_version{0x10000}; + uint32_t image_row_align{0}; + float timestamp_period{0}; }; /*! \brief Handle to the Vulkan API's VkDevice @@ -219,6 +228,8 @@ class VulkanDevice { std::unique_ptr descriptor_template_khr_functions{nullptr}; std::unique_ptr get_buffer_memory_requirements_2_functions{nullptr}; + std::unique_ptr + get_image_memory_requirements_2_functions{nullptr}; std::unique_ptr queue_insert_debug_utils_label_functions{nullptr}; // Memory type index for compute @@ -308,6 +319,9 @@ uint32_t FindMemoryType(const VulkanDevice& device, VkBufferCreateInfo info, VkBufferCreateInfo MakeBufferCreateInfo(size_t nbytes, VkBufferUsageFlags usage); +VkImageCreateInfo MakeImageCreateInfo(VkFormat format, uint32_t width, uint32_t height, + uint32_t layers, VkImageUsageFlags usage); + } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_device_api.cc b/src/runtime/vulkan/vulkan_device_api.cc index a2ff8bb7ce0e..0b025bc41c7f 100644 --- a/src/runtime/vulkan/vulkan_device_api.cc +++ b/src/runtime/vulkan/vulkan_device_api.cc @@ -22,17 +22,48 @@ #include #include +#include +#include +#include #include #include #include #include +#include "../memory/pooled_allocator.h" +#include "vulkan_buffer.h" #include "vulkan_common.h" +#include "vulkan_image.h" +#include "vulkan_timer.h" namespace tvm { namespace runtime { namespace vulkan { +using tvm::runtime::memory::Buffer; + +struct ImageInfo { + VkOffset3D origin; + VkExtent3D region; + uint32_t layer_count; +}; + +ImageInfo GetImageInfo(const VulkanImage* image, const DLTensor* tensor) { + ImageInfo info{}; + + ICHECK(tensor->dtype.lanes == 1) << "Image dtype has lanes: " << tensor->dtype.lanes; + + info.origin = {0, 0, 0}; + info.layer_count = 0; + size_t axis = DefaultTextureLayoutSeparator(tensor->ndim, + VulkanResource::ScopeFromMemoryLayout(image->layout)); + auto texture_shape = ApplyTexture2DFlattening(tensor->shape, tensor->ndim, axis); + info.region = {static_cast(texture_shape.width), + static_cast(texture_shape.height), 1}; + info.layer_count = static_cast(texture_shape.depth); + return info; +} + VulkanDeviceAPI* VulkanDeviceAPI::Global() { // Most of the TVM Global() functions allocate with "new" and do // not deallocate, as the OS can clean up any leftover buffers at @@ -175,7 +206,8 @@ void VulkanDeviceAPI::GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) { // devices that support the VK_EXT_memory_budget extension. break; case kImagePitchAlignment: - return; + *rv = int64_t(prop.image_row_align); + break; } } @@ -279,39 +311,153 @@ void VulkanDeviceAPI::GetTargetProperty(Device dev, const std::string& property, if (property == "max_spirv_version") { *rv = int64_t(prop.max_spirv_version); } + if (property == "image_row_align") { + *rv = int64_t(prop.image_row_align); + } } -void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, - DLDataType type_hint) { +size_t VulkanDeviceAPI::GetImageAlignment(Device dev) { + const auto& device = this->device(dev.device_id); + return device.device_properties.image_row_align; +} + +size_t VulkanDeviceAPI::GetDataSize(const DLTensor& arr, ffi::Optional mem_scope) { + if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { + return DeviceAPI::GetDataSize(arr); + } + + uint32_t row_align = static_cast(GetImageAlignment(arr.device)); + std::vector shape; + shape.assign(arr.shape, arr.shape + arr.ndim); + return runtime::GetTextureMemorySize>(shape, arr.dtype.bits, arr.dtype.lanes, + mem_scope.value(), row_align); +} + +static size_t GetMemObjectSize(Device dev, int ndim, const int64_t* shape, DLDataType dtype) { + DLTensor temp; + temp.data = nullptr; + temp.device = dev; + temp.ndim = ndim; + temp.dtype = dtype; + temp.shape = const_cast(shape); + temp.strides = nullptr; + temp.byte_offset = 0; + size_t size = DeviceAPI::Get(dev)->GetDataSize(temp); + return size; +} + +void* VulkanDeviceAPI::AllocVulkanBuffer(Device dev, size_t nbytes, DLDataType type_hint, + std::shared_ptr memory) { if (nbytes == 0) { // Vulkan seems to have issues if we return nullptr on zero size alloc nbytes = 1; } + + // For a standard buffer allocation, use the default layout (1D Buffer) + auto mem_scope = std::optional("global"); + const auto& device = this->device(dev.device_id); auto usage = VK_BUFFER_USAGE_TRANSFER_SRC_BIT | VK_BUFFER_USAGE_TRANSFER_DST_BIT | VK_BUFFER_USAGE_STORAGE_BUFFER_BIT; - return new VulkanBuffer(device, nbytes, usage, device.compute_mtype_index); + + return new VulkanBuffer(device, nbytes, usage, device.compute_mtype_index, mem_scope, memory); } -void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { - // Before releasing the vkBuffer, call sync to - // finish all the vulkan commands that reference the buffer. +void* VulkanDeviceAPI::AllocVulkanImage(Device dev, size_t width, size_t height, size_t layers, + DLDataType type_hint, ffi::Optional mem_scope, + std::shared_ptr memory) { + const auto& device = this->device(dev.device_id); + auto format = DTypeToVulkanFormat(type_hint); // Use the new function to get the format + auto usage = VK_IMAGE_USAGE_STORAGE_BIT | VK_IMAGE_USAGE_TRANSFER_DST_BIT | + VK_IMAGE_USAGE_TRANSFER_SRC_BIT; + + // image and view creation + VulkanImage* image = new VulkanImage(device, format, width, height, layers, usage, + device.compute_mtype_index, mem_scope.value(), memory); + image->CreateImageView(format); + return image; +} + +void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t nbytes, size_t alignment, + DLDataType type_hint) { + return AllocVulkanBuffer(dev, nbytes, type_hint, nullptr); +} + +void* VulkanDeviceAPI::AllocDataSpace(Device dev, size_t width, size_t height, size_t depth, + DLDataType type_hint, ffi::Optional mem_scope) { + if (!mem_scope.has_value()) { + mem_scope = ffi::String("global.texture"); + } + return AllocVulkanImage(dev, width, height, depth, type_hint, mem_scope, nullptr); +} + +void* VulkanDeviceAPI::AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, + ffi::Optional mem_scope) { + if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { + size_t size = GetMemObjectSize(dev, ndim, shape, dtype); + auto buf = MemoryManager::GetOrCreateAllocator(dev, AllocatorType::kPooled) + ->Alloc(dev, size, kTempAllocaAlignment, dtype); + return buf.data; + } + + size_t axis = DefaultTextureLayoutSeparator(ndim, mem_scope.value()); + auto texture = ApplyTexture2DFlattening(shape, ndim, axis); + + return AllocDataSpace(dev, texture.width, texture.height, texture.depth, dtype, mem_scope); +} + +void* VulkanDeviceAPI::AllocDataSpaceView(Device dev, void* data, ffi::Shape shape, + DLDataType dtype, ffi::Optional mem_scope) { + const auto* res = static_cast(data); + + if (!mem_scope.has_value() || mem_scope.value().empty() || mem_scope.value() == "global") { + size_t nbytes = GetMemObjectSize(dev, shape.size(), shape.data(), dtype); + return AllocVulkanBuffer(dev, nbytes, dtype, res->memory); + } + size_t axis = DefaultTextureLayoutSeparator(shape.size(), mem_scope.value()); + auto texture = ApplyTexture2DFlattening(shape.data(), shape.size(), axis); + return AllocVulkanImage(dev, texture.width, texture.height, texture.depth, dtype, mem_scope, + res->memory); +} + +void VulkanDeviceAPI::FreeDataSpaceView(Device dev, void* ptr) { StreamSync(dev, nullptr); + const auto* res = static_cast(ptr); - auto* pbuf = static_cast(ptr); - delete pbuf; + if (const auto* buf_res = dynamic_cast(res)) { + delete buf_res; + } else if (const auto* img_res = dynamic_cast(res)) { + delete img_res; + } +} + +void VulkanDeviceAPI::FreeDataSpace(Device dev, void* ptr) { + // Get Vulkan stream associated with the device + VulkanStream& stream = device(dev.device_id).ThreadLocalStream(); + const auto* res = static_cast(ptr); + + if (const auto* buf_res = dynamic_cast(res)) { + // Defer buffer destruction by scheduling it in VulkanStream + stream.Launch([buf_res](VulkanStreamState* state) { delete buf_res; }); + } else if (const auto* img_res = dynamic_cast(res)) { + // Defer image destruction in VulkanStream + stream.Launch([img_res](VulkanStreamState* state) { delete img_res; }); + } } void* VulkanDeviceAPI::AllocWorkspace(Device dev, size_t size, DLDataType type_hint) { - auto& pool = pool_per_thread.GetOrMake(kDLVulkan, this); - return pool.AllocWorkspace(dev, size); + // Use MemoryManager to allocate workspace memory. + auto buffer = MemoryManager::GetOrCreateAllocator(dev, AllocatorType::kPooled) + ->Alloc(dev, size, kTempAllocaAlignment, type_hint); + return buffer.data; } void VulkanDeviceAPI::FreeWorkspace(Device dev, void* data) { - auto* pool = pool_per_thread.Get(); - ICHECK(pool) << "Attempted to free a vulkan workspace on a CPU-thread " - << "that has never allocated a workspace"; - pool->FreeWorkspace(dev, data); + // Use MemoryManager to free workspace memory. + Allocator* allocator = MemoryManager::GetAllocator(dev, AllocatorType::kPooled); + Buffer buffer; + buffer.data = data; + allocator->Free(buffer); } TVMStreamHandle VulkanDeviceAPI::CreateStream(Device dev) { return nullptr; } @@ -332,33 +478,99 @@ void VulkanDeviceAPI::StreamSync(Device dev, TVMStreamHandle stream) { device(dev.device_id).ThreadLocalStream().Synchronize(); } -void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* to, - size_t to_offset, size_t size, Device dev_from, Device dev_to, - DLDataType type_hint, TVMStreamHandle stream) { +void VulkanDeviceAPI::SetStream(Device dev, TVMStreamHandle stream) { + ICHECK_EQ(stream, static_cast(nullptr)); +} + +TVMStreamHandle VulkanDeviceAPI::GetCurrentStream(Device dev) { return nullptr; } + +void VulkanDeviceAPI::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) { ICHECK(stream == nullptr); - Device dev = dev_from; - if (dev_from.device_type == kDLCPU) { - dev = dev_to; - } + ICHECK(from->device.device_type == kDLVulkan || from->device.device_type == kDLCPU); + ICHECK(to->device.device_type == kDLVulkan || to->device.device_type == kDLCPU); + + size_t nbytes = GetDataSize(*from); + ICHECK_EQ(nbytes, GetDataSize(*to)); + ICHECK(IsContiguous(*from) && IsContiguous(*to)) + << "CopyDataFromTo only supports contiguous array for now"; + + Device dev_from = from->device; + Device dev_to = to->device; + const auto* from_res = static_cast(from->data); + const auto* to_res = static_cast(to->data); int from_dev_type = static_cast(dev_from.device_type); int to_dev_type = static_cast(dev_to.device_type); + if (from_dev_type == kDLVulkan && to_dev_type == kDLVulkan) { ICHECK_EQ(dev_from.device_id, dev_to.device_id) << "The Vulkan runtime does not support deviceA to deviceB copies. " << "This should be changed to a deviceA to CPU copy, followed by a CPU to deviceB copy"; device(dev_from.device_id).ThreadLocalStream().Launch([=](VulkanStreamState* state) { - // 1: copy - const auto* from_buf = static_cast(from); - auto* to_buf = static_cast(to); - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); - // 2: barrier(transfer-> compute|transfer) - VkMemoryBarrier barrier_info; + // Buffer to Buffer Copy + if (const auto* from_buf = dynamic_cast(from_res)) { + if (const auto* to_buf = dynamic_cast(to_res)) { + VkBufferCopy copy_info = {}; + copy_info.srcOffset = from->byte_offset; + copy_info.dstOffset = to->byte_offset; + copy_info.size = nbytes; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, to_buf->buffer, 1, ©_info); + } else if (const auto* to_img = dynamic_cast(to_res)) { + auto image_info = GetImageInfo(to_img, to); + + VkBufferImageCopy copy_info = {}; + copy_info.bufferOffset = from->byte_offset; + copy_info.bufferRowLength = 0; + copy_info.bufferImageHeight = 0; + copy_info.imageSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + copy_info.imageSubresource.mipLevel = 0; + copy_info.imageSubresource.baseArrayLayer = 0; + copy_info.imageSubresource.layerCount = image_info.layer_count; + copy_info.imageOffset = {0, 0, 0}; + copy_info.imageExtent = image_info.region; + vkCmdCopyBufferToImage(state->cmd_buffer_, from_buf->buffer, to_img->image, + VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, ©_info); + } + } else if (const auto* from_img = dynamic_cast(from_res)) { + if (const auto* to_buf = dynamic_cast(to_res)) { + auto image_info = GetImageInfo(from_img, from); + + VkBufferImageCopy copy_info = {}; + copy_info.bufferOffset = to->byte_offset; + copy_info.bufferRowLength = 0; + copy_info.bufferImageHeight = 0; + copy_info.imageSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + copy_info.imageSubresource.mipLevel = 0; + copy_info.imageSubresource.baseArrayLayer = 0; + copy_info.imageSubresource.layerCount = image_info.layer_count; + copy_info.imageOffset = {0, 0, 0}; + copy_info.imageExtent = image_info.region; + vkCmdCopyImageToBuffer(state->cmd_buffer_, from_img->image, + VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, to_buf->buffer, 1, + ©_info); + } else if (const auto* to_img = dynamic_cast(to_res)) { + auto image_info = GetImageInfo(from_img, from); + + VkImageCopy copy_info = {}; + copy_info.srcSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + copy_info.srcSubresource.mipLevel = 0; + copy_info.srcSubresource.baseArrayLayer = 0; + copy_info.srcSubresource.layerCount = image_info.layer_count; + copy_info.srcOffset = {0, 0, 0}; + copy_info.dstSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + copy_info.dstSubresource.mipLevel = 0; + copy_info.dstSubresource.baseArrayLayer = 0; + copy_info.dstSubresource.layerCount = image_info.layer_count; + copy_info.dstOffset = {0, 0, 0}; + copy_info.extent = image_info.region; + vkCmdCopyImage(state->cmd_buffer_, from_img->image, VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, + to_img->image, VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, ©_info); + } + } + + // Memory barrier to ensure proper synchronization + VkMemoryBarrier barrier_info = {}; barrier_info.sType = VK_STRUCTURE_TYPE_MEMORY_BARRIER; barrier_info.pNext = nullptr; barrier_info.srcAccessMask = VK_ACCESS_TRANSFER_WRITE_BIT; @@ -370,43 +582,93 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* }); } else if (from_dev_type == kDLVulkan && to_dev_type == kDLCPU) { - const auto* from_buf = static_cast(from); auto& device = this->device(dev_from.device_id); auto& stream = device.ThreadLocalStream(); - auto& staging_buffer = device.ThreadLocalStagingBuffer(size); + auto& staging_buffer = device.ThreadLocalStagingBuffer(nbytes); + stream.Launch([&](VulkanStreamState* state) { - VkBufferCopy copy_info; - copy_info.srcOffset = from_offset; - copy_info.dstOffset = 0; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, staging_buffer.vk_buf.buffer, 1, - ©_info); + if (const auto* from_buf = dynamic_cast(from_res)) { + VkBufferCopy copy_info = {}; + copy_info.srcOffset = from->byte_offset; + copy_info.dstOffset = 0; + copy_info.size = nbytes; + vkCmdCopyBuffer(state->cmd_buffer_, from_buf->buffer, staging_buffer.vk_buf.buffer, 1, + ©_info); + } else if (const auto* from_img = dynamic_cast(from_res)) { + auto image_info = GetImageInfo(from_img, from); + + // Ensure the image is in the correct layout for transfer + VkImageMemoryBarrier img_barrier = {}; + img_barrier.sType = VK_STRUCTURE_TYPE_IMAGE_MEMORY_BARRIER; + img_barrier.oldLayout = VK_IMAGE_LAYOUT_SHADER_READ_ONLY_OPTIMAL; // Original layout + img_barrier.newLayout = VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL; + img_barrier.srcQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + img_barrier.dstQueueFamilyIndex = VK_QUEUE_FAMILY_IGNORED; + img_barrier.image = from_img->image; + img_barrier.subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + img_barrier.subresourceRange.baseMipLevel = 0; + img_barrier.subresourceRange.levelCount = 1; + img_barrier.subresourceRange.baseArrayLayer = 0; + img_barrier.subresourceRange.layerCount = image_info.layer_count; + img_barrier.srcAccessMask = VK_ACCESS_SHADER_READ_BIT; + img_barrier.dstAccessMask = VK_ACCESS_TRANSFER_READ_BIT; + + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, + VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 0, nullptr, 0, nullptr, 1, + &img_barrier); + VkBufferImageCopy copy_info = {}; + copy_info.bufferOffset = 0; + copy_info.bufferRowLength = 0; + copy_info.bufferImageHeight = 0; + copy_info.imageSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + copy_info.imageSubresource.mipLevel = 0; + copy_info.imageSubresource.baseArrayLayer = 0; + copy_info.imageSubresource.layerCount = image_info.layer_count; + copy_info.imageOffset = {0, 0, 0}; + copy_info.imageExtent = image_info.region; + vkCmdCopyImageToBuffer(state->cmd_buffer_, from_img->image, + VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL, staging_buffer.vk_buf.buffer, + 1, ©_info); + + // Restore the image layout + img_barrier.oldLayout = VK_IMAGE_LAYOUT_TRANSFER_SRC_OPTIMAL; + img_barrier.newLayout = VK_IMAGE_LAYOUT_GENERAL; + img_barrier.srcAccessMask = VK_ACCESS_TRANSFER_READ_BIT; + img_barrier.dstAccessMask = VK_ACCESS_SHADER_READ_BIT; + vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_TRANSFER_BIT, + VK_PIPELINE_STAGE_FRAGMENT_SHADER_BIT, 0, 0, nullptr, 0, nullptr, 1, + &img_barrier); + } }); + stream.Synchronize(); stream.ProfilerReset(); if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; - mrange.memory = staging_buffer.vk_buf.memory; + mrange.memory = staging_buffer.vk_buf.memory->memory_; mrange.offset = 0; - mrange.size = VK_WHOLE_SIZE; // size; + mrange.size = VK_WHOLE_SIZE; VULKAN_CALL(vkInvalidateMappedMemoryRanges(device, 1, &mrange)); } - memcpy(static_cast(to) + to_offset, static_cast(staging_buffer.host_addr), size); + memcpy(static_cast(to->data) + to->byte_offset, + static_cast(staging_buffer.host_addr), nbytes); + } else if (from_dev_type == kDLCPU && to_dev_type == kDLVulkan) { auto& device = this->device(dev_to.device_id); auto& stream = device.ThreadLocalStream(); - const auto* to_buf = static_cast(to); - auto& staging_buffer = device.ThreadLocalStagingBuffer(size); - memcpy(staging_buffer.host_addr, static_cast(from) + from_offset, size); + auto& staging_buffer = device.ThreadLocalStagingBuffer(nbytes); + memcpy(staging_buffer.host_addr, static_cast(from->data) + from->byte_offset, + nbytes); + // host side flush if access is not coherent. // so writes from CPU is visible to GPU if (!device.coherent_staging) { VkMappedMemoryRange mrange; mrange.sType = VK_STRUCTURE_TYPE_MAPPED_MEMORY_RANGE; mrange.pNext = nullptr; - mrange.memory = staging_buffer.vk_buf.memory; + mrange.memory = staging_buffer.vk_buf.memory->memory_; mrange.offset = 0; mrange.size = VK_WHOLE_SIZE; // size; VULKAN_CALL(vkFlushMappedMemoryRanges(device, 1, &mrange)); @@ -422,19 +684,35 @@ void VulkanDeviceAPI::CopyDataFromTo(const void* from, size_t from_offset, void* vkCmdPipelineBarrier(state->cmd_buffer_, VK_PIPELINE_STAGE_HOST_BIT, VK_PIPELINE_STAGE_TRANSFER_BIT, 0, 1, &barrier_info, 0, nullptr, 0, nullptr); - // 1: copy - VkBufferCopy copy_info; - copy_info.srcOffset = 0; - copy_info.dstOffset = to_offset; - copy_info.size = size; - vkCmdCopyBuffer(state->cmd_buffer_, staging_buffer.vk_buf.buffer, to_buf->buffer, 1, - ©_info); + + if (const auto* to_buf = dynamic_cast(to_res)) { + VkBufferCopy copy_info; + copy_info.srcOffset = 0; + copy_info.dstOffset = to->byte_offset; + copy_info.size = nbytes; + vkCmdCopyBuffer(state->cmd_buffer_, staging_buffer.vk_buf.buffer, to_buf->buffer, 1, + ©_info); + } else if (const auto* to_img = dynamic_cast(to_res)) { + auto image_info = GetImageInfo(to_img, to); + + VkBufferImageCopy copy_info = {}; + copy_info.bufferOffset = 0; + copy_info.bufferRowLength = 0; + copy_info.bufferImageHeight = 0; + copy_info.imageSubresource.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + copy_info.imageSubresource.mipLevel = 0; + copy_info.imageSubresource.baseArrayLayer = 0; + copy_info.imageSubresource.layerCount = image_info.layer_count; + copy_info.imageOffset = {0, 0, 0}; + copy_info.imageExtent = image_info.region; + vkCmdCopyBufferToImage(state->cmd_buffer_, staging_buffer.vk_buf.buffer, to_img->image, + VK_IMAGE_LAYOUT_TRANSFER_DST_OPTIMAL, 1, ©_info); + } }); stream.ProfilerReady(); - // TODO(tulloch): should we instead make the staging buffer a property of the - // Stream? This would allow us to elide synchronizations here. stream.Synchronize(); + } else { LOG(FATAL) << "Expect copy from/to Vulkan or between Vulkan" << ", from=" << from_dev_type << ", to=" << to_dev_type; @@ -459,13 +737,148 @@ TVM_FFI_STATIC_INIT_BLOCK() { DeviceAPI* ptr = VulkanDeviceAPI::Global(); *rv = static_cast(ptr); }) - .def("device_api.vulkan.get_target_property", [](Device dev, const std::string& property) { - ffi::Any rv; - VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); - return rv; + .def("device_api.vulkan.get_target_property", + [](Device dev, const std::string& property) { + ffi::Any rv; + VulkanDeviceAPI::Global()->GetTargetProperty(dev, property, &rv); + return rv; + }) + .def_packed("device_api.vulkan.alloc_nd", + [](ffi::PackedArgs args, ffi::Any* rv) { + int32_t device_type = args[0].cast(); + int32_t device_id = args[1].cast(); + int32_t dtype_code_hint = args[2].cast(); + int32_t dtype_bits_hint = args[3].cast(); + std::string scope = args[4].cast(); + + CHECK(scope.find("texture") != std::string::npos); + int64_t ndim = args[5].cast(); + CHECK_EQ(ndim, 2); + int64_t* shape = static_cast(args[6].cast()); + int64_t width = shape[0]; + int64_t height = shape[1]; + int64_t depth = shape[2]; + + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + + DLDataType type_hint; + type_hint.code = static_cast(dtype_code_hint); + type_hint.bits = static_cast(dtype_bits_hint); + type_hint.lanes = 1; + + *rv = VulkanDeviceAPI::Global()->AllocDataSpace( + dev, static_cast(width), static_cast(height), + static_cast(depth), type_hint, + ffi::Optional("global.texture")); + }) + .def_packed("device_api.vulkan.free_nd", [](ffi::PackedArgs args, ffi::Any* rv) { + int32_t device_type = args[0].cast(); + int32_t device_id = args[1].cast(); + std::string scope = args[2].cast(); + CHECK(scope.find("texture") != std::string::npos); + void* data = args[3].cast(); + Device dev; + dev.device_type = static_cast(device_type); + dev.device_id = device_id; + VulkanDeviceAPI::Global()->FreeDataSpace(dev, data); + *rv = static_cast(0); }); } +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("profiling.timer.vulkan", + [](Device dev) { return Timer(ffi::make_object(dev)); }); +} + +class VulkanPooledAllocator final : public memory::PooledAllocator { + public: + explicit VulkanPooledAllocator() : PooledAllocator() {} + + bool AllowMemoryScope(const std::string& mem_scope) const final { + return ((mem_scope.find("texture") != std::string::npos) || mem_scope.empty() || + ("global" == mem_scope)); + } + + Buffer Alloc(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) override { + std::lock_guard lock(mu_); + size_t size = ((nbytes + page_size_ - 1) / page_size_) * page_size_; + auto&& it = memory_pool_.find(size); + if (it != memory_pool_.end() && !it->second.empty()) { + auto&& pool = it->second; + auto ret = pool.back(); + pool.pop_back(); + return ret; + } + Buffer buf; + buf.device = dev; + buf.size = size; + buf.alloc_type = AllocatorType::kPooled; + try { + buf.data = DeviceAllocDataSpace(dev, size, alignment, type_hint); + } catch (InternalError& err) { + LOG(WARNING) << "PooledAllocator got InternalError during allocation: " << err.message(); + LOG(WARNING) << "Trying to release all unused memory and reallocate..."; + ReleaseAll(); + buf.data = DeviceAllocDataSpace(dev, size, alignment, type_hint); + } + + used_memory_.fetch_add(size, std::memory_order_relaxed); + VLOG(1) << "allocate " << size << " B, used memory " << used_memory_ << " B"; + return buf; + } + + Buffer Alloc(Device dev, ffi::Shape shape, DLDataType type_hint, + const std::string& mem_scope) override { + if (AllowMemoryScope(mem_scope)) { + size_t size = GetMemObjectSize(dev, shape.size(), shape.data(), type_hint); + Buffer buf; + buf.device = dev; + buf.size = size; + buf.alloc_type = AllocatorType::kPooled; + buf.data = DeviceAPI::Get(dev)->AllocDataSpace(dev, shape.size(), shape.data(), type_hint, + ffi::String(mem_scope)); + if (mem_scope.find("texture") == std::string::npos) { + // All textures are backed by buffers - don't count in total memory + used_memory_.fetch_add(size, std::memory_order_relaxed); + } + DLOG(INFO) << "allocate " << size << " B, used memory " << used_memory_ << " B"; + return buf; + } + LOG(FATAL) << "Unsupported memory scope for this Allocator:" << mem_scope; + return {}; + } + + void Free(const Buffer& buffer) override { + std::lock_guard lock(mu_); + if (memory_pool_.find(buffer.size) == memory_pool_.end()) { + memory_pool_.emplace(buffer.size, std::vector{}); + } + memory_pool_.at(buffer.size).push_back(buffer); + VLOG(1) << "reclaim buffer " << buffer.size; + } + + void* CreateView(const Buffer& buffer, ffi::Shape shape, DLDataType type_hint, + const std::string& mem_scope) final { + return VulkanDeviceAPI::Global()->AllocDataSpaceView( + buffer.device, buffer.data, shape, type_hint, ffi::Optional(mem_scope)); + } + + void FreeView(Device dev, void* data) final { + return VulkanDeviceAPI::Global()->FreeDataSpaceView(dev, data); + } +}; + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def_packed("DeviceAllocator.vulkan", [](ffi::PackedArgs args, ffi::Any* rv) { + Allocator* alloc = new VulkanPooledAllocator(); + *rv = static_cast(alloc); + }); +} + } // namespace vulkan } // namespace runtime } // namespace tvm diff --git a/src/runtime/vulkan/vulkan_device_api.h b/src/runtime/vulkan/vulkan_device_api.h index 5e9bfeb8c086..b286b2766d37 100644 --- a/src/runtime/vulkan/vulkan_device_api.h +++ b/src/runtime/vulkan/vulkan_device_api.h @@ -21,11 +21,15 @@ #define TVM_RUNTIME_VULKAN_VULKAN_DEVICE_API_H_ #include +#include +#include #include +#include #include #include +#include "../texture.h" #include "../thread_map.h" #include "../workspace_pool.h" #include "vulkan/vulkan_core.h" @@ -47,8 +51,22 @@ class VulkanDeviceAPI final : public DeviceAPI { void GetAttr(Device dev, DeviceAttrKind kind, ffi::Any* rv) final; // Implement memory management required by DeviceAPI + void* AllocVulkanBuffer(Device dev, size_t nbytes, DLDataType type_hint, + std::shared_ptr memory); + void* AllocVulkanImage(Device dev, size_t width, size_t height, size_t layers, + DLDataType type_hint, ffi::Optional mem_scope, + std::shared_ptr memory); void* AllocDataSpace(Device dev, size_t nbytes, size_t alignment, DLDataType type_hint) final; + void* AllocDataSpace(Device dev, int ndim, const int64_t* shape, DLDataType dtype, + ffi::Optional mem_scope = std::nullopt) final; + void* AllocDataSpace(Device dev, size_t width, size_t height, size_t depth, DLDataType type_hint, + ffi::Optional mem_scope = std::nullopt); + void* AllocDataSpaceView(Device dev, void* data, ffi::Shape shape, DLDataType dtype, + ffi::Optional mem_scope = std::nullopt); + void FreeDataSpace(Device dev, void* ptr) final; + void FreeDataSpaceView(Device dev, void* ptr); + void* AllocWorkspace(Device dev, size_t size, DLDataType type_hint) final; void FreeWorkspace(Device dev, void* data) final; @@ -61,11 +79,62 @@ class VulkanDeviceAPI final : public DeviceAPI { void FreeStream(Device dev, TVMStreamHandle stream) final; void SyncStreamFromTo(Device dev, TVMStreamHandle event_src, TVMStreamHandle event_dst) final; void StreamSync(Device dev, TVMStreamHandle stream) final; - - protected: - void CopyDataFromTo(const void* from, size_t from_offset, void* to, size_t to_offset, size_t size, - Device dev_from, Device dev_to, DLDataType type_hint, - TVMStreamHandle stream) final; + void SetStream(Device dev, TVMStreamHandle stream) final; + TVMStreamHandle GetCurrentStream(Device dev) final; + size_t GetDataSize(const DLTensor& arr, + ffi::Optional mem_scope = std::nullopt) final; + + void CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHandle stream) final; + + // Check if the device is a Vulkan device + virtual bool IsVulkanDevice(Device dev) { return dev.device_type == kDLVulkan; } + + inline VkFormat DTypeToVulkanFormat(DLDataType data_type, int num_channels = 4) { + DataType dtype(data_type); + + // Print information about the DataType for debugging + // PrintDataTypeInfo(dtype); + if (num_channels == 1) { + if (dtype == DataType::Float(32)) { + return VK_FORMAT_R32_SFLOAT; + } else if (dtype == DataType::Float(16)) { + return VK_FORMAT_R16_SFLOAT; + } else if (dtype == DataType::Int(8)) { + return VK_FORMAT_R8_SINT; + } else if (dtype == DataType::Int(16)) { + return VK_FORMAT_R16_SINT; + } else if (dtype == DataType::Int(32)) { + return VK_FORMAT_R32_SINT; + } else if (dtype == DataType::UInt(8)) { + return VK_FORMAT_R8_UINT; + } else if (dtype == DataType::UInt(16)) { + return VK_FORMAT_R16_UINT; + } else if (dtype == DataType::UInt(32)) { + return VK_FORMAT_R32_UINT; + } + } else if (num_channels == 4) { + if (dtype == DataType::Float(32)) { + return VK_FORMAT_R32G32B32A32_SFLOAT; // 4-channel 32-bit float + } else if (dtype == DataType::Float(16)) { + return VK_FORMAT_R16G16B16A16_SFLOAT; // 4-channel 16-bit float + } else if (dtype == DataType::Int(8)) { + return VK_FORMAT_R8G8B8A8_SINT; // 4-channel 8-bit signed integer + } else if (dtype == DataType::Int(16)) { + return VK_FORMAT_R16G16B16A16_SINT; // 4-channel 16-bit signed integer + } else if (dtype == DataType::Int(32)) { + return VK_FORMAT_R32G32B32A32_SINT; // 4-channel 32-bit signed integer + } else if (dtype == DataType::UInt(8)) { + return VK_FORMAT_R8G8B8A8_UINT; // 4-channel 8-bit unsigned integer + } else if (dtype == DataType::UInt(16)) { + return VK_FORMAT_R16G16B16A16_UINT; // 4-channel 16-bit unsigned integer + } else if (dtype == DataType::UInt(32)) { + return VK_FORMAT_R32G32B32A32_UINT; // 4-channel 32-bit unsigned integer + } + } + LOG(FATAL) << "Unsupported data type or channel count for Vulkan runtime: " << dtype + << ", channels: " << num_channels; + return VK_FORMAT_UNDEFINED; // Fallback, should not reach here + } // End of required methods for the DeviceAPI interface @@ -107,6 +176,8 @@ class VulkanDeviceAPI final : public DeviceAPI { */ void GetTargetProperty(Device dev, const std::string& property, ffi::Any* rv) final; + size_t GetImageAlignment(Device dev); + private: std::vector GetComputeQueueFamilies(VkPhysicalDevice phy_dev); diff --git a/src/runtime/vulkan/vulkan_image.cc b/src/runtime/vulkan/vulkan_image.cc new file mode 100644 index 000000000000..d6f962814e4e --- /dev/null +++ b/src/runtime/vulkan/vulkan_image.cc @@ -0,0 +1,167 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_image.h" + +#include +#include +#include + +#include "vulkan_device_api.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VkImageCreateInfo MakeImageCreateInfo(VkFormat format, uint32_t width, uint32_t height, + uint32_t layers, VkImageUsageFlags usage) { + VkImageCreateInfo info = {VK_STRUCTURE_TYPE_IMAGE_CREATE_INFO}; + info.imageType = VK_IMAGE_TYPE_2D; + info.flags = 0; + info.format = format; + info.extent.width = width; + info.extent.height = height; + info.extent.depth = 1; // Must be 1 for 2d images + info.mipLevels = 1; + info.arrayLayers = layers; + info.samples = VK_SAMPLE_COUNT_1_BIT; + info.tiling = VK_IMAGE_TILING_LINEAR; + info.usage = usage; + info.sharingMode = VK_SHARING_MODE_EXCLUSIVE; + info.initialLayout = VK_IMAGE_LAYOUT_UNDEFINED; + return info; +} + +VulkanImage::VulkanImage(const VulkanDevice& device, VkFormat format, uint32_t width, + uint32_t height, uint32_t layers, VkImageUsageFlags usage, + uint32_t mem_type_index, std::optional mem_scope, + std::shared_ptr back_memory) + : VulkanResource(device, mem_scope, back_memory), width(width), height(height), layers(layers) { + // Create an image + VkImageCreateInfo image_info = MakeImageCreateInfo(format, width, height, layers, usage); + VULKAN_CALL(vkCreateImage(device, &image_info, nullptr, &image)); + + VkMemoryRequirements mem_reqs; + vkGetImageMemoryRequirements(device, image, &mem_reqs); + + // Allocate new memory if no memory is passed in or if the existing memory is not compatible + if (!memory) { + AllocateMemory(mem_reqs, mem_type_index); + } + // Bind the image to the allocated memory + VULKAN_CALL(vkBindImageMemory(device, image, memory->memory_, 0)); +} + +void VulkanImage::AllocateMemory(const VkMemoryRequirements& mem_reqs, uint32_t mem_type_index) { + VkMemoryAllocateInfo mem_info = {VK_STRUCTURE_TYPE_MEMORY_ALLOCATE_INFO}; + mem_info.allocationSize = mem_reqs.size; + mem_info.memoryTypeIndex = mem_type_index; + + // Allocate memory + VkDeviceMemory raw_memory; + VULKAN_CALL(vkAllocateMemory(device_, &mem_info, nullptr, &raw_memory)); + + // Store the allocated memory along with its requirements + memory = std::make_shared(raw_memory, mem_reqs); +} + +VulkanImage::~VulkanImage() { + if (imageView) { + vkDestroyImageView(device_, imageView, nullptr); + } + if (image) { + vkDestroyImage(device_, image, nullptr); + } +} + +VulkanImage::VulkanImage(VulkanImage&& other) + : VulkanResource(std::move(other)), image(other.image), imageView(other.imageView) { + other.image = VK_NULL_HANDLE; + other.imageView = VK_NULL_HANDLE; +} + +uint32_t VulkanImage::FindMemoryTypeForImage(const VulkanDevice& device, + VkMemoryPropertyFlags properties, + uint32_t typeFilter) { + VkPhysicalDeviceMemoryProperties memProperties; + VkPhysicalDevice physicalDeviceHandle = + device; // Implicit conversion using operator VkPhysicalDevice() + vkGetPhysicalDeviceMemoryProperties(physicalDeviceHandle, &memProperties); + + for (uint32_t i = 0; i < memProperties.memoryTypeCount; i++) { + if ((typeFilter & (1 << i)) && + (memProperties.memoryTypes[i].propertyFlags & properties) == properties) { + return i; + } + } + + throw std::runtime_error("Failed to find suitable memory type!"); +} + +VulkanImage& VulkanImage::operator=(VulkanImage&& other) { + std::swap(device_, other.device_); + std::swap(image, other.image); + std::swap(memory, other.memory); + std::swap(imageView, other.imageView); + return *this; +} + +void VulkanImage::CreateImageView(VkFormat format) { + VkImageViewCreateInfo view_info = {VK_STRUCTURE_TYPE_IMAGE_VIEW_CREATE_INFO}; + view_info.image = image; + view_info.viewType = VK_IMAGE_VIEW_TYPE_2D_ARRAY; + view_info.format = format; + view_info.subresourceRange.aspectMask = VK_IMAGE_ASPECT_COLOR_BIT; + view_info.subresourceRange.baseMipLevel = 0; + view_info.subresourceRange.levelCount = 1; + view_info.subresourceRange.baseArrayLayer = 0; + view_info.subresourceRange.layerCount = layers; + + VULKAN_CALL(vkCreateImageView(device_, &view_info, nullptr, &imageView)); +} + +bool VulkanImage::UseDedicatedAllocation(const VulkanDevice& device, VkImage image, + VkDeviceSize* nbytes) { + if (device.get_image_memory_requirements_2_functions) { + // Which image to request information about + VkImageMemoryRequirementsInfo2KHR req_info2 = { + VK_STRUCTURE_TYPE_IMAGE_MEMORY_REQUIREMENTS_INFO_2_KHR}; + req_info2.image = image; + + // What information to request + VkMemoryDedicatedRequirementsKHR dedicated_req; + dedicated_req.sType = VK_STRUCTURE_TYPE_MEMORY_DEDICATED_REQUIREMENTS_KHR; + dedicated_req.pNext = nullptr; + + VkMemoryRequirements2KHR req2 = {VK_STRUCTURE_TYPE_MEMORY_REQUIREMENTS_2_KHR}; + req2.pNext = &dedicated_req; + + device.get_image_memory_requirements_2_functions->vkGetImageMemoryRequirements2KHR( + device, &req_info2, &req2); + if (dedicated_req.requiresDedicatedAllocation || dedicated_req.prefersDedicatedAllocation) { + *nbytes = req2.memoryRequirements.size; + return true; + } + } + return false; +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_image.h b/src/runtime/vulkan/vulkan_image.h new file mode 100644 index 000000000000..a83875534c13 --- /dev/null +++ b/src/runtime/vulkan/vulkan_image.h @@ -0,0 +1,133 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_IMAGE_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_IMAGE_H_ + +#include + +#include +#include +#include +#include +#include + +#include "vulkan_resource.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanImage : public VulkanResource { + public: + /* \brief Allocate and create an image on the device + * + * \param device Which device should have the image allocation. + * The VulkanDevice given should outlive the VulkanImage. + * + * \param format The format of the image (e.g., VK_FORMAT_R32_SFLOAT) + * + * \param width The width of the image + * + * \param height The height of the image + * + * \param layers The array layers of the image + * + * \param usage The usage flags for the image (e.g. sampled, transfer destination, etc.) + * + * \param mem_type_index The memory type to index. This should be + * an index to a compatible memory located in + * VkPhysicalDeviceMemoryProperties. + */ + VulkanImage(const VulkanDevice& device, VkFormat format, uint32_t width, uint32_t height, + uint32_t depth, VkImageUsageFlags usage, uint32_t mem_type_index, + std::optional mem_scope = std::nullopt, + std::shared_ptr back_memory = nullptr); + + //! \brief Destructor, deallocates the memory, image, and image view. + ~VulkanImage(); + + // Forbid copy assignment/constructor + VulkanImage(const VulkanImage&) = delete; + VulkanImage& operator=(const VulkanImage&) = delete; + + // Allow move assignment/constructor + VulkanImage(VulkanImage&&); + VulkanImage& operator=(VulkanImage&&); + + void AllocateMemory(const VkMemoryRequirements& mem_reqs, uint32_t mem_type_index); + + void CreateImageView(VkFormat format); + + inline uint32_t FindMemoryTypeForImage(const VulkanDevice& device, + VkMemoryPropertyFlags properties, uint32_t typeFilter); + + private: + /*! + * \brief Whether this image should be allocated using dedicated allocation + * + * In typical usage, there will be one VkDeviceMemory that has a + * large number of VkImages pointing to it. Currently, the TVM + * Vulkan runtime has a single VkImage for each VkDeviceMemory. In + * this case, there can be performance benefits by explicitly + * marking this as a dedicated allocation. The function returns + * true if the device supports the dedicated allocation extension, + * and the image either requires or has better performance with a + * dedicated allocation. + * + * \param[out] nbytes If using dedicated allocation, the number of + * bytes required for the allocation. If not using dedicated + * allocation, this value is unchanged. + * + * \returns Whether the allocation should use the dedicated + * allocation extension. + */ + static bool UseDedicatedAllocation(const VulkanDevice& device, VkImage image, + VkDeviceSize* nbytes); + + public: + /*! \brief Pointer to the device that owns this image. + * + * Assumes that the VulkanImage will be destructed before the + * VulkanDevice, and this will never be a dangling reference. + * Stores a VkDevice and not a VulkanDevice, because the + * VulkanDevice may be moved to a different location while the + * VulkanImage is alive. + */ + + //! \brief Handle to the logical image on the device + VkImage image{VK_NULL_HANDLE}; + + //! \brief Handle to the image view + VkImageView imageView{VK_NULL_HANDLE}; + + // capture the memory requirements. + // VkMemoryRequirements mem_reqs; + + // Add width and height members + uint32_t width{0}; // Width of the image + uint32_t height{0}; // Height of the image + uint32_t layers{0}; // Depth of the image +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_IMAGE_H_ diff --git a/src/runtime/vulkan/vulkan_resource.cc b/src/runtime/vulkan/vulkan_resource.cc new file mode 100644 index 000000000000..ccc7f78e4845 --- /dev/null +++ b/src/runtime/vulkan/vulkan_resource.cc @@ -0,0 +1,88 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_resource.h" + +#include + +#include +#include + +#include "vulkan_device_api.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanResource::VulkanResource(const VulkanDevice& device, std::optional mem_scope, + std::shared_ptr back_memory) + : device_(device), layout(MemoryLayoutFromScope(mem_scope)), memory(back_memory) {} + +VulkanResource::~VulkanResource() {} + +VulkanResource::VulkanResource(VulkanResource&& other) + : device_(other.device_), layout(other.layout), memory(other.memory) { + other.device_ = VK_NULL_HANDLE; + other.memory = VK_NULL_HANDLE; +} + +VulkanResource& VulkanResource::operator=(VulkanResource&& other) { + if (this != &other) { + device_ = other.device_; + layout = other.layout; + memory = other.memory; + } + return *this; +} + +VulkanResource::MemoryLayout VulkanResource::MemoryLayoutFromScope( + std::optional mem_scope) { + if (!mem_scope) { + return MemoryLayout::kBuffer1D; + } else if (*mem_scope == "global") { + return MemoryLayout::kBuffer1D; + } else if (*mem_scope == "global.texture") { + return MemoryLayout::kImage2DActivation; + } else if (*mem_scope == "global.texture-weight") { + return MemoryLayout::kImage2DWeight; + } else if (*mem_scope == "global.texture-nhwc") { + return MemoryLayout::kImage2DNHWC; + } + throw std::runtime_error("No memory layout defined for memory of scope: " + *mem_scope); +} + +std::string VulkanResource::ScopeFromMemoryLayout(MemoryLayout layout) { + switch (layout) { + case MemoryLayout::kBuffer1D: + return "global"; + case MemoryLayout::kImage2DActivation: + return "global.texture"; + case MemoryLayout::kImage2DWeight: + return "global.texture-weight"; + case MemoryLayout::kImage2DNHWC: + return "global.texture-nhwc"; + default: + throw std::runtime_error("No scope corresponding to the provided memory layout: " + + std::to_string(static_cast(layout))); + } +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_resource.h b/src/runtime/vulkan/vulkan_resource.h new file mode 100644 index 000000000000..3c19f432c618 --- /dev/null +++ b/src/runtime/vulkan/vulkan_resource.h @@ -0,0 +1,131 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_RESOURCE_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_RESOURCE_H_ + +#include + +#include +#include +#include + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanDevice; + +/*! + * \brief Class representing Vulkan device memory allocations. + * + * This class encapsulates a Vulkan device memory allocation and its memory requirements. + * It provides functionality to check memory compatibility with new resource requirements. + */ +class VulkanMemory { + public: + /*! + * \brief Constructor to create a VulkanMemory instance. + * + * \param mem The Vulkan device memory handle. + * \param mem_reqs The memory requirements associated with this allocation. + */ + VulkanMemory(VkDeviceMemory mem, const VkMemoryRequirements& mem_reqs) + : memory_(mem), mem_reqs_(mem_reqs) {} + + /*! + * \brief Destructor to free the Vulkan device memory. + */ + ~VulkanMemory() { + if (memory_ != VK_NULL_HANDLE) { + memory_ = VK_NULL_HANDLE; + } + } + + VkDeviceMemory memory_; + VkMemoryRequirements mem_reqs_; +}; + +/*! + * \brief Base class for Vulkan resources such as buffers and images. + * + * This class holds common properties and functionalities for Vulkan resources, + * including device association, memory layout, and memory management. + */ +class VulkanResource { + public: + /*! + * \brief Enumeration of memory layout types. + */ + enum class MemoryLayout { + kBuffer1D, + kImage2DActivation, + kImage2DWeight, + kImage2DNHWC, + }; + + /*! + * \brief Constructor to create a VulkanResource. + * + * \param device The Vulkan device associated with this resource. + * \param mem_scope Optional memory scope string specifying the memory layout. + * \param back_memory Optional shared pointer to existing VulkanMemory. + */ + VulkanResource(const VulkanDevice& device, std::optional mem_scope, + std::shared_ptr back_memory = nullptr); + + /*! + * \brief Virtual destructor. + */ + virtual ~VulkanResource(); + + // Forbid copy assignment/constructor + VulkanResource(const VulkanResource&) = delete; + VulkanResource& operator=(const VulkanResource&) = delete; + + // Allow move assignment/constructor + VulkanResource(VulkanResource&& other); + VulkanResource& operator=(VulkanResource&& other); + + /*! + * \brief Converts a memory scope string to a MemoryLayout enumeration. + * + * \param mem_scope The optional memory scope string. + * \return The corresponding MemoryLayout value. + */ + static MemoryLayout MemoryLayoutFromScope(std::optional mem_scope); + + /*! + * \brief Converts a MemoryLayout enumeration to a memory scope string. + * + * \param layout The MemoryLayout value. + * \return The corresponding memory scope string. + */ + static std::string ScopeFromMemoryLayout(MemoryLayout layout); + + VkDevice device_{VK_NULL_HANDLE}; + MemoryLayout layout{MemoryLayout::kBuffer1D}; + std::shared_ptr memory{nullptr}; +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_RESOURCE_H_ diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index 742a66f15dd4..c8e08f5d57b8 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -117,6 +117,8 @@ class VulkanStream { // Synchronize the current stream `state_` with respect to the host. void Synchronize(); + VkCommandPool CommanPool() const { return cmd_pool_; } + private: const VulkanDevice* device_; std::unique_ptr state_; diff --git a/src/runtime/vulkan/vulkan_timer.cc b/src/runtime/vulkan/vulkan_timer.cc new file mode 100644 index 000000000000..87758f88e722 --- /dev/null +++ b/src/runtime/vulkan/vulkan_timer.cc @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "vulkan_timer.h" + +#include + +#include "vulkan_device_api.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +VulkanTimerNode::VulkanTimerNode(Device dev) : dev_(dev) { + // Get the Vulkan device and stream + auto& vk_dev = VulkanDeviceAPI::Global()->device(dev_.device_id); + stream_ = &vk_dev.ThreadLocalStream(); + device_ = vk_dev; + + // Retrieve the timestamp period from device properties + timestamp_period_ = vk_dev.device_properties.timestamp_period; + + CreateQueryPool(); +} + +VulkanTimerNode::~VulkanTimerNode() { Cleanup(); } + +void VulkanTimerNode::CreateQueryPool() { + VkQueryPoolCreateInfo query_pool_info{}; + query_pool_info.sType = VK_STRUCTURE_TYPE_QUERY_POOL_CREATE_INFO; + query_pool_info.queryType = VK_QUERY_TYPE_TIMESTAMP; + query_pool_info.queryCount = 2; + + VkResult res = vkCreateQueryPool(device_, &query_pool_info, nullptr, &query_pool_); + ICHECK(res == VK_SUCCESS) << "Failed to create Vulkan query pool."; +} + +void VulkanTimerNode::Start() { + stream_->Launch([this](VulkanStreamState* state) { + // Reset the query pool before writing timestamps + vkCmdResetQueryPool(state->cmd_buffer_, query_pool_, start_query_, 2); + vkCmdWriteTimestamp(state->cmd_buffer_, VK_PIPELINE_STAGE_TOP_OF_PIPE_BIT, query_pool_, + start_query_); + }); +} + +void VulkanTimerNode::Stop() { + stream_->Launch([this](VulkanStreamState* state) { + vkCmdWriteTimestamp(state->cmd_buffer_, VK_PIPELINE_STAGE_BOTTOM_OF_PIPE_BIT, query_pool_, + end_query_); + }); + + // Ensure GPU has finished writing timestamps before collecting them + stream_->Synchronize(); + CollectTimestamps(); +} + +int64_t VulkanTimerNode::SyncAndGetElapsedNanos() { return duration_; } + +void VulkanTimerNode::CollectTimestamps() { + uint64_t timestamps[2] = {0}; + + VkResult result = + vkGetQueryPoolResults(device_, query_pool_, 0, 2, sizeof(timestamps), timestamps, + sizeof(uint64_t), VK_QUERY_RESULT_64_BIT | VK_QUERY_RESULT_WAIT_BIT); + + ICHECK(result == VK_SUCCESS) << "Failed to get Vulkan query pool results."; + + // Calculate the duration in nanoseconds + uint64_t diff = timestamps[1] - timestamps[0]; + duration_ = static_cast(diff * timestamp_period_); +} + +void VulkanTimerNode::Cleanup() { + if (query_pool_ != VK_NULL_HANDLE) { + vkDestroyQueryPool(device_, query_pool_, nullptr); + query_pool_ = VK_NULL_HANDLE; + } +} + +} // namespace vulkan +} // namespace runtime +} // namespace tvm diff --git a/src/runtime/vulkan/vulkan_timer.h b/src/runtime/vulkan/vulkan_timer.h new file mode 100644 index 000000000000..9ca420f86a47 --- /dev/null +++ b/src/runtime/vulkan/vulkan_timer.h @@ -0,0 +1,102 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#ifndef TVM_RUNTIME_VULKAN_VULKAN_TIMER_H_ +#define TVM_RUNTIME_VULKAN_VULKAN_TIMER_H_ + +#include +#include + +#include "vulkan_device.h" +#include "vulkan_stream.h" + +namespace tvm { +namespace runtime { +namespace vulkan { + +class VulkanDevice; + +/*! + * \brief Timer node for measuring GPU execution time using Vulkan. + * + * This class uses Vulkan timestamp queries to measure the time taken + * by GPU operations between `Start()` and `Stop()` calls. + */ +class VulkanTimerNode : public TimerNode { + public: + /*! + * \brief Constructs a VulkanTimerNode for the specified device. + * \param dev The TVM device to be used for timing. + */ + explicit VulkanTimerNode(Device dev); + + /*! + * \brief Destructor to clean up Vulkan resources. + */ + ~VulkanTimerNode() override; + + /*! + * \brief Starts the timer by recording a timestamp. + */ + void Start() override; + + /*! + * \brief Stops the timer by recording another timestamp. + */ + void Stop() override; + + /*! + * \brief Retrieves the elapsed time in nanoseconds. + * \return The elapsed time in nanoseconds between Start and Stop. + */ + int64_t SyncAndGetElapsedNanos() override; + + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.opencl.VulkanTimerNode", VulkanTimerNode, TimerNode); + + private: + Device dev_; ///< The TVM device being used. + VkDevice device_{VK_NULL_HANDLE}; ///< The Vulkan device handle. + VulkanStream* stream_{nullptr}; ///< The Vulkan stream for command buffer management. + VkQueryPool query_pool_{VK_NULL_HANDLE}; ///< The Vulkan query pool for timestamp queries. + float timestamp_period_; ///< The period (in nanoseconds) for each timestamp tick. + uint32_t start_query_ = 0; ///< The index for the start timestamp query. + uint32_t end_query_ = 1; ///< The index for the end timestamp query. + int64_t duration_ = 0; ///< The measured duration in nanoseconds. + + /*! + * \brief Creates a Vulkan query pool for timestamp queries. + */ + void CreateQueryPool(); + + /*! + * \brief Collects timestamps and calculates the duration. + */ + void CollectTimestamps(); + + /*! + * \brief Cleans up the Vulkan query pool. + */ + void Cleanup(); +}; + +} // namespace vulkan +} // namespace runtime +} // namespace tvm + +#endif // TVM_RUNTIME_VULKAN_VULKAN_TIMER_H_ diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index 007d6abdbadb..b282fc7f931d 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -25,6 +25,7 @@ #include "../file_utils.h" #include "vulkan_device_api.h" +#include "vulkan_resource.h" namespace tvm { namespace runtime { @@ -45,6 +46,7 @@ void VulkanWrappedFunc::Init(VulkanModuleNode* m, ObjectPtr sptr, void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, const ArgUnion64* pack_args) const { int device_id = VulkanDeviceAPI::Global()->GetActiveDeviceID(); + const auto total_function_args = num_buffer_args_; auto& device = VulkanDeviceAPI::Global()->device(device_id); if (!scache_[device_id]) { scache_[device_id] = m_->GetPipeline(device_id, func_name_, num_pack_args_); @@ -52,15 +54,37 @@ void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, const auto& pipeline = scache_[device_id]; ThreadWorkLoad wl = launch_param_config_.Extract(args); std::vector descriptor_buffers; - descriptor_buffers.resize(num_buffer_args_); + std::vector descriptor_images; + + descriptor_buffers.reserve(num_buffer_args_); + descriptor_images.reserve(num_buffer_args_); + for (size_t i = 0; i < num_buffer_args_; ++i) { - void* buf = args[static_cast(i)].cast(); - VkDescriptorBufferInfo binfo; - binfo.buffer = static_cast(buf)->buffer; - binfo.offset = 0; - binfo.range = VK_WHOLE_SIZE; - descriptor_buffers[i] = binfo; + void* res_ = args[static_cast(i)].cast(); + VulkanResource* res = static_cast(res_); + + if (auto* buffer = dynamic_cast(res)) { + VkDescriptorBufferInfo binfo; + binfo.buffer = buffer->buffer; + binfo.offset = 0; + binfo.range = VK_WHOLE_SIZE; + descriptor_buffers.push_back(binfo); + } else if (auto* image = dynamic_cast(res)) { + VkDescriptorImageInfo iinfo; + iinfo.imageView = image->imageView; + iinfo.imageLayout = VK_IMAGE_LAYOUT_GENERAL; + descriptor_images.push_back(iinfo); + } + } + + // Check that the total number of descriptors matches num_buffer_args_ + if (descriptor_buffers.size() + descriptor_images.size() != num_buffer_args_) { + std::cerr << "Error: The number of buffers and images does not match num_buffer_args_" + << std::endl; + // Handle the error appropriately (e.g., throw an exception, return, etc.) + throw std::runtime_error("Mismatch in the number of function arguments and descriptor sets."); } + const size_t nbytes_scalars = num_pack_args_ * sizeof(ArgUnion64); if (pipeline->use_ubo) { auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); @@ -71,13 +95,39 @@ void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, descriptor_buffers.push_back(binfo); } if (device.UseImmediate()) { + std::vector descriptor_data; + descriptor_data.resize(descriptor_buffers.size() * sizeof(VkDescriptorBufferInfo) + + descriptor_images.size() * sizeof(VkDescriptorImageInfo)); + + size_t offset = 0; + size_t buffer_idx = 0, image_idx = 0; + + for (size_t i = 0; i < total_function_args; ++i) { + void* res_ = args[static_cast(i)].cast(); + VulkanResource* res = static_cast(res_); + if (dynamic_cast(res)) { + std::memcpy(descriptor_data.data() + offset, &descriptor_buffers[buffer_idx++], + sizeof(VkDescriptorBufferInfo)); + offset += sizeof(VkDescriptorBufferInfo); + } else if (dynamic_cast(res)) { + std::memcpy(descriptor_data.data() + offset, &descriptor_images[image_idx++], + sizeof(VkDescriptorImageInfo)); + offset += sizeof(VkDescriptorImageInfo); + } + } + + if (pipeline->use_ubo) { + std::memcpy(descriptor_data.data() + offset, &descriptor_buffers[buffer_idx++], + sizeof(VkDescriptorBufferInfo)); + offset += sizeof(VkDescriptorBufferInfo); + } // Can safely capture by reference as this lambda is immediately executed on the calling thread. device.ThreadLocalStream().Launch([&](VulkanStreamState* state) { vkCmdBindPipeline(state->cmd_buffer_, VK_PIPELINE_BIND_POINT_COMPUTE, pipeline->pipeline); ICHECK(pipeline->descriptor_update_template != VK_NULL_HANDLE); device.descriptor_template_khr_functions->vkCmdPushDescriptorSetWithTemplateKHR( state->cmd_buffer_, pipeline->descriptor_update_template, pipeline->pipeline_layout, 0, - descriptor_buffers.data()); + descriptor_data.data()); if (pipeline->use_ubo) { auto& ubo = device.ThreadLocalUniformBuffer(nbytes_scalars); @@ -101,7 +151,7 @@ void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, if (device.UseDebugUtilsLabel()) { VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, - nullptr, + NULL, func_name_.c_str(), {0.0f, 0.0f, 0.0f, 0.0f}}; device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT( @@ -113,29 +163,38 @@ void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, // Otherwise, the more expensive deferred path. std::vector pack_args_storage(pack_args, pack_args + num_pack_args_); - const auto& deferred_initializer = [&device, pipeline, descriptor_buffers]() { + const auto& deferred_initializer = [&device, pipeline, descriptor_buffers, descriptor_images, + args, total_function_args]() { std::vector write_descriptor_sets; - write_descriptor_sets.resize(descriptor_buffers.size()); - for (size_t i = 0; i < write_descriptor_sets.size(); i++) { - write_descriptor_sets[i].sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; - write_descriptor_sets[i].pNext = nullptr; - write_descriptor_sets[i].dstSet = pipeline->descriptor_set; - write_descriptor_sets[i].dstBinding = i; - write_descriptor_sets[i].dstArrayElement = 0; - write_descriptor_sets[i].descriptorCount = 1; - write_descriptor_sets[i].pImageInfo = nullptr; - write_descriptor_sets[i].pBufferInfo = &(descriptor_buffers[i]); - write_descriptor_sets[i].pTexelBufferView = nullptr; - - if (pipeline->use_ubo && i == write_descriptor_sets.size() - 1) { - // The last binding is for UBO - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER; - } else { - write_descriptor_sets[i].descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + write_descriptor_sets.reserve(descriptor_buffers.size() + descriptor_images.size()); + + size_t buffer_idx = 0, image_idx = 0; + // Iterate over the arguments to determine their bindings + for (size_t i = 0; i < total_function_args; ++i) { + void* res_ = args[static_cast(i)].cast(); + VulkanResource* res = static_cast(res_); + VkWriteDescriptorSet write_set = {}; + write_set.sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET; + write_set.pNext = nullptr; + write_set.dstSet = pipeline->descriptor_set; + write_set.dstBinding = i; + write_set.dstArrayElement = 0; + write_set.descriptorCount = 1; + if (dynamic_cast(res)) { + write_set.descriptorType = + (buffer_idx == descriptor_buffers.size() - 1 && pipeline->use_ubo) + ? VK_DESCRIPTOR_TYPE_UNIFORM_BUFFER + : VK_DESCRIPTOR_TYPE_STORAGE_BUFFER; + write_set.pBufferInfo = &(descriptor_buffers[buffer_idx++]); + write_descriptor_sets.push_back(write_set); + } else if (dynamic_cast(res)) { + write_set.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE; + write_set.pImageInfo = &(descriptor_images[image_idx++]); + write_descriptor_sets.push_back(write_set); } } vkUpdateDescriptorSets(device, write_descriptor_sets.size(), write_descriptor_sets.data(), 0, - nullptr); + 0); }; const auto& deferred_kernel = [this, pipeline, wl, pack_args_storage, nbytes_scalars, device_id](VulkanStreamState* state) { @@ -176,7 +235,7 @@ void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, if (device.UseDebugUtilsLabel()) { VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, - nullptr, + NULL, func_name_.c_str(), {0.0f, 0.0f, 0.0f, 0.0f}}; device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT( @@ -278,8 +337,15 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, tpl.dstArrayElement = 0; tpl.descriptorCount = 1; tpl.descriptorType = desc_type; - tpl.offset = binding * sizeof(VkDescriptorBufferInfo); - tpl.stride = sizeof(VkDescriptorBufferInfo); + + // Choose the appropriate size for image descriptors + if (desc_type == VK_DESCRIPTOR_TYPE_STORAGE_IMAGE) { + tpl.offset = binding * sizeof(VkDescriptorImageInfo); + tpl.stride = sizeof(VkDescriptorImageInfo); + } else { + tpl.offset = binding * sizeof(VkDescriptorBufferInfo); + tpl.stride = sizeof(VkDescriptorBufferInfo); + } arg_template.push_back(tpl); } }; @@ -287,10 +353,17 @@ std::shared_ptr VulkanModuleNode::GetPipeline(size_t device_id, { auto fit = fmap_.find(func_name); ICHECK(fit != fmap_.end()); - for (DLDataType arg_type : fit->second.arg_types) { - if (arg_type.code == kDLOpaqueHandle) { - push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); - ++num_buffer; + const auto& info = fit->second; + + for (size_t i = 0; i < info.arg_types.size(); ++i) { + if (info.arg_types[i].code == kDLOpaqueHandle) { + if (runtime::IsTextureStorage(info.storage_scopes[i])) { + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_IMAGE); + ++num_buffer; // Increment num_image here + } else { + push_arg_info(num_buffer, VK_DESCRIPTOR_TYPE_STORAGE_BUFFER); + ++num_buffer; // Increment num_buffer here + } } else { ++num_pod; } diff --git a/src/support/libinfo.cc b/src/support/libinfo.cc index d0646ee8b06f..93d88bc6fe56 100644 --- a/src/support/libinfo.cc +++ b/src/support/libinfo.cc @@ -78,6 +78,10 @@ #define TVM_INFO_USE_VULKAN "NOT-FOUND" #endif +#ifndef TVM_INFO_USE_VULKAN_GTEST +#define TVM_INFO_USE_VULKAN_GTEST "NOT-FOUND" +#endif + #ifndef TVM_INFO_USE_METAL #define TVM_INFO_USE_METAL "NOT-FOUND" #endif @@ -352,6 +356,7 @@ TVM_DLL ffi::Map GetLibInfo() { {"USE_THRUST", TVM_INFO_USE_THRUST}, {"USE_CURAND", TVM_INFO_USE_CURAND}, {"USE_VULKAN", TVM_INFO_USE_VULKAN}, + {"USE_VULKAN_GTEST", TVM_INFO_USE_VULKAN_GTEST}, {"USE_CLML", TVM_INFO_USE_CLML}, {"TVM_CLML_VERSION", TVM_INFO_USE_TVM_CLML_VERSION}, {"USE_CLML_GRAPH_EXECUTOR", TVM_INFO_USE_CLML_GRAPH_EXECUTOR}, diff --git a/src/target/build_common.h b/src/target/build_common.h index cf1e3344fc3c..afc09e47aa1a 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -59,6 +59,14 @@ inline std::unordered_map ExtractFuncInfo(co info.arg_extra_tags.push_back(is_tensormap(f->params[i]) ? runtime::FunctionInfo::ArgExtraTags::kTensorMap : runtime::FunctionInfo::ArgExtraTags::kNone); + + // Get the storage scope from the type annotation if available + if (auto* ptr = f->params[i]->type_annotation.as()) { + info.storage_scopes.push_back(std::string(ptr->storage_scope)); + } else { + info.storage_scopes.push_back( + ""); // Use an empty string or "default" if no storage scope is provided + } } if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { for (const auto& tag : opt.value()) { diff --git a/src/target/spirv/build_vulkan.cc b/src/target/spirv/build_vulkan.cc index f71b7ef8d6fa..35f9e333c14f 100644 --- a/src/target/spirv/build_vulkan.cc +++ b/src/target/spirv/build_vulkan.cc @@ -24,6 +24,8 @@ #include +#include + #include "../../runtime/spirv/spirv_shader.h" #include "../../runtime/vulkan/vulkan_module.h" #include "../build_common.h" @@ -43,5 +45,21 @@ TVM_FFI_STATIC_INIT_BLOCK() { [](IRModule mod, Target target) { return BuildSPIRV(mod, target); }); } +ffi::String VulkanDeviceScopeCompatibilityFromTarget(Target target, ffi::String memory_scope) { + auto prototype_keys = target->GetKeys(); + bool is_adreno = + std::find(prototype_keys.begin(), prototype_keys.end(), "adreno") != prototype_keys.end(); + if (is_adreno) { + return ffi::String("global"); + } + return memory_scope; +} + +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("DeviceScopeCompatibility.vulkan", + VulkanDeviceScopeCompatibilityFromTarget); +} + } // namespace codegen } // namespace tvm diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 136f969896f5..a513690ef713 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -30,6 +30,7 @@ #include #include "../../runtime/pack_args.h" +#include "../../runtime/texture.h" #include "../../runtime/vulkan/vulkan_common.h" #include "../../tir/transforms/ir_utils.h" @@ -42,7 +43,9 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s this->InitFuncState(); ICHECK(f->HasNonzeroAttr(tir::attr::kNoAlias)) << "SPIRV only takes restricted memory model"; std::vector pod_args; - uint32_t i_buffer = 0; + + // binding for images and buffers + uint32_t binding_index = 0; // Currently, all storage and uniform buffer arguments are passed as // a single descriptor set at index 0. If ever non-zero, must @@ -66,8 +69,16 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s // The loaded byte is cast to bool inside the LoadNode visitor below. value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); } - spirv::Value arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), - descriptor_set, i_buffer++); + + spirv::Value arg_value; // Declare arg_value before the if-else block + if (ptr && runtime::IsTextureStorage(std::string(ptr->storage_scope))) { + arg_value = builder_->StorageImageArgument(arg->name_hint, value_storage_type, 2, 2, + descriptor_set, binding_index++); + } else { + arg_value = builder_->BufferArgument(builder_->GetSType(value_storage_type), descriptor_set, + binding_index++); + } + builder_->SetName(arg_value, arg->name_hint); storage_info_[arg.get()].SetContentType(value_storage_type, arg->name_hint); var_map_[arg.get()] = arg_value; @@ -77,7 +88,6 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s } spirv::Value func_ptr = builder_->NewFunction(); builder_->StartFunction(func_ptr); - runtime::SPIRVShader shader; if (pod_args.size() != 0) { @@ -95,7 +105,8 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s } else { shader.flag |= 1 << runtime::vulkan::ShaderMetaDataFlagMask::kUseUBO; // If we need to pass more arguments than push constants could handle, we use UBO. - spirv::Value ptr = builder_->DeclareUniformBuffer(value_types, descriptor_set, i_buffer++); + spirv::Value ptr = + builder_->DeclareUniformBuffer(value_types, descriptor_set, binding_index++); for (size_t i = 0; i < pod_args.size(); ++i) { spirv::Value value = builder_->GetUniform(ptr, value_types[i], static_cast(i)); var_map_[pod_args[i].get()] = value; @@ -511,6 +522,71 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { return builder_->StructArrayAccess(ptr_type, var_map_[buffer_node], MakeValue(index)); } else if (op->op.same_as(builtin::tvm_thread_invariant())) { return MakeValue(op->args[0]); + } else if (op->op.same_as(builtin::texture2d_store())) { + ICHECK_EQ(op->args.size(), 6U); + + // Extract the four arguments and convert them to SPIR-V values + spirv::Value image = MakeValue(op->args[0]); // image + spirv::Value coord_x = MakeValue(op->args[1]); // x-coordinate + spirv::Value coord_y = MakeValue(op->args[2]); // y-coordinate + spirv::Value layer_index = MakeValue(op->args[3]); // layer_index + spirv::Value texel = MakeValue(op->args.back()); + + // Create a composite value representing the coordinates (int3) + spirv::Value coord = + builder_->MakeComposite(builder_->GetSType(DataType::Int(32).with_lanes(3)), // Type: int3 + {coord_x, coord_y, layer_index}); + + spirv::SType image_type = builder_->QuerySType(op->args[0].as()->name_hint); + spirv::Value loaded_image = builder_->MakeValue(spv::OpLoad, image_type, image); + + // Generate the SPIR-V instruction to store the value in the texture + builder_->MakeInst(spv::OpImageWrite, loaded_image, coord, texel); + return spirv::Value(); // No result for image store + + } else if (op->op.same_as(builtin::texture2d_load())) { + ICHECK_EQ(op->args.size(), 6U); + + // Extract the three arguments and convert them to SPIR-V values + spirv::SType image_type = builder_->QuerySType(op->args[0].as()->name_hint); + spirv::Value image = MakeValue(op->args[0]); // image + spirv::Value coord_x = MakeValue(op->args[1]); // x-coordinate + spirv::Value coord_y = MakeValue(op->args[2]); // y-coordinate + spirv::Value layer_index = MakeValue(op->args[3]); // layer_index + + // Attempt to create a composite value representing the coordinates (int3) + spirv::Value coord = + builder_->MakeComposite(builder_->GetSType(DataType::Int(32).with_lanes(3)), // Type: int3 + {coord_x, coord_y, layer_index}); + + spirv::Value loaded_image = + builder_->MakeValue(spv::OpLoad, image_type, image); // Load the image handle + spirv::Value image_texel = builder_->MakeValue( + spv::OpImageRead, builder_->GetSType(op->dtype.with_lanes(4)), loaded_image, coord); + + if (op->args.back().as()) { + return image_texel; + } else { + std::vector components; + // Extract the required component from the vector + spirv::SType element_type = + builder_->GetSType(op->dtype.with_lanes(1)); // Scalar type (float) + spirv::Value index = MakeValue(op->args.back()); // Index to extract + spirv::Value component = + builder_->MakeValue(spv::OpVectorExtractDynamic, element_type, image_texel, index); + + if (op->dtype.lanes() > 1) { + // Create a vector by duplicating the extracted component + for (int i = 0; i < op->dtype.lanes(); i++) { + components.push_back(component); + } + // Combine the components into a single vector + return builder_->Concat(components); + } else { + return component; + } + } + } else { LOG(FATAL) << "Unresolved call " << op->op; } diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index bac66a3aacf7..33fca7701288 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -534,6 +534,136 @@ SType IRBuilder::DeclareType(const DataType& dtype, uint32_t row, uint32_t col) } } +void IRBuilder::RegisterSType(const std::string& name, SType var_stype) { + auto it = stype_name_tbl_.find(name); + if (it != stype_name_tbl_.end()) { + LOG(FATAL) << name << " already exists."; + return; + } + stype_name_tbl_[name] = var_stype; +} + +SType IRBuilder::QuerySType(const std::string& name) { + auto it = stype_name_tbl_.find(name); + if (it != stype_name_tbl_.end()) { + return it->second; + } + LOG(FATAL) << "Value \"" << name << "\" does not yet exist."; + return SType(); // Return an empty Value (this line may not be reached due to LOG(FATAL)) +} + +bool IRBuilder::CheckSTypeExistence(const std::string& name) { + return stype_name_tbl_.find(name) != stype_name_tbl_.end(); +} + +spv::ImageFormat IRBuilder::GetImageFormat(const DataType& dtype, int channels) { + // Handle float formats + if (dtype.is_float()) { + switch (dtype.bits()) { + case 32: + if (channels == 1) return spv::ImageFormatR32f; + if (channels == 2) return spv::ImageFormatRg32f; + if (channels == 4) return spv::ImageFormatRgba32f; + break; + case 16: + if (channels == 1) return spv::ImageFormatR16f; + if (channels == 2) return spv::ImageFormatRg16f; + if (channels == 4) return spv::ImageFormatRgba16f; + break; + default: + return spv::ImageFormatUnknown; + } + } else if (dtype.is_int()) { + switch (dtype.bits()) { + case 32: + if (channels == 1) return spv::ImageFormatR32i; + if (channels == 2) return spv::ImageFormatRg32i; + if (channels == 4) return spv::ImageFormatRgba32i; + break; + case 16: + if (channels == 1) return spv::ImageFormatR16i; + if (channels == 2) return spv::ImageFormatRg16i; + if (channels == 4) return spv::ImageFormatRgba16i; + break; + case 8: + if (channels == 1) return spv::ImageFormatR8i; + if (channels == 2) return spv::ImageFormatRg8i; + if (channels == 4) return spv::ImageFormatRgba8i; + break; + default: + return spv::ImageFormatUnknown; + } + } + return spv::ImageFormatUnknown; +} + +SType IRBuilder::GetStorageImageSType(const DataType& dtype, int num_dimensions, uint32_t sampled) { + // Get the appropriate SPIR-V ImageFormat using the dtype + spv::ImageFormat spv_format = GetImageFormat(dtype, 4); + + // get SPIR-V type for image + SType value_type = GetSType(dtype); + if (spv_format == spv::ImageFormatUnknown) { + LOG(FATAL) << "Unsupported image format for dtype: " << dtype; + } + + // Create a key to cache and reuse image types + auto key = std::make_tuple(spv_format, num_dimensions, sampled); + auto it = storage_image_ptr_tbl_.find(key); + if (it != storage_image_ptr_tbl_.end()) { + return it->second; + } + + // Determine the SPIR-V dimension based on the number of dimensions + spv::Dim dim; + if (num_dimensions == 1) { + dim = spv::Dim1D; + } else if (num_dimensions == 2) { + dim = spv::Dim2D; + } else if (num_dimensions == 3) { + dim = spv::Dim3D; + } else { + LOG(FATAL) << "Unsupported number of dimensions: " << num_dimensions; + } + + // Generate a unique ID for the new image type + int img_id = id_counter_++; + + // Declare the SPIR-V image type + ib_.Begin(spv::OpTypeImage) + .AddSeq(img_id, value_type, dim, + /*Depth=*/0, /*Arrayed=*/1, /*MS=*/0, /*Sampled=*/sampled, spv_format) + .Commit(&global_); + + // Create and cache the new image type + SType img_t; + img_t.id = img_id; + img_t.element_type_id = value_type.id; + storage_image_ptr_tbl_[key] = img_t; + return img_t; +} + +Value IRBuilder::StorageImageArgument(const std::string& name, const DataType& dtype, + int num_dimensions, uint32_t sampled, uint32_t descriptor_set, + uint32_t binding) { + auto texture_type = GetStorageImageSType(dtype, num_dimensions, sampled); + auto texture_ptr_type = GetPointerType(texture_type, spv::StorageClassUniformConstant); + + // Store the type in the map + RegisterSType(name, texture_type); + Value val = NewValue(texture_ptr_type, kVariablePtr); + + // Variable declaration + ib_.Begin(spv::OpVariable) + .AddSeq(texture_ptr_type, val, spv::StorageClassUniformConstant) + .Commit(&global_); + + // Decorate the image argument + this->Decorate(spv::OpDecorate, val, spv::DecorationDescriptorSet, descriptor_set); + this->Decorate(spv::OpDecorate, val, spv::DecorationBinding, binding); + return val; +} + void IRBuilder::AddCapabilityFor(const DataType& dtype) { // Declare appropriate capabilities for int/float types if (dtype.is_int() || dtype.is_uint()) { diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 5df779c59547..5f06c9025c09 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -63,6 +63,7 @@ enum ValueKind { kVectorPtr, kStructArrayPtr, kPushConstantPtr, + kVariablePtr, kFunction, kExtInst, kUniformPtr, @@ -607,6 +608,47 @@ class IRBuilder { */ Value GetSpecConst(const SType& dtype, uint64_t value); + SType f32_type() const { return t_fp32_; } + + SType i32_type() const { return t_int32_; } + + SType f32_v4_type() const { return t_v4_fp32_; } + + // Register name to corresponding Value/VariablePointer + void RegisterSType(const std::string& name, SType var_stype); + // Query Value/VariablePointer by name + SType QuerySType(const std::string& name); + // Check whether a value has been evaluated + bool CheckSTypeExistence(const std::string& name); + + Value MakeComposite(const SType& composite_type, const std::vector& constituents) { + // Create a new SSA value for the composite type + Value composite_value = NewValue(composite_type, kNormal); + + // Begin the OpCompositeConstruct instruction + ib_.Begin(spv::OpCompositeConstruct) + .Add(composite_type) // The type of the composite + .Add(composite_value); // The resulting value + + // Add each constituent value + for (const Value& val : constituents) { + ib_.Add(val); + } + + // Commit the instruction to the function segment + ib_.Commit(&function_); + + // Return the composite value + return composite_value; + } + + spv::ImageFormat GetImageFormat(const DataType& dtype, int channels); + + Value StorageImageArgument(const std::string& name, const DataType& dtype, int num_dimensions, + uint32_t sampled, uint32_t descriptor_set, uint32_t binding); + + SType GetStorageImageSType(const DataType& dtype, int num_dimensions, uint32_t sampled); + private: /*! * \brief Create new value @@ -679,7 +721,8 @@ class IRBuilder { /*! \brief glsl 450 extension */ Value ext_glsl450_; /*! \brief Special cache int32, fp32, void*/ - SType t_bool_, t_int32_, t_uint32_, t_fp32_, t_void_, t_void_func_; + SType t_bool_, t_int32_, t_uint32_, t_v2_int_, t_fp16_, t_fp32_, t_v4_fp32_, t_void_, + t_void_func_; /*! \brief quick cache for const one i32 */ Value const_i32_zero_; @@ -723,6 +766,9 @@ class IRBuilder { /*! \brief map from name of a ExtInstImport to its value */ std::map ext_inst_tbl_; + std::map, SType> storage_image_ptr_tbl_; + std::unordered_map stype_name_tbl_; + /*! \brief Header segment * * 5 words long, described in "First Words of Physical Layout" diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc index 96e90f17ac79..0e4af4b12b43 100644 --- a/src/target/target_kind.cc +++ b/src/target/target_kind.cc @@ -393,10 +393,10 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("supports_int8") .add_attr_option("supports_int16") .add_attr_option("supports_int32", true) - .add_attr_option("supports_int64") + .add_attr_option("supports_int64", true) .add_attr_option("supports_8bit_buffer") .add_attr_option("supports_16bit_buffer") - .add_attr_option("supports_storage_buffer_storage_class") + .add_attr_option("supports_storage_buffer_storage_class", true) .add_attr_option("supports_push_descriptor") .add_attr_option("supports_dedicated_allocation") .add_attr_option("supports_integer_dot_product") @@ -406,6 +406,8 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("max_num_threads", 256) .add_attr_option("max_threads_per_block", 256) .add_attr_option("thread_warp_size", 1) + .add_attr_option("texture_spatial_limit", 16384) + .add_attr_option("texture_depth_limit", 2048) .add_attr_option("max_block_size_x") .add_attr_option("max_block_size_y") .add_attr_option("max_block_size_z") @@ -421,6 +423,7 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan) .add_attr_option("driver_version") .add_attr_option("vulkan_api_version") .add_attr_option("max_spirv_version") + .add_attr_option("image_base_address_alignment", 64) // Tags .set_default_keys({"vulkan", "gpu"}); diff --git a/tests/cpp-runtime/vulkan/texture_copy_test.cc b/tests/cpp-runtime/vulkan/texture_copy_test.cc new file mode 100644 index 000000000000..0762d73e1672 --- /dev/null +++ b/tests/cpp-runtime/vulkan/texture_copy_test.cc @@ -0,0 +1,158 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include +#include +#include +#include +#include + +#include +#include + +#include "../src/runtime/vulkan/vulkan_device_api.h" + +using tvm::runtime::memory::AllocatorType; +using tvm::runtime::memory::MemoryManager; +using tvm::runtime::memory::Storage; + +class VulkanTextureCopyTest : public ::testing::Test { + protected: + void SetUp() override { + bool enabled = tvm::runtime::RuntimeEnabled("vulkan"); + if (!enabled) { + GTEST_SKIP() << "Skip texture copy test because Vulkan runtime is disabled.\n"; + } + } +}; + +TEST_F(VulkanTextureCopyTest, ViewBufferAsBuffer) { + using namespace tvm; + std::vector shape{1, 16, 16, 8}; + std::vector same_shape{1, 8, 16, 16}; + auto cpu_arr = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr_ret = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + ffi::String mem_scope = "global"; + + DLDevice cl_dev = {kDLVulkan, 0}; + auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); + auto buffer = allocator->Alloc(cl_dev, ffi::Shape(shape), {kDLFloat, 32, 1}); + auto stor = Storage(buffer, allocator); + + auto vulkan_memobj = stor->AllocTensorScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, mem_scope); + auto vulkan_memview = + stor->AllocTensorScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, mem_scope); + + std::random_device dev; + std::mt19937 mt(dev()); + std::uniform_real_distribution<> random(-10.0, 10.0); + + size_t size = 1; + for (size_t i = 0; i < shape.size(); ++i) { + size *= static_cast(shape[i]); + } + + /* Check original object round trip */ + // Random initialize host pool storage + for (size_t i = 0; i < size; i++) { + static_cast(cpu_arr->data)[i] = random(mt); + } + // Copy to VulkanBuffer + cpu_arr.CopyTo(vulkan_memobj); + // Copy from VulkanBuffer + vulkan_memobj.CopyTo(cpu_arr_ret); + for (size_t i = 0; i < size; i++) { + ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); + } + + /* Check view object round trip */ + // Random initialize host pool storage + for (size_t i = 0; i < size; i++) { + static_cast(cpu_arr->data)[i] = random(mt); + } + // Copy to VulkanBuffer + cpu_arr.CopyTo(vulkan_memview); + // Copy from VulkanBuffer + vulkan_memview.CopyTo(cpu_arr_ret); + for (size_t i = 0; i < size; i++) { + ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); + } +} + +TEST_F(VulkanTextureCopyTest, ViewBufferAsImage) { + using namespace tvm; + // Shape that doesn't cause padding for image row + std::vector shape{1, 16, 16, 8, 4}; + std::vector same_shape{1, 8, 16, 16, 4}; + auto cpu_arr = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + auto cpu_arr_ret = runtime::Tensor::Empty(shape, {kDLFloat, 32, 1}, {kDLCPU, 0}); + + DLDevice cl_dev = {kDLVulkan, 0}; + auto allocator = MemoryManager::GetOrCreateAllocator(cl_dev, AllocatorType::kPooled); + auto buffer = allocator->Alloc(cl_dev, ffi::Shape(shape), {kDLFloat, 32, 1}); + auto stor = Storage(buffer, allocator); + + auto vulkan_buf_obj = stor->AllocTensorScoped(0, ffi::Shape(shape), {kDLFloat, 32, 1}, "global"); + auto vulkan_img_obj = + stor->AllocTensorScoped(0, ffi::Shape(same_shape), {kDLFloat, 32, 1}, "global.texture"); + + std::random_device dev; + std::mt19937 mt(dev()); + std::uniform_real_distribution<> random(-10.0, 10.0); + + size_t size = 1; + for (size_t i = 0; i < shape.size(); ++i) { + size *= static_cast(shape[i]); + } + + /* Check original object round trip */ + // Random initialize host pool storage + for (size_t i = 0; i < size; i++) { + static_cast(cpu_arr->data)[i] = random(mt); + } + // Copy to VulkanBuffer + cpu_arr.CopyTo(vulkan_buf_obj); + // Copy from VulkanBuffer + vulkan_buf_obj.CopyTo(cpu_arr_ret); + for (size_t i = 0; i < size; i++) { + ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); + } + + /* Check view object round trip */ + // Random initialize host pool storage + for (size_t i = 0; i < size; i++) { + static_cast(cpu_arr->data)[i] = random(mt); + } + // Copy to VulkanBuffer + cpu_arr.CopyTo(vulkan_img_obj); + // Copy from VulkanBuffer + vulkan_img_obj.CopyTo(cpu_arr_ret); + for (size_t i = 0; i < size; i++) { + ICHECK_LT(std::fabs(static_cast(cpu_arr->data)[i] - + static_cast(cpu_arr_ret->data)[i]), + 1e-5); + } +} diff --git a/tests/cpp-runtime/vulkan/vulkan_timer_test.cc b/tests/cpp-runtime/vulkan/vulkan_timer_test.cc new file mode 100644 index 000000000000..b02f6c9ceec9 --- /dev/null +++ b/tests/cpp-runtime/vulkan/vulkan_timer_test.cc @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +#include "../src/runtime/vulkan/vulkan_timer.h" + +#include +#include + +#include "../src/runtime/vulkan/vulkan_device_api.h" + +using namespace tvm::runtime; +using namespace tvm::runtime::vulkan; + +#define BUFF_SIZE 1024 +#define NUM_REPEAT 10 + +TEST(VulkanTimerNode, TimerCorrectness) { + VulkanDeviceAPI* api = VulkanDeviceAPI::Global(); + auto device_id = api->GetActiveDeviceID(); + tvm::Device dev{kDLVulkan, device_id}; + + constexpr int32_t kBufferSize = 1024; + Tensor src = Tensor::Empty({kBufferSize}, {kDLInt, 32, 1}, {kDLCPU, 0}); + Tensor dst = Tensor::Empty({kBufferSize}, {kDLInt, 32, 1}, {kDLVulkan, device_id}); + + // Fill CPU array with dummy data + for (int32_t i = 0; i < kBufferSize; ++i) { + static_cast(src->data)[i] = i; + } + + // Create a Timer + Timer timer = Timer::Start(dev); + + // Perform a CPU -> Vulkan copy + src.CopyTo(dst); + + // Important: Force Vulkan to flush and sync work + api->StreamSync(dev, nullptr); + + timer->Stop(); + int64_t elapsed_nanos = timer->SyncAndGetElapsedNanos(); + + std::cout << "Elapsed time (nanoseconds): " << elapsed_nanos << std::endl; + + // Check that some time was measured + ASSERT_GT(elapsed_nanos, 0); +} diff --git a/tests/python/relax/texture/test_ops.py b/tests/python/relax/texture/test_ops.py index 7011852aa3ab..5f85b2040837 100644 --- a/tests/python/relax/texture/test_ops.py +++ b/tests/python/relax/texture/test_ops.py @@ -23,8 +23,8 @@ from adreno_utils import verify -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d(target): @I.ir_module class Input: @@ -40,8 +40,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_relu(target): @I.ir_module class Input: @@ -58,8 +58,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_relu_conv2d_relu(target): @I.ir_module class Input: @@ -77,8 +77,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_relu_tanh(target): @I.ir_module class Input: @@ -96,8 +96,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_add(target): @I.ir_module class Input: @@ -116,8 +116,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_sum(target): @I.ir_module class Input: @@ -134,8 +134,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_sum_keepdims(target): @I.ir_module class Input: @@ -152,8 +152,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_sum_reduce(target): @I.ir_module class Input: @@ -170,8 +170,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_transpose(target): @I.ir_module class Input: @@ -188,8 +188,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_expand_dims(target): @I.ir_module class Input: @@ -206,8 +206,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_squeeze(target): @I.ir_module class Input: @@ -224,8 +224,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_strided_slice(target): @I.ir_module class Input: @@ -244,8 +244,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_relu_concat(target): @I.ir_module class Input: @@ -263,8 +263,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_relu_concat_split(target): @I.ir_module class Input: @@ -283,8 +283,8 @@ def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "fl verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_relu_concat_split_transpose_concat(target): @I.ir_module class Input: @@ -304,8 +304,8 @@ def main(x: R.Tensor((2, 16, 28, 28), "float32"), w: R.Tensor((4, 16, 3, 3), "fl verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_maxpool2d(target): @I.ir_module class Input: @@ -329,8 +329,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_avgpool2d(target): @I.ir_module class Input: @@ -347,8 +347,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_softmax(target): @I.ir_module class Input: @@ -365,8 +365,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_layernorm(target): @I.ir_module class Input: @@ -388,8 +388,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_binary_broadcast(target): @I.ir_module class Input: @@ -408,8 +408,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_binary_ewise_scalar(target): @I.ir_module class Input: @@ -426,8 +426,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_residual_block(target): """ - some kind of residual block followed by convolution to have texture after residual block @@ -474,8 +474,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_conv2d_fallback_to_buffer_conv2d(target): """ layout_transform (NCHW->NCHW4c) @@ -512,11 +512,11 @@ def main( R.output(gv7) return gv7 - verify(Input, "opencl") + verify(Input, "opencl", "vulkan") -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_conv2d_conv2d_conv2d_concat(target): """ layout_transform (NCHW->NCHW4c) @@ -553,11 +553,11 @@ def main( R.output(gv7) return gv7 - verify(Input, "opencl") + verify(Input, "opencl", "vulkan") -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_pooling_branching_texture_params(target): """ Verification of the pooling and many branches having textures @@ -610,8 +610,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_injective_inputs1(target): """ Input @@ -659,8 +659,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_injective_nwo_inputs2(target): """ Input diff --git a/tests/scripts/ci.py b/tests/scripts/ci.py index 96f0a5f7edaa..9b0c4101437a 100755 --- a/tests/scripts/ci.py +++ b/tests/scripts/ci.py @@ -593,6 +593,7 @@ def add_subparser( [ "./tests/scripts/task_java_unittest.sh", "./tests/scripts/task_opencl_cpp_unittest.sh {build_dir}", + "./tests/scripts/task_vulkan_cpp_unittest.sh {build_dir}", "./tests/scripts/task_python_unittest_gpuonly.sh", "./tests/scripts/task_python_integration_gpuonly.sh", ], diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh index 10fefefbe800..424e11280ed5 100755 --- a/tests/scripts/task_config_build_adreno.sh +++ b/tests/scripts/task_config_build_adreno.sh @@ -24,6 +24,8 @@ cd "$BUILD_DIR" cp ../cmake/config.cmake . echo set\(USE_OPENCL_GTEST /googletest\) >> config.cmake +#echo set\(USE_VULKAN_GTEST /googletest\) >> config.cmake + if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake fi @@ -31,3 +33,4 @@ echo set\(USE_OPENCL ON\) >> config.cmake echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake echo set\(USE_LLVM ON\) >> config.cmake +#echo set\(USE_VULKAN ON\) >> config.cmake diff --git a/tests/scripts/task_config_build_gpu.sh b/tests/scripts/task_config_build_gpu.sh index f306bdf8bf74..a71a2cfedfe4 100755 --- a/tests/scripts/task_config_build_gpu.sh +++ b/tests/scripts/task_config_build_gpu.sh @@ -29,6 +29,7 @@ echo set\(USE_CUDA ON\) >> config.cmake echo set\(USE_VULKAN ON\) >> config.cmake echo set\(USE_OPENCL ON\) >> config.cmake echo set\(USE_OPENCL_GTEST \"/googletest\"\) >> config.cmake +echo set\(USE_VULKAN_GTEST \"/googletest\"\) >> config.cmake echo set\(USE_LLVM \"/usr/bin/llvm-config-15 --link-static\"\) >> config.cmake echo set\(USE_RPC ON\) >> config.cmake echo set\(USE_SORT ON\) >> config.cmake diff --git a/tests/scripts/task_python_adreno.sh b/tests/scripts/task_python_adreno.sh index b381fddc2427..ceb045b06657 100755 --- a/tests/scripts/task_python_adreno.sh +++ b/tests/scripts/task_python_adreno.sh @@ -18,8 +18,9 @@ set -euxo pipefail -export TVM_TEST_TARGETS="opencl" -export TVM_RELAY_OPENCL_TEXTURE_TARGETS="opencl -device=adreno" +export TVM_TEST_TARGETS="opencl;vulkan" +export TVM_RELAX_TEXTURE_TARGETS="opencl -device=adreno;vulkan -device=adreno" + source tests/scripts/setup-pytest-env.sh export LD_LIBRARY_PATH="build:${LD_LIBRARY_PATH:-}" diff --git a/tests/scripts/task_vulkan_cpp_unittest.sh b/tests/scripts/task_vulkan_cpp_unittest.sh new file mode 100644 index 000000000000..79debce955bc --- /dev/null +++ b/tests/scripts/task_vulkan_cpp_unittest.sh @@ -0,0 +1,38 @@ +#!/usr/bin/env bash +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +set -euxo pipefail + +if [ $# -gt 0 ]; then + BUILD_DIR="$1" +elif [ -n "${TVM_BUILD_PATH:-}" ]; then + # TVM_BUILD_PATH may contain multiple space-separated paths. If + # so, use the first one. + BUILD_DIR=$(IFS=" "; set -- $TVM_BUILD_PATH; echo $1) +else + BUILD_DIR=build +fi + +# to avoid CI thread throttling. +export TVM_BIND_THREADS=0 +export OMP_NUM_THREADS=1 + +pushd "${BUILD_DIR}" +# run cpp test executable +./vulkan-cpptest +popd From 1f78997be4994f499dce82f25c2b209c3cbf74d6 Mon Sep 17 00:00:00 2001 From: Siva Date: Mon, 19 Jan 2026 14:44:12 +0530 Subject: [PATCH 2/7] review. --- cmake/utils/FindVulkan.cmake | 1 - src/runtime/vulkan/vulkan_stream.h | 2 +- src/runtime/vulkan/vulkan_timer.h | 2 +- src/runtime/vulkan/vulkan_wrapped_func.cc | 4 ++-- 4 files changed, 4 insertions(+), 5 deletions(-) diff --git a/cmake/utils/FindVulkan.cmake b/cmake/utils/FindVulkan.cmake index c01d019e14d8..b1b9693a2f6c 100644 --- a/cmake/utils/FindVulkan.cmake +++ b/cmake/utils/FindVulkan.cmake @@ -82,7 +82,6 @@ macro(find_vulkan use_vulkan use_khronos_spirv) else() - message(STATUS "__vulkan_sdk:- " ${__vulkan_sdk}) if(__vulkan_sdk) set(Vulkan_INCLUDE_DIRS ${__vulkan_sdk}/include) find_library(Vulkan_LIBRARY NAMES vulkan vulkan-1 PATHS ${__vulkan_sdk}/lib) diff --git a/src/runtime/vulkan/vulkan_stream.h b/src/runtime/vulkan/vulkan_stream.h index c8e08f5d57b8..89d79d8a09aa 100644 --- a/src/runtime/vulkan/vulkan_stream.h +++ b/src/runtime/vulkan/vulkan_stream.h @@ -117,7 +117,7 @@ class VulkanStream { // Synchronize the current stream `state_` with respect to the host. void Synchronize(); - VkCommandPool CommanPool() const { return cmd_pool_; } + VkCommandPool CommandPool() const { return cmd_pool_; } private: const VulkanDevice* device_; diff --git a/src/runtime/vulkan/vulkan_timer.h b/src/runtime/vulkan/vulkan_timer.h index 9ca420f86a47..7636fcd812bf 100644 --- a/src/runtime/vulkan/vulkan_timer.h +++ b/src/runtime/vulkan/vulkan_timer.h @@ -67,7 +67,7 @@ class VulkanTimerNode : public TimerNode { */ int64_t SyncAndGetElapsedNanos() override; - TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.opencl.VulkanTimerNode", VulkanTimerNode, TimerNode); + TVM_FFI_DECLARE_OBJECT_INFO_FINAL("runtime.vulkan.VulkanTimerNode", VulkanTimerNode, TimerNode); private: Device dev_; ///< The TVM device being used. diff --git a/src/runtime/vulkan/vulkan_wrapped_func.cc b/src/runtime/vulkan/vulkan_wrapped_func.cc index b282fc7f931d..8d78aa7da85a 100644 --- a/src/runtime/vulkan/vulkan_wrapped_func.cc +++ b/src/runtime/vulkan/vulkan_wrapped_func.cc @@ -151,7 +151,7 @@ void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, if (device.UseDebugUtilsLabel()) { VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, - NULL, + nullptr, func_name_.c_str(), {0.0f, 0.0f, 0.0f, 0.0f}}; device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT( @@ -235,7 +235,7 @@ void VulkanWrappedFunc::operator()(ffi::PackedArgs args, ffi::Any* rv, if (device.UseDebugUtilsLabel()) { VkDebugUtilsLabelEXT dispatch_label = {VK_STRUCTURE_TYPE_DEBUG_UTILS_LABEL_EXT, - NULL, + nullptr, func_name_.c_str(), {0.0f, 0.0f, 0.0f, 0.0f}}; device.queue_insert_debug_utils_label_functions->vkQueueInsertDebugUtilsLabelEXT( From 1c9aa97ab0cae5b511806f14550667ba2757af12 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 20 Jan 2026 09:16:45 +0530 Subject: [PATCH 3/7] Targets to test. --- python/tvm/testing/utils.py | 2 +- tests/scripts/task_python_unittest_gpuonly.sh | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index b2d685a49e7e..4167759a8caf 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1000,7 +1000,7 @@ def _check_opencl_vulkan(): "opencl_vulkan", "OpenCL or Vulkan", run_time_check=_check_opencl_vulkan, - parent_features=["opencl", "vulkan"], + parent_features=["opencl", "gpu"], ) # Mark a test as requiring NNAPI support in build. diff --git a/tests/scripts/task_python_unittest_gpuonly.sh b/tests/scripts/task_python_unittest_gpuonly.sh index 776d29fda07f..217c6dbd2755 100755 --- a/tests/scripts/task_python_unittest_gpuonly.sh +++ b/tests/scripts/task_python_unittest_gpuonly.sh @@ -38,7 +38,7 @@ run_pytest ${TVM_UNITTEST_TESTSUITE_NAME}-1 tests/python/codegen/test_target_cod # Adreno : A comprehensive Texture tests on Nvidia GPU and clml codegen tests. export PYTEST_ADDOPTS="" -export TVM_TEST_TARGETS="opencl" +export TVM_TEST_TARGETS="opencl;vulkan -from_device=0" export TVM_UNITTEST_TESTSUITE_NAME=python-codegen-clml-texture source tests/scripts/setup-pytest-env.sh From 4ddd4efbbb50a0e312fd5720088b962c0a4de133 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 20 Jan 2026 15:27:03 +0530 Subject: [PATCH 4/7] Bulkan tests and Vulkan codegen issue fixed. --- python/tvm/tir/pipeline.py | 2 +- src/target/spirv/codegen_spirv.cc | 6 +++--- tests/python/relax/texture/test_ops.py | 4 ++-- 3 files changed, 6 insertions(+), 6 deletions(-) diff --git a/python/tvm/tir/pipeline.py b/python/tvm/tir/pipeline.py index 1ee4a5b1d315..8dd563752a95 100644 --- a/python/tvm/tir/pipeline.py +++ b/python/tvm/tir/pipeline.py @@ -178,7 +178,7 @@ def get_default_tir_pipeline( target: tvm.target.Target, # pylint: disable=unused-argument ) -> tvm.transform.Pass: """Get the default TIR pipeline for the given target.""" - if target.kind.name == "opencl" and "adreno" in target.keys: + if target.kind.name in ["opencl", "vulkan"] and "adreno" in target.keys: return backend.adreno.get_tir_pipeline(target) else: return default_tir_pipeline() diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index a513690ef713..9918803571a3 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -752,13 +752,12 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { spirv::Value init_value = MakeValue(op->min); PrimExpr end = is_zero(op->min) ? op->extent : analyzer_->Simplify(op->min + op->extent); spirv::Value end_value = MakeValue(end); - spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); // loop step spirv::Value step; if (op->HasTrivialStep()) { - step = op->loop_var.dtype().is_int() ? builder_->IntImm(loop_var.stype, 1) - : builder_->UIntImm(loop_var.stype, 1); + step = op->loop_var.dtype().is_int() ? builder_->IntImm(init_value.stype, 1) + : builder_->UIntImm(init_value.stype, 1); } else { step = MakeValue(tvm::cast(end->dtype, *op->step)); } @@ -777,6 +776,7 @@ void CodeGenSPIRV::VisitStmt_(const ForNode* op) { // Loop head builder_->StartLabel(head_label); + spirv::PhiValue loop_var = builder_->MakePhi(init_value.stype, 2); loop_var.SetIncoming(0, init_value, init_label); spirv::Value loop_cond = builder_->LT(loop_var, end_value); uint32_t control = diff --git a/tests/python/relax/texture/test_ops.py b/tests/python/relax/texture/test_ops.py index 5f85b2040837..a8abd22b90b4 100644 --- a/tests/python/relax/texture/test_ops.py +++ b/tests/python/relax/texture/test_ops.py @@ -512,7 +512,7 @@ def main( R.output(gv7) return gv7 - verify(Input, "opencl", "vulkan") + verify(Input, target) @tvm.testing.requires_opencl_vulkan @@ -553,7 +553,7 @@ def main( R.output(gv7) return gv7 - verify(Input, "opencl", "vulkan") + verify(Input, target) @tvm.testing.requires_opencl_vulkan From 66b90acccd3a9decc8eea7efcf442b9706c59cee Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 20 Jan 2026 15:38:37 +0530 Subject: [PATCH 5/7] tests --- tests/python/relax/texture/test_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/texture/test_ops.py b/tests/python/relax/texture/test_ops.py index a8abd22b90b4..bbd593bc97e2 100644 --- a/tests/python/relax/texture/test_ops.py +++ b/tests/python/relax/texture/test_ops.py @@ -474,8 +474,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl_vulkan -@tvm.testing.parametrize_targets("opencl", "vulkan") +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl") def test_conv2d_conv2d_fallback_to_buffer_conv2d(target): """ layout_transform (NCHW->NCHW4c) @@ -515,8 +515,8 @@ def main( verify(Input, target) -@tvm.testing.requires_opencl_vulkan -@tvm.testing.parametrize_targets("opencl", "vulkan") +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl") def test_conv2d_conv2d_conv2d_concat(target): """ layout_transform (NCHW->NCHW4c) From a283dcbb8d66dbbb9f09d82b6167d5d9c718c784 Mon Sep 17 00:00:00 2001 From: Siva Date: Tue, 20 Jan 2026 20:22:09 +0530 Subject: [PATCH 6/7] tests --- tests/python/relax/texture/test_network.py | 4 ++-- tests/python/relax/texture/test_ops.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python/relax/texture/test_network.py b/tests/python/relax/texture/test_network.py index aeb4f1248c10..7321d3827120 100644 --- a/tests/python/relax/texture/test_network.py +++ b/tests/python/relax/texture/test_network.py @@ -37,8 +37,8 @@ import copy -@tvm.testing.requires_opencl -@tvm.testing.parametrize_targets("opencl") +@tvm.testing.requires_opencl_vulkan +@tvm.testing.parametrize_targets("opencl", "vulkan") def test_network_resnet(target): @I.ir_module class Resnet: diff --git a/tests/python/relax/texture/test_ops.py b/tests/python/relax/texture/test_ops.py index bbd593bc97e2..7c0f523b42ba 100644 --- a/tests/python/relax/texture/test_ops.py +++ b/tests/python/relax/texture/test_ops.py @@ -23,8 +23,8 @@ from adreno_utils import verify -@tvm.testing.requires_opencl_vulkan -@tvm.testing.parametrize_targets("opencl", "vulkan") +@tvm.testing.requires_opencl +@tvm.testing.parametrize_targets("opencl") def test_conv2d(target): @I.ir_module class Input: From 35c37cba6377603997ce33b1a50f0ac294649022 Mon Sep 17 00:00:00 2001 From: ramashin Date: Thu, 22 Jan 2026 12:00:35 +0530 Subject: [PATCH 7/7] Lint and Vulkan image fix --- src/runtime/vulkan/vulkan_buffer.h | 1 - src/runtime/vulkan/vulkan_image.cc | 18 ------------------ src/runtime/vulkan/vulkan_image.h | 4 ---- src/target/build_common.h | 3 +-- src/target/spirv/codegen_spirv.cc | 14 +++++++------- src/target/spirv/ir_builder.cc | 2 +- src/target/spirv/ir_builder.h | 6 ++---- 7 files changed, 11 insertions(+), 37 deletions(-) diff --git a/src/runtime/vulkan/vulkan_buffer.h b/src/runtime/vulkan/vulkan_buffer.h index d5cbd2149dc5..464829489f5c 100644 --- a/src/runtime/vulkan/vulkan_buffer.h +++ b/src/runtime/vulkan/vulkan_buffer.h @@ -54,7 +54,6 @@ class VulkanBuffer : public VulkanResource { uint32_t mem_type_index, std::optional mem_scope = std::nullopt, std::shared_ptr back_memory = nullptr); - //! \brief Destructor, deallocates the memory and buffer. ~VulkanBuffer(); // Forbid copy assignment/constructor diff --git a/src/runtime/vulkan/vulkan_image.cc b/src/runtime/vulkan/vulkan_image.cc index d6f962814e4e..8a6f731a2bdb 100644 --- a/src/runtime/vulkan/vulkan_image.cc +++ b/src/runtime/vulkan/vulkan_image.cc @@ -96,24 +96,6 @@ VulkanImage::VulkanImage(VulkanImage&& other) other.imageView = VK_NULL_HANDLE; } -uint32_t VulkanImage::FindMemoryTypeForImage(const VulkanDevice& device, - VkMemoryPropertyFlags properties, - uint32_t typeFilter) { - VkPhysicalDeviceMemoryProperties memProperties; - VkPhysicalDevice physicalDeviceHandle = - device; // Implicit conversion using operator VkPhysicalDevice() - vkGetPhysicalDeviceMemoryProperties(physicalDeviceHandle, &memProperties); - - for (uint32_t i = 0; i < memProperties.memoryTypeCount; i++) { - if ((typeFilter & (1 << i)) && - (memProperties.memoryTypes[i].propertyFlags & properties) == properties) { - return i; - } - } - - throw std::runtime_error("Failed to find suitable memory type!"); -} - VulkanImage& VulkanImage::operator=(VulkanImage&& other) { std::swap(device_, other.device_); std::swap(image, other.image); diff --git a/src/runtime/vulkan/vulkan_image.h b/src/runtime/vulkan/vulkan_image.h index a83875534c13..ddea19184b55 100644 --- a/src/runtime/vulkan/vulkan_image.h +++ b/src/runtime/vulkan/vulkan_image.h @@ -60,7 +60,6 @@ class VulkanImage : public VulkanResource { std::optional mem_scope = std::nullopt, std::shared_ptr back_memory = nullptr); - //! \brief Destructor, deallocates the memory, image, and image view. ~VulkanImage(); // Forbid copy assignment/constructor @@ -75,9 +74,6 @@ class VulkanImage : public VulkanResource { void CreateImageView(VkFormat format); - inline uint32_t FindMemoryTypeForImage(const VulkanDevice& device, - VkMemoryPropertyFlags properties, uint32_t typeFilter); - private: /*! * \brief Whether this image should be allocated using dedicated allocation diff --git a/src/target/build_common.h b/src/target/build_common.h index afc09e47aa1a..190943a313b5 100644 --- a/src/target/build_common.h +++ b/src/target/build_common.h @@ -64,8 +64,7 @@ inline std::unordered_map ExtractFuncInfo(co if (auto* ptr = f->params[i]->type_annotation.as()) { info.storage_scopes.push_back(std::string(ptr->storage_scope)); } else { - info.storage_scopes.push_back( - ""); // Use an empty string or "default" if no storage scope is provided + info.storage_scopes.push_back(""); } } if (auto opt = f->GetAttr>(tir::attr::kKernelLaunchParams)) { diff --git a/src/target/spirv/codegen_spirv.cc b/src/target/spirv/codegen_spirv.cc index 9918803571a3..e201172cfb21 100644 --- a/src/target/spirv/codegen_spirv.cc +++ b/src/target/spirv/codegen_spirv.cc @@ -70,7 +70,7 @@ runtime::SPIRVShader CodeGenSPIRV::BuildFunction(const PrimFunc& f, const std::s value_storage_type = boolean_storage_type_.with_lanes(value_storage_type.lanes()); } - spirv::Value arg_value; // Declare arg_value before the if-else block + spirv::Value arg_value; if (ptr && runtime::IsTextureStorage(std::string(ptr->storage_scope))) { arg_value = builder_->StorageImageArgument(arg->name_hint, value_storage_type, 2, 2, descriptor_set, binding_index++); @@ -532,9 +532,9 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value layer_index = MakeValue(op->args[3]); // layer_index spirv::Value texel = MakeValue(op->args.back()); - // Create a composite value representing the coordinates (int3) + // Composite value representing the coordinates (int3) spirv::Value coord = - builder_->MakeComposite(builder_->GetSType(DataType::Int(32).with_lanes(3)), // Type: int3 + builder_->MakeComposite(builder_->GetSType(DataType::Int(32).with_lanes(3)), {coord_x, coord_y, layer_index}); spirv::SType image_type = builder_->QuerySType(op->args[0].as()->name_hint); @@ -542,7 +542,7 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { // Generate the SPIR-V instruction to store the value in the texture builder_->MakeInst(spv::OpImageWrite, loaded_image, coord, texel); - return spirv::Value(); // No result for image store + return spirv::Value(); } else if (op->op.same_as(builtin::texture2d_load())) { ICHECK_EQ(op->args.size(), 6U); @@ -554,13 +554,13 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) { spirv::Value coord_y = MakeValue(op->args[2]); // y-coordinate spirv::Value layer_index = MakeValue(op->args[3]); // layer_index - // Attempt to create a composite value representing the coordinates (int3) + // Create a composite value representing the coordinates (int3) spirv::Value coord = - builder_->MakeComposite(builder_->GetSType(DataType::Int(32).with_lanes(3)), // Type: int3 + builder_->MakeComposite(builder_->GetSType(DataType::Int(32).with_lanes(3)), {coord_x, coord_y, layer_index}); spirv::Value loaded_image = - builder_->MakeValue(spv::OpLoad, image_type, image); // Load the image handle + builder_->MakeValue(spv::OpLoad, image_type, image); spirv::Value image_texel = builder_->MakeValue( spv::OpImageRead, builder_->GetSType(op->dtype.with_lanes(4)), loaded_image, coord); diff --git a/src/target/spirv/ir_builder.cc b/src/target/spirv/ir_builder.cc index 33fca7701288..951c799e0cee 100644 --- a/src/target/spirv/ir_builder.cc +++ b/src/target/spirv/ir_builder.cc @@ -549,7 +549,7 @@ SType IRBuilder::QuerySType(const std::string& name) { return it->second; } LOG(FATAL) << "Value \"" << name << "\" does not yet exist."; - return SType(); // Return an empty Value (this line may not be reached due to LOG(FATAL)) + return SType(); } bool IRBuilder::CheckSTypeExistence(const std::string& name) { diff --git a/src/target/spirv/ir_builder.h b/src/target/spirv/ir_builder.h index 5f06c9025c09..6c57c6982f65 100644 --- a/src/target/spirv/ir_builder.h +++ b/src/target/spirv/ir_builder.h @@ -627,8 +627,8 @@ class IRBuilder { // Begin the OpCompositeConstruct instruction ib_.Begin(spv::OpCompositeConstruct) - .Add(composite_type) // The type of the composite - .Add(composite_value); // The resulting value + .Add(composite_type) + .Add(composite_value); // Add each constituent value for (const Value& val : constituents) { @@ -637,8 +637,6 @@ class IRBuilder { // Commit the instruction to the function segment ib_.Commit(&function_); - - // Return the composite value return composite_value; }