From b041a22080d3e4c6e1fd8da592ba36278dd32fef Mon Sep 17 00:00:00 2001 From: Junru Shao Date: Mon, 12 Jan 2026 10:39:43 -0800 Subject: [PATCH] Revert "Revert "feat: Add `tvm_ffi.Function.__init__`" (#406)" This reverts commit 91fcaa8bd261d55255fe658ce026ce6840547234. --- include/tvm/ffi/base_details.h | 16 ++++++++++++---- python/tvm_ffi/core.pyi | 1 + python/tvm_ffi/cython/function.pxi | 21 +++++++++++++++++++++ tests/python/test_function.py | 9 +++++++++ 4 files changed, 43 insertions(+), 4 deletions(-) diff --git a/include/tvm/ffi/base_details.h b/include/tvm/ffi/base_details.h index 7224ac11..a00d4a4a 100644 --- a/include/tvm/ffi/base_details.h +++ b/include/tvm/ffi/base_details.h @@ -86,6 +86,7 @@ #else #define TVM_FFI_FUNC_SIG __func__ #endif +/// \endcond #if defined(__GNUC__) // gcc and clang and attribute constructor @@ -114,14 +115,22 @@ return 0; \ }(); \ static void FnName() - +/// \endcond +/*! + * \brief Macro that defines a block that will be called during static initialization. + * + * \code{.cpp} + * TVM_FFI_STATIC_INIT_BLOCK() { + * RegisterFunctions(); + * } + * \endcode + */ #define TVM_FFI_STATIC_INIT_BLOCK() \ TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc, __COUNTER__), \ TVM_FFI_STR_CONCAT(__TVMFFIStaticInitReg, __COUNTER__)) -/// \endcond #endif -/* +/*! * \brief Define the default copy/move constructor and assign operator * \param TypeName The class typename. */ @@ -313,5 +322,4 @@ using TypeSchema = TypeSchemaImpl } // namespace details } // namespace ffi } // namespace tvm -/// \endcond #endif // TVM_FFI_BASE_DETAILS_H_ diff --git a/python/tvm_ffi/core.pyi b/python/tvm_ffi/core.pyi index 2cee79cc..1c3a3dad 100644 --- a/python/tvm_ffi/core.pyi +++ b/python/tvm_ffi/core.pyi @@ -184,6 +184,7 @@ class DLTensorTestWrapper: def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() -> int: ... class Function(Object): + def __init__(self, func: Callable[..., Any]) -> None: ... @property def release_gil(self) -> bool: ... @release_gil.setter diff --git a/python/tvm_ffi/cython/function.pxi b/python/tvm_ffi/cython/function.pxi index 01a33667..67429ea0 100644 --- a/python/tvm_ffi/cython/function.pxi +++ b/python/tvm_ffi/cython/function.pxi @@ -889,6 +889,27 @@ cdef class Function(Object): def __cinit__(self) -> None: self.c_release_gil = _RELEASE_GIL_BY_DEFAULT + def __init__(self, func: Callable[..., Any]) -> None: + """Initialize a Function from a Python callable. + + This constructor allows creating a `tvm_ffi.Function` directly + from a Python function or another `tvm_ffi.Function` instance. + + Parameters + ---------- + func : Callable[..., Any] + The Python callable to wrap. + """ + cdef TVMFFIObjectHandle chandle = NULL + if not callable(func): + raise TypeError(f"func must be callable, got {type(func)}") + if isinstance(func, Function): + chandle = (func).chandle + TVMFFIObjectIncRef(chandle) + else: + _convert_to_ffi_func_handle(func, &chandle) + self.chandle = chandle + property release_gil: """Whether calls release the Python GIL while executing.""" diff --git a/tests/python/test_function.py b/tests/python/test_function.py index 8a494fb1..34d3bdc8 100644 --- a/tests/python/test_function.py +++ b/tests/python/test_function.py @@ -153,6 +153,15 @@ def fapply(f: Any, *args: Any) -> Any: assert fapply(add, 1, 3.3) == 4.3 +def test_pyfunc_init() -> None: + def add(a: int, b: int) -> int: + return a + b + + fadd = tvm_ffi.Function(add) + assert isinstance(fadd, tvm_ffi.Function) + assert fadd(1, 2) == 3 + + def test_global_func() -> None: @tvm_ffi.register_global_func("mytest.echo") def echo(x: Any) -> Any: