rust_agent/memory/
summary.rs

1// Summary memory implementation, generates summaries when token count exceeds threshold
2use std::path::PathBuf;
3use anyhow::{Error, Result};
4use serde::{Serialize, Deserialize};
5use std::collections::HashMap;
6use serde_json::{json, Value};
7use std::pin::Pin;
8use std::future::Future;
9use log::info;
10use uuid;
11use chrono;
12use crate::ChatMessage;
13use std::sync::Arc;
14
15// Import FileChatMessageHistory
16use crate::memory::message_history::{FileChatMessageHistory, ChatMessageRecord, MessageHistoryMemory};
17// Import utility functions
18use crate::memory::utils::estimate_text_tokens;
19// Import common models
20use crate::{ChatModel, OpenAIChatModel, ModelChatMessage, ChatMessageContent};
21
22/// Summary data structure
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct SummaryData {
25    /// Session ID
26    pub session_id: String,
27    /// Summary update sequence number, used for incremental summary updates
28    pub sequence_number: u64,
29    /// Summary content
30    pub summary: Option<String>,
31    /// Token count (approximately equal to message count * 4)
32    pub token_count: usize,
33    /// Last update time
34    pub last_updated: String,
35}
36
37impl Default for SummaryData {
38    fn default() -> Self {
39        Self {
40            session_id: String::new(),
41            sequence_number: 0,
42            summary: None,
43            token_count: 0,
44            last_updated: chrono::Utc::now().to_rfc3339(),
45        }
46    }
47}
48
49/// Summary memory implementation
50/// 
51/// This struct is responsible for generating and managing conversation summaries.
52/// It can automatically generate summaries when the conversation reaches a certain length,
53/// and load previous summaries when needed.
54#[derive(Debug)]
55pub struct SummaryMemory {
56    /// Session ID
57    session_id: String,
58    /// Data directory
59    data_dir: PathBuf,
60    /// Summary threshold (in token count, 1 token ≈ 4 English characters, 1 token ≈ 1 Chinese character)
61    summary_threshold: usize,
62    /// Summary prompt template
63    summary_prompt_template: String,
64    /// Number of recent messages to keep (in message count)
65    recent_messages_count: usize,
66    /// Shared message history memory (optional)
67    message_history: Option<Arc<MessageHistoryMemory>>,
68}
69
70impl Clone for SummaryMemory {
71    fn clone(&self) -> Self {
72        Self {
73            session_id: self.session_id.clone(),
74            data_dir: self.data_dir.clone(),
75            summary_threshold: self.summary_threshold,
76            summary_prompt_template: self.summary_prompt_template.clone(),
77            recent_messages_count: self.recent_messages_count,
78            message_history: self.message_history.clone(),
79        }
80    }
81}
82
83impl SummaryMemory {
84    /// Create a new summary memory instance
85    pub async fn new(session_id: String, data_dir: PathBuf, summary_threshold: usize) -> Result<Self> {
86        // Ensure data directory exists
87        tokio::fs::create_dir_all(&data_dir).await?;
88        
89        Ok(Self {
90            session_id,
91            data_dir,
92            summary_threshold,
93            summary_prompt_template: "Please provide a concise summary of the following conversation. Focus on the main topics discussed, key decisions made, and any important outcomes.\n\nConversation:\n{chat_history}\n\nSummary:".to_string(),
94            recent_messages_count: crate::memory::utils::get_recent_messages_count_from_env(),
95            message_history: None,
96        })
97    }
98    
99    /// Create a new summary memory instance with shared message history
100    pub async fn new_with_shared_history(
101        session_id: String, 
102        data_dir: PathBuf, 
103        summary_threshold: usize,
104        message_history: Arc<MessageHistoryMemory>
105    ) -> Result<Self> {
106        // Ensure data directory exists
107        tokio::fs::create_dir_all(&data_dir).await?;
108        
109        Ok(Self {
110            session_id,
111            data_dir,
112            summary_threshold,
113            summary_prompt_template: "Please provide a concise summary of the following conversation. Focus on the main topics discussed, key decisions made, and any important outcomes.\n\nConversation:\n{chat_history}\n\nSummary:".to_string(),
114            recent_messages_count: crate::memory::utils::get_recent_messages_count_from_env(),
115            message_history: Some(message_history),
116        })
117    }
118    
119    /// Set summary prompt template
120    pub fn with_summary_prompt_template(mut self, template: String) -> Self {
121        self.summary_prompt_template = template;
122        self
123    }
124    
125    /// Set the number of recent messages to keep
126    pub fn with_recent_messages_count(mut self, count: usize) -> Self {
127        self.recent_messages_count = count;
128        self
129    }
130    
131    /// Get summary file path
132    fn get_summary_file_path(&self) -> PathBuf {
133        self.data_dir.join(format!("{}_summary.json", self.session_id))
134    }
135    
136    /// Load context from memory
137    pub async fn load_context(&self) -> Result<Vec<String>> {
138        // Load summary
139        let summary_data = self.load_summary().await?;
140        
141        // Load message history
142        let messages = if let Some(ref history) = self.message_history {
143            // Use shared message history
144            history.get_recent_messages(self.recent_messages_count).await?
145        } else {
146            // Create new FileChatMessageHistory instance
147            let file_path = self.data_dir.join(format!("{}_history.jsonl", self.session_id));
148            let chat_history = FileChatMessageHistory::new(self.session_id.clone(), file_path).await?;
149            chat_history.get_messages().await?
150        };
151        
152        // Build context vector
153        let mut context = Vec::new();
154        
155        // Add summary (if exists)
156        if let Some(summary) = summary_data.summary {
157            context.push(format!("Previous conversation summary: {}", summary));
158        }
159        
160        // Add recent messages
161        for msg in messages {
162            context.push(format!("{}: {}", msg.role, msg.content));
163        }
164        
165        Ok(context)
166    }
167    
168    /// Load summary
169    pub async fn load_summary(&self) -> Result<SummaryData> {
170        let file_path = self.get_summary_file_path();
171        
172        if !tokio::fs::metadata(&file_path).await.is_ok() {
173            return Ok(SummaryData {
174                session_id: self.session_id.clone(),
175                sequence_number: 0,
176                summary: None,
177                token_count: 0,
178                last_updated: chrono::Utc::now().to_rfc3339(),
179            });
180        }
181        
182        let contents = tokio::fs::read_to_string(&file_path).await?;
183        let summary_data: SummaryData = serde_json::from_str(&contents)?;
184        
185        Ok(summary_data)
186    }
187    
188    /// Save summary
189    async fn save_summary(&self, summary: &str, sequence_number: u64) -> Result<()> {
190        let file_path = self.get_summary_file_path();
191        
192        // Calculate token count for the summary
193        let token_count = estimate_text_tokens(summary);
194        
195        let summary_data = SummaryData {
196            session_id: self.session_id.clone(),
197            sequence_number,
198            summary: Some(summary.to_string()),
199            token_count,
200            last_updated: chrono::Utc::now().to_rfc3339(),
201        };
202        
203        let json = serde_json::to_string(&summary_data)?;
204        tokio::fs::write(&file_path, json).await?;
205        
206        Ok(())
207    }
208    
209    /// Generate summary
210    async fn generate_summary(&self, messages: &[ChatMessageRecord]) -> Result<(String, u64)> {
211        info!("Generating summary for {} messages", messages.len());
212
213        // Convert messages to text format
214        let mut chat_text = String::new();
215        for msg in messages {
216            let role = if msg.role == "user" { "User" } else { "Assistant" };
217            chat_text.push_str(&format!("{}: {}\n", role, msg.content));
218        }
219        
220        // Use summary prompt template
221        let summary_prompt = self.summary_prompt_template.replace("{chat_history}", &chat_text);
222        
223        // Get API key and base URL from environment variables
224        let api_key = std::env::var("OPENAI_API_KEY").unwrap_or_else(|_| "OPENAI_API_KEY".to_string());
225        let base_url = std::env::var("OPENAI_API_URL").ok();
226        
227        // // Check if API key is valid, return error if invalid
228        // if api_key == "OPENAI_API_KEY" || api_key.is_empty() || api_key.starts_with("mock_api") {
229        //     return Err(anyhow::anyhow!("OpenAI API key is not configured or is invalid. Please set the OPENAI_API_KEY environment variable."));
230        // }
231        
232        // Create OpenAI model instance
233        let model = crate::OpenAIChatModel::new(api_key.clone(), base_url)
234            .with_model(std::env::var("OPENAI_API_MODEL").unwrap_or_else(|_| "gpt-3.5-turbo".to_string()))
235            .with_temperature(0.3)
236            .with_max_tokens(1024);
237        
238        // Build message list
239        let model_messages = vec![
240            crate::ModelChatMessage::System(crate::ChatMessageContent {
241                content: "You are a helpful assistant that creates concise summaries of conversations.".to_string(),
242                name: None,
243                additional_kwargs: std::collections::HashMap::new(),
244            }),
245            crate::ModelChatMessage::Human(crate::ChatMessageContent {
246                content: summary_prompt,
247                name: None,
248                additional_kwargs: std::collections::HashMap::new(),
249            }),
250        ];
251        
252        // Call model to generate summary
253        let response = model.invoke(model_messages).await?;
254        
255        // Extract response content
256        let summary = match response.message {
257            crate::ModelChatMessage::AIMessage(content) => content.content,
258            _ => return Err(anyhow::anyhow!("Expected AI message response")),
259        };
260
261        // Get the sequence number of the last message as the summary update sequence number
262        let last_sequence_number = messages.last()
263            .map(|msg| msg.sequence_number)
264            .unwrap_or(0);
265
266        // Save summary to file
267        self.save_summary(&summary, last_sequence_number).await?;
268        
269        Ok((summary, last_sequence_number))
270    }
271    
272    /// Check if summary needs to be generated and generate if needed
273    pub async fn check_and_generate_summary(&self) -> Result<bool> {
274        // Load current summary data to get the last sequence number
275        let summary_data = self.load_summary().await?;
276        let last_summary_sequence = summary_data.sequence_number;
277        
278        // Load message history
279        let messages = if let Some(ref message_history) = self.message_history {
280            // Get all messages
281            message_history.get_recent_messages(usize::MAX).await?
282        } else {
283            return Ok(false);
284        };
285        
286        // If no messages, no need to generate summary
287        if messages.is_empty() {
288            return Ok(false);
289        }
290        
291        // Filter messages to only include those after the last summary sequence number
292        let messages_to_summarize: Vec<ChatMessageRecord> = messages
293            .into_iter()
294            .filter(|msg| msg.sequence_number > last_summary_sequence)
295            .collect();
296        
297        // If no new messages since last summary, no need to generate summary
298        if messages_to_summarize.is_empty() {
299            return Ok(false);
300        }
301        
302        // Calculate total tokens in new messages
303        let mut chat_text = String::new();
304        for msg in &messages_to_summarize {
305            chat_text.push_str(&format!("{}: {}\n", msg.role, msg.content));
306        }
307
308        let total_tokens = estimate_text_tokens(&chat_text);
309        
310        // If token count exceeds threshold, generate summary
311        if total_tokens > self.summary_threshold {
312            info!("[SummaryMemory] Generating summary... ({} new messages, {} tokens)", messages_to_summarize.len(), total_tokens);
313            
314            // Generate summary
315            let (summary, _) = self.generate_summary(&messages_to_summarize).await?;
316            
317            // Get the sequence number of the last message
318            let last_sequence = messages_to_summarize.last().map(|m| m.sequence_number).unwrap_or(0);
319            
320            // Save summary
321            self.save_summary(&summary, last_sequence).await?;
322            
323            // Keep only recent messages
324            if let Some(ref message_history) = self.message_history {
325                message_history.keep_recent_messages(self.recent_messages_count).await?;
326            }
327            
328            Ok(true)
329        } else {
330            Ok(false)
331        }
332    }
333    
334    /// Get session ID
335    pub fn get_session_id(&self) -> &str {
336        &self.session_id
337    }
338    
339    /// Get memory statistics
340    pub async fn get_memory_stats(&self) -> Result<Value> {
341        // Load summary
342        let summary_data = self.load_summary().await?;
343        
344        // Load message history
345        let file_path = self.data_dir.join(format!("{}_history.jsonl", self.session_id.clone()));
346        let chat_history = FileChatMessageHistory::new(self.session_id.clone(), file_path).await?;
347        let messages = chat_history.get_messages().await?;
348        
349        // Calculate total tokens in messages
350        let mut chat_text = String::new();
351        for msg in &messages {
352            chat_text.push_str(&format!("{}: {}\n", msg.role, msg.content));
353        }
354        let token_count = estimate_text_tokens(&chat_text);
355        
356        let stats = json!({
357            "session_id": self.session_id,
358            "summary_threshold": self.summary_threshold,
359            "recent_messages_count": self.recent_messages_count,
360            "message_count": messages.len(),
361            "token_count": token_count,
362            "has_summary": summary_data.summary.is_some(),
363            "summary_token_count": summary_data.token_count,
364            "last_updated": summary_data.last_updated
365        });
366        
367        Ok(stats)
368    }
369}
370
371// Implement BaseMemory trait, compatible with existing system
372use crate::memory::base::BaseMemory;
373
374impl BaseMemory for SummaryMemory {
375    fn memory_variables(&self) -> Vec<String> {
376        vec!["chat_history".to_string()]
377    }
378    
379    fn load_memory_variables<'a>(&'a self, _inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>> {
380        let session_id = self.session_id.clone();
381        let data_dir = self.data_dir.clone();
382        let summary_threshold = self.summary_threshold;
383        let recent_messages_count = self.recent_messages_count;
384        let use_shared_history = self.message_history.is_some();
385        
386        Box::pin(async move {
387            // Load summary
388            let summary_memory = SummaryMemory {
389                session_id: session_id.clone(),
390                data_dir: data_dir.clone(),
391                summary_threshold,
392                summary_prompt_template: String::new(),
393                recent_messages_count,
394                message_history: None, // We'll handle this separately
395            };
396            
397            let summary_data = summary_memory.load_summary().await?;
398            
399            // Load message history
400            let messages = if use_shared_history {
401                // This is a simplified approach - in a real implementation, we would need to pass the shared instance
402                // For now, we'll create a new instance but this should be improved
403                let file_path = data_dir.join(format!("{}_history.jsonl", session_id.clone()));
404                let chat_history = FileChatMessageHistory::new(session_id.clone(), file_path).await?;
405                chat_history.get_messages().await?
406            } else {
407                let file_path = data_dir.join(format!("{}_history.jsonl", session_id.clone()));
408                let chat_history = FileChatMessageHistory::new(session_id.clone(), file_path).await?;
409                chat_history.get_messages().await?
410            };
411            
412            // Convert to new format: system_prompt + chat_message
413            let mut history_array = Vec::new();
414            
415            // Build system prompt: basic_system_prompt + user_system_prompt + summary_prompt
416            let mut system_prompt_parts = Vec::new();
417            
418            // Add basic system prompt
419            system_prompt_parts.push("You are a helpful assistant that provides accurate and concise answers.".to_string());
420            
421            // Add user system prompt (if any)
422            if let Some(user_system_prompt) = std::env::var("USER_SYSTEM_PROMPT").ok() {
423                system_prompt_parts.push(user_system_prompt);
424            }
425            
426            // Add summary (if any)
427            if let Some(summary) = summary_data.summary {
428                system_prompt_parts.push(format!("Previous conversation summary: {}", summary));
429            }
430            
431            // Combine system prompt
432            let combined_system_prompt = system_prompt_parts.join("\n\n");
433            
434            // Add system prompt to history
435            let mut system_msg_obj = serde_json::Map::new();
436            system_msg_obj.insert("role".to_string(), serde_json::Value::String("system".to_string()));
437            system_msg_obj.insert("content".to_string(), serde_json::Value::String(combined_system_prompt));
438            history_array.push(serde_json::Value::Object(system_msg_obj));
439            
440            // Add recent messages (chat_message)
441            let len = messages.len();
442            let start = if len > recent_messages_count {
443                len - recent_messages_count
444            } else {
445                0
446            };
447            
448            for msg in &messages[start..] {
449                let mut msg_obj = serde_json::Map::new();
450                msg_obj.insert("role".to_string(), serde_json::Value::String(msg.role.clone()));
451                msg_obj.insert("content".to_string(), serde_json::Value::String(msg.content.clone()));
452                
453                if let Some(name) = &msg.name {
454                    msg_obj.insert("name".to_string(), serde_json::Value::String(name.clone()));
455                }
456                
457                if let Some(kwargs) = &msg.additional_kwargs {
458                    for (k, v) in kwargs {
459                        msg_obj.insert(k.clone(), v.clone());
460                    }
461                }
462                
463                history_array.push(serde_json::Value::Object(msg_obj));
464            }
465            
466            let mut result = HashMap::new();
467            result.insert("chat_history".to_string(), serde_json::Value::Array(history_array));
468            
469            Ok(result)
470        })
471    }
472    
473    fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
474        Box::pin(async move {
475            // Extract user and assistant messages
476            let mut user_message: Option<String> = None;
477            let mut assistant_message: Option<String> = None;
478            
479            // Check inputs for user message
480            if let Some(input_value) = inputs.get("input") {
481                if let Some(s) = input_value.as_str() {
482                    user_message = Some(s.to_string());
483                }
484            }
485            
486            // Check outputs for assistant message
487            if let Some(output_value) = outputs.get("output") {
488                if let Some(s) = output_value.as_str() {
489                    assistant_message = Some(s.to_string());
490                }
491            }
492            
493            // Add messages to shared message history if available
494            if let Some(ref message_history) = self.message_history {
495                if let Some(user_msg) = user_message {
496                    let chat_msg = ChatMessage {
497                        id: uuid::Uuid::new_v4().to_string(),
498                        role: "user".to_string(),
499                        content: user_msg,
500                        timestamp: chrono::Utc::now().to_rfc3339(),
501                        metadata: None,
502                    };
503                    message_history.add_message(&chat_msg).await?;
504                }
505                
506                if let Some(assistant_msg) = assistant_message {
507                    let chat_msg = ChatMessage {
508                        id: uuid::Uuid::new_v4().to_string(),
509                        role: "assistant".to_string(),
510                        content: assistant_msg,
511                        timestamp: chrono::Utc::now().to_rfc3339(),
512                        metadata: None,
513                    };
514                    message_history.add_message(&chat_msg).await?;
515                }
516                info!("save_context");
517                // Note: Removed check_and_generate_summary() call to avoid duplicate summary generation
518                // Summary generation is now handled by CompositeMemory::add_message
519            }
520            
521            Ok(())
522        })
523    }
524    
525    fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
526        let session_id = self.session_id.clone();
527        let data_dir = self.data_dir.clone();
528        
529        Box::pin(async move {
530            // Clear message history
531            let file_path = data_dir.join(format!("{}_history.jsonl", session_id.clone()));
532            let chat_history = FileChatMessageHistory::new(session_id.clone(), file_path).await?;
533            chat_history.clear().await?;
534            
535            // Clear summary file
536            let summary_path = data_dir.join(format!("{}_summary.json", session_id.clone()));
537            if tokio::fs::metadata(&summary_path).await.is_ok() {
538                tokio::fs::remove_file(&summary_path).await?;
539            }
540            
541            Ok(())
542        })
543    }
544    
545    fn clone_box(&self) -> Box<dyn BaseMemory> {
546        Box::new(self.clone())
547    }
548    
549    fn get_session_id(&self) -> Option<&str> {
550        Some(&self.session_id)
551    }
552    
553    fn set_session_id(&mut self, session_id: String) {
554        self.session_id = session_id;
555    }
556    
557    fn get_token_count(&self) -> Result<usize, Error> {
558        // Use common function to estimate token count
559        let text = format!("{}:{}", self.session_id, self.data_dir.to_string_lossy());
560        Ok(estimate_text_tokens(&text))
561    }
562    
563    fn as_any(&self) -> &dyn std::any::Any {
564        self
565    }
566}