Skip to main content

sayiir_core/
builder.rs

1use crate::codec::Codec;
2use crate::codec::sealed;
3use crate::context::WorkflowContext;
4use crate::error::WorkflowError;
5use crate::registry::TaskRegistry;
6use crate::task::{
7    BranchOutputs, ErasedBranch, branch, to_core_task_arc, to_heterogeneous_join_task_arc,
8};
9use crate::workflow::{SerializableWorkflow, Workflow, WorkflowContinuation};
10use std::marker::PhantomData;
11use std::sync::Arc;
12
13/// Marker type for empty continuation (no tasks yet).
14pub struct NoContinuation;
15
16/// Marker type for no registry (non-serializable workflow).
17pub struct NoRegistry;
18
19/// Trait for continuation state - allows unified handling of empty vs existing continuation.
20pub trait ContinuationState {
21    /// Append a new task to this continuation state, returning a `WorkflowContinuation`.
22    fn append(self, new_task: WorkflowContinuation) -> WorkflowContinuation;
23}
24
25impl ContinuationState for NoContinuation {
26    fn append(self, new_task: WorkflowContinuation) -> WorkflowContinuation {
27        new_task
28    }
29}
30
31impl ContinuationState for WorkflowContinuation {
32    fn append(mut self, new_task: WorkflowContinuation) -> WorkflowContinuation {
33        append_to_chain(&mut self, new_task);
34        self
35    }
36}
37
38/// Trait for registry behavior - allows unified implementation of builder methods.
39pub trait RegistryBehavior {
40    /// Register a task (no-op for `NoRegistry`, actual registration for `TaskRegistry`).
41    fn maybe_register<I, O, F, Fut, C>(&mut self, _id: &str, _codec: Arc<C>, _func: &Arc<F>)
42    where
43        F: Fn(I) -> Fut + Send + Sync + 'static,
44        I: Send + 'static,
45        O: Send + 'static,
46        Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
47        C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static;
48
49    /// Register a join task (no-op for `NoRegistry`, actual registration for `TaskRegistry`).
50    fn maybe_register_join<O, F, Fut, C>(&mut self, _id: &str, _codec: Arc<C>, _func: &Arc<F>)
51    where
52        F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
53        O: Send + 'static,
54        Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
55        C: Codec
56            + sealed::EncodeValue<O>
57            + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
58            + Send
59            + Sync
60            + 'static;
61}
62
63impl RegistryBehavior for NoRegistry {
64    fn maybe_register<I, O, F, Fut, C>(&mut self, _id: &str, _codec: Arc<C>, _func: &Arc<F>)
65    where
66        F: Fn(I) -> Fut + Send + Sync + 'static,
67        I: Send + 'static,
68        O: Send + 'static,
69        Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
70        C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
71    {
72        // No-op for non-serializable workflows
73    }
74
75    fn maybe_register_join<O, F, Fut, C>(&mut self, _id: &str, _codec: Arc<C>, _func: &Arc<F>)
76    where
77        F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
78        O: Send + 'static,
79        Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
80        C: Codec
81            + sealed::EncodeValue<O>
82            + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
83            + Send
84            + Sync
85            + 'static,
86    {
87        // No-op for non-serializable workflows
88    }
89}
90
91impl RegistryBehavior for TaskRegistry {
92    fn maybe_register<I, O, F, Fut, C>(&mut self, id: &str, codec: Arc<C>, func: &Arc<F>)
93    where
94        F: Fn(I) -> Fut + Send + Sync + 'static,
95        I: Send + 'static,
96        O: Send + 'static,
97        Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
98        C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O> + 'static,
99    {
100        use crate::task::TaskMetadata;
101        self.register_fn_arc(id, codec, Arc::clone(func), TaskMetadata::default());
102    }
103
104    fn maybe_register_join<O, F, Fut, C>(&mut self, id: &str, codec: Arc<C>, func: &Arc<F>)
105    where
106        F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
107        O: Send + 'static,
108        Fut: std::future::Future<Output = Result<O, crate::error::BoxError>> + Send + 'static,
109        C: Codec
110            + sealed::EncodeValue<O>
111            + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
112            + Send
113            + Sync
114            + 'static,
115    {
116        use crate::task::TaskMetadata;
117        self.register_arc_join(id, codec, Arc::clone(func), TaskMetadata::default());
118    }
119}
120
121pub struct WorkflowBuilder<C, Input, Output, M = (), Cont = NoContinuation, R = NoRegistry> {
122    context: WorkflowContext<C, M>,
123    continuation: Cont,
124    registry: R,
125    last_task_id: Option<String>,
126    _phantom: PhantomData<(Input, Output)>,
127}
128
129#[allow(clippy::mismatching_type_param_order)] // Input used for both Input and Output initially
130impl<C, Input, M> WorkflowBuilder<C, Input, Input, M, NoContinuation, NoRegistry> {
131    /// Create a new workflow builder with a context object.
132    ///
133    /// The context contains the workflow ID, codec and metadata that will be available
134    /// at any execution point via the `sayiir_ctx!` macro.
135    #[must_use]
136    pub fn new(ctx: WorkflowContext<C, M>) -> Self
137    where
138        C: Codec,
139        M: Send + Sync + 'static,
140    {
141        Self {
142            context: ctx,
143            continuation: NoContinuation,
144            registry: NoRegistry,
145            last_task_id: None,
146            _phantom: PhantomData,
147        }
148    }
149
150    /// Enable registry tracking for serializable workflows with a new empty registry.
151    ///
152    /// # Example
153    ///
154    /// ```rust
155    /// # use sayiir_core::prelude::*;
156    /// # use sayiir_core::codec::{Encoder, Decoder, sealed};
157    /// # use bytes::Bytes;
158    /// # use std::sync::Arc;
159    /// # struct MyCodec;
160    /// # impl Encoder for MyCodec {}
161    /// # impl Decoder for MyCodec {}
162    /// # impl<T> sealed::EncodeValue<T> for MyCodec {
163    /// #     fn encode_value(&self, _: &T) -> Result<Bytes, BoxError> { Ok(Bytes::new()) }
164    /// # }
165    /// # impl<T> sealed::DecodeValue<T> for MyCodec {
166    /// #     fn decode_value(&self, _: Bytes) -> Result<T, BoxError> { Err("dummy".into()) }
167    /// # }
168    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
169    /// # let codec = Arc::new(MyCodec);
170    /// # let metadata = Arc::new(());
171    /// use sayiir_core::prelude::*;
172    /// use std::time::Duration;
173    ///
174    /// let ctx = WorkflowContext::new("my-workflow", codec, metadata);
175    /// let workflow = WorkflowBuilder::new(ctx)
176    ///     .with_registry()  // Enable serialization
177    ///     .then("step1", |i: u32| async move { Ok(i + 1) })
178    ///     .with_metadata(TaskMetadata {
179    ///         display_name: Some("Increment".into()),
180    ///         timeout: Some(Duration::from_secs(30)),
181    ///         ..Default::default()
182    ///     })
183    ///     .build()?;  // Returns SerializableWorkflow
184    /// # Ok(())
185    /// # }
186    /// ```
187    #[must_use]
188    pub fn with_registry(
189        self,
190    ) -> WorkflowBuilder<C, Input, Input, M, NoContinuation, TaskRegistry> {
191        self.with_existing_registry(TaskRegistry::new())
192    }
193
194    /// Enable registry tracking with an existing registry.
195    ///
196    /// Use this to reference pre-registered tasks via [`then_registered`] or to
197    /// compose workflows from task libraries.
198    ///
199    /// **Note**: Takes ownership of the registry. For deserialization/hydration,
200    /// rebuild the same registry from code on the deserializing side.
201    /// See [`TaskRegistry`](crate::registry::TaskRegistry) docs for the pattern.
202    ///
203    /// # Example
204    ///
205    /// ```rust
206    /// # use sayiir_core::prelude::*;
207    /// # use sayiir_core::codec::{Encoder, Decoder, sealed};
208    /// # use sayiir_core::workflow::SerializableContinuation;
209    /// # use bytes::Bytes;
210    /// # use std::sync::Arc;
211    /// # struct MyCodec;
212    /// # impl Encoder for MyCodec {}
213    /// # impl Decoder for MyCodec {}
214    /// # impl<T> sealed::EncodeValue<T> for MyCodec {
215    /// #     fn encode_value(&self, _: &T) -> Result<Bytes, BoxError> { Ok(Bytes::new()) }
216    /// # }
217    /// # impl<T> sealed::DecodeValue<T> for MyCodec {
218    /// #     fn decode_value(&self, _: Bytes) -> Result<T, BoxError> { Err("dummy".into()) }
219    /// # }
220    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
221    /// # let codec = Arc::new(MyCodec);
222    /// # let metadata = Arc::new(());
223    /// // Shared function for building registry (called on both sides)
224    /// fn build_registry(codec: Arc<MyCodec>) -> TaskRegistry {
225    ///     let mut registry = TaskRegistry::new();
226    ///     registry.register_fn("step1", codec.clone(), |i: u32| async move { Ok(i + 1) });
227    ///     registry
228    /// }
229    ///
230    /// // Build workflow
231    /// let registry = build_registry(codec.clone());
232    /// let ctx = WorkflowContext::new("my-workflow", codec.clone(), metadata);
233    /// let workflow: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx)
234    ///     .with_existing_registry(registry)
235    ///     .then_registered::<u32>("step1")?
236    ///     .build()?;
237    ///
238    /// // Deserialize (on another side): rebuild registry, then convert to runnable
239    /// let registry = build_registry(codec.clone());
240    /// let serializable = workflow.continuation().to_serializable();
241    /// let runnable = serializable.to_runnable(&registry)?;
242    /// # Ok(())
243    /// # }
244    /// ```
245    #[must_use]
246    pub fn with_existing_registry(
247        self,
248        registry: TaskRegistry,
249    ) -> WorkflowBuilder<C, Input, Input, M, NoContinuation, TaskRegistry> {
250        WorkflowBuilder {
251            context: self.context,
252            continuation: NoContinuation,
253            registry,
254            last_task_id: None,
255            _phantom: PhantomData,
256        }
257    }
258}
259
260/// Methods for adding tasks - unified implementation using `RegistryBehavior` and `ContinuationState`.
261impl<C, Input, Output, M, Cont, R> WorkflowBuilder<C, Input, Output, M, Cont, R>
262where
263    R: RegistryBehavior,
264    Cont: ContinuationState,
265{
266    /// Add a sequential task to the workflow.
267    pub fn then<F, Fut, NewOutput>(
268        mut self,
269        id: &str,
270        func: F,
271    ) -> WorkflowBuilder<C, Input, NewOutput, M, WorkflowContinuation, R>
272    where
273        F: Fn(Output) -> Fut + Send + Sync + 'static,
274        Output: Send + 'static,
275        NewOutput: Send + 'static,
276        Fut: std::future::Future<Output = Result<NewOutput, crate::error::BoxError>>
277            + Send
278            + 'static,
279        C: Codec + sealed::DecodeValue<Output> + sealed::EncodeValue<NewOutput> + 'static,
280    {
281        let codec = Arc::clone(&self.context.codec);
282        let func = Arc::new(func);
283
284        // Register if registry is enabled (no-op for NoRegistry)
285        self.registry
286            .maybe_register::<Output, NewOutput, _, _, _>(id, codec.clone(), &func);
287
288        let task = to_core_task_arc(func, codec);
289
290        let new_task = WorkflowContinuation::Task {
291            id: id.to_string(),
292            func: Some(task),
293            timeout: None,
294            retry_policy: None,
295            next: None,
296        };
297
298        let continuation = self.continuation.append(new_task);
299
300        WorkflowBuilder {
301            continuation,
302            context: self.context,
303            registry: self.registry,
304            last_task_id: Some(id.to_string()),
305            _phantom: PhantomData,
306        }
307    }
308}
309
310/// Delay method — available for all registry/continuation combinations.
311impl<C, Input, Output, M, Cont, R> WorkflowBuilder<C, Input, Output, M, Cont, R>
312where
313    Cont: ContinuationState,
314{
315    /// Add a durable delay to the workflow.
316    ///
317    /// The delay is transparent to data flow — the input passes through unchanged.
318    /// In non-durable runners the delay is a simple sleep. In durable runners
319    /// the workflow parks at the delay, persists `wake_at`, and returns
320    /// `WorkflowStatus::Waiting`. A later `resume()` call advances past the
321    /// delay once the wall clock reaches `wake_at`.
322    #[must_use]
323    pub fn delay(
324        self,
325        id: &str,
326        duration: std::time::Duration,
327    ) -> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, R> {
328        let new_node = WorkflowContinuation::Delay {
329            id: id.to_string(),
330            duration,
331            next: None,
332        };
333        let continuation = self.continuation.append(new_node);
334        WorkflowBuilder {
335            continuation,
336            context: self.context,
337            registry: self.registry,
338            last_task_id: Some(id.to_string()),
339            _phantom: PhantomData,
340        }
341    }
342
343    /// Wait for a named external signal before continuing.
344    ///
345    /// The signal payload (if any) becomes the input to the next step.
346    /// If a timeout is specified and expires before a signal arrives,
347    /// `None` is passed as the payload.
348    #[must_use]
349    pub fn wait_for_signal(
350        self,
351        id: &str,
352        signal_name: &str,
353        timeout: Option<std::time::Duration>,
354    ) -> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, R> {
355        let new_node = WorkflowContinuation::AwaitSignal {
356            id: id.to_string(),
357            signal_name: signal_name.to_string(),
358            timeout,
359            next: None,
360        };
361        let continuation = self.continuation.append(new_node);
362        WorkflowBuilder {
363            continuation,
364            context: self.context,
365            registry: self.registry,
366            last_task_id: Some(id.to_string()),
367            _phantom: PhantomData,
368        }
369    }
370}
371
372/// Methods for referencing pre-registered tasks (only available with `TaskRegistry`).
373impl<C, Input, Output, M, Cont> WorkflowBuilder<C, Input, Output, M, Cont, TaskRegistry>
374where
375    Cont: ContinuationState,
376{
377    /// Reference a pre-registered task by ID.
378    ///
379    /// The task must have been registered in the registry before calling this method.
380    /// Type safety is maintained through the `NewOutput` type parameter - ensure it
381    /// matches the registered task's output type.
382    ///
383    /// # Errors
384    ///
385    /// Returns `WorkflowError::TaskNotFound` if the task ID is not in the registry.
386    ///
387    /// # Example
388    ///
389    /// ```rust
390    /// # use sayiir_core::prelude::*;
391    /// # use sayiir_core::codec::{Encoder, Decoder, sealed};
392    /// # use bytes::Bytes;
393    /// # use std::sync::Arc;
394    /// # struct MyCodec;
395    /// # impl Encoder for MyCodec {}
396    /// # impl Decoder for MyCodec {}
397    /// # impl<T> sealed::EncodeValue<T> for MyCodec {
398    /// #     fn encode_value(&self, _: &T) -> Result<Bytes, BoxError> { Ok(Bytes::new()) }
399    /// # }
400    /// # impl<T> sealed::DecodeValue<T> for MyCodec {
401    /// #     fn decode_value(&self, _: Bytes) -> Result<T, BoxError> { Err("dummy".into()) }
402    /// # }
403    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
404    /// # let codec = Arc::new(MyCodec);
405    /// # let ctx = WorkflowContext::new("my-workflow", codec.clone(), Arc::new(()));
406    /// use sayiir_core::prelude::*;
407    ///
408    /// let mut registry = TaskRegistry::new();
409    /// registry.register_fn("double", codec.clone(), |i: u32| async move { Ok(i * 2) });
410    ///
411    /// let workflow: SerializableWorkflow<_, u32> = WorkflowBuilder::new(ctx)
412    ///     .with_existing_registry(registry)
413    ///     .then_registered::<u32>("double")?
414    ///     .with_metadata(TaskMetadata {
415    ///         display_name: Some("Double".into()),
416    ///         ..Default::default()
417    ///     })
418    ///     .build()?;
419    /// # Ok(())
420    /// # }
421    /// ```
422    pub fn then_registered<NewOutput>(
423        self,
424        id: &str,
425    ) -> Result<
426        WorkflowBuilder<C, Input, NewOutput, M, WorkflowContinuation, TaskRegistry>,
427        WorkflowError,
428    >
429    where
430        Output: Send + 'static,
431        NewOutput: Send + 'static,
432    {
433        let func = self
434            .registry
435            .get(id)
436            .ok_or_else(|| WorkflowError::TaskNotFound(id.to_string()))?;
437        let meta = self.registry.get_metadata(id);
438        let timeout = meta.and_then(|m| m.timeout);
439        let retry_policy = self
440            .registry
441            .get_metadata(id)
442            .and_then(|m| m.retries.clone());
443
444        let new_task = WorkflowContinuation::Task {
445            id: id.to_string(),
446            func: Some(func),
447            timeout,
448            retry_policy,
449            next: None,
450        };
451
452        let continuation = self.continuation.append(new_task);
453
454        Ok(WorkflowBuilder {
455            continuation,
456            context: self.context,
457            registry: self.registry,
458            last_task_id: Some(id.to_string()),
459            _phantom: PhantomData,
460        })
461    }
462}
463
464/// Metadata attachment — only available after a task has been added.
465impl<C, Input, Output, M> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, TaskRegistry> {
466    /// Attach metadata to the most recently added task.
467    ///
468    /// This method allows chaining metadata after `then()`, `then_registered()`,
469    /// or `join()` calls.
470    ///
471    /// # Example
472    ///
473    /// ```rust
474    /// # use sayiir_core::prelude::*;
475    /// # use sayiir_core::codec::{Encoder, Decoder, sealed};
476    /// # use bytes::Bytes;
477    /// # use std::sync::Arc;
478    /// # struct MyCodec;
479    /// # impl Encoder for MyCodec {}
480    /// # impl Decoder for MyCodec {}
481    /// # impl<T> sealed::EncodeValue<T> for MyCodec {
482    /// #     fn encode_value(&self, _: &T) -> Result<Bytes, BoxError> { Ok(Bytes::new()) }
483    /// # }
484    /// # impl<T> sealed::DecodeValue<T> for MyCodec {
485    /// #     fn decode_value(&self, _: Bytes) -> Result<T, BoxError> { Err("dummy".into()) }
486    /// # }
487    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
488    /// # let codec = Arc::new(MyCodec);
489    /// # let ctx = WorkflowContext::new("my-workflow", codec, Arc::new(()));
490    /// use sayiir_core::prelude::*;
491    /// use sayiir_core::task::RetryPolicy;
492    /// use std::time::Duration;
493    ///
494    /// let workflow = WorkflowBuilder::new(ctx)
495    ///     .with_registry()
496    ///     .then("double", |i: u32| async move { Ok(i * 2) })
497    ///     .with_metadata(TaskMetadata {
498    ///         display_name: Some("Double".into()),
499    ///         timeout: Some(Duration::from_secs(30)),
500    ///         ..Default::default()
501    ///     })
502    ///     .then("add_ten", |i: u32| async move { Ok(i + 10) })
503    ///     .build()?;
504    /// # Ok(())
505    /// # }
506    /// ```
507    #[must_use]
508    pub fn with_metadata(mut self, metadata: crate::task::TaskMetadata) -> Self {
509        if let Some(ref id) = self.last_task_id {
510            let timeout = metadata.timeout;
511            let retry_policy = metadata.retries.clone();
512            self.registry.set_metadata(id, metadata);
513            // Also update the timeout and retry policy on the continuation node
514            // so they're available for direct execution (not just the serializable roundtrip path).
515            self.continuation.set_task_timeout(id, timeout);
516            self.continuation.set_task_retry_policy(id, retry_policy);
517        }
518        self
519    }
520}
521
522/// Fork methods - unified implementation.
523impl<C, Input, Output, M, Cont, R> WorkflowBuilder<C, Input, Output, M, Cont, R> {
524    /// Fork the workflow into multiple parallel branches with heterogeneous outputs.
525    ///
526    /// Each branch receives the same input (the current workflow's output) and executes in parallel.
527    /// Branches can return different types. After all branches complete, use `join()` to combine
528    /// the results using `BranchOutputs` for type-safe named access.
529    ///
530    /// # Example
531    ///
532    /// ```rust,ignore
533    /// use sayiir_core::task::TaskMetadata;
534    ///
535    /// workflow
536    ///     .then("prepare", |input| async { Ok(input) })
537    ///     .with_metadata(TaskMetadata {
538    ///         display_name: Some("Prepare Input".into()),
539    ///         ..Default::default()
540    ///     })
541    ///     .branches(|b| {
542    ///         b.add("count", |i: u32| async move { Ok(i * 2) });
543    ///         b.add("name", |i: u32| async move { Ok(format!("item_{}", i)) });
544    ///     })
545    ///     .join("combine", |outputs: BranchOutputs<_>| async move {
546    ///         let count: u32 = outputs.get("count")?;
547    ///         let name: String = outputs.get("name")?;
548    ///         Ok(format!("{}: {}", name, count))
549    ///     })
550    ///     .with_metadata(TaskMetadata {
551    ///         display_name: Some("Combine Results".into()),
552    ///         ..Default::default()
553    ///     })
554    /// ```
555    pub fn branches<F>(self, f: F) -> ForkBuilder<C, Input, Output, M, Cont, R>
556    where
557        F: FnOnce(&mut BranchCollector<C, Output>),
558        C: Codec,
559    {
560        let codec = Arc::clone(&self.context.codec);
561        let mut collector = BranchCollector {
562            codec,
563            branches: Vec::new(),
564            _phantom: PhantomData,
565        };
566        f(&mut collector);
567
568        ForkBuilder {
569            context: self.context,
570            continuation: self.continuation,
571            branches: collector.branches,
572            registry: self.registry,
573            _phantom: PhantomData,
574        }
575    }
576
577    /// Fork the workflow into multiple parallel branches (low-level API).
578    pub fn fork(self) -> ForkBuilder<C, Input, Output, M, Cont, R> {
579        ForkBuilder {
580            context: self.context,
581            continuation: self.continuation,
582            branches: Vec::new(),
583            registry: self.registry,
584            _phantom: PhantomData,
585        }
586    }
587}
588
589/// Build method for `WorkflowBuilder` without registry - returns Workflow.
590impl<C, Input, Output, M> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, NoRegistry> {
591    /// Build the workflow into an executable workflow.
592    ///
593    /// # Errors
594    ///
595    /// Returns an error if duplicate task IDs are found.
596    pub fn build(self) -> Result<Workflow<C, Input, M>, WorkflowError>
597    where
598        Input: Send + 'static,
599        Output: Send + 'static,
600        M: Send + Sync + 'static,
601        C: Codec
602            + sealed::DecodeValue<Input>
603            + sealed::DecodeValue<Output>
604            + sealed::EncodeValue<Input>
605            + sealed::EncodeValue<Output>,
606    {
607        if let Some(dup) = self.continuation.find_duplicate_id() {
608            return Err(WorkflowError::DuplicateTaskId(dup));
609        }
610
611        let definition_hash = self
612            .continuation
613            .to_serializable()
614            .compute_definition_hash();
615
616        Ok(Workflow {
617            definition_hash,
618            continuation: self.continuation,
619            context: self.context,
620            _phantom: PhantomData,
621        })
622    }
623}
624
625/// Build method for `WorkflowBuilder` with registry - returns `SerializableWorkflow`.
626impl<C, Input, Output, M> WorkflowBuilder<C, Input, Output, M, WorkflowContinuation, TaskRegistry> {
627    /// Build the workflow into a serializable workflow.
628    ///
629    /// # Errors
630    ///
631    /// Returns an error if duplicate task IDs are found.
632    pub fn build(self) -> Result<SerializableWorkflow<C, Input, M>, WorkflowError>
633    where
634        Input: Send + 'static,
635        Output: Send + 'static,
636        M: Send + Sync + 'static,
637        C: Codec
638            + sealed::DecodeValue<Input>
639            + sealed::DecodeValue<Output>
640            + sealed::EncodeValue<Input>
641            + sealed::EncodeValue<Output>,
642    {
643        if let Some(dup) = self.continuation.find_duplicate_id() {
644            return Err(WorkflowError::DuplicateTaskId(dup));
645        }
646
647        let definition_hash = self
648            .continuation
649            .to_serializable()
650            .compute_definition_hash();
651
652        let inner = Workflow {
653            definition_hash,
654            continuation: self.continuation,
655            context: self.context,
656            _phantom: PhantomData,
657        };
658
659        Ok(SerializableWorkflow {
660            inner,
661            registry: self.registry,
662        })
663    }
664}
665
666/// Helper function to append a task to the continuation chain.
667fn append_to_chain(continuation: &mut WorkflowContinuation, new_task: WorkflowContinuation) {
668    match continuation {
669        WorkflowContinuation::Task { next, .. }
670        | WorkflowContinuation::Delay { next, .. }
671        | WorkflowContinuation::AwaitSignal { next, .. } => match next {
672            Some(next_box) => append_to_chain(next_box, new_task),
673            None => *next = Some(Box::new(new_task)),
674        },
675        WorkflowContinuation::Fork { join, .. } => match join {
676            Some(join_box) => append_to_chain(join_box, new_task),
677            None => *join = Some(Box::new(new_task)),
678        },
679    }
680}
681
682/// Collector for adding branches in a closure.
683///
684/// Used by [`WorkflowBuilder::branches`] to collect multiple branches.
685pub struct BranchCollector<C, Input> {
686    codec: Arc<C>,
687    branches: Vec<ErasedBranch>,
688    _phantom: PhantomData<Input>,
689}
690
691impl<C, Input> BranchCollector<C, Input> {
692    /// Add a branch to the fork.
693    ///
694    /// Each branch receives the same input and can return a different output type.
695    /// Duplicate IDs are checked at `build()` time.
696    pub fn add<F, Fut, BranchOutput>(&mut self, id: &str, func: F)
697    where
698        F: Fn(Input) -> Fut + Send + Sync + 'static,
699        Input: Send + 'static,
700        BranchOutput: Send + 'static,
701        Fut: std::future::Future<Output = Result<BranchOutput, crate::error::BoxError>>
702            + Send
703            + 'static,
704        C: Codec + sealed::DecodeValue<Input> + sealed::EncodeValue<BranchOutput>,
705    {
706        let erased = branch(id, func).erase(Arc::clone(&self.codec));
707        self.branches.push(erased);
708    }
709}
710
711/// Builder for constructing fork branches fluently.
712///
713/// Created by calling `.fork()` on a `WorkflowBuilder`. Add branches with `.branch()`,
714/// then complete with `.join()`.
715pub struct ForkBuilder<C, Input, Output, M, Cont = NoContinuation, R = NoRegistry> {
716    context: WorkflowContext<C, M>,
717    continuation: Cont,
718    branches: Vec<ErasedBranch>,
719    registry: R,
720    _phantom: PhantomData<(Input, Output)>,
721}
722
723/// For `ForkBuilder` methods - unified implementation using `RegistryBehavior` and `ContinuationState`.
724impl<C, Input, Output, M, Cont, R> ForkBuilder<C, Input, Output, M, Cont, R>
725where
726    R: RegistryBehavior,
727    Cont: ContinuationState,
728{
729    /// Add a branch to the fork.
730    ///
731    /// # Returns
732    ///
733    /// Returns a new `ForkBuilder` with the branch added.
734    ///
735    #[must_use]
736    pub fn branch<F, Fut, BranchOutput>(mut self, id: &str, func: F) -> Self
737    where
738        F: Fn(Output) -> Fut + Send + Sync + 'static,
739        Output: Send + 'static,
740        BranchOutput: Send + 'static,
741        Fut: std::future::Future<Output = Result<BranchOutput, crate::error::BoxError>>
742            + Send
743            + 'static,
744        C: Codec + sealed::DecodeValue<Output> + sealed::EncodeValue<BranchOutput> + 'static,
745    {
746        let codec = Arc::clone(&self.context.codec);
747        let func = Arc::new(func);
748
749        // Register if registry is enabled (no-op for NoRegistry)
750        self.registry
751            .maybe_register::<Output, BranchOutput, _, _, _>(id, codec.clone(), &func);
752
753        // Create branch using a closure that calls through the Arc
754        let func_clone = Arc::clone(&func);
755        let erased = branch(id, move |input| func_clone(input)).erase(codec);
756        self.branches.push(erased);
757        self
758    }
759
760    /// Join the results from all branches.
761    pub fn join<F, Fut, JoinOutput>(
762        mut self,
763        id: &str,
764        func: F,
765    ) -> WorkflowBuilder<C, Input, JoinOutput, M, WorkflowContinuation, R>
766    where
767        F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
768        JoinOutput: Send + 'static,
769        Fut: std::future::Future<Output = Result<JoinOutput, crate::error::BoxError>>
770            + Send
771            + 'static,
772        C: Codec
773            + sealed::EncodeValue<JoinOutput>
774            + sealed::DecodeValue<crate::branch_results::NamedBranchResults>
775            + Send
776            + Sync
777            + 'static,
778    {
779        let codec = Arc::clone(&self.context.codec);
780        let func = Arc::new(func);
781
782        // Register if registry is enabled (no-op for NoRegistry)
783        self.registry
784            .maybe_register_join::<JoinOutput, _, _, _>(id, codec.clone(), &func);
785
786        let join_task_fn = to_heterogeneous_join_task_arc(func, codec);
787
788        let fork_id = WorkflowContinuation::derive_fork_id(
789            &self
790                .branches
791                .iter()
792                .map(|b| b.id.as_str())
793                .collect::<Vec<_>>(),
794        );
795
796        let branch_continuations: Vec<Arc<WorkflowContinuation>> = self
797            .branches
798            .into_iter()
799            .map(|b| {
800                Arc::new(WorkflowContinuation::Task {
801                    id: b.id,
802                    func: Some(b.task),
803                    timeout: None,
804                    retry_policy: None,
805                    next: None,
806                })
807            })
808            .collect();
809
810        //
811
812        let join_task = WorkflowContinuation::Task {
813            id: id.to_string(),
814            func: Some(join_task_fn),
815            timeout: None,
816            retry_policy: None,
817            next: None,
818        };
819
820        let fork_continuation = WorkflowContinuation::Fork {
821            id: fork_id,
822            branches: branch_continuations.into_boxed_slice(),
823            join: Some(Box::new(join_task)),
824        };
825
826        let continuation = self.continuation.append(fork_continuation);
827
828        WorkflowBuilder {
829            continuation,
830            context: self.context,
831            registry: self.registry,
832            last_task_id: Some(id.to_string()),
833            _phantom: PhantomData,
834        }
835    }
836}
837
838/// For `ForkBuilder` methods for referencing pre-registered tasks (only available with `TaskRegistry`).
839impl<C, Input, Output, M, Cont> ForkBuilder<C, Input, Output, M, Cont, TaskRegistry>
840where
841    Cont: ContinuationState,
842{
843    /// Add a pre-registered branch task by ID.
844    ///
845    /// # Errors
846    ///
847    /// Returns `WorkflowError::TaskNotFound` if the task ID is not in the registry.
848    pub fn branch_registered(mut self, id: &str) -> Result<Self, WorkflowError>
849    where
850        Output: Send + 'static,
851    {
852        let task = self
853            .registry
854            .get(id)
855            .ok_or_else(|| WorkflowError::TaskNotFound(id.to_string()))?;
856
857        self.branches.push(ErasedBranch {
858            id: id.to_string(),
859            task,
860        });
861        Ok(self)
862    }
863
864    /// Join using a pre-registered join task by ID.
865    ///
866    /// # Errors
867    ///
868    /// Returns `WorkflowError::TaskNotFound` if the task ID is not in the registry.
869    pub fn join_registered<JoinOutput>(
870        self,
871        id: &str,
872    ) -> Result<
873        WorkflowBuilder<C, Input, JoinOutput, M, WorkflowContinuation, TaskRegistry>,
874        WorkflowError,
875    >
876    where
877        Output: Send + 'static,
878        JoinOutput: Send + 'static,
879    {
880        let join_task_fn = self
881            .registry
882            .get(id)
883            .ok_or_else(|| WorkflowError::TaskNotFound(id.to_string()))?;
884        let join_timeout = self.registry.get_metadata(id).and_then(|m| m.timeout);
885
886        let fork_id = WorkflowContinuation::derive_fork_id(
887            &self
888                .branches
889                .iter()
890                .map(|b| b.id.as_str())
891                .collect::<Vec<_>>(),
892        );
893
894        let branch_continuations: Vec<Arc<WorkflowContinuation>> = self
895            .branches
896            .into_iter()
897            .map(|b| {
898                let meta = self.registry.get_metadata(&b.id);
899                let timeout = meta.and_then(|m| m.timeout);
900                let retry_policy = self
901                    .registry
902                    .get_metadata(&b.id)
903                    .and_then(|m| m.retries.clone());
904                Arc::new(WorkflowContinuation::Task {
905                    id: b.id,
906                    func: Some(b.task),
907                    timeout,
908                    retry_policy,
909                    next: None,
910                })
911            })
912            .collect();
913
914        let join_retry_policy = self
915            .registry
916            .get_metadata(id)
917            .and_then(|m| m.retries.clone());
918        let join_task = WorkflowContinuation::Task {
919            id: id.to_string(),
920            func: Some(join_task_fn),
921            timeout: join_timeout,
922            retry_policy: join_retry_policy,
923            next: None,
924        };
925
926        let fork_continuation = WorkflowContinuation::Fork {
927            id: fork_id,
928            branches: branch_continuations.into_boxed_slice(),
929            join: Some(Box::new(join_task)),
930        };
931
932        let continuation = self.continuation.append(fork_continuation);
933
934        Ok(WorkflowBuilder {
935            continuation,
936            context: self.context,
937            registry: self.registry,
938            last_task_id: Some(id.to_string()),
939            _phantom: PhantomData,
940        })
941    }
942}