rust_langgraph/graph/state_graph.rs
1//! StateGraph builder and CompiledGraph.
2//!
3//! StateGraph provides an ergonomic builder API for creating graphs,
4//! which then compile into executable CompiledGraph instances.
5
6use crate::channels::{BaseChannel, LastValue};
7use crate::checkpoint::{BaseCheckpointSaver, CheckpointMetadata, StateSnapshot};
8use crate::config::Config;
9use crate::errors::{Error, Result};
10use crate::graph::{START, END};
11use crate::nodes::{Node, PregelNode, ChannelWrite, NodeArc};
12use crate::pregel::{Branch, Pregel};
13use crate::state::State;
14use crate::types::{StreamEvent, StreamMode};
15use futures::stream::Stream;
16use std::collections::{HashMap, HashSet};
17use std::pin::Pin;
18use std::sync::Arc;
19
20/// Builder for creating state-based graphs.
21///
22/// StateGraph provides a declarative API for building graphs where nodes
23/// communicate through shared state. It compiles into a CompiledGraph
24/// which can be executed.
25///
26/// # Example
27///
28/// ```rust
29/// use rust_langgraph::{StateGraph, Config};
30/// # use rust_langgraph::{State, Error};
31/// # #[derive(Clone, serde::Serialize, serde::Deserialize)]
32/// # struct MyState { count: i32 }
33/// # impl State for MyState {
34/// # fn merge(&mut self, other: Self) -> Result<(), Error> {
35/// # self.count += other.count;
36/// # Ok(())
37/// # }
38/// # }
39///
40/// let mut graph = StateGraph::new();
41///
42/// graph.add_node("process", |mut state: MyState, _config: &Config| async move {
43/// state.count += 1;
44/// Ok(state)
45/// });
46///
47/// graph.set_entry_point("process");
48/// graph.set_finish_point("process");
49///
50/// let app = graph.compile(None).unwrap();
51/// ```
52pub struct StateGraph<S: State> {
53 nodes: HashMap<String, Box<dyn Node<S>>>,
54 edges: HashMap<String, Vec<String>>,
55 conditional_edges: HashMap<String, Box<dyn Branch<S>>>,
56 entry_point: Option<String>,
57 finish_points: HashSet<String>,
58}
59
60impl<S: State> StateGraph<S> {
61 /// Create a new StateGraph
62 pub fn new() -> Self {
63 Self {
64 nodes: HashMap::new(),
65 edges: HashMap::new(),
66 conditional_edges: HashMap::new(),
67 entry_point: None,
68 finish_points: HashSet::new(),
69 }
70 }
71
72 /// Add a node to the graph
73 ///
74 /// # Arguments
75 ///
76 /// * `name` - Unique identifier for this node
77 /// * `node` - The node implementation (function or struct implementing Node)
78 ///
79 /// # Example
80 ///
81 /// ```rust
82 /// # use rust_langgraph::{StateGraph, Config, State, Error};
83 /// # #[derive(Clone, serde::Serialize, serde::Deserialize)]
84 /// # struct MyState { count: i32 }
85 /// # impl State for MyState {
86 /// # fn merge(&mut self, other: Self) -> Result<(), Error> { Ok(()) }
87 /// # }
88 /// let mut graph = StateGraph::new();
89 ///
90 /// graph.add_node("increment", |mut state: MyState, _config: &Config| async move {
91 /// state.count += 1;
92 /// Ok(state)
93 /// });
94 /// ```
95 pub fn add_node(&mut self, name: impl Into<String>, node: impl Node<S> + 'static) -> &mut Self {
96 self.nodes.insert(name.into(), Box::new(node));
97 self
98 }
99
100 /// Add a static edge from one node to another
101 ///
102 /// # Arguments
103 ///
104 /// * `from` - Source node name
105 /// * `to` - Target node name
106 pub fn add_edge(&mut self, from: impl Into<String>, to: impl Into<String>) -> &mut Self {
107 let from = from.into();
108 let to = to.into();
109
110 self.edges.entry(from).or_default().push(to);
111 self
112 }
113
114 /// Add conditional edges from a source node
115 ///
116 /// The branch function examines state and returns which node(s) to route to next.
117 ///
118 /// # Arguments
119 ///
120 /// * `source` - Source node name
121 /// * `branch` - Branch logic that determines routing
122 ///
123 /// # Example
124 ///
125 /// ```rust
126 /// # use rust_langgraph::{StateGraph, Config, State, Error};
127 /// # use rust_langgraph::pregel::BranchResult;
128 /// # #[derive(Clone, serde::Serialize, serde::Deserialize)]
129 /// # struct MyState { value: i32 }
130 /// # impl State for MyState {
131 /// # fn merge(&mut self, other: Self) -> Result<(), Error> { Ok(()) }
132 /// # }
133 /// let mut graph = StateGraph::new();
134 ///
135 /// graph.add_conditional_edges(
136 /// "check",
137 /// |state: &MyState| async move {
138 /// if state.value > 0 {
139 /// Ok(BranchResult::single("positive"))
140 /// } else {
141 /// Ok(BranchResult::single("negative"))
142 /// }
143 /// }
144 /// );
145 /// ```
146 pub fn add_conditional_edges(
147 &mut self,
148 source: impl Into<String>,
149 branch: impl Branch<S> + 'static,
150 ) -> &mut Self {
151 self.conditional_edges.insert(source.into(), Box::new(branch));
152 self
153 }
154
155 /// Set the entry point for the graph
156 ///
157 /// This is the first node to execute when the graph is invoked.
158 pub fn set_entry_point(&mut self, node: impl Into<String>) -> &mut Self {
159 self.entry_point = Some(node.into());
160 self
161 }
162
163 /// Set a finish point for the graph
164 ///
165 /// When execution reaches a finish point, the graph completes.
166 pub fn set_finish_point(&mut self, node: impl Into<String>) -> &mut Self {
167 self.finish_points.insert(node.into());
168 self
169 }
170
171 /// Add multiple finish points
172 pub fn add_finish_points(&mut self, nodes: Vec<impl Into<String>>) -> &mut Self {
173 for node in nodes {
174 self.finish_points.insert(node.into());
175 }
176 self
177 }
178
179 /// Compile the graph into an executable CompiledGraph
180 ///
181 /// # Arguments
182 ///
183 /// * `checkpointer` - Optional checkpoint saver for persistence
184 ///
185 /// # Returns
186 ///
187 /// A CompiledGraph ready to execute
188 ///
189 /// # Errors
190 ///
191 /// Returns an error if the graph configuration is invalid (e.g., no entry point)
192 pub fn compile(
193 self,
194 checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
195 ) -> Result<CompiledGraph<S>> {
196 // Validate graph
197 if self.entry_point.is_none() {
198 return Err(Error::invalid_graph("No entry point set"));
199 }
200
201 let entry_point = self.entry_point.unwrap();
202
203 if !self.nodes.contains_key(&entry_point) {
204 return Err(Error::invalid_graph(format!(
205 "Entry point '{}' is not a valid node",
206 entry_point
207 )));
208 }
209
210 // Build PregelNodes from our nodes
211 let mut pregel_nodes = HashMap::new();
212
213 for (name, node) in self.nodes {
214 // Determine triggers: nodes triggered by their dependencies
215 let mut triggers = vec![];
216
217 // If this is the entry point, trigger on START
218 if name == entry_point {
219 triggers.push(START.to_string());
220 }
221
222 // Add triggers from incoming edges
223 for (source, targets) in &self.edges {
224 if targets.contains(&name) {
225 triggers.push(format!("{}_output", source));
226 }
227 }
228
229 // If no triggers, add a default
230 if triggers.is_empty() {
231 triggers.push(format!("{}_input", name));
232 }
233
234 let pregel_node = PregelNode::new(
235 name.clone(),
236 vec![format!("{}_input", name)],
237 triggers,
238 Arc::from(node) as NodeArc<S>,
239 vec![ChannelWrite::new(format!("{}_output", name))],
240 );
241
242 pregel_nodes.insert(name, pregel_node);
243 }
244
245 // Create channels
246 let mut channels: HashMap<String, Box<dyn BaseChannel>> = HashMap::new();
247
248 // Add START and END channels
249 channels.insert(START.to_string(), Box::new(LastValue::<S>::new()));
250 channels.insert(END.to_string(), Box::new(LastValue::<S>::new()));
251
252 // Add channels for each node
253 for node_name in pregel_nodes.keys() {
254 channels.insert(
255 format!("{}_input", node_name),
256 Box::new(LastValue::<S>::new()),
257 );
258 channels.insert(
259 format!("{}_output", node_name),
260 Box::new(LastValue::<S>::new()),
261 );
262 }
263
264 // Create the Pregel engine
265 let pregel = Pregel::new(
266 pregel_nodes,
267 channels,
268 checkpointer.clone(),
269 entry_point.clone(),
270 self.finish_points.clone(),
271 self.edges.clone(),
272 );
273
274 Ok(CompiledGraph {
275 pregel,
276 entry_point,
277 finish_points: self.finish_points,
278 checkpointer,
279 })
280 }
281}
282
283impl<S: State> Default for StateGraph<S> {
284 fn default() -> Self {
285 Self::new()
286 }
287}
288
289/// A compiled, executable graph.
290///
291/// CompiledGraph is the result of compiling a StateGraph. It provides
292/// methods to execute the graph with different invocation patterns.
293pub struct CompiledGraph<S: State> {
294 pregel: Pregel<S>,
295 entry_point: String,
296 finish_points: HashSet<String>,
297 checkpointer: Option<Arc<dyn BaseCheckpointSaver>>,
298}
299
300impl<S: State> CompiledGraph<S> {
301 /// Execute the graph with the given input
302 ///
303 /// This runs the graph to completion and returns the final state.
304 ///
305 /// # Arguments
306 ///
307 /// * `input` - Initial state
308 /// * `config` - Execution configuration
309 ///
310 /// # Returns
311 ///
312 /// The final state after graph execution
313 ///
314 /// # Example
315 ///
316 /// ```rust,no_run
317 /// # use rust_langgraph::{StateGraph, Config, State, Error};
318 /// # #[derive(Clone, serde::Serialize, serde::Deserialize)]
319 /// # struct MyState { count: i32 }
320 /// # impl State for MyState {
321 /// # fn merge(&mut self, other: Self) -> Result<(), Error> { Ok(()) }
322 /// # }
323 /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
324 /// # let mut graph = StateGraph::new();
325 /// # graph.add_node("test", |s: MyState, _| async move { Ok(s) });
326 /// # graph.set_entry_point("test");
327 /// # graph.set_finish_point("test");
328 /// let app = graph.compile(None)?;
329 ///
330 /// let result = app.invoke(
331 /// MyState { count: 0 },
332 /// Config::default()
333 /// ).await?;
334 ///
335 /// println!("Final count: {}", result.count);
336 /// # Ok(())
337 /// # }
338 /// ```
339 pub async fn invoke(&mut self, input: S, config: Config) -> Result<S> {
340 self.pregel.invoke(input, config).await
341 }
342
343 /// Stream execution events
344 ///
345 /// Returns a stream of events as the graph executes, allowing
346 /// real-time observation of progress.
347 ///
348 /// # Arguments
349 ///
350 /// * `input` - Initial state
351 /// * `config` - Execution configuration
352 /// * `mode` - Type of events to stream
353 pub async fn stream(
354 &mut self,
355 input: S,
356 config: Config,
357 mode: StreamMode,
358 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
359 self.pregel.stream(input, config, mode).await
360 }
361
362 /// Get the current state for a given configuration
363 ///
364 /// This retrieves the most recent checkpoint for the thread.
365 pub async fn get_state(&self, config: &Config) -> Result<Option<StateSnapshot<S>>> {
366 self.pregel.get_state(config).await
367 }
368
369 /// Get the state history for a thread
370 ///
371 /// Returns past checkpoints in reverse chronological order.
372 pub async fn get_state_history(
373 &self,
374 config: &Config,
375 limit: Option<usize>,
376 ) -> Result<Vec<StateSnapshot<S>>> {
377 self.pregel.get_state_history(config, limit).await
378 }
379
380 /// Update the state for a thread
381 ///
382 /// This allows modifying the checkpoint state, useful for
383 /// human-in-the-loop patterns.
384 pub async fn update_state(&mut self, config: Config, values: S) -> Result<Config> {
385 if let Some(checkpointer) = &self.checkpointer {
386 // Get the current checkpoint
387 let mut tuple = checkpointer
388 .get_tuple(&config)
389 .await?
390 .ok_or_else(|| Error::checkpoint("No checkpoint found for config"))?;
391
392 // Update the state
393 let mut current_state = S::from_value(
394 tuple
395 .checkpoint
396 .get_channel("__start__")
397 .ok_or_else(|| Error::checkpoint("No state in checkpoint"))?
398 .clone(),
399 )?;
400
401 current_state.merge(values)?;
402
403 // Create new checkpoint
404 tuple.checkpoint.set_channel("__start__", current_state.to_value()?);
405
406 // Save the updated checkpoint
407 let metadata = CheckpointMetadata {
408 step: tuple.metadata.step + 1,
409 source: "update_state".to_string(),
410 created_at: chrono::Utc::now(),
411 extra: HashMap::new(),
412 };
413
414 checkpointer.put(&tuple.checkpoint, &metadata, &config).await
415 } else {
416 Err(Error::checkpoint("No checkpointer configured"))
417 }
418 }
419}
420
421#[cfg(test)]
422mod tests {
423 use super::*;
424 use crate::checkpoint_backends::memory::MemorySaver;
425 use serde::{Deserialize, Serialize};
426
427 #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
428 struct TestState {
429 count: i32,
430 }
431
432 impl crate::state::State for TestState {
433 fn merge(&mut self, other: Self) -> Result<()> {
434 self.count += other.count;
435 Ok(())
436 }
437 }
438
439 #[tokio::test]
440 async fn test_state_graph_basic() {
441 let mut graph = StateGraph::new();
442
443 graph.add_node("increment", |mut state: TestState, _config: &Config| async move {
444 state.count += 1;
445 Ok(state)
446 });
447
448 graph.set_entry_point("increment");
449 graph.set_finish_point("increment");
450
451 let mut app = graph.compile(None).unwrap();
452
453 let result = app.invoke(TestState { count: 0 }, Config::default()).await.unwrap();
454 assert_eq!(result.count, 1);
455 }
456
457 #[tokio::test]
458 async fn test_state_graph_chain() {
459 let mut graph = StateGraph::new();
460
461 graph.add_node("add_one", |mut state: TestState, _config: &Config| async move {
462 state.count += 1;
463 Ok(state)
464 });
465
466 graph.add_node("multiply_two", |mut state: TestState, _config: &Config| async move {
467 state.count *= 2;
468 Ok(state)
469 });
470
471 graph.set_entry_point("add_one");
472 graph.add_edge("add_one", "multiply_two");
473 graph.set_finish_point("multiply_two");
474
475 let mut app = graph.compile(None).unwrap();
476
477 let result = app.invoke(TestState { count: 5 }, Config::default()).await.unwrap();
478 assert_eq!(result.count, 12); // (5 + 1) * 2
479 }
480
481 #[tokio::test]
482 async fn test_state_graph_with_checkpointer() {
483 let mut graph = StateGraph::new();
484
485 graph.add_node("increment", |mut state: TestState, _config: &Config| async move {
486 state.count += 1;
487 Ok(state)
488 });
489
490 graph.set_entry_point("increment");
491 graph.set_finish_point("increment");
492
493 let checkpointer = Arc::new(MemorySaver::new());
494 let mut app = graph.compile(Some(checkpointer)).unwrap();
495
496 let config = Config::new().with_thread_id("test-123");
497 let result = app.invoke(TestState { count: 0 }, config.clone()).await.unwrap();
498 assert_eq!(result.count, 1);
499
500 // Check that checkpoint was saved
501 let snapshot = app.get_state(&config).await.unwrap();
502 assert!(snapshot.is_some());
503 }
504
505 #[test]
506 fn test_state_graph_no_entry_point() {
507 let mut graph = StateGraph::<TestState>::new();
508 graph.add_node("test", |s: TestState, _config: &Config| async move { Ok(s) });
509
510 let result = graph.compile(None);
511 assert!(result.is_err());
512 if let Err(e) = result {
513 assert!(e.to_string().contains("entry point"));
514 }
515 }
516}