swiftide_agents/tasks/
closures.rs

1use std::pin::Pin;
2
3use async_trait::async_trait;
4
5use super::{
6    errors::NodeError,
7    node::{NodeArg, NodeId, TaskNode},
8};
9
10#[derive(Clone)]
11pub struct SyncFn<F, I, O>
12where
13    F: Fn(&I) -> Result<O, NodeError> + Send + Sync + Clone + 'static,
14{
15    pub f: F,
16    _phantom: std::marker::PhantomData<(I, O)>,
17}
18
19#[derive(Clone)]
20pub struct AsyncFn<F, I, O>
21where
22    F: for<'a> Fn(&'a I) -> Pin<Box<dyn Future<Output = Result<O, NodeError>> + Send + 'a>>
23        + Send
24        + Sync
25        + Clone
26        + 'static,
27{
28    pub f: F,
29    _phantom: std::marker::PhantomData<(I, O)>,
30}
31
32impl<F, I, O> SyncFn<F, I, O>
33where
34    F: Fn(&I) -> Result<O, NodeError> + Send + Sync + Clone + 'static,
35    I: NodeArg + Clone,
36    O: NodeArg + Clone,
37{
38    pub fn new(f: F) -> Self {
39        SyncFn {
40            f,
41            _phantom: std::marker::PhantomData,
42        }
43    }
44}
45
46impl<F, I, O> AsyncFn<F, I, O>
47where
48    F: for<'a> Fn(&'a I) -> Pin<Box<dyn Future<Output = Result<O, NodeError>> + Send + 'a>>
49        + Send
50        + Sync
51        + Clone
52        + 'static,
53    I: NodeArg + Clone,
54    O: NodeArg + Clone,
55{
56    pub fn new(f: F) -> Self {
57        AsyncFn {
58            f,
59            _phantom: std::marker::PhantomData,
60        }
61    }
62}
63
64impl<F> From<F> for SyncFn<F, (), ()>
65where
66    F: Fn(&()) -> Result<(), NodeError> + Send + Sync + Clone + 'static,
67{
68    fn from(f: F) -> Self {
69        SyncFn::new(f)
70    }
71}
72
73impl<F> From<F> for AsyncFn<F, (), ()>
74where
75    F: for<'a> Fn(&'a ()) -> Pin<Box<dyn Future<Output = Result<(), NodeError>> + Send + 'a>>
76        + Send
77        + Sync
78        + Clone
79        + 'static,
80{
81    fn from(f: F) -> Self {
82        AsyncFn::new(f)
83    }
84}
85
86#[async_trait]
87impl<F, I, O> TaskNode for SyncFn<F, I, O>
88where
89    F: Fn(&I) -> Result<O, NodeError> + Clone + Send + Sync + 'static,
90    I: NodeArg + Clone,
91    O: NodeArg + Clone,
92{
93    type Input = I;
94    type Output = O;
95    type Error = NodeError;
96
97    async fn evaluate(
98        &self,
99        _node_id: &NodeId<
100            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
101        >,
102        input: &Self::Input,
103    ) -> Result<Self::Output, Self::Error> {
104        (self.f)(input)
105    }
106}
107
108#[async_trait]
109impl<F, I, O> TaskNode for AsyncFn<F, I, O>
110where
111    F: for<'a> Fn(&'a I) -> Pin<Box<dyn Future<Output = Result<O, NodeError>> + Send + 'a>>
112        + Clone
113        + Send
114        + Sync
115        + 'static,
116    I: NodeArg + Clone,
117    O: NodeArg + Clone,
118{
119    type Input = I;
120    type Output = O;
121    type Error = NodeError;
122
123    async fn evaluate(
124        &self,
125        _node_id: &NodeId<
126            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
127        >,
128        input: &Self::Input,
129    ) -> Result<Self::Output, Self::Error> {
130        (self.f)(input).await
131    }
132}