Skip to main content

rust_langgraph/graph/
state_graph.rs

1//! StateGraph builder and CompiledGraph.
2//!
3//! StateGraph provides an ergonomic builder API for creating graphs,
4//! which then compile into executable CompiledGraph instances.
5
6use crate::channels::{BaseChannel, LastValue};
7use crate::checkpoint::{BaseCheckpointSaver, CheckpointMetadata, StateSnapshot};
8use crate::config::Config;
9use crate::errors::{Error, Result};
10use crate::graph::{START, END};
11use crate::nodes::{Node, PregelNode, ChannelWrite, NodeArc};
12use crate::pregel::{Branch, Pregel};
13use crate::state::State;
14use crate::types::{StreamEvent, StreamMode};
15use futures::stream::Stream;
16use std::collections::{HashMap, HashSet};
17use std::pin::Pin;
18use std::sync::Arc;
19
20/// Builder for creating state-based graphs.
21///
22/// StateGraph provides a declarative API for building graphs where nodes
23/// communicate through shared state. It compiles into a CompiledGraph
24/// which can be executed.
25///
26/// # Example
27///
28/// ```rust
29/// use rust_langgraph::{StateGraph, Config};
30/// # use rust_langgraph::{State, Error};
31/// # #[derive(Clone, serde::Serialize, serde::Deserialize)]
32/// # struct MyState { count: i32 }
33/// # impl State for MyState {
34/// #     fn merge(&mut self, other: Self) -> Result<(), Error> {
35/// #         self.count += other.count;
36/// #         Ok(())
37/// #     }
38/// # }
39///
40/// let mut graph = StateGraph::new();
41///
42/// graph.add_node("process", |mut state: MyState, _config: &Config| async move {
43///     state.count += 1;
44///     Ok(state)
45/// });
46///
47/// graph.set_entry_point("process");
48/// graph.set_finish_point("process");
49///
50/// let app = graph.compile(None).unwrap();
51/// ```
52pub struct StateGraph<S: State> {
53    nodes: HashMap<String, Box<dyn Node<S>>>,
54    edges: HashMap<String, Vec<String>>,
55    conditional_edges: HashMap<String, Box<dyn Branch<S>>>,
56    entry_point: Option<String>,
57    finish_points: HashSet<String>,
58}
59
60impl<S: State> StateGraph<S> {
61    /// Create a new StateGraph
62    pub fn new() -> Self {
63        Self {
64            nodes: HashMap::new(),
65            edges: HashMap::new(),
66            conditional_edges: HashMap::new(),
67            entry_point: None,
68            finish_points: HashSet::new(),
69        }
70    }
71
72    /// Add a node to the graph
73    ///
74    /// # Arguments
75    ///
76    /// * `name` - Unique identifier for this node
77    /// * `node` - The node implementation (function or struct implementing Node)
78    ///
79    /// # Example
80    ///
81    /// ```rust
82    /// # use rust_langgraph::{StateGraph, Config, State, Error};
83    /// # #[derive(Clone, serde::Serialize, serde::Deserialize)]
84    /// # struct MyState { count: i32 }
85    /// # impl State for MyState {
86    /// #     fn merge(&mut self, other: Self) -> Result<(), Error> { Ok(()) }
87    /// # }
88    /// let mut graph = StateGraph::new();
89    ///
90    /// graph.add_node("increment", |mut state: MyState, _config: &Config| async move {
91    ///     state.count += 1;
92    ///     Ok(state)
93    /// });
94    /// ```
95    pub fn add_node(&mut self, name: impl Into<String>, node: impl Node<S> + 'static) -> &mut Self {
96        self.nodes.insert(name.into(), Box::new(node));
97        self
98    }
99
100    /// Add a static edge from one node to another
101    ///
102    /// # Arguments
103    ///
104    /// * `from` - Source node name
105    /// * `to` - Target node name
106    pub fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self {
107        let from = from.into();
108        let to = to.into();
109
110        self.edges.entry(from).or_default().push(to);
111        self
112    }
113
114    /// Add conditional edges from a source node
115    ///
116    /// The branch function examines state and returns which node(s) to route to next.
117    ///
118    /// # Arguments
119    ///
120    /// * `source` - Source node name
121    /// * `branch` - Branch logic that determines routing
122    ///
123    /// # Example
124    ///
125    /// ```rust
126    /// # use rust_langgraph::{StateGraph, Config, State, Error};
127    /// # use rust_langgraph::pregel::BranchResult;
128    /// # #[derive(Clone, serde::Serialize, serde::Deserialize)]
129    /// # struct MyState { value: i32 }
130    /// # impl State for MyState {
131    /// #     fn merge(&mut self, other: Self) -> Result<(), Error> { Ok(()) }
132    /// # }
133    /// let mut graph = StateGraph::new();
134    ///
135    /// graph.add_conditional_edges(
136    ///     "check",
137    ///     |state: &MyState| async move {
138    ///         if state.value > 0 {
139    ///             Ok(BranchResult::single("positive"))
140    ///         } else {
141    ///             Ok(BranchResult::single("negative"))
142    ///         }
143    ///     }
144    /// );
145    /// ```
146    pub fn add_conditional_edges(
147        &mut self,
148        source: impl Into<String>,
149        branch: impl Branch<S> + 'static,
150    ) -> &mut Self {
151        self.conditional_edges.insert(source.into(), Box::new(branch));
152        self
153    }
154
155    /// Set the entry point for the graph
156    ///
157    /// This is the first node to execute when the graph is invoked.
158    pub fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self {
159        self.entry_point = Some(node.into());
160        self
161    }
162
163    /// Set a finish point for the graph
164    ///
165    /// When execution reaches a finish point, the graph completes.
166    pub fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self {
167        self.finish_points.insert(node.into());
168        self
169    }
170
171    /// Add multiple finish points
172    pub fn add_finish_points(&mut self, nodes: Vec<impl Into<String>>) -> &mut Self {
173        for node in nodes {
174            self.finish_points.insert(node.into());
175        }
176        self
177    }
178
179    /// Compile the graph into an executable CompiledGraph
180    ///
181    /// # Arguments
182    ///
183    /// * `checkpointer` - Optional checkpoint saver for persistence
184    ///
185    /// # Returns
186    ///
187    /// A CompiledGraph ready to execute
188    ///
189    /// # Errors
190    ///
191    /// Returns an error if the graph configuration is invalid (e.g., no entry point)
192    pub fn compile(
193        self,
194        checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
195    ) -> Result<CompiledGraph<S>> {
196        // Validate graph
197        if self.entry_point.is_none() {
198            return Err(Error::invalid_graph("No entry point set"));
199        }
200
201        let entry_point = self.entry_point.unwrap();
202
203        if !self.nodes.contains_key(&entry_point) {
204            return Err(Error::invalid_graph(format!(
205                "Entry point '{}' is not a valid node",
206                entry_point
207            )));
208        }
209
210        // Build PregelNodes from our nodes
211        let mut pregel_nodes = HashMap::new();
212
213        for (name, node) in self.nodes {
214            // Determine triggers: nodes triggered by their dependencies
215            let mut triggers = vec![];
216
217            // If this is the entry point, trigger on START
218            if name == entry_point {
219                triggers.push(START.to_string());
220            }
221
222            // Add triggers from incoming edges
223            for (source, targets) in &self.edges {
224                if targets.contains(&name) {
225                    triggers.push(format!("{}_output", source));
226                }
227            }
228
229            // If no triggers, add a default
230            if triggers.is_empty() {
231                triggers.push(format!("{}_input", name));
232            }
233
234            let pregel_node = PregelNode::new(
235                name.clone(),
236                vec![format!("{}_input", name)],
237                triggers,
238                Arc::from(node) as NodeArc<S>,
239                vec![ChannelWrite::new(format!("{}_output", name))],
240            );
241
242            pregel_nodes.insert(name, pregel_node);
243        }
244
245        // Create channels
246        let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
247
248        // Add START and END channels
249        channels.insert(START.to_string(), Box::new(LastValue::<S>::new()));
250        channels.insert(END.to_string(), Box::new(LastValue::<S>::new()));
251
252        // Add channels for each node
253        for node_name in pregel_nodes.keys() {
254            channels.insert(
255                format!("{}_input", node_name),
256                Box::new(LastValue::<S>::new()),
257            );
258            channels.insert(
259                format!("{}_output", node_name),
260                Box::new(LastValue::<S>::new()),
261            );
262        }
263
264        // Create the Pregel engine
265        let pregel = Pregel::new(
266            pregel_nodes,
267            channels,
268            checkpointer.clone(),
269            entry_point.clone(),
270            self.finish_points.clone(),
271            self.edges.clone(),
272        );
273
274        Ok(CompiledGraph {
275            pregel,
276            entry_point,
277            finish_points: self.finish_points,
278            checkpointer,
279        })
280    }
281}
282
283impl<S: State> Default for StateGraph<S> {
284    fn default() -> Self {
285        Self::new()
286    }
287}
288
289/// A compiled, executable graph.
290///
291/// CompiledGraph is the result of compiling a StateGraph. It provides
292/// methods to execute the graph with different invocation patterns.
293pub struct CompiledGraph<S: State> {
294    pregel: Pregel<S>,
295    entry_point: String,
296    finish_points: HashSet<String>,
297    checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
298}
299
300impl<S: State> CompiledGraph<S> {
301    /// Execute the graph with the given input
302    ///
303    /// This runs the graph to completion and returns the final state.
304    ///
305    /// # Arguments
306    ///
307    /// * `input` - Initial state
308    /// * `config` - Execution configuration
309    ///
310    /// # Returns
311    ///
312    /// The final state after graph execution
313    ///
314    /// # Example
315    ///
316    /// ```rust,no_run
317    /// # use rust_langgraph::{StateGraph, Config, State, Error};
318    /// # #[derive(Clone, serde::Serialize, serde::Deserialize)]
319    /// # struct MyState { count: i32 }
320    /// # impl State for MyState {
321    /// #     fn merge(&mut self, other: Self) -> Result<(), Error> { Ok(()) }
322    /// # }
323    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
324    /// # let mut graph = StateGraph::new();
325    /// # graph.add_node("test", |s: MyState, _| async move { Ok(s) });
326    /// # graph.set_entry_point("test");
327    /// # graph.set_finish_point("test");
328    /// let app = graph.compile(None)?;
329    ///
330    /// let result = app.invoke(
331    ///     MyState { count: 0 },
332    ///     Config::default()
333    /// ).await?;
334    ///
335    /// println!("Final count: {}", result.count);
336    /// # Ok(())
337    /// # }
338    /// ```
339    pub async fn invoke(&mut self, input: S, config: Config) -> Result<S> {
340        self.pregel.invoke(input, config).await
341    }
342
343    /// Stream execution events
344    ///
345    /// Returns a stream of events as the graph executes, allowing
346    /// real-time observation of progress.
347    ///
348    /// # Arguments
349    ///
350    /// * `input` - Initial state
351    /// * `config` - Execution configuration
352    /// * `mode` - Type of events to stream
353    pub async fn stream(
354        &mut self,
355        input: S,
356        config: Config,
357        mode: StreamMode,
358    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
359        self.pregel.stream(input, config, mode).await
360    }
361
362    /// Get the current state for a given configuration
363    ///
364    /// This retrieves the most recent checkpoint for the thread.
365    pub async fn get_state(&self, config: &Config) -> Result<Option<StateSnapshot<S>>> {
366        self.pregel.get_state(config).await
367    }
368
369    /// Get the state history for a thread
370    ///
371    /// Returns past checkpoints in reverse chronological order.
372    pub async fn get_state_history(
373        &self,
374        config: &Config,
375        limit: Option<usize>,
376    ) -> Result<Vec<StateSnapshot<S>>> {
377        self.pregel.get_state_history(config, limit).await
378    }
379
380    /// Update the state for a thread
381    ///
382    /// This allows modifying the checkpoint state, useful for
383    /// human-in-the-loop patterns.
384    pub async fn update_state(&mut self, config: Config, values: S) -> Result<Config> {
385        if let Some(checkpointer) = &self.checkpointer {
386            // Get the current checkpoint
387            let mut tuple = checkpointer
388                .get_tuple(&config)
389                .await?
390                .ok_or_else(|| Error::checkpoint("No checkpoint found for config"))?;
391
392            // Update the state
393            let mut current_state = S::from_value(
394                tuple
395                    .checkpoint
396                    .get_channel("__start__")
397                    .ok_or_else(|| Error::checkpoint("No state in checkpoint"))?
398                    .clone(),
399            )?;
400
401            current_state.merge(values)?;
402
403            // Create new checkpoint
404            tuple.checkpoint.set_channel("__start__", current_state.to_value()?);
405
406            // Save the updated checkpoint
407            let metadata = CheckpointMetadata {
408                step: tuple.metadata.step + 1,
409                source: "update_state".to_string(),
410                created_at: chrono::Utc::now(),
411                extra: HashMap::new(),
412            };
413
414            checkpointer.put(&tuple.checkpoint, &metadata, &config).await
415        } else {
416            Err(Error::checkpoint("No checkpointer configured"))
417        }
418    }
419}
420
421#[cfg(test)]
422mod tests {
423    use super::*;
424    use crate::checkpoint_backends::memory::MemorySaver;
425    use serde::{Deserialize, Serialize};
426
427    #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
428    struct TestState {
429        count: i32,
430    }
431
432    impl crate::state::State for TestState {
433        fn merge(&mut self, other: Self) -> Result<()> {
434            self.count += other.count;
435            Ok(())
436        }
437    }
438
439    #[tokio::test]
440    async fn test_state_graph_basic() {
441        let mut graph = StateGraph::new();
442
443        graph.add_node("increment", |mut state: TestState, _config: &Config| async move {
444            state.count += 1;
445            Ok(state)
446        });
447
448        graph.set_entry_point("increment");
449        graph.set_finish_point("increment");
450
451        let mut app = graph.compile(None).unwrap();
452
453        let result = app.invoke(TestState { count: 0 }, Config::default()).await.unwrap();
454        assert_eq!(result.count, 1);
455    }
456
457    #[tokio::test]
458    async fn test_state_graph_chain() {
459        let mut graph = StateGraph::new();
460
461        graph.add_node("add_one", |mut state: TestState, _config: &Config| async move {
462            state.count += 1;
463            Ok(state)
464        });
465
466        graph.add_node("multiply_two", |mut state: TestState, _config: &Config| async move {
467            state.count *= 2;
468            Ok(state)
469        });
470
471        graph.set_entry_point("add_one");
472        graph.add_edge("add_one", "multiply_two");
473        graph.set_finish_point("multiply_two");
474
475        let mut app = graph.compile(None).unwrap();
476
477        let result = app.invoke(TestState { count: 5 }, Config::default()).await.unwrap();
478        assert_eq!(result.count, 12); // (5 + 1) * 2
479    }
480
481    #[tokio::test]
482    async fn test_state_graph_with_checkpointer() {
483        let mut graph = StateGraph::new();
484
485        graph.add_node("increment", |mut state: TestState, _config: &Config| async move {
486            state.count += 1;
487            Ok(state)
488        });
489
490        graph.set_entry_point("increment");
491        graph.set_finish_point("increment");
492
493        let checkpointer = Arc::new(MemorySaver::new());
494        let mut app = graph.compile(Some(checkpointer)).unwrap();
495
496        let config = Config::new().with_thread_id("test-123");
497        let result = app.invoke(TestState { count: 0 }, config.clone()).await.unwrap();
498        assert_eq!(result.count, 1);
499
500        // Check that checkpoint was saved
501        let snapshot = app.get_state(&config).await.unwrap();
502        assert!(snapshot.is_some());
503    }
504
505    #[test]
506    fn test_state_graph_no_entry_point() {
507        let mut graph = StateGraph::<TestState>::new();
508        graph.add_node("test", |s: TestState, _config: &Config| async move { Ok(s) });
509
510        let result = graph.compile(None);
511        assert!(result.is_err());
512        if let Err(e) = result {
513            assert!(e.to_string().contains("entry point"));
514        }
515    }
516}