-
Notifications
You must be signed in to change notification settings - Fork 54
asyncio: fix SSL connections by using native TLS transport #884
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -67,13 +67,46 @@ def finish(self): | |
| 'does not implement .finish()') | ||
|
|
||
|
|
||
| class _AsyncioProtocol(asyncio.Protocol): | ||
| """ | ||
| Protocol adapter for asyncio SSL connections. Bridges asyncio's | ||
| transport/protocol API back to AsyncioConnection's buffer processing. | ||
| """ | ||
|
|
||
| def __init__(self, connection): | ||
| self._connection = connection | ||
| self.transport = None | ||
|
|
||
| def connection_made(self, transport): | ||
| self.transport = transport | ||
|
|
||
| def data_received(self, data): | ||
| conn = self._connection | ||
| conn._iobuf.write(data) | ||
| if conn._iobuf.tell(): | ||
| conn.process_io_buffer() | ||
|
|
||
| def connection_lost(self, exc): | ||
| conn = self._connection | ||
| if exc: | ||
| log.debug("Connection %s lost: %s", conn, exc) | ||
| conn.defunct(exc) | ||
| else: | ||
| log.debug("Connection %s closed by server", conn) | ||
| conn.close() | ||
|
|
||
| def eof_received(self): | ||
| return False | ||
|
|
||
|
|
||
| class AsyncioConnection(Connection): | ||
| """ | ||
| An experimental implementation of :class:`.Connection` that uses the | ||
| ``asyncio`` module in the Python standard library for its event loop. | ||
| An implementation of :class:`.Connection` that uses the ``asyncio`` | ||
| module in the Python standard library for its event loop. | ||
|
|
||
| Note that it requires ``asyncio`` features that were only introduced in the | ||
| 3.4 line in 3.4.6, and in the 3.5 line in 3.5.1. | ||
| Supports SSL connections via asyncio's native TLS transport, which | ||
| avoids the incompatibility between ``ssl.SSLSocket`` and asyncio's | ||
| low-level socket methods (``sock_sendall``, ``sock_recv``). | ||
| """ | ||
|
|
||
| _loop = None | ||
|
|
@@ -88,6 +121,10 @@ class AsyncioConnection(Connection): | |
| def __init__(self, *args, **kwargs): | ||
| Connection.__init__(self, *args, **kwargs) | ||
| self._background_tasks = set() | ||
| self._transport = None | ||
| self._protocol = _AsyncioProtocol(self) if self.ssl_context else None | ||
| self._using_ssl = bool(self.ssl_context) | ||
| self._ssl_ready = asyncio.Event() if self.ssl_context else None | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Small compatibility point: the repo still supports Python 3.9, and the existing code passes loop=self._loop to Queue and Lock for that runtime. Should this new Event follow the same pattern so it is bound to the reactor loop rather than the caller thread's loop? |
||
|
|
||
| self._connect_socket() | ||
| self._socket.setblocking(0) | ||
|
|
@@ -99,15 +136,87 @@ def __init__(self, *args, **kwargs): | |
|
|
||
| # see initialize_reactor -- loop is running in a separate thread, so we | ||
| # have to use a threadsafe call | ||
| self._read_watcher = asyncio.run_coroutine_threadsafe( | ||
| self.handle_read(), loop=self._loop | ||
| ) | ||
| if self._using_ssl: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would it make sense to use the transport/protocol path for both TLS and non-TLS connections here? With the current split, asyncio has two different I/O models in the same reactor: TLS uses asyncio.Protocol/transport, while plain CQL stays on sock_recv()/sock_sendall(). I think unifying them in this PR may be cleaner because the work needed for TLS is mostly the same work needed for a correct transport-based reactor: setup, read callbacks, write flow control, close handling, and tests. If those pieces are implemented only for TLS now, we end up with two connection state machines to maintain and a higher chance of future fixes applying to one path but not the other. A shared create_connection(sock=..., ssl=self.ssl_context or None) path could make lifecycle, reads, writes, and flow control consistent across both modes. |
||
| # For SSL: set up asyncio transport/protocol, then start writer | ||
| self._read_watcher = asyncio.run_coroutine_threadsafe( | ||
| self._setup_ssl_and_run(), loop=self._loop | ||
| ) | ||
| else: | ||
| # For non-SSL: use low-level sock_sendall/sock_recv as before | ||
| self._read_watcher = asyncio.run_coroutine_threadsafe( | ||
| self.handle_read(), loop=self._loop | ||
| ) | ||
| self._write_watcher = asyncio.run_coroutine_threadsafe( | ||
| self.handle_write(), loop=self._loop | ||
| ) | ||
| self._send_options_message() | ||
|
|
||
| def _connect_socket(self): | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This override looks like it accidentally skips the base class sockopts handling. It would be good to preserve Cluster(sockopts=...) for asyncio connections: if self.sockopts:
for args in self.sockopts:
self._socket.setsockopt(*args) |
||
| """ | ||
| Override base class to skip SSL wrapping of the socket. | ||
| For SSL connections, the plain TCP socket is connected here, and TLS | ||
| is set up later via asyncio's native SSL transport in _setup_ssl_and_run(). | ||
| """ | ||
| sockerr = None | ||
| addresses = self._get_socket_addresses() | ||
| for af, socktype, proto, _, sockaddr in addresses: | ||
| try: | ||
| self._socket = self._socket_impl.socket(af, socktype, proto) | ||
| # Do NOT wrap with ssl_context here -- asyncio will handle TLS | ||
| self._socket.settimeout(self.connect_timeout) | ||
| self._initiate_connection(sockaddr) | ||
| self._socket.settimeout(None) | ||
|
|
||
| local_addr = self._socket.getsockname() | ||
| log.debug("Connection %s: '%s' -> '%s'", id(self), local_addr, sockaddr) | ||
| sockerr = None | ||
| break | ||
| except socket.error as err: | ||
| if self._socket: | ||
| self._socket.close() | ||
| self._socket = None | ||
| sockerr = err | ||
|
|
||
| if sockerr: | ||
| raise socket.error( | ||
| sockerr.errno, | ||
| "Tried connecting to %s. Last error: %s" | ||
| % ([a[4] for a in addresses], sockerr.strerror or sockerr), | ||
| ) | ||
|
|
||
| async def _setup_ssl_and_run(self): | ||
| """ | ||
| Upgrade the plain TCP connection to TLS using asyncio's native SSL | ||
| transport, then continuously read data via the protocol callbacks. | ||
| """ | ||
| try: | ||
| ssl_context = self.ssl_context | ||
| server_hostname = None | ||
| if self.ssl_options: | ||
| server_hostname = self.ssl_options.get("server_hostname", None) | ||
| if not server_hostname: | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This may change behavior for an explicit ssl_options={"server_hostname": ""}. For asyncio create_connection(sock=..., ssl=...), server_hostname="" is the documented way to bypass hostname/SNI handling when no host argument is provided. Could we distinguish a missing value from an explicitly empty one, for example with is None or a sentinel? |
||
| # asyncio's create_connection requires server_hostname when | ||
| # ssl= is set, even if check_hostname is False | ||
| server_hostname = self.endpoint.address | ||
|
|
||
| transport, protocol = await self._loop.create_connection( | ||
| lambda: self._protocol, | ||
| sock=self._socket, | ||
| ssl=ssl_context, | ||
| server_hostname=server_hostname, | ||
| ) | ||
| self._transport = transport | ||
|
|
||
| if self._check_hostname: | ||
| self._validate_hostname() | ||
|
|
||
| self._ssl_ready.set() | ||
| except Exception as exc: | ||
| log.debug("SSL setup failed for %s: %s", self, exc) | ||
| self.defunct(exc) | ||
| # Unblock handle_write so it can observe the defunct state and exit | ||
| self._ssl_ready.set() | ||
| return | ||
|
|
||
| @classmethod | ||
| def initialize_reactor(cls): | ||
|
|
@@ -152,7 +261,10 @@ async def _close(self): | |
| self._write_watcher.cancel() | ||
| if self._read_watcher: | ||
| self._read_watcher.cancel() | ||
| if self._socket: | ||
| if self._transport: | ||
| self._transport.close() | ||
| self._transport = None | ||
| elif self._socket: | ||
| self._loop.remove_writer(self._socket.fileno()) | ||
| self._loop.remove_reader(self._socket.fileno()) | ||
| self._socket.close() | ||
|
|
@@ -196,11 +308,21 @@ async def _push_msg(self, chunks): | |
|
|
||
|
|
||
| async def handle_write(self): | ||
| # For SSL connections, wait until the TLS handshake completes | ||
| if self._ssl_ready: | ||
| await self._ssl_ready.wait() | ||
| if self.is_defunct: | ||
| return | ||
| while True: | ||
| try: | ||
| next_msg = await self._write_queue.get() | ||
| if next_msg: | ||
| await self._loop.sock_sendall(self._socket, next_msg) | ||
| if self._transport: | ||
| # SSL: use asyncio transport (handles TLS transparently) | ||
| self._transport.write(next_msg) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could we account for transport backpressure here? transport.write() queues into asyncio's transport buffer and returns immediately, unlike await sock_sendall(). Under sustained writes, this loop can drain _write_queue faster than the socket can actually send. One possible shape is to have the protocol expose a write-ready event: def pause_writing(self):
self.write_ready.clear()
def resume_writing(self):
self.write_ready.set()and then wait before writing: await self._protocol.write_ready.wait()
self._transport.write(next_msg) |
||
| else: | ||
| # Non-SSL: use low-level socket API | ||
| await self._loop.sock_sendall(self._socket, next_msg) | ||
| except socket.error as err: | ||
| log.debug("Exception in send for %s: %s", self, err) | ||
| self.defunct(err) | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since this protocol owns the transport callbacks, I think it would be useful for it to also own flow-control state. Implementing pause_writing() / resume_writing() here would let the writer respect asyncio's high/low watermarks. It may also be worth unblocking any paused writer from connection_lost() so shutdown cannot leave the write coroutine waiting indefinitely.