1use bytes::Bytes;
14use sayiir_core::codec::Codec;
15use sayiir_core::codec::sealed;
16use sayiir_core::context::{WorkflowContext, with_context};
17use sayiir_core::error::WorkflowError;
18use sayiir_core::snapshot::{ExecutionPosition, SignalKind, SignalRequest, WorkflowSnapshot};
19use sayiir_core::workflow::{Workflow, WorkflowContinuation, WorkflowStatus};
20use sayiir_persistence::PersistentBackend;
21use std::sync::Arc;
22
23use crate::error::RuntimeError;
24use crate::execution::{
25 ForkBranchOutcome, JoinResolution, ResumeParkedPosition, branch_execute_or_skip_task,
26 check_guards, collect_cached_branches, execute_or_skip_task, finalize_execution,
27 get_resume_input, park_at_delay, park_at_signal, park_branch_at_delay, park_branch_at_signal,
28 resolve_join, retry_with_checkpoint, set_deadline_if_needed, settle_fork_outcome,
29};
30
31pub struct CheckpointingRunner<B> {
68 backend: Arc<B>,
69}
70
71impl<B> CheckpointingRunner<B>
72where
73 B: PersistentBackend,
74{
75 pub fn new(backend: B) -> Self {
77 Self {
78 backend: Arc::new(backend),
79 }
80 }
81
82 pub async fn cancel(
97 &self,
98 instance_id: &str,
99 reason: Option<String>,
100 cancelled_by: Option<String>,
101 ) -> Result<(), RuntimeError> {
102 self.backend
103 .store_signal(
104 instance_id,
105 SignalKind::Cancel,
106 SignalRequest::new(reason, cancelled_by),
107 )
108 .await?;
109
110 Ok(())
111 }
112
113 pub async fn pause(
121 &self,
122 instance_id: &str,
123 reason: Option<String>,
124 paused_by: Option<String>,
125 ) -> Result<(), RuntimeError> {
126 self.backend
127 .store_signal(
128 instance_id,
129 SignalKind::Pause,
130 SignalRequest::new(reason, paused_by),
131 )
132 .await?;
133 Ok(())
134 }
135
136 pub async fn unpause(&self, instance_id: &str) -> Result<WorkflowSnapshot, RuntimeError> {
144 let snapshot = self.backend.unpause(instance_id).await?;
145 Ok(snapshot)
146 }
147
148 #[must_use]
150 pub fn backend(&self) -> &Arc<B> {
151 &self.backend
152 }
153}
154
155impl<B> CheckpointingRunner<B>
156where
157 B: PersistentBackend + 'static,
158{
159 pub async fn run<C, Input, M>(
169 &self,
170 workflow: &Workflow<C, Input, M>,
171 instance_id: impl Into<String>,
172 input: Input,
173 ) -> Result<WorkflowStatus, RuntimeError>
174 where
175 Input: Send + 'static,
176 M: Send + Sync + 'static,
177 C: Codec + sealed::EncodeValue<Input> + sealed::DecodeValue<Input> + 'static,
178 {
179 let instance_id = instance_id.into();
180 let definition_hash = workflow.definition_hash().to_string();
181
182 let input_bytes = workflow.context().codec.encode(&input)?;
184
185 let mut snapshot = WorkflowSnapshot::with_initial_input(
187 instance_id.clone(),
188 definition_hash.clone(),
189 input_bytes.clone(),
190 );
191 snapshot.update_position(ExecutionPosition::AtTask {
192 task_id: workflow.continuation().first_task_id().to_string(),
193 });
194
195 self.backend.save_snapshot(&snapshot).await?;
197
198 let context = workflow.context().clone();
200 let continuation = workflow.continuation();
201 let backend = Arc::clone(&self.backend);
202
203 with_context(context.clone(), || async move {
204 let result = Self::execute_with_checkpointing(
205 continuation,
206 input_bytes,
207 &mut snapshot,
208 Arc::clone(&backend),
209 context,
210 )
211 .await;
212
213 let (status, _output) =
214 finalize_execution(result, &mut snapshot, backend.as_ref()).await?;
215 Ok(status)
216 })
217 .await
218 }
219
220 #[allow(clippy::needless_lifetimes)]
232 pub async fn resume<'w, C, Input, M>(
233 &self,
234 workflow: &'w Workflow<C, Input, M>,
235 instance_id: &str,
236 ) -> Result<WorkflowStatus, RuntimeError>
237 where
238 Input: Send + 'static,
239 M: Send + Sync + 'static,
240 C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<Input> + 'static,
241 {
242 let mut snapshot = self.backend.load_snapshot(instance_id).await?;
244
245 if snapshot.definition_hash != workflow.definition_hash() {
247 return Err(WorkflowError::DefinitionMismatch {
248 expected: workflow.definition_hash().to_string(),
249 found: snapshot.definition_hash.clone(),
250 }
251 .into());
252 }
253
254 if let Some(status) = snapshot.state.as_terminal_status() {
256 return Ok(status);
257 }
258
259 let parked = ResumeParkedPosition::extract(&snapshot);
261 if let Some(status) = parked
262 .resolve(&mut snapshot, instance_id, self.backend.as_ref())
263 .await?
264 {
265 return Ok(status);
266 }
267
268 let context = workflow.context().clone();
270 let continuation = workflow.continuation();
271 let backend = Arc::clone(&self.backend);
272
273 with_context(context.clone(), || async move {
274 let input_bytes = get_resume_input(&snapshot)?;
276
277 let result = Self::execute_with_checkpointing(
278 continuation,
279 input_bytes,
280 &mut snapshot,
281 Arc::clone(&backend),
282 context,
283 )
284 .await;
285
286 let (status, _output) =
287 finalize_execution(result, &mut snapshot, backend.as_ref()).await?;
288 Ok(status)
289 })
290 .await
291 }
292
293 #[allow(clippy::manual_async_fn, clippy::too_many_lines)]
295 async fn execute_with_checkpointing<'a, C, M>(
296 continuation: &'a WorkflowContinuation,
297 input: Bytes,
298 snapshot: &'a mut WorkflowSnapshot,
299 backend: Arc<B>,
300 context: WorkflowContext<C, M>,
301 ) -> Result<Bytes, RuntimeError>
302 where
303 B: 'static,
304 C: Codec + 'static,
305 M: Send + Sync + 'static,
306 {
307 let mut current = continuation;
308 let mut current_input = input;
309
310 loop {
311 match current {
312 WorkflowContinuation::Task {
313 id,
314 func: Some(func),
315 timeout,
316 retry_policy,
317 next,
318 } => {
319 check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
320 set_deadline_if_needed(id, timeout.as_ref(), snapshot, backend.as_ref())
321 .await?;
322
323 let output = retry_with_checkpoint(
324 id,
325 retry_policy.as_ref(),
326 timeout.as_ref(),
327 snapshot,
328 Some(backend.as_ref()),
329 async |snap| {
330 execute_or_skip_task(id, current_input.clone(), |i| func.run(i), snap)
331 .await
332 },
333 )
334 .await?;
335
336 if let Some(next_cont) = next {
337 snapshot.update_position(ExecutionPosition::AtTask {
338 task_id: next_cont.first_task_id().to_string(),
339 });
340 }
341 backend.save_snapshot(snapshot).await?;
342 check_guards(backend.as_ref(), &snapshot.instance_id, None).await?;
343
344 match next {
345 Some(next_continuation) => {
346 current = next_continuation;
347 current_input = output;
348 }
349 None => return Ok(output),
350 }
351 }
352 WorkflowContinuation::Task { func: None, id, .. } => {
353 return Err(WorkflowError::TaskNotImplemented(id.clone()).into());
354 }
355 WorkflowContinuation::Delay { id, duration, next } => {
356 check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
357
358 if snapshot.get_task_result(id).is_some() {
359 match next {
360 Some(next_continuation) => {
361 current = next_continuation;
362 continue;
363 }
364 None => return Ok(current_input),
365 }
366 }
367
368 return Err(park_at_delay(
369 id,
370 duration,
371 next.as_deref(),
372 current_input,
373 snapshot,
374 backend.as_ref(),
375 )
376 .await);
377 }
378 WorkflowContinuation::AwaitSignal {
379 id,
380 signal_name,
381 timeout,
382 next,
383 } => {
384 check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
385
386 if snapshot.get_task_result(id).is_some() {
387 match next {
388 Some(n) => {
389 current = n;
390 current_input =
391 snapshot.get_task_result_bytes(id).unwrap_or(current_input);
392 continue;
393 }
394 None => return Ok(current_input),
395 }
396 }
397
398 let err = park_at_signal(
399 id,
400 signal_name,
401 timeout.as_ref(),
402 next.as_deref(),
403 snapshot,
404 backend.as_ref(),
405 )
406 .await;
407
408 if matches!(err, RuntimeError::Workflow(WorkflowError::SignalConsumed)) {
409 if let Some(n) = next {
410 current = n;
411 current_input =
412 snapshot.get_task_result_bytes(id).unwrap_or(current_input);
413 continue;
414 }
415 let output = snapshot.get_task_result_bytes(id).unwrap_or(current_input);
416 return Ok(output);
417 }
418
419 return Err(err);
420 }
421 WorkflowContinuation::Fork {
422 id: fork_id,
423 branches,
424 join,
425 } => {
426 check_guards(backend.as_ref(), &snapshot.instance_id, None).await?;
427
428 let branch_results =
429 if let Some(cached) = collect_cached_branches(branches, snapshot) {
430 cached
431 } else {
432 let outcome = Self::execute_fork_branches_parallel(
433 branches,
434 ¤t_input,
435 snapshot,
436 &backend,
437 &context,
438 )
439 .await?;
440 settle_fork_outcome(
441 fork_id,
442 outcome,
443 join.as_deref(),
444 snapshot,
445 backend.as_ref(),
446 )
447 .await?
448 };
449
450 match resolve_join(join.as_deref(), &branch_results)? {
451 JoinResolution::Continue { next, input } => {
452 current = next;
453 current_input = input;
454 }
455 JoinResolution::Done(output) => return Ok(output),
456 }
457 }
458 }
459 }
460 }
461
462 async fn execute_fork_branches_parallel<C, M>(
464 branches: &[Arc<WorkflowContinuation>],
465 input: &Bytes,
466 snapshot: &WorkflowSnapshot,
467 backend: &Arc<B>,
468 context: &WorkflowContext<C, M>,
469 ) -> Result<ForkBranchOutcome, RuntimeError>
470 where
471 B: 'static,
472 C: Codec + 'static,
473 M: Send + Sync + 'static,
474 {
475 let mut branch_results = Vec::with_capacity(branches.len());
476 let mut set = tokio::task::JoinSet::new();
477 let instance_id = snapshot.instance_id.clone();
478
479 for branch in branches {
480 let branch_id = branch.id().to_string();
481
482 if let Some(result) = snapshot.get_task_result(&branch_id) {
483 branch_results.push((branch_id, result.output.clone()));
484 } else {
485 let branch = Arc::clone(branch);
486 let branch_input = input.clone();
487 let branch_backend = Arc::clone(backend);
488 let branch_instance_id = instance_id.clone();
489 let ctx_for_work = context.clone();
490
491 set.spawn(with_context(context.clone(), || async move {
492 let result = Self::execute_branch_with_checkpoint(
493 &branch,
494 branch_input,
495 branch_backend,
496 branch_instance_id,
497 ctx_for_work,
498 )
499 .await?;
500 Ok((branch_id, result))
501 }));
502 }
503 }
504
505 let mut max_wake_at: Option<chrono::DateTime<chrono::Utc>> = None;
506
507 while let Some(result) = set.join_next().await {
508 match result {
509 Ok(Ok((branch_id, output))) => {
510 branch_results.push((branch_id, output));
511 }
512 Ok(Err(RuntimeError::Workflow(WorkflowError::Waiting { wake_at }))) => {
513 max_wake_at = Some(match max_wake_at {
514 Some(existing) => existing.max(wake_at),
515 None => wake_at,
516 });
517 }
518 Ok(Err(e)) => return Err(e),
519 Err(join_err) => return Err(RuntimeError::from(join_err)),
520 }
521 }
522
523 Ok(ForkBranchOutcome {
524 results: branch_results,
525 max_wake_at,
526 })
527 }
528
529 async fn execute_nested_fork_branches<C, M>(
534 branches: &[Arc<WorkflowContinuation>],
535 input: &Bytes,
536 backend: &Arc<B>,
537 instance_id: &str,
538 context: &WorkflowContext<C, M>,
539 ) -> Result<Vec<(String, Bytes)>, RuntimeError>
540 where
541 B: 'static,
542 C: Codec + 'static,
543 M: Send + Sync + 'static,
544 {
545 let mut set: tokio::task::JoinSet<Result<(String, Bytes), RuntimeError>> =
546 tokio::task::JoinSet::new();
547 for branch in branches {
548 let id = branch.id().to_string();
549 let branch = Arc::clone(branch);
550 let branch_input = input.clone();
551 let branch_backend = Arc::clone(backend);
552 let branch_instance_id = instance_id.to_string();
553 let ctx_for_work = context.clone();
554
555 set.spawn(with_context(context.clone(), || async move {
556 let result = Self::execute_branch_with_checkpoint(
557 &branch,
558 branch_input,
559 branch_backend,
560 branch_instance_id,
561 ctx_for_work,
562 )
563 .await?;
564 Ok((id, result))
565 }));
566 }
567
568 let mut branch_results: Vec<(String, Bytes)> = Vec::with_capacity(set.len());
569 while let Some(res) = set.join_next().await {
570 branch_results.push(res??);
571 }
572 Ok(branch_results)
573 }
574
575 #[allow(clippy::manual_async_fn, clippy::too_many_lines)]
584 fn execute_branch_with_checkpoint<C, M>(
585 continuation: &WorkflowContinuation,
586 input: Bytes,
587 backend: Arc<B>,
588 instance_id: String,
589 context: WorkflowContext<C, M>,
590 ) -> impl std::future::Future<Output = Result<Bytes, RuntimeError>> + Send + '_
591 where
592 B: 'static,
593 C: Codec + 'static,
594 M: Send + Sync + 'static,
595 {
596 async move {
597 let mut snapshot = backend.load_snapshot(&instance_id).await?;
599
600 let mut current = continuation;
601 let mut current_input = input;
602
603 loop {
604 match current {
605 WorkflowContinuation::Task {
606 id,
607 func,
608 timeout,
609 retry_policy,
610 next,
611 } => {
612 let func = func
613 .as_ref()
614 .ok_or_else(|| WorkflowError::TaskNotImplemented(id.clone()))?;
615
616 let output = loop {
617 match branch_execute_or_skip_task(
618 id,
619 current_input.clone(),
620 |i| func.run(i),
621 timeout.as_ref(),
622 &mut snapshot,
623 &instance_id,
624 backend.as_ref(),
625 )
626 .await
627 {
628 Ok(output) => {
629 snapshot.clear_retry_state(id);
630 break output;
631 }
632 Err(e) => {
633 if let Some(rp) = retry_policy
634 && !snapshot.retries_exhausted(id)
635 {
636 let next_retry_at =
637 snapshot.record_retry(id, rp, &e.to_string(), None);
638 snapshot.clear_task_deadline();
639 tracing::info!(
640 task_id = %id,
641 attempt = snapshot.get_retry_state(id).map_or(0, |rs| rs.attempts),
642 max_retries = rp.max_retries,
643 %next_retry_at,
644 error = %e,
645 "Retrying task (branch)"
646 );
647 let delay = (next_retry_at - chrono::Utc::now())
648 .to_std()
649 .unwrap_or_default();
650 tokio::time::sleep(delay).await;
651 continue;
652 }
653 return Err(e);
654 }
655 }
656 };
657
658 match next {
659 Some(next_continuation) => {
660 current = next_continuation;
661 current_input = output;
662 }
663 None => return Ok(output),
664 }
665 }
666 WorkflowContinuation::Delay { id, duration, next } => {
667 if let Some(result) = snapshot.get_task_result(id) {
669 tracing::debug!(delay_id = %id, "delay already completed in branch, skipping");
670 match next {
671 Some(next_cont) => {
672 current = next_cont;
673 current_input = result.output.clone();
674 continue;
675 }
676 None => return Ok(result.output.clone()),
677 }
678 }
679
680 return Err(park_branch_at_delay(
681 id,
682 duration,
683 current_input,
684 &instance_id,
685 backend.as_ref(),
686 )
687 .await);
688 }
689 WorkflowContinuation::AwaitSignal {
690 id,
691 signal_name,
692 timeout,
693 next,
694 } => {
695 if let Some(result) = snapshot.get_task_result(id) {
697 tracing::debug!(signal_id = %id, %signal_name, "signal already consumed in branch, skipping");
698 match next {
699 Some(next_cont) => {
700 current = next_cont;
701 current_input = result.output.clone();
702 continue;
703 }
704 None => return Ok(result.output.clone()),
705 }
706 }
707
708 return Err(park_branch_at_signal(
709 id,
710 signal_name,
711 timeout.as_ref(),
712 current_input,
713 &instance_id,
714 backend.as_ref(),
715 )
716 .await);
717 }
718 WorkflowContinuation::Fork { branches, join, .. } => {
719 let branch_results = Self::execute_nested_fork_branches(
720 branches,
721 ¤t_input,
722 &backend,
723 &instance_id,
724 &context,
725 )
726 .await?;
727
728 match resolve_join(join.as_deref(), &branch_results)? {
729 JoinResolution::Continue { next, input } => {
730 current = next;
731 current_input = input;
732 }
733 JoinResolution::Done(output) => return Ok(output),
734 }
735 }
736 }
737 }
738 }
739 }
740}
741
742#[cfg(test)]
743#[allow(
744 clippy::unwrap_used,
745 clippy::expect_used,
746 clippy::panic,
747 clippy::indexing_slicing,
748 clippy::too_many_lines
749)]
750mod tests {
751 use super::*;
752 use crate::serialization::JsonCodec;
753 use sayiir_core::codec::Encoder;
754 use sayiir_core::context::WorkflowContext;
755 use sayiir_core::error::BoxError;
756 use sayiir_core::snapshot::WorkflowSnapshotState;
757 use sayiir_core::task::BranchOutputs;
758 use sayiir_core::workflow::WorkflowBuilder;
759 use sayiir_persistence::InMemoryBackend;
760 use sayiir_persistence::{SignalStore, SnapshotStore};
761
762 fn ctx() -> WorkflowContext<JsonCodec, ()> {
763 WorkflowContext::new("test-workflow", Arc::new(JsonCodec), Arc::new(()))
764 }
765
766 #[tokio::test]
771 async fn test_run_single_task() {
772 let backend = InMemoryBackend::new();
773 let runner = CheckpointingRunner::new(backend);
774
775 let workflow = WorkflowBuilder::new(ctx())
776 .then("add_one", |i: u32| async move { Ok(i + 1) })
777 .build()
778 .unwrap();
779
780 let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
781 assert!(matches!(status, WorkflowStatus::Completed));
782
783 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
785 assert!(snapshot.state.is_completed());
786 }
787
788 #[tokio::test]
789 async fn test_run_chained_tasks() {
790 let backend = InMemoryBackend::new();
791 let runner = CheckpointingRunner::new(backend);
792
793 let workflow = WorkflowBuilder::new(ctx())
794 .then("add_one", |i: u32| async move { Ok(i + 1) })
795 .then("double", |i: u32| async move { Ok(i * 2) })
796 .build()
797 .unwrap();
798
799 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
800 assert!(matches!(status, WorkflowStatus::Completed));
801
802 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
803 assert!(snapshot.state.is_completed());
804 }
805
806 #[tokio::test]
807 async fn test_run_three_task_chain() {
808 let backend = InMemoryBackend::new();
809 let runner = CheckpointingRunner::new(backend);
810
811 let workflow = WorkflowBuilder::new(ctx())
812 .then("step1", |i: u32| async move { Ok(i + 1) })
813 .then("step2", |i: u32| async move { Ok(i * 3) })
814 .then("step3", |i: u32| async move { Ok(i - 2) })
815 .build()
816 .unwrap();
817
818 let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
819 assert!(matches!(status, WorkflowStatus::Completed));
821 }
822
823 #[tokio::test]
824 async fn test_run_task_failure() {
825 let backend = InMemoryBackend::new();
826 let runner = CheckpointingRunner::new(backend);
827
828 let workflow = WorkflowBuilder::new(ctx())
829 .then("fail", |_i: u32| async move {
830 Err::<u32, BoxError>("intentional failure".into())
831 })
832 .build()
833 .unwrap();
834
835 let status = runner.run(&workflow, "inst-1", 1u32).await.unwrap();
836 match status {
837 WorkflowStatus::Failed(e) => {
838 assert!(e.contains("intentional failure"));
839 }
840 _ => panic!("Expected Failed status"),
841 }
842
843 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
845 assert!(snapshot.state.is_failed());
846 }
847
848 #[tokio::test]
849 async fn test_run_fork_join() {
850 let backend = InMemoryBackend::new();
851 let runner = CheckpointingRunner::new(backend);
852
853 let workflow = WorkflowBuilder::new(ctx())
854 .then("prepare", |i: u32| async move { Ok(i) })
855 .branches(|b| {
856 b.add("double", |i: u32| async move { Ok(i * 2) });
857 b.add("add_ten", |i: u32| async move { Ok(i + 10) });
858 })
859 .join("combine", |outputs: BranchOutputs<JsonCodec>| async move {
860 let doubled: u32 = outputs.get("double")?;
861 let added: u32 = outputs.get("add_ten")?;
862 Ok(doubled + added)
863 })
864 .build()
865 .unwrap();
866
867 let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
868 assert!(matches!(status, WorkflowStatus::Completed));
869
870 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
871 assert!(snapshot.state.is_completed());
872 }
873
874 #[tokio::test]
875 async fn test_run_checkpoints_intermediate_tasks() {
876 let backend = InMemoryBackend::new();
877 let runner = CheckpointingRunner::new(backend);
878
879 let workflow = WorkflowBuilder::new(ctx())
880 .then("step1", |i: u32| async move { Ok(i + 1) })
881 .then("step2", |i: u32| async move { Ok(i * 2) })
882 .build()
883 .unwrap();
884
885 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
886 assert!(matches!(status, WorkflowStatus::Completed));
887
888 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
891 assert!(snapshot.state.is_completed());
892 }
893
894 #[tokio::test]
899 async fn test_resume_completed_workflow() {
900 let backend = InMemoryBackend::new();
901 let runner = CheckpointingRunner::new(backend);
902
903 let workflow = WorkflowBuilder::new(ctx())
904 .then("step1", |i: u32| async move { Ok(i + 1) })
905 .build()
906 .unwrap();
907
908 runner.run(&workflow, "inst-1", 5u32).await.unwrap();
910
911 let status = runner.resume(&workflow, "inst-1").await.unwrap();
913 assert!(matches!(status, WorkflowStatus::Completed));
914 }
915
916 #[tokio::test]
917 async fn test_resume_failed_workflow() {
918 let backend = InMemoryBackend::new();
919 let runner = CheckpointingRunner::new(backend);
920
921 let workflow = WorkflowBuilder::new(ctx())
922 .then("fail", |_i: u32| async move {
923 Err::<u32, BoxError>("failure".into())
924 })
925 .build()
926 .unwrap();
927
928 runner.run(&workflow, "inst-1", 1u32).await.unwrap();
929
930 let status = runner.resume(&workflow, "inst-1").await.unwrap();
931 match status {
932 WorkflowStatus::Failed(_) => {}
933 _ => panic!("Expected Failed status"),
934 }
935 }
936
937 #[tokio::test]
938 async fn test_resume_definition_hash_mismatch() {
939 let backend = InMemoryBackend::new();
940 let runner = CheckpointingRunner::new(backend);
941
942 let workflow1 = WorkflowBuilder::new(ctx())
943 .then("step1", |i: u32| async move { Ok(i + 1) })
944 .build()
945 .unwrap();
946
947 runner.run(&workflow1, "inst-1", 5u32).await.unwrap();
949
950 let mut snapshot = WorkflowSnapshot::with_initial_input(
952 "inst-2".into(),
953 workflow1.definition_hash().to_string(),
954 Bytes::from(serde_json::to_vec(&5u32).unwrap()),
955 );
956 snapshot.update_position(ExecutionPosition::AtTask {
957 task_id: "step1".into(),
958 });
959 runner.backend().save_snapshot(&snapshot).await.unwrap();
960
961 let workflow2 = WorkflowBuilder::new(ctx())
963 .then("step1", |i: u32| async move { Ok(i + 1) })
964 .then("step2", |i: u32| async move { Ok(i * 2) })
965 .build()
966 .unwrap();
967
968 let result = runner.resume(&workflow2, "inst-2").await;
970 assert!(result.is_err());
971 assert!(result.unwrap_err().to_string().contains("mismatch"));
972 }
973
974 #[tokio::test]
979 async fn test_cancel_running_workflow() {
980 let backend = InMemoryBackend::new();
981 let runner = CheckpointingRunner::new(backend);
982
983 let workflow = WorkflowBuilder::new(ctx())
985 .then("slow_task", |i: u32| async move {
986 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
987 Ok(i)
988 })
989 .build()
990 .unwrap();
991
992 let input_bytes = Arc::new(JsonCodec).encode(&1u32).unwrap();
994 let mut snapshot = WorkflowSnapshot::with_initial_input(
995 "inst-cancel".into(),
996 workflow.definition_hash().to_string(),
997 input_bytes,
998 );
999 snapshot.update_position(ExecutionPosition::AtTask {
1000 task_id: "slow_task".into(),
1001 });
1002 runner.backend().save_snapshot(&snapshot).await.unwrap();
1003
1004 runner
1006 .cancel(
1007 "inst-cancel",
1008 Some("testing".into()),
1009 Some("test-suite".into()),
1010 )
1011 .await
1012 .unwrap();
1013
1014 let req = runner
1016 .backend()
1017 .get_signal("inst-cancel", SignalKind::Cancel)
1018 .await
1019 .unwrap();
1020 assert!(req.is_some());
1021 assert_eq!(req.unwrap().reason, Some("testing".into()));
1022 }
1023
1024 #[tokio::test]
1025 async fn test_run_with_pre_cancellation() {
1026 let backend = InMemoryBackend::new();
1027 let runner = CheckpointingRunner::new(backend);
1028
1029 let workflow = WorkflowBuilder::new(ctx())
1030 .then("task1", |i: u32| async move { Ok(i + 1) })
1031 .then("task2", |i: u32| async move { Ok(i * 2) })
1032 .build()
1033 .unwrap();
1034
1035 let input_bytes = Arc::new(JsonCodec).encode(&1u32).unwrap();
1037 let mut snapshot = WorkflowSnapshot::with_initial_input(
1038 "inst-precancel".into(),
1039 workflow.definition_hash().to_string(),
1040 input_bytes,
1041 );
1042 snapshot.update_position(ExecutionPosition::AtTask {
1043 task_id: "task1".into(),
1044 });
1045 runner.backend().save_snapshot(&snapshot).await.unwrap();
1046
1047 runner
1048 .cancel("inst-precancel", Some("pre-cancel".into()), None)
1049 .await
1050 .unwrap();
1051
1052 let status = runner.resume(&workflow, "inst-precancel").await.unwrap();
1054 match status {
1055 WorkflowStatus::Cancelled { reason, .. } => {
1056 assert_eq!(reason, Some("pre-cancel".into()));
1057 }
1058 _ => panic!("Expected Cancelled status, got: {status:?}"),
1059 }
1060 }
1061
1062 #[tokio::test]
1067 async fn test_resume_nonexistent_instance() {
1068 let backend = InMemoryBackend::new();
1069 let runner = CheckpointingRunner::new(backend);
1070
1071 let workflow = WorkflowBuilder::new(ctx())
1072 .then("task", |i: u32| async move { Ok(i) })
1073 .build()
1074 .unwrap();
1075
1076 let result = runner.resume(&workflow, "nonexistent").await;
1077 assert!(result.is_err());
1078 }
1079
1080 #[tokio::test]
1081 async fn test_run_failure_in_chain_saves_snapshot() {
1082 let backend = InMemoryBackend::new();
1083 let runner = CheckpointingRunner::new(backend);
1084
1085 let workflow = WorkflowBuilder::new(ctx())
1086 .then("step1", |i: u32| async move { Ok(i + 1) })
1087 .then("fail_step", |_i: u32| async move {
1088 Err::<u32, BoxError>("mid-chain failure".into())
1089 })
1090 .then("step3", |i: u32| async move { Ok(i * 2) })
1091 .build()
1092 .unwrap();
1093
1094 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1095 match status {
1096 WorkflowStatus::Failed(e) => {
1097 assert!(e.contains("mid-chain failure"));
1098 }
1099 _ => panic!("Expected Failed"),
1100 }
1101
1102 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1104 assert!(snapshot.state.is_failed());
1105 }
1106
1107 #[tokio::test]
1112 async fn test_run_workflow_with_delay_returns_waiting() {
1113 let backend = InMemoryBackend::new();
1114 let runner = CheckpointingRunner::new(backend);
1115
1116 let workflow = WorkflowBuilder::new(ctx())
1117 .then("step1", |i: u32| async move { Ok(i + 1) })
1118 .delay("wait_1h", std::time::Duration::from_secs(3600))
1119 .then("step2", |i: u32| async move { Ok(i * 2) })
1120 .build()
1121 .unwrap();
1122
1123 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1124
1125 match &status {
1127 WorkflowStatus::Waiting { delay_id, .. } => {
1128 assert_eq!(delay_id, "wait_1h");
1129 }
1130 _ => panic!("Expected Waiting status, got {status:?}"),
1131 }
1132
1133 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1135 assert!(snapshot.state.is_in_progress());
1136 match &snapshot.state {
1137 WorkflowSnapshotState::InProgress { position, .. } => match position {
1138 ExecutionPosition::AtDelay {
1139 delay_id,
1140 next_task_id,
1141 ..
1142 } => {
1143 assert_eq!(delay_id, "wait_1h");
1144 assert_eq!(next_task_id.as_deref(), Some("step2"));
1145 }
1146 other => panic!("Expected AtDelay, got {other:?}"),
1147 },
1148 _ => panic!("Expected InProgress"),
1149 }
1150
1151 assert!(snapshot.get_task_result("step1").is_some());
1153 assert!(snapshot.get_task_result("wait_1h").is_some());
1155 }
1156
1157 #[tokio::test]
1158 async fn test_resume_before_delay_expires_returns_waiting() {
1159 let backend = InMemoryBackend::new();
1160 let runner = CheckpointingRunner::new(backend);
1161
1162 let workflow = WorkflowBuilder::new(ctx())
1163 .then("step1", |i: u32| async move { Ok(i + 1) })
1164 .delay("wait_1h", std::time::Duration::from_secs(3600))
1165 .then("step2", |i: u32| async move { Ok(i * 2) })
1166 .build()
1167 .unwrap();
1168
1169 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1171 assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1172
1173 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1175 match &status {
1176 WorkflowStatus::Waiting { delay_id, .. } => {
1177 assert_eq!(delay_id, "wait_1h");
1178 }
1179 _ => panic!("Expected Waiting on resume, got {status:?}"),
1180 }
1181 }
1182
1183 #[tokio::test]
1184 async fn test_resume_after_delay_expires_completes() {
1185 let backend = InMemoryBackend::new();
1186 let runner = CheckpointingRunner::new(backend);
1187
1188 let workflow = WorkflowBuilder::new(ctx())
1190 .then("step1", |i: u32| async move { Ok(i + 1) })
1191 .delay("wait_short", std::time::Duration::from_millis(1))
1192 .then("step2", |i: u32| async move { Ok(i * 2) })
1193 .build()
1194 .unwrap();
1195
1196 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1198 assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1199
1200 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1202
1203 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1205 assert!(
1206 matches!(status, WorkflowStatus::Completed),
1207 "Expected Completed after delay expired, got {status:?}"
1208 );
1209
1210 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1212 assert!(snapshot.state.is_completed());
1213 }
1214
1215 #[tokio::test]
1216 async fn test_cancel_during_delay() {
1217 let backend = InMemoryBackend::new();
1218 let runner = CheckpointingRunner::new(backend);
1219
1220 let workflow = WorkflowBuilder::new(ctx())
1221 .then("step1", |i: u32| async move { Ok(i + 1) })
1222 .delay("wait_1h", std::time::Duration::from_secs(3600))
1223 .then("step2", |i: u32| async move { Ok(i * 2) })
1224 .build()
1225 .unwrap();
1226
1227 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1229 assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1230
1231 runner
1233 .cancel(
1234 "inst-1",
1235 Some("no longer needed".into()),
1236 Some("admin".into()),
1237 )
1238 .await
1239 .unwrap();
1240
1241 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1243 match status {
1244 WorkflowStatus::Cancelled {
1245 reason,
1246 cancelled_by,
1247 } => {
1248 assert_eq!(reason, Some("no longer needed".into()));
1249 assert_eq!(cancelled_by, Some("admin".into()));
1250 }
1251 _ => panic!("Expected Cancelled status, got {status:?}"),
1252 }
1253 }
1254
1255 #[tokio::test]
1256 async fn test_delay_as_last_node() {
1257 let backend = InMemoryBackend::new();
1258 let runner = CheckpointingRunner::new(backend);
1259
1260 let workflow = WorkflowBuilder::new(ctx())
1261 .then("step1", |i: u32| async move { Ok(i + 1) })
1262 .delay("final_wait", std::time::Duration::from_millis(1))
1263 .build()
1264 .unwrap();
1265
1266 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1268 assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1269
1270 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1272
1273 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1275 assert!(
1276 matches!(status, WorkflowStatus::Completed),
1277 "Expected Completed when delay is last node, got {status:?}"
1278 );
1279 }
1280
1281 #[tokio::test]
1282 async fn test_delay_data_passthrough() {
1283 let backend = InMemoryBackend::new();
1284 let runner = CheckpointingRunner::new(backend);
1285
1286 let workflow = WorkflowBuilder::new(ctx())
1288 .then("step1", |i: u32| async move { Ok(i + 1) })
1289 .delay("wait", std::time::Duration::from_millis(1))
1290 .then("step2", |i: u32| async move {
1291 assert_eq!(i, 11);
1293 Ok(i * 2)
1294 })
1295 .build()
1296 .unwrap();
1297
1298 runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1300
1301 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1303 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1304 assert!(matches!(status, WorkflowStatus::Completed));
1305 }
1306
1307 #[tokio::test]
1312 async fn test_run_task_timeout_fails_workflow() {
1313 use sayiir_core::task::TaskMetadata;
1314
1315 let backend = InMemoryBackend::new();
1316 let runner = CheckpointingRunner::new(backend);
1317
1318 let workflow = WorkflowBuilder::new(ctx())
1319 .with_registry()
1320 .then("slow_task", |i: u32| async move {
1321 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1322 Ok(i)
1323 })
1324 .with_metadata(TaskMetadata {
1325 timeout: Some(std::time::Duration::from_millis(5)),
1326 ..Default::default()
1327 })
1328 .build()
1329 .unwrap();
1330
1331 let status = runner
1332 .run(workflow.workflow(), "inst-timeout", 5u32)
1333 .await
1334 .unwrap();
1335 match status {
1336 WorkflowStatus::Failed(msg) => {
1337 assert!(
1338 msg.contains("timed out"),
1339 "Expected timeout error, got: {msg}"
1340 );
1341 assert!(
1342 msg.contains("slow_task"),
1343 "Expected task id in error, got: {msg}"
1344 );
1345 }
1346 other => panic!("Expected Failed status, got {other:?}"),
1347 }
1348 }
1349
1350 #[tokio::test]
1351 async fn test_run_task_within_timeout_succeeds() {
1352 use sayiir_core::task::TaskMetadata;
1353
1354 let backend = InMemoryBackend::new();
1355 let runner = CheckpointingRunner::new(backend);
1356
1357 let workflow = WorkflowBuilder::new(ctx())
1358 .with_registry()
1359 .then("fast_task", |i: u32| async move { Ok(i + 1) })
1360 .with_metadata(TaskMetadata {
1361 timeout: Some(std::time::Duration::from_secs(5)),
1362 ..Default::default()
1363 })
1364 .build()
1365 .unwrap();
1366
1367 let status = runner
1368 .run(workflow.workflow(), "inst-fast", 5u32)
1369 .await
1370 .unwrap();
1371 assert!(matches!(status, WorkflowStatus::Completed));
1372 }
1373}