1use std::collections::HashMap;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime, UNIX_EPOCH};
5
6use serde::{Deserialize, Serialize};
7use taquba::{
8 EnqueueOptions, EnqueueResult, JobRecord, PermanentFailure, Queue, Worker, WorkerError,
9};
10use tokio::sync::Mutex;
11use tokio_util::sync::CancellationToken;
12use tracing::{debug, instrument, warn};
13
14use crate::error::{Error, Result};
15use crate::runner::{Step, StepError, StepErrorKind, StepOutcome, StepRunner};
16use crate::terminal::{RunOutcome, TerminalHook, TerminalStatus};
17
18pub const HEADER_RUN_ID: &str = "workflow.run_id";
20pub const HEADER_STEP: &str = "workflow.step";
22pub const RESERVED_HEADER_PREFIX: &str = "workflow.";
26
27const DEDUP_PREFIX: &str = "run:";
28
29const RUN_KV_PREFIX: &[u8] = b"workflow/runs/";
31
32fn run_kv_key(run_id: &str) -> Vec<u8> {
33 let mut k = Vec::with_capacity(RUN_KV_PREFIX.len() + run_id.len());
34 k.extend_from_slice(RUN_KV_PREFIX);
35 k.extend_from_slice(run_id.as_bytes());
36 k
37}
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
52struct DurableRunRecord {
53 run_id: String,
54 submitted_at_ms: u64,
55 input_hash: [u8; 32],
56}
57
58fn hash_input(input: &[u8]) -> [u8; 32] {
59 use sha2::{Digest, Sha256};
60 let mut hasher = Sha256::new();
61 hasher.update(input);
62 hasher.finalize().into()
63}
64
65#[derive(Debug, Default)]
70struct StepEnqueueOpts {
71 run_at: Option<SystemTime>,
73 priority: Option<u32>,
75 max_attempts: Option<u32>,
77}
78
79#[derive(Debug, Clone, Default)]
81pub struct RunSpec {
82 pub run_id: Option<String>,
87 pub input: Vec<u8>,
89 pub headers: HashMap<String, String>,
93 pub priority: Option<u32>,
95 pub max_attempts_per_step: Option<u32>,
97}
98
99#[derive(Debug, Clone)]
104#[non_exhaustive]
105pub struct SubmitOutcome {
106 pub run_id: String,
108 pub newly_submitted: bool,
114}
115
116#[derive(Debug, Clone)]
120pub struct RunStatus {
121 pub run_id: String,
123 pub state: RunState,
125 pub current_step: u32,
127}
128
129#[derive(Debug, Clone, Copy, PartialEq, Eq)]
131#[non_exhaustive]
132pub enum RunState {
133 Pending,
135 Running,
137 Cancelling,
150}
151
152pub struct WorkflowRuntimeBuilder<R, H> {
158 queue: Arc<Queue>,
159 queue_name: String,
160 runner: R,
161 terminal_hook: H,
162 max_concurrent_steps: usize,
163 poll_interval: Duration,
164}
165
166impl<R: StepRunner, H: TerminalHook> WorkflowRuntimeBuilder<R, H> {
167 pub fn queue_name(mut self, name: impl Into<String>) -> Self {
171 self.queue_name = name.into();
172 self
173 }
174
175 pub fn max_concurrent_steps(mut self, n: usize) -> Self {
178 assert!(n > 0, "max_concurrent_steps must be at least 1");
179 self.max_concurrent_steps = n;
180 self
181 }
182
183 pub fn poll_interval(mut self, interval: Duration) -> Self {
186 self.poll_interval = interval;
187 self
188 }
189
190 pub fn build(self) -> WorkflowRuntime<R, H> {
192 let inner = RuntimeInner {
193 queue: self.queue,
194 queue_name: self.queue_name,
195 runner: self.runner,
196 terminal_hook: self.terminal_hook,
197 max_concurrent_steps: self.max_concurrent_steps,
198 poll_interval: self.poll_interval,
199 registry: Mutex::new(HashMap::new()),
200 };
201 WorkflowRuntime {
202 inner: Arc::new(inner),
203 }
204 }
205}
206
207pub struct WorkflowRuntime<R, H> {
209 inner: Arc<RuntimeInner<R, H>>,
210}
211
212impl<R, H> Clone for WorkflowRuntime<R, H> {
213 fn clone(&self) -> Self {
214 Self {
215 inner: self.inner.clone(),
216 }
217 }
218}
219
220struct RuntimeInner<R, H> {
221 queue: Arc<Queue>,
222 queue_name: String,
223 runner: R,
224 terminal_hook: H,
225 max_concurrent_steps: usize,
226 poll_interval: Duration,
227 registry: Mutex<HashMap<String, RegistryEntry>>,
228}
229
230struct RegistryEntry {
239 status: RunStatus,
240 current_job_id: String,
241 user_headers: HashMap<String, String>,
242 cancel_requested: bool,
243 cancel_token: CancellationToken,
244 input_hash: Option<[u8; 32]>,
251}
252
253impl<R: StepRunner, H: TerminalHook> WorkflowRuntime<R, H> {
254 pub fn builder(queue: Arc<Queue>, runner: R, terminal_hook: H) -> WorkflowRuntimeBuilder<R, H> {
262 WorkflowRuntimeBuilder {
263 queue,
264 queue_name: "workflow-steps".to_string(),
265 runner,
266 terminal_hook,
267 max_concurrent_steps: 16,
268 poll_interval: Duration::from_millis(250),
269 }
270 }
271
272 #[instrument(skip(self, spec), fields(run_id))]
283 pub async fn submit(&self, spec: RunSpec) -> Result<SubmitOutcome> {
284 let run_id = spec.run_id.unwrap_or_else(|| ulid::Ulid::new().to_string());
285 tracing::Span::current().record("run_id", run_id.as_str());
286
287 for k in spec.headers.keys() {
288 if k.starts_with(RESERVED_HEADER_PREFIX) {
289 return Err(Error::ReservedHeaderInSubmit(k.clone()));
290 }
291 }
292
293 let input_hash = hash_input(&spec.input);
294
295 let mut registry = self.inner.registry.lock().await;
300 if let Some(entry) = registry.get(&run_id) {
301 if let Some(existing) = entry.input_hash {
304 if existing != input_hash {
305 return Err(Error::InputMismatch(run_id));
306 }
307 return Ok(SubmitOutcome {
308 run_id,
309 newly_submitted: false,
310 });
311 }
312 }
313
314 if let Some(bytes) = self.inner.queue.kv_get(&run_kv_key(&run_id)).await? {
318 let existing: DurableRunRecord =
319 rmp_serde::from_slice(&bytes).map_err(taquba::Error::from)?;
320 if existing.input_hash != input_hash {
321 return Err(Error::InputMismatch(run_id));
322 }
323 return Ok(SubmitOutcome {
324 run_id,
325 newly_submitted: false,
326 });
327 }
328
329 let mut headers = spec.headers.clone();
330 headers.insert(HEADER_RUN_ID.to_string(), run_id.clone());
331 headers.insert(HEADER_STEP.to_string(), "0".to_string());
332 let enqueue_opts = EnqueueOptions {
333 headers,
334 run_at: None,
335 priority: spec.priority,
336 max_attempts: spec.max_attempts_per_step,
337 dedup_key: Some(format!("{DEDUP_PREFIX}{run_id}:0")),
338 };
339
340 let record_bytes = rmp_serde::to_vec_named(&DurableRunRecord {
341 run_id: run_id.clone(),
342 submitted_at_ms: SystemTime::now()
343 .duration_since(UNIX_EPOCH)
344 .unwrap_or_default()
345 .as_millis() as u64,
346 input_hash,
347 })
348 .map_err(taquba::Error::from)?;
349 let kv = HashMap::from([(run_kv_key(&run_id), record_bytes)]);
350
351 let job_id = match self
352 .inner
353 .queue
354 .enqueue_with_kv(&self.inner.queue_name, spec.input, enqueue_opts, kv)
355 .await?
356 {
357 EnqueueResult::New(id) => id,
358 EnqueueResult::AlreadyEnqueued(_) => {
365 return Ok(SubmitOutcome {
366 run_id,
367 newly_submitted: false,
368 });
369 }
370 };
371
372 registry.insert(
373 run_id.clone(),
374 RegistryEntry {
375 status: RunStatus {
376 run_id: run_id.clone(),
377 state: RunState::Pending,
378 current_step: 0,
379 },
380 current_job_id: job_id.clone(),
381 user_headers: spec.headers.clone(),
382 cancel_requested: false,
383 cancel_token: CancellationToken::new(),
384 input_hash: Some(input_hash),
385 },
386 );
387 drop(registry);
388
389 debug!(run_id = %run_id, job_id = %job_id, "run submitted");
390 Ok(SubmitOutcome {
391 run_id,
392 newly_submitted: true,
393 })
394 }
395
396 pub async fn status(&self, run_id: &str) -> Option<RunStatus> {
404 self.inner.registry.lock().await.get(run_id).map(|e| {
405 let mut status = e.status.clone();
406 if e.cancel_requested {
407 status.state = RunState::Cancelling;
408 }
409 status
410 })
411 }
412
413 pub async fn cancel(&self, run_id: &str) -> Result<bool> {
439 let (job_id, headers, current_step) = {
440 let mut registry = self.inner.registry.lock().await;
441 let Some(entry) = registry.get_mut(run_id) else {
442 return Ok(false);
443 };
444 entry.cancel_requested = true;
445 entry.cancel_token.cancel();
451 (
452 entry.current_job_id.clone(),
453 entry.user_headers.clone(),
454 entry.status.current_step,
455 )
456 };
457
458 match self.inner.queue.cancel(&job_id).await? {
459 taquba::CancelOutcome::Removed => {
460 self.inner
464 .terminate(RunOutcome {
465 run_id: run_id.to_string(),
466 status: TerminalStatus::Cancelled,
467 result: None,
468 error: None,
469 headers,
470 final_step: current_step,
471 })
472 .await;
473 }
474 taquba::CancelOutcome::Requested => {
475 }
479 taquba::CancelOutcome::NotFound => {
480 }
486 }
487 Ok(true)
488 }
489
490 pub async fn run<F>(&self, shutdown: F) -> Result<()>
493 where
494 F: Future<Output = ()>,
495 R: 'static,
496 H: 'static,
497 {
498 let worker = Arc::new(StepWorker {
499 inner: self.inner.clone(),
500 });
501 taquba::run_worker_concurrent(
502 &self.inner.queue,
503 &self.inner.queue_name,
504 worker,
505 self.inner.max_concurrent_steps,
506 self.inner.poll_interval,
507 shutdown,
508 )
509 .await?;
510 Ok(())
511 }
512}
513
514struct StepWorker<R, H> {
515 inner: Arc<RuntimeInner<R, H>>,
516}
517
518impl<R: StepRunner + 'static, H: TerminalHook + 'static> Worker for StepWorker<R, H> {
519 async fn process(&self, job: &JobRecord) -> std::result::Result<(), WorkerError> {
520 self.inner.process_step(job).await
521 }
522}
523
524impl<R: StepRunner, H: TerminalHook> RuntimeInner<R, H> {
525 async fn enqueue_step(
526 &self,
527 run_id: &str,
528 step_number: u32,
529 payload: Vec<u8>,
530 user_headers: &HashMap<String, String>,
531 opts: StepEnqueueOpts,
532 ) -> Result<String> {
533 let mut headers = user_headers.clone();
534 headers.insert(HEADER_RUN_ID.to_string(), run_id.to_string());
535 headers.insert(HEADER_STEP.to_string(), step_number.to_string());
536
537 let enqueue_opts = EnqueueOptions {
538 headers,
539 run_at: opts.run_at,
540 priority: opts.priority,
541 max_attempts: opts.max_attempts,
542 dedup_key: Some(format!("{DEDUP_PREFIX}{run_id}:{step_number}")),
543 };
544 Ok(self
545 .queue
546 .enqueue_with(&self.queue_name, payload, enqueue_opts)
547 .await?)
548 }
549
550 fn split_headers(headers: &HashMap<String, String>) -> HashMap<String, String> {
551 headers
552 .iter()
553 .filter(|(k, _)| !k.starts_with(RESERVED_HEADER_PREFIX))
554 .map(|(k, v)| (k.clone(), v.clone()))
555 .collect()
556 }
557
558 fn parse_step_headers(job: &JobRecord) -> std::result::Result<(String, u32), Error> {
559 let run_id = job
560 .headers
561 .get(HEADER_RUN_ID)
562 .ok_or(Error::MissingHeader(HEADER_RUN_ID))?
563 .to_string();
564 let step_str = job
565 .headers
566 .get(HEADER_STEP)
567 .ok_or(Error::MissingHeader(HEADER_STEP))?;
568 let step_number: u32 = step_str.parse().map_err(|_| Error::InvalidStepHeader {
569 header: HEADER_STEP,
570 value: step_str.clone(),
571 })?;
572 Ok((run_id, step_number))
573 }
574
575 async fn terminate(&self, outcome: RunOutcome) {
585 self.registry.lock().await.remove(&outcome.run_id);
586 if let Err(err) = self.queue.kv_delete(&run_kv_key(&outcome.run_id)).await {
587 warn!(
588 run_id = %outcome.run_id,
589 "failed to clear durable run record: {err}"
590 );
591 }
592 self.terminal_hook.on_termination(&outcome).await;
593 }
594
595 async fn registry_mark_running(
603 &self,
604 run_id: &str,
605 step_number: u32,
606 job_id: &str,
607 user_headers: &HashMap<String, String>,
608 ) -> CancellationToken {
609 let mut registry = self.registry.lock().await;
610 match registry.get_mut(run_id) {
611 Some(entry) => {
612 entry.status.state = RunState::Running;
613 entry.status.current_step = step_number;
614 entry.current_job_id = job_id.to_string();
615 entry.cancel_token.clone()
616 }
617 None => {
618 let cancel_token = CancellationToken::new();
619 registry.insert(
620 run_id.to_string(),
621 RegistryEntry {
622 status: RunStatus {
623 run_id: run_id.to_string(),
624 state: RunState::Running,
625 current_step: step_number,
626 },
627 current_job_id: job_id.to_string(),
628 user_headers: user_headers.clone(),
629 cancel_requested: false,
630 cancel_token: cancel_token.clone(),
631 input_hash: None,
632 },
633 );
634 cancel_token
635 }
636 }
637 }
638
639 async fn process_step(&self, job: &JobRecord) -> std::result::Result<(), WorkerError> {
640 let (run_id, step_number) = match Self::parse_step_headers(job) {
641 Ok(v) => v,
642 Err(e) => {
643 warn!(job_id = %job.id, error = %e, "workflow step has malformed headers");
644 if e.is_permanent() {
645 return Err(PermanentFailure::new(e.to_string()).into());
646 }
647 return Err(e.to_string().into());
648 }
649 };
650
651 let user_headers = Self::split_headers(&job.headers);
652
653 let cancel_token = self
654 .registry_mark_running(&run_id, step_number, &job.id, &user_headers)
655 .await;
656
657 let step = Step {
658 run_id: run_id.clone(),
659 step_number,
660 payload: job.payload.clone(),
661 headers: user_headers.clone(),
662 job_id: job.id.clone(),
663 attempts: job.attempts,
664 cancel_token,
665 };
666
667 let inherit_opts = || StepEnqueueOpts {
670 run_at: None,
671 priority: Some(job.priority),
672 max_attempts: Some(job.max_attempts),
673 };
674
675 let outcome = self.runner.run_step(&step).await;
676 let external_cancel = self
677 .registry
678 .lock()
679 .await
680 .get(&run_id)
681 .is_some_and(|e| e.cancel_requested);
682
683 match outcome {
691 Ok(StepOutcome::Cancel { reason }) => {
692 self.terminate(RunOutcome {
693 run_id: run_id.clone(),
694 status: TerminalStatus::Cancelled,
695 result: None,
696 error: Some(reason),
697 headers: user_headers,
698 final_step: step_number,
699 })
700 .await;
701 Ok(())
702 }
703 _ if external_cancel => {
704 self.terminate(RunOutcome {
705 run_id: run_id.clone(),
706 status: TerminalStatus::Cancelled,
707 result: None,
708 error: None,
709 headers: user_headers,
710 final_step: step_number,
711 })
712 .await;
713 Ok(())
714 }
715 Ok(StepOutcome::Continue { payload }) => {
716 self.advance(
717 &run_id,
718 step_number + 1,
719 payload,
720 &user_headers,
721 inherit_opts(),
722 )
723 .await
724 }
725 Ok(StepOutcome::ContinueAfter { payload, delay }) => {
726 let opts = StepEnqueueOpts {
727 run_at: Some(SystemTime::now() + delay),
728 ..inherit_opts()
729 };
730 self.advance(&run_id, step_number + 1, payload, &user_headers, opts)
731 .await
732 }
733 Ok(StepOutcome::Succeed { result }) => {
734 self.terminate(RunOutcome {
735 run_id: run_id.clone(),
736 status: TerminalStatus::Succeeded,
737 result: Some(result),
738 error: None,
739 headers: user_headers,
740 final_step: step_number,
741 })
742 .await;
743 Ok(())
744 }
745 Ok(StepOutcome::Fail { reason }) => {
746 self.terminate(RunOutcome {
750 run_id: run_id.clone(),
751 status: TerminalStatus::Failed,
752 result: None,
753 error: Some(reason),
754 headers: user_headers,
755 final_step: step_number,
756 })
757 .await;
758 Ok(())
759 }
760 Err(StepError {
761 message,
762 kind: StepErrorKind::Permanent,
763 }) => {
764 self.terminate(RunOutcome {
765 run_id: run_id.clone(),
766 status: TerminalStatus::Failed,
767 result: None,
768 error: Some(message.clone()),
769 headers: user_headers,
770 final_step: step_number,
771 })
772 .await;
773 Err(PermanentFailure::new(message).into())
774 }
775 Err(StepError {
776 message,
777 kind: StepErrorKind::Transient,
778 }) => {
779 if job.attempts >= job.max_attempts {
783 self.terminate(RunOutcome {
784 run_id: run_id.clone(),
785 status: TerminalStatus::Failed,
786 result: None,
787 error: Some(message.clone()),
788 headers: user_headers,
789 final_step: step_number,
790 })
791 .await;
792 }
793 Err(message.into())
794 }
795 }
796 }
797
798 async fn advance(
799 &self,
800 run_id: &str,
801 next_step: u32,
802 payload: Vec<u8>,
803 user_headers: &HashMap<String, String>,
804 opts: StepEnqueueOpts,
805 ) -> std::result::Result<(), WorkerError> {
806 match self
807 .enqueue_step(run_id, next_step, payload, user_headers, opts)
808 .await
809 {
810 Ok(new_job_id) => {
811 if let Some(entry) = self.registry.lock().await.get_mut(run_id) {
813 entry.status.state = RunState::Pending;
814 entry.status.current_step = next_step;
815 entry.current_job_id = new_job_id;
816 }
817 Ok(())
818 }
819 Err(e) => Err(e.to_string().into()),
823 }
824 }
825}
826
827#[cfg(test)]
828mod tests {
829 use super::*;
830 use crate::terminal::NoopTerminalHook;
831 use std::sync::Mutex as StdMutex;
832 use std::sync::atomic::{AtomicU32, Ordering};
833 use taquba::object_store::memory::InMemory;
834 use taquba::{OpenOptions, QueueConfig};
835 use tokio::sync::oneshot;
836
837 struct ChannelHook {
839 tx: tokio::sync::mpsc::UnboundedSender<RunOutcome>,
840 }
841
842 impl TerminalHook for ChannelHook {
843 async fn on_termination(&self, outcome: &RunOutcome) {
844 let _ = self.tx.send(outcome.clone());
845 }
846 }
847
848 struct ScriptedRunner {
850 script: Arc<StdMutex<Vec<StepOutcome>>>,
851 }
852
853 impl ScriptedRunner {
854 fn new(steps: Vec<StepOutcome>) -> Self {
855 Self {
856 script: Arc::new(StdMutex::new(steps)),
857 }
858 }
859 }
860
861 impl StepRunner for ScriptedRunner {
862 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
863 let next = self.script.lock().unwrap().remove(0);
864 Ok(next)
865 }
866 }
867
868 async fn fresh_queue() -> Arc<Queue> {
869 Arc::new(
870 Queue::open(Arc::new(InMemory::new()), "test")
871 .await
872 .unwrap(),
873 )
874 }
875
876 async fn fresh_queue_fast_retry() -> Arc<Queue> {
879 let opts = OpenOptions {
880 default_queue_config: QueueConfig {
881 retry_backoff_base: Duration::ZERO,
882 ..QueueConfig::default()
883 },
884 reaper_interval: Duration::from_millis(50),
885 scheduler_interval: Duration::from_millis(50),
886 ..OpenOptions::default()
887 };
888 Arc::new(
889 Queue::open_with_options(Arc::new(InMemory::new()), "test", opts)
890 .await
891 .unwrap(),
892 )
893 }
894
895 fn spawn_runtime<R, H>(runtime: WorkflowRuntime<R, H>) -> oneshot::Sender<()>
896 where
897 R: StepRunner + 'static,
898 H: TerminalHook + 'static,
899 {
900 let (tx, rx) = oneshot::channel::<()>();
901 tokio::spawn(async move {
902 let _ = runtime
903 .run(async move {
904 let _ = rx.await;
905 })
906 .await;
907 });
908 tx
909 }
910
911 #[tokio::test(start_paused = true)]
912 async fn single_step_succeeds_and_fires_hook() {
913 let queue = fresh_queue().await;
914 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
915 let runtime = WorkflowRuntime::builder(
916 queue,
917 ScriptedRunner::new(vec![StepOutcome::Succeed {
918 result: b"done".to_vec(),
919 }]),
920 ChannelHook { tx },
921 )
922 .build();
923 let shutdown = spawn_runtime(runtime.clone());
924
925 let handle = runtime
926 .submit(RunSpec {
927 input: b"in".to_vec(),
928 ..Default::default()
929 })
930 .await
931 .unwrap();
932 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
933 .await
934 .unwrap()
935 .unwrap();
936
937 assert_eq!(outcome.run_id, handle.run_id);
938 assert_eq!(outcome.status, TerminalStatus::Succeeded);
939 assert_eq!(outcome.result.as_deref(), Some(b"done".as_slice()));
940 assert_eq!(outcome.final_step, 0);
941 assert!(runtime.status(&handle.run_id).await.is_none());
942
943 let _ = shutdown.send(());
944 }
945
946 #[tokio::test(start_paused = true)]
947 async fn multi_step_run_advances_through_continue() {
948 let queue = fresh_queue().await;
949 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
950 let runtime = WorkflowRuntime::builder(
951 queue,
952 ScriptedRunner::new(vec![
953 StepOutcome::Continue {
954 payload: b"step1".to_vec(),
955 },
956 StepOutcome::Continue {
957 payload: b"step2".to_vec(),
958 },
959 StepOutcome::Succeed {
960 result: b"final".to_vec(),
961 },
962 ]),
963 ChannelHook { tx },
964 )
965 .build();
966 let shutdown = spawn_runtime(runtime.clone());
967
968 let handle = runtime
969 .submit(RunSpec {
970 input: b"start".to_vec(),
971 ..Default::default()
972 })
973 .await
974 .unwrap();
975 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
976 .await
977 .unwrap()
978 .unwrap();
979
980 assert_eq!(outcome.run_id, handle.run_id);
981 assert_eq!(outcome.final_step, 2);
982 assert_eq!(outcome.status, TerminalStatus::Succeeded);
983 assert_eq!(outcome.result.as_deref(), Some(b"final".as_slice()));
984
985 let _ = shutdown.send(());
986 }
987
988 #[tokio::test(start_paused = true)]
989 async fn permanent_failure_dead_letters_and_fires_hook() {
990 struct FailingRunner;
991 impl StepRunner for FailingRunner {
992 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
993 Err(StepError::permanent("nope"))
994 }
995 }
996
997 let queue = fresh_queue().await;
998 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
999 let runtime =
1000 WorkflowRuntime::builder(queue.clone(), FailingRunner, ChannelHook { tx }).build();
1001 let shutdown = spawn_runtime(runtime.clone());
1002
1003 let handle = runtime
1004 .submit(RunSpec {
1005 input: b"x".to_vec(),
1006 ..Default::default()
1007 })
1008 .await
1009 .unwrap();
1010 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1011 .await
1012 .unwrap()
1013 .unwrap();
1014
1015 assert_eq!(outcome.run_id, handle.run_id);
1016 assert_eq!(outcome.status, TerminalStatus::Failed);
1017 assert_eq!(outcome.error.as_deref(), Some("nope"));
1018 assert!(runtime.status(&handle.run_id).await.is_none());
1019
1020 let stats = queue.stats("workflow-steps").await.unwrap();
1022 assert_eq!(stats.dead, 1, "permanent error should dead-letter");
1023
1024 let _ = shutdown.send(());
1025 }
1026
1027 #[tokio::test(start_paused = true)]
1028 async fn fail_outcome_terminates_run_without_dead_letter() {
1029 struct VerdictRunner;
1034 impl StepRunner for VerdictRunner {
1035 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1036 Ok(StepOutcome::Fail {
1037 reason: "agent declined the task".to_string(),
1038 })
1039 }
1040 }
1041
1042 let queue = fresh_queue().await;
1043 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1044 let runtime =
1045 WorkflowRuntime::builder(queue.clone(), VerdictRunner, ChannelHook { tx }).build();
1046 let shutdown = spawn_runtime(runtime.clone());
1047
1048 let handle = runtime
1049 .submit(RunSpec {
1050 input: b"x".to_vec(),
1051 ..Default::default()
1052 })
1053 .await
1054 .unwrap();
1055
1056 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1057 .await
1058 .expect("hook fired in time")
1059 .expect("hook channel open");
1060
1061 assert_eq!(outcome.run_id, handle.run_id);
1062 assert_eq!(outcome.status, TerminalStatus::Failed);
1063 assert_eq!(outcome.error.as_deref(), Some("agent declined the task"));
1064 assert!(runtime.status(&handle.run_id).await.is_none());
1065
1066 let stats = queue.stats("workflow-steps").await.unwrap();
1069 assert_eq!(stats.dead, 0, "Fail verdict must not dead-letter");
1070
1071 let _ = shutdown.send(());
1072 }
1073
1074 #[tokio::test(start_paused = true)]
1075 async fn duplicate_submit_in_process_with_same_input_is_idempotent() {
1076 struct PauseRunner;
1079 impl StepRunner for PauseRunner {
1080 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1081 std::future::pending().await
1082 }
1083 }
1084
1085 let queue = fresh_queue().await;
1086 let runtime = WorkflowRuntime::builder(queue, PauseRunner, NoopTerminalHook).build();
1087 let shutdown = spawn_runtime(runtime.clone());
1088
1089 let handle = runtime
1090 .submit(RunSpec {
1091 run_id: Some("fixed-id".to_string()),
1092 input: b"x".to_vec(),
1093 ..Default::default()
1094 })
1095 .await
1096 .unwrap();
1097 for _ in 0..40 {
1100 if runtime.status(&handle.run_id).await.is_some() {
1101 break;
1102 }
1103 tokio::time::sleep(Duration::from_millis(25)).await;
1104 }
1105 assert!(runtime.status(&handle.run_id).await.is_some());
1106
1107 let outcome = runtime
1108 .submit(RunSpec {
1109 run_id: Some("fixed-id".to_string()),
1110 input: b"x".to_vec(),
1111 ..Default::default()
1112 })
1113 .await
1114 .unwrap();
1115 assert_eq!(outcome.run_id, "fixed-id");
1116 assert!(!outcome.newly_submitted);
1117
1118 let _ = shutdown.send(());
1119 }
1120
1121 #[tokio::test(start_paused = true)]
1122 async fn duplicate_submit_in_process_with_different_input_errors() {
1123 struct PauseRunner;
1124 impl StepRunner for PauseRunner {
1125 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1126 std::future::pending().await
1127 }
1128 }
1129
1130 let queue = fresh_queue().await;
1131 let runtime = WorkflowRuntime::builder(queue, PauseRunner, NoopTerminalHook).build();
1132 let shutdown = spawn_runtime(runtime.clone());
1133
1134 runtime
1135 .submit(RunSpec {
1136 run_id: Some("fixed-id".to_string()),
1137 input: b"x".to_vec(),
1138 ..Default::default()
1139 })
1140 .await
1141 .unwrap();
1142
1143 let err = runtime
1144 .submit(RunSpec {
1145 run_id: Some("fixed-id".to_string()),
1146 input: b"y".to_vec(),
1147 ..Default::default()
1148 })
1149 .await
1150 .unwrap_err();
1151 assert!(matches!(&err, Error::InputMismatch(id) if id == "fixed-id"));
1152 assert!(err.is_permanent());
1153
1154 let _ = shutdown.send(());
1155 }
1156
1157 #[tokio::test(start_paused = true)]
1158 async fn duplicate_submit_across_runtime_restart_with_same_input_is_idempotent() {
1159 struct PauseRunner;
1166 impl StepRunner for PauseRunner {
1167 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1168 std::future::pending().await
1169 }
1170 }
1171
1172 let queue = fresh_queue().await;
1173
1174 {
1177 let runtime =
1178 WorkflowRuntime::builder(queue.clone(), PauseRunner, NoopTerminalHook).build();
1179 runtime
1180 .submit(RunSpec {
1181 run_id: Some("durable-id".to_string()),
1182 input: b"x".to_vec(),
1183 ..Default::default()
1184 })
1185 .await
1186 .unwrap();
1187 }
1188
1189 assert!(
1191 queue
1192 .kv_get(&run_kv_key("durable-id"))
1193 .await
1194 .unwrap()
1195 .is_some(),
1196 "durable run record must persist past runtime drop"
1197 );
1198
1199 let runtime2 =
1202 WorkflowRuntime::builder(queue.clone(), PauseRunner, NoopTerminalHook).build();
1203 let outcome = runtime2
1204 .submit(RunSpec {
1205 run_id: Some("durable-id".to_string()),
1206 input: b"x".to_vec(),
1207 ..Default::default()
1208 })
1209 .await
1210 .unwrap();
1211 assert_eq!(outcome.run_id, "durable-id");
1212 assert!(!outcome.newly_submitted);
1213 }
1214
1215 #[tokio::test(start_paused = true)]
1216 async fn duplicate_submit_across_runtime_restart_with_different_input_errors() {
1217 struct PauseRunner;
1221 impl StepRunner for PauseRunner {
1222 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1223 std::future::pending().await
1224 }
1225 }
1226
1227 let queue = fresh_queue().await;
1228
1229 {
1230 let runtime =
1231 WorkflowRuntime::builder(queue.clone(), PauseRunner, NoopTerminalHook).build();
1232 runtime
1233 .submit(RunSpec {
1234 run_id: Some("durable-id".to_string()),
1235 input: b"x".to_vec(),
1236 ..Default::default()
1237 })
1238 .await
1239 .unwrap();
1240 }
1241
1242 let runtime2 =
1243 WorkflowRuntime::builder(queue.clone(), PauseRunner, NoopTerminalHook).build();
1244 let err = runtime2
1245 .submit(RunSpec {
1246 run_id: Some("durable-id".to_string()),
1247 input: b"y".to_vec(),
1248 ..Default::default()
1249 })
1250 .await
1251 .unwrap_err();
1252 assert!(matches!(&err, Error::InputMismatch(id) if id == "durable-id"));
1253 }
1254
1255 #[tokio::test(start_paused = true)]
1256 async fn reserved_header_on_submit_is_rejected() {
1257 let queue = fresh_queue().await;
1258 let runtime: WorkflowRuntime<ScriptedRunner, NoopTerminalHook> =
1259 WorkflowRuntime::builder(queue, ScriptedRunner::new(vec![]), NoopTerminalHook).build();
1260 let mut headers = HashMap::new();
1261 headers.insert("workflow.run_id".to_string(), "evil".to_string());
1262
1263 let err = runtime
1264 .submit(RunSpec {
1265 input: b"x".to_vec(),
1266 headers,
1267 ..Default::default()
1268 })
1269 .await
1270 .unwrap_err();
1271 assert!(
1272 matches!(&err, Error::ReservedHeaderInSubmit(k) if k == "workflow.run_id"),
1273 "got: {err:?}"
1274 );
1275 }
1276
1277 #[tokio::test(start_paused = true)]
1278 async fn user_headers_thread_through_to_terminal_hook() {
1279 let queue = fresh_queue().await;
1280 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1281 let runtime = WorkflowRuntime::builder(
1282 queue,
1283 ScriptedRunner::new(vec![
1284 StepOutcome::Continue { payload: vec![] },
1285 StepOutcome::Succeed { result: vec![] },
1286 ]),
1287 ChannelHook { tx },
1288 )
1289 .build();
1290 let shutdown = spawn_runtime(runtime.clone());
1291
1292 let mut headers = HashMap::new();
1293 headers.insert("trace_id".to_string(), "abc-123".to_string());
1294 headers.insert("tenant".to_string(), "acme".to_string());
1295
1296 runtime
1297 .submit(RunSpec {
1298 input: b"x".to_vec(),
1299 headers,
1300 ..Default::default()
1301 })
1302 .await
1303 .unwrap();
1304 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1305 .await
1306 .unwrap()
1307 .unwrap();
1308
1309 assert_eq!(outcome.headers.get("trace_id").unwrap(), "abc-123");
1310 assert_eq!(outcome.headers.get("tenant").unwrap(), "acme");
1311 assert!(!outcome.headers.contains_key(HEADER_RUN_ID));
1313 assert!(!outcome.headers.contains_key(HEADER_STEP));
1314
1315 let _ = shutdown.send(());
1316 }
1317
1318 #[tokio::test(start_paused = true)]
1319 async fn restart_resumes_at_next_step() {
1320 struct GatedRunner {
1330 gate: tokio::sync::Mutex<Option<oneshot::Receiver<Vec<u8>>>>,
1331 }
1332
1333 impl StepRunner for GatedRunner {
1334 async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
1335 match step.step_number {
1336 0 => {
1337 let rx = self.gate.lock().await.take().expect("gate consumed twice");
1338 let payload = rx.await.expect("gate sender dropped");
1339 Ok(StepOutcome::Continue { payload })
1340 }
1341 _ => std::future::pending().await,
1342 }
1343 }
1344 }
1345
1346 struct CompleteOnStep1;
1347 impl StepRunner for CompleteOnStep1 {
1348 async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
1349 assert_eq!(step.step_number, 1, "runtime B should only ever see step 1");
1350 assert_eq!(step.payload.as_slice(), b"step1-payload");
1351 Ok(StepOutcome::Succeed {
1352 result: b"resumed".to_vec(),
1353 })
1354 }
1355 }
1356
1357 let queue = fresh_queue().await;
1358
1359 let (gate_tx, gate_rx) = oneshot::channel::<Vec<u8>>();
1360 let runtime_a = WorkflowRuntime::builder(
1361 queue.clone(),
1362 GatedRunner {
1363 gate: tokio::sync::Mutex::new(Some(gate_rx)),
1364 },
1365 NoopTerminalHook,
1366 )
1367 .max_concurrent_steps(1)
1368 .build();
1369
1370 let (shutdown_a_tx, shutdown_a_rx) = oneshot::channel::<()>();
1371 let worker_a = {
1372 let runtime_a = runtime_a.clone();
1373 tokio::spawn(async move {
1374 let _ = runtime_a
1375 .run(async move {
1376 let _ = shutdown_a_rx.await;
1377 })
1378 .await;
1379 })
1380 };
1381
1382 let handle = runtime_a
1383 .submit(RunSpec {
1384 input: b"input".to_vec(),
1385 ..Default::default()
1386 })
1387 .await
1388 .unwrap();
1389
1390 for _ in 0..80 {
1393 if let Some(s) = runtime_a.status(&handle.run_id).await {
1394 if s.state == RunState::Running && s.current_step == 0 {
1395 break;
1396 }
1397 }
1398 tokio::time::sleep(Duration::from_millis(25)).await;
1399 }
1400 let s = runtime_a.status(&handle.run_id).await.expect("status");
1401 assert_eq!(s.state, RunState::Running);
1402 assert_eq!(s.current_step, 0);
1403
1404 let _ = shutdown_a_tx.send(());
1408 let _ = gate_tx.send(b"step1-payload".to_vec());
1409
1410 worker_a.await.expect("runtime A drained cleanly");
1411
1412 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1415 let runtime_b =
1416 WorkflowRuntime::builder(queue, CompleteOnStep1, ChannelHook { tx }).build();
1417 let shutdown_b = spawn_runtime(runtime_b.clone());
1418
1419 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1420 .await
1421 .expect("hook fired in time")
1422 .expect("hook channel open");
1423
1424 assert_eq!(outcome.run_id, handle.run_id);
1425 assert_eq!(outcome.status, TerminalStatus::Succeeded);
1426 assert_eq!(outcome.result.as_deref(), Some(b"resumed".as_slice()));
1427 assert_eq!(outcome.final_step, 1);
1428
1429 let _ = shutdown_b.send(());
1430 }
1431
1432 async fn assert_transient_retries_until_max(max_attempts: u32) {
1438 struct AlwaysTransient {
1439 calls: Arc<AtomicU32>,
1440 }
1441 impl StepRunner for AlwaysTransient {
1442 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1443 self.calls.fetch_add(1, Ordering::SeqCst);
1444 Err(StepError::transient("flaky"))
1445 }
1446 }
1447
1448 let queue = fresh_queue_fast_retry().await;
1449 let calls = Arc::new(AtomicU32::new(0));
1450 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1451 let runtime = WorkflowRuntime::builder(
1452 queue,
1453 AlwaysTransient {
1454 calls: calls.clone(),
1455 },
1456 ChannelHook { tx },
1457 )
1458 .build();
1459 let shutdown = spawn_runtime(runtime.clone());
1460
1461 runtime
1462 .submit(RunSpec {
1463 input: b"x".to_vec(),
1464 max_attempts_per_step: Some(max_attempts),
1465 ..Default::default()
1466 })
1467 .await
1468 .unwrap();
1469
1470 let outcome = tokio::time::timeout(Duration::from_secs(3), rx.recv())
1471 .await
1472 .expect("hook fired in time")
1473 .expect("hook channel open");
1474
1475 assert_eq!(outcome.status, TerminalStatus::Failed);
1476 assert_eq!(outcome.error.as_deref(), Some("flaky"));
1477 assert_eq!(
1478 calls.load(Ordering::SeqCst),
1479 max_attempts,
1480 "runner called once per attempt up to max_attempts"
1481 );
1482
1483 tokio::time::sleep(Duration::from_millis(50)).await;
1485 assert!(rx.try_recv().is_err(), "hook fired more than once");
1486
1487 let _ = shutdown.send(());
1488 }
1489
1490 #[tokio::test(start_paused = true)]
1491 async fn cancel_outcome_terminates_run_without_dead_letter() {
1492 struct CancellingRunner;
1496 impl StepRunner for CancellingRunner {
1497 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1498 Ok(StepOutcome::Cancel {
1499 reason: "upstream aborted".to_string(),
1500 })
1501 }
1502 }
1503
1504 let queue = fresh_queue().await;
1505 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1506 let runtime =
1507 WorkflowRuntime::builder(queue.clone(), CancellingRunner, ChannelHook { tx }).build();
1508 let shutdown = spawn_runtime(runtime.clone());
1509
1510 let handle = runtime
1511 .submit(RunSpec {
1512 input: b"x".to_vec(),
1513 ..Default::default()
1514 })
1515 .await
1516 .unwrap();
1517
1518 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1519 .await
1520 .expect("hook fired in time")
1521 .expect("hook channel open");
1522
1523 assert_eq!(outcome.run_id, handle.run_id);
1524 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1525 assert_eq!(outcome.error.as_deref(), Some("upstream aborted"));
1526 assert!(runtime.status(&handle.run_id).await.is_none());
1527
1528 let stats = queue.stats("workflow-steps").await.unwrap();
1529 assert_eq!(stats.dead, 0, "Cancel verdict must not dead-letter");
1530
1531 let _ = shutdown.send(());
1532 }
1533
1534 #[tokio::test(start_paused = true)]
1535 async fn cancel_pending_run_fires_cancelled_hook() {
1536 struct UnreachableRunner;
1539 impl StepRunner for UnreachableRunner {
1540 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1541 unreachable!("worker must not claim the cancelled step");
1542 }
1543 }
1544
1545 let queue = fresh_queue().await;
1546 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1547 let runtime =
1548 WorkflowRuntime::builder(queue.clone(), UnreachableRunner, ChannelHook { tx }).build();
1549 let mut headers = HashMap::new();
1553 headers.insert("tenant".to_string(), "acme".to_string());
1554
1555 let handle = runtime
1556 .submit(RunSpec {
1557 input: b"x".to_vec(),
1558 headers,
1559 ..Default::default()
1560 })
1561 .await
1562 .unwrap();
1563 let status = runtime.status(&handle.run_id).await.expect("active");
1564 assert_eq!(status.state, RunState::Pending);
1565
1566 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1567 assert!(was_cancelled);
1568
1569 let outcome = tokio::time::timeout(Duration::from_secs(1), rx.recv())
1570 .await
1571 .expect("hook fired in time")
1572 .expect("hook channel open");
1573 assert_eq!(outcome.run_id, handle.run_id);
1574 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1575 assert!(outcome.error.is_none());
1577 assert_eq!(outcome.headers.get("tenant").unwrap(), "acme");
1578 assert!(runtime.status(&handle.run_id).await.is_none());
1579
1580 let stats = queue.stats("workflow-steps").await.unwrap();
1581 assert_eq!(stats.dead, 0, "cancel must not dead-letter");
1582 assert_eq!(stats.pending, 0, "cancelled job must be removed");
1583 }
1584
1585 #[tokio::test(start_paused = true)]
1586 async fn cancel_during_running_step_overrides_outcome() {
1587 struct GatedRunner {
1590 claimed: Arc<tokio::sync::Notify>,
1591 gate: tokio::sync::Mutex<Option<oneshot::Receiver<()>>>,
1592 }
1593 impl StepRunner for GatedRunner {
1594 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1595 self.claimed.notify_one();
1596 let rx = self.gate.lock().await.take().expect("gate consumed twice");
1597 let _ = rx.await;
1598 Ok(StepOutcome::Succeed {
1602 result: b"would-have-succeeded".to_vec(),
1603 })
1604 }
1605 }
1606
1607 let queue = fresh_queue().await;
1608 let claimed = Arc::new(tokio::sync::Notify::new());
1609 let (gate_tx, gate_rx) = oneshot::channel::<()>();
1610 let (hook_tx, mut hook_rx) = tokio::sync::mpsc::unbounded_channel();
1611 let runtime = WorkflowRuntime::builder(
1612 queue.clone(),
1613 GatedRunner {
1614 claimed: claimed.clone(),
1615 gate: tokio::sync::Mutex::new(Some(gate_rx)),
1616 },
1617 ChannelHook { tx: hook_tx },
1618 )
1619 .build();
1620 let shutdown = spawn_runtime(runtime.clone());
1621
1622 let handle = runtime
1623 .submit(RunSpec {
1624 input: b"x".to_vec(),
1625 ..Default::default()
1626 })
1627 .await
1628 .unwrap();
1629 tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1630 .await
1631 .expect("runner reached gate");
1632
1633 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1634 assert!(was_cancelled);
1635
1636 let _ = gate_tx.send(());
1639
1640 let outcome = tokio::time::timeout(Duration::from_secs(2), hook_rx.recv())
1641 .await
1642 .expect("hook fired")
1643 .expect("hook channel open");
1644 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1645 assert!(
1646 outcome.result.is_none(),
1647 "succeed payload must be discarded"
1648 );
1649 assert!(runtime.status(&handle.run_id).await.is_none());
1650
1651 let stats = queue.stats("workflow-steps").await.unwrap();
1652 assert_eq!(stats.dead, 0);
1653
1654 let _ = shutdown.send(());
1655 }
1656
1657 async fn assert_cancel_suppresses_runner_error(error: StepError) {
1665 struct GatedErrRunner {
1666 claimed: Arc<tokio::sync::Notify>,
1667 gate: tokio::sync::Mutex<Option<oneshot::Receiver<()>>>,
1668 calls: Arc<AtomicU32>,
1669 error: StdMutex<Option<StepError>>,
1670 }
1671 impl StepRunner for GatedErrRunner {
1672 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1673 self.calls.fetch_add(1, Ordering::SeqCst);
1674 self.claimed.notify_one();
1675 let rx = self.gate.lock().await.take().expect("gate consumed twice");
1676 let _ = rx.await;
1677 Err(self
1678 .error
1679 .lock()
1680 .unwrap()
1681 .take()
1682 .expect("error consumed twice"))
1683 }
1684 }
1685
1686 let queue = fresh_queue_fast_retry().await;
1687 let claimed = Arc::new(tokio::sync::Notify::new());
1688 let calls = Arc::new(AtomicU32::new(0));
1689 let (gate_tx, gate_rx) = oneshot::channel::<()>();
1690 let (hook_tx, mut hook_rx) = tokio::sync::mpsc::unbounded_channel();
1691 let runtime = WorkflowRuntime::builder(
1692 queue.clone(),
1693 GatedErrRunner {
1694 claimed: claimed.clone(),
1695 gate: tokio::sync::Mutex::new(Some(gate_rx)),
1696 calls: calls.clone(),
1697 error: StdMutex::new(Some(error)),
1698 },
1699 ChannelHook { tx: hook_tx },
1700 )
1701 .build();
1702 let shutdown = spawn_runtime(runtime.clone());
1703
1704 let handle = runtime
1705 .submit(RunSpec {
1706 input: b"x".to_vec(),
1707 ..Default::default()
1708 })
1709 .await
1710 .unwrap();
1711 tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1712 .await
1713 .expect("runner reached gate");
1714
1715 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1716 assert!(was_cancelled);
1717
1718 let _ = gate_tx.send(());
1722
1723 let outcome = tokio::time::timeout(Duration::from_secs(2), hook_rx.recv())
1724 .await
1725 .expect("hook fired")
1726 .expect("hook channel open");
1727 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1728 assert!(
1729 outcome.error.is_none(),
1730 "external cancel must carry no reason (Some(_) would imply runner-issued StepOutcome::Cancel)",
1731 );
1732 assert!(runtime.status(&handle.run_id).await.is_none());
1733
1734 tokio::time::sleep(Duration::from_millis(100)).await;
1737 assert_eq!(
1738 calls.load(Ordering::SeqCst),
1739 1,
1740 "cancellation must suppress retries",
1741 );
1742 let stats = queue.stats("workflow-steps").await.unwrap();
1743 assert_eq!(stats.dead, 0, "cancellation must suppress dead-letter");
1744 assert!(
1745 hook_rx.try_recv().is_err(),
1746 "hook must fire exactly once for the cancelled run",
1747 );
1748
1749 let _ = shutdown.send(());
1750 }
1751
1752 #[tokio::test(start_paused = true)]
1753 async fn cancel_suppresses_permanent_runner_error() {
1754 assert_cancel_suppresses_runner_error(StepError::permanent("would-dead-letter")).await;
1759 }
1760
1761 #[tokio::test(start_paused = true)]
1762 async fn cancel_suppresses_transient_runner_error() {
1763 assert_cancel_suppresses_runner_error(StepError::transient("would-retry")).await;
1768 }
1769
1770 #[tokio::test(start_paused = true)]
1771 async fn cancel_signals_step_token_for_cooperative_short_circuit() {
1772 struct CooperativeRunner {
1780 claimed: Arc<tokio::sync::Notify>,
1781 }
1782 impl StepRunner for CooperativeRunner {
1783 async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
1784 self.claimed.notify_one();
1785 tokio::select! {
1786 _ = tokio::time::sleep(Duration::from_secs(30)) => {
1787 Ok(StepOutcome::Succeed { result: b"slow".to_vec() })
1788 }
1789 _ = step.cancel_token.cancelled() => {
1790 Ok(StepOutcome::Cancel { reason: "cooperative".to_string() })
1791 }
1792 }
1793 }
1794 }
1795
1796 let queue = fresh_queue().await;
1797 let claimed = Arc::new(tokio::sync::Notify::new());
1798 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1799 let runtime = WorkflowRuntime::builder(
1800 queue.clone(),
1801 CooperativeRunner {
1802 claimed: claimed.clone(),
1803 },
1804 ChannelHook { tx },
1805 )
1806 .build();
1807 let shutdown = spawn_runtime(runtime.clone());
1808
1809 let handle = runtime
1810 .submit(RunSpec {
1811 input: b"x".to_vec(),
1812 ..Default::default()
1813 })
1814 .await
1815 .unwrap();
1816 tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1817 .await
1818 .expect("runner observed token");
1819
1820 let start = std::time::Instant::now();
1821 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1822 assert!(was_cancelled);
1823
1824 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1825 .await
1826 .expect("hook fired well before the 30s sleep would have")
1827 .expect("hook channel open");
1828 let elapsed = start.elapsed();
1829
1830 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1831 assert_eq!(outcome.error.as_deref(), Some("cooperative"));
1834 assert!(
1835 elapsed < Duration::from_secs(2),
1836 "cooperative cancel must short-circuit the 30s sleep (took {elapsed:?})",
1837 );
1838 assert!(runtime.status(&handle.run_id).await.is_none());
1839
1840 let stats = queue.stats("workflow-steps").await.unwrap();
1841 assert_eq!(stats.dead, 0);
1842
1843 let _ = shutdown.send(());
1844 }
1845
1846 #[tokio::test(start_paused = true)]
1847 async fn double_cancel_fires_hook_once_and_second_call_returns_false() {
1848 struct UnreachableRunner;
1854 impl StepRunner for UnreachableRunner {
1855 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1856 unreachable!("worker must not claim the cancelled step");
1857 }
1858 }
1859
1860 let queue = fresh_queue().await;
1861 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1862 let runtime =
1863 WorkflowRuntime::builder(queue, UnreachableRunner, ChannelHook { tx }).build();
1864 let handle = runtime
1868 .submit(RunSpec {
1869 input: b"x".to_vec(),
1870 ..Default::default()
1871 })
1872 .await
1873 .unwrap();
1874
1875 let first = runtime.cancel(&handle.run_id).await.unwrap();
1876 assert!(first, "first cancel initiates termination");
1877
1878 let second = runtime.cancel(&handle.run_id).await.unwrap();
1879 assert!(
1880 !second,
1881 "second cancel must report Ok(false): registry entry is gone after the first fired the hook",
1882 );
1883
1884 let _ = tokio::time::timeout(Duration::from_secs(1), rx.recv())
1886 .await
1887 .expect("hook fired in time")
1888 .expect("hook channel open");
1889 tokio::time::sleep(Duration::from_millis(50)).await;
1890 assert!(
1891 rx.try_recv().is_err(),
1892 "hook must fire exactly once for a double-cancelled run",
1893 );
1894 }
1895
1896 #[tokio::test(start_paused = true)]
1897 async fn cancel_after_run_already_terminated_returns_false() {
1898 let queue = fresh_queue().await;
1903 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1904 let runtime = WorkflowRuntime::builder(
1905 queue,
1906 ScriptedRunner::new(vec![StepOutcome::Succeed {
1907 result: b"done".to_vec(),
1908 }]),
1909 ChannelHook { tx },
1910 )
1911 .build();
1912 let shutdown = spawn_runtime(runtime.clone());
1913
1914 let handle = runtime
1915 .submit(RunSpec {
1916 input: b"x".to_vec(),
1917 ..Default::default()
1918 })
1919 .await
1920 .unwrap();
1921
1922 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1923 .await
1924 .expect("Succeeded hook fired")
1925 .expect("hook channel open");
1926 assert_eq!(outcome.status, TerminalStatus::Succeeded);
1927 assert!(runtime.status(&handle.run_id).await.is_none());
1928
1929 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1930 assert!(
1931 !was_cancelled,
1932 "cancel on an already-terminated run must report Ok(false)",
1933 );
1934
1935 tokio::time::sleep(Duration::from_millis(50)).await;
1936 assert!(
1937 rx.try_recv().is_err(),
1938 "no Cancelled hook may fire after the run already terminated as Succeeded",
1939 );
1940
1941 let _ = shutdown.send(());
1942 }
1943
1944 #[tokio::test(start_paused = true)]
1945 async fn status_reports_cancelling_while_termination_in_flight() {
1946 struct GatedRunner {
1952 claimed: Arc<tokio::sync::Notify>,
1953 gate: tokio::sync::Mutex<Option<oneshot::Receiver<()>>>,
1954 }
1955 impl StepRunner for GatedRunner {
1956 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1957 self.claimed.notify_one();
1958 let rx = self.gate.lock().await.take().expect("gate consumed twice");
1959 let _ = rx.await;
1960 Ok(StepOutcome::Succeed {
1961 result: b"would-have-succeeded".to_vec(),
1962 })
1963 }
1964 }
1965
1966 let queue = fresh_queue().await;
1967 let claimed = Arc::new(tokio::sync::Notify::new());
1968 let (gate_tx, gate_rx) = oneshot::channel::<()>();
1969 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1970 let runtime = WorkflowRuntime::builder(
1971 queue,
1972 GatedRunner {
1973 claimed: claimed.clone(),
1974 gate: tokio::sync::Mutex::new(Some(gate_rx)),
1975 },
1976 ChannelHook { tx },
1977 )
1978 .build();
1979 let shutdown = spawn_runtime(runtime.clone());
1980
1981 let handle = runtime
1982 .submit(RunSpec {
1983 input: b"x".to_vec(),
1984 ..Default::default()
1985 })
1986 .await
1987 .unwrap();
1988 tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1989 .await
1990 .expect("runner reached gate");
1991
1992 let before = runtime.status(&handle.run_id).await.expect("active");
1994 assert_eq!(before.state, RunState::Running);
1995
1996 runtime.cancel(&handle.run_id).await.unwrap();
1997
1998 let during = runtime
2002 .status(&handle.run_id)
2003 .await
2004 .expect("entry retained while termination is in flight");
2005 assert_eq!(during.state, RunState::Cancelling);
2006
2007 let _ = gate_tx.send(());
2010
2011 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
2012 .await
2013 .expect("hook fired")
2014 .expect("hook channel open");
2015 assert_eq!(outcome.status, TerminalStatus::Cancelled);
2016 assert!(runtime.status(&handle.run_id).await.is_none());
2017
2018 let _ = shutdown.send(());
2019 }
2020
2021 #[tokio::test(start_paused = true)]
2022 async fn cancel_unknown_run_returns_false() {
2023 let queue = fresh_queue().await;
2024 let runtime: WorkflowRuntime<ScriptedRunner, NoopTerminalHook> =
2025 WorkflowRuntime::builder(queue, ScriptedRunner::new(vec![]), NoopTerminalHook).build();
2026
2027 let was_cancelled = runtime.cancel("never-submitted").await.unwrap();
2028 assert!(!was_cancelled);
2029 }
2030
2031 #[tokio::test(start_paused = true)]
2032 async fn transient_fires_once_on_single_attempt() {
2033 assert_transient_retries_until_max(1).await;
2034 }
2035
2036 #[tokio::test(start_paused = true)]
2037 async fn transient_retries_up_to_max_attempts() {
2038 assert_transient_retries_until_max(3).await;
2039 }
2040}