swiftide_agents/tasks/
node.rs

1use std::any::Any;
2
3use async_trait::async_trait;
4use dyn_clone::DynClone;
5
6use super::{
7    errors::NodeError,
8    transition::{MarkedTransitionPayload, TransitionPayload},
9};
10
11pub trait NodeArg: Send + Sync + DynClone + 'static {}
12
13impl<T: Send + Sync + std::fmt::Debug + 'static + Clone> NodeArg for T {}
14
15#[derive(Debug, Clone)]
16pub struct NoopNode<Context: NodeArg> {
17    _marker: std::marker::PhantomData<(Context, Box<dyn std::error::Error + Send + Sync>)>,
18}
19
20impl<Context> Default for NoopNode<Context>
21where
22    Context: NodeArg,
23{
24    fn default() -> Self {
25        NoopNode {
26            _marker: std::marker::PhantomData,
27        }
28    }
29}
30
31#[async_trait]
32impl<Context: NodeArg + Clone> TaskNode for NoopNode<Context> {
33    type Output = ();
34    type Input = Context;
35    type Error = NodeError;
36
37    async fn evaluate(
38        &self,
39        _node_id: &DynNodeId<Self>,
40        _context: &Context,
41    ) -> Result<Self::Output, Self::Error> {
42        Ok(())
43    }
44}
45
46#[async_trait]
47pub trait TaskNode: Send + Sync + DynClone + Any {
48    type Input: NodeArg;
49    type Output: NodeArg;
50    type Error: std::error::Error + Send + Sync + 'static;
51
52    async fn evaluate(
53        &self,
54        node_id: &DynNodeId<Self>,
55        input: &Self::Input,
56    ) -> Result<Self::Output, Self::Error>;
57}
58
59pub type DynNodeId<T> = NodeId<
60    dyn TaskNode<
61            Input = <T as TaskNode>::Input,
62            Output = <T as TaskNode>::Output,
63            Error = <T as TaskNode>::Error,
64        >,
65>;
66
67dyn_clone::clone_trait_object!(
68    TaskNode<
69        Input = dyn NodeArg,
70        Output = dyn NodeArg,
71        Error = dyn std::error::Error + Send + Sync,
72    >
73);
74
75#[async_trait]
76impl<Input: NodeArg, Output: NodeArg, Error: std::error::Error + Send + Sync + 'static> TaskNode
77    for Box<dyn TaskNode<Input = Input, Output = Output, Error = Error>>
78{
79    type Input = Input;
80    type Output = Output;
81    type Error = Error;
82
83    async fn evaluate(
84        &self,
85        node_id: &NodeId<
86            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
87        >,
88        input: &Self::Input,
89    ) -> Result<Self::Output, Self::Error> {
90        self.as_ref().evaluate(node_id, input).await
91    }
92}
93
94dyn_clone::clone_trait_object!(<Input, Output, Error> TaskNode<Input = Input, Output = Output, Error = Error>);
95
96#[derive(PartialEq, Eq)]
97pub struct NodeId<T: TaskNode + ?Sized> {
98    pub id: usize,
99    _marker: std::marker::PhantomData<T>,
100}
101
102impl<T: TaskNode + ?Sized> std::fmt::Debug for NodeId<T> {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        let type_name = std::any::type_name::<T>();
105
106        write!(f, "NodeId<{type_name}>({})", self.id)
107    }
108}
109
110pub type AnyNodeId = usize;
111
112impl<T: TaskNode + ?Sized> NodeId<T> {
113    pub fn id(&self) -> usize {
114        self.id
115    }
116
117    /// Returns a closure that can be used as a transition function
118    pub fn as_transition(&self) -> impl Fn(T::Input) -> MarkedTransitionPayload<T> + 'static {
119        let node_id = *self;
120
121        Box::new(move |context| node_id.transitions_with(context))
122    }
123
124    /// Returns a transition payload suitable for inside a task transition
125    ///
126    /// You can also get the closure version with `as_transition`
127    pub fn transitions_with(&self, context: T::Input) -> MarkedTransitionPayload<T> {
128        MarkedTransitionPayload::new(TransitionPayload::next_node(self, context))
129    }
130}
131
132impl<T: TaskNode + 'static + ?Sized> NodeId<T> {
133    pub fn new(id: usize, _node: &T) -> Self {
134        NodeId {
135            id,
136            _marker: std::marker::PhantomData,
137        }
138    }
139
140    /// Returns the internal id of the node without the type information.
141    pub fn as_any(&self) -> AnyNodeId {
142        self.id
143    }
144
145    pub fn as_dyn(
146        self,
147    ) -> NodeId<dyn TaskNode<Input = T::Input, Output = T::Output, Error = T::Error>> {
148        NodeId {
149            id: self.id,
150            _marker: std::marker::PhantomData,
151        }
152    }
153}
154
155impl<T: TaskNode + ?Sized> Clone for NodeId<T> {
156    fn clone(&self) -> Self {
157        *self
158    }
159}
160impl<T: TaskNode + ?Sized> Copy for NodeId<T> {}