Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ jobs:
- "3.14t"
- "pypy-3.11"
fail-fast: false
timeout-minutes: 90
timeout-minutes: 120
steps:
- name: Checkout
uses: actions/checkout@v6
Expand Down
30 changes: 30 additions & 0 deletions include/optree/pytypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ inline bool IsNamedTupleClass(const py::handle &type) {
static read_write_mutex mutex{};

{
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_release_simple gil_release{};
#endif
const scoped_read_lock lock{mutex};
const auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
Expand All @@ -270,8 +273,14 @@ inline bool IsNamedTupleClass(const py::handle &type) {

const bool result = EVALUATE_WITH_LOCK_HELD(IsNamedTupleClassImpl(type), type);
{
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_release_simple gil_release{};
#endif
const scoped_write_lock lock{mutex};
if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] {
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_acquire_simple gil_acquire{};
#endif
cache.emplace(type, result);
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
const scoped_write_lock lock{mutex};
Expand Down Expand Up @@ -363,6 +372,9 @@ inline bool IsStructSequenceClass(const py::handle &type) {
static read_write_mutex mutex{};

{
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_release_simple gil_release{};
#endif
const scoped_read_lock lock{mutex};
const auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
Expand All @@ -372,8 +384,14 @@ inline bool IsStructSequenceClass(const py::handle &type) {

const bool result = EVALUATE_WITH_LOCK_HELD(IsStructSequenceClassImpl(type), type);
{
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_release_simple gil_release{};
#endif
const scoped_write_lock lock{mutex};
if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] {
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_acquire_simple gil_acquire{};
#endif
cache.emplace(type, result);
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
const scoped_write_lock lock{mutex};
Expand Down Expand Up @@ -446,17 +464,29 @@ inline py::tuple StructSequenceGetFields(const py::handle &object) {
static read_write_mutex mutex{};

{
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_release_simple gil_release{};
#endif
const scoped_read_lock lock{mutex};
const auto it = cache.find(type);
if (it != cache.end()) [[likely]] {
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_acquire_simple gil_acquire{};
#endif
return py::reinterpret_borrow<py::tuple>(it->second);
}
}

const py::tuple fields = EVALUATE_WITH_LOCK_HELD(StructSequenceGetFieldsImpl(type), type);
{
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_release_simple gil_release{};
#endif
const scoped_write_lock lock{mutex};
if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] {
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_acquire_simple gil_acquire{};
#endif
cache.emplace(type, fields);
fields.inc_ref();
(void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void {
Expand Down
2 changes: 0 additions & 2 deletions include/optree/treespec.h
Original file line number Diff line number Diff line change
Expand Up @@ -464,9 +464,7 @@ class PyTreeIter {
const bool m_none_is_leaf;
const std::string m_namespace;
const bool m_is_dict_insertion_ordered;
#if defined(Py_GIL_DISABLED)
mutable mutex m_mutex{};
#endif

template <bool NoneIsLeaf>
[[nodiscard]] py::object NextImpl();
Expand Down
12 changes: 4 additions & 8 deletions src/registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ template <bool NoneIsLeaf>
EXPECT_TRUE(registry.m_builtins_types.emplace(cls).second,
"PyTree type " + PyRepr(cls) +
" is already registered in the built-in types set.");
cls.inc_ref();
if (!NoneIsLeaf || kind != PyTreeKind::None) {
if (!NoneIsLeaf || kind != PyTreeKind::None) [[likely]] {
auto registration =
std::make_shared<std::remove_const_t<RegistrationPtr::element_type>>();
registration->kind = kind;
Expand All @@ -50,9 +49,9 @@ template <bool NoneIsLeaf>
registry.m_registrations.emplace(cls, std::move(registration)).second,
"PyTree type " + PyRepr(cls) +
" is already registered in the global namespace.");
if constexpr (!NoneIsLeaf) {
cls.inc_ref();
}
}
if constexpr (!NoneIsLeaf) {
cls.inc_ref();
}
};
add_builtin_type(PyNoneTypeObject, PyTreeKind::None);
Expand Down Expand Up @@ -368,9 +367,6 @@ template PyTreeKind PyTreeTypeRegistry::GetKind<NONE_IS_LEAF>(
}
#endif

for (const auto &cls : registry1.m_builtins_types) {
cls.dec_ref();
}
for (const auto &[_, registration1] : registry1.m_registrations) {
registration1->type.dec_ref();
registration1->flatten_func.dec_ref();
Expand Down
17 changes: 11 additions & 6 deletions src/treespec/traversal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,19 @@ py::object PyTreeIter::NextImpl() {
}

py::object PyTreeIter::Next() {
#if defined(Py_GIL_DISABLED)
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_release_simple gil_release{};
#endif
const scoped_lock lock{m_mutex};
{
#if !defined(Py_GIL_DISABLED)
const py::gil_scoped_acquire_simple gil_acquire{};
#endif

if (m_none_is_leaf) [[unlikely]] {
return NextImpl<NONE_IS_LEAF>();
} else [[likely]] {
return NextImpl<NONE_IS_NODE>();
if (m_none_is_leaf) [[unlikely]] {
return NextImpl<NONE_IS_LEAF>();
} else [[likely]] {
return NextImpl<NONE_IS_NODE>();
}
}
}

Expand Down
52 changes: 27 additions & 25 deletions tests/concurrent/test_threading.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,31 +343,33 @@ def test_tree_iter_thread_safe(
dict_should_be_sorted,
dict_session_namespace,
):
counter = itertools.count()
with optree.dict_insertion_ordered(
not dict_should_be_sorted,
namespace=dict_session_namespace or GLOBAL_NAMESPACE,
):
new_tree = optree.tree_map(
lambda x: next(counter),
tree,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
num_leaves = next(counter)
assert optree.tree_leaves(
new_tree,
none_is_leaf=none_is_leaf,
namespace=namespace,
) == list(range(num_leaves))

it = optree.tree_iter(
new_tree,
none_is_leaf=none_is_leaf,
namespace=namespace,
)

def get_iterator():
counter = itertools.count()
with optree.dict_insertion_ordered(
not dict_should_be_sorted,
namespace=dict_session_namespace or GLOBAL_NAMESPACE,
):
new_tree = optree.tree_map(
lambda x: next(counter),
tree,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
num_leaves = next(counter)
it = optree.tree_iter(
new_tree,
none_is_leaf=none_is_leaf,
namespace=namespace,
)
return it, new_tree, num_leaves

it, _, num_leaves = get_iterator()
sentinel = object()
assert list(it) == list(range(num_leaves))
assert next(it, sentinel) is sentinel

it, new_tree, _ = get_iterator()
results = concurrent_run(list, it)
assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves))
for seq in results:
assert sorted(seq) == seq
assert seq == sorted(seq), f'Expected {sorted(seq)}, but got {seq}: tree {new_tree!r}.'
Loading