1use std::ops::ControlFlow;
14use std::sync::Arc;
15
16use bytes::Bytes;
17use sayiir_core::codec::sealed;
18use sayiir_core::codec::{Codec, EnvelopeCodec};
19use sayiir_core::context::WorkflowContext;
20use sayiir_core::error::WorkflowError;
21use sayiir_core::snapshot::{ExecutionPosition, TaskHint, WorkflowSnapshot};
22use sayiir_core::workflow::{ConflictPolicy, Workflow, WorkflowContinuation, WorkflowStatus};
23use sayiir_persistence::PersistentBackend;
24
25use crate::error::RuntimeError;
26use crate::execution::control_flow::{
27 ParkReason, StepOutcome, StepResult, compute_signal_timeout, compute_wake_at,
28 save_branch_park_checkpoint, save_park_checkpoint,
29};
30use crate::execution::loop_runner::{
31 CheckpointingLoopHooks, LoopConfig, LoopExit, LoopNext, resolve_loop_iteration, run_loop_async,
32};
33use crate::execution::{
34 ForkBranchOutcome, JoinResolution, ResumeParkedPosition, branch_execute_or_skip_task,
35 check_guards, collect_cached_branches, execute_or_skip_task, finalize_execution,
36 get_resume_input, resolve_join, retry_with_checkpoint, set_deadline_if_needed,
37 settle_fork_outcome,
38};
39
40pub struct CheckpointingRunner<B> {
77 backend: Arc<B>,
78 conflict_policy: ConflictPolicy,
79}
80
81impl<B> CheckpointingRunner<B>
82where
83 B: PersistentBackend,
84{
85 pub fn new(backend: B) -> Self {
87 Self {
88 backend: Arc::new(backend),
89 conflict_policy: ConflictPolicy::default(),
90 }
91 }
92
93 pub fn from_shared(backend: Arc<B>) -> Self {
97 Self {
98 backend,
99 conflict_policy: ConflictPolicy::default(),
100 }
101 }
102
103 #[must_use]
105 pub fn with_conflict_policy(mut self, policy: ConflictPolicy) -> Self {
106 self.conflict_policy = policy;
107 self
108 }
109
110 #[must_use]
112 pub fn backend(&self) -> &Arc<B> {
113 &self.backend
114 }
115}
116
117impl<B> CheckpointingRunner<B>
118where
119 B: PersistentBackend + 'static,
120{
121 pub async fn run<C, Input, M>(
132 &self,
133 workflow: &Workflow<C, Input, M>,
134 instance_id: impl Into<String>,
135 input: Input,
136 ) -> Result<WorkflowStatus, RuntimeError>
137 where
138 Input: Send + 'static,
139 M: Send + Sync + 'static,
140 C: Codec
141 + EnvelopeCodec
142 + sealed::EncodeValue<Input>
143 + sealed::DecodeValue<Input>
144 + 'static,
145 {
146 use crate::{PrepareRunOutcome, check_existing_instance, prepare_run};
147
148 let instance_id = instance_id.into();
149 let definition_hash = workflow.definition_hash().to_string();
150 let conflict_policy = self.conflict_policy;
151
152 if let Some((status, _output)) = check_existing_instance(
154 &instance_id,
155 &definition_hash,
156 self.backend.as_ref(),
157 conflict_policy,
158 )
159 .await?
160 {
161 return Ok(status);
162 }
163
164 let input_bytes = workflow.context().codec.encode(&input)?;
166 let first_task = workflow.continuation().first_task_hint();
167
168 let mut snapshot = match prepare_run(
169 instance_id,
170 definition_hash,
171 input_bytes.clone(),
172 first_task,
173 self.backend.as_ref(),
174 conflict_policy,
175 true, )
177 .await?
178 {
179 PrepareRunOutcome::Fresh(s) => *s,
180 PrepareRunOutcome::ExistingStatus(status, _output) => {
181 return Ok(status);
182 }
183 };
184
185 let context = workflow.context().clone();
187 let continuation = workflow.continuation();
188 let backend = Arc::clone(&self.backend);
189
190 let result = Self::execute_with_checkpointing(
191 continuation,
192 input_bytes,
193 &mut snapshot,
194 Arc::clone(&backend),
195 context,
196 )
197 .await;
198
199 let (status, _output) = finalize_execution(result, &mut snapshot, backend.as_ref()).await?;
200 Ok(status)
201 }
202
203 #[allow(clippy::needless_lifetimes)]
215 pub async fn resume<'w, C, Input, M>(
216 &self,
217 workflow: &'w Workflow<C, Input, M>,
218 instance_id: &str,
219 ) -> Result<WorkflowStatus, RuntimeError>
220 where
221 Input: Send + 'static,
222 M: Send + Sync + 'static,
223 C: Codec
224 + EnvelopeCodec
225 + sealed::DecodeValue<Input>
226 + sealed::EncodeValue<Input>
227 + 'static,
228 {
229 let mut snapshot = self.backend.load_snapshot(instance_id).await?;
231
232 if snapshot.definition_hash != workflow.definition_hash() {
234 return Err(WorkflowError::DefinitionMismatch {
235 expected: workflow.definition_hash().to_string(),
236 found: snapshot.definition_hash.clone(),
237 }
238 .into());
239 }
240
241 if let Some(status) = snapshot.state.as_terminal_status() {
243 return Ok(status);
244 }
245
246 let parked = ResumeParkedPosition::extract(&snapshot);
248 if let Some(status) = parked
249 .resolve(&mut snapshot, instance_id, self.backend.as_ref())
250 .await?
251 {
252 return Ok(status);
253 }
254
255 let context = workflow.context().clone();
257 let continuation = workflow.continuation();
258 let backend = Arc::clone(&self.backend);
259
260 let input_bytes = get_resume_input(&snapshot)?;
262
263 let result = Self::execute_with_checkpointing(
264 continuation,
265 input_bytes,
266 &mut snapshot,
267 Arc::clone(&backend),
268 context,
269 )
270 .await;
271
272 let (status, _output) = finalize_execution(result, &mut snapshot, backend.as_ref()).await?;
273 Ok(status)
274 }
275
276 #[allow(clippy::manual_async_fn, clippy::too_many_lines)]
278 async fn execute_with_checkpointing<'a, C, M>(
279 continuation: &'a WorkflowContinuation,
280 input: Bytes,
281 snapshot: &'a mut WorkflowSnapshot,
282 backend: Arc<B>,
283 context: WorkflowContext<C, M>,
284 ) -> Result<Bytes, RuntimeError>
285 where
286 B: 'static,
287 C: Codec + EnvelopeCodec + 'static,
288 M: Send + Sync + 'static,
289 {
290 let mut current = continuation;
291 let mut current_input = input;
292
293 loop {
294 let step: StepResult = match current {
295 WorkflowContinuation::Task {
296 id,
297 func: Some(func),
298 timeout,
299 retry_policy,
300 ..
301 } => {
302 check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
303 set_deadline_if_needed(id, timeout.as_ref(), snapshot, backend.as_ref())
304 .await?;
305
306 let output = retry_with_checkpoint(
307 id,
308 retry_policy.as_ref(),
309 timeout.as_ref(),
310 snapshot,
311 Some(backend.as_ref()),
312 async |snap| {
313 execute_or_skip_task(id, current_input.clone(), |i| func.run(i), snap)
314 .await
315 },
316 )
317 .await?;
318
319 if let Some(next_cont) = current.get_next() {
320 let next_id = next_cont.first_task_id().to_string();
321 snapshot.set_task_hint(&TaskHint {
322 id: next_id.clone(),
323 priority: continuation.get_task_priority(&next_id),
324 tags: continuation.get_task_tags(&next_id),
325 });
326 snapshot.update_position(ExecutionPosition::AtTask { task_id: next_id });
327 }
328 backend.save_snapshot(snapshot).await?;
329 check_guards(backend.as_ref(), &snapshot.instance_id, None).await?;
330
331 Ok(ControlFlow::Continue(output))
332 }
333 WorkflowContinuation::Task { func: None, id, .. } => {
334 return Err(WorkflowError::TaskNotImplemented(id.clone()).into());
335 }
336 WorkflowContinuation::Delay { id, duration, next } => {
337 check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
338
339 if snapshot.get_task_result(id).is_some() {
340 Ok(ControlFlow::Continue(current_input.clone()))
341 } else {
342 let wake_at = compute_wake_at(duration)?;
343 Ok(ControlFlow::Break(StepOutcome::Park(ParkReason::Delay {
344 delay_id: id.clone(),
345 wake_at,
346 next_task: next.as_deref().map(WorkflowContinuation::first_task_hint),
347 passthrough: current_input.clone(),
348 })))
349 }
350 }
351 WorkflowContinuation::AwaitSignal {
352 id,
353 signal_name,
354 timeout,
355 next,
356 } => {
357 check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
358
359 if snapshot.get_task_result(id).is_some() {
360 let payload = snapshot
361 .get_task_result_bytes(id)
362 .unwrap_or(current_input.clone());
363 Ok(ControlFlow::Continue(payload))
364 } else {
365 match backend
366 .consume_event(&snapshot.instance_id, signal_name)
367 .await
368 {
369 Ok(Some(payload)) => {
370 snapshot.mark_task_completed(id.clone(), payload);
371 if let Some(next_cont) = next.as_deref() {
372 let next_id = next_cont.first_task_id().to_string();
373 snapshot.set_task_hint(&TaskHint {
374 id: next_id.clone(),
375 priority: continuation.get_task_priority(&next_id),
376 tags: continuation.get_task_tags(&next_id),
377 });
378 snapshot.update_position(ExecutionPosition::AtTask {
379 task_id: next_id,
380 });
381 }
382 backend.save_snapshot(snapshot).await?;
383 let output = snapshot
384 .get_task_result_bytes(id)
385 .unwrap_or(current_input.clone());
386 Ok(ControlFlow::Continue(output))
387 }
388 Ok(None) => Ok(ControlFlow::Break(StepOutcome::Park(
389 ParkReason::AwaitingSignal {
390 signal_id: id.clone(),
391 signal_name: signal_name.clone(),
392 timeout: compute_signal_timeout(timeout.as_ref()),
393 next_task: next
394 .as_deref()
395 .map(WorkflowContinuation::first_task_hint),
396 },
397 ))),
398 Err(e) => Err(RuntimeError::from(e)),
399 }
400 }
401 }
402 WorkflowContinuation::Fork {
403 id: fork_id,
404 branches,
405 join,
406 } => {
407 check_guards(backend.as_ref(), &snapshot.instance_id, None).await?;
408
409 let branch_results =
410 if let Some(cached) = collect_cached_branches(branches, snapshot) {
411 cached
412 } else {
413 let outcome = Self::execute_fork_branches_parallel(
414 branches,
415 ¤t_input,
416 snapshot,
417 &backend,
418 &context,
419 )
420 .await?;
421 settle_fork_outcome(
422 fork_id,
423 outcome,
424 join.as_deref(),
425 snapshot,
426 backend.as_ref(),
427 )
428 .await?
429 };
430
431 match resolve_join(join.as_deref(), &branch_results, context.codec.as_ref())? {
432 JoinResolution::Continue { input, .. } => Ok(ControlFlow::Continue(input)),
433 JoinResolution::Done(output) => {
434 Ok(ControlFlow::Break(StepOutcome::Done(output)))
435 }
436 }
437 }
438 WorkflowContinuation::Branch {
439 id,
440 key_fn: Some(key_fn),
441 branches,
442 default,
443 ..
444 } => {
445 check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
446
447 if let Some(result) = snapshot.get_task_result(id) {
448 Ok(ControlFlow::Continue(result.output.clone()))
449 } else {
450 let key_bytes = key_fn
451 .run(current_input.clone())
452 .await
453 .map_err(RuntimeError::from)?;
454 let key: String = context
455 .codec
456 .decode_string(&key_bytes)
457 .map_err(RuntimeError::from)?;
458
459 let chosen = branches.get(&key).or(default.as_ref()).ok_or_else(|| {
460 WorkflowError::BranchKeyNotFound {
461 branch_id: id.clone(),
462 key: key.clone(),
463 }
464 })?;
465
466 let branch_output = Self::execute_branch_with_checkpoint(
467 chosen,
468 current_input.clone(),
469 Arc::clone(&backend),
470 snapshot.instance_id.clone(),
471 context.clone(),
472 )
473 .await?;
474
475 let envelope_bytes = context
476 .codec
477 .encode_branch_envelope(&key, &branch_output)
478 .map_err(RuntimeError::from)?;
479
480 snapshot.mark_task_completed(id.clone(), envelope_bytes.clone());
481 backend.save_snapshot(snapshot).await?;
482
483 Ok(ControlFlow::Continue(envelope_bytes))
484 }
485 }
486 WorkflowContinuation::Branch {
487 key_fn: None, id, ..
488 } => {
489 return Err(WorkflowError::TaskNotImplemented(
490 sayiir_core::workflow::key_fn_id(id),
491 )
492 .into());
493 }
494 WorkflowContinuation::Loop {
495 id,
496 body,
497 max_iterations,
498 on_max,
499 ..
500 } => {
501 check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
502
503 if let Some(result) = snapshot.get_task_result(id) {
504 Ok(ControlFlow::Continue(result.output.clone()))
505 } else {
506 let cfg = LoopConfig {
507 id,
508 body,
509 max_iterations: *max_iterations,
510 on_max: *on_max,
511 start_iteration: snapshot.loop_iteration(id),
512 };
513 let mut loop_input = current_input.clone();
514 let mut final_output = None;
515
516 for iteration in cfg.start_iteration..cfg.max_iterations {
517 let output = Box::pin(Self::execute_with_checkpointing(
518 body,
519 loop_input.clone(),
520 snapshot,
521 Arc::clone(&backend),
522 context.clone(),
523 ))
524 .await?;
525
526 let body_ser = body.to_serializable();
527 for tid in &body_ser.task_ids() {
528 snapshot.remove_task_result(tid);
529 }
530
531 match resolve_loop_iteration(&output, iteration, &cfg)? {
532 ControlFlow::Break(LoopExit(inner)) => {
533 snapshot.clear_loop_iteration(id);
534 snapshot.mark_task_completed(id.clone(), inner.clone());
535 backend.save_snapshot(snapshot).await?;
536 final_output = Some(inner);
537 break;
538 }
539 ControlFlow::Continue(LoopNext(inner)) => {
540 snapshot.set_loop_iteration(id, iteration + 1);
541 snapshot.update_position(ExecutionPosition::InLoop {
542 loop_id: id.clone(),
543 iteration: iteration + 1,
544 next_task_id: Some(body.first_task_id().to_string()),
545 });
546 backend.save_snapshot(snapshot).await?;
547 loop_input = inner;
548 }
549 }
550 }
551
552 match final_output {
553 Some(output) => Ok(ControlFlow::Continue(output)),
554 None => Err(RuntimeError::from(WorkflowError::MaxIterationsExceeded {
555 loop_id: id.clone(),
556 max_iterations: *max_iterations,
557 })),
558 }
559 }
560 }
561 WorkflowContinuation::ChildWorkflow { id, child, .. } => {
562 check_guards(backend.as_ref(), &snapshot.instance_id, Some(id)).await?;
563
564 if let Some(result) = snapshot.get_task_result(id) {
565 Ok(ControlFlow::Continue(result.output.clone()))
566 } else {
567 let output = Box::pin(Self::execute_with_checkpointing(
568 child,
569 current_input.clone(),
570 snapshot,
571 Arc::clone(&backend),
572 context.clone(),
573 ))
574 .await?;
575
576 snapshot.mark_task_completed(id.clone(), output.clone());
577 backend.save_snapshot(snapshot).await?;
578
579 Ok(ControlFlow::Continue(output))
580 }
581 }
582 };
583
584 match step? {
585 ControlFlow::Continue(output) => match current.get_next() {
586 Some(next) => {
587 current = next;
588 current_input = output;
589 }
590 None => return Ok(output),
591 },
592 ControlFlow::Break(StepOutcome::Done(output)) => return Ok(output),
593 ControlFlow::Break(StepOutcome::Park(reason)) => {
594 return Err(save_park_checkpoint(reason, snapshot, backend.as_ref()).await);
595 }
596 }
597 }
598 }
599
600 async fn execute_fork_branches_parallel<C, M>(
602 branches: &[Arc<WorkflowContinuation>],
603 input: &Bytes,
604 snapshot: &WorkflowSnapshot,
605 backend: &Arc<B>,
606 context: &WorkflowContext<C, M>,
607 ) -> Result<ForkBranchOutcome, RuntimeError>
608 where
609 B: 'static,
610 C: Codec + EnvelopeCodec + 'static,
611 M: Send + Sync + 'static,
612 {
613 let mut branch_results = Vec::with_capacity(branches.len());
614 let mut set = tokio::task::JoinSet::new();
615 let instance_id = snapshot.instance_id.clone();
616
617 for branch in branches {
618 let branch_id = branch.id().to_string();
619
620 if let Some(result) = snapshot.get_task_result(&branch_id) {
621 branch_results.push((branch_id, result.output.clone()));
622 } else {
623 let branch = Arc::clone(branch);
624 let branch_input = input.clone();
625 let branch_backend = Arc::clone(backend);
626 let branch_instance_id = instance_id.clone();
627 let ctx_for_work = context.clone();
628
629 set.spawn(async move {
630 let result = Self::execute_branch_with_checkpoint(
631 &branch,
632 branch_input,
633 branch_backend,
634 branch_instance_id,
635 ctx_for_work,
636 )
637 .await?;
638 Ok((branch_id, result))
639 });
640 }
641 }
642
643 let mut max_wake_at: Option<chrono::DateTime<chrono::Utc>> = None;
644
645 while let Some(result) = set.join_next().await {
646 match result {
647 Ok(Ok((branch_id, output))) => {
648 branch_results.push((branch_id, output));
649 }
650 Ok(Err(RuntimeError::Workflow(WorkflowError::Waiting { wake_at }))) => {
651 max_wake_at = Some(match max_wake_at {
652 Some(existing) => existing.max(wake_at),
653 None => wake_at,
654 });
655 }
656 Ok(Err(e)) => return Err(e),
657 Err(join_err) => return Err(RuntimeError::from(join_err)),
658 }
659 }
660
661 Ok(ForkBranchOutcome {
662 results: branch_results,
663 max_wake_at,
664 })
665 }
666
667 async fn execute_nested_fork_branches<C, M>(
672 branches: &[Arc<WorkflowContinuation>],
673 input: &Bytes,
674 backend: &Arc<B>,
675 instance_id: &str,
676 context: &WorkflowContext<C, M>,
677 ) -> Result<Vec<(String, Bytes)>, RuntimeError>
678 where
679 B: 'static,
680 C: Codec + EnvelopeCodec + 'static,
681 M: Send + Sync + 'static,
682 {
683 let mut set: tokio::task::JoinSet<Result<(String, Bytes), RuntimeError>> =
684 tokio::task::JoinSet::new();
685 for branch in branches {
686 let id = branch.id().to_string();
687 let branch = Arc::clone(branch);
688 let branch_input = input.clone();
689 let branch_backend = Arc::clone(backend);
690 let branch_instance_id = instance_id.to_string();
691 let ctx_for_work = context.clone();
692
693 set.spawn(async move {
694 let result = Self::execute_branch_with_checkpoint(
695 &branch,
696 branch_input,
697 branch_backend,
698 branch_instance_id,
699 ctx_for_work,
700 )
701 .await?;
702 Ok((id, result))
703 });
704 }
705
706 let mut branch_results: Vec<(String, Bytes)> = Vec::with_capacity(set.len());
707 while let Some(res) = set.join_next().await {
708 branch_results.push(res??);
709 }
710 Ok(branch_results)
711 }
712
713 #[allow(clippy::manual_async_fn, clippy::too_many_lines)]
722 fn execute_branch_with_checkpoint<C, M>(
723 continuation: &WorkflowContinuation,
724 input: Bytes,
725 backend: Arc<B>,
726 instance_id: String,
727 context: WorkflowContext<C, M>,
728 ) -> impl std::future::Future<Output = Result<Bytes, RuntimeError>> + Send + '_
729 where
730 B: 'static,
731 C: Codec + EnvelopeCodec + 'static,
732 M: Send + Sync + 'static,
733 {
734 async move {
735 let mut snapshot = backend.load_snapshot(&instance_id).await?;
736
737 let mut current = continuation;
738 let mut current_input = input;
739
740 loop {
741 let step: StepResult = match current {
742 WorkflowContinuation::Task {
743 id,
744 func,
745 timeout,
746 retry_policy,
747 ..
748 } => {
749 let func = func
750 .as_ref()
751 .ok_or_else(|| WorkflowError::TaskNotImplemented(id.clone()))?;
752
753 let output = loop {
754 match branch_execute_or_skip_task(
755 id,
756 current_input.clone(),
757 |i| func.run(i),
758 timeout.as_ref(),
759 &mut snapshot,
760 &instance_id,
761 backend.as_ref(),
762 )
763 .await
764 {
765 Ok(output) => {
766 snapshot.clear_retry_state(id);
767 break output;
768 }
769 Err(e) => {
770 if let Some(rp) = retry_policy
771 && !snapshot.retries_exhausted(id)
772 {
773 let next_retry_at =
774 snapshot.record_retry(id, rp, &e.to_string(), None);
775 snapshot.clear_task_deadline();
776 tracing::info!(
777 task_id = %id,
778 attempt = snapshot.get_retry_state(id).map_or(0, |rs| rs.attempts),
779 max_retries = rp.max_retries,
780 %next_retry_at,
781 error = %e,
782 "Retrying task (branch)"
783 );
784 let delay = (next_retry_at - chrono::Utc::now())
785 .to_std()
786 .unwrap_or_default();
787 tokio::time::sleep(delay).await;
788 continue;
789 }
790 return Err(e);
791 }
792 }
793 };
794 Ok(ControlFlow::Continue(output))
795 }
796 WorkflowContinuation::Delay { id, duration, .. } => {
797 if let Some(result) = snapshot.get_task_result(id) {
798 tracing::debug!(delay_id = %id, "delay already completed in branch, skipping");
799 Ok(ControlFlow::Continue(result.output.clone()))
800 } else {
801 let wake_at = compute_wake_at(duration)?;
802 Ok(ControlFlow::Break(StepOutcome::Park(ParkReason::Delay {
803 delay_id: id.clone(),
804 wake_at,
805 next_task: None,
806 passthrough: current_input.clone(),
807 })))
808 }
809 }
810 WorkflowContinuation::AwaitSignal {
811 id,
812 signal_name,
813 timeout,
814 ..
815 } => {
816 if let Some(result) = snapshot.get_task_result(id) {
817 tracing::debug!(signal_id = %id, %signal_name, "signal already consumed in branch, skipping");
818 Ok(ControlFlow::Continue(result.output.clone()))
819 } else {
820 let wake_at = compute_signal_timeout(timeout.as_ref());
821 Ok(ControlFlow::Break(StepOutcome::Park(
822 ParkReason::AwaitingSignal {
823 signal_id: id.clone(),
824 signal_name: signal_name.clone(),
825 timeout: wake_at,
826 next_task: None,
827 },
828 )))
829 }
830 }
831 WorkflowContinuation::Fork { branches, join, .. } => {
832 let branch_results = Self::execute_nested_fork_branches(
833 branches,
834 ¤t_input,
835 &backend,
836 &instance_id,
837 &context,
838 )
839 .await?;
840
841 match resolve_join(
842 join.as_deref(),
843 &branch_results,
844 context.codec.as_ref(),
845 )? {
846 JoinResolution::Continue { input, .. } => {
847 Ok(ControlFlow::Continue(input))
848 }
849 JoinResolution::Done(output) => {
850 Ok(ControlFlow::Break(StepOutcome::Done(output)))
851 }
852 }
853 }
854 WorkflowContinuation::Branch {
855 id,
856 key_fn: Some(key_fn),
857 branches,
858 default,
859 ..
860 } => {
861 if let Some(result) = snapshot.get_task_result(id) {
862 Ok(ControlFlow::Continue(result.output.clone()))
863 } else {
864 let key_bytes = key_fn
865 .run(current_input.clone())
866 .await
867 .map_err(RuntimeError::from)?;
868 let key: String = context
869 .codec
870 .decode_string(&key_bytes)
871 .map_err(RuntimeError::from)?;
872
873 let chosen =
874 branches.get(&key).or(default.as_ref()).ok_or_else(|| {
875 WorkflowError::BranchKeyNotFound {
876 branch_id: id.clone(),
877 key: key.clone(),
878 }
879 })?;
880
881 let branch_output = Box::pin(Self::execute_branch_with_checkpoint(
882 chosen,
883 current_input.clone(),
884 Arc::clone(&backend),
885 instance_id.clone(),
886 context.clone(),
887 ))
888 .await?;
889
890 let envelope_bytes = context
891 .codec
892 .encode_branch_envelope(&key, &branch_output)
893 .map_err(RuntimeError::from)?;
894
895 snapshot.mark_task_completed(id.clone(), envelope_bytes.clone());
896 backend.save_snapshot(&snapshot).await?;
897
898 Ok(ControlFlow::Continue(envelope_bytes))
899 }
900 }
901 WorkflowContinuation::Branch {
902 key_fn: None, id, ..
903 } => {
904 return Err(WorkflowError::TaskNotImplemented(
905 sayiir_core::workflow::key_fn_id(id),
906 )
907 .into());
908 }
909 WorkflowContinuation::Loop {
910 id,
911 body,
912 max_iterations,
913 on_max,
914 ..
915 } => {
916 if let Some(result) = snapshot.get_task_result(id) {
917 Ok(ControlFlow::Continue(result.output.clone()))
918 } else {
919 let cfg = LoopConfig {
920 id,
921 body,
922 max_iterations: *max_iterations,
923 on_max: *on_max,
924 start_iteration: snapshot.loop_iteration(id),
925 };
926 let mut hooks = CheckpointingLoopHooks {
927 snapshot: &mut snapshot,
928 backend: backend.as_ref(),
929 track_position: false,
930 };
931 let output = run_loop_async(
932 &cfg,
933 current_input.clone(),
934 |input| {
935 Box::pin(Self::execute_branch_with_checkpoint(
936 body,
937 input,
938 Arc::clone(&backend),
939 instance_id.clone(),
940 context.clone(),
941 ))
942 },
943 &mut hooks,
944 )
945 .await?;
946 Ok(ControlFlow::Continue(output))
947 }
948 }
949 WorkflowContinuation::ChildWorkflow { id, child, .. } => {
950 if let Some(result) = snapshot.get_task_result(id) {
951 Ok(ControlFlow::Continue(result.output.clone()))
952 } else {
953 let output = Box::pin(Self::execute_branch_with_checkpoint(
954 child,
955 current_input.clone(),
956 Arc::clone(&backend),
957 instance_id.clone(),
958 context.clone(),
959 ))
960 .await?;
961
962 snapshot.mark_task_completed(id.clone(), output.clone());
963 backend.save_snapshot(&snapshot).await?;
964
965 Ok(ControlFlow::Continue(output))
966 }
967 }
968 };
969
970 match step? {
971 ControlFlow::Continue(output) => match current.get_next() {
972 Some(next) => {
973 current = next;
974 current_input = output;
975 }
976 None => return Ok(output),
977 },
978 ControlFlow::Break(StepOutcome::Done(output)) => return Ok(output),
979 ControlFlow::Break(StepOutcome::Park(reason)) => {
980 return Err(save_branch_park_checkpoint(
981 reason,
982 &instance_id,
983 backend.as_ref(),
984 )
985 .await);
986 }
987 }
988 }
989 }
990 }
991}
992
993#[cfg(test)]
994#[allow(
995 clippy::unwrap_used,
996 clippy::expect_used,
997 clippy::panic,
998 clippy::indexing_slicing,
999 clippy::too_many_lines,
1000 clippy::manual_let_else
1001)]
1002mod tests {
1003 use super::*;
1004 use crate::serialization::JsonCodec;
1005 use sayiir_core::codec::Encoder;
1006 use sayiir_core::context::WorkflowContext;
1007 use sayiir_core::error::BoxError;
1008 use sayiir_core::snapshot::SignalKind;
1009 use sayiir_core::snapshot::WorkflowSnapshotState;
1010 use sayiir_core::task::BranchOutputs;
1011 use sayiir_core::workflow::WorkflowBuilder;
1012 use sayiir_macros::BranchKey;
1013 use sayiir_persistence::InMemoryBackend;
1014 use sayiir_persistence::{SignalStore, SnapshotStore};
1015
1016 #[derive(BranchKey)]
1017 enum RouteKey {
1018 Billing,
1019 Tech,
1020 }
1021
1022 #[derive(BranchKey)]
1023 enum AbKey {
1024 A,
1025 B,
1026 }
1027
1028 fn ctx() -> WorkflowContext<JsonCodec, ()> {
1029 WorkflowContext::new("test-workflow", Arc::new(JsonCodec), Arc::new(()))
1030 }
1031
1032 #[tokio::test]
1037 async fn test_run_single_task() {
1038 let backend = InMemoryBackend::new();
1039 let runner = CheckpointingRunner::new(backend);
1040
1041 let workflow = WorkflowBuilder::new(ctx())
1042 .then("add_one", |i: u32| async move { Ok(i + 1) })
1043 .build()
1044 .unwrap();
1045
1046 let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
1047 assert!(matches!(status, WorkflowStatus::Completed));
1048
1049 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1051 assert!(snapshot.state.is_completed());
1052 }
1053
1054 #[tokio::test]
1055 async fn test_run_chained_tasks() {
1056 let backend = InMemoryBackend::new();
1057 let runner = CheckpointingRunner::new(backend);
1058
1059 let workflow = WorkflowBuilder::new(ctx())
1060 .then("add_one", |i: u32| async move { Ok(i + 1) })
1061 .then("double", |i: u32| async move { Ok(i * 2) })
1062 .build()
1063 .unwrap();
1064
1065 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1066 assert!(matches!(status, WorkflowStatus::Completed));
1067
1068 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1069 assert!(snapshot.state.is_completed());
1070 }
1071
1072 #[tokio::test]
1073 async fn test_run_three_task_chain() {
1074 let backend = InMemoryBackend::new();
1075 let runner = CheckpointingRunner::new(backend);
1076
1077 let workflow = WorkflowBuilder::new(ctx())
1078 .then("step1", |i: u32| async move { Ok(i + 1) })
1079 .then("step2", |i: u32| async move { Ok(i * 3) })
1080 .then("step3", |i: u32| async move { Ok(i - 2) })
1081 .build()
1082 .unwrap();
1083
1084 let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
1085 assert!(matches!(status, WorkflowStatus::Completed));
1087 }
1088
1089 #[tokio::test]
1090 async fn test_run_task_failure() {
1091 let backend = InMemoryBackend::new();
1092 let runner = CheckpointingRunner::new(backend);
1093
1094 let workflow = WorkflowBuilder::new(ctx())
1095 .then("fail", |_i: u32| async move {
1096 Err::<u32, BoxError>("intentional failure".into())
1097 })
1098 .build()
1099 .unwrap();
1100
1101 let status = runner.run(&workflow, "inst-1", 1u32).await.unwrap();
1102 match status {
1103 WorkflowStatus::Failed(e) => {
1104 assert!(e.contains("intentional failure"));
1105 }
1106 _ => panic!("Expected Failed status"),
1107 }
1108
1109 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1111 assert!(snapshot.state.is_failed());
1112 }
1113
1114 #[tokio::test]
1115 async fn test_run_fork_join() {
1116 let backend = InMemoryBackend::new();
1117 let runner = CheckpointingRunner::new(backend);
1118
1119 let workflow = WorkflowBuilder::new(ctx())
1120 .then("prepare", |i: u32| async move { Ok(i) })
1121 .branches(|b| {
1122 b.add("double", |i: u32| async move { Ok(i * 2) });
1123 b.add("add_ten", |i: u32| async move { Ok(i + 10) });
1124 })
1125 .join("combine", |outputs: BranchOutputs<JsonCodec>| async move {
1126 let doubled: u32 = outputs.get_by_id("double")?;
1127 let added: u32 = outputs.get_by_id("add_ten")?;
1128 Ok(doubled + added)
1129 })
1130 .build()
1131 .unwrap();
1132
1133 let status = runner.run(&workflow, "inst-1", 5u32).await.unwrap();
1134 assert!(matches!(status, WorkflowStatus::Completed));
1135
1136 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1137 assert!(snapshot.state.is_completed());
1138 }
1139
1140 #[tokio::test]
1141 async fn test_run_checkpoints_intermediate_tasks() {
1142 let backend = InMemoryBackend::new();
1143 let runner = CheckpointingRunner::new(backend);
1144
1145 let workflow = WorkflowBuilder::new(ctx())
1146 .then("step1", |i: u32| async move { Ok(i + 1) })
1147 .then("step2", |i: u32| async move { Ok(i * 2) })
1148 .build()
1149 .unwrap();
1150
1151 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1152 assert!(matches!(status, WorkflowStatus::Completed));
1153
1154 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1157 assert!(snapshot.state.is_completed());
1158 }
1159
1160 #[tokio::test]
1165 async fn test_resume_completed_workflow() {
1166 let backend = InMemoryBackend::new();
1167 let runner = CheckpointingRunner::new(backend);
1168
1169 let workflow = WorkflowBuilder::new(ctx())
1170 .then("step1", |i: u32| async move { Ok(i + 1) })
1171 .build()
1172 .unwrap();
1173
1174 runner.run(&workflow, "inst-1", 5u32).await.unwrap();
1176
1177 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1179 assert!(matches!(status, WorkflowStatus::Completed));
1180 }
1181
1182 #[tokio::test]
1183 async fn test_resume_failed_workflow() {
1184 let backend = InMemoryBackend::new();
1185 let runner = CheckpointingRunner::new(backend);
1186
1187 let workflow = WorkflowBuilder::new(ctx())
1188 .then("fail", |_i: u32| async move {
1189 Err::<u32, BoxError>("failure".into())
1190 })
1191 .build()
1192 .unwrap();
1193
1194 runner.run(&workflow, "inst-1", 1u32).await.unwrap();
1195
1196 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1197 match status {
1198 WorkflowStatus::Failed(_) => {}
1199 _ => panic!("Expected Failed status"),
1200 }
1201 }
1202
1203 #[tokio::test]
1204 async fn test_resume_definition_hash_mismatch() {
1205 let backend = InMemoryBackend::new();
1206 let runner = CheckpointingRunner::new(backend);
1207
1208 let workflow1 = WorkflowBuilder::new(ctx())
1209 .then("step1", |i: u32| async move { Ok(i + 1) })
1210 .build()
1211 .unwrap();
1212
1213 runner.run(&workflow1, "inst-1", 5u32).await.unwrap();
1215
1216 let mut snapshot = WorkflowSnapshot::with_initial_input(
1218 "inst-2".into(),
1219 workflow1.definition_hash().to_string(),
1220 Bytes::from(serde_json::to_vec(&5u32).unwrap()),
1221 );
1222 snapshot.update_position(ExecutionPosition::AtTask {
1223 task_id: "step1".into(),
1224 });
1225 runner.backend().save_snapshot(&snapshot).await.unwrap();
1226
1227 let workflow2 = WorkflowBuilder::new(ctx())
1229 .then("step1", |i: u32| async move { Ok(i + 1) })
1230 .then("step2", |i: u32| async move { Ok(i * 2) })
1231 .build()
1232 .unwrap();
1233
1234 let result = runner.resume(&workflow2, "inst-2").await;
1236 assert!(result.is_err());
1237 assert!(result.unwrap_err().to_string().contains("mismatch"));
1238 }
1239
1240 #[tokio::test]
1245 async fn test_cancel_running_workflow() {
1246 let backend = InMemoryBackend::new();
1247 let runner = CheckpointingRunner::new(backend);
1248
1249 let workflow = WorkflowBuilder::new(ctx())
1251 .then("slow_task", |i: u32| async move {
1252 tokio::time::sleep(std::time::Duration::from_secs(10)).await;
1253 Ok(i)
1254 })
1255 .build()
1256 .unwrap();
1257
1258 let input_bytes = Arc::new(JsonCodec).encode(&1u32).unwrap();
1260 let mut snapshot = WorkflowSnapshot::with_initial_input(
1261 "inst-cancel".into(),
1262 workflow.definition_hash().to_string(),
1263 input_bytes,
1264 );
1265 snapshot.update_position(ExecutionPosition::AtTask {
1266 task_id: "slow_task".into(),
1267 });
1268 runner.backend().save_snapshot(&snapshot).await.unwrap();
1269
1270 let client = crate::WorkflowClient::from_shared(Arc::clone(runner.backend()));
1272 client
1273 .cancel(
1274 "inst-cancel",
1275 Some("testing".into()),
1276 Some("test-suite".into()),
1277 )
1278 .await
1279 .unwrap();
1280
1281 let req = runner
1283 .backend()
1284 .get_signal("inst-cancel", SignalKind::Cancel)
1285 .await
1286 .unwrap();
1287 assert!(req.is_some());
1288 assert_eq!(req.unwrap().reason, Some("testing".into()));
1289 }
1290
1291 #[tokio::test]
1292 async fn test_run_with_pre_cancellation() {
1293 let backend = InMemoryBackend::new();
1294 let runner = CheckpointingRunner::new(backend);
1295
1296 let workflow = WorkflowBuilder::new(ctx())
1297 .then("task1", |i: u32| async move { Ok(i + 1) })
1298 .then("task2", |i: u32| async move { Ok(i * 2) })
1299 .build()
1300 .unwrap();
1301
1302 let input_bytes = Arc::new(JsonCodec).encode(&1u32).unwrap();
1304 let mut snapshot = WorkflowSnapshot::with_initial_input(
1305 "inst-precancel".into(),
1306 workflow.definition_hash().to_string(),
1307 input_bytes,
1308 );
1309 snapshot.update_position(ExecutionPosition::AtTask {
1310 task_id: "task1".into(),
1311 });
1312 runner.backend().save_snapshot(&snapshot).await.unwrap();
1313
1314 let client = crate::WorkflowClient::from_shared(Arc::clone(runner.backend()));
1315 client
1316 .cancel("inst-precancel", Some("pre-cancel".into()), None)
1317 .await
1318 .unwrap();
1319
1320 let status = runner.resume(&workflow, "inst-precancel").await.unwrap();
1322 match status {
1323 WorkflowStatus::Cancelled { reason, .. } => {
1324 assert_eq!(reason, Some("pre-cancel".into()));
1325 }
1326 _ => panic!("Expected Cancelled status, got: {status:?}"),
1327 }
1328 }
1329
1330 #[tokio::test]
1335 async fn test_resume_nonexistent_instance() {
1336 let backend = InMemoryBackend::new();
1337 let runner = CheckpointingRunner::new(backend);
1338
1339 let workflow = WorkflowBuilder::new(ctx())
1340 .then("task", |i: u32| async move { Ok(i) })
1341 .build()
1342 .unwrap();
1343
1344 let result = runner.resume(&workflow, "nonexistent").await;
1345 assert!(result.is_err());
1346 }
1347
1348 #[tokio::test]
1349 async fn test_run_failure_in_chain_saves_snapshot() {
1350 let backend = InMemoryBackend::new();
1351 let runner = CheckpointingRunner::new(backend);
1352
1353 let workflow = WorkflowBuilder::new(ctx())
1354 .then("step1", |i: u32| async move { Ok(i + 1) })
1355 .then("fail_step", |_i: u32| async move {
1356 Err::<u32, BoxError>("mid-chain failure".into())
1357 })
1358 .then("step3", |i: u32| async move { Ok(i * 2) })
1359 .build()
1360 .unwrap();
1361
1362 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1363 match status {
1364 WorkflowStatus::Failed(e) => {
1365 assert!(e.contains("mid-chain failure"));
1366 }
1367 _ => panic!("Expected Failed"),
1368 }
1369
1370 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1372 assert!(snapshot.state.is_failed());
1373 }
1374
1375 #[tokio::test]
1380 async fn test_run_workflow_with_delay_returns_waiting() {
1381 let backend = InMemoryBackend::new();
1382 let runner = CheckpointingRunner::new(backend);
1383
1384 let workflow = WorkflowBuilder::new(ctx())
1385 .then("step1", |i: u32| async move { Ok(i + 1) })
1386 .delay("wait_1h", std::time::Duration::from_secs(3600))
1387 .then("step2", |i: u32| async move { Ok(i * 2) })
1388 .build()
1389 .unwrap();
1390
1391 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1392
1393 match &status {
1395 WorkflowStatus::Waiting { delay_id, .. } => {
1396 assert_eq!(delay_id, "wait_1h");
1397 }
1398 _ => panic!("Expected Waiting status, got {status:?}"),
1399 }
1400
1401 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1403 assert!(snapshot.state.is_in_progress());
1404 match &snapshot.state {
1405 WorkflowSnapshotState::InProgress { position, .. } => match position {
1406 ExecutionPosition::AtDelay {
1407 delay_id,
1408 next_task_id,
1409 ..
1410 } => {
1411 assert_eq!(delay_id, "wait_1h");
1412 assert_eq!(next_task_id.as_deref(), Some("step2"));
1413 }
1414 other => panic!("Expected AtDelay, got {other:?}"),
1415 },
1416 _ => panic!("Expected InProgress"),
1417 }
1418
1419 assert!(snapshot.get_task_result("step1").is_some());
1421 assert!(snapshot.get_task_result("wait_1h").is_some());
1423 }
1424
1425 #[tokio::test]
1426 async fn test_resume_before_delay_expires_returns_waiting() {
1427 let backend = InMemoryBackend::new();
1428 let runner = CheckpointingRunner::new(backend);
1429
1430 let workflow = WorkflowBuilder::new(ctx())
1431 .then("step1", |i: u32| async move { Ok(i + 1) })
1432 .delay("wait_1h", std::time::Duration::from_secs(3600))
1433 .then("step2", |i: u32| async move { Ok(i * 2) })
1434 .build()
1435 .unwrap();
1436
1437 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1439 assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1440
1441 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1443 match &status {
1444 WorkflowStatus::Waiting { delay_id, .. } => {
1445 assert_eq!(delay_id, "wait_1h");
1446 }
1447 _ => panic!("Expected Waiting on resume, got {status:?}"),
1448 }
1449 }
1450
1451 #[tokio::test]
1452 async fn test_resume_after_delay_expires_completes() {
1453 let backend = InMemoryBackend::new();
1454 let runner = CheckpointingRunner::new(backend);
1455
1456 let workflow = WorkflowBuilder::new(ctx())
1458 .then("step1", |i: u32| async move { Ok(i + 1) })
1459 .delay("wait_short", std::time::Duration::from_millis(1))
1460 .then("step2", |i: u32| async move { Ok(i * 2) })
1461 .build()
1462 .unwrap();
1463
1464 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1466 assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1467
1468 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1470
1471 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1473 assert!(
1474 matches!(status, WorkflowStatus::Completed),
1475 "Expected Completed after delay expired, got {status:?}"
1476 );
1477
1478 let snapshot = runner.backend().load_snapshot("inst-1").await.unwrap();
1480 assert!(snapshot.state.is_completed());
1481 }
1482
1483 #[tokio::test]
1484 async fn test_cancel_during_delay() {
1485 let backend = InMemoryBackend::new();
1486 let runner = CheckpointingRunner::new(backend);
1487
1488 let workflow = WorkflowBuilder::new(ctx())
1489 .then("step1", |i: u32| async move { Ok(i + 1) })
1490 .delay("wait_1h", std::time::Duration::from_secs(3600))
1491 .then("step2", |i: u32| async move { Ok(i * 2) })
1492 .build()
1493 .unwrap();
1494
1495 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1497 assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1498
1499 let client = crate::WorkflowClient::from_shared(Arc::clone(runner.backend()));
1501 client
1502 .cancel(
1503 "inst-1",
1504 Some("no longer needed".into()),
1505 Some("admin".into()),
1506 )
1507 .await
1508 .unwrap();
1509
1510 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1512 match status {
1513 WorkflowStatus::Cancelled {
1514 reason,
1515 cancelled_by,
1516 } => {
1517 assert_eq!(reason, Some("no longer needed".into()));
1518 assert_eq!(cancelled_by, Some("admin".into()));
1519 }
1520 _ => panic!("Expected Cancelled status, got {status:?}"),
1521 }
1522 }
1523
1524 #[tokio::test]
1525 async fn test_delay_as_last_node() {
1526 let backend = InMemoryBackend::new();
1527 let runner = CheckpointingRunner::new(backend);
1528
1529 let workflow = WorkflowBuilder::new(ctx())
1530 .then("step1", |i: u32| async move { Ok(i + 1) })
1531 .delay("final_wait", std::time::Duration::from_millis(1))
1532 .build()
1533 .unwrap();
1534
1535 let status = runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1537 assert!(matches!(status, WorkflowStatus::Waiting { .. }));
1538
1539 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1541
1542 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1544 assert!(
1545 matches!(status, WorkflowStatus::Completed),
1546 "Expected Completed when delay is last node, got {status:?}"
1547 );
1548 }
1549
1550 #[tokio::test]
1551 async fn test_delay_data_passthrough() {
1552 let backend = InMemoryBackend::new();
1553 let runner = CheckpointingRunner::new(backend);
1554
1555 let workflow = WorkflowBuilder::new(ctx())
1557 .then("step1", |i: u32| async move { Ok(i + 1) })
1558 .delay("wait", std::time::Duration::from_millis(1))
1559 .then("step2", |i: u32| async move {
1560 assert_eq!(i, 11);
1562 Ok(i * 2)
1563 })
1564 .build()
1565 .unwrap();
1566
1567 runner.run(&workflow, "inst-1", 10u32).await.unwrap();
1569
1570 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1572 let status = runner.resume(&workflow, "inst-1").await.unwrap();
1573 assert!(matches!(status, WorkflowStatus::Completed));
1574 }
1575
1576 #[tokio::test]
1581 async fn test_run_task_timeout_fails_workflow() {
1582 use sayiir_core::task::TaskMetadata;
1583
1584 let backend = InMemoryBackend::new();
1585 let runner = CheckpointingRunner::new(backend);
1586
1587 let workflow = WorkflowBuilder::new(ctx())
1588 .with_registry()
1589 .then("slow_task", |i: u32| async move {
1590 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1591 Ok(i)
1592 })
1593 .with_metadata(TaskMetadata {
1594 timeout: Some(std::time::Duration::from_millis(5)),
1595 ..Default::default()
1596 })
1597 .build()
1598 .unwrap();
1599
1600 let status = runner
1601 .run(workflow.workflow(), "inst-timeout", 5u32)
1602 .await
1603 .unwrap();
1604 match status {
1605 WorkflowStatus::Failed(msg) => {
1606 assert!(
1607 msg.contains("timed out"),
1608 "Expected timeout error, got: {msg}"
1609 );
1610 assert!(
1611 msg.contains("slow_task"),
1612 "Expected task id in error, got: {msg}"
1613 );
1614 }
1615 other => panic!("Expected Failed status, got {other:?}"),
1616 }
1617 }
1618
1619 #[tokio::test]
1620 async fn test_run_task_within_timeout_succeeds() {
1621 use sayiir_core::task::TaskMetadata;
1622
1623 let backend = InMemoryBackend::new();
1624 let runner = CheckpointingRunner::new(backend);
1625
1626 let workflow = WorkflowBuilder::new(ctx())
1627 .with_registry()
1628 .then("fast_task", |i: u32| async move { Ok(i + 1) })
1629 .with_metadata(TaskMetadata {
1630 timeout: Some(std::time::Duration::from_secs(5)),
1631 ..Default::default()
1632 })
1633 .build()
1634 .unwrap();
1635
1636 let status = runner
1637 .run(workflow.workflow(), "inst-fast", 5u32)
1638 .await
1639 .unwrap();
1640 assert!(matches!(status, WorkflowStatus::Completed));
1641 }
1642
1643 #[tokio::test]
1644 async fn test_route_selects_correct_branch() {
1645 let backend = InMemoryBackend::new();
1646 let runner = CheckpointingRunner::new(backend.clone());
1647
1648 let workflow = WorkflowBuilder::new(ctx())
1649 .then("classify", |input: String| async move {
1650 Ok(serde_json::json!({ "intent": input }))
1651 })
1652 .route::<u32, RouteKey, _, _>(|data: serde_json::Value| async move {
1653 match data["intent"].as_str().unwrap_or("unknown") {
1654 "billing" => Ok(RouteKey::Billing),
1655 "tech" => Ok(RouteKey::Tech),
1656 other => Err(format!("unknown intent: {other}").into()),
1657 }
1658 })
1659 .branch(RouteKey::Billing, |sub| {
1660 sub.then("handle_billing", |_data: serde_json::Value| async move {
1661 Ok(100u32)
1662 })
1663 })
1664 .branch(RouteKey::Tech, |sub| {
1665 sub.then("handle_tech", |_data: serde_json::Value| async move {
1666 Ok(200u32)
1667 })
1668 })
1669 .done()
1670 .build()
1671 .unwrap();
1672
1673 let status = runner
1675 .run(&workflow, "inst-branch-1", "billing".to_string())
1676 .await
1677 .unwrap();
1678 assert!(matches!(status, WorkflowStatus::Completed));
1679
1680 let snapshot = backend.load_snapshot("inst-branch-1").await.unwrap();
1681 match &snapshot.state {
1684 WorkflowSnapshotState::Completed { final_output } => {
1685 let envelope: serde_json::Value = serde_json::from_slice(final_output).unwrap();
1686 assert_eq!(envelope["branch"], "billing");
1687 assert_eq!(envelope["result"], 100);
1688 }
1689 other => panic!("Expected Completed, got: {other:?}"),
1690 }
1691 }
1692
1693 #[tokio::test]
1694 async fn test_route_with_default() {
1695 let backend = InMemoryBackend::new();
1696 let runner = CheckpointingRunner::new(backend.clone());
1697
1698 let workflow = WorkflowBuilder::new(ctx())
1702 .route::<String, AbKey, _, _>(|input: String| async move {
1703 match input.as_str() {
1704 "a" => Ok(AbKey::A),
1705 "b" => Ok(AbKey::B),
1706 other => Err(format!("unknown: {other}").into()),
1707 }
1708 })
1709 .branch(AbKey::A, |sub| {
1710 sub.then("handle_a", |_data: String| async move {
1711 Ok("matched".to_string())
1712 })
1713 })
1714 .default_branch(|sub| {
1715 sub.then("handle_fallback", |_data: String| async move {
1716 Ok("fallback".to_string())
1717 })
1718 })
1719 .done()
1720 .build()
1721 .unwrap();
1722
1723 let status = runner
1725 .run(&workflow, "inst-branch-default", "b".to_string())
1726 .await
1727 .unwrap();
1728 assert!(matches!(status, WorkflowStatus::Completed));
1729
1730 let snapshot = backend.load_snapshot("inst-branch-default").await.unwrap();
1731 match &snapshot.state {
1732 WorkflowSnapshotState::Completed { final_output } => {
1733 let envelope: serde_json::Value = serde_json::from_slice(final_output).unwrap();
1734 assert_eq!(envelope["branch"], "b");
1735 assert_eq!(envelope["result"], "fallback");
1736 }
1737 other => panic!("Expected Completed, got: {other:?}"),
1738 }
1739 }
1740
1741 #[tokio::test]
1742 async fn test_route_missing_branches_detected() {
1743 let result = WorkflowBuilder::new(ctx())
1746 .route::<String, RouteKey, _, _>(|input: String| async move {
1747 match input.as_str() {
1748 "billing" => Ok(RouteKey::Billing),
1749 _ => Ok(RouteKey::Tech),
1750 }
1751 })
1752 .branch(RouteKey::Billing, |sub| {
1753 sub.then("handle_billing", |_data: String| async move {
1754 Ok("ok".to_string())
1755 })
1756 })
1757 .done()
1758 .build();
1759
1760 let errors = match result {
1761 Err(e) => e,
1762 Ok(_) => panic!("expected build error"),
1763 };
1764 let has_missing = errors.iter().any(|e| {
1765 matches!(
1766 e,
1767 sayiir_core::error::BuildError::MissingBranches {
1768 branch_id,
1769 missing_keys,
1770 } if branch_id == "branch_1" && missing_keys.contains(&"tech".to_string())
1771 )
1772 });
1773 assert!(has_missing, "Expected MissingBranches error in: {errors:?}");
1774 }
1775
1776 #[tokio::test]
1777 async fn test_route_then_next_step() {
1778 use sayiir_core::task::BranchEnvelope;
1779
1780 let backend = InMemoryBackend::new();
1781 let runner = CheckpointingRunner::new(backend.clone());
1782
1783 let workflow = WorkflowBuilder::new(ctx())
1784 .route::<u32, AbKey, _, _>(|input: String| async move {
1785 match input.as_str() {
1786 "a" => Ok(AbKey::A),
1787 "b" => Ok(AbKey::B),
1788 other => Err(format!("unknown: {other}").into()),
1789 }
1790 })
1791 .branch(AbKey::A, |sub| {
1792 sub.then("handle_a", |_data: String| async move { Ok(10u32) })
1793 })
1794 .branch(AbKey::B, |sub| {
1795 sub.then("handle_b", |_data: String| async move { Ok(20u32) })
1796 })
1797 .done()
1798 .then("finalize", |env: BranchEnvelope<u32>| async move {
1799 Ok(env.result + 1)
1800 })
1801 .build()
1802 .unwrap();
1803
1804 let status = runner
1805 .run(&workflow, "inst-branch-next", "a".to_string())
1806 .await
1807 .unwrap();
1808 assert!(matches!(status, WorkflowStatus::Completed));
1809
1810 let snapshot = backend.load_snapshot("inst-branch-next").await.unwrap();
1811 match &snapshot.state {
1812 WorkflowSnapshotState::Completed { final_output } => {
1813 let val: u32 = serde_json::from_slice(final_output).unwrap();
1814 assert_eq!(val, 11); }
1816 other => panic!("Expected Completed, got: {other:?}"),
1817 }
1818 }
1819}