swiftide_agents/tasks/
task.rs

1//! Tasks enable you to to define a graph of interacting nodes
2//!
3//! The nodes can be any type that implements the `TaskNode` trait, which defines how the node
4//! will be evaluated with its input and output.
5//!
6//! Most swiftide primitives implement `TaskNode`, and it's easy to implement your own. Since how
7//! agents interact is subject to taste, we recommend implementing your own.
8//!
9//! WARN: Here be dragons! This api is not stable yet. We are using it in production, and is
10//! subject to rapid change. However, do not hesitate to open an issue if you find anything.
11use std::{any::Any, pin::Pin, sync::Arc};
12
13use tracing::Instrument as _;
14
15use crate::tasks::{errors::NodeError, transition::TransitionFn};
16
17use super::{
18    errors::TaskError,
19    node::{NodeArg, NodeId, NoopNode, TaskNode},
20    transition::{AnyNodeTransition, MarkedTransitionPayload, Transition, TransitionPayload},
21};
22
23#[derive(Debug)]
24pub struct Task<Input: NodeArg, Output: NodeArg> {
25    nodes: Vec<Box<dyn AnyNodeTransition>>,
26    current_node: usize,
27    start_node: usize,
28    current_context: Option<Arc<dyn Any + Send + Sync>>,
29    _marker: std::marker::PhantomData<(Input, Output)>,
30}
31
32impl<Input: NodeArg, Output: NodeArg> Clone for Task<Input, Output> {
33    fn clone(&self) -> Self {
34        Self {
35            nodes: self.nodes.clone(),
36            current_node: 0,
37            start_node: self.start_node,
38            current_context: None,
39            _marker: std::marker::PhantomData,
40        }
41    }
42}
43
44impl<Input: NodeArg + Clone, Output: NodeArg + Clone> Default for Task<Input, Output> {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl<Input: NodeArg + Clone, Output: NodeArg + Clone> Task<Input, Output> {
51    pub fn new() -> Self {
52        let noop = NoopNode::<Output>::default();
53
54        let node_id = NodeId::new(0, &noop).as_dyn();
55
56        let noop_executor = Box::new(Transition {
57            node: Box::new(noop),
58            node_id: Box::new(node_id),
59            r#fn: Arc::new(|_output| {
60                Box::pin(async { unreachable!("Done node should never be evaluated.") })
61            }),
62            is_set: false,
63        });
64        Self {
65            nodes: vec![noop_executor],
66            current_node: 0,
67            start_node: 0,
68            current_context: None,
69            _marker: std::marker::PhantomData,
70        }
71    }
72
73    /// Returns the current context as the input type, if it matches
74    pub fn current_input(&self) -> Option<&Input> {
75        let input = self.current_context.as_ref()?;
76
77        input.downcast_ref::<Input>()
78    }
79
80    /// Returns the current context as the output type, if it matches
81    pub fn current_output(&self) -> Option<&Output> {
82        let input = self.current_context.as_ref()?;
83
84        input.downcast_ref::<Output>()
85    }
86
87    /// Returns the `done` node for this task
88    pub fn done(&self) -> NodeId<NoopNode<Output>> {
89        NodeId::new(0, &NoopNode::default())
90    }
91
92    /// Creates a transition to the done node
93    pub fn transitions_to_done(
94        &self,
95    ) -> impl Fn(Output) -> MarkedTransitionPayload<NoopNode<Output>> + Send + Sync + 'static {
96        let done = self.done();
97        move |context| done.transitions_with(context)
98    }
99
100    /// Defines the start node of the task
101    pub fn starts_with<T: TaskNode<Input = Input> + Clone + 'static>(
102        &mut self,
103        node_id: NodeId<T>,
104    ) {
105        self.current_node = node_id.id;
106        self.start_node = node_id.id;
107    }
108
109    /// Validates that all nodes have transitions set
110    ///
111    /// # Errors
112    ///
113    /// Errors if a node is missing a transition
114    pub fn validate_transitions(&self) -> Result<(), TaskError> {
115        // TODO: Validate that the task can complete
116        for node_executor in &self.nodes {
117            // Skip the done node (index 0)
118            if node_executor.node_id() == 0 {
119                continue;
120            }
121
122            if !node_executor.transition_is_set() {
123                return Err(TaskError::missing_transition(node_executor.node_id()));
124            }
125        }
126
127        Ok(())
128    }
129
130    /// Runs the task with the given input
131    ///
132    /// # Errors
133    ///
134    /// Errors if the task fails
135    #[tracing::instrument(skip(self, input), name = "task.run", err)]
136    pub async fn run(&mut self, input: impl Into<Input>) -> Result<Option<Output>, TaskError> {
137        self.validate_transitions()?;
138
139        self.current_context = Some(Arc::new(input.into()) as Arc<dyn Any + Send + Sync>);
140
141        self.start_task().await
142    }
143
144    /// Resets the task to the start node
145    ///
146    /// WARN: This **will** lead to a type mismatch if the previous context is not the same as the
147    /// input of the start node
148    pub fn reset(&mut self) {
149        self.current_node = self.start_node;
150    }
151
152    /// Resumes the task from the current node
153    ///
154    /// # Errors
155    ///
156    /// Errors if the task fails
157    #[tracing::instrument(skip(self), name = "task.resume", err)]
158    pub async fn resume(&mut self) -> Result<Option<Output>, TaskError> {
159        self.start_task().await
160    }
161
162    async fn start_task(&mut self) -> Result<Option<Output>, TaskError> {
163        self.validate_transitions()?;
164
165        let mut span = tracing::info_span!("task.step", node = self.current_node);
166        loop {
167            if self.current_node == 0 {
168                break;
169            }
170            let node_transition = self
171                .nodes
172                .get(self.current_node)
173                .ok_or_else(|| TaskError::missing_node(self.current_node))?;
174
175            let input = self
176                .current_context
177                .clone()
178                .ok_or_else(|| TaskError::missing_input(self.current_node))?;
179
180            tracing::debug!("Running node {}", self.current_node);
181
182            let span_id = span.id().clone();
183            let transition_payload = node_transition
184                .evaluate_next(input)
185                .instrument(span.or_current())
186                .await?;
187
188            match transition_payload {
189                TransitionPayload::Pause => {
190                    tracing::info!("Task paused at node {}", self.current_node);
191                    return Ok(None);
192                }
193                TransitionPayload::NextNode(transition_payload) => {
194                    self.current_node = transition_payload.node_id;
195                    self.current_context = Some(transition_payload.context);
196                }
197                TransitionPayload::Error(error) => {
198                    return Err(TaskError::NodeError(NodeError::new(
199                        error,
200                        self.current_node,
201                        None,
202                    )));
203                }
204            }
205            if self.current_node == 0 {
206                tracing::debug!("Task completed at node {}", self.current_node);
207                break;
208            }
209
210            span = tracing::info_span!("task.step", node = self.current_node).or_current();
211            span.follows_from(span_id);
212        }
213
214        let output = self
215            .current_context
216            .clone()
217            .ok_or_else(|| TaskError::missing_output(self.current_node))?;
218        let output = output
219            .downcast::<Output>()
220            .map_err(|e| TaskError::type_error(&e))?
221            .as_ref()
222            .clone();
223
224        Ok(Some(output))
225    }
226
227    /// Gets the current node of the task
228    pub fn current_node<T: TaskNode + 'static>(&self) -> Option<&T> {
229        self.node_at_index(self.current_node)
230    }
231
232    /// Gets the node at the given `NodeId`
233    pub fn node_at<T: TaskNode + 'static>(&self, node_id: NodeId<T>) -> Option<&T> {
234        self.node_at_index(node_id.id)
235    }
236
237    /// Gets the node at the given index
238    pub fn node_at_index<T: TaskNode + 'static>(&self, index: usize) -> Option<&T> {
239        let transition = self.transition_at_index::<T>(index)?;
240
241        let node = &*transition.node;
242
243        (node as &dyn Any).downcast_ref::<T>()
244    }
245
246    /// Gets the current transition of the task
247    #[allow(dead_code)]
248    fn current_transition<T: TaskNode + 'static>(
249        &self,
250    ) -> Option<&Transition<T::Input, T::Output, T::Error>> {
251        self.transition_at_index::<T>(self.current_node)
252    }
253
254    /// Gets the transition at the given `NodeId`
255    fn transition_at_index<T: TaskNode + 'static>(
256        &self,
257        index: usize,
258    ) -> Option<&Transition<T::Input, T::Output, T::Error>> {
259        tracing::debug!("Getting transition at index {}", index);
260        let transition = self.nodes.get(index)?;
261
262        dbg!(&transition);
263
264        (&**transition as &dyn Any).downcast_ref::<Transition<T::Input, T::Output, T::Error>>()
265    }
266
267    /// Registers a new node in the task
268    pub fn register_node<T>(&mut self, node: T) -> NodeId<T>
269    where
270        T: TaskNode + 'static + Clone,
271        <T as TaskNode>::Input: Clone,
272        <T as TaskNode>::Output: Clone,
273    {
274        let id = self.nodes.len();
275        let node_id = NodeId::new(id, &node);
276        let node_executor = Box::new(Transition::<T::Input, T::Output, T::Error> {
277            node_id: Box::new(node_id.as_dyn()),
278            node: Box::new(node),
279            r#fn: Arc::new(move |_output| unreachable!("No transition for node {}.", node_id.id)),
280            is_set: false,
281        });
282        // Debug the type name
283        tracing::debug!(node_id = ?node_id, type_name = std::any::type_name_of_val(&node_executor), "Registering node");
284
285        self.nodes.push(node_executor);
286
287        node_id
288    }
289
290    /// Registers a transition from one node to another
291    ///
292    /// Note that there are various helpers and conversions for the `MarkedTransitionPayload`
293    ///
294    /// # Errors
295    ///
296    /// Errors if the node does not exist
297    pub fn register_transition<'a, From, To, F>(
298        &mut self,
299        from: NodeId<From>,
300        transition: F,
301    ) -> Result<(), TaskError>
302    where
303        From: TaskNode + 'static + ?Sized,
304        To: TaskNode<Input = From::Output> + 'a + ?Sized,
305        F: Fn(To::Input) -> MarkedTransitionPayload<To> + Send + Sync + 'static,
306    {
307        let node_executor = self
308            .nodes
309            .get_mut(from.id)
310            .ok_or_else(|| TaskError::missing_node(from.id))?;
311
312        let any_executor: &mut dyn Any = node_executor.as_mut();
313
314        let Some(exec) =
315            any_executor.downcast_mut::<Transition<From::Input, From::Output, From::Error>>()
316        else {
317            let expected =
318                std::any::type_name::<Transition<From::Input, From::Output, From::Error>>();
319            let actual = std::any::type_name_of_val(node_executor);
320
321            unreachable!(
322                "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}",
323                from.id
324            );
325        };
326        let transition = Arc::new(transition);
327        let wrapped: Arc<dyn TransitionFn<From::Output>> = Arc::new(move |output: From::Output| {
328            let transition = transition.clone();
329            Box::pin(async move {
330                let output = transition(output);
331                output.into_inner()
332            })
333        });
334
335        exec.r#fn = wrapped;
336        exec.is_set = true;
337        // set function as before
338
339        Ok(())
340    }
341
342    /// Registers a transition from one node to another asynchronously
343    ///
344    /// Note that there are various helpers and conversions for the `MarkedTransitionPayload`
345    ///
346    /// # Errors
347    ///
348    /// Errors if the node does not exist
349    ///
350    /// NOTE: `AsyncFn` traits' returned future are not 'Send' and the inner type is unstable.
351    /// When they are, we can update Fn to `AsyncFn`
352    pub fn register_transition_async<'a, From, To, F>(
353        &mut self,
354        from: NodeId<From>,
355        transition: F,
356    ) -> Result<(), TaskError>
357    where
358        From: TaskNode + 'static + ?Sized,
359        To: TaskNode<Input = From::Output> + 'a + ?Sized,
360        F: Fn(To::Input) -> Pin<Box<dyn Future<Output = MarkedTransitionPayload<To>> + Send>>
361            + Send
362            + Sync
363            + 'static,
364    {
365        let node_executor = self
366            .nodes
367            .get_mut(from.id)
368            .ok_or_else(|| TaskError::missing_node(from.id))?;
369
370        let any_executor: &mut dyn Any = node_executor.as_mut();
371
372        let Some(exec) =
373            any_executor.downcast_mut::<Transition<From::Input, From::Output, From::Error>>()
374        else {
375            let expected =
376                std::any::type_name::<Transition<From::Input, From::Output, From::Error>>();
377            let actual = std::any::type_name_of_val(node_executor);
378
379            unreachable!(
380                "Transition at index {:?} is not a {expected:?}; Mismatched types, should not never happen. Actual: {actual:?}",
381                from.id
382            );
383        };
384        let transition = Arc::new(transition);
385        let wrapped: Arc<dyn TransitionFn<From::Output>> = Arc::new(move |output: From::Output| {
386            let transition = transition.clone();
387
388            Box::pin(async move {
389                let output = transition(output).await;
390                output.into_inner()
391            })
392        });
393
394        exec.r#fn = wrapped;
395        exec.is_set = true;
396        // set function as before
397
398        Ok(())
399    }
400}
401
402#[cfg(test)]
403mod tests {
404    use async_trait::async_trait;
405
406    use super::*;
407
408    #[derive(thiserror::Error, Debug)]
409    struct Error(String);
410
411    impl std::fmt::Display for Error {
412        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
413            write!(f, "{}", self.0)
414        }
415    }
416
417    #[derive(Clone, Default, Debug)]
418    struct IntNode;
419    #[async_trait]
420    impl TaskNode for IntNode {
421        type Input = i32;
422        type Output = i32;
423        type Error = Error;
424
425        async fn evaluate(
426            &self,
427            _node_id: &NodeId<
428                dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
429            >,
430            input: &Self::Input,
431        ) -> Result<Self::Output, Self::Error> {
432            Ok(input + 1)
433        }
434    }
435    // Implement other required traits if necessary...
436
437    #[test_log::test(tokio::test)]
438    async fn sequential_3_node_task_reset_works() {
439        let mut task: Task<i32, i32> = Task::new();
440
441        // Register three nodes
442        let node1 = task.register_node(IntNode);
443        let node2 = task.register_node(IntNode);
444        let node3 = task.register_node(IntNode);
445
446        // Set start node
447        task.starts_with(node1);
448
449        // Register transitions (node1 → node2 → node3 → done)
450        task.register_transition::<_, _, _>(node1, move |input| node2.transitions_with(input))
451            .unwrap();
452        task.register_transition::<_, _, _>(node2, move |input| node3.transitions_with(input))
453            .unwrap();
454        task.register_transition::<_, _, _>(node3, task.transitions_to_done())
455            .unwrap();
456
457        // Run the task to completion
458        let res = task.run(1).await.unwrap();
459        assert_eq!(res, Some(4)); // 1 + 1 + 1 + 1
460
461        // Reset the task
462        task.reset();
463
464        // Assert current_node returns the correct node (node1)
465        dbg!(&task);
466        let n1_transition = task.transition_at_index::<IntNode>(1);
467
468        assert!(n1_transition.is_some());
469
470        let n1_transition = task.current_transition::<IntNode>();
471        assert!(n1_transition.is_some());
472
473        let n1_ref = task.current_node::<IntNode>();
474        assert!(n1_ref.is_some());
475    }
476}