Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/advanced/error-handling.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ The SDK provides several exception types for different failure scenarios.
| `InvocationError` | Yes (by Lambda) | Lambda retries invocation | Transient infrastructure issues |
| `CallbackError` | No | Returns FAILED status | Callback handling failures |
| `StepInterruptedError` | Yes (automatic) | Retries on next invocation | Step interrupted before checkpoint |
| `CheckpointError` | Depends | Retries if 4xx (except invalid token) | Failed to save execution state |
| `CheckpointError` | Depends | Permanent on 4xx non-429 (except invalid checkpoint token); retries otherwise | Failed to save execution state |
| `SerDesError` | No | Returns FAILED status | Serialization failures |

### Base exceptions
Expand Down
23 changes: 8 additions & 15 deletions src/aws_durable_execution_sdk_python/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,32 +155,25 @@ def from_exception(cls, exception: Exception) -> CheckpointError:
error: AwsErrorObj | None = base.error
error_category: CheckpointErrorCategory = CheckpointErrorCategory.INVOCATION

# InvalidParameterValueException and error message starts with "Invalid Checkpoint Token" is an InvocationError
# all other 4xx errors are Execution Errors and should be retried
# all 5xx errors are Invocation Errors
# 4xx errors (except 429) are permanent failures (EXECUTION), unless it's an
# InvalidParameterValueException with "Invalid Checkpoint Token" which is retriable (INVOCATION).
# 5xx, 429, and network errors are retriable (INVOCATION).
status_code: int | None = (metadata and metadata.get("HTTPStatusCode")) or None
if (
status_code
# if we are in 4xx range (except 429) and is not an InvalidParameterValueException with Invalid Checkpoint Token
# then it's an execution error
and status_code < SERVICE_ERROR
and status_code >= BAD_REQUEST_ERROR
and BAD_REQUEST_ERROR <= status_code < SERVICE_ERROR
and status_code != TOO_MANY_REQUESTS_ERROR
and error
and (
# is not InvalidParam => Execution
(error.get("Code", "") or "") != "InvalidParameterValueException"
# is not Invalid Token => Execution
or not (error.get("Message") or "").startswith(
"Invalid Checkpoint Token"
)
and not (
(error.get("Code") or "") == "InvalidParameterValueException"
and (error.get("Message") or "").startswith("Invalid Checkpoint Token")
)
):
error_category = CheckpointErrorCategory.EXECUTION
return CheckpointError(str(exception), error_category, error, metadata)

def is_retriable(self):
return self.error_category == CheckpointErrorCategory.EXECUTION
return self.error_category == CheckpointErrorCategory.INVOCATION


class ValidationError(DurableExecutionsError):
Expand Down
27 changes: 22 additions & 5 deletions tests/exceptions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,23 @@ def test_checkpoint_error_classification_invalid_token_invocation():
result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.INVOCATION
assert result.is_retriable()


def test_checkpoint_error_classification_payload_size_exceeded_execution():
"""Test 4xx InvalidParameterValueException with STEP output payload size limit exceeded is execution error."""
error_response = {
"Error": {
"Code": "InvalidParameterValueException",
"Message": "STEP output payload size must be less than or equal to 262144 bytes.",
},
"ResponseMetadata": {"HTTPStatusCode": 400},
}
client_error = ClientError(error_response, "Checkpoint")

result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.EXECUTION
assert not result.is_retriable()


Expand All @@ -81,7 +98,7 @@ def test_checkpoint_error_classification_other_4xx_execution():
result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.EXECUTION
assert result.is_retriable()
assert not result.is_retriable()


def test_checkpoint_error_classification_429_invocation():
Expand All @@ -95,7 +112,7 @@ def test_checkpoint_error_classification_429_invocation():
result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.INVOCATION
assert not result.is_retriable()
assert result.is_retriable()


def test_checkpoint_error_classification_invalid_param_without_token_execution():
Expand All @@ -112,7 +129,7 @@ def test_checkpoint_error_classification_invalid_param_without_token_execution()
result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.EXECUTION
assert result.is_retriable()
assert not result.is_retriable()


def test_checkpoint_error_classification_5xx_invocation():
Expand All @@ -126,7 +143,7 @@ def test_checkpoint_error_classification_5xx_invocation():
result = CheckpointError.from_exception(client_error)

assert result.error_category == CheckpointErrorCategory.INVOCATION
assert not result.is_retriable()
assert result.is_retriable()


def test_checkpoint_error_classification_unknown_invocation():
Expand All @@ -136,7 +153,7 @@ def test_checkpoint_error_classification_unknown_invocation():
result = CheckpointError.from_exception(unknown_error)

assert result.error_category == CheckpointErrorCategory.INVOCATION
assert not result.is_retriable()
assert result.is_retriable()


def test_validation_error():
Expand Down
89 changes: 43 additions & 46 deletions tests/execution_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1066,8 +1066,9 @@ def test_handler(event: Any, context: DurableContext) -> dict:
# Make the service client checkpoint call fail with CheckpointError
mock_client.checkpoint.side_effect = failing_checkpoint

with pytest.raises(CheckpointError, match="Background checkpoint failed"):
test_handler(invocation_input, lambda_context)
response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"


# endregion durable_execution
Expand Down Expand Up @@ -1120,16 +1121,13 @@ def slow_background():
"aws_durable_execution_sdk_python.state.ExecutionState.checkpoint_batches_forever",
side_effect=slow_background,
):
with pytest.raises(CheckpointError, match="Checkpoint system failed"):
test_handler(invocation_input, lambda_context)

response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"

def test_durable_execution_checkpoint_invocation_error_stops_background():
"""Test that CheckpointError handler stops background checkpointing.

When user code raises CheckpointError, the handler should stop the background
thread before re-raising to terminate the Lambda.
"""
def test_durable_execution_checkpoint_invocation_error_retries():
"""Test that CheckpointError with INVOCATION category re-raises to trigger Lambda retry."""
mock_client = Mock(spec=DurableServiceClient)

@durable_execution
Expand Down Expand Up @@ -1171,13 +1169,12 @@ def slow_background():
"aws_durable_execution_sdk_python.state.ExecutionState.checkpoint_batches_forever",
side_effect=slow_background,
):
response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"
with pytest.raises(CheckpointError, match="Checkpoint system failed"):
test_handler(invocation_input, lambda_context)


def test_durable_execution_background_thread_execution_error_retries():
"""Test that background thread Execution errors are retried (re-raised)."""
def test_durable_execution_background_thread_execution_error_returns_failed():
"""Test that background thread Execution errors return FAILED (permanent, no retry)."""
mock_client = Mock(spec=DurableServiceClient)

def failing_checkpoint(*args, **kwargs):
Expand Down Expand Up @@ -1215,12 +1212,13 @@ def test_handler(event: Any, context: DurableContext) -> dict:

mock_client.checkpoint.side_effect = failing_checkpoint

with pytest.raises(CheckpointError, match="Background checkpoint failed"):
test_handler(invocation_input, lambda_context)
response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"


def test_durable_execution_background_thread_invocation_error_returns_failed():
"""Test that background thread Invocation errors return FAILED status."""
def test_durable_execution_background_thread_invocation_error_retries():
"""Test that background thread Invocation errors re-raise to trigger Lambda retry."""
mock_client = Mock(spec=DurableServiceClient)

def failing_checkpoint(*args, **kwargs):
Expand Down Expand Up @@ -1258,13 +1256,12 @@ def test_handler(event: Any, context: DurableContext) -> dict:

mock_client.checkpoint.side_effect = failing_checkpoint

response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"
with pytest.raises(CheckpointError, match="Background checkpoint failed"):
test_handler(invocation_input, lambda_context)


def test_durable_execution_final_success_checkpoint_execution_error_retries():
"""Test that execution errors on final success checkpoint trigger retry."""
def test_durable_execution_final_success_checkpoint_execution_error_returns_failed():
"""Test that execution errors on final success checkpoint return FAILED (permanent, no retry)."""
mock_client = Mock(spec=DurableServiceClient)

def failing_final_checkpoint(*args, **kwargs):
Expand Down Expand Up @@ -1303,12 +1300,13 @@ def test_handler(event: Any, context: DurableContext) -> dict:

mock_client.checkpoint.side_effect = failing_final_checkpoint

with pytest.raises(CheckpointError, match="Final checkpoint failed"):
test_handler(invocation_input, lambda_context)
response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"


def test_durable_execution_final_success_checkpoint_invocation_error_returns_failed():
"""Test that invocation errors on final success checkpoint return FAILED."""
def test_durable_execution_final_success_checkpoint_invocation_error_retries():
"""Test that invocation errors on final success checkpoint re-raise to trigger Lambda retry."""
mock_client = Mock(spec=DurableServiceClient)

def failing_final_checkpoint(*args, **kwargs):
Expand Down Expand Up @@ -1348,14 +1346,12 @@ def test_handler(event: Any, context: DurableContext) -> dict:

mock_client.checkpoint.side_effect = failing_final_checkpoint

response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"
assert response["Error"]["ErrorMessage"] == "Final checkpoint failed"
with pytest.raises(CheckpointError, match="Final checkpoint failed"):
test_handler(invocation_input, lambda_context)


def test_durable_execution_final_failure_checkpoint_execution_error_retries():
"""Test that execution errors on final failure checkpoint trigger retry."""
def test_durable_execution_final_failure_checkpoint_execution_error_returns_failed():
"""Test that execution errors on final failure checkpoint return FAILED (permanent, no retry)."""
mock_client = Mock(spec=DurableServiceClient)

def failing_final_checkpoint(*args, **kwargs):
Expand Down Expand Up @@ -1396,12 +1392,13 @@ def test_handler(event: Any, context: DurableContext) -> dict:

mock_client.checkpoint.side_effect = failing_final_checkpoint

with pytest.raises(CheckpointError, match="Final checkpoint failed"):
test_handler(invocation_input, lambda_context)
response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"


def test_durable_execution_final_failure_checkpoint_invocation_error_returns_failed():
"""Test that invocation errors on final failure checkpoint return FAILED."""
def test_durable_execution_final_failure_checkpoint_invocation_error_retries():
"""Test that invocation errors on final failure checkpoint re-raise to trigger Lambda retry."""
mock_client = Mock(spec=DurableServiceClient)

def failing_final_checkpoint(*args, **kwargs):
Expand Down Expand Up @@ -1442,10 +1439,8 @@ def test_handler(event: Any, context: DurableContext) -> dict:

mock_client.checkpoint.side_effect = failing_final_checkpoint

response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"
assert response["Error"]["ErrorMessage"] == "Final checkpoint failed"
with pytest.raises(CheckpointError, match="Final checkpoint failed"):
test_handler(invocation_input, lambda_context)


def test_durable_handler_background_thread_failure_on_succeed_checkpoint():
Expand Down Expand Up @@ -1809,8 +1804,9 @@ def test_handler(event: Any, context: DurableContext) -> dict:
mock_client.checkpoint.side_effect = failing_checkpoint

with patch("aws_durable_execution_sdk_python.execution.logger", mock_logger):
with pytest.raises(CheckpointError):
test_handler(invocation_input, lambda_context)
response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"

mock_logger.exception.assert_called_once()
call_args = mock_logger.exception.call_args
Expand Down Expand Up @@ -1922,8 +1918,9 @@ def test_handler(event: Any, context: DurableContext) -> dict:
lambda_context.tenant_id = None

with patch("aws_durable_execution_sdk_python.execution.logger", mock_logger):
with pytest.raises(CheckpointError):
test_handler(invocation_input, lambda_context)
response = test_handler(invocation_input, lambda_context)
assert response["Status"] == InvocationStatus.FAILED.value
assert response["Error"]["ErrorType"] == "CheckpointError"

mock_logger.exception.assert_called_once()
call_args = mock_logger.exception.call_args
Expand Down
Loading