From c4862e78e06fa51803fd0b6bd0ff228ea137650a Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 09:37:12 -0400 Subject: [PATCH 1/3] connection: release stream ids after send failures --- cassandra/connection.py | 31 +++++++++++++++++++++---------- tests/unit/test_connection.py | 30 ++++++++++++++++++++++++++++-- 2 files changed, 49 insertions(+), 12 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index 08501d0a2b..7d5ee0f47b 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1219,15 +1219,19 @@ def send_msg(self, msg, request_id, cb, encoder=ProtocolHandler.encode_message, # queue the decoder function with the request # this allows us to inject custom functions per request to encode, decode messages self._requests[request_id] = (cb, decoder, result_metadata) - msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, - allow_beta_protocol_version=self.allow_beta_protocol_version) + try: + msg = encoder(msg, request_id, self.protocol_version, compressor=self.compressor, + allow_beta_protocol_version=self.allow_beta_protocol_version) - if self._is_checksumming_enabled: - buffer = io.BytesIO() - self._segment_codec.encode(buffer, msg) - msg = buffer.getvalue() + if self._is_checksumming_enabled: + buffer = io.BytesIO() + self._segment_codec.encode(buffer, msg) + msg = buffer.getvalue() - self.push(msg) + self.push(msg) + except Exception: + self._requests.pop(request_id, None) + raise return len(msg) def wait_for_response(self, msg, timeout=None, **kwargs): @@ -1262,9 +1266,16 @@ def wait_for_responses(self, *msgs, **kwargs): self.in_flight += available for i, request_id in enumerate(request_ids): - self.send_msg(msgs[messages_sent + i], - request_id, - partial(waiter.got_response, index=messages_sent + i)) + try: + self.send_msg(msgs[messages_sent + i], + request_id, + partial(waiter.got_response, index=messages_sent + i)) + except Exception: + unsent_request_ids = request_ids[i:] + with self.lock: + self.in_flight -= len(unsent_request_ids) + self.request_ids.extend(unsent_request_ids) + raise messages_sent += available if messages_sent == len(msgs): diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 2fa7c71196..311e47a374 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -18,14 +18,14 @@ from threading import Lock from unittest.mock import Mock, ANY, call, patch -from cassandra import OperationTimedOut +from cassandra import ConsistencyLevel, OperationTimedOut from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, - SupportedMessage, ProtocolHandler, ResultMessage, + SupportedMessage, ProtocolHandler, ResultMessage, QueryMessage, RESULT_KIND_SET_KEYSPACE) from tests.util import wait_until, assertRegex @@ -363,6 +363,32 @@ def test_wait_for_responses_shutdown_includes_last_error(self): assert "already closed" in error_message assert "Bad file descriptor" in error_message + def test_wait_for_responses_releases_request_id_when_send_fails(self): + c = self.make_connection() + c._socket_writable = False + initial_in_flight = c.in_flight + initial_request_ids = len(c.request_ids) + + with pytest.raises(ConnectionBusy): + c.wait_for_responses(Mock()) + + assert c.in_flight == initial_in_flight + assert len(c.request_ids) == initial_request_ids + assert not c._requests + + def test_wait_for_responses_releases_request_id_when_send_raises_after_registration(self): + c = self.make_connection() + c.push = Mock(side_effect=ConnectionException("write failed")) + initial_in_flight = c.in_flight + initial_request_ids = len(c.request_ids) + + with pytest.raises(ConnectionException): + c.wait_for_responses(QueryMessage("SELECT * FROM system.local", ConsistencyLevel.ONE)) + + assert c.in_flight == initial_in_flight + assert len(c.request_ids) == initial_request_ids + assert not c._requests + @patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') class ConnectionHeartbeatTest(unittest.TestCase): From 856efaa5f8e26195d0f27bb41fbb2c93bdc8507d Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Tue, 5 May 2026 11:31:49 -0400 Subject: [PATCH 2/3] connection: clean up failed async keyspace sends --- cassandra/connection.py | 8 +++++++- tests/unit/test_connection.py | 22 +++++++++++++++++++++- 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index 7d5ee0f47b..36d4da275f 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1745,7 +1745,13 @@ def process_result(result): # acquire a new request id request_id = self.get_request_id() - self.send_msg(query, request_id, process_result) + try: + self.send_msg(query, request_id, process_result) + except Exception as exc: + with self.lock: + if request_id not in self._requests and request_id not in self.request_ids: + self.request_ids.append(request_id) + callback(self, exc) @property def is_idle(self): diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 311e47a374..20da6d1875 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -22,7 +22,7 @@ from cassandra.cluster import Cluster from cassandra.connection import (Connection, HEADER_DIRECTION_TO_CLIENT, ProtocolError, locally_supported_compressions, ConnectionHeartbeat, _Frame, Timer, TimerManager, - ConnectionException, ConnectionShutdown, DefaultEndPoint, ShardAwarePortGenerator) + ConnectionException, ConnectionShutdown, ConnectionBusy, DefaultEndPoint, ShardAwarePortGenerator) from cassandra.marshal import uint8_pack, uint32_pack, int32_pack from cassandra.protocol import (write_stringmultimap, write_int, write_string, SupportedMessage, ProtocolHandler, ResultMessage, QueryMessage, @@ -389,6 +389,26 @@ def test_wait_for_responses_releases_request_id_when_send_raises_after_registrat assert len(c.request_ids) == initial_request_ids assert not c._requests + def test_set_keyspace_async_reports_send_failure_and_releases_request_id(self): + c = self.make_connection() + c.push = Mock(side_effect=ConnectionException("write failed")) + initial_in_flight = c.in_flight + initial_request_ids = len(c.request_ids) + callback_errors = [] + + def callback(conn, error): + callback_errors.append(error) + with conn.lock: + conn.in_flight -= 1 + + c.set_keyspace_async("ks", callback) + + assert len(callback_errors) == 1 + assert isinstance(callback_errors[0], ConnectionException) + assert c.in_flight == initial_in_flight + assert len(c.request_ids) == initial_request_ids + assert not c._requests + @patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') class ConnectionHeartbeatTest(unittest.TestCase): From 52eed906ce07d6aaa6055076414dc80beff383a8 Mon Sep 17 00:00:00 2001 From: Dmitry Kropachev Date: Thu, 7 May 2026 11:16:10 -0400 Subject: [PATCH 3/3] connection: drop async keyspace send cleanup --- cassandra/connection.py | 8 +------- tests/unit/test_connection.py | 21 --------------------- 2 files changed, 1 insertion(+), 28 deletions(-) diff --git a/cassandra/connection.py b/cassandra/connection.py index 36d4da275f..7d5ee0f47b 100644 --- a/cassandra/connection.py +++ b/cassandra/connection.py @@ -1745,13 +1745,7 @@ def process_result(result): # acquire a new request id request_id = self.get_request_id() - try: - self.send_msg(query, request_id, process_result) - except Exception as exc: - with self.lock: - if request_id not in self._requests and request_id not in self.request_ids: - self.request_ids.append(request_id) - callback(self, exc) + self.send_msg(query, request_id, process_result) @property def is_idle(self): diff --git a/tests/unit/test_connection.py b/tests/unit/test_connection.py index 20da6d1875..5e2d57192e 100644 --- a/tests/unit/test_connection.py +++ b/tests/unit/test_connection.py @@ -389,27 +389,6 @@ def test_wait_for_responses_releases_request_id_when_send_raises_after_registrat assert len(c.request_ids) == initial_request_ids assert not c._requests - def test_set_keyspace_async_reports_send_failure_and_releases_request_id(self): - c = self.make_connection() - c.push = Mock(side_effect=ConnectionException("write failed")) - initial_in_flight = c.in_flight - initial_request_ids = len(c.request_ids) - callback_errors = [] - - def callback(conn, error): - callback_errors.append(error) - with conn.lock: - conn.in_flight -= 1 - - c.set_keyspace_async("ks", callback) - - assert len(callback_errors) == 1 - assert isinstance(callback_errors[0], ConnectionException) - assert c.in_flight == initial_in_flight - assert len(c.request_ids) == initial_request_ids - assert not c._requests - - @patch('cassandra.connection.ConnectionHeartbeat._raise_if_stopped') class ConnectionHeartbeatTest(unittest.TestCase):