Skip to content

Commit 35b109c

Browse files
authored
Introduce RetryableError (#184)
Allow users to customize retry timing when needed, e.g., when receiving a `Retry-After` header. cf. restatedev/sdk-typescript#569
1 parent 538d03c commit 35b109c

5 files changed

Lines changed: 182 additions & 5 deletions

File tree

python/restate/exceptions.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,9 @@
1111
"""This module contains the restate exceptions"""
1212

1313
# pylint: disable=C0301
14+
from typing import Optional
15+
16+
from datetime import timedelta
1417

1518

1619
class TerminalError(Exception):
@@ -22,6 +25,23 @@ def __init__(self, message: str, status_code: int = 500) -> None:
2225
self.status_code = status_code
2326

2427

28+
class RetryableError(Exception):
29+
"""
30+
This exception is thrown to indicate that Restate should retry with an explicit delay.
31+
32+
Args:
33+
message: The error message.
34+
status_code: The HTTP status code to return for this error (default: 500).
35+
retry_after: The delay after which Restate should retry the invocation.
36+
"""
37+
38+
def __init__(self, message: str, status_code: int = 500, retry_after: Optional[timedelta] = None) -> None:
39+
super().__init__(message)
40+
self.message = message
41+
self.status_code = status_code
42+
self.retry_after = retry_after
43+
44+
2545
class SdkInternalBaseException(BaseException):
2646
"""This exception is internal, and you should not catch it.
2747
If you need to distinguish with other exceptions, use is_internal_exception."""

python/restate/server_context.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,13 @@
4646
RunOptions,
4747
P,
4848
)
49-
from restate.exceptions import TerminalError, SdkInternalBaseException, SdkInternalException, SuspendedException
49+
from restate.exceptions import (
50+
TerminalError,
51+
SdkInternalBaseException,
52+
SdkInternalException,
53+
SuspendedException,
54+
RetryableError,
55+
)
5056
from restate.handler import Handler, handler_from_callable, invoke_handler
5157
from restate.serde import BytesSerde, DefaultSerde, Serde
5258
from restate.server_types import ReceiveChannel, Send
@@ -404,6 +410,10 @@ async def enter(self):
404410
restate_context_is_replaying.set(False)
405411
self.vm.sys_write_output_failure(failure)
406412
self.vm.sys_end()
413+
except RetryableError as r:
414+
stacktrace = "".join(traceback.format_exception(r))
415+
restate_context_is_replaying.set(False)
416+
self.vm.notify_error(r.message, stacktrace, r.retry_after)
407417
# pylint: disable=W0718
408418
except asyncio.CancelledError:
409419
pass
@@ -422,6 +432,11 @@ async def enter(self):
422432
self.vm.sys_write_output_failure(failure)
423433
self.vm.sys_end()
424434
break
435+
elif isinstance(cause, RetryableError):
436+
stacktrace = "".join(traceback.format_exception(cause))
437+
restate_context_is_replaying.set(False)
438+
self.vm.notify_error(cause.message, stacktrace, cause.retry_after)
439+
break
425440
elif isinstance(cause, SdkInternalBaseException):
426441
break
427442
cause = cause.__cause__
@@ -674,6 +689,18 @@ async def create_run_coroutine(
674689
except TerminalError as t:
675690
failure = Failure(code=t.status_code, message=t.message)
676691
self.vm.propose_run_completion_failure(handle, failure)
692+
except RetryableError as r:
693+
failure = Failure(code=r.status_code, message=r.message)
694+
end = time.time()
695+
attempt_duration = int((end - start) * 1000)
696+
self.vm.propose_run_completion_transient_with_delay_override(
697+
handle,
698+
failure,
699+
attempt_duration_ms=attempt_duration,
700+
delay_override=r.retry_after,
701+
max_retry_attempts_override=max_attempts,
702+
max_retry_duration_override=max_duration,
703+
)
677704
except asyncio.CancelledError as e:
678705
raise e from None
679706
except SdkInternalBaseException as e:

python/restate/vm.py

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,12 @@
1111
"""
1212
wrap the restate._internal.PyVM class
1313
"""
14+
1415
# pylint: disable=E1101,R0917
1516
# pylint: disable=too-many-arguments
1617
# pylint: disable=too-few-public-methods
18+
from typing import Optional
19+
from datetime import timedelta
1720

1821
from dataclasses import dataclass
1922
import typing
@@ -173,9 +176,11 @@ def notify_input_closed(self):
173176
"""Notify the virtual machine that the input has been closed."""
174177
self.vm.notify_input_closed()
175178

176-
def notify_error(self, error: str, stacktrace: str):
179+
def notify_error(self, error: str, stacktrace: str, delay_override: Optional[timedelta] = None):
177180
"""Notify the virtual machine of an error."""
178-
self.vm.notify_error(error, stacktrace)
181+
self.vm.notify_error(
182+
error, stacktrace, int(delay_override.total_seconds() * 1000) if delay_override is not None else None
183+
)
179184

180185
def take_output(self) -> typing.Optional[bytes]:
181186
"""Take the output from the virtual machine."""
@@ -444,6 +449,28 @@ def propose_run_completion_transient(
444449
)
445450
self.vm.propose_run_completion_failure_transient(handle, py_failure, attempt_duration_ms, py_config)
446451

452+
def propose_run_completion_transient_with_delay_override(
453+
self,
454+
handle: int,
455+
failure: Failure,
456+
attempt_duration_ms: int,
457+
delay_override: timedelta | None,
458+
max_retry_attempts_override: int | None,
459+
max_retry_duration_override: timedelta | None,
460+
):
461+
"""
462+
Exit a side effect with a transient Error and override the retry policy with explicit parameters.
463+
"""
464+
py_failure = PyFailure(failure.code, failure.message, failure.stacktrace)
465+
self.vm.propose_run_completion_failure_transient_with_delay_override(
466+
handle,
467+
py_failure,
468+
attempt_duration_ms,
469+
int(delay_override.total_seconds() * 1000) if delay_override else None,
470+
max_retry_attempts_override,
471+
int(max_retry_duration_override.total_seconds() * 1000) if max_retry_duration_override else None,
472+
)
473+
447474
def sys_end(self):
448475
"""
449476
This method is responsible for ending the system.

src/lib.rs

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -335,12 +335,15 @@ impl PyVM {
335335
self_.vm.notify_input_closed();
336336
}
337337

338-
#[pyo3(signature = (error, stacktrace=None))]
339-
fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, stacktrace: Option<String>) {
338+
#[pyo3(signature = (error, stacktrace=None, delay_override_ms=None))]
339+
fn notify_error(mut self_: PyRefMut<'_, Self>, error: String, stacktrace: Option<String>, delay_override_ms: Option<u64>) {
340340
let mut error = Error::new(restate_sdk_shared_core::error::codes::INTERNAL, error);
341341
if let Some(desc) = stacktrace {
342342
error = error.with_stacktrace(desc);
343343
}
344+
if let Some(delay) = delay_override_ms {
345+
error = error.with_next_retry_delay_override(Duration::from_millis(delay));
346+
}
344347
CoreVM::notify_error(&mut self_.vm, error, None);
345348
}
346349

@@ -721,6 +724,37 @@ impl PyVM {
721724
.map_err(Into::into)
722725
}
723726

727+
fn propose_run_completion_failure_transient_with_delay_override(
728+
mut self_: PyRefMut<'_, Self>,
729+
handle: PyNotificationHandle,
730+
value: PyFailure,
731+
attempt_duration: u64,
732+
delay_override_ms: Option<u64>,
733+
max_retry_attempts_override: Option<u32>,
734+
max_retry_duration_override_ms: Option<u64>,
735+
) -> Result<(), PyVMError> {
736+
let retry_policy = if delay_override_ms.is_some() || max_retry_attempts_override.is_some() || max_retry_duration_override_ms.is_some() {
737+
RetryPolicy::FixedDelay {
738+
interval: delay_override_ms.map(Duration::from_millis),
739+
max_attempts: max_retry_attempts_override,
740+
max_duration: max_retry_duration_override_ms.map(Duration::from_millis),
741+
}
742+
} else {
743+
RetryPolicy::Infinite
744+
};
745+
self_
746+
.vm
747+
.propose_run_completion(
748+
handle.into(),
749+
RunExitResult::RetryableFailure {
750+
attempt_duration: Duration::from_millis(attempt_duration),
751+
error: value.into(),
752+
},
753+
retry_policy,
754+
)
755+
.map_err(Into::into)
756+
}
757+
724758
fn sys_write_output_success(
725759
mut self_: PyRefMut<'_, Self>,
726760
buffer: &Bound<'_, PyBytes>,

tests/servercontext.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,15 @@
88
# directory of this repository or package, or at
99
# https://github.com/restatedev/sdk-typescript/blob/main/LICENSE
1010
#
11+
from datetime import timedelta
12+
from restate.exceptions import RetryableError
1113

1214
from contextlib import asynccontextmanager
1315
import restate
1416
from restate import (
1517
Context,
18+
HttpError,
19+
InvocationRetryPolicy,
1620
RunOptions,
1721
Service,
1822
TerminalError,
@@ -78,6 +82,71 @@ async def greet(ctx: Context, name: str) -> str:
7882
await client.service_call(greet, arg="bob")
7983

8084

85+
async def test_retryable_exception():
86+
greeter = Service("greeter")
87+
attempts = 0
88+
89+
@greeter.handler(
90+
invocation_retry_policy=InvocationRetryPolicy(
91+
max_attempts=3,
92+
# Something really long to trigger a test timeout.
93+
# Default httpx client timeout is 5 seconds.
94+
initial_interval=timedelta(hours=1),
95+
),
96+
)
97+
async def greet(ctx: Context, name: str) -> str:
98+
nonlocal attempts
99+
print(f"Attempt {attempts}")
100+
try:
101+
if attempts == 0:
102+
raise RetryableError("Simulated retryable error", retry_after=timedelta(milliseconds=100))
103+
else:
104+
raise TerminalError("Simulated terminal error")
105+
finally:
106+
attempts += 1
107+
108+
async with simple_harness(greeter) as client:
109+
with pytest.raises(HttpError): # Should be some sort of client error (not a timeout).
110+
await client.service_call(greet, arg="bob")
111+
112+
assert attempts == 2
113+
114+
115+
async def test_accidentally_wrapped_retryable_exception():
116+
greeter = Service("greeter")
117+
attempts = 0
118+
119+
@greeter.handler(
120+
invocation_retry_policy=InvocationRetryPolicy(
121+
max_attempts=3,
122+
# Something really long to trigger a test timeout.
123+
# Default httpx client timeout is 5 seconds.
124+
initial_interval=timedelta(hours=1),
125+
),
126+
)
127+
async def greet(ctx: Context, name: str) -> str:
128+
nonlocal attempts
129+
print(f"Attempt {attempts}")
130+
try:
131+
if attempts == 0:
132+
try:
133+
raise RetryableError("Simulated retryable error", retry_after=timedelta(milliseconds=100))
134+
except RetryableError as re:
135+
# Simulate a developer accidentally catching and wrapping a RetryableError, which should still
136+
# be treated as retryable by the system.
137+
raise ValueError("Wrapped retryable error") from re
138+
else:
139+
raise TerminalError("Simulated terminal error")
140+
finally:
141+
attempts += 1
142+
143+
async with simple_harness(greeter) as client:
144+
with pytest.raises(HttpError): # Should be some sort of client error (not a timeout).
145+
await client.service_call(greet, arg="bob")
146+
147+
assert attempts == 2
148+
149+
81150
async def test_promise_default_serde():
82151
workflow = Workflow("test_workflow")
83152

0 commit comments

Comments
 (0)