diff --git a/.sampo/changesets/roguish-stormcaller-tuonetar.md b/.sampo/changesets/roguish-stormcaller-tuonetar.md new file mode 100644 index 00000000..c6213223 --- /dev/null +++ b/.sampo/changesets/roguish-stormcaller-tuonetar.md @@ -0,0 +1,5 @@ +--- +pypi/posthog: minor +--- + +Add async flag definition cache providers diff --git a/examples/async_redis_flag_cache.py b/examples/async_redis_flag_cache.py new file mode 100644 index 00000000..957bc9d5 --- /dev/null +++ b/examples/async_redis_flag_cache.py @@ -0,0 +1,149 @@ +""" +Async Redis-based distributed cache for PostHog feature flag definitions. + +This example demonstrates how to implement a FlagDefinitionCacheProvider with +redis.asyncio for async-first applications. The PostHog SDK accepts async cache +provider methods and runs their awaitables to completion before continuing. + +Usage: + import redis.asyncio as redis + from posthog import Posthog + + # Use a Redis client dedicated to this cache provider. The SDK runs async + # provider methods on its own background event loop. + redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True) + cache = AsyncRedisFlagCache(redis_client, service_key="my-service") + + posthog = Posthog( + "", + personal_api_key="", + flag_definition_cache_provider=cache, + ) + +Requirements: + pip install redis +""" + +import json +import uuid +from typing import Optional + +from posthog import FlagDefinitionCacheData, FlagDefinitionCacheProvider +from redis.asyncio import Redis + + +class AsyncRedisFlagCache(FlagDefinitionCacheProvider): + """ + A distributed cache for PostHog feature flag definitions using redis.asyncio. + + In a multi-instance deployment, only one instance should poll PostHog for + flag updates while all instances share the cached results. This prevents N + instances from making N redundant API calls. + + The implementation uses leader election: + - One instance "wins" and becomes responsible for fetching + - Other instances read from the shared cache + - If the leader dies, the lock expires and another instance takes over + + Uses Lua scripts for atomic operations, following Redis distributed lock best + practices: https://redis.io/docs/latest/develop/clients/patterns/distributed-locks/ + """ + + LOCK_TTL_MS = 60 * 1000 # 60 seconds, should be longer than the flags poll interval + CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours + + # Lua script: acquire lock if free, or extend if we own it + _LUA_TRY_LEAD = """ + local current = redis.call('GET', KEYS[1]) + if current == false then + redis.call('SET', KEYS[1], ARGV[1], 'PX', ARGV[2]) + return 1 + elseif current == ARGV[1] then + redis.call('PEXPIRE', KEYS[1], ARGV[2]) + return 1 + end + return 0 + """ + + # Lua script: release lock only if we own it + _LUA_STOP_LEAD = """ + if redis.call('GET', KEYS[1]) == ARGV[1] then + return redis.call('DEL', KEYS[1]) + end + return 0 + """ + + def __init__(self, redis: Redis, service_key: str): + """ + Initialize the async Redis flag cache. + + Args: + redis: A redis.asyncio client instance dedicated to this cache provider. + The SDK runs async provider methods on its own background event + loop, so avoid sharing the same asyncio Redis client with a + different application event loop. Configure decode_responses=True + for string responses, or bytes responses will be decoded here. + service_key: A unique identifier for this service/environment. + Used to scope Redis keys, allowing multiple services + or environments to share the same Redis instance. + """ + self._redis = redis + self._cache_key = f"posthog:flags:{service_key}" + self._lock_key = f"posthog:flags:{service_key}:lock" + self._instance_id = str(uuid.uuid4()) + + async def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]: + """ + Retrieve cached flag definitions from Redis. + + Returns: + Cached flag definitions if available, None otherwise. + """ + cached = await self._redis.get(self._cache_key) + if not cached: + return None + if isinstance(cached, bytes): + cached = cached.decode("utf-8") + return json.loads(cached) + + async def should_fetch_flag_definitions(self) -> bool: + """ + Determines if this instance should fetch flag definitions from PostHog. + + Atomically either: + - Acquires the lock if no one holds it, OR + - Extends the lock TTL if we already hold it + + Returns: + True if this instance is the leader and should fetch, False otherwise. + """ + result = await self._redis.eval( + self._LUA_TRY_LEAD, + 1, + self._lock_key, + self._instance_id, + self.LOCK_TTL_MS, + ) + return result == 1 + + async def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None: + """ + Store fetched flag definitions in Redis. + + Args: + data: The flag definitions to cache. + """ + await self._redis.set( + self._cache_key, json.dumps(data), ex=self.CACHE_TTL_SECONDS + ) + + async def shutdown(self) -> None: + """ + Release leadership if we hold it. Safe to call even if not the leader. + """ + await self._redis.eval( + self._LUA_STOP_LEAD, + 1, + self._lock_key, + self._instance_id, + ) diff --git a/posthog/_async_utils.py b/posthog/_async_utils.py new file mode 100644 index 00000000..8bc076a1 --- /dev/null +++ b/posthog/_async_utils.py @@ -0,0 +1,85 @@ +import asyncio +import threading +from collections.abc import Awaitable +from typing import Any + + +class _BackgroundEventLoopRunner: + """Run awaitables to completion on a reusable background event loop.""" + + def __init__(self) -> None: + self._loop: asyncio.AbstractEventLoop | None = None + self._thread: threading.Thread | None = None + self._started = threading.Event() + self._lock = threading.Lock() + + def run(self, awaitable: Awaitable[Any]) -> Any: + loop = self._ensure_loop() + future = asyncio.run_coroutine_threadsafe(self._await_result(awaitable), loop) + return future.result() + + def close(self) -> None: + with self._lock: + loop = self._loop + thread = self._thread + self._loop = None + self._thread = None + + if loop is None or thread is None or loop.is_closed(): + return + + if thread is threading.current_thread(): + loop.call_soon(loop.stop) + return + + loop.call_soon_threadsafe(loop.stop) + thread.join() + + @staticmethod + async def _await_result(awaitable: Awaitable[Any]) -> Any: + return await awaitable + + def _ensure_loop(self) -> asyncio.AbstractEventLoop: + with self._lock: + if ( + self._loop is not None + and self._thread is not None + and self._thread.is_alive() + and not self._loop.is_closed() + ): + return self._loop + + self._started.clear() + self._thread = threading.Thread( + target=self._run_loop, + name="PostHogBackgroundEventLoopRunner", + daemon=True, + ) + self._thread.start() + + self._started.wait() + with self._lock: + assert self._loop is not None + return self._loop + + def _run_loop(self) -> None: + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + with self._lock: + self._loop = loop + self._started.set() + + try: + loop.run_forever() + finally: + pending = asyncio.all_tasks(loop) + for task in pending: + task.cancel() + if pending: + loop.run_until_complete( + asyncio.gather(*pending, return_exceptions=True) + ) + loop.run_until_complete(loop.shutdown_asyncgens()) + loop.run_until_complete(loop.shutdown_default_executor()) + asyncio.set_event_loop(None) + loop.close() diff --git a/posthog/client.py b/posthog/client.py index dc1c6f8c..42a826ac 100644 --- a/posthog/client.py +++ b/posthog/client.py @@ -1,8 +1,10 @@ import atexit +import inspect import json import logging import os import sys +import threading import warnings import weakref from datetime import datetime, timedelta, timezone @@ -11,6 +13,7 @@ from typing_extensions import Unpack +from posthog._async_utils import _BackgroundEventLoopRunner from posthog.args import ID_TYPES, ExceptionArg, OptionalCaptureArgs, OptionalSetArgs from posthog.consumer import Consumer from posthog.contexts import ( @@ -305,6 +308,10 @@ def __init__( self.flag_definition_version = 0 self._flags_etag: Optional[str] = None self._flag_definition_cache_provider = flag_definition_cache_provider + self._flag_definition_cache_provider_async_runner: Optional[ + _BackgroundEventLoopRunner + ] = None + self._flag_definition_cache_provider_async_runner_lock = threading.Lock() self.disabled = disabled or not self.api_key self.disable_geoip = disable_geoip self.historical_migration = historical_migration @@ -1246,6 +1253,10 @@ def _reinit_after_fork(self): else: self.poller = None + # Async runner threads do not survive fork(); recreate lazily on next async cache call. + self._flag_definition_cache_provider_async_runner = None + self._flag_definition_cache_provider_async_runner_lock = threading.Lock() + # If using Redis cache, we must reinitialize to get a fresh connection (fork-safe). # If using Memory cache, we keep it as-is to benefit from the inherited warm cache. if isinstance(self.flag_cache, RedisFlagCache): @@ -1373,11 +1384,7 @@ def join(self): self.poller.stop() # Shutdown the cache provider (release locks, cleanup) - if self._flag_definition_cache_provider: - try: - self._flag_definition_cache_provider.shutdown() - except Exception as e: - self.log.error(f"[FEATURE FLAGS] Cache provider shutdown error: {e}") + self._shutdown_flag_definition_cache_provider() def shutdown(self): """ @@ -1394,6 +1401,33 @@ def shutdown(self): if self.exception_capture: self.exception_capture.close() + def _resolve_flag_definition_cache_provider_result(self, result): + if not inspect.isawaitable(result): + return result + + with self._flag_definition_cache_provider_async_runner_lock: + if self._flag_definition_cache_provider_async_runner is None: + self._flag_definition_cache_provider_async_runner = ( + _BackgroundEventLoopRunner() + ) + return self._flag_definition_cache_provider_async_runner.run(result) + + def _shutdown_flag_definition_cache_provider(self): + if not self._flag_definition_cache_provider: + return + + try: + self._resolve_flag_definition_cache_provider_result( + self._flag_definition_cache_provider.shutdown() + ) + except Exception as e: + self.log.error(f"[FEATURE FLAGS] Cache provider shutdown error: {e}") + finally: + with self._flag_definition_cache_provider_async_runner_lock: + if self._flag_definition_cache_provider_async_runner: + self._flag_definition_cache_provider_async_runner.close() + self._flag_definition_cache_provider_async_runner = None + def _update_flag_state( self, data: FlagDefinitionCacheData, old_flags_by_key: Optional[dict] = None ) -> None: @@ -1416,7 +1450,7 @@ def _load_feature_flags(self): should_fetch = True if self._flag_definition_cache_provider: try: - should_fetch = ( + should_fetch = self._resolve_flag_definition_cache_provider_result( self._flag_definition_cache_provider.should_fetch_flag_definitions() ) except Exception as e: @@ -1429,7 +1463,7 @@ def _load_feature_flags(self): # If not fetching, try to get from cache if not should_fetch and self._flag_definition_cache_provider: try: - cached_data = ( + cached_data = self._resolve_flag_definition_cache_provider_result( self._flag_definition_cache_provider.get_flag_definitions() ) if cached_data: @@ -1500,12 +1534,14 @@ def _fetch_feature_flags_from_api(self): # Store in external cache if provider is configured if self._flag_definition_cache_provider: try: - self._flag_definition_cache_provider.on_flag_definitions_received( - { - "flags": self.feature_flags or [], - "group_type_mapping": self.group_type_mapping or {}, - "cohorts": self.cohorts or {}, - } + self._resolve_flag_definition_cache_provider_result( + self._flag_definition_cache_provider.on_flag_definitions_received( + { + "flags": self.feature_flags or [], + "group_type_mapping": self.group_type_mapping or {}, + "cohorts": self.cohorts or {}, + } + ) ) except Exception as e: self.log.error(f"[FEATURE FLAGS] Cache provider store error: {e}") diff --git a/posthog/flag_definition_cache.py b/posthog/flag_definition_cache.py index 330bbd45..0a55680a 100644 --- a/posthog/flag_definition_cache.py +++ b/posthog/flag_definition_cache.py @@ -20,7 +20,16 @@ ) """ -from typing import Any, Dict, List, Optional, Protocol, runtime_checkable +from typing import ( + Any, + Awaitable, + Dict, + List, + Optional, + Protocol, + Union, + runtime_checkable, +) from typing_extensions import Required, TypedDict @@ -50,6 +59,10 @@ class FlagDefinitionCacheProvider(Protocol): EXPERIMENTAL: This API may change in future minor version bumps. + Methods may be implemented as either synchronous functions or async + functions. If a method returns an awaitable, the SDK runs it to completion + before continuing. + The four methods handle the complete lifecycle of flag definition caching: 1. `should_fetch_flag_definitions()` - Called before each poll to determine @@ -75,18 +88,23 @@ class FlagDefinitionCacheProvider(Protocol): - `shutdown()` errors are logged but shutdown continues """ - def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]: + def get_flag_definitions( + self, + ) -> Union[ + Optional[FlagDefinitionCacheData], Awaitable[Optional[FlagDefinitionCacheData]] + ]: """ Retrieve cached flag definitions. Returns: Cached flag definitions if available and valid, None otherwise. - Returning None will trigger a fetch from the API if this worker - has no flags loaded yet. + May return an awaitable resolving to the same value. Returning None + will trigger a fetch from the API if this worker has no flags loaded + yet. """ ... - def should_fetch_flag_definitions(self) -> bool: + def should_fetch_flag_definitions(self) -> Union[bool, Awaitable[bool]]: """ Determine whether this instance should fetch new flag definitions. @@ -97,12 +115,15 @@ def should_fetch_flag_definitions(self) -> bool: Returns: True if this instance should fetch from the API, False otherwise. - When False, the client will call `get_flag_definitions()` to - retrieve cached data instead. + May return an awaitable resolving to the same value. When False, the + client will call `get_flag_definitions()` to retrieve cached data + instead. """ ... - def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None: + def on_flag_definitions_received( + self, data: FlagDefinitionCacheData + ) -> Optional[Awaitable[None]]: """ Called after successfully receiving new flag definitions from PostHog. @@ -115,7 +136,7 @@ def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None: """ ... - def shutdown(self) -> None: + def shutdown(self) -> Optional[Awaitable[None]]: """ Called when the PostHog client shuts down. diff --git a/posthog/test/test_flag_definition_cache.py b/posthog/test/test_flag_definition_cache.py index ee17eb9a..829f8454 100644 --- a/posthog/test/test_flag_definition_cache.py +++ b/posthog/test/test_flag_definition_cache.py @@ -4,6 +4,7 @@ These tests follow the patterns from the TypeScript implementation in posthog-js/packages/node. """ +import asyncio import threading import unittest from typing import Optional @@ -57,6 +58,36 @@ def shutdown(self) -> None: raise self.shutdown_error +class AsyncMockCacheProvider(MockCacheProvider): + """An async implementation of FlagDefinitionCacheProvider for testing.""" + + def __init__(self): + super().__init__() + self.loop_ids = [] + self.thread_ids = [] + + async def _record_async_call(self): + self.loop_ids.append(id(asyncio.get_running_loop())) + self.thread_ids.append(threading.get_ident()) + await asyncio.sleep(0) + + async def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]: + await self._record_async_call() + return super().get_flag_definitions() + + async def should_fetch_flag_definitions(self) -> bool: + await self._record_async_call() + return super().should_fetch_flag_definitions() + + async def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None: + await self._record_async_call() + super().on_flag_definitions_received(data) + + async def shutdown(self) -> None: + await self._record_async_call() + super().shutdown() + + class TestFlagDefinitionCacheProvider(unittest.TestCase): """Tests for the FlagDefinitionCacheProvider protocol.""" @@ -369,6 +400,155 @@ def test_shutdown_error_is_logged_but_continues(self, mock_get): self.assertEqual(self.cache_provider.shutdown_call_count, 1) +class TestAsyncCacheProvider(TestFlagDefinitionCacheProvider): + """Tests for async FlagDefinitionCacheProvider implementations.""" + + def setUp(self): + super().setUp() + self.cache_provider = AsyncMockCacheProvider() + + @mock.patch("posthog.client.get") + def test_awaits_async_provider_for_cached_data(self, mock_get): + """Async providers can serve cached data without an API fetch.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.stored_data = self.sample_flags_data + + client = self._create_client_with_cache() + client._load_feature_flags() + + mock_get.assert_not_called() + self.assertEqual(self.cache_provider.should_fetch_call_count, 1) + self.assertEqual(self.cache_provider.get_call_count, 1) + self.assertEqual(len(client.feature_flags), 2) + self.assertEqual(client.feature_flags[0]["key"], "test-flag") + self.assertEqual(len(set(self.cache_provider.loop_ids)), 1) + + client.join() + + @mock.patch("posthog.client.get") + def test_awaits_async_provider_when_fetching_from_api(self, mock_get): + """Async should_fetch and on_received methods are awaited.""" + self.cache_provider.should_fetch_return_value = True + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + mock_get.assert_called_once() + self.assertEqual(self.cache_provider.should_fetch_call_count, 1) + self.assertEqual(self.cache_provider.get_call_count, 0) + self.assertEqual(self.cache_provider.on_received_call_count, 1) + self.assertEqual(self.cache_provider.stored_data, self.sample_flags_data) + self.assertEqual(len(set(self.cache_provider.loop_ids)), 1) + + client.join() + + @mock.patch("posthog.client.get") + def test_reuses_async_provider_loop_across_polls_and_shutdown(self, mock_get): + """All async cache provider methods run on one reusable event loop.""" + self.cache_provider.should_fetch_return_value = True + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + client._load_feature_flags() + client.join() + + self.assertEqual(self.cache_provider.should_fetch_call_count, 2) + self.assertEqual(self.cache_provider.on_received_call_count, 2) + self.assertEqual(self.cache_provider.shutdown_call_count, 1) + self.assertEqual(len(set(self.cache_provider.loop_ids)), 1) + self.assertEqual(len(set(self.cache_provider.thread_ids)), 1) + + @mock.patch("posthog.client.get") + def test_async_provider_can_load_while_caller_loop_is_running(self, mock_get): + """Async providers work even when called from an async app context.""" + self.cache_provider.should_fetch_return_value = True + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + + async def load_flags(): + client._load_feature_flags() + + asyncio.run(load_flags()) + + self.assertEqual(len(client.feature_flags), 2) + self.assertEqual(self.cache_provider.on_received_call_count, 1) + + client.join() + + @mock.patch("posthog.client.get") + def test_async_should_fetch_error_defaults_to_fetching(self, mock_get): + """Async should_fetch errors fail open to an API fetch.""" + self.cache_provider.should_fetch_error = Exception("Lock acquisition failed") + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + mock_get.assert_called_once() + self.assertEqual(len(client.feature_flags), 2) + + client.join() + + @mock.patch("posthog.client.get") + def test_async_get_error_falls_back_to_api_fetch(self, mock_get): + """Async get_flag_definitions errors fail open to an API fetch.""" + self.cache_provider.should_fetch_return_value = False + self.cache_provider.get_error = Exception("Cache read failed") + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + mock_get.assert_called_once() + self.assertEqual(len(client.feature_flags), 2) + + client.join() + + @mock.patch("posthog.client.get") + def test_async_on_received_error_keeps_flags_in_memory(self, mock_get): + """Async store errors do not discard freshly fetched flag definitions.""" + self.cache_provider.should_fetch_return_value = True + self.cache_provider.on_received_error = Exception("Cache write failed") + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + + self.assertEqual(len(client.feature_flags), 2) + self.assertEqual(client.feature_flags[0]["key"], "test-flag") + + client.join() + + @mock.patch("posthog.client.get") + def test_async_shutdown_error_is_logged_but_continues(self, mock_get): + """Async shutdown errors are logged and do not escape join().""" + self.cache_provider.shutdown_error = Exception("Async lock release failed") + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + client._load_feature_flags() + client.join() + + self.assertEqual(self.cache_provider.shutdown_call_count, 1) + + class TestShutdownLifecycle(TestFlagDefinitionCacheProvider): """Tests for shutdown lifecycle.""" @@ -588,6 +768,38 @@ def load_flags(): client.join() + @mock.patch("posthog.client.get") + def test_concurrent_async_load_feature_flags_uses_single_runner(self, mock_get): + """Concurrent async provider calls share one runner loop and shut down cleanly.""" + self.cache_provider = AsyncMockCacheProvider() + mock_get.return_value = GetResponse( + data=self.sample_flags_data, etag="test-etag", not_modified=False + ) + + client = self._create_client_with_cache() + errors = [] + barrier = threading.Barrier(10) + + def load_flags(): + try: + barrier.wait() + client._load_feature_flags() + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=load_flags) for _ in range(10)] + for t in threads: + t.start() + for t in threads: + t.join() + + self.assertEqual(len(errors), 0, f"Unexpected errors: {errors}") + self.assertEqual(len(client.feature_flags), 2) + self.assertEqual(len(set(self.cache_provider.loop_ids)), 1) + self.assertEqual(len(set(self.cache_provider.thread_ids)), 1) + + client.join() + class TestProtocolCompliance(unittest.TestCase): """Tests for Protocol compliance.""" @@ -597,6 +809,11 @@ def test_mock_provider_is_protocol_instance(self): provider = MockCacheProvider() self.assertIsInstance(provider, FlagDefinitionCacheProvider) + def test_async_mock_provider_is_protocol_instance(self): + """AsyncMockCacheProvider satisfies FlagDefinitionCacheProvider protocol.""" + provider = AsyncMockCacheProvider() + self.assertIsInstance(provider, FlagDefinitionCacheProvider) + def test_incomplete_provider_is_not_protocol_instance(self): """Class missing methods is not a FlagDefinitionCacheProvider."""