From b510df61fe6acd3a9ff638480d364c28ae151f4f Mon Sep 17 00:00:00 2001 From: Guan-Ming Chiu Date: Sun, 11 Jan 2026 23:22:39 +0800 Subject: [PATCH 1/2] feat: Add ffi::Expected for exception-free error handling --- include/tvm/ffi/expected.h | 256 +++++++++++++++++++++++++++ include/tvm/ffi/function.h | 45 +++++ include/tvm/ffi/function_details.h | 30 +++- include/tvm/ffi/tvm_ffi.h | 1 + tests/cpp/test_expected.cc | 274 +++++++++++++++++++++++++++++ 5 files changed, 604 insertions(+), 2 deletions(-) create mode 100644 include/tvm/ffi/expected.h create mode 100644 tests/cpp/test_expected.cc diff --git a/include/tvm/ffi/expected.h b/include/tvm/ffi/expected.h new file mode 100644 index 00000000..5e0a1cd1 --- /dev/null +++ b/include/tvm/ffi/expected.h @@ -0,0 +1,256 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/ffi/expected.h + * \brief Runtime Expected container type for exception-free error handling. + */ +#ifndef TVM_FFI_EXPECTED_H_ +#define TVM_FFI_EXPECTED_H_ + +#include +#include + +#include +#include + +namespace tvm { +namespace ffi { + +/*! + * \brief Expected provides exception-free error handling for FFI functions. + * + * Expected is similar to Rust's Result or C++23's std::expected. + * It can hold either a success value of type T or an error of type Error. + * + * \tparam T The success type. Must be Any-compatible and cannot be Error. + * + * Usage: + * \code + * Expected divide(int a, int b) { + * if (b == 0) { + * return ExpectedErr(Error("ValueError", "Division by zero")); + * } + * return ExpectedOk(a / b); + * } + * + * Expected result = divide(10, 2); + * if (result.is_ok()) { + * int value = result.value(); + * } else { + * Error err = result.error(); + * } + * \endcode + */ +template +class Expected { + public: + static_assert(!std::is_same_v, "Expected is not allowed. Use Error directly."); + + /*! + * \brief Create an Expected with a success value. + * \param value The success value. + * \return Expected containing the success value. + */ + static Expected Ok(T value) { return Expected(Any(std::move(value))); } + + /*! + * \brief Create an Expected with an error. + * \param error The error value. + * \return Expected containing the error. + */ + static Expected Err(Error error) { return Expected(Any(std::move(error))); } + + /*! + * \brief Check if the Expected contains a success value. + * \return True if contains success value, false if contains error. + */ + TVM_FFI_INLINE bool is_ok() const { return data_.as().has_value(); } + + /*! + * \brief Check if the Expected contains an error. + * \return True if contains error, false if contains success value. + */ + TVM_FFI_INLINE bool is_err() const { return data_.as().has_value(); } + + /*! + * \brief Alias for is_ok(). + * \return True if contains success value. + */ + TVM_FFI_INLINE bool has_value() const { return is_ok(); } + + /*! + * \brief Access the success value. + * \return The success value. + * \throws RuntimeError if the Expected contains an error. + */ + TVM_FFI_INLINE T value() const { + if (is_err()) { + TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains error"; + } + return data_.cast(); + } + + /*! + * \brief Access the error value. + * \return The error value. + * \note Behavior is undefined if the Expected contains a success value. + * Always check is_err() before calling this method. + */ + TVM_FFI_INLINE Error error() const { + TVM_FFI_ICHECK(is_err()) << "Expected does not contain an error"; + return data_.cast(); + } + + /*! + * \brief Get the success value or a default value. + * \param default_value The value to return if Expected contains an error. + * \return The success value if present, otherwise the default value. + */ + template > + TVM_FFI_INLINE T value_or(U&& default_value) const { + if (is_ok()) { + return data_.cast(); + } + return T(std::forward(default_value)); + } + + private: + friend struct TypeTraits>; + + /*! + * \brief Private constructor from Any. + * \param data The data containing either T or Error. + * \note This constructor is used by TypeTraits for conversion. + */ + explicit Expected(Any data) : data_(std::move(data)) { + TVM_FFI_ICHECK(data_.as().has_value() || data_.as().has_value()) + << "Expected must contain either T or Error"; + } + + Any data_; // Holds either T or Error +}; + +/*! + * \brief Helper function to create Expected::Ok with type deduction. + * \tparam T The success type (deduced from argument). + * \param value The success value. + * \return Expected containing the success value. + */ +template +TVM_FFI_INLINE Expected ExpectedOk(T value) { + return Expected::Ok(std::move(value)); +} + +/*! + * \brief Helper function to create Expected::Err. + * \param error The error value. + * \return Expected containing the error. + * \note Returns Expected to allow usage in contexts where T is inferred. + */ +template +TVM_FFI_INLINE Expected ExpectedErr(Error error) { + return Expected::Err(std::move(error)); +} + +// TypeTraits specialization for Expected +template +inline constexpr bool use_default_type_traits_v> = false; + +template +struct TypeTraits> : public TypeTraitsBase { + TVM_FFI_INLINE static void CopyToAnyView(const Expected& src, TVMFFIAny* result) { + // Extract value from src.data_ and copy it properly + const TVMFFIAny* src_any = reinterpret_cast(&src.data_); + + if (TypeTraits::CheckAnyStrict(src_any)) { + // It contains T, copy it out and move to result + T value = TypeTraits::CopyFromAnyViewAfterCheck(src_any); + TypeTraits::MoveToAny(std::move(value), result); + } else { + // It contains Error, copy it out and move to result + Error err = TypeTraits::CopyFromAnyViewAfterCheck(src_any); + TypeTraits::MoveToAny(std::move(err), result); + } + } + + TVM_FFI_INLINE static void MoveToAny(Expected src, TVMFFIAny* result) { + // Extract value from src.data_ and move it properly + TVMFFIAny* src_any = reinterpret_cast(&src.data_); + + if (TypeTraits::CheckAnyStrict(src_any)) { + // It contains T, move it out and move to result + T value = TypeTraits::MoveFromAnyAfterCheck(src_any); + TypeTraits::MoveToAny(std::move(value), result); + } else { + // It contains Error, move it out and move to result + Error err = TypeTraits::MoveFromAnyAfterCheck(src_any); + TypeTraits::MoveToAny(std::move(err), result); + } + } + + TVM_FFI_INLINE static bool CheckAnyStrict(const TVMFFIAny* src) { + return TypeTraits::CheckAnyStrict(src) || TypeTraits::CheckAnyStrict(src); + } + + TVM_FFI_INLINE static Expected CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { + if (TypeTraits::CheckAnyStrict(src)) { + return Expected::Ok(TypeTraits::CopyFromAnyViewAfterCheck(src)); + } else { + return Expected::Err(TypeTraits::CopyFromAnyViewAfterCheck(src)); + } + } + + TVM_FFI_INLINE static Expected MoveFromAnyAfterCheck(TVMFFIAny* src) { + if (TypeTraits::CheckAnyStrict(src)) { + return Expected::Ok(TypeTraits::MoveFromAnyAfterCheck(src)); + } else { + return Expected::Err(TypeTraits::MoveFromAnyAfterCheck(src)); + } + } + + TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { + // Try to convert to T first + if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { + return Expected::Ok(*std::move(opt)); + } + // Try to convert to Error + if (std::optional opt_err = TypeTraits::TryCastFromAnyView(src)) { + return Expected::Err(*std::move(opt_err)); + } + // Conversion failed - return explicit nullopt to indicate failure + return std::optional>(std::nullopt); + } + + TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { + return TypeTraitsBase::GetMismatchTypeInfo(src); + } + + TVM_FFI_INLINE static std::string TypeStr() { + return "Expected<" + TypeTraits::TypeStr() + ">"; + } + + TVM_FFI_INLINE static std::string TypeSchema() { + return R"({"type":"Expected","args":[)" + details::TypeSchema::v() + "]}"; + } +}; + +} // namespace ffi +} // namespace tvm +#endif // TVM_FFI_EXPECTED_H_ diff --git a/include/tvm/ffi/function.h b/include/tvm/ffi/function.h index 30437315..7adef638 100644 --- a/include/tvm/ffi/function.h +++ b/include/tvm/ffi/function.h @@ -643,6 +643,51 @@ class Function : public ObjectRef { static_cast(data_.get())->CallPacked(args.data(), args.size(), result); } + /*! + * \brief Call the function and return Expected for exception-free error handling. + * \tparam T The expected return type (default: Any). + * \param args The arguments to pass to the function. + * \return Expected containing either the result or an error. + * + * This method provides exception-free calling by catching all exceptions + * and returning them as Error values in the Expected type. + * + * \code + * Function func = Function::GetGlobal("risky_function"); + * Expected result = func.CallExpected(arg1, arg2); + * if (result.is_ok()) { + * int value = result.value(); + * } else { + * Error err = result.error(); + * } + * \endcode + */ + template + TVM_FFI_INLINE Expected CallExpected(Args&&... args) const { + constexpr size_t kNumArgs = sizeof...(Args); + AnyView args_pack[kNumArgs > 0 ? kNumArgs : 1]; + PackedArgs::Fill(args_pack, std::forward(args)...); + + Any result; + FunctionObj* func_obj = static_cast(data_.get()); + + // Use safe_call path to catch exceptions + int ret_code = func_obj->safe_call(func_obj, reinterpret_cast(args_pack), + kNumArgs, reinterpret_cast(&result)); + + if (ret_code == 0) { + // Success - cast result to T and return Ok + return Expected::Ok(std::move(result).cast()); + } else if (ret_code == -2) { + // Environment error already set (e.g., Python KeyboardInterrupt) + // We still throw this since it's a signal, not a normal error + throw ::tvm::ffi::EnvErrorAlreadySet(); + } else { + // Error occurred - retrieve from safe call context and return Err + return Expected::Err(details::MoveFromSafeCallRaised()); + } + } + /*! \return Whether the packed function is nullptr */ TVM_FFI_INLINE bool operator==(std::nullptr_t) const { return data_ == nullptr; } /*! \return Whether the packed function is not nullptr */ diff --git a/include/tvm/ffi/function_details.h b/include/tvm/ffi/function_details.h index 8f163405..b3a1fe94 100644 --- a/include/tvm/ffi/function_details.h +++ b/include/tvm/ffi/function_details.h @@ -34,6 +34,11 @@ namespace tvm { namespace ffi { + +// Forward declaration for Expected +template +class Expected; + namespace details { template @@ -67,10 +72,23 @@ static constexpr bool ArgSupported = std::is_same_v>, AnyView> || TypeTraitsNoCR::convert_enabled)); +template +struct is_expected : std::false_type { + using value_type = void; +}; + +template +struct is_expected> : std::true_type { + using value_type = T; +}; + +template +inline constexpr bool is_expected_v = is_expected::value; + // NOTE: return type can only support non-reference managed returns template -static constexpr bool RetSupported = - (std::is_same_v || std::is_void_v || TypeTraits::convert_enabled); +static constexpr bool RetSupported = (std::is_same_v || std::is_void_v || + TypeTraits::convert_enabled || is_expected_v); template struct FuncFunctorImpl { @@ -219,6 +237,14 @@ TVM_FFI_INLINE void unpack_call(std::index_sequence, const std::string* o // use index sequence to do recursive-less unpacking if constexpr (std::is_same_v) { f(ArgValueWithContext>{args, Is, optional_name, f_sig}...); + } else if constexpr (is_expected_v) { + R expected_result = f(ArgValueWithContext>{ + args, Is, optional_name, f_sig}...); + if (expected_result.is_ok()) { + *rv = expected_result.value(); + } else { + throw expected_result.error(); + } } else { *rv = R(f(ArgValueWithContext>{args, Is, optional_name, f_sig}...)); diff --git a/include/tvm/ffi/tvm_ffi.h b/include/tvm/ffi/tvm_ffi.h index be26aed3..9d0c1d04 100644 --- a/include/tvm/ffi/tvm_ffi.h +++ b/include/tvm/ffi/tvm_ffi.h @@ -41,6 +41,7 @@ #include #include #include +#include #include #include #include diff --git a/tests/cpp/test_expected.cc b/tests/cpp/test_expected.cc new file mode 100644 index 00000000..72f0fcdc --- /dev/null +++ b/tests/cpp/test_expected.cc @@ -0,0 +1,274 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include "./testing_object.h" + +namespace { + +using namespace tvm::ffi; +using namespace tvm::ffi::testing; + +// Test basic construction with Ok +TEST(Expected, BasicOk) { + Expected result = ExpectedOk(42); + + EXPECT_TRUE(result.is_ok()); + EXPECT_FALSE(result.is_err()); + EXPECT_TRUE(result.has_value()); + EXPECT_EQ(result.value(), 42); + EXPECT_EQ(result.value_or(0), 42); +} + +// Test basic construction with Err +TEST(Expected, BasicErr) { + Expected result = ExpectedErr(Error("RuntimeError", "test error", "")); + + EXPECT_FALSE(result.is_ok()); + EXPECT_TRUE(result.is_err()); + EXPECT_FALSE(result.has_value()); + + Error err = result.error(); + EXPECT_EQ(err.kind(), "RuntimeError"); + EXPECT_EQ(err.message(), "test error"); +} + +// Test value_or with error +TEST(Expected, ValueOrWithError) { + Expected result = ExpectedErr(Error("RuntimeError", "test error", "")); + EXPECT_EQ(result.value_or(99), 99); +} + +// Test with ObjectRef types +TEST(Expected, ObjectRefType) { + Expected result = ExpectedOk(TInt(123)); + + EXPECT_TRUE(result.is_ok()); + EXPECT_EQ(result.value()->value, 123); +} + +// Test with String type +TEST(Expected, StringType) { + Expected result = ExpectedOk(String("hello")); + + EXPECT_TRUE(result.is_ok()); + EXPECT_EQ(result.value(), "hello"); + + Expected err_result = ExpectedErr(Error("ValueError", "bad string", "")); + EXPECT_TRUE(err_result.is_err()); +} + +// Test TypeTraits conversion: Expected -> Any -> Expected +TEST(Expected, TypeTraitsRoundtrip) { + Expected original = ExpectedOk(42); + + // Convert to Any (should unwrap to int) + Any any_value = original; + EXPECT_EQ(any_value.cast(), 42); + + // Convert back to Expected (should reconstruct as Ok) + Expected recovered = any_value.cast>(); + EXPECT_TRUE(recovered.is_ok()); + EXPECT_EQ(recovered.value(), 42); +} + +// Test TypeTraits conversion with Error +TEST(Expected, TypeTraitsErrorRoundtrip) { + Expected original = ExpectedErr(Error("TypeError", "conversion failed", "")); + + // Convert to Any (should unwrap to Error) + Any any_value = original; + EXPECT_TRUE(any_value.as().has_value()); + + // Convert back to Expected (should reconstruct as Err) + Expected recovered = any_value.cast>(); + EXPECT_TRUE(recovered.is_err()); + EXPECT_EQ(recovered.error().kind(), "TypeError"); +} + +// Test move semantics +TEST(Expected, MoveSemantics) { + Expected result = ExpectedOk(String("test")); + EXPECT_TRUE(result.is_ok()); + + String value = std::move(result).value(); + EXPECT_EQ(value, "test"); +} + +// Test CallExpected with normal function +TEST(Expected, CallExpectedNormal) { + auto safe_add = [](int a, int b) { return a + b; }; + + Function func = Function::FromTyped(safe_add); + Expected result = func.CallExpected(5, 3); + + EXPECT_TRUE(result.is_ok()); + EXPECT_EQ(result.value(), 8); +} + +// Test CallExpected with throwing function +TEST(Expected, CallExpectedThrowing) { + auto throwing_func = [](int a) -> int { + if (a < 0) { + TVM_FFI_THROW(ValueError) << "Negative value not allowed"; + } + return a * 2; + }; + + Function func = Function::FromTyped(throwing_func); + + // Normal case + Expected result_ok = func.CallExpected(5); + EXPECT_TRUE(result_ok.is_ok()); + EXPECT_EQ(result_ok.value(), 10); + + // Error case + Expected result_err = func.CallExpected(-1); + EXPECT_TRUE(result_err.is_err()); + EXPECT_EQ(result_err.error().kind(), "ValueError"); +} + +// Test that lambda returning Expected works directly +TEST(Expected, LambdaDirectCall) { + auto safe_divide = [](int a, int b) -> Expected { + if (b == 0) { + return ExpectedErr(Error("ValueError", "Division by zero", "")); + } + return ExpectedOk(a / b); + }; + + // Direct call to lambda should work + Expected result = safe_divide(10, 2); + EXPECT_TRUE(result.is_ok()); + EXPECT_EQ(result.value(), 5); + + // Check the value can be extracted + int val = result.value(); + EXPECT_EQ(val, 5); + + // Check assigning to Any works + Any any_val = result.value(); + EXPECT_EQ(any_val.cast(), 5); +} + +// Test registering function that returns Expected +TEST(Expected, RegisterExpectedReturning) { + auto safe_divide = [](int a, int b) -> Expected { + if (b == 0) { + return ExpectedErr(Error("ValueError", "Division by zero", "")); + } + return ExpectedOk(a / b); + }; + + // Verify the FunctionInfo extracts Expected as return type + using FuncInfo = tvm::ffi::details::FunctionInfo; + static_assert(std::is_same_v>, + "Return type should be Expected"); + static_assert(tvm::ffi::details::is_expected_v, + "RetType should be detected as Expected"); + + Function::SetGlobal("test.safe_divide3", Function::FromTyped(safe_divide)); + + Function func = Function::GetGlobalRequired("test.safe_divide3"); + + // Normal call should throw when function returns Err + EXPECT_THROW({ func(10, 0).cast(); }, Error); + + // Normal call should succeed when function returns Ok + int result = func(10, 2).cast(); + EXPECT_EQ(result, 5); + + // CallExpected should return Expected + Expected exp_ok = func.CallExpected(10, 2); + EXPECT_TRUE(exp_ok.is_ok()); + EXPECT_EQ(exp_ok.value(), 5); + + Expected exp_err = func.CallExpected(10, 0); + EXPECT_TRUE(exp_err.is_err()); + EXPECT_EQ(exp_err.error().message(), "Division by zero"); +} + +// Test Expected with Optional (nested types) +TEST(Expected, NestedOptional) { + Expected> result = ExpectedOk(Optional(42)); + + EXPECT_TRUE(result.is_ok()); + EXPECT_TRUE(result.value().has_value()); + EXPECT_EQ(result.value().value(), 42); + + Expected> result_none = ExpectedOk(Optional(std::nullopt)); + EXPECT_TRUE(result_none.is_ok()); + EXPECT_FALSE(result_none.value().has_value()); +} + +// Test Expected with Array +TEST(Expected, ArrayType) { + Array arr{1, 2, 3}; + Expected> result = ExpectedOk(arr); + + EXPECT_TRUE(result.is_ok()); + EXPECT_EQ(result.value().size(), 3); + EXPECT_EQ(result.value()[0], 1); +} + +// Test complex example: function returning Expected> +TEST(Expected, ComplexExample) { + auto parse_csv = [](const String& input) -> Expected> { + if (input.size() == 0) { + return ExpectedErr>(Error("ValueError", "Empty input", "")); + } + // Simple split by comma + Array result; + result.push_back(input); // Simplified for test + return ExpectedOk(result); + }; + + Function::SetGlobal("test.parse_csv", Function::FromTyped(parse_csv)); + Function func = Function::GetGlobalRequired("test.parse_csv"); + + Expected> result_ok = func.CallExpected>(String("a,b,c")); + EXPECT_TRUE(result_ok.is_ok()); + + Expected> result_err = func.CallExpected>(String("")); + EXPECT_TRUE(result_err.is_err()); + EXPECT_EQ(result_err.error().message(), "Empty input"); +} + +// Test bad access throws +TEST(Expected, BadAccessThrows) { + Expected result = ExpectedErr(Error("RuntimeError", "error", "")); + EXPECT_THROW({ result.value(); }, Error); +} + +// Test TryCastFromAnyView with incompatible type +TEST(Expected, TryCastIncompatible) { + Any any_str = String("hello"); + auto result = any_str.try_cast>(); + EXPECT_FALSE(result.has_value()); // Cannot convert String to Expected +} + +} // namespace From ad48e1b101694c0ec4ba9ff7c17eef7f1396ab6a Mon Sep 17 00:00:00 2001 From: Guan-Ming Chiu Date: Tue, 20 Jan 2026 19:27:16 +0800 Subject: [PATCH 2/2] refactor: apply review comments --- include/tvm/ffi/expected.h | 73 +++++++++++++++----------------------- tests/cpp/test_expected.cc | 61 ++++++++++++++++++++++++++++--- 2 files changed, 85 insertions(+), 49 deletions(-) diff --git a/include/tvm/ffi/expected.h b/include/tvm/ffi/expected.h index 5e0a1cd1..3cb8c0f5 100644 --- a/include/tvm/ffi/expected.h +++ b/include/tvm/ffi/expected.h @@ -80,14 +80,15 @@ class Expected { /*! * \brief Check if the Expected contains a success value. * \return True if contains success value, false if contains error. + * \note Checks for Error first to handle cases where T is a base class of Error. */ - TVM_FFI_INLINE bool is_ok() const { return data_.as().has_value(); } + TVM_FFI_INLINE bool is_ok() const { return !data_.as().has_value(); } /*! * \brief Check if the Expected contains an error. * \return True if contains error, false if contains success value. */ - TVM_FFI_INLINE bool is_err() const { return data_.as().has_value(); } + TVM_FFI_INLINE bool is_err() const { return !is_ok(); } /*! * \brief Alias for is_ok(). @@ -95,28 +96,27 @@ class Expected { */ TVM_FFI_INLINE bool has_value() const { return is_ok(); } - /*! - * \brief Access the success value. - * \return The success value. - * \throws RuntimeError if the Expected contains an error. - */ - TVM_FFI_INLINE T value() const { - if (is_err()) { - TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains error"; - } + /*! \brief Access the success value. Throws the contained error if is_err(). */ + TVM_FFI_INLINE T value() const& { + if (is_err()) throw data_.cast(); return data_.cast(); } + /*! \brief Access the success value (rvalue). Throws the contained error if is_err(). */ + TVM_FFI_INLINE T value() && { + if (is_err()) throw std::move(data_).template cast(); + return std::move(data_).template cast(); + } - /*! - * \brief Access the error value. - * \return The error value. - * \note Behavior is undefined if the Expected contains a success value. - * Always check is_err() before calling this method. - */ - TVM_FFI_INLINE Error error() const { - TVM_FFI_ICHECK(is_err()) << "Expected does not contain an error"; + /*! \brief Access the error. Throws RuntimeError if is_ok(). */ + TVM_FFI_INLINE Error error() const& { + if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not error"; return data_.cast(); } + /*! \brief Access the error (rvalue). Throws RuntimeError if is_ok(). */ + TVM_FFI_INLINE Error error() && { + if (!is_err()) TVM_FFI_THROW(RuntimeError) << "Bad expected access: contains value, not error"; + return std::move(data_).template cast(); + } /*! * \brief Get the success value or a default value. @@ -176,32 +176,20 @@ inline constexpr bool use_default_type_traits_v> = false; template struct TypeTraits> : public TypeTraitsBase { TVM_FFI_INLINE static void CopyToAnyView(const Expected& src, TVMFFIAny* result) { - // Extract value from src.data_ and copy it properly const TVMFFIAny* src_any = reinterpret_cast(&src.data_); - if (TypeTraits::CheckAnyStrict(src_any)) { - // It contains T, copy it out and move to result - T value = TypeTraits::CopyFromAnyViewAfterCheck(src_any); - TypeTraits::MoveToAny(std::move(value), result); + TypeTraits::MoveToAny(TypeTraits::CopyFromAnyViewAfterCheck(src_any), result); } else { - // It contains Error, copy it out and move to result - Error err = TypeTraits::CopyFromAnyViewAfterCheck(src_any); - TypeTraits::MoveToAny(std::move(err), result); + TypeTraits::MoveToAny(TypeTraits::CopyFromAnyViewAfterCheck(src_any), result); } } TVM_FFI_INLINE static void MoveToAny(Expected src, TVMFFIAny* result) { - // Extract value from src.data_ and move it properly TVMFFIAny* src_any = reinterpret_cast(&src.data_); - if (TypeTraits::CheckAnyStrict(src_any)) { - // It contains T, move it out and move to result - T value = TypeTraits::MoveFromAnyAfterCheck(src_any); - TypeTraits::MoveToAny(std::move(value), result); + TypeTraits::MoveToAny(TypeTraits::MoveFromAnyAfterCheck(src_any), result); } else { - // It contains Error, move it out and move to result - Error err = TypeTraits::MoveFromAnyAfterCheck(src_any); - TypeTraits::MoveToAny(std::move(err), result); + TypeTraits::MoveToAny(TypeTraits::MoveFromAnyAfterCheck(src_any), result); } } @@ -212,30 +200,25 @@ struct TypeTraits> : public TypeTraitsBase { TVM_FFI_INLINE static Expected CopyFromAnyViewAfterCheck(const TVMFFIAny* src) { if (TypeTraits::CheckAnyStrict(src)) { return Expected::Ok(TypeTraits::CopyFromAnyViewAfterCheck(src)); - } else { - return Expected::Err(TypeTraits::CopyFromAnyViewAfterCheck(src)); } + return Expected::Err(TypeTraits::CopyFromAnyViewAfterCheck(src)); } TVM_FFI_INLINE static Expected MoveFromAnyAfterCheck(TVMFFIAny* src) { if (TypeTraits::CheckAnyStrict(src)) { return Expected::Ok(TypeTraits::MoveFromAnyAfterCheck(src)); - } else { - return Expected::Err(TypeTraits::MoveFromAnyAfterCheck(src)); } + return Expected::Err(TypeTraits::MoveFromAnyAfterCheck(src)); } TVM_FFI_INLINE static std::optional> TryCastFromAnyView(const TVMFFIAny* src) { - // Try to convert to T first - if (std::optional opt = TypeTraits::TryCastFromAnyView(src)) { + if (auto opt = TypeTraits::TryCastFromAnyView(src)) { return Expected::Ok(*std::move(opt)); } - // Try to convert to Error - if (std::optional opt_err = TypeTraits::TryCastFromAnyView(src)) { + if (auto opt_err = TypeTraits::TryCastFromAnyView(src)) { return Expected::Err(*std::move(opt_err)); } - // Conversion failed - return explicit nullopt to indicate failure - return std::optional>(std::nullopt); + return std::nullopt; } TVM_FFI_INLINE static std::string GetMismatchTypeInfo(const TVMFFIAny* src) { diff --git a/tests/cpp/test_expected.cc b/tests/cpp/test_expected.cc index 72f0fcdc..d56b6103 100644 --- a/tests/cpp/test_expected.cc +++ b/tests/cpp/test_expected.cc @@ -258,10 +258,63 @@ TEST(Expected, ComplexExample) { EXPECT_EQ(result_err.error().message(), "Empty input"); } -// Test bad access throws -TEST(Expected, BadAccessThrows) { - Expected result = ExpectedErr(Error("RuntimeError", "error", "")); - EXPECT_THROW({ result.value(); }, Error); +// Test bad access throws the original error +TEST(Expected, BadAccessThrowsOriginalError) { + Expected result = ExpectedErr(Error("CustomError", "original message", "")); + try { + result.value(); + FAIL() << "Expected Error to be thrown"; + } catch (const Error& e) { + // Verify the original error is preserved + EXPECT_EQ(e.kind(), "CustomError"); + EXPECT_EQ(e.message(), "original message"); + } +} + +// Test error() throws RuntimeError on bad access +TEST(Expected, ErrorBadAccessThrows) { + Expected result = ExpectedOk(42); + EXPECT_THROW({ result.error(); }, Error); +} + +// Test rvalue overload for value() +TEST(Expected, RvalueValueAccess) { + auto get_expected = []() -> Expected { return ExpectedOk(String("rvalue test")); }; + + // Call value() on rvalue + String val = get_expected().value(); + EXPECT_EQ(val, "rvalue test"); +} + +// Test rvalue overload for error() +TEST(Expected, RvalueErrorAccess) { + auto get_expected = []() -> Expected { + return ExpectedErr(Error("TestError", "rvalue error", "")); + }; + + // Call error() on rvalue + Error err = get_expected().error(); + EXPECT_EQ(err.kind(), "TestError"); + EXPECT_EQ(err.message(), "rvalue error"); +} + +// Test Expected with inheritance (Error is a subclass of ObjectRef) +// This ensures is_ok() and is_err() work correctly for ObjectRef types +TEST(Expected, ObjectRefInheritance) { + // Expected with an actual ObjectRef value + ObjectRef obj = TInt(100); + Expected result_ok = ExpectedOk(obj); + + EXPECT_TRUE(result_ok.is_ok()); + EXPECT_FALSE(result_ok.is_err()); + EXPECT_TRUE(result_ok.value().defined()); + + // Expected with an error + Expected result_err = ExpectedErr(Error("TestError", "test", "")); + + EXPECT_FALSE(result_err.is_ok()); + EXPECT_TRUE(result_err.is_err()); + EXPECT_EQ(result_err.error().kind(), "TestError"); } // Test TryCastFromAnyView with incompatible type