swiftide_agents/tasks/
task.rs

1//! Tasks enable you to to define a graph of interacting nodes
2//!
3//! The nodes can be any type that implements the `TaskNode` trait, which defines how the node
4//! will be evaluated with its input and output.
5//!
6//! Most swiftide primitives implement `TaskNode`, and it's easy to implement your own. Since how
7//! agents interact is subject to taste, we recommend implementing your own.
8//!
9//! WARN: Here be dragons! This api is not stable yet. We are using it in production, and is
10//! subject to rapid change. However, do not hesitate to open an issue if you find anything.
11use std::{any::Any, pin::Pin, sync::Arc};
12
13use crate::tasks::{errors::NodeError, transition::TransitionFn};
14
15use super::{
16    errors::TaskError,
17    node::{NodeArg, NodeId, NoopNode, TaskNode},
18    transition::{AnyNodeTransition, MarkedTransitionPayload, Transition, TransitionPayload},
19};
20
21#[derive(Debug)]
22pub struct Task<Input: NodeArg, Output: NodeArg> {
23    nodes: Vec<Box<dyn AnyNodeTransition>>,
24    current_node: usize,
25    start_node: usize,
26    current_context: Option<Arc<dyn Any + Send + Sync>>,
27    _marker: std::marker::PhantomData<(Input, Output)>,
28}
29
30impl<Input: NodeArg, Output: NodeArg> Clone for Task<Input, Output> {
31    fn clone(&self) -> Self {
32        Self {
33            nodes: self.nodes.clone(),
34            current_node: 0,
35            start_node: self.start_node,
36            current_context: None,
37            _marker: std::marker::PhantomData,
38        }
39    }
40}
41
42impl<Input: NodeArg + Clone, Output: NodeArg + Clone> Default for Task<Input, Output> {
43    fn default() -> Self {
44        Self::new()
45    }
46}
47
48impl<Input: NodeArg + Clone, Output: NodeArg + Clone> Task<Input, Output> {
49    pub fn new() -> Self {
50        let noop = NoopNode::<Output>::default();
51
52        let node_id = NodeId::new(0, &noop).as_dyn();
53
54        let noop_executor = Box::new(Transition {
55            node: Box::new(noop),
56            node_id: Box::new(node_id),
57            r#fn: Arc::new(|_output| {
58                Box::pin(async { unreachable!("Done node should never be evaluated.") })
59            }),
60            is_set: false,
61        });
62        Self {
63            nodes: vec![noop_executor],
64            current_node: 0,
65            start_node: 0,
66            current_context: None,
67            _marker: std::marker::PhantomData,
68        }
69    }
70
71    pub fn done(&self) -> NodeId<NoopNode<Output>> {
72        NodeId::new(0, &NoopNode::default())
73    }
74
75    /// Creates a transition to the done node
76    pub fn transitions_to_done(
77        &self,
78    ) -> impl Fn(Output) -> MarkedTransitionPayload<NoopNode<Output>> + Send + Sync + 'static {
79        let done = self.done();
80        move |context| done.transitions_with(context)
81    }
82
83    /// Defines the start node of the task
84    pub fn starts_with<T: TaskNode<Input = Input> + Clone + 'static>(
85        &mut self,
86        node_id: NodeId<T>,
87    ) {
88        self.current_node = node_id.id;
89        self.start_node = node_id.id;
90    }
91
92    /// Validates that all nodes have transitions set
93    ///
94    /// # Errors
95    ///
96    /// Errors if a node is missing a transition
97    pub fn validate_transitions(&self) -> Result<(), TaskError> {
98        // TODO: Validate that the task can complete
99        for node_executor in &self.nodes {
100            // Skip the done node (index 0)
101            if node_executor.node_id() == 0 {
102                continue;
103            }
104
105            if !node_executor.transition_is_set() {
106                return Err(TaskError::missing_transition(node_executor.node_id()));
107            }
108        }
109
110        Ok(())
111    }
112
113    /// Runs the task with the given input
114    ///
115    /// # Errors
116    ///
117    /// Errors if the task fails
118    pub async fn run(&mut self, input: impl Into<Input>) -> Result<Option<Output>, TaskError> {
119        self.validate_transitions()?;
120
121        self.current_context = Some(Arc::new(input.into()) as Arc<dyn Any + Send + Sync>);
122
123        self.resume().await
124    }
125
126    /// Resets the task to the start node
127    ///
128    /// WARN: This **will** lead to a type mismatch if the previous context is not the same as the
129    /// input of the start node
130    pub fn reset(&mut self) {
131        self.current_node = self.start_node;
132    }
133
134    /// Resumes the task from the current node
135    ///
136    /// # Errors
137    ///
138    /// Errors if the task fails
139    pub async fn resume(&mut self) -> Result<Option<Output>, TaskError> {
140        self.validate_transitions()?;
141
142        loop {
143            if self.current_node == 0 {
144                break;
145            }
146            let node_transition = self
147                .nodes
148                .get(self.current_node)
149                .ok_or_else(|| TaskError::missing_node(self.current_node))?;
150
151            let input = self
152                .current_context
153                .clone()
154                .ok_or_else(|| TaskError::missing_input(self.current_node))?;
155
156            tracing::debug!("Running node {}", self.current_node);
157            let transition_payload = node_transition.evaluate_next(input).await?;
158
159            match transition_payload {
160                TransitionPayload::Pause => {
161                    tracing::info!("Task paused at node {}", self.current_node);
162                    return Ok(None);
163                }
164                TransitionPayload::NextNode(transition_payload) => {
165                    self.current_node = transition_payload.node_id;
166                    self.current_context = Some(transition_payload.context);
167                }
168                TransitionPayload::Error(error) => {
169                    return Err(TaskError::NodeError(NodeError::new(
170                        error,
171                        self.current_node,
172                        None,
173                    )));
174                }
175            }
176        }
177
178        let output = self
179            .current_context
180            .clone()
181            .ok_or_else(|| TaskError::missing_output(self.current_node))?;
182        let output = output
183            .downcast::<Output>()
184            .map_err(|e| TaskError::type_error(&e))?
185            .as_ref()
186            .clone();
187
188        Ok(Some(output))
189    }
190
191    /// Gets the current node of the task
192    pub fn current_node<T: TaskNode + 'static>(&self) -> Option<&T> {
193        self.node_at_index(self.current_node)
194    }
195
196    /// Gets the node at the given `NodeId`
197    pub fn node_at<T: TaskNode + 'static>(&self, node_id: NodeId<T>) -> Option<&T> {
198        self.node_at_index(node_id.id)
199    }
200
201    /// Gets the node at the given index
202    pub fn node_at_index<T: TaskNode + 'static>(&self, index: usize) -> Option<&T> {
203        let transition = self.transition_at_index::<T>(index)?;
204
205        let node = &*transition.node;
206
207        (node as &dyn Any).downcast_ref::<T>()
208    }
209
210    /// Gets the current transition of the task
211    #[allow(dead_code)]
212    fn current_transition<T: TaskNode + 'static>(
213        &self,
214    ) -> Option<&Transition<T::Input, T::Output, T::Error>> {
215        self.transition_at_index::<T>(self.current_node)
216    }
217
218    /// Gets the transition at the given `NodeId`
219    fn transition_at_index<T: TaskNode + 'static>(
220        &self,
221        index: usize,
222    ) -> Option<&Transition<T::Input, T::Output, T::Error>> {
223        tracing::debug!("Getting transition at index {}", index);
224        let transition = self.nodes.get(index)?;
225
226        dbg!(&transition);
227
228        (&**transition as &dyn Any).downcast_ref::<Transition<T::Input, T::Output, T::Error>>()
229    }
230
231    /// Registers a new node in the task
232    pub fn register_node<T>(&mut self, node: T) -> NodeId<T>
233    where
234        T: TaskNode + 'static + Clone,
235        <T as TaskNode>::Input: Clone,
236        <T as TaskNode>::Output: Clone,
237    {
238        let id = self.nodes.len();
239        let node_id = NodeId::new(id, &node);
240        let node_executor = Box::new(Transition::<T::Input, T::Output, T::Error> {
241            node_id: Box::new(node_id.as_dyn()),
242            node: Box::new(node),
243            r#fn: Arc::new(move |_output| unreachable!("No transition for node {}.", node_id.id)),
244            is_set: false,
245        });
246        // Debug the type name
247        tracing::debug!(node_id = ?node_id, type_name = std::any::type_name_of_val(&node_executor), "Registering node");
248
249        self.nodes.push(node_executor);
250
251        node_id
252    }
253
254    /// Registers a transition from one node to another
255    ///
256    /// Note that there are various helpers and conversions for the `MarkedTransitionPayload`
257    ///
258    /// # Errors
259    ///
260    /// Errors if the node does not exist
261    pub fn register_transition<'a, From, To, F>(
262        &mut self,
263        from: NodeId<From>,
264        transition: F,
265    ) -> Result<(), TaskError>
266    where
267        From: TaskNode + 'static + ?Sized,
268        To: TaskNode<Input = From::Output> + 'a + ?Sized,
269        F: Fn(To::Input) -> MarkedTransitionPayload<To> + Send + Sync + 'static,
270    {
271        let node_executor = self
272            .nodes
273            .get_mut(from.id)
274            .ok_or_else(|| TaskError::missing_node(from.id))?;
275
276        let any_executor: &mut dyn Any = node_executor.as_mut();
277
278        let Some(exec) =
279            any_executor.downcast_mut::<Transition<From::Input, From::Output, From::Error>>()
280        else {
281            let expected =
282                std::any::type_name::<Transition<From::Input, From::Output, From::Error>>();
283            let actual = std::any::type_name_of_val(node_executor);
284
285            unreachable!(
286                "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}",
287                from.id
288            );
289        };
290        let transition = Arc::new(transition);
291        let wrapped: Arc<dyn TransitionFn<From::Output>> = Arc::new(move |output: From::Output| {
292            let transition = transition.clone();
293            Box::pin(async move {
294                let output = transition(output);
295                output.into_inner()
296            })
297        });
298
299        exec.r#fn = wrapped;
300        exec.is_set = true;
301        // set function as before
302
303        Ok(())
304    }
305
306    /// Registers a transition from one node to another asynchronously
307    ///
308    /// Note that there are various helpers and conversions for the `MarkedTransitionPayload`
309    ///
310    /// # Errors
311    ///
312    /// Errors if the node does not exist
313    ///
314    /// NOTE: `AsyncFn` traits' returned future are not 'Send' and the inner type is unstable.
315    /// When they are, we can update Fn to `AsyncFn`
316    pub fn register_transition_async<'a, From, To, F>(
317        &mut self,
318        from: NodeId<From>,
319        transition: F,
320    ) -> Result<(), TaskError>
321    where
322        From: TaskNode + 'static + ?Sized,
323        To: TaskNode<Input = From::Output> + 'a + ?Sized,
324        F: Fn(To::Input) -> Pin<Box<dyn Future<Output = MarkedTransitionPayload<To>> + Send>>
325            + Send
326            + Sync
327            + 'static,
328    {
329        let node_executor = self
330            .nodes
331            .get_mut(from.id)
332            .ok_or_else(|| TaskError::missing_node(from.id))?;
333
334        let any_executor: &mut dyn Any = node_executor.as_mut();
335
336        let Some(exec) =
337            any_executor.downcast_mut::<Transition<From::Input, From::Output, From::Error>>()
338        else {
339            let expected =
340                std::any::type_name::<Transition<From::Input, From::Output, From::Error>>();
341            let actual = std::any::type_name_of_val(node_executor);
342
343            unreachable!(
344                "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}",
345                from.id
346            );
347        };
348        let transition = Arc::new(transition);
349        let wrapped: Arc<dyn TransitionFn<From::Output>> = Arc::new(move |output: From::Output| {
350            let transition = transition.clone();
351
352            Box::pin(async move {
353                let output = transition(output).await;
354                output.into_inner()
355            })
356        });
357
358        exec.r#fn = wrapped;
359        exec.is_set = true;
360        // set function as before
361
362        Ok(())
363    }
364}
365
366#[cfg(test)]
367mod tests {
368    use async_trait::async_trait;
369
370    use super::*;
371
372    #[derive(thiserror::Error, Debug)]
373    struct Error(String);
374
375    impl std::fmt::Display for Error {
376        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
377            write!(f, "{}", self.0)
378        }
379    }
380
381    #[derive(Clone, Default, Debug)]
382    struct IntNode;
383    #[async_trait]
384    impl TaskNode for IntNode {
385        type Input = i32;
386        type Output = i32;
387        type Error = Error;
388
389        async fn evaluate(
390            &self,
391            _node_id: &NodeId<
392                dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
393            >,
394            input: &Self::Input,
395        ) -> Result<Self::Output, Self::Error> {
396            Ok(input + 1)
397        }
398    }
399    // Implement other required traits if necessary...
400
401    #[test_log::test(tokio::test)]
402    async fn sequential_3_node_task_reset_works() {
403        let mut task: Task<i32, i32> = Task::new();
404
405        // Register three nodes
406        let node1 = task.register_node(IntNode);
407        let node2 = task.register_node(IntNode);
408        let node3 = task.register_node(IntNode);
409
410        // Set start node
411        task.starts_with(node1);
412
413        // Register transitions (node1 → node2 → node3 → done)
414        task.register_transition::<_, _, _>(node1, move |input| node2.transitions_with(input))
415            .unwrap();
416        task.register_transition::<_, _, _>(node2, move |input| node3.transitions_with(input))
417            .unwrap();
418        task.register_transition::<_, _, _>(node3, task.transitions_to_done())
419            .unwrap();
420
421        // Run the task to completion
422        let res = task.run(1).await.unwrap();
423        assert_eq!(res, Some(4)); // 1 + 1 + 1 + 1
424
425        // Reset the task
426        task.reset();
427
428        // Assert current_node returns the correct node (node1)
429        dbg!(&task);
430        let n1_transition = task.transition_at_index::<IntNode>(1);
431
432        assert!(n1_transition.is_some());
433
434        let n1_transition = task.current_transition::<IntNode>();
435        assert!(n1_transition.is_some());
436
437        let n1_ref = task.current_node::<IntNode>();
438        assert!(n1_ref.is_some());
439    }
440}