rust_agent/memory/
message_history.rs

1// Long-term memory implementation, persisting conversation history to file
2use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use tokio::fs::File;
6use tokio::io::{AsyncWriteExt, AsyncReadExt};
7use anyhow::{Error, Result};
8use serde::{Serialize, Deserialize};
9use std::collections::HashMap;
10use serde_json::Value;
11use std::pin::Pin;
12use std::future::Future;
13use log::{info, warn};
14use chrono::Utc;
15
16// Chat message structure
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ChatMessage {
19    pub id: String,
20    pub role: String,
21    pub content: String,
22    pub timestamp: String,
23    pub metadata: Option<Value>,
24}
25
26/// Single message record
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ChatMessageRecord {
29    /// Message role: "system", "user", "assistant", "tool"
30    pub role: String,
31    /// Message content (without historical conversation)
32    pub content: String,
33    /// Optional message name
34    pub name: Option<String>,
35    /// Additional metadata
36    pub additional_kwargs: Option<HashMap<String, serde_json::Value>>,
37    /// Timestamp (ISO 8601 format)
38    pub timestamp: String,
39    /// Message sequence number (to ensure order)
40    pub sequence_number: u64,
41}
42
43/// Session-level message history structure
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ChatSessionHistory {
46    /// Session ID (identity identifier for the entire conversation)
47    pub session_id: String,
48    /// Session creation time
49    pub created_at: String,
50    /// Session last update time
51    pub updated_at: String,
52    /// Message list (in chronological order)
53    pub messages: Vec<ChatMessageRecord>,
54    /// Session-level metadata
55    pub metadata: Option<HashMap<String, serde_json::Value>>,
56}
57
58/// File message history implementation, aligned with LangChain's FileChatMessageHistory
59/// Uses JSONL format, one JSON object per line
60#[derive(Debug)]
61pub struct FileChatMessageHistory {
62    /// Session ID
63    session_id: String,
64    /// File path
65    file_path: PathBuf,
66    /// In-memory session history
67    session_history: Arc<RwLock<ChatSessionHistory>>,
68    /// Next message sequence number
69    next_sequence_number: Arc<RwLock<u64>>,
70}
71
72impl Clone for FileChatMessageHistory {
73    fn clone(&self) -> Self {
74        Self {
75            session_id: self.session_id.clone(),
76            file_path: self.file_path.clone(),
77            session_history: Arc::clone(&self.session_history),
78            next_sequence_number: Arc::clone(&self.next_sequence_number),
79        }
80    }
81}
82
83impl FileChatMessageHistory {
84    /// Create a new file message history instance
85    pub async fn new(session_id: String, file_path: PathBuf) -> Result<Self> {
86        // Ensure parent directory exists
87        if let Some(parent) = file_path.parent() {
88            tokio::fs::create_dir_all(parent).await?;
89        }
90        
91        // Initialize session history
92        let now = Utc::now().to_rfc3339();
93        let session_history = ChatSessionHistory {
94            session_id: session_id.clone(),
95            created_at: now.clone(),
96            updated_at: now,
97            messages: Vec::new(),
98            metadata: None,
99        };
100        
101        let instance = Self {
102            session_id: session_id.clone(),
103            file_path: file_path.clone(),
104            session_history: Arc::new(RwLock::new(session_history)),
105            next_sequence_number: Arc::new(RwLock::new(1)),
106        };
107        
108        // Try to load existing session history
109        instance.load_session_history().await?;
110        
111        Ok(instance)
112    }
113    
114    /// Load session history from file
115    async fn load_session_history(&self) -> Result<()> {
116        if !tokio::fs::metadata(&self.file_path).await.is_ok() {
117            // File doesn't exist, use default session history
118            return Ok(());
119        }
120        
121        let mut file = File::open(&self.file_path).await?;
122        let mut contents = String::new();
123        file.read_to_string(&mut contents).await?;
124        
125        if contents.trim().is_empty() {
126            return Ok(());
127        }
128        
129        // Try to parse as JSON format session history (entire file is a JSON object)
130        match serde_json::from_str::<ChatSessionHistory>(&contents) {
131            Ok(session_history) => {
132                // Update in-memory session history
133                {
134                    let mut history = self.session_history.write().await;
135                    *history = session_history;
136                }
137                
138                // Update next message sequence number
139                {
140                    let history = self.session_history.read().await;
141                    let next_seq = history.messages.len() as u64 + 1;
142                    let mut next_sequence = self.next_sequence_number.write().await;
143                    *next_sequence = next_seq;
144                }
145                
146                info!("[FileChatMessageHistory] Loaded session history with {} messages from JSONL format", {
147                    let history = self.session_history.read().await;
148                    history.messages.len()
149                });
150            },
151            Err(e) => {
152                // If parsing fails, try to parse as old format (one JSON object per line)
153                warn!("Failed to parse as session history JSON, trying as old format: {}", e);
154                
155                let mut messages = Vec::new();
156                let mut max_sequence_number = 0u64;
157                
158                for line in contents.lines() {
159                    if line.trim().is_empty() {
160                        continue;
161                    }
162                    
163                    // Try to parse JSONL format message
164                    match serde_json::from_str::<serde_json::Value>(line) {
165                        Ok(msg_value) => {
166                            // Check if it's an old format message (no sequence_number field)
167                            if msg_value.get("sequence_number").is_none() {
168                                // Try to migrate from old format
169                                if let (Some(role), Some(content)) = (
170                                    msg_value.get("role").and_then(|v| v.as_str()),
171                                    msg_value.get("content").and_then(|v| v.as_str())
172                                ) {
173                                    // Skip assistant messages containing complete history
174                                    if role == "assistant" && content.contains("user:") && content.contains("assistant:") {
175                                        continue;
176                                    }
177                                    
178                                    // Create new format message
179                                    let message = ChatMessageRecord {
180                                        role: role.to_string(),
181                                        content: content.to_string(),
182                                        name: msg_value.get("name").and_then(|v| v.as_str()).map(|s| s.to_string()),
183                                        additional_kwargs: msg_value.get("additional_kwargs").cloned().and_then(|v| {
184                                            if v.is_null() {
185                                                None
186                                            } else {
187                                                Some(serde_json::from_value(v).unwrap_or_default())
188                                            }
189                                        }),
190                                        timestamp: msg_value.get("timestamp")
191                                            .and_then(|v| v.as_str())
192                                            .unwrap_or(&Utc::now().to_rfc3339())
193                                            .to_string(),
194                                        sequence_number: max_sequence_number + 1,
195                                    };
196                                    
197                                    max_sequence_number += 1;
198                                    messages.push(message);
199                                }
200                            } else {
201                                // Directly parse as new format message
202                                if let Ok(message) = serde_json::from_value::<ChatMessageRecord>(msg_value) {
203                                    max_sequence_number = max_sequence_number.max(message.sequence_number);
204                                    messages.push(message);
205                                }
206                            }
207                        },
208                        Err(e) => {
209                            warn!("Failed to parse line in JSONL file: {}, error: {}", line, e);
210                        }
211                    }
212                }
213                
214                if !messages.is_empty() {
215                    // Sort messages by sequence_number
216                    messages.sort_by_key(|m| m.sequence_number);
217                    
218                    // Update session history
219                    {
220                        let mut history = self.session_history.write().await;
221                        history.messages = messages;
222                        history.updated_at = Utc::now().to_rfc3339();
223                    }
224                    
225                    // Update next message sequence number
226                    {
227                        let mut next_sequence = self.next_sequence_number.write().await;
228                        *next_sequence = max_sequence_number + 1;
229                    }
230                    
231                    info!("[FileChatMessageHistory] Loaded session history with {} messages from old JSONL format", {
232                        let history = self.session_history.read().await;
233                        history.messages.len()
234                    });
235                    
236                    // Save as new format
237                    self.save_session_history().await?;
238                } else {
239                    return Err(anyhow::anyhow!("Failed to parse file as either session history JSON or old JSONL format"));
240                }
241            }
242        }
243        
244        Ok(())
245    }
246    
247    /// Save session history to file (entire session as a JSON object)
248    pub async fn save_session_history(&self) -> Result<()> {
249        // Get current session history
250        let history = {
251            let history_guard = self.session_history.read().await;
252            history_guard.clone()
253        };
254        
255        // Create temporary file
256        let temp_path = self.file_path.with_extension("tmp");
257        {
258            let mut file = File::create(&temp_path).await?;
259            
260            // Write entire session history as a JSON object to file
261            let json_content = serde_json::to_string_pretty(&history)?;
262            file.write_all(json_content.as_bytes()).await?;
263            
264            file.flush().await?;
265        }
266        
267        // Atomically replace original file
268        tokio::fs::rename(&temp_path, &self.file_path).await?;
269        
270        Ok(())
271    }
272    
273    /// Add user message, aligned with LangChain's add_user_message
274    pub async fn add_user_message(&self, content: String) -> Result<()> {
275        // Check if message content is empty
276        if content.trim().is_empty() {
277            return Ok(());
278        }
279        
280        let sequence_number = {
281            let mut seq = self.next_sequence_number.write().await;
282            let current = *seq;
283            *seq += 1;
284            current
285        };
286        
287        let message = ChatMessageRecord {
288            role: "user".to_string(),
289            content,
290            name: None,
291            additional_kwargs: None,
292            timestamp: Utc::now().to_rfc3339(),
293            sequence_number,
294        };
295        
296        self.add_message(message).await?;
297        Ok(())
298    }
299    
300    /// Add AI message to history
301    pub async fn add_ai_message(&self, content: &str) -> Result<()> {
302        // Preprocess content, if it's JSON string format, extract the content field
303        let processed_content = if content.starts_with('"') && content.ends_with('"') {
304            // Try to parse as JSON string
305            match serde_json::from_str::<serde_json::Value>(content) {
306                Ok(serde_json::Value::String(s)) => s,
307                _ => content.to_string(),
308            }
309        } else if content.starts_with('{') && content.ends_with('}') {
310            // Try to parse as JSON object
311            match serde_json::from_str::<serde_json::Value>(content) {
312                Ok(json_obj) => {
313                    // If it's a JSON object, try to extract the content field
314                    if let Some(content_value) = json_obj.get("content") {
315                        if let Some(content_str) = content_value.as_str() {
316                            content_str.to_string()
317                        } else {
318                            content.to_string()
319                        }
320                    } else {
321                        content.to_string()
322                    }
323                },
324                _ => content.to_string(),
325            }
326        } else {
327            content.to_string()
328        };
329        
330        let sequence_number = {
331            let mut seq = self.next_sequence_number.write().await;
332            let current = *seq;
333            *seq += 1;
334            current
335        };
336        
337        let message = ChatMessageRecord {
338            role: "assistant".to_string(),
339            content: processed_content,
340            name: None,
341            additional_kwargs: None,
342            timestamp: Utc::now().to_rfc3339(),
343            sequence_number,
344        };
345        
346        self.add_message(message).await?;
347        Ok(())
348    }
349    
350    /// Add message to memory and save to file
351    async fn add_message(&self, message: ChatMessageRecord) -> Result<()> {
352        // Add to memory
353        {
354            let mut history = self.session_history.write().await;
355            history.messages.push(message.clone());
356            history.updated_at = Utc::now().to_rfc3339();
357        }
358        
359        // Save to file
360        self.save_session_history().await?;
361        
362        Ok(())
363    }
364    
365    /// Get all messages, aligned with LangChain's get_messages
366    pub async fn get_messages(&self) -> Result<Vec<ChatMessageRecord>> {
367        let history = self.session_history.read().await;
368        Ok(history.messages.clone())
369    }
370    
371    /// Clear all messages
372    pub async fn clear(&self) -> Result<()> {
373        // Reset session history
374        {
375            let mut history = self.session_history.write().await;
376            history.messages.clear();
377            history.updated_at = Utc::now().to_rfc3339();
378        }
379        
380        // Reset message sequence number
381        {
382            let mut next_sequence = self.next_sequence_number.write().await;
383            *next_sequence = 1;
384        }
385        
386        // Save to file
387        self.save_session_history().await?;
388        
389        Ok(())
390    }
391}
392
393/// MessageHistoryMemory implementation, implementing BaseMemory trait
394#[derive(Debug)]
395pub struct MessageHistoryMemory {
396    /// Session ID
397    session_id: String,
398    /// Data directory
399    data_dir: PathBuf,
400    /// File message history
401    chat_history: FileChatMessageHistory,
402    /// Default number of recent messages to get
403    default_recent_count: usize,
404}
405
406impl Clone for MessageHistoryMemory {
407    fn clone(&self) -> Self {
408        Self {
409            session_id: self.session_id.clone(),
410            data_dir: self.data_dir.clone(),
411            chat_history: self.chat_history.clone(),
412            default_recent_count: self.default_recent_count,
413        }
414    }
415}
416
417impl MessageHistoryMemory {
418    /// Create a new MessageHistoryMemory instance
419    pub async fn new(session_id: String, data_dir: PathBuf) -> Result<Self> {
420        // Use default recent message count
421        let default_recent_count = crate::memory::utils::get_recent_messages_count_from_env();
422        Self::new_with_recent_count(session_id, data_dir, default_recent_count).await
423    }
424    
425    /// Create a new MessageHistoryMemory instance with specified recent message count
426    pub async fn new_with_recent_count(session_id: String, data_dir: PathBuf, recent_count: usize) -> Result<Self> {
427        // Ensure data directory exists
428        tokio::fs::create_dir_all(&data_dir).await?;
429        
430        // Create file message history, using JSONL format
431        let file_path = data_dir.join(format!("{}_history.jsonl", session_id));
432        let chat_history = FileChatMessageHistory::new(session_id.clone(), file_path).await?;
433        
434        Ok(Self {
435            session_id,
436            data_dir,
437            chat_history,
438            default_recent_count: recent_count,
439        })
440    }
441    
442    /// Get session ID
443    pub fn get_session_id(&self) -> &str {
444        &self.session_id
445    }
446    
447    /// Get recent messages
448    pub async fn get_recent_messages(&self, count: usize) -> Result<Vec<ChatMessageRecord>> {
449        let messages = self.chat_history.get_messages().await?;
450        
451        // Get recent messages
452        let messages_len = messages.len();
453        let recent_messages: Vec<ChatMessageRecord> = if messages_len > count {
454            messages.into_iter().skip(messages_len - count).collect()
455        } else {
456            messages
457        };
458        
459        Ok(recent_messages)
460    }
461    
462    /// Get recent messages using default count
463    pub async fn get_default_recent_messages(&self) -> Result<Vec<ChatMessageRecord>> {
464        self.get_recent_messages(self.default_recent_count).await
465    }
466    
467    /// Get total message count
468    pub async fn get_message_count(&self) -> Result<usize> {
469        let messages = self.chat_history.get_messages().await?;
470        Ok(messages.len())
471    }
472    
473    /// Keep only the most recent N messages
474    pub async fn keep_recent_messages(&self, count: usize) -> Result<()> {
475        let messages = self.chat_history.get_messages().await?;
476        
477        if messages.len() <= count {
478            return Ok(());
479        }
480        
481        // Get recent messages
482        let messages_len = messages.len();
483        let recent_messages: Vec<ChatMessageRecord> = if messages_len > count {
484            messages.into_iter().skip(messages_len - count).collect()
485        } else {
486            messages
487        };
488        
489        // Update session history
490        {
491            let mut history = self.chat_history.session_history.write().await;
492            history.messages = recent_messages;
493            history.updated_at = Utc::now().to_rfc3339();
494        }
495        
496        // Save to file
497        self.chat_history.save_session_history().await?;
498        
499        Ok(())
500    }
501    
502    /// Add ChatMessage to history
503    pub async fn add_message(&self, message: &ChatMessage) -> Result<()> {
504        // Check if message content is empty
505        if message.content.trim().is_empty() {
506            return Ok(());
507        }
508        
509        let sequence_number = {
510            let mut seq = self.chat_history.next_sequence_number.write().await;
511            let current = *seq;
512            *seq += 1;
513            current
514        };
515        
516        let record = ChatMessageRecord {
517            role: message.role.clone(),
518            content: message.content.clone(),
519            name: None,
520            additional_kwargs: if let Some(metadata) = &message.metadata {
521                let filtered_kwargs: HashMap<String, serde_json::Value> = metadata.as_object()
522                    .unwrap_or(&serde_json::Map::new())
523                    .iter()
524                    .filter(|(k, _)| k != &"type") // Filter out special fields
525                    .map(|(k, v)| (k.clone(), v.clone()))
526                    .collect();
527                Some(filtered_kwargs)
528            } else {
529                None
530            },
531            timestamp: message.timestamp.clone(),
532            sequence_number,
533        };
534        
535        self.chat_history.add_message(record).await?;
536        Ok(())
537    }
538    
539    /// Get the most recent N messages, return ChatMessage type
540    pub async fn get_recent_chat_messages(&self, count: usize) -> Result<Vec<ChatMessage>> {
541        let records = self.get_recent_messages(count).await?;
542        
543        // Convert to ChatMessage
544        let messages: Result<Vec<ChatMessage>> = records.into_iter().map(|record| {
545            Ok(ChatMessage {
546                id: uuid::Uuid::new_v4().to_string(), // Generate new ID
547                role: record.role,
548                content: record.content,
549                timestamp: record.timestamp,
550                metadata: record.additional_kwargs.map(|kwargs| {
551                    let mut map = serde_json::Map::new();
552                    for (k, v) in kwargs {
553                        map.insert(k, v);
554                    }
555                    serde_json::Value::Object(map)
556                }),
557            })
558        }).collect();
559        
560        messages
561    }
562    
563}
564
565// Implement BaseMemory trait, compatible with existing system
566use crate::memory::base::BaseMemory;
567
568impl BaseMemory for MessageHistoryMemory {
569    fn memory_variables(&self) -> Vec<String> {
570        vec!["chat_history".to_string()]
571    }
572    
573    fn load_memory_variables<'a>(&'a self, _inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>> {
574        Box::pin(async move {
575            // Load messages from file, but only return recent messages
576            // Use default configured message count
577            let messages = self.get_default_recent_messages().await?;
578            
579            // Convert to format compatible with SimpleMemory
580            let mut history_array = Vec::new();
581            for msg in messages {
582                let mut msg_obj = serde_json::Map::new();
583                msg_obj.insert("role".to_string(), serde_json::Value::String(msg.role));
584                msg_obj.insert("content".to_string(), serde_json::Value::String(msg.content));
585                
586                if let Some(kwargs) = msg.additional_kwargs {
587                    for (k, v) in kwargs {
588                        msg_obj.insert(k, v);
589                    }
590                }
591                
592                history_array.push(serde_json::Value::Object(msg_obj));
593            }
594            
595            let mut result = HashMap::new();
596            result.insert("chat_history".to_string(), serde_json::Value::Array(history_array));
597            
598            Ok(result)
599        })
600    }
601    
602    fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
603        Box::pin(async move {
604            // Save user message
605            if let Some(input_value) = inputs.get("input") {
606                if let Some(content) = input_value.as_str() {
607                    self.chat_history.add_user_message(content.to_string()).await?;
608                }
609            }
610            
611            // Save AI response
612            if let Some(output_value) = outputs.get("output") {
613                if let Some(content) = output_value.as_str() {
614                    // Preprocess content, if it's JSON string format, extract the content field
615                    let processed_content = if content.starts_with('"') && content.ends_with('"') {
616                        // Try to parse as JSON string
617                        match serde_json::from_str::<serde_json::Value>(content) {
618                            Ok(serde_json::Value::String(s)) => s,
619                            _ => content.to_string(),
620                        }
621                    } else if content.starts_with('{') && content.ends_with('}') {
622                        // Try to parse as JSON object
623                        match serde_json::from_str::<serde_json::Value>(content) {
624                            Ok(json_obj) => {
625                                // If it's a JSON object, try to extract the content field
626                                if let Some(content_value) = json_obj.get("content") {
627                                    if let Some(content_str) = content_value.as_str() {
628                                        content_str.to_string()
629                                    } else {
630                                        content.to_string()
631                                    }
632                                } else {
633                                    content.to_string()
634                                }
635                            },
636                            _ => content.to_string(),
637                        }
638                    } else {
639                        content.to_string()
640                    };
641                    
642                    self.chat_history.add_ai_message(&processed_content).await?;
643                }
644            }
645            
646            Ok(())
647        })
648    }
649    
650    fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
651        Box::pin(async move {
652            self.chat_history.clear().await?;
653            Ok(())
654        })
655    }
656    
657    fn clone_box(&self) -> Box<dyn BaseMemory> {
658        Box::new(self.clone())
659    }
660    
661    fn get_session_id(&self) -> Option<&str> {
662        Some(&self.session_id)
663    }
664    
665    fn set_session_id(&mut self, session_id: String) {
666        self.session_id = session_id;
667    }
668    
669    fn get_token_count(&self) -> Result<usize, Error> {
670        // Simplified implementation: estimate token count based on character count
671        // In actual applications, a more precise token calculator can be used
672        let count = self.session_id.len() + self.data_dir.to_string_lossy().len();
673        Ok(count)
674    }
675    
676    fn as_any(&self) -> &dyn std::any::Any {
677        self
678    }
679}