1use crate::context::WorkflowContext;
2use crate::error::WorkflowError;
3use crate::task::{RetryPolicy, UntypedCoreTask};
4use sha2::{Digest, Sha256};
5use std::collections::HashSet;
6use std::marker::PhantomData;
7use std::sync::Arc;
8
9macro_rules! impl_find_duplicate_id {
12 ($name:ident, task_fields: { $($task_extra:tt)* }, delay_extra: { $($delay_extra:tt)* }, deref_branch: $deref:expr) => {
13 impl $name {
14 pub(crate) fn find_duplicate_id(&self) -> Option<String> {
15 fn collect(cont: &$name, seen: &mut HashSet<String>) -> Option<String> {
16 match cont {
17 $name::Task { id, next, $($task_extra)* } => {
18 if !seen.insert(id.clone()) {
19 return Some(id.clone());
20 }
21 next.as_ref().and_then(|n| collect(n, seen))
22 }
23 $name::Fork { id, branches, join } => {
24 if !seen.insert(id.clone()) {
25 return Some(id.clone());
26 }
27 let deref_fn: fn(&_) -> &$name = $deref;
28 branches
29 .iter()
30 .find_map(|b| collect(deref_fn(b), seen))
31 .or_else(|| join.as_ref().and_then(|j| collect(j, seen)))
32 }
33 $name::Delay { id, next, $($delay_extra)* }
34 | $name::AwaitSignal { id, next, $($delay_extra)* } => {
35 if !seen.insert(id.clone()) {
36 return Some(id.clone());
37 }
38 next.as_ref().and_then(|n| collect(n, seen))
39 }
40 }
41 }
42 collect(self, &mut HashSet::new())
43 }
44 }
45 };
46}
47
48pub enum WorkflowContinuation {
50 Task {
51 id: String,
52 func: Option<UntypedCoreTask>,
55 timeout: Option<std::time::Duration>,
57 retry_policy: Option<RetryPolicy>,
59 next: Option<Box<WorkflowContinuation>>,
60 },
61 Fork {
62 id: String,
63 branches: Box<[Arc<WorkflowContinuation>]>,
64 join: Option<Box<WorkflowContinuation>>,
65 },
66 Delay {
68 id: String,
69 duration: std::time::Duration,
70 next: Option<Box<WorkflowContinuation>>,
71 },
72 AwaitSignal {
76 id: String,
77 signal_name: String,
78 timeout: Option<std::time::Duration>,
79 next: Option<Box<WorkflowContinuation>>,
80 },
81}
82
83impl_find_duplicate_id!(
84 WorkflowContinuation,
85 task_fields: { .. },
86 delay_extra: { .. },
87 deref_branch: |b: &Arc<WorkflowContinuation>| -> &WorkflowContinuation { b }
88);
89
90impl WorkflowContinuation {
91 #[must_use]
95 pub fn derive_fork_id(branch_ids: &[&str]) -> String {
96 branch_ids.join("||")
97 }
98
99 #[must_use]
101 pub fn id(&self) -> &str {
102 match self {
103 WorkflowContinuation::Task { id, .. }
104 | WorkflowContinuation::Fork { id, .. }
105 | WorkflowContinuation::Delay { id, .. }
106 | WorkflowContinuation::AwaitSignal { id, .. } => id,
107 }
108 }
109
110 #[must_use]
115 pub fn first_task_id(&self) -> &str {
116 match self {
117 WorkflowContinuation::Task { id, .. }
118 | WorkflowContinuation::Delay { id, .. }
119 | WorkflowContinuation::AwaitSignal { id, .. } => id,
120 WorkflowContinuation::Fork { branches, .. } => {
121 if let Some(first_branch) = branches.first() {
122 first_branch.first_task_id()
123 } else {
124 "unknown"
125 }
126 }
127 }
128 }
129
130 pub fn set_task_timeout(&mut self, target_id: &str, timeout: Option<std::time::Duration>) {
136 match self {
137 WorkflowContinuation::Task {
138 id,
139 timeout: t,
140 next,
141 ..
142 } => {
143 if id == target_id {
144 *t = timeout;
145 } else if let Some(next) = next {
146 next.set_task_timeout(target_id, timeout);
147 }
148 }
149 WorkflowContinuation::Delay { next, .. }
150 | WorkflowContinuation::AwaitSignal { next, .. } => {
151 if let Some(next) = next {
152 next.set_task_timeout(target_id, timeout);
153 }
154 }
155 WorkflowContinuation::Fork { join, .. } => {
156 if let Some(join) = join {
157 join.set_task_timeout(target_id, timeout);
158 }
159 }
160 }
161 }
162
163 pub fn set_task_retry_policy(&mut self, target_id: &str, policy: Option<RetryPolicy>) {
167 match self {
168 WorkflowContinuation::Task {
169 id,
170 retry_policy,
171 next,
172 ..
173 } => {
174 if id == target_id {
175 *retry_policy = policy;
176 } else if let Some(next) = next {
177 next.set_task_retry_policy(target_id, policy);
178 }
179 }
180 WorkflowContinuation::Delay { next, .. }
181 | WorkflowContinuation::AwaitSignal { next, .. } => {
182 if let Some(next) = next {
183 next.set_task_retry_policy(target_id, policy);
184 }
185 }
186 WorkflowContinuation::Fork { join, .. } => {
187 if let Some(join) = join {
188 join.set_task_retry_policy(target_id, policy);
189 }
190 }
191 }
192 }
193
194 #[must_use]
196 pub fn get_task_retry_policy(&self, task_id: &str) -> Option<&RetryPolicy> {
197 match self {
198 WorkflowContinuation::Task {
199 id,
200 retry_policy,
201 next,
202 ..
203 } => {
204 if id == task_id {
205 return retry_policy.as_ref();
206 }
207 next.as_ref().and_then(|n| n.get_task_retry_policy(task_id))
208 }
209 WorkflowContinuation::Delay { next, .. }
210 | WorkflowContinuation::AwaitSignal { next, .. } => {
211 next.as_ref().and_then(|n| n.get_task_retry_policy(task_id))
212 }
213 WorkflowContinuation::Fork { branches, join, .. } => {
214 for branch in branches {
215 if let Some(p) = branch.get_task_retry_policy(task_id) {
216 return Some(p);
217 }
218 }
219 join.as_ref().and_then(|j| j.get_task_retry_policy(task_id))
220 }
221 }
222 }
223
224 #[must_use]
230 pub fn get_task_timeout(&self, task_id: &str) -> Option<std::time::Duration> {
231 match self {
232 WorkflowContinuation::Task {
233 id, timeout, next, ..
234 } => {
235 if id == task_id {
236 return *timeout;
237 }
238 next.as_ref().and_then(|n| n.get_task_timeout(task_id))
239 }
240 WorkflowContinuation::Delay { next, .. }
241 | WorkflowContinuation::AwaitSignal { next, .. } => {
242 next.as_ref().and_then(|n| n.get_task_timeout(task_id))
243 }
244 WorkflowContinuation::Fork { branches, join, .. } => {
245 for branch in branches {
246 if let Some(t) = branch.get_task_timeout(task_id) {
247 return Some(t);
248 }
249 }
250 join.as_ref().and_then(|j| j.get_task_timeout(task_id))
251 }
252 }
253 }
254
255 #[must_use]
257 pub fn to_serializable(&self) -> SerializableContinuation {
258 match self {
259 #[allow(clippy::cast_possible_truncation)] WorkflowContinuation::Task {
261 id,
262 timeout,
263 retry_policy,
264 next,
265 ..
266 } => SerializableContinuation::Task {
267 id: id.clone(),
268 timeout_ms: timeout.map(|d| d.as_millis() as u64),
269 retry_policy: retry_policy.clone(),
270 next: next.as_ref().map(|n| Box::new(n.to_serializable())),
271 },
272 WorkflowContinuation::Fork { id, branches, join } => SerializableContinuation::Fork {
273 id: id.clone(),
274 branches: branches.iter().map(|b| b.to_serializable()).collect(),
275 join: join.as_ref().map(|j| Box::new(j.to_serializable())),
276 },
277 #[allow(clippy::cast_possible_truncation)] WorkflowContinuation::Delay { id, duration, next } => SerializableContinuation::Delay {
279 id: id.clone(),
280 duration_ms: duration.as_millis() as u64,
281 next: next.as_ref().map(|n| Box::new(n.to_serializable())),
282 },
283 #[allow(clippy::cast_possible_truncation)]
284 WorkflowContinuation::AwaitSignal {
285 id,
286 signal_name,
287 timeout,
288 next,
289 } => SerializableContinuation::AwaitSignal {
290 id: id.clone(),
291 signal_name: signal_name.clone(),
292 timeout_ms: timeout.map(|d| d.as_millis() as u64),
293 next: next.as_ref().map(|n| Box::new(n.to_serializable())),
294 },
295 }
296 }
297}
298
299#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
341pub enum SerializableContinuation {
342 Task {
343 id: String,
344 #[serde(default, skip_serializing_if = "Option::is_none")]
345 timeout_ms: Option<u64>,
346 #[serde(default, skip_serializing_if = "Option::is_none")]
347 retry_policy: Option<RetryPolicy>,
348 next: Option<Box<SerializableContinuation>>,
349 },
350 Fork {
351 id: String,
352 branches: Vec<SerializableContinuation>,
353 join: Option<Box<SerializableContinuation>>,
354 },
355 Delay {
356 id: String,
357 duration_ms: u64,
358 next: Option<Box<SerializableContinuation>>,
359 },
360 AwaitSignal {
361 id: String,
362 signal_name: String,
363 #[serde(default, skip_serializing_if = "Option::is_none")]
364 timeout_ms: Option<u64>,
365 next: Option<Box<SerializableContinuation>>,
366 },
367}
368
369impl_find_duplicate_id!(
370 SerializableContinuation,
371 task_fields: { .. },
372 delay_extra: { .. },
373 deref_branch: |b: &SerializableContinuation| -> &SerializableContinuation { b }
374);
375
376impl SerializableContinuation {
377 pub fn to_runnable(
385 &self,
386 registry: &crate::registry::TaskRegistry,
387 ) -> Result<WorkflowContinuation, WorkflowError> {
388 if let Some(dup) = self.find_duplicate_id() {
389 return Err(WorkflowError::DuplicateTaskId(dup));
390 }
391
392 self.to_runnable_unchecked(registry)
393 }
394
395 fn to_runnable_unchecked(
397 &self,
398 registry: &crate::registry::TaskRegistry,
399 ) -> Result<WorkflowContinuation, WorkflowError> {
400 match self {
401 SerializableContinuation::Task {
402 id,
403 timeout_ms,
404 retry_policy,
405 next,
406 } => {
407 let func = registry
408 .get(id)
409 .ok_or_else(|| WorkflowError::TaskNotFound(id.clone()))?;
410 let next = next
411 .as_ref()
412 .map(|n| n.to_runnable_unchecked(registry).map(Box::new))
413 .transpose()?;
414 Ok(WorkflowContinuation::Task {
415 id: id.clone(),
416 func: Some(func),
417 timeout: timeout_ms.map(std::time::Duration::from_millis),
418 retry_policy: retry_policy.clone(),
419 next,
420 })
421 }
422 SerializableContinuation::Fork { id, branches, join } => {
423 let branches: Result<Vec<_>, _> = branches
424 .iter()
425 .map(|b| b.to_runnable_unchecked(registry).map(Arc::new))
426 .collect();
427 let join = join
428 .as_ref()
429 .map(|j| j.to_runnable_unchecked(registry).map(Box::new))
430 .transpose()?;
431 Ok(WorkflowContinuation::Fork {
432 id: id.clone(),
433 branches: branches?.into_boxed_slice(),
434 join,
435 })
436 }
437 SerializableContinuation::Delay {
438 id,
439 duration_ms,
440 next,
441 } => {
442 let next = next
443 .as_ref()
444 .map(|n| n.to_runnable_unchecked(registry).map(Box::new))
445 .transpose()?;
446 Ok(WorkflowContinuation::Delay {
447 id: id.clone(),
448 duration: std::time::Duration::from_millis(*duration_ms),
449 next,
450 })
451 }
452 SerializableContinuation::AwaitSignal {
453 id,
454 signal_name,
455 timeout_ms,
456 next,
457 } => {
458 let next = next
459 .as_ref()
460 .map(|n| n.to_runnable_unchecked(registry).map(Box::new))
461 .transpose()?;
462 Ok(WorkflowContinuation::AwaitSignal {
463 id: id.clone(),
464 signal_name: signal_name.clone(),
465 timeout: timeout_ms.map(std::time::Duration::from_millis),
466 next,
467 })
468 }
469 }
470 }
471
472 #[must_use]
474 pub fn task_ids(&self) -> Vec<&str> {
475 fn collect<'a>(cont: &'a SerializableContinuation, ids: &mut Vec<&'a str>) {
476 match cont {
477 SerializableContinuation::Task { id, next, .. }
478 | SerializableContinuation::Delay { id, next, .. }
479 | SerializableContinuation::AwaitSignal { id, next, .. } => {
480 ids.push(id.as_str());
481 if let Some(n) = next {
482 collect(n, ids);
483 }
484 }
485 SerializableContinuation::Fork { id, branches, join } => {
486 ids.push(id.as_str());
487 for b in branches {
488 collect(b, ids);
489 }
490 if let Some(j) = join {
491 collect(j, ids);
492 }
493 }
494 }
495 }
496 let mut ids = Vec::new();
497 collect(self, &mut ids);
498 ids
499 }
500
501 #[must_use]
510 pub fn compute_definition_hash(&self) -> String {
511 fn hash_continuation(cont: &SerializableContinuation, hasher: &mut Sha256) {
512 match cont {
513 SerializableContinuation::Task {
514 id,
515 timeout_ms,
516 retry_policy,
517 next,
518 } => {
519 hasher.update(b"T:"); hasher.update(id.as_bytes());
521 if let Some(ms) = timeout_ms {
522 hasher.update(b":t:");
523 hasher.update(ms.to_string().as_bytes());
524 }
525 if let Some(rp) = retry_policy {
526 hasher.update(b":r:");
527 hasher.update(rp.max_retries.to_string().as_bytes());
528 hasher.update(b":");
529 hasher.update(rp.initial_delay.as_millis().to_string().as_bytes());
530 hasher.update(b":");
531 hasher.update(rp.backoff_multiplier.to_string().as_bytes());
532 }
533 hasher.update(b";");
534 if let Some(n) = next {
535 hash_continuation(n, hasher);
536 }
537 }
538 SerializableContinuation::Fork { id, branches, join } => {
539 hasher.update(b"F:");
540 hasher.update(id.as_bytes());
541 hasher.update(b"[");
542 for branch in branches {
543 hash_continuation(branch, hasher);
544 hasher.update(b",");
545 }
546 hasher.update(b"]");
547 if let Some(j) = join {
548 hasher.update(b"J:");
549 hash_continuation(j, hasher);
550 }
551 }
552 SerializableContinuation::Delay {
553 id,
554 duration_ms,
555 next,
556 } => {
557 hasher.update(b"D:");
558 hasher.update(id.as_bytes());
559 hasher.update(b":");
560 hasher.update(duration_ms.to_string().as_bytes());
561 hasher.update(b";");
562 if let Some(n) = next {
563 hash_continuation(n, hasher);
564 }
565 }
566 SerializableContinuation::AwaitSignal {
567 id,
568 signal_name,
569 timeout_ms,
570 next,
571 } => {
572 hasher.update(b"S:");
573 hasher.update(id.as_bytes());
574 hasher.update(b":");
575 hasher.update(signal_name.as_bytes());
576 if let Some(ms) = timeout_ms {
577 hasher.update(b":t:");
578 hasher.update(ms.to_string().as_bytes());
579 }
580 hasher.update(b";");
581 if let Some(n) = next {
582 hash_continuation(n, hasher);
583 }
584 }
585 }
586 }
587
588 let mut hasher = Sha256::new();
589 hash_continuation(self, &mut hasher);
590 let result = hasher.finalize();
591 format!("{result:x}")
592 }
593}
594
595#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
602pub struct SerializedWorkflowState {
603 pub workflow_id: String,
605 pub definition_hash: String,
608 pub continuation: SerializableContinuation,
610}
611
612#[derive(Debug)]
614pub enum WorkflowStatus {
615 InProgress,
617 Completed,
619 Failed(String),
621 Cancelled {
623 reason: Option<String>,
625 cancelled_by: Option<String>,
627 },
628 Paused {
630 reason: Option<String>,
632 paused_by: Option<String>,
634 },
635 Waiting {
637 wake_at: chrono::DateTime<chrono::Utc>,
639 delay_id: String,
641 },
642 AwaitingSignal {
644 signal_id: String,
646 signal_name: String,
648 wake_at: Option<chrono::DateTime<chrono::Utc>>,
650 },
651}
652
653pub use crate::builder::{
655 BranchCollector, ContinuationState, ForkBuilder, NoContinuation, NoRegistry, RegistryBehavior,
656 WorkflowBuilder,
657};
658
659use crate::registry::TaskRegistry;
660
661pub struct Workflow<C, Input, M = ()> {
663 pub(crate) definition_hash: String,
664 pub(crate) context: WorkflowContext<C, M>,
665 pub(crate) continuation: WorkflowContinuation,
666 pub(crate) _phantom: PhantomData<Input>,
667}
668
669impl<C, Input, M> Workflow<C, Input, M> {
670 #[must_use]
672 pub fn workflow_id(&self) -> &str {
673 &self.context.workflow_id
674 }
675
676 #[must_use]
682 pub fn definition_hash(&self) -> &str {
683 &self.definition_hash
684 }
685
686 #[must_use]
688 pub fn context(&self) -> &WorkflowContext<C, M> {
689 &self.context
690 }
691
692 #[must_use]
694 pub fn codec(&self) -> &Arc<C> {
695 &self.context.codec
696 }
697
698 #[must_use]
700 pub fn continuation(&self) -> &WorkflowContinuation {
701 &self.continuation
702 }
703
704 #[must_use]
706 pub fn metadata(&self) -> &Arc<M> {
707 &self.context.metadata
708 }
709}
710
711pub struct SerializableWorkflow<C, Input, M = ()> {
759 pub(crate) inner: Workflow<C, Input, M>,
760 pub(crate) registry: TaskRegistry,
761}
762
763impl<C, Input, M> SerializableWorkflow<C, Input, M> {
764 #[must_use]
766 pub fn workflow_id(&self) -> &str {
767 self.inner.workflow_id()
768 }
769
770 #[must_use]
772 pub fn definition_hash(&self) -> &str {
773 self.inner.definition_hash()
774 }
775
776 #[must_use]
778 pub fn workflow(&self) -> &Workflow<C, Input, M> {
779 &self.inner
780 }
781
782 #[must_use]
784 pub fn context(&self) -> &WorkflowContext<C, M> {
785 self.inner.context()
786 }
787
788 #[must_use]
790 pub fn codec(&self) -> &Arc<C> {
791 self.inner.codec()
792 }
793
794 #[must_use]
796 pub fn continuation(&self) -> &WorkflowContinuation {
797 self.inner.continuation()
798 }
799
800 #[must_use]
802 pub fn metadata(&self) -> &Arc<M> {
803 self.inner.metadata()
804 }
805
806 #[must_use]
808 pub fn registry(&self) -> &TaskRegistry {
809 &self.registry
810 }
811
812 #[must_use]
818 pub fn to_serializable(&self) -> SerializedWorkflowState {
819 SerializedWorkflowState {
820 workflow_id: self.inner.workflow_id().to_string(),
821 definition_hash: self.inner.definition_hash.clone(),
822 continuation: self.inner.continuation().to_serializable(),
823 }
824 }
825
826 pub fn to_runnable(
836 &self,
837 state: &SerializedWorkflowState,
838 ) -> Result<WorkflowContinuation, WorkflowError> {
839 if state.definition_hash != self.inner.definition_hash {
840 return Err(WorkflowError::DefinitionMismatch {
841 expected: self.inner.definition_hash.clone(),
842 found: state.definition_hash.clone(),
843 });
844 }
845 state.continuation.to_runnable(&self.registry)
846 }
847}
848
849#[cfg(test)]
850mod tests {
851 use crate::codec::{Decoder, Encoder, sealed};
852 use crate::error::BoxError;
853 use crate::workflow::WorkflowBuilder;
854 use bytes::Bytes;
855
856 struct DummyCodec;
857
858 impl Encoder for DummyCodec {}
859 impl Decoder for DummyCodec {}
860
861 impl<Input> sealed::EncodeValue<Input> for DummyCodec {
862 fn encode_value(&self, _value: &Input) -> Result<Bytes, BoxError> {
863 Ok(Bytes::new())
864 }
865 }
866 impl<Output> sealed::DecodeValue<Output> for DummyCodec {
867 fn decode_value(&self, _bytes: Bytes) -> Result<Output, BoxError> {
868 Err("Not implemented".into())
869 }
870 }
871
872 #[test]
873 fn test_workflow_build() {
874 use crate::context::WorkflowContext;
875 use crate::workflow::Workflow;
876 use std::sync::Arc;
877
878 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
879 let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
880 .then("test", |i: u32| async move { Ok(i + 1) })
881 .build()
882 .unwrap();
883
884 let _workflow_ref = &workflow;
887 }
888
889 #[test]
890 fn test_workflow_with_metadata() {
891 use crate::context::WorkflowContext;
892 use crate::workflow::Workflow;
893 use std::sync::Arc;
894
895 let ctx = WorkflowContext::new(
896 "test-workflow",
897 Arc::new(DummyCodec),
898 Arc::new("test_metadata"),
899 );
900 let workflow: Workflow<DummyCodec, u32, &str> = WorkflowBuilder::new(ctx)
901 .then("test", |i: u32| async move { Ok(i + 1) })
902 .build()
903 .unwrap();
904
905 assert_eq!(**workflow.metadata(), "test_metadata");
906 }
907
908 #[test]
909 fn test_task_order() {
910 use crate::context::WorkflowContext;
911 use crate::workflow::Workflow;
912 use std::sync::Arc;
913
914 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
915 let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
916 .then("first", |i: u32| async move { Ok(i + 1) })
917 .then("second", |i: u32| async move { Ok(i + 2) })
918 .then("third", |i: u32| async move { Ok(i + 3) })
919 .build()
920 .unwrap();
921
922 let mut current = workflow.continuation();
925 let mut task_ids = Vec::new();
926
927 loop {
928 match current {
929 crate::workflow::WorkflowContinuation::Task { id, next, .. } => {
930 task_ids.push(id.clone());
931 match next {
932 Some(next_box) => current = next_box.as_ref(),
933 None => break,
934 }
935 }
936 _ => break,
937 }
938 }
939
940 assert_eq!(
941 task_ids,
942 vec!["first", "second", "third"],
943 "Tasks should execute in the order they were added"
944 );
945 }
946
947 #[test]
948 fn test_heterogeneous_fork_join_compiles() {
949 use crate::context::WorkflowContext;
950 use crate::task::BranchOutputs;
951 use crate::workflow::Workflow;
952 use std::sync::Arc;
953
954 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
955 let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
958 .then("prepare", |i: u32| async move { Ok(i) })
959 .branches(|b| {
960 b.add("count", |i: u32| async move { Ok(i * 2) });
962 b.add("name", |i: u32| async move { Ok(format!("item_{}", i)) });
964 b.add("ratio", |i: u32| async move { Ok(i as f64 / 100.0) });
966 })
967 .join("combine", |outputs: BranchOutputs<DummyCodec>| async move {
968 let _ = outputs.len();
974 Ok(format!("combined {} branches", outputs.len()))
975 })
976 .then("final", |s: String| async move { Ok(s.len() as u32) })
977 .build()
978 .unwrap();
979
980 let _workflow_ref = &workflow;
981 }
982
983 #[test]
984 fn test_duplicate_branch_id_returns_error() {
985 use crate::context::WorkflowContext;
986 use crate::error::WorkflowError;
987 use std::sync::Arc;
988
989 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
990 let result = WorkflowBuilder::<_, u32, _>::new(ctx)
991 .then("prepare", |i: u32| async move { Ok(i) })
992 .branches(|b| {
993 b.add("count", |i: u32| async move { Ok(i * 2) });
994 b.add("count", |i: u32| async move { Ok(i * 3) }); })
996 .join("combine", |_outputs| async move { Ok(0u32) })
997 .build();
998
999 assert!(matches!(
1000 result,
1001 Err(WorkflowError::DuplicateTaskId(id)) if id == "count"
1002 ));
1003 }
1004
1005 #[test]
1006 fn test_serializable_continuation() {
1007 use crate::context::WorkflowContext;
1008 use crate::error::WorkflowError;
1009 use crate::registry::TaskRegistry;
1010 use std::sync::Arc;
1011
1012 let codec = Arc::new(DummyCodec);
1014 let ctx = WorkflowContext::new("test-workflow", codec.clone(), Arc::new(()));
1015 let workflow = WorkflowBuilder::new(ctx)
1016 .then("step1", |i: u32| async move { Ok(i + 1) })
1017 .then("step2", |i: u32| async move { Ok(i * 2) })
1018 .build()
1019 .unwrap();
1020
1021 let serializable = workflow.continuation().to_serializable();
1023
1024 let task_ids = serializable.task_ids();
1026 assert_eq!(task_ids, vec!["step1", "step2"]);
1027
1028 let empty_registry = TaskRegistry::new();
1030 let result = serializable.to_runnable(&empty_registry);
1031 assert!(matches!(result, Err(WorkflowError::TaskNotFound(id)) if id == "step1"));
1032
1033 let mut registry = TaskRegistry::new();
1035 registry.register_fn("step1", codec.clone(), |i: u32| async move { Ok(i + 1) });
1036 registry.register_fn("step2", codec.clone(), |i: u32| async move { Ok(i * 2) });
1037
1038 let hydrated = serializable.to_runnable(®istry);
1039 assert!(hydrated.is_ok());
1040 }
1041
1042 #[test]
1043 fn test_serializable_fork_join() {
1044 use crate::context::WorkflowContext;
1045 use crate::task::BranchOutputs;
1046 use std::sync::Arc;
1047
1048 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1049 let workflow = WorkflowBuilder::new(ctx)
1050 .then("prepare", |i: u32| async move { Ok(i) })
1051 .branches(|b| {
1052 b.add("branch_a", |i: u32| async move { Ok(i * 2) });
1053 b.add("branch_b", |i: u32| async move { Ok(i + 10) });
1054 })
1055 .join(
1056 "merge",
1057 |_: BranchOutputs<DummyCodec>| async move { Ok(0u32) },
1058 )
1059 .build()
1060 .unwrap();
1061
1062 let serializable = workflow.continuation().to_serializable();
1063 let task_ids = serializable.task_ids();
1064
1065 assert!(task_ids.contains(&"prepare"));
1067 assert!(task_ids.contains(&"branch_a||branch_b"));
1068 assert!(task_ids.contains(&"branch_a"));
1069 assert!(task_ids.contains(&"branch_b"));
1070 assert!(task_ids.contains(&"merge"));
1071 assert_eq!(task_ids.len(), 5);
1072 }
1073
1074 #[test]
1075 fn test_serializable_workflow_builder() {
1076 use crate::context::WorkflowContext;
1077 use std::sync::Arc;
1078
1079 let codec = Arc::new(DummyCodec);
1080 let ctx = WorkflowContext::new("test-workflow", codec, Arc::new(()));
1081
1082 let workflow = WorkflowBuilder::new(ctx)
1084 .with_registry()
1085 .then("step1", |i: u32| async move { Ok(i + 1) })
1086 .then("step2", |i: u32| async move { Ok(i * 2) })
1087 .build()
1088 .unwrap();
1089
1090 assert!(workflow.registry().contains("step1"));
1092 assert!(workflow.registry().contains("step2"));
1093 assert_eq!(workflow.registry().len(), 2);
1094
1095 let serializable = workflow.to_serializable();
1097 assert_eq!(serializable.continuation.task_ids(), vec!["step1", "step2"]);
1098
1099 let hydrated = workflow.to_runnable(&serializable);
1101 assert!(hydrated.is_ok());
1102 }
1103
1104 #[test]
1105 fn test_with_existing_registry_and_then_registered() {
1106 use crate::context::WorkflowContext;
1107 use crate::registry::TaskRegistry;
1108 use crate::workflow::SerializableWorkflow;
1109 use std::sync::Arc;
1110
1111 let codec = Arc::new(DummyCodec);
1112
1113 let mut registry = TaskRegistry::new();
1115 registry.register_fn("double", codec.clone(), |i: u32| async move { Ok(i * 2) });
1116 registry.register_fn("add_ten", codec.clone(), |i: u32| async move { Ok(i + 10) });
1117
1118 let ctx = WorkflowContext::new("test-workflow", codec.clone(), Arc::new(()));
1120 let workflow: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx)
1121 .with_existing_registry(registry)
1122 .then_registered::<u32>("double")
1123 .unwrap()
1124 .then_registered::<u32>("add_ten")
1125 .unwrap()
1126 .build()
1127 .unwrap();
1128
1129 assert!(workflow.registry().contains("double"));
1131 assert!(workflow.registry().contains("add_ten"));
1132
1133 let serializable = workflow.to_serializable();
1135 assert_eq!(
1136 serializable.continuation.task_ids(),
1137 vec!["double", "add_ten"]
1138 );
1139
1140 let hydrated = workflow.to_runnable(&serializable);
1142 assert!(hydrated.is_ok());
1143 }
1144
1145 #[test]
1146 fn test_mixed_inline_and_registered_tasks() {
1147 use crate::context::WorkflowContext;
1148 use crate::registry::TaskRegistry;
1149 use crate::workflow::SerializableWorkflow;
1150 use std::sync::Arc;
1151
1152 let codec = Arc::new(DummyCodec);
1153
1154 let mut registry = TaskRegistry::new();
1156 registry.register_fn(
1157 "preregistered",
1158 codec.clone(),
1159 |i: u32| async move { Ok(i * 2) },
1160 );
1161
1162 let ctx = WorkflowContext::new("test-workflow", codec.clone(), Arc::new(()));
1164 let workflow: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx)
1165 .with_existing_registry(registry)
1166 .then_registered::<u32>("preregistered") .unwrap()
1168 .then("inline", |i: u32| async move { Ok(i + 5) }) .build()
1170 .unwrap();
1171
1172 assert!(workflow.registry().contains("preregistered"));
1174 assert!(workflow.registry().contains("inline"));
1175 assert_eq!(workflow.registry().len(), 2);
1176 }
1177
1178 #[test]
1179 fn test_workflow_id_and_definition_hash() {
1180 use crate::context::WorkflowContext;
1181 use std::sync::Arc;
1182
1183 let ctx = WorkflowContext::new("my-workflow-id", Arc::new(DummyCodec), Arc::new(()));
1184 let workflow = WorkflowBuilder::new(ctx)
1185 .with_registry()
1186 .then("step1", |i: u32| async move { Ok(i + 1) })
1187 .then("step2", |i: u32| async move { Ok(i * 2) })
1188 .build()
1189 .unwrap();
1190
1191 assert_eq!(workflow.workflow_id(), "my-workflow-id");
1193
1194 assert!(!workflow.definition_hash().is_empty());
1196
1197 let state = workflow.to_serializable();
1199 assert_eq!(state.workflow_id, "my-workflow-id");
1200 assert_eq!(state.definition_hash, workflow.definition_hash());
1201 }
1202
1203 #[test]
1204 fn test_definition_hash_changes_with_structure() {
1205 use crate::context::WorkflowContext;
1206 use std::sync::Arc;
1207
1208 let ctx1 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
1210 let workflow1 = WorkflowBuilder::new(ctx1)
1211 .with_registry()
1212 .then("step1", |i: u32| async move { Ok(i + 1) })
1213 .build()
1214 .unwrap();
1215
1216 let ctx2 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
1217 let workflow2 = WorkflowBuilder::new(ctx2)
1218 .with_registry()
1219 .then("step1", |i: u32| async move { Ok(i + 1) })
1220 .then("step2", |i: u32| async move { Ok(i * 2) })
1221 .build()
1222 .unwrap();
1223
1224 assert_ne!(workflow1.definition_hash(), workflow2.definition_hash());
1225 }
1226
1227 #[test]
1228 fn test_definition_mismatch_error() {
1229 use crate::context::WorkflowContext;
1230 use crate::error::WorkflowError;
1231 use std::sync::Arc;
1232
1233 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1234 let workflow = WorkflowBuilder::new(ctx)
1235 .with_registry()
1236 .then("step1", |i: u32| async move { Ok(i + 1) })
1237 .build()
1238 .unwrap();
1239
1240 let mut state = workflow.to_serializable();
1242 state.definition_hash = "wrong-hash".to_string();
1243
1244 let result = workflow.to_runnable(&state);
1246 assert!(matches!(
1247 result,
1248 Err(WorkflowError::DefinitionMismatch { .. })
1249 ));
1250 }
1251
1252 #[test]
1253 fn test_duplicate_id_tampering_detection() {
1254 use crate::error::WorkflowError;
1255 use crate::registry::TaskRegistry;
1256 use crate::workflow::SerializableContinuation;
1257 use std::sync::Arc;
1258
1259 let codec = Arc::new(DummyCodec);
1260
1261 let mut registry = TaskRegistry::new();
1263 registry.register_fn("step1", codec.clone(), |i: u32| async move { Ok(i + 1) });
1264 registry.register_fn("step2", codec.clone(), |i: u32| async move { Ok(i * 2) });
1265
1266 let tampered = SerializableContinuation::Task {
1268 id: "step1".to_string(),
1269 timeout_ms: None,
1270 retry_policy: None,
1271 next: Some(Box::new(SerializableContinuation::Task {
1272 id: "step1".to_string(), timeout_ms: None,
1274 retry_policy: None,
1275 next: None,
1276 })),
1277 };
1278
1279 let result = tampered.to_runnable(®istry);
1281 assert!(matches!(
1282 result,
1283 Err(WorkflowError::DuplicateTaskId(id)) if id == "step1"
1284 ));
1285 }
1286
1287 #[test]
1292 fn test_delay_builder() {
1293 use crate::context::WorkflowContext;
1294 use crate::workflow::{Workflow, WorkflowContinuation};
1295 use std::sync::Arc;
1296 use std::time::Duration;
1297
1298 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1299 let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
1300 .then("step1", |i: u32| async move { Ok(i + 1) })
1301 .delay("wait_1s", Duration::from_secs(1))
1302 .then("step2", |i: u32| async move { Ok(i * 2) })
1303 .build()
1304 .unwrap();
1305
1306 let mut ids = Vec::new();
1308 let mut current = workflow.continuation();
1309 loop {
1310 match current {
1311 WorkflowContinuation::Task { id, next, .. } => {
1312 ids.push(format!("task:{id}"));
1313 match next {
1314 Some(n) => current = n,
1315 None => break,
1316 }
1317 }
1318 WorkflowContinuation::Delay {
1319 id, duration, next, ..
1320 } => {
1321 ids.push(format!("delay:{id}:{}ms", duration.as_millis()));
1322 match next {
1323 Some(n) => current = n,
1324 None => break,
1325 }
1326 }
1327 _ => break,
1328 }
1329 }
1330
1331 assert_eq!(
1332 ids,
1333 vec!["task:step1", "delay:wait_1s:1000ms", "task:step2"]
1334 );
1335 }
1336
1337 #[test]
1338 fn test_delay_serialization_roundtrip() {
1339 use crate::context::WorkflowContext;
1340 use crate::workflow::SerializableContinuation;
1341 use std::sync::Arc;
1342 use std::time::Duration;
1343
1344 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1345 let workflow = WorkflowBuilder::new(ctx)
1346 .with_registry()
1347 .then("step1", |i: u32| async move { Ok(i + 1) })
1348 .delay("wait_5s", Duration::from_secs(5))
1349 .then("step2", |i: u32| async move { Ok(i * 2) })
1350 .build()
1351 .unwrap();
1352
1353 let serializable = workflow.to_serializable();
1355
1356 let task_ids = serializable.continuation.task_ids();
1358 assert_eq!(task_ids, vec!["step1", "wait_5s", "step2"]);
1359
1360 match &serializable.continuation {
1362 SerializableContinuation::Task { next, .. } => {
1363 let next = next.as_ref().unwrap();
1364 match next.as_ref() {
1365 SerializableContinuation::Delay {
1366 id, duration_ms, ..
1367 } => {
1368 assert_eq!(id, "wait_5s");
1369 assert_eq!(*duration_ms, 5000);
1370 }
1371 other => panic!("Expected Delay, got {other:?}"),
1372 }
1373 }
1374 other => panic!("Expected Task, got {other:?}"),
1375 }
1376
1377 let hydrated = workflow.to_runnable(&serializable);
1379 assert!(hydrated.is_ok());
1380 }
1381
1382 #[test]
1383 fn test_delay_first_task_id() {
1384 use crate::context::WorkflowContext;
1385 use std::sync::Arc;
1386 use std::time::Duration;
1387
1388 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1389 let workflow = WorkflowBuilder::new(ctx)
1390 .delay("initial_delay", Duration::from_secs(10))
1391 .then("step1", |i: u32| async move { Ok(i + 1) })
1392 .build()
1393 .unwrap();
1394
1395 assert_eq!(workflow.continuation().first_task_id(), "initial_delay");
1396 }
1397
1398 #[test]
1399 fn test_delay_duplicate_id_detection() {
1400 use crate::context::WorkflowContext;
1401 use crate::error::WorkflowError;
1402 use std::sync::Arc;
1403 use std::time::Duration;
1404
1405 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1406 let result = WorkflowBuilder::<_, u32, _>::new(ctx)
1407 .then("dup", |i: u32| async move { Ok(i + 1) })
1408 .delay("dup", Duration::from_secs(1))
1409 .build();
1410
1411 assert!(matches!(
1412 result,
1413 Err(WorkflowError::DuplicateTaskId(id)) if id == "dup"
1414 ));
1415 }
1416
1417 #[test]
1418 fn test_delay_definition_hash_includes_duration() {
1419 use crate::context::WorkflowContext;
1420 use crate::workflow::SerializableWorkflow;
1421 use std::sync::Arc;
1422 use std::time::Duration;
1423
1424 let ctx1 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
1426 let wf1: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx1)
1427 .with_registry()
1428 .then("step1", |i: u32| async move { Ok(i + 1) })
1429 .delay("wait", Duration::from_secs(1))
1430 .build()
1431 .unwrap();
1432
1433 let ctx2 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
1435 let wf2: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx2)
1436 .with_registry()
1437 .then("step1", |i: u32| async move { Ok(i + 1) })
1438 .delay("wait", Duration::from_secs(60))
1439 .build()
1440 .unwrap();
1441
1442 assert_ne!(wf1.definition_hash(), wf2.definition_hash());
1444 }
1445
1446 #[test]
1447 fn test_delay_definition_hash_differs_from_task() {
1448 use crate::context::WorkflowContext;
1449 use crate::workflow::SerializableWorkflow;
1450 use std::sync::Arc;
1451 use std::time::Duration;
1452
1453 let ctx1 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
1455 let wf1: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx1)
1456 .with_registry()
1457 .then("step1", |i: u32| async move { Ok(i + 1) })
1458 .build()
1459 .unwrap();
1460
1461 let ctx2 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
1463 let wf2: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx2)
1464 .with_registry()
1465 .delay("step1", Duration::from_secs(1))
1466 .build()
1467 .unwrap();
1468
1469 assert_ne!(wf1.definition_hash(), wf2.definition_hash());
1471 }
1472
1473 #[test]
1474 fn test_delay_task_ids() {
1475 use crate::context::WorkflowContext;
1476 use std::sync::Arc;
1477 use std::time::Duration;
1478
1479 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1480 let workflow = WorkflowBuilder::new(ctx)
1481 .then("fetch", |i: u32| async move { Ok(i) })
1482 .delay("wait_24h", Duration::from_secs(86400))
1483 .then("process", |i: u32| async move { Ok(i + 1) })
1484 .build()
1485 .unwrap();
1486
1487 let serializable = workflow.continuation().to_serializable();
1488 let ids = serializable.task_ids();
1489 assert_eq!(ids, vec!["fetch", "wait_24h", "process"]);
1490 }
1491
1492 #[test]
1493 fn test_delay_only_workflow() {
1494 use crate::context::WorkflowContext;
1495 use std::sync::Arc;
1496 use std::time::Duration;
1497
1498 use crate::workflow::Workflow;
1499
1500 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1501 let workflow: Workflow<DummyCodec, u32> = WorkflowBuilder::new(ctx)
1502 .delay("just_wait", Duration::from_millis(10))
1503 .build()
1504 .unwrap();
1505
1506 assert_eq!(workflow.continuation().first_task_id(), "just_wait");
1507
1508 let serializable = workflow.continuation().to_serializable();
1509 assert_eq!(serializable.task_ids(), vec!["just_wait"]);
1510 }
1511
1512 #[test]
1513 fn test_delay_to_runnable_no_registry_needed() {
1514 use crate::registry::TaskRegistry;
1515 use crate::workflow::SerializableContinuation;
1516
1517 let delay = SerializableContinuation::Delay {
1519 id: "wait".to_string(),
1520 duration_ms: 5000,
1521 next: None,
1522 };
1523
1524 let empty_registry = TaskRegistry::new();
1525 let result = delay.to_runnable(&empty_registry);
1526 assert!(result.is_ok());
1527
1528 let runnable = result.unwrap();
1529 match runnable {
1530 crate::workflow::WorkflowContinuation::Delay {
1531 id, duration, next, ..
1532 } => {
1533 assert_eq!(id, "wait");
1534 assert_eq!(duration, std::time::Duration::from_millis(5000));
1535 assert!(next.is_none());
1536 }
1537 _ => panic!("Expected Delay variant"),
1538 }
1539 }
1540
1541 #[test]
1546 fn test_timeout_serialization_roundtrip() {
1547 use crate::context::WorkflowContext;
1548 use crate::task::TaskMetadata;
1549 use crate::workflow::SerializableContinuation;
1550 use std::sync::Arc;
1551 use std::time::Duration;
1552
1553 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1554 let workflow = WorkflowBuilder::new(ctx)
1555 .with_registry()
1556 .then("step1", |i: u32| async move { Ok(i + 1) })
1557 .with_metadata(TaskMetadata {
1558 timeout: Some(Duration::from_secs(30)),
1559 ..Default::default()
1560 })
1561 .then("step2", |i: u32| async move { Ok(i * 2) })
1562 .build()
1563 .unwrap();
1564
1565 let serializable = workflow.to_serializable();
1567
1568 match &serializable.continuation {
1570 SerializableContinuation::Task { id, timeout_ms, .. } => {
1571 assert_eq!(id, "step1");
1572 assert_eq!(*timeout_ms, Some(30_000));
1573 }
1574 other => panic!("Expected Task, got {other:?}"),
1575 }
1576
1577 let hydrated = workflow.to_runnable(&serializable).unwrap();
1579 match &hydrated {
1580 crate::workflow::WorkflowContinuation::Task { id, timeout, .. } => {
1581 assert_eq!(id, "step1");
1582 assert_eq!(*timeout, Some(Duration::from_secs(30)));
1583 }
1584 _ => panic!("Expected Task variant"),
1585 }
1586 }
1587
1588 #[test]
1589 fn test_timeout_changes_definition_hash() {
1590 use crate::context::WorkflowContext;
1591 use crate::task::TaskMetadata;
1592 use crate::workflow::SerializableWorkflow;
1593 use std::sync::Arc;
1594 use std::time::Duration;
1595
1596 let ctx1 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
1598 let wf1: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx1)
1599 .with_registry()
1600 .then("step1", |i: u32| async move { Ok(i + 1) })
1601 .build()
1602 .unwrap();
1603
1604 let ctx2 = WorkflowContext::new("workflow", Arc::new(DummyCodec), Arc::new(()));
1606 let wf2: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx2)
1607 .with_registry()
1608 .then("step1", |i: u32| async move { Ok(i + 1) })
1609 .with_metadata(TaskMetadata {
1610 timeout: Some(Duration::from_secs(30)),
1611 ..Default::default()
1612 })
1613 .build()
1614 .unwrap();
1615
1616 assert_ne!(wf1.definition_hash(), wf2.definition_hash());
1618 }
1619
1620 #[test]
1621 fn test_no_timeout_field_absent_in_serialization() {
1622 use crate::context::WorkflowContext;
1623 use std::sync::Arc;
1624
1625 let ctx = WorkflowContext::new("test-workflow", Arc::new(DummyCodec), Arc::new(()));
1626 let workflow = WorkflowBuilder::new(ctx)
1627 .with_registry()
1628 .then("step1", |i: u32| async move { Ok(i + 1) })
1629 .build()
1630 .unwrap();
1631
1632 let serializable = workflow.to_serializable();
1633 let json = serde_json::to_string(&serializable.continuation).unwrap();
1635 assert!(
1636 !json.contains("timeout_ms"),
1637 "timeout_ms should be absent when None: {json}"
1638 );
1639 }
1640}
1641
1642#[cfg(test)]
1643#[allow(
1644 clippy::unwrap_used,
1645 clippy::expect_used,
1646 clippy::panic,
1647 clippy::indexing_slicing
1648)]
1649mod proptests {
1650 use super::SerializableContinuation;
1651 use proptest::prelude::*;
1652
1653 fn arb_id() -> impl Strategy<Value = String> {
1655 "[a-z0-9]{1,8}"
1656 }
1657
1658 fn arb_continuation(depth: usize) -> BoxedStrategy<SerializableContinuation> {
1660 let leaf = arb_id().prop_map(|id| SerializableContinuation::Task {
1661 id,
1662 timeout_ms: None,
1663 retry_policy: None,
1664 next: None,
1665 });
1666
1667 if depth == 0 {
1668 return leaf.boxed();
1669 }
1670
1671 prop_oneof![
1672 (
1674 arb_id(),
1675 prop::option::of(any::<u64>()),
1676 prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
1677 )
1678 .prop_map(|(id, timeout_ms, next)| SerializableContinuation::Task {
1679 id,
1680 timeout_ms,
1681 retry_policy: None,
1682 next,
1683 }),
1684 (
1686 arb_id(),
1687 prop::collection::vec(arb_continuation(depth - 1), 0..3),
1688 prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
1689 )
1690 .prop_map(|(id, branches, join)| SerializableContinuation::Fork {
1691 id,
1692 branches,
1693 join,
1694 }),
1695 (
1697 arb_id(),
1698 any::<u64>(),
1699 prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
1700 )
1701 .prop_map(|(id, duration_ms, next)| SerializableContinuation::Delay {
1702 id,
1703 duration_ms,
1704 next,
1705 }),
1706 (
1708 arb_id(),
1709 arb_id(),
1710 prop::option::of(any::<u64>()),
1711 prop::option::of(arb_continuation(depth - 1).prop_map(Box::new)),
1712 )
1713 .prop_map(|(id, signal_name, timeout_ms, next)| {
1714 SerializableContinuation::AwaitSignal {
1715 id,
1716 signal_name,
1717 timeout_ms,
1718 next,
1719 }
1720 }),
1721 ]
1722 .boxed()
1723 }
1724
1725 fn arb_unique_continuation(
1729 depth: usize,
1730 prefix: &str,
1731 ) -> BoxedStrategy<SerializableContinuation> {
1732 let id = format!("{prefix}n");
1733
1734 if depth == 0 {
1735 return Just(SerializableContinuation::Task {
1736 id,
1737 timeout_ms: None,
1738 retry_policy: None,
1739 next: None,
1740 })
1741 .boxed();
1742 }
1743
1744 let id_clone = id.clone();
1745 prop_oneof![
1746 prop::option::of(
1748 arb_unique_continuation(depth - 1, &format!("{prefix}0_")).prop_map(Box::new),
1749 )
1750 .prop_map(move |next| SerializableContinuation::Task {
1751 id: id_clone.clone(),
1752 timeout_ms: None,
1753 retry_policy: None,
1754 next,
1755 }),
1756 {
1758 let id_f = id.clone();
1759 let prefix_f = prefix.to_string();
1760 (0..3u8)
1761 .prop_flat_map(move |branch_count| {
1762 let id_inner = id_f.clone();
1763 let prefix_inner = prefix_f.clone();
1764 let branches: Vec<BoxedStrategy<SerializableContinuation>> = (0
1765 ..branch_count)
1766 .map(|i| {
1767 arb_unique_continuation(depth - 1, &format!("{prefix_inner}b{i}_"))
1768 })
1769 .collect();
1770 let join = prop::option::of(
1771 arb_unique_continuation(depth - 1, &format!("{prefix_inner}j_"))
1772 .prop_map(Box::new),
1773 );
1774 (branches, join).prop_map(move |(branches, join)| {
1775 SerializableContinuation::Fork {
1776 id: id_inner.clone(),
1777 branches,
1778 join,
1779 }
1780 })
1781 })
1782 .boxed()
1783 },
1784 {
1786 let id_d = id.clone();
1787 let prefix_d = prefix.to_string();
1788 (
1789 any::<u64>(),
1790 prop::option::of(
1791 arb_unique_continuation(depth - 1, &format!("{prefix_d}d_"))
1792 .prop_map(Box::new),
1793 ),
1794 )
1795 .prop_map(move |(duration_ms, next)| {
1796 SerializableContinuation::Delay {
1797 id: id_d.clone(),
1798 duration_ms,
1799 next,
1800 }
1801 })
1802 },
1803 {
1805 let id_s = id;
1806 let prefix_s = prefix.to_string();
1807 (
1808 arb_id(),
1809 prop::option::of(any::<u64>()),
1810 prop::option::of(
1811 arb_unique_continuation(depth - 1, &format!("{prefix_s}s_"))
1812 .prop_map(Box::new),
1813 ),
1814 )
1815 .prop_map(move |(signal_name, timeout_ms, next)| {
1816 SerializableContinuation::AwaitSignal {
1817 id: id_s.clone(),
1818 signal_name,
1819 timeout_ms,
1820 next,
1821 }
1822 })
1823 },
1824 ]
1825 .boxed()
1826 }
1827
1828 fn collect_ids(cont: &SerializableContinuation) -> Vec<String> {
1830 let mut ids = Vec::new();
1831 fn walk(c: &SerializableContinuation, out: &mut Vec<String>) {
1832 match c {
1833 SerializableContinuation::Task { id, next, .. }
1834 | SerializableContinuation::Delay { id, next, .. }
1835 | SerializableContinuation::AwaitSignal { id, next, .. } => {
1836 out.push(id.clone());
1837 if let Some(n) = next {
1838 walk(n, out);
1839 }
1840 }
1841 SerializableContinuation::Fork { id, branches, join } => {
1842 out.push(id.clone());
1843 for b in branches {
1844 walk(b, out);
1845 }
1846 if let Some(j) = join {
1847 walk(j, out);
1848 }
1849 }
1850 }
1851 }
1852 walk(cont, &mut ids);
1853 ids
1854 }
1855
1856 fn inject_duplicate(cont: &SerializableContinuation, dup_id: &str) -> SerializableContinuation {
1858 match cont {
1859 SerializableContinuation::Task {
1860 timeout_ms,
1861 retry_policy,
1862 next,
1863 ..
1864 } => SerializableContinuation::Task {
1865 id: dup_id.to_string(),
1866 timeout_ms: *timeout_ms,
1867 retry_policy: retry_policy.clone(),
1868 next: next.clone(),
1869 },
1870 SerializableContinuation::Fork { branches, join, .. } => {
1871 SerializableContinuation::Fork {
1872 id: dup_id.to_string(),
1873 branches: branches.clone(),
1874 join: join.clone(),
1875 }
1876 }
1877 SerializableContinuation::Delay {
1878 duration_ms, next, ..
1879 } => SerializableContinuation::Delay {
1880 id: dup_id.to_string(),
1881 duration_ms: *duration_ms,
1882 next: next.clone(),
1883 },
1884 SerializableContinuation::AwaitSignal {
1885 signal_name,
1886 timeout_ms,
1887 next,
1888 ..
1889 } => SerializableContinuation::AwaitSignal {
1890 id: dup_id.to_string(),
1891 signal_name: signal_name.clone(),
1892 timeout_ms: *timeout_ms,
1893 next: next.clone(),
1894 },
1895 }
1896 }
1897
1898 proptest! {
1899 #[test]
1901 fn hash_is_deterministic(cont in arb_continuation(3)) {
1902 let h1 = cont.compute_definition_hash();
1903 let h2 = cont.compute_definition_hash();
1904 prop_assert_eq!(h1, h2);
1905 }
1906
1907 #[test]
1909 fn serde_roundtrip_preserves_hash(cont in arb_continuation(3)) {
1910 let original_hash = cont.compute_definition_hash();
1911 let json = serde_json::to_string(&cont).unwrap();
1912 let recovered: SerializableContinuation = serde_json::from_str(&json).unwrap();
1913 prop_assert_eq!(original_hash, recovered.compute_definition_hash());
1914 }
1915
1916 #[test]
1918 fn unique_ids_means_none(cont in arb_unique_continuation(3, "r_")) {
1919 prop_assert!(cont.find_duplicate_id().is_none());
1920 }
1921
1922 #[test]
1924 fn injected_duplicate_is_detected(cont in arb_unique_continuation(3, "r_")) {
1925 let ids = collect_ids(&cont);
1926 if ids.len() >= 2 {
1928 let dup_id = &ids[1];
1930 let tampered = inject_duplicate(&cont, dup_id);
1931 prop_assert!(tampered.find_duplicate_id().is_some());
1932 }
1933 }
1934 }
1935}