Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
84 changes: 72 additions & 12 deletions sdk/src/opendecree/_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,34 @@

import grpc

# Default channel options for keepalive and reconnection.
_DEFAULT_OPTIONS: list[tuple[str, int]] = [
("grpc.keepalive_time_ms", 30000),
("grpc.keepalive_timeout_ms", 10000),
("grpc.keepalive_permit_without_calls", 1),
("grpc.initial_reconnect_backoff_ms", 1000),
("grpc.max_reconnect_backoff_ms", 30000),
]
_DEFAULT_KEEPALIVE_TIME_MS = 30000
_DEFAULT_KEEPALIVE_TIMEOUT_MS = 10000
_DEFAULT_KEEPALIVE_PERMIT_WITHOUT_CALLS = 1
_DEFAULT_RECONNECT_BACKOFF_INITIAL_MS = 1000
_DEFAULT_RECONNECT_BACKOFF_MAX_MS = 30000


def _build_options(
max_send_message_length: int | None,
max_recv_message_length: int | None,
keepalive_time_ms: int,
keepalive_timeout_ms: int,
keepalive_permit_without_calls: int,
reconnect_backoff_initial_ms: int,
reconnect_backoff_max_ms: int,
) -> list[tuple[str, int]]:
opts: list[tuple[str, int]] = [
("grpc.keepalive_time_ms", keepalive_time_ms),
("grpc.keepalive_timeout_ms", keepalive_timeout_ms),
("grpc.keepalive_permit_without_calls", keepalive_permit_without_calls),
("grpc.initial_reconnect_backoff_ms", reconnect_backoff_initial_ms),
("grpc.max_reconnect_backoff_ms", reconnect_backoff_max_ms),
]
if max_send_message_length is not None:
opts.append(("grpc.max_send_message_length", max_send_message_length))
if max_recv_message_length is not None:
opts.append(("grpc.max_receive_message_length", max_recv_message_length))
return opts


def _token_call_credentials(token: str) -> grpc.CallCredentials:
Expand All @@ -30,6 +50,13 @@ def create_channel(
insecure: bool = True,
credentials: grpc.ChannelCredentials | None = None,
token: str | None = None,
max_send_message_length: int | None = None,
max_recv_message_length: int | None = None,
keepalive_time_ms: int = _DEFAULT_KEEPALIVE_TIME_MS,
keepalive_timeout_ms: int = _DEFAULT_KEEPALIVE_TIMEOUT_MS,
keepalive_permit_without_calls: int = _DEFAULT_KEEPALIVE_PERMIT_WITHOUT_CALLS,
reconnect_backoff_initial_ms: int = _DEFAULT_RECONNECT_BACKOFF_INITIAL_MS,
reconnect_backoff_max_ms: int = _DEFAULT_RECONNECT_BACKOFF_MAX_MS,
) -> grpc.Channel:
"""Create a gRPC channel with sensible defaults.

Expand All @@ -38,7 +65,20 @@ def create_channel(
``composite_channel_credentials`` so it is protected by the TLS layer.
On an insecure channel the token is sent as a raw header — callers should
warn the user before allowing this.

Pass *max_send_message_length* or *max_recv_message_length* (bytes) to
override gRPC's 4 MB default, which can be too small for large JSON values.
"""
options = _build_options(
max_send_message_length,
max_recv_message_length,
keepalive_time_ms,
keepalive_timeout_ms,
keepalive_permit_without_calls,
reconnect_backoff_initial_ms,
reconnect_backoff_max_ms,
)

channel_creds: grpc.ChannelCredentials | None = credentials
if channel_creds is None and not insecure:
channel_creds = grpc.ssl_channel_credentials()
Expand All @@ -48,9 +88,9 @@ def create_channel(
channel_creds = grpc.composite_channel_credentials(
channel_creds, _token_call_credentials(token)
)
return grpc.secure_channel(target, channel_creds, options=_DEFAULT_OPTIONS)
return grpc.secure_channel(target, channel_creds, options=options)

return grpc.insecure_channel(target, options=_DEFAULT_OPTIONS)
return grpc.insecure_channel(target, options=options)


def create_aio_channel(
Expand All @@ -59,6 +99,13 @@ def create_aio_channel(
insecure: bool = True,
credentials: grpc.ChannelCredentials | None = None,
token: str | None = None,
max_send_message_length: int | None = None,
max_recv_message_length: int | None = None,
keepalive_time_ms: int = _DEFAULT_KEEPALIVE_TIME_MS,
keepalive_timeout_ms: int = _DEFAULT_KEEPALIVE_TIMEOUT_MS,
keepalive_permit_without_calls: int = _DEFAULT_KEEPALIVE_PERMIT_WITHOUT_CALLS,
reconnect_backoff_initial_ms: int = _DEFAULT_RECONNECT_BACKOFF_INITIAL_MS,
reconnect_backoff_max_ms: int = _DEFAULT_RECONNECT_BACKOFF_MAX_MS,
) -> grpc.aio.Channel:
"""Create an async gRPC channel with sensible defaults.

Expand All @@ -67,7 +114,20 @@ def create_aio_channel(
``composite_channel_credentials`` so it is protected by the TLS layer.
On an insecure channel the token is sent as a raw header — callers should
warn the user before allowing this.

Pass *max_send_message_length* or *max_recv_message_length* (bytes) to
override gRPC's 4 MB default, which can be too small for large JSON values.
"""
options = _build_options(
max_send_message_length,
max_recv_message_length,
keepalive_time_ms,
keepalive_timeout_ms,
keepalive_permit_without_calls,
reconnect_backoff_initial_ms,
reconnect_backoff_max_ms,
)

channel_creds: grpc.ChannelCredentials | None = credentials
if channel_creds is None and not insecure:
channel_creds = grpc.ssl_channel_credentials()
Expand All @@ -77,6 +137,6 @@ def create_aio_channel(
channel_creds = grpc.composite_channel_credentials(
channel_creds, _token_call_credentials(token)
)
return grpc.aio.secure_channel(target, channel_creds, options=_DEFAULT_OPTIONS)
return grpc.aio.secure_channel(target, channel_creds, options=options)

return grpc.aio.insecure_channel(target, options=_DEFAULT_OPTIONS)
return grpc.aio.insecure_channel(target, options=options)
62 changes: 62 additions & 0 deletions sdk/tests/test_channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,37 @@ def test_insecure_with_token_does_not_use_composite(self):
mock_insecure.assert_called_once()
mock_comp.assert_not_called()

def test_message_size_options_included_when_set(self):
with patch("opendecree._channel.grpc.insecure_channel") as mock:
mock.return_value = MagicMock()
create_channel(
"localhost:9090",
max_send_message_length=16 * 1024 * 1024,
max_recv_message_length=32 * 1024 * 1024,
)
_, kwargs = mock.call_args
opts = dict(kwargs["options"])
assert opts["grpc.max_send_message_length"] == 16 * 1024 * 1024
assert opts["grpc.max_receive_message_length"] == 32 * 1024 * 1024

def test_message_size_options_absent_by_default(self):
with patch("opendecree._channel.grpc.insecure_channel") as mock:
mock.return_value = MagicMock()
create_channel("localhost:9090")
_, kwargs = mock.call_args
keys = [k for k, _ in kwargs["options"]]
assert "grpc.max_send_message_length" not in keys
assert "grpc.max_receive_message_length" not in keys

def test_keepalive_override(self):
with patch("opendecree._channel.grpc.insecure_channel") as mock:
mock.return_value = MagicMock()
create_channel("localhost:9090", keepalive_time_ms=60000, keepalive_timeout_ms=5000)
_, kwargs = mock.call_args
opts = dict(kwargs["options"])
assert opts["grpc.keepalive_time_ms"] == 60000
assert opts["grpc.keepalive_timeout_ms"] == 5000


class TestCreateAioChannel:
def test_insecure(self):
Expand Down Expand Up @@ -119,3 +150,34 @@ def test_insecure_with_token_does_not_use_composite(self):
create_aio_channel("localhost:9090", insecure=True, token="tok")
mock_insecure.assert_called_once()
mock_comp.assert_not_called()

def test_message_size_options_included_when_set(self):
with patch("opendecree._channel.grpc.aio.insecure_channel") as mock:
mock.return_value = MagicMock()
create_aio_channel(
"localhost:9090",
max_send_message_length=16 * 1024 * 1024,
max_recv_message_length=32 * 1024 * 1024,
)
_, kwargs = mock.call_args
opts = dict(kwargs["options"])
assert opts["grpc.max_send_message_length"] == 16 * 1024 * 1024
assert opts["grpc.max_receive_message_length"] == 32 * 1024 * 1024

def test_message_size_options_absent_by_default(self):
with patch("opendecree._channel.grpc.aio.insecure_channel") as mock:
mock.return_value = MagicMock()
create_aio_channel("localhost:9090")
_, kwargs = mock.call_args
keys = [k for k, _ in kwargs["options"]]
assert "grpc.max_send_message_length" not in keys
assert "grpc.max_receive_message_length" not in keys

def test_keepalive_override(self):
with patch("opendecree._channel.grpc.aio.insecure_channel") as mock:
mock.return_value = MagicMock()
create_aio_channel("localhost:9090", keepalive_time_ms=60000, keepalive_timeout_ms=5000)
_, kwargs = mock.call_args
opts = dict(kwargs["options"])
assert opts["grpc.keepalive_time_ms"] == 60000
assert opts["grpc.keepalive_timeout_ms"] == 5000