1use crate::codec::{Codec, sealed};
2use crate::error::{BoxError, WorkflowError};
3use bytes::Bytes;
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::future::Future;
7use std::marker::PhantomData;
8use std::pin::Pin;
9use std::sync::Arc;
10use std::task::{Context, Poll};
11use std::time::Duration;
12
13#[derive(Clone, Debug, Default, Serialize, Deserialize)]
18pub struct TaskMetadata {
19 pub display_name: Option<String>,
21 pub description: Option<String>,
23 pub timeout: Option<Duration>,
25 pub retries: Option<RetryPolicy>,
27 pub tags: Vec<String>,
29}
30
31#[derive(Clone, Debug, Serialize, Deserialize)]
33pub struct RetryPolicy {
34 #[serde(alias = "max_attempts")]
36 pub max_retries: u32,
37 pub initial_delay: Duration,
39 pub backoff_multiplier: f32,
41 #[serde(default)]
43 pub max_delay: Option<Duration>,
44}
45
46pub use crate::branch_results::NamedBranchResults;
47
48pub struct BranchOutputs<C> {
64 outputs: HashMap<String, Bytes>,
65 codec: Arc<C>,
66}
67
68impl<C> BranchOutputs<C> {
69 pub fn new(outputs: HashMap<String, Bytes>, codec: Arc<C>) -> Self {
71 Self { outputs, codec }
72 }
73
74 pub fn branch_names(&self) -> impl Iterator<Item = &str> {
76 self.outputs.keys().map(std::string::String::as_str)
77 }
78
79 #[must_use]
81 pub fn contains(&self, name: &str) -> bool {
82 self.outputs.contains_key(name)
83 }
84
85 #[must_use]
87 pub fn len(&self) -> usize {
88 self.outputs.len()
89 }
90
91 #[must_use]
93 pub fn is_empty(&self) -> bool {
94 self.outputs.is_empty()
95 }
96}
97
98impl<C: Codec> BranchOutputs<C> {
99 pub fn get<T>(&self, name: &str) -> Result<T, BoxError>
105 where
106 C: sealed::DecodeValue<T>,
107 {
108 let bytes = self
109 .outputs
110 .get(name)
111 .ok_or_else(|| WorkflowError::BranchNotFound(name.to_string()))?;
112
113 self.codec.decode(bytes.clone())
114 }
115}
116
117pub trait CoreTask: Send + Sync {
162 type Input;
163 type Output;
164 type Future: Future<Output = Result<Self::Output, BoxError>> + Send;
165
166 fn run(&self, input: Self::Input) -> Self::Future;
168}
169
170pub struct FnTask<F, I, O, Fut>(F, PhantomData<fn(I) -> (O, Fut)>);
184
185impl<F, I, O, Fut> CoreTask for FnTask<F, I, O, Fut>
186where
187 F: Fn(I) -> Fut + Send + Sync,
188 I: Send,
189 O: Send,
190 Fut: Future<Output = Result<O, BoxError>> + Send,
191{
192 type Input = I;
193 type Output = O;
194 type Future = Fut;
195
196 fn run(&self, input: I) -> Self::Future {
197 (self.0)(input)
198 }
199}
200
201pub fn fn_task<F, I, O, Fut>(f: F) -> FnTask<F, I, O, Fut>
213where
214 F: Fn(I) -> Fut,
215{
216 FnTask(f, PhantomData)
217}
218
219pub struct BytesFuture(Pin<Box<dyn Future<Output = Result<Bytes, BoxError>> + Send>>);
228
229impl Future for BytesFuture {
230 type Output = Result<Bytes, BoxError>;
231
232 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
233 self.0.as_mut().poll(cx)
234 }
235}
236
237impl BytesFuture {
238 pub fn new<F>(fut: F) -> Self
240 where
241 F: Future<Output = Result<Bytes, BoxError>> + Send + 'static,
242 {
243 BytesFuture(Box::pin(fut))
244 }
245}
246
247pub type UntypedCoreTask =
252 Box<dyn CoreTask<Input = Bytes, Output = Bytes, Future = BytesFuture> + Send + Sync>;
253
254macro_rules! impl_codec_task {
259 (
260 $wrapper:ident < $($gen:ident),+ >
261 where $func_type:ty : Fn($input:ty) -> $fut_type:ty,
262 $($bound:tt)+
263 ) => {
264 impl< $($gen),+ > CoreTask for $wrapper < $($gen),+ >
265 where
266 $func_type : Fn($input) -> $fut_type + Send + Sync + 'static,
267 $($bound)+
268 {
269 type Input = Bytes;
270 type Output = Bytes;
271 type Future = BytesFuture;
272
273 fn run(&self, input: Bytes) -> Self::Future {
274 let func = Arc::clone(&self.func);
275 let codec = Arc::clone(&self.codec);
276 BytesFuture::new(async move {
277 let decoded_input = codec.decode::<$input>(input)?;
278 let output = func(decoded_input).await?;
279 codec.encode(&output)
280 })
281 }
282 }
283 };
284}
285
286struct UntypedTaskFnWrapper<F, I, O, Fut, C> {
288 func: Arc<F>,
289 codec: Arc<C>,
290 _phantom: std::marker::PhantomData<fn(I) -> (O, Fut)>,
291}
292
293impl_codec_task!(
294 UntypedTaskFnWrapper<F, I, O, Fut, C>
295 where F: Fn(I) -> Fut,
296 I: Send + 'static,
297 O: Send + 'static,
298 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
299 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
300);
301
302pub fn to_core_task<F, I, O, Fut, C>(func: F, codec: Arc<C>) -> UntypedCoreTask
308where
309 F: Fn(I) -> Fut + Send + Sync + 'static,
310 I: Send + 'static,
311 O: Send + 'static,
312 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
313 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
314{
315 to_core_task_arc(Arc::new(func), codec)
316}
317
318pub fn to_core_task_arc<F, I, O, Fut, C>(func: Arc<F>, codec: Arc<C>) -> UntypedCoreTask
323where
324 F: Fn(I) -> Fut + Send + Sync + 'static,
325 I: Send + 'static,
326 O: Send + 'static,
327 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
328 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
329{
330 Box::new(UntypedTaskFnWrapper {
331 func,
332 codec,
333 _phantom: std::marker::PhantomData,
334 })
335}
336
337type BoxedBranchFn<I, O> = Box<
339 dyn Fn(I) -> std::pin::Pin<Box<dyn Future<Output = Result<O, BoxError>> + Send>> + Send + Sync,
340>;
341
342pub(crate) struct Branch<I, O> {
344 id: String,
345 func: BoxedBranchFn<I, O>,
346}
347
348pub(crate) fn branch<F, Fut, I, O>(id: &str, f: F) -> Branch<I, O>
350where
351 F: Fn(I) -> Fut + Send + Sync + 'static,
352 Fut: Future<Output = Result<O, BoxError>> + Send + 'static,
353 I: 'static,
354 O: 'static,
355{
356 Branch {
357 id: id.to_string(),
358 func: Box::new(move |i| Box::pin(f(i))),
359 }
360}
361
362pub(crate) struct ErasedBranch {
364 pub(crate) id: String,
365 pub(crate) task: UntypedCoreTask,
366}
367
368impl<I, O> Branch<I, O> {
369 pub fn erase<C>(self, codec: Arc<C>) -> ErasedBranch
373 where
374 I: Send + 'static,
375 O: Send + 'static,
376 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
377 {
378 ErasedBranch {
379 id: self.id.clone(),
380 task: branch_to_core_task(self, codec),
381 }
382 }
383}
384
385#[allow(clippy::items_after_statements)]
387pub(crate) fn branch_to_core_task<I, O, C>(branch: Branch<I, O>, codec: Arc<C>) -> UntypedCoreTask
388where
389 I: Send + 'static,
390 O: Send + 'static,
391 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
392{
393 let func = Arc::new(branch.func);
395
396 struct ArcBranchWrapper<I, O, C> {
397 func: Arc<BoxedBranchFn<I, O>>,
398 codec: Arc<C>,
399 _phantom: PhantomData<fn(I) -> O>,
400 }
401
402 impl_codec_task!(
403 ArcBranchWrapper<I, O, C>
404 where BoxedBranchFn<I, O>: Fn(I) -> std::pin::Pin<Box<dyn Future<Output = Result<O, BoxError>> + Send>>,
405 I: Send + 'static,
406 O: Send + 'static,
407 C: Codec + sealed::DecodeValue<I> + sealed::EncodeValue<O>,
408 );
409
410 Box::new(ArcBranchWrapper {
411 func,
412 codec,
413 _phantom: PhantomData,
414 })
415}
416
417#[allow(clippy::type_complexity)]
422struct HeterogeneousJoinTaskWrapper<F, JoinOutput, Fut, C> {
423 func: Arc<F>,
424 codec: Arc<C>,
425 _phantom: PhantomData<fn(BranchOutputs<C>) -> (JoinOutput, Fut)>,
426}
427
428impl<F, JoinOutput, Fut, C> CoreTask for HeterogeneousJoinTaskWrapper<F, JoinOutput, Fut, C>
429where
430 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
431 JoinOutput: Send + 'static,
432 Fut: Future<Output = Result<JoinOutput, BoxError>> + Send + 'static,
433 C: Codec
434 + sealed::EncodeValue<JoinOutput>
435 + sealed::DecodeValue<NamedBranchResults>
436 + Send
437 + Sync
438 + 'static,
439{
440 type Input = Bytes;
441 type Output = Bytes;
442 type Future = BytesFuture;
443
444 fn run(&self, input: Bytes) -> Self::Future {
445 let func = Arc::clone(&self.func);
446 let codec = Arc::clone(&self.codec);
447 BytesFuture::new(async move {
448 let named_results: NamedBranchResults = codec.decode(input)?;
449 let branch_outputs = BranchOutputs::new(named_results.into_map(), codec.clone());
450
451 let output = func(branch_outputs).await?;
452 codec.encode(&output)
453 })
454 }
455}
456
457pub fn to_heterogeneous_join_task<F, JoinOutput, Fut, C>(func: F, codec: Arc<C>) -> UntypedCoreTask
472where
473 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
474 JoinOutput: Send + 'static,
475 Fut: Future<Output = Result<JoinOutput, BoxError>> + Send + 'static,
476 C: Codec
477 + sealed::EncodeValue<JoinOutput>
478 + sealed::DecodeValue<NamedBranchResults>
479 + Send
480 + Sync
481 + 'static,
482{
483 to_heterogeneous_join_task_arc(Arc::new(func), codec)
484}
485
486pub fn to_heterogeneous_join_task_arc<F, JoinOutput, Fut, C>(
491 func: Arc<F>,
492 codec: Arc<C>,
493) -> UntypedCoreTask
494where
495 F: Fn(BranchOutputs<C>) -> Fut + Send + Sync + 'static,
496 JoinOutput: Send + 'static,
497 Fut: Future<Output = Result<JoinOutput, BoxError>> + Send + 'static,
498 C: Codec
499 + sealed::EncodeValue<JoinOutput>
500 + sealed::DecodeValue<NamedBranchResults>
501 + Send
502 + Sync
503 + 'static,
504{
505 Box::new(HeterogeneousJoinTaskWrapper {
506 func,
507 codec,
508 _phantom: PhantomData,
509 })
510}