Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions sdk/src/opendecree/async_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ def __init__(
credentials: grpc.ChannelCredentials | None = None,
timeout: float = 10.0,
retry: RetryConfig | None = None,
check_version: bool = False,
) -> None:
"""Create a new AsyncConfigClient.

Expand All @@ -72,9 +73,14 @@ def __init__(
timeout: Default per-RPC timeout in seconds. Defaults to 10.
retry: Retry configuration. Defaults to ``RetryConfig()``.
Pass ``None`` to disable retry.
check_version: When True, run :meth:`check_compatibility` lazily
on the first RPC call. Raises :exc:`IncompatibleServerError`
if the server version is outside the supported range.
"""
self._timeout = timeout
self._retry = retry if retry is not None else RetryConfig()
self._check_version = check_version
self._version_checked = False

tls_active = credentials is not None or not insecure
if token and not tls_active:
Expand Down Expand Up @@ -151,6 +157,11 @@ async def check_compatibility(self) -> None:
sv = await self.get_server_version()
check_version_compatible(sv.version)

async def _ensure_version_checked(self) -> None:
if self._check_version and not self._version_checked:
self._version_checked = True
await self.check_compatibility()

def _metadata(self) -> list[tuple[str, str]]:
"""Return auth metadata for each call."""
return list(self._auth_metadata)
Expand Down Expand Up @@ -211,6 +222,7 @@ async def get(
TypeMismatchError: If the value cannot be converted to the requested type.
"""
target_type = value_type or str
await self._ensure_version_checked()

async def _call() -> object:
resp = await self._stub.GetField(
Expand Down Expand Up @@ -238,6 +250,8 @@ async def get_all(self, tenant_id: str) -> dict[str, str]:
NotFoundError: If the tenant does not exist.
"""

await self._ensure_version_checked()

async def _call() -> dict[str, str]:
resp = await self._stub.GetConfig(
self._pb2.GetConfigRequest(tenant_id=tenant_id),
Expand Down Expand Up @@ -289,6 +303,7 @@ async def set(
ChecksumMismatchError: If ``expected_checksum`` is set and does not match.
"""
retry_cfg = self._retry if idempotency_key is not None else write_safe_config(self._retry)
await self._ensure_version_checked()

async def _call() -> None:
await self._stub.SetField(
Expand Down Expand Up @@ -335,6 +350,7 @@ async def set_many(
ChecksumMismatchError: If any ``expected_checksum`` does not match.
"""
retry_cfg = self._retry if idempotency_key is not None else write_safe_config(self._retry)
await self._ensure_version_checked()

async def _call() -> None:
proto_updates = [
Expand Down Expand Up @@ -382,6 +398,7 @@ async def set_null(
LockedError: If the field is locked.
"""
retry_cfg = self._retry if idempotency_key is not None else write_safe_config(self._retry)
await self._ensure_version_checked()

async def _call() -> None:
await self._stub.SetField(
Expand Down
17 changes: 17 additions & 0 deletions sdk/src/opendecree/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ def __init__(
credentials: grpc.ChannelCredentials | None = None,
timeout: float = 10.0,
retry: RetryConfig | None = None,
check_version: bool = False,
) -> None:
"""Create a new ConfigClient.

Expand All @@ -75,9 +76,14 @@ def __init__(
timeout: Default per-RPC timeout in seconds. Defaults to 10.
retry: Retry configuration. Defaults to ``RetryConfig()``.
Pass ``None`` to disable retry.
check_version: When True, run :meth:`check_compatibility` lazily
on the first RPC call. Raises :exc:`IncompatibleServerError`
if the server version is outside the supported range.
"""
self._timeout = timeout
self._retry = retry if retry is not None else RetryConfig()
self._check_version = check_version
self._version_checked = False

tls_active = credentials is not None or not insecure
if token and not tls_active:
Expand Down Expand Up @@ -170,6 +176,11 @@ def check_compatibility(self) -> None:
"""
check_version_compatible(self.get_server_version().version)

def _ensure_version_checked(self) -> None:
if self._check_version and not self._version_checked:
self._version_checked = True
self.check_compatibility()

# --- get() with @overload for type safety ---

@overload
Expand Down Expand Up @@ -224,6 +235,7 @@ def get(
TypeMismatchError: If the value cannot be converted to the requested type.
"""
target_type = value_type or str
self._ensure_version_checked()

def _call() -> object:
resp = self._stub.GetField(
Expand All @@ -250,6 +262,8 @@ def get_all(self, tenant_id: str) -> dict[str, str]:
NotFoundError: If the tenant does not exist.
"""

self._ensure_version_checked()

def _call() -> dict[str, str]:
resp = self._stub.GetConfig(
self._pb2.GetConfigRequest(tenant_id=tenant_id),
Expand Down Expand Up @@ -300,6 +314,7 @@ def set(
ChecksumMismatchError: If ``expected_checksum`` is set and does not match.
"""
retry_cfg = self._retry if idempotency_key is not None else write_safe_config(self._retry)
self._ensure_version_checked()

def _call() -> None:
self._stub.SetField(
Expand Down Expand Up @@ -345,6 +360,7 @@ def set_many(
ChecksumMismatchError: If any ``expected_checksum`` does not match.
"""
retry_cfg = self._retry if idempotency_key is not None else write_safe_config(self._retry)
self._ensure_version_checked()

def _call() -> None:
proto_updates = [
Expand Down Expand Up @@ -391,6 +407,7 @@ def set_null(
LockedError: If the field is locked.
"""
retry_cfg = self._retry if idempotency_key is not None else write_safe_config(self._retry)
self._ensure_version_checked()

def _call() -> None:
self._stub.SetField(
Expand Down
82 changes: 82 additions & 0 deletions sdk/tests/test_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,3 +212,85 @@ def test_client_check_compatibility_fails():

with pytest.raises(IncompatibleServerError):
client.check_compatibility()


# --- check_version ctor flag ---


def _make_client_with_version(server_version: str, check_version: bool = True):
"""Return a ConfigClient.__new__ instance wired with a mock server version."""
from opendecree import ConfigClient

client = ConfigClient.__new__(ConfigClient)
client._timeout = 5.0
client._check_version = check_version
client._version_checked = False
client._server_version = ServerVersion(version=server_version, commit="abc")
client._version_stub = MagicMock()
client._version_pb2 = MagicMock()
return client


def test_ensure_version_checked_runs_once():
client = _make_client_with_version("0.3.1")
with patch.object(client, "check_compatibility") as mock_check:
client._ensure_version_checked()
client._ensure_version_checked()
mock_check.assert_called_once()


def test_ensure_version_checked_noop_when_disabled():
client = _make_client_with_version("0.3.1", check_version=False)
with patch.object(client, "check_compatibility") as mock_check:
client._ensure_version_checked()
mock_check.assert_not_called()


def test_ensure_version_checked_raises_on_incompatible():
client = _make_client_with_version("0.1.0")
with pytest.raises(IncompatibleServerError):
client._ensure_version_checked()


@pytest.mark.asyncio
async def test_async_ensure_version_checked_runs_once():
from opendecree import AsyncConfigClient

client = AsyncConfigClient.__new__(AsyncConfigClient)
client._timeout = 5.0
client._check_version = True
client._version_checked = False
client._server_version = ServerVersion(version="0.3.1", commit="abc")
client._version_stub = MagicMock()
client._version_pb2 = MagicMock()

call_count = 0

async def fake_check():
nonlocal call_count
call_count += 1

client.check_compatibility = fake_check
await client._ensure_version_checked()
await client._ensure_version_checked()
assert call_count == 1


@pytest.mark.asyncio
async def test_async_ensure_version_checked_noop_when_disabled():
from opendecree import AsyncConfigClient

client = AsyncConfigClient.__new__(AsyncConfigClient)
client._timeout = 5.0
client._check_version = False
client._version_checked = False

called = False

async def fake_check():
nonlocal called
called = True

client.check_compatibility = fake_check
await client._ensure_version_checked()
assert not called