Skip to main content

rust_langgraph/
config.rs

1//! Configuration for graph execution.
2//!
3//! The `Config` type contains settings that control how graphs execute,
4//! including checkpointing, recursion limits, and metadata.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Configuration for graph execution.
10///
11/// Config controls various aspects of graph execution including which
12/// checkpoint to load, recursion limits, and custom metadata.
13///
14/// # Example
15///
16/// ```rust
17/// use rust_langgraph::Config;
18///
19/// let config = Config::new()
20///     .with_thread_id("user-123")
21///     .with_recursion_limit(100)
22///     .with_metadata("user_name", "Alice");
23/// ```
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct Config {
26    /// Thread ID for checkpoint isolation
27    pub thread_id: Option<String>,
28    
29    /// Specific checkpoint ID to load (for time travel)
30    pub checkpoint_id: Option<String>,
31    
32    /// Maximum recursion depth before error
33    pub recursion_limit: usize,
34    
35    /// Custom metadata
36    pub metadata: HashMap<String, serde_json::Value>,
37    
38    /// Tags for categorizing runs
39    pub tags: Vec<String>,
40}
41
42impl Default for Config {
43    fn default() -> Self {
44        Self {
45            thread_id: None,
46            checkpoint_id: None,
47            recursion_limit: 25,
48            metadata: HashMap::new(),
49            tags: Vec::new(),
50        }
51    }
52}
53
54impl Config {
55    /// Create a new default configuration
56    pub fn new() -> Self {
57        Self::default()
58    }
59
60    /// Set the thread ID for checkpoint isolation
61    pub fn with_thread_id(mut self, thread_id: impl Into<String>) -> Self {
62        self.thread_id = Some(thread_id.into());
63        self
64    }
65
66    /// Set a specific checkpoint ID to load (for time travel)
67    pub fn with_checkpoint_id(mut self, checkpoint_id: impl Into<String>) -> Self {
68        self.checkpoint_id = Some(checkpoint_id.into());
69        self
70    }
71
72    /// Set the recursion limit
73    pub fn with_recursion_limit(mut self, limit: usize) -> Self {
74        self.recursion_limit = limit;
75        self
76    }
77
78    /// Add metadata
79    pub fn with_metadata(
80        mut self,
81        key: impl Into<String>,
82        value: impl Into<serde_json::Value>,
83    ) -> Self {
84        self.metadata.insert(key.into(), value.into());
85        self
86    }
87
88    /// Add a tag
89    pub fn with_tag(mut self, tag: impl Into<String>) -> Self {
90        self.tags.push(tag.into());
91        self
92    }
93
94    /// Get or create a thread ID
95    pub fn ensure_thread_id(&mut self) -> &str {
96        if self.thread_id.is_none() {
97            self.thread_id = Some(uuid::Uuid::new_v4().to_string());
98        }
99        self.thread_id.as_ref().unwrap()
100    }
101}
102
103#[cfg(test)]
104mod tests {
105    use super::*;
106
107    #[test]
108    fn test_config_builder() {
109        let config = Config::new()
110            .with_thread_id("test-thread")
111            .with_recursion_limit(100)
112            .with_metadata("key", "value")
113            .with_tag("test");
114
115        assert_eq!(config.thread_id.as_deref(), Some("test-thread"));
116        assert_eq!(config.recursion_limit, 100);
117        assert_eq!(config.metadata.len(), 1);
118        assert_eq!(config.tags.len(), 1);
119    }
120
121    #[test]
122    fn test_default_config() {
123        let config = Config::default();
124        assert_eq!(config.recursion_limit, 25);
125        assert!(config.thread_id.is_none());
126        assert!(config.metadata.is_empty());
127    }
128
129    #[test]
130    fn test_ensure_thread_id() {
131        let mut config = Config::new();
132        assert!(config.thread_id.is_none());
133        
134        let thread_id = config.ensure_thread_id().to_string();
135        assert!(!thread_id.is_empty());
136        
137        // Should return same thread_id on second call
138        let thread_id2 = config.ensure_thread_id().to_string();
139        assert_eq!(thread_id, thread_id2);
140    }
141
142    #[test]
143    fn test_config_serialization() {
144        let config = Config::new()
145            .with_thread_id("test")
146            .with_recursion_limit(50);
147
148        let json = serde_json::to_string(&config).unwrap();
149        let deserialized: Config = serde_json::from_str(&json).unwrap();
150
151        assert_eq!(deserialized.thread_id, config.thread_id);
152        assert_eq!(deserialized.recursion_limit, config.recursion_limit);
153    }
154}