Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 239 additions & 0 deletions include/tvm/ffi/expected.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,239 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file tvm/ffi/expected.h
* \brief Runtime Expected container type for exception-free error handling.
*/
#ifndef TVM_FFI_EXPECTED_H_
#define TVM_FFI_EXPECTED_H_

#include <tvm/ffi/any.h>
#include <tvm/ffi/error.h>

#include <type_traits>
#include <utility>

namespace tvm {
namespace ffi {

/*!
* \brief Expected<T> provides exception-free error handling for FFI functions.
*
* Expected<T> is similar to Rust's Result<T, Error> or C++23's std::expected.
* It can hold either a success value of type T or an error of type Error.
*
* \tparam T The success type. Must be Any-compatible and cannot be Error.
*
* Usage:
* \code
* Expected<int> divide(int a, int b) {
* if (b == 0) {
* return ExpectedErr(Error("ValueError", "Division by zero"));
* }
* return ExpectedOk(a / b);
* }
*
* Expected<int> result = divide(10, 2);
* if (result.is_ok()) {
* int value = result.value();
* } else {
* Error err = result.error();
* }
* \endcode
*/
template <typename T>
class Expected {
public:
static_assert(!std::is_same_v<T, Error>, "Expected<Error> is not allowed. Use Error directly.");

/*!
* \brief Create an Expected with a success value.
* \param value The success value.
* \return Expected containing the success value.
*/
static Expected Ok(T value) { return Expected(Any(std::move(value))); }

/*!
* \brief Create an Expected with an error.
* \param error The error value.
* \return Expected containing the error.
*/
static Expected Err(Error error) { return Expected(Any(std::move(error))); }

/*!
* \brief Check if the Expected contains a success value.
* \return True if contains success value, false if contains error.
* \note Checks for Error first to handle cases where T is a base class of Error.
*/
TVM_FFI_INLINE bool is_ok() const { return !data_.as<Error>().has_value(); }

/*!
* \brief Check if the Expected contains an error.
* \return True if contains error, false if contains success value.
*/
TVM_FFI_INLINE bool is_err() const { return !is_ok(); }

/*!
* \brief Alias for is_ok().
* \return True if contains success value.
*/
TVM_FFI_INLINE bool has_value() const { return is_ok(); }

/*! \brief Access the success value. Throws the contained error if is_err(). */
TVM_FFI_INLINE T value() const& {
if (is_err()) throw data_.cast<Error>();
return data_.cast<T>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For large data, we can use move instead of copy, so I agree with Gemini that we can add an overload function here:

TVM_FFI_INLINE T value() && {
  if (is_err()) { throw error(); }
  return std::move(data_).cast<T>();
}

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree as well. I've added both const& and && qualified overloads for value():

}
/*! \brief Access the success value (rvalue). Throws the contained error if is_err(). */
TVM_FFI_INLINE T value() && {
if (is_err()) throw std::move(data_).template cast<Error>();
return std::move(data_).template cast<T>();
}

/*! \brief Access the error. Throws RuntimeError if is_ok(). */
TVM_FFI_INLINE Error error() const& {
if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not error";
return data_.cast<Error>();
}
/*! \brief Access the error (rvalue). Throws RuntimeError if is_ok(). */
TVM_FFI_INLINE Error error() && {
if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not error";
return std::move(data_).template cast<Error>();
}

/*!
* \brief Get the success value or a default value.
* \param default_value The value to return if Expected contains an error.
* \return The success value if present, otherwise the default value.
*/
template <typename U = std::remove_cv_t<T>>
TVM_FFI_INLINE T value_or(U&& default_value) const {
if (is_ok()) {
return data_.cast<T>();
}
return T(std::forward<U>(default_value));
}

private:
friend struct TypeTraits<Expected<T>>;

/*!
* \brief Private constructor from Any.
* \param data The data containing either T or Error.
* \note This constructor is used by TypeTraits for conversion.
*/
explicit Expected(Any data) : data_(std::move(data)) {
TVM_FFI_ICHECK(data_.as<T>().has_value() || data_.as<Error>().has_value())
<< "Expected must contain either T or Error";
}

Any data_; // Holds either T or Error
};

/*!
* \brief Helper function to create Expected::Ok with type deduction.
* \tparam T The success type (deduced from argument).
* \param value The success value.
* \return Expected<T> containing the success value.
*/
template <typename T>
TVM_FFI_INLINE Expected<T> ExpectedOk(T value) {
return Expected<T>::Ok(std::move(value));
}

/*!
* \brief Helper function to create Expected::Err.
* \param error The error value.
* \return Expected<Any> containing the error.
* \note Returns Expected<Any> to allow usage in contexts where T is inferred.
*/
template <typename T = Any>
TVM_FFI_INLINE Expected<T> ExpectedErr(Error error) {
return Expected<T>::Err(std::move(error));
}

// TypeTraits specialization for Expected<T>
template <typename T>
inline constexpr bool use_default_type_traits_v<Expected<T>> = false;

template <typename T>
struct TypeTraits<Expected<T>> : public TypeTraitsBase {
TVM_FFI_INLINE static void CopyToAnyView(const Expected<T>& src, TVMFFIAny* result) {
const TVMFFIAny* src_any = reinterpret_cast<const TVMFFIAny*>(&src.data_);
if (TypeTraits<T>::CheckAnyStrict(src_any)) {
TypeTraits<T>::MoveToAny(TypeTraits<T>::CopyFromAnyViewAfterCheck(src_any), result);
} else {
TypeTraits<Error>::MoveToAny(TypeTraits<Error>::CopyFromAnyViewAfterCheck(src_any), result);
}
}

TVM_FFI_INLINE static void MoveToAny(Expected<T> src, TVMFFIAny* result) {
TVMFFIAny* src_any = reinterpret_cast<TVMFFIAny*>(&src.data_);
if (TypeTraits<T>::CheckAnyStrict(src_any)) {
TypeTraits<T>::MoveToAny(TypeTraits<T>::MoveFromAnyAfterCheck(src_any), result);
} else {
TypeTraits<Error>::MoveToAny(TypeTraits<Error>::MoveFromAnyAfterCheck(src_any), result);
}
}

TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) {
return TypeTraits<T>::CheckAnyStrict(src) || TypeTraits<Error>::CheckAnyStrict(src);
}

TVM_FFI_INLINE static Expected<T> CopyFromAnyViewAfterCheck(const TVMFFIAny* src) {
if (TypeTraits<T>::CheckAnyStrict(src)) {
return Expected<T>::Ok(TypeTraits<T>::CopyFromAnyViewAfterCheck(src));
}
return Expected<T>::Err(TypeTraits<Error>::CopyFromAnyViewAfterCheck(src));
}

TVM_FFI_INLINE static Expected<T> MoveFromAnyAfterCheck(TVMFFIAny* src) {
if (TypeTraits<T>::CheckAnyStrict(src)) {
return Expected<T>::Ok(TypeTraits<T>::MoveFromAnyAfterCheck(src));
}
return Expected<T>::Err(TypeTraits<Error>::MoveFromAnyAfterCheck(src));
}

TVM_FFI_INLINE static std::optional<Expected<T>> TryCastFromAnyView(const TVMFFIAny* src) {
if (auto opt = TypeTraits<T>::TryCastFromAnyView(src)) {
return Expected<T>::Ok(*std::move(opt));
}
if (auto opt_err = TypeTraits<Error>::TryCastFromAnyView(src)) {
return Expected<T>::Err(*std::move(opt_err));
}
return std::nullopt;
}

TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) {
return TypeTraitsBase::GetMismatchTypeInfo(src);
}

TVM_FFI_INLINE static std::string TypeStr() {
return "Expected<" + TypeTraits<T>::TypeStr() + ">";
}

TVM_FFI_INLINE static std::string TypeSchema() {
return R"({"type":"Expected","args":[)" + details::TypeSchema<T>::v() + "]}";
}
};

} // namespace ffi
} // namespace tvm
#endif // TVM_FFI_EXPECTED_H_
45 changes: 45 additions & 0 deletions include/tvm/ffi/function.h
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,51 @@ class Function : public ObjectRef {
static_cast<FunctionObj*>(data_.get())->CallPacked(args.data(), args.size(), result);
}

/*!
* \brief Call the function and return Expected<T> for exception-free error handling.
* \tparam T The expected return type (default: Any).
* \param args The arguments to pass to the function.
* \return Expected<T> containing either the result or an error.
*
* This method provides exception-free calling by catching all exceptions
* and returning them as Error values in the Expected type.
*
* \code
* Function func = Function::GetGlobal("risky_function");
* Expected<int> result = func.CallExpected<int>(arg1, arg2);
* if (result.is_ok()) {
* int value = result.value();
* } else {
* Error err = result.error();
* }
* \endcode
*/
template <typename T = Any, typename... Args>
TVM_FFI_INLINE Expected<T> CallExpected(Args&&... args) const {
constexpr size_t kNumArgs = sizeof...(Args);
AnyView args_pack[kNumArgs > 0 ? kNumArgs : 1];
PackedArgs::Fill(args_pack, std::forward<Args>(args)...);

Any result;
FunctionObj* func_obj = static_cast<FunctionObj*>(data_.get());

// Use safe_call path to catch exceptions
int ret_code = func_obj->safe_call(func_obj, reinterpret_cast<const TVMFFIAny*>(args_pack),
kNumArgs, reinterpret_cast<TVMFFIAny*>(&result));

if (ret_code == 0) {
// Success - cast result to T and return Ok
return Expected<T>::Ok(std::move(result).cast<T>());
} else if (ret_code == -2) {
// Environment error already set (e.g., Python KeyboardInterrupt)
// We still throw this since it's a signal, not a normal error
throw ::tvm::ffi::EnvErrorAlreadySet();
} else {
// Error occurred - retrieve from safe call context and return Err
return Expected<T>::Err(details::MoveFromSafeCallRaised());
}
}

/*! \return Whether the packed function is nullptr */
TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; }
/*! \return Whether the packed function is not nullptr */
Expand Down
30 changes: 28 additions & 2 deletions include/tvm/ffi/function_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,11 @@

namespace tvm {
namespace ffi {

// Forward declaration for Expected<T>
template <typename T>
class Expected;

namespace details {

template <typename ArgType>
Expand Down Expand Up @@ -67,10 +72,23 @@ static constexpr bool ArgSupported =
std::is_same_v<std::remove_const_t<std::remove_reference_t<T>>, AnyView> ||
TypeTraitsNoCR<T>::convert_enabled));

template <typename T>
struct is_expected : std::false_type {
using value_type = void;
};

template <typename T>
struct is_expected<Expected<T>> : std::true_type {
using value_type = T;
};

template <typename T>
inline constexpr bool is_expected_v = is_expected<T>::value;

// NOTE: return type can only support non-reference managed returns
template <typename T>
static constexpr bool RetSupported =
(std::is_same_v<T, Any> || std::is_void_v<T> || TypeTraits<T>::convert_enabled);
static constexpr bool RetSupported = (std::is_same_v<T, Any> || std::is_void_v<T> ||
TypeTraits<T>::convert_enabled || is_expected_v<T>);

template <typename R, typename... Args>
struct FuncFunctorImpl {
Expand Down Expand Up @@ -219,6 +237,14 @@ TVM_FFI_INLINE void unpack_call(std::index_sequence<Is...>, const std::string* o
// use index sequence to do recursive-less unpacking
if constexpr (std::is_same_v<R, void>) {
f(ArgValueWithContext<std::tuple_element_t<Is, PackedArgs>>{args, Is, optional_name, f_sig}...);
} else if constexpr (is_expected_v<R>) {
R expected_result = f(ArgValueWithContext<std::tuple_element_t<Is, PackedArgs>>{
args, Is, optional_name, f_sig}...);
if (expected_result.is_ok()) {
*rv = expected_result.value();
} else {
throw expected_result.error();
}
} else {
*rv = R(f(ArgValueWithContext<std::tuple_element_t<Is, PackedArgs>>{args, Is, optional_name,
f_sig}...));
Expand Down
1 change: 1 addition & 0 deletions include/tvm/ffi/tvm_ffi.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@
#include <tvm/ffi/dtype.h>
#include <tvm/ffi/endian.h>
#include <tvm/ffi/error.h>
#include <tvm/ffi/expected.h>
#include <tvm/ffi/function.h>
#include <tvm/ffi/function_details.h>
#include <tvm/ffi/memory.h>
Expand Down
Loading