Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .sampo/changesets/roguish-stormcaller-tuonetar.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
pypi/posthog: minor
---

Add async flag definition cache providers
149 changes: 149 additions & 0 deletions examples/async_redis_flag_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
"""
Async Redis-based distributed cache for PostHog feature flag definitions.

This example demonstrates how to implement a FlagDefinitionCacheProvider with
redis.asyncio for async-first applications. The PostHog SDK accepts async cache
provider methods and runs their awaitables to completion before continuing.

Usage:
import redis.asyncio as redis
from posthog import Posthog

# Use a Redis client dedicated to this cache provider. The SDK runs async
# provider methods on its own background event loop.
redis_client = redis.Redis(host="localhost", port=6379, decode_responses=True)
cache = AsyncRedisFlagCache(redis_client, service_key="my-service")

posthog = Posthog(
"<project_api_key>",
personal_api_key="<personal_api_key>",
flag_definition_cache_provider=cache,
)

Requirements:
pip install redis
"""

import json
import uuid
from typing import Optional

from posthog import FlagDefinitionCacheData, FlagDefinitionCacheProvider
from redis.asyncio import Redis


class AsyncRedisFlagCache(FlagDefinitionCacheProvider):
"""
A distributed cache for PostHog feature flag definitions using redis.asyncio.

In a multi-instance deployment, only one instance should poll PostHog for
flag updates while all instances share the cached results. This prevents N
instances from making N redundant API calls.

The implementation uses leader election:
- One instance "wins" and becomes responsible for fetching
- Other instances read from the shared cache
- If the leader dies, the lock expires and another instance takes over

Uses Lua scripts for atomic operations, following Redis distributed lock best
practices: https://redis.io/docs/latest/develop/clients/patterns/distributed-locks/
"""

LOCK_TTL_MS = 60 * 1000 # 60 seconds, should be longer than the flags poll interval
CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours

# Lua script: acquire lock if free, or extend if we own it
_LUA_TRY_LEAD = """
local current = redis.call('GET', KEYS[1])
if current == false then
redis.call('SET', KEYS[1], ARGV[1], 'PX', ARGV[2])
return 1
elseif current == ARGV[1] then
redis.call('PEXPIRE', KEYS[1], ARGV[2])
return 1
end
return 0
"""

# Lua script: release lock only if we own it
_LUA_STOP_LEAD = """
if redis.call('GET', KEYS[1]) == ARGV[1] then
return redis.call('DEL', KEYS[1])
end
return 0
"""

def __init__(self, redis: Redis, service_key: str):
"""
Initialize the async Redis flag cache.

Args:
redis: A redis.asyncio client instance dedicated to this cache provider.
The SDK runs async provider methods on its own background event
loop, so avoid sharing the same asyncio Redis client with a
different application event loop. Configure decode_responses=True
for string responses, or bytes responses will be decoded here.
service_key: A unique identifier for this service/environment.
Used to scope Redis keys, allowing multiple services
or environments to share the same Redis instance.
"""
self._redis = redis
self._cache_key = f"posthog:flags:{service_key}"
self._lock_key = f"posthog:flags:{service_key}:lock"
self._instance_id = str(uuid.uuid4())

async def get_flag_definitions(self) -> Optional[FlagDefinitionCacheData]:
"""
Retrieve cached flag definitions from Redis.

Returns:
Cached flag definitions if available, None otherwise.
"""
cached = await self._redis.get(self._cache_key)
if not cached:
return None
if isinstance(cached, bytes):
cached = cached.decode("utf-8")
return json.loads(cached)

async def should_fetch_flag_definitions(self) -> bool:
"""
Determines if this instance should fetch flag definitions from PostHog.

Atomically either:
- Acquires the lock if no one holds it, OR
- Extends the lock TTL if we already hold it

Returns:
True if this instance is the leader and should fetch, False otherwise.
"""
result = await self._redis.eval(
self._LUA_TRY_LEAD,
1,
self._lock_key,
self._instance_id,
self.LOCK_TTL_MS,
)
return result == 1

async def on_flag_definitions_received(self, data: FlagDefinitionCacheData) -> None:
"""
Store fetched flag definitions in Redis.

Args:
data: The flag definitions to cache.
"""
await self._redis.set(
self._cache_key, json.dumps(data), ex=self.CACHE_TTL_SECONDS
)

async def shutdown(self) -> None:
"""
Release leadership if we hold it. Safe to call even if not the leader.
"""
await self._redis.eval(
self._LUA_STOP_LEAD,
1,
self._lock_key,
self._instance_id,
)
85 changes: 85 additions & 0 deletions posthog/_async_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
import asyncio
import threading
from collections.abc import Awaitable
from typing import Any


class _BackgroundEventLoopRunner:
"""Run awaitables to completion on a reusable background event loop."""

def __init__(self) -> None:
self._loop: asyncio.AbstractEventLoop | None = None
self._thread: threading.Thread | None = None
self._started = threading.Event()
self._lock = threading.Lock()

def run(self, awaitable: Awaitable[Any]) -> Any:
loop = self._ensure_loop()
future = asyncio.run_coroutine_threadsafe(self._await_result(awaitable), loop)
return future.result()

def close(self) -> None:
with self._lock:
loop = self._loop
thread = self._thread
self._loop = None
self._thread = None

if loop is None or thread is None or loop.is_closed():
return

if thread is threading.current_thread():
loop.call_soon(loop.stop)
return

loop.call_soon_threadsafe(loop.stop)
thread.join()

@staticmethod
async def _await_result(awaitable: Awaitable[Any]) -> Any:
return await awaitable

def _ensure_loop(self) -> asyncio.AbstractEventLoop:
with self._lock:
if (
self._loop is not None
and self._thread is not None
and self._thread.is_alive()
and not self._loop.is_closed()
):
return self._loop

self._started.clear()
self._thread = threading.Thread(
target=self._run_loop,
name="PostHogBackgroundEventLoopRunner",
daemon=True,
)
self._thread.start()

self._started.wait()
with self._lock:
assert self._loop is not None
return self._loop

def _run_loop(self) -> None:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
with self._lock:
self._loop = loop
self._started.set()

try:
loop.run_forever()
finally:
pending = asyncio.all_tasks(loop)
for task in pending:
task.cancel()
if pending:
loop.run_until_complete(
asyncio.gather(*pending, return_exceptions=True)
)
loop.run_until_complete(loop.shutdown_asyncgens())
loop.run_until_complete(loop.shutdown_default_executor())
asyncio.set_event_loop(None)
loop.close()
62 changes: 49 additions & 13 deletions posthog/client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import atexit
import inspect
import json
import logging
import os
import sys
import threading
import warnings
import weakref
from datetime import datetime, timedelta, timezone
Expand All @@ -11,6 +13,7 @@

from typing_extensions import Unpack

from posthog._async_utils import _BackgroundEventLoopRunner
from posthog.args import ID_TYPES, ExceptionArg, OptionalCaptureArgs, OptionalSetArgs
from posthog.consumer import Consumer
from posthog.contexts import (
Expand Down Expand Up @@ -305,6 +308,10 @@ def __init__(
self.flag_definition_version = 0
self._flags_etag: Optional[str] = None
self._flag_definition_cache_provider = flag_definition_cache_provider
self._flag_definition_cache_provider_async_runner: Optional[
_BackgroundEventLoopRunner
] = None
self._flag_definition_cache_provider_async_runner_lock = threading.Lock()
self.disabled = disabled or not self.api_key
self.disable_geoip = disable_geoip
self.historical_migration = historical_migration
Expand Down Expand Up @@ -1246,6 +1253,10 @@ def _reinit_after_fork(self):
else:
self.poller = None

# Async runner threads do not survive fork(); recreate lazily on next async cache call.
self._flag_definition_cache_provider_async_runner = None
self._flag_definition_cache_provider_async_runner_lock = threading.Lock()

# If using Redis cache, we must reinitialize to get a fresh connection (fork-safe).
# If using Memory cache, we keep it as-is to benefit from the inherited warm cache.
if isinstance(self.flag_cache, RedisFlagCache):
Expand Down Expand Up @@ -1373,11 +1384,7 @@ def join(self):
self.poller.stop()

# Shutdown the cache provider (release locks, cleanup)
if self._flag_definition_cache_provider:
try:
self._flag_definition_cache_provider.shutdown()
except Exception as e:
self.log.error(f"[FEATURE FLAGS] Cache provider shutdown error: {e}")
self._shutdown_flag_definition_cache_provider()

def shutdown(self):
"""
Expand All @@ -1394,6 +1401,33 @@ def shutdown(self):
if self.exception_capture:
self.exception_capture.close()

def _resolve_flag_definition_cache_provider_result(self, result):
if not inspect.isawaitable(result):
return result

with self._flag_definition_cache_provider_async_runner_lock:
if self._flag_definition_cache_provider_async_runner is None:
self._flag_definition_cache_provider_async_runner = (
_BackgroundEventLoopRunner()
)
return self._flag_definition_cache_provider_async_runner.run(result)
Comment thread
dustinbyrne marked this conversation as resolved.

def _shutdown_flag_definition_cache_provider(self):
if not self._flag_definition_cache_provider:
return

try:
self._resolve_flag_definition_cache_provider_result(
self._flag_definition_cache_provider.shutdown()
)
except Exception as e:
self.log.error(f"[FEATURE FLAGS] Cache provider shutdown error: {e}")
finally:
with self._flag_definition_cache_provider_async_runner_lock:
if self._flag_definition_cache_provider_async_runner:
self._flag_definition_cache_provider_async_runner.close()
self._flag_definition_cache_provider_async_runner = None

def _update_flag_state(
self, data: FlagDefinitionCacheData, old_flags_by_key: Optional[dict] = None
) -> None:
Expand All @@ -1416,7 +1450,7 @@ def _load_feature_flags(self):
should_fetch = True
if self._flag_definition_cache_provider:
try:
should_fetch = (
should_fetch = self._resolve_flag_definition_cache_provider_result(
self._flag_definition_cache_provider.should_fetch_flag_definitions()
)
except Exception as e:
Expand All @@ -1429,7 +1463,7 @@ def _load_feature_flags(self):
# If not fetching, try to get from cache
if not should_fetch and self._flag_definition_cache_provider:
try:
cached_data = (
cached_data = self._resolve_flag_definition_cache_provider_result(
self._flag_definition_cache_provider.get_flag_definitions()
)
if cached_data:
Expand Down Expand Up @@ -1500,12 +1534,14 @@ def _fetch_feature_flags_from_api(self):
# Store in external cache if provider is configured
if self._flag_definition_cache_provider:
try:
self._flag_definition_cache_provider.on_flag_definitions_received(
{
"flags": self.feature_flags or [],
"group_type_mapping": self.group_type_mapping or {},
"cohorts": self.cohorts or {},
}
self._resolve_flag_definition_cache_provider_result(
self._flag_definition_cache_provider.on_flag_definitions_received(
{
"flags": self.feature_flags or [],
"group_type_mapping": self.group_type_mapping or {},
"cohorts": self.cohorts or {},
}
)
)
except Exception as e:
self.log.error(f"[FEATURE FLAGS] Cache provider store error: {e}")
Expand Down
Loading
Loading