diff --git a/massive/rest/__init__.py b/massive/rest/__init__.py index 5a00da5a..d1f6aa2c 100644 --- a/massive/rest/__init__.py +++ b/massive/rest/__init__.py @@ -20,7 +20,7 @@ ContractsClient, ) from .vX import VXClient -from typing import Optional, Any +from typing import Optional, Any, Dict import os BASE = "https://api.massive.com" @@ -60,6 +60,7 @@ def __init__( verbose: bool = False, trace: bool = False, custom_json: Optional[Any] = None, + connection_pool_kw: Optional[Dict[str, Any]] = None, ): super().__init__( api_key=api_key, @@ -72,6 +73,7 @@ def __init__( verbose=verbose, trace=trace, custom_json=custom_json, + connection_pool_kw=connection_pool_kw, ) self.vx = VXClient( api_key=api_key, @@ -84,4 +86,5 @@ def __init__( verbose=verbose, trace=trace, custom_json=custom_json, + connection_pool_kw=connection_pool_kw, ) diff --git a/massive/rest/base.py b/massive/rest/base.py index 3349d7ef..dadead3f 100644 --- a/massive/rest/base.py +++ b/massive/rest/base.py @@ -28,6 +28,7 @@ def __init__( connect_timeout: float, read_timeout: float, num_pools: int, + connection_pool_kw: Optional[Dict[str, Any]], retries: int, base: str, pagination: bool, @@ -80,6 +81,7 @@ def __init__( cert_reqs="CERT_REQUIRED", retries=retry_strategy, # use the customized Retry instance timeout=self.timeout, # set timeout for each request + **(connection_pool_kw if connection_pool_kw is not None else {}), ) if verbose: diff --git a/test_rest/test_connection_pool_kw.py b/test_rest/test_connection_pool_kw.py new file mode 100644 index 00000000..d1ed0b31 --- /dev/null +++ b/test_rest/test_connection_pool_kw.py @@ -0,0 +1,36 @@ +import sys +import os + +sys.path.insert(0, os.path.dirname(__file__)) + +import unittest +from unittest.mock import patch, MagicMock +from massive import RESTClient + + +class ConnectionPoolKwTest(unittest.TestCase): + def _make_client(self, **kwargs): + with patch("urllib3.PoolManager") as mock_pm: + mock_pm.return_value = MagicMock() + RESTClient(api_key="test", **kwargs) + return mock_pm + + def test_default_no_extra_kwargs(self): + """connection_pool_kw=None passes no extra kwargs to PoolManager.""" + mock_pm = self._make_client() + _, call_kwargs = mock_pm.call_args + self.assertNotIn("connection_pool_kw", call_kwargs) + + def test_connection_pool_kw_passed_as_kwargs(self): + """connection_pool_kw dict is unpacked into PoolManager as kwargs.""" + extra = {"maxsize": 20, "block": True} + mock_pm = self._make_client(connection_pool_kw=extra) + _, call_kwargs = mock_pm.call_args + self.assertEqual(call_kwargs["maxsize"], 20) + self.assertEqual(call_kwargs["block"], True) + + def test_empty_connection_pool_kw(self): + """Empty dict connection_pool_kw adds no extra kwargs to PoolManager.""" + mock_pm = self._make_client(connection_pool_kw={}) + _, call_kwargs = mock_pm.call_args + self.assertNotIn("connection_pool_kw", call_kwargs)