1use std::collections::HashMap;
2use std::future::Future;
3use std::sync::Arc;
4use std::time::{Duration, SystemTime};
5
6use taquba::{EnqueueOptions, JobRecord, PermanentFailure, Queue, Worker, WorkerError};
7use tokio::sync::Mutex;
8use tokio_util::sync::CancellationToken;
9use tracing::{debug, instrument, warn};
10
11use crate::error::{Error, Result};
12use crate::runner::{Step, StepError, StepErrorKind, StepOutcome, StepRunner};
13use crate::terminal::{RunOutcome, TerminalHook, TerminalStatus};
14
15pub const HEADER_RUN_ID: &str = "workflow.run_id";
17pub const HEADER_STEP: &str = "workflow.step";
19pub const RESERVED_HEADER_PREFIX: &str = "workflow.";
23
24const DEDUP_PREFIX: &str = "run:";
25
26#[derive(Debug, Default)]
31struct StepEnqueueOpts {
32 run_at: Option<SystemTime>,
34 priority: Option<u32>,
36 max_attempts: Option<u32>,
38}
39
40#[derive(Debug, Clone, Default)]
42pub struct RunSpec {
43 pub run_id: Option<String>,
48 pub input: Vec<u8>,
50 pub headers: HashMap<String, String>,
54 pub priority: Option<u32>,
56 pub max_attempts_per_step: Option<u32>,
58}
59
60#[derive(Debug, Clone)]
62pub struct RunHandle {
63 pub run_id: String,
65 pub first_job_id: String,
67}
68
69#[derive(Debug, Clone)]
73pub struct RunStatus {
74 pub run_id: String,
76 pub state: RunState,
78 pub current_step: u32,
80}
81
82#[derive(Debug, Clone, Copy, PartialEq, Eq)]
84#[non_exhaustive]
85pub enum RunState {
86 Pending,
88 Running,
90 Cancelling,
103}
104
105pub struct WorkflowRuntimeBuilder<R, H> {
111 queue: Arc<Queue>,
112 queue_name: String,
113 runner: R,
114 terminal_hook: H,
115 max_concurrent_steps: usize,
116 poll_interval: Duration,
117}
118
119impl<R: StepRunner, H: TerminalHook> WorkflowRuntimeBuilder<R, H> {
120 pub fn queue_name(mut self, name: impl Into<String>) -> Self {
124 self.queue_name = name.into();
125 self
126 }
127
128 pub fn max_concurrent_steps(mut self, n: usize) -> Self {
131 assert!(n > 0, "max_concurrent_steps must be at least 1");
132 self.max_concurrent_steps = n;
133 self
134 }
135
136 pub fn poll_interval(mut self, interval: Duration) -> Self {
139 self.poll_interval = interval;
140 self
141 }
142
143 pub fn build(self) -> WorkflowRuntime<R, H> {
145 let inner = RuntimeInner {
146 queue: self.queue,
147 queue_name: self.queue_name,
148 runner: self.runner,
149 terminal_hook: self.terminal_hook,
150 max_concurrent_steps: self.max_concurrent_steps,
151 poll_interval: self.poll_interval,
152 registry: Mutex::new(HashMap::new()),
153 };
154 WorkflowRuntime {
155 inner: Arc::new(inner),
156 }
157 }
158}
159
160pub struct WorkflowRuntime<R, H> {
162 inner: Arc<RuntimeInner<R, H>>,
163}
164
165impl<R, H> Clone for WorkflowRuntime<R, H> {
166 fn clone(&self) -> Self {
167 Self {
168 inner: self.inner.clone(),
169 }
170 }
171}
172
173struct RuntimeInner<R, H> {
174 queue: Arc<Queue>,
175 queue_name: String,
176 runner: R,
177 terminal_hook: H,
178 max_concurrent_steps: usize,
179 poll_interval: Duration,
180 registry: Mutex<HashMap<String, RegistryEntry>>,
181}
182
183struct RegistryEntry {
192 status: RunStatus,
193 current_job_id: String,
194 user_headers: HashMap<String, String>,
195 cancel_requested: bool,
196 cancel_token: CancellationToken,
197}
198
199impl<R: StepRunner, H: TerminalHook> WorkflowRuntime<R, H> {
200 pub fn builder(queue: Arc<Queue>, runner: R, terminal_hook: H) -> WorkflowRuntimeBuilder<R, H> {
208 WorkflowRuntimeBuilder {
209 queue,
210 queue_name: "workflow-steps".to_string(),
211 runner,
212 terminal_hook,
213 max_concurrent_steps: 16,
214 poll_interval: Duration::from_millis(250),
215 }
216 }
217
218 #[instrument(skip(self, spec), fields(run_id))]
224 pub async fn submit(&self, spec: RunSpec) -> Result<RunHandle> {
225 let run_id = spec.run_id.unwrap_or_else(|| ulid::Ulid::new().to_string());
226 tracing::Span::current().record("run_id", run_id.as_str());
227
228 for k in spec.headers.keys() {
229 if k.starts_with(RESERVED_HEADER_PREFIX) {
230 return Err(Error::ReservedHeaderInSubmit(k.clone()));
231 }
232 }
233
234 let mut registry = self.inner.registry.lock().await;
239 if registry.contains_key(&run_id) {
240 return Err(Error::DuplicateRun(run_id));
241 }
242
243 let job_id = self
244 .inner
245 .enqueue_step(
246 &run_id,
247 0,
248 spec.input,
249 &spec.headers,
250 StepEnqueueOpts {
251 priority: spec.priority,
252 max_attempts: spec.max_attempts_per_step,
253 ..Default::default()
254 },
255 )
256 .await?;
257
258 registry.insert(
259 run_id.clone(),
260 RegistryEntry {
261 status: RunStatus {
262 run_id: run_id.clone(),
263 state: RunState::Pending,
264 current_step: 0,
265 },
266 current_job_id: job_id.clone(),
267 user_headers: spec.headers.clone(),
268 cancel_requested: false,
269 cancel_token: CancellationToken::new(),
270 },
271 );
272 drop(registry);
273
274 debug!(run_id = %run_id, job_id = %job_id, "run submitted");
275 Ok(RunHandle {
276 run_id,
277 first_job_id: job_id,
278 })
279 }
280
281 pub async fn status(&self, run_id: &str) -> Option<RunStatus> {
289 self.inner.registry.lock().await.get(run_id).map(|e| {
290 let mut status = e.status.clone();
291 if e.cancel_requested {
292 status.state = RunState::Cancelling;
293 }
294 status
295 })
296 }
297
298 pub async fn cancel(&self, run_id: &str) -> Result<bool> {
324 let (job_id, headers, current_step) = {
325 let mut registry = self.inner.registry.lock().await;
326 let Some(entry) = registry.get_mut(run_id) else {
327 return Ok(false);
328 };
329 entry.cancel_requested = true;
330 entry.cancel_token.cancel();
336 (
337 entry.current_job_id.clone(),
338 entry.user_headers.clone(),
339 entry.status.current_step,
340 )
341 };
342
343 let cancelled_in_queue = self.inner.queue.cancel(&job_id).await?;
344 if cancelled_in_queue {
345 self.inner
349 .terminate(RunOutcome {
350 run_id: run_id.to_string(),
351 status: TerminalStatus::Cancelled,
352 result: None,
353 error: None,
354 headers,
355 final_step: current_step,
356 })
357 .await;
358 }
359 Ok(true)
362 }
363
364 pub async fn run<F>(&self, shutdown: F) -> Result<()>
367 where
368 F: Future<Output = ()>,
369 R: 'static,
370 H: 'static,
371 {
372 let worker = Arc::new(StepWorker {
373 inner: self.inner.clone(),
374 });
375 taquba::run_worker_concurrent(
376 &self.inner.queue,
377 &self.inner.queue_name,
378 worker,
379 self.inner.max_concurrent_steps,
380 self.inner.poll_interval,
381 shutdown,
382 )
383 .await?;
384 Ok(())
385 }
386}
387
388struct StepWorker<R, H> {
389 inner: Arc<RuntimeInner<R, H>>,
390}
391
392impl<R: StepRunner + 'static, H: TerminalHook + 'static> Worker for StepWorker<R, H> {
393 async fn process(&self, job: &JobRecord) -> std::result::Result<(), WorkerError> {
394 self.inner.process_step(job).await
395 }
396}
397
398impl<R: StepRunner, H: TerminalHook> RuntimeInner<R, H> {
399 async fn enqueue_step(
400 &self,
401 run_id: &str,
402 step_number: u32,
403 payload: Vec<u8>,
404 user_headers: &HashMap<String, String>,
405 opts: StepEnqueueOpts,
406 ) -> Result<String> {
407 let mut headers = user_headers.clone();
408 headers.insert(HEADER_RUN_ID.to_string(), run_id.to_string());
409 headers.insert(HEADER_STEP.to_string(), step_number.to_string());
410
411 let enqueue_opts = EnqueueOptions {
412 headers,
413 run_at: opts.run_at,
414 priority: opts.priority,
415 max_attempts: opts.max_attempts,
416 dedup_key: Some(format!("{DEDUP_PREFIX}{run_id}:{step_number}")),
417 };
418 Ok(self
419 .queue
420 .enqueue_with(&self.queue_name, payload, enqueue_opts)
421 .await?)
422 }
423
424 fn split_headers(headers: &HashMap<String, String>) -> HashMap<String, String> {
425 headers
426 .iter()
427 .filter(|(k, _)| !k.starts_with(RESERVED_HEADER_PREFIX))
428 .map(|(k, v)| (k.clone(), v.clone()))
429 .collect()
430 }
431
432 fn parse_step_headers(job: &JobRecord) -> std::result::Result<(String, u32), Error> {
433 let run_id = job
434 .headers
435 .get(HEADER_RUN_ID)
436 .ok_or(Error::MissingHeader(HEADER_RUN_ID))?
437 .to_string();
438 let step_str = job
439 .headers
440 .get(HEADER_STEP)
441 .ok_or(Error::MissingHeader(HEADER_STEP))?;
442 let step_number: u32 = step_str.parse().map_err(|_| Error::InvalidStepHeader {
443 header: HEADER_STEP,
444 value: step_str.clone(),
445 })?;
446 Ok((run_id, step_number))
447 }
448
449 async fn terminate(&self, outcome: RunOutcome) {
455 self.registry.lock().await.remove(&outcome.run_id);
456 self.terminal_hook.on_termination(&outcome).await;
457 }
458
459 async fn registry_mark_running(
467 &self,
468 run_id: &str,
469 step_number: u32,
470 job_id: &str,
471 user_headers: &HashMap<String, String>,
472 ) -> CancellationToken {
473 let mut registry = self.registry.lock().await;
474 match registry.get_mut(run_id) {
475 Some(entry) => {
476 entry.status.state = RunState::Running;
477 entry.status.current_step = step_number;
478 entry.current_job_id = job_id.to_string();
479 entry.cancel_token.clone()
480 }
481 None => {
482 let cancel_token = CancellationToken::new();
483 registry.insert(
484 run_id.to_string(),
485 RegistryEntry {
486 status: RunStatus {
487 run_id: run_id.to_string(),
488 state: RunState::Running,
489 current_step: step_number,
490 },
491 current_job_id: job_id.to_string(),
492 user_headers: user_headers.clone(),
493 cancel_requested: false,
494 cancel_token: cancel_token.clone(),
495 },
496 );
497 cancel_token
498 }
499 }
500 }
501
502 async fn process_step(&self, job: &JobRecord) -> std::result::Result<(), WorkerError> {
503 let (run_id, step_number) = match Self::parse_step_headers(job) {
504 Ok(v) => v,
505 Err(e) => {
506 warn!(job_id = %job.id, error = %e, "workflow step has malformed headers");
507 if e.is_permanent() {
508 return Err(PermanentFailure::new(e.to_string()).into());
509 }
510 return Err(e.to_string().into());
511 }
512 };
513
514 let user_headers = Self::split_headers(&job.headers);
515
516 let cancel_token = self
517 .registry_mark_running(&run_id, step_number, &job.id, &user_headers)
518 .await;
519
520 let step = Step {
521 run_id: run_id.clone(),
522 step_number,
523 payload: job.payload.clone(),
524 headers: user_headers.clone(),
525 job_id: job.id.clone(),
526 attempts: job.attempts,
527 cancel_token,
528 };
529
530 let inherit_opts = || StepEnqueueOpts {
533 run_at: None,
534 priority: Some(job.priority),
535 max_attempts: Some(job.max_attempts),
536 };
537
538 let outcome = self.runner.run_step(&step).await;
539 let external_cancel = self
540 .registry
541 .lock()
542 .await
543 .get(&run_id)
544 .is_some_and(|e| e.cancel_requested);
545
546 match outcome {
554 Ok(StepOutcome::Cancel { reason }) => {
555 self.terminate(RunOutcome {
556 run_id: run_id.clone(),
557 status: TerminalStatus::Cancelled,
558 result: None,
559 error: Some(reason),
560 headers: user_headers,
561 final_step: step_number,
562 })
563 .await;
564 Ok(())
565 }
566 _ if external_cancel => {
567 self.terminate(RunOutcome {
568 run_id: run_id.clone(),
569 status: TerminalStatus::Cancelled,
570 result: None,
571 error: None,
572 headers: user_headers,
573 final_step: step_number,
574 })
575 .await;
576 Ok(())
577 }
578 Ok(StepOutcome::Continue { payload }) => {
579 self.advance(
580 &run_id,
581 step_number + 1,
582 payload,
583 &user_headers,
584 inherit_opts(),
585 )
586 .await
587 }
588 Ok(StepOutcome::ContinueAfter { payload, delay }) => {
589 let opts = StepEnqueueOpts {
590 run_at: Some(SystemTime::now() + delay),
591 ..inherit_opts()
592 };
593 self.advance(&run_id, step_number + 1, payload, &user_headers, opts)
594 .await
595 }
596 Ok(StepOutcome::Succeed { result }) => {
597 self.terminate(RunOutcome {
598 run_id: run_id.clone(),
599 status: TerminalStatus::Succeeded,
600 result: Some(result),
601 error: None,
602 headers: user_headers,
603 final_step: step_number,
604 })
605 .await;
606 Ok(())
607 }
608 Ok(StepOutcome::Fail { reason }) => {
609 self.terminate(RunOutcome {
613 run_id: run_id.clone(),
614 status: TerminalStatus::Failed,
615 result: None,
616 error: Some(reason),
617 headers: user_headers,
618 final_step: step_number,
619 })
620 .await;
621 Ok(())
622 }
623 Err(StepError {
624 message,
625 kind: StepErrorKind::Permanent,
626 }) => {
627 self.terminate(RunOutcome {
628 run_id: run_id.clone(),
629 status: TerminalStatus::Failed,
630 result: None,
631 error: Some(message.clone()),
632 headers: user_headers,
633 final_step: step_number,
634 })
635 .await;
636 Err(PermanentFailure::new(message).into())
637 }
638 Err(StepError {
639 message,
640 kind: StepErrorKind::Transient,
641 }) => {
642 if job.attempts >= job.max_attempts {
646 self.terminate(RunOutcome {
647 run_id: run_id.clone(),
648 status: TerminalStatus::Failed,
649 result: None,
650 error: Some(message.clone()),
651 headers: user_headers,
652 final_step: step_number,
653 })
654 .await;
655 }
656 Err(message.into())
657 }
658 }
659 }
660
661 async fn advance(
662 &self,
663 run_id: &str,
664 next_step: u32,
665 payload: Vec<u8>,
666 user_headers: &HashMap<String, String>,
667 opts: StepEnqueueOpts,
668 ) -> std::result::Result<(), WorkerError> {
669 match self
670 .enqueue_step(run_id, next_step, payload, user_headers, opts)
671 .await
672 {
673 Ok(new_job_id) => {
674 if let Some(entry) = self.registry.lock().await.get_mut(run_id) {
676 entry.status.state = RunState::Pending;
677 entry.status.current_step = next_step;
678 entry.current_job_id = new_job_id;
679 }
680 Ok(())
681 }
682 Err(e) => Err(e.to_string().into()),
686 }
687 }
688}
689
690#[cfg(test)]
691mod tests {
692 use super::*;
693 use crate::terminal::NoopTerminalHook;
694 use std::sync::Mutex as StdMutex;
695 use std::sync::atomic::{AtomicU32, Ordering};
696 use taquba::object_store::memory::InMemory;
697 use taquba::{OpenOptions, QueueConfig};
698 use tokio::sync::oneshot;
699
700 struct ChannelHook {
702 tx: tokio::sync::mpsc::UnboundedSender<RunOutcome>,
703 }
704
705 impl TerminalHook for ChannelHook {
706 async fn on_termination(&self, outcome: &RunOutcome) {
707 let _ = self.tx.send(outcome.clone());
708 }
709 }
710
711 struct ScriptedRunner {
713 script: Arc<StdMutex<Vec<StepOutcome>>>,
714 }
715
716 impl ScriptedRunner {
717 fn new(steps: Vec<StepOutcome>) -> Self {
718 Self {
719 script: Arc::new(StdMutex::new(steps)),
720 }
721 }
722 }
723
724 impl StepRunner for ScriptedRunner {
725 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
726 let next = self.script.lock().unwrap().remove(0);
727 Ok(next)
728 }
729 }
730
731 async fn fresh_queue() -> Arc<Queue> {
732 Arc::new(
733 Queue::open(Arc::new(InMemory::new()), "test")
734 .await
735 .unwrap(),
736 )
737 }
738
739 async fn fresh_queue_fast_retry() -> Arc<Queue> {
742 let opts = OpenOptions {
743 default_queue_config: QueueConfig {
744 retry_backoff_base: Duration::ZERO,
745 ..QueueConfig::default()
746 },
747 reaper_interval: Duration::from_millis(50),
748 scheduler_interval: Duration::from_millis(50),
749 ..OpenOptions::default()
750 };
751 Arc::new(
752 Queue::open_with_options(Arc::new(InMemory::new()), "test", opts)
753 .await
754 .unwrap(),
755 )
756 }
757
758 fn spawn_runtime<R, H>(runtime: WorkflowRuntime<R, H>) -> oneshot::Sender<()>
759 where
760 R: StepRunner + 'static,
761 H: TerminalHook + 'static,
762 {
763 let (tx, rx) = oneshot::channel::<()>();
764 tokio::spawn(async move {
765 let _ = runtime
766 .run(async move {
767 let _ = rx.await;
768 })
769 .await;
770 });
771 tx
772 }
773
774 #[tokio::test]
775 async fn single_step_succeeds_and_fires_hook() {
776 let queue = fresh_queue().await;
777 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
778 let runtime = WorkflowRuntime::builder(
779 queue,
780 ScriptedRunner::new(vec![StepOutcome::Succeed {
781 result: b"done".to_vec(),
782 }]),
783 ChannelHook { tx },
784 )
785 .build();
786 let shutdown = spawn_runtime(runtime.clone());
787
788 let handle = runtime
789 .submit(RunSpec {
790 input: b"in".to_vec(),
791 ..Default::default()
792 })
793 .await
794 .unwrap();
795 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
796 .await
797 .unwrap()
798 .unwrap();
799
800 assert_eq!(outcome.run_id, handle.run_id);
801 assert_eq!(outcome.status, TerminalStatus::Succeeded);
802 assert_eq!(outcome.result.as_deref(), Some(b"done".as_slice()));
803 assert_eq!(outcome.final_step, 0);
804 assert!(runtime.status(&handle.run_id).await.is_none());
805
806 let _ = shutdown.send(());
807 }
808
809 #[tokio::test]
810 async fn multi_step_run_advances_through_continue() {
811 let queue = fresh_queue().await;
812 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
813 let runtime = WorkflowRuntime::builder(
814 queue,
815 ScriptedRunner::new(vec![
816 StepOutcome::Continue {
817 payload: b"step1".to_vec(),
818 },
819 StepOutcome::Continue {
820 payload: b"step2".to_vec(),
821 },
822 StepOutcome::Succeed {
823 result: b"final".to_vec(),
824 },
825 ]),
826 ChannelHook { tx },
827 )
828 .build();
829 let shutdown = spawn_runtime(runtime.clone());
830
831 let handle = runtime
832 .submit(RunSpec {
833 input: b"start".to_vec(),
834 ..Default::default()
835 })
836 .await
837 .unwrap();
838 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
839 .await
840 .unwrap()
841 .unwrap();
842
843 assert_eq!(outcome.run_id, handle.run_id);
844 assert_eq!(outcome.final_step, 2);
845 assert_eq!(outcome.status, TerminalStatus::Succeeded);
846 assert_eq!(outcome.result.as_deref(), Some(b"final".as_slice()));
847
848 let _ = shutdown.send(());
849 }
850
851 #[tokio::test]
852 async fn permanent_failure_dead_letters_and_fires_hook() {
853 struct FailingRunner;
854 impl StepRunner for FailingRunner {
855 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
856 Err(StepError::permanent("nope"))
857 }
858 }
859
860 let queue = fresh_queue().await;
861 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
862 let runtime =
863 WorkflowRuntime::builder(queue.clone(), FailingRunner, ChannelHook { tx }).build();
864 let shutdown = spawn_runtime(runtime.clone());
865
866 let handle = runtime
867 .submit(RunSpec {
868 input: b"x".to_vec(),
869 ..Default::default()
870 })
871 .await
872 .unwrap();
873 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
874 .await
875 .unwrap()
876 .unwrap();
877
878 assert_eq!(outcome.run_id, handle.run_id);
879 assert_eq!(outcome.status, TerminalStatus::Failed);
880 assert_eq!(outcome.error.as_deref(), Some("nope"));
881 assert!(runtime.status(&handle.run_id).await.is_none());
882
883 let stats = queue.stats("workflow-steps").await.unwrap();
885 assert_eq!(stats.dead, 1, "permanent error should dead-letter");
886
887 let _ = shutdown.send(());
888 }
889
890 #[tokio::test]
891 async fn fail_outcome_terminates_run_without_dead_letter() {
892 struct VerdictRunner;
897 impl StepRunner for VerdictRunner {
898 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
899 Ok(StepOutcome::Fail {
900 reason: "agent declined the task".to_string(),
901 })
902 }
903 }
904
905 let queue = fresh_queue().await;
906 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
907 let runtime =
908 WorkflowRuntime::builder(queue.clone(), VerdictRunner, ChannelHook { tx }).build();
909 let shutdown = spawn_runtime(runtime.clone());
910
911 let handle = runtime
912 .submit(RunSpec {
913 input: b"x".to_vec(),
914 ..Default::default()
915 })
916 .await
917 .unwrap();
918
919 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
920 .await
921 .expect("hook fired in time")
922 .expect("hook channel open");
923
924 assert_eq!(outcome.run_id, handle.run_id);
925 assert_eq!(outcome.status, TerminalStatus::Failed);
926 assert_eq!(outcome.error.as_deref(), Some("agent declined the task"));
927 assert!(runtime.status(&handle.run_id).await.is_none());
928
929 let stats = queue.stats("workflow-steps").await.unwrap();
932 assert_eq!(stats.dead, 0, "Fail verdict must not dead-letter");
933
934 let _ = shutdown.send(());
935 }
936
937 #[tokio::test]
938 async fn duplicate_submit_in_process_is_rejected() {
939 struct PauseRunner;
942 impl StepRunner for PauseRunner {
943 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
944 std::future::pending().await
945 }
946 }
947
948 let queue = fresh_queue().await;
949 let runtime = WorkflowRuntime::builder(queue, PauseRunner, NoopTerminalHook).build();
950 let shutdown = spawn_runtime(runtime.clone());
951
952 let handle = runtime
953 .submit(RunSpec {
954 run_id: Some("fixed-id".to_string()),
955 input: b"x".to_vec(),
956 ..Default::default()
957 })
958 .await
959 .unwrap();
960 for _ in 0..40 {
963 if runtime.status(&handle.run_id).await.is_some() {
964 break;
965 }
966 tokio::time::sleep(Duration::from_millis(25)).await;
967 }
968 assert!(runtime.status(&handle.run_id).await.is_some());
969
970 let err = runtime
971 .submit(RunSpec {
972 run_id: Some("fixed-id".to_string()),
973 input: b"y".to_vec(),
974 ..Default::default()
975 })
976 .await
977 .unwrap_err();
978 assert!(matches!(err, Error::DuplicateRun(id) if id == "fixed-id"));
979
980 let _ = shutdown.send(());
981 }
982
983 #[tokio::test]
984 async fn reserved_header_on_submit_is_rejected() {
985 let queue = fresh_queue().await;
986 let runtime: WorkflowRuntime<ScriptedRunner, NoopTerminalHook> =
987 WorkflowRuntime::builder(queue, ScriptedRunner::new(vec![]), NoopTerminalHook).build();
988 let mut headers = HashMap::new();
989 headers.insert("workflow.run_id".to_string(), "evil".to_string());
990
991 let err = runtime
992 .submit(RunSpec {
993 input: b"x".to_vec(),
994 headers,
995 ..Default::default()
996 })
997 .await
998 .unwrap_err();
999 assert!(
1000 matches!(&err, Error::ReservedHeaderInSubmit(k) if k == "workflow.run_id"),
1001 "got: {err:?}"
1002 );
1003 }
1004
1005 #[tokio::test]
1006 async fn user_headers_thread_through_to_terminal_hook() {
1007 let queue = fresh_queue().await;
1008 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1009 let runtime = WorkflowRuntime::builder(
1010 queue,
1011 ScriptedRunner::new(vec![
1012 StepOutcome::Continue { payload: vec![] },
1013 StepOutcome::Succeed { result: vec![] },
1014 ]),
1015 ChannelHook { tx },
1016 )
1017 .build();
1018 let shutdown = spawn_runtime(runtime.clone());
1019
1020 let mut headers = HashMap::new();
1021 headers.insert("trace_id".to_string(), "abc-123".to_string());
1022 headers.insert("tenant".to_string(), "acme".to_string());
1023
1024 runtime
1025 .submit(RunSpec {
1026 input: b"x".to_vec(),
1027 headers,
1028 ..Default::default()
1029 })
1030 .await
1031 .unwrap();
1032 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1033 .await
1034 .unwrap()
1035 .unwrap();
1036
1037 assert_eq!(outcome.headers.get("trace_id").unwrap(), "abc-123");
1038 assert_eq!(outcome.headers.get("tenant").unwrap(), "acme");
1039 assert!(!outcome.headers.contains_key(HEADER_RUN_ID));
1041 assert!(!outcome.headers.contains_key(HEADER_STEP));
1042
1043 let _ = shutdown.send(());
1044 }
1045
1046 #[tokio::test]
1047 async fn restart_resumes_at_next_step() {
1048 struct GatedRunner {
1058 gate: tokio::sync::Mutex<Option<oneshot::Receiver<Vec<u8>>>>,
1059 }
1060
1061 impl StepRunner for GatedRunner {
1062 async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
1063 match step.step_number {
1064 0 => {
1065 let rx = self.gate.lock().await.take().expect("gate consumed twice");
1066 let payload = rx.await.expect("gate sender dropped");
1067 Ok(StepOutcome::Continue { payload })
1068 }
1069 _ => std::future::pending().await,
1070 }
1071 }
1072 }
1073
1074 struct CompleteOnStep1;
1075 impl StepRunner for CompleteOnStep1 {
1076 async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
1077 assert_eq!(step.step_number, 1, "runtime B should only ever see step 1");
1078 assert_eq!(step.payload.as_slice(), b"step1-payload");
1079 Ok(StepOutcome::Succeed {
1080 result: b"resumed".to_vec(),
1081 })
1082 }
1083 }
1084
1085 let queue = fresh_queue().await;
1086
1087 let (gate_tx, gate_rx) = oneshot::channel::<Vec<u8>>();
1088 let runtime_a = WorkflowRuntime::builder(
1089 queue.clone(),
1090 GatedRunner {
1091 gate: tokio::sync::Mutex::new(Some(gate_rx)),
1092 },
1093 NoopTerminalHook,
1094 )
1095 .max_concurrent_steps(1)
1096 .build();
1097
1098 let (shutdown_a_tx, shutdown_a_rx) = oneshot::channel::<()>();
1099 let worker_a = {
1100 let runtime_a = runtime_a.clone();
1101 tokio::spawn(async move {
1102 let _ = runtime_a
1103 .run(async move {
1104 let _ = shutdown_a_rx.await;
1105 })
1106 .await;
1107 })
1108 };
1109
1110 let handle = runtime_a
1111 .submit(RunSpec {
1112 input: b"input".to_vec(),
1113 ..Default::default()
1114 })
1115 .await
1116 .unwrap();
1117
1118 for _ in 0..80 {
1121 if let Some(s) = runtime_a.status(&handle.run_id).await {
1122 if s.state == RunState::Running && s.current_step == 0 {
1123 break;
1124 }
1125 }
1126 tokio::time::sleep(Duration::from_millis(25)).await;
1127 }
1128 let s = runtime_a.status(&handle.run_id).await.expect("status");
1129 assert_eq!(s.state, RunState::Running);
1130 assert_eq!(s.current_step, 0);
1131
1132 let _ = shutdown_a_tx.send(());
1136 let _ = gate_tx.send(b"step1-payload".to_vec());
1137
1138 worker_a.await.expect("runtime A drained cleanly");
1139
1140 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1143 let runtime_b =
1144 WorkflowRuntime::builder(queue, CompleteOnStep1, ChannelHook { tx }).build();
1145 let shutdown_b = spawn_runtime(runtime_b.clone());
1146
1147 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1148 .await
1149 .expect("hook fired in time")
1150 .expect("hook channel open");
1151
1152 assert_eq!(outcome.run_id, handle.run_id);
1153 assert_eq!(outcome.status, TerminalStatus::Succeeded);
1154 assert_eq!(outcome.result.as_deref(), Some(b"resumed".as_slice()));
1155 assert_eq!(outcome.final_step, 1);
1156
1157 let _ = shutdown_b.send(());
1158 }
1159
1160 async fn assert_transient_retries_until_max(max_attempts: u32) {
1166 struct AlwaysTransient {
1167 calls: Arc<AtomicU32>,
1168 }
1169 impl StepRunner for AlwaysTransient {
1170 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1171 self.calls.fetch_add(1, Ordering::SeqCst);
1172 Err(StepError::transient("flaky"))
1173 }
1174 }
1175
1176 let queue = fresh_queue_fast_retry().await;
1177 let calls = Arc::new(AtomicU32::new(0));
1178 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1179 let runtime = WorkflowRuntime::builder(
1180 queue,
1181 AlwaysTransient {
1182 calls: calls.clone(),
1183 },
1184 ChannelHook { tx },
1185 )
1186 .build();
1187 let shutdown = spawn_runtime(runtime.clone());
1188
1189 runtime
1190 .submit(RunSpec {
1191 input: b"x".to_vec(),
1192 max_attempts_per_step: Some(max_attempts),
1193 ..Default::default()
1194 })
1195 .await
1196 .unwrap();
1197
1198 let outcome = tokio::time::timeout(Duration::from_secs(3), rx.recv())
1199 .await
1200 .expect("hook fired in time")
1201 .expect("hook channel open");
1202
1203 assert_eq!(outcome.status, TerminalStatus::Failed);
1204 assert_eq!(outcome.error.as_deref(), Some("flaky"));
1205 assert_eq!(
1206 calls.load(Ordering::SeqCst),
1207 max_attempts,
1208 "runner called once per attempt up to max_attempts"
1209 );
1210
1211 tokio::time::sleep(Duration::from_millis(50)).await;
1213 assert!(rx.try_recv().is_err(), "hook fired more than once");
1214
1215 let _ = shutdown.send(());
1216 }
1217
1218 #[tokio::test]
1219 async fn cancel_outcome_terminates_run_without_dead_letter() {
1220 struct CancellingRunner;
1224 impl StepRunner for CancellingRunner {
1225 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1226 Ok(StepOutcome::Cancel {
1227 reason: "upstream aborted".to_string(),
1228 })
1229 }
1230 }
1231
1232 let queue = fresh_queue().await;
1233 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1234 let runtime =
1235 WorkflowRuntime::builder(queue.clone(), CancellingRunner, ChannelHook { tx }).build();
1236 let shutdown = spawn_runtime(runtime.clone());
1237
1238 let handle = runtime
1239 .submit(RunSpec {
1240 input: b"x".to_vec(),
1241 ..Default::default()
1242 })
1243 .await
1244 .unwrap();
1245
1246 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1247 .await
1248 .expect("hook fired in time")
1249 .expect("hook channel open");
1250
1251 assert_eq!(outcome.run_id, handle.run_id);
1252 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1253 assert_eq!(outcome.error.as_deref(), Some("upstream aborted"));
1254 assert!(runtime.status(&handle.run_id).await.is_none());
1255
1256 let stats = queue.stats("workflow-steps").await.unwrap();
1257 assert_eq!(stats.dead, 0, "Cancel verdict must not dead-letter");
1258
1259 let _ = shutdown.send(());
1260 }
1261
1262 #[tokio::test]
1263 async fn cancel_pending_run_fires_cancelled_hook() {
1264 struct UnreachableRunner;
1267 impl StepRunner for UnreachableRunner {
1268 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1269 unreachable!("worker must not claim the cancelled step");
1270 }
1271 }
1272
1273 let queue = fresh_queue().await;
1274 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1275 let runtime =
1276 WorkflowRuntime::builder(queue.clone(), UnreachableRunner, ChannelHook { tx }).build();
1277 let mut headers = HashMap::new();
1281 headers.insert("tenant".to_string(), "acme".to_string());
1282
1283 let handle = runtime
1284 .submit(RunSpec {
1285 input: b"x".to_vec(),
1286 headers,
1287 ..Default::default()
1288 })
1289 .await
1290 .unwrap();
1291 let status = runtime.status(&handle.run_id).await.expect("active");
1292 assert_eq!(status.state, RunState::Pending);
1293
1294 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1295 assert!(was_cancelled);
1296
1297 let outcome = tokio::time::timeout(Duration::from_secs(1), rx.recv())
1298 .await
1299 .expect("hook fired in time")
1300 .expect("hook channel open");
1301 assert_eq!(outcome.run_id, handle.run_id);
1302 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1303 assert!(outcome.error.is_none());
1305 assert_eq!(outcome.headers.get("tenant").unwrap(), "acme");
1306 assert!(runtime.status(&handle.run_id).await.is_none());
1307
1308 let stats = queue.stats("workflow-steps").await.unwrap();
1309 assert_eq!(stats.dead, 0, "cancel must not dead-letter");
1310 assert_eq!(stats.pending, 0, "cancelled job must be removed");
1311 }
1312
1313 #[tokio::test]
1314 async fn cancel_during_running_step_overrides_outcome() {
1315 struct GatedRunner {
1318 claimed: Arc<tokio::sync::Notify>,
1319 gate: tokio::sync::Mutex<Option<oneshot::Receiver<()>>>,
1320 }
1321 impl StepRunner for GatedRunner {
1322 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1323 self.claimed.notify_one();
1324 let rx = self.gate.lock().await.take().expect("gate consumed twice");
1325 let _ = rx.await;
1326 Ok(StepOutcome::Succeed {
1330 result: b"would-have-succeeded".to_vec(),
1331 })
1332 }
1333 }
1334
1335 let queue = fresh_queue().await;
1336 let claimed = Arc::new(tokio::sync::Notify::new());
1337 let (gate_tx, gate_rx) = oneshot::channel::<()>();
1338 let (hook_tx, mut hook_rx) = tokio::sync::mpsc::unbounded_channel();
1339 let runtime = WorkflowRuntime::builder(
1340 queue.clone(),
1341 GatedRunner {
1342 claimed: claimed.clone(),
1343 gate: tokio::sync::Mutex::new(Some(gate_rx)),
1344 },
1345 ChannelHook { tx: hook_tx },
1346 )
1347 .build();
1348 let shutdown = spawn_runtime(runtime.clone());
1349
1350 let handle = runtime
1351 .submit(RunSpec {
1352 input: b"x".to_vec(),
1353 ..Default::default()
1354 })
1355 .await
1356 .unwrap();
1357 tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1358 .await
1359 .expect("runner reached gate");
1360
1361 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1362 assert!(was_cancelled);
1363
1364 let _ = gate_tx.send(());
1367
1368 let outcome = tokio::time::timeout(Duration::from_secs(2), hook_rx.recv())
1369 .await
1370 .expect("hook fired")
1371 .expect("hook channel open");
1372 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1373 assert!(
1374 outcome.result.is_none(),
1375 "succeed payload must be discarded"
1376 );
1377 assert!(runtime.status(&handle.run_id).await.is_none());
1378
1379 let stats = queue.stats("workflow-steps").await.unwrap();
1380 assert_eq!(stats.dead, 0);
1381
1382 let _ = shutdown.send(());
1383 }
1384
1385 async fn assert_cancel_suppresses_runner_error(error: StepError) {
1393 struct GatedErrRunner {
1394 claimed: Arc<tokio::sync::Notify>,
1395 gate: tokio::sync::Mutex<Option<oneshot::Receiver<()>>>,
1396 calls: Arc<AtomicU32>,
1397 error: StdMutex<Option<StepError>>,
1398 }
1399 impl StepRunner for GatedErrRunner {
1400 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1401 self.calls.fetch_add(1, Ordering::SeqCst);
1402 self.claimed.notify_one();
1403 let rx = self.gate.lock().await.take().expect("gate consumed twice");
1404 let _ = rx.await;
1405 Err(self
1406 .error
1407 .lock()
1408 .unwrap()
1409 .take()
1410 .expect("error consumed twice"))
1411 }
1412 }
1413
1414 let queue = fresh_queue_fast_retry().await;
1415 let claimed = Arc::new(tokio::sync::Notify::new());
1416 let calls = Arc::new(AtomicU32::new(0));
1417 let (gate_tx, gate_rx) = oneshot::channel::<()>();
1418 let (hook_tx, mut hook_rx) = tokio::sync::mpsc::unbounded_channel();
1419 let runtime = WorkflowRuntime::builder(
1420 queue.clone(),
1421 GatedErrRunner {
1422 claimed: claimed.clone(),
1423 gate: tokio::sync::Mutex::new(Some(gate_rx)),
1424 calls: calls.clone(),
1425 error: StdMutex::new(Some(error)),
1426 },
1427 ChannelHook { tx: hook_tx },
1428 )
1429 .build();
1430 let shutdown = spawn_runtime(runtime.clone());
1431
1432 let handle = runtime
1433 .submit(RunSpec {
1434 input: b"x".to_vec(),
1435 ..Default::default()
1436 })
1437 .await
1438 .unwrap();
1439 tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1440 .await
1441 .expect("runner reached gate");
1442
1443 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1444 assert!(was_cancelled);
1445
1446 let _ = gate_tx.send(());
1450
1451 let outcome = tokio::time::timeout(Duration::from_secs(2), hook_rx.recv())
1452 .await
1453 .expect("hook fired")
1454 .expect("hook channel open");
1455 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1456 assert!(
1457 outcome.error.is_none(),
1458 "external cancel must carry no reason (Some(_) would imply runner-issued StepOutcome::Cancel)",
1459 );
1460 assert!(runtime.status(&handle.run_id).await.is_none());
1461
1462 tokio::time::sleep(Duration::from_millis(100)).await;
1465 assert_eq!(
1466 calls.load(Ordering::SeqCst),
1467 1,
1468 "cancellation must suppress retries",
1469 );
1470 let stats = queue.stats("workflow-steps").await.unwrap();
1471 assert_eq!(stats.dead, 0, "cancellation must suppress dead-letter");
1472 assert!(
1473 hook_rx.try_recv().is_err(),
1474 "hook must fire exactly once for the cancelled run",
1475 );
1476
1477 let _ = shutdown.send(());
1478 }
1479
1480 #[tokio::test]
1481 async fn cancel_suppresses_permanent_runner_error() {
1482 assert_cancel_suppresses_runner_error(StepError::permanent("would-dead-letter")).await;
1487 }
1488
1489 #[tokio::test]
1490 async fn cancel_suppresses_transient_runner_error() {
1491 assert_cancel_suppresses_runner_error(StepError::transient("would-retry")).await;
1496 }
1497
1498 #[tokio::test]
1499 async fn cancel_signals_step_token_for_cooperative_short_circuit() {
1500 struct CooperativeRunner {
1508 claimed: Arc<tokio::sync::Notify>,
1509 }
1510 impl StepRunner for CooperativeRunner {
1511 async fn run_step(&self, step: &Step) -> std::result::Result<StepOutcome, StepError> {
1512 self.claimed.notify_one();
1513 tokio::select! {
1514 _ = tokio::time::sleep(Duration::from_secs(30)) => {
1515 Ok(StepOutcome::Succeed { result: b"slow".to_vec() })
1516 }
1517 _ = step.cancel_token.cancelled() => {
1518 Ok(StepOutcome::Cancel { reason: "cooperative".to_string() })
1519 }
1520 }
1521 }
1522 }
1523
1524 let queue = fresh_queue().await;
1525 let claimed = Arc::new(tokio::sync::Notify::new());
1526 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1527 let runtime = WorkflowRuntime::builder(
1528 queue.clone(),
1529 CooperativeRunner {
1530 claimed: claimed.clone(),
1531 },
1532 ChannelHook { tx },
1533 )
1534 .build();
1535 let shutdown = spawn_runtime(runtime.clone());
1536
1537 let handle = runtime
1538 .submit(RunSpec {
1539 input: b"x".to_vec(),
1540 ..Default::default()
1541 })
1542 .await
1543 .unwrap();
1544 tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1545 .await
1546 .expect("runner observed token");
1547
1548 let start = std::time::Instant::now();
1549 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1550 assert!(was_cancelled);
1551
1552 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1553 .await
1554 .expect("hook fired well before the 30s sleep would have")
1555 .expect("hook channel open");
1556 let elapsed = start.elapsed();
1557
1558 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1559 assert_eq!(outcome.error.as_deref(), Some("cooperative"));
1562 assert!(
1563 elapsed < Duration::from_secs(2),
1564 "cooperative cancel must short-circuit the 30s sleep (took {elapsed:?})",
1565 );
1566 assert!(runtime.status(&handle.run_id).await.is_none());
1567
1568 let stats = queue.stats("workflow-steps").await.unwrap();
1569 assert_eq!(stats.dead, 0);
1570
1571 let _ = shutdown.send(());
1572 }
1573
1574 #[tokio::test]
1575 async fn double_cancel_fires_hook_once_and_second_call_returns_false() {
1576 struct UnreachableRunner;
1582 impl StepRunner for UnreachableRunner {
1583 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1584 unreachable!("worker must not claim the cancelled step");
1585 }
1586 }
1587
1588 let queue = fresh_queue().await;
1589 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1590 let runtime =
1591 WorkflowRuntime::builder(queue, UnreachableRunner, ChannelHook { tx }).build();
1592 let handle = runtime
1596 .submit(RunSpec {
1597 input: b"x".to_vec(),
1598 ..Default::default()
1599 })
1600 .await
1601 .unwrap();
1602
1603 let first = runtime.cancel(&handle.run_id).await.unwrap();
1604 assert!(first, "first cancel initiates termination");
1605
1606 let second = runtime.cancel(&handle.run_id).await.unwrap();
1607 assert!(
1608 !second,
1609 "second cancel must report Ok(false): registry entry is gone after the first fired the hook",
1610 );
1611
1612 let _ = tokio::time::timeout(Duration::from_secs(1), rx.recv())
1614 .await
1615 .expect("hook fired in time")
1616 .expect("hook channel open");
1617 tokio::time::sleep(Duration::from_millis(50)).await;
1618 assert!(
1619 rx.try_recv().is_err(),
1620 "hook must fire exactly once for a double-cancelled run",
1621 );
1622 }
1623
1624 #[tokio::test]
1625 async fn cancel_after_run_already_terminated_returns_false() {
1626 let queue = fresh_queue().await;
1631 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1632 let runtime = WorkflowRuntime::builder(
1633 queue,
1634 ScriptedRunner::new(vec![StepOutcome::Succeed {
1635 result: b"done".to_vec(),
1636 }]),
1637 ChannelHook { tx },
1638 )
1639 .build();
1640 let shutdown = spawn_runtime(runtime.clone());
1641
1642 let handle = runtime
1643 .submit(RunSpec {
1644 input: b"x".to_vec(),
1645 ..Default::default()
1646 })
1647 .await
1648 .unwrap();
1649
1650 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1651 .await
1652 .expect("Succeeded hook fired")
1653 .expect("hook channel open");
1654 assert_eq!(outcome.status, TerminalStatus::Succeeded);
1655 assert!(runtime.status(&handle.run_id).await.is_none());
1656
1657 let was_cancelled = runtime.cancel(&handle.run_id).await.unwrap();
1658 assert!(
1659 !was_cancelled,
1660 "cancel on an already-terminated run must report Ok(false)",
1661 );
1662
1663 tokio::time::sleep(Duration::from_millis(50)).await;
1664 assert!(
1665 rx.try_recv().is_err(),
1666 "no Cancelled hook may fire after the run already terminated as Succeeded",
1667 );
1668
1669 let _ = shutdown.send(());
1670 }
1671
1672 #[tokio::test]
1673 async fn status_reports_cancelling_while_termination_in_flight() {
1674 struct GatedRunner {
1680 claimed: Arc<tokio::sync::Notify>,
1681 gate: tokio::sync::Mutex<Option<oneshot::Receiver<()>>>,
1682 }
1683 impl StepRunner for GatedRunner {
1684 async fn run_step(&self, _step: &Step) -> std::result::Result<StepOutcome, StepError> {
1685 self.claimed.notify_one();
1686 let rx = self.gate.lock().await.take().expect("gate consumed twice");
1687 let _ = rx.await;
1688 Ok(StepOutcome::Succeed {
1689 result: b"would-have-succeeded".to_vec(),
1690 })
1691 }
1692 }
1693
1694 let queue = fresh_queue().await;
1695 let claimed = Arc::new(tokio::sync::Notify::new());
1696 let (gate_tx, gate_rx) = oneshot::channel::<()>();
1697 let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
1698 let runtime = WorkflowRuntime::builder(
1699 queue,
1700 GatedRunner {
1701 claimed: claimed.clone(),
1702 gate: tokio::sync::Mutex::new(Some(gate_rx)),
1703 },
1704 ChannelHook { tx },
1705 )
1706 .build();
1707 let shutdown = spawn_runtime(runtime.clone());
1708
1709 let handle = runtime
1710 .submit(RunSpec {
1711 input: b"x".to_vec(),
1712 ..Default::default()
1713 })
1714 .await
1715 .unwrap();
1716 tokio::time::timeout(Duration::from_secs(2), claimed.notified())
1717 .await
1718 .expect("runner reached gate");
1719
1720 let before = runtime.status(&handle.run_id).await.expect("active");
1722 assert_eq!(before.state, RunState::Running);
1723
1724 runtime.cancel(&handle.run_id).await.unwrap();
1725
1726 let during = runtime
1730 .status(&handle.run_id)
1731 .await
1732 .expect("entry retained while termination is in flight");
1733 assert_eq!(during.state, RunState::Cancelling);
1734
1735 let _ = gate_tx.send(());
1738
1739 let outcome = tokio::time::timeout(Duration::from_secs(2), rx.recv())
1740 .await
1741 .expect("hook fired")
1742 .expect("hook channel open");
1743 assert_eq!(outcome.status, TerminalStatus::Cancelled);
1744 assert!(runtime.status(&handle.run_id).await.is_none());
1745
1746 let _ = shutdown.send(());
1747 }
1748
1749 #[tokio::test]
1750 async fn cancel_unknown_run_returns_false() {
1751 let queue = fresh_queue().await;
1752 let runtime: WorkflowRuntime<ScriptedRunner, NoopTerminalHook> =
1753 WorkflowRuntime::builder(queue, ScriptedRunner::new(vec![]), NoopTerminalHook).build();
1754
1755 let was_cancelled = runtime.cancel("never-submitted").await.unwrap();
1756 assert!(!was_cancelled);
1757 }
1758
1759 #[tokio::test]
1760 async fn transient_fires_once_on_single_attempt() {
1761 assert_transient_retries_until_max(1).await;
1762 }
1763
1764 #[tokio::test]
1765 async fn transient_retries_up_to_max_attempts() {
1766 assert_transient_retries_until_max(3).await;
1767 }
1768}