From a45bd9d6c171c1e418555470a7ec3ece03bb6ee6 Mon Sep 17 00:00:00 2001 From: Xuehai Pan Date: Fri, 2 Jan 2026 22:19:52 +0800 Subject: [PATCH] fix: fix potential deadlock when using GIL with locks from STL --- .github/workflows/tests.yml | 2 +- include/optree/pytypes.h | 30 +++++++++++++++++ include/optree/treespec.h | 2 -- src/registry.cpp | 12 +++---- src/treespec/traversal.cpp | 17 ++++++---- tests/concurrent/test_threading.py | 52 ++++++++++++++++-------------- 6 files changed, 73 insertions(+), 42 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index edf87196..fc731376 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -70,7 +70,7 @@ jobs: - "3.14t" - "pypy-3.11" fail-fast: false - timeout-minutes: 90 + timeout-minutes: 120 steps: - name: Checkout uses: actions/checkout@v6 diff --git a/include/optree/pytypes.h b/include/optree/pytypes.h index a932bd9e..2dfea5f9 100644 --- a/include/optree/pytypes.h +++ b/include/optree/pytypes.h @@ -261,6 +261,9 @@ inline bool IsNamedTupleClass(const py::handle &type) { static read_write_mutex mutex{}; { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_read_lock lock{mutex}; const auto it = cache.find(type); if (it != cache.end()) [[likely]] { @@ -270,8 +273,14 @@ inline bool IsNamedTupleClass(const py::handle &type) { const bool result = EVALUATE_WITH_LOCK_HELD(IsNamedTupleClassImpl(type), type); { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_write_lock lock{mutex}; if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_acquire_simple gil_acquire{}; +#endif cache.emplace(type, result); (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { const scoped_write_lock lock{mutex}; @@ -363,6 +372,9 @@ inline bool IsStructSequenceClass(const py::handle &type) { static read_write_mutex mutex{}; { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_read_lock lock{mutex}; const auto it = cache.find(type); if (it != cache.end()) [[likely]] { @@ -372,8 +384,14 @@ inline bool IsStructSequenceClass(const py::handle &type) { const bool result = EVALUATE_WITH_LOCK_HELD(IsStructSequenceClassImpl(type), type); { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_write_lock lock{mutex}; if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_acquire_simple gil_acquire{}; +#endif cache.emplace(type, result); (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { const scoped_write_lock lock{mutex}; @@ -446,17 +464,29 @@ inline py::tuple StructSequenceGetFields(const py::handle &object) { static read_write_mutex mutex{}; { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_read_lock lock{mutex}; const auto it = cache.find(type); if (it != cache.end()) [[likely]] { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_acquire_simple gil_acquire{}; +#endif return py::reinterpret_borrow(it->second); } } const py::tuple fields = EVALUATE_WITH_LOCK_HELD(StructSequenceGetFieldsImpl(type), type); { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_write_lock lock{mutex}; if (cache.size() < MAX_TYPE_CACHE_SIZE) [[likely]] { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_acquire_simple gil_acquire{}; +#endif cache.emplace(type, fields); fields.inc_ref(); (void)py::weakref(type, py::cpp_function([type](py::handle weakref) -> void { diff --git a/include/optree/treespec.h b/include/optree/treespec.h index 1ff25883..ecc6079e 100644 --- a/include/optree/treespec.h +++ b/include/optree/treespec.h @@ -464,9 +464,7 @@ class PyTreeIter { const bool m_none_is_leaf; const std::string m_namespace; const bool m_is_dict_insertion_ordered; -#if defined(Py_GIL_DISABLED) mutable mutex m_mutex{}; -#endif template [[nodiscard]] py::object NextImpl(); diff --git a/src/registry.cpp b/src/registry.cpp index c5c10636..faa6c662 100644 --- a/src/registry.cpp +++ b/src/registry.cpp @@ -40,8 +40,7 @@ template EXPECT_TRUE(registry.m_builtins_types.emplace(cls).second, "PyTree type " + PyRepr(cls) + " is already registered in the built-in types set."); - cls.inc_ref(); - if (!NoneIsLeaf || kind != PyTreeKind::None) { + if (!NoneIsLeaf || kind != PyTreeKind::None) [[likely]] { auto registration = std::make_shared>(); registration->kind = kind; @@ -50,9 +49,9 @@ template registry.m_registrations.emplace(cls, std::move(registration)).second, "PyTree type " + PyRepr(cls) + " is already registered in the global namespace."); - if constexpr (!NoneIsLeaf) { - cls.inc_ref(); - } + } + if constexpr (!NoneIsLeaf) { + cls.inc_ref(); } }; add_builtin_type(PyNoneTypeObject, PyTreeKind::None); @@ -368,9 +367,6 @@ template PyTreeKind PyTreeTypeRegistry::GetKind( } #endif - for (const auto &cls : registry1.m_builtins_types) { - cls.dec_ref(); - } for (const auto &[_, registration1] : registry1.m_registrations) { registration1->type.dec_ref(); registration1->flatten_func.dec_ref(); diff --git a/src/treespec/traversal.cpp b/src/treespec/traversal.cpp index e22ba47e..aea43087 100644 --- a/src/treespec/traversal.cpp +++ b/src/treespec/traversal.cpp @@ -162,14 +162,19 @@ py::object PyTreeIter::NextImpl() { } py::object PyTreeIter::Next() { -#if defined(Py_GIL_DISABLED) +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_release_simple gil_release{}; +#endif const scoped_lock lock{m_mutex}; + { +#if !defined(Py_GIL_DISABLED) + const py::gil_scoped_acquire_simple gil_acquire{}; #endif - - if (m_none_is_leaf) [[unlikely]] { - return NextImpl(); - } else [[likely]] { - return NextImpl(); + if (m_none_is_leaf) [[unlikely]] { + return NextImpl(); + } else [[likely]] { + return NextImpl(); + } } } diff --git a/tests/concurrent/test_threading.py b/tests/concurrent/test_threading.py index d1888b42..ec172926 100644 --- a/tests/concurrent/test_threading.py +++ b/tests/concurrent/test_threading.py @@ -343,31 +343,33 @@ def test_tree_iter_thread_safe( dict_should_be_sorted, dict_session_namespace, ): - counter = itertools.count() - with optree.dict_insertion_ordered( - not dict_should_be_sorted, - namespace=dict_session_namespace or GLOBAL_NAMESPACE, - ): - new_tree = optree.tree_map( - lambda x: next(counter), - tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - num_leaves = next(counter) - assert optree.tree_leaves( - new_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) == list(range(num_leaves)) - - it = optree.tree_iter( - new_tree, - none_is_leaf=none_is_leaf, - namespace=namespace, - ) - + def get_iterator(): + counter = itertools.count() + with optree.dict_insertion_ordered( + not dict_should_be_sorted, + namespace=dict_session_namespace or GLOBAL_NAMESPACE, + ): + new_tree = optree.tree_map( + lambda x: next(counter), + tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + num_leaves = next(counter) + it = optree.tree_iter( + new_tree, + none_is_leaf=none_is_leaf, + namespace=namespace, + ) + return it, new_tree, num_leaves + + it, _, num_leaves = get_iterator() + sentinel = object() + assert list(it) == list(range(num_leaves)) + assert next(it, sentinel) is sentinel + + it, new_tree, _ = get_iterator() results = concurrent_run(list, it) assert sorted(itertools.chain.from_iterable(results)) == list(range(num_leaves)) for seq in results: - assert sorted(seq) == seq + assert seq == sorted(seq), f'Expected {sorted(seq)}, but got {seq}: tree {new_tree!r}.'