diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index 10f593c..f11a75c 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -46,6 +46,12 @@ jobs:
- name: run SSE contract tests
run: make run-contract-tests
+ - name: start async SSE contract test service
+ run: make start-async-contract-test-service-bg
+
+ - name: run async SSE contract tests
+ run: make run-async-contract-tests
+
windows:
runs-on: windows-latest
diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md
index adcf964..89e2ba8 100644
--- a/CONTRIBUTING.md
+++ b/CONTRIBUTING.md
@@ -23,6 +23,12 @@ poetry install
eval $(poetry env activate)
```
+To also install the optional async dependencies (required to use `AsyncSSEClient`):
+
+```
+poetry install --extras async
+```
+
### Testing
To run all unit tests:
@@ -36,6 +42,11 @@ To run the standardized contract tests that are run against all LaunchDarkly SSE
make contract-tests
```
+To run the same contract tests against the async implementation:
+```
+make async-contract-tests
+```
+
### Linting
To run the linter and check type hints:
diff --git a/Makefile b/Makefile
index 716ade2..7b2f851 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,7 @@
PYTEST_FLAGS=-W error::SyntaxWarning
TEMP_TEST_OUTPUT=/tmp/sse-contract-test-service.log
+TEMP_ASYNC_TEST_OUTPUT=/tmp/sse-async-contract-test-service.log
SPHINXOPTS = -W --keep-going
SPHINXBUILD = sphinx-build
@@ -70,3 +71,21 @@ run-contract-tests:
.PHONY: contract-tests
contract-tests: #! Run the SSE contract test harness
contract-tests: install-contract-tests-deps start-contract-test-service-bg run-contract-tests
+
+.PHONY: start-async-contract-test-service
+start-async-contract-test-service:
+ @cd contract-tests && poetry run python async_service.py 8001
+
+.PHONY: start-async-contract-test-service-bg
+start-async-contract-test-service-bg:
+ @echo "Async test service output will be captured in $(TEMP_ASYNC_TEST_OUTPUT)"
+ @make start-async-contract-test-service >$(TEMP_ASYNC_TEST_OUTPUT) 2>&1 &
+
+.PHONY: run-async-contract-tests
+run-async-contract-tests:
+ @curl -s https://raw.githubusercontent.com/launchdarkly/sse-contract-tests/main/downloader/run.sh \
+ | VERSION=v2 PARAMS="-url http://localhost:8001 -debug -stop-service-at-end" sh
+
+.PHONY: async-contract-tests
+async-contract-tests: #! Run the SSE async contract test harness
+async-contract-tests: install-contract-tests-deps start-async-contract-test-service-bg run-async-contract-tests
diff --git a/README.md b/README.md
index 8b2bd5e..ef67bbb 100644
--- a/README.md
+++ b/README.md
@@ -14,7 +14,13 @@ This package's primary purpose is to support the [LaunchDarkly SDK for Python](h
* Setting read timeouts, custom headers, and other HTTP request properties.
* Specifying that connections should be retried under circumstances where the standard EventSource behavior would not retry them, such as if the server returns an HTTP error status.
-This is a synchronous implementation which blocks the caller's thread when reading events or reconnecting. By default, it uses `urllib3` to make HTTP requests, but it can be configured to read any input stream.
+The default `SSEClient` is a synchronous implementation which blocks the caller's thread when reading events or reconnecting. By default, it uses `urllib3` to make HTTP requests, but it can be configured to read any input stream.
+
+An async implementation, `AsyncSSEClient`, is also available for use with `asyncio`-based applications. It uses `aiohttp` for HTTP and requires installing the optional `async` extra:
+
+```
+pip install launchdarkly-eventsource[async]
+```
## Supported Python versions
diff --git a/contract-tests/async_service.py b/contract-tests/async_service.py
new file mode 100644
index 0000000..0c5a4ed
--- /dev/null
+++ b/contract-tests/async_service.py
@@ -0,0 +1,119 @@
+import json
+import logging
+import os
+import sys
+from logging.config import dictConfig
+
+import aiohttp.web
+from async_stream_entity import AsyncStreamEntity
+
+default_port = 8000
+
+dictConfig({
+ 'version': 1,
+ 'formatters': {
+ 'default': {
+ 'format': '[%(asctime)s] [%(name)s] %(levelname)s: %(message)s',
+ }
+ },
+ 'handlers': {
+ 'console': {
+ 'class': 'logging.StreamHandler',
+ 'formatter': 'default'
+ }
+ },
+ 'root': {
+ 'level': 'INFO',
+ 'handlers': ['console']
+ },
+})
+
+global_log = logging.getLogger('testservice')
+
+stream_counter = 0
+streams = {}
+
+
+async def handle_get_status(request):
+ body = {
+ 'capabilities': [
+ 'comments',
+ 'headers',
+ 'last-event-id',
+ 'read-timeout',
+ ]
+ }
+ return aiohttp.web.Response(
+ body=json.dumps(body),
+ content_type='application/json',
+ )
+
+
+async def handle_delete_stop(request):
+ global_log.info("Test service has told us to exit")
+ os._exit(0)
+
+
+async def handle_post_create_stream(request):
+ global stream_counter, streams
+
+ options = json.loads(await request.read())
+
+ stream_counter += 1
+ stream_id = str(stream_counter)
+ resource_url = '/streams/%s' % stream_id
+
+ stream = AsyncStreamEntity(options, request.app['http_session'])
+ streams[stream_id] = stream
+
+ return aiohttp.web.Response(status=201, headers={'Location': resource_url})
+
+
+async def handle_post_stream_command(request):
+ stream_id = request.match_info['id']
+ params = json.loads(await request.read())
+
+ stream = streams.get(stream_id)
+ if stream is None:
+ return aiohttp.web.Response(status=404)
+ if not await stream.do_command(params.get('command')):
+ return aiohttp.web.Response(status=400)
+ return aiohttp.web.Response(status=204)
+
+
+async def handle_delete_stream(request):
+ stream_id = request.match_info['id']
+
+ stream = streams.get(stream_id)
+ if stream is None:
+ return aiohttp.web.Response(status=404)
+ await stream.close()
+ return aiohttp.web.Response(status=204)
+
+
+async def on_startup(app):
+ app['http_session'] = aiohttp.ClientSession()
+
+
+async def on_cleanup(app):
+ await app['http_session'].close()
+
+
+def make_app():
+ app = aiohttp.web.Application()
+ app.router.add_get('/', handle_get_status)
+ app.router.add_delete('/', handle_delete_stop)
+ app.router.add_post('/', handle_post_create_stream)
+ app.router.add_post('/streams/{id}', handle_post_stream_command)
+ app.router.add_delete('/streams/{id}', handle_delete_stream)
+ app.on_startup.append(on_startup)
+ app.on_cleanup.append(on_cleanup)
+ return app
+
+
+if __name__ == "__main__":
+ port = default_port
+ if sys.argv[len(sys.argv) - 1] != 'async_service.py':
+ port = int(sys.argv[len(sys.argv) - 1])
+ global_log.info('Listening on port %d', port)
+ aiohttp.web.run_app(make_app(), host='0.0.0.0', port=port)
diff --git a/contract-tests/async_stream_entity.py b/contract-tests/async_stream_entity.py
new file mode 100644
index 0000000..6ba5a1d
--- /dev/null
+++ b/contract-tests/async_stream_entity.py
@@ -0,0 +1,116 @@
+import asyncio
+import json
+import logging
+import os
+import sys
+import traceback
+
+import aiohttp
+
+# Import ld_eventsource from parent directory
+sys.path.insert(1, os.path.join(sys.path[0], '..'))
+from ld_eventsource.actions import Comment, Event, Fault # noqa: E402
+from ld_eventsource.async_client import AsyncSSEClient # noqa: E402
+from ld_eventsource.config.async_connect_strategy import \
+ AsyncConnectStrategy # noqa: E402
+from ld_eventsource.config.error_strategy import ErrorStrategy # noqa: E402
+
+
+def millis_to_seconds(t):
+ return None if t is None else t / 1000
+
+
+class AsyncStreamEntity:
+ def __init__(self, options, http_session: aiohttp.ClientSession):
+ self.options = options
+ self.callback_url = options["callbackUrl"]
+ self.log = logging.getLogger(options["tag"])
+ self.closed = False
+ self.callback_counter = 0
+ self.sse = None
+ self._http_session = http_session
+ asyncio.create_task(self.run())
+
+ async def run(self):
+ stream_url = self.options["streamUrl"]
+ try:
+ self.log.info('Opening stream from %s', stream_url)
+
+ request_options = {}
+ if self.options.get("readTimeoutMs") is not None:
+ request_options["timeout"] = aiohttp.ClientTimeout(
+ sock_read=millis_to_seconds(self.options.get("readTimeoutMs"))
+ )
+
+ connect = AsyncConnectStrategy.http(
+ url=stream_url,
+ headers=self.options.get("headers"),
+ aiohttp_request_options=request_options if request_options else None,
+ )
+ sse = AsyncSSEClient(
+ connect,
+ initial_retry_delay=millis_to_seconds(self.options.get("initialDelayMs")),
+ last_event_id=self.options.get("lastEventId"),
+ error_strategy=ErrorStrategy.from_lambda(
+ lambda _: (
+ ErrorStrategy.FAIL if self.closed else ErrorStrategy.CONTINUE,
+ None,
+ )
+ ),
+ logger=self.log,
+ )
+ self.sse = sse
+ async for item in sse.all:
+ if isinstance(item, Event):
+ self.log.info('Received event from stream (%s)', item.event)
+ await self.send_message(
+ {
+ 'kind': 'event',
+ 'event': {
+ 'type': item.event,
+ 'data': item.data,
+ 'id': item.last_event_id,
+ },
+ }
+ )
+ elif isinstance(item, Comment):
+ self.log.info('Received comment from stream: %s', item.comment)
+ await self.send_message({'kind': 'comment', 'comment': item.comment})
+ elif isinstance(item, Fault):
+ if self.closed:
+ break
+ if item.error:
+ self.log.info('Received error from stream: %s', item.error)
+ await self.send_message({'kind': 'error', 'error': str(item.error)})
+ except Exception as e:
+ self.log.info('Received error from stream: %s', e)
+ self.log.info(traceback.format_exc())
+ await self.send_message({'kind': 'error', 'error': str(e)})
+
+ async def do_command(self, command: str) -> bool:
+ self.log.info('Test service sent command: %s' % command)
+ # currently we support no special commands
+ return False
+
+ async def send_message(self, message):
+ if self.closed:
+ return
+ self.callback_counter += 1
+ callback_url = "%s/%d" % (self.callback_url, self.callback_counter)
+ try:
+ async with self._http_session.post(
+ callback_url,
+ data=json.dumps(message),
+ headers={'Content-Type': 'application/json'},
+ ) as resp:
+ if resp.status >= 300 and not self.closed:
+ self.log.error('Callback request returned HTTP error %d', resp.status)
+ except Exception as e:
+ if not self.closed:
+ self.log.error('Callback request failed: %s', e)
+
+ async def close(self):
+ self.closed = True
+ if self.sse is not None:
+ await self.sse.close()
+ self.log.info('Test ended')
diff --git a/docs/conf.py b/docs/conf.py
index c3f71d2..186a344 100644
--- a/docs/conf.py
+++ b/docs/conf.py
@@ -170,3 +170,6 @@
autodoc_default_options = {
'undoc-members': False
}
+
+# aiohttp is an optional dependency not installed during doc builds
+autodoc_mock_imports = ['aiohttp']
diff --git a/docs/index.rst b/docs/index.rst
index 06dd66a..fa33901 100644
--- a/docs/index.rst
+++ b/docs/index.rst
@@ -3,7 +3,6 @@ LaunchDarkly Python SSE Client
This is the API reference for the `launchdarkly-eventsource `_ package, a `Server-Sent Events `_ client for Python. This package is used internally by the `LaunchDarkly Python SDK `_, but may also be useful for other purposes.
-
ld_eventsource module
---------------------
@@ -37,3 +36,21 @@ ld_eventsource.errors module
:members:
:special-members: __init__
:show-inheritance:
+
+
+ld_eventsource.async_client module
+-----------------------------------
+
+.. automodule:: ld_eventsource.async_client
+ :members:
+ :special-members: __init__
+ :show-inheritance:
+
+
+ld_eventsource.config.async_connect_strategy module
+----------------------------------------------------
+
+.. automodule:: ld_eventsource.config.async_connect_strategy
+ :members:
+ :special-members: __init__
+ :show-inheritance:
diff --git a/ld_eventsource/__init__.py b/ld_eventsource/__init__.py
index 8e88c11..c4a208d 100644
--- a/ld_eventsource/__init__.py
+++ b/ld_eventsource/__init__.py
@@ -1 +1,10 @@
from ld_eventsource.sse_client import *
+
+
+def __getattr__(name):
+ # Lazily import AsyncSSEClient so that aiohttp (an optional dependency)
+ # is never imported for sync-only users who don't have it installed.
+ if name == 'AsyncSSEClient':
+ from ld_eventsource.async_client import AsyncSSEClient
+ return AsyncSSEClient
+ raise AttributeError(f"module 'ld_eventsource' has no attribute {name!r}")
diff --git a/ld_eventsource/async_client.py b/ld_eventsource/async_client.py
new file mode 100644
index 0000000..8a96873
--- /dev/null
+++ b/ld_eventsource/async_client.py
@@ -0,0 +1,238 @@
+import asyncio
+import logging
+import time
+from typing import AsyncIterable, Optional, Union
+
+from ld_eventsource.actions import Action, Event, Fault, Start
+from ld_eventsource.async_reader import (_AsyncBufferedLineReader,
+ _AsyncSSEReader)
+from ld_eventsource.config.async_connect_strategy import (
+ AsyncConnectionClient, AsyncConnectionResult, AsyncConnectStrategy)
+from ld_eventsource.config.error_strategy import ErrorStrategy
+from ld_eventsource.config.retry_delay_strategy import RetryDelayStrategy
+
+
+class AsyncSSEClient:
+ """
+ An async client for reading a Server-Sent Events stream.
+
+ This is an async/await implementation. The expected usage is to create an ``AsyncSSEClient``
+ instance (either as an async context manager or directly), then read from it using the async
+ iterator properties :attr:`events` or :attr:`all`.
+
+ By default, ``AsyncSSEClient`` uses ``aiohttp`` to make HTTP requests to an SSE endpoint.
+ You can customize this behavior using :class:`.AsyncConnectStrategy`.
+
+ Connection failures and error responses can be handled in various ways depending on the
+ constructor parameters. The default behavior is the same as :class:`.SSEClient`.
+
+ Example::
+
+ async with AsyncSSEClient("https://my-server/events") as client:
+ async for event in client.events:
+ print(event.data)
+ """
+
+ def __init__(
+ self,
+ connect: Union[str, AsyncConnectStrategy],
+ initial_retry_delay: float = 1,
+ retry_delay_strategy: Optional[RetryDelayStrategy] = None,
+ retry_delay_reset_threshold: float = 60,
+ error_strategy: Optional[ErrorStrategy] = None,
+ last_event_id: Optional[str] = None,
+ logger: Optional[logging.Logger] = None,
+ ):
+ """
+ Creates an async client instance.
+
+ :param connect: either an :class:`.AsyncConnectStrategy` instance or a URL string
+ :param initial_retry_delay: the initial delay before reconnecting after a failure, in seconds
+ :param retry_delay_strategy: allows customization of the delay behavior for retries
+ :param retry_delay_reset_threshold: minimum connection time before resetting retry delay
+ :param error_strategy: allows customization of the behavior after a stream failure
+ :param last_event_id: if provided, the ``Last-Event-Id`` value will be preset to this
+ :param logger: if provided, log messages will be written here
+ """
+ if isinstance(connect, str):
+ connect = AsyncConnectStrategy.http(connect)
+ elif not isinstance(connect, AsyncConnectStrategy):
+ raise TypeError("connect must be either a string or AsyncConnectStrategy")
+
+ self.__base_retry_delay = initial_retry_delay
+ self.__base_retry_delay_strategy = (
+ retry_delay_strategy or RetryDelayStrategy.default()
+ )
+ self.__retry_delay_reset_threshold = retry_delay_reset_threshold
+ self.__current_retry_delay_strategy = self.__base_retry_delay_strategy
+ self.__next_retry_delay = 0
+
+ self.__base_error_strategy = error_strategy or ErrorStrategy.always_fail()
+ self.__current_error_strategy = self.__base_error_strategy
+
+ self.__last_event_id = last_event_id
+
+ if logger is None:
+ logger = logging.getLogger('launchdarkly-eventsource-async.null')
+ logger.addHandler(logging.NullHandler())
+ logger.propagate = False
+ self.__logger = logger
+
+ self.__connection_client: AsyncConnectionClient = connect.create_client(logger)
+ self.__connection_result: Optional[AsyncConnectionResult] = None
+ self._retry_reset_baseline: float = 0
+ self.__disconnected_time: float = 0
+
+ self.__closed = False
+ self.__interrupted = False
+
+ async def start(self):
+ """
+ Attempts to start the stream if it is not already active.
+ """
+ await self._try_start(False)
+
+ async def close(self):
+ """
+ Permanently shuts down this client instance and closes any active connection.
+ """
+ self.__closed = True
+ await self.interrupt()
+ await self.__connection_client.close()
+
+ async def interrupt(self):
+ """
+ Stops the stream connection if it is currently active, without permanently closing.
+ """
+ if self.__connection_result:
+ self.__interrupted = True
+ await self.__connection_result.close()
+ self.__connection_result = None
+ self._compute_next_retry_delay()
+
+ @property
+ def all(self) -> AsyncIterable[Action]:
+ """
+ An async iterable series of notifications from the stream.
+
+ Each can be any subclass of :class:`.Action`: :class:`.Event`, :class:`.Comment`,
+ :class:`.Start`, or :class:`.Fault`.
+ """
+ return self._all_generator()
+
+ @property
+ def events(self) -> AsyncIterable[Event]:
+ """
+ An async iterable series of :class:`.Event` objects received from the stream.
+ """
+ return self._events_generator()
+
+ async def _all_generator(self):
+ while True:
+ while self.__connection_result is None:
+ result = await self._try_start(True)
+ if result is not None:
+ yield result
+
+ lines = _AsyncBufferedLineReader.lines_from(self.__connection_result.stream)
+ reader = _AsyncSSEReader(lines, self.__last_event_id, None)
+ error: Optional[Exception] = None
+ try:
+ async for ec in reader.events_and_comments():
+ self.__last_event_id = reader.last_event_id
+ yield ec
+ if self.__interrupted:
+ break
+ self.__connection_result = None
+ except Exception as e:
+ if self.__closed:
+ return
+ error = e
+ self.__connection_result = None
+ finally:
+ self.__last_event_id = reader.last_event_id
+
+ self._compute_next_retry_delay()
+ fail_or_continue, self.__current_error_strategy = (
+ self.__current_error_strategy.apply(error)
+ )
+ if fail_or_continue == ErrorStrategy.FAIL:
+ if error is None:
+ yield Fault(None)
+ return
+ raise error
+ yield Fault(error)
+ continue
+
+ async def _events_generator(self):
+ async for item in self._all_generator():
+ if isinstance(item, Event):
+ yield item
+
+ @property
+ def next_retry_delay(self) -> float:
+ """
+ The retry delay that will be used for the next reconnection, in seconds.
+ """
+ return self.__next_retry_delay
+
+ def _compute_next_retry_delay(self):
+ if self.__retry_delay_reset_threshold > 0 and self._retry_reset_baseline != 0:
+ now = time.time()
+ connection_duration = now - self._retry_reset_baseline
+ if connection_duration >= self.__retry_delay_reset_threshold:
+ self.__current_retry_delay_strategy = self.__base_retry_delay_strategy
+ self._retry_reset_baseline = now
+ self.__next_retry_delay, self.__current_retry_delay_strategy = (
+ self.__current_retry_delay_strategy.apply(self.__base_retry_delay)
+ )
+
+ async def _try_start(self, can_return_fault: bool):
+ if self.__connection_result is not None:
+ return None
+ while True:
+ if self.__next_retry_delay > 0:
+ delay = (
+ self.__next_retry_delay
+ if self.__disconnected_time == 0
+ else self.__next_retry_delay
+ - (time.time() - self.__disconnected_time)
+ )
+ if delay > 0:
+ self.__logger.info("Will reconnect after delay of %fs" % delay)
+ await asyncio.sleep(delay)
+ try:
+ self.__connection_result = await self.__connection_client.connect(
+ self.__last_event_id
+ )
+ except Exception as e:
+ self.__disconnected_time = time.time()
+ self._compute_next_retry_delay()
+ fail_or_continue, self.__current_error_strategy = (
+ self.__current_error_strategy.apply(e)
+ )
+ if fail_or_continue == ErrorStrategy.FAIL:
+ raise e
+ if can_return_fault:
+ return Fault(e)
+ continue
+ self._retry_reset_baseline = time.time()
+ self.__current_error_strategy = self.__base_error_strategy
+ self.__interrupted = False
+ return Start(self.__connection_result.headers)
+
+ @property
+ def last_event_id(self) -> Optional[str]:
+ """
+ The ID value, if any, of the last known event.
+ """
+ return self.__last_event_id
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, type, value, traceback):
+ await self.close()
+
+
+__all__ = ['AsyncSSEClient']
diff --git a/ld_eventsource/async_http.py b/ld_eventsource/async_http.py
new file mode 100644
index 0000000..c8cb5fa
--- /dev/null
+++ b/ld_eventsource/async_http.py
@@ -0,0 +1,120 @@
+import asyncio
+from logging import Logger
+from typing import Any, AsyncIterator, Callable, Dict, Optional, Tuple
+from urllib.parse import parse_qsl, urlencode, urlsplit, urlunsplit
+
+import aiohttp
+
+from ld_eventsource.errors import HTTPContentTypeError, HTTPStatusError
+
+_CHUNK_SIZE = 10000
+
+
+class _AsyncHttpConnectParams:
+ def __init__(
+ self,
+ url: str,
+ headers: Optional[dict] = None,
+ session: Optional[aiohttp.ClientSession] = None,
+ aiohttp_request_options: Optional[dict] = None,
+ query_params=None,
+ ):
+ self.__url = url
+ self.__headers = headers
+ self.__session = session
+ self.__aiohttp_request_options = aiohttp_request_options
+ self.__query_params = query_params
+
+ @property
+ def url(self) -> str:
+ return self.__url
+
+ @property
+ def headers(self) -> Optional[dict]:
+ return self.__headers
+
+ @property
+ def session(self) -> Optional[aiohttp.ClientSession]:
+ return self.__session
+
+ @property
+ def aiohttp_request_options(self) -> Optional[dict]:
+ return self.__aiohttp_request_options
+
+ @property
+ def query_params(self):
+ return self.__query_params
+
+
+class _AsyncHttpClientImpl:
+ def __init__(self, params: _AsyncHttpConnectParams, logger: Logger):
+ self.__params = params
+ self.__external_session = params.session
+ self.__session: Optional[aiohttp.ClientSession] = params.session
+ self.__session_lock = asyncio.Lock()
+ self.__logger = logger
+
+ async def _get_session(self) -> aiohttp.ClientSession:
+ if self.__session is not None:
+ return self.__session
+ async with self.__session_lock:
+ if self.__session is None:
+ self.__session = aiohttp.ClientSession()
+ return self.__session
+
+ async def connect(
+ self, last_event_id: Optional[str]
+ ) -> Tuple[AsyncIterator[bytes], Callable, Dict[str, Any]]:
+ url = self.__params.url
+ if self.__params.query_params is not None:
+ qp = self.__params.query_params()
+ if qp:
+ url_parts = list(urlsplit(url))
+ query = dict(parse_qsl(url_parts[3]))
+ query.update(qp)
+ url_parts[3] = urlencode(query)
+ url = urlunsplit(url_parts)
+ self.__logger.info("Connecting to stream at %s" % url)
+
+ headers = self.__params.headers.copy() if self.__params.headers else {}
+ headers['Cache-Control'] = 'no-cache'
+ headers['Accept'] = 'text/event-stream'
+
+ if last_event_id:
+ headers['Last-Event-ID'] = last_event_id
+
+ request_options = (
+ self.__params.aiohttp_request_options.copy()
+ if self.__params.aiohttp_request_options
+ else {}
+ )
+ request_options['headers'] = headers
+
+ session = await self._get_session()
+ resp = await session.get(url, **request_options)
+
+ response_headers: Dict[str, Any] = dict(resp.headers)
+
+ if resp.status >= 400 or resp.status == 204:
+ await resp.release()
+ raise HTTPStatusError(resp.status, response_headers)
+
+ content_type = resp.headers.get('Content-Type', None)
+ if content_type is None or not str(content_type).startswith("text/event-stream"):
+ await resp.release()
+ raise HTTPContentTypeError(content_type or '', response_headers)
+
+ async def chunk_iterator() -> AsyncIterator[bytes]:
+ async for chunk in resp.content.iter_chunked(_CHUNK_SIZE):
+ yield chunk
+
+ async def closer():
+ await resp.release()
+
+ return chunk_iterator(), closer, response_headers
+
+ async def close(self):
+ # Only close the session if we created it ourselves
+ if self.__external_session is None and self.__session is not None:
+ await self.__session.close()
+ self.__session = None
diff --git a/ld_eventsource/async_reader.py b/ld_eventsource/async_reader.py
new file mode 100644
index 0000000..d71aa9b
--- /dev/null
+++ b/ld_eventsource/async_reader.py
@@ -0,0 +1,101 @@
+from typing import AsyncIterator, Callable, Optional
+
+from ld_eventsource.actions import Comment, Event
+
+
+class _AsyncBufferedLineReader:
+ """
+ Async version of _BufferedLineReader. Reads UTF-8 stream data as a series of text lines,
+ each of which can be terminated by \n, \r, or \r\n.
+ """
+
+ @staticmethod
+ async def lines_from(chunks: AsyncIterator[bytes]) -> AsyncIterator[str]:
+ last_char_was_cr = False
+ partial_line = None
+
+ async for chunk in chunks:
+ if len(chunk) == 0:
+ continue
+
+ lines = chunk.splitlines()
+ if last_char_was_cr:
+ last_char_was_cr = False
+ if chunk[0] == 10:
+ lines.pop(0)
+ if len(lines) == 0:
+ continue
+ if partial_line is not None:
+ lines[0] = partial_line + lines[0]
+ partial_line = None
+ last_char = chunk[-1]
+ if last_char == 13:
+ last_char_was_cr = True
+ elif last_char != 10:
+ partial_line = lines.pop()
+ for line in lines:
+ yield line.decode()
+
+
+class _AsyncSSEReader:
+ def __init__(
+ self,
+ lines_source: AsyncIterator[str],
+ last_event_id: Optional[str] = None,
+ set_retry: Optional[Callable[[int], None]] = None,
+ ):
+ self._lines_source = lines_source
+ self._last_event_id = last_event_id
+ self._set_retry = set_retry
+
+ @property
+ def last_event_id(self):
+ return self._last_event_id
+
+ async def events_and_comments(self) -> AsyncIterator:
+ event_type = ""
+ event_data = None
+ event_id = None
+ async for line in self._lines_source:
+ if line == "":
+ if event_data is not None:
+ if event_id is not None:
+ self._last_event_id = event_id
+ yield Event(
+ "message" if event_type == "" else event_type,
+ event_data,
+ event_id,
+ self._last_event_id,
+ )
+ event_type = ""
+ event_data = None
+ event_id = None
+ continue
+ colon_pos = line.find(':')
+ if colon_pos == 0:
+ yield Comment(line[1:])
+ continue
+ if colon_pos < 0:
+ name = line
+ value = ""
+ else:
+ name = line[:colon_pos]
+ if colon_pos < (len(line) - 1) and line[colon_pos + 1] == ' ':
+ colon_pos += 1
+ value = line[colon_pos + 1:]
+ if name == 'event':
+ event_type = value
+ elif name == 'data':
+ event_data = (
+ value if event_data is None else (event_data + "\n" + value)
+ )
+ elif name == 'id':
+ if value.find("\x00") < 0:
+ event_id = value
+ elif name == 'retry':
+ try:
+ n = int(value)
+ if self._set_retry:
+ self._set_retry(n)
+ except Exception:
+ pass
diff --git a/ld_eventsource/config/async_connect_strategy.py b/ld_eventsource/config/async_connect_strategy.py
new file mode 100644
index 0000000..c4014b4
--- /dev/null
+++ b/ld_eventsource/config/async_connect_strategy.py
@@ -0,0 +1,149 @@
+from __future__ import annotations
+
+from logging import Logger
+from typing import Any, AsyncIterator, Callable, Dict, Optional
+
+
+class AsyncConnectStrategy:
+ """
+ An abstraction for how :class:`.AsyncSSEClient` should obtain an input stream.
+
+ The default implementation is :meth:`http()`, which makes HTTP requests with ``aiohttp``.
+ Or, if you want to consume an input stream from some other source, you can create your own
+ subclass of :class:`AsyncConnectStrategy`.
+
+ Instances of this class should be immutable and should not contain any state that is specific
+ to one active stream. The :class:`AsyncConnectionClient` that they produce is stateful and
+ belongs to a single :class:`.AsyncSSEClient`.
+ """
+
+ def create_client(self, logger: Logger) -> AsyncConnectionClient:
+ """
+ Creates a client instance.
+
+ This is called once when an :class:`.AsyncSSEClient` is created.
+
+ :param logger: the logger being used by the AsyncSSEClient
+ """
+ raise NotImplementedError("AsyncConnectStrategy base class cannot be used by itself")
+
+ @staticmethod
+ def http(
+ url: str,
+ headers: Optional[dict] = None,
+ session=None,
+ aiohttp_request_options: Optional[dict] = None,
+ query_params=None,
+ ) -> AsyncConnectStrategy:
+ """
+ Creates the default async HTTP implementation using aiohttp.
+
+ :param url: the stream URL
+ :param headers: optional HTTP headers to add to the request
+ :param session: optional ``aiohttp.ClientSession`` to use
+ :param aiohttp_request_options: optional kwargs passed to the aiohttp ``get()`` call
+ :param query_params: optional callable that returns a dict of query params per connection
+ """
+ # Import here to avoid requiring aiohttp for users who don't use async HTTP
+ from ld_eventsource.async_http import (_AsyncHttpClientImpl,
+ _AsyncHttpConnectParams)
+ return _AsyncHttpConnectStrategy(
+ _AsyncHttpConnectParams(url, headers, session, aiohttp_request_options, query_params)
+ )
+
+
+class AsyncConnectionClient:
+ """
+ An object provided by :class:`.AsyncConnectStrategy` that is retained by a single
+ :class:`.AsyncSSEClient` to perform all connection attempts by that instance.
+ """
+
+ async def connect(self, last_event_id: Optional[str]) -> AsyncConnectionResult:
+ """
+ Attempts to connect to a stream. Raises an exception if unsuccessful.
+
+ :param last_event_id: the current value of last_event_id (sent to server for resuming)
+ :return: an :class:`AsyncConnectionResult` representing the stream
+ """
+ raise NotImplementedError("AsyncConnectionClient base class cannot be used by itself")
+
+ async def close(self):
+ """
+ Does whatever is necessary to release resources when the AsyncSSEClient is closed.
+ """
+ pass
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, type, value, traceback):
+ await self.close()
+
+
+class AsyncConnectionResult:
+ """
+ The return type of :meth:`AsyncConnectionClient.connect()`.
+ """
+
+ def __init__(
+ self,
+ stream: AsyncIterator[bytes],
+ closer: Optional[Callable],
+ headers: Optional[Dict[str, Any]] = None,
+ ):
+ self.__stream = stream
+ self.__closer = closer
+ self.__headers = headers
+
+ @property
+ def stream(self) -> AsyncIterator[bytes]:
+ """
+ An async iterator that returns chunks of data.
+ """
+ return self.__stream
+
+ @property
+ def headers(self) -> Optional[Dict[str, Any]]:
+ """
+ The HTTP response headers, if available.
+ """
+ return self.__headers
+
+ async def close(self):
+ """
+ Does whatever is necessary to release the connection.
+ """
+ if self.__closer:
+ await self.__closer()
+ self.__closer = None
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, type, value, traceback):
+ await self.close()
+
+
+class _AsyncHttpConnectStrategy(AsyncConnectStrategy):
+ def __init__(self, params):
+ self.__params = params
+
+ def create_client(self, logger: Logger) -> AsyncConnectionClient:
+ from ld_eventsource.async_http import _AsyncHttpClientImpl
+ return _AsyncHttpConnectionClient(self.__params, logger)
+
+
+class _AsyncHttpConnectionClient(AsyncConnectionClient):
+ def __init__(self, params, logger: Logger):
+ from ld_eventsource.async_http import _AsyncHttpClientImpl
+ self.__impl = _AsyncHttpClientImpl(params, logger)
+
+ async def connect(self, last_event_id: Optional[str]) -> AsyncConnectionResult:
+ stream, closer, headers = await self.__impl.connect(last_event_id)
+ return AsyncConnectionResult(stream, closer, headers)
+
+ async def close(self):
+ await self.__impl.close()
+
+
+__all__ = ['AsyncConnectStrategy', 'AsyncConnectionClient', 'AsyncConnectionResult']
diff --git a/ld_eventsource/testing/async_helpers.py b/ld_eventsource/testing/async_helpers.py
new file mode 100644
index 0000000..6bfbd36
--- /dev/null
+++ b/ld_eventsource/testing/async_helpers.py
@@ -0,0 +1,71 @@
+from __future__ import annotations
+
+import asyncio
+from logging import Logger
+from typing import AsyncIterable, AsyncIterator, List, Optional
+
+from ld_eventsource.config.async_connect_strategy import (
+ AsyncConnectionClient, AsyncConnectionResult, AsyncConnectStrategy)
+
+
+class MockAsyncConnectStrategy(AsyncConnectStrategy):
+ def __init__(self, *request_handlers: MockAsyncConnectionHandler):
+ self.__handlers = list(request_handlers)
+
+ def create_client(self, logger: Logger) -> AsyncConnectionClient:
+ return MockAsyncConnectionClient(self.__handlers)
+
+
+class MockAsyncConnectionClient(AsyncConnectionClient):
+ def __init__(self, handlers: List[MockAsyncConnectionHandler]):
+ self.__handlers = handlers
+ self.__request_count = 0
+
+ async def connect(self, last_event_id: Optional[str]) -> AsyncConnectionResult:
+ handler = self.__handlers[self.__request_count]
+ if self.__request_count < len(self.__handlers) - 1:
+ self.__request_count += 1
+ return await handler.apply()
+
+
+class MockAsyncConnectionHandler:
+ async def apply(self) -> AsyncConnectionResult:
+ raise NotImplementedError(
+ "MockAsyncConnectionHandler base class cannot be used by itself"
+ )
+
+
+class AsyncRejectConnection(MockAsyncConnectionHandler):
+ def __init__(self, error: Exception):
+ self.__error = error
+
+ async def apply(self) -> AsyncConnectionResult:
+ raise self.__error
+
+
+class AsyncRespondWithStream(MockAsyncConnectionHandler):
+ def __init__(self, stream: AsyncIterable[bytes], headers: Optional[dict] = None):
+ self.__stream = stream
+ self.__headers = headers
+
+ async def apply(self) -> AsyncConnectionResult:
+ return AsyncConnectionResult(
+ stream=self.__stream.__aiter__(),
+ closer=None,
+ headers=self.__headers,
+ )
+
+
+class AsyncRespondWithData(AsyncRespondWithStream):
+ def __init__(self, data: str, headers: Optional[dict] = None):
+ super().__init__(_bytes_async_iter([bytes(data, 'utf-8')]), headers)
+
+
+class AsyncExpectNoMoreRequests(MockAsyncConnectionHandler):
+ async def apply(self) -> AsyncConnectionResult:
+ assert False, "AsyncSSEClient should not have made another request"
+
+
+async def _bytes_async_iter(items: List[bytes]) -> AsyncIterator[bytes]:
+ for item in items:
+ yield item
diff --git a/ld_eventsource/testing/test_async_http_connect_strategy.py b/ld_eventsource/testing/test_async_http_connect_strategy.py
new file mode 100644
index 0000000..497f7d0
--- /dev/null
+++ b/ld_eventsource/testing/test_async_http_connect_strategy.py
@@ -0,0 +1,163 @@
+import logging
+
+import pytest
+
+from ld_eventsource.async_client import AsyncSSEClient
+from ld_eventsource.config.async_connect_strategy import AsyncConnectStrategy
+from ld_eventsource.errors import HTTPContentTypeError, HTTPStatusError
+from ld_eventsource.testing.helpers import no_delay, retry_for_status
+from ld_eventsource.testing.http_util import (BasicResponse, CauseNetworkError,
+ ChunkedResponse, start_server)
+
+
+def logger():
+ return logging.getLogger("test")
+
+
+@pytest.mark.asyncio
+async def test_http_request_gets_chunked_data():
+ with start_server() as server:
+ with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream:
+ server.for_path('/', stream)
+ strategy = AsyncConnectStrategy.http(server.uri)
+ client_obj = strategy.create_client(logger())
+ result = await client_obj.connect(None)
+ try:
+ stream.push('hello')
+ chunk = await result.stream.__anext__()
+ assert chunk == b'hello'
+ finally:
+ await result.close()
+ await client_obj.close()
+
+
+@pytest.mark.asyncio
+async def test_http_request_default_headers():
+ with start_server() as server:
+ with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream:
+ server.for_path('/', stream)
+ strategy = AsyncConnectStrategy.http(server.uri)
+ client_obj = strategy.create_client(logger())
+ result = await client_obj.connect(None)
+ try:
+ r = server.await_request()
+ assert r.headers['Accept'] == 'text/event-stream'
+ assert r.headers['Cache-Control'] == 'no-cache'
+ assert r.headers.get('Last-Event-Id') is None
+ finally:
+ await result.close()
+ await client_obj.close()
+
+
+@pytest.mark.asyncio
+async def test_http_request_custom_headers():
+ with start_server() as server:
+ with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream:
+ server.for_path('/', stream)
+ strategy = AsyncConnectStrategy.http(server.uri, headers={'name1': 'value1'})
+ client_obj = strategy.create_client(logger())
+ result = await client_obj.connect(None)
+ try:
+ r = server.await_request()
+ assert r.headers['Accept'] == 'text/event-stream'
+ assert r.headers['Cache-Control'] == 'no-cache'
+ assert r.headers['name1'] == 'value1'
+ finally:
+ await result.close()
+ await client_obj.close()
+
+
+@pytest.mark.asyncio
+async def test_http_request_last_event_id_header():
+ with start_server() as server:
+ with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream:
+ server.for_path('/', stream)
+ strategy = AsyncConnectStrategy.http(server.uri)
+ client_obj = strategy.create_client(logger())
+ result = await client_obj.connect('id123')
+ try:
+ r = server.await_request()
+ assert r.headers['Last-Event-Id'] == 'id123'
+ finally:
+ await result.close()
+ await client_obj.close()
+
+
+@pytest.mark.asyncio
+async def test_http_status_error():
+ with start_server() as server:
+ server.for_path('/', BasicResponse(400))
+ strategy = AsyncConnectStrategy.http(server.uri)
+ client_obj = strategy.create_client(logger())
+ try:
+ with pytest.raises(HTTPStatusError) as exc_info:
+ await client_obj.connect(None)
+ assert exc_info.value.status == 400
+ finally:
+ await client_obj.close()
+
+
+@pytest.mark.asyncio
+async def test_http_content_type_error():
+ with start_server() as server:
+ with ChunkedResponse({'Content-Type': 'text/plain'}) as stream:
+ server.for_path('/', stream)
+ strategy = AsyncConnectStrategy.http(server.uri)
+ client_obj = strategy.create_client(logger())
+ try:
+ with pytest.raises(HTTPContentTypeError) as exc_info:
+ await client_obj.connect(None)
+ assert exc_info.value.content_type == "text/plain"
+ finally:
+ await client_obj.close()
+
+
+@pytest.mark.asyncio
+async def test_http_response_headers_captured():
+ with start_server() as server:
+ custom_headers = {
+ 'Content-Type': 'text/event-stream',
+ 'X-Custom-Header': 'custom-value',
+ }
+ with ChunkedResponse(custom_headers) as stream:
+ server.for_path('/', stream)
+ strategy = AsyncConnectStrategy.http(server.uri)
+ client_obj = strategy.create_client(logger())
+ result = await client_obj.connect(None)
+ try:
+ assert result.headers is not None
+ assert result.headers.get('X-Custom-Header') == 'custom-value'
+ finally:
+ await result.close()
+ await client_obj.close()
+
+
+@pytest.mark.asyncio
+async def test_http_status_error_includes_headers():
+ with start_server() as server:
+ server.for_path('/', BasicResponse(429, None, {
+ 'Retry-After': '120',
+ }))
+ strategy = AsyncConnectStrategy.http(server.uri)
+ client_obj = strategy.create_client(logger())
+ try:
+ with pytest.raises(HTTPStatusError) as exc_info:
+ await client_obj.connect(None)
+ assert exc_info.value.status == 429
+ assert exc_info.value.headers is not None
+ assert exc_info.value.headers.get('Retry-After') == '120'
+ finally:
+ await client_obj.close()
+
+
+@pytest.mark.asyncio
+async def test_sse_client_with_http_connect_strategy():
+ with start_server() as server:
+ with ChunkedResponse({'Content-Type': 'text/event-stream'}) as stream:
+ server.for_path('/', stream)
+ async with AsyncSSEClient(connect=AsyncConnectStrategy.http(server.uri)) as client:
+ await client.start()
+ stream.push("data: data1\n\n")
+ async for event in client.events:
+ assert event.data == 'data1'
+ break
diff --git a/ld_eventsource/testing/test_async_reader.py b/ld_eventsource/testing/test_async_reader.py
new file mode 100644
index 0000000..05c02ca
--- /dev/null
+++ b/ld_eventsource/testing/test_async_reader.py
@@ -0,0 +1,153 @@
+import pytest
+
+from ld_eventsource.actions import Comment, Event
+from ld_eventsource.async_reader import (_AsyncBufferedLineReader,
+ _AsyncSSEReader)
+
+
+async def lines_from_bytes(*chunks: bytes):
+ """Helper: collect all lines from given byte chunks."""
+ async def gen():
+ for chunk in chunks:
+ yield chunk
+
+ result = []
+ async for line in _AsyncBufferedLineReader.lines_from(gen()):
+ result.append(line)
+ return result
+
+
+async def events_from_lines(*lines: str):
+ """Helper: collect all events/comments from given lines."""
+ async def gen():
+ for line in lines:
+ yield line
+
+ reader = _AsyncSSEReader(gen())
+ result = []
+ async for item in reader.events_and_comments():
+ result.append(item)
+ return result
+
+
+@pytest.mark.asyncio
+async def test_line_reader_simple_newline():
+ lines = await lines_from_bytes(b"hello\nworld\n")
+ assert lines == ["hello", "world"]
+
+
+@pytest.mark.asyncio
+async def test_line_reader_carriage_return():
+ lines = await lines_from_bytes(b"hello\rworld\r")
+ assert lines == ["hello", "world"]
+
+
+@pytest.mark.asyncio
+async def test_line_reader_crlf():
+ lines = await lines_from_bytes(b"hello\r\nworld\r\n")
+ assert lines == ["hello", "world"]
+
+
+@pytest.mark.asyncio
+async def test_line_reader_crlf_split_across_chunks():
+ lines = await lines_from_bytes(b"hello\r", b"\nworld\r\n")
+ assert lines == ["hello", "world"]
+
+
+@pytest.mark.asyncio
+async def test_line_reader_partial_line_across_chunks():
+ lines = await lines_from_bytes(b"hel", b"lo\n")
+ assert lines == ["hello"]
+
+
+@pytest.mark.asyncio
+async def test_line_reader_empty_chunk():
+ lines = await lines_from_bytes(b"hello\n", b"", b"world\n")
+ assert lines == ["hello", "world"]
+
+
+@pytest.mark.asyncio
+async def test_sse_reader_simple_event():
+ items = await events_from_lines("data: hello", "")
+ assert len(items) == 1
+ assert isinstance(items[0], Event)
+ assert items[0].data == "hello"
+ assert items[0].event == "message"
+
+
+@pytest.mark.asyncio
+async def test_sse_reader_event_with_type():
+ items = await events_from_lines("event: ping", "data: test", "")
+ assert len(items) == 1
+ assert isinstance(items[0], Event)
+ assert items[0].event == "ping"
+ assert items[0].data == "test"
+
+
+@pytest.mark.asyncio
+async def test_sse_reader_multiline_data():
+ items = await events_from_lines("data: line1", "data: line2", "")
+ assert len(items) == 1
+ assert items[0].data == "line1\nline2"
+
+
+@pytest.mark.asyncio
+async def test_sse_reader_comment():
+ items = await events_from_lines(":this is a comment", "data: event", "")
+ assert len(items) == 2
+ assert isinstance(items[0], Comment)
+ assert items[0].comment == "this is a comment"
+ assert isinstance(items[1], Event)
+
+
+@pytest.mark.asyncio
+async def test_sse_reader_event_id():
+ items = await events_from_lines("id: 123", "data: hello", "")
+ assert len(items) == 1
+ assert items[0].id == "123"
+ assert items[0].last_event_id == "123"
+
+
+@pytest.mark.asyncio
+async def test_sse_reader_id_persists_across_events():
+ items = await events_from_lines(
+ "id: 1", "data: first", "",
+ "data: second", "",
+ )
+ assert len(items) == 2
+ assert items[0].last_event_id == "1"
+ assert items[1].last_event_id == "1"
+
+
+@pytest.mark.asyncio
+async def test_sse_reader_retry_field():
+ retries = []
+
+ async def gen():
+ for line in ["retry: 5000", "data: test", ""]:
+ yield line
+
+ reader = _AsyncSSEReader(gen(), set_retry=lambda n: retries.append(n))
+ async for _ in reader.events_and_comments():
+ pass
+ assert retries == [5000]
+
+
+@pytest.mark.asyncio
+async def test_sse_reader_ignores_null_in_id():
+ items = await events_from_lines("id: bad\x00id", "data: test", "")
+ assert len(items) == 1
+ assert items[0].id is None
+
+
+@pytest.mark.asyncio
+async def test_sse_reader_multiple_events():
+ items = await events_from_lines(
+ "event: e1", "data: d1", "",
+ "event: e2", "data: d2", "",
+ )
+ assert len(items) == 2
+ assert items[0].event == "e1"
+ assert items[0].data == "d1"
+ assert items[1].event == "e2"
+ assert items[1].data == "d2"
diff --git a/ld_eventsource/testing/test_async_sse_client_basic.py b/ld_eventsource/testing/test_async_sse_client_basic.py
new file mode 100644
index 0000000..260deb6
--- /dev/null
+++ b/ld_eventsource/testing/test_async_sse_client_basic.py
@@ -0,0 +1,150 @@
+import pytest
+
+from ld_eventsource.actions import Comment, Event, Fault, Start
+from ld_eventsource.async_client import AsyncSSEClient
+from ld_eventsource.testing.async_helpers import (AsyncRespondWithData,
+ MockAsyncConnectStrategy)
+
+
+@pytest.mark.asyncio
+@pytest.mark.parametrize('explicitly_start', [False, True])
+async def test_receives_events(explicitly_start: bool):
+ mock = MockAsyncConnectStrategy(
+ AsyncRespondWithData(
+ "event: event1\ndata: data1\n\n:whatever\nevent: event2\ndata: data2\n\n"
+ )
+ )
+ async with AsyncSSEClient(connect=mock) as client:
+ if explicitly_start:
+ await client.start()
+
+ events = client.events
+ event_iter = events.__aiter__()
+
+ event1 = await event_iter.__anext__()
+ assert event1.event == 'event1'
+ assert event1.data == 'data1'
+
+ event2 = await event_iter.__anext__()
+ assert event2.event == 'event2'
+ assert event2.data == 'data2'
+
+
+@pytest.mark.asyncio
+async def test_events_returns_eof_when_stream_ends():
+ mock = MockAsyncConnectStrategy(AsyncRespondWithData("event: event1\ndata: data1\n\n"))
+ async with AsyncSSEClient(connect=mock) as client:
+ events = []
+ async for event in client.events:
+ events.append(event)
+
+ assert len(events) == 1
+ assert events[0].event == 'event1'
+ assert events[0].data == 'data1'
+
+
+@pytest.mark.asyncio
+async def test_receives_all():
+ mock = MockAsyncConnectStrategy(
+ AsyncRespondWithData(
+ "event: event1\ndata: data1\n\n:whatever\nevent: event2\ndata: data2\n\n"
+ )
+ )
+ async with AsyncSSEClient(connect=mock) as client:
+ all_iter = client.all.__aiter__()
+
+ item1 = await all_iter.__anext__()
+ assert isinstance(item1, Start)
+
+ item2 = await all_iter.__anext__()
+ assert isinstance(item2, Event)
+ assert item2.event == 'event1'
+ assert item2.data == 'data1'
+
+ item3 = await all_iter.__anext__()
+ assert isinstance(item3, Comment)
+ assert item3.comment == 'whatever'
+
+ item4 = await all_iter.__anext__()
+ assert isinstance(item4, Event)
+ assert item4.event == 'event2'
+ assert item4.data == 'data2'
+
+
+@pytest.mark.asyncio
+async def test_all_returns_fault_and_eof_when_stream_ends():
+ mock = MockAsyncConnectStrategy(AsyncRespondWithData("event: event1\ndata: data1\n\n"))
+ async with AsyncSSEClient(connect=mock) as client:
+ all_iter = client.all.__aiter__()
+
+ item1 = await all_iter.__anext__()
+ assert isinstance(item1, Start)
+
+ item2 = await all_iter.__anext__()
+ assert isinstance(item2, Event)
+ assert item2.event == 'event1'
+ assert item2.data == 'data1'
+
+ item3 = await all_iter.__anext__()
+ assert isinstance(item3, Fault)
+ assert item3.error is None
+
+ with pytest.raises(StopAsyncIteration):
+ await all_iter.__anext__()
+
+
+@pytest.mark.asyncio
+async def test_start_headers_exposed():
+ mock = MockAsyncConnectStrategy(
+ AsyncRespondWithData("data: hello\n\n", headers={'X-My-Header': 'myvalue'})
+ )
+ async with AsyncSSEClient(connect=mock) as client:
+ all_iter = client.all.__aiter__()
+
+ start = await all_iter.__anext__()
+ assert isinstance(start, Start)
+ assert start.headers is not None
+ assert start.headers.get('X-My-Header') == 'myvalue'
+
+
+@pytest.mark.asyncio
+async def test_last_event_id_tracked():
+ mock = MockAsyncConnectStrategy(
+ AsyncRespondWithData("id: abc\ndata: hello\n\n")
+ )
+ async with AsyncSSEClient(connect=mock) as client:
+ async for event in client.events:
+ assert event.last_event_id == 'abc'
+ break
+ assert client.last_event_id == 'abc'
+
+
+@pytest.mark.asyncio
+async def test_close_stops_iteration():
+ mock = MockAsyncConnectStrategy(
+ AsyncRespondWithData("data: first\n\ndata: second\n\n")
+ )
+ async with AsyncSSEClient(connect=mock) as client:
+ events_seen = []
+ async for event in client.events:
+ events_seen.append(event)
+ await client.close()
+
+ assert len(events_seen) == 1
+
+
+@pytest.mark.asyncio
+async def test_string_url_creates_http_strategy():
+ # Just verifies the constructor accepts a string without crashing
+ # (actual HTTP is tested separately)
+ from ld_eventsource.config.async_connect_strategy import \
+ AsyncConnectStrategy
+ client = AsyncSSEClient(connect="http://localhost:9999/stream")
+ assert client is not None
+ await client.close()
+
+
+@pytest.mark.asyncio
+async def test_invalid_connect_type_raises():
+ with pytest.raises(TypeError):
+ AsyncSSEClient(connect=12345)
diff --git a/ld_eventsource/testing/test_async_sse_client_retry.py b/ld_eventsource/testing/test_async_sse_client_retry.py
new file mode 100644
index 0000000..8fc08d4
--- /dev/null
+++ b/ld_eventsource/testing/test_async_sse_client_retry.py
@@ -0,0 +1,200 @@
+import asyncio
+
+import pytest
+
+from ld_eventsource.actions import Event, Fault, Start
+from ld_eventsource.async_client import AsyncSSEClient
+from ld_eventsource.config.error_strategy import ErrorStrategy
+from ld_eventsource.config.retry_delay_strategy import RetryDelayStrategy
+from ld_eventsource.errors import HTTPStatusError
+from ld_eventsource.testing.async_helpers import (AsyncExpectNoMoreRequests,
+ AsyncRejectConnection,
+ AsyncRespondWithData,
+ MockAsyncConnectStrategy)
+from ld_eventsource.testing.helpers import no_delay, retry_for_status
+
+
+@pytest.mark.asyncio
+async def test_retry_during_initial_connect_succeeds():
+ mock = MockAsyncConnectStrategy(
+ AsyncRejectConnection(HTTPStatusError(503)),
+ AsyncRespondWithData("data: data1\n\n"),
+ AsyncExpectNoMoreRequests(),
+ )
+ async with AsyncSSEClient(
+ connect=mock,
+ retry_delay_strategy=no_delay(),
+ error_strategy=retry_for_status(503),
+ ) as client:
+ await client.start()
+
+ events = []
+ async for event in client.events:
+ events.append(event)
+ break
+ assert events[0].data == 'data1'
+
+
+@pytest.mark.asyncio
+async def test_retry_during_initial_connect_succeeds_then_fails():
+ mock = MockAsyncConnectStrategy(
+ AsyncRejectConnection(HTTPStatusError(503)),
+ AsyncRejectConnection(HTTPStatusError(400)),
+ AsyncExpectNoMoreRequests(),
+ )
+ with pytest.raises(HTTPStatusError) as exc_info:
+ async with AsyncSSEClient(
+ connect=mock,
+ retry_delay_strategy=no_delay(),
+ error_strategy=retry_for_status(503),
+ ) as client:
+ await client.start()
+ assert exc_info.value.status == 400
+
+
+@pytest.mark.asyncio
+async def test_events_iterator_continues_after_retry():
+ mock = MockAsyncConnectStrategy(
+ AsyncRespondWithData("data: data1\n\n"),
+ AsyncRespondWithData("data: data2\n\n"),
+ AsyncExpectNoMoreRequests(),
+ )
+ async with AsyncSSEClient(
+ connect=mock,
+ error_strategy=ErrorStrategy.always_continue(),
+ retry_delay_strategy=no_delay(),
+ ) as client:
+ events = []
+ async for event in client.events:
+ events.append(event)
+ if len(events) == 2:
+ break
+
+ assert events[0].data == 'data1'
+ assert events[1].data == 'data2'
+
+
+@pytest.mark.asyncio
+async def test_all_iterator_continues_after_retry():
+ initial_delay = 0.005
+ mock = MockAsyncConnectStrategy(
+ AsyncRespondWithData("data: data1\n\n"),
+ AsyncRespondWithData("data: data2\n\n"),
+ AsyncRespondWithData("data: data3\n\n"),
+ AsyncExpectNoMoreRequests(),
+ )
+ async with AsyncSSEClient(
+ connect=mock,
+ error_strategy=ErrorStrategy.always_continue(),
+ initial_retry_delay=initial_delay,
+ retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None),
+ ) as client:
+ all_iter = client.all.__aiter__()
+
+ item1 = await all_iter.__anext__()
+ assert isinstance(item1, Start)
+
+ item2 = await all_iter.__anext__()
+ assert isinstance(item2, Event)
+ assert item2.data == 'data1'
+
+ item3 = await all_iter.__anext__()
+ assert isinstance(item3, Fault)
+ assert item3.error is None
+ assert client.next_retry_delay == initial_delay
+
+ item4 = await all_iter.__anext__()
+ assert isinstance(item4, Start)
+
+ item5 = await all_iter.__anext__()
+ assert isinstance(item5, Event)
+ assert item5.data == 'data2'
+
+ item6 = await all_iter.__anext__()
+ assert isinstance(item6, Fault)
+ assert item6.error is None
+ assert client.next_retry_delay == initial_delay * 2
+
+
+@pytest.mark.asyncio
+async def test_can_interrupt_and_restart_stream():
+ initial_delay = 0.005
+ mock = MockAsyncConnectStrategy(
+ AsyncRespondWithData("data: data1\n\ndata: data2\n\n"),
+ AsyncRespondWithData("data: data3\n\n"),
+ AsyncExpectNoMoreRequests(),
+ )
+ async with AsyncSSEClient(
+ connect=mock,
+ error_strategy=ErrorStrategy.always_continue(),
+ initial_retry_delay=initial_delay,
+ retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None),
+ ) as client:
+ all_iter = client.all.__aiter__()
+
+ item1 = await all_iter.__anext__()
+ assert isinstance(item1, Start)
+
+ item2 = await all_iter.__anext__()
+ assert isinstance(item2, Event)
+ assert item2.data == 'data1'
+
+ await client.interrupt()
+ assert client.next_retry_delay == initial_delay
+
+ item3 = await all_iter.__anext__()
+ assert isinstance(item3, Fault)
+
+ item4 = await all_iter.__anext__()
+ assert isinstance(item4, Start)
+
+ item5 = await all_iter.__anext__()
+ assert isinstance(item5, Event)
+ assert item5.data == 'data3'
+
+
+@pytest.mark.asyncio
+async def test_retry_delay_gets_reset_after_threshold():
+ initial_delay = 0.005
+ retry_delay_reset_threshold = 0.1
+ mock = MockAsyncConnectStrategy(
+ AsyncRespondWithData("data: data1\n\n"),
+ AsyncRejectConnection(HTTPStatusError(503)),
+ )
+ async with AsyncSSEClient(
+ connect=mock,
+ error_strategy=ErrorStrategy.always_continue(),
+ initial_retry_delay=initial_delay,
+ retry_delay_reset_threshold=retry_delay_reset_threshold,
+ retry_delay_strategy=RetryDelayStrategy.default(jitter_multiplier=None),
+ ) as client:
+ assert client._retry_reset_baseline == 0
+ all_iter = client.all.__aiter__()
+
+ item1 = await all_iter.__anext__()
+ assert isinstance(item1, Start)
+ assert client._retry_reset_baseline != 0
+
+ item2 = await all_iter.__anext__()
+ assert isinstance(item2, Event)
+ assert item2.data == 'data1'
+
+ item3 = await all_iter.__anext__()
+ assert isinstance(item3, Fault)
+ assert client.next_retry_delay == initial_delay
+
+ item4 = await all_iter.__anext__()
+ assert isinstance(item4, Fault)
+ assert client.next_retry_delay == initial_delay * 2
+
+ await asyncio.sleep(retry_delay_reset_threshold)
+
+ item5 = await all_iter.__anext__()
+ assert isinstance(item5, Fault)
+ assert client.next_retry_delay == initial_delay
+
+ await asyncio.sleep(retry_delay_reset_threshold / 2)
+
+ item6 = await all_iter.__anext__()
+ assert isinstance(item6, Fault)
+ assert client.next_retry_delay == initial_delay * 2
diff --git a/pyproject.toml b/pyproject.toml
index 2948b3d..604fdca 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -29,11 +29,16 @@ exclude = [
[tool.poetry.dependencies]
python = ">=3.9"
urllib3 = ">=1.26.0,<3"
+aiohttp = { version = ">=3.8.0,<4", optional = true }
+[tool.poetry.extras]
+async = ["aiohttp"]
[tool.poetry.group.dev.dependencies]
mock = ">=2.0.0"
pytest = ">=2.8"
+pytest-asyncio = ">=0.21"
+aiohttp = ">=3.8.0,<4"
mypy = "^1.4.0"
pycodestyle = "^2.12.1"
isort = "^5.13.2"
@@ -43,7 +48,7 @@ isort = "^5.13.2"
optional = true
[tool.poetry.group.contract-tests.dependencies]
-Flask = "2.2.5"
+Flask = ">=3.0"
[tool.poetry.group.docs]
@@ -59,7 +64,7 @@ pyrfc3339 = ">=1.0"
jsonpickle = ">1.4.1"
semver = ">=2.7.9"
urllib3 = ">=1.26.0"
-jinja2 = "3.0.0"
+jinja2 = ">=3.1.2"
[tool.mypy]
python_version = "3.9"
@@ -70,6 +75,7 @@ non_interactive = true
[tool.pytest.ini_options]
addopts = ["-ra"]
+asyncio_mode = "auto"
[build-system]