1use std::path::{Path, PathBuf};
3use std::sync::Arc;
4use tokio::sync::RwLock;
5use tokio::fs::File;
6use tokio::io::{AsyncWriteExt, AsyncReadExt};
7use anyhow::{Error, Result};
8use serde::{Serialize, Deserialize};
9use std::collections::HashMap;
10use serde_json::Value;
11use std::pin::Pin;
12use std::future::Future;
13use log::{info, warn};
14use chrono::Utc;
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct ChatMessage {
19 pub id: String,
20 pub role: String,
21 pub content: String,
22 pub timestamp: String,
23 pub metadata: Option<Value>,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ChatMessageRecord {
29 pub role: String,
31 pub content: String,
33 pub name: Option<String>,
35 pub additional_kwargs: Option<HashMap<String, serde_json::Value>>,
37 pub timestamp: String,
39 pub sequence_number: u64,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct ChatSessionHistory {
46 pub session_id: String,
48 pub created_at: String,
50 pub updated_at: String,
52 pub messages: Vec<ChatMessageRecord>,
54 pub metadata: Option<HashMap<String, serde_json::Value>>,
56}
57
58#[derive(Debug)]
61pub struct FileChatMessageHistory {
62 session_id: String,
64 file_path: PathBuf,
66 session_history: Arc<RwLock<ChatSessionHistory>>,
68 next_sequence_number: Arc<RwLock<u64>>,
70}
71
72impl Clone for FileChatMessageHistory {
73 fn clone(&self) -> Self {
74 Self {
75 session_id: self.session_id.clone(),
76 file_path: self.file_path.clone(),
77 session_history: Arc::clone(&self.session_history),
78 next_sequence_number: Arc::clone(&self.next_sequence_number),
79 }
80 }
81}
82
83impl FileChatMessageHistory {
84 pub async fn new(session_id: String, file_path: PathBuf) -> Result<Self> {
86 if let Some(parent) = file_path.parent() {
88 tokio::fs::create_dir_all(parent).await?;
89 }
90
91 let now = Utc::now().to_rfc3339();
93 let session_history = ChatSessionHistory {
94 session_id: session_id.clone(),
95 created_at: now.clone(),
96 updated_at: now,
97 messages: Vec::new(),
98 metadata: None,
99 };
100
101 let instance = Self {
102 session_id: session_id.clone(),
103 file_path: file_path.clone(),
104 session_history: Arc::new(RwLock::new(session_history)),
105 next_sequence_number: Arc::new(RwLock::new(1)),
106 };
107
108 instance.load_session_history().await?;
110
111 Ok(instance)
112 }
113
114 async fn load_session_history(&self) -> Result<()> {
116 if !tokio::fs::metadata(&self.file_path).await.is_ok() {
117 return Ok(());
119 }
120
121 let mut file = File::open(&self.file_path).await?;
122 let mut contents = String::new();
123 file.read_to_string(&mut contents).await?;
124
125 if contents.trim().is_empty() {
126 return Ok(());
127 }
128
129 match serde_json::from_str::<ChatSessionHistory>(&contents) {
131 Ok(session_history) => {
132 {
134 let mut history = self.session_history.write().await;
135 *history = session_history;
136 }
137
138 {
140 let history = self.session_history.read().await;
141 let next_seq = history.messages.len() as u64 + 1;
142 let mut next_sequence = self.next_sequence_number.write().await;
143 *next_sequence = next_seq;
144 }
145
146 info!("[FileChatMessageHistory] Loaded session history with {} messages from JSONL format", {
147 let history = self.session_history.read().await;
148 history.messages.len()
149 });
150 },
151 Err(e) => {
152 warn!("Failed to parse as session history JSON, trying as old format: {}", e);
154
155 let mut messages = Vec::new();
156 let mut max_sequence_number = 0u64;
157
158 for line in contents.lines() {
159 if line.trim().is_empty() {
160 continue;
161 }
162
163 match serde_json::from_str::<serde_json::Value>(line) {
165 Ok(msg_value) => {
166 if msg_value.get("sequence_number").is_none() {
168 if let (Some(role), Some(content)) = (
170 msg_value.get("role").and_then(|v| v.as_str()),
171 msg_value.get("content").and_then(|v| v.as_str())
172 ) {
173 if role == "assistant" && content.contains("user:") && content.contains("assistant:") {
175 continue;
176 }
177
178 let message = ChatMessageRecord {
180 role: role.to_string(),
181 content: content.to_string(),
182 name: msg_value.get("name").and_then(|v| v.as_str()).map(|s| s.to_string()),
183 additional_kwargs: msg_value.get("additional_kwargs").cloned().and_then(|v| {
184 if v.is_null() {
185 None
186 } else {
187 Some(serde_json::from_value(v).unwrap_or_default())
188 }
189 }),
190 timestamp: msg_value.get("timestamp")
191 .and_then(|v| v.as_str())
192 .unwrap_or(&Utc::now().to_rfc3339())
193 .to_string(),
194 sequence_number: max_sequence_number + 1,
195 };
196
197 max_sequence_number += 1;
198 messages.push(message);
199 }
200 } else {
201 if let Ok(message) = serde_json::from_value::<ChatMessageRecord>(msg_value) {
203 max_sequence_number = max_sequence_number.max(message.sequence_number);
204 messages.push(message);
205 }
206 }
207 },
208 Err(e) => {
209 warn!("Failed to parse line in JSONL file: {}, error: {}", line, e);
210 }
211 }
212 }
213
214 if !messages.is_empty() {
215 messages.sort_by_key(|m| m.sequence_number);
217
218 {
220 let mut history = self.session_history.write().await;
221 history.messages = messages;
222 history.updated_at = Utc::now().to_rfc3339();
223 }
224
225 {
227 let mut next_sequence = self.next_sequence_number.write().await;
228 *next_sequence = max_sequence_number + 1;
229 }
230
231 info!("[FileChatMessageHistory] Loaded session history with {} messages from old JSONL format", {
232 let history = self.session_history.read().await;
233 history.messages.len()
234 });
235
236 self.save_session_history().await?;
238 } else {
239 return Err(anyhow::anyhow!("Failed to parse file as either session history JSON or old JSONL format"));
240 }
241 }
242 }
243
244 Ok(())
245 }
246
247 pub async fn save_session_history(&self) -> Result<()> {
249 let history = {
251 let history_guard = self.session_history.read().await;
252 history_guard.clone()
253 };
254
255 let temp_path = self.file_path.with_extension("tmp");
257 {
258 let mut file = File::create(&temp_path).await?;
259
260 let json_content = serde_json::to_string_pretty(&history)?;
262 file.write_all(json_content.as_bytes()).await?;
263
264 file.flush().await?;
265 }
266
267 tokio::fs::rename(&temp_path, &self.file_path).await?;
269
270 Ok(())
271 }
272
273 pub async fn add_user_message(&self, content: String) -> Result<()> {
275 if content.trim().is_empty() {
277 return Ok(());
278 }
279
280 let sequence_number = {
281 let mut seq = self.next_sequence_number.write().await;
282 let current = *seq;
283 *seq += 1;
284 current
285 };
286
287 let message = ChatMessageRecord {
288 role: "user".to_string(),
289 content,
290 name: None,
291 additional_kwargs: None,
292 timestamp: Utc::now().to_rfc3339(),
293 sequence_number,
294 };
295
296 self.add_message(message).await?;
297 Ok(())
298 }
299
300 pub async fn add_ai_message(&self, content: &str) -> Result<()> {
302 let processed_content = if content.starts_with('"') && content.ends_with('"') {
304 match serde_json::from_str::<serde_json::Value>(content) {
306 Ok(serde_json::Value::String(s)) => s,
307 _ => content.to_string(),
308 }
309 } else if content.starts_with('{') && content.ends_with('}') {
310 match serde_json::from_str::<serde_json::Value>(content) {
312 Ok(json_obj) => {
313 if let Some(content_value) = json_obj.get("content") {
315 if let Some(content_str) = content_value.as_str() {
316 content_str.to_string()
317 } else {
318 content.to_string()
319 }
320 } else {
321 content.to_string()
322 }
323 },
324 _ => content.to_string(),
325 }
326 } else {
327 content.to_string()
328 };
329
330 let sequence_number = {
331 let mut seq = self.next_sequence_number.write().await;
332 let current = *seq;
333 *seq += 1;
334 current
335 };
336
337 let message = ChatMessageRecord {
338 role: "assistant".to_string(),
339 content: processed_content,
340 name: None,
341 additional_kwargs: None,
342 timestamp: Utc::now().to_rfc3339(),
343 sequence_number,
344 };
345
346 self.add_message(message).await?;
347 Ok(())
348 }
349
350 async fn add_message(&self, message: ChatMessageRecord) -> Result<()> {
352 {
354 let mut history = self.session_history.write().await;
355 history.messages.push(message.clone());
356 history.updated_at = Utc::now().to_rfc3339();
357 }
358
359 self.save_session_history().await?;
361
362 Ok(())
363 }
364
365 pub async fn get_messages(&self) -> Result<Vec<ChatMessageRecord>> {
367 let history = self.session_history.read().await;
368 Ok(history.messages.clone())
369 }
370
371 pub async fn clear(&self) -> Result<()> {
373 {
375 let mut history = self.session_history.write().await;
376 history.messages.clear();
377 history.updated_at = Utc::now().to_rfc3339();
378 }
379
380 {
382 let mut next_sequence = self.next_sequence_number.write().await;
383 *next_sequence = 1;
384 }
385
386 self.save_session_history().await?;
388
389 Ok(())
390 }
391}
392
393#[derive(Debug)]
395pub struct MessageHistoryMemory {
396 session_id: String,
398 data_dir: PathBuf,
400 chat_history: FileChatMessageHistory,
402 default_recent_count: usize,
404}
405
406impl Clone for MessageHistoryMemory {
407 fn clone(&self) -> Self {
408 Self {
409 session_id: self.session_id.clone(),
410 data_dir: self.data_dir.clone(),
411 chat_history: self.chat_history.clone(),
412 default_recent_count: self.default_recent_count,
413 }
414 }
415}
416
417impl MessageHistoryMemory {
418 pub async fn new(session_id: String, data_dir: PathBuf) -> Result<Self> {
420 let default_recent_count = crate::memory::utils::get_recent_messages_count_from_env();
422 Self::new_with_recent_count(session_id, data_dir, default_recent_count).await
423 }
424
425 pub async fn new_with_recent_count(session_id: String, data_dir: PathBuf, recent_count: usize) -> Result<Self> {
427 tokio::fs::create_dir_all(&data_dir).await?;
429
430 let file_path = data_dir.join(format!("{}_history.jsonl", session_id));
432 let chat_history = FileChatMessageHistory::new(session_id.clone(), file_path).await?;
433
434 Ok(Self {
435 session_id,
436 data_dir,
437 chat_history,
438 default_recent_count: recent_count,
439 })
440 }
441
442 pub fn get_session_id(&self) -> &str {
444 &self.session_id
445 }
446
447 pub async fn get_recent_messages(&self, count: usize) -> Result<Vec<ChatMessageRecord>> {
449 let messages = self.chat_history.get_messages().await?;
450
451 let messages_len = messages.len();
453 let recent_messages: Vec<ChatMessageRecord> = if messages_len > count {
454 messages.into_iter().skip(messages_len - count).collect()
455 } else {
456 messages
457 };
458
459 Ok(recent_messages)
460 }
461
462 pub async fn get_default_recent_messages(&self) -> Result<Vec<ChatMessageRecord>> {
464 self.get_recent_messages(self.default_recent_count).await
465 }
466
467 pub async fn get_message_count(&self) -> Result<usize> {
469 let messages = self.chat_history.get_messages().await?;
470 Ok(messages.len())
471 }
472
473 pub async fn keep_recent_messages(&self, count: usize) -> Result<()> {
475 let messages = self.chat_history.get_messages().await?;
476
477 if messages.len() <= count {
478 return Ok(());
479 }
480
481 let messages_len = messages.len();
483 let recent_messages: Vec<ChatMessageRecord> = if messages_len > count {
484 messages.into_iter().skip(messages_len - count).collect()
485 } else {
486 messages
487 };
488
489 {
491 let mut history = self.chat_history.session_history.write().await;
492 history.messages = recent_messages;
493 history.updated_at = Utc::now().to_rfc3339();
494 }
495
496 self.chat_history.save_session_history().await?;
498
499 Ok(())
500 }
501
502 pub async fn add_message(&self, message: &ChatMessage) -> Result<()> {
504 if message.content.trim().is_empty() {
506 return Ok(());
507 }
508
509 let sequence_number = {
510 let mut seq = self.chat_history.next_sequence_number.write().await;
511 let current = *seq;
512 *seq += 1;
513 current
514 };
515
516 let record = ChatMessageRecord {
517 role: message.role.clone(),
518 content: message.content.clone(),
519 name: None,
520 additional_kwargs: if let Some(metadata) = &message.metadata {
521 let filtered_kwargs: HashMap<String, serde_json::Value> = metadata.as_object()
522 .unwrap_or(&serde_json::Map::new())
523 .iter()
524 .filter(|(k, _)| k != &"type") .map(|(k, v)| (k.clone(), v.clone()))
526 .collect();
527 Some(filtered_kwargs)
528 } else {
529 None
530 },
531 timestamp: message.timestamp.clone(),
532 sequence_number,
533 };
534
535 self.chat_history.add_message(record).await?;
536 Ok(())
537 }
538
539 pub async fn get_recent_chat_messages(&self, count: usize) -> Result<Vec<ChatMessage>> {
541 let records = self.get_recent_messages(count).await?;
542
543 let messages: Result<Vec<ChatMessage>> = records.into_iter().map(|record| {
545 Ok(ChatMessage {
546 id: uuid::Uuid::new_v4().to_string(), role: record.role,
548 content: record.content,
549 timestamp: record.timestamp,
550 metadata: record.additional_kwargs.map(|kwargs| {
551 let mut map = serde_json::Map::new();
552 for (k, v) in kwargs {
553 map.insert(k, v);
554 }
555 serde_json::Value::Object(map)
556 }),
557 })
558 }).collect();
559
560 messages
561 }
562
563}
564
565use crate::memory::base::BaseMemory;
567
568impl BaseMemory for MessageHistoryMemory {
569 fn memory_variables(&self) -> Vec<String> {
570 vec!["chat_history".to_string()]
571 }
572
573 fn load_memory_variables<'a>(&'a self, _inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>, Error>> + Send + 'a>> {
574 Box::pin(async move {
575 let messages = self.get_default_recent_messages().await?;
578
579 let mut history_array = Vec::new();
581 for msg in messages {
582 let mut msg_obj = serde_json::Map::new();
583 msg_obj.insert("role".to_string(), serde_json::Value::String(msg.role));
584 msg_obj.insert("content".to_string(), serde_json::Value::String(msg.content));
585
586 if let Some(kwargs) = msg.additional_kwargs {
587 for (k, v) in kwargs {
588 msg_obj.insert(k, v);
589 }
590 }
591
592 history_array.push(serde_json::Value::Object(msg_obj));
593 }
594
595 let mut result = HashMap::new();
596 result.insert("chat_history".to_string(), serde_json::Value::Array(history_array));
597
598 Ok(result)
599 })
600 }
601
602 fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
603 Box::pin(async move {
604 if let Some(input_value) = inputs.get("input") {
606 if let Some(content) = input_value.as_str() {
607 self.chat_history.add_user_message(content.to_string()).await?;
608 }
609 }
610
611 if let Some(output_value) = outputs.get("output") {
613 if let Some(content) = output_value.as_str() {
614 let processed_content = if content.starts_with('"') && content.ends_with('"') {
616 match serde_json::from_str::<serde_json::Value>(content) {
618 Ok(serde_json::Value::String(s)) => s,
619 _ => content.to_string(),
620 }
621 } else if content.starts_with('{') && content.ends_with('}') {
622 match serde_json::from_str::<serde_json::Value>(content) {
624 Ok(json_obj) => {
625 if let Some(content_value) = json_obj.get("content") {
627 if let Some(content_str) = content_value.as_str() {
628 content_str.to_string()
629 } else {
630 content.to_string()
631 }
632 } else {
633 content.to_string()
634 }
635 },
636 _ => content.to_string(),
637 }
638 } else {
639 content.to_string()
640 };
641
642 self.chat_history.add_ai_message(&processed_content).await?;
643 }
644 }
645
646 Ok(())
647 })
648 }
649
650 fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<(), Error>> + Send + 'a>> {
651 Box::pin(async move {
652 self.chat_history.clear().await?;
653 Ok(())
654 })
655 }
656
657 fn clone_box(&self) -> Box<dyn BaseMemory> {
658 Box::new(self.clone())
659 }
660
661 fn get_session_id(&self) -> Option<&str> {
662 Some(&self.session_id)
663 }
664
665 fn set_session_id(&mut self, session_id: String) {
666 self.session_id = session_id;
667 }
668
669 fn get_token_count(&self) -> Result<usize, Error> {
670 let count = self.session_id.len() + self.data_dir.to_string_lossy().len();
673 Ok(count)
674 }
675
676 fn as_any(&self) -> &dyn std::any::Any {
677 self
678 }
679}