Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ tvm_option(USE_OPENCL "Build with OpenCL" OFF)
tvm_option(USE_OPENCL_ENABLE_HOST_PTR "Enable OpenCL memory object access to host" OFF)
tvm_option(USE_OPENCL_GTEST "Path to OpenCL specific gtest version for runtime cpp tests." /path/to/opencl/gtest)
tvm_option(USE_VULKAN "Build with Vulkan" OFF)
tvm_option(USE_VULKAN_GTEST "Path to Vulkan specific gtest version for runtime cpp tests." /path/to/vulkan/gtest)


# Whether to use spirv-tools.and SPIRV-Headers from Khronos github or gitlab.
Expand Down Expand Up @@ -454,6 +455,9 @@ set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
set(CMAKE_CUDA_STANDARD 17)

#include centralized gtest setup
include(cmake/modules/GTestConfig.cmake)

# Module rules
include(cmake/modules/CUDA.cmake)
include(cmake/modules/Hexagon.cmake) # This must come before logging.cmake
Expand Down
52 changes: 52 additions & 0 deletions cmake/modules/GTestConfig.cmake
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

set(Build_GTests OFF)
if(NOT TARGET gtest)
unset(runtime_gtests)
if(DEFINED USE_OPENCL_GTEST AND EXISTS ${USE_OPENCL_GTEST})
set(runtime_gtests ${USE_OPENCL_GTEST})
elseif(DEFINED USE_VULKAN_GTEST AND EXISTS ${USE_VULKAN_GTEST})
set(runtime_gtests ${USE_VULKAN_GTEST})
elseif(ANDROID_ABI AND DEFINED ENV{ANDROID_NDK_HOME})
set(GOOGLETEST_ROOT $ENV{ANDROID_NDK_HOME}/sources/third_party/googletest)
add_library(gtest_main STATIC
${GOOGLETEST_ROOT}/src/gtest_main.cc
${GOOGLETEST_ROOT}/src/gtest-all.cc)
target_include_directories(gtest_main PRIVATE ${GOOGLETEST_ROOT})
target_include_directories(gtest_main PUBLIC ${GOOGLETEST_ROOT}/include)
set(Build_GTests ON)
message(STATUS "Using gtest from Android NDK")
return()
else()
message(STATUS "No valid GTest path found, skipping GTest configuration")
return()
endif()

# Configure if runtime_gtests is valid
if(runtime_gtests AND EXISTS ${runtime_gtests})
include(FetchContent)
FetchContent_Declare(googletest SOURCE_DIR "${runtime_gtests}")
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX})
set(Build_GTests ON)
else()
set(Build_GTests OFF)
return()
endif()
endif()
1 change: 1 addition & 0 deletions cmake/modules/LibInfo.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ function(add_lib_info src_file)
TVM_INFO_USE_THRUST="${USE_THRUST}"
TVM_INFO_USE_CURAND="${USE_CURAND}"
TVM_INFO_USE_VULKAN="${USE_VULKAN}"
TVM_INFO_USE_VULKAN_GTEST="${USE_VULKAN_GTEST}"
TVM_INFO_USE_CLML="${USE_CLML}"
TVM_INFO_USE_CLML_GRAPH_EXECUTOR="${USE_CLML_GRAPH_EXECUTOR}"
TVM_INFO_USE_TVM_CLML_VERSION="${CLML_VERSION_MAJOR}"
Expand Down
38 changes: 8 additions & 30 deletions cmake/modules/OpenCL.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

if(USE_OPENCL)
tvm_file_glob(GLOB RUNTIME_OPENCL_SRCS src/runtime/opencl/*.cc)
list(APPEND COMPILER_SRCS src/target/spirv/spirv_utils.cc)
Expand All @@ -35,36 +34,15 @@ if(USE_OPENCL)
list(APPEND TVM_RUNTIME_LINKER_LIBS ${OpenCL_LIBRARIES})
endif()

if(DEFINED USE_OPENCL_GTEST)
if(EXISTS ${USE_OPENCL_GTEST})
include(FetchContent)
FetchContent_Declare(googletest SOURCE_DIR "${USE_OPENCL_GTEST}")
set(gtest_force_shared_crt ON CACHE BOOL "" FORCE)
FetchContent_MakeAvailable(googletest)
install(TARGETS gtest EXPORT ${PROJECT_NAME}Targets DESTINATION lib${LIB_SUFFIX})

message(STATUS "Found OpenCL gtest at ${USE_OPENCL_GTEST}")
set(Build_OpenCL_GTests ON)
elseif (ANDROID_ABI AND DEFINED ENV{ANDROID_NDK_HOME})
set(GOOGLETEST_ROOT $ENV{ANDROID_NDK_HOME}/sources/third_party/googletest)
add_library(gtest_main STATIC ${GOOGLETEST_ROOT}/src/gtest_main.cc ${GOOGLETEST_ROOT}/src/gtest-all.cc)
target_include_directories(gtest_main PRIVATE ${GOOGLETEST_ROOT})
target_include_directories(gtest_main PUBLIC ${GOOGLETEST_ROOT}/include)
message(STATUS "Using gtest from Android NDK")
set(Build_OpenCL_GTests ON)
endif()

if(Build_OpenCL_GTests)
message(STATUS "Building OpenCL-Gtests")
tvm_file_glob(GLOB_RECURSE OPENCL_TEST_SRCS
"tests/cpp-runtime/opencl/*.cc"
)
add_executable(opencl-cpptest ${OPENCL_TEST_SRCS})
target_link_libraries(opencl-cpptest PRIVATE gtest_main tvm_runtime ${OpenCL_LIBRARIES})
else()
message(STATUS "Couldn't build OpenCL-Gtests")
endif()
if(Build_GTests)
message(STATUS "Building OpenCL GTests")
tvm_file_glob(GLOB_RECURSE OPENCL_TEST_SRCS "tests/cpp-runtime/opencl/*.cc")
add_executable(opencl-cpptest ${OPENCL_TEST_SRCS})
target_link_libraries(opencl-cpptest PRIVATE gtest_main tvm_runtime ${OpenCL_LIBRARIES})
else()
message(STATUS "Couldn't build OpenCL-Gtests")
endif()

list(APPEND RUNTIME_SRCS ${RUNTIME_OPENCL_SRCS})
if(USE_OPENCL_ENABLE_HOST_PTR)
add_definitions(-DOPENCL_ENABLE_HOST_PTR)
Expand Down
10 changes: 10 additions & 0 deletions cmake/modules/Vulkan.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ if(USE_VULKAN)
message(STATUS "Build with Vulkan support")
tvm_file_glob(GLOB RUNTIME_VULKAN_SRCS src/runtime/vulkan/*.cc)
tvm_file_glob(GLOB COMPILER_VULKAN_SRCS src/target/spirv/*.cc)

if(Build_GTests)
message(STATUS "Building Vulkan GTests")
tvm_file_glob(GLOB_RECURSE VULKAN_TEST_SRCS "tests/cpp-runtime/vulkan/*.cc")
add_executable(vulkan-cpptest ${VULKAN_TEST_SRCS})
target_link_libraries(vulkan-cpptest PRIVATE gtest_main tvm_runtime)
else()
message(STATUS "Couldn't build Vulkan-Gtests")
endif()

list(APPEND RUNTIME_SRCS ${RUNTIME_VULKAN_SRCS})
list(APPEND COMPILER_SRCS ${COMPILER_VULKAN_SRCS})
list(APPEND TVM_LINKER_LIBS ${Vulkan_SPIRV_TOOLS_LIBRARY})
Expand Down
54 changes: 37 additions & 17 deletions cmake/utils/FindVulkan.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ macro(find_vulkan use_vulkan use_khronos_spirv)
set(__use_vulkan ${use_vulkan})
if(IS_DIRECTORY ${__use_vulkan})
set(__vulkan_sdk ${__use_vulkan})
message(STATUS "Custom Vulkan SDK PATH=" ${__use_vulkan})
message(STATUS "Using custom Vulkan SDK: ${__vulkan_sdk}")
elseif(IS_DIRECTORY $ENV{VULKAN_SDK})
set(__vulkan_sdk $ENV{VULKAN_SDK})
else()
Expand All @@ -46,45 +46,64 @@ macro(find_vulkan use_vulkan use_khronos_spirv)

if(IS_DIRECTORY ${use_khronos_spirv})
set(__use_khronos_spirv ${use_khronos_spirv})
message(STATUS "Custom khronos spirv PATH=" ${__use_khronos_spirv})
message(STATUS "Using custom Khronos SPIRV path: ${__use_khronos_spirv}")
else()
set(__use_khronos_spirv "")
endif()

if(CMAKE_SYSTEM_NAME STREQUAL "Android")
set(VULKAN_NDK_SRC ${CMAKE_ANDROID_NDK}/sources/third_party/vulkan/src)
set(Vulkan_INCLUDE_DIRS ${VULKAN_NDK_SRC}/include)
set(Vulkan_FOUND TRUE)
message(STATUS "Android Vulkan_INCLUDE_DIRS=" ${Vulkan_INCLUDE_DIRS})
message(STATUS "Skip finding SPIRV in Android, make sure you only build tvm runtime.")
return()
endif()
message(STATUS "Detected Android build")

set(Vulkan_INCLUDE_DIRS "${CMAKE_SYSROOT}/usr/include/vulkan")

# Map Android ABI to architecture
set(ANDROID_LIB_ARCH "")
if(CMAKE_ANDROID_ARCH_ABI STREQUAL "arm64-v8a")
set(ANDROID_LIB_ARCH "aarch64-linux-android")
elseif(CMAKE_ANDROID_ARCH_ABI STREQUAL "armeabi-v7a")
set(ANDROID_LIB_ARCH "arm-linux-androideabi")
elseif(CMAKE_ANDROID_ARCH_ABI STREQUAL "x86")
set(ANDROID_LIB_ARCH "i686-linux-android")
elseif(CMAKE_ANDROID_ARCH_ABI STREQUAL "x86_64")
set(ANDROID_LIB_ARCH "x86_64-linux-android")
else()
message(FATAL_ERROR "Unsupported Android ABI: ${CMAKE_ANDROID_ARCH_ABI}")
endif()

# Find Vulkan library for Android
set(Vulkan_LIB_PATH "${CMAKE_SYSROOT}/usr/lib/${ANDROID_LIB_ARCH}/27")
find_library(Vulkan_LIBRARY NAMES vulkan libvulkan.so PATHS ${Vulkan_LIB_PATH} NO_DEFAULT_PATH)

if(Vulkan_LIBRARY)
set(Vulkan_FOUND TRUE)
else()
message(FATAL_ERROR "Could not find Vulkan lib in ${Vulkan_LIB_PATH}")
endif()

else()

if(__vulkan_sdk)
set(Vulkan_INCLUDE_DIRS ${__vulkan_sdk}/include)
find_library(Vulkan_LIBRARY NAMES vulkan vulkan-1 PATHS ${__vulkan_sdk}/lib)
if(Vulkan_LIBRARY)
set(Vulkan_FOUND TRUE)
endif()
endif(__vulkan_sdk)
endif()

# resort to find vulkan of option is on
if(NOT Vulkan_FOUND)
if(${__use_vulkan} MATCHES ${IS_TRUE_PATTERN})
find_package(Vulkan QUIET)
endif()
if(NOT Vulkan_FOUND AND ${use_vulkan} MATCHES ${IS_TRUE_PATTERN})
find_package(Vulkan QUIET)
endif()

if(Vulkan_FOUND)
get_filename_component(VULKAN_LIBRARY_PATH ${Vulkan_LIBRARY} DIRECTORY)
if (WIN32)
find_library(Vulkan_SPIRV_TOOLS_LIBRARY SPIRV-Tools
HINTS ${__use_khronos_spirv}/spirv-tools/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${VULKAN_SDK}/lib)
HINTS ${__use_khronos_spirv}/spirv-tools/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${__vulkan_sdk}/lib)
find_path(_libspirv libspirv.h HINTS ${__use_khronos_spirv}/spirv-tools/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan spirv-tools)
find_path(_spirv spirv.hpp HINTS ${__use_khronos_spirv}/SPIRV-Headers/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
else()
find_library(Vulkan_SPIRV_TOOLS_LIBRARY SPIRV-Tools
HINTS ${__use_khronos_spirv}/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${VULKAN_SDK}/lib)
HINTS ${__use_khronos_spirv}/lib ${VULKAN_LIBRARY_PATH} ${VULKAN_LIBRARY_PATH}/spirv-tools ${__vulkan_sdk}/lib)
find_path(_libspirv libspirv.h HINTS ${__use_khronos_spirv}/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan spirv-tools)
find_path(_spirv spirv.hpp HINTS ${__use_khronos_spirv}/include ${Vulkan_INCLUDE_DIRS} PATH_SUFFIXES vulkan SPIRV spirv/unified1 spirv-headers)
endif()
Expand All @@ -95,4 +114,5 @@ macro(find_vulkan use_vulkan use_khronos_spirv)
message(STATUS "Vulkan_LIBRARY=" ${Vulkan_LIBRARY})
message(STATUS "Vulkan_SPIRV_TOOLS_LIBRARY=" ${Vulkan_SPIRV_TOOLS_LIBRARY})
endif(Vulkan_FOUND)
endif()
endmacro(find_vulkan)
2 changes: 1 addition & 1 deletion python/tvm/relax/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def get_default_pipeline(target: tvm.target.Target):
return backend.gpu_generic.get_default_pipeline(target)
if target.kind.name == "llvm":
return backend.cpu_generic.get_default_pipeline(target)
if target.kind.name == "opencl" and "adreno" in target.keys:
if target.kind.name in ["opencl", "vulkan"] and "adreno" in target.keys:
return backend.adreno.get_default_pipeline(target)
if BackendDispatcher.is_gpu_target(target):
return backend.gpu_generic.get_default_pipeline(target)
Expand Down
19 changes: 17 additions & 2 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -838,6 +838,14 @@ def _multi_gpu_exists():
)


def _check_opencl_vulkan():
return (
(_cmake_flag_enabled("USE_OPENCL") and tvm.opencl(0).exist)
or (_cmake_flag_enabled("USE_VULKAN") and tvm.vulkan(0).exist)
or "RPC_TARGET" in os.environ
)


# Mark a test as requiring llvm to run
requires_llvm = Feature(
"llvm", "LLVM", cmake_flag="USE_LLVM", target_kind_enabled="llvm", target_kind_hardware="llvm"
Expand Down Expand Up @@ -976,8 +984,8 @@ def _multi_gpu_exists():
"Vulkan",
cmake_flag="USE_VULKAN",
target_kind_enabled="vulkan",
target_kind_hardware="vulkan",
parent_features="gpu",
target_kind_hardware="vulkan" if "RPC_TARGET" not in os.environ else None,
parent_features="gpu" if "RPC_TARGET" not in os.environ else None,
)

# Mark a test as requiring OpenCLML support in build.
Expand All @@ -988,6 +996,13 @@ def _multi_gpu_exists():
target_kind_enabled="opencl",
)

requires_opencl_vulkan = Feature(
"opencl_vulkan",
"OpenCL or Vulkan",
run_time_check=_check_opencl_vulkan,
parent_features=["opencl", "gpu"],
)

# Mark a test as requiring NNAPI support in build.
requires_nnapi = Feature(
"NNAPI",
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/tir/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ def get_default_tir_pipeline(
target: tvm.target.Target, # pylint: disable=unused-argument
) -> tvm.transform.Pass:
"""Get the default TIR pipeline for the given target."""
if target.kind.name == "opencl" and "adreno" in target.keys:
if target.kind.name in ["opencl", "vulkan"] and "adreno" in target.keys:
return backend.adreno.get_tir_pipeline(target)
else:
return default_tir_pipeline()
4 changes: 4 additions & 0 deletions src/runtime/file_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ void FunctionInfo::Save(dmlc::JSONWriter* writer) const {
writer->BeginObject();
writer->WriteObjectKeyValue("name", name);
writer->WriteObjectKeyValue("arg_types", sarg_types);
writer->WriteObjectKeyValue("storage_scopes", storage_scopes);
writer->WriteObjectKeyValue("launch_param_tags", launch_param_tags);
std::vector<int> iarg_extra_tags(arg_extra_tags.size());
for (size_t i = 0; i < arg_extra_tags.size(); ++i) {
Expand All @@ -59,6 +60,7 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
std::vector<std::string> sarg_types;
helper.DeclareField("name", &name);
helper.DeclareField("arg_types", &sarg_types);
helper.DeclareOptionalField("storage_scopes", &storage_scopes);
helper.DeclareOptionalField("launch_param_tags", &launch_param_tags);
helper.DeclareOptionalField("thread_axis_tags",
&launch_param_tags); // for backward compatibility
Expand All @@ -78,13 +80,15 @@ void FunctionInfo::Load(dmlc::JSONReader* reader) {
void FunctionInfo::Save(dmlc::Stream* writer) const {
writer->Write(name);
writer->Write(arg_types);
writer->Write(storage_scopes);
writer->Write(launch_param_tags);
writer->Write(arg_extra_tags);
}

bool FunctionInfo::Load(dmlc::Stream* reader) {
if (!reader->Read(&name)) return false;
if (!reader->Read(&arg_types)) return false;
if (!reader->Read(&storage_scopes)) return false;
if (!reader->Read(&launch_param_tags)) return false;
if (!reader->Read(&arg_extra_tags)) return false;
return true;
Expand Down
1 change: 1 addition & 0 deletions src/runtime/meta_data.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ constexpr const char* kUseCooperativeLaunch = "tir.use_cooperative_launch";
struct FunctionInfo {
std::string name;
std::vector<DLDataType> arg_types;
std::vector<std::string> storage_scopes;
std::vector<std::string> launch_param_tags;

enum class ArgExtraTags : int { kNone = 0, kTensorMap = 1 };
Expand Down
Loading
Loading