Skip to main content

sayiir_core/
task.rs

1use crate::codec::{Codec, sealed};
2use crate::error::{BoxError, WorkflowError};
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::future::Future;
7use std::marker::PhantomData;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::time::Duration;
12
13/// Metadata associated with a task definition.
14///
15/// This provides optional configuration for task execution behavior,
16/// including display information, timeouts, and retry policies.
17#[derive(Clone, Debug, Default, Serialize, Deserialize)]
18pub struct TaskMetadata {
19    /// Human-readable name for the task (for UI/logging).
20    pub display_name: Option<String>,
21    /// Description of what the task does.
22    pub description: Option<String>,
23    /// Maximum time the task is allowed to run.
24    pub timeout: Option<Duration>,
25    /// Retry policy for failed task executions.
26    pub retries: Option<RetryPolicy>,
27    /// Tags for categorization and filtering.
28    pub tags: Vec<String>,
29}
30
31/// Configuration for retrying failed task executions.
32#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct RetryPolicy {
34    /// Maximum number of retries after the initial attempt.
35    #[serde(alias = "max_attempts")]
36    pub max_retries: u32,
37    /// Initial delay before the first retry.
38    pub initial_delay: Duration,
39    /// Multiplier applied to delay after each retry (for exponential backoff).
40    pub backoff_multiplier: f32,
41    /// Maximum delay between retries (caps exponential growth).
42    #[serde(default)]
43    pub max_delay: Option<Duration>,
44}
45
46pub use crate::branch_results::NamedBranchResults;
47
48/// A type-safe map of branch outputs for heterogeneous fork-join.
49///
50/// Each branch can return a different type. Use `get::<T>(name)` to retrieve
51/// a branch's output with the correct type.
52///
53/// # Example
54///
55/// ```rust,ignore
56/// .join("combine", |outputs: BranchOutputs<MyCodec>| async move {
57///     let count: u32 = outputs.get("counter")?;
58///     let name: String = outputs.get("fetch_name")?;
59///     let items: Vec<Item> = outputs.get("load_items")?;
60///     Ok(format!("{}: {} items for {}", count, items.len(), name))
61/// })
62/// ```
63pub struct BranchOutputs<C> {
64    outputs: HashMap<String, Bytes>,
65    codec: Arc<C>,
66}
67
68impl<C> BranchOutputs<C> {
69    /// Create a new `BranchOutputs` from raw data.
70    pub fn new(outputs: HashMap<String, Bytes>, codec: Arc<C>) -> Self {
71        Self { outputs, codec }
72    }
73
74    /// Get the names of all branches.
75    pub fn branch_names(&self) -> impl Iterator<Item = &str> {
76        self.outputs.keys().map(std::string::String::as_str)
77    }
78
79    /// Check if a branch exists.
80    #[must_use]
81    pub fn contains(&self, name: &str) -> bool {
82        self.outputs.contains_key(name)
83    }
84
85    /// Get the number of branches.
86    #[must_use]
87    pub fn len(&self) -> usize {
88        self.outputs.len()
89    }
90
91    /// Check if there are no branches.
92    #[must_use]
93    pub fn is_empty(&self) -> bool {
94        self.outputs.is_empty()
95    }
96}
97
98impl<C: Codec> BranchOutputs<C> {
99    /// Get a branch output by name, deserializing to the requested type.
100    ///
101    /// # Errors
102    ///
103    /// Returns an error if the branch doesn't exist or deserialization fails.
104    pub fn get<T>(&self, name: &str) -> Result<T, BoxError>
105    where
106        C: sealed::DecodeValue<T>,
107    {
108        let bytes = self
109            .outputs
110            .get(name)
111            .ok_or_else(|| WorkflowError::BranchNotFound(name.to_string()))?;
112
113        self.codec.decode(bytes.clone())
114    }
115}
116
117/// A core task is a task that can be run by the workflow runtime.
118///
119/// Tasks can be defined either as closures (via `WorkflowBuilder::then`) or as
120/// structs implementing this trait directly. Struct-based tasks are useful for:
121/// - Reusable task logic across workflows
122/// - Tasks with configuration/state
123/// - Serializable workflows (tasks can be registered by ID)
124///
125/// # Example
126///
127/// ```rust
128/// use sayiir_core::prelude::*;
129/// use std::pin::Pin;
130/// use std::future::Future;
131///
132/// /// A task that doubles its input.
133/// struct DoubleTask;
134///
135/// impl CoreTask for DoubleTask {
136///     type Input = u32;
137///     type Output = u32;
138///     type Future = Pin<Box<dyn Future<Output = Result<u32, BoxError>> + Send>>;
139///
140///     fn run(&self, input: u32) -> Self::Future {
141///         Box::pin(async move { Ok(input * 2) })
142///     }
143/// }
144///
145/// /// A configurable task with state.
146/// struct MultiplyTask {
147///     factor: u32,
148/// }
149///
150/// impl CoreTask for MultiplyTask {
151///     type Input = u32;
152///     type Output = u32;
153///     type Future = Pin<Box<dyn Future<Output = Result<u32, BoxError>> + Send>>;
154///
155///     fn run(&self, input: u32) -> Self::Future {
156///         let factor = self.factor;
157///         Box::pin(async move { Ok(input * factor) })
158///     }
159/// }
160/// ```
161pub trait CoreTask: Send + Sync {
162    type Input;
163    type Output;
164    type Future: Future<Output = Result<Self::Output, BoxError>> + Send;
165
166    /// Run the task with the given input and return the output.
167    fn run(&self, input: Self::Input) -> Self::Future;
168}
169
170/// Wrapper that enables closures to implement `CoreTask`.
171///
172/// Use the [`fn_task`] helper function to create instances with inferred types.
173///
174/// # Example
175///
176/// ```rust,ignore
177/// use sayiir_core::task::fn_task;
178///
179/// // Both work with the same `register` method:
180/// registry.register("closure", codec.clone(), fn_task(|i: u32| async move { Ok(i * 2) }));
181/// registry.register("struct", codec.clone(), MyTask::new());
182/// ```
183pub struct FnTask<F, I, O, Fut>(F, PhantomData<fn(I) -> (O, Fut)>);
184
185impl<F, I, O, Fut> CoreTask for FnTask<F, I, O, Fut>
186where
187    F: Fn(I) -> Fut + Send + Sync,
188    I: Send,
189    O: Send,
190    Fut: Future<Output = Result<O, BoxError>> + Send,
191{
192    type Input = I;
193    type Output = O;
194    type Future = Fut;
195
196    fn run(&self, input: I) -> Self::Future {
197        (self.0)(input)
198    }
199}
200
201/// Create a `FnTask` from a closure with inferred types.
202///
203/// This is the preferred way to wrap closures for use with the unified `register` API.
204///
205/// # Example
206///
207/// ```rust,ignore
208/// use sayiir_core::task::fn_task;
209///
210/// registry.register("double", codec, fn_task(|i: u32| async move { Ok(i * 2) }));
211/// ```
212pub fn fn_task<F, I, O, Fut>(f: F) -> FnTask<F, I, O, Fut>
213where
214    F: Fn(I) -> Fut,
215{
216    FnTask(f, PhantomData)
217}
218
219/// A type-erased future that outputs `Result<Bytes>`.
220///
221/// This is a newtype around a pinned boxed future, providing a concrete type
222/// for the `Future` associated type in `UntypedCoreTask`. While it still uses
223/// boxing internally (necessary for type erasure), the named type provides:
224/// - Better error messages and stack traces
225/// - A concrete type instead of `dyn Future`
226/// - Clearer API boundaries
227pub struct BytesFuture(Pin<Box<dyn Future<Output = Result<Bytes, BoxError>> + Send>>);
228
229impl Future for BytesFuture {
230    type Output = Result<Bytes, BoxError>;
231
232    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233        self.0.as_mut().poll(cx)
234    }
235}
236
237impl BytesFuture {
238    /// Create a new `BytesFuture` from any future that outputs `Result<Bytes>`.
239    pub fn new<F>(fut: F) -> Self
240    where
241        F: Future<Output = Result<Bytes, BoxError>> + Send + 'static,
242    {
243        BytesFuture(Box::pin(fut))
244    }
245}
246
247/// A boxed core task that can be used to run a task without knowing the input and output types.
248///
249/// Uses `BytesFuture` as the concrete future type, which internally boxes the future.
250/// This boxing is necessary for type erasure when storing heterogeneous tasks.
251pub type UntypedCoreTask =
252    Box<dyn CoreTask<Input = Bytes, Output = Bytes, Future = BytesFuture> + Send + Sync>;
253
254/// Implement `CoreTask<Bytes, Bytes>` for a struct that has `func` and `codec` fields.
255///
256/// The generated `run` method: decodes the input via the codec, calls the function,
257/// and encodes the output back to `Bytes`.
258macro_rules! impl_codec_task {
259    (
260        $wrapper:ident < $($gen:ident),+ >
261        where $func_type:ty : Fn($input:ty) -> $fut_type:ty,
262              $($bound:tt)+
263    ) => {
264        impl< $($gen),+ > CoreTask for $wrapper < $($gen),+ >
265        where
266            $func_type : Fn($input) -> $fut_type + Send + Sync + 'static,
267            $($bound)+
268        {
269            type Input = Bytes;
270            type Output = Bytes;
271            type Future = BytesFuture;
272
273            fn run(&self, input: Bytes) -> Self::Future {
274                let func = Arc::clone(&self.func);
275                let codec = Arc::clone(&self.codec);
276                BytesFuture::new(async move {
277                    let decoded_input = codec.decode::<$input>(input)?;
278                    let output = func(decoded_input).await?;
279                    codec.encode(&output)
280                })
281            }
282        }
283    };
284}
285
286/// Internal wrapper that implements `CoreTask<Input = Bytes, Output = Bytes>` for async functions.
287struct UntypedTaskFnWrapper<F, I, O, Fut, C> {
288    func: Arc<F>,
289    codec: Arc<C>,
290    _phantom: std::marker::PhantomData<fn(I) -> (O, Fut)>,
291}
292
293impl_codec_task!(
294    UntypedTaskFnWrapper<F, I, O, Fut, C>
295    where F: Fn(I) -> Fut,
296          I: Send + 'static,
297          O: Send + 'static,
298          Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
299          C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
300);
301
302/// Create a new untyped task from any function using a codec.
303///
304/// The function must be Send + Sync + 'static and return a Future that resolves to a Result.
305/// Both input and output types must be Send for type erasure to work.
306/// The codec must be able to decode the input type and encode the output type.
307pub fn to_core_task<F, I, O, Fut, C>(func: F, codec: Arc<C>) -> UntypedCoreTask
308where
309    F: Fn(I) -> Fut + Send + Sync + 'static,
310    I: Send + 'static,
311    O: Send + 'static,
312    Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
313    C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
314{
315    to_core_task_arc(Arc::new(func), codec)
316}
317
318/// Create a new untyped task from an Arc-wrapped function.
319///
320/// This variant accepts an already-Arc'd function, avoiding the need
321/// for the function to implement Clone.
322pub fn to_core_task_arc<F, I, O, Fut, C>(func: Arc<F>, codec: Arc<C>) -> UntypedCoreTask
323where
324    F: Fn(I) -> Fut + Send + Sync + 'static,
325    I: Send + 'static,
326    O: Send + 'static,
327    Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
328    C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
329{
330    Box::new(UntypedTaskFnWrapper {
331        func,
332        codec,
333        _phantom: std::marker::PhantomData,
334    })
335}
336
337/// A boxed async function for use in fork branches (internal).
338type BoxedBranchFn<I, O> = Box<
339    dyn Fn(I) -> std::pin::Pin<Box<dyn Future<Output = Result<O, BoxError>> + Send>> + Send + Sync,
340>;
341
342/// A branch for use with `fork()` (internal).
343pub(crate) struct Branch<I, O> {
344    id: String,
345    func: BoxedBranchFn<I, O>,
346}
347
348/// Create a branch (internal helper used by `ForkBuilder`).
349pub(crate) fn branch<F, Fut, I, O>(id: &str, f: F) -> Branch<I, O>
350where
351    F: Fn(I) -> Fut + Send + Sync + 'static,
352    Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
353    I: 'static,
354    O: 'static,
355{
356    Branch {
357        id: id.to_string(),
358        func: Box::new(move |i| Box::pin(f(i))),
359    }
360}
361
362/// A type-erased branch for heterogeneous fork operations (internal).
363pub(crate) struct ErasedBranch {
364    pub(crate) id: String,
365    pub(crate) task: UntypedCoreTask,
366}
367
368impl<I, O> Branch<I, O> {
369    /// Convert this branch to a type-erased branch.
370    ///
371    /// This is used internally by `fork()` to allow heterogeneous output types.
372    pub fn erase<C>(self, codec: Arc<C>) -> ErasedBranch
373    where
374        I: Send + 'static,
375        O: Send + 'static,
376        C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
377    {
378        ErasedBranch {
379            id: self.id.clone(),
380            task: branch_to_core_task(self, codec),
381        }
382    }
383}
384
385/// Convert a Branch to an `UntypedCoreTask` (internal).
386#[allow(clippy::items_after_statements)]
387pub(crate) fn branch_to_core_task<I, O, C>(branch: Branch<I, O>, codec: Arc<C>) -> UntypedCoreTask
388where
389    I: Send + 'static,
390    O: Send + 'static,
391    C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
392{
393    // Wrap the boxed function in Arc so it can be cloned into the future
394    let func = Arc::new(branch.func);
395
396    struct ArcBranchWrapper<I, O, C> {
397        func: Arc<BoxedBranchFn<I, O>>,
398        codec: Arc<C>,
399        _phantom: PhantomData<fn(I) -> O>,
400    }
401
402    impl_codec_task!(
403        ArcBranchWrapper<I, O, C>
404        where BoxedBranchFn<I, O>: Fn(I) -> std::pin::Pin<Box<dyn Future<Output = Result<O, BoxError>> + Send>>,
405              I: Send + 'static,
406              O: Send + 'static,
407              C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
408    );
409
410    Box::new(ArcBranchWrapper {
411        func,
412        codec,
413        _phantom: PhantomData,
414    })
415}
416
417/// Join task wrapper for heterogeneous branch outputs.
418///
419/// This wrapper receives serialized named branch results and passes a
420/// `BranchOutputs` map to the user function for type-safe access.
421#[allow(clippy::type_complexity)]
422struct HeterogeneousJoinTaskWrapper<F, JoinOutput, Fut, C> {
423    func: Arc<F>,
424    codec: Arc<C>,
425    _phantom: PhantomData<fn(BranchOutputs<C>) -> (JoinOutput, Fut)>,
426}
427
428impl<F, JoinOutput, Fut, C> CoreTask for HeterogeneousJoinTaskWrapper<F, JoinOutput, Fut, C>
429where
430    F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
431    JoinOutput: Send + 'static,
432    Fut: Future<Output = Result<JoinOutput, BoxError>> + Send + 'static,
433    C: Codec
434        + sealed::EncodeValue<JoinOutput>
435        + sealed::DecodeValue<NamedBranchResults>
436        + Send
437        + Sync
438        + 'static,
439{
440    type Input = Bytes;
441    type Output = Bytes;
442    type Future = BytesFuture;
443
444    fn run(&self, input: Bytes) -> Self::Future {
445        let func = Arc::clone(&self.func);
446        let codec = Arc::clone(&self.codec);
447        BytesFuture::new(async move {
448            let named_results: NamedBranchResults = codec.decode(input)?;
449            let branch_outputs = BranchOutputs::new(named_results.into_map(), codec.clone());
450
451            let output = func(branch_outputs).await?;
452            codec.encode(&output)
453        })
454    }
455}
456
457/// Create a join task for heterogeneous branch outputs.
458///
459/// The join function receives `BranchOutputs<C>` which allows type-safe
460/// retrieval of each branch's output by name.
461///
462/// # Example
463///
464/// ```rust,ignore
465/// .join("combine", |outputs: BranchOutputs<MyCodec>| async move {
466///     let count: u32 = outputs.get("counter")?;
467///     let name: String = outputs.get("fetch_name")?;
468///     Ok(format!("{} - {}", name, count))
469/// })
470/// ```
471pub fn to_heterogeneous_join_task<F, JoinOutput, Fut, C>(func: F, codec: Arc<C>) -> UntypedCoreTask
472where
473    F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
474    JoinOutput: Send + 'static,
475    Fut: Future<Output = Result<JoinOutput, BoxError>> + Send + 'static,
476    C: Codec
477        + sealed::EncodeValue<JoinOutput>
478        + sealed::DecodeValue<NamedBranchResults>
479        + Send
480        + Sync
481        + 'static,
482{
483    to_heterogeneous_join_task_arc(Arc::new(func), codec)
484}
485
486/// Create a join task from an Arc-wrapped function.
487///
488/// This variant accepts an already-Arc'd function, avoiding the need
489/// for the function to implement Clone.
490pub fn to_heterogeneous_join_task_arc<F, JoinOutput, Fut, C>(
491    func: Arc<F>,
492    codec: Arc<C>,
493) -> UntypedCoreTask
494where
495    F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
496    JoinOutput: Send + 'static,
497    Fut: Future<Output = Result<JoinOutput, BoxError>> + Send + 'static,
498    C: Codec
499        + sealed::EncodeValue<JoinOutput>
500        + sealed::DecodeValue<NamedBranchResults>
501        + Send
502        + Sync
503        + 'static,
504{
505    Box::new(HeterogeneousJoinTaskWrapper {
506        func,
507        codec,
508        _phantom: PhantomData,
509    })
510}