rust_agent/memory/
base.rs

1// Basic memory interface definition
2use anyhow::Error;
3use std::collections::HashMap;
4use std::sync::Arc;
5use tokio::sync::RwLock;
6use serde_json::Value;
7use std::pin::Pin;
8use std::future::Future;
9use log::info;
10
11// Memory variable type alias
12pub type MemoryVariables = HashMap<String, Value>;
13
14// Minimal memory abstraction interface
15pub trait BaseMemory: Send + Sync {
16    // Get memory variable names
17    fn memory_variables(&self) -> Vec<String>;
18    
19    // Core method: load memory variables
20    fn load_memory_variables<'a>(&'a self, inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>>;
21    
22    // Core method: save context
23    fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>>;
24    
25    // Optional method: clear memory
26    fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>>;
27    
28    // Clone method
29    fn clone_box(&self) -> Box<dyn BaseMemory>;
30    
31    // New method: get session ID
32    fn get_session_id(&self) -> Option<&str>;
33    
34    // New method: set session ID
35    fn set_session_id(&mut self, session_id: String);
36    
37    // New method: get token count
38    fn get_token_count(&self) -> Result<usize, Error>;
39    
40    // New method: get Any reference for type conversion
41    fn as_any(&self) -> &dyn std::any::Any;
42}
43
44// Simple memory implementation, similar to Langchain's ConversationBufferMemory
45#[derive(Debug)]
46pub struct SimpleMemory {
47    memories: Arc<RwLock<HashMap<String, Value>>>,
48    memory_key: String,
49    session_id: Option<String>,
50}
51
52impl Clone for SimpleMemory {
53    fn clone(&self) -> Self {
54        Self {
55            memories: Arc::clone(&self.memories),
56            memory_key: self.memory_key.clone(),
57            session_id: self.session_id.clone(),
58        }
59    }
60}
61
62impl SimpleMemory {
63    pub fn new() -> Self {
64        Self {
65            memories: Arc::new(RwLock::new(HashMap::new())),
66            memory_key: "chat_history".to_string(),
67            session_id: None,
68        }
69    }
70    
71    pub fn with_memory_key(memory_key: String) -> Self {
72        Self {
73            memories: Arc::new(RwLock::new(HashMap::new())),
74            memory_key,
75            session_id: None,
76        }
77    }
78    
79    pub fn with_memories(memories: HashMap<String, Value>) -> Self {
80        Self {
81            memories: Arc::new(RwLock::new(memories)),
82            memory_key: "chat_history".to_string(),
83            session_id: None,
84        }
85    }
86    
87    pub async fn add_message(&self, message: Value) -> Result<(), Error> {
88        let mut memories = self.memories.write().await;
89        let chat_history = memories.entry(self.memory_key.clone()).or_insert_with(|| Value::Array(vec![]));
90        
91        if let Value::Array(ref mut arr) = chat_history {
92            arr.push(message);
93        } else {
94            *chat_history = Value::Array(vec![message]);
95        }
96        
97        Ok(())
98    }
99    
100    pub fn get_memory_key(&self) -> String {
101        self.memory_key.clone()
102    }
103}
104
105impl Default for SimpleMemory {
106    fn default() -> Self {
107        Self::new()
108    }
109}
110
111impl BaseMemory for SimpleMemory {
112    fn memory_variables(&self) -> Vec<String> {
113        vec![self.memory_key.clone()]
114    }
115    
116    fn load_memory_variables<'a>(&'a self, _inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>> {
117        let memories = Arc::clone(&self.memories);
118        Box::pin(async move {
119            let memories = memories.read().await;
120            Ok(memories.clone())
121        })
122    }
123    
124    fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
125        let memories = Arc::clone(&self.memories);
126        let input_clone = inputs.clone();
127        let output_clone = outputs.clone();
128        let memory_key = self.memory_key.clone();
129        
130        Box::pin(async move {
131            let mut memories = memories.write().await;
132            
133            // Get or create chat history array
134            let chat_history = memories.entry(memory_key.clone()).or_insert_with(|| Value::Array(vec![]));
135            
136            // Ensure chat_history is an array type
137            if !chat_history.is_array() {
138                *chat_history = Value::Array(vec![]);
139            }
140            
141            // Add input as human message or tool message to chat history
142            if let Some(input_value) = input_clone.get("input") {
143                let user_message = serde_json::json!({
144                        "role": "human",
145                        "content": input_value
146                    });
147                
148                if let Value::Array(ref mut arr) = chat_history {
149                    info!("Adding to chat history: {:?}", user_message);
150                    arr.push(user_message);
151                }
152            }
153            
154            // Add output as AI message to chat history
155            if let Some(output_value) = output_clone.get("output") {
156                let ai_message = serde_json::json!({
157                    "role": "ai",
158                    "content": output_value
159                });
160                
161                if let Value::Array(ref mut arr) = chat_history {
162                    arr.push(ai_message);
163                }
164            }
165            
166            Ok(())
167        })
168    }
169    
170    fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
171        let memories = Arc::clone(&self.memories);
172        Box::pin(async move {
173            let mut memories = memories.write().await;
174            memories.clear();
175            Ok(())
176        })
177    }
178    
179    fn clone_box(&self) -> Box<dyn BaseMemory> {
180        Box::new(self.clone())
181    }
182    
183    fn get_session_id(&self) -> Option<&str> {
184        self.session_id.as_deref()
185    }
186    
187    fn set_session_id(&mut self, session_id: String) {
188        self.session_id = Some(session_id);
189    }
190    
191    fn get_token_count(&self) -> Result<usize, Error> {
192        // Simplified implementation: estimate token count based on character count
193        let count = self.memory_key.len() + self.session_id.as_ref().map(|s| s.len()).unwrap_or(0);
194        Ok(count)
195    }
196    
197    fn as_any(&self) -> &dyn std::any::Any {
198        self
199    }
200}
201
202// Implement Clone trait for Box<dyn BaseMemory>
203impl Clone for Box<dyn BaseMemory> {
204    fn clone(&self) -> Self {
205        self.as_ref().clone_box()
206    }
207}