Skip to content

Commit fc29f1b

Browse files
committed
Introduce RetryableError
Allow users to customize retry timing when needed, e.g., when receiving a `Retry-After` header. cf. restatedev/sdk-typescript#569
1 parent 9e75895 commit fc29f1b

File tree

5 files changed

+163
-1
lines changed

5 files changed

+163
-1
lines changed

python/restate/exceptions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
"""This module contains the restate exceptions"""
1212

1313
# pylint: disable=C0301
14+
from typing import Optional
15+
16+
from datetime import timedelta
1417

1518

1619
class TerminalError(Exception):
@@ -22,6 +25,23 @@ def __init__(self, message: str, status_code: int = 500) -> None:
2225
self.status_code = status_code
2326

2427

28+
class RetryableError(Exception):
29+
"""
30+
This exception is thrown to indicate that Restate should retry with an explicit delay.
31+
32+
Args:
33+
message: The error message.
34+
retry_after: The delay after which Restate should retry the invocation.
35+
status_code: The HTTP status code to return for this error (default: 500).
36+
"""
37+
38+
def __init__(self, message: str, status_code: int = 500, retry_after: Optional[timedelta] = None) -> None:
39+
super().__init__(message)
40+
self.message = message
41+
self.status_code = status_code
42+
self.retry_after = retry_after
43+
44+
2545
class SdkInternalBaseException(BaseException):
2646
"""This exception is internal, and you should not catch it.
2747
If you need to distinguish with other exceptions, use is_internal_exception."""

python/restate/server_context.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@
4646
RunOptions,
4747
P,
4848
)
49-
from restate.exceptions import TerminalError, SdkInternalBaseException, SdkInternalException, SuspendedException
49+
from restate.exceptions import (
50+
TerminalError,
51+
SdkInternalBaseException,
52+
SdkInternalException,
53+
SuspendedException,
54+
RetryableError,
55+
)
5056
from restate.handler import Handler, handler_from_callable, invoke_handler
5157
from restate.serde import BytesSerde, DefaultSerde, Serde
5258
from restate.server_types import ReceiveChannel, Send
@@ -404,6 +410,10 @@ async def enter(self):
404410
restate_context_is_replaying.set(False)
405411
self.vm.sys_write_output_failure(failure)
406412
self.vm.sys_end()
413+
except RetryableError as r:
414+
stacktrace = "".join(traceback.format_exception(r))
415+
restate_context_is_replaying.set(False)
416+
self.vm.notify_error_with_delay_override(r.message, stacktrace, r.retry_after)
407417
# pylint: disable=W0718
408418
except asyncio.CancelledError:
409419
pass
@@ -674,6 +684,18 @@ async def create_run_coroutine(
674684
except TerminalError as t:
675685
failure = Failure(code=t.status_code, message=t.message)
676686
self.vm.propose_run_completion_failure(handle, failure)
687+
except RetryableError as r:
688+
failure = Failure(code=r.status_code, message=r.message)
689+
end = time.time()
690+
attempt_duration = int((end - start) * 1000)
691+
self.vm.propose_run_completion_transient_with_delay_override(
692+
handle,
693+
failure,
694+
attempt_duration_ms=attempt_duration,
695+
delay_override=r.retry_after,
696+
max_retry_attempts_override=max_attempts,
697+
max_retry_duration_override=max_duration,
698+
)
677699
except asyncio.CancelledError as e:
678700
raise e from None
679701
except SdkInternalBaseException as e:

python/restate/vm.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
"""
1212
wrap the restate._internal.PyVM class
1313
"""
14+
1415
# pylint: disable=E1101,R0917
1516
# pylint: disable=too-many-arguments
1617
# pylint: disable=too-few-public-methods
18+
from typing import Optional
19+
from datetime import timedelta
1720

1821
from dataclasses import dataclass
1922
import typing
@@ -177,6 +180,13 @@ def notify_error(self, error: str, stacktrace: str):
177180
"""Notify the virtual machine of an error."""
178181
self.vm.notify_error(error, stacktrace)
179182

183+
def notify_error_with_delay_override(self, error: str, stacktrace: str, delay_override: Optional[timedelta]):
184+
"""Notify the virtual machine of an error, with a delay override for retrying."""
185+
if delay_override is None:
186+
self.vm.notify_error(error, stacktrace)
187+
else:
188+
self.vm.notify_error_with_delay_override(error, stacktrace, int(delay_override.total_seconds() * 1000))
189+
180190
def take_output(self) -> typing.Optional[bytes]:
181191
"""Take the output from the virtual machine."""
182192
return self.vm.take_output()
@@ -444,6 +454,28 @@ def propose_run_completion_transient(
444454
)
445455
self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config)
446456

457+
def propose_run_completion_transient_with_delay_override(
458+
self,
459+
handle: int,
460+
failure: Failure,
461+
attempt_duration_ms: int,
462+
delay_override: timedelta | None,
463+
max_retry_attempts_override: int | None,
464+
max_retry_duration_override: timedelta | None,
465+
):
466+
"""
467+
Exit a side effect with a transient Error and override the retry policy with explicit parameters.
468+
"""
469+
py_failure = PyFailure(failure.code, failure.message, failure.stacktrace)
470+
self.vm.propose_run_completion_failure_transient_with_delay_override(
471+
handle,
472+
py_failure,
473+
attempt_duration_ms,
474+
int(delay_override.total_seconds() * 1000) if delay_override else None,
475+
max_retry_attempts_override,
476+
int(max_retry_duration_override.total_seconds() * 1000) if max_retry_duration_override else None,
477+
)
478+
447479
def sys_end(self):
448480
"""
449481
This method is responsible for ending the system.

src/lib.rs

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,18 @@ impl PyVM {
344344
CoreVM::notify_error(&mut self_.vm, error, None);
345345
}
346346

347+
#[pyo3(signature = (error, stacktrace=None, delay_override_ms=None))]
348+
fn notify_error_with_delay_override(mut self_: PyRefMut<'_, Self>, error: String, stacktrace: Option<String>, delay_override_ms: Option<u64>) {
349+
let mut error = Error::new(restate_sdk_shared_core::error::codes::INTERNAL, error);
350+
if let Some(desc) = stacktrace {
351+
error = error.with_stacktrace(desc);
352+
}
353+
if let Some(delay) = delay_override_ms {
354+
error = error.with_next_retry_delay_override(Duration::from_millis(delay));
355+
}
356+
CoreVM::notify_error(&mut self_.vm, error, None);
357+
}
358+
347359
// Take(s)
348360

349361
/// Returns either bytes or None, indicating EOF
@@ -721,6 +733,37 @@ impl PyVM {
721733
.map_err(Into::into)
722734
}
723735

736+
fn propose_run_completion_failure_transient_with_delay_override(
737+
mut self_: PyRefMut<'_, Self>,
738+
handle: PyNotificationHandle,
739+
value: PyFailure,
740+
attempt_duration: u64,
741+
delay_override_ms: Option<u64>,
742+
max_retry_attempts_override: Option<u32>,
743+
max_retry_duration_override_ms: Option<u64>,
744+
) -> Result<(), PyVMError> {
745+
let retry_policy = if delay_override_ms.is_some() || max_retry_attempts_override.is_some() || max_retry_duration_override_ms.is_some() {
746+
RetryPolicy::FixedDelay {
747+
interval: delay_override_ms.map(Duration::from_millis),
748+
max_attempts: max_retry_attempts_override,
749+
max_duration: max_retry_duration_override_ms.map(Duration::from_millis),
750+
}
751+
} else {
752+
RetryPolicy::Infinite
753+
};
754+
self_
755+
.vm
756+
.propose_run_completion(
757+
handle.into(),
758+
RunExitResult::RetryableFailure {
759+
attempt_duration: Duration::from_millis(attempt_duration),
760+
error: value.into(),
761+
},
762+
retry_policy,
763+
)
764+
.map_err(Into::into)
765+
}
766+
724767
fn sys_write_output_success(
725768
mut self_: PyRefMut<'_, Self>,
726769
buffer: &Bound<'_, PyBytes>,

tests/servercontext.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,9 @@
88
# directory of this repository or package, or at
99
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
1010
#
11+
import time
12+
from datetime import timedelta
13+
from restate.exceptions import RetryableError
1114

1215
from contextlib import asynccontextmanager
1316
import restate
@@ -19,6 +22,8 @@
1922
VirtualObject,
2023
Workflow,
2124
WorkflowContext,
25+
InvocationRetryPolicy,
26+
HttpError,
2227
)
2328
from restate.serde import DefaultSerde
2429
import pytest
@@ -78,6 +83,46 @@ async def greet(ctx: Context, name: str) -> str:
7883
await client.service_call(greet, arg="bob")
7984

8085

86+
async def test_retryable_exception():
87+
# TODO: This test is not deterministic because of the timing of the retry.
88+
# However, it should only have false positives, not false negatives.
89+
90+
greeter = Service("greeter")
91+
attempts = 0
92+
93+
@greeter.handler(
94+
invocation_retry_policy=InvocationRetryPolicy(
95+
max_attempts=3,
96+
# Something really long to trigger a test timeout.
97+
# Default httpx client timeout is 5 seconds.
98+
initial_interval=timedelta(hours=1),
99+
),
100+
)
101+
async def greet(ctx: Context, name: str) -> str:
102+
nonlocal attempts
103+
print(f"Attempt {attempts}")
104+
try:
105+
if attempts == 0:
106+
raise RetryableError("Simulated retryable error", retry_after=timedelta(seconds=1))
107+
else:
108+
raise TerminalError("Simulated terminal error")
109+
finally:
110+
attempts += 1
111+
112+
async with simple_harness(greeter) as client:
113+
start = time.monotonic()
114+
with pytest.raises(HttpError): # Should be some sort of client error (not a timeout).
115+
await client.service_call(greet, arg="bob")
116+
end = time.monotonic()
117+
duration = end - start
118+
# Retry should take _at least_ 1 second.
119+
assert duration >= 1
120+
# The upper limit is really ~5s due to the httpx client timeout.
121+
assert duration < 300
122+
123+
assert attempts == 2
124+
125+
81126
async def test_promise_default_serde():
82127
workflow = Workflow("test_workflow")
83128

0 commit comments

Comments
 (0)