From 9a9103f30fd3d64f49a56a7467289450b006fe5c Mon Sep 17 00:00:00 2001 From: Michael Chou Date: Tue, 23 Jun 2026 13:21:51 -0700 Subject: [PATCH] refresh --- agentex/src/config/dependencies.py | 133 ++++++++++++-- agentex/src/config/environment_variables.py | 9 + agentex/tests/integration/config/__init__.py | 0 .../test_mongodb_oidc_refresh_integration.py | 71 ++++++++ .../unit/config/test_mongodb_oidc_refresh.py | 164 ++++++++++++++++++ 5 files changed, 366 insertions(+), 11 deletions(-) create mode 100644 agentex/tests/integration/config/__init__.py create mode 100644 agentex/tests/integration/config/test_mongodb_oidc_refresh_integration.py create mode 100644 agentex/tests/unit/config/test_mongodb_oidc_refresh.py diff --git a/agentex/src/config/dependencies.py b/agentex/src/config/dependencies.py index f37cf043..ad51e5ea 100644 --- a/agentex/src/config/dependencies.py +++ b/agentex/src/config/dependencies.py @@ -49,6 +49,8 @@ def __init__(self): self.redis_pool: redis.ConnectionPool | None = None self.database_async_read_only_engine: AsyncEngine | None = None self.postgres_metrics_collector: PostgresMetricsCollector | None = None + self._mongodb_refresh_task: asyncio.Task | None = None + self._mongodb_close_tasks: set[asyncio.Task] = set() self._loaded = False async def create_temporal_client(self): @@ -122,17 +124,7 @@ async def load(self): logger.info("Connecting to MongoDB") - self.mongodb_client = AsyncMongoClient( - mongodb_uri, - serverSelectionTimeoutMS=20000, - connectTimeoutMS=20000, - socketTimeoutMS=20000, - retryWrites=False, # Disable retryable writes for AWS DocumentDB compatibility - maxPoolSize=self.environment_variables.MONGODB_MAX_POOL_SIZE, - minPoolSize=self.environment_variables.MONGODB_MIN_POOL_SIZE, - maxIdleTimeMS=30000, # Close connections after 30 seconds of inactivity - waitQueueTimeoutMS=5000, # Wait up to 5 seconds for a connection from pool - ) + self.mongodb_client = self._build_mongodb_client(mongodb_uri) self.mongodb_database = self.mongodb_client[mongodb_database_name] # Ping the database to verify connection @@ -226,10 +218,126 @@ async def load(self): service_name=service_name, ) + self._start_mongodb_oidc_refresh_loop() + self._loaded = True + def _build_mongodb_client(self, mongodb_uri: str) -> AsyncMongoClient: + """Construct an AsyncMongoClient with the shared pool/timeout settings. + + Used both at startup and by the OIDC refresh, so the two paths can never + drift apart. + """ + return AsyncMongoClient( + mongodb_uri, + serverSelectionTimeoutMS=20000, + connectTimeoutMS=20000, + socketTimeoutMS=20000, + retryWrites=False, # Disable retryable writes for AWS DocumentDB compatibility + maxPoolSize=self.environment_variables.MONGODB_MAX_POOL_SIZE, + minPoolSize=self.environment_variables.MONGODB_MIN_POOL_SIZE, + maxIdleTimeMS=30000, # Close connections after 30 seconds of inactivity + waitQueueTimeoutMS=5000, # Wait up to 5 seconds for a connection from pool + ) + + def _mongodb_uses_oidc(self) -> bool: + """True only when the Mongo URI authenticates via MONGODB-OIDC. + + Gates the refresh loop so standard-auth / AWS DocumentDB deployments are + never churned — only GCP OIDC tokens expire out from under a live client. + """ + uri = self.environment_variables.MONGODB_URI or "" + return "MONGODB-OIDC" in uri.upper() + + async def refresh_mongodb_client(self) -> None: + """Rebuild the Mongo client to renew the cached GCP OIDC token. + + pymongo's built-in GCP OIDC provider caches the access token for the life + of the client and only refreshes it reactively (on a server reauth + challenge). GCP tokens expire after ~1h, so a long-lived client eventually + fails auth. A new client authenticates fresh, picking up a new token. + + The new client is built and pinged (which forces fresh auth) before the + swap, and the old client is closed only after a drain delay, so no in-flight + operation is ever dropped and we never swap to a broken client. + """ + mongodb_uri = self.environment_variables.MONGODB_URI + if not mongodb_uri or not self._mongodb_uses_oidc(): + return + + new_client = self._build_mongodb_client(mongodb_uri) + # Force fresh OIDC auth and validate the new client before trusting it. + # If this raises, we keep using the existing (working) client. + await new_client.admin.command("ping") + + old_client = self.mongodb_client + self.mongodb_client = new_client + self.mongodb_database = new_client[ + self.environment_variables.MONGODB_DATABASE_NAME + ] + logger.info("Refreshed MongoDB client to renew OIDC credentials") + + if old_client is not None and old_client is not new_client: + task = asyncio.create_task( + self._close_mongodb_client_after_delay(old_client) + ) + # Keep a strong reference until done so the task is not GC'd mid-flight. + self._mongodb_close_tasks.add(task) + task.add_done_callback(self._mongodb_close_tasks.discard) + + async def _close_mongodb_client_after_delay( + self, client: AsyncMongoClient, delay: float = 60.0 + ) -> None: + """Close a superseded Mongo client after letting in-flight ops drain.""" + try: + await asyncio.sleep(delay) + await client.close() + except asyncio.CancelledError: + await client.close() + raise + except Exception as e: + logger.warning(f"Error closing superseded MongoDB client: {e}") + + def _start_mongodb_oidc_refresh_loop(self) -> None: + interval = self.environment_variables.MONGODB_OIDC_REFRESH_INTERVAL_SECONDS + if ( + self.mongodb_client is None + or not self._mongodb_uses_oidc() + or interval <= 0 + or self._mongodb_refresh_task is not None + ): + return + self._mongodb_refresh_task = asyncio.create_task( + self._mongodb_oidc_refresh_loop(interval) + ) + logger.info(f"Started MongoDB OIDC refresh loop (interval={interval}s)") + + async def _mongodb_oidc_refresh_loop(self, interval: int) -> None: + while True: + try: + await asyncio.sleep(interval) + await self.refresh_mongodb_client() + except asyncio.CancelledError: + raise + except Exception as e: + logger.error( + f"MongoDB OIDC refresh failed; retrying next interval: {e}" + ) + + async def _stop_mongodb_oidc_refresh_loop(self) -> None: + if self._mongodb_refresh_task is not None: + self._mongodb_refresh_task.cancel() + try: + await self._mongodb_refresh_task + except asyncio.CancelledError: + pass + self._mongodb_refresh_task = None + async def force_reload(self): """Force reload all dependencies with fresh environment variables""" + # Stop the MongoDB OIDC refresh loop before tearing down the client + await self._stop_mongodb_oidc_refresh_loop() + # Stop metrics collection if self.postgres_metrics_collector: await self.postgres_metrics_collector.stop_collection() @@ -272,6 +380,9 @@ def shutdown(): async def async_shutdown(): global_dependencies = GlobalDependencies() + # Stop the MongoDB OIDC refresh loop + await global_dependencies._stop_mongodb_oidc_refresh_loop() + # Stop PostgreSQL metrics collection if global_dependencies.postgres_metrics_collector: await global_dependencies.postgres_metrics_collector.stop_collection() diff --git a/agentex/src/config/environment_variables.py b/agentex/src/config/environment_variables.py index 0872c0cf..a6f1ca15 100644 --- a/agentex/src/config/environment_variables.py +++ b/agentex/src/config/environment_variables.py @@ -37,6 +37,7 @@ class EnvVarKeys(str, Enum): MONGODB_DATABASE_NAME = "MONGODB_DATABASE_NAME" MONGODB_MAX_POOL_SIZE = "MONGODB_MAX_POOL_SIZE" MONGODB_MIN_POOL_SIZE = "MONGODB_MIN_POOL_SIZE" + MONGODB_OIDC_REFRESH_INTERVAL_SECONDS = "MONGODB_OIDC_REFRESH_INTERVAL_SECONDS" REDIS_MAX_CONNECTIONS = "REDIS_MAX_CONNECTIONS" REDIS_CONNECTION_TIMEOUT = "REDIS_CONNECTION_TIMEOUT" REDIS_SOCKET_TIMEOUT = "REDIS_SOCKET_TIMEOUT" @@ -96,6 +97,11 @@ class EnvironmentVariables(BaseModel): MONGODB_DATABASE_NAME: str | None = "agentex" MONGODB_MAX_POOL_SIZE: int = 50 MONGODB_MIN_POOL_SIZE: int = 5 + # Rebuild the Mongo client on this interval to renew GCP OIDC credentials. + # pymongo caches the OIDC token for the life of the client and never refreshes + # it proactively, so a long-lived client fails auth once the ~1h GCP token + # expires. Only applied to MONGODB-OIDC URIs; 0 disables. Default 45 min. + MONGODB_OIDC_REFRESH_INTERVAL_SECONDS: int = 2700 REDIS_MAX_CONNECTIONS: int = 50 # Increased for SSE streaming REDIS_CONNECTION_TIMEOUT: int = 60 # Connection timeout in seconds REDIS_SOCKET_TIMEOUT: int = 30 # Socket timeout in seconds @@ -163,6 +169,9 @@ def refresh(cls, force_refresh: bool = False) -> EnvironmentVariables | None: MONGODB_MIN_POOL_SIZE=int( os.environ.get(EnvVarKeys.MONGODB_MIN_POOL_SIZE, "5") ), + MONGODB_OIDC_REFRESH_INTERVAL_SECONDS=int( + os.environ.get(EnvVarKeys.MONGODB_OIDC_REFRESH_INTERVAL_SECONDS, "2700") + ), REDIS_MAX_CONNECTIONS=int( os.environ.get(EnvVarKeys.REDIS_MAX_CONNECTIONS, "100") ), diff --git a/agentex/tests/integration/config/__init__.py b/agentex/tests/integration/config/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/agentex/tests/integration/config/test_mongodb_oidc_refresh_integration.py b/agentex/tests/integration/config/test_mongodb_oidc_refresh_integration.py new file mode 100644 index 00000000..2fce71dc --- /dev/null +++ b/agentex/tests/integration/config/test_mongodb_oidc_refresh_integration.py @@ -0,0 +1,71 @@ +"""Integration tests for the MongoDB client-refresh swap against a real Mongo. + +The unit tests mock the client; these prove the build-validate-swap-drain works +end-to-end against a live MongoDB container: data written before the swap is still +readable after it, the post-swap client is fully functional, and the superseded +client is drained and closed. (The container doesn't speak GCP OIDC, so the OIDC +gate is forced on to exercise the swap path itself.) +""" + +from unittest.mock import AsyncMock + +import pytest +from src.config.dependencies import GlobalDependencies, Singleton + + +@pytest.fixture +def deps(mongodb_connection_string): + """Fresh GlobalDependencies wired to the test Mongo container.""" + Singleton._instances.pop(GlobalDependencies, None) + instance = GlobalDependencies() + instance.environment_variables = instance.environment_variables.model_copy( + update={ + "MONGODB_URI": mongodb_connection_string, + "MONGODB_DATABASE_NAME": "agentex_oidc_refresh_test", + } + ) + instance.mongodb_client = instance._build_mongodb_client(mongodb_connection_string) + instance.mongodb_database = instance.mongodb_client["agentex_oidc_refresh_test"] + yield instance + Singleton._instances.pop(GlobalDependencies, None) + + +@pytest.mark.asyncio +@pytest.mark.integration +async def test_refresh_preserves_data_and_drains_old_client(deps, monkeypatch): + # Treat the container URI as OIDC so the refresh path actually runs, and + # collapse the drain delay so the close completes within the test. + monkeypatch.setattr(deps, "_mongodb_uses_oidc", lambda: True) + original_close_after_delay = deps._close_mongodb_client_after_delay + + async def fast_close(client, delay=0.0): + await original_close_after_delay(client, delay=0.0) + + monkeypatch.setattr(deps, "_close_mongodb_client_after_delay", fast_close) + + collection = "docs" + await deps.mongodb_database[collection].insert_one({"_id": "before", "n": 1}) + + old_client = deps.mongodb_client + old_client.close = AsyncMock(wraps=old_client.close) + + await deps.refresh_mongodb_client() + + # A genuinely new client is now installed. + assert deps.mongodb_client is not old_client + + # The new client can write, and reads the doc written before the swap. + await deps.mongodb_database[collection].insert_one({"_id": "after", "n": 2}) + ids = { + doc["_id"] + async for doc in deps.mongodb_database[collection].find({}, {"_id": 1}) + } + assert ids == {"before", "after"} + + # The superseded client is drained and closed. + for task in list(deps._mongodb_close_tasks): + await task + old_client.close.assert_awaited_once() + + await deps.mongodb_client.drop_database("agentex_oidc_refresh_test") + await deps.mongodb_client.close() diff --git a/agentex/tests/unit/config/test_mongodb_oidc_refresh.py b/agentex/tests/unit/config/test_mongodb_oidc_refresh.py new file mode 100644 index 00000000..148469b4 --- /dev/null +++ b/agentex/tests/unit/config/test_mongodb_oidc_refresh.py @@ -0,0 +1,164 @@ +"""Unit tests for the MongoDB OIDC client-refresh path in GlobalDependencies. + +pymongo's built-in GCP OIDC provider caches the access token for the life of the +client and never refreshes it proactively, so a long-lived client fails auth once +the ~1h GCP token expires. `refresh_mongodb_client()` rebuilds the client to renew +the token without bouncing the process; these tests cover the gating, the +build-validate-swap-drain ordering, and the loop lifecycle — all without a real +MongoDB (the client is mocked). +""" + +from unittest.mock import AsyncMock, MagicMock + +import pytest +from src.config.dependencies import GlobalDependencies, Singleton +from src.config.environment_variables import EnvironmentVariables + +OIDC_URI = ( + "mongodb://host/?authMechanism=MONGODB-OIDC" + "&authMechanismProperties=ENVIRONMENT:gcp,TOKEN_RESOURCE:FIRESTORE" +) +PLAIN_URI = "mongodb://user:pass@host:27017/?authSource=admin" + + +@pytest.fixture +def deps(): + """A fresh GlobalDependencies, isolated from the process-wide singleton.""" + Singleton._instances.pop(GlobalDependencies, None) + instance = GlobalDependencies() + yield instance + Singleton._instances.pop(GlobalDependencies, None) + + +def _mock_client() -> MagicMock: + client = MagicMock() + client.admin.command = AsyncMock(return_value={"ok": 1}) + client.close = AsyncMock() + return client + + +def _set_uri(deps: GlobalDependencies, uri: str | None) -> None: + deps.environment_variables = deps.environment_variables.model_copy( + update={"MONGODB_URI": uri, "MONGODB_DATABASE_NAME": "agentex"} + ) + + +@pytest.mark.unit +def test_env_refresh_interval_parses_and_defaults(monkeypatch): + monkeypatch.setenv("MONGODB_OIDC_REFRESH_INTERVAL_SECONDS", "900") + assert ( + EnvironmentVariables.refresh( + force_refresh=True + ).MONGODB_OIDC_REFRESH_INTERVAL_SECONDS + == 900 + ) + + monkeypatch.delenv("MONGODB_OIDC_REFRESH_INTERVAL_SECONDS", raising=False) + assert ( + EnvironmentVariables.refresh( + force_refresh=True + ).MONGODB_OIDC_REFRESH_INTERVAL_SECONDS + == 2700 + ) + + +@pytest.mark.unit +def test_uses_oidc_gate(deps): + _set_uri(deps, OIDC_URI) + assert deps._mongodb_uses_oidc() is True + + _set_uri(deps, PLAIN_URI) + assert deps._mongodb_uses_oidc() is False + + _set_uri(deps, None) + assert deps._mongodb_uses_oidc() is False + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_refresh_swaps_client_and_drains_old(deps, monkeypatch): + _set_uri(deps, OIDC_URI) + old_client = _mock_client() + new_client = _mock_client() + deps.mongodb_client = old_client + monkeypatch.setattr(deps, "_build_mongodb_client", lambda uri: new_client) + + await deps.refresh_mongodb_client() + + # New client validated before the swap, then installed. + new_client.admin.command.assert_awaited_once_with("ping") + assert deps.mongodb_client is new_client + assert deps.mongodb_database is new_client["agentex"] + + # Old client is scheduled for a drained close, not closed immediately. + old_client.close.assert_not_awaited() + assert len(deps._mongodb_close_tasks) == 1 + for task in list(deps._mongodb_close_tasks): + task.cancel() + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_refresh_noop_for_non_oidc(deps, monkeypatch): + _set_uri(deps, PLAIN_URI) + old_client = _mock_client() + deps.mongodb_client = old_client + build = MagicMock() + monkeypatch.setattr(deps, "_build_mongodb_client", build) + + await deps.refresh_mongodb_client() + + build.assert_not_called() + assert deps.mongodb_client is old_client + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_refresh_keeps_old_client_when_new_fails_validation(deps, monkeypatch): + _set_uri(deps, OIDC_URI) + old_client = _mock_client() + deps.mongodb_client = old_client + + broken = _mock_client() + broken.admin.command = AsyncMock(side_effect=RuntimeError("auth failed")) + monkeypatch.setattr(deps, "_build_mongodb_client", lambda uri: broken) + + with pytest.raises(RuntimeError): + await deps.refresh_mongodb_client() + + # Never swapped to the broken client; never tore down the working one. + assert deps.mongodb_client is old_client + old_client.close.assert_not_awaited() + + +@pytest.mark.asyncio +@pytest.mark.unit +async def test_loop_start_gating(deps): + deps.mongodb_client = _mock_client() + + # Disabled by interval. + _set_uri(deps, OIDC_URI) + deps.environment_variables = deps.environment_variables.model_copy( + update={"MONGODB_OIDC_REFRESH_INTERVAL_SECONDS": 0} + ) + deps._start_mongodb_oidc_refresh_loop() + assert deps._mongodb_refresh_task is None + + # Disabled by non-OIDC URI. + _set_uri(deps, PLAIN_URI) + deps.environment_variables = deps.environment_variables.model_copy( + update={"MONGODB_OIDC_REFRESH_INTERVAL_SECONDS": 2700} + ) + deps._start_mongodb_oidc_refresh_loop() + assert deps._mongodb_refresh_task is None + + # Enabled: OIDC + positive interval. + _set_uri(deps, OIDC_URI) + deps.environment_variables = deps.environment_variables.model_copy( + update={"MONGODB_OIDC_REFRESH_INTERVAL_SECONDS": 2700} + ) + deps._start_mongodb_oidc_refresh_loop() + assert deps._mongodb_refresh_task is not None + + await deps._stop_mongodb_oidc_refresh_loop() + assert deps._mongodb_refresh_task is None