1use crate::errors::{Error, Result};
8use serde::{Deserialize, Serialize};
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12pub trait State: Clone + Send + Sync + Serialize + for<'de> Deserialize<'de> + std::fmt::Debug + 'static {
35 fn merge(&mut self, other: Self) -> Result<()>;
40
41 fn to_value(&self) -> Result<serde_json::Value> {
43 serde_json::to_value(self).map_err(Error::from)
44 }
45
46 fn from_value(value: serde_json::Value) -> Result<Self> {
48 serde_json::from_value(value).map_err(Error::from)
49 }
50}
51
52#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
56pub struct Message {
57 pub role: String,
59
60 pub content: String,
62
63 pub name: Option<String>,
65
66 pub tool_calls: Option<Vec<ToolCall>>,
68
69 pub tool_call_id: Option<String>,
71
72 pub metadata: HashMap<String, serde_json::Value>,
74}
75
76impl Message {
77 pub fn user(content: impl Into<String>) -> Self {
79 Self {
80 role: "user".to_string(),
81 content: content.into(),
82 name: None,
83 tool_calls: None,
84 tool_call_id: None,
85 metadata: HashMap::new(),
86 }
87 }
88
89 pub fn assistant(content: impl Into<String>) -> Self {
91 Self {
92 role: "assistant".to_string(),
93 content: content.into(),
94 name: None,
95 tool_calls: None,
96 tool_call_id: None,
97 metadata: HashMap::new(),
98 }
99 }
100
101 pub fn system(content: impl Into<String>) -> Self {
103 Self {
104 role: "system".to_string(),
105 content: content.into(),
106 name: None,
107 tool_calls: None,
108 tool_call_id: None,
109 metadata: HashMap::new(),
110 }
111 }
112
113 pub fn tool(content: impl Into<String>, tool_call_id: impl Into<String>) -> Self {
115 Self {
116 role: "tool".to_string(),
117 content: content.into(),
118 name: None,
119 tool_calls: None,
120 tool_call_id: Some(tool_call_id.into()),
121 metadata: HashMap::new(),
122 }
123 }
124
125 pub fn with_tool_calls(mut self, tool_calls: Vec<ToolCall>) -> Self {
127 self.tool_calls = Some(tool_calls);
128 self
129 }
130
131 pub fn with_name(mut self, name: impl Into<String>) -> Self {
133 self.name = Some(name.into());
134 self
135 }
136}
137
138#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
140pub struct ToolCall {
141 pub id: String,
143
144 pub name: String,
146
147 pub arguments: serde_json::Value,
149}
150
151impl ToolCall {
152 pub fn new(
154 id: impl Into<String>,
155 name: impl Into<String>,
156 arguments: serde_json::Value,
157 ) -> Self {
158 Self {
159 id: id.into(),
160 name: name.into(),
161 arguments,
162 }
163 }
164}
165
166#[derive(Debug, Clone, Serialize, Deserialize)]
188pub struct MessagesState {
189 pub messages: Vec<Message>,
191}
192
193impl State for MessagesState {
194 fn merge(&mut self, other: Self) -> Result<()> {
195 add_messages(&mut self.messages, other.messages);
196 Ok(())
197 }
198}
199
200pub fn add_messages(existing: &mut Vec<Message>, new: Vec<Message>) {
209 existing.extend(new);
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
221pub struct DictState {
222 pub data: HashMap<String, serde_json::Value>,
224}
225
226impl State for DictState {
227 fn merge(&mut self, other: Self) -> Result<()> {
228 self.data.extend(other.data);
230 Ok(())
231 }
232}
233
234impl DictState {
235 pub fn new() -> Self {
237 Self {
238 data: HashMap::new(),
239 }
240 }
241
242 pub fn with_data(data: HashMap<String, serde_json::Value>) -> Self {
244 Self { data }
245 }
246
247 pub fn get(&self, key: &str) -> Option<&serde_json::Value> {
249 self.data.get(key)
250 }
251
252 pub fn set(&mut self, key: impl Into<String>, value: serde_json::Value) {
254 self.data.insert(key.into(), value);
255 }
256}
257
258impl Default for DictState {
259 fn default() -> Self {
260 Self::new()
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
269 struct TestState {
270 count: i32,
271 }
272
273 impl State for TestState {
274 fn merge(&mut self, other: Self) -> Result<()> {
275 self.count += other.count;
276 Ok(())
277 }
278 }
279
280 #[test]
281 fn test_state_merge() {
282 let mut state = TestState { count: 5 };
283 let other = TestState { count: 3 };
284
285 state.merge(other).unwrap();
286 assert_eq!(state.count, 8);
287 }
288
289 #[test]
290 fn test_message_creation() {
291 let msg = Message::user("Hello");
292 assert_eq!(msg.role, "user");
293 assert_eq!(msg.content, "Hello");
294
295 let msg = Message::assistant("Hi").with_name("bot");
296 assert_eq!(msg.name.as_deref(), Some("bot"));
297 }
298
299 #[test]
300 fn test_messages_state() {
301 let mut state = MessagesState {
302 messages: vec![Message::user("Hello")],
303 };
304
305 let update = MessagesState {
306 messages: vec![Message::assistant("Hi there!")],
307 };
308
309 state.merge(update).unwrap();
310 assert_eq!(state.messages.len(), 2);
311 assert_eq!(state.messages[0].role, "user");
312 assert_eq!(state.messages[1].role, "assistant");
313 }
314
315 #[test]
316 fn test_dict_state() {
317 let mut state = DictState::new();
318 state.set("key1", serde_json::json!("value1"));
319
320 let mut other = DictState::new();
321 other.set("key2", serde_json::json!(42));
322
323 state.merge(other).unwrap();
324
325 assert_eq!(state.data.len(), 2);
326 assert_eq!(state.get("key1").unwrap(), &serde_json::json!("value1"));
327 assert_eq!(state.get("key2").unwrap(), &serde_json::json!(42));
328 }
329
330 #[test]
331 fn test_tool_call() {
332 let tool_call = ToolCall::new("call-1", "search", serde_json::json!({"query": "rust"}));
333 assert_eq!(tool_call.id, "call-1");
334 assert_eq!(tool_call.name, "search");
335 }
336}