diff --git a/docs/advanced/error-handling.md b/docs/advanced/error-handling.md index 1123ad7..59f002a 100644 --- a/docs/advanced/error-handling.md +++ b/docs/advanced/error-handling.md @@ -98,7 +98,7 @@ The SDK provides several exception types for different failure scenarios. | `InvocationError` | Yes (by Lambda) | Lambda retries invocation | Transient infrastructure issues | | `CallbackError` | No | Returns FAILED status | Callback handling failures | | `StepInterruptedError` | Yes (automatic) | Retries on next invocation | Step interrupted before checkpoint | -| `CheckpointError` | Depends | Retries if 4xx (except invalid token) | Failed to save execution state | +| `CheckpointError` | Depends | Permanent on 4xx non-429 (except invalid checkpoint token); retries otherwise | Failed to save execution state | | `SerDesError` | No | Returns FAILED status | Serialization failures | ### Base exceptions diff --git a/src/aws_durable_execution_sdk_python/exceptions.py b/src/aws_durable_execution_sdk_python/exceptions.py index 72f0aa0..336996c 100644 --- a/src/aws_durable_execution_sdk_python/exceptions.py +++ b/src/aws_durable_execution_sdk_python/exceptions.py @@ -155,32 +155,25 @@ def from_exception(cls, exception: Exception) -> CheckpointError: error: AwsErrorObj | None = base.error error_category: CheckpointErrorCategory = CheckpointErrorCategory.INVOCATION - # InvalidParameterValueException and error message starts with "Invalid Checkpoint Token" is an InvocationError - # all other 4xx errors are Execution Errors and should be retried - # all 5xx errors are Invocation Errors + # 4xx errors (except 429) are permanent failures (EXECUTION), unless it's an + # InvalidParameterValueException with "Invalid Checkpoint Token" which is retriable (INVOCATION). + # 5xx, 429, and network errors are retriable (INVOCATION). status_code: int | None = (metadata and metadata.get("HTTPStatusCode")) or None if ( status_code - # if we are in 4xx range (except 429) and is not an InvalidParameterValueException with Invalid Checkpoint Token - # then it's an execution error - and status_code < SERVICE_ERROR - and status_code >= BAD_REQUEST_ERROR + and BAD_REQUEST_ERROR <= status_code < SERVICE_ERROR and status_code != TOO_MANY_REQUESTS_ERROR and error - and ( - # is not InvalidParam => Execution - (error.get("Code", "") or "") != "InvalidParameterValueException" - # is not Invalid Token => Execution - or not (error.get("Message") or "").startswith( - "Invalid Checkpoint Token" - ) + and not ( + (error.get("Code") or "") == "InvalidParameterValueException" + and (error.get("Message") or "").startswith("Invalid Checkpoint Token") ) ): error_category = CheckpointErrorCategory.EXECUTION return CheckpointError(str(exception), error_category, error, metadata) def is_retriable(self): - return self.error_category == CheckpointErrorCategory.EXECUTION + return self.error_category == CheckpointErrorCategory.INVOCATION class ValidationError(DurableExecutionsError): diff --git a/tests/exceptions_test.py b/tests/exceptions_test.py index f3ed213..f350b42 100644 --- a/tests/exceptions_test.py +++ b/tests/exceptions_test.py @@ -67,6 +67,23 @@ def test_checkpoint_error_classification_invalid_token_invocation(): result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.INVOCATION + assert result.is_retriable() + + +def test_checkpoint_error_classification_payload_size_exceeded_execution(): + """Test 4xx InvalidParameterValueException with STEP output payload size limit exceeded is execution error.""" + error_response = { + "Error": { + "Code": "InvalidParameterValueException", + "Message": "STEP output payload size must be less than or equal to 262144 bytes.", + }, + "ResponseMetadata": {"HTTPStatusCode": 400}, + } + client_error = ClientError(error_response, "Checkpoint") + + result = CheckpointError.from_exception(client_error) + + assert result.error_category == CheckpointErrorCategory.EXECUTION assert not result.is_retriable() @@ -81,7 +98,7 @@ def test_checkpoint_error_classification_other_4xx_execution(): result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.EXECUTION - assert result.is_retriable() + assert not result.is_retriable() def test_checkpoint_error_classification_429_invocation(): @@ -95,7 +112,7 @@ def test_checkpoint_error_classification_429_invocation(): result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.INVOCATION - assert not result.is_retriable() + assert result.is_retriable() def test_checkpoint_error_classification_invalid_param_without_token_execution(): @@ -112,7 +129,7 @@ def test_checkpoint_error_classification_invalid_param_without_token_execution() result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.EXECUTION - assert result.is_retriable() + assert not result.is_retriable() def test_checkpoint_error_classification_5xx_invocation(): @@ -126,7 +143,7 @@ def test_checkpoint_error_classification_5xx_invocation(): result = CheckpointError.from_exception(client_error) assert result.error_category == CheckpointErrorCategory.INVOCATION - assert not result.is_retriable() + assert result.is_retriable() def test_checkpoint_error_classification_unknown_invocation(): @@ -136,7 +153,7 @@ def test_checkpoint_error_classification_unknown_invocation(): result = CheckpointError.from_exception(unknown_error) assert result.error_category == CheckpointErrorCategory.INVOCATION - assert not result.is_retriable() + assert result.is_retriable() def test_validation_error(): diff --git a/tests/execution_test.py b/tests/execution_test.py index 27f90a4..485d400 100644 --- a/tests/execution_test.py +++ b/tests/execution_test.py @@ -1066,8 +1066,9 @@ def test_handler(event: Any, context: DurableContext) -> dict: # Make the service client checkpoint call fail with CheckpointError mock_client.checkpoint.side_effect = failing_checkpoint - with pytest.raises(CheckpointError, match="Background checkpoint failed"): - test_handler(invocation_input, lambda_context) + response = test_handler(invocation_input, lambda_context) + assert response["Status"] == InvocationStatus.FAILED.value + assert response["Error"]["ErrorType"] == "CheckpointError" # endregion durable_execution @@ -1120,16 +1121,13 @@ def slow_background(): "aws_durable_execution_sdk_python.state.ExecutionState.checkpoint_batches_forever", side_effect=slow_background, ): - with pytest.raises(CheckpointError, match="Checkpoint system failed"): - test_handler(invocation_input, lambda_context) - + response = test_handler(invocation_input, lambda_context) + assert response["Status"] == InvocationStatus.FAILED.value + assert response["Error"]["ErrorType"] == "CheckpointError" -def test_durable_execution_checkpoint_invocation_error_stops_background(): - """Test that CheckpointError handler stops background checkpointing. - When user code raises CheckpointError, the handler should stop the background - thread before re-raising to terminate the Lambda. - """ +def test_durable_execution_checkpoint_invocation_error_retries(): + """Test that CheckpointError with INVOCATION category re-raises to trigger Lambda retry.""" mock_client = Mock(spec=DurableServiceClient) @durable_execution @@ -1171,13 +1169,12 @@ def slow_background(): "aws_durable_execution_sdk_python.state.ExecutionState.checkpoint_batches_forever", side_effect=slow_background, ): - response = test_handler(invocation_input, lambda_context) - assert response["Status"] == InvocationStatus.FAILED.value - assert response["Error"]["ErrorType"] == "CheckpointError" + with pytest.raises(CheckpointError, match="Checkpoint system failed"): + test_handler(invocation_input, lambda_context) -def test_durable_execution_background_thread_execution_error_retries(): - """Test that background thread Execution errors are retried (re-raised).""" +def test_durable_execution_background_thread_execution_error_returns_failed(): + """Test that background thread Execution errors return FAILED (permanent, no retry).""" mock_client = Mock(spec=DurableServiceClient) def failing_checkpoint(*args, **kwargs): @@ -1215,12 +1212,13 @@ def test_handler(event: Any, context: DurableContext) -> dict: mock_client.checkpoint.side_effect = failing_checkpoint - with pytest.raises(CheckpointError, match="Background checkpoint failed"): - test_handler(invocation_input, lambda_context) + response = test_handler(invocation_input, lambda_context) + assert response["Status"] == InvocationStatus.FAILED.value + assert response["Error"]["ErrorType"] == "CheckpointError" -def test_durable_execution_background_thread_invocation_error_returns_failed(): - """Test that background thread Invocation errors return FAILED status.""" +def test_durable_execution_background_thread_invocation_error_retries(): + """Test that background thread Invocation errors re-raise to trigger Lambda retry.""" mock_client = Mock(spec=DurableServiceClient) def failing_checkpoint(*args, **kwargs): @@ -1258,13 +1256,12 @@ def test_handler(event: Any, context: DurableContext) -> dict: mock_client.checkpoint.side_effect = failing_checkpoint - response = test_handler(invocation_input, lambda_context) - assert response["Status"] == InvocationStatus.FAILED.value - assert response["Error"]["ErrorType"] == "CheckpointError" + with pytest.raises(CheckpointError, match="Background checkpoint failed"): + test_handler(invocation_input, lambda_context) -def test_durable_execution_final_success_checkpoint_execution_error_retries(): - """Test that execution errors on final success checkpoint trigger retry.""" +def test_durable_execution_final_success_checkpoint_execution_error_returns_failed(): + """Test that execution errors on final success checkpoint return FAILED (permanent, no retry).""" mock_client = Mock(spec=DurableServiceClient) def failing_final_checkpoint(*args, **kwargs): @@ -1303,12 +1300,13 @@ def test_handler(event: Any, context: DurableContext) -> dict: mock_client.checkpoint.side_effect = failing_final_checkpoint - with pytest.raises(CheckpointError, match="Final checkpoint failed"): - test_handler(invocation_input, lambda_context) + response = test_handler(invocation_input, lambda_context) + assert response["Status"] == InvocationStatus.FAILED.value + assert response["Error"]["ErrorType"] == "CheckpointError" -def test_durable_execution_final_success_checkpoint_invocation_error_returns_failed(): - """Test that invocation errors on final success checkpoint return FAILED.""" +def test_durable_execution_final_success_checkpoint_invocation_error_retries(): + """Test that invocation errors on final success checkpoint re-raise to trigger Lambda retry.""" mock_client = Mock(spec=DurableServiceClient) def failing_final_checkpoint(*args, **kwargs): @@ -1348,14 +1346,12 @@ def test_handler(event: Any, context: DurableContext) -> dict: mock_client.checkpoint.side_effect = failing_final_checkpoint - response = test_handler(invocation_input, lambda_context) - assert response["Status"] == InvocationStatus.FAILED.value - assert response["Error"]["ErrorType"] == "CheckpointError" - assert response["Error"]["ErrorMessage"] == "Final checkpoint failed" + with pytest.raises(CheckpointError, match="Final checkpoint failed"): + test_handler(invocation_input, lambda_context) -def test_durable_execution_final_failure_checkpoint_execution_error_retries(): - """Test that execution errors on final failure checkpoint trigger retry.""" +def test_durable_execution_final_failure_checkpoint_execution_error_returns_failed(): + """Test that execution errors on final failure checkpoint return FAILED (permanent, no retry).""" mock_client = Mock(spec=DurableServiceClient) def failing_final_checkpoint(*args, **kwargs): @@ -1396,12 +1392,13 @@ def test_handler(event: Any, context: DurableContext) -> dict: mock_client.checkpoint.side_effect = failing_final_checkpoint - with pytest.raises(CheckpointError, match="Final checkpoint failed"): - test_handler(invocation_input, lambda_context) + response = test_handler(invocation_input, lambda_context) + assert response["Status"] == InvocationStatus.FAILED.value + assert response["Error"]["ErrorType"] == "CheckpointError" -def test_durable_execution_final_failure_checkpoint_invocation_error_returns_failed(): - """Test that invocation errors on final failure checkpoint return FAILED.""" +def test_durable_execution_final_failure_checkpoint_invocation_error_retries(): + """Test that invocation errors on final failure checkpoint re-raise to trigger Lambda retry.""" mock_client = Mock(spec=DurableServiceClient) def failing_final_checkpoint(*args, **kwargs): @@ -1442,10 +1439,8 @@ def test_handler(event: Any, context: DurableContext) -> dict: mock_client.checkpoint.side_effect = failing_final_checkpoint - response = test_handler(invocation_input, lambda_context) - assert response["Status"] == InvocationStatus.FAILED.value - assert response["Error"]["ErrorType"] == "CheckpointError" - assert response["Error"]["ErrorMessage"] == "Final checkpoint failed" + with pytest.raises(CheckpointError, match="Final checkpoint failed"): + test_handler(invocation_input, lambda_context) def test_durable_handler_background_thread_failure_on_succeed_checkpoint(): @@ -1809,8 +1804,9 @@ def test_handler(event: Any, context: DurableContext) -> dict: mock_client.checkpoint.side_effect = failing_checkpoint with patch("aws_durable_execution_sdk_python.execution.logger", mock_logger): - with pytest.raises(CheckpointError): - test_handler(invocation_input, lambda_context) + response = test_handler(invocation_input, lambda_context) + assert response["Status"] == InvocationStatus.FAILED.value + assert response["Error"]["ErrorType"] == "CheckpointError" mock_logger.exception.assert_called_once() call_args = mock_logger.exception.call_args @@ -1922,8 +1918,9 @@ def test_handler(event: Any, context: DurableContext) -> dict: lambda_context.tenant_id = None with patch("aws_durable_execution_sdk_python.execution.logger", mock_logger): - with pytest.raises(CheckpointError): - test_handler(invocation_input, lambda_context) + response = test_handler(invocation_input, lambda_context) + assert response["Status"] == InvocationStatus.FAILED.value + assert response["Error"]["ErrorType"] == "CheckpointError" mock_logger.exception.assert_called_once() call_args = mock_logger.exception.call_args