diff --git a/sql/schema.sql b/sql/schema.sql index 5109ab7..a0b8d66 100644 --- a/sql/schema.sql +++ b/sql/schema.sql @@ -857,12 +857,16 @@ begin end; $$; +--- Marks a run as failed. +--- If p_force_fail is true, then the retry policy and `p_retry_at` are ignored, +--- and the task is immediately failed (as though it had reached the max retries) create function durable.fail_run ( p_queue_name text, p_task_id uuid, p_run_id uuid, p_reason jsonb, - p_retry_at timestamptz default null + p_retry_at timestamptz default null, + p_force_fail boolean default false ) returns void language plpgsql @@ -941,8 +945,8 @@ begin v_task_state_after := 'failed'; v_recorded_attempt := v_attempt; - -- Compute the next retry time - if v_max_attempts is null or v_next_attempt <= v_max_attempts then + -- Compute the next retry time, unless we're forcing a failure + if (not p_force_fail) and (v_max_attempts is null or v_next_attempt <= v_max_attempts) then if p_retry_at is not null then v_next_available := p_retry_at; else diff --git a/src/context.rs b/src/context.rs index 1827eae..09bf129 100644 --- a/src/context.rs +++ b/src/context.rs @@ -186,7 +186,13 @@ where } // Execute the step - let result = f(params, self.durable.state().clone()).await?; + let result = + f(params, self.durable.state().clone()) + .await + .map_err(|e| TaskError::Step { + base_name: base_name.to_string(), + error: e, + })?; // Persist checkpoint (also extends claim lease) #[cfg(feature = "telemetry")] @@ -262,7 +268,8 @@ where .bind(self.run_id) .bind(self.claim_timeout.as_secs() as i32) .execute(self.durable.pool()) - .await?; + .await + .map_err(TaskError::from_sqlx_error)?; self.checkpoint_cache.insert(name.to_string(), state_json); @@ -301,7 +308,8 @@ where .bind(&checkpoint_name) .bind(duration_ms) .fetch_one(self.durable.pool()) - .await?; + .await + .map_err(TaskError::from_sqlx_error)?; if needs_suspend { return Err(TaskError::Control(ControlFlow::Suspend)); @@ -379,7 +387,8 @@ where .bind(event_name) .bind(timeout_secs) .fetch_one(self.durable.pool()) - .await?; + .await + .map_err(TaskError::from_sqlx_error)?; if result.should_suspend { return Err(TaskError::Control(ControlFlow::Suspend)); @@ -480,7 +489,8 @@ where .bind(self.run_id) .bind(extend_by.as_secs() as i32) .execute(self.durable.pool()) - .await?; + .await + .map_err(TaskError::from_sqlx_error)?; // Notify worker that lease was extended so it can reset timers self.lease_extender.notify(extend_by); @@ -743,7 +753,8 @@ where .bind(&event_name) .bind(None::) // No timeout .fetch_one(self.durable.pool()) - .await?; + .await + .map_err(TaskError::from_sqlx_error)?; if result.should_suspend { return Err(TaskError::Control(ControlFlow::Suspend)); diff --git a/src/error.rs b/src/error.rs index baeaff6..b5a69dc 100644 --- a/src/error.rs +++ b/src/error.rs @@ -128,12 +128,41 @@ pub enum TaskError { error_data: JsonValue, }, - /// An internal error from user task code. - /// - /// This is the catch-all variant for errors propagated via `?` on anyhow errors. - /// For structured user errors, prefer using [`TaskError::user()`]. - #[error(transparent)] - TaskInternal(#[from] anyhow::Error), + //// The user callback provided to `step` failed. + /// We treat this as a non-deterministic error, and will retry the task + #[error("user step `{base_name}` failed: {error}")] + Step { + base_name: String, + error: anyhow::Error, + }, + + /// The task panicked. + #[error("task panicked: {message}")] + TaskPanicked { + /// The error message from the task. + message: String, + }, +} + +impl TaskError { + pub fn retryable(&self) -> bool { + match self { + // These are non-deterministic errors, which might succeed on a retry + // (which will have the same checkpoint cache up to the point of the error) + TaskError::Timeout { .. } | TaskError::Database(_) | TaskError::Step { .. } => true, + // Everything else is considered to be a deterministic error, which will fail again + // on a retry + TaskError::SubtaskSpawnFailed { .. } + | TaskError::EmitEventFailed { .. } + | TaskError::Control(_) + | TaskError::Serialization(_) + | TaskError::ChildFailed { .. } + | TaskError::ChildCancelled { .. } + | TaskError::Validation { .. } + | TaskError::User { .. } + | TaskError::TaskPanicked { .. } => false, + } + } } /// Result type alias for task execution. @@ -183,8 +212,10 @@ impl From for TaskError { } } -impl From for TaskError { - fn from(err: sqlx::Error) -> Self { +impl TaskError { + // This is explicitly *not* a `From for TaskError` impl, + // because we don't want user code to be performing database queries directly. + pub(crate) fn from_sqlx_error(err: sqlx::Error) -> Self { if is_cancelled_error(&err) { TaskError::Control(ControlFlow::Cancelled) } else { @@ -275,11 +306,17 @@ pub fn serialize_task_error(err: &TaskError) -> JsonValue { "error_data": error_data, }) } - TaskError::TaskInternal(e) => { + TaskError::Step { base_name, error } => { serde_json::json!({ - "name": "TaskInternal", - "message": e.to_string(), - "backtrace": format!("{:?}", e) + "name": "Step", + "base_name": base_name, + "message": error.to_string(), + }) + } + TaskError::TaskPanicked { message } => { + serde_json::json!({ + "name": "TaskPanicked", + "message": message, }) } } diff --git a/src/postgres/migrations/20260129181016_add_force_fail_to_fail_run.sql b/src/postgres/migrations/20260129181016_add_force_fail_to_fail_run.sql new file mode 100644 index 0000000..e2e859a --- /dev/null +++ b/src/postgres/migrations/20260129181016_add_force_fail_to_fail_run.sql @@ -0,0 +1,179 @@ +-- Add p_force_fail parameter to durable.fail_run function +-- When p_force_fail is true, the retry policy and p_retry_at are ignored, +-- and the task is immediately failed (as though it had reached the max retries) + +drop function if exists durable.fail_run(text, uuid, uuid, jsonb, timestamptz); + +create function durable.fail_run ( + p_queue_name text, + p_task_id uuid, + p_run_id uuid, + p_reason jsonb, + p_retry_at timestamptz default null, + p_force_fail boolean default false +) + returns void + language plpgsql +as $$ +declare + v_run_task_id uuid; + v_attempt integer; + v_retry_strategy jsonb; + v_max_attempts integer; + v_now timestamptz := durable.current_time(); + v_next_attempt integer; + v_delay_seconds double precision := 0; + v_next_available timestamptz; + v_retry_kind text; + v_base double precision; + v_factor double precision; + v_max_seconds double precision; + v_first_started timestamptz; + v_cancellation jsonb; + v_max_duration bigint; + v_task_state text; + v_task_cancel boolean := false; + v_new_run_id uuid; + v_task_state_after text; + v_recorded_attempt integer; + v_last_attempt_run uuid := p_run_id; + v_cancelled_at timestamptz := null; +begin + -- Lock task first to keep a consistent task -> run lock order. + execute format( + 'select retry_strategy, max_attempts, first_started_at, cancellation, state + from durable.%I + where task_id = $1 + for update', + 't_' || p_queue_name + ) + into v_retry_strategy, v_max_attempts, v_first_started, v_cancellation, v_task_state + using p_task_id; + + if v_task_state is null then + raise exception 'Task "%" not found in queue "%"', p_task_id, p_queue_name; + end if; + + -- Lock run after task and ensure it's still eligible + execute format( + 'select task_id, attempt + from durable.%I + where run_id = $1 + and state in (''running'', ''sleeping'') + for update', + 'r_' || p_queue_name + ) + into v_run_task_id, v_attempt + using p_run_id; + + if v_run_task_id is null then + raise exception 'Run "%" cannot be failed in queue "%"', p_run_id, p_queue_name; + end if; + + if v_run_task_id <> p_task_id then + raise exception 'Run "%" does not belong to task "%"', p_run_id, p_task_id; + end if; + + -- Actually fail the run + execute format( + 'update durable.%I + set state = ''failed'', + wake_event = null, + failed_at = $2, + failure_reason = $3 + where run_id = $1', + 'r_' || p_queue_name + ) using p_run_id, v_now, p_reason; + + v_next_attempt := v_attempt + 1; + v_task_state_after := 'failed'; + v_recorded_attempt := v_attempt; + + -- Compute the next retry time, unless we're forcing a failure + if (not p_force_fail) and (v_max_attempts is null or v_next_attempt <= v_max_attempts) then + if p_retry_at is not null then + v_next_available := p_retry_at; + else + v_retry_kind := coalesce(v_retry_strategy->>'kind', 'none'); + if v_retry_kind = 'fixed' then + v_base := coalesce((v_retry_strategy->>'base_seconds')::double precision, 60); + v_delay_seconds := v_base; + elsif v_retry_kind = 'exponential' then + v_base := coalesce((v_retry_strategy->>'base_seconds')::double precision, 30); + v_factor := coalesce((v_retry_strategy->>'factor')::double precision, 2); + v_delay_seconds := v_base * power(v_factor, greatest(v_attempt - 1, 0)); + v_max_seconds := (v_retry_strategy->>'max_seconds')::double precision; + if v_max_seconds is not null then + v_delay_seconds := least(v_delay_seconds, v_max_seconds); + end if; + else + v_delay_seconds := 0; + end if; + v_next_available := v_now + (v_delay_seconds * interval '1 second'); + end if; + + if v_next_available < v_now then + v_next_available := v_now; + end if; + + if v_cancellation is not null then + v_max_duration := (v_cancellation->>'max_duration')::bigint; + if v_max_duration is not null and v_first_started is not null then + if extract(epoch from (v_next_available - v_first_started)) >= v_max_duration then + v_task_cancel := true; + end if; + end if; + end if; + + -- Set up the new run if not cancelling + if not v_task_cancel then + v_task_state_after := case when v_next_available > v_now then 'sleeping' else 'pending' end; + v_new_run_id := durable.portable_uuidv7(); + v_recorded_attempt := v_next_attempt; + v_last_attempt_run := v_new_run_id; + execute format( + 'insert into durable.%I (run_id, task_id, attempt, state, available_at, wake_event, event_payload, result, failure_reason) + values ($1, $2, $3, %L, $4, null, null, null, null)', + 'r_' || p_queue_name, + v_task_state_after + ) + using v_new_run_id, p_task_id, v_next_attempt, v_next_available; + end if; + end if; + + if v_task_cancel then + v_task_state_after := 'cancelled'; + v_cancelled_at := v_now; + v_recorded_attempt := greatest(v_recorded_attempt, v_attempt); + v_last_attempt_run := p_run_id; + end if; + + execute format( + 'update durable.%I + set state = %L, + attempts = greatest(attempts, $3), + last_attempt_run = $4, + cancelled_at = coalesce(cancelled_at, $5) + where task_id = $1', + 't_' || p_queue_name, + v_task_state_after + ) using p_task_id, v_task_state_after, v_recorded_attempt, v_last_attempt_run, v_cancelled_at; + + -- Delete wait registrations for this run + execute format( + 'delete from durable.%I where run_id = $1', + 'w_' || p_queue_name + ) using p_run_id; + + -- If task reached terminal state, cleanup (emit event, cascade cancel) + if v_task_state_after in ('failed', 'cancelled') then + perform durable.cleanup_task_terminal( + p_queue_name, + p_task_id, + v_task_state_after, + jsonb_build_object('error', p_reason), + true -- cascade cancel children + ); + end if; +end; +$$; diff --git a/src/worker.rs b/src/worker.rs index f7742af..0da6563 100644 --- a/src/worker.rs +++ b/src/worker.rs @@ -309,7 +309,7 @@ impl Worker { durable.queue_name(), task.task_id, task.run_id, - &e.into(), + &TaskError::from_sqlx_error(e), ) .await; return; @@ -412,8 +412,9 @@ impl Worker { Ok(r) => Some(r), Err(e) if e.is_cancelled() => None, // Task was aborted Err(e) => { - tracing::error!("Task {} panicked: {}", task_label, e); - Some(Err(TaskError::TaskInternal(anyhow::anyhow!("Task panicked: {e}")))) + let message = format!("Task {} panicked: {}", task_label, e); + tracing::error!("{}", message); + Some(Err(TaskError::TaskPanicked { message })) } } } @@ -563,13 +564,15 @@ impl Worker { error: &TaskError, ) { let error_json = serialize_task_error(error); - let query = "SELECT durable.fail_run($1, $2, $3, $4, $5)"; + let query = "SELECT durable.fail_run($1, $2, $3, $4, $5, $6)"; + let force_fail = !error.retryable(); if let Err(e) = sqlx::query(query) .bind(queue_name) .bind(task_id) .bind(run_id) .bind(&error_json) .bind(None::>) + .bind(force_fail) .execute(pool) .await { diff --git a/tests/checkpoint_test.rs b/tests/checkpoint_test.rs index 3d43186..b05e890 100644 --- a/tests/checkpoint_test.rs +++ b/tests/checkpoint_test.rs @@ -116,12 +116,12 @@ async fn test_checkpoint_prevents_step_reexecution(pool: PgPool) -> sqlx::Result assert_eq!(terminal2, Some("completed".to_string())); - // Should have 3 checkpoints + // Should have 4 checkpoints (step1, step2, maybe_fail, step3) let checkpoint_count2 = get_checkpoint_count(&pool, "ckpt_replay2", spawn_result2.task_id).await?; assert_eq!( - checkpoint_count2, 3, - "Should have 3 checkpoints for all steps" + checkpoint_count2, 4, + "Should have 4 checkpoints for all steps" ); Ok(()) diff --git a/tests/common/tasks.rs b/tests/common/tasks.rs index c760866..9b84b7c 100644 --- a/tests/common/tasks.rs +++ b/tests/common/tasks.rs @@ -128,13 +128,53 @@ impl Task<()> for FailingTask { type Params = FailingParams; type Output = (); + async fn run( + &self, + params: Self::Params, + mut ctx: TaskContext, + _state: (), + ) -> TaskResult { + ctx.step( + "fail_task", + params.error_message, + |error_message, _| async move { + Err(anyhow::anyhow!("Intentional failure: {}", error_message)) + }, + ) + .await + } +} + +// ============================================================================ +// UserErrorTask - Task that emits a UserError (non-retryable) +// ============================================================================ + +#[allow(dead_code)] +#[derive(Default)] +pub struct UserErrorTask; + +#[allow(dead_code)] +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct UserErrorParams { + pub error_message: String, +} + +#[async_trait] +impl Task<()> for UserErrorTask { + fn name(&self) -> Cow<'static, str> { + Cow::Borrowed("user-error") + } + type Params = UserErrorParams; + type Output = (); + async fn run( &self, params: Self::Params, _ctx: TaskContext, _state: (), ) -> TaskResult { - Err(TaskError::user_message(params.error_message.to_string())) + // Emit a UserError directly - this should not be retried + Err(TaskError::user_message(params.error_message)) } } @@ -308,11 +348,17 @@ impl Task<()> for StepCountingTask { }) .await?; - if params.fail_after_step2 { - return Err(TaskError::user_message( - "Intentional failure after step2".to_string(), - )); - } + ctx.step( + "maybe_fail", + params.fail_after_step2, + |fail_after_step2, _| async move { + if fail_after_step2 { + return Err(anyhow::anyhow!("Intentional failure after step2")); + } + Ok(()) + }, + ) + .await?; let step3_value: String = ctx .step("step3", (), |_, _| async { @@ -1120,9 +1166,14 @@ impl Task<()> for DeterministicReplayTask { } }); - if should_fail { - return Err(TaskError::user_message("First attempt failure".to_string())); - } + // Fail within a step so that the task gets retried + ctx.step("maybe_fail", should_fail, |should_fail, _| async move { + if should_fail { + return Err(anyhow::anyhow!("First attempt failure")); + } + Ok(()) + }) + .await?; } Ok(DeterministicReplayOutput { @@ -1184,11 +1235,13 @@ impl Task<()> for EventThenFailTask { } }); - if should_fail { - return Err(TaskError::user_message( - "First attempt failure after event".to_string(), - )); - } + ctx.step("maybe_fail", should_fail, |should_fail, _| async move { + if should_fail { + return Err(anyhow::anyhow!("First attempt failure")); + } + Ok(()) + }) + .await?; // Second attempt succeeds with the same payload (from checkpoint) Ok(payload) diff --git a/tests/crash_test.rs b/tests/crash_test.rs index bf186b0..7107f75 100644 --- a/tests/crash_test.rs +++ b/tests/crash_test.rs @@ -415,9 +415,9 @@ async fn test_step_idempotency_after_retry(pool: PgPool) -> sqlx::Result<()> { assert_eq!(terminal, Some("completed".to_string())); - // Verify exactly 3 checkpoints (step1, step2, step3) + // Verify exactly 4 checkpoints (step1, step2, maybe_fail, step3) let checkpoint_count = get_checkpoint_count(&pool, "crash_step", spawn_result.task_id).await?; - assert_eq!(checkpoint_count, 3, "Should have exactly 3 checkpoints"); + assert_eq!(checkpoint_count, 4, "Should have exactly 4 checkpoints"); // Verify each checkpoint has unique name (no duplicates) let query = AssertSqlSafe( @@ -429,7 +429,7 @@ async fn test_step_idempotency_after_retry(pool: PgPool) -> sqlx::Result<()> { .fetch_one(&pool) .await?; - assert_eq!(distinct_count, 3, "Each checkpoint name should be unique"); + assert_eq!(distinct_count, 4, "Each checkpoint name should be unique"); Ok(()) } diff --git a/tests/fanout_test.rs b/tests/fanout_test.rs index 37b5036..6312690 100644 --- a/tests/fanout_test.rs +++ b/tests/fanout_test.rs @@ -833,8 +833,7 @@ async fn test_join_timeout_when_parent_claim_expires(pool: PgPool) -> sqlx::Resu assert!( error_name == Some("Timeout") || error_name == Some("ChildCancelled") - || error_name == Some("ChildFailed") - || error_name == Some("TaskInternal"), + || error_name == Some("ChildFailed"), "Expected timeout-related error, got: {:?}", error_name ); diff --git a/tests/retry_test.rs b/tests/retry_test.rs index 937e5d6..63717d8 100644 --- a/tests/retry_test.rs +++ b/tests/retry_test.rs @@ -3,7 +3,7 @@ mod common; use common::helpers::{advance_time, count_runs_for_task, set_fake_time, wait_for_task_terminal}; -use common::tasks::{FailingParams, FailingTask}; +use common::tasks::{FailingParams, FailingTask, UserErrorParams, UserErrorTask}; use durable::{Durable, MIGRATOR, RetryStrategy, SpawnOptions, WorkerOptions}; use sqlx::{AssertSqlSafe, PgPool}; use std::time::Duration; @@ -301,3 +301,71 @@ async fn test_max_attempts_honored(pool: PgPool) -> sqlx::Result<()> { Ok(()) } + +/// Test that a UserError (non-retryable error) does not get retried even with retry strategy. +#[sqlx::test(migrator = "MIGRATOR")] +async fn test_user_error_not_retried(pool: PgPool) -> sqlx::Result<()> { + let client = create_client(pool.clone(), "user_error_no_retry").await; + client.create_queue(None).await.unwrap(); + client.register::().await.unwrap(); + + // Spawn task WITH retry strategy configured - but UserError should still not retry + let spawn_result = client + .spawn_with_options::( + UserErrorParams { + error_message: "User error - should not retry".to_string(), + }, + { + let mut opts = SpawnOptions::default(); + // Configure retry strategy that would normally allow retries + opts.retry_strategy = Some(RetryStrategy::Fixed { + base_delay: Duration::from_secs(0), + }); + opts.max_attempts = Some(5); // Allow up to 5 attempts + opts + }, + ) + .await + .expect("Failed to spawn task"); + + let worker = client + .start_worker(WorkerOptions { + poll_interval: Duration::from_millis(50), + claim_timeout: Duration::from_secs(30), + ..Default::default() + }) + .await + .unwrap(); + + let terminal = wait_for_task_terminal( + &pool, + "user_error_no_retry", + spawn_result.task_id, + Duration::from_secs(5), + ) + .await?; + worker.shutdown().await; + + assert_eq!(terminal, Some("failed".to_string())); + + // Verify only 1 run was created - UserError should NOT trigger retry + let run_count = count_runs_for_task(&pool, "user_error_no_retry", spawn_result.task_id).await?; + assert_eq!( + run_count, 1, + "UserError should not be retried - expected 1 run, got {}", + run_count + ); + + // Verify the task's attempts counter is 1 + let query = AssertSqlSafe( + "SELECT attempts FROM durable.t_user_error_no_retry WHERE task_id = $1".to_string(), + ); + let result: (i32,) = sqlx::query_as(query) + .bind(spawn_result.task_id) + .fetch_one(&pool) + .await?; + + assert_eq!(result.0, 1, "Task should show only 1 attempt"); + + Ok(()) +}