1use std::collections::HashMap;
3use std::path::PathBuf;
4use std::sync::Arc;
5use std::pin::Pin;
6use anyhow::{Error, Result};
7use serde::{Serialize, Deserialize};
8use serde_json::{json, Value};
9use tokio::sync::RwLock;
10use async_trait::async_trait;
11use log::{info, warn, error};
12use std::future::Future;
13
14use crate::memory::base::{BaseMemory, MemoryVariables};
15use crate::memory::message_history::{MessageHistoryMemory, ChatMessage};
16use crate::memory::summary::SummaryMemory;
17use crate::memory::utils::{
18 ensure_data_dir_exists, get_data_dir_from_env, get_summary_threshold_from_env,
19 get_recent_messages_count_from_env, generate_session_id
20};
21
22#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct CompositeMemoryConfig {
25 pub data_dir: PathBuf,
27 pub session_id: Option<String>,
29 pub summary_threshold: usize,
31 pub recent_messages_count: usize,
33 pub auto_generate_summary: bool,
35}
36
37impl Default for CompositeMemoryConfig {
38 fn default() -> Self {
39 Self {
40 data_dir: get_data_dir_from_env(),
41 session_id: None, summary_threshold: get_summary_threshold_from_env(),
43 recent_messages_count: get_recent_messages_count_from_env(),
44 auto_generate_summary: true,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
54pub struct CompositeMemory {
55 config: CompositeMemoryConfig,
57 message_history: Option<Arc<MessageHistoryMemory>>,
59 summary_memory: Option<Arc<SummaryMemory>>,
61 memory_variables: Arc<RwLock<MemoryVariables>>,
63}
64
65impl CompositeMemory {
66 pub async fn new() -> Result<Self> {
68 Self::with_config(CompositeMemoryConfig::default()).await
69 }
70
71 pub async fn with_basic_params(
75 data_dir: PathBuf,
76 summary_threshold: usize,
77 recent_messages_count: usize,
78 ) -> Result<Self> {
79 let config = CompositeMemoryConfig {
80 data_dir,
81 session_id: None, summary_threshold,
83 recent_messages_count,
84 auto_generate_summary: true,
85 };
86 Self::with_config(config).await
87 }
88
89 pub async fn with_config(config: CompositeMemoryConfig) -> Result<Self> {
91 ensure_data_dir_exists(&config.data_dir).await?;
93
94 let session_id = config.session_id.clone()
96 .unwrap_or_else(|| generate_session_id());
97
98 let history = MessageHistoryMemory::new_with_recent_count(
100 session_id.clone(),
101 config.data_dir.clone(),
102 config.recent_messages_count
103 ).await?;
104 let message_history = Some(Arc::new(history));
105
106 let summary = SummaryMemory::new_with_shared_history(
108 session_id.clone(),
109 config.data_dir.clone(),
110 config.summary_threshold,
111 message_history.clone().unwrap() ).await?;
113 let summary_memory = Some(Arc::new(summary));
114
115 Ok(Self {
116 config,
117 message_history,
118 summary_memory,
119 memory_variables: Arc::new(RwLock::new(HashMap::new())),
120 })
121 }
122
123 pub async fn with_session_id(session_id: String) -> Result<Self> {
125 let mut config = CompositeMemoryConfig::default();
126 config.session_id = Some(session_id);
127 Self::with_config(config).await
128 }
129
130 pub async fn add_message(&self, message: ChatMessage) -> Result<()> {
132 if let Some(ref history) = self.message_history {
134 history.add_message(&message).await?;
135 }
136
137 if self.config.auto_generate_summary {
139 info!("Checking if summary generation is needed...");
140 if let Some(ref summary) = self.summary_memory {
143 summary.check_and_generate_summary().await?;
144
145 if let Some(ref history) = self.message_history {
147 let keep_count = self.config.recent_messages_count;
148 history.keep_recent_messages(keep_count).await?;
149 }
150 }
151 }
152
153 Ok(())
154 }
155
156 pub async fn get_message_count(&self) -> Result<usize> {
158 if let Some(ref history) = self.message_history {
159 history.get_message_count().await
160 } else {
161 Ok(0)
162 }
163 }
164
165 pub async fn get_recent_messages(&self, count: usize) -> Result<Vec<ChatMessage>> {
167 if let Some(ref history) = self.message_history {
168 history.get_recent_chat_messages(count).await
169 } else {
170 Ok(Vec::new())
171 }
172 }
173
174 pub async fn cleanup_old_messages(&self) -> Result<()> {
176 if let Some(ref history) = self.message_history {
177 history.keep_recent_messages(self.config.recent_messages_count).await?;
178 }
179 Ok(())
180 }
181
182 pub async fn get_memory_stats(&self) -> Result<Value> {
184 let mut stats = json!({
185 "config": {
186 "summary_threshold": self.config.summary_threshold,
187 "recent_messages_count": self.config.recent_messages_count,
188 "auto_generate_summary": self.config.auto_generate_summary,
189 }
190 });
191
192 if let Some(ref history) = self.message_history {
194 let message_count: usize = history.get_message_count().await?;
195 stats["message_history"] = json!({
196 "enabled": true,
197 "message_count": message_count,
198 });
199 }
200
201 if let Some(ref summary) = self.summary_memory {
203 let summary_data = summary.load_summary().await?;
204 stats["summary_memory"] = json!({
205 "enabled": true,
206 "has_summary": summary_data.summary.is_some(),
207 "token_count": summary_data.token_count,
208 "last_updated": summary_data.last_updated,
209 });
210 }
211
212 Ok(stats)
213 }
214
215 pub async fn get_summary(&self) -> Result<Option<String>> {
217 if let Some(ref summary) = self.summary_memory {
218 let summary_data = summary.load_summary().await?;
219 Ok(summary_data.summary)
220 } else {
221 Ok(None)
222 }
223 }
224}
225
226impl CompositeMemory {
228 pub fn as_any(&self) -> &dyn std::any::Any {
230 self
231 }
232}
233
234#[async_trait]
235impl BaseMemory for CompositeMemory {
236 fn memory_variables(&self) -> Vec<String> {
237 let mut vars = Vec::new();
239
240 vars.extend_from_slice(&["chat_history".to_string(), "summary".to_string(), "input".to_string(), "output".to_string()]);
242
243 vars.push("config".to_string());
245
246 vars
247 }
248
249 fn load_memory_variables<'a>(&'a self, inputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<HashMap<String, Value>>> + Send + 'a>> {
250 Box::pin(async move {
251 let mut result = HashMap::new();
252
253 if let Some(ref history) = self.message_history {
255 let messages = history.get_recent_chat_messages(
256 self.config.recent_messages_count
257 ).await?;
258
259 let history_json = serde_json::to_value(&messages)?;
260 result.insert("chat_history".to_string(), history_json);
261 }
262
263 if let Some(ref summary) = self.summary_memory {
265 let summary_data = summary.load_summary().await?;
266
267 if let Some(summary_text) = summary_data.summary {
268 result.insert("summary".to_string(), json!(summary_text));
269 }
270 }
271
272 if let Some(input) = inputs.get("input") {
274 result.insert("input".to_string(), input.clone());
275 }
276
277 if let Some(output) = inputs.get("output") {
279 result.insert("output".to_string(), output.clone());
280 }
281
282 result.insert("config".to_string(), serde_json::to_value(&self.config)?);
284
285 *self.memory_variables.write().await = result.clone();
287
288 Ok(result)
289 })
290 }
291
292 fn save_context<'a>(&'a self, inputs: &'a HashMap<String, Value>, outputs: &'a HashMap<String, Value>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
293 Box::pin(async move {
294 let input = inputs.get("input")
296 .and_then(|v| v.as_str())
297 .unwrap_or("");
298
299 let output = outputs.get("output")
300 .and_then(|v| v.as_str())
301 .unwrap_or("");
302
303 if !input.is_empty() {
305 let user_message = ChatMessage {
306 id: uuid::Uuid::new_v4().to_string(),
307 role: "user".to_string(),
308 content: input.to_string(),
309 timestamp: chrono::Utc::now().to_rfc3339(),
310 metadata: None,
311 };
312
313 if let Some(ref history) = self.message_history {
315 history.add_message(&user_message).await?;
316 }
317 }
318
319 if !output.is_empty() {
321 let assistant_message = ChatMessage {
322 id: uuid::Uuid::new_v4().to_string(),
323 role: "assistant".to_string(),
324 content: output.to_string(),
325 timestamp: chrono::Utc::now().to_rfc3339(),
326 metadata: None,
327 };
328
329 if let Some(ref history) = self.message_history {
331 history.add_message(&assistant_message).await?;
332 }
333 }
334
335 if self.config.auto_generate_summary {
337 info!("Checking if summary generation is needed...");
338 if let Some(ref summary) = self.summary_memory {
339 summary.check_and_generate_summary().await?;
340
341 if let Some(ref history) = self.message_history {
343 let keep_count = self.config.recent_messages_count;
344 history.keep_recent_messages(keep_count).await?;
345 }
346 }
347 }
348
349 let mut memory_vars = self.memory_variables.write().await;
351
352 if let Some(input_val) = inputs.get("input") {
353 memory_vars.insert("input".to_string(), input_val.clone());
354 }
355
356 if let Some(output_val) = outputs.get("output") {
357 memory_vars.insert("output".to_string(), output_val.clone());
358 }
359
360 Ok(())
361 })
362 }
363
364 fn clear<'a>(&'a self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>> {
365 Box::pin(async move {
366 if let Some(ref history) = self.message_history {
368 history.clear().await?;
369 }
370
371 if let Some(ref summary) = self.summary_memory {
373 summary.clear().await?;
374 }
375
376 self.memory_variables.write().await.clear();
378
379 Ok(())
380 })
381 }
382
383 fn clone_box(&self) -> Box<dyn BaseMemory> {
384 Box::new(self.clone())
385 }
386
387 fn get_session_id(&self) -> Option<&str> {
388 self.config.session_id.as_deref()
389 }
390
391 fn set_session_id(&mut self, session_id: String) {
392 self.config.session_id = Some(session_id);
393 }
394
395 fn get_token_count(&self) -> Result<usize, Error> {
396 let mut count = 0;
398
399 if let Ok(config_json) = serde_json::to_value(&self.config) {
401 count += crate::memory::utils::estimate_json_token_count(&config_json);
402 }
403
404 if let Ok(memory_vars) = self.memory_variables.try_read() {
406 if let Ok(vars_json) = serde_json::to_value(&*memory_vars) {
407 count += crate::memory::utils::estimate_json_token_count(&vars_json);
408 }
409 }
410
411 Ok(count)
412 }
413
414 fn as_any(&self) -> &dyn std::any::Any {
415 self
416 }
417}
418
419#[cfg(test)]
420mod tests {
421 use super::*;
422 use tempfile::TempDir;
423 use crate::memory::message_history::ChatMessage;
424
425 #[tokio::test]
426 async fn test_composite_memory_new() {
427 let memory = CompositeMemory::new().await;
428 assert!(memory.is_ok());
429 }
430
431 #[tokio::test]
432 async fn test_composite_memory_with_session_id() {
433 let session_id = "test_session";
434 let memory = CompositeMemory::with_session_id(session_id.to_string()).await;
435 assert!(memory.is_ok());
436
437 let memory = memory.unwrap();
438 assert_eq!(memory.get_session_id(), Some(session_id));
439 }
440
441 #[tokio::test]
442 async fn test_add_message() {
443 let temp_dir = TempDir::new().unwrap();
444 let mut config = CompositeMemoryConfig::default();
445 config.data_dir = temp_dir.path().to_path_buf();
446 config.auto_generate_summary = false; let memory = CompositeMemory::with_config(config).await.unwrap();
449
450 let message = ChatMessage {
451 id: "test_id".to_string(),
452 role: "user".to_string(),
453 content: "Hello, world!".to_string(),
454 timestamp: chrono::Utc::now().to_rfc3339(),
455 metadata: None,
456 };
457
458 let result = memory.add_message(message).await;
459 assert!(result.is_ok());
460
461 let count = memory.get_message_count().await.unwrap();
462 assert_eq!(count, 1);
463 }
464
465 #[tokio::test]
466 async fn test_save_context() {
467 let temp_dir = TempDir::new().unwrap();
468 let mut config = CompositeMemoryConfig::default();
469 config.data_dir = temp_dir.path().to_path_buf();
470 config.auto_generate_summary = false; let memory = CompositeMemory::with_config(config).await.unwrap();
473
474 let mut inputs = HashMap::new();
475 inputs.insert("input".to_string(), json!("Hello"));
476
477 let mut outputs = HashMap::new();
478 outputs.insert("output".to_string(), json!("Hi there!"));
479
480 let result = memory.save_context(&inputs, &outputs).await;
481 assert!(result.is_ok());
482
483 let count = memory.get_message_count().await.unwrap();
484 assert_eq!(count, 2); }
486
487 #[tokio::test]
488 async fn test_clear() {
489 let temp_dir = TempDir::new().unwrap();
490 let mut config = CompositeMemoryConfig::default();
491 config.data_dir = temp_dir.path().to_path_buf();
492 config.auto_generate_summary = false; let memory = CompositeMemory::with_config(config).await.unwrap();
495
496 let mut inputs = HashMap::new();
498 inputs.insert("input".to_string(), json!("Hello"));
499
500 let mut outputs = HashMap::new();
501 outputs.insert("output".to_string(), json!("Hi there!"));
502
503 memory.save_context(&inputs, &outputs).await.unwrap();
504
505 let count = memory.get_message_count().await.unwrap();
507 assert_eq!(count, 2);
508
509 let result = memory.clear().await;
511 assert!(result.is_ok());
512
513 let count = memory.get_message_count().await.unwrap();
515 assert_eq!(count, 0);
516 }
517}