Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 23 additions & 15 deletions src/strands/experimental/bidi/_async/_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,17 @@
"""

import asyncio
from typing import Any, Coroutine
from typing import Any, Coroutine, cast


class _TaskGroup:
"""Shim of asyncio.TaskGroup for use in Python 3.10.

Attributes:
_tasks: List of tasks in group.
_tasks: Set of tasks in group.
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Switching to set to be consistent with asyncio.wait which puts done_tasks and pending_tasks into sets. This in turn helps to resolve a mypy error.

"""

_tasks: list[asyncio.Task]
_tasks: set[asyncio.Task]

def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task:
"""Create an async task and add to group.
Expand All @@ -25,12 +25,12 @@ def create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task:
The created task.
"""
task = asyncio.create_task(coro)
self._tasks.append(task)
self._tasks.add(task)
return task

async def __aenter__(self) -> "_TaskGroup":
"""Setup self managed task group context."""
self._tasks = []
self._tasks = set()
return self

async def __aexit__(self, *_: Any) -> None:
Expand All @@ -42,20 +42,28 @@ async def __aexit__(self, *_: Any) -> None:
- The context re-raises CancelledErrors to the caller only if the context itself was cancelled.
"""
try:
Copy link
Member Author

@pgrayy pgrayy Dec 17, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have the execution rules listed above in the docstring. To repeat here:

"""Execute tasks in group.

The following execution rules are enforced:
- The context stops executing all tasks if at least one task raises an Exception or the context is cancelled.
- The context re-raises Exceptions to the caller.
- The context re-raises CancelledErrors to the caller only if the context itself was cancelled.
"""

await asyncio.gather(*self._tasks)
pending_tasks = self._tasks
while pending_tasks:
done_tasks, pending_tasks = await asyncio.wait(pending_tasks, return_when=asyncio.FIRST_EXCEPTION)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here we reassign pending_tasks which led mypy to complain about set[Task] vs list[Task] and thus why I switched self._tasks to set[Task].


except (Exception, asyncio.CancelledError) as error:
if any(exception := done_task.exception() for done_task in done_tasks if not done_task.cancelled()):
break

else: # all tasks completed/cancelled successfully
return

for pending_task in pending_tasks:
pending_task.cancel()

await asyncio.gather(*pending_tasks, return_exceptions=True)
raise cast(BaseException, exception)

except asyncio.CancelledError: # context itself was cancelled
for task in self._tasks:
task.cancel()

await asyncio.gather(*self._tasks, return_exceptions=True)

if not isinstance(error, asyncio.CancelledError):
raise

context_task = asyncio.current_task()
if context_task and context_task.cancelling() > 0: # context itself was cancelled
raise
raise

finally:
self._tasks = []
self._tasks = set()
16 changes: 14 additions & 2 deletions tests/strands/experimental/bidi/_async/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ async def test_task_group__aexit__():


@pytest.mark.asyncio
async def test_task_group__aexit__exception():
async def test_task_group__aexit__task_exception():
wait_event = asyncio.Event()
async def wait():
await wait_event.wait()
Expand All @@ -35,7 +35,19 @@ async def fail():


@pytest.mark.asyncio
async def test_task_group__aexit__cancelled():
async def test_task_group__aexit__task_cancelled():
async def wait():
asyncio.current_task().cancel()
await asyncio.sleep(0)

async with _TaskGroup() as task_group:
wait_task = task_group.create_task(wait())

assert wait_task.cancelled()


@pytest.mark.asyncio
async def test_task_group__aexit__context_cancelled():
wait_event = asyncio.Event()
async def wait():
await wait_event.wait()
Expand Down
Loading