swiftide_agents/tasks/
node.rs1use std::any::Any;
2
3use async_trait::async_trait;
4use dyn_clone::DynClone;
5
6use super::{
7 errors::NodeError,
8 transition::{MarkedTransitionPayload, TransitionPayload},
9};
10
11pub trait NodeArg: Send + Sync + DynClone + 'static {}
12
13impl<T: Send + Sync + std::fmt::Debug + 'static + Clone> NodeArg for T {}
14
15#[derive(Debug, Clone)]
16pub struct NoopNode<Context: NodeArg> {
17 _marker: std::marker::PhantomData<(Context, Box<dyn std::error::Error + Send + Sync>)>,
18}
19
20impl<Context> Default for NoopNode<Context>
21where
22 Context: NodeArg,
23{
24 fn default() -> Self {
25 NoopNode {
26 _marker: std::marker::PhantomData,
27 }
28 }
29}
30
31#[async_trait]
32impl<Context: NodeArg + Clone> TaskNode for NoopNode<Context> {
33 type Output = ();
34 type Input = Context;
35 type Error = NodeError;
36
37 async fn evaluate(
38 &self,
39 _node_id: &DynNodeId<Self>,
40 _context: &Context,
41 ) -> Result<Self::Output, Self::Error> {
42 Ok(())
43 }
44}
45
46#[async_trait]
47pub trait TaskNode: Send + Sync + DynClone + Any {
48 type Input: NodeArg;
49 type Output: NodeArg;
50 type Error: std::error::Error + Send + Sync + 'static;
51
52 async fn evaluate(
53 &self,
54 node_id: &DynNodeId<Self>,
55 input: &Self::Input,
56 ) -> Result<Self::Output, Self::Error>;
57}
58
59pub type DynNodeId<T> = NodeId<
60 dyn TaskNode<
61 Input = <T as TaskNode>::Input,
62 Output = <T as TaskNode>::Output,
63 Error = <T as TaskNode>::Error,
64 >,
65>;
66
67dyn_clone::clone_trait_object!(
68 TaskNode<
69 Input = dyn NodeArg,
70 Output = dyn NodeArg,
71 Error = dyn std::error::Error + Send + Sync,
72 >
73);
74
75#[async_trait]
76impl<Input: NodeArg, Output: NodeArg, Error: std::error::Error + Send + Sync + 'static> TaskNode
77 for Box<dyn TaskNode<Input = Input, Output = Output, Error = Error>>
78{
79 type Input = Input;
80 type Output = Output;
81 type Error = Error;
82
83 async fn evaluate(
84 &self,
85 node_id: &NodeId<
86 dyn TaskNode<Input = Self::Input, Output = Self::Output, Error = Self::Error>,
87 >,
88 input: &Self::Input,
89 ) -> Result<Self::Output, Self::Error> {
90 self.as_ref().evaluate(node_id, input).await
91 }
92}
93
94dyn_clone::clone_trait_object!(<Input, Output, Error> TaskNode<Input = Input, Output = Output, Error = Error>);
95
96#[derive(PartialEq, Eq)]
97pub struct NodeId<T: TaskNode + ?Sized> {
98 pub id: usize,
99 _marker: std::marker::PhantomData<T>,
100}
101
102impl<T: TaskNode + ?Sized> std::fmt::Debug for NodeId<T> {
103 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
104 let type_name = std::any::type_name::<T>();
105
106 write!(f, "NodeId<{type_name}>({})", self.id)
107 }
108}
109
110pub type AnyNodeId = usize;
111
112impl<T: TaskNode + 'static + ?Sized> NodeId<T> {
113 pub fn new(id: usize, _node: &T) -> Self {
114 NodeId {
115 id,
116 _marker: std::marker::PhantomData,
117 }
118 }
119
120 pub fn transitions_with(&self, context: T::Input) -> MarkedTransitionPayload<T> {
121 MarkedTransitionPayload::new(TransitionPayload::next_node(self, context))
122 }
123
124 pub fn as_any(&self) -> AnyNodeId {
126 self.id
127 }
128
129 pub fn as_dyn(
130 self,
131 ) -> NodeId<dyn TaskNode<Input = T::Input, Output = T::Output, Error = T::Error>> {
132 NodeId {
133 id: self.id,
134 _marker: std::marker::PhantomData,
135 }
136 }
137}
138
139impl<T: TaskNode + ?Sized> Clone for NodeId<T> {
140 fn clone(&self) -> Self {
141 *self
142 }
143}
144impl<T: TaskNode + ?Sized> Copy for NodeId<T> {}