From 6be95c773422a5139d7202c9656af7039901c8a3 Mon Sep 17 00:00:00 2001 From: Noah Stapp Date: Tue, 7 Apr 2026 14:38:19 -0400 Subject: [PATCH] PYTHON-5788 - Refine withTransaction timeout error wrapping semantics and label propagation in spec and prose tests --- pymongo/asynchronous/client_session.py | 15 +++++++++++---- pymongo/synchronous/client_session.py | 15 +++++++++++---- test/asynchronous/test_transactions.py | 12 +++++++++--- test/test_transactions.py | 12 +++++++++--- 4 files changed, 40 insertions(+), 14 deletions(-) diff --git a/pymongo/asynchronous/client_session.py b/pymongo/asynchronous/client_session.py index 8212d1396a..ed19f16f4d 100644 --- a/pymongo/asynchronous/client_session.py +++ b/pymongo/asynchronous/client_session.py @@ -516,9 +516,14 @@ def _within_time_limit(start_time: float, backoff: float = 0) -> bool: def _make_timeout_error(error: BaseException) -> PyMongoError: """Convert error to a NetworkTimeout or ExecutionTimeout as appropriate.""" if _csot.remaining() is not None: - return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}) + timeout_error: PyMongoError = ExecutionTimeout( + str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50} + ) else: - return NetworkTimeout(str(error)) + timeout_error = NetworkTimeout(str(error)) + if isinstance(error, PyMongoError): + timeout_error._error_labels = error._error_labels.copy() + return timeout_error _T = TypeVar("_T") @@ -804,15 +809,17 @@ async def callback(session, custom_arg, custom_kwarg=None): await self.commit_transaction() except PyMongoError as exc: last_error = exc - if not _within_time_limit(start_time): - raise _make_timeout_error(last_error) from exc if exc.has_error_label( "UnknownTransactionCommitResult" ) and not _max_time_expired_error(exc): + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc # Retry the commit. continue if exc.has_error_label("TransientTransactionError"): + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc # Retry the entire transaction. break raise diff --git a/pymongo/synchronous/client_session.py b/pymongo/synchronous/client_session.py index f8211a60b1..64056f33c4 100644 --- a/pymongo/synchronous/client_session.py +++ b/pymongo/synchronous/client_session.py @@ -514,9 +514,14 @@ def _within_time_limit(start_time: float, backoff: float = 0) -> bool: def _make_timeout_error(error: BaseException) -> PyMongoError: """Convert error to a NetworkTimeout or ExecutionTimeout as appropriate.""" if _csot.remaining() is not None: - return ExecutionTimeout(str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50}) + timeout_error: PyMongoError = ExecutionTimeout( + str(error), 50, {"ok": 0, "errmsg": str(error), "code": 50} + ) else: - return NetworkTimeout(str(error)) + timeout_error = NetworkTimeout(str(error)) + if isinstance(error, PyMongoError): + timeout_error._error_labels = error._error_labels.copy() + return timeout_error _T = TypeVar("_T") @@ -800,15 +805,17 @@ def callback(session, custom_arg, custom_kwarg=None): self.commit_transaction() except PyMongoError as exc: last_error = exc - if not _within_time_limit(start_time): - raise _make_timeout_error(last_error) from exc if exc.has_error_label( "UnknownTransactionCommitResult" ) and not _max_time_expired_error(exc): + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc # Retry the commit. continue if exc.has_error_label("TransientTransactionError"): + if not _within_time_limit(start_time): + raise _make_timeout_error(last_error) from exc # Retry the entire transaction. break raise diff --git a/test/asynchronous/test_transactions.py b/test/asynchronous/test_transactions.py index 95a07a743c..7b28b5bd91 100644 --- a/test/asynchronous/test_transactions.py +++ b/test/asynchronous/test_transactions.py @@ -500,10 +500,12 @@ async def callback(session): listener.reset() async with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(NetworkTimeout): + with self.assertRaises(NetworkTimeout) as context: await s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @async_client_context.require_test_commands @async_client_context.require_transactions @@ -534,10 +536,12 @@ async def callback(session): async with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(NetworkTimeout): + with self.assertRaises(NetworkTimeout) as context: await s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @async_client_context.require_test_commands @async_client_context.require_transactions @@ -565,7 +569,7 @@ async def callback(session): async with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(NetworkTimeout): + with self.assertRaises(NetworkTimeout) as context: await s.with_transaction(callback) # One insert for the callback and two commits (includes the automatic @@ -573,6 +577,8 @@ async def callback(session): self.assertEqual( listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] ) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult")) @async_client_context.require_transactions async def test_callback_not_retried_after_csot_timeout(self): diff --git a/test/test_transactions.py b/test/test_transactions.py index 9e370294ef..861d6d0c02 100644 --- a/test/test_transactions.py +++ b/test/test_transactions.py @@ -492,10 +492,12 @@ def callback(session): listener.reset() with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(NetworkTimeout): + with self.assertRaises(NetworkTimeout) as context: s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "abortTransaction"]) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @client_context.require_test_commands @client_context.require_transactions @@ -524,10 +526,12 @@ def callback(session): with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(NetworkTimeout): + with self.assertRaises(NetworkTimeout) as context: s.with_transaction(callback) self.assertEqual(listener.started_command_names(), ["insert", "commitTransaction"]) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("TransientTransactionError")) @client_context.require_test_commands @client_context.require_transactions @@ -553,7 +557,7 @@ def callback(session): with client.start_session() as s: with PatchSessionTimeout(0): - with self.assertRaises(NetworkTimeout): + with self.assertRaises(NetworkTimeout) as context: s.with_transaction(callback) # One insert for the callback and two commits (includes the automatic @@ -561,6 +565,8 @@ def callback(session): self.assertEqual( listener.started_command_names(), ["insert", "commitTransaction", "commitTransaction"] ) + # Assert that the timeout error has the same labels as the error it wraps. + self.assertTrue(context.exception.has_error_label("UnknownTransactionCommitResult")) @client_context.require_transactions def test_callback_not_retried_after_csot_timeout(self):