Skip to main content

taquba_workflow/
runtime.rs

1use std::collections::HashMap;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime};
5
6use taquba::{EnqueueOptions, JobRecord, PermanentFailure, Queue, Worker, WorkerError};
7use tokio::sync::Mutex;
8use tokio_util::sync::CancellationToken;
9use tracing::{debug, instrument, warn};
10
11use crate::error::{Error, Result};
12use crate::runner::{Step, StepError, StepErrorKind, StepOutcome, StepRunner};
13use crate::terminal::{RunOutcome, TerminalHook, TerminalStatus};
14
15/// Header key carrying the run identifier on every step job.
16pub const HEADER_RUN_ID: &str = "workflow.run_id";
17/// Header key carrying the zero-based step number on every step job.
18pub const HEADER_STEP: &str = "workflow.step";
19/// Reserved prefix the runtime owns on step-job headers. Submitter-supplied
20/// headers must not start with this prefix; if they do, the runtime treats
21/// them as its own and strips them before invoking the runner.
22pub const RESERVED_HEADER_PREFIX: &str = "workflow.";
23
24const DEDUP_PREFIX: &str = "run:";
25
26/// Per-step enqueue options the runtime forwards through to Taquba. The
27/// runtime always owns `headers` (it injects [`HEADER_RUN_ID`] and
28/// [`HEADER_STEP`]) and `dedup_key` (it derives one from
29/// `(run_id, step_number)`), so callers only pick the three fields below.
30#[derive(Debug, Default)]
31struct StepEnqueueOpts {
32    /// Earliest claimable time for the step. `None` means immediate.
33    run_at: Option<SystemTime>,
34    /// Per-step priority override.
35    priority: Option<u32>,
36    /// Per-step `max_attempts` override.
37    max_attempts: Option<u32>,
38}
39
40/// Spec passed to [`WorkflowRuntime::submit`].
41#[derive(Debug, Clone, Default)]
42pub struct RunSpec {
43    /// Caller-supplied run identifier. If `None`, the runtime generates a
44    /// ULID. The dedup key for the first step job is `run:{run_id}:0`, so
45    /// re-submitting the same `run_id` while the run is active returns the
46    /// existing job rather than creating a duplicate.
47    pub run_id: Option<String>,
48    /// Bytes handed to the runner as the first step's payload.
49    pub input: Vec<u8>,
50    /// Submitter-supplied metadata, threaded through every step of the run
51    /// and surfaced to the terminal hook. Reserved `workflow.*` keys are
52    /// rejected at submission with [`Error::ReservedHeaderInSubmit`].
53    pub headers: HashMap<String, String>,
54    /// Override the queue's default priority for every step of this run.
55    pub priority: Option<u32>,
56    /// Override the queue's `max_attempts` for every step of this run.
57    pub max_attempts_per_step: Option<u32>,
58}
59
60/// Returned by [`WorkflowRuntime::submit`].
61#[derive(Debug, Clone)]
62pub struct RunHandle {
63    /// The run's identifier (generated if the spec didn't carry one).
64    pub run_id: String,
65    /// Taquba job ID of the first enqueued step.
66    pub first_job_id: String,
67}
68
69/// In-memory status snapshot for an active run. Returned by
70/// [`WorkflowRuntime::status`]. Terminal runs are not retained; once the
71/// terminal hook fires, the registry entry is removed.
72#[derive(Debug, Clone)]
73pub struct RunStatus {
74    /// The run's identifier.
75    pub run_id: String,
76    /// Lifecycle state of the run within this runtime process.
77    pub state: RunState,
78    /// Step number of the most recently observed step.
79    pub current_step: u32,
80}
81
82/// Lifecycle state tracked in [`RunStatus::state`].
83#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84#[non_exhaustive]
85pub enum RunState {
86    /// A step job exists in the queue but has not yet been claimed.
87    Pending,
88    /// A step is currently being processed by a worker.
89    Running,
90    /// [`WorkflowRuntime::cancel`] was called for this run and the
91    /// terminal hook has not yet fired. Reported until the in-flight
92    /// step returns and the runtime settles the run as
93    /// [`crate::TerminalStatus::Cancelled`] (entry removed and hook
94    /// fired); after that, [`WorkflowRuntime::status`] returns `None`.
95    ///
96    /// Only set by external cancellation. A pure runner-issued
97    /// [`crate::StepOutcome::Cancel`] (with no external `cancel()`
98    /// call) terminates as `Cancelled` without ever transitioning
99    /// through `Cancelling`: the registry only learns the runner's
100    /// verdict when `run_step` returns, at which point the entry is
101    /// removed.
102    Cancelling,
103}
104
105/// Builder for [`WorkflowRuntime`].
106///
107/// Construct via [`WorkflowRuntime::builder`], which takes the three required
108/// fields (queue, runner, terminal hook) directly so missing-required-field
109/// errors are caught at compile time rather than at `build()`.
110pub struct WorkflowRuntimeBuilder<R, H> {
111    queue: Arc<Queue>,
112    queue_name: String,
113    runner: R,
114    terminal_hook: H,
115    max_concurrent_steps: usize,
116    poll_interval: Duration,
117}
118
119impl<R: StepRunner, H: TerminalHook> WorkflowRuntimeBuilder<R, H> {
120    /// The Taquba queue name that step jobs are enqueued onto. Defaults to
121    /// `"workflow-steps"`. Multiple runtimes can share a `Queue` handle by
122    /// using distinct queue names.
123    pub fn queue_name(mut self, name: impl Into<String>) -> Self {
124        self.queue_name = name.into();
125        self
126    }
127
128    /// Maximum number of steps processed concurrently in [`WorkflowRuntime::run`].
129    /// Defaults to 16.
130    pub fn max_concurrent_steps(mut self, n: usize) -> Self {
131        assert!(n > 0, "max_concurrent_steps must be at least 1");
132        self.max_concurrent_steps = n;
133        self
134    }
135
136    /// Maximum time the worker loop waits on an empty queue before re-checking.
137    /// Defaults to 250ms.
138    pub fn poll_interval(mut self, interval: Duration) -> Self {
139        self.poll_interval = interval;
140        self
141    }
142
143    /// Finalize the builder.
144    pub fn build(self) -> WorkflowRuntime<R, H> {
145        let inner = RuntimeInner {
146            queue: self.queue,
147            queue_name: self.queue_name,
148            runner: self.runner,
149            terminal_hook: self.terminal_hook,
150            max_concurrent_steps: self.max_concurrent_steps,
151            poll_interval: self.poll_interval,
152            registry: Mutex::new(HashMap::new()),
153        };
154        WorkflowRuntime {
155            inner: Arc::new(inner),
156        }
157    }
158}
159
160/// Durable runtime for workflow runs. Cheap to clone (internally `Arc`).
161pub struct WorkflowRuntime<R, H> {
162    inner: Arc<RuntimeInner<R, H>>,
163}
164
165impl<R, H> Clone for WorkflowRuntime<R, H> {
166    fn clone(&self) -> Self {
167        Self {
168            inner: self.inner.clone(),
169        }
170    }
171}
172
173struct RuntimeInner<R, H> {
174    queue: Arc<Queue>,
175    queue_name: String,
176    runner: R,
177    terminal_hook: H,
178    max_concurrent_steps: usize,
179    poll_interval: Duration,
180    registry: Mutex<HashMap<String, RegistryEntry>>,
181}
182
183/// Per-active-run state retained by the runtime. Combines the publicly
184/// observable [`RunStatus`] with the in-process state needed to resolve
185/// [`WorkflowRuntime::cancel`] races: the Taquba job currently
186/// representing the run (so `cancel` can target it), the submitter's
187/// headers (so the terminal hook fires with the right metadata even when
188/// `cancel` fires it directly from a pending step), a flag for any
189/// pending cancellation request, and a [`CancellationToken`] cloned into
190/// the in-flight [`Step`] so runners can short-circuit cooperatively.
191struct RegistryEntry {
192    status: RunStatus,
193    current_job_id: String,
194    user_headers: HashMap<String, String>,
195    cancel_requested: bool,
196    cancel_token: CancellationToken,
197}
198
199impl<R: StepRunner, H: TerminalHook> WorkflowRuntime<R, H> {
200    /// Start configuring a runtime. Takes the three required dependencies
201    /// (Taquba queue, [`StepRunner`], [`TerminalHook`]); optional fields are
202    /// set via [`WorkflowRuntimeBuilder`] methods before [`build`].
203    ///
204    /// Use [`crate::NoopTerminalHook`] if you don't need terminal callbacks.
205    ///
206    /// [`build`]: WorkflowRuntimeBuilder::build
207    pub fn builder(queue: Arc<Queue>, runner: R, terminal_hook: H) -> WorkflowRuntimeBuilder<R, H> {
208        WorkflowRuntimeBuilder {
209            queue,
210            queue_name: "workflow-steps".to_string(),
211            runner,
212            terminal_hook,
213            max_concurrent_steps: 16,
214            poll_interval: Duration::from_millis(250),
215        }
216    }
217
218    /// Submit a new run. Enqueues step 0 with payload `spec.input`. Idempotent
219    /// against in-process duplicates: if a run with the same `run_id` is
220    /// already active in this runtime, returns [`Error::DuplicateRun`].
221    /// Cross-process / cross-restart duplicate-prevention is enforced by
222    /// Taquba's dedup key on the step job.
223    #[instrument(skip(self, spec), fields(run_id))]
224    pub async fn submit(&self, spec: RunSpec) -> Result<RunHandle> {
225        let run_id = spec.run_id.unwrap_or_else(|| ulid::Ulid::new().to_string());
226        tracing::Span::current().record("run_id", run_id.as_str());
227
228        for k in spec.headers.keys() {
229            if k.starts_with(RESERVED_HEADER_PREFIX) {
230                return Err(Error::ReservedHeaderInSubmit(k.clone()));
231            }
232        }
233
234        // Hold the registry lock across enqueue so two concurrent submits
235        // with the same `run_id` can't both pass the duplicate check before
236        // either inserts. Submission is not on a hot path; queue I/O latency
237        // here is acceptable.
238        let mut registry = self.inner.registry.lock().await;
239        if registry.contains_key(&run_id) {
240            return Err(Error::DuplicateRun(run_id));
241        }
242
243        let job_id = self
244            .inner
245            .enqueue_step(
246                &run_id,
247                0,
248                spec.input,
249                &spec.headers,
250                StepEnqueueOpts {
251                    priority: spec.priority,
252                    max_attempts: spec.max_attempts_per_step,
253                    ..Default::default()
254                },
255            )
256            .await?;
257
258        registry.insert(
259            run_id.clone(),
260            RegistryEntry {
261                status: RunStatus {
262                    run_id: run_id.clone(),
263                    state: RunState::Pending,
264                    current_step: 0,
265                },
266                current_job_id: job_id.clone(),
267                user_headers: spec.headers.clone(),
268                cancel_requested: false,
269                cancel_token: CancellationToken::new(),
270            },
271        );
272        drop(registry);
273
274        debug!(run_id = %run_id, job_id = %job_id, "run submitted");
275        Ok(RunHandle {
276            run_id,
277            first_job_id: job_id,
278        })
279    }
280
281    /// Look up the in-process status of a run. Returns `None` for unknown or
282    /// already-terminated runs (the registry only retains active runs).
283    ///
284    /// Returns [`RunState::Cancelling`] for any run with a pending
285    /// cancellation request, regardless of its underlying step lifecycle
286    /// position; the cancellation overlay wins over `Pending`/`Running`
287    /// until the terminal hook fires.
288    pub async fn status(&self, run_id: &str) -> Option<RunStatus> {
289        self.inner.registry.lock().await.get(run_id).map(|e| {
290            let mut status = e.status.clone();
291            if e.cancel_requested {
292                status.state = RunState::Cancelling;
293            }
294            status
295        })
296    }
297
298    /// Request cancellation of an active run.
299    ///
300    /// Returns `Ok(true)` if a cancellation was initiated for `run_id`, or
301    /// `Ok(false)` if the run is not active in this runtime (already
302    /// terminal, never submitted here, or owned by a different runtime
303    /// instance).
304    ///
305    /// The terminal hook fires once with [`TerminalStatus::Cancelled`]:
306    ///
307    /// - **Pending / scheduled step**: the queued step job is cancelled in
308    ///   Taquba and the hook fires from this call before it returns.
309    /// - **Running step**: cancellation is delivered to the runner via
310    ///   [`Step::cancel_token`]; runners that watch the token short-circuit
311    ///   immediately. Runners that ignore the token are allowed to run to
312    ///   completion (futures cannot be safely aborted mid-step). In both
313    ///   cases the runner's [`StepOutcome`] / [`StepError`] is discarded
314    ///   and the hook fires from the worker once the step returns, with
315    ///   any pending transient retry suppressed and the step acked rather
316    ///   than nacked.
317    ///
318    /// Cancellation is best-effort: if the run is already terminal by the
319    /// time `cancel` is called (either because the runner returned a
320    /// terminating [`StepOutcome`] or a prior `cancel` already settled
321    /// it), `cancel` returns `Ok(false)`, the run keeps whatever terminal
322    /// outcome it already delivered, and no additional hook fires.
323    pub async fn cancel(&self, run_id: &str) -> Result<bool> {
324        let (job_id, headers, current_step) = {
325            let mut registry = self.inner.registry.lock().await;
326            let Some(entry) = registry.get_mut(run_id) else {
327                return Ok(false);
328            };
329            entry.cancel_requested = true;
330            // Signal cooperative cancellation. Idempotent on
331            // `CancellationToken`: a second `cancel()` is a no-op. Runners
332            // that watch `step.cancel_token` can short-circuit; runners
333            // that ignore it still get terminated by the worker via the
334            // `cancel_requested` flag after `run_step` returns.
335            entry.cancel_token.cancel();
336            (
337                entry.current_job_id.clone(),
338                entry.user_headers.clone(),
339                entry.status.current_step,
340            )
341        };
342
343        let cancelled_in_queue = self.inner.queue.cancel(&job_id).await?;
344        if cancelled_in_queue {
345            // Job was Pending/Scheduled and is now removed; no worker will
346            // ever see it. Fire the hook here. `error` is `None`: external
347            // cancellation carries no reason at the API level.
348            self.inner
349                .terminate(RunOutcome {
350                    run_id: run_id.to_string(),
351                    status: TerminalStatus::Cancelled,
352                    result: None,
353                    error: None,
354                    headers,
355                    final_step: current_step,
356                })
357                .await;
358        }
359        // Otherwise the job is Claimed (worker has it). The worker will read
360        // `cancel_requested` after `run_step` returns and fire the hook.
361        Ok(true)
362    }
363
364    /// Drive the step worker loop until `shutdown` resolves. Spawns up to
365    /// `max_concurrent_steps` step processors and drains them on shutdown.
366    pub async fn run<F>(&self, shutdown: F) -> Result<()>
367    where
368        F: Future<Output = ()>,
369        R: 'static,
370        H: 'static,
371    {
372        let worker = Arc::new(StepWorker {
373            inner: self.inner.clone(),
374        });
375        taquba::run_worker_concurrent(
376            &self.inner.queue,
377            &self.inner.queue_name,
378            worker,
379            self.inner.max_concurrent_steps,
380            self.inner.poll_interval,
381            shutdown,
382        )
383        .await?;
384        Ok(())
385    }
386}
387
388struct StepWorker<R, H> {
389    inner: Arc<RuntimeInner<R, H>>,
390}
391
392impl<R: StepRunner + 'static, H: TerminalHook + 'static> Worker for StepWorker<R, H> {
393    async fn process(&self, job: &JobRecord) -> std::result::Result<(), WorkerError> {
394        self.inner.process_step(job).await
395    }
396}
397
398impl<R: StepRunner, H: TerminalHook> RuntimeInner<R, H> {
399    async fn enqueue_step(
400        &self,
401        run_id: &str,
402        step_number: u32,
403        payload: Vec<u8>,
404        user_headers: &HashMap<String, String>,
405        opts: StepEnqueueOpts,
406    ) -> Result<String> {
407        let mut headers = user_headers.clone();
408        headers.insert(HEADER_RUN_ID.to_string(), run_id.to_string());
409        headers.insert(HEADER_STEP.to_string(), step_number.to_string());
410
411        let enqueue_opts = EnqueueOptions {
412            headers,
413            run_at: opts.run_at,
414            priority: opts.priority,
415            max_attempts: opts.max_attempts,
416            dedup_key: Some(format!("{DEDUP_PREFIX}{run_id}:{step_number}")),
417        };
418        Ok(self
419            .queue
420            .enqueue_with(&self.queue_name, payload, enqueue_opts)
421            .await?)
422    }
423
424    fn split_headers(headers: &HashMap<String, String>) -> HashMap<String, String> {
425        headers
426            .iter()
427            .filter(|(k, _)| !k.starts_with(RESERVED_HEADER_PREFIX))
428            .map(|(k, v)| (k.clone(), v.clone()))
429            .collect()
430    }
431
432    fn parse_step_headers(job: &JobRecord) -> std::result::Result<(String, u32), Error> {
433        let run_id = job
434            .headers
435            .get(HEADER_RUN_ID)
436            .ok_or(Error::MissingHeader(HEADER_RUN_ID))?
437            .to_string();
438        let step_str = job
439            .headers
440            .get(HEADER_STEP)
441            .ok_or(Error::MissingHeader(HEADER_STEP))?;
442        let step_number: u32 = step_str.parse().map_err(|_| Error::InvalidStepHeader {
443            header: HEADER_STEP,
444            value: step_str.clone(),
445        })?;
446        Ok((run_id, step_number))
447    }
448
449    /// Settle a run into its terminal state: drop its registry entry and
450    /// fire the terminal hook. Removal happens first so that
451    /// [`WorkflowRuntime::status`] doesn't briefly report an
452    /// already-terminated run as active while a slow hook (e.g. a webhook
453    /// delivery) is in flight.
454    async fn terminate(&self, outcome: RunOutcome) {
455        self.registry.lock().await.remove(&outcome.run_id);
456        self.terminal_hook.on_termination(&outcome).await;
457    }
458
459    /// Transition the entry for `run_id` into [`RunState::Running`] for
460    /// `step_number`, recording the Taquba job ID powering the step so a
461    /// concurrent [`WorkflowRuntime::cancel`] can target it. Creates a
462    /// fresh entry if the run is unknown to this runtime (e.g. after a
463    /// restart on another runtime, where the worker first learns of the
464    /// run by claiming its step). Returns the entry's
465    /// [`CancellationToken`] for cloning into the in-flight [`Step`].
466    async fn registry_mark_running(
467        &self,
468        run_id: &str,
469        step_number: u32,
470        job_id: &str,
471        user_headers: &HashMap<String, String>,
472    ) -> CancellationToken {
473        let mut registry = self.registry.lock().await;
474        match registry.get_mut(run_id) {
475            Some(entry) => {
476                entry.status.state = RunState::Running;
477                entry.status.current_step = step_number;
478                entry.current_job_id = job_id.to_string();
479                entry.cancel_token.clone()
480            }
481            None => {
482                let cancel_token = CancellationToken::new();
483                registry.insert(
484                    run_id.to_string(),
485                    RegistryEntry {
486                        status: RunStatus {
487                            run_id: run_id.to_string(),
488                            state: RunState::Running,
489                            current_step: step_number,
490                        },
491                        current_job_id: job_id.to_string(),
492                        user_headers: user_headers.clone(),
493                        cancel_requested: false,
494                        cancel_token: cancel_token.clone(),
495                    },
496                );
497                cancel_token
498            }
499        }
500    }
501
502    async fn process_step(&self, job: &JobRecord) -> std::result::Result<(), WorkerError> {
503        let (run_id, step_number) = match Self::parse_step_headers(job) {
504            Ok(v) => v,
505            Err(e) => {
506                warn!(job_id = %job.id, error = %e, "workflow step has malformed headers");
507                if e.is_permanent() {
508                    return Err(PermanentFailure::new(e.to_string()).into());
509                }
510                return Err(e.to_string().into());
511            }
512        };
513
514        let user_headers = Self::split_headers(&job.headers);
515
516        let cancel_token = self
517            .registry_mark_running(&run_id, step_number, &job.id, &user_headers)
518            .await;
519
520        let step = Step {
521            run_id: run_id.clone(),
522            step_number,
523            payload: job.payload.clone(),
524            headers: user_headers.clone(),
525            job_id: job.id.clone(),
526            attempts: job.attempts,
527            cancel_token,
528        };
529
530        // Preserve the run's per-step priority and max_attempts across the
531        // boundary by re-using the values from the just-processed job.
532        let inherit_opts = || StepEnqueueOpts {
533            run_at: None,
534            priority: Some(job.priority),
535            max_attempts: Some(job.max_attempts),
536        };
537
538        let outcome = self.runner.run_step(&step).await;
539        let external_cancel = self
540            .registry
541            .lock()
542            .await
543            .get(&run_id)
544            .is_some_and(|e| e.cancel_requested);
545
546        // Cancellation precedence:
547        // 1. A runner-issued `StepOutcome::Cancel` wins (it carries an
548        //    in-step reason that we surface on `RunOutcome::error`).
549        // 2. Otherwise an external `WorkflowRuntime::cancel` overrides
550        //    whatever outcome the runner returned (including transient
551        //    retries and permanent dead-letters), with `error: None` so
552        //    consumers can distinguish external vs. runner-issued cancel.
553        match outcome {
554            Ok(StepOutcome::Cancel { reason }) => {
555                self.terminate(RunOutcome {
556                    run_id: run_id.clone(),
557                    status: TerminalStatus::Cancelled,
558                    result: None,
559                    error: Some(reason),
560                    headers: user_headers,
561                    final_step: step_number,
562                })
563                .await;
564                Ok(())
565            }
566            _ if external_cancel => {
567                self.terminate(RunOutcome {
568                    run_id: run_id.clone(),
569                    status: TerminalStatus::Cancelled,
570                    result: None,
571                    error: None,
572                    headers: user_headers,
573                    final_step: step_number,
574                })
575                .await;
576                Ok(())
577            }
578            Ok(StepOutcome::Continue { payload }) => {
579                self.advance(
580                    &run_id,
581                    step_number + 1,
582                    payload,
583                    &user_headers,
584                    inherit_opts(),
585                )
586                .await
587            }
588            Ok(StepOutcome::ContinueAfter { payload, delay }) => {
589                let opts = StepEnqueueOpts {
590                    run_at: Some(SystemTime::now() + delay),
591                    ..inherit_opts()
592                };
593                self.advance(&run_id, step_number + 1, payload, &user_headers, opts)
594                    .await
595            }
596            Ok(StepOutcome::Succeed { result }) => {
597                self.terminate(RunOutcome {
598                    run_id: run_id.clone(),
599                    status: TerminalStatus::Succeeded,
600                    result: Some(result),
601                    error: None,
602                    headers: user_headers,
603                    final_step: step_number,
604                })
605                .await;
606                Ok(())
607            }
608            Ok(StepOutcome::Fail { reason }) => {
609                // Runner verdict: workflow failed but the step itself ran
610                // cleanly. Ack the step (no dead-letter) and fire the hook
611                // with `Failed`.
612                self.terminate(RunOutcome {
613                    run_id: run_id.clone(),
614                    status: TerminalStatus::Failed,
615                    result: None,
616                    error: Some(reason),
617                    headers: user_headers,
618                    final_step: step_number,
619                })
620                .await;
621                Ok(())
622            }
623            Err(StepError {
624                message,
625                kind: StepErrorKind::Permanent,
626            }) => {
627                self.terminate(RunOutcome {
628                    run_id: run_id.clone(),
629                    status: TerminalStatus::Failed,
630                    result: None,
631                    error: Some(message.clone()),
632                    headers: user_headers,
633                    final_step: step_number,
634                })
635                .await;
636                Err(PermanentFailure::new(message).into())
637            }
638            Err(StepError {
639                message,
640                kind: StepErrorKind::Transient,
641            }) => {
642                // Last attempt: this nack will dead-letter. Fire the failure
643                // hook now so the user is notified once, before the job
644                // record disappears from the registry.
645                if job.attempts >= job.max_attempts {
646                    self.terminate(RunOutcome {
647                        run_id: run_id.clone(),
648                        status: TerminalStatus::Failed,
649                        result: None,
650                        error: Some(message.clone()),
651                        headers: user_headers,
652                        final_step: step_number,
653                    })
654                    .await;
655                }
656                Err(message.into())
657            }
658        }
659    }
660
661    async fn advance(
662        &self,
663        run_id: &str,
664        next_step: u32,
665        payload: Vec<u8>,
666        user_headers: &HashMap<String, String>,
667        opts: StepEnqueueOpts,
668    ) -> std::result::Result<(), WorkerError> {
669        match self
670            .enqueue_step(run_id, next_step, payload, user_headers, opts)
671            .await
672        {
673            Ok(new_job_id) => {
674                // Make sure to preserve `cancel_requested`.
675                if let Some(entry) = self.registry.lock().await.get_mut(run_id) {
676                    entry.status.state = RunState::Pending;
677                    entry.status.current_step = next_step;
678                    entry.current_job_id = new_job_id;
679                }
680                Ok(())
681            }
682            // Transient: the runner already executed for this step; failing
683            // the worker triggers a retry of the same step. The runner must be
684            // idempotent for `(run_id, step_number)`.
685            Err(e) => Err(e.to_string().into()),
686        }
687    }
688}
689
690#[cfg(test)]
691mod tests {
692    use super::*;
693    use crate::terminal::NoopTerminalHook;
694    use std::sync::Mutex as StdMutex;
695    use std::sync::atomic::{AtomicU32, Ordering};
696    use taquba::object_store::memory::InMemory;
697    use taquba::{OpenOptions, QueueConfig};
698    use tokio::sync::oneshot;
699
700    /// Recording terminal hook backed by an mpsc channel.
701    struct ChannelHook {
702        tx: tokio::sync::mpsc::UnboundedSender<RunOutcome>,
703    }
704
705    impl TerminalHook for ChannelHook {
706        async fn on_termination(&self, outcome: &RunOutcome) {
707            let _ = self.tx.send(outcome.clone());
708        }
709    }
710
711    /// Runner that executes a fixed list of step outcomes in order.
712    struct ScriptedRunner {
713        script: Arc<StdMutex<Vec<StepOutcome>>>,
714    }
715
716    impl ScriptedRunner {
717        fn new(steps: Vec<StepOutcome>) -> Self {
718            Self {
719                script: Arc::new(StdMutex::new(steps)),
720            }
721        }
722    }
723
724    impl StepRunner for ScriptedRunner {
725        async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
726            let next = self.script.lock().unwrap().remove(0);
727            Ok(next)
728        }
729    }
730
731    async fn fresh_queue() -> Arc<Queue> {
732        Arc::new(
733            Queue::open(Arc::new(InMemory::new()), "test")
734                .await
735                .unwrap(),
736        )
737    }
738
739    /// Queue with zero retry backoff and a tight reaper, so multi-attempt
740    /// tests run in well under a second.
741    async fn fresh_queue_fast_retry() -> Arc<Queue> {
742        let opts = OpenOptions {
743            default_queue_config: QueueConfig {
744                retry_backoff_base: Duration::ZERO,
745                ..QueueConfig::default()
746            },
747            reaper_interval: Duration::from_millis(50),
748            scheduler_interval: Duration::from_millis(50),
749            ..OpenOptions::default()
750        };
751        Arc::new(
752            Queue::open_with_options(Arc::new(InMemory::new()), "test", opts)
753                .await
754                .unwrap(),
755        )
756    }
757
758    fn spawn_runtime<R, H>(runtime: WorkflowRuntime<R, H>) -> oneshot::Sender<()>
759    where
760        R: StepRunner + 'static,
761        H: TerminalHook + 'static,
762    {
763        let (tx, rx) = oneshot::channel::<()>();
764        tokio::spawn(async move {
765            let _ = runtime
766                .run(async move {
767                    let _ = rx.await;
768                })
769                .await;
770        });
771        tx
772    }
773
774    #[tokio::test]
775    async fn single_step_succeeds_and_fires_hook() {
776        let queue = fresh_queue().await;
777        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
778        let runtime = WorkflowRuntime::builder(
779            queue,
780            ScriptedRunner::new(vec![StepOutcome::Succeed {
781                result: b"done".to_vec(),
782            }]),
783            ChannelHook { tx },
784        )
785        .build();
786        let shutdown = spawn_runtime(runtime.clone());
787
788        let handle = runtime
789            .submit(RunSpec {
790                input: b"in".to_vec(),
791                ..Default::default()
792            })
793            .await
794            .unwrap();
795        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
796            .await
797            .unwrap()
798            .unwrap();
799
800        assert_eq!(outcome.run_id, handle.run_id);
801        assert_eq!(outcome.status, TerminalStatus::Succeeded);
802        assert_eq!(outcome.result.as_deref(), Some(b"done".as_slice()));
803        assert_eq!(outcome.final_step, 0);
804        assert!(runtime.status(&handle.run_id).await.is_none());
805
806        let _ = shutdown.send(());
807    }
808
809    #[tokio::test]
810    async fn multi_step_run_advances_through_continue() {
811        let queue = fresh_queue().await;
812        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
813        let runtime = WorkflowRuntime::builder(
814            queue,
815            ScriptedRunner::new(vec![
816                StepOutcome::Continue {
817                    payload: b"step1".to_vec(),
818                },
819                StepOutcome::Continue {
820                    payload: b"step2".to_vec(),
821                },
822                StepOutcome::Succeed {
823                    result: b"final".to_vec(),
824                },
825            ]),
826            ChannelHook { tx },
827        )
828        .build();
829        let shutdown = spawn_runtime(runtime.clone());
830
831        let handle = runtime
832            .submit(RunSpec {
833                input: b"start".to_vec(),
834                ..Default::default()
835            })
836            .await
837            .unwrap();
838        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
839            .await
840            .unwrap()
841            .unwrap();
842
843        assert_eq!(outcome.run_id, handle.run_id);
844        assert_eq!(outcome.final_step, 2);
845        assert_eq!(outcome.status, TerminalStatus::Succeeded);
846        assert_eq!(outcome.result.as_deref(), Some(b"final".as_slice()));
847
848        let _ = shutdown.send(());
849    }
850
851    #[tokio::test]
852    async fn permanent_failure_dead_letters_and_fires_hook() {
853        struct FailingRunner;
854        impl StepRunner for FailingRunner {
855            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
856                Err(StepError::permanent("nope"))
857            }
858        }
859
860        let queue = fresh_queue().await;
861        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
862        let runtime =
863            WorkflowRuntime::builder(queue.clone(), FailingRunner, ChannelHook { tx }).build();
864        let shutdown = spawn_runtime(runtime.clone());
865
866        let handle = runtime
867            .submit(RunSpec {
868                input: b"x".to_vec(),
869                ..Default::default()
870            })
871            .await
872            .unwrap();
873        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
874            .await
875            .unwrap()
876            .unwrap();
877
878        assert_eq!(outcome.run_id, handle.run_id);
879        assert_eq!(outcome.status, TerminalStatus::Failed);
880        assert_eq!(outcome.error.as_deref(), Some("nope"));
881        assert!(runtime.status(&handle.run_id).await.is_none());
882
883        // Permanent runner errors *do* dead-letter the step.
884        let stats = queue.stats("workflow-steps").await.unwrap();
885        assert_eq!(stats.dead, 1, "permanent error should dead-letter");
886
887        let _ = shutdown.send(());
888    }
889
890    #[tokio::test]
891    async fn fail_outcome_terminates_run_without_dead_letter() {
892        // StepOutcome::Fail is the runner's *verdict* path, not an
893        // infrastructure error: the hook fires with Failed, the registry
894        // entry is cleaned up, but the step is acked normally so no dead
895        // job is left behind for operators to inspect.
896        struct VerdictRunner;
897        impl StepRunner for VerdictRunner {
898            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
899                Ok(StepOutcome::Fail {
900                    reason: "agent declined the task".to_string(),
901                })
902            }
903        }
904
905        let queue = fresh_queue().await;
906        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
907        let runtime =
908            WorkflowRuntime::builder(queue.clone(), VerdictRunner, ChannelHook { tx }).build();
909        let shutdown = spawn_runtime(runtime.clone());
910
911        let handle = runtime
912            .submit(RunSpec {
913                input: b"x".to_vec(),
914                ..Default::default()
915            })
916            .await
917            .unwrap();
918
919        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
920            .await
921            .expect("hook fired in time")
922            .expect("hook channel open");
923
924        assert_eq!(outcome.run_id, handle.run_id);
925        assert_eq!(outcome.status, TerminalStatus::Failed);
926        assert_eq!(outcome.error.as_deref(), Some("agent declined the task"));
927        assert!(runtime.status(&handle.run_id).await.is_none());
928
929        // Crucially: no dead-letter, distinguishing runner verdict from
930        // infrastructure failure at the queue level.
931        let stats = queue.stats("workflow-steps").await.unwrap();
932        assert_eq!(stats.dead, 0, "Fail verdict must not dead-letter");
933
934        let _ = shutdown.send(());
935    }
936
937    #[tokio::test]
938    async fn duplicate_submit_in_process_is_rejected() {
939        // Pause forever on the first step so the run stays active in the
940        // registry while we attempt the duplicate submit.
941        struct PauseRunner;
942        impl StepRunner for PauseRunner {
943            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
944                std::future::pending().await
945            }
946        }
947
948        let queue = fresh_queue().await;
949        let runtime = WorkflowRuntime::builder(queue, PauseRunner, NoopTerminalHook).build();
950        let shutdown = spawn_runtime(runtime.clone());
951
952        let handle = runtime
953            .submit(RunSpec {
954                run_id: Some("fixed-id".to_string()),
955                input: b"x".to_vec(),
956                ..Default::default()
957            })
958            .await
959            .unwrap();
960        // Wait for the worker to start the step so the registry observes the
961        // run as Running (or at least Pending).
962        for _ in 0..40 {
963            if runtime.status(&handle.run_id).await.is_some() {
964                break;
965            }
966            tokio::time::sleep(Duration::from_millis(25)).await;
967        }
968        assert!(runtime.status(&handle.run_id).await.is_some());
969
970        let err = runtime
971            .submit(RunSpec {
972                run_id: Some("fixed-id".to_string()),
973                input: b"y".to_vec(),
974                ..Default::default()
975            })
976            .await
977            .unwrap_err();
978        assert!(matches!(err, Error::DuplicateRun(id) if id == "fixed-id"));
979
980        let _ = shutdown.send(());
981    }
982
983    #[tokio::test]
984    async fn reserved_header_on_submit_is_rejected() {
985        let queue = fresh_queue().await;
986        let runtime: WorkflowRuntime<ScriptedRunner, NoopTerminalHook> =
987            WorkflowRuntime::builder(queue, ScriptedRunner::new(vec![]), NoopTerminalHook).build();
988        let mut headers = HashMap::new();
989        headers.insert("workflow.run_id".to_string(), "evil".to_string());
990
991        let err = runtime
992            .submit(RunSpec {
993                input: b"x".to_vec(),
994                headers,
995                ..Default::default()
996            })
997            .await
998            .unwrap_err();
999        assert!(
1000            matches!(&err, Error::ReservedHeaderInSubmit(k) if k == "workflow.run_id"),
1001            "got: {err:?}"
1002        );
1003    }
1004
1005    #[tokio::test]
1006    async fn user_headers_thread_through_to_terminal_hook() {
1007        let queue = fresh_queue().await;
1008        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1009        let runtime = WorkflowRuntime::builder(
1010            queue,
1011            ScriptedRunner::new(vec![
1012                StepOutcome::Continue { payload: vec![] },
1013                StepOutcome::Succeed { result: vec![] },
1014            ]),
1015            ChannelHook { tx },
1016        )
1017        .build();
1018        let shutdown = spawn_runtime(runtime.clone());
1019
1020        let mut headers = HashMap::new();
1021        headers.insert("trace_id".to_string(), "abc-123".to_string());
1022        headers.insert("tenant".to_string(), "acme".to_string());
1023
1024        runtime
1025            .submit(RunSpec {
1026                input: b"x".to_vec(),
1027                headers,
1028                ..Default::default()
1029            })
1030            .await
1031            .unwrap();
1032        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1033            .await
1034            .unwrap()
1035            .unwrap();
1036
1037        assert_eq!(outcome.headers.get("trace_id").unwrap(), "abc-123");
1038        assert_eq!(outcome.headers.get("tenant").unwrap(), "acme");
1039        // Reserved keys must not leak through.
1040        assert!(!outcome.headers.contains_key(HEADER_RUN_ID));
1041        assert!(!outcome.headers.contains_key(HEADER_STEP));
1042
1043        let _ = shutdown.send(());
1044    }
1045
1046    #[tokio::test]
1047    async fn restart_resumes_at_next_step() {
1048        // Headline durability test: after step 0 has acked and step 1 is in
1049        // the queue, kill runtime A entirely and spawn runtime B on the same
1050        // Queue handle. B should claim and complete step 1 without re-running
1051        // step 0.
1052        //
1053        // To make this race-free we gate step 0's runner: the test holds the
1054        // gate while signalling shutdown to A so A enters drain mode without
1055        // ever claiming step 1. Then the gate is opened, A's spawned step-0
1056        // task finishes (enqueueing step 1 + acking step 0) and A exits.
1057        struct GatedRunner {
1058            gate: tokio::sync::Mutex<Option<oneshot::Receiver<Vec<u8>>>>,
1059        }
1060
1061        impl StepRunner for GatedRunner {
1062            async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
1063                match step.step_number {
1064                    0 => {
1065                        let rx = self.gate.lock().await.take().expect("gate consumed twice");
1066                        let payload = rx.await.expect("gate sender dropped");
1067                        Ok(StepOutcome::Continue { payload })
1068                    }
1069                    _ => std::future::pending().await,
1070                }
1071            }
1072        }
1073
1074        struct CompleteOnStep1;
1075        impl StepRunner for CompleteOnStep1 {
1076            async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
1077                assert_eq!(step.step_number, 1, "runtime B should only ever see step 1");
1078                assert_eq!(step.payload.as_slice(), b"step1-payload");
1079                Ok(StepOutcome::Succeed {
1080                    result: b"resumed".to_vec(),
1081                })
1082            }
1083        }
1084
1085        let queue = fresh_queue().await;
1086
1087        let (gate_tx, gate_rx) = oneshot::channel::<Vec<u8>>();
1088        let runtime_a = WorkflowRuntime::builder(
1089            queue.clone(),
1090            GatedRunner {
1091                gate: tokio::sync::Mutex::new(Some(gate_rx)),
1092            },
1093            NoopTerminalHook,
1094        )
1095        .max_concurrent_steps(1)
1096        .build();
1097
1098        let (shutdown_a_tx, shutdown_a_rx) = oneshot::channel::<()>();
1099        let worker_a = {
1100            let runtime_a = runtime_a.clone();
1101            tokio::spawn(async move {
1102                let _ = runtime_a
1103                    .run(async move {
1104                        let _ = shutdown_a_rx.await;
1105                    })
1106                    .await;
1107            })
1108        };
1109
1110        let handle = runtime_a
1111            .submit(RunSpec {
1112                input: b"input".to_vec(),
1113                ..Default::default()
1114            })
1115            .await
1116            .unwrap();
1117
1118        // Wait for runtime A to claim step 0 and reach the gate (registry
1119        // shows Running for step 0).
1120        for _ in 0..80 {
1121            if let Some(s) = runtime_a.status(&handle.run_id).await {
1122                if s.state == RunState::Running && s.current_step == 0 {
1123                    break;
1124                }
1125            }
1126            tokio::time::sleep(Duration::from_millis(25)).await;
1127        }
1128        let s = runtime_a.status(&handle.run_id).await.expect("status");
1129        assert_eq!(s.state, RunState::Running);
1130        assert_eq!(s.current_step, 0);
1131
1132        // A's worker is in the at-capacity select-loop. Signal shutdown
1133        // first, then open the gate so step 0 finishes processing inside
1134        // drain mode (A will not claim step 1).
1135        let _ = shutdown_a_tx.send(());
1136        let _ = gate_tx.send(b"step1-payload".to_vec());
1137
1138        worker_a.await.expect("runtime A drained cleanly");
1139
1140        // Bring up runtime B on the same Queue handle. It should pick up
1141        // step 1 from where A left off.
1142        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1143        let runtime_b =
1144            WorkflowRuntime::builder(queue, CompleteOnStep1, ChannelHook { tx }).build();
1145        let shutdown_b = spawn_runtime(runtime_b.clone());
1146
1147        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1148            .await
1149            .expect("hook fired in time")
1150            .expect("hook channel open");
1151
1152        assert_eq!(outcome.run_id, handle.run_id);
1153        assert_eq!(outcome.status, TerminalStatus::Succeeded);
1154        assert_eq!(outcome.result.as_deref(), Some(b"resumed".as_slice()));
1155        assert_eq!(outcome.final_step, 1);
1156
1157        let _ = shutdown_b.send(());
1158    }
1159
1160    /// Submits a run whose runner always returns
1161    /// [`StepError::transient`], capped at `max_attempts`. Asserts the
1162    /// runner is invoked exactly `max_attempts` times (per-step max-attempts
1163    /// propagation) and that the terminal hook fires Failed exactly once on
1164    /// the final attempt (fire-once-on-last-attempt logic).
1165    async fn assert_transient_retries_until_max(max_attempts: u32) {
1166        struct AlwaysTransient {
1167            calls: Arc<AtomicU32>,
1168        }
1169        impl StepRunner for AlwaysTransient {
1170            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1171                self.calls.fetch_add(1, Ordering::SeqCst);
1172                Err(StepError::transient("flaky"))
1173            }
1174        }
1175
1176        let queue = fresh_queue_fast_retry().await;
1177        let calls = Arc::new(AtomicU32::new(0));
1178        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1179        let runtime = WorkflowRuntime::builder(
1180            queue,
1181            AlwaysTransient {
1182                calls: calls.clone(),
1183            },
1184            ChannelHook { tx },
1185        )
1186        .build();
1187        let shutdown = spawn_runtime(runtime.clone());
1188
1189        runtime
1190            .submit(RunSpec {
1191                input: b"x".to_vec(),
1192                max_attempts_per_step: Some(max_attempts),
1193                ..Default::default()
1194            })
1195            .await
1196            .unwrap();
1197
1198        let outcome = tokio::time::timeout(Duration::from_secs(3), rx.recv())
1199            .await
1200            .expect("hook fired in time")
1201            .expect("hook channel open");
1202
1203        assert_eq!(outcome.status, TerminalStatus::Failed);
1204        assert_eq!(outcome.error.as_deref(), Some("flaky"));
1205        assert_eq!(
1206            calls.load(Ordering::SeqCst),
1207            max_attempts,
1208            "runner called once per attempt up to max_attempts"
1209        );
1210
1211        // Settle window: assert no duplicate hook fires after the terminal one.
1212        tokio::time::sleep(Duration::from_millis(50)).await;
1213        assert!(rx.try_recv().is_err(), "hook fired more than once");
1214
1215        let _ = shutdown.send(());
1216    }
1217
1218    #[tokio::test]
1219    async fn cancel_outcome_terminates_run_without_dead_letter() {
1220        // `StepOutcome::Cancel` is the runner's cancellation verdict path:
1221        // the hook fires with Cancelled, the registry is cleaned up, the
1222        // step is acked, and no dead job is left behind.
1223        struct CancellingRunner;
1224        impl StepRunner for CancellingRunner {
1225            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1226                Ok(StepOutcome::Cancel {
1227                    reason: "upstream aborted".to_string(),
1228                })
1229            }
1230        }
1231
1232        let queue = fresh_queue().await;
1233        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1234        let runtime =
1235            WorkflowRuntime::builder(queue.clone(), CancellingRunner, ChannelHook { tx }).build();
1236        let shutdown = spawn_runtime(runtime.clone());
1237
1238        let handle = runtime
1239            .submit(RunSpec {
1240                input: b"x".to_vec(),
1241                ..Default::default()
1242            })
1243            .await
1244            .unwrap();
1245
1246        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1247            .await
1248            .expect("hook fired in time")
1249            .expect("hook channel open");
1250
1251        assert_eq!(outcome.run_id, handle.run_id);
1252        assert_eq!(outcome.status, TerminalStatus::Cancelled);
1253        assert_eq!(outcome.error.as_deref(), Some("upstream aborted"));
1254        assert!(runtime.status(&handle.run_id).await.is_none());
1255
1256        let stats = queue.stats("workflow-steps").await.unwrap();
1257        assert_eq!(stats.dead, 0, "Cancel verdict must not dead-letter");
1258
1259        let _ = shutdown.send(());
1260    }
1261
1262    #[tokio::test]
1263    async fn cancel_pending_run_fires_cancelled_hook() {
1264        // Pending case: a run sits in the queue, we call `cancel()` before
1265        // any worker claims it. The hook fires from `cancel` itself.
1266        struct UnreachableRunner;
1267        impl StepRunner for UnreachableRunner {
1268            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1269                unreachable!("worker must not claim the cancelled step");
1270            }
1271        }
1272
1273        let queue = fresh_queue().await;
1274        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1275        let runtime =
1276            WorkflowRuntime::builder(queue.clone(), UnreachableRunner, ChannelHook { tx }).build();
1277        // Note: deliberately do NOT spawn the worker loop, so the submitted
1278        // step stays Pending in the queue while we cancel it.
1279
1280        let mut headers = HashMap::new();
1281        headers.insert("tenant".to_string(), "acme".to_string());
1282
1283        let handle = runtime
1284            .submit(RunSpec {
1285                input: b"x".to_vec(),
1286                headers,
1287                ..Default::default()
1288            })
1289            .await
1290            .unwrap();
1291        let status = runtime.status(&handle.run_id).await.expect("active");
1292        assert_eq!(status.state, RunState::Pending);
1293
1294        let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1295        assert!(was_cancelled);
1296
1297        let outcome = tokio::time::timeout(Duration::from_secs(1), rx.recv())
1298            .await
1299            .expect("hook fired in time")
1300            .expect("hook channel open");
1301        assert_eq!(outcome.run_id, handle.run_id);
1302        assert_eq!(outcome.status, TerminalStatus::Cancelled);
1303        // External cancellation carries no reason: `error` is `None`.
1304        assert!(outcome.error.is_none());
1305        assert_eq!(outcome.headers.get("tenant").unwrap(), "acme");
1306        assert!(runtime.status(&handle.run_id).await.is_none());
1307
1308        let stats = queue.stats("workflow-steps").await.unwrap();
1309        assert_eq!(stats.dead, 0, "cancel must not dead-letter");
1310        assert_eq!(stats.pending, 0, "cancelled job must be removed");
1311    }
1312
1313    #[tokio::test]
1314    async fn cancel_during_running_step_overrides_outcome() {
1315        // Running case: the step is in-flight when cancel is called. The
1316        // runner's eventual outcome is discarded; the worker fires Cancelled.
1317        struct GatedRunner {
1318            claimed: Arc<tokio::sync::Notify>,
1319            gate: tokio::sync::Mutex<Option<oneshot::Receiver<()>>>,
1320        }
1321        impl StepRunner for GatedRunner {
1322            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1323                self.claimed.notify_one();
1324                let rx = self.gate.lock().await.take().expect("gate consumed twice");
1325                let _ = rx.await;
1326                // The runner "successfully completes" the step, but cancel
1327                // was requested mid-flight so the outcome should be ignored
1328                // and the hook should fire Cancelled instead.
1329                Ok(StepOutcome::Succeed {
1330                    result: b"would-have-succeeded".to_vec(),
1331                })
1332            }
1333        }
1334
1335        let queue = fresh_queue().await;
1336        let claimed = Arc::new(tokio::sync::Notify::new());
1337        let (gate_tx, gate_rx) = oneshot::channel::<()>();
1338        let (hook_tx, mut hook_rx) = tokio::sync::mpsc::unbounded_channel();
1339        let runtime = WorkflowRuntime::builder(
1340            queue.clone(),
1341            GatedRunner {
1342                claimed: claimed.clone(),
1343                gate: tokio::sync::Mutex::new(Some(gate_rx)),
1344            },
1345            ChannelHook { tx: hook_tx },
1346        )
1347        .build();
1348        let shutdown = spawn_runtime(runtime.clone());
1349
1350        let handle = runtime
1351            .submit(RunSpec {
1352                input: b"x".to_vec(),
1353                ..Default::default()
1354            })
1355            .await
1356            .unwrap();
1357        tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1358            .await
1359            .expect("runner reached gate");
1360
1361        let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1362        assert!(was_cancelled);
1363
1364        // Let the runner finish. The worker should observe `cancel_requested`
1365        // and fire Cancelled rather than advancing or firing Succeeded.
1366        let _ = gate_tx.send(());
1367
1368        let outcome = tokio::time::timeout(Duration::from_secs(2), hook_rx.recv())
1369            .await
1370            .expect("hook fired")
1371            .expect("hook channel open");
1372        assert_eq!(outcome.status, TerminalStatus::Cancelled);
1373        assert!(
1374            outcome.result.is_none(),
1375            "succeed payload must be discarded"
1376        );
1377        assert!(runtime.status(&handle.run_id).await.is_none());
1378
1379        let stats = queue.stats("workflow-steps").await.unwrap();
1380        assert_eq!(stats.dead, 0);
1381
1382        let _ = shutdown.send(());
1383    }
1384
1385    /// Drive a single step that blocks on a gate, calls `cancel(run_id)`
1386    /// while the step is in-flight, and then has the runner return the
1387    /// supplied error. Asserts that external cancellation suppresses the
1388    /// error path entirely: the hook fires `Cancelled` (not `Failed`),
1389    /// no dead-letter is produced regardless of `permanent`/`transient`,
1390    /// and the worker returns `Ok` (no retry, no PermanentFailure
1391    /// propagation).
1392    async fn assert_cancel_suppresses_runner_error(error: StepError) {
1393        struct GatedErrRunner {
1394            claimed: Arc<tokio::sync::Notify>,
1395            gate: tokio::sync::Mutex<Option<oneshot::Receiver<()>>>,
1396            calls: Arc<AtomicU32>,
1397            error: StdMutex<Option<StepError>>,
1398        }
1399        impl StepRunner for GatedErrRunner {
1400            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1401                self.calls.fetch_add(1, Ordering::SeqCst);
1402                self.claimed.notify_one();
1403                let rx = self.gate.lock().await.take().expect("gate consumed twice");
1404                let _ = rx.await;
1405                Err(self
1406                    .error
1407                    .lock()
1408                    .unwrap()
1409                    .take()
1410                    .expect("error consumed twice"))
1411            }
1412        }
1413
1414        let queue = fresh_queue_fast_retry().await;
1415        let claimed = Arc::new(tokio::sync::Notify::new());
1416        let calls = Arc::new(AtomicU32::new(0));
1417        let (gate_tx, gate_rx) = oneshot::channel::<()>();
1418        let (hook_tx, mut hook_rx) = tokio::sync::mpsc::unbounded_channel();
1419        let runtime = WorkflowRuntime::builder(
1420            queue.clone(),
1421            GatedErrRunner {
1422                claimed: claimed.clone(),
1423                gate: tokio::sync::Mutex::new(Some(gate_rx)),
1424                calls: calls.clone(),
1425                error: StdMutex::new(Some(error)),
1426            },
1427            ChannelHook { tx: hook_tx },
1428        )
1429        .build();
1430        let shutdown = spawn_runtime(runtime.clone());
1431
1432        let handle = runtime
1433            .submit(RunSpec {
1434                input: b"x".to_vec(),
1435                ..Default::default()
1436            })
1437            .await
1438            .unwrap();
1439        tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1440            .await
1441            .expect("runner reached gate");
1442
1443        let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1444        assert!(was_cancelled);
1445
1446        // Release the runner. It returns Err; without cancellation this
1447        // would either dead-letter (permanent) or nack for retry
1448        // (transient). Cancellation must suppress both.
1449        let _ = gate_tx.send(());
1450
1451        let outcome = tokio::time::timeout(Duration::from_secs(2), hook_rx.recv())
1452            .await
1453            .expect("hook fired")
1454            .expect("hook channel open");
1455        assert_eq!(outcome.status, TerminalStatus::Cancelled);
1456        assert!(
1457            outcome.error.is_none(),
1458            "external cancel must carry no reason (Some(_) would imply runner-issued StepOutcome::Cancel)",
1459        );
1460        assert!(runtime.status(&handle.run_id).await.is_none());
1461
1462        // Settle window: assert no retry attempt and no dead-letter or
1463        // duplicate hook fires after the terminal one.
1464        tokio::time::sleep(Duration::from_millis(100)).await;
1465        assert_eq!(
1466            calls.load(Ordering::SeqCst),
1467            1,
1468            "cancellation must suppress retries",
1469        );
1470        let stats = queue.stats("workflow-steps").await.unwrap();
1471        assert_eq!(stats.dead, 0, "cancellation must suppress dead-letter");
1472        assert!(
1473            hook_rx.try_recv().is_err(),
1474            "hook must fire exactly once for the cancelled run",
1475        );
1476
1477        let _ = shutdown.send(());
1478    }
1479
1480    #[tokio::test]
1481    async fn cancel_suppresses_permanent_runner_error() {
1482        // Without cancellation, `StepError::permanent` dead-letters the
1483        // step and causes the worker to return `PermanentFailure`. With
1484        // an external cancel in flight, the worker must ack and fire
1485        // `Cancelled` instead.
1486        assert_cancel_suppresses_runner_error(StepError::permanent("would-dead-letter")).await;
1487    }
1488
1489    #[tokio::test]
1490    async fn cancel_suppresses_transient_runner_error() {
1491        // Without cancellation, `StepError::transient` nacks for retry
1492        // (and eventually dead-letters). With an external cancel in
1493        // flight, the worker must ack and fire `Cancelled` without
1494        // re-invoking the runner.
1495        assert_cancel_suppresses_runner_error(StepError::transient("would-retry")).await;
1496    }
1497
1498    #[tokio::test]
1499    async fn cancel_signals_step_token_for_cooperative_short_circuit() {
1500        // A runner that watches `step.cancel_token` should short-circuit
1501        // long after-claim work as soon as `WorkflowRuntime::cancel` is
1502        // called. Without the token, cancellation latency is bounded by
1503        // step duration; with it, the runner returns essentially
1504        // immediately. The test pins this by using a step that would
1505        // otherwise sleep for 30 seconds; if the token didn't fire, the
1506        // test would time out.
1507        struct CooperativeRunner {
1508            claimed: Arc<tokio::sync::Notify>,
1509        }
1510        impl StepRunner for CooperativeRunner {
1511            async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
1512                self.claimed.notify_one();
1513                tokio::select! {
1514                    _ = tokio::time::sleep(Duration::from_secs(30)) => {
1515                        Ok(StepOutcome::Succeed { result: b"slow".to_vec() })
1516                    }
1517                    _ = step.cancel_token.cancelled() => {
1518                        Ok(StepOutcome::Cancel { reason: "cooperative".to_string() })
1519                    }
1520                }
1521            }
1522        }
1523
1524        let queue = fresh_queue().await;
1525        let claimed = Arc::new(tokio::sync::Notify::new());
1526        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1527        let runtime = WorkflowRuntime::builder(
1528            queue.clone(),
1529            CooperativeRunner {
1530                claimed: claimed.clone(),
1531            },
1532            ChannelHook { tx },
1533        )
1534        .build();
1535        let shutdown = spawn_runtime(runtime.clone());
1536
1537        let handle = runtime
1538            .submit(RunSpec {
1539                input: b"x".to_vec(),
1540                ..Default::default()
1541            })
1542            .await
1543            .unwrap();
1544        tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1545            .await
1546            .expect("runner observed token");
1547
1548        let start = std::time::Instant::now();
1549        let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1550        assert!(was_cancelled);
1551
1552        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1553            .await
1554            .expect("hook fired well before the 30s sleep would have")
1555            .expect("hook channel open");
1556        let elapsed = start.elapsed();
1557
1558        assert_eq!(outcome.status, TerminalStatus::Cancelled);
1559        // Runner-issued Cancel wins precedence over external cancel, so
1560        // the runner's reason surfaces.
1561        assert_eq!(outcome.error.as_deref(), Some("cooperative"));
1562        assert!(
1563            elapsed < Duration::from_secs(2),
1564            "cooperative cancel must short-circuit the 30s sleep (took {elapsed:?})",
1565        );
1566        assert!(runtime.status(&handle.run_id).await.is_none());
1567
1568        let stats = queue.stats("workflow-steps").await.unwrap();
1569        assert_eq!(stats.dead, 0);
1570
1571        let _ = shutdown.send(());
1572    }
1573
1574    #[tokio::test]
1575    async fn double_cancel_fires_hook_once_and_second_call_returns_false() {
1576        // Submit a run and cancel twice while it sits pending. The first
1577        // call removes the queued step, fires the hook, and drops the
1578        // registry entry. The second call must see no entry and report
1579        // `Ok(false)`; crucially, the hook must NOT fire a second
1580        // time.
1581        struct UnreachableRunner;
1582        impl StepRunner for UnreachableRunner {
1583            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1584                unreachable!("worker must not claim the cancelled step");
1585            }
1586        }
1587
1588        let queue = fresh_queue().await;
1589        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1590        let runtime =
1591            WorkflowRuntime::builder(queue, UnreachableRunner, ChannelHook { tx }).build();
1592        // Deliberately do not spawn the worker loop, so step 0 stays
1593        // Pending while both cancels race.
1594
1595        let handle = runtime
1596            .submit(RunSpec {
1597                input: b"x".to_vec(),
1598                ..Default::default()
1599            })
1600            .await
1601            .unwrap();
1602
1603        let first = runtime.cancel(&handle.run_id).await.unwrap();
1604        assert!(first, "first cancel initiates termination");
1605
1606        let second = runtime.cancel(&handle.run_id).await.unwrap();
1607        assert!(
1608            !second,
1609            "second cancel must report Ok(false): registry entry is gone after the first fired the hook",
1610        );
1611
1612        // Hook fires exactly once.
1613        let _ = tokio::time::timeout(Duration::from_secs(1), rx.recv())
1614            .await
1615            .expect("hook fired in time")
1616            .expect("hook channel open");
1617        tokio::time::sleep(Duration::from_millis(50)).await;
1618        assert!(
1619            rx.try_recv().is_err(),
1620            "hook must fire exactly once for a double-cancelled run",
1621        );
1622    }
1623
1624    #[tokio::test]
1625    async fn cancel_after_run_already_terminated_returns_false() {
1626        // Submit a run that succeeds normally, wait for the terminal
1627        // hook, then call `cancel`. The registry entry was removed when
1628        // the success hook fired, so `cancel` must report `Ok(false)`
1629        // and must not fire a second hook.
1630        let queue = fresh_queue().await;
1631        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1632        let runtime = WorkflowRuntime::builder(
1633            queue,
1634            ScriptedRunner::new(vec![StepOutcome::Succeed {
1635                result: b"done".to_vec(),
1636            }]),
1637            ChannelHook { tx },
1638        )
1639        .build();
1640        let shutdown = spawn_runtime(runtime.clone());
1641
1642        let handle = runtime
1643            .submit(RunSpec {
1644                input: b"x".to_vec(),
1645                ..Default::default()
1646            })
1647            .await
1648            .unwrap();
1649
1650        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1651            .await
1652            .expect("Succeeded hook fired")
1653            .expect("hook channel open");
1654        assert_eq!(outcome.status, TerminalStatus::Succeeded);
1655        assert!(runtime.status(&handle.run_id).await.is_none());
1656
1657        let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1658        assert!(
1659            !was_cancelled,
1660            "cancel on an already-terminated run must report Ok(false)",
1661        );
1662
1663        tokio::time::sleep(Duration::from_millis(50)).await;
1664        assert!(
1665            rx.try_recv().is_err(),
1666            "no Cancelled hook may fire after the run already terminated as Succeeded",
1667        );
1668
1669        let _ = shutdown.send(());
1670    }
1671
1672    #[tokio::test]
1673    async fn status_reports_cancelling_while_termination_in_flight() {
1674        // Once `cancel()` has been called but the terminal hook hasn't
1675        // fired yet, `status()` should report `RunState::Cancelling` so
1676        // external observers can see termination is in progress. A gated
1677        // runner holds the cancellation window open long enough to
1678        // observe it deterministically.
1679        struct GatedRunner {
1680            claimed: Arc<tokio::sync::Notify>,
1681            gate: tokio::sync::Mutex<Option<oneshot::Receiver<()>>>,
1682        }
1683        impl StepRunner for GatedRunner {
1684            async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1685                self.claimed.notify_one();
1686                let rx = self.gate.lock().await.take().expect("gate consumed twice");
1687                let _ = rx.await;
1688                Ok(StepOutcome::Succeed {
1689                    result: b"would-have-succeeded".to_vec(),
1690                })
1691            }
1692        }
1693
1694        let queue = fresh_queue().await;
1695        let claimed = Arc::new(tokio::sync::Notify::new());
1696        let (gate_tx, gate_rx) = oneshot::channel::<()>();
1697        let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1698        let runtime = WorkflowRuntime::builder(
1699            queue,
1700            GatedRunner {
1701                claimed: claimed.clone(),
1702                gate: tokio::sync::Mutex::new(Some(gate_rx)),
1703            },
1704            ChannelHook { tx },
1705        )
1706        .build();
1707        let shutdown = spawn_runtime(runtime.clone());
1708
1709        let handle = runtime
1710            .submit(RunSpec {
1711                input: b"x".to_vec(),
1712                ..Default::default()
1713            })
1714            .await
1715            .unwrap();
1716        tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1717            .await
1718            .expect("runner reached gate");
1719
1720        // Before cancel: runner is in flight, state is Running.
1721        let before = runtime.status(&handle.run_id).await.expect("active");
1722        assert_eq!(before.state, RunState::Running);
1723
1724        runtime.cancel(&handle.run_id).await.unwrap();
1725
1726        // After cancel but before the gate is released: the step is still
1727        // in flight, but the cancellation overlay must dominate the
1728        // reported state.
1729        let during = runtime
1730            .status(&handle.run_id)
1731            .await
1732            .expect("entry retained while termination is in flight");
1733        assert_eq!(during.state, RunState::Cancelling);
1734
1735        // Release the runner; the worker observes cancel_requested and
1736        // settles the run as Cancelled, removing the entry.
1737        let _ = gate_tx.send(());
1738
1739        let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1740            .await
1741            .expect("hook fired")
1742            .expect("hook channel open");
1743        assert_eq!(outcome.status, TerminalStatus::Cancelled);
1744        assert!(runtime.status(&handle.run_id).await.is_none());
1745
1746        let _ = shutdown.send(());
1747    }
1748
1749    #[tokio::test]
1750    async fn cancel_unknown_run_returns_false() {
1751        let queue = fresh_queue().await;
1752        let runtime: WorkflowRuntime<ScriptedRunner, NoopTerminalHook> =
1753            WorkflowRuntime::builder(queue, ScriptedRunner::new(vec![]), NoopTerminalHook).build();
1754
1755        let was_cancelled = runtime.cancel("never-submitted").await.unwrap();
1756        assert!(!was_cancelled);
1757    }
1758
1759    #[tokio::test]
1760    async fn transient_fires_once_on_single_attempt() {
1761        assert_transient_retries_until_max(1).await;
1762    }
1763
1764    #[tokio::test]
1765    async fn transient_retries_up_to_max_attempts() {
1766        assert_transient_retries_until_max(3).await;
1767    }
1768}