1use std::collections::{HashMap, VecDeque};
41use std::sync::{Arc, RwLock};
42use std::time::{SystemTime, UNIX_EPOCH};
43
44#[derive(Debug, Clone)]
50pub struct MemoryCompactionConfig {
51 pub l0_max_episodes: usize,
53
54 pub l1_max_summaries: usize,
56
57 pub l0_age_threshold_secs: u64,
59
60 pub l1_age_threshold_secs: u64,
62
63 pub group_size: usize,
65
66 pub similarity_threshold: f32,
68
69 pub max_summary_tokens: usize,
71
72 pub reembed_summaries: bool,
74
75 pub check_interval_secs: u64,
77}
78
79impl Default for MemoryCompactionConfig {
80 fn default() -> Self {
81 Self {
82 l0_max_episodes: 1000,
83 l1_max_summaries: 100,
84 l0_age_threshold_secs: 3600, l1_age_threshold_secs: 86400 * 7, group_size: 10,
87 similarity_threshold: 0.7,
88 max_summary_tokens: 200,
89 reembed_summaries: true,
90 check_interval_secs: 300, }
92 }
93}
94
95impl MemoryCompactionConfig {
96 pub fn aggressive() -> Self {
98 Self {
99 l0_max_episodes: 100,
100 l1_max_summaries: 20,
101 l0_age_threshold_secs: 60,
102 l1_age_threshold_secs: 3600,
103 group_size: 5,
104 ..Default::default()
105 }
106 }
107
108 pub fn long_running() -> Self {
110 Self {
111 l0_max_episodes: 5000,
112 l1_max_summaries: 500,
113 l0_age_threshold_secs: 3600 * 6, l1_age_threshold_secs: 86400 * 30, group_size: 20,
116 ..Default::default()
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
127pub struct Episode {
128 pub id: String,
130
131 pub timestamp: f64,
133
134 pub content: String,
136
137 pub episode_type: EpisodeType,
139
140 pub metadata: HashMap<String, String>,
142
143 pub embedding: Option<Vec<f32>>,
145
146 pub token_count: usize,
148}
149
150#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
152pub enum EpisodeType {
153 UserMessage,
155 AssistantResponse,
157 ToolCall,
159 ToolResult,
161 SystemEvent,
163 Observation,
165}
166
167#[derive(Debug, Clone)]
169pub struct Summary {
170 pub id: String,
172
173 pub content: String,
175
176 pub source_episode_ids: Vec<String>,
178
179 pub time_range: (f64, f64),
181
182 pub embedding: Option<Vec<f32>>,
184
185 pub token_count: usize,
187
188 pub created_at: f64,
190
191 pub topics: Vec<String>,
193}
194
195#[derive(Debug, Clone)]
197pub struct Abstraction {
198 pub id: String,
200
201 pub content: String,
203
204 pub source_summary_ids: Vec<String>,
206
207 pub time_range: (f64, f64),
209
210 pub embedding: Option<Vec<f32>>,
212
213 pub token_count: usize,
215
216 pub created_at: f64,
218
219 pub insights: Vec<String>,
221}
222
223pub trait Summarizer: Send + Sync {
229 fn summarize_episodes(&self, episodes: &[Episode]) -> Result<String, CompactionError>;
231
232 fn abstract_summaries(&self, summaries: &[Summary]) -> Result<String, CompactionError>;
234
235 fn extract_topics(&self, content: &str) -> Vec<String>;
237}
238
239pub struct ExtractiveSummarizer {
241 pub max_sentences: usize,
243
244 pub include_timestamps: bool,
246}
247
248impl Default for ExtractiveSummarizer {
249 fn default() -> Self {
250 Self {
251 max_sentences: 5,
252 include_timestamps: true,
253 }
254 }
255}
256
257impl Summarizer for ExtractiveSummarizer {
258 fn summarize_episodes(&self, episodes: &[Episode]) -> Result<String, CompactionError> {
259 if episodes.is_empty() {
260 return Ok(String::new());
261 }
262
263 let mut summary_parts = Vec::new();
264
265 let first_ts = episodes.first().map(|e| e.timestamp).unwrap_or(0.0);
267 let last_ts = episodes.last().map(|e| e.timestamp).unwrap_or(0.0);
268
269 if self.include_timestamps {
270 summary_parts.push(format!(
271 "[{} episodes over {:.0} seconds]",
272 episodes.len(),
273 last_ts - first_ts
274 ));
275 }
276
277 let mut by_type: HashMap<EpisodeType, Vec<&Episode>> = HashMap::new();
279 for episode in episodes {
280 by_type.entry(episode.episode_type).or_default().push(episode);
281 }
282
283 for (ep_type, eps) in by_type {
285 let type_name = match ep_type {
286 EpisodeType::UserMessage => "User messages",
287 EpisodeType::AssistantResponse => "Responses",
288 EpisodeType::ToolCall => "Tool calls",
289 EpisodeType::ToolResult => "Tool results",
290 EpisodeType::SystemEvent => "Events",
291 EpisodeType::Observation => "Observations",
292 };
293
294 let sentences: Vec<String> = eps
296 .iter()
297 .take(self.max_sentences)
298 .filter_map(|e| e.content.split('.').next().map(|s| s.trim().to_string()))
299 .filter(|s| !s.is_empty())
300 .collect();
301
302 if !sentences.is_empty() {
303 summary_parts.push(format!("{}: {}", type_name, sentences.join("; ")));
304 }
305 }
306
307 Ok(summary_parts.join("\n"))
308 }
309
310 fn abstract_summaries(&self, summaries: &[Summary]) -> Result<String, CompactionError> {
311 if summaries.is_empty() {
312 return Ok(String::new());
313 }
314
315 let mut abstraction_parts = Vec::new();
316
317 let first_ts = summaries.iter().map(|s| s.time_range.0).fold(f64::MAX, f64::min);
319 let last_ts = summaries.iter().map(|s| s.time_range.1).fold(f64::MIN, f64::max);
320
321 abstraction_parts.push(format!(
322 "[{} summaries, {:.1} hours span]",
323 summaries.len(),
324 (last_ts - first_ts) / 3600.0
325 ));
326
327 let all_topics: Vec<&str> = summaries
329 .iter()
330 .flat_map(|s| s.topics.iter().map(|t| t.as_str()))
331 .collect();
332
333 if !all_topics.is_empty() {
334 let unique_topics: Vec<_> = all_topics.iter()
335 .cloned()
336 .collect::<std::collections::HashSet<_>>()
337 .into_iter()
338 .take(10)
339 .collect();
340 abstraction_parts.push(format!("Topics: {}", unique_topics.join(", ")));
341 }
342
343 let key_points: Vec<String> = summaries
345 .iter()
346 .take(5)
347 .filter_map(|s| s.content.lines().next().map(|l| l.to_string()))
348 .collect();
349
350 if !key_points.is_empty() {
351 abstraction_parts.push(format!("Key points:\n- {}", key_points.join("\n- ")));
352 }
353
354 Ok(abstraction_parts.join("\n"))
355 }
356
357 fn extract_topics(&self, content: &str) -> Vec<String> {
358 let stopwords = ["the", "a", "an", "is", "are", "was", "were", "to", "from", "in", "on", "at", "for", "and", "or"];
360
361 let words: Vec<&str> = content
362 .split_whitespace()
363 .filter(|w| w.len() > 3)
364 .filter(|w| !stopwords.contains(&w.to_lowercase().as_str()))
365 .collect();
366
367 let mut freq: HashMap<String, usize> = HashMap::new();
369 for word in words {
370 let normalized = word.to_lowercase().trim_matches(|c: char| !c.is_alphanumeric()).to_string();
371 if normalized.len() > 3 {
372 *freq.entry(normalized).or_insert(0) += 1;
373 }
374 }
375
376 let mut sorted: Vec<_> = freq.into_iter().collect();
378 sorted.sort_by(|a, b| b.1.cmp(&a.1));
379
380 sorted.into_iter().take(5).map(|(w, _)| w).collect()
381 }
382}
383
384pub struct HierarchicalMemory<S: Summarizer> {
390 config: MemoryCompactionConfig,
392
393 l0_episodes: RwLock<VecDeque<Episode>>,
395
396 l1_summaries: RwLock<VecDeque<Summary>>,
398
399 l2_abstractions: RwLock<VecDeque<Abstraction>>,
401
402 summarizer: Arc<S>,
404
405 stats: RwLock<CompactionStats>,
407
408 next_id: std::sync::atomic::AtomicU64,
410}
411
412#[derive(Debug, Clone, Default)]
414pub struct CompactionStats {
415 pub total_episodes: usize,
417
418 pub total_summaries: usize,
420
421 pub total_abstractions: usize,
423
424 pub episodes_compacted: usize,
426
427 pub summaries_compacted: usize,
429
430 pub last_compaction: Option<f64>,
432
433 pub token_savings: usize,
435}
436
437impl<S: Summarizer> HierarchicalMemory<S> {
438 pub fn new(config: MemoryCompactionConfig, summarizer: Arc<S>) -> Self {
440 Self {
441 config,
442 l0_episodes: RwLock::new(VecDeque::new()),
443 l1_summaries: RwLock::new(VecDeque::new()),
444 l2_abstractions: RwLock::new(VecDeque::new()),
445 summarizer,
446 stats: RwLock::new(CompactionStats::default()),
447 next_id: std::sync::atomic::AtomicU64::new(1),
448 }
449 }
450
451 fn next_id(&self) -> String {
453 let id = self.next_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
454 format!("mem_{}", id)
455 }
456
457 pub fn add_episode(&self, content: String, episode_type: EpisodeType) -> String {
459 let id = self.next_id();
460 let timestamp = SystemTime::now()
461 .duration_since(UNIX_EPOCH)
462 .unwrap_or_default()
463 .as_secs_f64();
464
465 let token_count = content.len() / 4; let episode = Episode {
468 id: id.clone(),
469 timestamp,
470 content,
471 episode_type,
472 metadata: HashMap::new(),
473 embedding: None,
474 token_count,
475 };
476
477 {
478 let mut l0 = self.l0_episodes.write().unwrap();
479 l0.push_back(episode);
480 }
481
482 {
483 let mut stats = self.stats.write().unwrap();
484 stats.total_episodes += 1;
485 }
486
487 id
488 }
489
490 pub fn add_episode_with_embedding(
492 &self,
493 content: String,
494 episode_type: EpisodeType,
495 embedding: Vec<f32>,
496 ) -> String {
497 let id = self.next_id();
498 let timestamp = SystemTime::now()
499 .duration_since(UNIX_EPOCH)
500 .unwrap_or_default()
501 .as_secs_f64();
502
503 let token_count = content.len() / 4;
504
505 let episode = Episode {
506 id: id.clone(),
507 timestamp,
508 content,
509 episode_type,
510 metadata: HashMap::new(),
511 embedding: Some(embedding),
512 token_count,
513 };
514
515 {
516 let mut l0 = self.l0_episodes.write().unwrap();
517 l0.push_back(episode);
518 }
519
520 {
521 let mut stats = self.stats.write().unwrap();
522 stats.total_episodes += 1;
523 }
524
525 id
526 }
527
528 pub fn maybe_compact(&self) -> Result<bool, CompactionError> {
530 let needs_l0 = {
531 let l0 = self.l0_episodes.read().unwrap();
532 l0.len() >= self.config.l0_max_episodes
533 };
534
535 let needs_l1 = {
536 let l1 = self.l1_summaries.read().unwrap();
537 l1.len() >= self.config.l1_max_summaries
538 };
539
540 if needs_l0 || needs_l1 {
541 self.run_compaction()?;
542 return Ok(true);
543 }
544
545 Ok(false)
546 }
547
548 pub fn run_compaction(&self) -> Result<(), CompactionError> {
550 self.compact_l0_to_l1()?;
552
553 self.compact_l1_to_l2()?;
555
556 {
558 let mut stats = self.stats.write().unwrap();
559 stats.last_compaction = Some(
560 SystemTime::now()
561 .duration_since(UNIX_EPOCH)
562 .unwrap_or_default()
563 .as_secs_f64()
564 );
565 }
566
567 Ok(())
568 }
569
570 fn compact_l0_to_l1(&self) -> Result<(), CompactionError> {
572 let now = SystemTime::now()
573 .duration_since(UNIX_EPOCH)
574 .unwrap_or_default()
575 .as_secs_f64();
576
577 let age_threshold = now - self.config.l0_age_threshold_secs as f64;
578
579 let to_compact: Vec<Episode> = {
581 let l0 = self.l0_episodes.read().unwrap();
582 l0.iter()
583 .filter(|e| e.timestamp < age_threshold)
584 .cloned()
585 .collect()
586 };
587
588 if to_compact.is_empty() {
589 return Ok(());
590 }
591
592 let groups = self.group_episodes(&to_compact);
594
595 for group in groups {
597 if group.is_empty() {
598 continue;
599 }
600
601 let content = self.summarizer.summarize_episodes(&group)?;
602 let topics = self.summarizer.extract_topics(&content);
603
604 let first_ts = group.iter().map(|e| e.timestamp).fold(f64::MAX, f64::min);
605 let last_ts = group.iter().map(|e| e.timestamp).fold(f64::MIN, f64::max);
606
607 let episode_ids: Vec<String> = group.iter().map(|e| e.id.clone()).collect();
608 let original_tokens: usize = group.iter().map(|e| e.token_count).sum();
609 let summary_tokens = content.len() / 4;
610
611 let summary = Summary {
612 id: self.next_id(),
613 content,
614 source_episode_ids: episode_ids,
615 time_range: (first_ts, last_ts),
616 embedding: None, token_count: summary_tokens,
618 created_at: now,
619 topics,
620 };
621
622 {
624 let mut l1 = self.l1_summaries.write().unwrap();
625 l1.push_back(summary);
626 }
627
628 {
630 let mut stats = self.stats.write().unwrap();
631 stats.total_summaries += 1;
632 stats.episodes_compacted += group.len();
633 stats.token_savings += original_tokens.saturating_sub(summary_tokens);
634 }
635 }
636
637 {
639 let mut l0 = self.l0_episodes.write().unwrap();
640 l0.retain(|e| e.timestamp >= age_threshold);
641 }
642
643 Ok(())
644 }
645
646 fn compact_l1_to_l2(&self) -> Result<(), CompactionError> {
648 let now = SystemTime::now()
649 .duration_since(UNIX_EPOCH)
650 .unwrap_or_default()
651 .as_secs_f64();
652
653 let age_threshold = now - self.config.l1_age_threshold_secs as f64;
654
655 let to_compact: Vec<Summary> = {
657 let l1 = self.l1_summaries.read().unwrap();
658 l1.iter()
659 .filter(|s| s.created_at < age_threshold)
660 .cloned()
661 .collect()
662 };
663
664 if to_compact.len() < self.config.group_size {
665 return Ok(());
666 }
667
668 let groups = self.group_summaries(&to_compact);
670
671 for group in groups {
672 if group.is_empty() {
673 continue;
674 }
675
676 let content = self.summarizer.abstract_summaries(&group)?;
677
678 let first_ts = group.iter().map(|s| s.time_range.0).fold(f64::MAX, f64::min);
679 let last_ts = group.iter().map(|s| s.time_range.1).fold(f64::MIN, f64::max);
680
681 let summary_ids: Vec<String> = group.iter().map(|s| s.id.clone()).collect();
682 let original_tokens: usize = group.iter().map(|s| s.token_count).sum();
683 let abstraction_tokens = content.len() / 4;
684
685 let insights: Vec<String> = group
687 .iter()
688 .flat_map(|s| s.topics.clone())
689 .collect::<std::collections::HashSet<_>>()
690 .into_iter()
691 .take(5)
692 .collect();
693
694 let abstraction = Abstraction {
695 id: self.next_id(),
696 content,
697 source_summary_ids: summary_ids,
698 time_range: (first_ts, last_ts),
699 embedding: None,
700 token_count: abstraction_tokens,
701 created_at: now,
702 insights,
703 };
704
705 {
707 let mut l2 = self.l2_abstractions.write().unwrap();
708 l2.push_back(abstraction);
709 }
710
711 {
713 let mut stats = self.stats.write().unwrap();
714 stats.total_abstractions += 1;
715 stats.summaries_compacted += group.len();
716 stats.token_savings += original_tokens.saturating_sub(abstraction_tokens);
717 }
718 }
719
720 {
722 let mut l1 = self.l1_summaries.write().unwrap();
723 l1.retain(|s| s.created_at >= age_threshold);
724 }
725
726 Ok(())
727 }
728
729 fn group_episodes(&self, episodes: &[Episode]) -> Vec<Vec<Episode>> {
731 episodes
733 .chunks(self.config.group_size)
734 .map(|chunk| chunk.to_vec())
735 .collect()
736 }
737
738 fn group_summaries(&self, summaries: &[Summary]) -> Vec<Vec<Summary>> {
740 summaries
741 .chunks(self.config.group_size)
742 .map(|chunk| chunk.to_vec())
743 .collect()
744 }
745
746 pub fn total_tokens(&self) -> usize {
748 let l0: usize = self.l0_episodes.read().unwrap().iter().map(|e| e.token_count).sum();
749 let l1: usize = self.l1_summaries.read().unwrap().iter().map(|s| s.token_count).sum();
750 let l2: usize = self.l2_abstractions.read().unwrap().iter().map(|a| a.token_count).sum();
751
752 l0 + l1 + l2
753 }
754
755 pub fn get_context(&self, max_tokens: usize) -> Vec<MemoryEntry> {
757 let mut entries = Vec::new();
758 let mut tokens_used = 0;
759
760 let l0 = self.l0_episodes.read().unwrap();
762 for episode in l0.iter().rev() {
763 if tokens_used + episode.token_count > max_tokens {
764 break;
765 }
766 entries.push(MemoryEntry::Episode(episode.clone()));
767 tokens_used += episode.token_count;
768 }
769
770 let l1 = self.l1_summaries.read().unwrap();
772 for summary in l1.iter().rev() {
773 if tokens_used + summary.token_count > max_tokens {
774 break;
775 }
776 entries.push(MemoryEntry::Summary(summary.clone()));
777 tokens_used += summary.token_count;
778 }
779
780 let l2 = self.l2_abstractions.read().unwrap();
782 for abstraction in l2.iter().rev() {
783 if tokens_used + abstraction.token_count > max_tokens {
784 break;
785 }
786 entries.push(MemoryEntry::Abstraction(abstraction.clone()));
787 tokens_used += abstraction.token_count;
788 }
789
790 entries
791 }
792
793 pub fn stats(&self) -> CompactionStats {
795 self.stats.read().unwrap().clone()
796 }
797
798 pub fn tier_counts(&self) -> (usize, usize, usize) {
800 let l0 = self.l0_episodes.read().unwrap().len();
801 let l1 = self.l1_summaries.read().unwrap().len();
802 let l2 = self.l2_abstractions.read().unwrap().len();
803 (l0, l1, l2)
804 }
805}
806
807#[derive(Debug, Clone)]
809pub enum MemoryEntry {
810 Episode(Episode),
811 Summary(Summary),
812 Abstraction(Abstraction),
813}
814
815impl MemoryEntry {
816 pub fn content(&self) -> &str {
818 match self {
819 Self::Episode(e) => &e.content,
820 Self::Summary(s) => &s.content,
821 Self::Abstraction(a) => &a.content,
822 }
823 }
824
825 pub fn token_count(&self) -> usize {
827 match self {
828 Self::Episode(e) => e.token_count,
829 Self::Summary(s) => s.token_count,
830 Self::Abstraction(a) => a.token_count,
831 }
832 }
833
834 pub fn tier(&self) -> usize {
836 match self {
837 Self::Episode(_) => 0,
838 Self::Summary(_) => 1,
839 Self::Abstraction(_) => 2,
840 }
841 }
842}
843
844#[derive(Debug, Clone)]
850pub enum CompactionError {
851 SummarizationFailed(String),
853 EmbeddingFailed(String),
855 StorageError(String),
857}
858
859impl std::fmt::Display for CompactionError {
860 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
861 match self {
862 Self::SummarizationFailed(msg) => write!(f, "Summarization failed: {}", msg),
863 Self::EmbeddingFailed(msg) => write!(f, "Embedding failed: {}", msg),
864 Self::StorageError(msg) => write!(f, "Storage error: {}", msg),
865 }
866 }
867}
868
869impl std::error::Error for CompactionError {}
870
871pub fn create_hierarchical_memory() -> HierarchicalMemory<ExtractiveSummarizer> {
877 HierarchicalMemory::new(
878 MemoryCompactionConfig::default(),
879 Arc::new(ExtractiveSummarizer::default()),
880 )
881}
882
883pub fn create_test_memory() -> HierarchicalMemory<ExtractiveSummarizer> {
885 HierarchicalMemory::new(
886 MemoryCompactionConfig::aggressive(),
887 Arc::new(ExtractiveSummarizer::default()),
888 )
889}
890
891#[cfg(test)]
896mod tests {
897 use super::*;
898
899 #[test]
900 fn test_add_episode() {
901 let memory = create_test_memory();
902
903 let id = memory.add_episode(
904 "User asked about weather".to_string(),
905 EpisodeType::UserMessage,
906 );
907
908 assert!(id.starts_with("mem_"));
909
910 let (l0, l1, l2) = memory.tier_counts();
911 assert_eq!(l0, 1);
912 assert_eq!(l1, 0);
913 assert_eq!(l2, 0);
914 }
915
916 #[test]
917 fn test_extractive_summarizer() {
918 let summarizer = ExtractiveSummarizer::default();
919
920 let episodes = vec![
921 Episode {
922 id: "1".to_string(),
923 timestamp: 0.0,
924 content: "User asked about the weather forecast.".to_string(),
925 episode_type: EpisodeType::UserMessage,
926 metadata: HashMap::new(),
927 embedding: None,
928 token_count: 10,
929 },
930 Episode {
931 id: "2".to_string(),
932 timestamp: 1.0,
933 content: "Assistant provided weather information for NYC.".to_string(),
934 episode_type: EpisodeType::AssistantResponse,
935 metadata: HashMap::new(),
936 embedding: None,
937 token_count: 12,
938 },
939 ];
940
941 let summary = summarizer.summarize_episodes(&episodes).unwrap();
942
943 assert!(!summary.is_empty());
944 assert!(summary.contains("episodes") || summary.contains("User") || summary.contains("Responses"));
945 }
946
947 #[test]
948 fn test_topic_extraction() {
949 let summarizer = ExtractiveSummarizer::default();
950
951 let content = "The weather forecast shows sunny conditions with temperatures around 75 degrees. Tomorrow expects rain and thunderstorms across the region.";
952
953 let topics = summarizer.extract_topics(content);
954
955 assert!(!topics.is_empty());
956 }
958
959 #[test]
960 fn test_memory_context_retrieval() {
961 let memory = create_test_memory();
962
963 for i in 0..5 {
965 memory.add_episode(
966 format!("Episode {} content here with some text.", i),
967 EpisodeType::UserMessage,
968 );
969 }
970
971 let context = memory.get_context(1000);
972
973 assert!(!context.is_empty());
974
975 for entry in &context {
977 assert_eq!(entry.tier(), 0);
978 }
979 }
980
981 #[test]
982 fn test_token_tracking() {
983 let memory = create_test_memory();
984
985 memory.add_episode(
986 "Short message".to_string(),
987 EpisodeType::UserMessage,
988 );
989
990 memory.add_episode(
991 "A much longer message with more content that should have more tokens estimated".to_string(),
992 EpisodeType::AssistantResponse,
993 );
994
995 let total = memory.total_tokens();
996 assert!(total > 0);
997 }
998
999 #[test]
1000 fn test_stats_tracking() {
1001 let memory = create_test_memory();
1002
1003 for _ in 0..10 {
1004 memory.add_episode("Test episode".to_string(), EpisodeType::UserMessage);
1005 }
1006
1007 let stats = memory.stats();
1008 assert_eq!(stats.total_episodes, 10);
1009 }
1010}