1use std::collections::{HashMap, VecDeque};
44use std::sync::{Arc, RwLock};
45use std::time::{SystemTime, UNIX_EPOCH};
46
47#[derive(Debug, Clone)]
53pub struct MemoryCompactionConfig {
54 pub l0_max_episodes: usize,
56
57 pub l1_max_summaries: usize,
59
60 pub l0_age_threshold_secs: u64,
62
63 pub l1_age_threshold_secs: u64,
65
66 pub group_size: usize,
68
69 pub similarity_threshold: f32,
71
72 pub max_summary_tokens: usize,
74
75 pub reembed_summaries: bool,
77
78 pub check_interval_secs: u64,
80}
81
82impl Default for MemoryCompactionConfig {
83 fn default() -> Self {
84 Self {
85 l0_max_episodes: 1000,
86 l1_max_summaries: 100,
87 l0_age_threshold_secs: 3600, l1_age_threshold_secs: 86400 * 7, group_size: 10,
90 similarity_threshold: 0.7,
91 max_summary_tokens: 200,
92 reembed_summaries: true,
93 check_interval_secs: 300, }
95 }
96}
97
98impl MemoryCompactionConfig {
99 pub fn aggressive() -> Self {
101 Self {
102 l0_max_episodes: 100,
103 l1_max_summaries: 20,
104 l0_age_threshold_secs: 60,
105 l1_age_threshold_secs: 3600,
106 group_size: 5,
107 ..Default::default()
108 }
109 }
110
111 pub fn long_running() -> Self {
113 Self {
114 l0_max_episodes: 5000,
115 l1_max_summaries: 500,
116 l0_age_threshold_secs: 3600 * 6, l1_age_threshold_secs: 86400 * 30, group_size: 20,
119 ..Default::default()
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
130pub struct Episode {
131 pub id: String,
133
134 pub timestamp: f64,
136
137 pub content: String,
139
140 pub episode_type: EpisodeType,
142
143 pub metadata: HashMap<String, String>,
145
146 pub embedding: Option<Vec<f32>>,
148
149 pub token_count: usize,
151}
152
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
155pub enum EpisodeType {
156 UserMessage,
158 AssistantResponse,
160 ToolCall,
162 ToolResult,
164 SystemEvent,
166 Observation,
168}
169
170#[derive(Debug, Clone)]
172pub struct Summary {
173 pub id: String,
175
176 pub content: String,
178
179 pub source_episode_ids: Vec<String>,
181
182 pub time_range: (f64, f64),
184
185 pub embedding: Option<Vec<f32>>,
187
188 pub token_count: usize,
190
191 pub created_at: f64,
193
194 pub topics: Vec<String>,
196}
197
198#[derive(Debug, Clone)]
200pub struct Abstraction {
201 pub id: String,
203
204 pub content: String,
206
207 pub source_summary_ids: Vec<String>,
209
210 pub time_range: (f64, f64),
212
213 pub embedding: Option<Vec<f32>>,
215
216 pub token_count: usize,
218
219 pub created_at: f64,
221
222 pub insights: Vec<String>,
224}
225
226pub trait Summarizer: Send + Sync {
232 fn summarize_episodes(&self, episodes: &[Episode]) -> Result<String, CompactionError>;
234
235 fn abstract_summaries(&self, summaries: &[Summary]) -> Result<String, CompactionError>;
237
238 fn extract_topics(&self, content: &str) -> Vec<String>;
240}
241
242pub struct ExtractiveSummarizer {
244 pub max_sentences: usize,
246
247 pub include_timestamps: bool,
249}
250
251impl Default for ExtractiveSummarizer {
252 fn default() -> Self {
253 Self {
254 max_sentences: 5,
255 include_timestamps: true,
256 }
257 }
258}
259
260impl Summarizer for ExtractiveSummarizer {
261 fn summarize_episodes(&self, episodes: &[Episode]) -> Result<String, CompactionError> {
262 if episodes.is_empty() {
263 return Ok(String::new());
264 }
265
266 let mut summary_parts = Vec::new();
267
268 let first_ts = episodes.first().map(|e| e.timestamp).unwrap_or(0.0);
270 let last_ts = episodes.last().map(|e| e.timestamp).unwrap_or(0.0);
271
272 if self.include_timestamps {
273 summary_parts.push(format!(
274 "[{} episodes over {:.0} seconds]",
275 episodes.len(),
276 last_ts - first_ts
277 ));
278 }
279
280 let mut by_type: HashMap<EpisodeType, Vec<&Episode>> = HashMap::new();
282 for episode in episodes {
283 by_type
284 .entry(episode.episode_type)
285 .or_default()
286 .push(episode);
287 }
288
289 for (ep_type, eps) in by_type {
291 let type_name = match ep_type {
292 EpisodeType::UserMessage => "User messages",
293 EpisodeType::AssistantResponse => "Responses",
294 EpisodeType::ToolCall => "Tool calls",
295 EpisodeType::ToolResult => "Tool results",
296 EpisodeType::SystemEvent => "Events",
297 EpisodeType::Observation => "Observations",
298 };
299
300 let sentences: Vec<String> = eps
302 .iter()
303 .take(self.max_sentences)
304 .filter_map(|e| e.content.split('.').next().map(|s| s.trim().to_string()))
305 .filter(|s| !s.is_empty())
306 .collect();
307
308 if !sentences.is_empty() {
309 summary_parts.push(format!("{}: {}", type_name, sentences.join("; ")));
310 }
311 }
312
313 Ok(summary_parts.join("\n"))
314 }
315
316 fn abstract_summaries(&self, summaries: &[Summary]) -> Result<String, CompactionError> {
317 if summaries.is_empty() {
318 return Ok(String::new());
319 }
320
321 let mut abstraction_parts = Vec::new();
322
323 let first_ts = summaries
325 .iter()
326 .map(|s| s.time_range.0)
327 .fold(f64::MAX, f64::min);
328 let last_ts = summaries
329 .iter()
330 .map(|s| s.time_range.1)
331 .fold(f64::MIN, f64::max);
332
333 abstraction_parts.push(format!(
334 "[{} summaries, {:.1} hours span]",
335 summaries.len(),
336 (last_ts - first_ts) / 3600.0
337 ));
338
339 let all_topics: Vec<&str> = summaries
341 .iter()
342 .flat_map(|s| s.topics.iter().map(|t| t.as_str()))
343 .collect();
344
345 if !all_topics.is_empty() {
346 let unique_topics: Vec<_> = all_topics
347 .iter()
348 .cloned()
349 .collect::<std::collections::HashSet<_>>()
350 .into_iter()
351 .take(10)
352 .collect();
353 abstraction_parts.push(format!("Topics: {}", unique_topics.join(", ")));
354 }
355
356 let key_points: Vec<String> = summaries
358 .iter()
359 .take(5)
360 .filter_map(|s| s.content.lines().next().map(|l| l.to_string()))
361 .collect();
362
363 if !key_points.is_empty() {
364 abstraction_parts.push(format!("Key points:\n- {}", key_points.join("\n- ")));
365 }
366
367 Ok(abstraction_parts.join("\n"))
368 }
369
370 fn extract_topics(&self, content: &str) -> Vec<String> {
371 let stopwords = [
373 "the", "a", "an", "is", "are", "was", "were", "to", "from", "in", "on", "at", "for",
374 "and", "or",
375 ];
376
377 let words: Vec<&str> = content
378 .split_whitespace()
379 .filter(|w| w.len() > 3)
380 .filter(|w| !stopwords.contains(&w.to_lowercase().as_str()))
381 .collect();
382
383 let mut freq: HashMap<String, usize> = HashMap::new();
385 for word in words {
386 let normalized = word
387 .to_lowercase()
388 .trim_matches(|c: char| !c.is_alphanumeric())
389 .to_string();
390 if normalized.len() > 3 {
391 *freq.entry(normalized).or_insert(0) += 1;
392 }
393 }
394
395 let mut sorted: Vec<_> = freq.into_iter().collect();
397 sorted.sort_by(|a, b| b.1.cmp(&a.1));
398
399 sorted.into_iter().take(5).map(|(w, _)| w).collect()
400 }
401}
402
403pub struct HierarchicalMemory<S: Summarizer> {
409 config: MemoryCompactionConfig,
411
412 l0_episodes: RwLock<VecDeque<Episode>>,
414
415 l1_summaries: RwLock<VecDeque<Summary>>,
417
418 l2_abstractions: RwLock<VecDeque<Abstraction>>,
420
421 summarizer: Arc<S>,
423
424 stats: RwLock<CompactionStats>,
426
427 next_id: std::sync::atomic::AtomicU64,
429}
430
431#[derive(Debug, Clone, Default)]
433pub struct CompactionStats {
434 pub total_episodes: usize,
436
437 pub total_summaries: usize,
439
440 pub total_abstractions: usize,
442
443 pub episodes_compacted: usize,
445
446 pub summaries_compacted: usize,
448
449 pub last_compaction: Option<f64>,
451
452 pub token_savings: usize,
454}
455
456impl<S: Summarizer> HierarchicalMemory<S> {
457 pub fn new(config: MemoryCompactionConfig, summarizer: Arc<S>) -> Self {
459 Self {
460 config,
461 l0_episodes: RwLock::new(VecDeque::new()),
462 l1_summaries: RwLock::new(VecDeque::new()),
463 l2_abstractions: RwLock::new(VecDeque::new()),
464 summarizer,
465 stats: RwLock::new(CompactionStats::default()),
466 next_id: std::sync::atomic::AtomicU64::new(1),
467 }
468 }
469
470 fn next_id(&self) -> String {
472 let id = self
473 .next_id
474 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
475 format!("mem_{}", id)
476 }
477
478 pub fn add_episode(&self, content: String, episode_type: EpisodeType) -> String {
480 let id = self.next_id();
481 let timestamp = SystemTime::now()
482 .duration_since(UNIX_EPOCH)
483 .unwrap_or_default()
484 .as_secs_f64();
485
486 let token_count = content.len() / 4; let episode = Episode {
489 id: id.clone(),
490 timestamp,
491 content,
492 episode_type,
493 metadata: HashMap::new(),
494 embedding: None,
495 token_count,
496 };
497
498 {
499 let mut l0 = self.l0_episodes.write().unwrap();
500 l0.push_back(episode);
501 }
502
503 {
504 let mut stats = self.stats.write().unwrap();
505 stats.total_episodes += 1;
506 }
507
508 id
509 }
510
511 pub fn add_episode_with_embedding(
513 &self,
514 content: String,
515 episode_type: EpisodeType,
516 embedding: Vec<f32>,
517 ) -> String {
518 let id = self.next_id();
519 let timestamp = SystemTime::now()
520 .duration_since(UNIX_EPOCH)
521 .unwrap_or_default()
522 .as_secs_f64();
523
524 let token_count = content.len() / 4;
525
526 let episode = Episode {
527 id: id.clone(),
528 timestamp,
529 content,
530 episode_type,
531 metadata: HashMap::new(),
532 embedding: Some(embedding),
533 token_count,
534 };
535
536 {
537 let mut l0 = self.l0_episodes.write().unwrap();
538 l0.push_back(episode);
539 }
540
541 {
542 let mut stats = self.stats.write().unwrap();
543 stats.total_episodes += 1;
544 }
545
546 id
547 }
548
549 pub fn maybe_compact(&self) -> Result<bool, CompactionError> {
551 let needs_l0 = {
552 let l0 = self.l0_episodes.read().unwrap();
553 l0.len() >= self.config.l0_max_episodes
554 };
555
556 let needs_l1 = {
557 let l1 = self.l1_summaries.read().unwrap();
558 l1.len() >= self.config.l1_max_summaries
559 };
560
561 if needs_l0 || needs_l1 {
562 self.run_compaction()?;
563 return Ok(true);
564 }
565
566 Ok(false)
567 }
568
569 pub fn run_compaction(&self) -> Result<(), CompactionError> {
571 self.compact_l0_to_l1()?;
573
574 self.compact_l1_to_l2()?;
576
577 {
579 let mut stats = self.stats.write().unwrap();
580 stats.last_compaction = Some(
581 SystemTime::now()
582 .duration_since(UNIX_EPOCH)
583 .unwrap_or_default()
584 .as_secs_f64(),
585 );
586 }
587
588 Ok(())
589 }
590
591 fn compact_l0_to_l1(&self) -> Result<(), CompactionError> {
593 let now = SystemTime::now()
594 .duration_since(UNIX_EPOCH)
595 .unwrap_or_default()
596 .as_secs_f64();
597
598 let age_threshold = now - self.config.l0_age_threshold_secs as f64;
599
600 let to_compact: Vec<Episode> = {
602 let l0 = self.l0_episodes.read().unwrap();
603 l0.iter()
604 .filter(|e| e.timestamp < age_threshold)
605 .cloned()
606 .collect()
607 };
608
609 if to_compact.is_empty() {
610 return Ok(());
611 }
612
613 let groups = self.group_episodes(&to_compact);
615
616 for group in groups {
618 if group.is_empty() {
619 continue;
620 }
621
622 let content = self.summarizer.summarize_episodes(&group)?;
623 let topics = self.summarizer.extract_topics(&content);
624
625 let first_ts = group.iter().map(|e| e.timestamp).fold(f64::MAX, f64::min);
626 let last_ts = group.iter().map(|e| e.timestamp).fold(f64::MIN, f64::max);
627
628 let episode_ids: Vec<String> = group.iter().map(|e| e.id.clone()).collect();
629 let original_tokens: usize = group.iter().map(|e| e.token_count).sum();
630 let summary_tokens = content.len() / 4;
631
632 let summary = Summary {
633 id: self.next_id(),
634 content,
635 source_episode_ids: episode_ids,
636 time_range: (first_ts, last_ts),
637 embedding: None, token_count: summary_tokens,
639 created_at: now,
640 topics,
641 };
642
643 {
645 let mut l1 = self.l1_summaries.write().unwrap();
646 l1.push_back(summary);
647 }
648
649 {
651 let mut stats = self.stats.write().unwrap();
652 stats.total_summaries += 1;
653 stats.episodes_compacted += group.len();
654 stats.token_savings += original_tokens.saturating_sub(summary_tokens);
655 }
656 }
657
658 {
660 let mut l0 = self.l0_episodes.write().unwrap();
661 l0.retain(|e| e.timestamp >= age_threshold);
662 }
663
664 Ok(())
665 }
666
667 fn compact_l1_to_l2(&self) -> Result<(), CompactionError> {
669 let now = SystemTime::now()
670 .duration_since(UNIX_EPOCH)
671 .unwrap_or_default()
672 .as_secs_f64();
673
674 let age_threshold = now - self.config.l1_age_threshold_secs as f64;
675
676 let to_compact: Vec<Summary> = {
678 let l1 = self.l1_summaries.read().unwrap();
679 l1.iter()
680 .filter(|s| s.created_at < age_threshold)
681 .cloned()
682 .collect()
683 };
684
685 if to_compact.len() < self.config.group_size {
686 return Ok(());
687 }
688
689 let groups = self.group_summaries(&to_compact);
691
692 for group in groups {
693 if group.is_empty() {
694 continue;
695 }
696
697 let content = self.summarizer.abstract_summaries(&group)?;
698
699 let first_ts = group
700 .iter()
701 .map(|s| s.time_range.0)
702 .fold(f64::MAX, f64::min);
703 let last_ts = group
704 .iter()
705 .map(|s| s.time_range.1)
706 .fold(f64::MIN, f64::max);
707
708 let summary_ids: Vec<String> = group.iter().map(|s| s.id.clone()).collect();
709 let original_tokens: usize = group.iter().map(|s| s.token_count).sum();
710 let abstraction_tokens = content.len() / 4;
711
712 let insights: Vec<String> = group
714 .iter()
715 .flat_map(|s| s.topics.clone())
716 .collect::<std::collections::HashSet<_>>()
717 .into_iter()
718 .take(5)
719 .collect();
720
721 let abstraction = Abstraction {
722 id: self.next_id(),
723 content,
724 source_summary_ids: summary_ids,
725 time_range: (first_ts, last_ts),
726 embedding: None,
727 token_count: abstraction_tokens,
728 created_at: now,
729 insights,
730 };
731
732 {
734 let mut l2 = self.l2_abstractions.write().unwrap();
735 l2.push_back(abstraction);
736 }
737
738 {
740 let mut stats = self.stats.write().unwrap();
741 stats.total_abstractions += 1;
742 stats.summaries_compacted += group.len();
743 stats.token_savings += original_tokens.saturating_sub(abstraction_tokens);
744 }
745 }
746
747 {
749 let mut l1 = self.l1_summaries.write().unwrap();
750 l1.retain(|s| s.created_at >= age_threshold);
751 }
752
753 Ok(())
754 }
755
756 fn group_episodes(&self, episodes: &[Episode]) -> Vec<Vec<Episode>> {
758 episodes
760 .chunks(self.config.group_size)
761 .map(|chunk| chunk.to_vec())
762 .collect()
763 }
764
765 fn group_summaries(&self, summaries: &[Summary]) -> Vec<Vec<Summary>> {
767 summaries
768 .chunks(self.config.group_size)
769 .map(|chunk| chunk.to_vec())
770 .collect()
771 }
772
773 pub fn total_tokens(&self) -> usize {
775 let l0: usize = self
776 .l0_episodes
777 .read()
778 .unwrap()
779 .iter()
780 .map(|e| e.token_count)
781 .sum();
782 let l1: usize = self
783 .l1_summaries
784 .read()
785 .unwrap()
786 .iter()
787 .map(|s| s.token_count)
788 .sum();
789 let l2: usize = self
790 .l2_abstractions
791 .read()
792 .unwrap()
793 .iter()
794 .map(|a| a.token_count)
795 .sum();
796
797 l0 + l1 + l2
798 }
799
800 pub fn get_context(&self, max_tokens: usize) -> Vec<MemoryEntry> {
802 let mut entries = Vec::new();
803 let mut tokens_used = 0;
804
805 let l0 = self.l0_episodes.read().unwrap();
807 for episode in l0.iter().rev() {
808 if tokens_used + episode.token_count > max_tokens {
809 break;
810 }
811 entries.push(MemoryEntry::Episode(episode.clone()));
812 tokens_used += episode.token_count;
813 }
814
815 let l1 = self.l1_summaries.read().unwrap();
817 for summary in l1.iter().rev() {
818 if tokens_used + summary.token_count > max_tokens {
819 break;
820 }
821 entries.push(MemoryEntry::Summary(summary.clone()));
822 tokens_used += summary.token_count;
823 }
824
825 let l2 = self.l2_abstractions.read().unwrap();
827 for abstraction in l2.iter().rev() {
828 if tokens_used + abstraction.token_count > max_tokens {
829 break;
830 }
831 entries.push(MemoryEntry::Abstraction(abstraction.clone()));
832 tokens_used += abstraction.token_count;
833 }
834
835 entries
836 }
837
838 pub fn stats(&self) -> CompactionStats {
840 self.stats.read().unwrap().clone()
841 }
842
843 pub fn tier_counts(&self) -> (usize, usize, usize) {
845 let l0 = self.l0_episodes.read().unwrap().len();
846 let l1 = self.l1_summaries.read().unwrap().len();
847 let l2 = self.l2_abstractions.read().unwrap().len();
848 (l0, l1, l2)
849 }
850}
851
852#[derive(Debug, Clone)]
854pub enum MemoryEntry {
855 Episode(Episode),
856 Summary(Summary),
857 Abstraction(Abstraction),
858}
859
860impl MemoryEntry {
861 pub fn content(&self) -> &str {
863 match self {
864 Self::Episode(e) => &e.content,
865 Self::Summary(s) => &s.content,
866 Self::Abstraction(a) => &a.content,
867 }
868 }
869
870 pub fn token_count(&self) -> usize {
872 match self {
873 Self::Episode(e) => e.token_count,
874 Self::Summary(s) => s.token_count,
875 Self::Abstraction(a) => a.token_count,
876 }
877 }
878
879 pub fn tier(&self) -> usize {
881 match self {
882 Self::Episode(_) => 0,
883 Self::Summary(_) => 1,
884 Self::Abstraction(_) => 2,
885 }
886 }
887}
888
889#[derive(Debug, Clone)]
895pub enum CompactionError {
896 SummarizationFailed(String),
898 EmbeddingFailed(String),
900 StorageError(String),
902}
903
904impl std::fmt::Display for CompactionError {
905 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
906 match self {
907 Self::SummarizationFailed(msg) => write!(f, "Summarization failed: {}", msg),
908 Self::EmbeddingFailed(msg) => write!(f, "Embedding failed: {}", msg),
909 Self::StorageError(msg) => write!(f, "Storage error: {}", msg),
910 }
911 }
912}
913
914impl std::error::Error for CompactionError {}
915
916pub fn create_hierarchical_memory() -> HierarchicalMemory<ExtractiveSummarizer> {
922 HierarchicalMemory::new(
923 MemoryCompactionConfig::default(),
924 Arc::new(ExtractiveSummarizer::default()),
925 )
926}
927
928pub fn create_test_memory() -> HierarchicalMemory<ExtractiveSummarizer> {
930 HierarchicalMemory::new(
931 MemoryCompactionConfig::aggressive(),
932 Arc::new(ExtractiveSummarizer::default()),
933 )
934}
935
936#[cfg(test)]
941mod tests {
942 use super::*;
943
944 #[test]
945 fn test_add_episode() {
946 let memory = create_test_memory();
947
948 let id = memory.add_episode(
949 "User asked about weather".to_string(),
950 EpisodeType::UserMessage,
951 );
952
953 assert!(id.starts_with("mem_"));
954
955 let (l0, l1, l2) = memory.tier_counts();
956 assert_eq!(l0, 1);
957 assert_eq!(l1, 0);
958 assert_eq!(l2, 0);
959 }
960
961 #[test]
962 fn test_extractive_summarizer() {
963 let summarizer = ExtractiveSummarizer::default();
964
965 let episodes = vec![
966 Episode {
967 id: "1".to_string(),
968 timestamp: 0.0,
969 content: "User asked about the weather forecast.".to_string(),
970 episode_type: EpisodeType::UserMessage,
971 metadata: HashMap::new(),
972 embedding: None,
973 token_count: 10,
974 },
975 Episode {
976 id: "2".to_string(),
977 timestamp: 1.0,
978 content: "Assistant provided weather information for NYC.".to_string(),
979 episode_type: EpisodeType::AssistantResponse,
980 metadata: HashMap::new(),
981 embedding: None,
982 token_count: 12,
983 },
984 ];
985
986 let summary = summarizer.summarize_episodes(&episodes).unwrap();
987
988 assert!(!summary.is_empty());
989 assert!(
990 summary.contains("episodes")
991 || summary.contains("User")
992 || summary.contains("Responses")
993 );
994 }
995
996 #[test]
997 fn test_topic_extraction() {
998 let summarizer = ExtractiveSummarizer::default();
999
1000 let content = "The weather forecast shows sunny conditions with temperatures around 75 degrees. Tomorrow expects rain and thunderstorms across the region.";
1001
1002 let topics = summarizer.extract_topics(content);
1003
1004 assert!(!topics.is_empty());
1005 }
1007
1008 #[test]
1009 fn test_memory_context_retrieval() {
1010 let memory = create_test_memory();
1011
1012 for i in 0..5 {
1014 memory.add_episode(
1015 format!("Episode {} content here with some text.", i),
1016 EpisodeType::UserMessage,
1017 );
1018 }
1019
1020 let context = memory.get_context(1000);
1021
1022 assert!(!context.is_empty());
1023
1024 for entry in &context {
1026 assert_eq!(entry.tier(), 0);
1027 }
1028 }
1029
1030 #[test]
1031 fn test_token_tracking() {
1032 let memory = create_test_memory();
1033
1034 memory.add_episode("Short message".to_string(), EpisodeType::UserMessage);
1035
1036 memory.add_episode(
1037 "A much longer message with more content that should have more tokens estimated"
1038 .to_string(),
1039 EpisodeType::AssistantResponse,
1040 );
1041
1042 let total = memory.total_tokens();
1043 assert!(total > 0);
1044 }
1045
1046 #[test]
1047 fn test_stats_tracking() {
1048 let memory = create_test_memory();
1049
1050 for _ in 0..10 {
1051 memory.add_episode("Test episode".to_string(), EpisodeType::UserMessage);
1052 }
1053
1054 let stats = memory.stats();
1055 assert_eq!(stats.total_episodes, 10);
1056 }
1057}