diff --git a/cpgqls_client/client.py b/cpgqls_client/client.py index aa77bd7..4c7b4b0 100644 --- a/cpgqls_client/client.py +++ b/cpgqls_client/client.py @@ -3,6 +3,15 @@ import websockets +def _get_or_create_event_loop(): + try: + return asyncio.get_event_loop() + except RuntimeError: + event_loop = asyncio.new_event_loop() + asyncio.set_event_loop(event_loop) + return event_loop + + class CPGQLSTransport: def __init__(self): @@ -32,7 +41,7 @@ def __init__(self, server_endpoint, event_loop=None, transport=None, auth_creden if not isinstance(server_endpoint, str): raise ValueError("server_endpoint parameter has to be a string") - self._loop = asyncio.get_event_loop() if not event_loop else event_loop + self._loop = _get_or_create_event_loop() if event_loop is None else event_loop self._transport = CPGQLSTransport() if not transport else transport self._endpoint = server_endpoint.rstrip("/") self._auth_creds = auth_credentials diff --git a/tests/test_client.py b/tests/test_client.py index ac6d99d..bdfcfc8 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -92,6 +92,21 @@ def test_basic_execution(): assert result == post_response_mock.json() +def test_client_creates_event_loop_when_current_thread_has_none(): + asyncio.set_event_loop(None) + client = None + + try: + client = CPGQLSClient("localhost:8080", transport=Mock()) + + assert isinstance(client._loop, asyncio.AbstractEventLoop) + assert asyncio.get_event_loop() is client._loop + finally: + if client is not None: + client._loop.close() + asyncio.set_event_loop(None) + + def test_get_response_not_200(): event_loop = asyncio.new_event_loop() conn = MockCPGQLTransportConnection("connected", "received")