Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 12 additions & 4 deletions include/tvm/ffi/base_details.h
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@
#else
#define TVM_FFI_FUNC_SIG __func__
#endif
/// \endcond

#if defined(__GNUC__)
// gcc and clang and attribute constructor
Expand Down Expand Up @@ -114,14 +115,22 @@
return 0; \
}(); \
static void FnName()

/// \endcond
/*!
* \brief Macro that defines a block that will be called during static initialization.
*
* \code{.cpp}
* TVM_FFI_STATIC_INIT_BLOCK() {
* RegisterFunctions();
* }
* \endcode
*/
#define TVM_FFI_STATIC_INIT_BLOCK() \
TVM_FFI_STATIC_INIT_BLOCK_DEF_(TVM_FFI_STR_CONCAT(__TVMFFIStaticInitFunc, __COUNTER__), \
TVM_FFI_STR_CONCAT(__TVMFFIStaticInitReg, __COUNTER__))
/// \endcond
#endif

/*
/*!
* \brief Define the default copy/move constructor and assign operator
* \param TypeName The class typename.
*/
Expand Down Expand Up @@ -313,5 +322,4 @@ using TypeSchema = TypeSchemaImpl<std::remove_const_t<std::remove_reference_t<T>
} // namespace details
} // namespace ffi
} // namespace tvm
/// \endcond
#endif // TVM_FFI_BASE_DETAILS_H_
1 change: 1 addition & 0 deletions python/tvm_ffi/core.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,7 @@ class DLTensorTestWrapper:
def _dltensor_test_wrapper_c_dlpack_from_pyobject_as_intptr() -> int: ...

class Function(Object):
def __init__(self, func: Callable[..., Any]) -> None: ...
@property
def release_gil(self) -> bool: ...
@release_gil.setter
Expand Down
21 changes: 21 additions & 0 deletions python/tvm_ffi/cython/function.pxi
Original file line number Diff line number Diff line change
Expand Up @@ -889,6 +889,27 @@ cdef class Function(Object):
def __cinit__(self) -> None:
self.c_release_gil = _RELEASE_GIL_BY_DEFAULT

def __init__(self, func: Callable[..., Any]) -> None:
"""Initialize a Function from a Python callable.

This constructor allows creating a `tvm_ffi.Function` directly
from a Python function or another `tvm_ffi.Function` instance.

Parameters
----------
func : Callable[..., Any]
The Python callable to wrap.
"""
cdef TVMFFIObjectHandle chandle = NULL
if not callable(func):
raise TypeError(f"func must be callable, got {type(func)}")
if isinstance(func, Function):
chandle = (<Object>func).chandle
TVMFFIObjectIncRef(chandle)
else:
_convert_to_ffi_func_handle(func, &chandle)
self.chandle = chandle
Comment on lines +903 to +911
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Calling __init__ on an already-initialized tvm_ffi.Function object will cause a memory leak. The current implementation overwrites self.chandle without decrementing the reference count of the previously held handle. To prevent this and enforce that __init__ acts as a one-time constructor, you should add a check to ensure the object is not already initialized.

        if self.chandle != NULL:
            raise TypeError("A tvm_ffi.Function object can only be initialized once.")

        cdef TVMFFIObjectHandle chandle = NULL
        if not callable(func):
            raise TypeError(f"func must be callable, got {type(func)}")
        if isinstance(func, Function):
            chandle = (<Object>func).chandle
            TVMFFIObjectIncRef(chandle)
        else:
            _convert_to_ffi_func_handle(func, &chandle)
        self.chandle = chandle


property release_gil:
"""Whether calls release the Python GIL while executing."""

Expand Down
9 changes: 9 additions & 0 deletions tests/python/test_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,15 @@ def fapply(f: Any, *args: Any) -> Any:
assert fapply(add, 1, 3.3) == 4.3


def test_pyfunc_init() -> None:
def add(a: int, b: int) -> int:
return a + b

fadd = tvm_ffi.Function(add)
assert isinstance(fadd, tvm_ffi.Function)
assert fadd(1, 2) == 3


def test_global_func() -> None:
@tvm_ffi.register_global_func("mytest.echo")
def echo(x: Any) -> Any:
Expand Down
Loading