From 1be949cc485fc6bed18b1e227fe0e086e59d4e47 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Fri, 27 Mar 2026 15:14:38 +0100 Subject: [PATCH 01/12] Add class based python api --- .gitignore | 2 + etc/run-integration-tests.sh | 82 ++ examples/README.md | 0 examples/class_based_greeter.py | 54 + examples/main.py | 6 + examples/pyproject.toml | 15 + examples/uv.lock | 277 ++++++ python/restate/__init__.py | 10 + python/restate/admin_client.py | 140 +++ python/restate/cls.py | 928 ++++++++++++++++++ python/restate/context_access.py | 307 ++++++ python/restate/endpoint.py | 17 +- python/restate/ext/a2a/PLAN.md | 247 +++++ python/restate/ext/a2a/__init__.py | 13 + python/restate/ext/adk/summarizer.py | 9 +- python/restate/ext/pydantic/__init__.py | 1 + test-services-cls/Dockerfile | 42 + test-services-cls/entrypoint.sh | 19 + test-services-cls/exclusions.yaml | 1 + test-services-cls/hypercorn-config.toml | 5 + test-services-cls/services/__init__.py | 42 + .../services/awakeable_holder.py | 36 + .../services/block_and_wait_workflow.py | 39 + test-services-cls/services/cancel_test.py | 66 ++ test-services-cls/services/counter.py | 56 ++ test-services-cls/services/failing.py | 88 ++ test-services-cls/services/interpreter.py | 295 ++++++ test-services-cls/services/kill_test.py | 42 + test-services-cls/services/list_object.py | 33 + test-services-cls/services/map_object.py | 42 + test-services-cls/services/non_determinism.py | 76 ++ test-services-cls/services/proxy.py | 95 ++ test-services-cls/services/test_utils.py | 66 ++ .../virtual_object_command_interpreter.py | 203 ++++ test-services-cls/testservices.py | 16 + tests/admin_client.py | 166 ++++ 36 files changed, 3523 insertions(+), 13 deletions(-) create mode 100755 etc/run-integration-tests.sh create mode 100644 examples/README.md create mode 100644 examples/class_based_greeter.py create mode 100644 examples/main.py create mode 100644 examples/pyproject.toml create mode 100644 examples/uv.lock create mode 100644 python/restate/admin_client.py create mode 100644 python/restate/cls.py create mode 100644 python/restate/context_access.py create mode 100644 python/restate/ext/a2a/PLAN.md create mode 100644 python/restate/ext/a2a/__init__.py create mode 100644 test-services-cls/Dockerfile create mode 100755 test-services-cls/entrypoint.sh create mode 100644 test-services-cls/exclusions.yaml create mode 100644 test-services-cls/hypercorn-config.toml create mode 100644 test-services-cls/services/__init__.py create mode 100644 test-services-cls/services/awakeable_holder.py create mode 100644 test-services-cls/services/block_and_wait_workflow.py create mode 100644 test-services-cls/services/cancel_test.py create mode 100644 test-services-cls/services/counter.py create mode 100644 test-services-cls/services/failing.py create mode 100644 test-services-cls/services/interpreter.py create mode 100644 test-services-cls/services/kill_test.py create mode 100644 test-services-cls/services/list_object.py create mode 100644 test-services-cls/services/map_object.py create mode 100644 test-services-cls/services/non_determinism.py create mode 100644 test-services-cls/services/proxy.py create mode 100644 test-services-cls/services/test_utils.py create mode 100644 test-services-cls/services/virtual_object_command_interpreter.py create mode 100644 test-services-cls/testservices.py create mode 100644 tests/admin_client.py diff --git a/.gitignore b/.gitignore index 7e6a605..cc374af 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,5 @@ +tmp/ + # Byte-compiled / optimized / DLL files __pycache__/ *.py[cod] diff --git a/etc/run-integration-tests.sh b/etc/run-integration-tests.sh new file mode 100755 index 0000000..df84f19 --- /dev/null +++ b/etc/run-integration-tests.sh @@ -0,0 +1,82 @@ +#!/usr/bin/env bash +set -euo pipefail + +# Run the sdk-test-suite integration tests locally. +# +# Prerequisites: +# - Docker running +# +# Usage: +# ./etc/run-integration-tests.sh # test original test-services +# ./etc/run-integration-tests.sh --cls # test class-based test-services +# ./etc/run-integration-tests.sh --skip-build # reuse existing image +# ./etc/run-integration-tests.sh --cls --skip-build # combine flags + +REPO_ROOT="$(cd "$(dirname "$0")/.." && pwd)" +SDK_TEST_SUITE_VERSION="v3.4" +JAR_URL="https://github.com/restatedev/sdk-test-suite/releases/download/${SDK_TEST_SUITE_VERSION}/restate-sdk-test-suite.jar" +JAR_PATH="${REPO_ROOT}/tmp/restate-sdk-test-suite.jar" +RESTATE_IMAGE="${RESTATE_CONTAINER_IMAGE:-ghcr.io/restatedev/restate:main}" +REPORT_DIR="${REPO_ROOT}/tmp/test-report" + +SKIP_BUILD=false +USE_CLS=false +for arg in "$@"; do + case "$arg" in + --skip-build) SKIP_BUILD=true ;; + --cls) USE_CLS=true ;; + esac +done + +if [ "$USE_CLS" = true ]; then + SERVICE_IMAGE="restatedev/test-services-python-cls" + DOCKERFILE="${REPO_ROOT}/test-services-cls/Dockerfile" + EXCLUSIONS="${REPO_ROOT}/test-services-cls/exclusions.yaml" + ENV_FILE="${REPO_ROOT}/test-services-cls/.env" + echo "==> Using class-based test-services (test-services-cls/)" +else + SERVICE_IMAGE="restatedev/test-services-python" + DOCKERFILE="${REPO_ROOT}/test-services/Dockerfile" + EXCLUSIONS="${REPO_ROOT}/test-services/exclusions.yaml" + ENV_FILE="${REPO_ROOT}/test-services/.env" + echo "==> Using original test-services (test-services/)" +fi + +# 1. Build the test-services Docker image +if [ "$SKIP_BUILD" = false ]; then + echo "==> Building test-services Docker image..." + docker build -f "${DOCKERFILE}" -t "${SERVICE_IMAGE}" "${REPO_ROOT}" +fi + +# 2. Download the test suite JAR (if not cached) +mkdir -p "$(dirname "$JAR_PATH")" +if [ ! -f "$JAR_PATH" ]; then + echo "==> Downloading sdk-test-suite ${SDK_TEST_SUITE_VERSION}..." + curl -fSL -o "$JAR_PATH" "$JAR_URL" +fi + +# 3. Pull restate image +echo "==> Pulling Restate image: ${RESTATE_IMAGE}" +docker pull "${RESTATE_IMAGE}" + +# 4. Run the test suite via Docker (no local Java needed) +echo "==> Running integration tests..." +rm -rf "${REPORT_DIR}" +mkdir -p "${REPORT_DIR}" + +docker run --rm \ + -v /var/run/docker.sock:/var/run/docker.sock \ + -v "${JAR_PATH}:/opt/test-suite.jar:ro" \ + -v "${EXCLUSIONS}:/opt/exclusions.yaml:ro" \ + -v "${ENV_FILE}:/opt/service.env:ro" \ + -v "${REPORT_DIR}:/opt/test-report" \ + -e RESTATE_CONTAINER_IMAGE="${RESTATE_IMAGE}" \ + --network host \ + eclipse-temurin:21-jre \ + java -jar /opt/test-suite.jar run \ + --exclusions-file=/opt/exclusions.yaml \ + --service-container-env-file=/opt/service.env \ + --report-dir=/opt/test-report \ + "${SERVICE_IMAGE}" + +echo "==> Done. Test report: ${REPORT_DIR}" diff --git a/examples/README.md b/examples/README.md new file mode 100644 index 0000000..e69de29 diff --git a/examples/class_based_greeter.py b/examples/class_based_greeter.py new file mode 100644 index 0000000..839dff9 --- /dev/null +++ b/examples/class_based_greeter.py @@ -0,0 +1,54 @@ +""" +Class-based API example for the Restate Python SDK. + +This example demonstrates the same services as the decorator-based examples, +but using the class-based API with @handler, @shared, and @main decorators. +""" + +import restate +from restate.cls import Service, VirtualObject, Workflow, handler, shared, main, Context + + +class Greeter(Service): + """A simple stateless greeting service.""" + + @handler + async def greet(self, name: str) -> str: + return f"Hello {name}!" + + +class Counter(VirtualObject): + """A stateful counter backed by durable state.""" + + @handler + async def increment(self, value: int) -> int: + n: int = await Context.get("counter", type_hint=int) or 0 + n += value + Context.set("counter", n) + return n + + @shared + async def count(self) -> int: + return await Context.get("counter", type_hint=int) or 0 + + +class PaymentWorkflow(Workflow): + """A durable payment workflow with external verification.""" + + @main + async def pay(self, amount: int) -> str: + Context.set("status", "processing") + + async def charge(): + return f"charged ${amount}" + + receipt = await Context.run_typed("charge", charge) + Context.set("status", "completed") + return receipt + + @handler + async def status(self) -> str: + return await Context.get("status", type_hint=str) or "unknown" + + +app = restate.app([Greeter, Counter, PaymentWorkflow]) diff --git a/examples/main.py b/examples/main.py new file mode 100644 index 0000000..9367ba4 --- /dev/null +++ b/examples/main.py @@ -0,0 +1,6 @@ +def main(): + print("Hello from examples!") + + +if __name__ == "__main__": + main() diff --git a/examples/pyproject.toml b/examples/pyproject.toml new file mode 100644 index 0000000..4db9d1a --- /dev/null +++ b/examples/pyproject.toml @@ -0,0 +1,15 @@ +[project] +name = "examples" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.12" +dependencies = [ + "hypercorn>=0.18.0", + "pydantic>=2.12.5", + "restate-sdk", + "uvicorn>=0.38.0", +] + +[tool.uv.sources] +restate-sdk = { path = "../dist/restate_sdk-0.14.2-cp313-cp313-macosx_14_0_arm64.whl" } diff --git a/examples/uv.lock b/examples/uv.lock new file mode 100644 index 0000000..17bf112 --- /dev/null +++ b/examples/uv.lock @@ -0,0 +1,277 @@ +version = 1 +revision = 3 +requires-python = ">=3.12" + +[[package]] +name = "annotated-types" +version = "0.7.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/ee/67/531ea369ba64dcff5ec9c3402f9f51bf748cec26dde048a2f973a4eea7f5/annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89", size = 16081, upload-time = "2024-05-20T21:33:25.928Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/78/b6/6307fbef88d9b5ee7421e68d78a9f162e0da4900bc5f5793f6d3d0e34fb8/annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53", size = 13643, upload-time = "2024-05-20T21:33:24.1Z" }, +] + +[[package]] +name = "click" +version = "8.3.1" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "colorama", marker = "sys_platform == 'win32'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/3d/fa/656b739db8587d7b5dfa22e22ed02566950fbfbcdc20311993483657a5c0/click-8.3.1.tar.gz", hash = "sha256:12ff4785d337a1bb490bb7e9c2b1ee5da3112e94a8622f26a6c77f5d2fc6842a", size = 295065, upload-time = "2025-11-15T20:45:42.706Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/98/78/01c019cdb5d6498122777c1a43056ebb3ebfeef2076d9d026bfe15583b2b/click-8.3.1-py3-none-any.whl", hash = "sha256:981153a64e25f12d547d3426c367a4857371575ee7ad18df2a6183ab0545b2a6", size = 108274, upload-time = "2025-11-15T20:45:41.139Z" }, +] + +[[package]] +name = "colorama" +version = "0.4.6" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/d8/53/6f443c9a4a8358a93a6792e2acffb9d9d5cb0a5cfd8802644b7b1c9a02e4/colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44", size = 27697, upload-time = "2022-10-25T02:36:22.414Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/d1/d6/3965ed04c63042e047cb6a3e6ed1a63a35087b6a609aa3a15ed8ac56c221/colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6", size = 25335, upload-time = "2022-10-25T02:36:20.889Z" }, +] + +[[package]] +name = "examples" +version = "0.1.0" +source = { virtual = "." } +dependencies = [ + { name = "hypercorn" }, + { name = "pydantic" }, + { name = "restate-sdk" }, + { name = "uvicorn" }, +] + +[package.metadata] +requires-dist = [ + { name = "hypercorn", specifier = ">=0.18.0" }, + { name = "pydantic", specifier = ">=2.12.5" }, + { name = "restate-sdk", path = "../dist/restate_sdk-0.14.2-cp313-cp313-macosx_14_0_arm64.whl" }, + { name = "uvicorn", specifier = ">=0.38.0" }, +] + +[[package]] +name = "h11" +version = "0.16.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/01/ee/02a2c011bdab74c6fb3c75474d40b3052059d95df7e73351460c8588d963/h11-0.16.0.tar.gz", hash = "sha256:4e35b956cf45792e4caa5885e69fba00bdbc6ffafbfa020300e549b208ee5ff1", size = 101250, upload-time = "2025-04-24T03:35:25.427Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/04/4b/29cac41a4d98d144bf5f6d33995617b185d14b22401f75ca86f384e87ff1/h11-0.16.0-py3-none-any.whl", hash = "sha256:63cf8bbe7522de3bf65932fda1d9c2772064ffb3dae62d55932da54b31cb6c86", size = 37515, upload-time = "2025-04-24T03:35:24.344Z" }, +] + +[[package]] +name = "h2" +version = "4.3.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "hpack" }, + { name = "hyperframe" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1d/17/afa56379f94ad0fe8defd37d6eb3f89a25404ffc71d4d848893d270325fc/h2-4.3.0.tar.gz", hash = "sha256:6c59efe4323fa18b47a632221a1888bd7fde6249819beda254aeca909f221bf1", size = 2152026, upload-time = "2025-08-23T18:12:19.778Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/69/b2/119f6e6dcbd96f9069ce9a2665e0146588dc9f88f29549711853645e736a/h2-4.3.0-py3-none-any.whl", hash = "sha256:c438f029a25f7945c69e0ccf0fb951dc3f73a5f6412981daee861431b70e2bdd", size = 61779, upload-time = "2025-08-23T18:12:17.779Z" }, +] + +[[package]] +name = "hpack" +version = "4.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/2c/48/71de9ed269fdae9c8057e5a4c0aa7402e8bb16f2c6e90b3aa53327b113f8/hpack-4.1.0.tar.gz", hash = "sha256:ec5eca154f7056aa06f196a557655c5b009b382873ac8d1e66e79e87535f1dca", size = 51276, upload-time = "2025-01-22T21:44:58.347Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/07/c6/80c95b1b2b94682a72cbdbfb85b81ae2daffa4291fbfa1b1464502ede10d/hpack-4.1.0-py3-none-any.whl", hash = "sha256:157ac792668d995c657d93111f46b4535ed114f0c9c8d672271bbec7eae1b496", size = 34357, upload-time = "2025-01-22T21:44:56.92Z" }, +] + +[[package]] +name = "hypercorn" +version = "0.18.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, + { name = "h2" }, + { name = "priority" }, + { name = "wsproto" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/44/01/39f41a014b83dd5c795217362f2ca9071cf243e6a75bdcd6cd5b944658cc/hypercorn-0.18.0.tar.gz", hash = "sha256:d63267548939c46b0247dc8e5b45a9947590e35e64ee73a23c074aa3cf88e9da", size = 68420, upload-time = "2025-11-08T13:54:04.78Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/93/35/850277d1b17b206bd10874c8a9a3f52e059452fb49bb0d22cbb908f6038b/hypercorn-0.18.0-py3-none-any.whl", hash = "sha256:225e268f2c1c2f28f6d8f6db8f40cb8c992963610c5725e13ccfcddccb24b1cd", size = 61640, upload-time = "2025-11-08T13:54:03.202Z" }, +] + +[[package]] +name = "hyperframe" +version = "6.1.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/02/e7/94f8232d4a74cc99514c13a9f995811485a6903d48e5d952771ef6322e30/hyperframe-6.1.0.tar.gz", hash = "sha256:f630908a00854a7adeabd6382b43923a4c4cd4b821fcb527e6ab9e15382a3b08", size = 26566, upload-time = "2025-01-22T21:41:49.302Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/48/30/47d0bf6072f7252e6521f3447ccfa40b421b6824517f82854703d0f5a98b/hyperframe-6.1.0-py3-none-any.whl", hash = "sha256:b03380493a519fce58ea5af42e4a42317bf9bd425596f7a0835ffce80f1a42e5", size = 13007, upload-time = "2025-01-22T21:41:47.295Z" }, +] + +[[package]] +name = "priority" +version = "2.0.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/f5/3c/eb7c35f4dcede96fca1842dac5f4f5d15511aa4b52f3a961219e68ae9204/priority-2.0.0.tar.gz", hash = "sha256:c965d54f1b8d0d0b19479db3924c7c36cf672dbf2aec92d43fbdaf4492ba18c0", size = 24792, upload-time = "2021-06-27T10:15:05.487Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5e/5f/82c8074f7e84978129347c2c6ec8b6c59f3584ff1a20bc3c940a3e061790/priority-2.0.0-py3-none-any.whl", hash = "sha256:6f8eefce5f3ad59baf2c080a664037bb4725cd0a790d53d59ab4059288faf6aa", size = 8946, upload-time = "2021-06-27T10:15:03.856Z" }, +] + +[[package]] +name = "pydantic" +version = "2.12.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "annotated-types" }, + { name = "pydantic-core" }, + { name = "typing-extensions" }, + { name = "typing-inspection" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/69/44/36f1a6e523abc58ae5f928898e4aca2e0ea509b5aa6f6f392a5d882be928/pydantic-2.12.5.tar.gz", hash = "sha256:4d351024c75c0f085a9febbb665ce8c0c6ec5d30e903bdb6394b7ede26aebb49", size = 821591, upload-time = "2025-11-26T15:11:46.471Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5a/87/b70ad306ebb6f9b585f114d0ac2137d792b48be34d732d60e597c2f8465a/pydantic-2.12.5-py3-none-any.whl", hash = "sha256:e561593fccf61e8a20fc46dfc2dfe075b8be7d0188df33f221ad1f0139180f9d", size = 463580, upload-time = "2025-11-26T15:11:44.605Z" }, +] + +[[package]] +name = "pydantic-core" +version = "2.41.5" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/71/70/23b021c950c2addd24ec408e9ab05d59b035b39d97cdc1130e1bce647bb6/pydantic_core-2.41.5.tar.gz", hash = "sha256:08daa51ea16ad373ffd5e7606252cc32f07bc72b28284b6bc9c6df804816476e", size = 460952, upload-time = "2025-11-04T13:43:49.098Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/5f/5d/5f6c63eebb5afee93bcaae4ce9a898f3373ca23df3ccaef086d0233a35a7/pydantic_core-2.41.5-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:f41a7489d32336dbf2199c8c0a215390a751c5b014c2c1c5366e817202e9cdf7", size = 2110990, upload-time = "2025-11-04T13:39:58.079Z" }, + { url = "https://files.pythonhosted.org/packages/aa/32/9c2e8ccb57c01111e0fd091f236c7b371c1bccea0fa85247ac55b1e2b6b6/pydantic_core-2.41.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:070259a8818988b9a84a449a2a7337c7f430a22acc0859c6b110aa7212a6d9c0", size = 1896003, upload-time = "2025-11-04T13:39:59.956Z" }, + { url = "https://files.pythonhosted.org/packages/68/b8/a01b53cb0e59139fbc9e4fda3e9724ede8de279097179be4ff31f1abb65a/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e96cea19e34778f8d59fe40775a7a574d95816eb150850a85a7a4c8f4b94ac69", size = 1919200, upload-time = "2025-11-04T13:40:02.241Z" }, + { url = "https://files.pythonhosted.org/packages/38/de/8c36b5198a29bdaade07b5985e80a233a5ac27137846f3bc2d3b40a47360/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:ed2e99c456e3fadd05c991f8f437ef902e00eedf34320ba2b0842bd1c3ca3a75", size = 2052578, upload-time = "2025-11-04T13:40:04.401Z" }, + { url = "https://files.pythonhosted.org/packages/00/b5/0e8e4b5b081eac6cb3dbb7e60a65907549a1ce035a724368c330112adfdd/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65840751b72fbfd82c3c640cff9284545342a4f1eb1586ad0636955b261b0b05", size = 2208504, upload-time = "2025-11-04T13:40:06.072Z" }, + { url = "https://files.pythonhosted.org/packages/77/56/87a61aad59c7c5b9dc8caad5a41a5545cba3810c3e828708b3d7404f6cef/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e536c98a7626a98feb2d3eaf75944ef6f3dbee447e1f841eae16f2f0a72d8ddc", size = 2335816, upload-time = "2025-11-04T13:40:07.835Z" }, + { url = "https://files.pythonhosted.org/packages/0d/76/941cc9f73529988688a665a5c0ecff1112b3d95ab48f81db5f7606f522d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eceb81a8d74f9267ef4081e246ffd6d129da5d87e37a77c9bde550cb04870c1c", size = 2075366, upload-time = "2025-11-04T13:40:09.804Z" }, + { url = "https://files.pythonhosted.org/packages/d3/43/ebef01f69baa07a482844faaa0a591bad1ef129253ffd0cdaa9d8a7f72d3/pydantic_core-2.41.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d38548150c39b74aeeb0ce8ee1d8e82696f4a4e16ddc6de7b1d8823f7de4b9b5", size = 2171698, upload-time = "2025-11-04T13:40:12.004Z" }, + { url = "https://files.pythonhosted.org/packages/b1/87/41f3202e4193e3bacfc2c065fab7706ebe81af46a83d3e27605029c1f5a6/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:c23e27686783f60290e36827f9c626e63154b82b116d7fe9adba1fda36da706c", size = 2132603, upload-time = "2025-11-04T13:40:13.868Z" }, + { url = "https://files.pythonhosted.org/packages/49/7d/4c00df99cb12070b6bccdef4a195255e6020a550d572768d92cc54dba91a/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_armv7l.whl", hash = "sha256:482c982f814460eabe1d3bb0adfdc583387bd4691ef00b90575ca0d2b6fe2294", size = 2329591, upload-time = "2025-11-04T13:40:15.672Z" }, + { url = "https://files.pythonhosted.org/packages/cc/6a/ebf4b1d65d458f3cda6a7335d141305dfa19bdc61140a884d165a8a1bbc7/pydantic_core-2.41.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:bfea2a5f0b4d8d43adf9d7b8bf019fb46fdd10a2e5cde477fbcb9d1fa08c68e1", size = 2319068, upload-time = "2025-11-04T13:40:17.532Z" }, + { url = "https://files.pythonhosted.org/packages/49/3b/774f2b5cd4192d5ab75870ce4381fd89cf218af999515baf07e7206753f0/pydantic_core-2.41.5-cp312-cp312-win32.whl", hash = "sha256:b74557b16e390ec12dca509bce9264c3bbd128f8a2c376eaa68003d7f327276d", size = 1985908, upload-time = "2025-11-04T13:40:19.309Z" }, + { url = "https://files.pythonhosted.org/packages/86/45/00173a033c801cacf67c190fef088789394feaf88a98a7035b0e40d53dc9/pydantic_core-2.41.5-cp312-cp312-win_amd64.whl", hash = "sha256:1962293292865bca8e54702b08a4f26da73adc83dd1fcf26fbc875b35d81c815", size = 2020145, upload-time = "2025-11-04T13:40:21.548Z" }, + { url = "https://files.pythonhosted.org/packages/f9/22/91fbc821fa6d261b376a3f73809f907cec5ca6025642c463d3488aad22fb/pydantic_core-2.41.5-cp312-cp312-win_arm64.whl", hash = "sha256:1746d4a3d9a794cacae06a5eaaccb4b8643a131d45fbc9af23e353dc0a5ba5c3", size = 1976179, upload-time = "2025-11-04T13:40:23.393Z" }, + { url = "https://files.pythonhosted.org/packages/87/06/8806241ff1f70d9939f9af039c6c35f2360cf16e93c2ca76f184e76b1564/pydantic_core-2.41.5-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:941103c9be18ac8daf7b7adca8228f8ed6bb7a1849020f643b3a14d15b1924d9", size = 2120403, upload-time = "2025-11-04T13:40:25.248Z" }, + { url = "https://files.pythonhosted.org/packages/94/02/abfa0e0bda67faa65fef1c84971c7e45928e108fe24333c81f3bfe35d5f5/pydantic_core-2.41.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:112e305c3314f40c93998e567879e887a3160bb8689ef3d2c04b6cc62c33ac34", size = 1896206, upload-time = "2025-11-04T13:40:27.099Z" }, + { url = "https://files.pythonhosted.org/packages/15/df/a4c740c0943e93e6500f9eb23f4ca7ec9bf71b19e608ae5b579678c8d02f/pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0cbaad15cb0c90aa221d43c00e77bb33c93e8d36e0bf74760cd00e732d10a6a0", size = 1919307, upload-time = "2025-11-04T13:40:29.806Z" }, + { url = "https://files.pythonhosted.org/packages/9a/e3/6324802931ae1d123528988e0e86587c2072ac2e5394b4bc2bc34b61ff6e/pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:03ca43e12fab6023fc79d28ca6b39b05f794ad08ec2feccc59a339b02f2b3d33", size = 2063258, upload-time = "2025-11-04T13:40:33.544Z" }, + { url = "https://files.pythonhosted.org/packages/c9/d4/2230d7151d4957dd79c3044ea26346c148c98fbf0ee6ebd41056f2d62ab5/pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:dc799088c08fa04e43144b164feb0c13f9a0bc40503f8df3e9fde58a3c0c101e", size = 2214917, upload-time = "2025-11-04T13:40:35.479Z" }, + { url = "https://files.pythonhosted.org/packages/e6/9f/eaac5df17a3672fef0081b6c1bb0b82b33ee89aa5cec0d7b05f52fd4a1fa/pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:97aeba56665b4c3235a0e52b2c2f5ae9cd071b8a8310ad27bddb3f7fb30e9aa2", size = 2332186, upload-time = "2025-11-04T13:40:37.436Z" }, + { url = "https://files.pythonhosted.org/packages/cf/4e/35a80cae583a37cf15604b44240e45c05e04e86f9cfd766623149297e971/pydantic_core-2.41.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:406bf18d345822d6c21366031003612b9c77b3e29ffdb0f612367352aab7d586", size = 2073164, upload-time = "2025-11-04T13:40:40.289Z" }, + { url = "https://files.pythonhosted.org/packages/bf/e3/f6e262673c6140dd3305d144d032f7bd5f7497d3871c1428521f19f9efa2/pydantic_core-2.41.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b93590ae81f7010dbe380cdeab6f515902ebcbefe0b9327cc4804d74e93ae69d", size = 2179146, upload-time = "2025-11-04T13:40:42.809Z" }, + { url = "https://files.pythonhosted.org/packages/75/c7/20bd7fc05f0c6ea2056a4565c6f36f8968c0924f19b7d97bbfea55780e73/pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:01a3d0ab748ee531f4ea6c3e48ad9dac84ddba4b0d82291f87248f2f9de8d740", size = 2137788, upload-time = "2025-11-04T13:40:44.752Z" }, + { url = "https://files.pythonhosted.org/packages/3a/8d/34318ef985c45196e004bc46c6eab2eda437e744c124ef0dbe1ff2c9d06b/pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_armv7l.whl", hash = "sha256:6561e94ba9dacc9c61bce40e2d6bdc3bfaa0259d3ff36ace3b1e6901936d2e3e", size = 2340133, upload-time = "2025-11-04T13:40:46.66Z" }, + { url = "https://files.pythonhosted.org/packages/9c/59/013626bf8c78a5a5d9350d12e7697d3d4de951a75565496abd40ccd46bee/pydantic_core-2.41.5-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:915c3d10f81bec3a74fbd4faebe8391013ba61e5a1a8d48c4455b923bdda7858", size = 2324852, upload-time = "2025-11-04T13:40:48.575Z" }, + { url = "https://files.pythonhosted.org/packages/1a/d9/c248c103856f807ef70c18a4f986693a46a8ffe1602e5d361485da502d20/pydantic_core-2.41.5-cp313-cp313-win32.whl", hash = "sha256:650ae77860b45cfa6e2cdafc42618ceafab3a2d9a3811fcfbd3bbf8ac3c40d36", size = 1994679, upload-time = "2025-11-04T13:40:50.619Z" }, + { url = "https://files.pythonhosted.org/packages/9e/8b/341991b158ddab181cff136acd2552c9f35bd30380422a639c0671e99a91/pydantic_core-2.41.5-cp313-cp313-win_amd64.whl", hash = "sha256:79ec52ec461e99e13791ec6508c722742ad745571f234ea6255bed38c6480f11", size = 2019766, upload-time = "2025-11-04T13:40:52.631Z" }, + { url = "https://files.pythonhosted.org/packages/73/7d/f2f9db34af103bea3e09735bb40b021788a5e834c81eedb541991badf8f5/pydantic_core-2.41.5-cp313-cp313-win_arm64.whl", hash = "sha256:3f84d5c1b4ab906093bdc1ff10484838aca54ef08de4afa9de0f5f14d69639cd", size = 1981005, upload-time = "2025-11-04T13:40:54.734Z" }, + { url = "https://files.pythonhosted.org/packages/ea/28/46b7c5c9635ae96ea0fbb779e271a38129df2550f763937659ee6c5dbc65/pydantic_core-2.41.5-cp314-cp314-macosx_10_12_x86_64.whl", hash = "sha256:3f37a19d7ebcdd20b96485056ba9e8b304e27d9904d233d7b1015db320e51f0a", size = 2119622, upload-time = "2025-11-04T13:40:56.68Z" }, + { url = "https://files.pythonhosted.org/packages/74/1a/145646e5687e8d9a1e8d09acb278c8535ebe9e972e1f162ed338a622f193/pydantic_core-2.41.5-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:1d1d9764366c73f996edd17abb6d9d7649a7eb690006ab6adbda117717099b14", size = 1891725, upload-time = "2025-11-04T13:40:58.807Z" }, + { url = "https://files.pythonhosted.org/packages/23/04/e89c29e267b8060b40dca97bfc64a19b2a3cf99018167ea1677d96368273/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:25e1c2af0fce638d5f1988b686f3b3ea8cd7de5f244ca147c777769e798a9cd1", size = 1915040, upload-time = "2025-11-04T13:41:00.853Z" }, + { url = "https://files.pythonhosted.org/packages/84/a3/15a82ac7bd97992a82257f777b3583d3e84bdb06ba6858f745daa2ec8a85/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:506d766a8727beef16b7adaeb8ee6217c64fc813646b424d0804d67c16eddb66", size = 2063691, upload-time = "2025-11-04T13:41:03.504Z" }, + { url = "https://files.pythonhosted.org/packages/74/9b/0046701313c6ef08c0c1cf0e028c67c770a4e1275ca73131563c5f2a310a/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4819fa52133c9aa3c387b3328f25c1facc356491e6135b459f1de698ff64d869", size = 2213897, upload-time = "2025-11-04T13:41:05.804Z" }, + { url = "https://files.pythonhosted.org/packages/8a/cd/6bac76ecd1b27e75a95ca3a9a559c643b3afcd2dd62086d4b7a32a18b169/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2b761d210c9ea91feda40d25b4efe82a1707da2ef62901466a42492c028553a2", size = 2333302, upload-time = "2025-11-04T13:41:07.809Z" }, + { url = "https://files.pythonhosted.org/packages/4c/d2/ef2074dc020dd6e109611a8be4449b98cd25e1b9b8a303c2f0fca2f2bcf7/pydantic_core-2.41.5-cp314-cp314-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:22f0fb8c1c583a3b6f24df2470833b40207e907b90c928cc8d3594b76f874375", size = 2064877, upload-time = "2025-11-04T13:41:09.827Z" }, + { url = "https://files.pythonhosted.org/packages/18/66/e9db17a9a763d72f03de903883c057b2592c09509ccfe468187f2a2eef29/pydantic_core-2.41.5-cp314-cp314-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2782c870e99878c634505236d81e5443092fba820f0373997ff75f90f68cd553", size = 2180680, upload-time = "2025-11-04T13:41:12.379Z" }, + { url = "https://files.pythonhosted.org/packages/d3/9e/3ce66cebb929f3ced22be85d4c2399b8e85b622db77dad36b73c5387f8f8/pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_aarch64.whl", hash = "sha256:0177272f88ab8312479336e1d777f6b124537d47f2123f89cb37e0accea97f90", size = 2138960, upload-time = "2025-11-04T13:41:14.627Z" }, + { url = "https://files.pythonhosted.org/packages/a6/62/205a998f4327d2079326b01abee48e502ea739d174f0a89295c481a2272e/pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_armv7l.whl", hash = "sha256:63510af5e38f8955b8ee5687740d6ebf7c2a0886d15a6d65c32814613681bc07", size = 2339102, upload-time = "2025-11-04T13:41:16.868Z" }, + { url = "https://files.pythonhosted.org/packages/3c/0d/f05e79471e889d74d3d88f5bd20d0ed189ad94c2423d81ff8d0000aab4ff/pydantic_core-2.41.5-cp314-cp314-musllinux_1_1_x86_64.whl", hash = "sha256:e56ba91f47764cc14f1daacd723e3e82d1a89d783f0f5afe9c364b8bb491ccdb", size = 2326039, upload-time = "2025-11-04T13:41:18.934Z" }, + { url = "https://files.pythonhosted.org/packages/ec/e1/e08a6208bb100da7e0c4b288eed624a703f4d129bde2da475721a80cab32/pydantic_core-2.41.5-cp314-cp314-win32.whl", hash = "sha256:aec5cf2fd867b4ff45b9959f8b20ea3993fc93e63c7363fe6851424c8a7e7c23", size = 1995126, upload-time = "2025-11-04T13:41:21.418Z" }, + { url = "https://files.pythonhosted.org/packages/48/5d/56ba7b24e9557f99c9237e29f5c09913c81eeb2f3217e40e922353668092/pydantic_core-2.41.5-cp314-cp314-win_amd64.whl", hash = "sha256:8e7c86f27c585ef37c35e56a96363ab8de4e549a95512445b85c96d3e2f7c1bf", size = 2015489, upload-time = "2025-11-04T13:41:24.076Z" }, + { url = "https://files.pythonhosted.org/packages/4e/bb/f7a190991ec9e3e0ba22e4993d8755bbc4a32925c0b5b42775c03e8148f9/pydantic_core-2.41.5-cp314-cp314-win_arm64.whl", hash = "sha256:e672ba74fbc2dc8eea59fb6d4aed6845e6905fc2a8afe93175d94a83ba2a01a0", size = 1977288, upload-time = "2025-11-04T13:41:26.33Z" }, + { url = "https://files.pythonhosted.org/packages/92/ed/77542d0c51538e32e15afe7899d79efce4b81eee631d99850edc2f5e9349/pydantic_core-2.41.5-cp314-cp314t-macosx_10_12_x86_64.whl", hash = "sha256:8566def80554c3faa0e65ac30ab0932b9e3a5cd7f8323764303d468e5c37595a", size = 2120255, upload-time = "2025-11-04T13:41:28.569Z" }, + { url = "https://files.pythonhosted.org/packages/bb/3d/6913dde84d5be21e284439676168b28d8bbba5600d838b9dca99de0fad71/pydantic_core-2.41.5-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:b80aa5095cd3109962a298ce14110ae16b8c1aece8b72f9dafe81cf597ad80b3", size = 1863760, upload-time = "2025-11-04T13:41:31.055Z" }, + { url = "https://files.pythonhosted.org/packages/5a/f0/e5e6b99d4191da102f2b0eb9687aaa7f5bea5d9964071a84effc3e40f997/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3006c3dd9ba34b0c094c544c6006cc79e87d8612999f1a5d43b769b89181f23c", size = 1878092, upload-time = "2025-11-04T13:41:33.21Z" }, + { url = "https://files.pythonhosted.org/packages/71/48/36fb760642d568925953bcc8116455513d6e34c4beaa37544118c36aba6d/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:72f6c8b11857a856bcfa48c86f5368439f74453563f951e473514579d44aa612", size = 2053385, upload-time = "2025-11-04T13:41:35.508Z" }, + { url = "https://files.pythonhosted.org/packages/20/25/92dc684dd8eb75a234bc1c764b4210cf2646479d54b47bf46061657292a8/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5cb1b2f9742240e4bb26b652a5aeb840aa4b417c7748b6f8387927bc6e45e40d", size = 2218832, upload-time = "2025-11-04T13:41:37.732Z" }, + { url = "https://files.pythonhosted.org/packages/e2/09/f53e0b05023d3e30357d82eb35835d0f6340ca344720a4599cd663dca599/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:bd3d54f38609ff308209bd43acea66061494157703364ae40c951f83ba99a1a9", size = 2327585, upload-time = "2025-11-04T13:41:40Z" }, + { url = "https://files.pythonhosted.org/packages/aa/4e/2ae1aa85d6af35a39b236b1b1641de73f5a6ac4d5a7509f77b814885760c/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2ff4321e56e879ee8d2a879501c8e469414d948f4aba74a2d4593184eb326660", size = 2041078, upload-time = "2025-11-04T13:41:42.323Z" }, + { url = "https://files.pythonhosted.org/packages/cd/13/2e215f17f0ef326fc72afe94776edb77525142c693767fc347ed6288728d/pydantic_core-2.41.5-cp314-cp314t-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:d0d2568a8c11bf8225044aa94409e21da0cb09dcdafe9ecd10250b2baad531a9", size = 2173914, upload-time = "2025-11-04T13:41:45.221Z" }, + { url = "https://files.pythonhosted.org/packages/02/7a/f999a6dcbcd0e5660bc348a3991c8915ce6599f4f2c6ac22f01d7a10816c/pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_aarch64.whl", hash = "sha256:a39455728aabd58ceabb03c90e12f71fd30fa69615760a075b9fec596456ccc3", size = 2129560, upload-time = "2025-11-04T13:41:47.474Z" }, + { url = "https://files.pythonhosted.org/packages/3a/b1/6c990ac65e3b4c079a4fb9f5b05f5b013afa0f4ed6780a3dd236d2cbdc64/pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_armv7l.whl", hash = "sha256:239edca560d05757817c13dc17c50766136d21f7cd0fac50295499ae24f90fdf", size = 2329244, upload-time = "2025-11-04T13:41:49.992Z" }, + { url = "https://files.pythonhosted.org/packages/d9/02/3c562f3a51afd4d88fff8dffb1771b30cfdfd79befd9883ee094f5b6c0d8/pydantic_core-2.41.5-cp314-cp314t-musllinux_1_1_x86_64.whl", hash = "sha256:2a5e06546e19f24c6a96a129142a75cee553cc018ffee48a460059b1185f4470", size = 2331955, upload-time = "2025-11-04T13:41:54.079Z" }, + { url = "https://files.pythonhosted.org/packages/5c/96/5fb7d8c3c17bc8c62fdb031c47d77a1af698f1d7a406b0f79aaa1338f9ad/pydantic_core-2.41.5-cp314-cp314t-win32.whl", hash = "sha256:b4ececa40ac28afa90871c2cc2b9ffd2ff0bf749380fbdf57d165fd23da353aa", size = 1988906, upload-time = "2025-11-04T13:41:56.606Z" }, + { url = "https://files.pythonhosted.org/packages/22/ed/182129d83032702912c2e2d8bbe33c036f342cc735737064668585dac28f/pydantic_core-2.41.5-cp314-cp314t-win_amd64.whl", hash = "sha256:80aa89cad80b32a912a65332f64a4450ed00966111b6615ca6816153d3585a8c", size = 1981607, upload-time = "2025-11-04T13:41:58.889Z" }, + { url = "https://files.pythonhosted.org/packages/9f/ed/068e41660b832bb0b1aa5b58011dea2a3fe0ba7861ff38c4d4904c1c1a99/pydantic_core-2.41.5-cp314-cp314t-win_arm64.whl", hash = "sha256:35b44f37a3199f771c3eaa53051bc8a70cd7b54f333531c59e29fd4db5d15008", size = 1974769, upload-time = "2025-11-04T13:42:01.186Z" }, + { url = "https://files.pythonhosted.org/packages/09/32/59b0c7e63e277fa7911c2fc70ccfb45ce4b98991e7ef37110663437005af/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_10_12_x86_64.whl", hash = "sha256:7da7087d756b19037bc2c06edc6c170eeef3c3bafcb8f532ff17d64dc427adfd", size = 2110495, upload-time = "2025-11-04T13:42:49.689Z" }, + { url = "https://files.pythonhosted.org/packages/aa/81/05e400037eaf55ad400bcd318c05bb345b57e708887f07ddb2d20e3f0e98/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-macosx_11_0_arm64.whl", hash = "sha256:aabf5777b5c8ca26f7824cb4a120a740c9588ed58df9b2d196ce92fba42ff8dc", size = 1915388, upload-time = "2025-11-04T13:42:52.215Z" }, + { url = "https://files.pythonhosted.org/packages/6e/0d/e3549b2399f71d56476b77dbf3cf8937cec5cd70536bdc0e374a421d0599/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c007fe8a43d43b3969e8469004e9845944f1a80e6acd47c150856bb87f230c56", size = 1942879, upload-time = "2025-11-04T13:42:56.483Z" }, + { url = "https://files.pythonhosted.org/packages/f7/07/34573da085946b6a313d7c42f82f16e8920bfd730665de2d11c0c37a74b5/pydantic_core-2.41.5-graalpy312-graalpy250_312_native-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76d0819de158cd855d1cbb8fcafdf6f5cf1eb8e470abe056d5d161106e38062b", size = 2139017, upload-time = "2025-11-04T13:42:59.471Z" }, +] + +[[package]] +name = "restate-sdk" +version = "0.14.2" +source = { path = "../dist/restate_sdk-0.14.2-cp313-cp313-macosx_14_0_arm64.whl" } +wheels = [ + { filename = "restate_sdk-0.14.2-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:e01a80d73e8856bd90c103ef99563cd4f5378db8e9638f861ec18ae70c8c4a83" }, +] + +[package.metadata] +requires-dist = [ + { name = "anyio", marker = "extra == 'test'" }, + { name = "dacite", marker = "extra == 'serde'" }, + { name = "google-adk", marker = "extra == 'adk'", specifier = ">=1.20.0" }, + { name = "httpx", marker = "extra == 'harness'" }, + { name = "httpx", extras = ["http2"], marker = "extra == 'client'" }, + { name = "hypercorn", marker = "extra == 'harness'" }, + { name = "hypercorn", marker = "extra == 'test'" }, + { name = "msgspec", marker = "extra == 'serde'" }, + { name = "mypy", marker = "extra == 'lint'", specifier = ">=1.11.2" }, + { name = "openai-agents", marker = "extra == 'openai'", specifier = ">=0.6.1" }, + { name = "pydantic", marker = "extra == 'serde'" }, + { name = "pydantic-ai-slim", marker = "extra == 'pydantic-ai'", specifier = ">=1.35.0" }, + { name = "pyright", marker = "extra == 'lint'", specifier = ">=1.1.390" }, + { name = "pytest", marker = "extra == 'test'" }, + { name = "ruff", marker = "extra == 'lint'", specifier = ">=0.6.9" }, + { name = "testcontainers", marker = "extra == 'harness'" }, +] +provides-extras = ["adk", "client", "harness", "lint", "openai", "pydantic-ai", "serde", "test"] + +[[package]] +name = "typing-extensions" +version = "4.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/72/94/1a15dd82efb362ac84269196e94cf00f187f7ed21c242792a923cdb1c61f/typing_extensions-4.15.0.tar.gz", hash = "sha256:0cea48d173cc12fa28ecabc3b837ea3cf6f38c6d1136f85cbaaf598984861466", size = 109391, upload-time = "2025-08-25T13:49:26.313Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/18/67/36e9267722cc04a6b9f15c7f3441c2363321a3ea07da7ae0c0707beb2a9c/typing_extensions-4.15.0-py3-none-any.whl", hash = "sha256:f0fa19c6845758ab08074a0cfa8b7aecb71c999ca73d62883bc25cc018c4e548", size = 44614, upload-time = "2025-08-25T13:49:24.86Z" }, +] + +[[package]] +name = "typing-inspection" +version = "0.4.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/55/e3/70399cb7dd41c10ac53367ae42139cf4b1ca5f36bb3dc6c9d33acdb43655/typing_inspection-0.4.2.tar.gz", hash = "sha256:ba561c48a67c5958007083d386c3295464928b01faa735ab8547c5692e87f464", size = 75949, upload-time = "2025-10-01T02:14:41.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/dc/9b/47798a6c91d8bdb567fe2698fe81e0c6b7cb7ef4d13da4114b41d239f65d/typing_inspection-0.4.2-py3-none-any.whl", hash = "sha256:4ed1cacbdc298c220f1bd249ed5287caa16f34d44ef4e9c3d0cbad5b521545e7", size = 14611, upload-time = "2025-10-01T02:14:40.154Z" }, +] + +[[package]] +name = "uvicorn" +version = "0.40.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c3/d1/8f3c683c9561a4e6689dd3b1d345c815f10f86acd044ee1fb9a4dcd0b8c5/uvicorn-0.40.0.tar.gz", hash = "sha256:839676675e87e73694518b5574fd0f24c9d97b46bea16df7b8c05ea1a51071ea", size = 81761, upload-time = "2025-12-21T14:16:22.45Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/3d/d8/2083a1daa7439a66f3a48589a57d576aa117726762618f6bb09fe3798796/uvicorn-0.40.0-py3-none-any.whl", hash = "sha256:c6c8f55bc8bf13eb6fa9ff87ad62308bbbc33d0b67f84293151efe87e0d5f2ee", size = 68502, upload-time = "2025-12-21T14:16:21.041Z" }, +] + +[[package]] +name = "wsproto" +version = "1.3.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "h11" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/c7/79/12135bdf8b9c9367b8701c2c19a14c913c120b882d50b014ca0d38083c2c/wsproto-1.3.2.tar.gz", hash = "sha256:b86885dcf294e15204919950f666e06ffc6c7c114ca900b060d6e16293528294", size = 50116, upload-time = "2025-11-20T18:18:01.871Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a4/f5/10b68b7b1544245097b2a1b8238f66f2fc6dcaeb24ba5d917f52bd2eed4f/wsproto-1.3.2-py3-none-any.whl", hash = "sha256:61eea322cdf56e8cc904bd3ad7573359a242ba65688716b0710a5eb12beab584", size = 24405, upload-time = "2025-11-20T18:18:00.454Z" }, +] diff --git a/python/restate/__init__.py b/python/restate/__init__.py index 84173d5..b9fc9cd 100644 --- a/python/restate/__init__.py +++ b/python/restate/__init__.py @@ -42,6 +42,7 @@ from .endpoint import app + from .logging import getLogger, RestateLoggingFilter try: @@ -87,6 +88,12 @@ async def create_client( yield # type: ignore +try: + from .admin_client import AdminClient, ServiceInfo, HandlerInfo +except ImportError: + pass + + __all__ = [ "Service", "VirtualObject", @@ -121,4 +128,7 @@ async def create_client( "RestateClientSendHandle", "HttpError", "create_client", + "AdminClient", + "ServiceInfo", + "HandlerInfo", ] diff --git a/python/restate/admin_client.py b/python/restate/admin_client.py new file mode 100644 index 0000000..4239627 --- /dev/null +++ b/python/restate/admin_client.py @@ -0,0 +1,140 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +Client for the Restate Admin API. + +Provides typed access to service and handler metadata +via the Restate admin API (default port 9070). +""" + +from __future__ import annotations + +from typing import Any, Literal + +import httpx +from pydantic import BaseModel, ConfigDict, Field + + +class HandlerInfo(BaseModel): + """Metadata about a handler returned by the Restate admin API.""" + + model_config = ConfigDict(extra="allow") + + name: str + ty: Literal["Exclusive", "Shared", "Workflow"] | None = None + documentation: str | None = None + metadata: dict[str, str] | None = None + input_description: str | None = None + output_description: str | None = None + input_json_schema: dict[str, Any] | None = None + output_json_schema: dict[str, Any] | None = None + + +class ServiceInfo(BaseModel): + """Metadata about a service returned by the Restate admin API.""" + + model_config = ConfigDict(extra="allow") + + name: str + ty: Literal["Service", "VirtualObject", "Workflow"] + handlers: list[HandlerInfo] = Field(default_factory=list) + deployment_id: str | None = None + revision: int | None = None + public: bool | None = None + documentation: str | None = None + metadata: dict[str, str] | None = None + + def get_handler(self, name: str) -> HandlerInfo | None: + """Get a handler by name, or None if not found.""" + for h in self.handlers: + if h.name == name: + return h + return None + + +class ListServicesResponse(BaseModel): + """Response from GET /services.""" + services: list[ServiceInfo] + + +class AdminClient: + """Client for the Restate Admin API. + + Example:: + + async with AdminClient("http://localhost:9070") as client: + services = await client.list_services() + for svc in services: + print(f"{svc.name} ({svc.ty}): {len(svc.handlers)} handlers") + for h in svc.handlers: + print(f" - {h.name} metadata={h.metadata}") + """ + + def __init__(self, admin_url: str): + self._admin_url = admin_url.rstrip("/") + self._client: httpx.AsyncClient | None = None + self._owns_client = False + + @classmethod + def from_client(cls, admin_url: str, client: httpx.AsyncClient) -> AdminClient: + """Create an AdminClient using an existing httpx client.""" + instance = cls(admin_url) + instance._client = client + instance._owns_client = False + return instance + + async def _get_client(self) -> httpx.AsyncClient: + if self._client is None: + self._client = httpx.AsyncClient(base_url=self._admin_url) + self._owns_client = True + return self._client + + async def close(self) -> None: + """Close the underlying HTTP client if we own it.""" + if self._client is not None and self._owns_client: + await self._client.aclose() + self._client = None + + async def __aenter__(self) -> AdminClient: + return self + + async def __aexit__(self, *args: Any) -> None: + await self.close() + + async def list_services(self) -> list[ServiceInfo]: + """List all registered services with their handlers and metadata. + + Returns: + A list of ServiceInfo objects, each containing the service's + handlers, metadata, documentation, and configuration. + """ + client = await self._get_client() + response = await client.get("/services") + response.raise_for_status() + parsed = ListServicesResponse.model_validate(response.json()) + return parsed.services + + async def get_service(self, name: str) -> ServiceInfo: + """Get detailed information about a specific service. + + Args: + name: The service name. + + Returns: + A ServiceInfo object with full handler and metadata details. + + Raises: + httpx.HTTPStatusError: If the service is not found (404) or other errors. + """ + client = await self._get_client() + response = await client.get(f"/services/{name}") + response.raise_for_status() + return ServiceInfo.model_validate(response.json()) diff --git a/python/restate/cls.py b/python/restate/cls.py new file mode 100644 index 0000000..a7d7505 --- /dev/null +++ b/python/restate/cls.py @@ -0,0 +1,928 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +Class-based API for defining Restate services. + +This module provides an alternative to the decorator-based API, allowing +services to be defined as classes with handler methods. + +Example:: + + from restate.cls import Service, VirtualObject, Workflow, handler, shared, main + import restate + + class Greeter(Service): + @handler + async def greet(self, name: str) -> str: + return f"Hello {name}!" + + class Counter(VirtualObject): + @handler + async def increment(self, value: int) -> int: + n = await restate.get("counter", type_hint=int) or 0 + n += value + restate.set("counter", n) + return n + + @shared + async def count(self) -> int: + return await restate.get("counter", type_hint=int) or 0 +""" + +from __future__ import annotations + +import inspect +from dataclasses import dataclass, field +from datetime import timedelta +from functools import wraps +from typing import Any, AsyncContextManager, Callable, Dict, List, Literal, Optional, TypeVar + +from restate.handler import HandlerIO, ServiceTag, make_handler +from restate.retry_policy import InvocationRetryPolicy +from restate.serde import DefaultSerde, Serde + +# Imports for type annotations only — the actual classes are used +# as companion objects inside __init_subclass__. +from restate.service import Service as _OriginalService +from restate.object import VirtualObject as _OriginalVirtualObject +from restate.workflow import Workflow as _OriginalWorkflow + +I = TypeVar("I") +O = TypeVar("O") +T = TypeVar("T") + +# ── Handler marker decorators ────────────────────────────────────────────── + +_HANDLER_MARKER = "__restate_handler_meta__" + +_MISSING = object() + + +@dataclass +class _HandlerMeta: + """Metadata attached to a method by @handler / @shared / @main.""" + + kind: Literal["handler", "shared", "main"] + name: Optional[str] = None + accept: str = "application/json" + content_type: str = "application/json" + input_serde: Serde = field(default_factory=DefaultSerde) + output_serde: Serde = field(default_factory=DefaultSerde) + metadata: Optional[Dict[str, str]] = None + inactivity_timeout: Optional[timedelta] = None + abort_timeout: Optional[timedelta] = None + journal_retention: Optional[timedelta] = None + idempotency_retention: Optional[timedelta] = None + workflow_retention: Optional[timedelta] = None + enable_lazy_state: Optional[bool] = None + ingress_private: Optional[bool] = None + invocation_retry_policy: Optional[InvocationRetryPolicy] = None + invocation_context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None + + +def handler( + fn=None, + *, + name: Optional[str] = None, + accept: str = "application/json", + content_type: str = "application/json", + input_serde: Serde[I] = DefaultSerde(), + output_serde: Serde[O] = DefaultSerde(), + metadata: Optional[Dict[str, str]] = None, + inactivity_timeout: Optional[timedelta] = None, + abort_timeout: Optional[timedelta] = None, + journal_retention: Optional[timedelta] = None, + idempotency_retention: Optional[timedelta] = None, + enable_lazy_state: Optional[bool] = None, + ingress_private: Optional[bool] = None, + invocation_retry_policy: Optional[InvocationRetryPolicy] = None, + invocation_context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None, +): + """Mark a method as a Restate handler. + + For Service: a regular handler. + For VirtualObject: an exclusive handler. + For Workflow: a shared handler. + + Can be used as ``@handler`` or ``@handler(name="customName")``. + Supports the same parameters as the decorator-based API. + """ + + def decorator(fn): + setattr( + fn, + _HANDLER_MARKER, + _HandlerMeta( + kind="handler", + name=name, + accept=accept, + content_type=content_type, + input_serde=input_serde, + output_serde=output_serde, + metadata=metadata, + inactivity_timeout=inactivity_timeout, + abort_timeout=abort_timeout, + journal_retention=journal_retention, + idempotency_retention=idempotency_retention, + enable_lazy_state=enable_lazy_state, + ingress_private=ingress_private, + invocation_retry_policy=invocation_retry_policy, + invocation_context_managers=invocation_context_managers, + ), + ) + return fn + + if fn is not None: + return decorator(fn) + return decorator + + +def shared( + fn=None, + *, + name: Optional[str] = None, + accept: str = "application/json", + content_type: str = "application/json", + input_serde: Serde[I] = DefaultSerde(), + output_serde: Serde[O] = DefaultSerde(), + metadata: Optional[Dict[str, str]] = None, + inactivity_timeout: Optional[timedelta] = None, + abort_timeout: Optional[timedelta] = None, + journal_retention: Optional[timedelta] = None, + idempotency_retention: Optional[timedelta] = None, + enable_lazy_state: Optional[bool] = None, + ingress_private: Optional[bool] = None, + invocation_retry_policy: Optional[InvocationRetryPolicy] = None, + invocation_context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None, +): + """Mark a method as a shared (read-only) handler on a VirtualObject or Workflow.""" + + def decorator(fn): + setattr( + fn, + _HANDLER_MARKER, + _HandlerMeta( + kind="shared", + name=name, + accept=accept, + content_type=content_type, + input_serde=input_serde, + output_serde=output_serde, + metadata=metadata, + inactivity_timeout=inactivity_timeout, + abort_timeout=abort_timeout, + journal_retention=journal_retention, + idempotency_retention=idempotency_retention, + enable_lazy_state=enable_lazy_state, + ingress_private=ingress_private, + invocation_retry_policy=invocation_retry_policy, + invocation_context_managers=invocation_context_managers, + ), + ) + return fn + + if fn is not None: + return decorator(fn) + return decorator + + +def main( + fn=None, + *, + name: Optional[str] = None, + accept: str = "application/json", + content_type: str = "application/json", + input_serde: Serde[I] = DefaultSerde(), + output_serde: Serde[O] = DefaultSerde(), + metadata: Optional[Dict[str, str]] = None, + inactivity_timeout: Optional[timedelta] = None, + abort_timeout: Optional[timedelta] = None, + journal_retention: Optional[timedelta] = None, + workflow_retention: Optional[timedelta] = None, + enable_lazy_state: Optional[bool] = None, + ingress_private: Optional[bool] = None, + invocation_retry_policy: Optional[InvocationRetryPolicy] = None, + invocation_context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None, +): + """Mark a method as the workflow entry point.""" + + def decorator(fn): + setattr( + fn, + _HANDLER_MARKER, + _HandlerMeta( + kind="main", + name=name, + accept=accept, + content_type=content_type, + input_serde=input_serde, + output_serde=output_serde, + metadata=metadata, + inactivity_timeout=inactivity_timeout, + abort_timeout=abort_timeout, + journal_retention=journal_retention, + workflow_retention=workflow_retention, + enable_lazy_state=enable_lazy_state, + ingress_private=ingress_private, + invocation_retry_policy=invocation_retry_policy, + invocation_context_managers=invocation_context_managers, + ), + ) + return fn + + if fn is not None: + return decorator(fn) + return decorator + + +# ── Class processing ─────────────────────────────────────────────────────── + + +def _resolve_handler_kind( + service_kind: Literal["service", "object", "workflow"], + marker_kind: Literal["handler", "shared", "main"], +) -> Optional[Literal["exclusive", "shared", "workflow"]]: + """Map (service_type, marker) → Handler.kind value.""" + if service_kind == "service": + return None + if service_kind == "object": + if marker_kind == "handler": + return "exclusive" + if marker_kind == "shared": + return "shared" + raise ValueError(f"VirtualObject does not support @{marker_kind}") + if service_kind == "workflow": + if marker_kind == "main": + return "workflow" + if marker_kind == "handler": + return "shared" + if marker_kind == "shared": + return "shared" + raise ValueError(f"Workflow does not support @{marker_kind}") + raise ValueError(f"Unknown service kind: {service_kind}") + + +@dataclass +class _ServiceConfig: + """Service-level configuration extracted from __init_subclass__ kwargs.""" + + name: Optional[str] = None + description: Optional[str] = None + metadata: Optional[Dict[str, str]] = None + inactivity_timeout: Optional[timedelta] = None + abort_timeout: Optional[timedelta] = None + journal_retention: Optional[timedelta] = None + idempotency_retention: Optional[timedelta] = None + enable_lazy_state: Optional[bool] = None + ingress_private: Optional[bool] = None + invocation_retry_policy: Optional[InvocationRetryPolicy] = None + invocation_context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None + + +def _process_class( + cls: type, + service_kind: Literal["service", "object", "workflow"], + config: _ServiceConfig, +) -> None: + """Scan *cls* for marked methods and build the companion service object.""" + name = config.name or cls.__name__ + + service_tag = ServiceTag( + kind=service_kind, + name=name, + description=config.description, + metadata=config.metadata, + ) + handlers: Dict[str, Any] = {} + + for attr_name, attr_value in list(cls.__dict__.items()): + meta: Optional[_HandlerMeta] = getattr(attr_value, _HANDLER_MARKER, None) + if meta is None: + continue + + method = attr_value # unbound function + handler_kind = _resolve_handler_kind(service_kind, meta.kind) + handler_name = meta.name or method.__name__ + + # Create a wrapper that instantiates the class and calls the method. + # The wrapper has signature (ctx, *args) matching what invoke_handler expects. + @wraps(method) + async def wrapper(ctx, *args, _method=method, _cls=cls): + instance = object.__new__(_cls) + if args: + return await _method(instance, *args) + return await _method(instance) + + # Use the original method's signature for type/serde inspection + sig = inspect.signature(method, eval_str=True) + handler_io: HandlerIO = HandlerIO( + accept=meta.accept, + content_type=meta.content_type, + input_serde=meta.input_serde, + output_serde=meta.output_serde, + ) + + # Combine service-level and handler-level context managers + combined_context_managers = ( + (config.invocation_context_managers or []) + (meta.invocation_context_managers or []) + if config.invocation_context_managers or meta.invocation_context_managers + else None + ) + + h = make_handler( + service_tag=service_tag, + handler_io=handler_io, + name=handler_name, + kind=handler_kind, + wrapped=wrapper, + signature=sig, + description=inspect.getdoc(method), + metadata=meta.metadata, + inactivity_timeout=meta.inactivity_timeout, + abort_timeout=meta.abort_timeout, + journal_retention=meta.journal_retention, + idempotency_retention=meta.idempotency_retention, + workflow_retention=meta.workflow_retention, + enable_lazy_state=meta.enable_lazy_state, + ingress_private=meta.ingress_private, + invocation_retry_policy=meta.invocation_retry_policy, + context_managers=combined_context_managers, + ) + handlers[h.name] = h + + # Store handlers on the class for proxy access + cls._restate_handlers = handlers # type: ignore[attr-defined] + + # Build companion service object of the original type + svc: _OriginalService | _OriginalVirtualObject | _OriginalWorkflow + if service_kind == "service": + svc = _OriginalService( + name, + description=config.description, + metadata=config.metadata, + inactivity_timeout=config.inactivity_timeout, + abort_timeout=config.abort_timeout, + journal_retention=config.journal_retention, + idempotency_retention=config.idempotency_retention, + ingress_private=config.ingress_private, + invocation_retry_policy=config.invocation_retry_policy, + invocation_context_managers=config.invocation_context_managers, + ) + elif service_kind == "object": + svc = _OriginalVirtualObject( + name, + description=config.description, + metadata=config.metadata, + inactivity_timeout=config.inactivity_timeout, + abort_timeout=config.abort_timeout, + journal_retention=config.journal_retention, + idempotency_retention=config.idempotency_retention, + enable_lazy_state=config.enable_lazy_state, + ingress_private=config.ingress_private, + invocation_retry_policy=config.invocation_retry_policy, + invocation_context_managers=config.invocation_context_managers, + ) + elif service_kind == "workflow": + svc = _OriginalWorkflow( + name, + description=config.description, + metadata=config.metadata, + inactivity_timeout=config.inactivity_timeout, + abort_timeout=config.abort_timeout, + journal_retention=config.journal_retention, + idempotency_retention=config.idempotency_retention, + enable_lazy_state=config.enable_lazy_state, + ingress_private=config.ingress_private, + invocation_retry_policy=config.invocation_retry_policy, + invocation_context_managers=config.invocation_context_managers, + ) + else: + raise ValueError(f"Unknown service kind: {service_kind}") + + svc.handlers = handlers + cls._restate_service = svc # type: ignore[attr-defined] + + +# ── Fluent RPC proxy classes ────────────────────────────────────────────── + + +class _ServiceCallProxy: + """Proxy returned by Service.call() for type-safe RPC.""" + + def __init__(self, cls: type) -> None: + self._cls = cls + + def __getattr__(self, name: str): + handlers = getattr(self._cls, "_restate_handlers", {}) + h = handlers.get(name) + if h is None: + raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") + from restate.server_context import _restate_context_var # pylint: disable=C0415 + + ctx = _restate_context_var.get() + + def invoke(arg=_MISSING): + if arg is _MISSING: + return ctx.service_call(h.fn, arg=None) + return ctx.service_call(h.fn, arg=arg) + + return invoke + + +class _ServiceSendProxy: + """Proxy returned by Service.send() for fire-and-forget.""" + + def __init__(self, cls: type, delay: Optional[timedelta] = None) -> None: + self._cls = cls + self._delay = delay + + def __getattr__(self, name: str): + handlers = getattr(self._cls, "_restate_handlers", {}) + h = handlers.get(name) + if h is None: + raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") + from restate.server_context import _restate_context_var # pylint: disable=C0415 + + ctx = _restate_context_var.get() + + def invoke(arg=_MISSING): + if arg is _MISSING: + return ctx.service_send(h.fn, arg=None, send_delay=self._delay) + return ctx.service_send(h.fn, arg=arg, send_delay=self._delay) + + return invoke + + +class _ObjectCallProxy: + """Proxy returned by VirtualObject.call(key) for type-safe RPC.""" + + def __init__(self, cls: type, key: str) -> None: + self._cls = cls + self._key = key + + def __getattr__(self, name: str): + handlers = getattr(self._cls, "_restate_handlers", {}) + h = handlers.get(name) + if h is None: + raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") + from restate.server_context import _restate_context_var # pylint: disable=C0415 + + ctx = _restate_context_var.get() + + def invoke(arg=_MISSING): + if arg is _MISSING: + return ctx.object_call(h.fn, key=self._key, arg=None) + return ctx.object_call(h.fn, key=self._key, arg=arg) + + return invoke + + +class _ObjectSendProxy: + """Proxy returned by VirtualObject.send(key) for fire-and-forget.""" + + def __init__(self, cls: type, key: str, delay: Optional[timedelta] = None) -> None: + self._cls = cls + self._key = key + self._delay = delay + + def __getattr__(self, name: str): + handlers = getattr(self._cls, "_restate_handlers", {}) + h = handlers.get(name) + if h is None: + raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") + from restate.server_context import _restate_context_var # pylint: disable=C0415 + + ctx = _restate_context_var.get() + + def invoke(arg=_MISSING): + if arg is _MISSING: + return ctx.object_send(h.fn, key=self._key, arg=None, send_delay=self._delay) + return ctx.object_send(h.fn, key=self._key, arg=arg, send_delay=self._delay) + + return invoke + + +class _WorkflowCallProxy: + """Proxy returned by Workflow.call(key) for type-safe RPC.""" + + def __init__(self, cls: type, key: str) -> None: + self._cls = cls + self._key = key + + def __getattr__(self, name: str): + handlers = getattr(self._cls, "_restate_handlers", {}) + h = handlers.get(name) + if h is None: + raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") + from restate.server_context import _restate_context_var # pylint: disable=C0415 + + ctx = _restate_context_var.get() + + def invoke(arg=_MISSING): + if arg is _MISSING: + return ctx.workflow_call(h.fn, key=self._key, arg=None) + return ctx.workflow_call(h.fn, key=self._key, arg=arg) + + return invoke + + +class _WorkflowSendProxy: + """Proxy returned by Workflow.send(key) for fire-and-forget.""" + + def __init__(self, cls: type, key: str, delay: Optional[timedelta] = None) -> None: + self._cls = cls + self._key = key + self._delay = delay + + def __getattr__(self, name: str): + handlers = getattr(self._cls, "_restate_handlers", {}) + h = handlers.get(name) + if h is None: + raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") + from restate.server_context import _restate_context_var # pylint: disable=C0415 + + ctx = _restate_context_var.get() + + def invoke(arg=_MISSING): + if arg is _MISSING: + return ctx.workflow_send(h.fn, key=self._key, arg=None, send_delay=self._delay) + return ctx.workflow_send(h.fn, key=self._key, arg=arg, send_delay=self._delay) + + return invoke + + +# ── Base classes ─────────────────────────────────────────────────────────── + + +class Service: + """Base class for class-based Restate services. + + Supports the same service-level configuration as the decorator-based API:: + + class Greeter(Service, name="MyGreeter", ingress_private=True): + @handler + async def greet(self, name: str) -> str: + return f"Hello {name}!" + + app = restate.app([Greeter]) + """ + + _restate_service: _OriginalService + _restate_handlers: Dict[str, Any] + + def __init_subclass__( + cls, + *, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + inactivity_timeout: Optional[timedelta] = None, + abort_timeout: Optional[timedelta] = None, + journal_retention: Optional[timedelta] = None, + idempotency_retention: Optional[timedelta] = None, + ingress_private: Optional[bool] = None, + invocation_retry_policy: Optional[InvocationRetryPolicy] = None, + invocation_context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None, + **kwargs: Any, + ) -> None: + super().__init_subclass__(**kwargs) + _process_class( + cls, + "service", + _ServiceConfig( + name=name, + description=description, + metadata=metadata, + inactivity_timeout=inactivity_timeout, + abort_timeout=abort_timeout, + journal_retention=journal_retention, + idempotency_retention=idempotency_retention, + ingress_private=ingress_private, + invocation_retry_policy=invocation_retry_policy, + invocation_context_managers=invocation_context_managers, + ), + ) + + @classmethod + def call(cls) -> "Service": # type: ignore[return-type] + """Return a proxy for making durable service calls. + + The proxy has the same method signatures as the class, + giving full IDE autocomplete and type inference. + """ + return _ServiceCallProxy(cls) # type: ignore[return-value] + + @classmethod + def send(cls, *, delay: Optional[timedelta] = None) -> "Service": # type: ignore[return-type] + """Return a proxy for fire-and-forget service sends.""" + return _ServiceSendProxy(cls, delay) # type: ignore[return-value] + + +class VirtualObject: + """Base class for class-based Restate virtual objects. + + Supports the same service-level configuration as the decorator-based API:: + + class Counter(VirtualObject, name="MyCounter", enable_lazy_state=True): + @handler + async def increment(self, value: int) -> int: ... + + app = restate.app([Counter]) + """ + + _restate_service: _OriginalVirtualObject + _restate_handlers: Dict[str, Any] + + def __init_subclass__( + cls, + *, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + inactivity_timeout: Optional[timedelta] = None, + abort_timeout: Optional[timedelta] = None, + journal_retention: Optional[timedelta] = None, + idempotency_retention: Optional[timedelta] = None, + enable_lazy_state: Optional[bool] = None, + ingress_private: Optional[bool] = None, + invocation_retry_policy: Optional[InvocationRetryPolicy] = None, + invocation_context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None, + **kwargs: Any, + ) -> None: + super().__init_subclass__(**kwargs) + _process_class( + cls, + "object", + _ServiceConfig( + name=name, + description=description, + metadata=metadata, + inactivity_timeout=inactivity_timeout, + abort_timeout=abort_timeout, + journal_retention=journal_retention, + idempotency_retention=idempotency_retention, + enable_lazy_state=enable_lazy_state, + ingress_private=ingress_private, + invocation_retry_policy=invocation_retry_policy, + invocation_context_managers=invocation_context_managers, + ), + ) + + @classmethod + def call(cls, key: str) -> "VirtualObject": # type: ignore[return-type] + """Return a proxy for making durable object calls.""" + return _ObjectCallProxy(cls, key) # type: ignore[return-value] + + @classmethod + def send(cls, key: str, *, delay: Optional[timedelta] = None) -> "VirtualObject": # type: ignore[return-type] + """Return a proxy for fire-and-forget object sends.""" + return _ObjectSendProxy(cls, key, delay) # type: ignore[return-value] + + +class Workflow: + """Base class for class-based Restate workflows. + + Supports the same service-level configuration as the decorator-based API:: + + class Payment(Workflow, name="MyPayment"): + @main + async def pay(self, amount: int) -> dict: ... + + app = restate.app([Payment]) + """ + + _restate_service: _OriginalWorkflow + _restate_handlers: Dict[str, Any] + + def __init_subclass__( + cls, + *, + name: Optional[str] = None, + description: Optional[str] = None, + metadata: Optional[Dict[str, str]] = None, + inactivity_timeout: Optional[timedelta] = None, + abort_timeout: Optional[timedelta] = None, + journal_retention: Optional[timedelta] = None, + idempotency_retention: Optional[timedelta] = None, + enable_lazy_state: Optional[bool] = None, + ingress_private: Optional[bool] = None, + invocation_retry_policy: Optional[InvocationRetryPolicy] = None, + invocation_context_managers: Optional[List[Callable[[], AsyncContextManager[None]]]] = None, + **kwargs: Any, + ) -> None: + super().__init_subclass__(**kwargs) + _process_class( + cls, + "workflow", + _ServiceConfig( + name=name, + description=description, + metadata=metadata, + inactivity_timeout=inactivity_timeout, + abort_timeout=abort_timeout, + journal_retention=journal_retention, + idempotency_retention=idempotency_retention, + enable_lazy_state=enable_lazy_state, + ingress_private=ingress_private, + invocation_retry_policy=invocation_retry_policy, + invocation_context_managers=invocation_context_managers, + ), + ) + + @classmethod + def call(cls, key: str) -> "Workflow": # type: ignore[return-type] + """Return a proxy for making durable workflow calls.""" + return _WorkflowCallProxy(cls, key) # type: ignore[return-value] + + @classmethod + def send(cls, key: str, *, delay: Optional[timedelta] = None) -> "Workflow": # type: ignore[return-type] + """Return a proxy for fire-and-forget workflow sends.""" + return _WorkflowSendProxy(cls, key, delay) # type: ignore[return-value] + + +# ── Context accessor class ──────────────────────────────────────────────── + + +class Context: + """Static accessor for the current Restate invocation context. + + Use from within handler methods to access Restate functionality + without an explicit ``ctx`` parameter:: + + from restate.cls import Service, handler, Context + + class Greeter(Service): + @handler + async def greet(self, name: str) -> str: + count = await Context.get("visits", type_hint=int) or 0 + Context.set("visits", count + 1) + return f"Hello {name}!" + """ + + @staticmethod + def _ctx() -> Any: + from restate.server_context import _restate_context_var # pylint: disable=C0415 + + try: + return _restate_context_var.get() + except LookupError: + raise RuntimeError( + "Not inside a Restate handler. Context methods can only be called within a handler invocation." + ) from None + + # ── State ── + + @staticmethod + def get(name: str, serde: Serde = DefaultSerde(), type_hint: Optional[type] = None) -> Any: + """Retrieve a state value by name.""" + return Context._ctx().get(name, serde=serde, type_hint=type_hint) + + @staticmethod + def set(name: str, value: Any, serde: Serde = DefaultSerde()) -> None: + """Set a state value by name.""" + Context._ctx().set(name, value, serde=serde) + + @staticmethod + def clear(name: str) -> None: + """Clear a state value by name.""" + Context._ctx().clear(name) + + @staticmethod + def clear_all() -> None: + """Clear all state values.""" + Context._ctx().clear_all() + + @staticmethod + def state_keys() -> Any: + """Return the list of state keys.""" + return Context._ctx().state_keys() + + # ── Identity & request ── + + @staticmethod + def key() -> str: + """Return the key of the current virtual object or workflow.""" + return Context._ctx().key() + + @staticmethod + def request() -> Any: + """Return the current request object.""" + return Context._ctx().request() + + @staticmethod + def random() -> Any: + """Return a deterministically-seeded Random instance.""" + return Context._ctx().random() + + @staticmethod + def uuid() -> Any: + """Return a deterministic UUID, stable across retries.""" + return Context._ctx().uuid() + + @staticmethod + def time() -> Any: + """Return a durable timestamp, stable across retries.""" + return Context._ctx().time() + + # ── Durable execution ── + + @staticmethod + def run(name: str, action: Any, serde: Serde = DefaultSerde(), **kwargs: Any) -> Any: + """Run a durable side effect (deprecated — use run_typed).""" + return Context._ctx().run(name, action, serde=serde, **kwargs) + + @staticmethod + def run_typed(name: str, action: Any, *args: Any, **kwargs: Any) -> Any: + """Run a durable side effect with typed arguments.""" + return Context._ctx().run_typed(name, action, *args, **kwargs) + + @staticmethod + def sleep(delta: timedelta, name: Optional[str] = None) -> Any: + """Suspend the current invocation for the given duration.""" + return Context._ctx().sleep(delta, name=name) + + # ── Service communication ── + + @staticmethod + def service_call(tpe: Any, arg: Any, **kwargs: Any) -> Any: + """Call a service handler.""" + return Context._ctx().service_call(tpe, arg=arg, **kwargs) + + @staticmethod + def service_send(tpe: Any, arg: Any, **kwargs: Any) -> Any: + """Send a message to a service handler (fire-and-forget).""" + return Context._ctx().service_send(tpe, arg=arg, **kwargs) + + @staticmethod + def object_call(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: + """Call a virtual object handler.""" + return Context._ctx().object_call(tpe, key=key, arg=arg, **kwargs) + + @staticmethod + def object_send(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: + """Send a message to a virtual object handler (fire-and-forget).""" + return Context._ctx().object_send(tpe, key=key, arg=arg, **kwargs) + + @staticmethod + def workflow_call(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: + """Call a workflow handler.""" + return Context._ctx().workflow_call(tpe, key=key, arg=arg, **kwargs) + + @staticmethod + def workflow_send(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: + """Send a message to a workflow handler (fire-and-forget).""" + return Context._ctx().workflow_send(tpe, key=key, arg=arg, **kwargs) + + @staticmethod + def generic_call(service: str, handler: str, arg: bytes, key: Optional[str] = None, **kwargs: Any) -> Any: + """Call a generic service/handler with raw bytes.""" + return Context._ctx().generic_call(service, handler, arg, key=key, **kwargs) + + @staticmethod + def generic_send(service: str, handler: str, arg: bytes, key: Optional[str] = None, **kwargs: Any) -> Any: + """Send a message to a generic service/handler with raw bytes.""" + return Context._ctx().generic_send(service, handler, arg, key=key, **kwargs) + + # ── Awakeables ── + + @staticmethod + def awakeable(serde: Serde = DefaultSerde(), type_hint: Optional[type] = None) -> Any: + """Create an awakeable and return (id, future).""" + return Context._ctx().awakeable(serde=serde, type_hint=type_hint) + + @staticmethod + def resolve_awakeable(name: str, value: Any, serde: Serde = DefaultSerde()) -> None: + """Resolve an awakeable by id.""" + Context._ctx().resolve_awakeable(name, value, serde=serde) + + @staticmethod + def reject_awakeable(name: str, failure_message: str, failure_code: int = 500) -> None: + """Reject an awakeable by id.""" + Context._ctx().reject_awakeable(name, failure_message, failure_code=failure_code) + + # ── Promises (Workflow only) ── + + @staticmethod + def promise(name: str, serde: Serde = DefaultSerde(), type_hint: Optional[type] = None) -> Any: + """Return a durable promise (workflow handlers only).""" + return Context._ctx().promise(name, serde=serde, type_hint=type_hint) + + # ── Invocation management ── + + @staticmethod + def cancel_invocation(invocation_id: str) -> None: + """Cancel an invocation by id.""" + Context._ctx().cancel_invocation(invocation_id) + + @staticmethod + def attach_invocation(invocation_id: str, serde: Serde = DefaultSerde(), type_hint: Optional[type] = None) -> Any: + """Attach to an invocation by id.""" + return Context._ctx().attach_invocation(invocation_id, serde=serde, type_hint=type_hint) diff --git a/python/restate/context_access.py b/python/restate/context_access.py new file mode 100644 index 0000000..1ebb2d8 --- /dev/null +++ b/python/restate/context_access.py @@ -0,0 +1,307 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +Module-level context accessor functions for the class-based API. + +These functions delegate to the current invocation context via contextvars, +allowing handlers to access Restate functionality without an explicit ctx parameter. +""" + +from datetime import timedelta +from random import Random +from typing import Any, Awaitable, Callable, Coroutine, Dict, List, Optional, Tuple, TypeVar, Union +from uuid import UUID + +from restate.context import ( + DurablePromise, + HandlerType, + RestateDurableCallFuture, + RestateDurableFuture, + RestateDurableSleepFuture, + Request, + RunAction, + RunOptions, + SendHandle, +) +from restate.serde import DefaultSerde, Serde + +T = TypeVar("T") +I = TypeVar("I") +O = TypeVar("O") + + +def _ctx() -> Any: + """Get the current restate context, raising if not inside a handler. + + Returns Any because the actual runtime type is ServerInvocationContext + (which implements ObjectContext, WorkflowContext, etc.) but we want all + methods accessible without narrowing — runtime raises if mismatched. + """ + # Import here to avoid circular imports + from restate.server_context import _restate_context_var # pylint: disable=C0415 + + try: + return _restate_context_var.get() + except LookupError: + raise RuntimeError( + "Not inside a Restate handler. " + "Module-level restate functions can only be called within a handler invocation." + ) from None + + +def current_context(): + """Get the current Restate context. + + Returns the context object for the current handler invocation. + Raises RuntimeError if called outside a handler. + """ + return _ctx() + + +# ── State operations ── + + +def get(name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[type] = None) -> Awaitable[Optional[T]]: + """Retrieve a state value by name.""" + return _ctx().get(name, serde=serde, type_hint=type_hint) + + +def set(name: str, value: T, serde: Serde[T] = DefaultSerde()) -> None: + """Set a state value by name.""" + _ctx().set(name, value, serde=serde) + + +def clear(name: str) -> None: + """Clear a state value by name.""" + _ctx().clear(name) + + +def clear_all() -> None: + """Clear all state values.""" + _ctx().clear_all() + + +def state_keys() -> Awaitable[List[str]]: + """Return the list of state keys.""" + return _ctx().state_keys() + + +# ── Identity & request ── + + +def key() -> str: + """Return the key of the current virtual object or workflow.""" + return _ctx().key() + + +def request() -> Request: + """Return the current request object.""" + return _ctx().request() + + +def random() -> Random: + """Return a deterministically-seeded Random instance.""" + return _ctx().random() + + +def uuid() -> UUID: + """Return a deterministic UUID, stable across retries.""" + return _ctx().uuid() + + +def time() -> RestateDurableFuture[float]: + """Return a durable timestamp, stable across retries.""" + return _ctx().time() + + +# ── Durable execution ── + + +def run( + name: str, + action: RunAction[T], + serde: Serde[T] = DefaultSerde(), + max_attempts: Optional[int] = None, + max_retry_duration: Optional[timedelta] = None, + type_hint: Optional[type] = None, + args: Optional[tuple] = None, +) -> RestateDurableFuture[T]: + """Run a durable side effect (deprecated — use run_typed instead).""" + return _ctx().run( + name, + action, + serde=serde, + max_attempts=max_attempts, + max_retry_duration=max_retry_duration, + type_hint=type_hint, + args=args, + ) + + +def run_typed( + name: str, + action: Union[Callable[..., Coroutine[Any, Any, T]], Callable[..., T]], + options: RunOptions[T] = RunOptions(), + /, + *args: Any, + **kwargs: Any, +) -> RestateDurableFuture[T]: + """Run a durable side effect with typed arguments.""" + return _ctx().run_typed(name, action, options, *args, **kwargs) + + +def sleep(delta: timedelta, name: Optional[str] = None) -> RestateDurableSleepFuture: + """Suspend the current invocation for the given duration.""" + return _ctx().sleep(delta, name=name) + + +# ── Service communication ── + + +def service_call( + tpe: HandlerType[I, O], + arg: I, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> RestateDurableCallFuture[O]: + """Call a service handler.""" + return _ctx().service_call(tpe, arg=arg, idempotency_key=idempotency_key, headers=headers) + + +def service_send( + tpe: HandlerType[I, O], + arg: I, + send_delay: Optional[timedelta] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> SendHandle: + """Send a message to a service handler (fire-and-forget).""" + return _ctx().service_send(tpe, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers) + + +def object_call( + tpe: HandlerType[I, O], + key: str, + arg: I, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> RestateDurableCallFuture[O]: + """Call a virtual object handler.""" + return _ctx().object_call(tpe, key=key, arg=arg, idempotency_key=idempotency_key, headers=headers) + + +def object_send( + tpe: HandlerType[I, O], + key: str, + arg: I, + send_delay: Optional[timedelta] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> SendHandle: + """Send a message to a virtual object handler (fire-and-forget).""" + return _ctx().object_send( + tpe, key=key, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers + ) + + +def workflow_call( + tpe: HandlerType[I, O], + key: str, + arg: I, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> RestateDurableCallFuture[O]: + """Call a workflow handler.""" + return _ctx().workflow_call(tpe, key=key, arg=arg, idempotency_key=idempotency_key, headers=headers) + + +def workflow_send( + tpe: HandlerType[I, O], + key: str, + arg: I, + send_delay: Optional[timedelta] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> SendHandle: + """Send a message to a workflow handler (fire-and-forget).""" + return _ctx().workflow_send( + tpe, key=key, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers + ) + + +def generic_call( + service: str, + handler: str, + arg: bytes, + key: Optional[str] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> RestateDurableCallFuture[bytes]: + """Call a generic service/handler with raw bytes.""" + return _ctx().generic_call(service, handler, arg, key=key, idempotency_key=idempotency_key, headers=headers) + + +def generic_send( + service: str, + handler: str, + arg: bytes, + key: Optional[str] = None, + send_delay: Optional[timedelta] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, +) -> SendHandle: + """Send a message to a generic service/handler with raw bytes.""" + return _ctx().generic_send( + service, handler, arg, key=key, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers + ) + + +# ── Awakeables ── + + +def awakeable( + serde: Serde[T] = DefaultSerde(), type_hint: Optional[type] = None +) -> Tuple[str, RestateDurableFuture[T]]: + """Create an awakeable and return (id, future).""" + return _ctx().awakeable(serde=serde, type_hint=type_hint) + + +def resolve_awakeable(name: str, value: I, serde: Serde[I] = DefaultSerde()) -> None: + """Resolve an awakeable by id.""" + _ctx().resolve_awakeable(name, value, serde=serde) + + +def reject_awakeable(name: str, failure_message: str, failure_code: int = 500) -> None: + """Reject an awakeable by id.""" + _ctx().reject_awakeable(name, failure_message, failure_code=failure_code) + + +# ── Promises (Workflow only) ── + + +def promise(name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[type] = None) -> DurablePromise[T]: + """Return a durable promise (workflow handlers only).""" + return _ctx().promise(name, serde=serde, type_hint=type_hint) + + +# ── Invocation management ── + + +def cancel_invocation(invocation_id: str): + """Cancel an invocation by id.""" + _ctx().cancel_invocation(invocation_id) + + +def attach_invocation( + invocation_id: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[type] = None +) -> RestateDurableFuture[T]: + """Attach to an invocation by id.""" + return _ctx().attach_invocation(invocation_id, serde=serde, type_hint=type_hint) diff --git a/python/restate/endpoint.py b/python/restate/endpoint.py index f8c3a83..0b59e96 100644 --- a/python/restate/endpoint.py +++ b/python/restate/endpoint.py @@ -44,12 +44,13 @@ def __init__(self): self.identity_keys = [] - def bind(self, *services: typing.Union[Service, VirtualObject, Workflow]): + def bind(self, *services: typing.Any): """ Bind a service to the endpoint Args: - service: The service or virtual object to bind to the endpoint + service: The service or virtual object to bind to the endpoint. + Also accepts class-based services (subclasses of restate.cls.Service etc.) Raises: ValueError: If a service with the same name already exists in the endpoint @@ -58,10 +59,12 @@ def bind(self, *services: typing.Union[Service, VirtualObject, Workflow]): The updated Endpoint instance """ for service in services: - if service.name in self.services: - raise ValueError(f"Service {service.name} already exists") - if isinstance(service, (Service, VirtualObject, Workflow)): - self.services[service.name] = service + # Support class-based services: extract companion object + actual = getattr(service, "_restate_service", service) + if actual.name in self.services: + raise ValueError(f"Service {actual.name} already exists") + if isinstance(actual, (Service, VirtualObject, Workflow)): + self.services[actual.name] = actual else: raise ValueError(f"Invalid service type {service}") return self @@ -98,7 +101,7 @@ def app(self): def app( - services: typing.Iterable[typing.Union[Service, VirtualObject, Workflow]], + services: typing.Iterable[typing.Any], protocol: typing.Optional[typing.Literal["bidi", "request_response"]] = None, identity_keys: typing.Optional[typing.List[str]] = None, ): diff --git a/python/restate/ext/a2a/PLAN.md b/python/restate/ext/a2a/PLAN.md new file mode 100644 index 0000000..2a82c54 --- /dev/null +++ b/python/restate/ext/a2a/PLAN.md @@ -0,0 +1,247 @@ +# A2A Integration for Restate Python SDK — v3 + +## Context + +Adding an A2A integration that combines the best of both approaches: +- **Inside Restate**: A `TaskObject` (VirtualObject) manages task state, durability, and agent invocation +- **Outside Restate**: A thin FastAPI gateway handles A2A protocol (JSON-RPC, agent card catalog) and discovers agents via the Restate admin API + +The gateway translates A2A JSON-RPC into Restate ingress HTTP calls. No `AgentExecutor`, no `DefaultRequestHandler`, no A2A server SDK needed on the gateway — it's a direct protocol translation. + +## Architecture + +``` +A2A Client + │ + ▼ +FastAPI Gateway (outside Restate) + │ + ├─ GET /.well-known/agent-card + │ → queries Restate admin API for services with a2a metadata + │ → builds and returns agent card catalog + │ + ├─ POST /{agent}/a2a (JSON-RPC: message/send) + │ → POST http://ingress/{agent}-task/{task_id}/handle_send_message_request + │ + ├─ POST /{agent}/a2a (JSON-RPC: tasks/get) + │ → POST http://ingress/{agent}-task/{task_id}/get_task + │ + └─ POST /{agent}/a2a (JSON-RPC: tasks/cancel) + → POST http://ingress/{agent}-task/{task_id}/get_invocation_id + → PATCH http://admin/invocations/{id}/cancel + → POST http://ingress/{agent}-task/{task_id}/get_task (poll for result) + +Inside Restate: + TaskObject (VirtualObject, keyed by task_id) + ├─ handle_send_message_request (exclusive) — runs agent, manages task state + ├─ get_task (shared) — returns task from K/V store + ├─ get_invocation_id (shared) — returns in-flight invocation ID + └─ cancel_task (exclusive) — marks task as canceled +``` + +## Agent Function Signature + +```python +async def my_agent(query: str, context_id: str) -> AgentInvokeResult: + ctx = restate_context() # available via current_context() + result = await ctx.run_typed("call_llm", llm.call, query) + return AgentInvokeResult(parts=[TextPart(text=result)]) +``` + +Runs within TaskObject's exclusive handler context — full access to `ctx.run_typed()`, service calls, etc. + +## Agent Discovery via Metadata + +Each `TaskObject` stores its agent card in service metadata: + +```python +metadata={"a2a.agent_card": agent_card.model_dump_json()} +``` + +Service `description` is used as the agent description. + +The gateway queries `GET http://admin:9070/services`, filters for services with `a2a.agent_card` metadata, and deserializes the agent cards. + +## User-Facing API + +### Restate side (agent definition): + +```python +from restate.ext.a2a import A2ATaskObject, AgentInvokeResult +from a2a.types import AgentCard, AgentSkill, TextPart + +async def weather_agent(query: str, context_id: str) -> AgentInvokeResult: + ctx = restate_context() + forecast = await ctx.run_typed("get_forecast", fetch_forecast, query) + return AgentInvokeResult(parts=[TextPart(text=forecast)]) + +weather = A2ATaskObject( + "weather", + invoke_function=weather_agent, + agent_card=AgentCard( + name="Weather Agent", + description="Provides weather forecasts", + url="http://gateway:8000/weather/a2a", # gateway URL + version="1.0", + skills=[AgentSkill(id="forecast", name="Forecast", description="...")], + default_input_modes=["text"], + default_output_modes=["text"], + ), +) + +# Standard restate app — TaskObject is just a VirtualObject +app = restate.app(services=[weather]) +``` + +### Gateway side (separate process): + +```python +from restate.ext.a2a import A2AGateway + +gateway = A2AGateway( + restate_admin_url="http://localhost:9070", + restate_ingress_url="http://localhost:8080", +) +app = gateway.build() # FastAPI app + +# Run: uvicorn gateway:app --port 8000 +``` + +The gateway auto-discovers all `A2ATaskObject` services from the admin API. + +## Files to Create/Modify + +### 1. `python/restate/ext/a2a/__init__.py` (new) + +Exports: +- `A2ATaskObject` — VirtualObject with built-in task management +- `AgentInvokeResult` — result type for invoke_function +- `A2AGateway` — FastAPI gateway +- `restate_context()` / `restate_object_context()` — context helpers + +### 2. `python/restate/ext/a2a/_models.py` (new) + +```python +@dataclass +class AgentInvokeResult: + parts: list[Part] + require_user_input: bool = False + +InvokeAgentType = Callable[[str, str], Awaitable[AgentInvokeResult]] +``` + +### 3. `python/restate/ext/a2a/_task.py` (new, based on reference) + +**`TaskObject`** — VirtualObject keyed by `task_id`. Copied from reference with adjustments: + +- `handle_send_message_request(ctx, request: SendMessageRequest) -> SendMessageResponse` + - Generates context_id if missing + - Stores invocation ID for cancellation + - Upserts task in K/V store + - Calls `invoke_function(query, context_id)` + - Updates task to completed/input_required/failed/canceled +- `get_task(ctx) -> Task | None` (shared) +- `get_invocation_id(ctx) -> str | None` (shared) +- `cancel_task(ctx, request) -> CancelTaskResponse` (exclusive) +- `update_store(ctx, state, ...) -> Task` (exclusive) +- `upsert_task(ctx, params) -> Task` (exclusive) + +**`A2ATaskObject`** — wrapper that creates a `TaskObject` with agent card stored in metadata: + +```python +class A2ATaskObject: + def __init__(self, name, invoke_function, agent_card): + self._task_object = TaskObject( + f"{name}", + invoke_function, + ) + # Store agent card in metadata for discovery + self._task_object.metadata = {"a2a.agent_card": agent_card.model_dump_json()} + self._task_object.description = agent_card.description +``` + +Exposes the same interface as VirtualObject so it can be passed to `restate.app()`. + +### 4. `python/restate/ext/a2a/_gateway.py` (new) + +**`A2AGateway`** — FastAPI app builder: + +- Constructor: `(restate_admin_url, restate_ingress_url)` +- `build()` → FastAPI app with: + - `GET /.well-known/agent-card` — returns agent card(s) from admin API discovery + - `POST /{agent_name}/a2a` — JSON-RPC dispatch endpoint per agent +- Agent discovery: queries `GET {admin_url}/services`, filters by `a2a.agent_card` metadata +- JSON-RPC dispatch: + - `message/send` → `POST {ingress_url}/{agent_name}/{task_id}/handle_send_message_request` + - `tasks/get` → `POST {ingress_url}/{agent_name}/{task_id}/get_task` + - `tasks/cancel`: + 1. `POST {ingress_url}/{agent_name}/{task_id}/get_invocation_id` + 2. `PATCH {admin_url}/invocations/{id}/cancel` + 3. Poll `get_task` for final state + - Other methods → return appropriate JSON-RPC error responses + +Uses `httpx.AsyncClient` for all HTTP calls. Lifespan manages the client. + +### 5. `pyproject.toml` (modify) +Add: `a2a = ["a2a-sdk", "fastapi", "httpx[http2]"]` + +## Key Design Details + +### Ingress URL patterns (VirtualObject) +- `POST http://ingress/{service}/{key}/{handler}` — blocking call +- `POST http://ingress/{service}/{key}/{handler}/send` — returns invocation ID immediately +- Request body: JSON-serialized handler input +- Response: JSON-serialized handler output +- Header `x-restate-id`: invocation ID + +### Admin API patterns +- `GET http://admin/services` — list all services with metadata +- `PATCH http://admin/invocations/{id}/cancel` — cancel in-flight invocation + +### Serialization +A2A types are Pydantic models → Restate's PydanticJsonSerde handles them correctly. The gateway serializes/deserializes using the same Pydantic models. + +### Agent Card in Metadata +```python +metadata = {"a2a.agent_card": agent_card.model_dump_json()} +``` +Gateway deserializes: `AgentCard.model_validate_json(metadata["a2a.agent_card"])` + +### Task ID +- If the A2A client provides a `task_id` in the message → use it as VirtualObject key +- If not → gateway generates a UUID and uses it as the key + +### Multi-agent Support +- Each agent is a separate `A2ATaskObject` (VirtualObject) +- Gateway discovers all of them via admin API +- Each agent gets its own endpoint: `POST /{agent_name}/a2a` +- Agent cards include their specific URL + +## Cancellation Flow (detailed) + +1. A2A client sends `tasks/cancel` JSON-RPC request +2. Gateway receives it, extracts `task_id` and `agent_name` +3. Gateway calls `POST {ingress}/{agent_name}/{task_id}/get_invocation_id` +4. If invocation_id exists: + a. `PATCH {admin}/invocations/{invocation_id}/cancel` + b. The running handler catches TerminalError(409, "cancelled"), updates task to canceled + c. Gateway calls `POST {ingress}/{agent_name}/{task_id}/get_task` to get final state +5. If no invocation_id (task already completed): + a. Gateway calls `POST {ingress}/{agent_name}/{task_id}/cancel_task` + +## v1 Limitations +- No streaming support +- No push notifications +- No resubscribe support +- No authenticated extended card +- Gateway discovery is one-shot on startup (could add polling later) + +## Verification +1. Create a simple echo agent with `A2ATaskObject` +2. Run restate app: `restate.app(services=[echo_agent])` +3. Run gateway: `uvicorn gateway:app` +4. Send JSON-RPC `message/send` via curl +5. Send `tasks/get` to retrieve task +6. Test multi-turn with same task_id +7. Test cancellation +8. Verify agent card discovery at `/.well-known/agent-card` diff --git a/python/restate/ext/a2a/__init__.py b/python/restate/ext/a2a/__init__.py new file mode 100644 index 0000000..7b8add8 --- /dev/null +++ b/python/restate/ext/a2a/__init__.py @@ -0,0 +1,13 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +""" +This module contains the optional A2A (Agent-to-Agent) integration for Restate. +""" diff --git a/python/restate/ext/adk/summarizer.py b/python/restate/ext/adk/summarizer.py index c50f17a..73ed47e 100644 --- a/python/restate/ext/adk/summarizer.py +++ b/python/restate/ext/adk/summarizer.py @@ -67,17 +67,14 @@ def from_summarizer( """Create a RestateEventSummarizer wrapping a custom summarizer.""" return RestateEventSummarizer(summarizer, max_retries=max_retries) - async def maybe_summarize_events( - self, *, events: list[Event] - ) -> Optional[Event]: + async def maybe_summarize_events(self, *, events: list[Event]) -> Optional[Event]: if not events: return None ctx = current_context() if ctx is None: raise RuntimeError( - "No Restate context found. " - "RestateEventSummarizer must be used from within a Restate handler." + "No Restate context found. RestateEventSummarizer must be used from within a Restate handler." ) inner = self._inner @@ -92,4 +89,4 @@ async def call_inner() -> Optional[Event]: max_attempts=self._max_retries, initial_retry_interval=timedelta(seconds=1), ), - ) \ No newline at end of file + ) diff --git a/python/restate/ext/pydantic/__init__.py b/python/restate/ext/pydantic/__init__.py index 76d633d..2b16785 100644 --- a/python/restate/ext/pydantic/__init__.py +++ b/python/restate/ext/pydantic/__init__.py @@ -8,6 +8,7 @@ from ._serde import PydanticTypeAdapter from ._toolset import RestateContextRunToolSet + def restate_object_context() -> ObjectContext: """Get the current Restate ObjectContext.""" ctx = current_context() diff --git a/test-services-cls/Dockerfile b/test-services-cls/Dockerfile new file mode 100644 index 0000000..4140858 --- /dev/null +++ b/test-services-cls/Dockerfile @@ -0,0 +1,42 @@ +# syntax=docker.io/docker/dockerfile:1.7-labs + +FROM ghcr.io/astral-sh/uv:python3.12-bookworm-slim as build-sdk + +ENV UV_PYTHON "3.12" + +RUN apt-get update -y && apt-get install -y build-essential + +WORKDIR /usr/src/app + +COPY src ./src/ +COPY python ./python/ +COPY Cargo.lock . +COPY Cargo.toml . +COPY rust-toolchain.toml . +COPY pyproject.toml . +COPY LICENSE . +COPY README.md . +COPY uv.lock . + + +RUN uv sync --all-extras --all-packages +RUN uv build --all-packages + +FROM python:3.12-slim AS test-services + +WORKDIR /usr/src/app + +COPY --from=build-sdk /usr/src/app/dist/* /usr/src/app/deps/ + +RUN pip install deps/*whl +RUN pip install hypercorn + +COPY test-services-cls/ . + +EXPOSE 9080 + +ENV RESTATE_CORE_LOG=debug +ENV RUST_BACKTRACE=1 +ENV PORT 9080 + +ENTRYPOINT ["./entrypoint.sh"] diff --git a/test-services-cls/entrypoint.sh b/test-services-cls/entrypoint.sh new file mode 100755 index 0000000..94b8e2c --- /dev/null +++ b/test-services-cls/entrypoint.sh @@ -0,0 +1,19 @@ +#!/usr/bin/env sh + +PORT=${PORT:-"9080"} + +if [ -n "$MAX_CONCURRENT_STREAMS" ]; then + # respect the MAX_CONCURRENT_STREAMS environment variable + awk 'BEGIN { FS=OFS="=" } $1=="h2_max_concurrent_streams " {$2=" '$MAX_CONCURRENT_STREAMS'"} {print}' hypercorn-config.toml > hypercorn-config.toml.new + mv hypercorn-config.toml.new hypercorn-config.toml +fi + +if [ -n "$RESTATE_LOGGING" ]; then + # unification of the RESTATE_LOGGING environment variable + # which is also used by the node-test-services. + # + # Set by the e2e-verification-runner + export RESTATE_CORE_LOG=$RESTATE_LOGGING +fi + +python3 -m hypercorn testservices:app --config hypercorn-config.toml --bind "0.0.0.0:${PORT}" diff --git a/test-services-cls/exclusions.yaml b/test-services-cls/exclusions.yaml new file mode 100644 index 0000000..7831c23 --- /dev/null +++ b/test-services-cls/exclusions.yaml @@ -0,0 +1 @@ +exclusions: {} diff --git a/test-services-cls/hypercorn-config.toml b/test-services-cls/hypercorn-config.toml new file mode 100644 index 0000000..0a02d23 --- /dev/null +++ b/test-services-cls/hypercorn-config.toml @@ -0,0 +1,5 @@ +h2_max_concurrent_streams = 2147483647 +keep_alive_max_requests = 2147483647 +keep_alive_timeout = 2147483647 +workers = 8 + diff --git a/test-services-cls/services/__init__.py b/test-services-cls/services/__init__.py new file mode 100644 index 0000000..fd31a2e --- /dev/null +++ b/test-services-cls/services/__init__.py @@ -0,0 +1,42 @@ +from restate.service import Service as _OrigService +from restate.object import VirtualObject as _OrigObject +from restate.workflow import Workflow as _OrigWorkflow + +from .counter import Counter as s1 +from .proxy import Proxy as s2 +from .awakeable_holder import AwakeableHolder as s3 +from .block_and_wait_workflow import BlockAndWaitWorkflow as s4 +from .cancel_test import CancelTestRunner, CancelTestBlockingService as s5 +from .failing import Failing as s6 +from .kill_test import KillTestRunner, KillTestSingleton as s7 +from .list_object import ListObject as s8 +from .map_object import MapObject as s9 +from .non_determinism import NonDeterministic as s10 +from .test_utils import TestUtilsService as s11 +from .virtual_object_command_interpreter import VirtualObjectCommandInterpreter as s16 + +from .interpreter import layer_0 as s12 +from .interpreter import layer_1 as s13 +from .interpreter import layer_2 as s14 +from .interpreter import helper as s15 + + +def list_services(bindings): + """List all services from local bindings — supports both class-based and decorator-based.""" + result = {} + for _, obj in bindings.items(): + svc = getattr(obj, '_restate_service', obj) + if isinstance(svc, (_OrigService, _OrigObject, _OrigWorkflow)): + result[svc.name] = obj + return result + + +def services_named(service_names): + return [_all_services[name] for name in service_names] + + +def all_services(): + return _all_services.values() + + +_all_services = list_services(locals()) diff --git a/test-services-cls/services/awakeable_holder.py b/test-services-cls/services/awakeable_holder.py new file mode 100644 index 0000000..dbf4c65 --- /dev/null +++ b/test-services-cls/services/awakeable_holder.py @@ -0,0 +1,36 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""awakeable_holder.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 +# pylint: disable=W0622 + +from restate.cls import VirtualObject, handler, Context +from restate.exceptions import TerminalError + + +class AwakeableHolder(VirtualObject, name="AwakeableHolder"): + + @handler + async def hold(self, id: str): + Context.set("id", id) + + @handler(name="hasAwakeable") + async def has_awakeable(self) -> bool: + res = await Context.get("id") + return res is not None + + @handler + async def unlock(self, payload: str): + id = await Context.get("id") + if id is None: + raise TerminalError(message="No awakeable is registered") + Context.resolve_awakeable(id, payload) diff --git a/test-services-cls/services/block_and_wait_workflow.py b/test-services-cls/services/block_and_wait_workflow.py new file mode 100644 index 0000000..5cead7d --- /dev/null +++ b/test-services-cls/services/block_and_wait_workflow.py @@ -0,0 +1,39 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""block_and_wait_workflow.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 +# pylint: disable=W0622 + +from restate.cls import Workflow, handler, main, Context +from restate.exceptions import TerminalError + + +class BlockAndWaitWorkflow(Workflow, name="BlockAndWaitWorkflow"): + + @main + async def run(self, input: str): + Context.set("my-state", input) + output = await Context.promise("durable-promise").value() + + peek = await Context.promise("durable-promise").peek() + if peek is None: + raise TerminalError(message="Durable promise should be completed") + + return output + + @handler + async def unblock(self, output: str): + await Context.promise("durable-promise").resolve(output) + + @handler(name="getState") + async def get_state(self, output: str) -> str | None: + return await Context.get("my-state") diff --git a/test-services-cls/services/cancel_test.py b/test-services-cls/services/cancel_test.py new file mode 100644 index 0000000..272c266 --- /dev/null +++ b/test-services-cls/services/cancel_test.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""cancel_test.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 + +from datetime import timedelta +from typing import Literal +from restate.cls import VirtualObject, handler, Context +from restate.exceptions import TerminalError + +from . import awakeable_holder + +BlockingOperation = Literal["CALL", "SLEEP", "AWAKEABLE"] + + +class CancelTestRunner(VirtualObject, name="CancelTestRunner"): + + @handler(name="startTest") + async def start_test(self, op: BlockingOperation): + block_fn = CancelTestBlockingService._restate_handlers["block"].fn + try: + await Context.object_call(block_fn, key=Context.key(), arg=op) + except TerminalError as t: + if t.status_code == 409: + Context.set("state", True) + else: + raise t + + @handler(name="verifyTest") + async def verify_test(self) -> bool: + state = await Context.get("state") + if state is None: + return False + return state + + +class CancelTestBlockingService(VirtualObject, name="CancelTestBlockingService"): + + @handler + async def block(self, op: BlockingOperation): + hold_fn = awakeable_holder.AwakeableHolder._restate_handlers["hold"].fn + name, awakeable = Context.awakeable() + Context.object_send(hold_fn, key=Context.key(), arg=name) + await awakeable + + block_fn = CancelTestBlockingService._restate_handlers["block"].fn + if op == "CALL": + await Context.object_call(block_fn, key=Context.key(), arg=op) + elif op == "SLEEP": + await Context.sleep(timedelta(days=1024)) + elif op == "AWAKEABLE": + name, uncompleteable = Context.awakeable() + await uncompleteable + + @handler(name="isUnlocked") + async def is_unlocked(self): + return None diff --git a/test-services-cls/services/counter.py b/test-services-cls/services/counter.py new file mode 100644 index 0000000..1140db3 --- /dev/null +++ b/test-services-cls/services/counter.py @@ -0,0 +1,56 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""counter.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 + +from typing import TypedDict +from restate.cls import VirtualObject, handler, Context +from restate.exceptions import TerminalError + +COUNTER_KEY = "counter" + + +class CounterUpdateResponse(TypedDict): + oldValue: int + newValue: int + + +class Counter(VirtualObject, name="Counter"): + + @handler + async def reset(self): + Context.clear(COUNTER_KEY) + + @handler + async def get(self) -> int: + c: int | None = await Context.get(COUNTER_KEY) + if c is None: + return 0 + return c + + @handler + async def add(self, addend: int) -> CounterUpdateResponse: + old_value: int | None = await Context.get(COUNTER_KEY) + if old_value is None: + old_value = 0 + new_value = old_value + addend + Context.set(COUNTER_KEY, new_value) + return CounterUpdateResponse(oldValue=old_value, newValue=new_value) + + @handler(name="addThenFail") + async def add_then_fail(self, addend: int): + old_value: int | None = await Context.get(COUNTER_KEY) + if old_value is None: + old_value = 0 + new_value = old_value + addend + Context.set(COUNTER_KEY, new_value) + raise TerminalError(message=Context.key()) diff --git a/test-services-cls/services/failing.py b/test-services-cls/services/failing.py new file mode 100644 index 0000000..760f339 --- /dev/null +++ b/test-services-cls/services/failing.py @@ -0,0 +1,88 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""failing.py — class-based""" + +from datetime import timedelta + +# pylint: disable=C0116 +# pylint: disable=W0613 +# pylint: disable=W0622 + +from restate.cls import VirtualObject, handler, Context +from restate.exceptions import TerminalError +from restate import RunOptions + +failures = 0 +eventual_success_side_effects = 0 +eventual_failure_side_effects = 0 + + +class Failing(VirtualObject, name="Failing"): + + @handler(name="terminallyFailingCall") + async def terminally_failing_call(self, msg: str): + raise TerminalError(message=msg) + + @handler(name="callTerminallyFailingCall") + async def call_terminally_failing_call(self, msg: str) -> str: + fn = Failing._restate_handlers["terminallyFailingCall"].fn + await Context.object_call(fn, key="random-583e1bf2", arg=msg) + raise Exception("Should not reach here") + + @handler(name="failingCallWithEventualSuccess") + async def failing_call_with_eventual_success(self) -> int: + global failures + failures += 1 + if failures >= 4: + failures = 0 + return 4 + raise ValueError(f"Failed at attempt: {failures}") + + @handler(name="terminallyFailingSideEffect") + async def terminally_failing_side_effect(self, error_message: str): + def side_effect(): + raise TerminalError(message=error_message) + + await Context.run_typed("sideEffect", side_effect) + raise ValueError("Should not reach here") + + @handler(name="sideEffectSucceedsAfterGivenAttempts") + async def side_effect_succeeds_after_given_attempts(self, minimum_attempts: int) -> int: + def side_effect(): + global eventual_success_side_effects + eventual_success_side_effects += 1 + if eventual_success_side_effects >= minimum_attempts: + return eventual_success_side_effects + raise ValueError(f"Failed at attempt: {eventual_success_side_effects}") + + options: RunOptions[int] = RunOptions( + max_attempts=minimum_attempts + 1, initial_retry_interval=timedelta(milliseconds=1), retry_interval_factor=1.0 + ) + return await Context.run_typed("sideEffect", side_effect, options) + + @handler(name="sideEffectFailsAfterGivenAttempts") + async def side_effect_fails_after_given_attempts(self, retry_policy_max_retry_count: int) -> int: + def side_effect(): + global eventual_failure_side_effects + eventual_failure_side_effects += 1 + raise ValueError(f"Failed at attempt: {eventual_failure_side_effects}") + + try: + options: RunOptions[int] = RunOptions( + max_attempts=retry_policy_max_retry_count, + initial_retry_interval=timedelta(milliseconds=1), + retry_interval_factor=1.0, + ) + await Context.run_typed("sideEffect", side_effect, options) + raise ValueError("Side effect did not fail.") + except TerminalError: + global eventual_failure_side_effects + return eventual_failure_side_effects diff --git a/test-services-cls/services/interpreter.py b/test-services-cls/services/interpreter.py new file mode 100644 index 0000000..528e364 --- /dev/null +++ b/test-services-cls/services/interpreter.py @@ -0,0 +1,295 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""Verification test — class-based""" + +from datetime import timedelta +import json +from typing import TypedDict +import typing +import random + +from restate.cls import Service, handler, Context +from restate.exceptions import TerminalError +from restate.serde import JsonSerde + +# Import the decorator-based VirtualObject for the dynamically-created layers +from restate.object import VirtualObject +from restate.context import ObjectContext, ObjectSharedContext + +# suppress missing docstring +# pylint: disable=C0115 +# pylint: disable=C0116 +# pylint: disable=C0301 +# pylint: disable=R0914, R0912, R0915, R0913 + +SET_STATE = 1 +GET_STATE = 2 +CLEAR_STATE = 3 +INCREMENT_STATE_COUNTER = 4 +INCREMENT_STATE_COUNTER_INDIRECTLY = 5 +SLEEP = 6 +CALL_SERVICE = 7 +CALL_SLOW_SERVICE = 8 +INCREMENT_VIA_DELAYED_CALL = 9 +SIDE_EFFECT = 10 +THROWING_SIDE_EFFECT = 11 +SLOW_SIDE_EFFECT = 12 +RECOVER_TERMINAL_CALL = 13 +RECOVER_TERMINAL_MAYBE_UN_AWAITED = 14 +AWAIT_PROMISE = 15 +RESOLVE_AWAKEABLE = 16 +REJECT_AWAKEABLE = 17 +INCREMENT_STATE_COUNTER_VIA_AWAKEABLE = 18 +CALL_NEXT_LAYER_OBJECT = 19 + + +class ServiceInterpreterHelper(Service, name="ServiceInterpreterHelper"): + + @handler + async def ping(self) -> None: + pass + + @handler + async def echo(self, parameters: str) -> str: + return parameters + + @handler(name="echoLater") + async def echo_later(self, parameter: dict[str, typing.Any]) -> str: + await Context.sleep(timedelta(milliseconds=parameter["sleep"])) + return parameter["parameter"] + + @handler(name="terminalFailure") + async def terminal_failure(self) -> str: + raise TerminalError("bye") + + @handler(name="incrementIndirectly") + async def increment_indirectly(self, parameter) -> None: + layer = parameter["layer"] + key = parameter["key"] + + program = { + "commands": [ + { + "kind": INCREMENT_STATE_COUNTER, + }, + ], + } + + program_bytes = json.dumps(program).encode("utf-8") + Context.generic_send(f"ObjectInterpreterL{layer}", "interpret", program_bytes, key) + + @handler(name="resolveAwakeable") + async def resolve_awakeable(self, aid: str) -> None: + Context.resolve_awakeable(aid, "ok") + + @handler(name="rejectAwakeable") + async def reject_awakeable(self, aid: str) -> None: + Context.reject_awakeable(aid, "error") + + @handler(name="incrementViaAwakeableDance") + async def increment_via_awakeable_dance(self, input: dict[str, typing.Any]) -> None: + tx_promise_id = input["txPromiseId"] + layer = input["interpreter"]["layer"] + key = input["interpreter"]["key"] + + aid, promise = Context.awakeable() + Context.resolve_awakeable(tx_promise_id, aid) + await promise + + program = { + "commands": [ + { + "kind": INCREMENT_STATE_COUNTER, + }, + ], + } + + program_bytes = json.dumps(program).encode("utf-8") + Context.generic_send(f"ObjectInterpreterL{layer}", "interpret", program_bytes, key) + + +# Keep helper as a reference to the class for the __init__.py import +helper = ServiceInterpreterHelper + + +class SupportService: + """Helper for making generic calls to ServiceInterpreterHelper.""" + def __init__(self) -> None: + self.serde = JsonSerde[typing.Any]() + + async def call(self, method: str, arg: typing.Any) -> typing.Any: + buffer = self.serde.serialize(arg) + out_buffer = await Context.generic_call("ServiceInterpreterHelper", method, buffer) + return self.serde.deserialize(out_buffer) + + def send(self, method: str, arg: typing.Any, delay: int | None = None) -> None: + buffer = self.serde.serialize(arg) + if delay is None: + send_delay = None + else: + send_delay = timedelta(milliseconds=delay) + Context.generic_send("ServiceInterpreterHelper", method, buffer, send_delay=send_delay) + + async def ping(self) -> None: + return await self.call(method="ping", arg=None) + + async def echo(self, parameters: str) -> str: + return await self.call(method="echo", arg=parameters) + + async def echo_later(self, parameter: str, sleep: int) -> str: + arg = {"parameter": parameter, "sleep": sleep} + return await self.call(method="echoLater", arg=arg) + + async def terminal_failure(self) -> str: + return await self.call(method="terminalFailure", arg=None) + + async def increment_indirectly(self, layer: int, key: str, delay: typing.Optional[int] = None) -> None: + arg = {"layer": layer, "key": key} + self.send(method="incrementIndirectly", arg=arg, delay=delay) + + def resolve_awakeable(self, aid: str) -> None: + self.send("resolveAwakeable", aid) + + def reject_awakeable(self, aid: str) -> None: + self.send("rejectAwakeable", aid) + + def increment_via_awakeable_dance(self, layer: int, key: str, tx_promise_id: str) -> None: + arg = {"interpreter": {"layer": layer, "key": key}, "txPromiseId": tx_promise_id} + self.send("incrementViaAwakeableDance", arg) + + +class Command(TypedDict): + kind: int + key: int + duration: int + sleep: int + index: int + program: typing.Any # avoid circular type + + +Program = dict[typing.Literal["commands"], typing.List[Command]] + + +async def interpreter(layer: int, program: Program) -> None: + """Interprets a command and executes it.""" + service = SupportService() + coros: dict[int, typing.Tuple[typing.Any, typing.Awaitable[typing.Any]]] = {} + + async def await_promise(index: int) -> None: + if index not in coros: + return + + expected, coro = coros[index] + del coros[index] + try: + result = await coro + except TerminalError: + result = "rejected" + + if result != expected: + raise TerminalError(f"Expected {expected} but got {result}") + + for i, command in enumerate(program["commands"]): + command_type = command["kind"] + if command_type == SET_STATE: + Context.set(f"key-{command['key']}", f"value-{command['key']}") + elif command_type == GET_STATE: + await Context.get(f"key-{command['key']}") + elif command_type == CLEAR_STATE: + Context.clear(f"key-{command['key']}") + elif command_type == INCREMENT_STATE_COUNTER: + c = await Context.get("counter") or 0 + c += 1 + Context.set("counter", c) + elif command_type == SLEEP: + duration = timedelta(milliseconds=command["duration"]) + await Context.sleep(duration) + elif command_type == CALL_SERVICE: + expected = f"hello-{i}" + coros[i] = (expected, service.echo(expected)) + elif command_type == INCREMENT_VIA_DELAYED_CALL: + delay = command["duration"] + await service.increment_indirectly(layer=layer, key=Context.key(), delay=delay) + elif command_type == CALL_SLOW_SERVICE: + expected = f"hello-{i}" + coros[i] = (expected, service.echo_later(expected, command["sleep"])) + elif command_type == SIDE_EFFECT: + expected = f"hello-{i}" + result = await Context.run_typed("sideEffect", lambda: expected) + if result != expected: + raise TerminalError(f"Expected {expected} but got {result}") + elif command_type == SLOW_SIDE_EFFECT: + pass + elif command_type == RECOVER_TERMINAL_CALL: + try: + await service.terminal_failure() + except TerminalError: + pass + else: + raise TerminalError("Expected terminal error") + elif command_type == RECOVER_TERMINAL_MAYBE_UN_AWAITED: + pass + elif command_type == THROWING_SIDE_EFFECT: + + async def side_effect(): + if bool(random.getrandbits(1)): + raise ValueError("Random error") + + await Context.run_typed("throwingSideEffect", side_effect) + elif command_type == INCREMENT_STATE_COUNTER_INDIRECTLY: + await service.increment_indirectly(layer=layer, key=Context.key()) + elif command_type == AWAIT_PROMISE: + index = command["index"] + await await_promise(index) + elif command_type == RESOLVE_AWAKEABLE: + name, promise = Context.awakeable() + coros[i] = ("ok", promise) + service.resolve_awakeable(name) + elif command_type == REJECT_AWAKEABLE: + name, promise = Context.awakeable() + coros[i] = ("rejected", promise) + service.reject_awakeable(name) + elif command_type == INCREMENT_STATE_COUNTER_VIA_AWAKEABLE: + tx_promise_id, tx_promise = Context.awakeable() + service.increment_via_awakeable_dance(layer=layer, key=Context.key(), tx_promise_id=tx_promise_id) + their_promise_for_us_to_resolve: str = await tx_promise + Context.resolve_awakeable(their_promise_for_us_to_resolve, "ok") + elif command_type == CALL_NEXT_LAYER_OBJECT: + next_layer = f"ObjectInterpreterL{layer + 1}" + key = f"{command['key']}" + program = command["program"] + js_program = json.dumps(program) + raw_js_program = js_program.encode("utf-8") + promise = Context.generic_call(next_layer, "interpret", raw_js_program, key) + coros[i] = (b"", promise) + else: + raise ValueError(f"Unknown command type: {command_type}") + await await_promise(i) + + +# The layers use the original decorator-based API since they're dynamically created +def make_layer(i): + layer = VirtualObject(f"ObjectInterpreterL{i}") + + @layer.handler() + async def interpret(ctx: ObjectContext, program: Program) -> None: + await interpreter(i, program) + + @layer.handler(kind="shared") + async def counter(ctx: ObjectSharedContext) -> int: + return await ctx.get("counter") or 0 + + return layer + + +layer_0 = make_layer(0) +layer_1 = make_layer(1) +layer_2 = make_layer(2) diff --git a/test-services-cls/services/kill_test.py b/test-services-cls/services/kill_test.py new file mode 100644 index 0000000..2994350 --- /dev/null +++ b/test-services-cls/services/kill_test.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""kill_test.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 + +from restate.cls import VirtualObject, handler, Context + +from . import awakeable_holder + + +class KillTestRunner(VirtualObject, name="KillTestRunner"): + + @handler(name="startCallTree") + async def start_call_tree(self): + fn = KillTestSingleton._restate_handlers["recursiveCall"].fn + await Context.object_call(fn, key=Context.key(), arg=None) + + +class KillTestSingleton(VirtualObject, name="KillTestSingleton"): + + @handler(name="recursiveCall") + async def recursive_call(self): + hold_fn = awakeable_holder.AwakeableHolder._restate_handlers["hold"].fn + name, promise = Context.awakeable() + Context.object_send(hold_fn, key=Context.key(), arg=name) + await promise + + fn = KillTestSingleton._restate_handlers["recursiveCall"].fn + await Context.object_call(fn, key=Context.key(), arg=None) + + @handler(name="isUnlocked") + async def is_unlocked(self): + return None diff --git a/test-services-cls/services/list_object.py b/test-services-cls/services/list_object.py new file mode 100644 index 0000000..1786b7e --- /dev/null +++ b/test-services-cls/services/list_object.py @@ -0,0 +1,33 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""list_object.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 + +from restate.cls import VirtualObject, handler, Context + + +class ListObject(VirtualObject, name="ListObject"): + + @handler + async def append(self, value: str): + lst = await Context.get("list") or [] + Context.set("list", lst + [value]) + + @handler + async def get(self) -> list[str]: + return await Context.get("list") or [] + + @handler + async def clear(self) -> list[str]: + result = await Context.get("list") or [] + Context.clear("list") + return result diff --git a/test-services-cls/services/map_object.py b/test-services-cls/services/map_object.py new file mode 100644 index 0000000..e039c46 --- /dev/null +++ b/test-services-cls/services/map_object.py @@ -0,0 +1,42 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""map_object.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 + +from typing import TypedDict +from restate.cls import VirtualObject, handler, Context + + +class Entry(TypedDict): + key: str + value: str + + +class MapObject(VirtualObject, name="MapObject"): + + @handler(name="set") + async def map_set(self, entry: Entry): + Context.set(entry["key"], entry["value"]) + + @handler(name="get") + async def map_get(self, key: str) -> str: + return await Context.get(key) or "" + + @handler(name="clearAll") + async def map_clear_all(self) -> list[Entry]: + entries = [] + for key in await Context.state_keys(): + value: str = await Context.get(key) # type: ignore + entry = Entry(key=key, value=value) + entries.append(entry) + Context.clear(key) + return entries diff --git a/test-services-cls/services/non_determinism.py b/test-services-cls/services/non_determinism.py new file mode 100644 index 0000000..e11c237 --- /dev/null +++ b/test-services-cls/services/non_determinism.py @@ -0,0 +1,76 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""non_determinism.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 + +from datetime import timedelta +from typing import Dict +from restate.cls import VirtualObject, handler, Context + +from . import counter + +invoke_counts: Dict[str, int] = {} + + +def do_left_action() -> bool: + count_key = Context.key() + invoke_counts[count_key] = invoke_counts.get(count_key, 0) + 1 + return invoke_counts[count_key] % 2 == 1 + + +def increment_counter(): + add_fn = counter.Counter._restate_handlers["add"].fn + Context.object_send(add_fn, key=Context.key(), arg=1) + + +class NonDeterministic(VirtualObject, name="NonDeterministic"): + + @handler(name="setDifferentKey") + async def set_different_key(self): + if do_left_action(): + Context.set("a", "my-state") + else: + Context.set("b", "my-state") + await Context.sleep(timedelta(milliseconds=100)) + increment_counter() + + @handler(name="backgroundInvokeWithDifferentTargets") + async def background_invoke_with_different_targets(self): + get_fn = counter.Counter._restate_handlers["get"].fn + reset_fn = counter.Counter._restate_handlers["reset"].fn + if do_left_action(): + Context.object_send(get_fn, key="abc", arg=None) + else: + Context.object_send(reset_fn, key="abc", arg=None) + await Context.sleep(timedelta(milliseconds=100)) + increment_counter() + + @handler(name="callDifferentMethod") + async def call_different_method(self): + get_fn = counter.Counter._restate_handlers["get"].fn + reset_fn = counter.Counter._restate_handlers["reset"].fn + if do_left_action(): + await Context.object_call(get_fn, key="abc", arg=None) + else: + await Context.object_call(reset_fn, key="abc", arg=None) + await Context.sleep(timedelta(milliseconds=100)) + increment_counter() + + @handler(name="eitherSleepOrCall") + async def either_sleep_or_call(self): + get_fn = counter.Counter._restate_handlers["get"].fn + if do_left_action(): + await Context.sleep(timedelta(milliseconds=100)) + else: + await Context.object_call(get_fn, key="abc", arg=None) + await Context.sleep(timedelta(milliseconds=100)) + increment_counter() diff --git a/test-services-cls/services/proxy.py b/test-services-cls/services/proxy.py new file mode 100644 index 0000000..902667c --- /dev/null +++ b/test-services-cls/services/proxy.py @@ -0,0 +1,95 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""proxy.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 + +from datetime import timedelta +from typing import TypedDict, Optional, Iterable +from restate.cls import Service, handler, Context + + +class ProxyRequest(TypedDict): + serviceName: str + virtualObjectKey: Optional[str] + handlerName: str + message: Iterable[int] + delayMillis: Optional[int] + idempotencyKey: Optional[str] + + +class ManyCallRequest(TypedDict): + proxyRequest: ProxyRequest + oneWayCall: bool + awaitAtTheEnd: bool + + +class Proxy(Service, name="Proxy"): + + @handler(name="call") + async def proxy_call(self, req: ProxyRequest) -> Iterable[int]: + response = await Context.generic_call( + req["serviceName"], + req["handlerName"], + bytes(req["message"]), + req.get("virtualObjectKey"), + idempotency_key=req.get("idempotencyKey"), + ) + return list(response) + + @handler(name="oneWayCall") + async def one_way_call(self, req: ProxyRequest) -> str: + send_delay = None + delayMillis = req.get("delayMillis") + if delayMillis is not None: + send_delay = timedelta(milliseconds=delayMillis) + handle = Context.generic_send( + req["serviceName"], + req["handlerName"], + bytes(req["message"]), + req.get("virtualObjectKey"), + send_delay=send_delay, + idempotency_key=req.get("idempotencyKey"), + ) + invocation_id = await handle.invocation_id() + return invocation_id + + @handler(name="manyCalls") + async def many_calls(self, requests: Iterable[ManyCallRequest]): + to_await = [] + + for req in requests: + if req["oneWayCall"]: + send_delay = None + delayMillis = req["proxyRequest"].get("delayMillis") + if delayMillis is not None: + send_delay = timedelta(milliseconds=delayMillis) + Context.generic_send( + req["proxyRequest"]["serviceName"], + req["proxyRequest"]["handlerName"], + bytes(req["proxyRequest"]["message"]), + req["proxyRequest"].get("virtualObjectKey"), + send_delay=send_delay, + idempotency_key=req["proxyRequest"].get("idempotencyKey"), + ) + else: + awaitable = Context.generic_call( + req["proxyRequest"]["serviceName"], + req["proxyRequest"]["handlerName"], + bytes(req["proxyRequest"]["message"]), + req["proxyRequest"].get("virtualObjectKey"), + idempotency_key=req["proxyRequest"].get("idempotencyKey"), + ) + if req["awaitAtTheEnd"]: + to_await.append(awaitable) + + for awaitable in to_await: + await awaitable diff --git a/test-services-cls/services/test_utils.py b/test-services-cls/services/test_utils.py new file mode 100644 index 0000000..87c21c0 --- /dev/null +++ b/test-services-cls/services/test_utils.py @@ -0,0 +1,66 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""test_utils.py — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 + +from datetime import timedelta +from typing import Dict, List +from restate.cls import Service, handler, Context +from restate.serde import BytesSerde + + +class TestUtilsService(Service, name="TestUtilsService"): + + @handler + async def echo(self, input: str) -> str: + return input + + @handler(name="uppercaseEcho") + async def uppercase_echo(self, input: str) -> str: + return input.upper() + + @handler(name="echoHeaders") + async def echo_headers(self) -> Dict[str, str]: + return Context.request().headers + + @handler(name="sleepConcurrently") + async def sleep_concurrently(self, millis_duration: List[int]) -> None: + timers = [Context.sleep(timedelta(milliseconds=duration)) for duration in millis_duration] + for timer in timers: + await timer + + @handler(name="countExecutedSideEffects") + async def count_executed_side_effects(self, increments: int) -> int: + invoked_side_effects = 0 + + def effect(): + nonlocal invoked_side_effects + invoked_side_effects += 1 + + for _ in range(increments): + await Context.run("count", effect) + + return invoked_side_effects + + @handler(name="cancelInvocation") + async def cancel_invocation(self, invocation_id: str) -> None: + Context.cancel_invocation(invocation_id) + + @handler( + name="rawEcho", + accept="*/*", + content_type="application/octet-stream", + input_serde=BytesSerde(), + output_serde=BytesSerde(), + ) + async def raw_echo(self, input: bytes) -> bytes: + return input diff --git a/test-services-cls/services/virtual_object_command_interpreter.py b/test-services-cls/services/virtual_object_command_interpreter.py new file mode 100644 index 0000000..c40127f --- /dev/null +++ b/test-services-cls/services/virtual_object_command_interpreter.py @@ -0,0 +1,203 @@ +# +# Copyright (c) 2023-2024 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# +"""Virtual Object Command Interpreter — class-based""" +# pylint: disable=C0116 +# pylint: disable=W0613 + +import os +from datetime import timedelta +from typing import Iterable, List, Union, TypedDict, Literal, Any +from restate.cls import VirtualObject, handler, shared, Context +from restate import RestateDurableFuture, RestateDurableSleepFuture +from restate import select, wait_completed, as_completed +from restate.exceptions import TerminalError + + +class CreateAwakeable(TypedDict): + type: Literal["createAwakeable"] + awakeableKey: str + + +class Sleep(TypedDict): + type: Literal["sleep"] + timeoutMillis: int + + +class RunThrowTerminalException(TypedDict): + type: Literal["runThrowTerminalException"] + reason: str + + +AwaitableCommand = Union[CreateAwakeable, Sleep, RunThrowTerminalException] + + +class AwaitOne(TypedDict): + type: Literal["awaitOne"] + command: AwaitableCommand + + +class AwaitAnySuccessful(TypedDict): + type: Literal["awaitAnySuccessful"] + commands: List[AwaitableCommand] + + +class AwaitAny(TypedDict): + type: Literal["awaitAny"] + commands: List[AwaitableCommand] + + +class AwaitAwakeableOrTimeout(TypedDict): + type: Literal["awaitAwakeableOrTimeout"] + awakeableKey: str + timeoutMillis: int + + +class ResolveAwakeable(TypedDict): + type: Literal["resolveAwakeable"] + awakeableKey: str + value: str + + +class RejectAwakeable(TypedDict): + type: Literal["rejectAwakeable"] + awakeableKey: str + reason: str + + +class GetEnvVariable(TypedDict): + type: Literal["getEnvVariable"] + envName: str + + +Command = Union[ + AwaitOne, AwaitAny, AwaitAnySuccessful, AwaitAwakeableOrTimeout, ResolveAwakeable, RejectAwakeable, GetEnvVariable +] + + +class InterpretRequest(TypedDict): + commands: Iterable[Command] + + +def to_durable_future(cmd: AwaitableCommand) -> RestateDurableFuture[Any]: + if cmd["type"] == "createAwakeable": + awk_id, awakeable = Context.awakeable() + Context.set("awk-" + cmd["awakeableKey"], awk_id) + return awakeable + elif cmd["type"] == "sleep": + return Context.sleep(timedelta(milliseconds=cmd["timeoutMillis"])) + elif cmd["type"] == "runThrowTerminalException": + + def side_effect(reason: str): + raise TerminalError(message=reason) + + res = Context.run_typed("run should fail command", side_effect, reason=cmd["reason"]) + return res + + +async def _resolve_awakeable_impl(req: ResolveAwakeable): + awk_id = await Context.get("awk-" + req["awakeableKey"]) + if not awk_id: + raise TerminalError(message="No awakeable is registered") + Context.resolve_awakeable(awk_id, req["value"]) + + +async def _reject_awakeable_impl(req: RejectAwakeable): + awk_id = await Context.get("awk-" + req["awakeableKey"]) + if not awk_id: + raise TerminalError(message="No awakeable is registered") + Context.reject_awakeable(awk_id, req["reason"]) + + +class VirtualObjectCommandInterpreter(VirtualObject, name="VirtualObjectCommandInterpreter"): + + @shared(name="getResults") + async def get_results(self) -> List[str]: + return (await Context.get("results")) or [] + + @shared(name="hasAwakeable") + async def has_awakeable(self, awk_key: str) -> bool: + awk_id = await Context.get("awk-" + awk_key) + if awk_id: + return True + return False + + @shared(name="resolveAwakeable") + async def resolve_awakeable(self, req: ResolveAwakeable): + await _resolve_awakeable_impl(req) + + @shared(name="rejectAwakeable") + async def reject_awakeable(self, req: RejectAwakeable): + await _reject_awakeable_impl(req) + + @handler(name="interpretCommands") + async def interpret_commands(self, req: InterpretRequest): + result = "" + + for cmd in req["commands"]: + if cmd["type"] == "awaitAwakeableOrTimeout": + awk_id, awakeable = Context.awakeable() + Context.set("awk-" + cmd["awakeableKey"], awk_id) + match await select(awakeable=awakeable, timeout=Context.sleep(timedelta(milliseconds=cmd["timeoutMillis"]))): + case ["awakeable", awk_res]: + result = awk_res + case ["timeout", _]: + raise TerminalError(message="await-timeout", status_code=500) + elif cmd["type"] == "resolveAwakeable": + await _resolve_awakeable_impl(cmd) + result = "" + elif cmd["type"] == "rejectAwakeable": + await _reject_awakeable_impl(cmd) + result = "" + elif cmd["type"] == "getEnvVariable": + env_name = cmd["envName"] + + def side_effect(env_name: str): + return os.environ.get(env_name, "") + + result = await Context.run_typed("get_env", side_effect, env_name=env_name) + elif cmd["type"] == "awaitOne": + awaitable = to_durable_future(cmd["command"]) + # We need this dance because the Python SDK doesn't support .map on futures + if isinstance(awaitable, RestateDurableSleepFuture): + await awaitable + result = "sleep" + else: + result = await awaitable + elif cmd["type"] == "awaitAny": + futures = [to_durable_future(c) for c in cmd["commands"]] + done, _ = await wait_completed(*futures) + done_fut = done[0] + # We need this dance because the Python SDK doesn't support .map on futures + if isinstance(done_fut, RestateDurableSleepFuture): + await done_fut + result = "sleep" + else: + result = await done_fut + elif cmd["type"] == "awaitAnySuccessful": + futures = [to_durable_future(c) for c in cmd["commands"]] + async for done_fut in as_completed(*futures): + try: + # We need this dance because the Python SDK doesn't support .map on futures + if isinstance(done_fut, RestateDurableSleepFuture): + await done_fut + result = "sleep" + break + result = await done_fut + break + except TerminalError: + pass + + # Direct state access (same invocation, not RPC) + last_results = (await Context.get("results")) or [] + last_results.append(result) + Context.set("results", last_results) + + return result diff --git a/test-services-cls/testservices.py b/test-services-cls/testservices.py new file mode 100644 index 0000000..212bfee --- /dev/null +++ b/test-services-cls/testservices.py @@ -0,0 +1,16 @@ +"""testservices.py""" +import os +import restate +import services + + +def test_services(): + names = os.environ.get("SERVICES") + return services.services_named(names.split(",")) if names else services.all_services() + + +e2e_signing_key_env = os.environ.get("E2E_REQUEST_SIGNING_ENV") +if e2e_signing_key_env is not None: + e2e_signing_key_env = [e2e_signing_key_env] + +app = restate.app(services=test_services(), identity_keys=e2e_signing_key_env) diff --git a/tests/admin_client.py b/tests/admin_client.py new file mode 100644 index 0000000..af75f3b --- /dev/null +++ b/tests/admin_client.py @@ -0,0 +1,166 @@ +# +# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH +# +# This file is part of the Restate SDK for Python, +# which is released under the MIT license. +# +# You can find a copy of the license in file LICENSE in the root +# directory of this repository or package, or at +# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE +# + +import restate +from restate import ( + Context, + Service, + VirtualObject, + ObjectContext, + ObjectSharedContext, + HarnessEnvironment, + AdminClient, +) +import pytest + +# ----- Asyncio fixtures + + +@pytest.fixture(scope="session") +def anyio_backend(): + return "asyncio" + + +pytestmark = [ + pytest.mark.anyio, +] + +# -------- Services with metadata and descriptions + +greeter = Service( + "greeter", + description="A simple greeting service", + metadata={"team": "platform", "version": "1.0"}, +) + + +@greeter.handler( + metadata={"a2a.restate.dev/handler": "true", "a2a.restate.dev/skill": "greeting"}, +) +async def greet(ctx: Context, name: str) -> str: + """Greets a person by name.""" + return f"Hello {name}!" + + +counter = VirtualObject( + "counter", + description="A durable counter", +) + + +@counter.handler() +async def increment(ctx: ObjectContext, value: int) -> int: + n = await ctx.get("counter", type_hint=int) or 0 + n += value + ctx.set("counter", n) + return n + + +@counter.handler(kind="shared") +async def count(ctx: ObjectSharedContext) -> int: + """Returns the current count.""" + return await ctx.get("counter") or 0 + + +bare = Service("bare") + + +@bare.handler() +async def ping(ctx: Context) -> str: + return "pong" + + +# -------- Harness fixture + + +@pytest.fixture(scope="session") +async def harness(): + async with restate.create_test_harness( + restate.app([greeter, counter, bare]), + restate_image="ghcr.io/restatedev/restate:latest", + ) as env: + yield env + + +# -------- Tests + + +async def test_list_services(harness: HarnessEnvironment): + async with AdminClient(harness.admin_api_url) as admin: + services = await admin.list_services() + + names = {s.name for s in services} + assert "greeter" in names + assert "counter" in names + assert "bare" in names + + +async def test_get_service(harness: HarnessEnvironment): + async with AdminClient(harness.admin_api_url) as admin: + svc = await admin.get_service("greeter") + + assert svc.name == "greeter" + assert svc.ty == "Service" + assert svc.documentation == "A simple greeting service" + assert svc.metadata == {"team": "platform", "version": "1.0"} + + +async def test_service_handlers(harness: HarnessEnvironment): + async with AdminClient(harness.admin_api_url) as admin: + svc = await admin.get_service("greeter") + + handler = svc.get_handler("greet") + assert handler is not None + assert handler.name == "greet" + assert handler.documentation == "Greets a person by name." + assert handler.metadata is not None + assert handler.metadata["a2a.restate.dev/handler"] == "true" + assert handler.metadata["a2a.restate.dev/skill"] == "greeting" + + +async def test_virtual_object_type(harness: HarnessEnvironment): + async with AdminClient(harness.admin_api_url) as admin: + svc = await admin.get_service("counter") + + assert svc.ty == "VirtualObject" + assert svc.documentation == "A durable counter" + + inc = svc.get_handler("increment") + assert inc is not None + assert inc.ty == "Exclusive" + + cnt = svc.get_handler("count") + assert cnt is not None + assert cnt.ty == "Shared" + assert cnt.documentation == "Returns the current count." + + +async def test_handler_not_found(harness: HarnessEnvironment): + async with AdminClient(harness.admin_api_url) as admin: + svc = await admin.get_service("bare") + + assert svc.get_handler("nonexistent") is None + + +async def test_service_not_found(harness: HarnessEnvironment): + async with AdminClient(harness.admin_api_url) as admin: + with pytest.raises(Exception): + await admin.get_service("does_not_exist") + + +async def test_service_without_metadata(harness: HarnessEnvironment): + async with AdminClient(harness.admin_api_url) as admin: + svc = await admin.get_service("bare") + + assert svc.name == "bare" + # Services without explicit metadata may have None or empty dict + assert svc.metadata is None or svc.metadata == {} + assert svc.documentation is None From 3fb59f1306fda2c148074de8fda91a1cd27b3ff7 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Fri, 27 Mar 2026 15:54:51 +0100 Subject: [PATCH 02/12] Add an RPC example --- examples/class_based_greeter.py | 27 ++++++++++++++++++++++++++- python/restate/cls.py | 18 ++++++++++++------ 2 files changed, 38 insertions(+), 7 deletions(-) diff --git a/examples/class_based_greeter.py b/examples/class_based_greeter.py index 839dff9..640f377 100644 --- a/examples/class_based_greeter.py +++ b/examples/class_based_greeter.py @@ -5,6 +5,8 @@ but using the class-based API with @handler, @shared, and @main decorators. """ +from datetime import timedelta + import restate from restate.cls import Service, VirtualObject, Workflow, handler, shared, main, Context @@ -51,4 +53,27 @@ async def status(self) -> str: return await Context.get("status", type_hint=str) or "unknown" -app = restate.app([Greeter, Counter, PaymentWorkflow]) +class OrderProcessor(Service): + """Demonstrates type-safe RPC between services using fluent proxies.""" + + @handler + async def process(self, customer: str) -> str: + # Call a service handler — IDE knows .greet() takes str, returns str + greeting = await Greeter.call().greet(customer) + + # Call a virtual object — IDE knows .increment() takes int, returns int + count = await Counter.call(customer).increment(1) + + # Fire-and-forget send (returns SendHandle, not a coroutine) + Counter.send(customer).increment(1) # type: ignore[unused-coroutine] + + # Send with delay + Counter.send(customer, delay=timedelta(seconds=30)).increment(1) # type: ignore[unused-coroutine] + + # Call a workflow + receipt = await PaymentWorkflow.call(f"order-{count}").pay(100) + + return f"{greeting} (visit #{count}, {receipt})" + + +app = restate.app([Greeter, Counter, PaymentWorkflow, OrderProcessor]) diff --git a/python/restate/cls.py b/python/restate/cls.py index a7d7505..0ebcfa6 100644 --- a/python/restate/cls.py +++ b/python/restate/cls.py @@ -40,11 +40,17 @@ async def count(self) -> int: from __future__ import annotations import inspect +import sys from dataclasses import dataclass, field from datetime import timedelta from functools import wraps from typing import Any, AsyncContextManager, Callable, Dict, List, Literal, Optional, TypeVar +if sys.version_info >= (3, 11): + from typing import Self +else: + from typing_extensions import Self + from restate.handler import HandlerIO, ServiceTag, make_handler from restate.retry_policy import InvocationRetryPolicy from restate.serde import DefaultSerde, Serde @@ -612,7 +618,7 @@ def __init_subclass__( ) @classmethod - def call(cls) -> "Service": # type: ignore[return-type] + def call(cls) -> Self: # type: ignore[return-type] """Return a proxy for making durable service calls. The proxy has the same method signatures as the class, @@ -621,7 +627,7 @@ def call(cls) -> "Service": # type: ignore[return-type] return _ServiceCallProxy(cls) # type: ignore[return-value] @classmethod - def send(cls, *, delay: Optional[timedelta] = None) -> "Service": # type: ignore[return-type] + def send(cls, *, delay: Optional[timedelta] = None) -> Self: # type: ignore[return-type] """Return a proxy for fire-and-forget service sends.""" return _ServiceSendProxy(cls, delay) # type: ignore[return-value] @@ -677,12 +683,12 @@ def __init_subclass__( ) @classmethod - def call(cls, key: str) -> "VirtualObject": # type: ignore[return-type] + def call(cls, key: str) -> Self: # type: ignore[return-type] """Return a proxy for making durable object calls.""" return _ObjectCallProxy(cls, key) # type: ignore[return-value] @classmethod - def send(cls, key: str, *, delay: Optional[timedelta] = None) -> "VirtualObject": # type: ignore[return-type] + def send(cls, key: str, *, delay: Optional[timedelta] = None) -> Self: # type: ignore[return-type] """Return a proxy for fire-and-forget object sends.""" return _ObjectSendProxy(cls, key, delay) # type: ignore[return-value] @@ -738,12 +744,12 @@ def __init_subclass__( ) @classmethod - def call(cls, key: str) -> "Workflow": # type: ignore[return-type] + def call(cls, key: str) -> Self: # type: ignore[return-type] """Return a proxy for making durable workflow calls.""" return _WorkflowCallProxy(cls, key) # type: ignore[return-value] @classmethod - def send(cls, key: str, *, delay: Optional[timedelta] = None) -> "Workflow": # type: ignore[return-type] + def send(cls, key: str, *, delay: Optional[timedelta] = None) -> Self: # type: ignore[return-type] """Return a proxy for fire-and-forget workflow sends.""" return _WorkflowSendProxy(cls, key, delay) # type: ignore[return-value] From a58d5b0d57c4734ea91219c440b2390bd5a7b61e Mon Sep 17 00:00:00 2001 From: igalshilman Date: Fri, 27 Mar 2026 18:01:59 +0100 Subject: [PATCH 03/12] Allow the admin client to pass headers --- python/restate/admin_client.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python/restate/admin_client.py b/python/restate/admin_client.py index 4239627..2874885 100644 --- a/python/restate/admin_client.py +++ b/python/restate/admin_client.py @@ -62,6 +62,7 @@ def get_handler(self, name: str) -> HandlerInfo | None: class ListServicesResponse(BaseModel): """Response from GET /services.""" + services: list[ServiceInfo] @@ -78,8 +79,9 @@ class AdminClient: print(f" - {h.name} metadata={h.metadata}") """ - def __init__(self, admin_url: str): + def __init__(self, admin_url: str, headers: dict[str, str] | None = None): self._admin_url = admin_url.rstrip("/") + self._headers = headers self._client: httpx.AsyncClient | None = None self._owns_client = False @@ -93,7 +95,7 @@ def from_client(cls, admin_url: str, client: httpx.AsyncClient) -> AdminClient: async def _get_client(self) -> httpx.AsyncClient: if self._client is None: - self._client = httpx.AsyncClient(base_url=self._admin_url) + self._client = httpx.AsyncClient(base_url=self._admin_url, headers=self._headers) self._owns_client = True return self._client From d42ff46fb0efdc2c2b2c0dcefc6d2d522bc774a2 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Fri, 27 Mar 2026 20:11:04 +0100 Subject: [PATCH 04/12] Remove unrelated changes --- python/restate/__init__.py | 9 -- python/restate/admin_client.py | 142 ----------------- python/restate/ext/a2a/PLAN.md | 247 ----------------------------- python/restate/ext/a2a/__init__.py | 13 -- tests/admin_client.py | 166 ------------------- 5 files changed, 577 deletions(-) delete mode 100644 python/restate/admin_client.py delete mode 100644 python/restate/ext/a2a/PLAN.md delete mode 100644 python/restate/ext/a2a/__init__.py delete mode 100644 tests/admin_client.py diff --git a/python/restate/__init__.py b/python/restate/__init__.py index b9fc9cd..e4e0e49 100644 --- a/python/restate/__init__.py +++ b/python/restate/__init__.py @@ -88,12 +88,6 @@ async def create_client( yield # type: ignore -try: - from .admin_client import AdminClient, ServiceInfo, HandlerInfo -except ImportError: - pass - - __all__ = [ "Service", "VirtualObject", @@ -128,7 +122,4 @@ async def create_client( "RestateClientSendHandle", "HttpError", "create_client", - "AdminClient", - "ServiceInfo", - "HandlerInfo", ] diff --git a/python/restate/admin_client.py b/python/restate/admin_client.py deleted file mode 100644 index 2874885..0000000 --- a/python/restate/admin_client.py +++ /dev/null @@ -1,142 +0,0 @@ -# -# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH -# -# This file is part of the Restate SDK for Python, -# which is released under the MIT license. -# -# You can find a copy of the license in file LICENSE in the root -# directory of this repository or package, or at -# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE -# -""" -Client for the Restate Admin API. - -Provides typed access to service and handler metadata -via the Restate admin API (default port 9070). -""" - -from __future__ import annotations - -from typing import Any, Literal - -import httpx -from pydantic import BaseModel, ConfigDict, Field - - -class HandlerInfo(BaseModel): - """Metadata about a handler returned by the Restate admin API.""" - - model_config = ConfigDict(extra="allow") - - name: str - ty: Literal["Exclusive", "Shared", "Workflow"] | None = None - documentation: str | None = None - metadata: dict[str, str] | None = None - input_description: str | None = None - output_description: str | None = None - input_json_schema: dict[str, Any] | None = None - output_json_schema: dict[str, Any] | None = None - - -class ServiceInfo(BaseModel): - """Metadata about a service returned by the Restate admin API.""" - - model_config = ConfigDict(extra="allow") - - name: str - ty: Literal["Service", "VirtualObject", "Workflow"] - handlers: list[HandlerInfo] = Field(default_factory=list) - deployment_id: str | None = None - revision: int | None = None - public: bool | None = None - documentation: str | None = None - metadata: dict[str, str] | None = None - - def get_handler(self, name: str) -> HandlerInfo | None: - """Get a handler by name, or None if not found.""" - for h in self.handlers: - if h.name == name: - return h - return None - - -class ListServicesResponse(BaseModel): - """Response from GET /services.""" - - services: list[ServiceInfo] - - -class AdminClient: - """Client for the Restate Admin API. - - Example:: - - async with AdminClient("http://localhost:9070") as client: - services = await client.list_services() - for svc in services: - print(f"{svc.name} ({svc.ty}): {len(svc.handlers)} handlers") - for h in svc.handlers: - print(f" - {h.name} metadata={h.metadata}") - """ - - def __init__(self, admin_url: str, headers: dict[str, str] | None = None): - self._admin_url = admin_url.rstrip("/") - self._headers = headers - self._client: httpx.AsyncClient | None = None - self._owns_client = False - - @classmethod - def from_client(cls, admin_url: str, client: httpx.AsyncClient) -> AdminClient: - """Create an AdminClient using an existing httpx client.""" - instance = cls(admin_url) - instance._client = client - instance._owns_client = False - return instance - - async def _get_client(self) -> httpx.AsyncClient: - if self._client is None: - self._client = httpx.AsyncClient(base_url=self._admin_url, headers=self._headers) - self._owns_client = True - return self._client - - async def close(self) -> None: - """Close the underlying HTTP client if we own it.""" - if self._client is not None and self._owns_client: - await self._client.aclose() - self._client = None - - async def __aenter__(self) -> AdminClient: - return self - - async def __aexit__(self, *args: Any) -> None: - await self.close() - - async def list_services(self) -> list[ServiceInfo]: - """List all registered services with their handlers and metadata. - - Returns: - A list of ServiceInfo objects, each containing the service's - handlers, metadata, documentation, and configuration. - """ - client = await self._get_client() - response = await client.get("/services") - response.raise_for_status() - parsed = ListServicesResponse.model_validate(response.json()) - return parsed.services - - async def get_service(self, name: str) -> ServiceInfo: - """Get detailed information about a specific service. - - Args: - name: The service name. - - Returns: - A ServiceInfo object with full handler and metadata details. - - Raises: - httpx.HTTPStatusError: If the service is not found (404) or other errors. - """ - client = await self._get_client() - response = await client.get(f"/services/{name}") - response.raise_for_status() - return ServiceInfo.model_validate(response.json()) diff --git a/python/restate/ext/a2a/PLAN.md b/python/restate/ext/a2a/PLAN.md deleted file mode 100644 index 2a82c54..0000000 --- a/python/restate/ext/a2a/PLAN.md +++ /dev/null @@ -1,247 +0,0 @@ -# A2A Integration for Restate Python SDK — v3 - -## Context - -Adding an A2A integration that combines the best of both approaches: -- **Inside Restate**: A `TaskObject` (VirtualObject) manages task state, durability, and agent invocation -- **Outside Restate**: A thin FastAPI gateway handles A2A protocol (JSON-RPC, agent card catalog) and discovers agents via the Restate admin API - -The gateway translates A2A JSON-RPC into Restate ingress HTTP calls. No `AgentExecutor`, no `DefaultRequestHandler`, no A2A server SDK needed on the gateway — it's a direct protocol translation. - -## Architecture - -``` -A2A Client - │ - ▼ -FastAPI Gateway (outside Restate) - │ - ├─ GET /.well-known/agent-card - │ → queries Restate admin API for services with a2a metadata - │ → builds and returns agent card catalog - │ - ├─ POST /{agent}/a2a (JSON-RPC: message/send) - │ → POST http://ingress/{agent}-task/{task_id}/handle_send_message_request - │ - ├─ POST /{agent}/a2a (JSON-RPC: tasks/get) - │ → POST http://ingress/{agent}-task/{task_id}/get_task - │ - └─ POST /{agent}/a2a (JSON-RPC: tasks/cancel) - → POST http://ingress/{agent}-task/{task_id}/get_invocation_id - → PATCH http://admin/invocations/{id}/cancel - → POST http://ingress/{agent}-task/{task_id}/get_task (poll for result) - -Inside Restate: - TaskObject (VirtualObject, keyed by task_id) - ├─ handle_send_message_request (exclusive) — runs agent, manages task state - ├─ get_task (shared) — returns task from K/V store - ├─ get_invocation_id (shared) — returns in-flight invocation ID - └─ cancel_task (exclusive) — marks task as canceled -``` - -## Agent Function Signature - -```python -async def my_agent(query: str, context_id: str) -> AgentInvokeResult: - ctx = restate_context() # available via current_context() - result = await ctx.run_typed("call_llm", llm.call, query) - return AgentInvokeResult(parts=[TextPart(text=result)]) -``` - -Runs within TaskObject's exclusive handler context — full access to `ctx.run_typed()`, service calls, etc. - -## Agent Discovery via Metadata - -Each `TaskObject` stores its agent card in service metadata: - -```python -metadata={"a2a.agent_card": agent_card.model_dump_json()} -``` - -Service `description` is used as the agent description. - -The gateway queries `GET http://admin:9070/services`, filters for services with `a2a.agent_card` metadata, and deserializes the agent cards. - -## User-Facing API - -### Restate side (agent definition): - -```python -from restate.ext.a2a import A2ATaskObject, AgentInvokeResult -from a2a.types import AgentCard, AgentSkill, TextPart - -async def weather_agent(query: str, context_id: str) -> AgentInvokeResult: - ctx = restate_context() - forecast = await ctx.run_typed("get_forecast", fetch_forecast, query) - return AgentInvokeResult(parts=[TextPart(text=forecast)]) - -weather = A2ATaskObject( - "weather", - invoke_function=weather_agent, - agent_card=AgentCard( - name="Weather Agent", - description="Provides weather forecasts", - url="http://gateway:8000/weather/a2a", # gateway URL - version="1.0", - skills=[AgentSkill(id="forecast", name="Forecast", description="...")], - default_input_modes=["text"], - default_output_modes=["text"], - ), -) - -# Standard restate app — TaskObject is just a VirtualObject -app = restate.app(services=[weather]) -``` - -### Gateway side (separate process): - -```python -from restate.ext.a2a import A2AGateway - -gateway = A2AGateway( - restate_admin_url="http://localhost:9070", - restate_ingress_url="http://localhost:8080", -) -app = gateway.build() # FastAPI app - -# Run: uvicorn gateway:app --port 8000 -``` - -The gateway auto-discovers all `A2ATaskObject` services from the admin API. - -## Files to Create/Modify - -### 1. `python/restate/ext/a2a/__init__.py` (new) - -Exports: -- `A2ATaskObject` — VirtualObject with built-in task management -- `AgentInvokeResult` — result type for invoke_function -- `A2AGateway` — FastAPI gateway -- `restate_context()` / `restate_object_context()` — context helpers - -### 2. `python/restate/ext/a2a/_models.py` (new) - -```python -@dataclass -class AgentInvokeResult: - parts: list[Part] - require_user_input: bool = False - -InvokeAgentType = Callable[[str, str], Awaitable[AgentInvokeResult]] -``` - -### 3. `python/restate/ext/a2a/_task.py` (new, based on reference) - -**`TaskObject`** — VirtualObject keyed by `task_id`. Copied from reference with adjustments: - -- `handle_send_message_request(ctx, request: SendMessageRequest) -> SendMessageResponse` - - Generates context_id if missing - - Stores invocation ID for cancellation - - Upserts task in K/V store - - Calls `invoke_function(query, context_id)` - - Updates task to completed/input_required/failed/canceled -- `get_task(ctx) -> Task | None` (shared) -- `get_invocation_id(ctx) -> str | None` (shared) -- `cancel_task(ctx, request) -> CancelTaskResponse` (exclusive) -- `update_store(ctx, state, ...) -> Task` (exclusive) -- `upsert_task(ctx, params) -> Task` (exclusive) - -**`A2ATaskObject`** — wrapper that creates a `TaskObject` with agent card stored in metadata: - -```python -class A2ATaskObject: - def __init__(self, name, invoke_function, agent_card): - self._task_object = TaskObject( - f"{name}", - invoke_function, - ) - # Store agent card in metadata for discovery - self._task_object.metadata = {"a2a.agent_card": agent_card.model_dump_json()} - self._task_object.description = agent_card.description -``` - -Exposes the same interface as VirtualObject so it can be passed to `restate.app()`. - -### 4. `python/restate/ext/a2a/_gateway.py` (new) - -**`A2AGateway`** — FastAPI app builder: - -- Constructor: `(restate_admin_url, restate_ingress_url)` -- `build()` → FastAPI app with: - - `GET /.well-known/agent-card` — returns agent card(s) from admin API discovery - - `POST /{agent_name}/a2a` — JSON-RPC dispatch endpoint per agent -- Agent discovery: queries `GET {admin_url}/services`, filters by `a2a.agent_card` metadata -- JSON-RPC dispatch: - - `message/send` → `POST {ingress_url}/{agent_name}/{task_id}/handle_send_message_request` - - `tasks/get` → `POST {ingress_url}/{agent_name}/{task_id}/get_task` - - `tasks/cancel`: - 1. `POST {ingress_url}/{agent_name}/{task_id}/get_invocation_id` - 2. `PATCH {admin_url}/invocations/{id}/cancel` - 3. Poll `get_task` for final state - - Other methods → return appropriate JSON-RPC error responses - -Uses `httpx.AsyncClient` for all HTTP calls. Lifespan manages the client. - -### 5. `pyproject.toml` (modify) -Add: `a2a = ["a2a-sdk", "fastapi", "httpx[http2]"]` - -## Key Design Details - -### Ingress URL patterns (VirtualObject) -- `POST http://ingress/{service}/{key}/{handler}` — blocking call -- `POST http://ingress/{service}/{key}/{handler}/send` — returns invocation ID immediately -- Request body: JSON-serialized handler input -- Response: JSON-serialized handler output -- Header `x-restate-id`: invocation ID - -### Admin API patterns -- `GET http://admin/services` — list all services with metadata -- `PATCH http://admin/invocations/{id}/cancel` — cancel in-flight invocation - -### Serialization -A2A types are Pydantic models → Restate's PydanticJsonSerde handles them correctly. The gateway serializes/deserializes using the same Pydantic models. - -### Agent Card in Metadata -```python -metadata = {"a2a.agent_card": agent_card.model_dump_json()} -``` -Gateway deserializes: `AgentCard.model_validate_json(metadata["a2a.agent_card"])` - -### Task ID -- If the A2A client provides a `task_id` in the message → use it as VirtualObject key -- If not → gateway generates a UUID and uses it as the key - -### Multi-agent Support -- Each agent is a separate `A2ATaskObject` (VirtualObject) -- Gateway discovers all of them via admin API -- Each agent gets its own endpoint: `POST /{agent_name}/a2a` -- Agent cards include their specific URL - -## Cancellation Flow (detailed) - -1. A2A client sends `tasks/cancel` JSON-RPC request -2. Gateway receives it, extracts `task_id` and `agent_name` -3. Gateway calls `POST {ingress}/{agent_name}/{task_id}/get_invocation_id` -4. If invocation_id exists: - a. `PATCH {admin}/invocations/{invocation_id}/cancel` - b. The running handler catches TerminalError(409, "cancelled"), updates task to canceled - c. Gateway calls `POST {ingress}/{agent_name}/{task_id}/get_task` to get final state -5. If no invocation_id (task already completed): - a. Gateway calls `POST {ingress}/{agent_name}/{task_id}/cancel_task` - -## v1 Limitations -- No streaming support -- No push notifications -- No resubscribe support -- No authenticated extended card -- Gateway discovery is one-shot on startup (could add polling later) - -## Verification -1. Create a simple echo agent with `A2ATaskObject` -2. Run restate app: `restate.app(services=[echo_agent])` -3. Run gateway: `uvicorn gateway:app` -4. Send JSON-RPC `message/send` via curl -5. Send `tasks/get` to retrieve task -6. Test multi-turn with same task_id -7. Test cancellation -8. Verify agent card discovery at `/.well-known/agent-card` diff --git a/python/restate/ext/a2a/__init__.py b/python/restate/ext/a2a/__init__.py deleted file mode 100644 index 7b8add8..0000000 --- a/python/restate/ext/a2a/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# -# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH -# -# This file is part of the Restate SDK for Python, -# which is released under the MIT license. -# -# You can find a copy of the license in file LICENSE in the root -# directory of this repository or package, or at -# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE -# -""" -This module contains the optional A2A (Agent-to-Agent) integration for Restate. -""" diff --git a/tests/admin_client.py b/tests/admin_client.py deleted file mode 100644 index af75f3b..0000000 --- a/tests/admin_client.py +++ /dev/null @@ -1,166 +0,0 @@ -# -# Copyright (c) 2023-2025 - Restate Software, Inc., Restate GmbH -# -# This file is part of the Restate SDK for Python, -# which is released under the MIT license. -# -# You can find a copy of the license in file LICENSE in the root -# directory of this repository or package, or at -# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE -# - -import restate -from restate import ( - Context, - Service, - VirtualObject, - ObjectContext, - ObjectSharedContext, - HarnessEnvironment, - AdminClient, -) -import pytest - -# ----- Asyncio fixtures - - -@pytest.fixture(scope="session") -def anyio_backend(): - return "asyncio" - - -pytestmark = [ - pytest.mark.anyio, -] - -# -------- Services with metadata and descriptions - -greeter = Service( - "greeter", - description="A simple greeting service", - metadata={"team": "platform", "version": "1.0"}, -) - - -@greeter.handler( - metadata={"a2a.restate.dev/handler": "true", "a2a.restate.dev/skill": "greeting"}, -) -async def greet(ctx: Context, name: str) -> str: - """Greets a person by name.""" - return f"Hello {name}!" - - -counter = VirtualObject( - "counter", - description="A durable counter", -) - - -@counter.handler() -async def increment(ctx: ObjectContext, value: int) -> int: - n = await ctx.get("counter", type_hint=int) or 0 - n += value - ctx.set("counter", n) - return n - - -@counter.handler(kind="shared") -async def count(ctx: ObjectSharedContext) -> int: - """Returns the current count.""" - return await ctx.get("counter") or 0 - - -bare = Service("bare") - - -@bare.handler() -async def ping(ctx: Context) -> str: - return "pong" - - -# -------- Harness fixture - - -@pytest.fixture(scope="session") -async def harness(): - async with restate.create_test_harness( - restate.app([greeter, counter, bare]), - restate_image="ghcr.io/restatedev/restate:latest", - ) as env: - yield env - - -# -------- Tests - - -async def test_list_services(harness: HarnessEnvironment): - async with AdminClient(harness.admin_api_url) as admin: - services = await admin.list_services() - - names = {s.name for s in services} - assert "greeter" in names - assert "counter" in names - assert "bare" in names - - -async def test_get_service(harness: HarnessEnvironment): - async with AdminClient(harness.admin_api_url) as admin: - svc = await admin.get_service("greeter") - - assert svc.name == "greeter" - assert svc.ty == "Service" - assert svc.documentation == "A simple greeting service" - assert svc.metadata == {"team": "platform", "version": "1.0"} - - -async def test_service_handlers(harness: HarnessEnvironment): - async with AdminClient(harness.admin_api_url) as admin: - svc = await admin.get_service("greeter") - - handler = svc.get_handler("greet") - assert handler is not None - assert handler.name == "greet" - assert handler.documentation == "Greets a person by name." - assert handler.metadata is not None - assert handler.metadata["a2a.restate.dev/handler"] == "true" - assert handler.metadata["a2a.restate.dev/skill"] == "greeting" - - -async def test_virtual_object_type(harness: HarnessEnvironment): - async with AdminClient(harness.admin_api_url) as admin: - svc = await admin.get_service("counter") - - assert svc.ty == "VirtualObject" - assert svc.documentation == "A durable counter" - - inc = svc.get_handler("increment") - assert inc is not None - assert inc.ty == "Exclusive" - - cnt = svc.get_handler("count") - assert cnt is not None - assert cnt.ty == "Shared" - assert cnt.documentation == "Returns the current count." - - -async def test_handler_not_found(harness: HarnessEnvironment): - async with AdminClient(harness.admin_api_url) as admin: - svc = await admin.get_service("bare") - - assert svc.get_handler("nonexistent") is None - - -async def test_service_not_found(harness: HarnessEnvironment): - async with AdminClient(harness.admin_api_url) as admin: - with pytest.raises(Exception): - await admin.get_service("does_not_exist") - - -async def test_service_without_metadata(harness: HarnessEnvironment): - async with AdminClient(harness.admin_api_url) as admin: - svc = await admin.get_service("bare") - - assert svc.name == "bare" - # Services without explicit metadata may have None or empty dict - assert svc.metadata is None or svc.metadata == {} - assert svc.documentation is None From 115a0b4da7903fa86f5f206991b21005ce49fe37 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Mon, 30 Mar 2026 11:59:45 +0200 Subject: [PATCH 05/12] allow passing in class instances --- examples/class_based_greeter.py | 50 ++++- python/restate/cls.py | 191 +++++++++++++----- python/restate/context_access.py | 97 ++++----- python/restate/endpoint.py | 24 ++- .../services/awakeable_holder.py | 10 +- .../services/block_and_wait_workflow.py | 12 +- test-services-cls/services/cancel_test.py | 18 +- test-services-cls/services/counter.py | 16 +- test-services-cls/services/failing.py | 10 +- test-services-cls/services/interpreter.py | 52 ++--- test-services-cls/services/kill_test.py | 10 +- test-services-cls/services/list_object.py | 12 +- test-services-cls/services/map_object.py | 12 +- test-services-cls/services/non_determinism.py | 30 +-- test-services-cls/services/proxy.py | 10 +- test-services-cls/services/test_utils.py | 10 +- .../virtual_object_command_interpreter.py | 34 ++-- 17 files changed, 350 insertions(+), 248 deletions(-) diff --git a/examples/class_based_greeter.py b/examples/class_based_greeter.py index 640f377..070febc 100644 --- a/examples/class_based_greeter.py +++ b/examples/class_based_greeter.py @@ -7,8 +7,23 @@ from datetime import timedelta +from pydantic import BaseModel + import restate -from restate.cls import Service, VirtualObject, Workflow, handler, shared, main, Context +from restate.cls import Service, VirtualObject, Workflow, handler, shared, main, Restate + + +# ── Pydantic models ── + + +class GreetingRequest(BaseModel): + name: str + language: str = "en" + + +class GreetingResponse(BaseModel): + message: str + language: str class Greeter(Service): @@ -24,14 +39,14 @@ class Counter(VirtualObject): @handler async def increment(self, value: int) -> int: - n: int = await Context.get("counter", type_hint=int) or 0 + n: int = await Restate.get("counter", type_hint=int) or 0 n += value - Context.set("counter", n) + Restate.set("counter", n) return n @shared async def count(self) -> int: - return await Context.get("counter", type_hint=int) or 0 + return await Restate.get("counter", type_hint=int) or 0 class PaymentWorkflow(Workflow): @@ -39,18 +54,18 @@ class PaymentWorkflow(Workflow): @main async def pay(self, amount: int) -> str: - Context.set("status", "processing") + Restate.set("status", "processing") async def charge(): return f"charged ${amount}" - receipt = await Context.run_typed("charge", charge) - Context.set("status", "completed") + receipt = await Restate.run("charge", charge) + Restate.set("status", "completed") return receipt @handler async def status(self) -> str: - return await Context.get("status", type_hint=str) or "unknown" + return await Restate.get("status", type_hint=str) or "unknown" class OrderProcessor(Service): @@ -76,4 +91,21 @@ async def process(self, customer: str) -> str: return f"{greeting} (visit #{count}, {receipt})" -app = restate.app([Greeter, Counter, PaymentWorkflow, OrderProcessor]) +class PydanticGreeter(Service): + """Demonstrates Pydantic model serde with the class-based API.""" + + def __init__(self, name): + self.name = name + + @handler + async def greet(self, req: GreetingRequest) -> GreetingResponse: + greetings = {"en": "Hello", "es": "Hola", "de": "Hallo"} + greeting = greetings.get(req.language, "Hello") + + async def translate() -> GreetingResponse: + return GreetingResponse(message=f"{greeting} {req.name} from {self.name}", language=req.language) + + return await Restate.run("translate", translate) + + +app = restate.app([Greeter, Counter, PaymentWorkflow, OrderProcessor, PydanticGreeter("Restate")]) diff --git a/python/restate/cls.py b/python/restate/cls.py index 0ebcfa6..d077234 100644 --- a/python/restate/cls.py +++ b/python/restate/cls.py @@ -11,13 +11,72 @@ """ Class-based API for defining Restate services. -This module provides an alternative to the decorator-based API, allowing -services to be defined as classes with handler methods. +This module lets you define Restate services as plain Python classes. +Under the hood, each class is transformed into the same primitives used +by the decorator-based API (``restate.Service``, ``restate.VirtualObject``, +``restate.Workflow``). + +Transformation overview +----------------------- + +Given user code like this:: + + class Greeter(Service): + def __init__(self, prefix: str): + self.prefix = prefix + + @handler + async def greet(self, name: str) -> str: + return f"{self.prefix} {name}!" + + app = restate.app([Greeter("Hello")]) + +The following happens at **class definition time** (``__init_subclass__``): + +1. ``_process_class`` scans ``Greeter.__dict__`` for methods marked with + ``@handler``, ``@shared``, or ``@main``. + +2. For each marked method it creates a **placeholder wrapper** with the + signature ``(ctx, *args)`` that Restate's ``invoke_handler`` expects. + This placeholder raises ``RuntimeError`` if called before binding. + +3. The original method's signature is preserved separately for **type + deduction** — ``inspect.signature(method)`` is passed to + ``make_handler`` so that Pydantic/msgspec serde and JSON schemas + are derived from the real ``(self, name: str) -> str`` annotations, + not from the ``(*args)`` wrapper. + +4. A **companion service object** (a plain ``restate.Service``, + ``restate.VirtualObject``, or ``restate.Workflow``) is created and + stored on the class as ``Greeter._restate_service``. This companion + holds the handler dict and all service-level configuration. + +Then at **bind time** (``restate.app([...])`` → ``Endpoint.bind``): + +5. If a **class** is passed (``restate.app([Greeter])``), it is + instantiated via ``Greeter()``. If the constructor requires arguments, + a ``TypeError`` tells the user to pass an instance instead. + +6. If an **instance** is passed (``restate.app([Greeter("Hello")])``), + it is used directly. + +7. ``_bind_instance(instance)`` is called, which replaces each handler's + placeholder wrapper with a real one that **closes over the instance**. + The companion ``_restate_service`` is then registered with the + endpoint just like any decorator-based service. + +At **invocation time**, Restate calls the wrapper which dispatches +to the bound method via its closure — no class-level state involved:: + + wrapper(ctx, "Alice") + → _method = Greeter.greet # captured in closure + → _inst = Greeter("Hello") obj # captured in closure + → Greeter.greet(_inst, "Alice") + → "Hello Alice!" Example:: from restate.cls import Service, VirtualObject, Workflow, handler, shared, main - import restate class Greeter(Service): @handler @@ -27,14 +86,14 @@ async def greet(self, name: str) -> str: class Counter(VirtualObject): @handler async def increment(self, value: int) -> int: - n = await restate.get("counter", type_hint=int) or 0 + n = await Restate.get("counter", type_hint=int) or 0 n += value - restate.set("counter", n) + Restate.set("counter", n) return n @shared async def count(self) -> int: - return await restate.get("counter", type_hint=int) or 0 + return await Restate.get("counter", type_hint=int) or 0 """ from __future__ import annotations @@ -318,14 +377,14 @@ def _process_class( handler_kind = _resolve_handler_kind(service_kind, meta.kind) handler_name = meta.name or method.__name__ - # Create a wrapper that instantiates the class and calls the method. - # The wrapper has signature (ctx, *args) matching what invoke_handler expects. + # Placeholder wrapper — replaced by _bind_instance() at bind time + # with one that closes over the actual instance. @wraps(method) - async def wrapper(ctx, *args, _method=method, _cls=cls): - instance = object.__new__(_cls) - if args: - return await _method(instance, *args) - return await _method(instance) + async def wrapper(ctx, *args): + raise RuntimeError( + f"Handler {handler_name} called before instance was bound. " + f"Use restate.app([{cls.__name__}(...)]) to bind an instance." + ) # Use the original method's signature for type/serde inspection sig = inspect.signature(method, eval_str=True) @@ -417,6 +476,36 @@ async def wrapper(ctx, *args, _method=method, _cls=cls): cls._restate_service = svc # type: ignore[attr-defined] +def _bind_instance(instance: Any) -> None: + """Create real handler wrappers that close over *instance*. + + Called from ``Endpoint.bind()`` once the instance is known. + Replaces the placeholder ``fn`` on each handler with a wrapper + that dispatches to the bound method on the instance. + """ + cls = type(instance) + svc = cls._restate_service # type: ignore[attr-defined] + for handler_name, h in svc.handlers.items(): + method = cls.__dict__.get(handler_name) + if method is None: + # handler name might differ from method name + for attr in cls.__dict__.values(): + meta = getattr(attr, _HANDLER_MARKER, None) + if meta and meta.name == handler_name: + method = attr + break + if method is None: + continue + + @wraps(method) + async def wrapper(ctx, *args, _method=method, _inst=instance): + if args: + return await _method(_inst, *args) + return await _method(_inst) + + h.fn = wrapper + + # ── Fluent RPC proxy classes ────────────────────────────────────────────── @@ -757,178 +846,168 @@ def send(cls, key: str, *, delay: Optional[timedelta] = None) -> Self: # type: # ── Context accessor class ──────────────────────────────────────────────── -class Context: +class Restate: """Static accessor for the current Restate invocation context. Use from within handler methods to access Restate functionality without an explicit ``ctx`` parameter:: - from restate.cls import Service, handler, Context + from restate.cls import Service, handler, Restate class Greeter(Service): @handler async def greet(self, name: str) -> str: - count = await Context.get("visits", type_hint=int) or 0 - Context.set("visits", count + 1) + count = await Restate.get("visits", type_hint=int) or 0 + Restate.set("visits", count + 1) return f"Hello {name}!" """ @staticmethod def _ctx() -> Any: - from restate.server_context import _restate_context_var # pylint: disable=C0415 + from restate.context_access import current_context # pylint: disable=C0415 - try: - return _restate_context_var.get() - except LookupError: - raise RuntimeError( - "Not inside a Restate handler. Context methods can only be called within a handler invocation." - ) from None + return current_context() # ── State ── @staticmethod def get(name: str, serde: Serde = DefaultSerde(), type_hint: Optional[type] = None) -> Any: """Retrieve a state value by name.""" - return Context._ctx().get(name, serde=serde, type_hint=type_hint) + return Restate._ctx().get(name, serde=serde, type_hint=type_hint) @staticmethod def set(name: str, value: Any, serde: Serde = DefaultSerde()) -> None: """Set a state value by name.""" - Context._ctx().set(name, value, serde=serde) + Restate._ctx().set(name, value, serde=serde) @staticmethod def clear(name: str) -> None: """Clear a state value by name.""" - Context._ctx().clear(name) + Restate._ctx().clear(name) @staticmethod def clear_all() -> None: """Clear all state values.""" - Context._ctx().clear_all() + Restate._ctx().clear_all() @staticmethod def state_keys() -> Any: """Return the list of state keys.""" - return Context._ctx().state_keys() + return Restate._ctx().state_keys() # ── Identity & request ── @staticmethod def key() -> str: """Return the key of the current virtual object or workflow.""" - return Context._ctx().key() + return Restate._ctx().key() @staticmethod def request() -> Any: """Return the current request object.""" - return Context._ctx().request() + return Restate._ctx().request() @staticmethod def random() -> Any: """Return a deterministically-seeded Random instance.""" - return Context._ctx().random() + return Restate._ctx().random() @staticmethod def uuid() -> Any: """Return a deterministic UUID, stable across retries.""" - return Context._ctx().uuid() + return Restate._ctx().uuid() @staticmethod def time() -> Any: """Return a durable timestamp, stable across retries.""" - return Context._ctx().time() + return Restate._ctx().time() # ── Durable execution ── @staticmethod - def run(name: str, action: Any, serde: Serde = DefaultSerde(), **kwargs: Any) -> Any: - """Run a durable side effect (deprecated — use run_typed).""" - return Context._ctx().run(name, action, serde=serde, **kwargs) - - @staticmethod - def run_typed(name: str, action: Any, *args: Any, **kwargs: Any) -> Any: + def run(name: str, action: Any, *args: Any, **kwargs: Any) -> Any: """Run a durable side effect with typed arguments.""" - return Context._ctx().run_typed(name, action, *args, **kwargs) + return Restate._ctx().run_typed(name, action, *args, **kwargs) @staticmethod def sleep(delta: timedelta, name: Optional[str] = None) -> Any: """Suspend the current invocation for the given duration.""" - return Context._ctx().sleep(delta, name=name) + return Restate._ctx().sleep(delta, name=name) # ── Service communication ── @staticmethod def service_call(tpe: Any, arg: Any, **kwargs: Any) -> Any: """Call a service handler.""" - return Context._ctx().service_call(tpe, arg=arg, **kwargs) + return Restate._ctx().service_call(tpe, arg=arg, **kwargs) @staticmethod def service_send(tpe: Any, arg: Any, **kwargs: Any) -> Any: """Send a message to a service handler (fire-and-forget).""" - return Context._ctx().service_send(tpe, arg=arg, **kwargs) + return Restate._ctx().service_send(tpe, arg=arg, **kwargs) @staticmethod def object_call(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: """Call a virtual object handler.""" - return Context._ctx().object_call(tpe, key=key, arg=arg, **kwargs) + return Restate._ctx().object_call(tpe, key=key, arg=arg, **kwargs) @staticmethod def object_send(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: """Send a message to a virtual object handler (fire-and-forget).""" - return Context._ctx().object_send(tpe, key=key, arg=arg, **kwargs) + return Restate._ctx().object_send(tpe, key=key, arg=arg, **kwargs) @staticmethod def workflow_call(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: """Call a workflow handler.""" - return Context._ctx().workflow_call(tpe, key=key, arg=arg, **kwargs) + return Restate._ctx().workflow_call(tpe, key=key, arg=arg, **kwargs) @staticmethod def workflow_send(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: """Send a message to a workflow handler (fire-and-forget).""" - return Context._ctx().workflow_send(tpe, key=key, arg=arg, **kwargs) + return Restate._ctx().workflow_send(tpe, key=key, arg=arg, **kwargs) @staticmethod def generic_call(service: str, handler: str, arg: bytes, key: Optional[str] = None, **kwargs: Any) -> Any: """Call a generic service/handler with raw bytes.""" - return Context._ctx().generic_call(service, handler, arg, key=key, **kwargs) + return Restate._ctx().generic_call(service, handler, arg, key=key, **kwargs) @staticmethod def generic_send(service: str, handler: str, arg: bytes, key: Optional[str] = None, **kwargs: Any) -> Any: """Send a message to a generic service/handler with raw bytes.""" - return Context._ctx().generic_send(service, handler, arg, key=key, **kwargs) + return Restate._ctx().generic_send(service, handler, arg, key=key, **kwargs) # ── Awakeables ── @staticmethod def awakeable(serde: Serde = DefaultSerde(), type_hint: Optional[type] = None) -> Any: """Create an awakeable and return (id, future).""" - return Context._ctx().awakeable(serde=serde, type_hint=type_hint) + return Restate._ctx().awakeable(serde=serde, type_hint=type_hint) @staticmethod def resolve_awakeable(name: str, value: Any, serde: Serde = DefaultSerde()) -> None: """Resolve an awakeable by id.""" - Context._ctx().resolve_awakeable(name, value, serde=serde) + Restate._ctx().resolve_awakeable(name, value, serde=serde) @staticmethod def reject_awakeable(name: str, failure_message: str, failure_code: int = 500) -> None: """Reject an awakeable by id.""" - Context._ctx().reject_awakeable(name, failure_message, failure_code=failure_code) + Restate._ctx().reject_awakeable(name, failure_message, failure_code=failure_code) # ── Promises (Workflow only) ── @staticmethod def promise(name: str, serde: Serde = DefaultSerde(), type_hint: Optional[type] = None) -> Any: """Return a durable promise (workflow handlers only).""" - return Context._ctx().promise(name, serde=serde, type_hint=type_hint) + return Restate._ctx().promise(name, serde=serde, type_hint=type_hint) # ── Invocation management ── @staticmethod def cancel_invocation(invocation_id: str) -> None: """Cancel an invocation by id.""" - Context._ctx().cancel_invocation(invocation_id) + Restate._ctx().cancel_invocation(invocation_id) @staticmethod def attach_invocation(invocation_id: str, serde: Serde = DefaultSerde(), type_hint: Optional[type] = None) -> Any: """Attach to an invocation by id.""" - return Context._ctx().attach_invocation(invocation_id, serde=serde, type_hint=type_hint) + return Restate._ctx().attach_invocation(invocation_id, serde=serde, type_hint=type_hint) diff --git a/python/restate/context_access.py b/python/restate/context_access.py index 1ebb2d8..b8465f7 100644 --- a/python/restate/context_access.py +++ b/python/restate/context_access.py @@ -27,7 +27,6 @@ RestateDurableFuture, RestateDurableSleepFuture, Request, - RunAction, RunOptions, SendHandle, ) @@ -38,14 +37,12 @@ O = TypeVar("O") -def _ctx() -> Any: - """Get the current restate context, raising if not inside a handler. +def current_context() -> Any: + """Get the current Restate context. - Returns Any because the actual runtime type is ServerInvocationContext - (which implements ObjectContext, WorkflowContext, etc.) but we want all - methods accessible without narrowing — runtime raises if mismatched. + Returns the context object for the current handler invocation. + Raises RuntimeError if called outside a handler. """ - # Import here to avoid circular imports from restate.server_context import _restate_context_var # pylint: disable=C0415 try: @@ -57,41 +54,32 @@ def _ctx() -> Any: ) from None -def current_context(): - """Get the current Restate context. - - Returns the context object for the current handler invocation. - Raises RuntimeError if called outside a handler. - """ - return _ctx() - - # ── State operations ── def get(name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[type] = None) -> Awaitable[Optional[T]]: """Retrieve a state value by name.""" - return _ctx().get(name, serde=serde, type_hint=type_hint) + return current_context().get(name, serde=serde, type_hint=type_hint) def set(name: str, value: T, serde: Serde[T] = DefaultSerde()) -> None: """Set a state value by name.""" - _ctx().set(name, value, serde=serde) + current_context().set(name, value, serde=serde) def clear(name: str) -> None: """Clear a state value by name.""" - _ctx().clear(name) + current_context().clear(name) def clear_all() -> None: """Clear all state values.""" - _ctx().clear_all() + current_context().clear_all() def state_keys() -> Awaitable[List[str]]: """Return the list of state keys.""" - return _ctx().state_keys() + return current_context().state_keys() # ── Identity & request ── @@ -99,54 +87,33 @@ def state_keys() -> Awaitable[List[str]]: def key() -> str: """Return the key of the current virtual object or workflow.""" - return _ctx().key() + return current_context().key() def request() -> Request: """Return the current request object.""" - return _ctx().request() + return current_context().request() def random() -> Random: """Return a deterministically-seeded Random instance.""" - return _ctx().random() + return current_context().random() def uuid() -> UUID: """Return a deterministic UUID, stable across retries.""" - return _ctx().uuid() + return current_context().uuid() def time() -> RestateDurableFuture[float]: """Return a durable timestamp, stable across retries.""" - return _ctx().time() + return current_context().time() # ── Durable execution ── def run( - name: str, - action: RunAction[T], - serde: Serde[T] = DefaultSerde(), - max_attempts: Optional[int] = None, - max_retry_duration: Optional[timedelta] = None, - type_hint: Optional[type] = None, - args: Optional[tuple] = None, -) -> RestateDurableFuture[T]: - """Run a durable side effect (deprecated — use run_typed instead).""" - return _ctx().run( - name, - action, - serde=serde, - max_attempts=max_attempts, - max_retry_duration=max_retry_duration, - type_hint=type_hint, - args=args, - ) - - -def run_typed( name: str, action: Union[Callable[..., Coroutine[Any, Any, T]], Callable[..., T]], options: RunOptions[T] = RunOptions(), @@ -155,12 +122,12 @@ def run_typed( **kwargs: Any, ) -> RestateDurableFuture[T]: """Run a durable side effect with typed arguments.""" - return _ctx().run_typed(name, action, options, *args, **kwargs) + return current_context().run_typed(name, action, options, *args, **kwargs) def sleep(delta: timedelta, name: Optional[str] = None) -> RestateDurableSleepFuture: """Suspend the current invocation for the given duration.""" - return _ctx().sleep(delta, name=name) + return current_context().sleep(delta, name=name) # ── Service communication ── @@ -173,7 +140,7 @@ def service_call( headers: Optional[Dict[str, str]] = None, ) -> RestateDurableCallFuture[O]: """Call a service handler.""" - return _ctx().service_call(tpe, arg=arg, idempotency_key=idempotency_key, headers=headers) + return current_context().service_call(tpe, arg=arg, idempotency_key=idempotency_key, headers=headers) def service_send( @@ -184,7 +151,9 @@ def service_send( headers: Optional[Dict[str, str]] = None, ) -> SendHandle: """Send a message to a service handler (fire-and-forget).""" - return _ctx().service_send(tpe, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers) + return current_context().service_send( + tpe, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers + ) def object_call( @@ -195,7 +164,7 @@ def object_call( headers: Optional[Dict[str, str]] = None, ) -> RestateDurableCallFuture[O]: """Call a virtual object handler.""" - return _ctx().object_call(tpe, key=key, arg=arg, idempotency_key=idempotency_key, headers=headers) + return current_context().object_call(tpe, key=key, arg=arg, idempotency_key=idempotency_key, headers=headers) def object_send( @@ -207,7 +176,7 @@ def object_send( headers: Optional[Dict[str, str]] = None, ) -> SendHandle: """Send a message to a virtual object handler (fire-and-forget).""" - return _ctx().object_send( + return current_context().object_send( tpe, key=key, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers ) @@ -220,7 +189,7 @@ def workflow_call( headers: Optional[Dict[str, str]] = None, ) -> RestateDurableCallFuture[O]: """Call a workflow handler.""" - return _ctx().workflow_call(tpe, key=key, arg=arg, idempotency_key=idempotency_key, headers=headers) + return current_context().workflow_call(tpe, key=key, arg=arg, idempotency_key=idempotency_key, headers=headers) def workflow_send( @@ -232,7 +201,7 @@ def workflow_send( headers: Optional[Dict[str, str]] = None, ) -> SendHandle: """Send a message to a workflow handler (fire-and-forget).""" - return _ctx().workflow_send( + return current_context().workflow_send( tpe, key=key, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers ) @@ -246,7 +215,9 @@ def generic_call( headers: Optional[Dict[str, str]] = None, ) -> RestateDurableCallFuture[bytes]: """Call a generic service/handler with raw bytes.""" - return _ctx().generic_call(service, handler, arg, key=key, idempotency_key=idempotency_key, headers=headers) + return current_context().generic_call( + service, handler, arg, key=key, idempotency_key=idempotency_key, headers=headers + ) def generic_send( @@ -259,7 +230,7 @@ def generic_send( headers: Optional[Dict[str, str]] = None, ) -> SendHandle: """Send a message to a generic service/handler with raw bytes.""" - return _ctx().generic_send( + return current_context().generic_send( service, handler, arg, key=key, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers ) @@ -271,17 +242,17 @@ def awakeable( serde: Serde[T] = DefaultSerde(), type_hint: Optional[type] = None ) -> Tuple[str, RestateDurableFuture[T]]: """Create an awakeable and return (id, future).""" - return _ctx().awakeable(serde=serde, type_hint=type_hint) + return current_context().awakeable(serde=serde, type_hint=type_hint) def resolve_awakeable(name: str, value: I, serde: Serde[I] = DefaultSerde()) -> None: """Resolve an awakeable by id.""" - _ctx().resolve_awakeable(name, value, serde=serde) + current_context().resolve_awakeable(name, value, serde=serde) def reject_awakeable(name: str, failure_message: str, failure_code: int = 500) -> None: """Reject an awakeable by id.""" - _ctx().reject_awakeable(name, failure_message, failure_code=failure_code) + current_context().reject_awakeable(name, failure_message, failure_code=failure_code) # ── Promises (Workflow only) ── @@ -289,7 +260,7 @@ def reject_awakeable(name: str, failure_message: str, failure_code: int = 500) - def promise(name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[type] = None) -> DurablePromise[T]: """Return a durable promise (workflow handlers only).""" - return _ctx().promise(name, serde=serde, type_hint=type_hint) + return current_context().promise(name, serde=serde, type_hint=type_hint) # ── Invocation management ── @@ -297,11 +268,11 @@ def promise(name: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[typ def cancel_invocation(invocation_id: str): """Cancel an invocation by id.""" - _ctx().cancel_invocation(invocation_id) + current_context().cancel_invocation(invocation_id) def attach_invocation( invocation_id: str, serde: Serde[T] = DefaultSerde(), type_hint: Optional[type] = None ) -> RestateDurableFuture[T]: """Attach to an invocation by id.""" - return _ctx().attach_invocation(invocation_id, serde=serde, type_hint=type_hint) + return current_context().attach_invocation(invocation_id, serde=serde, type_hint=type_hint) diff --git a/python/restate/endpoint.py b/python/restate/endpoint.py index 0b59e96..5d6760d 100644 --- a/python/restate/endpoint.py +++ b/python/restate/endpoint.py @@ -59,8 +59,28 @@ def bind(self, *services: typing.Any): The updated Endpoint instance """ for service in services: - # Support class-based services: extract companion object - actual = getattr(service, "_restate_service", service) + # Support class-based services: extract companion object. + if isinstance(service, type) and hasattr(service, "_restate_service"): + # Class passed — instantiate it and bind + from restate.cls import _bind_instance # pylint: disable=C0415 + + try: + instance = service() + except TypeError as e: + raise TypeError( + f"{service.__name__} requires constructor arguments. " + f"Pass an instance instead: restate.app([{service.__name__}(...)])" + ) from e + _bind_instance(instance) + actual = service._restate_service # type: ignore[attr-defined] + elif not isinstance(service, type) and hasattr(type(service), "_restate_service"): + # Instance passed — bind it + from restate.cls import _bind_instance # pylint: disable=C0415 + + _bind_instance(service) + actual = type(service)._restate_service # type: ignore[attr-defined] + else: + actual = getattr(service, "_restate_service", service) if actual.name in self.services: raise ValueError(f"Service {actual.name} already exists") if isinstance(actual, (Service, VirtualObject, Workflow)): diff --git a/test-services-cls/services/awakeable_holder.py b/test-services-cls/services/awakeable_holder.py index dbf4c65..bbd2bc4 100644 --- a/test-services-cls/services/awakeable_holder.py +++ b/test-services-cls/services/awakeable_holder.py @@ -13,7 +13,7 @@ # pylint: disable=W0613 # pylint: disable=W0622 -from restate.cls import VirtualObject, handler, Context +from restate.cls import VirtualObject, handler, Restate from restate.exceptions import TerminalError @@ -21,16 +21,16 @@ class AwakeableHolder(VirtualObject, name="AwakeableHolder"): @handler async def hold(self, id: str): - Context.set("id", id) + Restate.set("id", id) @handler(name="hasAwakeable") async def has_awakeable(self) -> bool: - res = await Context.get("id") + res = await Restate.get("id") return res is not None @handler async def unlock(self, payload: str): - id = await Context.get("id") + id = await Restate.get("id") if id is None: raise TerminalError(message="No awakeable is registered") - Context.resolve_awakeable(id, payload) + Restate.resolve_awakeable(id, payload) diff --git a/test-services-cls/services/block_and_wait_workflow.py b/test-services-cls/services/block_and_wait_workflow.py index 5cead7d..3ac6570 100644 --- a/test-services-cls/services/block_and_wait_workflow.py +++ b/test-services-cls/services/block_and_wait_workflow.py @@ -13,7 +13,7 @@ # pylint: disable=W0613 # pylint: disable=W0622 -from restate.cls import Workflow, handler, main, Context +from restate.cls import Workflow, handler, main, Restate from restate.exceptions import TerminalError @@ -21,10 +21,10 @@ class BlockAndWaitWorkflow(Workflow, name="BlockAndWaitWorkflow"): @main async def run(self, input: str): - Context.set("my-state", input) - output = await Context.promise("durable-promise").value() + Restate.set("my-state", input) + output = await Restate.promise("durable-promise").value() - peek = await Context.promise("durable-promise").peek() + peek = await Restate.promise("durable-promise").peek() if peek is None: raise TerminalError(message="Durable promise should be completed") @@ -32,8 +32,8 @@ async def run(self, input: str): @handler async def unblock(self, output: str): - await Context.promise("durable-promise").resolve(output) + await Restate.promise("durable-promise").resolve(output) @handler(name="getState") async def get_state(self, output: str) -> str | None: - return await Context.get("my-state") + return await Restate.get("my-state") diff --git a/test-services-cls/services/cancel_test.py b/test-services-cls/services/cancel_test.py index 272c266..d31f369 100644 --- a/test-services-cls/services/cancel_test.py +++ b/test-services-cls/services/cancel_test.py @@ -14,7 +14,7 @@ from datetime import timedelta from typing import Literal -from restate.cls import VirtualObject, handler, Context +from restate.cls import VirtualObject, handler, Restate from restate.exceptions import TerminalError from . import awakeable_holder @@ -28,16 +28,16 @@ class CancelTestRunner(VirtualObject, name="CancelTestRunner"): async def start_test(self, op: BlockingOperation): block_fn = CancelTestBlockingService._restate_handlers["block"].fn try: - await Context.object_call(block_fn, key=Context.key(), arg=op) + await Restate.object_call(block_fn, key=Restate.key(), arg=op) except TerminalError as t: if t.status_code == 409: - Context.set("state", True) + Restate.set("state", True) else: raise t @handler(name="verifyTest") async def verify_test(self) -> bool: - state = await Context.get("state") + state = await Restate.get("state") if state is None: return False return state @@ -48,17 +48,17 @@ class CancelTestBlockingService(VirtualObject, name="CancelTestBlockingService") @handler async def block(self, op: BlockingOperation): hold_fn = awakeable_holder.AwakeableHolder._restate_handlers["hold"].fn - name, awakeable = Context.awakeable() - Context.object_send(hold_fn, key=Context.key(), arg=name) + name, awakeable = Restate.awakeable() + Restate.object_send(hold_fn, key=Restate.key(), arg=name) await awakeable block_fn = CancelTestBlockingService._restate_handlers["block"].fn if op == "CALL": - await Context.object_call(block_fn, key=Context.key(), arg=op) + await Restate.object_call(block_fn, key=Restate.key(), arg=op) elif op == "SLEEP": - await Context.sleep(timedelta(days=1024)) + await Restate.sleep(timedelta(days=1024)) elif op == "AWAKEABLE": - name, uncompleteable = Context.awakeable() + name, uncompleteable = Restate.awakeable() await uncompleteable @handler(name="isUnlocked") diff --git a/test-services-cls/services/counter.py b/test-services-cls/services/counter.py index 1140db3..1ca93f1 100644 --- a/test-services-cls/services/counter.py +++ b/test-services-cls/services/counter.py @@ -13,7 +13,7 @@ # pylint: disable=W0613 from typing import TypedDict -from restate.cls import VirtualObject, handler, Context +from restate.cls import VirtualObject, handler, Restate from restate.exceptions import TerminalError COUNTER_KEY = "counter" @@ -28,29 +28,29 @@ class Counter(VirtualObject, name="Counter"): @handler async def reset(self): - Context.clear(COUNTER_KEY) + Restate.clear(COUNTER_KEY) @handler async def get(self) -> int: - c: int | None = await Context.get(COUNTER_KEY) + c: int | None = await Restate.get(COUNTER_KEY) if c is None: return 0 return c @handler async def add(self, addend: int) -> CounterUpdateResponse: - old_value: int | None = await Context.get(COUNTER_KEY) + old_value: int | None = await Restate.get(COUNTER_KEY) if old_value is None: old_value = 0 new_value = old_value + addend - Context.set(COUNTER_KEY, new_value) + Restate.set(COUNTER_KEY, new_value) return CounterUpdateResponse(oldValue=old_value, newValue=new_value) @handler(name="addThenFail") async def add_then_fail(self, addend: int): - old_value: int | None = await Context.get(COUNTER_KEY) + old_value: int | None = await Restate.get(COUNTER_KEY) if old_value is None: old_value = 0 new_value = old_value + addend - Context.set(COUNTER_KEY, new_value) - raise TerminalError(message=Context.key()) + Restate.set(COUNTER_KEY, new_value) + raise TerminalError(message=Restate.key()) diff --git a/test-services-cls/services/failing.py b/test-services-cls/services/failing.py index 760f339..8d36382 100644 --- a/test-services-cls/services/failing.py +++ b/test-services-cls/services/failing.py @@ -16,7 +16,7 @@ # pylint: disable=W0613 # pylint: disable=W0622 -from restate.cls import VirtualObject, handler, Context +from restate.cls import VirtualObject, handler, Restate from restate.exceptions import TerminalError from restate import RunOptions @@ -34,7 +34,7 @@ async def terminally_failing_call(self, msg: str): @handler(name="callTerminallyFailingCall") async def call_terminally_failing_call(self, msg: str) -> str: fn = Failing._restate_handlers["terminallyFailingCall"].fn - await Context.object_call(fn, key="random-583e1bf2", arg=msg) + await Restate.object_call(fn, key="random-583e1bf2", arg=msg) raise Exception("Should not reach here") @handler(name="failingCallWithEventualSuccess") @@ -51,7 +51,7 @@ async def terminally_failing_side_effect(self, error_message: str): def side_effect(): raise TerminalError(message=error_message) - await Context.run_typed("sideEffect", side_effect) + await Restate.run("sideEffect", side_effect) raise ValueError("Should not reach here") @handler(name="sideEffectSucceedsAfterGivenAttempts") @@ -66,7 +66,7 @@ def side_effect(): options: RunOptions[int] = RunOptions( max_attempts=minimum_attempts + 1, initial_retry_interval=timedelta(milliseconds=1), retry_interval_factor=1.0 ) - return await Context.run_typed("sideEffect", side_effect, options) + return await Restate.run("sideEffect", side_effect, options) @handler(name="sideEffectFailsAfterGivenAttempts") async def side_effect_fails_after_given_attempts(self, retry_policy_max_retry_count: int) -> int: @@ -81,7 +81,7 @@ def side_effect(): initial_retry_interval=timedelta(milliseconds=1), retry_interval_factor=1.0, ) - await Context.run_typed("sideEffect", side_effect, options) + await Restate.run("sideEffect", side_effect, options) raise ValueError("Side effect did not fail.") except TerminalError: global eventual_failure_side_effects diff --git a/test-services-cls/services/interpreter.py b/test-services-cls/services/interpreter.py index 528e364..bb66c52 100644 --- a/test-services-cls/services/interpreter.py +++ b/test-services-cls/services/interpreter.py @@ -16,7 +16,7 @@ import typing import random -from restate.cls import Service, handler, Context +from restate.cls import Service, handler, Restate from restate.exceptions import TerminalError from restate.serde import JsonSerde @@ -63,7 +63,7 @@ async def echo(self, parameters: str) -> str: @handler(name="echoLater") async def echo_later(self, parameter: dict[str, typing.Any]) -> str: - await Context.sleep(timedelta(milliseconds=parameter["sleep"])) + await Restate.sleep(timedelta(milliseconds=parameter["sleep"])) return parameter["parameter"] @handler(name="terminalFailure") @@ -84,15 +84,15 @@ async def increment_indirectly(self, parameter) -> None: } program_bytes = json.dumps(program).encode("utf-8") - Context.generic_send(f"ObjectInterpreterL{layer}", "interpret", program_bytes, key) + Restate.generic_send(f"ObjectInterpreterL{layer}", "interpret", program_bytes, key) @handler(name="resolveAwakeable") async def resolve_awakeable(self, aid: str) -> None: - Context.resolve_awakeable(aid, "ok") + Restate.resolve_awakeable(aid, "ok") @handler(name="rejectAwakeable") async def reject_awakeable(self, aid: str) -> None: - Context.reject_awakeable(aid, "error") + Restate.reject_awakeable(aid, "error") @handler(name="incrementViaAwakeableDance") async def increment_via_awakeable_dance(self, input: dict[str, typing.Any]) -> None: @@ -100,8 +100,8 @@ async def increment_via_awakeable_dance(self, input: dict[str, typing.Any]) -> N layer = input["interpreter"]["layer"] key = input["interpreter"]["key"] - aid, promise = Context.awakeable() - Context.resolve_awakeable(tx_promise_id, aid) + aid, promise = Restate.awakeable() + Restate.resolve_awakeable(tx_promise_id, aid) await promise program = { @@ -113,7 +113,7 @@ async def increment_via_awakeable_dance(self, input: dict[str, typing.Any]) -> N } program_bytes = json.dumps(program).encode("utf-8") - Context.generic_send(f"ObjectInterpreterL{layer}", "interpret", program_bytes, key) + Restate.generic_send(f"ObjectInterpreterL{layer}", "interpret", program_bytes, key) # Keep helper as a reference to the class for the __init__.py import @@ -127,7 +127,7 @@ def __init__(self) -> None: async def call(self, method: str, arg: typing.Any) -> typing.Any: buffer = self.serde.serialize(arg) - out_buffer = await Context.generic_call("ServiceInterpreterHelper", method, buffer) + out_buffer = await Restate.generic_call("ServiceInterpreterHelper", method, buffer) return self.serde.deserialize(out_buffer) def send(self, method: str, arg: typing.Any, delay: int | None = None) -> None: @@ -136,7 +136,7 @@ def send(self, method: str, arg: typing.Any, delay: int | None = None) -> None: send_delay = None else: send_delay = timedelta(milliseconds=delay) - Context.generic_send("ServiceInterpreterHelper", method, buffer, send_delay=send_delay) + Restate.generic_send("ServiceInterpreterHelper", method, buffer, send_delay=send_delay) async def ping(self) -> None: return await self.call(method="ping", arg=None) @@ -200,30 +200,30 @@ async def await_promise(index: int) -> None: for i, command in enumerate(program["commands"]): command_type = command["kind"] if command_type == SET_STATE: - Context.set(f"key-{command['key']}", f"value-{command['key']}") + Restate.set(f"key-{command['key']}", f"value-{command['key']}") elif command_type == GET_STATE: - await Context.get(f"key-{command['key']}") + await Restate.get(f"key-{command['key']}") elif command_type == CLEAR_STATE: - Context.clear(f"key-{command['key']}") + Restate.clear(f"key-{command['key']}") elif command_type == INCREMENT_STATE_COUNTER: - c = await Context.get("counter") or 0 + c = await Restate.get("counter") or 0 c += 1 - Context.set("counter", c) + Restate.set("counter", c) elif command_type == SLEEP: duration = timedelta(milliseconds=command["duration"]) - await Context.sleep(duration) + await Restate.sleep(duration) elif command_type == CALL_SERVICE: expected = f"hello-{i}" coros[i] = (expected, service.echo(expected)) elif command_type == INCREMENT_VIA_DELAYED_CALL: delay = command["duration"] - await service.increment_indirectly(layer=layer, key=Context.key(), delay=delay) + await service.increment_indirectly(layer=layer, key=Restate.key(), delay=delay) elif command_type == CALL_SLOW_SERVICE: expected = f"hello-{i}" coros[i] = (expected, service.echo_later(expected, command["sleep"])) elif command_type == SIDE_EFFECT: expected = f"hello-{i}" - result = await Context.run_typed("sideEffect", lambda: expected) + result = await Restate.run("sideEffect", lambda: expected) if result != expected: raise TerminalError(f"Expected {expected} but got {result}") elif command_type == SLOW_SIDE_EFFECT: @@ -243,32 +243,32 @@ async def side_effect(): if bool(random.getrandbits(1)): raise ValueError("Random error") - await Context.run_typed("throwingSideEffect", side_effect) + await Restate.run("throwingSideEffect", side_effect) elif command_type == INCREMENT_STATE_COUNTER_INDIRECTLY: - await service.increment_indirectly(layer=layer, key=Context.key()) + await service.increment_indirectly(layer=layer, key=Restate.key()) elif command_type == AWAIT_PROMISE: index = command["index"] await await_promise(index) elif command_type == RESOLVE_AWAKEABLE: - name, promise = Context.awakeable() + name, promise = Restate.awakeable() coros[i] = ("ok", promise) service.resolve_awakeable(name) elif command_type == REJECT_AWAKEABLE: - name, promise = Context.awakeable() + name, promise = Restate.awakeable() coros[i] = ("rejected", promise) service.reject_awakeable(name) elif command_type == INCREMENT_STATE_COUNTER_VIA_AWAKEABLE: - tx_promise_id, tx_promise = Context.awakeable() - service.increment_via_awakeable_dance(layer=layer, key=Context.key(), tx_promise_id=tx_promise_id) + tx_promise_id, tx_promise = Restate.awakeable() + service.increment_via_awakeable_dance(layer=layer, key=Restate.key(), tx_promise_id=tx_promise_id) their_promise_for_us_to_resolve: str = await tx_promise - Context.resolve_awakeable(their_promise_for_us_to_resolve, "ok") + Restate.resolve_awakeable(their_promise_for_us_to_resolve, "ok") elif command_type == CALL_NEXT_LAYER_OBJECT: next_layer = f"ObjectInterpreterL{layer + 1}" key = f"{command['key']}" program = command["program"] js_program = json.dumps(program) raw_js_program = js_program.encode("utf-8") - promise = Context.generic_call(next_layer, "interpret", raw_js_program, key) + promise = Restate.generic_call(next_layer, "interpret", raw_js_program, key) coros[i] = (b"", promise) else: raise ValueError(f"Unknown command type: {command_type}") diff --git a/test-services-cls/services/kill_test.py b/test-services-cls/services/kill_test.py index 2994350..70254fb 100644 --- a/test-services-cls/services/kill_test.py +++ b/test-services-cls/services/kill_test.py @@ -12,7 +12,7 @@ # pylint: disable=C0116 # pylint: disable=W0613 -from restate.cls import VirtualObject, handler, Context +from restate.cls import VirtualObject, handler, Restate from . import awakeable_holder @@ -22,7 +22,7 @@ class KillTestRunner(VirtualObject, name="KillTestRunner"): @handler(name="startCallTree") async def start_call_tree(self): fn = KillTestSingleton._restate_handlers["recursiveCall"].fn - await Context.object_call(fn, key=Context.key(), arg=None) + await Restate.object_call(fn, key=Restate.key(), arg=None) class KillTestSingleton(VirtualObject, name="KillTestSingleton"): @@ -30,12 +30,12 @@ class KillTestSingleton(VirtualObject, name="KillTestSingleton"): @handler(name="recursiveCall") async def recursive_call(self): hold_fn = awakeable_holder.AwakeableHolder._restate_handlers["hold"].fn - name, promise = Context.awakeable() - Context.object_send(hold_fn, key=Context.key(), arg=name) + name, promise = Restate.awakeable() + Restate.object_send(hold_fn, key=Restate.key(), arg=name) await promise fn = KillTestSingleton._restate_handlers["recursiveCall"].fn - await Context.object_call(fn, key=Context.key(), arg=None) + await Restate.object_call(fn, key=Restate.key(), arg=None) @handler(name="isUnlocked") async def is_unlocked(self): diff --git a/test-services-cls/services/list_object.py b/test-services-cls/services/list_object.py index 1786b7e..122f37e 100644 --- a/test-services-cls/services/list_object.py +++ b/test-services-cls/services/list_object.py @@ -12,22 +12,22 @@ # pylint: disable=C0116 # pylint: disable=W0613 -from restate.cls import VirtualObject, handler, Context +from restate.cls import VirtualObject, handler, Restate class ListObject(VirtualObject, name="ListObject"): @handler async def append(self, value: str): - lst = await Context.get("list") or [] - Context.set("list", lst + [value]) + lst = await Restate.get("list") or [] + Restate.set("list", lst + [value]) @handler async def get(self) -> list[str]: - return await Context.get("list") or [] + return await Restate.get("list") or [] @handler async def clear(self) -> list[str]: - result = await Context.get("list") or [] - Context.clear("list") + result = await Restate.get("list") or [] + Restate.clear("list") return result diff --git a/test-services-cls/services/map_object.py b/test-services-cls/services/map_object.py index e039c46..d8d7b54 100644 --- a/test-services-cls/services/map_object.py +++ b/test-services-cls/services/map_object.py @@ -13,7 +13,7 @@ # pylint: disable=W0613 from typing import TypedDict -from restate.cls import VirtualObject, handler, Context +from restate.cls import VirtualObject, handler, Restate class Entry(TypedDict): @@ -25,18 +25,18 @@ class MapObject(VirtualObject, name="MapObject"): @handler(name="set") async def map_set(self, entry: Entry): - Context.set(entry["key"], entry["value"]) + Restate.set(entry["key"], entry["value"]) @handler(name="get") async def map_get(self, key: str) -> str: - return await Context.get(key) or "" + return await Restate.get(key) or "" @handler(name="clearAll") async def map_clear_all(self) -> list[Entry]: entries = [] - for key in await Context.state_keys(): - value: str = await Context.get(key) # type: ignore + for key in await Restate.state_keys(): + value: str = await Restate.get(key) # type: ignore entry = Entry(key=key, value=value) entries.append(entry) - Context.clear(key) + Restate.clear(key) return entries diff --git a/test-services-cls/services/non_determinism.py b/test-services-cls/services/non_determinism.py index e11c237..bc3a8ff 100644 --- a/test-services-cls/services/non_determinism.py +++ b/test-services-cls/services/non_determinism.py @@ -14,7 +14,7 @@ from datetime import timedelta from typing import Dict -from restate.cls import VirtualObject, handler, Context +from restate.cls import VirtualObject, handler, Restate from . import counter @@ -22,14 +22,14 @@ def do_left_action() -> bool: - count_key = Context.key() + count_key = Restate.key() invoke_counts[count_key] = invoke_counts.get(count_key, 0) + 1 return invoke_counts[count_key] % 2 == 1 def increment_counter(): add_fn = counter.Counter._restate_handlers["add"].fn - Context.object_send(add_fn, key=Context.key(), arg=1) + Restate.object_send(add_fn, key=Restate.key(), arg=1) class NonDeterministic(VirtualObject, name="NonDeterministic"): @@ -37,10 +37,10 @@ class NonDeterministic(VirtualObject, name="NonDeterministic"): @handler(name="setDifferentKey") async def set_different_key(self): if do_left_action(): - Context.set("a", "my-state") + Restate.set("a", "my-state") else: - Context.set("b", "my-state") - await Context.sleep(timedelta(milliseconds=100)) + Restate.set("b", "my-state") + await Restate.sleep(timedelta(milliseconds=100)) increment_counter() @handler(name="backgroundInvokeWithDifferentTargets") @@ -48,10 +48,10 @@ async def background_invoke_with_different_targets(self): get_fn = counter.Counter._restate_handlers["get"].fn reset_fn = counter.Counter._restate_handlers["reset"].fn if do_left_action(): - Context.object_send(get_fn, key="abc", arg=None) + Restate.object_send(get_fn, key="abc", arg=None) else: - Context.object_send(reset_fn, key="abc", arg=None) - await Context.sleep(timedelta(milliseconds=100)) + Restate.object_send(reset_fn, key="abc", arg=None) + await Restate.sleep(timedelta(milliseconds=100)) increment_counter() @handler(name="callDifferentMethod") @@ -59,18 +59,18 @@ async def call_different_method(self): get_fn = counter.Counter._restate_handlers["get"].fn reset_fn = counter.Counter._restate_handlers["reset"].fn if do_left_action(): - await Context.object_call(get_fn, key="abc", arg=None) + await Restate.object_call(get_fn, key="abc", arg=None) else: - await Context.object_call(reset_fn, key="abc", arg=None) - await Context.sleep(timedelta(milliseconds=100)) + await Restate.object_call(reset_fn, key="abc", arg=None) + await Restate.sleep(timedelta(milliseconds=100)) increment_counter() @handler(name="eitherSleepOrCall") async def either_sleep_or_call(self): get_fn = counter.Counter._restate_handlers["get"].fn if do_left_action(): - await Context.sleep(timedelta(milliseconds=100)) + await Restate.sleep(timedelta(milliseconds=100)) else: - await Context.object_call(get_fn, key="abc", arg=None) - await Context.sleep(timedelta(milliseconds=100)) + await Restate.object_call(get_fn, key="abc", arg=None) + await Restate.sleep(timedelta(milliseconds=100)) increment_counter() diff --git a/test-services-cls/services/proxy.py b/test-services-cls/services/proxy.py index 902667c..22fc0a7 100644 --- a/test-services-cls/services/proxy.py +++ b/test-services-cls/services/proxy.py @@ -14,7 +14,7 @@ from datetime import timedelta from typing import TypedDict, Optional, Iterable -from restate.cls import Service, handler, Context +from restate.cls import Service, handler, Restate class ProxyRequest(TypedDict): @@ -36,7 +36,7 @@ class Proxy(Service, name="Proxy"): @handler(name="call") async def proxy_call(self, req: ProxyRequest) -> Iterable[int]: - response = await Context.generic_call( + response = await Restate.generic_call( req["serviceName"], req["handlerName"], bytes(req["message"]), @@ -51,7 +51,7 @@ async def one_way_call(self, req: ProxyRequest) -> str: delayMillis = req.get("delayMillis") if delayMillis is not None: send_delay = timedelta(milliseconds=delayMillis) - handle = Context.generic_send( + handle = Restate.generic_send( req["serviceName"], req["handlerName"], bytes(req["message"]), @@ -72,7 +72,7 @@ async def many_calls(self, requests: Iterable[ManyCallRequest]): delayMillis = req["proxyRequest"].get("delayMillis") if delayMillis is not None: send_delay = timedelta(milliseconds=delayMillis) - Context.generic_send( + Restate.generic_send( req["proxyRequest"]["serviceName"], req["proxyRequest"]["handlerName"], bytes(req["proxyRequest"]["message"]), @@ -81,7 +81,7 @@ async def many_calls(self, requests: Iterable[ManyCallRequest]): idempotency_key=req["proxyRequest"].get("idempotencyKey"), ) else: - awaitable = Context.generic_call( + awaitable = Restate.generic_call( req["proxyRequest"]["serviceName"], req["proxyRequest"]["handlerName"], bytes(req["proxyRequest"]["message"]), diff --git a/test-services-cls/services/test_utils.py b/test-services-cls/services/test_utils.py index 87c21c0..19cc49c 100644 --- a/test-services-cls/services/test_utils.py +++ b/test-services-cls/services/test_utils.py @@ -14,7 +14,7 @@ from datetime import timedelta from typing import Dict, List -from restate.cls import Service, handler, Context +from restate.cls import Service, handler, Restate from restate.serde import BytesSerde @@ -30,11 +30,11 @@ async def uppercase_echo(self, input: str) -> str: @handler(name="echoHeaders") async def echo_headers(self) -> Dict[str, str]: - return Context.request().headers + return Restate.request().headers @handler(name="sleepConcurrently") async def sleep_concurrently(self, millis_duration: List[int]) -> None: - timers = [Context.sleep(timedelta(milliseconds=duration)) for duration in millis_duration] + timers = [Restate.sleep(timedelta(milliseconds=duration)) for duration in millis_duration] for timer in timers: await timer @@ -47,13 +47,13 @@ def effect(): invoked_side_effects += 1 for _ in range(increments): - await Context.run("count", effect) + await Restate.run("count", effect) return invoked_side_effects @handler(name="cancelInvocation") async def cancel_invocation(self, invocation_id: str) -> None: - Context.cancel_invocation(invocation_id) + Restate.cancel_invocation(invocation_id) @handler( name="rawEcho", diff --git a/test-services-cls/services/virtual_object_command_interpreter.py b/test-services-cls/services/virtual_object_command_interpreter.py index c40127f..d51c0f6 100644 --- a/test-services-cls/services/virtual_object_command_interpreter.py +++ b/test-services-cls/services/virtual_object_command_interpreter.py @@ -15,7 +15,7 @@ import os from datetime import timedelta from typing import Iterable, List, Union, TypedDict, Literal, Any -from restate.cls import VirtualObject, handler, shared, Context +from restate.cls import VirtualObject, handler, shared, Restate from restate import RestateDurableFuture, RestateDurableSleepFuture from restate import select, wait_completed, as_completed from restate.exceptions import TerminalError @@ -88,43 +88,43 @@ class InterpretRequest(TypedDict): def to_durable_future(cmd: AwaitableCommand) -> RestateDurableFuture[Any]: if cmd["type"] == "createAwakeable": - awk_id, awakeable = Context.awakeable() - Context.set("awk-" + cmd["awakeableKey"], awk_id) + awk_id, awakeable = Restate.awakeable() + Restate.set("awk-" + cmd["awakeableKey"], awk_id) return awakeable elif cmd["type"] == "sleep": - return Context.sleep(timedelta(milliseconds=cmd["timeoutMillis"])) + return Restate.sleep(timedelta(milliseconds=cmd["timeoutMillis"])) elif cmd["type"] == "runThrowTerminalException": def side_effect(reason: str): raise TerminalError(message=reason) - res = Context.run_typed("run should fail command", side_effect, reason=cmd["reason"]) + res = Restate.run("run should fail command", side_effect, reason=cmd["reason"]) return res async def _resolve_awakeable_impl(req: ResolveAwakeable): - awk_id = await Context.get("awk-" + req["awakeableKey"]) + awk_id = await Restate.get("awk-" + req["awakeableKey"]) if not awk_id: raise TerminalError(message="No awakeable is registered") - Context.resolve_awakeable(awk_id, req["value"]) + Restate.resolve_awakeable(awk_id, req["value"]) async def _reject_awakeable_impl(req: RejectAwakeable): - awk_id = await Context.get("awk-" + req["awakeableKey"]) + awk_id = await Restate.get("awk-" + req["awakeableKey"]) if not awk_id: raise TerminalError(message="No awakeable is registered") - Context.reject_awakeable(awk_id, req["reason"]) + Restate.reject_awakeable(awk_id, req["reason"]) class VirtualObjectCommandInterpreter(VirtualObject, name="VirtualObjectCommandInterpreter"): @shared(name="getResults") async def get_results(self) -> List[str]: - return (await Context.get("results")) or [] + return (await Restate.get("results")) or [] @shared(name="hasAwakeable") async def has_awakeable(self, awk_key: str) -> bool: - awk_id = await Context.get("awk-" + awk_key) + awk_id = await Restate.get("awk-" + awk_key) if awk_id: return True return False @@ -143,9 +143,9 @@ async def interpret_commands(self, req: InterpretRequest): for cmd in req["commands"]: if cmd["type"] == "awaitAwakeableOrTimeout": - awk_id, awakeable = Context.awakeable() - Context.set("awk-" + cmd["awakeableKey"], awk_id) - match await select(awakeable=awakeable, timeout=Context.sleep(timedelta(milliseconds=cmd["timeoutMillis"]))): + awk_id, awakeable = Restate.awakeable() + Restate.set("awk-" + cmd["awakeableKey"], awk_id) + match await select(awakeable=awakeable, timeout=Restate.sleep(timedelta(milliseconds=cmd["timeoutMillis"]))): case ["awakeable", awk_res]: result = awk_res case ["timeout", _]: @@ -162,7 +162,7 @@ async def interpret_commands(self, req: InterpretRequest): def side_effect(env_name: str): return os.environ.get(env_name, "") - result = await Context.run_typed("get_env", side_effect, env_name=env_name) + result = await Restate.run("get_env", side_effect, env_name=env_name) elif cmd["type"] == "awaitOne": awaitable = to_durable_future(cmd["command"]) # We need this dance because the Python SDK doesn't support .map on futures @@ -196,8 +196,8 @@ def side_effect(env_name: str): pass # Direct state access (same invocation, not RPC) - last_results = (await Context.get("results")) or [] + last_results = (await Restate.get("results")) or [] last_results.append(result) - Context.set("results", last_results) + Restate.set("results", last_results) return result From d9c09c7ba0e07a57486c03e681ffd8aee7104c64 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Wed, 1 Apr 2026 18:47:58 +0200 Subject: [PATCH 06/12] Add experimental class based interface --- python/restate/cls.py | 73 ++++++++++++------- python/restate/endpoint.py | 11 +-- test-services-cls/services/__init__.py | 2 +- test-services-cls/services/cancel_test.py | 9 +-- test-services-cls/services/failing.py | 3 +- test-services-cls/services/kill_test.py | 11 +-- test-services-cls/services/non_determinism.py | 20 ++--- 7 files changed, 69 insertions(+), 60 deletions(-) diff --git a/python/restate/cls.py b/python/restate/cls.py index d077234..b15994e 100644 --- a/python/restate/cls.py +++ b/python/restate/cls.py @@ -48,7 +48,7 @@ async def greet(self, name: str) -> str: 4. A **companion service object** (a plain ``restate.Service``, ``restate.VirtualObject``, or ``restate.Workflow``) is created and - stored on the class as ``Greeter._restate_service``. This companion + stored on the class as ``Greeter.__restate_service__``. This companion holds the handler dict and all service-level configuration. Then at **bind time** (``restate.app([...])`` → ``Endpoint.bind``): @@ -62,7 +62,7 @@ async def greet(self, name: str) -> str: 7. ``_bind_instance(instance)`` is called, which replaces each handler's placeholder wrapper with a real one that **closes over the instance**. - The companion ``_restate_service`` is then registered with the + The companion ``__restate_service__`` is then registered with the endpoint just like any decorator-based service. At **invocation time**, Restate calls the wrapper which dispatches @@ -98,9 +98,10 @@ async def count(self) -> int: from __future__ import annotations +import copy import inspect import sys -from dataclasses import dataclass, field +from dataclasses import dataclass, field, replace from datetime import timedelta from functools import wraps from typing import Any, AsyncContextManager, Callable, Dict, List, Literal, Optional, TypeVar @@ -367,6 +368,9 @@ def _process_class( metadata=config.metadata, ) handlers: Dict[str, Any] = {} + # Proxy lookup index: maps both handler names and Python method names. + # Kept separate from svc.handlers which only has canonical handler names. + handler_index: Dict[str, Any] = {} for attr_name, attr_value in list(cls.__dict__.items()): meta: Optional[_HandlerMeta] = getattr(attr_value, _HANDLER_MARKER, None) @@ -380,13 +384,16 @@ def _process_class( # Placeholder wrapper — replaced by _bind_instance() at bind time # with one that closes over the actual instance. @wraps(method) - async def wrapper(ctx, *args): + async def wrapper(_ctx, *args, _handler_name=handler_name): raise RuntimeError( - f"Handler {handler_name} called before instance was bound. " + f"Handler {_handler_name} called before instance was bound. " f"Use restate.app([{cls.__name__}(...)]) to bind an instance." ) - # Use the original method's signature for type/serde inspection + # Use the original method's signature for type/serde inspection. + # Note: arity is derived from this signature (including `self`), + # which matches the (ctx, arg) calling convention of invoke_handler — + # `self` occupies the same slot as `ctx` in the decorator-based API. sig = inspect.signature(method, eval_str=True) handler_io: HandlerIO = HandlerIO( accept=meta.accept, @@ -422,9 +429,12 @@ async def wrapper(ctx, *args): context_managers=combined_context_managers, ) handlers[h.name] = h + handler_index[h.name] = h + if method.__name__ != h.name: + handler_index[method.__name__] = h - # Store handlers on the class for proxy access - cls._restate_handlers = handlers # type: ignore[attr-defined] + # Store handler index on the class for proxy lookup (method name + handler name) + cls.__restate_handlers__ = handler_index # type: ignore[attr-defined] # Build companion service object of the original type svc: _OriginalService | _OriginalVirtualObject | _OriginalWorkflow @@ -473,18 +483,21 @@ async def wrapper(ctx, *args): raise ValueError(f"Unknown service kind: {service_kind}") svc.handlers = handlers - cls._restate_service = svc # type: ignore[attr-defined] + cls.__restate_service__ = svc # type: ignore[attr-defined] def _bind_instance(instance: Any) -> None: """Create real handler wrappers that close over *instance*. Called from ``Endpoint.bind()`` once the instance is known. - Replaces the placeholder ``fn`` on each handler with a wrapper - that dispatches to the bound method on the instance. + Creates a **copy** of the companion service with new handler objects + whose ``fn`` dispatches to the bound method on the instance. + The copy is stored on the *instance* so that binding a second instance + of the same class (e.g. to a different endpoint) does not clobber the first. """ cls = type(instance) - svc = cls._restate_service # type: ignore[attr-defined] + svc = cls.__restate_service__ # type: ignore[attr-defined] + new_handlers: Dict[str, Any] = {} for handler_name, h in svc.handlers.items(): method = cls.__dict__.get(handler_name) if method is None: @@ -495,15 +508,23 @@ def _bind_instance(instance: Any) -> None: method = attr break if method is None: + new_handlers[handler_name] = h continue @wraps(method) - async def wrapper(ctx, *args, _method=method, _inst=instance): + async def wrapper(_ctx, *args, _method=method, _inst=instance): + # _ctx is passed by invoke_handler but unused here; + # context is accessed via Restate._ctx() (contextvars). if args: return await _method(_inst, *args) return await _method(_inst) - h.fn = wrapper + new_handlers[handler_name] = replace(h, fn=wrapper) + + # Create a shallow copy of the companion service with the new handlers + bound_svc = copy.copy(svc) + bound_svc.handlers = new_handlers + instance.__restate_service__ = bound_svc # ── Fluent RPC proxy classes ────────────────────────────────────────────── @@ -516,7 +537,7 @@ def __init__(self, cls: type) -> None: self._cls = cls def __getattr__(self, name: str): - handlers = getattr(self._cls, "_restate_handlers", {}) + handlers = getattr(self._cls, "__restate_handlers__", {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -540,7 +561,7 @@ def __init__(self, cls: type, delay: Optional[timedelta] = None) -> None: self._delay = delay def __getattr__(self, name: str): - handlers = getattr(self._cls, "_restate_handlers", {}) + handlers = getattr(self._cls, "__restate_handlers__", {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -564,7 +585,7 @@ def __init__(self, cls: type, key: str) -> None: self._key = key def __getattr__(self, name: str): - handlers = getattr(self._cls, "_restate_handlers", {}) + handlers = getattr(self._cls, "__restate_handlers__", {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -589,7 +610,7 @@ def __init__(self, cls: type, key: str, delay: Optional[timedelta] = None) -> No self._delay = delay def __getattr__(self, name: str): - handlers = getattr(self._cls, "_restate_handlers", {}) + handlers = getattr(self._cls, "__restate_handlers__", {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -613,7 +634,7 @@ def __init__(self, cls: type, key: str) -> None: self._key = key def __getattr__(self, name: str): - handlers = getattr(self._cls, "_restate_handlers", {}) + handlers = getattr(self._cls, "__restate_handlers__", {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -638,7 +659,7 @@ def __init__(self, cls: type, key: str, delay: Optional[timedelta] = None) -> No self._delay = delay def __getattr__(self, name: str): - handlers = getattr(self._cls, "_restate_handlers", {}) + handlers = getattr(self._cls, "__restate_handlers__", {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -670,8 +691,8 @@ async def greet(self, name: str) -> str: app = restate.app([Greeter]) """ - _restate_service: _OriginalService - _restate_handlers: Dict[str, Any] + __restate_service__: _OriginalService + __restate_handlers__: Dict[str, Any] def __init_subclass__( cls, @@ -733,8 +754,8 @@ async def increment(self, value: int) -> int: ... app = restate.app([Counter]) """ - _restate_service: _OriginalVirtualObject - _restate_handlers: Dict[str, Any] + __restate_service__: _OriginalVirtualObject + __restate_handlers__: Dict[str, Any] def __init_subclass__( cls, @@ -794,8 +815,8 @@ async def pay(self, amount: int) -> dict: ... app = restate.app([Payment]) """ - _restate_service: _OriginalWorkflow - _restate_handlers: Dict[str, Any] + __restate_service__: _OriginalWorkflow + __restate_handlers__: Dict[str, Any] def __init_subclass__( cls, diff --git a/python/restate/endpoint.py b/python/restate/endpoint.py index 5d6760d..ae9afec 100644 --- a/python/restate/endpoint.py +++ b/python/restate/endpoint.py @@ -60,7 +60,7 @@ def bind(self, *services: typing.Any): """ for service in services: # Support class-based services: extract companion object. - if isinstance(service, type) and hasattr(service, "_restate_service"): + if isinstance(service, type) and hasattr(service, "__restate_service__"): # Class passed — instantiate it and bind from restate.cls import _bind_instance # pylint: disable=C0415 @@ -72,15 +72,16 @@ def bind(self, *services: typing.Any): f"Pass an instance instead: restate.app([{service.__name__}(...)])" ) from e _bind_instance(instance) - actual = service._restate_service # type: ignore[attr-defined] - elif not isinstance(service, type) and hasattr(type(service), "_restate_service"): + # Read from instance — _bind_instance stores a per-instance copy there + actual = instance.__restate_service__ # type: ignore[attr-defined] + elif not isinstance(service, type) and hasattr(type(service), "__restate_service__"): # Instance passed — bind it from restate.cls import _bind_instance # pylint: disable=C0415 _bind_instance(service) - actual = type(service)._restate_service # type: ignore[attr-defined] + actual = service.__restate_service__ # type: ignore[attr-defined] else: - actual = getattr(service, "_restate_service", service) + actual = getattr(service, "__restate_service__", service) if actual.name in self.services: raise ValueError(f"Service {actual.name} already exists") if isinstance(actual, (Service, VirtualObject, Workflow)): diff --git a/test-services-cls/services/__init__.py b/test-services-cls/services/__init__.py index fd31a2e..17c9497 100644 --- a/test-services-cls/services/__init__.py +++ b/test-services-cls/services/__init__.py @@ -25,7 +25,7 @@ def list_services(bindings): """List all services from local bindings — supports both class-based and decorator-based.""" result = {} for _, obj in bindings.items(): - svc = getattr(obj, '_restate_service', obj) + svc = getattr(obj, '__restate_service__', obj) if isinstance(svc, (_OrigService, _OrigObject, _OrigWorkflow)): result[svc.name] = obj return result diff --git a/test-services-cls/services/cancel_test.py b/test-services-cls/services/cancel_test.py index d31f369..ab6c103 100644 --- a/test-services-cls/services/cancel_test.py +++ b/test-services-cls/services/cancel_test.py @@ -26,9 +26,8 @@ class CancelTestRunner(VirtualObject, name="CancelTestRunner"): @handler(name="startTest") async def start_test(self, op: BlockingOperation): - block_fn = CancelTestBlockingService._restate_handlers["block"].fn try: - await Restate.object_call(block_fn, key=Restate.key(), arg=op) + await CancelTestBlockingService.call(Restate.key()).block(op) except TerminalError as t: if t.status_code == 409: Restate.set("state", True) @@ -47,14 +46,12 @@ class CancelTestBlockingService(VirtualObject, name="CancelTestBlockingService") @handler async def block(self, op: BlockingOperation): - hold_fn = awakeable_holder.AwakeableHolder._restate_handlers["hold"].fn name, awakeable = Restate.awakeable() - Restate.object_send(hold_fn, key=Restate.key(), arg=name) + awakeable_holder.AwakeableHolder.send(Restate.key()).hold(name) # type: ignore[unused-coroutine] await awakeable - block_fn = CancelTestBlockingService._restate_handlers["block"].fn if op == "CALL": - await Restate.object_call(block_fn, key=Restate.key(), arg=op) + await CancelTestBlockingService.call(Restate.key()).block(op) elif op == "SLEEP": await Restate.sleep(timedelta(days=1024)) elif op == "AWAKEABLE": diff --git a/test-services-cls/services/failing.py b/test-services-cls/services/failing.py index 8d36382..60f1b28 100644 --- a/test-services-cls/services/failing.py +++ b/test-services-cls/services/failing.py @@ -33,8 +33,7 @@ async def terminally_failing_call(self, msg: str): @handler(name="callTerminallyFailingCall") async def call_terminally_failing_call(self, msg: str) -> str: - fn = Failing._restate_handlers["terminallyFailingCall"].fn - await Restate.object_call(fn, key="random-583e1bf2", arg=msg) + await Failing.call("random-583e1bf2").terminally_failing_call(msg) raise Exception("Should not reach here") @handler(name="failingCallWithEventualSuccess") diff --git a/test-services-cls/services/kill_test.py b/test-services-cls/services/kill_test.py index 70254fb..8e16e85 100644 --- a/test-services-cls/services/kill_test.py +++ b/test-services-cls/services/kill_test.py @@ -14,28 +14,25 @@ from restate.cls import VirtualObject, handler, Restate -from . import awakeable_holder +from .awakeable_holder import AwakeableHolder class KillTestRunner(VirtualObject, name="KillTestRunner"): @handler(name="startCallTree") async def start_call_tree(self): - fn = KillTestSingleton._restate_handlers["recursiveCall"].fn - await Restate.object_call(fn, key=Restate.key(), arg=None) + await KillTestSingleton.call(Restate.key()).recursive_call() class KillTestSingleton(VirtualObject, name="KillTestSingleton"): @handler(name="recursiveCall") async def recursive_call(self): - hold_fn = awakeable_holder.AwakeableHolder._restate_handlers["hold"].fn name, promise = Restate.awakeable() - Restate.object_send(hold_fn, key=Restate.key(), arg=name) + AwakeableHolder.send(Restate.key()).hold(name) # type: ignore[unused-coroutine] await promise - fn = KillTestSingleton._restate_handlers["recursiveCall"].fn - await Restate.object_call(fn, key=Restate.key(), arg=None) + await KillTestSingleton.call(Restate.key()).recursive_call() @handler(name="isUnlocked") async def is_unlocked(self): diff --git a/test-services-cls/services/non_determinism.py b/test-services-cls/services/non_determinism.py index bc3a8ff..65b90bd 100644 --- a/test-services-cls/services/non_determinism.py +++ b/test-services-cls/services/non_determinism.py @@ -16,7 +16,7 @@ from typing import Dict from restate.cls import VirtualObject, handler, Restate -from . import counter +from .counter import Counter invoke_counts: Dict[str, int] = {} @@ -28,8 +28,7 @@ def do_left_action() -> bool: def increment_counter(): - add_fn = counter.Counter._restate_handlers["add"].fn - Restate.object_send(add_fn, key=Restate.key(), arg=1) + Counter.send(Restate.key()).add(1) # type: ignore[unused-coroutine] class NonDeterministic(VirtualObject, name="NonDeterministic"): @@ -45,32 +44,27 @@ async def set_different_key(self): @handler(name="backgroundInvokeWithDifferentTargets") async def background_invoke_with_different_targets(self): - get_fn = counter.Counter._restate_handlers["get"].fn - reset_fn = counter.Counter._restate_handlers["reset"].fn if do_left_action(): - Restate.object_send(get_fn, key="abc", arg=None) + Counter.send("abc").get() # type: ignore[unused-coroutine] else: - Restate.object_send(reset_fn, key="abc", arg=None) + Counter.send("abc").reset() # type: ignore[unused-coroutine] await Restate.sleep(timedelta(milliseconds=100)) increment_counter() @handler(name="callDifferentMethod") async def call_different_method(self): - get_fn = counter.Counter._restate_handlers["get"].fn - reset_fn = counter.Counter._restate_handlers["reset"].fn if do_left_action(): - await Restate.object_call(get_fn, key="abc", arg=None) + await Counter.call("abc").get() else: - await Restate.object_call(reset_fn, key="abc", arg=None) + await Counter.call("abc").reset() await Restate.sleep(timedelta(milliseconds=100)) increment_counter() @handler(name="eitherSleepOrCall") async def either_sleep_or_call(self): - get_fn = counter.Counter._restate_handlers["get"].fn if do_left_action(): await Restate.sleep(timedelta(milliseconds=100)) else: - await Restate.object_call(get_fn, key="abc", arg=None) + await Counter.call("abc").get() await Restate.sleep(timedelta(milliseconds=100)) increment_counter() From c5ecda35f8d9ca9f7dd17e5f4116dc11bde1109d Mon Sep 17 00:00:00 2001 From: igalshilman Date: Wed, 1 Apr 2026 20:03:59 +0200 Subject: [PATCH 07/12] Add an example --- examples/class_based_greeter.py | 33 +++++++++++++++- python/restate/cls.py | 55 +++++++++++++++++++------- python/restate/endpoint.py | 11 +++--- test-services-cls/services/__init__.py | 3 +- 4 files changed, 80 insertions(+), 22 deletions(-) diff --git a/examples/class_based_greeter.py b/examples/class_based_greeter.py index 070febc..6480b98 100644 --- a/examples/class_based_greeter.py +++ b/examples/class_based_greeter.py @@ -108,4 +108,35 @@ async def translate() -> GreetingResponse: return await Restate.run("translate", translate) -app = restate.app([Greeter, Counter, PaymentWorkflow, OrderProcessor, PydanticGreeter("Restate")]) +# ── Service contract without implementation ── +# +# Define the shape of a service (handlers + types) without providing an +# implementation. The proxy only needs class-level metadata created by +# __init_subclass__, so you can call a service that lives in another +# process — or is written in another language — just from its contract. + + +class ExternalInventory(VirtualObject, name="Inventory"): + """Contract for an Inventory service whose implementation lives elsewhere.""" + + @handler + async def reserve(self, item_id: str) -> bool: ... # type: ignore[empty-body] + + @handler + async def current_stock(self) -> int: ... # type: ignore[empty-body] + + +class Shop(Service): + """Demonstrates calling a service defined only by its contract.""" + + @handler + async def buy(self, item_id: str) -> str: + # Full IDE autocomplete — reserve(str) -> bool, current_stock() -> int + ok = await ExternalInventory.call(item_id).reserve(item_id) + if not ok: + return "out of stock" + stock = await ExternalInventory.call(item_id).current_stock() + return f"reserved (remaining: {stock})" + + +app = restate.app([Greeter, Counter, PaymentWorkflow, OrderProcessor, PydanticGreeter("Restate"), Shop]) diff --git a/python/restate/cls.py b/python/restate/cls.py index b15994e..c15a2d3 100644 --- a/python/restate/cls.py +++ b/python/restate/cls.py @@ -104,14 +104,27 @@ async def count(self) -> int: from dataclasses import dataclass, field, replace from datetime import timedelta from functools import wraps -from typing import Any, AsyncContextManager, Callable, Dict, List, Literal, Optional, TypeVar +from typing import ( + Any, + AsyncContextManager, + Callable, + Coroutine, + Dict, + List, + Literal, + Optional, + ParamSpec, + TypeVar, + Union, +) if sys.version_info >= (3, 11): from typing import Self else: from typing_extensions import Self -from restate.handler import HandlerIO, ServiceTag, make_handler +from restate.context import RestateDurableFuture, RunOptions +from restate.handler import RESTATE_UNIQUE_HANDLER_SYMBOL, HandlerIO, ServiceTag, make_handler from restate.retry_policy import InvocationRetryPolicy from restate.serde import DefaultSerde, Serde @@ -124,10 +137,13 @@ async def count(self) -> int: I = TypeVar("I") O = TypeVar("O") T = TypeVar("T") +P = ParamSpec("P") # ── Handler marker decorators ────────────────────────────────────────────── _HANDLER_MARKER = "__restate_handler_meta__" +_SERVICE_ATTR = "__restate_service__" +_HANDLERS_ATTR = "__restate_handlers__" _MISSING = object() @@ -434,7 +450,7 @@ async def wrapper(_ctx, *args, _handler_name=handler_name): handler_index[method.__name__] = h # Store handler index on the class for proxy lookup (method name + handler name) - cls.__restate_handlers__ = handler_index # type: ignore[attr-defined] + setattr(cls, _HANDLERS_ATTR, handler_index) # Build companion service object of the original type svc: _OriginalService | _OriginalVirtualObject | _OriginalWorkflow @@ -483,7 +499,7 @@ async def wrapper(_ctx, *args, _handler_name=handler_name): raise ValueError(f"Unknown service kind: {service_kind}") svc.handlers = handlers - cls.__restate_service__ = svc # type: ignore[attr-defined] + setattr(cls, _SERVICE_ATTR, svc) def _bind_instance(instance: Any) -> None: @@ -496,7 +512,7 @@ def _bind_instance(instance: Any) -> None: of the same class (e.g. to a different endpoint) does not clobber the first. """ cls = type(instance) - svc = cls.__restate_service__ # type: ignore[attr-defined] + svc = getattr(cls, _SERVICE_ATTR) new_handlers: Dict[str, Any] = {} for handler_name, h in svc.handlers.items(): method = cls.__dict__.get(handler_name) @@ -519,12 +535,14 @@ async def wrapper(_ctx, *args, _method=method, _inst=instance): return await _method(_inst, *args) return await _method(_inst) - new_handlers[handler_name] = replace(h, fn=wrapper) + new_h = replace(h, fn=wrapper) + vars(wrapper)[RESTATE_UNIQUE_HANDLER_SYMBOL] = new_h + new_handlers[handler_name] = new_h # Create a shallow copy of the companion service with the new handlers bound_svc = copy.copy(svc) bound_svc.handlers = new_handlers - instance.__restate_service__ = bound_svc + setattr(instance, _SERVICE_ATTR, bound_svc) # ── Fluent RPC proxy classes ────────────────────────────────────────────── @@ -537,7 +555,7 @@ def __init__(self, cls: type) -> None: self._cls = cls def __getattr__(self, name: str): - handlers = getattr(self._cls, "__restate_handlers__", {}) + handlers = getattr(self._cls, _HANDLERS_ATTR, {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -561,7 +579,7 @@ def __init__(self, cls: type, delay: Optional[timedelta] = None) -> None: self._delay = delay def __getattr__(self, name: str): - handlers = getattr(self._cls, "__restate_handlers__", {}) + handlers = getattr(self._cls, _HANDLERS_ATTR, {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -585,7 +603,7 @@ def __init__(self, cls: type, key: str) -> None: self._key = key def __getattr__(self, name: str): - handlers = getattr(self._cls, "__restate_handlers__", {}) + handlers = getattr(self._cls, _HANDLERS_ATTR, {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -610,7 +628,7 @@ def __init__(self, cls: type, key: str, delay: Optional[timedelta] = None) -> No self._delay = delay def __getattr__(self, name: str): - handlers = getattr(self._cls, "__restate_handlers__", {}) + handlers = getattr(self._cls, _HANDLERS_ATTR, {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -634,7 +652,7 @@ def __init__(self, cls: type, key: str) -> None: self._key = key def __getattr__(self, name: str): - handlers = getattr(self._cls, "__restate_handlers__", {}) + handlers = getattr(self._cls, _HANDLERS_ATTR, {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -659,7 +677,7 @@ def __init__(self, cls: type, key: str, delay: Optional[timedelta] = None) -> No self._delay = delay def __getattr__(self, name: str): - handlers = getattr(self._cls, "__restate_handlers__", {}) + handlers = getattr(self._cls, _HANDLERS_ATTR, {}) h = handlers.get(name) if h is None: raise AttributeError(f"No handler '{name}' on {self._cls.__name__}") @@ -946,9 +964,16 @@ def time() -> Any: # ── Durable execution ── @staticmethod - def run(name: str, action: Any, *args: Any, **kwargs: Any) -> Any: + def run( + name: str, + action: Union[Callable[P, Coroutine[Any, Any, T]], Callable[P, T]], + options: RunOptions[T] = RunOptions(), + /, + *args: P.args, + **kwargs: P.kwargs, + ) -> RestateDurableFuture[T]: """Run a durable side effect with typed arguments.""" - return Restate._ctx().run_typed(name, action, *args, **kwargs) + return Restate._ctx().run_typed(name, action, options, *args, **kwargs) @staticmethod def sleep(delta: timedelta, name: Optional[str] = None) -> Any: diff --git a/python/restate/endpoint.py b/python/restate/endpoint.py index ae9afec..fc58e9d 100644 --- a/python/restate/endpoint.py +++ b/python/restate/endpoint.py @@ -14,6 +14,7 @@ import typing +from restate.cls import _SERVICE_ATTR from restate.service import Service from restate.object import VirtualObject from restate.workflow import Workflow @@ -60,7 +61,7 @@ def bind(self, *services: typing.Any): """ for service in services: # Support class-based services: extract companion object. - if isinstance(service, type) and hasattr(service, "__restate_service__"): + if isinstance(service, type) and hasattr(service, _SERVICE_ATTR): # Class passed — instantiate it and bind from restate.cls import _bind_instance # pylint: disable=C0415 @@ -73,15 +74,15 @@ def bind(self, *services: typing.Any): ) from e _bind_instance(instance) # Read from instance — _bind_instance stores a per-instance copy there - actual = instance.__restate_service__ # type: ignore[attr-defined] - elif not isinstance(service, type) and hasattr(type(service), "__restate_service__"): + actual = getattr(instance, _SERVICE_ATTR) + elif not isinstance(service, type) and hasattr(type(service), _SERVICE_ATTR): # Instance passed — bind it from restate.cls import _bind_instance # pylint: disable=C0415 _bind_instance(service) - actual = service.__restate_service__ # type: ignore[attr-defined] + actual = getattr(service, _SERVICE_ATTR) else: - actual = getattr(service, "__restate_service__", service) + actual = service if actual.name in self.services: raise ValueError(f"Service {actual.name} already exists") if isinstance(actual, (Service, VirtualObject, Workflow)): diff --git a/test-services-cls/services/__init__.py b/test-services-cls/services/__init__.py index 17c9497..e49fdb0 100644 --- a/test-services-cls/services/__init__.py +++ b/test-services-cls/services/__init__.py @@ -1,3 +1,4 @@ +from restate.cls import _SERVICE_ATTR from restate.service import Service as _OrigService from restate.object import VirtualObject as _OrigObject from restate.workflow import Workflow as _OrigWorkflow @@ -25,7 +26,7 @@ def list_services(bindings): """List all services from local bindings — supports both class-based and decorator-based.""" result = {} for _, obj in bindings.items(): - svc = getattr(obj, '__restate_service__', obj) + svc = getattr(obj, _SERVICE_ATTR, obj) if isinstance(svc, (_OrigService, _OrigObject, _OrigWorkflow)): result[svc.name] = obj return result From fd02ddfc041758b336e70cffbaa89cd2df70013d Mon Sep 17 00:00:00 2001 From: igalshilman Date: Wed, 1 Apr 2026 20:08:01 +0200 Subject: [PATCH 08/12] Add it for ci --- .github/workflows/integration-cls.yaml | 136 +++++++++++++++++++++++++ 1 file changed, 136 insertions(+) create mode 100644 .github/workflows/integration-cls.yaml diff --git a/.github/workflows/integration-cls.yaml b/.github/workflows/integration-cls.yaml new file mode 100644 index 0000000..1b0f3b5 --- /dev/null +++ b/.github/workflows/integration-cls.yaml @@ -0,0 +1,136 @@ +name: Integration (class-based API) + +# Controls when the workflow will run +on: + pull_request: + push: + branches: + - main + schedule: + - cron: "0 */6 * * *" # Every 6 hours + workflow_dispatch: + inputs: + restateCommit: + description: "restate commit" + required: false + default: "" + type: string + restateImage: + description: "restate image, superseded by restate commit" + required: false + default: "ghcr.io/restatedev/restate:main" + type: string + serviceImage: + description: "service image, if provided it will skip building the image from sdk main branch" + required: false + default: "" + type: string + workflow_call: + inputs: + restateCommit: + description: "restate commit" + required: false + default: "" + type: string + restateImage: + description: "restate image, superseded by restate commit" + required: false + default: "ghcr.io/restatedev/restate:main" + type: string + serviceImage: + description: "service image, if provided it will skip building the image from sdk main branch" + required: false + default: "" + type: string + +jobs: + sdk-test-suite: + if: github.repository_owner == 'restatedev' + runs-on: warp-ubuntu-latest-x64-4x + name: Features integration test (class-based API) + permissions: + contents: read + issues: read + checks: write + pull-requests: write + actions: read + + steps: + - uses: actions/checkout@v4 + with: + repository: restatedev/sdk-python + + - name: Set up Docker containerd snapshotter + uses: docker/setup-docker-action@v4 + with: + version: "v28.5.2" + set-host: true + daemon-config: | + { + "features": { + "containerd-snapshotter": true + } + } + + ### Download the Restate container image, if needed + # Setup restate snapshot if necessary + # Due to https://github.com/actions/upload-artifact/issues/53 + # We must use download-artifact to get artifacts created during *this* workflow run, ie by workflow call + - name: Download restate snapshot from in-progress workflow + if: ${{ inputs.restateCommit != '' && github.event_name != 'workflow_dispatch' }} + uses: actions/download-artifact@v4 + with: + name: restate.tar + # In the workflow dispatch case where the artifact was created in a previous run, we can download as normal + - name: Download restate snapshot from completed workflow + if: ${{ inputs.restateCommit != '' && github.event_name == 'workflow_dispatch' }} + uses: dawidd6/action-download-artifact@v3 + with: + repo: restatedev/restate + workflow: ci.yml + commit: ${{ inputs.restateCommit }} + name: restate.tar + - name: Install restate snapshot + if: ${{ inputs.restateCommit != '' }} + run: | + output=$(docker load --input restate.tar | head -n 1) + docker tag "${output#*: }" "localhost/restatedev/restate-commit-download:latest" + docker image ls -a + + # Either build the docker image from source + - name: Set up QEMU + if: ${{ inputs.serviceImage == '' }} + uses: docker/setup-qemu-action@v3 + - name: Set up Docker Buildx + if: ${{ inputs.serviceImage == '' }} + uses: docker/setup-buildx-action@v3 + with: + driver-opts: | + network=host + - name: Build Python test-services-cls image + if: ${{ inputs.serviceImage == '' }} + id: build + uses: docker/build-push-action@v6 + with: + context: . + file: "test-services-cls/Dockerfile" + push: false + load: true + tags: restatedev/test-services-python-cls + cache-from: type=gha,url=http://127.0.0.1:49160/,version=1,scope=${{ github.workflow }} + cache-to: type=gha,url=http://127.0.0.1:49160/,mode=max,version=1,scope=${{ github.workflow }} + + # Or use the provided one + - name: Pull test services image + if: ${{ inputs.serviceImage != '' }} + shell: bash + run: docker pull ${{ inputs.serviceImage }} + + - name: Run test tool + uses: restatedev/sdk-test-suite@v3.4 + with: + restateContainerImage: ${{ inputs.restateCommit != '' && 'localhost/restatedev/restate-commit-download:latest' || (inputs.restateImage != '' && inputs.restateImage || 'ghcr.io/restatedev/restate:main') }} + serviceContainerImage: ${{ inputs.serviceImage != '' && inputs.serviceImage || 'restatedev/test-services-python-cls' }} + exclusionsFile: "test-services-cls/exclusions.yaml" + testArtifactOutput: "sdk-python-cls-integration-test-report" + serviceContainerEnvFile: "test-services-cls/.env" From b8e9b7b15e7c4ee8dcfb7ed64b984dfc02149536 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Wed, 1 Apr 2026 20:11:13 +0200 Subject: [PATCH 09/12] Remove hardcoded whl --- examples/pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/pyproject.toml b/examples/pyproject.toml index 4db9d1a..ff9b22b 100644 --- a/examples/pyproject.toml +++ b/examples/pyproject.toml @@ -12,4 +12,4 @@ dependencies = [ ] [tool.uv.sources] -restate-sdk = { path = "../dist/restate_sdk-0.14.2-cp313-cp313-macosx_14_0_arm64.whl" } +restate-sdk = { path = ".." } From d517cf096622587cb9df6f28c59eaf51e5a2a989 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Wed, 1 Apr 2026 20:18:29 +0200 Subject: [PATCH 10/12] Add .env --- test-services-cls/.env | 1 + 1 file changed, 1 insertion(+) create mode 100644 test-services-cls/.env diff --git a/test-services-cls/.env b/test-services-cls/.env new file mode 100644 index 0000000..9840038 --- /dev/null +++ b/test-services-cls/.env @@ -0,0 +1 @@ +RESTATE_CORE_LOG=trace \ No newline at end of file From aa831dd45ef576bf399b6c85ce49e044eeb2c889 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Wed, 1 Apr 2026 20:20:41 +0200 Subject: [PATCH 11/12] Use proper args instead of kwargs --- python/restate/cls.py | 96 +++++++++++++++++++++++++++++++++++-------- 1 file changed, 79 insertions(+), 17 deletions(-) diff --git a/python/restate/cls.py b/python/restate/cls.py index c15a2d3..7b03946 100644 --- a/python/restate/cls.py +++ b/python/restate/cls.py @@ -123,7 +123,7 @@ async def count(self) -> int: else: from typing_extensions import Self -from restate.context import RestateDurableFuture, RunOptions +from restate.context import HandlerType, RestateDurableCallFuture, RestateDurableFuture, RunOptions, SendHandle from restate.handler import RESTATE_UNIQUE_HANDLER_SYMBOL, HandlerIO, ServiceTag, make_handler from restate.retry_policy import InvocationRetryPolicy from restate.serde import DefaultSerde, Serde @@ -983,44 +983,106 @@ def sleep(delta: timedelta, name: Optional[str] = None) -> Any: # ── Service communication ── @staticmethod - def service_call(tpe: Any, arg: Any, **kwargs: Any) -> Any: + def service_call( + tpe: HandlerType[I, O], + arg: I, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> RestateDurableCallFuture[O]: """Call a service handler.""" - return Restate._ctx().service_call(tpe, arg=arg, **kwargs) + return Restate._ctx().service_call(tpe, arg=arg, idempotency_key=idempotency_key, headers=headers) @staticmethod - def service_send(tpe: Any, arg: Any, **kwargs: Any) -> Any: + def service_send( + tpe: HandlerType[I, O], + arg: I, + send_delay: Optional[timedelta] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> SendHandle: """Send a message to a service handler (fire-and-forget).""" - return Restate._ctx().service_send(tpe, arg=arg, **kwargs) + return Restate._ctx().service_send( + tpe, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers + ) @staticmethod - def object_call(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: + def object_call( + tpe: HandlerType[I, O], + key: str, + arg: I, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> RestateDurableCallFuture[O]: """Call a virtual object handler.""" - return Restate._ctx().object_call(tpe, key=key, arg=arg, **kwargs) + return Restate._ctx().object_call(tpe, key=key, arg=arg, idempotency_key=idempotency_key, headers=headers) @staticmethod - def object_send(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: + def object_send( + tpe: HandlerType[I, O], + key: str, + arg: I, + send_delay: Optional[timedelta] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> SendHandle: """Send a message to a virtual object handler (fire-and-forget).""" - return Restate._ctx().object_send(tpe, key=key, arg=arg, **kwargs) + return Restate._ctx().object_send( + tpe, key=key, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers + ) @staticmethod - def workflow_call(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: + def workflow_call( + tpe: HandlerType[I, O], + key: str, + arg: I, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> RestateDurableCallFuture[O]: """Call a workflow handler.""" - return Restate._ctx().workflow_call(tpe, key=key, arg=arg, **kwargs) + return Restate._ctx().workflow_call(tpe, key=key, arg=arg, idempotency_key=idempotency_key, headers=headers) @staticmethod - def workflow_send(tpe: Any, key: str, arg: Any, **kwargs: Any) -> Any: + def workflow_send( + tpe: HandlerType[I, O], + key: str, + arg: I, + send_delay: Optional[timedelta] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> SendHandle: """Send a message to a workflow handler (fire-and-forget).""" - return Restate._ctx().workflow_send(tpe, key=key, arg=arg, **kwargs) + return Restate._ctx().workflow_send( + tpe, key=key, arg=arg, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers + ) @staticmethod - def generic_call(service: str, handler: str, arg: bytes, key: Optional[str] = None, **kwargs: Any) -> Any: + def generic_call( + service: str, + handler: str, + arg: bytes, + key: Optional[str] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> RestateDurableCallFuture[bytes]: """Call a generic service/handler with raw bytes.""" - return Restate._ctx().generic_call(service, handler, arg, key=key, **kwargs) + return Restate._ctx().generic_call( + service, handler, arg, key=key, idempotency_key=idempotency_key, headers=headers + ) @staticmethod - def generic_send(service: str, handler: str, arg: bytes, key: Optional[str] = None, **kwargs: Any) -> Any: + def generic_send( + service: str, + handler: str, + arg: bytes, + key: Optional[str] = None, + send_delay: Optional[timedelta] = None, + idempotency_key: Optional[str] = None, + headers: Optional[Dict[str, str]] = None, + ) -> SendHandle: """Send a message to a generic service/handler with raw bytes.""" - return Restate._ctx().generic_send(service, handler, arg, key=key, **kwargs) + return Restate._ctx().generic_send( + service, handler, arg, key=key, send_delay=send_delay, idempotency_key=idempotency_key, headers=headers + ) # ── Awakeables ── From 9e174881d1d732de2dab3f633724fe19bd932a40 Mon Sep 17 00:00:00 2001 From: igalshilman Date: Wed, 1 Apr 2026 20:49:23 +0200 Subject: [PATCH 12/12] Add Restate.call_handle() --- examples/class_based_greeter.py | 4 +- python/restate/cls.py | 87 +++++++++++++++---- test-services-cls/services/cancel_test.py | 2 +- test-services-cls/services/kill_test.py | 2 +- test-services-cls/services/non_determinism.py | 6 +- 5 files changed, 75 insertions(+), 26 deletions(-) diff --git a/examples/class_based_greeter.py b/examples/class_based_greeter.py index 6480b98..92c589a 100644 --- a/examples/class_based_greeter.py +++ b/examples/class_based_greeter.py @@ -80,10 +80,10 @@ async def process(self, customer: str) -> str: count = await Counter.call(customer).increment(1) # Fire-and-forget send (returns SendHandle, not a coroutine) - Counter.send(customer).increment(1) # type: ignore[unused-coroutine] + Counter.send(customer).increment(1) # Send with delay - Counter.send(customer, delay=timedelta(seconds=30)).increment(1) # type: ignore[unused-coroutine] + Counter.send(customer, delay=timedelta(seconds=30)).increment(1) # Call a workflow receipt = await PaymentWorkflow.call(f"order-{count}").pay(100) diff --git a/python/restate/cls.py b/python/restate/cls.py index 7b03946..1e87632 100644 --- a/python/restate/cls.py +++ b/python/restate/cls.py @@ -107,6 +107,7 @@ async def count(self) -> int: from typing import ( Any, AsyncContextManager, + Awaitable, Callable, Coroutine, Dict, @@ -578,7 +579,7 @@ def __init__(self, cls: type, delay: Optional[timedelta] = None) -> None: self._cls = cls self._delay = delay - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Callable[..., SendHandle]: handlers = getattr(self._cls, _HANDLERS_ATTR, {}) h = handlers.get(name) if h is None: @@ -587,7 +588,7 @@ def __getattr__(self, name: str): ctx = _restate_context_var.get() - def invoke(arg=_MISSING): + def invoke(arg=_MISSING) -> SendHandle: if arg is _MISSING: return ctx.service_send(h.fn, arg=None, send_delay=self._delay) return ctx.service_send(h.fn, arg=arg, send_delay=self._delay) @@ -627,7 +628,7 @@ def __init__(self, cls: type, key: str, delay: Optional[timedelta] = None) -> No self._key = key self._delay = delay - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Callable[..., SendHandle]: handlers = getattr(self._cls, _HANDLERS_ATTR, {}) h = handlers.get(name) if h is None: @@ -636,7 +637,7 @@ def __getattr__(self, name: str): ctx = _restate_context_var.get() - def invoke(arg=_MISSING): + def invoke(arg=_MISSING) -> SendHandle: if arg is _MISSING: return ctx.object_send(h.fn, key=self._key, arg=None, send_delay=self._delay) return ctx.object_send(h.fn, key=self._key, arg=arg, send_delay=self._delay) @@ -676,7 +677,7 @@ def __init__(self, cls: type, key: str, delay: Optional[timedelta] = None) -> No self._key = key self._delay = delay - def __getattr__(self, name: str): + def __getattr__(self, name: str) -> Callable[..., SendHandle]: handlers = getattr(self._cls, _HANDLERS_ATTR, {}) h = handlers.get(name) if h is None: @@ -685,7 +686,7 @@ def __getattr__(self, name: str): ctx = _restate_context_var.get() - def invoke(arg=_MISSING): + def invoke(arg=_MISSING) -> SendHandle: if arg is _MISSING: return ctx.workflow_send(h.fn, key=self._key, arg=None, send_delay=self._delay) return ctx.workflow_send(h.fn, key=self._key, arg=arg, send_delay=self._delay) @@ -749,15 +750,27 @@ def __init_subclass__( def call(cls) -> Self: # type: ignore[return-type] """Return a proxy for making durable service calls. - The proxy has the same method signatures as the class, - giving full IDE autocomplete and type inference. + Typed as ``Self`` so the IDE sees the real handler signatures:: + + greeting = await Greeter.call().greet("Alice") # str + + At runtime returns a ``_ServiceCallProxy`` whose methods yield + ``RestateDurableCallFuture[T]``. After ``await`` the types + align (both are ``Awaitable[T]``). For advanced use cases + (invocation id, ``gather``), unwrap with ``Restate.call_handle()``. """ return _ServiceCallProxy(cls) # type: ignore[return-value] @classmethod - def send(cls, *, delay: Optional[timedelta] = None) -> Self: # type: ignore[return-type] - """Return a proxy for fire-and-forget service sends.""" - return _ServiceSendProxy(cls, delay) # type: ignore[return-value] + def send(cls, *, delay: Optional[timedelta] = None) -> _ServiceSendProxy: + """Return a proxy for fire-and-forget service sends. + + Returns a ``_ServiceSendProxy`` whose methods return ``SendHandle`` + (not a coroutine), so there is no need to ``await``:: + + Greeter.send().greet("Alice") # SendHandle — no await needed + """ + return _ServiceSendProxy(cls, delay) class VirtualObject: @@ -812,13 +825,19 @@ def __init_subclass__( @classmethod def call(cls, key: str) -> Self: # type: ignore[return-type] - """Return a proxy for making durable object calls.""" + """Return a proxy for making durable object calls. + + Typed as ``Self`` for IDE autocomplete — see ``Service.call()`` docstring. + """ return _ObjectCallProxy(cls, key) # type: ignore[return-value] @classmethod - def send(cls, key: str, *, delay: Optional[timedelta] = None) -> Self: # type: ignore[return-type] - """Return a proxy for fire-and-forget object sends.""" - return _ObjectSendProxy(cls, key, delay) # type: ignore[return-value] + def send(cls, key: str, *, delay: Optional[timedelta] = None) -> _ObjectSendProxy: + """Return a proxy for fire-and-forget object sends. + + Returns ``_ObjectSendProxy`` — methods return ``SendHandle``, not a coroutine. + """ + return _ObjectSendProxy(cls, key, delay) class Workflow: @@ -873,13 +892,19 @@ def __init_subclass__( @classmethod def call(cls, key: str) -> Self: # type: ignore[return-type] - """Return a proxy for making durable workflow calls.""" + """Return a proxy for making durable workflow calls. + + Typed as ``Self`` for IDE autocomplete — see ``Service.call()`` docstring. + """ return _WorkflowCallProxy(cls, key) # type: ignore[return-value] @classmethod - def send(cls, key: str, *, delay: Optional[timedelta] = None) -> Self: # type: ignore[return-type] - """Return a proxy for fire-and-forget workflow sends.""" - return _WorkflowSendProxy(cls, key, delay) # type: ignore[return-value] + def send(cls, key: str, *, delay: Optional[timedelta] = None) -> _WorkflowSendProxy: + """Return a proxy for fire-and-forget workflow sends. + + Returns ``_WorkflowSendProxy`` — methods return ``SendHandle``, not a coroutine. + """ + return _WorkflowSendProxy(cls, key, delay) # ── Context accessor class ──────────────────────────────────────────────── @@ -907,6 +932,30 @@ def _ctx() -> Any: return current_context() + # ── Call handle ── + + @staticmethod + def call_handle(coro: Awaitable[T]) -> RestateDurableCallFuture[T]: + """Unwrap a ``call()`` proxy result to its real ``RestateDurableCallFuture``. + + The ``call()`` fluent proxy returns ``Self`` for IDE autocomplete, + so the type checker sees handler return types (e.g. ``Awaitable[int]``). + For simple ``await`` usage that's fine. But when you need the full + ``RestateDurableCallFuture`` — for example to read ``invocation_id`` + or to pass it to ``restate.gather()`` — use this method:: + + handle = Restate.call_handle(Counter.call("key").increment(1)) + # handle: RestateDurableCallFuture[int] + invocation_id = handle.invocation_id + result = await handle + + At runtime the proxy already returns a ``RestateDurableCallFuture``, + so this is a safe cast with a runtime sanity check. + """ + if not isinstance(coro, RestateDurableCallFuture): + raise TypeError(f"Expected a RestateDurableCallFuture from a .call() proxy, got {type(coro).__name__}") + return coro + # ── State ── @staticmethod diff --git a/test-services-cls/services/cancel_test.py b/test-services-cls/services/cancel_test.py index ab6c103..6816c16 100644 --- a/test-services-cls/services/cancel_test.py +++ b/test-services-cls/services/cancel_test.py @@ -47,7 +47,7 @@ class CancelTestBlockingService(VirtualObject, name="CancelTestBlockingService") @handler async def block(self, op: BlockingOperation): name, awakeable = Restate.awakeable() - awakeable_holder.AwakeableHolder.send(Restate.key()).hold(name) # type: ignore[unused-coroutine] + awakeable_holder.AwakeableHolder.send(Restate.key()).hold(name) await awakeable if op == "CALL": diff --git a/test-services-cls/services/kill_test.py b/test-services-cls/services/kill_test.py index 8e16e85..2eceb35 100644 --- a/test-services-cls/services/kill_test.py +++ b/test-services-cls/services/kill_test.py @@ -29,7 +29,7 @@ class KillTestSingleton(VirtualObject, name="KillTestSingleton"): @handler(name="recursiveCall") async def recursive_call(self): name, promise = Restate.awakeable() - AwakeableHolder.send(Restate.key()).hold(name) # type: ignore[unused-coroutine] + AwakeableHolder.send(Restate.key()).hold(name) await promise await KillTestSingleton.call(Restate.key()).recursive_call() diff --git a/test-services-cls/services/non_determinism.py b/test-services-cls/services/non_determinism.py index 65b90bd..0c83f29 100644 --- a/test-services-cls/services/non_determinism.py +++ b/test-services-cls/services/non_determinism.py @@ -28,7 +28,7 @@ def do_left_action() -> bool: def increment_counter(): - Counter.send(Restate.key()).add(1) # type: ignore[unused-coroutine] + Counter.send(Restate.key()).add(1) class NonDeterministic(VirtualObject, name="NonDeterministic"): @@ -45,9 +45,9 @@ async def set_different_key(self): @handler(name="backgroundInvokeWithDifferentTargets") async def background_invoke_with_different_targets(self): if do_left_action(): - Counter.send("abc").get() # type: ignore[unused-coroutine] + Counter.send("abc").get() else: - Counter.send("abc").reset() # type: ignore[unused-coroutine] + Counter.send("abc").reset() await Restate.sleep(timedelta(milliseconds=100)) increment_counter()