1use crate::codec::Codec;
2use crate::codec::sealed;
3use crate::context::WorkflowContext;
4use crate::error::WorkflowError;
5use crate::registry::TaskRegistry;
6use crate::task::{
7 BranchOutputs, ErasedBranch, branch, to_core_task_arc, to_heterogeneous_join_task_arc,
8};
9use crate::workflow::{SerializableWorkflow, Workflow, WorkflowContinuation};
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13pub struct NoContinuation;
15
16pub struct NoRegistry;
18
19pub trait ContinuationState {
21 fn append(self, new_task: WorkflowContinuation) -> WorkflowContinuation;
23}
24
25impl ContinuationState for NoContinuation {
26 fn append(self, new_task: WorkflowContinuation) -> WorkflowContinuation {
27 new_task
28 }
29}
30
31impl ContinuationState for WorkflowContinuation {
32 fn append(mut self, new_task: WorkflowContinuation) -> WorkflowContinuation {
33 append_to_chain(&mut self, new_task);
34 self
35 }
36}
37
38pub trait RegistryBehavior {
40 fn maybe_register<I, O, F, Fut, C>(&mut self, _id: &str, _codec: Arc<C>, _func: &Arc<F>)
42 where
43 F: Fn(I) -> Fut + Send + Sync + 'static,
44 I: Send + 'static,
45 O: Send + 'static,
46 Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
47 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static;
48
49 fn maybe_register_join<O, F, Fut, C>(&mut self, _id: &str, _codec: Arc<C>, _func: &Arc<F>)
51 where
52 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
53 O: Send + 'static,
54 Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
55 C: Codec
56 + sealed::EncodeValue<O>
57 + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
58 + Send
59 + Sync
60 + 'static;
61}
62
63impl RegistryBehavior for NoRegistry {
64 fn maybe_register<I, O, F, Fut, C>(&mut self, _id: &str, _codec: Arc<C>, _func: &Arc<F>)
65 where
66 F: Fn(I) -> Fut + Send + Sync + 'static,
67 I: Send + 'static,
68 O: Send + 'static,
69 Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
70 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
71 {
72 }
74
75 fn maybe_register_join<O, F, Fut, C>(&mut self, _id: &str, _codec: Arc<C>, _func: &Arc<F>)
76 where
77 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
78 O: Send + 'static,
79 Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
80 C: Codec
81 + sealed::EncodeValue<O>
82 + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
83 + Send
84 + Sync
85 + 'static,
86 {
87 }
89}
90
91impl RegistryBehavior for TaskRegistry {
92 fn maybe_register<I, O, F, Fut, C>(&mut self, id: &str, codec: Arc<C>, func: &Arc<F>)
93 where
94 F: Fn(I) -> Fut + Send + Sync + 'static,
95 I: Send + 'static,
96 O: Send + 'static,
97 Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
98 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
99 {
100 use crate::task::TaskMetadata;
101 self.register_fn_arc(id, codec, Arc::clone(func), TaskMetadata::default());
102 }
103
104 fn maybe_register_join<O, F, Fut, C>(&mut self, id: &str, codec: Arc<C>, func: &Arc<F>)
105 where
106 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
107 O: Send + 'static,
108 Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
109 C: Codec
110 + sealed::EncodeValue<O>
111 + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
112 + Send
113 + Sync
114 + 'static,
115 {
116 use crate::task::TaskMetadata;
117 self.register_arc_join(id, codec, Arc::clone(func), TaskMetadata::default());
118 }
119}
120
121pub struct WorkflowBuilder<C, Input, Output, M = (), Cont = NoContinuation, R = NoRegistry> {
122 context: WorkflowContext<C, M>,
123 continuation: Cont,
124 registry: R,
125 last_task_id: Option<String>,
126 _phantom: PhantomData<(Input, Output)>,
127}
128
129#[allow(clippy::mismatching_type_param_order)] impl<C, Input, M> WorkflowBuilder<C, Input, Input, M, NoContinuation, NoRegistry> {
131 #[must_use]
136 pub fn new(ctx: WorkflowContext<C, M>) -> Self
137 where
138 C: Codec,
139 M: Send + Sync + 'static,
140 {
141 Self {
142 context: ctx,
143 continuation: NoContinuation,
144 registry: NoRegistry,
145 last_task_id: None,
146 _phantom: PhantomData,
147 }
148 }
149
150 #[must_use]
188 pub fn with_registry(
189 self,
190 ) -> WorkflowBuilder<C, Input, Input, M, NoContinuation, TaskRegistry> {
191 self.with_existing_registry(TaskRegistry::new())
192 }
193
194 #[must_use]
246 pub fn with_existing_registry(
247 self,
248 registry: TaskRegistry,
249 ) -> WorkflowBuilder<C, Input, Input, M, NoContinuation, TaskRegistry> {
250 WorkflowBuilder {
251 context: self.context,
252 continuation: NoContinuation,
253 registry,
254 last_task_id: None,
255 _phantom: PhantomData,
256 }
257 }
258}
259
260impl<C, Input, Output, M, Cont, R> WorkflowBuilder<C, Input, Output, M, Cont, R>
262where
263 R: RegistryBehavior,
264 Cont: ContinuationState,
265{
266 pub fn then<F, Fut, NewOutput>(
268 mut self,
269 id: &str,
270 func: F,
271 ) -> WorkflowBuilder<C, Input, NewOutput, M, WorkflowContinuation, R>
272 where
273 F: Fn(Output) -> Fut + Send + Sync + 'static,
274 Output: Send + 'static,
275 NewOutput: Send + 'static,
276 Fut: std::future::Future<Output = Result<NewOutput, crate::error::BoxError>>
277 + Send
278 + 'static,
279 C: Codec + sealed::DecodeValue<Output> + sealed::EncodeValue<NewOutput> + 'static,
280 {
281 let codec = Arc::clone(&self.context.codec);
282 let func = Arc::new(func);
283
284 self.registry
286 .maybe_register::<Output, NewOutput, _, _, _>(id, codec.clone(), &func);
287
288 let task = to_core_task_arc(func, codec);
289
290 let new_task = WorkflowContinuation::Task {
291 id: id.to_string(),
292 func: Some(task),
293 timeout: None,
294 retry_policy: None,
295 next: None,
296 };
297
298 let continuation = self.continuation.append(new_task);
299
300 WorkflowBuilder {
301 continuation,
302 context: self.context,
303 registry: self.registry,
304 last_task_id: Some(id.to_string()),
305 _phantom: PhantomData,
306 }
307 }
308}
309
310impl<C, Input, Output, M, Cont, R> WorkflowBuilder<C, Input, Output, M, Cont, R>
312where
313 Cont: ContinuationState,
314{
315 #[must_use]
323 pub fn delay(
324 self,
325 id: &str,
326 duration: std::time::Duration,
327 ) -> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, R> {
328 let new_node = WorkflowContinuation::Delay {
329 id: id.to_string(),
330 duration,
331 next: None,
332 };
333 let continuation = self.continuation.append(new_node);
334 WorkflowBuilder {
335 continuation,
336 context: self.context,
337 registry: self.registry,
338 last_task_id: Some(id.to_string()),
339 _phantom: PhantomData,
340 }
341 }
342
343 #[must_use]
349 pub fn wait_for_signal(
350 self,
351 id: &str,
352 signal_name: &str,
353 timeout: Option<std::time::Duration>,
354 ) -> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, R> {
355 let new_node = WorkflowContinuation::AwaitSignal {
356 id: id.to_string(),
357 signal_name: signal_name.to_string(),
358 timeout,
359 next: None,
360 };
361 let continuation = self.continuation.append(new_node);
362 WorkflowBuilder {
363 continuation,
364 context: self.context,
365 registry: self.registry,
366 last_task_id: Some(id.to_string()),
367 _phantom: PhantomData,
368 }
369 }
370}
371
372impl<C, Input, Output, M, Cont> WorkflowBuilder<C, Input, Output, M, Cont, TaskRegistry>
374where
375 Cont: ContinuationState,
376{
377 pub fn then_registered<NewOutput>(
423 self,
424 id: &str,
425 ) -> Result<
426 WorkflowBuilder<C, Input, NewOutput, M, WorkflowContinuation, TaskRegistry>,
427 WorkflowError,
428 >
429 where
430 Output: Send + 'static,
431 NewOutput: Send + 'static,
432 {
433 let func = self
434 .registry
435 .get(id)
436 .ok_or_else(|| WorkflowError::TaskNotFound(id.to_string()))?;
437 let meta = self.registry.get_metadata(id);
438 let timeout = meta.and_then(|m| m.timeout);
439 let retry_policy = self
440 .registry
441 .get_metadata(id)
442 .and_then(|m| m.retries.clone());
443
444 let new_task = WorkflowContinuation::Task {
445 id: id.to_string(),
446 func: Some(func),
447 timeout,
448 retry_policy,
449 next: None,
450 };
451
452 let continuation = self.continuation.append(new_task);
453
454 Ok(WorkflowBuilder {
455 continuation,
456 context: self.context,
457 registry: self.registry,
458 last_task_id: Some(id.to_string()),
459 _phantom: PhantomData,
460 })
461 }
462}
463
464impl<C, Input, Output, M> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, TaskRegistry> {
466 #[must_use]
508 pub fn with_metadata(mut self, metadata: crate::task::TaskMetadata) -> Self {
509 if let Some(ref id) = self.last_task_id {
510 let timeout = metadata.timeout;
511 let retry_policy = metadata.retries.clone();
512 self.registry.set_metadata(id, metadata);
513 self.continuation.set_task_timeout(id, timeout);
516 self.continuation.set_task_retry_policy(id, retry_policy);
517 }
518 self
519 }
520}
521
522impl<C, Input, Output, M, Cont, R> WorkflowBuilder<C, Input, Output, M, Cont, R> {
524 pub fn branches<F>(self, f: F) -> ForkBuilder<C, Input, Output, M, Cont, R>
556 where
557 F: FnOnce(&mut BranchCollector<C, Output>),
558 C: Codec,
559 {
560 let codec = Arc::clone(&self.context.codec);
561 let mut collector = BranchCollector {
562 codec,
563 branches: Vec::new(),
564 _phantom: PhantomData,
565 };
566 f(&mut collector);
567
568 ForkBuilder {
569 context: self.context,
570 continuation: self.continuation,
571 branches: collector.branches,
572 registry: self.registry,
573 _phantom: PhantomData,
574 }
575 }
576
577 pub fn fork(self) -> ForkBuilder<C, Input, Output, M, Cont, R> {
579 ForkBuilder {
580 context: self.context,
581 continuation: self.continuation,
582 branches: Vec::new(),
583 registry: self.registry,
584 _phantom: PhantomData,
585 }
586 }
587}
588
589impl<C, Input, Output, M> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, NoRegistry> {
591 pub fn build(self) -> Result<Workflow<C, Input, M>, WorkflowError>
597 where
598 Input: Send + 'static,
599 Output: Send + 'static,
600 M: Send + Sync + 'static,
601 C: Codec
602 + sealed::DecodeValue<Input>
603 + sealed::DecodeValue<Output>
604 + sealed::EncodeValue<Input>
605 + sealed::EncodeValue<Output>,
606 {
607 if let Some(dup) = self.continuation.find_duplicate_id() {
608 return Err(WorkflowError::DuplicateTaskId(dup));
609 }
610
611 let definition_hash = self
612 .continuation
613 .to_serializable()
614 .compute_definition_hash();
615
616 Ok(Workflow {
617 definition_hash,
618 continuation: self.continuation,
619 context: self.context,
620 _phantom: PhantomData,
621 })
622 }
623}
624
625impl<C, Input, Output, M> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, TaskRegistry> {
627 pub fn build(self) -> Result<SerializableWorkflow<C, Input, M>, WorkflowError>
633 where
634 Input: Send + 'static,
635 Output: Send + 'static,
636 M: Send + Sync + 'static,
637 C: Codec
638 + sealed::DecodeValue<Input>
639 + sealed::DecodeValue<Output>
640 + sealed::EncodeValue<Input>
641 + sealed::EncodeValue<Output>,
642 {
643 if let Some(dup) = self.continuation.find_duplicate_id() {
644 return Err(WorkflowError::DuplicateTaskId(dup));
645 }
646
647 let definition_hash = self
648 .continuation
649 .to_serializable()
650 .compute_definition_hash();
651
652 let inner = Workflow {
653 definition_hash,
654 continuation: self.continuation,
655 context: self.context,
656 _phantom: PhantomData,
657 };
658
659 Ok(SerializableWorkflow {
660 inner,
661 registry: self.registry,
662 })
663 }
664}
665
666fn append_to_chain(continuation: &mut WorkflowContinuation, new_task: WorkflowContinuation) {
668 match continuation {
669 WorkflowContinuation::Task { next, .. }
670 | WorkflowContinuation::Delay { next, .. }
671 | WorkflowContinuation::AwaitSignal { next, .. } => match next {
672 Some(next_box) => append_to_chain(next_box, new_task),
673 None => *next = Some(Box::new(new_task)),
674 },
675 WorkflowContinuation::Fork { join, .. } => match join {
676 Some(join_box) => append_to_chain(join_box, new_task),
677 None => *join = Some(Box::new(new_task)),
678 },
679 }
680}
681
682pub struct BranchCollector<C, Input> {
686 codec: Arc<C>,
687 branches: Vec<ErasedBranch>,
688 _phantom: PhantomData<Input>,
689}
690
691impl<C, Input> BranchCollector<C, Input> {
692 pub fn add<F, Fut, BranchOutput>(&mut self, id: &str, func: F)
697 where
698 F: Fn(Input) -> Fut + Send + Sync + 'static,
699 Input: Send + 'static,
700 BranchOutput: Send + 'static,
701 Fut: std::future::Future<Output = Result<BranchOutput, crate::error::BoxError>>
702 + Send
703 + 'static,
704 C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<BranchOutput>,
705 {
706 let erased = branch(id, func).erase(Arc::clone(&self.codec));
707 self.branches.push(erased);
708 }
709}
710
711pub struct ForkBuilder<C, Input, Output, M, Cont = NoContinuation, R = NoRegistry> {
716 context: WorkflowContext<C, M>,
717 continuation: Cont,
718 branches: Vec<ErasedBranch>,
719 registry: R,
720 _phantom: PhantomData<(Input, Output)>,
721}
722
723impl<C, Input, Output, M, Cont, R> ForkBuilder<C, Input, Output, M, Cont, R>
725where
726 R: RegistryBehavior,
727 Cont: ContinuationState,
728{
729 #[must_use]
736 pub fn branch<F, Fut, BranchOutput>(mut self, id: &str, func: F) -> Self
737 where
738 F: Fn(Output) -> Fut + Send + Sync + 'static,
739 Output: Send + 'static,
740 BranchOutput: Send + 'static,
741 Fut: std::future::Future<Output = Result<BranchOutput, crate::error::BoxError>>
742 + Send
743 + 'static,
744 C: Codec + sealed::DecodeValue<Output> + sealed::EncodeValue<BranchOutput> + 'static,
745 {
746 let codec = Arc::clone(&self.context.codec);
747 let func = Arc::new(func);
748
749 self.registry
751 .maybe_register::<Output, BranchOutput, _, _, _>(id, codec.clone(), &func);
752
753 let func_clone = Arc::clone(&func);
755 let erased = branch(id, move |input| func_clone(input)).erase(codec);
756 self.branches.push(erased);
757 self
758 }
759
760 pub fn join<F, Fut, JoinOutput>(
762 mut self,
763 id: &str,
764 func: F,
765 ) -> WorkflowBuilder<C, Input, JoinOutput, M, WorkflowContinuation, R>
766 where
767 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
768 JoinOutput: Send + 'static,
769 Fut: std::future::Future<Output = Result<JoinOutput, crate::error::BoxError>>
770 + Send
771 + 'static,
772 C: Codec
773 + sealed::EncodeValue<JoinOutput>
774 + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
775 + Send
776 + Sync
777 + 'static,
778 {
779 let codec = Arc::clone(&self.context.codec);
780 let func = Arc::new(func);
781
782 self.registry
784 .maybe_register_join::<JoinOutput, _, _, _>(id, codec.clone(), &func);
785
786 let join_task_fn = to_heterogeneous_join_task_arc(func, codec);
787
788 let fork_id = WorkflowContinuation::derive_fork_id(
789 &self
790 .branches
791 .iter()
792 .map(|b| b.id.as_str())
793 .collect::<Vec<_>>(),
794 );
795
796 let branch_continuations: Vec<Arc<WorkflowContinuation>> = self
797 .branches
798 .into_iter()
799 .map(|b| {
800 Arc::new(WorkflowContinuation::Task {
801 id: b.id,
802 func: Some(b.task),
803 timeout: None,
804 retry_policy: None,
805 next: None,
806 })
807 })
808 .collect();
809
810 let join_task = WorkflowContinuation::Task {
813 id: id.to_string(),
814 func: Some(join_task_fn),
815 timeout: None,
816 retry_policy: None,
817 next: None,
818 };
819
820 let fork_continuation = WorkflowContinuation::Fork {
821 id: fork_id,
822 branches: branch_continuations.into_boxed_slice(),
823 join: Some(Box::new(join_task)),
824 };
825
826 let continuation = self.continuation.append(fork_continuation);
827
828 WorkflowBuilder {
829 continuation,
830 context: self.context,
831 registry: self.registry,
832 last_task_id: Some(id.to_string()),
833 _phantom: PhantomData,
834 }
835 }
836}
837
838impl<C, Input, Output, M, Cont> ForkBuilder<C, Input, Output, M, Cont, TaskRegistry>
840where
841 Cont: ContinuationState,
842{
843 pub fn branch_registered(mut self, id: &str) -> Result<Self, WorkflowError>
849 where
850 Output: Send + 'static,
851 {
852 let task = self
853 .registry
854 .get(id)
855 .ok_or_else(|| WorkflowError::TaskNotFound(id.to_string()))?;
856
857 self.branches.push(ErasedBranch {
858 id: id.to_string(),
859 task,
860 });
861 Ok(self)
862 }
863
864 pub fn join_registered<JoinOutput>(
870 self,
871 id: &str,
872 ) -> Result<
873 WorkflowBuilder<C, Input, JoinOutput, M, WorkflowContinuation, TaskRegistry>,
874 WorkflowError,
875 >
876 where
877 Output: Send + 'static,
878 JoinOutput: Send + 'static,
879 {
880 let join_task_fn = self
881 .registry
882 .get(id)
883 .ok_or_else(|| WorkflowError::TaskNotFound(id.to_string()))?;
884 let join_timeout = self.registry.get_metadata(id).and_then(|m| m.timeout);
885
886 let fork_id = WorkflowContinuation::derive_fork_id(
887 &self
888 .branches
889 .iter()
890 .map(|b| b.id.as_str())
891 .collect::<Vec<_>>(),
892 );
893
894 let branch_continuations: Vec<Arc<WorkflowContinuation>> = self
895 .branches
896 .into_iter()
897 .map(|b| {
898 let meta = self.registry.get_metadata(&b.id);
899 let timeout = meta.and_then(|m| m.timeout);
900 let retry_policy = self
901 .registry
902 .get_metadata(&b.id)
903 .and_then(|m| m.retries.clone());
904 Arc::new(WorkflowContinuation::Task {
905 id: b.id,
906 func: Some(b.task),
907 timeout,
908 retry_policy,
909 next: None,
910 })
911 })
912 .collect();
913
914 let join_retry_policy = self
915 .registry
916 .get_metadata(id)
917 .and_then(|m| m.retries.clone());
918 let join_task = WorkflowContinuation::Task {
919 id: id.to_string(),
920 func: Some(join_task_fn),
921 timeout: join_timeout,
922 retry_policy: join_retry_policy,
923 next: None,
924 };
925
926 let fork_continuation = WorkflowContinuation::Fork {
927 id: fork_id,
928 branches: branch_continuations.into_boxed_slice(),
929 join: Some(Box::new(join_task)),
930 };
931
932 let continuation = self.continuation.append(fork_continuation);
933
934 Ok(WorkflowBuilder {
935 continuation,
936 context: self.context,
937 registry: self.registry,
938 last_task_id: Some(id.to_string()),
939 _phantom: PhantomData,
940 })
941 }
942}