diff --git a/src/strands/models/bedrock.py b/src/strands/models/bedrock.py index ab9adb67a..95cc8522e 100644 --- a/src/strands/models/bedrock.py +++ b/src/strands/models/bedrock.py @@ -64,6 +64,12 @@ def _clear_skip_count_tokens_cache() -> None: _SKIP_COUNT_TOKENS_MODELS.clear() +def _suppress_task_exception(task: "asyncio.Task[None]") -> None: + """Consume exception from orphaned stream task to silence 'never retrieved' warning.""" + if not task.cancelled(): + task.exception() + + T = TypeVar("T", bound=BaseModel) DEFAULT_READ_TIMEOUT = 120 @@ -909,14 +915,17 @@ def callback(event: StreamEvent | None = None) -> None: thread = asyncio.to_thread(self._stream, callback, messages, tool_specs, system_prompt_content, tool_choice) task = asyncio.create_task(thread) - while True: - event = await queue.get() - if event is None: - break - - yield event - - await task + try: + while True: + event = await queue.get() + if event is None: + break + + yield event + await task + except BaseException: + task.add_done_callback(_suppress_task_exception) + raise def _stream( self, diff --git a/tests/strands/models/test_bedrock.py b/tests/strands/models/test_bedrock.py index 2e105d64a..2b1384783 100644 --- a/tests/strands/models/test_bedrock.py +++ b/tests/strands/models/test_bedrock.py @@ -1,3 +1,4 @@ +import asyncio import copy import logging import os