rust_agent/memory/
composite_memory.rs

1// Composite memory module
2use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::pin::Pin;
6use anyhow::{Error, Result};
7use serde::{Serialize, Deserialize};
8use serde_json::{json, Value};
9use tokio::sync::RwLock;
10use async_trait::async_trait;
11use log::{info, warn, error};
12use std::future::Future;
13
14use crate::memory::base::{BaseMemory, MemoryVariables};
15use crate::memory::message_history::{MessageHistoryMemory, ChatMessage};
16use crate::memory::summary::SummaryMemory;
17use crate::memory::utils::{
18    ensure_data_dir_exists, get_data_dir_from_env, get_summary_threshold_from_env,
19    get_recent_messages_count_from_env, generate_session_id
20};
21
22/// Composite memory configuration
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CompositeMemoryConfig {
25    /// Data directory
26    pub data_dir: PathBuf,
27    /// Session ID (automatically generated internally)
28    pub session_id: Option<String>,
29    /// Summary threshold (in token count, 1 token ≈ 4 English characters, 1 token ≈ 1 Chinese character)
30    pub summary_threshold: usize,
31    /// Number of recent messages to keep (in message count)
32    pub recent_messages_count: usize,
33    /// Whether to automatically generate summaries
34    pub auto_generate_summary: bool,
35}
36
37impl Default for CompositeMemoryConfig {
38    fn default() -> Self {
39        Self {
40            data_dir: get_data_dir_from_env(),
41            session_id: None, // Will be automatically generated internally
42            summary_threshold: get_summary_threshold_from_env(),
43            recent_messages_count: get_recent_messages_count_from_env(),
44            auto_generate_summary: true,
45        }
46    }
47}
48
49/// Composite memory implementation
50/// 
51/// This struct combines multiple memory types, providing a unified interface to manage different types of memory.
52/// It can simultaneously manage message history and summary memory, and provide intelligent summary generation functionality.
53#[derive(Debug, Clone)]
54pub struct CompositeMemory {
55    /// Configuration
56    config: CompositeMemoryConfig,
57    /// Message history memory
58    message_history: Option<Arc<MessageHistoryMemory>>,
59    /// Summary memory
60    summary_memory: Option<Arc<SummaryMemory>>,
61    /// In-memory memory variables
62    memory_variables: Arc<RwLock<MemoryVariables>>,
63}
64
65impl CompositeMemory {
66    /// Create a new composite memory instance
67    pub async fn new() -> Result<Self> {
68        Self::with_config(CompositeMemoryConfig::default()).await
69    }
70
71    /// Create a composite memory instance with basic parameters
72    /// This is the recommended constructor, only requires necessary parameters
73    /// session_id will be automatically generated internally
74    pub async fn with_basic_params(
75        data_dir: PathBuf,
76        summary_threshold: usize,
77        recent_messages_count: usize,
78    ) -> Result<Self> {
79        let config = CompositeMemoryConfig {
80            data_dir,
81            session_id: None, // Will be automatically generated internally
82            summary_threshold,
83            recent_messages_count,
84            auto_generate_summary: true,
85        };
86        Self::with_config(config).await
87    }
88
89    /// Create a composite memory instance with configuration
90    pub async fn with_config(config: CompositeMemoryConfig) -> Result<Self> {
91        // Ensure data directory exists
92        ensure_data_dir_exists(&config.data_dir).await?;
93
94        // Automatically generate session ID (if not provided)
95        let session_id = config.session_id.clone()
96            .unwrap_or_else(|| generate_session_id());
97
98        // Always create message history memory
99        let history = MessageHistoryMemory::new_with_recent_count(
100            session_id.clone(),
101            config.data_dir.clone(),
102            config.recent_messages_count
103        ).await?;
104        let message_history = Some(Arc::new(history));
105
106        // Always create summary memory with shared message history
107        let summary = SummaryMemory::new_with_shared_history(
108            session_id.clone(),
109            config.data_dir.clone(),
110            config.summary_threshold,
111            message_history.clone().unwrap() // We just created it, so it's safe to unwrap
112        ).await?;
113        let summary_memory = Some(Arc::new(summary));
114
115        Ok(Self {
116            config,
117            message_history,
118            summary_memory,
119            memory_variables: Arc::new(RwLock::new(HashMap::new())),
120        })
121    }
122
123    /// Create a composite memory instance with session ID
124    pub async fn with_session_id(session_id: String) -> Result<Self> {
125        let mut config = CompositeMemoryConfig::default();
126        config.session_id = Some(session_id);
127        Self::with_config(config).await
128    }
129
130    /// Add message to memory
131    pub async fn add_message(&self, message: ChatMessage) -> Result<()> {
132        // Add to message history (always enabled)
133        if let Some(ref history) = self.message_history {
134            history.add_message(&message).await?;
135        }
136
137        // Check if summary generation is needed (always enabled)
138        if self.config.auto_generate_summary {
139            info!("Checking if summary generation is needed...");
140            // Directly call SummaryMemory's check_and_generate_summary method
141            // This avoids duplicate implementation of summary generation logic and simplifies the call chain
142            if let Some(ref summary) = self.summary_memory {
143                summary.check_and_generate_summary().await?;
144                
145                // Clean up old messages
146                if let Some(ref history) = self.message_history {
147                    let keep_count = self.config.recent_messages_count;
148                    history.keep_recent_messages(keep_count).await?;
149                }
150            }
151        }
152
153        Ok(())
154    }
155
156    /// Get message count
157    pub async fn get_message_count(&self) -> Result<usize> {
158        if let Some(ref history) = self.message_history {
159            history.get_message_count().await
160        } else {
161            Ok(0)
162        }
163    }
164
165    /// Get the most recent N messages
166    pub async fn get_recent_messages(&self, count: usize) -> Result<Vec<ChatMessage>> {
167        if let Some(ref history) = self.message_history {
168            history.get_recent_chat_messages(count).await
169        } else {
170            Ok(Vec::new())
171        }
172    }
173
174    /// Clean up old messages
175    pub async fn cleanup_old_messages(&self) -> Result<()> {
176        if let Some(ref history) = self.message_history {
177            history.keep_recent_messages(self.config.recent_messages_count).await?;
178        }
179        Ok(())
180    }
181
182    /// Get memory statistics
183    pub async fn get_memory_stats(&self) -> Result<Value> {
184        let mut stats = json!({
185            "config": {
186                "summary_threshold": self.config.summary_threshold,
187                "recent_messages_count": self.config.recent_messages_count,
188                "auto_generate_summary": self.config.auto_generate_summary,
189            }
190        });
191
192        // Add message history statistics (always enabled)
193        if let Some(ref history) = self.message_history {
194            let message_count: usize = history.get_message_count().await?;
195            stats["message_history"] = json!({
196                "enabled": true,
197                "message_count": message_count,
198            });
199        }
200
201        // Add summary memory statistics (always enabled)
202        if let Some(ref summary) = self.summary_memory {
203            let summary_data = summary.load_summary().await?;
204            stats["summary_memory"] = json!({
205                "enabled": true,
206                "has_summary": summary_data.summary.is_some(),
207                "token_count": summary_data.token_count,
208                "last_updated": summary_data.last_updated,
209            });
210        }
211
212        Ok(stats)
213    }
214
215    /// Get summary content
216    pub async fn get_summary(&self) -> Result<Option<String>> {
217        if let Some(ref summary) = self.summary_memory {
218            let summary_data = summary.load_summary().await?;
219            Ok(summary_data.summary)
220        } else {
221            Ok(None)
222        }
223    }
224}
225
226// Implement as_any method for CompositeMemory's BaseMemory trait
227impl CompositeMemory {
228    /// Get Any reference for type conversion
229    pub fn as_any(&self) -> &dyn std::any::Any {
230        self
231    }
232}
233
234#[async_trait]
235impl BaseMemory for CompositeMemory {
236    fn memory_variables(&self) -> Vec<String> {
237        // Return all memory variables
238        let mut vars = Vec::new();
239        
240        // Add base memory variables
241        vars.extend_from_slice(&["chat_history".to_string(), "summary".to_string(), "input".to_string(), "output".to_string()]);
242        
243        // Add configuration related variables
244        vars.push("config".to_string());
245        
246        vars
247    }
248
249    fn load_memory_variables<'a>(&'a self, inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>>> + Send + 'a>> {
250        Box::pin(async move {
251            let mut result = HashMap::new();
252
253            // Load chat history (always enabled)
254            if let Some(ref history) = self.message_history {
255                let messages = history.get_recent_chat_messages(
256                    self.config.recent_messages_count
257                ).await?;
258                
259                let history_json = serde_json::to_value(&messages)?;
260                result.insert("chat_history".to_string(), history_json);
261            }
262
263            // Load summary (always enabled)
264            if let Some(ref summary) = self.summary_memory {
265                let summary_data = summary.load_summary().await?;
266                
267                if let Some(summary_text) = summary_data.summary {
268                    result.insert("summary".to_string(), json!(summary_text));
269                }
270            }
271
272            // Add input
273            if let Some(input) = inputs.get("input") {
274                result.insert("input".to_string(), input.clone());
275            }
276
277            // Add output
278            if let Some(output) = inputs.get("output") {
279                result.insert("output".to_string(), output.clone());
280            }
281
282            // Add configuration information
283            result.insert("config".to_string(), serde_json::to_value(&self.config)?);
284
285            // Update internal memory variables
286            *self.memory_variables.write().await = result.clone();
287
288            Ok(result)
289        })
290    }
291
292    fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
293        Box::pin(async move {
294            // Extract input and output
295            let input = inputs.get("input")
296                .and_then(|v| v.as_str())
297                .unwrap_or("");
298                
299            let output = outputs.get("output")
300                .and_then(|v| v.as_str())
301                .unwrap_or("");
302
303            // Create user message
304            if !input.is_empty() {
305                let user_message = ChatMessage {
306                    id: uuid::Uuid::new_v4().to_string(),
307                    role: "user".to_string(),
308                    content: input.to_string(),
309                    timestamp: chrono::Utc::now().to_rfc3339(),
310                    metadata: None,
311                };
312                
313                // Add directly to message history without triggering summary generation
314                if let Some(ref history) = self.message_history {
315                    history.add_message(&user_message).await?;
316                }
317            }
318
319            // Create assistant message
320            if !output.is_empty() {
321                let assistant_message = ChatMessage {
322                    id: uuid::Uuid::new_v4().to_string(),
323                    role: "assistant".to_string(),
324                    content: output.to_string(),
325                    timestamp: chrono::Utc::now().to_rfc3339(),
326                    metadata: None,
327                };
328                
329                // Add directly to message history without triggering summary generation
330                if let Some(ref history) = self.message_history {
331                    history.add_message(&assistant_message).await?;
332                }
333            }
334
335            // Check for summary generation only once after all messages are added
336            if self.config.auto_generate_summary {
337                info!("Checking if summary generation is needed...");
338                if let Some(ref summary) = self.summary_memory {
339                    summary.check_and_generate_summary().await?;
340                    
341                    // Keep only recent messages after summary generation
342                    if let Some(ref history) = self.message_history {
343                        let keep_count = self.config.recent_messages_count;
344                        history.keep_recent_messages(keep_count).await?;
345                    }
346                }
347            }
348
349            // Update internal memory variables
350            let mut memory_vars = self.memory_variables.write().await;
351            
352            if let Some(input_val) = inputs.get("input") {
353                memory_vars.insert("input".to_string(), input_val.clone());
354            }
355            
356            if let Some(output_val) = outputs.get("output") {
357                memory_vars.insert("output".to_string(), output_val.clone());
358            }
359
360            Ok(())
361        })
362    }
363
364    fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
365        Box::pin(async move {
366            // Clear message history (always enabled)
367            if let Some(ref history) = self.message_history {
368                history.clear().await?;
369            }
370
371            // Clear summary memory (always enabled)
372            if let Some(ref summary) = self.summary_memory {
373                summary.clear().await?;
374            }
375
376            // Clear internal memory variables
377            self.memory_variables.write().await.clear();
378
379            Ok(())
380        })
381    }
382
383    fn clone_box(&self) -> Box<dyn BaseMemory> {
384        Box::new(self.clone())
385    }
386
387    fn get_session_id(&self) -> Option<&str> {
388        self.config.session_id.as_deref()
389    }
390
391    fn set_session_id(&mut self, session_id: String) {
392        self.config.session_id = Some(session_id);
393    }
394
395    fn get_token_count(&self) -> Result<usize, Error> {
396        // This is a simplified implementation, actual applications may need more precise calculation
397        let mut count = 0;
398        
399        // Estimate configuration token count
400        if let Ok(config_json) = serde_json::to_value(&self.config) {
401            count += crate::memory::utils::estimate_json_token_count(&config_json);
402        }
403        
404        // Estimate memory variables token count
405        if let Ok(memory_vars) = self.memory_variables.try_read() {
406            if let Ok(vars_json) = serde_json::to_value(&*memory_vars) {
407                count += crate::memory::utils::estimate_json_token_count(&vars_json);
408            }
409        }
410        
411        Ok(count)
412    }
413    
414    fn as_any(&self) -> &dyn std::any::Any {
415        self
416    }
417}
418
419#[cfg(test)]
420mod tests {
421    use super::*;
422    use tempfile::TempDir;
423    use crate::memory::message_history::ChatMessage;
424
425    #[tokio::test]
426    async fn test_composite_memory_new() {
427        let memory = CompositeMemory::new().await;
428        assert!(memory.is_ok());
429    }
430
431    #[tokio::test]
432    async fn test_composite_memory_with_session_id() {
433        let session_id = "test_session";
434        let memory = CompositeMemory::with_session_id(session_id.to_string()).await;
435        assert!(memory.is_ok());
436        
437        let memory = memory.unwrap();
438        assert_eq!(memory.get_session_id(), Some(session_id));
439    }
440
441    #[tokio::test]
442    async fn test_add_message() {
443        let temp_dir = TempDir::new().unwrap();
444        let mut config = CompositeMemoryConfig::default();
445        config.data_dir = temp_dir.path().to_path_buf();
446        config.auto_generate_summary = false; // Disable auto summary for testing
447        
448        let memory = CompositeMemory::with_config(config).await.unwrap();
449        
450        let message = ChatMessage {
451            id: "test_id".to_string(),
452            role: "user".to_string(),
453            content: "Hello, world!".to_string(),
454            timestamp: chrono::Utc::now().to_rfc3339(),
455            metadata: None,
456        };
457        
458        let result = memory.add_message(message).await;
459        assert!(result.is_ok());
460        
461        let count = memory.get_message_count().await.unwrap();
462        assert_eq!(count, 1);
463    }
464
465    #[tokio::test]
466    async fn test_save_context() {
467        let temp_dir = TempDir::new().unwrap();
468        let mut config = CompositeMemoryConfig::default();
469        config.data_dir = temp_dir.path().to_path_buf();
470        config.auto_generate_summary = false; // Disable auto summary for testing
471        
472        let memory = CompositeMemory::with_config(config).await.unwrap();
473        
474        let mut inputs = HashMap::new();
475        inputs.insert("input".to_string(), json!("Hello"));
476        
477        let mut outputs = HashMap::new();
478        outputs.insert("output".to_string(), json!("Hi there!"));
479        
480        let result = memory.save_context(&inputs, &outputs).await;
481        assert!(result.is_ok());
482        
483        let count = memory.get_message_count().await.unwrap();
484        assert_eq!(count, 2); // User message and assistant message
485    }
486
487    #[tokio::test]
488    async fn test_clear() {
489        let temp_dir = TempDir::new().unwrap();
490        let mut config = CompositeMemoryConfig::default();
491        config.data_dir = temp_dir.path().to_path_buf();
492        config.auto_generate_summary = false; // Disable auto summary for testing
493        
494        let memory = CompositeMemory::with_config(config).await.unwrap();
495        
496        // Add some messages
497        let mut inputs = HashMap::new();
498        inputs.insert("input".to_string(), json!("Hello"));
499        
500        let mut outputs = HashMap::new();
501        outputs.insert("output".to_string(), json!("Hi there!"));
502        
503        memory.save_context(&inputs, &outputs).await.unwrap();
504        
505        // Verify messages have been added
506        let count = memory.get_message_count().await.unwrap();
507        assert_eq!(count, 2);
508        
509        // Clear memory
510        let result = memory.clear().await;
511        assert!(result.is_ok());
512        
513        // Verify messages have been cleared
514        let count = memory.get_message_count().await.unwrap();
515        assert_eq!(count, 0);
516    }
517}