Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 131 additions & 9 deletions cassandra/io/asyncioreactor.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,46 @@ def finish(self):
'does not implement .finish()')


class _AsyncioProtocol(asyncio.Protocol):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since this protocol owns the transport callbacks, I think it would be useful for it to also own flow-control state. Implementing pause_writing() / resume_writing() here would let the writer respect asyncio's high/low watermarks. It may also be worth unblocking any paused writer from connection_lost() so shutdown cannot leave the write coroutine waiting indefinitely.

"""
Protocol adapter for asyncio SSL connections. Bridges asyncio's
transport/protocol API back to AsyncioConnection's buffer processing.
"""

def __init__(self, connection):
self._connection = connection
self.transport = None

def connection_made(self, transport):
self.transport = transport

def data_received(self, data):
conn = self._connection
conn._iobuf.write(data)
if conn._iobuf.tell():
conn.process_io_buffer()

def connection_lost(self, exc):
conn = self._connection
if exc:
log.debug("Connection %s lost: %s", conn, exc)
conn.defunct(exc)
else:
log.debug("Connection %s closed by server", conn)
conn.close()

def eof_received(self):
return False


class AsyncioConnection(Connection):
"""
An experimental implementation of :class:`.Connection` that uses the
``asyncio`` module in the Python standard library for its event loop.
An implementation of :class:`.Connection` that uses the ``asyncio``
module in the Python standard library for its event loop.

Note that it requires ``asyncio`` features that were only introduced in the
3.4 line in 3.4.6, and in the 3.5 line in 3.5.1.
Supports SSL connections via asyncio's native TLS transport, which
avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's
low-level socket methods (``sock_sendall``, ``sock_recv``).
"""

_loop = None
Expand All @@ -88,6 +121,10 @@ class AsyncioConnection(Connection):
def __init__(self, *args, **kwargs):
Connection.__init__(self, *args, **kwargs)
self._background_tasks = set()
self._transport = None
self._protocol = _AsyncioProtocol(self) if self.ssl_context else None
self._using_ssl = bool(self.ssl_context)
self._ssl_ready = asyncio.Event() if self.ssl_context else None
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small compatibility point: the repo still supports Python 3.9, and the existing code passes loop=self._loop to Queue and Lock for that runtime. Should this new Event follow the same pattern so it is bound to the reactor loop rather than the caller thread's loop?


self._connect_socket()
self._socket.setblocking(0)
Expand All @@ -99,15 +136,87 @@ def __init__(self, *args, **kwargs):

# see initialize_reactor -- loop is running in a separate thread, so we
# have to use a threadsafe call
self._read_watcher = asyncio.run_coroutine_threadsafe(
self.handle_read(), loop=self._loop
)
if self._using_ssl:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to use the transport/protocol path for both TLS and non-TLS connections here? With the current split, asyncio has two different I/O models in the same reactor: TLS uses asyncio.Protocol/transport, while plain CQL stays on sock_recv()/sock_sendall().

I think unifying them in this PR may be cleaner because the work needed for TLS is mostly the same work needed for a correct transport-based reactor: setup, read callbacks, write flow control, close handling, and tests. If those pieces are implemented only for TLS now, we end up with two connection state machines to maintain and a higher chance of future fixes applying to one path but not the other.

A shared create_connection(sock=..., ssl=self.ssl_context or None) path could make lifecycle, reads, writes, and flow control consistent across both modes.

# For SSL: set up asyncio transport/protocol, then start writer
self._read_watcher = asyncio.run_coroutine_threadsafe(
self._setup_ssl_and_run(), loop=self._loop
)
else:
# For non-SSL: use low-level sock_sendall/sock_recv as before
self._read_watcher = asyncio.run_coroutine_threadsafe(
self.handle_read(), loop=self._loop
)
self._write_watcher = asyncio.run_coroutine_threadsafe(
self.handle_write(), loop=self._loop
)
self._send_options_message()

def _connect_socket(self):
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This override looks like it accidentally skips the base class sockopts handling. It would be good to preserve Cluster(sockopts=...) for asyncio connections:

if self.sockopts:
    for args in self.sockopts:
        self._socket.setsockopt(*args)

"""
Override base class to skip SSL wrapping of the socket.
For SSL connections, the plain TCP socket is connected here, and TLS
is set up later via asyncio's native SSL transport in _setup_ssl_and_run().
"""
sockerr = None
addresses = self._get_socket_addresses()
for af, socktype, proto, _, sockaddr in addresses:
try:
self._socket = self._socket_impl.socket(af, socktype, proto)
# Do NOT wrap with ssl_context here -- asyncio will handle TLS
self._socket.settimeout(self.connect_timeout)
self._initiate_connection(sockaddr)
self._socket.settimeout(None)

local_addr = self._socket.getsockname()
log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr)
sockerr = None
break
except socket.error as err:
if self._socket:
self._socket.close()
self._socket = None
sockerr = err

if sockerr:
raise socket.error(
sockerr.errno,
"Tried connecting to %s. Last error: %s"
% ([a[4] for a in addresses], sockerr.strerror or sockerr),
)

async def _setup_ssl_and_run(self):
"""
Upgrade the plain TCP connection to TLS using asyncio's native SSL
transport, then continuously read data via the protocol callbacks.
"""
try:
ssl_context = self.ssl_context
server_hostname = None
if self.ssl_options:
server_hostname = self.ssl_options.get("server_hostname", None)
if not server_hostname:
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This may change behavior for an explicit ssl_options={"server_hostname": ""}. For asyncio create_connection(sock=..., ssl=...), server_hostname="" is the documented way to bypass hostname/SNI handling when no host argument is provided. Could we distinguish a missing value from an explicitly empty one, for example with is None or a sentinel?

# asyncio's create_connection requires server_hostname when
# ssl= is set, even if check_hostname is False
server_hostname = self.endpoint.address

transport, protocol = await self._loop.create_connection(
lambda: self._protocol,
sock=self._socket,
ssl=ssl_context,
server_hostname=server_hostname,
)
self._transport = transport

if self._check_hostname:
self._validate_hostname()

self._ssl_ready.set()
except Exception as exc:
log.debug("SSL setup failed for %s: %s", self, exc)
self.defunct(exc)
# Unblock handle_write so it can observe the defunct state and exit
self._ssl_ready.set()
return

@classmethod
def initialize_reactor(cls):
Expand Down Expand Up @@ -152,7 +261,10 @@ async def _close(self):
self._write_watcher.cancel()
if self._read_watcher:
self._read_watcher.cancel()
if self._socket:
if self._transport:
self._transport.close()
self._transport = None
elif self._socket:
self._loop.remove_writer(self._socket.fileno())
self._loop.remove_reader(self._socket.fileno())
self._socket.close()
Expand Down Expand Up @@ -196,11 +308,21 @@ async def _push_msg(self, chunks):


async def handle_write(self):
# For SSL connections, wait until the TLS handshake completes
if self._ssl_ready:
await self._ssl_ready.wait()
if self.is_defunct:
return
while True:
try:
next_msg = await self._write_queue.get()
if next_msg:
await self._loop.sock_sendall(self._socket, next_msg)
if self._transport:
# SSL: use asyncio transport (handles TLS transparently)
self._transport.write(next_msg)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could we account for transport backpressure here? transport.write() queues into asyncio's transport buffer and returns immediately, unlike await sock_sendall(). Under sustained writes, this loop can drain _write_queue faster than the socket can actually send.

One possible shape is to have the protocol expose a write-ready event:

def pause_writing(self):
    self.write_ready.clear()

def resume_writing(self):
    self.write_ready.set()

and then wait before writing:

await self._protocol.write_ready.wait()
self._transport.write(next_msg)

else:
# Non-SSL: use low-level socket API
await self._loop.sock_sendall(self._socket, next_msg)
except socket.error as err:
log.debug("Exception in send for %s: %s", self, err)
self.defunct(err)
Expand Down
Loading