diff --git a/cassandra/io/asyncioreactor.py b/cassandra/io/asyncioreactor.py index 66e1d7295c..778717c08b 100644 --- a/cassandra/io/asyncioreactor.py +++ b/cassandra/io/asyncioreactor.py @@ -67,13 +67,46 @@ def finish(self): 'does not implement .finish()') +class _AsyncioProtocol(asyncio.Protocol): + """ + Protocol adapter for asyncio SSL connections. Bridges asyncio's + transport/protocol API back to AsyncioConnection's buffer processing. + """ + + def __init__(self, connection): + self._connection = connection + self.transport = None + + def connection_made(self, transport): + self.transport = transport + + def data_received(self, data): + conn = self._connection + conn._iobuf.write(data) + if conn._iobuf.tell(): + conn.process_io_buffer() + + def connection_lost(self, exc): + conn = self._connection + if exc: + log.debug("Connection %s lost: %s", conn, exc) + conn.defunct(exc) + else: + log.debug("Connection %s closed by server", conn) + conn.close() + + def eof_received(self): + return False + + class AsyncioConnection(Connection): """ - An experimental implementation of :class:`.Connection` that uses the - ``asyncio`` module in the Python standard library for its event loop. + An implementation of :class:`.Connection` that uses the ``asyncio`` + module in the Python standard library for its event loop. - Note that it requires ``asyncio`` features that were only introduced in the - 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1. + Supports SSL connections via asyncio's native TLS transport, which + avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's + low-level socket methods (``sock_sendall``, ``sock_recv``). """ _loop = None @@ -88,6 +121,10 @@ class AsyncioConnection(Connection): def __init__(self, *args, **kwargs): Connection.__init__(self, *args, **kwargs) self._background_tasks = set() + self._transport = None + self._protocol = _AsyncioProtocol(self) if self.ssl_context else None + self._using_ssl = bool(self.ssl_context) + self._ssl_ready = asyncio.Event() if self.ssl_context else None self._connect_socket() self._socket.setblocking(0) @@ -99,15 +136,87 @@ def __init__(self, *args, **kwargs): # see initialize_reactor -- loop is running in a separate thread, so we # have to use a threadsafe call - self._read_watcher = asyncio.run_coroutine_threadsafe( - self.handle_read(), loop=self._loop - ) + if self._using_ssl: + # For SSL: set up asyncio transport/protocol, then start writer + self._read_watcher = asyncio.run_coroutine_threadsafe( + self._setup_ssl_and_run(), loop=self._loop + ) + else: + # For non-SSL: use low-level sock_sendall/sock_recv as before + self._read_watcher = asyncio.run_coroutine_threadsafe( + self.handle_read(), loop=self._loop + ) self._write_watcher = asyncio.run_coroutine_threadsafe( self.handle_write(), loop=self._loop ) self._send_options_message() + def _connect_socket(self): + """ + Override base class to skip SSL wrapping of the socket. + For SSL connections, the plain TCP socket is connected here, and TLS + is set up later via asyncio's native SSL transport in _setup_ssl_and_run(). + """ + sockerr = None + addresses = self._get_socket_addresses() + for af, socktype, proto, _, sockaddr in addresses: + try: + self._socket = self._socket_impl.socket(af, socktype, proto) + # Do NOT wrap with ssl_context here -- asyncio will handle TLS + self._socket.settimeout(self.connect_timeout) + self._initiate_connection(sockaddr) + self._socket.settimeout(None) + + local_addr = self._socket.getsockname() + log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr) + sockerr = None + break + except socket.error as err: + if self._socket: + self._socket.close() + self._socket = None + sockerr = err + + if sockerr: + raise socket.error( + sockerr.errno, + "Tried connecting to %s. Last error: %s" + % ([a[4] for a in addresses], sockerr.strerror or sockerr), + ) + async def _setup_ssl_and_run(self): + """ + Upgrade the plain TCP connection to TLS using asyncio's native SSL + transport, then continuously read data via the protocol callbacks. + """ + try: + ssl_context = self.ssl_context + server_hostname = None + if self.ssl_options: + server_hostname = self.ssl_options.get("server_hostname", None) + if not server_hostname: + # asyncio's create_connection requires server_hostname when + # ssl= is set, even if check_hostname is False + server_hostname = self.endpoint.address + + transport, protocol = await self._loop.create_connection( + lambda: self._protocol, + sock=self._socket, + ssl=ssl_context, + server_hostname=server_hostname, + ) + self._transport = transport + + if self._check_hostname: + self._validate_hostname() + + self._ssl_ready.set() + except Exception as exc: + log.debug("SSL setup failed for %s: %s", self, exc) + self.defunct(exc) + # Unblock handle_write so it can observe the defunct state and exit + self._ssl_ready.set() + return @classmethod def initialize_reactor(cls): @@ -152,7 +261,10 @@ async def _close(self): self._write_watcher.cancel() if self._read_watcher: self._read_watcher.cancel() - if self._socket: + if self._transport: + self._transport.close() + self._transport = None + elif self._socket: self._loop.remove_writer(self._socket.fileno()) self._loop.remove_reader(self._socket.fileno()) self._socket.close() @@ -196,11 +308,21 @@ async def _push_msg(self, chunks): async def handle_write(self): + # For SSL connections, wait until the TLS handshake completes + if self._ssl_ready: + await self._ssl_ready.wait() + if self.is_defunct: + return while True: try: next_msg = await self._write_queue.get() if next_msg: - await self._loop.sock_sendall(self._socket, next_msg) + if self._transport: + # SSL: use asyncio transport (handles TLS transparently) + self._transport.write(next_msg) + else: + # Non-SSL: use low-level socket API + await self._loop.sock_sendall(self._socket, next_msg) except socket.error as err: log.debug("Exception in send for %s: %s", self, err) self.defunct(err)