swiftide_agents/tasks/
node.rs

1use std::any::Any;
2
3use async_trait::async_trait;
4use dyn_clone::DynClone;
5
6use super::{
7    errors::NodeError,
8    transition::{MarkedTransitionPayload, TransitionPayload},
9};
10
11pub trait NodeArg: Send + Sync + DynClone + 'static {}
12
13impl<T: Send + Sync + std::fmt::Debug + 'static + Clone> NodeArg for T {}
14
15#[derive(Debug, Clone)]
16pub struct NoopNode<Context: NodeArg> {
17    _marker: std::marker::PhantomData<(Context, Box<dyn std::error::Error + Send + Sync>)>,
18}
19
20impl<Context> Default for NoopNode<Context>
21where
22    Context: NodeArg,
23{
24    fn default() -> Self {
25        NoopNode {
26            _marker: std::marker::PhantomData,
27        }
28    }
29}
30
31#[async_trait]
32impl<Context: NodeArg + Clone> TaskNode for NoopNode<Context> {
33    type Output = ();
34    type Input = Context;
35    type Error = NodeError;
36
37    async fn evaluate(
38        &self,
39        _node_id: &DynNodeId<Self>,
40        _context: &Context,
41    ) -> Result<Self::Output, Self::Error> {
42        Ok(())
43    }
44}
45
46#[async_trait]
47pub trait TaskNode: Send + Sync + DynClone + Any {
48    type Input: NodeArg;
49    type Output: NodeArg;
50    type Error: std::error::Error + Send + Sync + 'static;
51
52    async fn evaluate(
53        &self,
54        node_id: &DynNodeId<Self>,
55        input: &Self::Input,
56    ) -> Result<Self::Output, Self::Error>;
57}
58
59pub type DynNodeId<T> = NodeId<
60    dyn TaskNode<
61            Input = <T as TaskNode>::Input,
62            Output = <T as TaskNode>::Output,
63            Error = <T as TaskNode>::Error,
64        >,
65>;
66
67dyn_clone::clone_trait_object!(
68    TaskNode<
69        Input = dyn NodeArg,
70        Output = dyn NodeArg,
71        Error = dyn std::error::Error + Send + Sync,
72    >
73);
74
75#[async_trait]
76impl<Input: NodeArg, Output: NodeArg, Error: std::error::Error + Send + Sync + 'static> TaskNode
77    for Box<dyn TaskNode<Input = Input, Output = Output, Error = Error>>
78{
79    type Input = Input;
80    type Output = Output;
81    type Error = Error;
82
83    async fn evaluate(
84        &self,
85        node_id: &NodeId<
86            dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
87        >,
88        input: &Self::Input,
89    ) -> Result<Self::Output, Self::Error> {
90        self.as_ref().evaluate(node_id, input).await
91    }
92}
93
94dyn_clone::clone_trait_object!(<Input, Output, Error> TaskNode<Input = Input, Output = Output, Error = Error>);
95
96#[derive(PartialEq, Eq)]
97pub struct NodeId<T: TaskNode + ?Sized> {
98    pub id: usize,
99    _marker: std::marker::PhantomData<T>,
100}
101
102impl<T: TaskNode + ?Sized> std::fmt::Debug for NodeId<T> {
103    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104        let type_name = std::any::type_name::<T>();
105
106        write!(f, "NodeId<{type_name}>({})", self.id)
107    }
108}
109
110pub type AnyNodeId = usize;
111
112impl<T: TaskNode + 'static + ?Sized> NodeId<T> {
113    pub fn new(id: usize, _node: &T) -> Self {
114        NodeId {
115            id,
116            _marker: std::marker::PhantomData,
117        }
118    }
119
120    pub fn transitions_with(&self, context: T::Input) -> MarkedTransitionPayload<T> {
121        MarkedTransitionPayload::new(TransitionPayload::next_node(self, context))
122    }
123
124    /// Returns the internal id of the node without the type information.
125    pub fn as_any(&self) -> AnyNodeId {
126        self.id
127    }
128
129    pub fn as_dyn(
130        self,
131    ) -> NodeId<dyn TaskNode<Input = T::Input, Output = T::Output, Error = T::Error>> {
132        NodeId {
133            id: self.id,
134            _marker: std::marker::PhantomData,
135        }
136    }
137}
138
139impl<T: TaskNode + ?Sized> Clone for NodeId<T> {
140    fn clone(&self) -> Self {
141        *self
142    }
143}
144impl<T: TaskNode + ?Sized> Copy for NodeId<T> {}