Skip to main content

serdes_ai_graph/
graph.rs

1//! Graph definition and execution.
2
3use crate::edge::Edge;
4use crate::error::{GraphError, GraphResult};
5use crate::executor::ExecutionOptions;
6use crate::node::{BaseNode, Node, NodeDef, NodeResult};
7use crate::state::{generate_run_id, GraphRunContext, GraphRunResult, GraphState};
8use std::collections::HashMap;
9
10/// A graph for multi-agent workflows.
11pub struct Graph<State, Deps = (), End = ()>
12where
13    State: GraphState,
14{
15    name: Option<String>,
16    /// Nodes in the graph.
17    pub nodes: HashMap<String, NodeDef<State, Deps, End>>,
18    edges: Vec<Edge<State>>,
19    entry_node: Option<String>,
20    finish_nodes: Vec<String>,
21    max_steps: u32,
22    auto_instrument: bool,
23}
24
25impl<State, Deps, End> Graph<State, Deps, End>
26where
27    State: GraphState,
28    Deps: Send + Sync + 'static,
29    End: Send + Sync + 'static,
30{
31    /// Create a new empty graph.
32    pub fn new() -> Self {
33        Self {
34            name: None,
35            nodes: HashMap::new(),
36            edges: Vec::new(),
37            entry_node: None,
38            finish_nodes: Vec::new(),
39            max_steps: 100,
40            auto_instrument: true,
41        }
42    }
43
44    /// Set the graph name.
45    pub fn with_name(mut self, name: impl Into<String>) -> Self {
46        self.name = Some(name.into());
47        self
48    }
49
50    /// Set maximum steps.
51    pub fn with_max_steps(mut self, max: u32) -> Self {
52        self.max_steps = max;
53        self
54    }
55
56    /// Disable auto instrumentation.
57    pub fn without_instrumentation(mut self) -> Self {
58        self.auto_instrument = false;
59        self
60    }
61
62    /// Add a node to the graph.
63    pub fn node<N>(mut self, name: impl Into<String>, node: N) -> Self
64    where
65        N: BaseNode<State, Deps, End> + 'static,
66    {
67        let name = name.into();
68        self.nodes.insert(name.clone(), NodeDef::new(name, node));
69        self
70    }
71
72    /// Add an edge with a condition.
73    pub fn edge<F>(mut self, from: impl Into<String>, to: impl Into<String>, condition: F) -> Self
74    where
75        F: Fn(&State) -> bool + Send + Sync + 'static,
76    {
77        self.edges.push(Edge::new(from, to, condition));
78        self
79    }
80
81    /// Add an unconditional edge.
82    pub fn edge_always(mut self, from: impl Into<String>, to: impl Into<String>) -> Self {
83        self.edges.push(Edge::unconditional(from, to));
84        self
85    }
86
87    /// Set the entry node.
88    pub fn entry(mut self, name: impl Into<String>) -> Self {
89        self.entry_node = Some(name.into());
90        self
91    }
92
93    /// Set finish nodes.
94    pub fn finish(mut self, names: &[&str]) -> Self {
95        self.finish_nodes = names.iter().map(|s| s.to_string()).collect();
96        self
97    }
98
99    /// Add a finish node.
100    pub fn add_finish(mut self, name: impl Into<String>) -> Self {
101        self.finish_nodes.push(name.into());
102        self
103    }
104
105    /// Get the graph name.
106    pub fn name(&self) -> Option<&str> {
107        self.name.as_deref()
108    }
109
110    /// Get node names.
111    pub fn node_names(&self) -> impl Iterator<Item = &str> {
112        self.nodes.keys().map(|s| s.as_str())
113    }
114
115    /// Get node count.
116    pub fn node_count(&self) -> usize {
117        self.nodes.len()
118    }
119
120    /// Get edge count.
121    pub fn edge_count(&self) -> usize {
122        self.edges.len()
123    }
124
125    /// Get edges.
126    pub fn edges(&self) -> &[Edge<State>] {
127        &self.edges
128    }
129
130    fn detect_cycle(
131        node: &str,
132        adjacency: &HashMap<String, Vec<String>>,
133        visiting: &mut std::collections::HashSet<String>,
134        visited: &mut std::collections::HashSet<String>,
135    ) -> bool {
136        if visited.contains(node) {
137            return false;
138        }
139        if visiting.contains(node) {
140            return true;
141        }
142
143        visiting.insert(node.to_string());
144        if let Some(neighbors) = adjacency.get(node) {
145            for neighbor in neighbors {
146                if Self::detect_cycle(neighbor, adjacency, visiting, visited) {
147                    return true;
148                }
149            }
150        }
151        visiting.remove(node);
152        visited.insert(node.to_string());
153        false
154    }
155
156    /// Validate the graph configuration.
157    pub fn validate(&self) -> GraphResult<()> {
158        // Check entry node exists
159        if let Some(ref entry) = self.entry_node {
160            if !self.nodes.contains_key(entry) {
161                return Err(GraphError::node_not_found(entry));
162            }
163        } else {
164            return Err(GraphError::NoEntryNode);
165        }
166
167        // Check all edge references exist
168        for edge in &self.edges {
169            if !self.nodes.contains_key(&edge.from) {
170                return Err(GraphError::node_not_found(&edge.from));
171            }
172            if !self.nodes.contains_key(&edge.to) {
173                return Err(GraphError::node_not_found(&edge.to));
174            }
175        }
176
177        // Check finish nodes exist
178        for finish in &self.finish_nodes {
179            if !self.nodes.contains_key(finish) {
180                return Err(GraphError::node_not_found(finish));
181            }
182        }
183
184        let mut adjacency: HashMap<String, Vec<String>> = HashMap::new();
185        for edge in &self.edges {
186            adjacency
187                .entry(edge.from.clone())
188                .or_default()
189                .push(edge.to.clone());
190        }
191
192        let mut visiting = std::collections::HashSet::new();
193        let mut visited = std::collections::HashSet::new();
194        for node in self.nodes.keys() {
195            if Self::detect_cycle(node, &adjacency, &mut visiting, &mut visited) {
196                return Err(GraphError::CycleDetected);
197            }
198        }
199
200        Ok(())
201    }
202
203    /// Build and validate the graph.
204    pub fn build(self) -> GraphResult<Self> {
205        self.validate()?;
206        Ok(self)
207    }
208}
209
210impl<State, Deps, End> Graph<State, Deps, End>
211where
212    State: GraphState,
213    Deps: Clone + Send + Sync + 'static,
214    End: Clone + Send + Sync + 'static,
215{
216    /// Run the graph from the entry node.
217    pub async fn run(&self, state: State, deps: Deps) -> GraphResult<GraphRunResult<State, End>> {
218        let options = ExecutionOptions::new()
219            .max_steps(self.max_steps)
220            .tracing(self.auto_instrument);
221        self.run_with_options(state, deps, options).await
222    }
223
224    /// Run the graph from the entry node with options.
225    pub async fn run_with_options(
226        &self,
227        state: State,
228        deps: Deps,
229        options: ExecutionOptions,
230    ) -> GraphResult<GraphRunResult<State, End>> {
231        let entry = self.entry_node.as_ref().ok_or(GraphError::NoEntryNode)?;
232        let start_node = self
233            .nodes
234            .get(entry)
235            .ok_or_else(|| GraphError::node_not_found(entry))?;
236
237        self.run_from_with_options(&*start_node.node, state, deps, options)
238            .await
239    }
240
241    /// Run the graph from a specific node.
242    pub async fn run_from<N>(
243        &self,
244        start: &N,
245        state: State,
246        deps: Deps,
247    ) -> GraphResult<GraphRunResult<State, End>>
248    where
249        N: BaseNode<State, Deps, End> + ?Sized,
250    {
251        let options = ExecutionOptions::new()
252            .max_steps(self.max_steps)
253            .tracing(self.auto_instrument);
254        self.run_from_with_options(start, state, deps, options)
255            .await
256    }
257
258    /// Run the graph from a specific node with options.
259    pub async fn run_from_with_options<N>(
260        &self,
261        start: &N,
262        state: State,
263        deps: Deps,
264        mut options: ExecutionOptions,
265    ) -> GraphResult<GraphRunResult<State, End>>
266    where
267        N: BaseNode<State, Deps, End> + ?Sized,
268    {
269        let run_id = options.run_id.take().unwrap_or_else(generate_run_id);
270        let max_steps = options.max_steps;
271        let mut ctx = GraphRunContext::new(state, deps, &run_id).with_max_steps(max_steps);
272        let mut history = Vec::new();
273        let mut steps = 0;
274
275        steps += 1;
276        if steps > max_steps {
277            return Err(GraphError::MaxStepsExceeded(max_steps));
278        }
279        ctx.increment_step();
280        let node_name = start.name().to_string();
281        history.push(node_name);
282
283        let mut result = start.run(&mut ctx).await?;
284
285        loop {
286            match result {
287                NodeResult::Next(next) => {
288                    steps += 1;
289                    if steps > max_steps {
290                        return Err(GraphError::MaxStepsExceeded(max_steps));
291                    }
292                    ctx.increment_step();
293                    let name = next.name().to_string();
294                    history.push(name);
295                    result = next.run(&mut ctx).await?;
296                }
297                NodeResult::NextNamed(name) => {
298                    let node = self
299                        .nodes
300                        .get(&name)
301                        .ok_or_else(|| GraphError::node_not_found(&name))?;
302                    steps += 1;
303                    if steps > max_steps {
304                        return Err(GraphError::MaxStepsExceeded(max_steps));
305                    }
306                    ctx.increment_step();
307                    history.push(name);
308                    result = node.node.run(&mut ctx).await?;
309                }
310                NodeResult::End(end) => {
311                    return Ok(
312                        GraphRunResult::new(end, ctx.state, ctx.step, run_id).with_history(history)
313                    );
314                }
315            }
316        }
317    }
318}
319
320impl<State, Deps, End> Default for Graph<State, Deps, End>
321where
322    State: GraphState,
323    Deps: Send + Sync + 'static,
324    End: Send + Sync + 'static,
325{
326    fn default() -> Self {
327        Self::new()
328    }
329}
330
331/// Simple graph using the old API (state-only nodes).
332pub struct SimpleGraph<State: GraphState> {
333    nodes: HashMap<String, Box<dyn Node<State>>>,
334    edges: Vec<Edge<State>>,
335    entry_node: Option<String>,
336    finish_nodes: Vec<String>,
337}
338
339impl<State: GraphState + 'static> SimpleGraph<State> {
340    /// Create a new simple graph.
341    pub fn new() -> Self {
342        Self {
343            nodes: HashMap::new(),
344            edges: Vec::new(),
345            entry_node: None,
346            finish_nodes: Vec::new(),
347        }
348    }
349
350    /// Add a node.
351    pub fn add_node(mut self, name: impl Into<String>, node: impl Node<State> + 'static) -> Self {
352        self.nodes.insert(name.into(), Box::new(node));
353        self
354    }
355
356    /// Add a conditional edge.
357    pub fn add_edge<F>(
358        mut self,
359        from: impl Into<String>,
360        to: impl Into<String>,
361        condition: F,
362    ) -> Self
363    where
364        F: Fn(&State) -> bool + Send + Sync + 'static,
365    {
366        self.edges.push(Edge::new(from, to, condition));
367        self
368    }
369
370    /// Set entry node.
371    pub fn set_entry(mut self, name: impl Into<String>) -> Self {
372        self.entry_node = Some(name.into());
373        self
374    }
375
376    /// Set finish nodes.
377    pub fn set_finish(mut self, names: &[&str]) -> Self {
378        self.finish_nodes = names.iter().map(|s| s.to_string()).collect();
379        self
380    }
381
382    /// Build the graph.
383    pub fn build(self) -> GraphResult<Self> {
384        if self.entry_node.is_none() {
385            return Err(GraphError::NoEntryNode);
386        }
387        Ok(self)
388    }
389
390    /// Run the graph.
391    pub async fn run(&self, mut state: State) -> GraphResult<State> {
392        let entry = self.entry_node.as_ref().ok_or(GraphError::NoEntryNode)?;
393        let mut current = entry.clone();
394
395        loop {
396            if self.finish_nodes.contains(&current) {
397                break;
398            }
399
400            let node = self
401                .nodes
402                .get(&current)
403                .ok_or_else(|| GraphError::node_not_found(&current))?;
404
405            state = node.execute(state).await?;
406
407            // Find next node
408            let next = self
409                .edges
410                .iter()
411                .find(|e| e.from == current && e.matches(&state));
412            match next {
413                Some(edge) => current = edge.to.clone(),
414                None => break,
415            }
416        }
417
418        Ok(state)
419    }
420}
421
422impl<State: GraphState + 'static> Default for SimpleGraph<State> {
423    fn default() -> Self {
424        Self::new()
425    }
426}
427
428#[cfg(test)]
429mod tests {
430    use super::*;
431    use async_trait::async_trait;
432
433    #[derive(Debug, Clone, Default)]
434    struct TestState {
435        value: i32,
436    }
437
438    struct IncrementNode;
439
440    #[async_trait]
441    impl BaseNode<TestState, (), i32> for IncrementNode {
442        fn name(&self) -> &str {
443            "increment"
444        }
445
446        async fn run(
447            &self,
448            ctx: &mut GraphRunContext<TestState, ()>,
449        ) -> GraphResult<NodeResult<TestState, (), i32>> {
450            ctx.state.value += 1;
451            if ctx.state.value >= 3 {
452                Ok(NodeResult::end(ctx.state.value))
453            } else {
454                Ok(NodeResult::next(IncrementNode))
455            }
456        }
457    }
458
459    #[tokio::test]
460    async fn test_simple_graph_run() {
461        let graph = Graph::<TestState, (), i32>::new()
462            .with_name("test")
463            .node("start", IncrementNode)
464            .entry("start")
465            .build()
466            .unwrap();
467
468        let result = graph.run(TestState::default(), ()).await.unwrap();
469        assert_eq!(result.result, 3);
470        assert_eq!(result.steps, 3);
471    }
472
473    #[test]
474    fn test_graph_validation() {
475        let graph = Graph::<TestState, (), i32>::new()
476            .node("a", IncrementNode)
477            .entry("missing");
478
479        assert!(graph.build().is_err());
480    }
481
482    #[test]
483    fn test_graph_no_entry() {
484        let graph = Graph::<TestState, (), i32>::new().node("a", IncrementNode);
485
486        assert!(graph.build().is_err());
487    }
488}