From 6d64b011d24d0807781c8faa5ba9ea7c40d38db7 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Fri, 6 Mar 2026 10:35:37 -0500 Subject: [PATCH 1/3] feat: Add AsyncSSEClient with aiohttp-based async/await support Adds AsyncSSEClient as a purely additive new public API alongside the existing SSEClient. Async users install with the [async] extra to get aiohttp; sync users have no new dependencies. All existing tests pass unchanged. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- .github/workflows/ci.yml | 6 + CONTRIBUTING.md | 11 + Makefile | 19 ++ README.md | 8 +- contract-tests/async_service.py | 120 +++++++++ contract-tests/async_stream_entity.py | 115 +++++++++ docs/conf.py | 3 + docs/index.rst | 19 +- ld_eventsource/__init__.py | 9 + ld_eventsource/async_client.py | 237 ++++++++++++++++++ ld_eventsource/async_http.py | 120 +++++++++ ld_eventsource/async_reader.py | 101 ++++++++ .../config/async_connect_strategy.py | 148 +++++++++++ ld_eventsource/testing/async_helpers.py | 96 +++++++ .../test_async_http_connect_strategy.py | 164 ++++++++++++ ld_eventsource/testing/test_async_reader.py | 152 +++++++++++ .../testing/test_async_sse_client_basic.py | 149 +++++++++++ .../testing/test_async_sse_client_retry.py | 201 +++++++++++++++ pyproject.toml | 10 +- 19 files changed, 1684 insertions(+), 4 deletions(-) create mode 100644 contract-tests/async_service.py create mode 100644 contract-tests/async_stream_entity.py create mode 100644 ld_eventsource/async_client.py create mode 100644 ld_eventsource/async_http.py create mode 100644 ld_eventsource/async_reader.py create mode 100644 ld_eventsource/config/async_connect_strategy.py create mode 100644 ld_eventsource/testing/async_helpers.py create mode 100644 ld_eventsource/testing/test_async_http_connect_strategy.py create mode 100644 ld_eventsource/testing/test_async_reader.py create mode 100644 ld_eventsource/testing/test_async_sse_client_basic.py create mode 100644 ld_eventsource/testing/test_async_sse_client_retry.py diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 10f593c..f11a75c 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,6 +46,12 @@ jobs: - name: run SSE contract tests run: make run-contract-tests + - name: start async SSE contract test service + run: make start-async-contract-test-service-bg + + - name: run async SSE contract tests + run: make run-async-contract-tests + windows: runs-on: windows-latest diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index adcf964..89e2ba8 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -23,6 +23,12 @@ poetry install eval $(poetry env activate) ``` +To also install the optional async dependencies (required to use `AsyncSSEClient`): + +``` +poetry install --extras async +``` + ### Testing To run all unit tests: @@ -36,6 +42,11 @@ To run the standardized contract tests that are run against all LaunchDarkly SSE make contract-tests ``` +To run the same contract tests against the async implementation: +``` +make async-contract-tests +``` + ### Linting To run the linter and check type hints: diff --git a/Makefile b/Makefile index 716ade2..7b2f851 100644 --- a/Makefile +++ b/Makefile @@ -1,6 +1,7 @@ PYTEST_FLAGS=-W error::SyntaxWarning TEMP_TEST_OUTPUT=/tmp/sse-contract-test-service.log +TEMP_ASYNC_TEST_OUTPUT=/tmp/sse-async-contract-test-service.log SPHINXOPTS = -W --keep-going SPHINXBUILD = sphinx-build @@ -70,3 +71,21 @@ run-contract-tests: .PHONY: contract-tests contract-tests: #! Run the SSE contract test harness contract-tests: install-contract-tests-deps start-contract-test-service-bg run-contract-tests + +.PHONY: start-async-contract-test-service +start-async-contract-test-service: + @cd contract-tests && poetry run python async_service.py 8001 + +.PHONY: start-async-contract-test-service-bg +start-async-contract-test-service-bg: + @echo "Async test service output will be captured in $(TEMP_ASYNC_TEST_OUTPUT)" + @make start-async-contract-test-service >$(TEMP_ASYNC_TEST_OUTPUT) 2>&1 & + +.PHONY: run-async-contract-tests +run-async-contract-tests: + @curl -s https://raw.githubusercontent.com/launchdarkly/sse-contract-tests/main/downloader/run.sh \ + | VERSION=v2 PARAMS="-url http://localhost:8001 -debug -stop-service-at-end" sh + +.PHONY: async-contract-tests +async-contract-tests: #! Run the SSE async contract test harness +async-contract-tests: install-contract-tests-deps start-async-contract-test-service-bg run-async-contract-tests diff --git a/README.md b/README.md index 8b2bd5e..ef67bbb 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,13 @@ This package's primary purpose is to support the [LaunchDarkly SDK for Python](h * Setting read timeouts, custom headers, and other HTTP request properties. * Specifying that connections should be retried under circumstances where the standard EventSource behavior would not retry them, such as if the server returns an HTTP error status. -This is a synchronous implementation which blocks the caller's thread when reading events or reconnecting. By default, it uses `urllib3` to make HTTP requests, but it can be configured to read any input stream. +The default `SSEClient` is a synchronous implementation which blocks the caller's thread when reading events or reconnecting. By default, it uses `urllib3` to make HTTP requests, but it can be configured to read any input stream. + +An async implementation, `AsyncSSEClient`, is also available for use with `asyncio`-based applications. It uses `aiohttp` for HTTP and requires installing the optional `async` extra: + +``` +pip install launchdarkly-eventsource[async] +``` ## Supported Python versions diff --git a/contract-tests/async_service.py b/contract-tests/async_service.py new file mode 100644 index 0000000..8d86be1 --- /dev/null +++ b/contract-tests/async_service.py @@ -0,0 +1,120 @@ +import json +import logging +import os +import sys +from logging.config import dictConfig + +import aiohttp.web + +from async_stream_entity import AsyncStreamEntity + +default_port = 8000 + +dictConfig({ + 'version': 1, + 'formatters': { + 'default': { + 'format': '[%(asctime)s] [%(name)s] %(levelname)s: %(message)s', + } + }, + 'handlers': { + 'console': { + 'class': 'logging.StreamHandler', + 'formatter': 'default' + } + }, + 'root': { + 'level': 'INFO', + 'handlers': ['console'] + }, +}) + +global_log = logging.getLogger('testservice') + +stream_counter = 0 +streams = {} + + +async def handle_get_status(request): + body = { + 'capabilities': [ + 'comments', + 'headers', + 'last-event-id', + 'read-timeout', + ] + } + return aiohttp.web.Response( + body=json.dumps(body), + content_type='application/json', + ) + + +async def handle_delete_stop(request): + global_log.info("Test service has told us to exit") + os._exit(0) + + +async def handle_post_create_stream(request): + global stream_counter, streams + + options = json.loads(await request.read()) + + stream_counter += 1 + stream_id = str(stream_counter) + resource_url = '/streams/%s' % stream_id + + stream = AsyncStreamEntity(options, request.app['http_session']) + streams[stream_id] = stream + + return aiohttp.web.Response(status=201, headers={'Location': resource_url}) + + +async def handle_post_stream_command(request): + stream_id = request.match_info['id'] + params = json.loads(await request.read()) + + stream = streams.get(stream_id) + if stream is None: + return aiohttp.web.Response(status=404) + if not await stream.do_command(params.get('command')): + return aiohttp.web.Response(status=400) + return aiohttp.web.Response(status=204) + + +async def handle_delete_stream(request): + stream_id = request.match_info['id'] + + stream = streams.get(stream_id) + if stream is None: + return aiohttp.web.Response(status=404) + await stream.close() + return aiohttp.web.Response(status=204) + + +async def on_startup(app): + app['http_session'] = aiohttp.ClientSession() + + +async def on_cleanup(app): + await app['http_session'].close() + + +def make_app(): + app = aiohttp.web.Application() + app.router.add_get('/', handle_get_status) + app.router.add_delete('/', handle_delete_stop) + app.router.add_post('/', handle_post_create_stream) + app.router.add_post('/streams/{id}', handle_post_stream_command) + app.router.add_delete('/streams/{id}', handle_delete_stream) + app.on_startup.append(on_startup) + app.on_cleanup.append(on_cleanup) + return app + + +if __name__ == "__main__": + port = default_port + if sys.argv[len(sys.argv) - 1] != 'async_service.py': + port = int(sys.argv[len(sys.argv) - 1]) + global_log.info('Listening on port %d', port) + aiohttp.web.run_app(make_app(), host='0.0.0.0', port=port) diff --git a/contract-tests/async_stream_entity.py b/contract-tests/async_stream_entity.py new file mode 100644 index 0000000..ce9628e --- /dev/null +++ b/contract-tests/async_stream_entity.py @@ -0,0 +1,115 @@ +import asyncio +import json +import logging +import os +import sys +import traceback + +import aiohttp + +# Import ld_eventsource from parent directory +sys.path.insert(1, os.path.join(sys.path[0], '..')) +from ld_eventsource.actions import Comment, Event, Fault # noqa: E402 +from ld_eventsource.async_client import AsyncSSEClient # noqa: E402 +from ld_eventsource.config.async_connect_strategy import AsyncConnectStrategy # noqa: E402 +from ld_eventsource.config.error_strategy import ErrorStrategy # noqa: E402 + + +def millis_to_seconds(t): + return None if t is None else t / 1000 + + +class AsyncStreamEntity: + def __init__(self, options, http_session: aiohttp.ClientSession): + self.options = options + self.callback_url = options["callbackUrl"] + self.log = logging.getLogger(options["tag"]) + self.closed = False + self.callback_counter = 0 + self.sse = None + self._http_session = http_session + asyncio.create_task(self.run()) + + async def run(self): + stream_url = self.options["streamUrl"] + try: + self.log.info('Opening stream from %s', stream_url) + + request_options = {} + if self.options.get("readTimeoutMs") is not None: + request_options["timeout"] = aiohttp.ClientTimeout( + sock_read=millis_to_seconds(self.options.get("readTimeoutMs")) + ) + + connect = AsyncConnectStrategy.http( + url=stream_url, + headers=self.options.get("headers"), + aiohttp_request_options=request_options if request_options else None, + ) + sse = AsyncSSEClient( + connect, + initial_retry_delay=millis_to_seconds(self.options.get("initialDelayMs")), + last_event_id=self.options.get("lastEventId"), + error_strategy=ErrorStrategy.from_lambda( + lambda _: ( + ErrorStrategy.FAIL if self.closed else ErrorStrategy.CONTINUE, + None, + ) + ), + logger=self.log, + ) + self.sse = sse + async for item in sse.all: + if isinstance(item, Event): + self.log.info('Received event from stream (%s)', item.event) + await self.send_message( + { + 'kind': 'event', + 'event': { + 'type': item.event, + 'data': item.data, + 'id': item.last_event_id, + }, + } + ) + elif isinstance(item, Comment): + self.log.info('Received comment from stream: %s', item.comment) + await self.send_message({'kind': 'comment', 'comment': item.comment}) + elif isinstance(item, Fault): + if self.closed: + break + if item.error: + self.log.info('Received error from stream: %s', item.error) + await self.send_message({'kind': 'error', 'error': str(item.error)}) + except Exception as e: + self.log.info('Received error from stream: %s', e) + self.log.info(traceback.format_exc()) + await self.send_message({'kind': 'error', 'error': str(e)}) + + async def do_command(self, command: str) -> bool: + self.log.info('Test service sent command: %s' % command) + # currently we support no special commands + return False + + async def send_message(self, message): + if self.closed: + return + self.callback_counter += 1 + callback_url = "%s/%d" % (self.callback_url, self.callback_counter) + try: + async with self._http_session.post( + callback_url, + data=json.dumps(message), + headers={'Content-Type': 'application/json'}, + ) as resp: + if resp.status >= 300 and not self.closed: + self.log.error('Callback request returned HTTP error %d', resp.status) + except Exception as e: + if not self.closed: + self.log.error('Callback request failed: %s', e) + + async def close(self): + self.closed = True + if self.sse is not None: + await self.sse.close() + self.log.info('Test ended') diff --git a/docs/conf.py b/docs/conf.py index c3f71d2..186a344 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -170,3 +170,6 @@ autodoc_default_options = { 'undoc-members': False } + +# aiohttp is an optional dependency not installed during doc builds +autodoc_mock_imports = ['aiohttp'] diff --git a/docs/index.rst b/docs/index.rst index 06dd66a..fa33901 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -3,7 +3,6 @@ LaunchDarkly Python SSE Client This is the API reference for the `launchdarkly-eventsource `_ package, a `Server-Sent Events `_ client for Python. This package is used internally by the `LaunchDarkly Python SDK `_, but may also be useful for other purposes. - ld_eventsource module --------------------- @@ -37,3 +36,21 @@ ld_eventsource.errors module :members: :special-members: __init__ :show-inheritance: + + +ld_eventsource.async_client module +----------------------------------- + +.. automodule:: ld_eventsource.async_client + :members: + :special-members: __init__ + :show-inheritance: + + +ld_eventsource.config.async_connect_strategy module +---------------------------------------------------- + +.. automodule:: ld_eventsource.config.async_connect_strategy + :members: + :special-members: __init__ + :show-inheritance: diff --git a/ld_eventsource/__init__.py b/ld_eventsource/__init__.py index 8e88c11..c4a208d 100644 --- a/ld_eventsource/__init__.py +++ b/ld_eventsource/__init__.py @@ -1 +1,10 @@ from ld_eventsource.sse_client import * + + +def __getattr__(name): + # Lazily import AsyncSSEClient so that aiohttp (an optional dependency) + # is never imported for sync-only users who don't have it installed. + if name == 'AsyncSSEClient': + from ld_eventsource.async_client import AsyncSSEClient + return AsyncSSEClient + raise AttributeError(f"module 'ld_eventsource' has no attribute {name!r}") diff --git a/ld_eventsource/async_client.py b/ld_eventsource/async_client.py new file mode 100644 index 0000000..b97d1bd --- /dev/null +++ b/ld_eventsource/async_client.py @@ -0,0 +1,237 @@ +import asyncio +import logging +import time +from typing import AsyncIterable, Optional, Union + +from ld_eventsource.actions import Action, Event, Fault, Start +from ld_eventsource.async_reader import _AsyncBufferedLineReader, _AsyncSSEReader +from ld_eventsource.config.async_connect_strategy import ( + AsyncConnectStrategy, AsyncConnectionClient, AsyncConnectionResult) +from ld_eventsource.config.error_strategy import ErrorStrategy +from ld_eventsource.config.retry_delay_strategy import RetryDelayStrategy + + +class AsyncSSEClient: + """ + An async client for reading a Server-Sent Events stream. + + This is an async/await implementation. The expected usage is to create an ``AsyncSSEClient`` + instance (either as an async context manager or directly), then read from it using the async + iterator properties :attr:`events` or :attr:`all`. + + By default, ``AsyncSSEClient`` uses ``aiohttp`` to make HTTP requests to an SSE endpoint. + You can customize this behavior using :class:`.AsyncConnectStrategy`. + + Connection failures and error responses can be handled in various ways depending on the + constructor parameters. The default behavior is the same as :class:`.SSEClient`. + + Example:: + + async with AsyncSSEClient("https://my-server/events") as client: + async for event in client.events: + print(event.data) + """ + + def __init__( + self, + connect: Union[str, AsyncConnectStrategy], + initial_retry_delay: float = 1, + retry_delay_strategy: Optional[RetryDelayStrategy] = None, + retry_delay_reset_threshold: float = 60, + error_strategy: Optional[ErrorStrategy] = None, + last_event_id: Optional[str] = None, + logger: Optional[logging.Logger] = None, + ): + """ + Creates an async client instance. + + :param connect: either an :class:`.AsyncConnectStrategy` instance or a URL string + :param initial_retry_delay: the initial delay before reconnecting after a failure, in seconds + :param retry_delay_strategy: allows customization of the delay behavior for retries + :param retry_delay_reset_threshold: minimum connection time before resetting retry delay + :param error_strategy: allows customization of the behavior after a stream failure + :param last_event_id: if provided, the ``Last-Event-Id`` value will be preset to this + :param logger: if provided, log messages will be written here + """ + if isinstance(connect, str): + connect = AsyncConnectStrategy.http(connect) + elif not isinstance(connect, AsyncConnectStrategy): + raise TypeError("connect must be either a string or AsyncConnectStrategy") + + self.__base_retry_delay = initial_retry_delay + self.__base_retry_delay_strategy = ( + retry_delay_strategy or RetryDelayStrategy.default() + ) + self.__retry_delay_reset_threshold = retry_delay_reset_threshold + self.__current_retry_delay_strategy = self.__base_retry_delay_strategy + self.__next_retry_delay = 0 + + self.__base_error_strategy = error_strategy or ErrorStrategy.always_fail() + self.__current_error_strategy = self.__base_error_strategy + + self.__last_event_id = last_event_id + + if logger is None: + logger = logging.getLogger('launchdarkly-eventsource-async.null') + logger.addHandler(logging.NullHandler()) + logger.propagate = False + self.__logger = logger + + self.__connection_client: AsyncConnectionClient = connect.create_client(logger) + self.__connection_result: Optional[AsyncConnectionResult] = None + self._retry_reset_baseline: float = 0 + self.__disconnected_time: float = 0 + + self.__closed = False + self.__interrupted = False + + async def start(self): + """ + Attempts to start the stream if it is not already active. + """ + await self._try_start(False) + + async def close(self): + """ + Permanently shuts down this client instance and closes any active connection. + """ + self.__closed = True + await self.interrupt() + await self.__connection_client.close() + + async def interrupt(self): + """ + Stops the stream connection if it is currently active, without permanently closing. + """ + if self.__connection_result: + self.__interrupted = True + await self.__connection_result.close() + self.__connection_result = None + self._compute_next_retry_delay() + + @property + def all(self) -> AsyncIterable[Action]: + """ + An async iterable series of notifications from the stream. + + Each can be any subclass of :class:`.Action`: :class:`.Event`, :class:`.Comment`, + :class:`.Start`, or :class:`.Fault`. + """ + return self._all_generator() + + @property + def events(self) -> AsyncIterable[Event]: + """ + An async iterable series of :class:`.Event` objects received from the stream. + """ + return self._events_generator() + + async def _all_generator(self): + while True: + while self.__connection_result is None: + result = await self._try_start(True) + if result is not None: + yield result + + lines = _AsyncBufferedLineReader.lines_from(self.__connection_result.stream) + reader = _AsyncSSEReader(lines, self.__last_event_id, None) + error: Optional[Exception] = None + try: + async for ec in reader.events_and_comments(): + self.__last_event_id = reader.last_event_id + yield ec + if self.__interrupted: + break + self.__connection_result = None + except Exception as e: + if self.__closed: + return + error = e + self.__connection_result = None + finally: + self.__last_event_id = reader.last_event_id + + self._compute_next_retry_delay() + fail_or_continue, self.__current_error_strategy = ( + self.__current_error_strategy.apply(error) + ) + if fail_or_continue == ErrorStrategy.FAIL: + if error is None: + yield Fault(None) + return + raise error + yield Fault(error) + continue + + async def _events_generator(self): + async for item in self._all_generator(): + if isinstance(item, Event): + yield item + + @property + def next_retry_delay(self) -> float: + """ + The retry delay that will be used for the next reconnection, in seconds. + """ + return self.__next_retry_delay + + def _compute_next_retry_delay(self): + if self.__retry_delay_reset_threshold > 0 and self._retry_reset_baseline != 0: + now = time.time() + connection_duration = now - self._retry_reset_baseline + if connection_duration >= self.__retry_delay_reset_threshold: + self.__current_retry_delay_strategy = self.__base_retry_delay_strategy + self._retry_reset_baseline = now + self.__next_retry_delay, self.__current_retry_delay_strategy = ( + self.__current_retry_delay_strategy.apply(self.__base_retry_delay) + ) + + async def _try_start(self, can_return_fault: bool): + if self.__connection_result is not None: + return None + while True: + if self.__next_retry_delay > 0: + delay = ( + self.__next_retry_delay + if self.__disconnected_time == 0 + else self.__next_retry_delay + - (time.time() - self.__disconnected_time) + ) + if delay > 0: + self.__logger.info("Will reconnect after delay of %fs" % delay) + await asyncio.sleep(delay) + try: + self.__connection_result = await self.__connection_client.connect( + self.__last_event_id + ) + except Exception as e: + self.__disconnected_time = time.time() + self._compute_next_retry_delay() + fail_or_continue, self.__current_error_strategy = ( + self.__current_error_strategy.apply(e) + ) + if fail_or_continue == ErrorStrategy.FAIL: + raise e + if can_return_fault: + return Fault(e) + continue + self._retry_reset_baseline = time.time() + self.__current_error_strategy = self.__base_error_strategy + self.__interrupted = False + return Start(self.__connection_result.headers) + + @property + def last_event_id(self) -> Optional[str]: + """ + The ID value, if any, of the last known event. + """ + return self.__last_event_id + + async def __aenter__(self): + return self + + async def __aexit__(self, type, value, traceback): + await self.close() + + +__all__ = ['AsyncSSEClient'] diff --git a/ld_eventsource/async_http.py b/ld_eventsource/async_http.py new file mode 100644 index 0000000..c8cb5fa --- /dev/null +++ b/ld_eventsource/async_http.py @@ -0,0 +1,120 @@ +import asyncio +from logging import Logger +from typing import Any, AsyncIterator, Callable, Dict, Optional, Tuple +from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit + +import aiohttp + +from ld_eventsource.errors import HTTPContentTypeError, HTTPStatusError + +_CHUNK_SIZE = 10000 + + +class _AsyncHttpConnectParams: + def __init__( + self, + url: str, + headers: Optional[dict] = None, + session: Optional[aiohttp.ClientSession] = None, + aiohttp_request_options: Optional[dict] = None, + query_params=None, + ): + self.__url = url + self.__headers = headers + self.__session = session + self.__aiohttp_request_options = aiohttp_request_options + self.__query_params = query_params + + @property + def url(self) -> str: + return self.__url + + @property + def headers(self) -> Optional[dict]: + return self.__headers + + @property + def session(self) -> Optional[aiohttp.ClientSession]: + return self.__session + + @property + def aiohttp_request_options(self) -> Optional[dict]: + return self.__aiohttp_request_options + + @property + def query_params(self): + return self.__query_params + + +class _AsyncHttpClientImpl: + def __init__(self, params: _AsyncHttpConnectParams, logger: Logger): + self.__params = params + self.__external_session = params.session + self.__session: Optional[aiohttp.ClientSession] = params.session + self.__session_lock = asyncio.Lock() + self.__logger = logger + + async def _get_session(self) -> aiohttp.ClientSession: + if self.__session is not None: + return self.__session + async with self.__session_lock: + if self.__session is None: + self.__session = aiohttp.ClientSession() + return self.__session + + async def connect( + self, last_event_id: Optional[str] + ) -> Tuple[AsyncIterator[bytes], Callable, Dict[str, Any]]: + url = self.__params.url + if self.__params.query_params is not None: + qp = self.__params.query_params() + if qp: + url_parts = list(urlsplit(url)) + query = dict(parse_qsl(url_parts[3])) + query.update(qp) + url_parts[3] = urlencode(query) + url = urlunsplit(url_parts) + self.__logger.info("Connecting to stream at %s" % url) + + headers = self.__params.headers.copy() if self.__params.headers else {} + headers['Cache-Control'] = 'no-cache' + headers['Accept'] = 'text/event-stream' + + if last_event_id: + headers['Last-Event-ID'] = last_event_id + + request_options = ( + self.__params.aiohttp_request_options.copy() + if self.__params.aiohttp_request_options + else {} + ) + request_options['headers'] = headers + + session = await self._get_session() + resp = await session.get(url, **request_options) + + response_headers: Dict[str, Any] = dict(resp.headers) + + if resp.status >= 400 or resp.status == 204: + await resp.release() + raise HTTPStatusError(resp.status, response_headers) + + content_type = resp.headers.get('Content-Type', None) + if content_type is None or not str(content_type).startswith("text/event-stream"): + await resp.release() + raise HTTPContentTypeError(content_type or '', response_headers) + + async def chunk_iterator() -> AsyncIterator[bytes]: + async for chunk in resp.content.iter_chunked(_CHUNK_SIZE): + yield chunk + + async def closer(): + await resp.release() + + return chunk_iterator(), closer, response_headers + + async def close(self): + # Only close the session if we created it ourselves + if self.__external_session is None and self.__session is not None: + await self.__session.close() + self.__session = None diff --git a/ld_eventsource/async_reader.py b/ld_eventsource/async_reader.py new file mode 100644 index 0000000..d71aa9b --- /dev/null +++ b/ld_eventsource/async_reader.py @@ -0,0 +1,101 @@ +from typing import AsyncIterator, Callable, Optional + +from ld_eventsource.actions import Comment, Event + + +class _AsyncBufferedLineReader: + """ + Async version of _BufferedLineReader. Reads UTF-8 stream data as a series of text lines, + each of which can be terminated by \n, \r, or \r\n. + """ + + @staticmethod + async def lines_from(chunks: AsyncIterator[bytes]) -> AsyncIterator[str]: + last_char_was_cr = False + partial_line = None + + async for chunk in chunks: + if len(chunk) == 0: + continue + + lines = chunk.splitlines() + if last_char_was_cr: + last_char_was_cr = False + if chunk[0] == 10: + lines.pop(0) + if len(lines) == 0: + continue + if partial_line is not None: + lines[0] = partial_line + lines[0] + partial_line = None + last_char = chunk[-1] + if last_char == 13: + last_char_was_cr = True + elif last_char != 10: + partial_line = lines.pop() + for line in lines: + yield line.decode() + + +class _AsyncSSEReader: + def __init__( + self, + lines_source: AsyncIterator[str], + last_event_id: Optional[str] = None, + set_retry: Optional[Callable[[int], None]] = None, + ): + self._lines_source = lines_source + self._last_event_id = last_event_id + self._set_retry = set_retry + + @property + def last_event_id(self): + return self._last_event_id + + async def events_and_comments(self) -> AsyncIterator: + event_type = "" + event_data = None + event_id = None + async for line in self._lines_source: + if line == "": + if event_data is not None: + if event_id is not None: + self._last_event_id = event_id + yield Event( + "message" if event_type == "" else event_type, + event_data, + event_id, + self._last_event_id, + ) + event_type = "" + event_data = None + event_id = None + continue + colon_pos = line.find(':') + if colon_pos == 0: + yield Comment(line[1:]) + continue + if colon_pos < 0: + name = line + value = "" + else: + name = line[:colon_pos] + if colon_pos < (len(line) - 1) and line[colon_pos + 1] == ' ': + colon_pos += 1 + value = line[colon_pos + 1:] + if name == 'event': + event_type = value + elif name == 'data': + event_data = ( + value if event_data is None else (event_data + "\n" + value) + ) + elif name == 'id': + if value.find("\x00") < 0: + event_id = value + elif name == 'retry': + try: + n = int(value) + if self._set_retry: + self._set_retry(n) + except Exception: + pass diff --git a/ld_eventsource/config/async_connect_strategy.py b/ld_eventsource/config/async_connect_strategy.py new file mode 100644 index 0000000..4c09cda --- /dev/null +++ b/ld_eventsource/config/async_connect_strategy.py @@ -0,0 +1,148 @@ +from __future__ import annotations + +from logging import Logger +from typing import Any, AsyncIterator, Callable, Dict, Optional + + +class AsyncConnectStrategy: + """ + An abstraction for how :class:`.AsyncSSEClient` should obtain an input stream. + + The default implementation is :meth:`http()`, which makes HTTP requests with ``aiohttp``. + Or, if you want to consume an input stream from some other source, you can create your own + subclass of :class:`AsyncConnectStrategy`. + + Instances of this class should be immutable and should not contain any state that is specific + to one active stream. The :class:`AsyncConnectionClient` that they produce is stateful and + belongs to a single :class:`.AsyncSSEClient`. + """ + + def create_client(self, logger: Logger) -> AsyncConnectionClient: + """ + Creates a client instance. + + This is called once when an :class:`.AsyncSSEClient` is created. + + :param logger: the logger being used by the AsyncSSEClient + """ + raise NotImplementedError("AsyncConnectStrategy base class cannot be used by itself") + + @staticmethod + def http( + url: str, + headers: Optional[dict] = None, + session=None, + aiohttp_request_options: Optional[dict] = None, + query_params=None, + ) -> AsyncConnectStrategy: + """ + Creates the default async HTTP implementation using aiohttp. + + :param url: the stream URL + :param headers: optional HTTP headers to add to the request + :param session: optional ``aiohttp.ClientSession`` to use + :param aiohttp_request_options: optional kwargs passed to the aiohttp ``get()`` call + :param query_params: optional callable that returns a dict of query params per connection + """ + # Import here to avoid requiring aiohttp for users who don't use async HTTP + from ld_eventsource.async_http import _AsyncHttpClientImpl, _AsyncHttpConnectParams + return _AsyncHttpConnectStrategy( + _AsyncHttpConnectParams(url, headers, session, aiohttp_request_options, query_params) + ) + + +class AsyncConnectionClient: + """ + An object provided by :class:`.AsyncConnectStrategy` that is retained by a single + :class:`.AsyncSSEClient` to perform all connection attempts by that instance. + """ + + async def connect(self, last_event_id: Optional[str]) -> AsyncConnectionResult: + """ + Attempts to connect to a stream. Raises an exception if unsuccessful. + + :param last_event_id: the current value of last_event_id (sent to server for resuming) + :return: an :class:`AsyncConnectionResult` representing the stream + """ + raise NotImplementedError("AsyncConnectionClient base class cannot be used by itself") + + async def close(self): + """ + Does whatever is necessary to release resources when the AsyncSSEClient is closed. + """ + pass + + async def __aenter__(self): + return self + + async def __aexit__(self, type, value, traceback): + await self.close() + + +class AsyncConnectionResult: + """ + The return type of :meth:`AsyncConnectionClient.connect()`. + """ + + def __init__( + self, + stream: AsyncIterator[bytes], + closer: Optional[Callable], + headers: Optional[Dict[str, Any]] = None, + ): + self.__stream = stream + self.__closer = closer + self.__headers = headers + + @property + def stream(self) -> AsyncIterator[bytes]: + """ + An async iterator that returns chunks of data. + """ + return self.__stream + + @property + def headers(self) -> Optional[Dict[str, Any]]: + """ + The HTTP response headers, if available. + """ + return self.__headers + + async def close(self): + """ + Does whatever is necessary to release the connection. + """ + if self.__closer: + await self.__closer() + self.__closer = None + + async def __aenter__(self): + return self + + async def __aexit__(self, type, value, traceback): + await self.close() + + +class _AsyncHttpConnectStrategy(AsyncConnectStrategy): + def __init__(self, params): + self.__params = params + + def create_client(self, logger: Logger) -> AsyncConnectionClient: + from ld_eventsource.async_http import _AsyncHttpClientImpl + return _AsyncHttpConnectionClient(self.__params, logger) + + +class _AsyncHttpConnectionClient(AsyncConnectionClient): + def __init__(self, params, logger: Logger): + from ld_eventsource.async_http import _AsyncHttpClientImpl + self.__impl = _AsyncHttpClientImpl(params, logger) + + async def connect(self, last_event_id: Optional[str]) -> AsyncConnectionResult: + stream, closer, headers = await self.__impl.connect(last_event_id) + return AsyncConnectionResult(stream, closer, headers) + + async def close(self): + await self.__impl.close() + + +__all__ = ['AsyncConnectStrategy', 'AsyncConnectionClient', 'AsyncConnectionResult'] diff --git a/ld_eventsource/testing/async_helpers.py b/ld_eventsource/testing/async_helpers.py new file mode 100644 index 0000000..f5f61ef --- /dev/null +++ b/ld_eventsource/testing/async_helpers.py @@ -0,0 +1,96 @@ +from __future__ import annotations + +import asyncio +from logging import Logger +from typing import AsyncIterable, AsyncIterator, List, Optional + +from ld_eventsource.config.async_connect_strategy import ( + AsyncConnectionClient, AsyncConnectionResult, AsyncConnectStrategy) +from ld_eventsource.config.error_strategy import ErrorStrategy +from ld_eventsource.config.retry_delay_strategy import RetryDelayStrategy +from ld_eventsource.errors import HTTPStatusError +from ld_eventsource.testing.http_util import ChunkedResponse + + +def make_stream() -> ChunkedResponse: + return ChunkedResponse({'Content-Type': 'text/event-stream'}) + + +def retry_for_status(status: int) -> ErrorStrategy: + return ErrorStrategy.from_lambda( + lambda error: ( + ( + ErrorStrategy.CONTINUE + if isinstance(error, HTTPStatusError) and error.status == status + else ErrorStrategy.FAIL + ), + None, + ) + ) + + +def no_delay() -> RetryDelayStrategy: + return RetryDelayStrategy.from_lambda(lambda _: (0, None)) + + +class MockAsyncConnectStrategy(AsyncConnectStrategy): + def __init__(self, *request_handlers: MockAsyncConnectionHandler): + self.__handlers = list(request_handlers) + + def create_client(self, logger: Logger) -> AsyncConnectionClient: + return MockAsyncConnectionClient(self.__handlers) + + +class MockAsyncConnectionClient(AsyncConnectionClient): + def __init__(self, handlers: List[MockAsyncConnectionHandler]): + self.__handlers = handlers + self.__request_count = 0 + + async def connect(self, last_event_id: Optional[str]) -> AsyncConnectionResult: + handler = self.__handlers[self.__request_count] + if self.__request_count < len(self.__handlers) - 1: + self.__request_count += 1 + return await handler.apply() + + +class MockAsyncConnectionHandler: + async def apply(self) -> AsyncConnectionResult: + raise NotImplementedError( + "MockAsyncConnectionHandler base class cannot be used by itself" + ) + + +class AsyncRejectConnection(MockAsyncConnectionHandler): + def __init__(self, error: Exception): + self.__error = error + + async def apply(self) -> AsyncConnectionResult: + raise self.__error + + +class AsyncRespondWithStream(MockAsyncConnectionHandler): + def __init__(self, stream: AsyncIterable[bytes], headers: Optional[dict] = None): + self.__stream = stream + self.__headers = headers + + async def apply(self) -> AsyncConnectionResult: + return AsyncConnectionResult( + stream=self.__stream.__aiter__(), + closer=None, + headers=self.__headers, + ) + + +class AsyncRespondWithData(AsyncRespondWithStream): + def __init__(self, data: str, headers: Optional[dict] = None): + super().__init__(_bytes_async_iter([bytes(data, 'utf-8')]), headers) + + +class AsyncExpectNoMoreRequests(MockAsyncConnectionHandler): + async def apply(self) -> AsyncConnectionResult: + assert False, "AsyncSSEClient should not have made another request" + + +async def _bytes_async_iter(items: List[bytes]) -> AsyncIterator[bytes]: + for item in items: + yield item diff --git a/ld_eventsource/testing/test_async_http_connect_strategy.py b/ld_eventsource/testing/test_async_http_connect_strategy.py new file mode 100644 index 0000000..1f1ca4d --- /dev/null +++ b/ld_eventsource/testing/test_async_http_connect_strategy.py @@ -0,0 +1,164 @@ +import logging + +import pytest + +from ld_eventsource.async_client import AsyncSSEClient +from ld_eventsource.config.async_connect_strategy import AsyncConnectStrategy +from ld_eventsource.errors import HTTPContentTypeError, HTTPStatusError +from ld_eventsource.testing.async_helpers import no_delay +from ld_eventsource.testing.helpers import retry_for_status +from ld_eventsource.testing.http_util import (BasicResponse, ChunkedResponse, + CauseNetworkError, start_server) + + +def logger(): + return logging.getLogger("test") + + +@pytest.mark.asyncio +async def test_http_request_gets_chunked_data(): + with start_server() as server: + with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream: + server.for_path('/', stream) + strategy = AsyncConnectStrategy.http(server.uri) + client_obj = strategy.create_client(logger()) + result = await client_obj.connect(None) + try: + stream.push('hello') + chunk = await result.stream.__anext__() + assert chunk == b'hello' + finally: + await result.close() + await client_obj.close() + + +@pytest.mark.asyncio +async def test_http_request_default_headers(): + with start_server() as server: + with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream: + server.for_path('/', stream) + strategy = AsyncConnectStrategy.http(server.uri) + client_obj = strategy.create_client(logger()) + result = await client_obj.connect(None) + try: + r = server.await_request() + assert r.headers['Accept'] == 'text/event-stream' + assert r.headers['Cache-Control'] == 'no-cache' + assert r.headers.get('Last-Event-Id') is None + finally: + await result.close() + await client_obj.close() + + +@pytest.mark.asyncio +async def test_http_request_custom_headers(): + with start_server() as server: + with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream: + server.for_path('/', stream) + strategy = AsyncConnectStrategy.http(server.uri, headers={'name1': 'value1'}) + client_obj = strategy.create_client(logger()) + result = await client_obj.connect(None) + try: + r = server.await_request() + assert r.headers['Accept'] == 'text/event-stream' + assert r.headers['Cache-Control'] == 'no-cache' + assert r.headers['name1'] == 'value1' + finally: + await result.close() + await client_obj.close() + + +@pytest.mark.asyncio +async def test_http_request_last_event_id_header(): + with start_server() as server: + with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream: + server.for_path('/', stream) + strategy = AsyncConnectStrategy.http(server.uri) + client_obj = strategy.create_client(logger()) + result = await client_obj.connect('id123') + try: + r = server.await_request() + assert r.headers['Last-Event-Id'] == 'id123' + finally: + await result.close() + await client_obj.close() + + +@pytest.mark.asyncio +async def test_http_status_error(): + with start_server() as server: + server.for_path('/', BasicResponse(400)) + strategy = AsyncConnectStrategy.http(server.uri) + client_obj = strategy.create_client(logger()) + try: + with pytest.raises(HTTPStatusError) as exc_info: + await client_obj.connect(None) + assert exc_info.value.status == 400 + finally: + await client_obj.close() + + +@pytest.mark.asyncio +async def test_http_content_type_error(): + with start_server() as server: + with ChunkedResponse({'Content-Type': 'text/plain'}) as stream: + server.for_path('/', stream) + strategy = AsyncConnectStrategy.http(server.uri) + client_obj = strategy.create_client(logger()) + try: + with pytest.raises(HTTPContentTypeError) as exc_info: + await client_obj.connect(None) + assert exc_info.value.content_type == "text/plain" + finally: + await client_obj.close() + + +@pytest.mark.asyncio +async def test_http_response_headers_captured(): + with start_server() as server: + custom_headers = { + 'Content-Type': 'text/event-stream', + 'X-Custom-Header': 'custom-value', + } + with ChunkedResponse(custom_headers) as stream: + server.for_path('/', stream) + strategy = AsyncConnectStrategy.http(server.uri) + client_obj = strategy.create_client(logger()) + result = await client_obj.connect(None) + try: + assert result.headers is not None + assert result.headers.get('X-Custom-Header') == 'custom-value' + finally: + await result.close() + await client_obj.close() + + +@pytest.mark.asyncio +async def test_http_status_error_includes_headers(): + with start_server() as server: + server.for_path('/', BasicResponse(429, None, { + 'Retry-After': '120', + })) + strategy = AsyncConnectStrategy.http(server.uri) + client_obj = strategy.create_client(logger()) + try: + with pytest.raises(HTTPStatusError) as exc_info: + await client_obj.connect(None) + assert exc_info.value.status == 429 + assert exc_info.value.headers is not None + assert exc_info.value.headers.get('Retry-After') == '120' + finally: + await client_obj.close() + + +@pytest.mark.asyncio +async def test_sse_client_with_http_connect_strategy(): + with start_server() as server: + with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream: + server.for_path('/', stream) + async with AsyncSSEClient(connect=AsyncConnectStrategy.http(server.uri)) as client: + await client.start() + stream.push("data: data1\n\n") + async for event in client.events: + assert event.data == 'data1' + break diff --git a/ld_eventsource/testing/test_async_reader.py b/ld_eventsource/testing/test_async_reader.py new file mode 100644 index 0000000..04ba805 --- /dev/null +++ b/ld_eventsource/testing/test_async_reader.py @@ -0,0 +1,152 @@ +import pytest + +from ld_eventsource.actions import Comment, Event +from ld_eventsource.async_reader import _AsyncBufferedLineReader, _AsyncSSEReader + + +async def lines_from_bytes(*chunks: bytes): + """Helper: collect all lines from given byte chunks.""" + async def gen(): + for chunk in chunks: + yield chunk + + result = [] + async for line in _AsyncBufferedLineReader.lines_from(gen()): + result.append(line) + return result + + +async def events_from_lines(*lines: str): + """Helper: collect all events/comments from given lines.""" + async def gen(): + for line in lines: + yield line + + reader = _AsyncSSEReader(gen()) + result = [] + async for item in reader.events_and_comments(): + result.append(item) + return result + + +@pytest.mark.asyncio +async def test_line_reader_simple_newline(): + lines = await lines_from_bytes(b"hello\nworld\n") + assert lines == ["hello", "world"] + + +@pytest.mark.asyncio +async def test_line_reader_carriage_return(): + lines = await lines_from_bytes(b"hello\rworld\r") + assert lines == ["hello", "world"] + + +@pytest.mark.asyncio +async def test_line_reader_crlf(): + lines = await lines_from_bytes(b"hello\r\nworld\r\n") + assert lines == ["hello", "world"] + + +@pytest.mark.asyncio +async def test_line_reader_crlf_split_across_chunks(): + lines = await lines_from_bytes(b"hello\r", b"\nworld\r\n") + assert lines == ["hello", "world"] + + +@pytest.mark.asyncio +async def test_line_reader_partial_line_across_chunks(): + lines = await lines_from_bytes(b"hel", b"lo\n") + assert lines == ["hello"] + + +@pytest.mark.asyncio +async def test_line_reader_empty_chunk(): + lines = await lines_from_bytes(b"hello\n", b"", b"world\n") + assert lines == ["hello", "world"] + + +@pytest.mark.asyncio +async def test_sse_reader_simple_event(): + items = await events_from_lines("data: hello", "") + assert len(items) == 1 + assert isinstance(items[0], Event) + assert items[0].data == "hello" + assert items[0].event == "message" + + +@pytest.mark.asyncio +async def test_sse_reader_event_with_type(): + items = await events_from_lines("event: ping", "data: test", "") + assert len(items) == 1 + assert isinstance(items[0], Event) + assert items[0].event == "ping" + assert items[0].data == "test" + + +@pytest.mark.asyncio +async def test_sse_reader_multiline_data(): + items = await events_from_lines("data: line1", "data: line2", "") + assert len(items) == 1 + assert items[0].data == "line1\nline2" + + +@pytest.mark.asyncio +async def test_sse_reader_comment(): + items = await events_from_lines(":this is a comment", "data: event", "") + assert len(items) == 2 + assert isinstance(items[0], Comment) + assert items[0].comment == "this is a comment" + assert isinstance(items[1], Event) + + +@pytest.mark.asyncio +async def test_sse_reader_event_id(): + items = await events_from_lines("id: 123", "data: hello", "") + assert len(items) == 1 + assert items[0].id == "123" + assert items[0].last_event_id == "123" + + +@pytest.mark.asyncio +async def test_sse_reader_id_persists_across_events(): + items = await events_from_lines( + "id: 1", "data: first", "", + "data: second", "", + ) + assert len(items) == 2 + assert items[0].last_event_id == "1" + assert items[1].last_event_id == "1" + + +@pytest.mark.asyncio +async def test_sse_reader_retry_field(): + retries = [] + + async def gen(): + for line in ["retry: 5000", "data: test", ""]: + yield line + + reader = _AsyncSSEReader(gen(), set_retry=lambda n: retries.append(n)) + async for _ in reader.events_and_comments(): + pass + assert retries == [5000] + + +@pytest.mark.asyncio +async def test_sse_reader_ignores_null_in_id(): + items = await events_from_lines("id: bad\x00id", "data: test", "") + assert len(items) == 1 + assert items[0].id is None + + +@pytest.mark.asyncio +async def test_sse_reader_multiple_events(): + items = await events_from_lines( + "event: e1", "data: d1", "", + "event: e2", "data: d2", "", + ) + assert len(items) == 2 + assert items[0].event == "e1" + assert items[0].data == "d1" + assert items[1].event == "e2" + assert items[1].data == "d2" diff --git a/ld_eventsource/testing/test_async_sse_client_basic.py b/ld_eventsource/testing/test_async_sse_client_basic.py new file mode 100644 index 0000000..10fd7cf --- /dev/null +++ b/ld_eventsource/testing/test_async_sse_client_basic.py @@ -0,0 +1,149 @@ +import pytest + +from ld_eventsource.actions import Comment, Event, Fault, Start +from ld_eventsource.async_client import AsyncSSEClient +from ld_eventsource.testing.async_helpers import (AsyncRespondWithData, + MockAsyncConnectStrategy) + + +@pytest.mark.asyncio +@pytest.mark.parametrize('explicitly_start', [False, True]) +async def test_receives_events(explicitly_start: bool): + mock = MockAsyncConnectStrategy( + AsyncRespondWithData( + "event: event1\ndata: data1\n\n:whatever\nevent: event2\ndata: data2\n\n" + ) + ) + async with AsyncSSEClient(connect=mock) as client: + if explicitly_start: + await client.start() + + events = client.events + event_iter = events.__aiter__() + + event1 = await event_iter.__anext__() + assert event1.event == 'event1' + assert event1.data == 'data1' + + event2 = await event_iter.__anext__() + assert event2.event == 'event2' + assert event2.data == 'data2' + + +@pytest.mark.asyncio +async def test_events_returns_eof_when_stream_ends(): + mock = MockAsyncConnectStrategy(AsyncRespondWithData("event: event1\ndata: data1\n\n")) + async with AsyncSSEClient(connect=mock) as client: + events = [] + async for event in client.events: + events.append(event) + + assert len(events) == 1 + assert events[0].event == 'event1' + assert events[0].data == 'data1' + + +@pytest.mark.asyncio +async def test_receives_all(): + mock = MockAsyncConnectStrategy( + AsyncRespondWithData( + "event: event1\ndata: data1\n\n:whatever\nevent: event2\ndata: data2\n\n" + ) + ) + async with AsyncSSEClient(connect=mock) as client: + all_iter = client.all.__aiter__() + + item1 = await all_iter.__anext__() + assert isinstance(item1, Start) + + item2 = await all_iter.__anext__() + assert isinstance(item2, Event) + assert item2.event == 'event1' + assert item2.data == 'data1' + + item3 = await all_iter.__anext__() + assert isinstance(item3, Comment) + assert item3.comment == 'whatever' + + item4 = await all_iter.__anext__() + assert isinstance(item4, Event) + assert item4.event == 'event2' + assert item4.data == 'data2' + + +@pytest.mark.asyncio +async def test_all_returns_fault_and_eof_when_stream_ends(): + mock = MockAsyncConnectStrategy(AsyncRespondWithData("event: event1\ndata: data1\n\n")) + async with AsyncSSEClient(connect=mock) as client: + all_iter = client.all.__aiter__() + + item1 = await all_iter.__anext__() + assert isinstance(item1, Start) + + item2 = await all_iter.__anext__() + assert isinstance(item2, Event) + assert item2.event == 'event1' + assert item2.data == 'data1' + + item3 = await all_iter.__anext__() + assert isinstance(item3, Fault) + assert item3.error is None + + with pytest.raises(StopAsyncIteration): + await all_iter.__anext__() + + +@pytest.mark.asyncio +async def test_start_headers_exposed(): + mock = MockAsyncConnectStrategy( + AsyncRespondWithData("data: hello\n\n", headers={'X-My-Header': 'myvalue'}) + ) + async with AsyncSSEClient(connect=mock) as client: + all_iter = client.all.__aiter__() + + start = await all_iter.__anext__() + assert isinstance(start, Start) + assert start.headers is not None + assert start.headers.get('X-My-Header') == 'myvalue' + + +@pytest.mark.asyncio +async def test_last_event_id_tracked(): + mock = MockAsyncConnectStrategy( + AsyncRespondWithData("id: abc\ndata: hello\n\n") + ) + async with AsyncSSEClient(connect=mock) as client: + async for event in client.events: + assert event.last_event_id == 'abc' + break + assert client.last_event_id == 'abc' + + +@pytest.mark.asyncio +async def test_close_stops_iteration(): + mock = MockAsyncConnectStrategy( + AsyncRespondWithData("data: first\n\ndata: second\n\n") + ) + async with AsyncSSEClient(connect=mock) as client: + events_seen = [] + async for event in client.events: + events_seen.append(event) + await client.close() + + assert len(events_seen) == 1 + + +@pytest.mark.asyncio +async def test_string_url_creates_http_strategy(): + # Just verifies the constructor accepts a string without crashing + # (actual HTTP is tested separately) + from ld_eventsource.config.async_connect_strategy import AsyncConnectStrategy + client = AsyncSSEClient(connect="http://localhost:9999/stream") + assert client is not None + await client.close() + + +@pytest.mark.asyncio +async def test_invalid_connect_type_raises(): + with pytest.raises(TypeError): + AsyncSSEClient(connect=12345) diff --git a/ld_eventsource/testing/test_async_sse_client_retry.py b/ld_eventsource/testing/test_async_sse_client_retry.py new file mode 100644 index 0000000..f125ec7 --- /dev/null +++ b/ld_eventsource/testing/test_async_sse_client_retry.py @@ -0,0 +1,201 @@ +import asyncio + +import pytest + +from ld_eventsource.actions import Event, Fault, Start +from ld_eventsource.async_client import AsyncSSEClient +from ld_eventsource.config.error_strategy import ErrorStrategy +from ld_eventsource.config.retry_delay_strategy import RetryDelayStrategy +from ld_eventsource.errors import HTTPStatusError +from ld_eventsource.testing.async_helpers import (AsyncExpectNoMoreRequests, + AsyncRejectConnection, + AsyncRespondWithData, + MockAsyncConnectStrategy, + no_delay, + retry_for_status) + + +@pytest.mark.asyncio +async def test_retry_during_initial_connect_succeeds(): + mock = MockAsyncConnectStrategy( + AsyncRejectConnection(HTTPStatusError(503)), + AsyncRespondWithData("data: data1\n\n"), + AsyncExpectNoMoreRequests(), + ) + async with AsyncSSEClient( + connect=mock, + retry_delay_strategy=no_delay(), + error_strategy=retry_for_status(503), + ) as client: + await client.start() + + events = [] + async for event in client.events: + events.append(event) + break + assert events[0].data == 'data1' + + +@pytest.mark.asyncio +async def test_retry_during_initial_connect_succeeds_then_fails(): + mock = MockAsyncConnectStrategy( + AsyncRejectConnection(HTTPStatusError(503)), + AsyncRejectConnection(HTTPStatusError(400)), + AsyncExpectNoMoreRequests(), + ) + with pytest.raises(HTTPStatusError) as exc_info: + async with AsyncSSEClient( + connect=mock, + retry_delay_strategy=no_delay(), + error_strategy=retry_for_status(503), + ) as client: + await client.start() + assert exc_info.value.status == 400 + + +@pytest.mark.asyncio +async def test_events_iterator_continues_after_retry(): + mock = MockAsyncConnectStrategy( + AsyncRespondWithData("data: data1\n\n"), + AsyncRespondWithData("data: data2\n\n"), + AsyncExpectNoMoreRequests(), + ) + async with AsyncSSEClient( + connect=mock, + error_strategy=ErrorStrategy.always_continue(), + retry_delay_strategy=no_delay(), + ) as client: + events = [] + async for event in client.events: + events.append(event) + if len(events) == 2: + break + + assert events[0].data == 'data1' + assert events[1].data == 'data2' + + +@pytest.mark.asyncio +async def test_all_iterator_continues_after_retry(): + initial_delay = 0.005 + mock = MockAsyncConnectStrategy( + AsyncRespondWithData("data: data1\n\n"), + AsyncRespondWithData("data: data2\n\n"), + AsyncRespondWithData("data: data3\n\n"), + AsyncExpectNoMoreRequests(), + ) + async with AsyncSSEClient( + connect=mock, + error_strategy=ErrorStrategy.always_continue(), + initial_retry_delay=initial_delay, + retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None), + ) as client: + all_iter = client.all.__aiter__() + + item1 = await all_iter.__anext__() + assert isinstance(item1, Start) + + item2 = await all_iter.__anext__() + assert isinstance(item2, Event) + assert item2.data == 'data1' + + item3 = await all_iter.__anext__() + assert isinstance(item3, Fault) + assert item3.error is None + assert client.next_retry_delay == initial_delay + + item4 = await all_iter.__anext__() + assert isinstance(item4, Start) + + item5 = await all_iter.__anext__() + assert isinstance(item5, Event) + assert item5.data == 'data2' + + item6 = await all_iter.__anext__() + assert isinstance(item6, Fault) + assert item6.error is None + assert client.next_retry_delay == initial_delay * 2 + + +@pytest.mark.asyncio +async def test_can_interrupt_and_restart_stream(): + initial_delay = 0.005 + mock = MockAsyncConnectStrategy( + AsyncRespondWithData("data: data1\n\ndata: data2\n\n"), + AsyncRespondWithData("data: data3\n\n"), + AsyncExpectNoMoreRequests(), + ) + async with AsyncSSEClient( + connect=mock, + error_strategy=ErrorStrategy.always_continue(), + initial_retry_delay=initial_delay, + retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None), + ) as client: + all_iter = client.all.__aiter__() + + item1 = await all_iter.__anext__() + assert isinstance(item1, Start) + + item2 = await all_iter.__anext__() + assert isinstance(item2, Event) + assert item2.data == 'data1' + + await client.interrupt() + assert client.next_retry_delay == initial_delay + + item3 = await all_iter.__anext__() + assert isinstance(item3, Fault) + + item4 = await all_iter.__anext__() + assert isinstance(item4, Start) + + item5 = await all_iter.__anext__() + assert isinstance(item5, Event) + assert item5.data == 'data3' + + +@pytest.mark.asyncio +async def test_retry_delay_gets_reset_after_threshold(): + initial_delay = 0.005 + retry_delay_reset_threshold = 0.1 + mock = MockAsyncConnectStrategy( + AsyncRespondWithData("data: data1\n\n"), + AsyncRejectConnection(HTTPStatusError(503)), + ) + async with AsyncSSEClient( + connect=mock, + error_strategy=ErrorStrategy.always_continue(), + initial_retry_delay=initial_delay, + retry_delay_reset_threshold=retry_delay_reset_threshold, + retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None), + ) as client: + assert client._retry_reset_baseline == 0 + all_iter = client.all.__aiter__() + + item1 = await all_iter.__anext__() + assert isinstance(item1, Start) + assert client._retry_reset_baseline != 0 + + item2 = await all_iter.__anext__() + assert isinstance(item2, Event) + assert item2.data == 'data1' + + item3 = await all_iter.__anext__() + assert isinstance(item3, Fault) + assert client.next_retry_delay == initial_delay + + item4 = await all_iter.__anext__() + assert isinstance(item4, Fault) + assert client.next_retry_delay == initial_delay * 2 + + await asyncio.sleep(retry_delay_reset_threshold) + + item5 = await all_iter.__anext__() + assert isinstance(item5, Fault) + assert client.next_retry_delay == initial_delay + + await asyncio.sleep(retry_delay_reset_threshold / 2) + + item6 = await all_iter.__anext__() + assert isinstance(item6, Fault) + assert client.next_retry_delay == initial_delay * 2 diff --git a/pyproject.toml b/pyproject.toml index 2948b3d..604fdca 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,11 +29,16 @@ exclude = [ [tool.poetry.dependencies] python = ">=3.9" urllib3 = ">=1.26.0,<3" +aiohttp = { version = ">=3.8.0,<4", optional = true } +[tool.poetry.extras] +async = ["aiohttp"] [tool.poetry.group.dev.dependencies] mock = ">=2.0.0" pytest = ">=2.8" +pytest-asyncio = ">=0.21" +aiohttp = ">=3.8.0,<4" mypy = "^1.4.0" pycodestyle = "^2.12.1" isort = "^5.13.2" @@ -43,7 +48,7 @@ isort = "^5.13.2" optional = true [tool.poetry.group.contract-tests.dependencies] -Flask = "2.2.5" +Flask = ">=3.0" [tool.poetry.group.docs] @@ -59,7 +64,7 @@ pyrfc3339 = ">=1.0" jsonpickle = ">1.4.1" semver = ">=2.7.9" urllib3 = ">=1.26.0" -jinja2 = "3.0.0" +jinja2 = ">=3.1.2" [tool.mypy] python_version = "3.9" @@ -70,6 +75,7 @@ non_interactive = true [tool.pytest.ini_options] addopts = ["-ra"] +asyncio_mode = "auto" [build-system] From f4c0e17031e157ded679b22eb5fc0f9d6671b092 Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Wed, 11 Mar 2026 16:44:37 -0400 Subject: [PATCH 2/3] fix: Remove duplicate helpers from async_helpers and consolidate imports Removes make_stream, retry_for_status, and no_delay from async_helpers.py as they were either unused or duplicates of the same functions in helpers.py. Updates async test files to import these from helpers.py consistently. Co-Authored-By: Claude Sonnet 4.6 (1M context) --- ld_eventsource/testing/async_helpers.py | 25 ------------------- .../test_async_http_connect_strategy.py | 3 +-- .../testing/test_async_sse_client_retry.py | 5 ++-- 3 files changed, 3 insertions(+), 30 deletions(-) diff --git a/ld_eventsource/testing/async_helpers.py b/ld_eventsource/testing/async_helpers.py index f5f61ef..6bfbd36 100644 --- a/ld_eventsource/testing/async_helpers.py +++ b/ld_eventsource/testing/async_helpers.py @@ -6,31 +6,6 @@ from ld_eventsource.config.async_connect_strategy import ( AsyncConnectionClient, AsyncConnectionResult, AsyncConnectStrategy) -from ld_eventsource.config.error_strategy import ErrorStrategy -from ld_eventsource.config.retry_delay_strategy import RetryDelayStrategy -from ld_eventsource.errors import HTTPStatusError -from ld_eventsource.testing.http_util import ChunkedResponse - - -def make_stream() -> ChunkedResponse: - return ChunkedResponse({'Content-Type': 'text/event-stream'}) - - -def retry_for_status(status: int) -> ErrorStrategy: - return ErrorStrategy.from_lambda( - lambda error: ( - ( - ErrorStrategy.CONTINUE - if isinstance(error, HTTPStatusError) and error.status == status - else ErrorStrategy.FAIL - ), - None, - ) - ) - - -def no_delay() -> RetryDelayStrategy: - return RetryDelayStrategy.from_lambda(lambda _: (0, None)) class MockAsyncConnectStrategy(AsyncConnectStrategy): diff --git a/ld_eventsource/testing/test_async_http_connect_strategy.py b/ld_eventsource/testing/test_async_http_connect_strategy.py index 1f1ca4d..5196c54 100644 --- a/ld_eventsource/testing/test_async_http_connect_strategy.py +++ b/ld_eventsource/testing/test_async_http_connect_strategy.py @@ -5,8 +5,7 @@ from ld_eventsource.async_client import AsyncSSEClient from ld_eventsource.config.async_connect_strategy import AsyncConnectStrategy from ld_eventsource.errors import HTTPContentTypeError, HTTPStatusError -from ld_eventsource.testing.async_helpers import no_delay -from ld_eventsource.testing.helpers import retry_for_status +from ld_eventsource.testing.helpers import no_delay, retry_for_status from ld_eventsource.testing.http_util import (BasicResponse, ChunkedResponse, CauseNetworkError, start_server) diff --git a/ld_eventsource/testing/test_async_sse_client_retry.py b/ld_eventsource/testing/test_async_sse_client_retry.py index f125ec7..e1e710c 100644 --- a/ld_eventsource/testing/test_async_sse_client_retry.py +++ b/ld_eventsource/testing/test_async_sse_client_retry.py @@ -10,9 +10,8 @@ from ld_eventsource.testing.async_helpers import (AsyncExpectNoMoreRequests, AsyncRejectConnection, AsyncRespondWithData, - MockAsyncConnectStrategy, - no_delay, - retry_for_status) + MockAsyncConnectStrategy) +from ld_eventsource.testing.helpers import no_delay, retry_for_status @pytest.mark.asyncio From 069bffeb6ee789112b862d8e45fb9f493503dadb Mon Sep 17 00:00:00 2001 From: Matthew Keeler Date: Wed, 11 Mar 2026 16:46:29 -0400 Subject: [PATCH 3/3] fix isort --- contract-tests/async_service.py | 1 - contract-tests/async_stream_entity.py | 3 ++- ld_eventsource/async_client.py | 5 +++-- ld_eventsource/config/async_connect_strategy.py | 3 ++- ld_eventsource/testing/test_async_http_connect_strategy.py | 4 ++-- ld_eventsource/testing/test_async_reader.py | 3 ++- ld_eventsource/testing/test_async_sse_client_basic.py | 5 +++-- ld_eventsource/testing/test_async_sse_client_retry.py | 6 +++--- 8 files changed, 17 insertions(+), 13 deletions(-) diff --git a/contract-tests/async_service.py b/contract-tests/async_service.py index 8d86be1..0c5a4ed 100644 --- a/contract-tests/async_service.py +++ b/contract-tests/async_service.py @@ -5,7 +5,6 @@ from logging.config import dictConfig import aiohttp.web - from async_stream_entity import AsyncStreamEntity default_port = 8000 diff --git a/contract-tests/async_stream_entity.py b/contract-tests/async_stream_entity.py index ce9628e..6ba5a1d 100644 --- a/contract-tests/async_stream_entity.py +++ b/contract-tests/async_stream_entity.py @@ -11,7 +11,8 @@ sys.path.insert(1, os.path.join(sys.path[0], '..')) from ld_eventsource.actions import Comment, Event, Fault # noqa: E402 from ld_eventsource.async_client import AsyncSSEClient # noqa: E402 -from ld_eventsource.config.async_connect_strategy import AsyncConnectStrategy # noqa: E402 +from ld_eventsource.config.async_connect_strategy import \ + AsyncConnectStrategy # noqa: E402 from ld_eventsource.config.error_strategy import ErrorStrategy # noqa: E402 diff --git a/ld_eventsource/async_client.py b/ld_eventsource/async_client.py index b97d1bd..8a96873 100644 --- a/ld_eventsource/async_client.py +++ b/ld_eventsource/async_client.py @@ -4,9 +4,10 @@ from typing import AsyncIterable, Optional, Union from ld_eventsource.actions import Action, Event, Fault, Start -from ld_eventsource.async_reader import _AsyncBufferedLineReader, _AsyncSSEReader +from ld_eventsource.async_reader import (_AsyncBufferedLineReader, + _AsyncSSEReader) from ld_eventsource.config.async_connect_strategy import ( - AsyncConnectStrategy, AsyncConnectionClient, AsyncConnectionResult) + AsyncConnectionClient, AsyncConnectionResult, AsyncConnectStrategy) from ld_eventsource.config.error_strategy import ErrorStrategy from ld_eventsource.config.retry_delay_strategy import RetryDelayStrategy diff --git a/ld_eventsource/config/async_connect_strategy.py b/ld_eventsource/config/async_connect_strategy.py index 4c09cda..c4014b4 100644 --- a/ld_eventsource/config/async_connect_strategy.py +++ b/ld_eventsource/config/async_connect_strategy.py @@ -45,7 +45,8 @@ def http( :param query_params: optional callable that returns a dict of query params per connection """ # Import here to avoid requiring aiohttp for users who don't use async HTTP - from ld_eventsource.async_http import _AsyncHttpClientImpl, _AsyncHttpConnectParams + from ld_eventsource.async_http import (_AsyncHttpClientImpl, + _AsyncHttpConnectParams) return _AsyncHttpConnectStrategy( _AsyncHttpConnectParams(url, headers, session, aiohttp_request_options, query_params) ) diff --git a/ld_eventsource/testing/test_async_http_connect_strategy.py b/ld_eventsource/testing/test_async_http_connect_strategy.py index 5196c54..497f7d0 100644 --- a/ld_eventsource/testing/test_async_http_connect_strategy.py +++ b/ld_eventsource/testing/test_async_http_connect_strategy.py @@ -6,8 +6,8 @@ from ld_eventsource.config.async_connect_strategy import AsyncConnectStrategy from ld_eventsource.errors import HTTPContentTypeError, HTTPStatusError from ld_eventsource.testing.helpers import no_delay, retry_for_status -from ld_eventsource.testing.http_util import (BasicResponse, ChunkedResponse, - CauseNetworkError, start_server) +from ld_eventsource.testing.http_util import (BasicResponse, CauseNetworkError, + ChunkedResponse, start_server) def logger(): diff --git a/ld_eventsource/testing/test_async_reader.py b/ld_eventsource/testing/test_async_reader.py index 04ba805..05c02ca 100644 --- a/ld_eventsource/testing/test_async_reader.py +++ b/ld_eventsource/testing/test_async_reader.py @@ -1,7 +1,8 @@ import pytest from ld_eventsource.actions import Comment, Event -from ld_eventsource.async_reader import _AsyncBufferedLineReader, _AsyncSSEReader +from ld_eventsource.async_reader import (_AsyncBufferedLineReader, + _AsyncSSEReader) async def lines_from_bytes(*chunks: bytes): diff --git a/ld_eventsource/testing/test_async_sse_client_basic.py b/ld_eventsource/testing/test_async_sse_client_basic.py index 10fd7cf..260deb6 100644 --- a/ld_eventsource/testing/test_async_sse_client_basic.py +++ b/ld_eventsource/testing/test_async_sse_client_basic.py @@ -3,7 +3,7 @@ from ld_eventsource.actions import Comment, Event, Fault, Start from ld_eventsource.async_client import AsyncSSEClient from ld_eventsource.testing.async_helpers import (AsyncRespondWithData, - MockAsyncConnectStrategy) + MockAsyncConnectStrategy) @pytest.mark.asyncio @@ -137,7 +137,8 @@ async def test_close_stops_iteration(): async def test_string_url_creates_http_strategy(): # Just verifies the constructor accepts a string without crashing # (actual HTTP is tested separately) - from ld_eventsource.config.async_connect_strategy import AsyncConnectStrategy + from ld_eventsource.config.async_connect_strategy import \ + AsyncConnectStrategy client = AsyncSSEClient(connect="http://localhost:9999/stream") assert client is not None await client.close() diff --git a/ld_eventsource/testing/test_async_sse_client_retry.py b/ld_eventsource/testing/test_async_sse_client_retry.py index e1e710c..8fc08d4 100644 --- a/ld_eventsource/testing/test_async_sse_client_retry.py +++ b/ld_eventsource/testing/test_async_sse_client_retry.py @@ -8,9 +8,9 @@ from ld_eventsource.config.retry_delay_strategy import RetryDelayStrategy from ld_eventsource.errors import HTTPStatusError from ld_eventsource.testing.async_helpers import (AsyncExpectNoMoreRequests, - AsyncRejectConnection, - AsyncRespondWithData, - MockAsyncConnectStrategy) + AsyncRejectConnection, + AsyncRespondWithData, + MockAsyncConnectStrategy) from ld_eventsource.testing.helpers import no_delay, retry_for_status