Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/tvm_ffi/dataclasses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,6 @@
from dataclasses import MISSING

from .c_class import c_class
from .field import Field, field
from .field import KW_ONLY, Field, field

__all__ = ["MISSING", "Field", "c_class", "field"]
__all__ = ["KW_ONLY", "MISSING", "Field", "c_class", "field"]
58 changes: 41 additions & 17 deletions python/tvm_ffi/dataclasses/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,13 @@ def _add_method(name: str, func: Callable[..., Any]) -> None:
return cast(Type[_InputClsType], new_cls)


def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
def fill_dataclass_field(
type_cls: type,
type_field: TypeField,
*,
class_kw_only: bool = False,
kw_only_from_sentinel: bool = False,
) -> None:
from .field import Field, field # noqa: PLC0415

field_name = type_field.name
Expand All @@ -94,6 +100,14 @@ def fill_dataclass_field(type_cls: type, type_field: TypeField) -> None:
raise ValueError(f"Cannot recognize field: {type_field.name}: {rhs}")
assert isinstance(rhs, Field)
rhs.name = type_field.name

# Resolve kw_only: field-level > KW_ONLY sentinel > class-level
if rhs.kw_only is MISSING:
if kw_only_from_sentinel:
rhs.kw_only = True
else:
rhs.kw_only = class_kw_only

type_field.dataclass_field = rhs


Expand Down Expand Up @@ -148,47 +162,56 @@ def method_repr(type_cls: type, type_info: TypeInfo) -> Callable[..., str]:
return __repr__


def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
def method_init(_type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
"""Generate an ``__init__`` that forwards to the FFI constructor.

The generated initializer has a proper Python signature built from the
reflected field list, supporting default values and ``__post_init__``.
reflected field list, supporting default values, keyword-only args, and ``__post_init__``.
"""
# Step 0. Collect all fields from the type hierarchy
fields = _get_all_fields(type_info)
# sanity check
for type_method in type_info.methods:
if type_method.name == "__ffi_init__":
break
else:
if not any(m.name == "__ffi_init__" for m in type_info.methods):
raise ValueError(f"Cannot find constructor method: `{type_info.type_key}.__ffi_init__`")
# Step 1. Split args into sections and register default factories
args_no_defaults: list[str] = []
args_with_defaults: list[str] = []
pos_no_defaults: list[str] = []
pos_with_defaults: list[str] = []
kw_no_defaults: list[str] = []
kw_with_defaults: list[str] = []
fields_with_defaults: list[tuple[str, bool]] = []
ffi_arg_order: list[str] = []
exec_globals = {"MISSING": MISSING}
exec_globals: dict[str, Any] = {"MISSING": MISSING}

for field in fields:
assert field.name is not None
assert field.dataclass_field is not None
dataclass_field = field.dataclass_field
has_default_factory = (default_factory := dataclass_field.default_factory) is not MISSING
has_default = (default_factory := dataclass_field.default_factory) is not MISSING
is_kw_only = dataclass_field.kw_only is True

if dataclass_field.init:
ffi_arg_order.append(field.name)
if has_default_factory:
args_with_defaults.append(field.name)
if has_default:
(kw_with_defaults if is_kw_only else pos_with_defaults).append(field.name)
fields_with_defaults.append((field.name, True))
exec_globals[f"_default_factory_{field.name}"] = default_factory
else:
args_no_defaults.append(field.name)
elif has_default_factory:
(kw_no_defaults if is_kw_only else pos_no_defaults).append(field.name)
elif has_default:
ffi_arg_order.append(field.name)
fields_with_defaults.append((field.name, False))
exec_globals[f"_default_factory_{field.name}"] = default_factory

# Step 2. Build signature
args: list[str] = ["self"]
args.extend(args_no_defaults)
args.extend(f"{name}=MISSING" for name in args_with_defaults)
args.extend(pos_no_defaults)
args.extend(f"{name}=MISSING" for name in pos_with_defaults)
if kw_no_defaults or kw_with_defaults:
args.append("*")
args.extend(kw_no_defaults)
args.extend(f"{name}=MISSING" for name in kw_with_defaults)

# Step 3. Build body
body_lines: list[str] = []
for field_name, is_init in fields_with_defaults:
if is_init:
Expand All @@ -208,6 +231,7 @@ def method_init(type_cls: type, type_info: TypeInfo) -> Callable[..., None]:
" fn_post_init()",
]
)

source_lines = [f"def __init__({', '.join(args)}):"]
source_lines.extend(f" {line}" for line in body_lines)
source_lines.append(" ...")
Expand Down
47 changes: 39 additions & 8 deletions python/tvm_ffi/dataclasses/c_class.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@

from ..core import TypeField, TypeInfo, _lookup_or_register_type_info_from_type_key, _set_type_cls
from . import _utils
from .field import field
from .field import KW_ONLY, field

_InputClsType = TypeVar("_InputClsType")


@dataclass_transform(field_specifiers=(field,))
@dataclass_transform(field_specifiers=(field,), kw_only_default=False)
def c_class(
type_key: str, init: bool = True, repr: bool = True
type_key: str, init: bool = True, kw_only: bool = False, repr: bool = True
) -> Callable[[Type[_InputClsType]], Type[_InputClsType]]: # noqa: UP006
"""(Experimental) Create a dataclass-like proxy for a C++ class registered with TVM FFI.

Expand Down Expand Up @@ -71,6 +71,12 @@ def c_class(
signature. The generated initializer calls the C++ ``__init__``
function registered with ``ObjectDef`` and invokes ``__post_init__`` if
it exists on the Python class.

kw_only
If ``True``, all fields become keyword-only parameters in the generated
``__init__``. Individual fields can override this by setting
``kw_only=False`` in :func:`field`. Additionally, a ``KW_ONLY`` sentinel
annotation can be used to mark all subsequent fields as keyword-only.
repr
If ``True`` and the Python class does not define ``__repr__``, a
representation method is auto-generated that includes all fields with
Expand Down Expand Up @@ -129,9 +135,15 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no
type_info: TypeInfo = _lookup_or_register_type_info_from_type_key(type_key)
assert type_info.parent_type_info is not None
# Step 2. Reflect all the fields of the type
type_info.fields = _inspect_c_class_fields(super_type_cls, type_info)
for type_field in type_info.fields:
_utils.fill_dataclass_field(super_type_cls, type_field)
type_info.fields, kw_only_start_idx = _inspect_c_class_fields(super_type_cls, type_info)
for idx, type_field in enumerate(type_info.fields):
kw_only_from_sentinel = kw_only_start_idx is not None and idx >= kw_only_start_idx
_utils.fill_dataclass_field(
super_type_cls,
type_field,
class_kw_only=kw_only,
kw_only_from_sentinel=kw_only_from_sentinel,
)
# Step 3. Create the proxy class with the fields as properties
fn_init = _utils.method_init(super_type_cls, type_info) if init else None
fn_repr = _utils.method_repr(super_type_cls, type_info) if repr else None
Expand All @@ -146,7 +158,9 @@ def decorator(super_type_cls: Type[_InputClsType]) -> Type[_InputClsType]: # no
return decorator


def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeField]:
def _inspect_c_class_fields(
type_cls: type, type_info: TypeInfo
) -> tuple[list[TypeField], int | None]:
if sys.version_info >= (3, 9):
type_hints_resolved = get_type_hints(type_cls, include_extras=True)
else:
Expand All @@ -159,7 +173,24 @@ def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeFie
ClassVar,
InitVar,
]
and type_hints_resolved[name] is not KW_ONLY
}

# Detect KW_ONLY sentinel position
kw_only_start_idx: int | None = None
field_count = 0
for name in getattr(type_cls, "__annotations__", {}).keys():
resolved_type = type_hints_resolved.get(name)
if resolved_type is None:
continue
if get_origin(resolved_type) in [ClassVar, InitVar]:
continue
if resolved_type is KW_ONLY:
if kw_only_start_idx is not None:
raise ValueError(f"KW_ONLY may only be used once per class: {type_cls}")
kw_only_start_idx = field_count
continue
field_count += 1
del type_hints_resolved

type_fields_cxx: dict[str, TypeField] = {f.name: f for f in type_info.fields}
Expand All @@ -178,4 +209,4 @@ def _inspect_c_class_fields(type_cls: type, type_info: TypeInfo) -> list[TypeFie
raise ValueError(
f"Missing fields in `{type_cls}`: {extra_fields}. Defined in C++ but not in Python"
)
return type_fields
return type_fields, kw_only_start_idx
35 changes: 33 additions & 2 deletions python/tvm_ffi/dataclasses/field.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,17 @@
from dataclasses import _MISSING_TYPE, MISSING
from typing import Any, Callable, TypeVar, cast

try:
from dataclasses import KW_ONLY # type: ignore[attr-defined]
except ImportError:
# Python < 3.10: define our own KW_ONLY sentinel
class _KW_ONLY_Sentinel:
__slots__ = ()

KW_ONLY = _KW_ONLY_Sentinel()

_FieldValue = TypeVar("_FieldValue")
_KW_ONLY_TYPE = type(KW_ONLY)


class Field:
Expand All @@ -37,7 +47,7 @@ class Field:
way the decorator understands.
"""

__slots__ = ("default_factory", "init", "name", "repr")
__slots__ = ("default_factory", "init", "kw_only", "name", "repr")

def __init__(
self,
Expand All @@ -46,12 +56,14 @@ def __init__(
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING,
init: bool = True,
repr: bool = True,
kw_only: bool | _MISSING_TYPE = MISSING,
) -> None:
"""Do not call directly; use :func:`field` instead."""
self.name = name
self.default_factory = default_factory
self.init = init
self.repr = repr
self.kw_only = kw_only


def field(
Expand All @@ -60,6 +72,7 @@ def field(
default_factory: Callable[[], _FieldValue] | _MISSING_TYPE = MISSING, # type: ignore[assignment]
init: bool = True,
repr: bool = True,
kw_only: bool | _MISSING_TYPE = MISSING, # type: ignore[assignment]
) -> _FieldValue:
"""(Experimental) Declare a dataclass-style field on a :func:`c_class` proxy.

Expand All @@ -84,6 +97,10 @@ def field(
repr
If ``True`` the field is included in the generated ``__repr__``.
If ``False`` the field is omitted from the ``__repr__`` output.
kw_only
If ``True``, the field is a keyword-only argument in ``__init__``.
If ``MISSING``, inherits from the class-level ``kw_only`` setting or
from a preceding ``KW_ONLY`` sentinel annotation.

Note
----
Expand Down Expand Up @@ -124,16 +141,30 @@ class PyBase:
obj = PyBase(v_i64=4)
obj.v_i32 # -> 16

Use ``kw_only=True`` to make a field keyword-only:

.. code-block:: python

@c_class("testing.TestCxxClassBase")
class PyBase:
v_i64: int
v_i32: int = field(kw_only=True)


obj = PyBase(4, v_i32=8) # v_i32 must be keyword

"""
if default is not MISSING and default_factory is not MISSING:
raise ValueError("Cannot specify both `default` and `default_factory`")
if not isinstance(init, bool):
raise TypeError("`init` must be a bool")
if not isinstance(repr, bool):
raise TypeError("`repr` must be a bool")
if kw_only is not MISSING and not isinstance(kw_only, bool):
raise TypeError(f"`kw_only` must be a bool, got {type(kw_only).__name__!r}")
if default is not MISSING:
default_factory = _make_default_factory(default)
ret = Field(default_factory=default_factory, init=init, repr=repr)
ret = Field(default_factory=default_factory, init=init, repr=repr, kw_only=kw_only)
return cast(_FieldValue, ret)


Expand Down
1 change: 1 addition & 0 deletions python/tvm_ffi/testing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
_TestCxxClassDerived,
_TestCxxClassDerivedDerived,
_TestCxxInitSubset,
_TestCxxKwOnly,
add_one,
create_object,
make_unregistered_object,
Expand Down
8 changes: 8 additions & 0 deletions python/tvm_ffi/testing/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,11 @@ class _TestCxxInitSubset:
required_field: int
optional_field: int = field(init=False)
note: str = field(default_factory=lambda: "py-default", init=False)


@c_class("testing.TestCxxKwOnly", kw_only=True)
class _TestCxxKwOnly:
x: int
y: int
z: int
w: int = 100
20 changes: 20 additions & 0 deletions src/ffi/testing/testing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,19 @@ class TestCxxInitSubsetObj : public Object {
TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxInitSubset", TestCxxInitSubsetObj, Object);
};

class TestCxxKwOnly : public Object {
public:
int64_t x;
int64_t y;
int64_t z;
int64_t w;

TestCxxKwOnly(int64_t x, int64_t y, int64_t z, int64_t w) : x(x), y(y), z(z), w(w) {}

static constexpr bool _type_mutable = true;
TVM_FFI_DECLARE_OBJECT_INFO("testing.TestCxxKwOnly", TestCxxKwOnly, Object);
};

class TestUnregisteredBaseObject : public Object {
public:
int64_t v1;
Expand Down Expand Up @@ -229,6 +242,13 @@ TVM_FFI_STATIC_INIT_BLOCK() {
.def_rw("optional_field", &TestCxxInitSubsetObj::optional_field)
.def_rw("note", &TestCxxInitSubsetObj::note);

refl::ObjectDef<TestCxxKwOnly>()
.def(refl::init<int64_t, int64_t, int64_t, int64_t>())
.def_rw("x", &TestCxxKwOnly::x)
.def_rw("y", &TestCxxKwOnly::y)
.def_rw("z", &TestCxxKwOnly::z)
.def_rw("w", &TestCxxKwOnly::w);

refl::ObjectDef<TestUnregisteredBaseObject>()
.def(refl::init<int64_t>(), "Constructor of TestUnregisteredBaseObject")
.def_ro("v1", &TestUnregisteredBaseObject::v1)
Expand Down
Loading