diff --git a/kanboard.py b/kanboard.py index 8076031..4f4dead 100644 --- a/kanboard.py +++ b/kanboard.py @@ -98,11 +98,25 @@ def __init__( if loop: self._event_loop = loop + self._owns_event_loop = False else: try: self._event_loop = asyncio.get_event_loop() + self._owns_event_loop = False except RuntimeError: self._event_loop = asyncio.new_event_loop() + self._owns_event_loop = True + + def __enter__(self) -> "Client": + return self + + def __exit__(self, *args: Any) -> None: + self.close() + + def close(self) -> None: + """Close the event loop if it was created by this client.""" + if self._owns_event_loop and not self._event_loop.is_closed(): + self._event_loop.close() def __getattr__(self, name: str) -> Callable[..., Any]: if self.is_async_method_name(name): diff --git a/tests/test_kanboard.py b/tests/test_kanboard.py index e03118f..99fd187 100644 --- a/tests/test_kanboard.py +++ b/tests/test_kanboard.py @@ -21,9 +21,7 @@ # THE SOFTWARE. import asyncio -import types import unittest -import warnings from unittest import mock import kanboard @@ -35,13 +33,8 @@ def setUp(self): self.client = kanboard.Client(self.url, "username", "password") self.request, self.urlopen = self._create_mocks() - def ignore_warnings(test_func): - def do_test(self, *args, **kwargs): - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - test_func(self, *args, **kwargs) - - return do_test + def tearDown(self): + self.client.close() def test_api_call(self): body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' @@ -106,13 +99,6 @@ def test_method_name_extracted_from_async_name(self): result = self.client.get_funcname_from_async_name(async_method_name) self.assertEqual(expected_method_name, result) - # suppress a RuntimeWarning because coro is not awaited - # this is done on purpose - @ignore_warnings - def test_async_call_generates_coro(self): - method = self.client.my_method_async() - self.assertIsInstance(method, types.CoroutineType) - def test_async_call_returns_result(self): body = b'{"jsonrpc": "2.0", "result": 42, "id": 123}' self.urlopen.return_value.read.return_value = body @@ -125,16 +111,37 @@ def test_custom_event_loop(self): try: client = kanboard.Client(self.url, "username", "password", loop=custom_loop) self.assertIs(client._event_loop, custom_loop) + self.assertFalse(client._owns_event_loop) + finally: + custom_loop.close() + + def test_close_owned_event_loop(self): + client = kanboard.Client(self.url, "username", "password") + if client._owns_event_loop: + self.assertFalse(client._event_loop.is_closed()) + client.close() + self.assertTrue(client._event_loop.is_closed()) + + def test_close_does_not_close_external_loop(self): + custom_loop = asyncio.new_event_loop() + try: + client = kanboard.Client(self.url, "username", "password", loop=custom_loop) + client.close() + self.assertFalse(custom_loop.is_closed()) finally: custom_loop.close() + def test_context_manager(self): + with kanboard.Client(self.url, "username", "password") as client: + self.assertIsNotNone(client._event_loop) + def test_custom_user_agent(self): - client = kanboard.Client(self.url, "username", "password", user_agent="CustomAgent/1.0") - body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' - self.urlopen.return_value.read.return_value = body - client.remote_procedure() - _, kwargs = self.request.call_args - self.assertEqual("CustomAgent/1.0", kwargs["headers"]["User-Agent"]) + with kanboard.Client(self.url, "username", "password", user_agent="CustomAgent/1.0") as client: + body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' + self.urlopen.return_value.read.return_value = body + client.remote_procedure() + _, kwargs = self.request.call_args + self.assertEqual("CustomAgent/1.0", kwargs["headers"]["User-Agent"]) def test_default_user_agent(self): body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' @@ -145,30 +152,30 @@ def test_default_user_agent(self): @mock.patch("ssl.create_default_context") def test_insecure_disables_ssl_verification(self, mock_ssl_context): - client = kanboard.Client(self.url, "username", "password", insecure=True) - ctx = mock_ssl_context.return_value - body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' - self.urlopen.return_value.read.return_value = body - client.remote_procedure() - self.assertFalse(ctx.check_hostname) - self.assertEqual(ctx.verify_mode, __import__("ssl").CERT_NONE) + with kanboard.Client(self.url, "username", "password", insecure=True) as client: + ctx = mock_ssl_context.return_value + body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' + self.urlopen.return_value.read.return_value = body + client.remote_procedure() + self.assertFalse(ctx.check_hostname) + self.assertEqual(ctx.verify_mode, __import__("ssl").CERT_NONE) @mock.patch("ssl.create_default_context") def test_ignore_hostname_verification(self, mock_ssl_context): - client = kanboard.Client(self.url, "username", "password", ignore_hostname_verification=True) - ctx = mock_ssl_context.return_value - body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' - self.urlopen.return_value.read.return_value = body - client.remote_procedure() - self.assertFalse(ctx.check_hostname) + with kanboard.Client(self.url, "username", "password", ignore_hostname_verification=True) as client: + ctx = mock_ssl_context.return_value + body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' + self.urlopen.return_value.read.return_value = body + client.remote_procedure() + self.assertFalse(ctx.check_hostname) @mock.patch("ssl.create_default_context") def test_cafile_passed_to_ssl_context(self, mock_ssl_context): - client = kanboard.Client(self.url, "username", "password", cafile="/path/to/ca.pem") - body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' - self.urlopen.return_value.read.return_value = body - client.remote_procedure() - mock_ssl_context.assert_called_once_with(cafile="/path/to/ca.pem") + with kanboard.Client(self.url, "username", "password", cafile="/path/to/ca.pem") as client: + body = b'{"jsonrpc": "2.0", "result": true, "id": 123}' + self.urlopen.return_value.read.return_value = body + client.remote_procedure() + mock_ssl_context.assert_called_once_with(cafile="/path/to/ca.pem") @staticmethod def _create_mocks():