Skip to main content

rust_langgraph/
state.rs

1//! State management and message handling.
2//!
3//! This module defines the `State` trait which is the core abstraction for
4//! graph state, as well as built-in state types like `MessagesState` for
5//! chat applications.
6
7use crate::errors::{Error, Result};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12/// The core trait for graph state.
13///
14/// Implementors of this trait define how state is merged when multiple
15/// nodes produce updates that need to be combined.
16///
17/// # Example
18///
19/// ```rust
20/// use rust_langgraph::{State, Error};
21///
22/// #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
23/// struct CounterState {
24///     count: i32,
25/// }
26///
27/// impl State for CounterState {
28///     fn merge(&mut self, other: Self) -> Result<(), Error> {
29///         self.count += other.count;
30///         Ok(())
31///     }
32/// }
33/// ```
34pub trait State: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + std::fmt::Debug + 'static {
35    /// Merge another state into this one.
36    ///
37    /// This is called when multiple nodes write to the same state,
38    /// or when resuming from a checkpoint.
39    fn merge(&mut self, other: Self) -> Result<()>;
40
41    /// Convert state to JSON value (default implementation)
42    fn to_value(&self) -> Result<serde_json::Value> {
43        serde_json::to_value(self).map_err(Error::from)
44    }
45
46    /// Create state from JSON value (default implementation)
47    fn from_value(value: serde_json::Value) -> Result<Self> {
48        serde_json::from_value(value).map_err(Error::from)
49    }
50}
51
52/// A message in a conversation.
53///
54/// Messages are the core unit of communication in chat applications.
55#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct Message {
57    /// The role of the message sender (e.g., "user", "assistant", "system")
58    pub role: String,
59    
60    /// The content of the message
61    pub content: String,
62    
63    /// Optional message name/identifier
64    pub name: Option<String>,
65    
66    /// Optional function/tool call information
67    pub tool_calls: Option<Vec<ToolCall>>,
68    
69    /// Optional tool call ID (for tool responses)
70    pub tool_call_id: Option<String>,
71    
72    /// Additional metadata
73    pub metadata: HashMap<String, serde_json::Value>,
74}
75
76impl Message {
77    /// Create a new user message
78    pub fn user(content: impl Into<String>) -> Self {
79        Self {
80            role: "user".to_string(),
81            content: content.into(),
82            name: None,
83            tool_calls: None,
84            tool_call_id: None,
85            metadata: HashMap::new(),
86        }
87    }
88
89    /// Create a new assistant message
90    pub fn assistant(content: impl Into<String>) -> Self {
91        Self {
92            role: "assistant".to_string(),
93            content: content.into(),
94            name: None,
95            tool_calls: None,
96            tool_call_id: None,
97            metadata: HashMap::new(),
98        }
99    }
100
101    /// Create a new system message
102    pub fn system(content: impl Into<String>) -> Self {
103        Self {
104            role: "system".to_string(),
105            content: content.into(),
106            name: None,
107            tool_calls: None,
108            tool_call_id: None,
109            metadata: HashMap::new(),
110        }
111    }
112
113    /// Create a new tool message (response from a tool)
114    pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
115        Self {
116            role: "tool".to_string(),
117            content: content.into(),
118            name: None,
119            tool_calls: None,
120            tool_call_id: Some(tool_call_id.into()),
121            metadata: HashMap::new(),
122        }
123    }
124
125    /// Add tool calls to this message
126    pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
127        self.tool_calls = Some(tool_calls);
128        self
129    }
130
131    /// Add a name to this message
132    pub fn with_name(mut self, name: impl Into<String>) -> Self {
133        self.name = Some(name.into());
134        self
135    }
136}
137
138/// A tool call in a message.
139#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
140pub struct ToolCall {
141    /// Unique identifier for this tool call
142    pub id: String,
143    
144    /// The name of the tool to call
145    pub name: String,
146    
147    /// Arguments to pass to the tool (as JSON)
148    pub arguments: serde_json::Value,
149}
150
151impl ToolCall {
152    /// Create a new tool call
153    pub fn new(
154        id: impl Into<String>,
155        name: impl Into<String>,
156        arguments: serde_json::Value,
157    ) -> Self {
158        Self {
159            id: id.into(),
160            name: name.into(),
161            arguments,
162        }
163    }
164}
165
166/// State that contains a list of messages.
167///
168/// This is the standard state type for chat applications and agent workflows.
169/// The `add_messages` function provides the reducer logic.
170///
171/// # Example
172///
173/// ```rust
174/// use rust_langgraph::{State, MessagesState, Message};
175///
176/// let mut state = MessagesState {
177///     messages: vec![Message::user("Hello!")],
178/// };
179///
180/// let update = MessagesState {
181///     messages: vec![Message::assistant("Hi there!")],
182/// };
183///
184/// state.merge(update).unwrap();
185/// assert_eq!(state.messages.len(), 2);
186/// ```
187#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct MessagesState {
189    /// The list of messages
190    pub messages: Vec<Message>,
191}
192
193impl State for MessagesState {
194    fn merge(&mut self, other: Self) -> Result<()> {
195        add_messages(&mut self.messages, other.messages);
196        Ok(())
197    }
198}
199
200/// Add messages to an existing list with smart merging.
201///
202/// This function implements the message reduction logic:
203/// - Appends new messages
204/// - Updates existing messages if they have the same ID
205/// - Handles tool calls and responses properly
206///
207/// This is used as the default reducer for `MessagesState`.
208pub fn add_messages(existing: &mut Vec<Message>, new: Vec<Message>) {
209    // For now, simple append logic
210    // In a full implementation, this would handle:
211    // - Deduplication by message ID
212    // - Updating tool call responses
213    // - Merging metadata
214    existing.extend(new);
215}
216
217/// A simple dictionary-based state.
218///
219/// Useful for quick prototypes or when you don't need custom merge logic.
220#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct DictState {
222    /// The state data
223    pub data: HashMap<String, serde_json::Value>,
224}
225
226impl State for DictState {
227    fn merge(&mut self, other: Self) -> Result<()> {
228        // Later values overwrite earlier ones
229        self.data.extend(other.data);
230        Ok(())
231    }
232}
233
234impl DictState {
235    /// Create a new empty dict state
236    pub fn new() -> Self {
237        Self {
238            data: HashMap::new(),
239        }
240    }
241
242    /// Create dict state with initial data
243    pub fn with_data(data: HashMap<String, serde_json::Value>) -> Self {
244        Self { data }
245    }
246
247    /// Get a value
248    pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
249        self.data.get(key)
250    }
251
252    /// Set a value
253    pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
254        self.data.insert(key.into(), value);
255    }
256}
257
258impl Default for DictState {
259    fn default() -> Self {
260        Self::new()
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
269    struct TestState {
270        count: i32,
271    }
272
273    impl State for TestState {
274        fn merge(&mut self, other: Self) -> Result<()> {
275            self.count += other.count;
276            Ok(())
277        }
278    }
279
280    #[test]
281    fn test_state_merge() {
282        let mut state = TestState { count: 5 };
283        let other = TestState { count: 3 };
284        
285        state.merge(other).unwrap();
286        assert_eq!(state.count, 8);
287    }
288
289    #[test]
290    fn test_message_creation() {
291        let msg = Message::user("Hello");
292        assert_eq!(msg.role, "user");
293        assert_eq!(msg.content, "Hello");
294
295        let msg = Message::assistant("Hi").with_name("bot");
296        assert_eq!(msg.name.as_deref(), Some("bot"));
297    }
298
299    #[test]
300    fn test_messages_state() {
301        let mut state = MessagesState {
302            messages: vec![Message::user("Hello")],
303        };
304
305        let update = MessagesState {
306            messages: vec![Message::assistant("Hi there!")],
307        };
308
309        state.merge(update).unwrap();
310        assert_eq!(state.messages.len(), 2);
311        assert_eq!(state.messages[0].role, "user");
312        assert_eq!(state.messages[1].role, "assistant");
313    }
314
315    #[test]
316    fn test_dict_state() {
317        let mut state = DictState::new();
318        state.set("key1", serde_json::json!("value1"));
319        
320        let mut other = DictState::new();
321        other.set("key2", serde_json::json!(42));
322        
323        state.merge(other).unwrap();
324        
325        assert_eq!(state.data.len(), 2);
326        assert_eq!(state.get("key1").unwrap(), &serde_json::json!("value1"));
327        assert_eq!(state.get("key2").unwrap(), &serde_json::json!(42));
328    }
329
330    #[test]
331    fn test_tool_call() {
332        let tool_call = ToolCall::new("call-1", "search", serde_json::json!({"query": "rust"}));
333        assert_eq!(tool_call.id, "call-1");
334        assert_eq!(tool_call.name, "search");
335    }
336}