1use crate::error::GraphResult;
4use crate::state::{GraphRunContext, GraphState};
5use async_trait::async_trait;
6use std::future::Future;
7use std::marker::PhantomData;
8use std::sync::Arc;
9
10pub enum NodeResult<State, Deps, End> {
12 Next(Box<dyn BaseNode<State, Deps, End>>),
14 NextNamed(String),
16 End(End),
18}
19
20impl<State, Deps, End> NodeResult<State, Deps, End> {
21 pub fn next<N: BaseNode<State, Deps, End> + 'static>(node: N) -> Self {
23 Self::Next(Box::new(node))
24 }
25
26 pub fn next_named(name: impl Into<String>) -> Self {
28 Self::NextNamed(name.into())
29 }
30
31 pub fn end(value: End) -> Self {
33 Self::End(value)
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub struct End<T>(pub T);
40
41impl<T> End<T> {
42 pub fn new(value: T) -> Self {
44 Self(value)
45 }
46
47 pub fn into_inner(self) -> T {
49 self.0
50 }
51
52 pub fn value(&self) -> &T {
54 &self.0
55 }
56}
57
58impl<T: Default> Default for End<T> {
59 fn default() -> Self {
60 Self(T::default())
61 }
62}
63
64#[async_trait]
66pub trait BaseNode<State, Deps = (), End = ()>: Send + Sync {
67 fn type_name(&self) -> &'static str {
69 std::any::type_name::<Self>()
70 }
71
72 fn name(&self) -> &str {
74 self.type_name()
75 }
76
77 async fn run(
79 &self,
80 ctx: &mut GraphRunContext<State, Deps>,
81 ) -> GraphResult<NodeResult<State, Deps, End>>;
82}
83
84#[async_trait]
86pub trait Node<State: GraphState>: Send + Sync {
87 async fn execute(&self, state: State) -> GraphResult<State>;
89
90 fn name(&self) -> &str;
92}
93
94pub struct FunctionNode<State, F, Fut>
96where
97 F: Fn(State) -> Fut + Send + Sync,
98 Fut: Future<Output = GraphResult<State>> + Send,
99{
100 name: String,
101 func: F,
102 _phantom: PhantomData<State>,
103}
104
105impl<State, F, Fut> FunctionNode<State, F, Fut>
106where
107 F: Fn(State) -> Fut + Send + Sync,
108 Fut: Future<Output = GraphResult<State>> + Send,
109{
110 pub fn new(name: impl Into<String>, func: F) -> Self {
112 Self {
113 name: name.into(),
114 func,
115 _phantom: PhantomData,
116 }
117 }
118}
119
120#[async_trait]
121impl<State, F, Fut> Node<State> for FunctionNode<State, F, Fut>
122where
123 State: GraphState,
124 F: Fn(State) -> Fut + Send + Sync,
125 Fut: Future<Output = GraphResult<State>> + Send,
126{
127 async fn execute(&self, state: State) -> GraphResult<State> {
128 (self.func)(state).await
129 }
130
131 fn name(&self) -> &str {
132 &self.name
133 }
134}
135
136#[allow(dead_code)]
138pub struct AgentNode<State, Agent, UpdateFn>
139where
140 UpdateFn: Fn(State, &Agent) -> State + Send + Sync,
141{
142 name: String,
143 agent: Arc<Agent>,
144 update_state: UpdateFn,
145 _phantom: PhantomData<State>,
146}
147
148impl<State, Agent, UpdateFn> AgentNode<State, Agent, UpdateFn>
149where
150 UpdateFn: Fn(State, &Agent) -> State + Send + Sync,
151{
152 pub fn new(name: impl Into<String>, agent: Agent, update_state: UpdateFn) -> Self {
154 Self {
155 name: name.into(),
156 agent: Arc::new(agent),
157 update_state,
158 _phantom: PhantomData,
159 }
160 }
161
162 pub fn agent(&self) -> &Agent {
164 &self.agent
165 }
166}
167
168pub struct RouterNode<State, F>
170where
171 F: Fn(&State) -> String + Send + Sync,
172{
173 #[allow(dead_code)]
174 name: String,
175 router: F,
176 _phantom: PhantomData<State>,
177}
178
179impl<State, F> RouterNode<State, F>
180where
181 F: Fn(&State) -> String + Send + Sync,
182{
183 pub fn new(name: impl Into<String>, router: F) -> Self {
185 Self {
186 name: name.into(),
187 router,
188 _phantom: PhantomData,
189 }
190 }
191
192 pub fn route(&self, state: &State) -> String {
194 (self.router)(state)
195 }
196}
197
198#[allow(dead_code)]
200pub struct ConditionalNode<State, Cond, Then, Else>
201where
202 Cond: Fn(&State) -> bool + Send + Sync,
203 Then: BaseNode<State> + 'static,
204 Else: BaseNode<State> + 'static,
205{
206 name: String,
207 condition: Cond,
208 then_node: Box<Then>,
209 else_node: Box<Else>,
210 _phantom: PhantomData<State>,
211}
212
213impl<State, Cond, Then, Else> ConditionalNode<State, Cond, Then, Else>
214where
215 Cond: Fn(&State) -> bool + Send + Sync,
216 Then: BaseNode<State> + 'static,
217 Else: BaseNode<State> + 'static,
218{
219 pub fn new(name: impl Into<String>, condition: Cond, then_node: Then, else_node: Else) -> Self {
221 Self {
222 name: name.into(),
223 condition,
224 then_node: Box::new(then_node),
225 else_node: Box::new(else_node),
226 _phantom: PhantomData,
227 }
228 }
229}
230
231pub struct NodeDef<State, Deps = (), End = ()> {
233 pub name: String,
235 pub node: Box<dyn BaseNode<State, Deps, End>>,
237}
238
239impl<State, Deps, End> NodeDef<State, Deps, End> {
240 pub fn new<N: BaseNode<State, Deps, End> + 'static>(name: impl Into<String>, node: N) -> Self {
242 Self {
243 name: name.into(),
244 node: Box::new(node),
245 }
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use super::*;
252
253 #[derive(Debug, Clone, Default)]
254 struct TestState {
255 value: i32,
256 }
257
258 #[test]
259 fn test_end_marker() {
260 let end = End::new(42);
261 assert_eq!(end.value(), &42);
262 assert_eq!(end.into_inner(), 42);
263 }
264
265 #[test]
266 fn test_node_result_variants() {
267 let _next_named: NodeResult<TestState, (), i32> = NodeResult::next_named("next");
268 let _end: NodeResult<TestState, (), i32> = NodeResult::end(42);
269 }
270
271 #[test]
272 fn test_router_node() {
273 let router = RouterNode::new("router", |state: &TestState| {
274 if state.value > 0 {
275 "positive".to_string()
276 } else {
277 "negative".to_string()
278 }
279 });
280
281 assert_eq!(router.route(&TestState { value: 1 }), "positive");
282 assert_eq!(router.route(&TestState { value: -1 }), "negative");
283 }
284}