1use std::collections::{HashMap, VecDeque};
44use std::sync::{Arc, RwLock};
45use std::time::{SystemTime, UNIX_EPOCH};
46
47#[derive(Debug, Clone)]
53pub struct MemoryCompactionConfig {
54 pub l0_max_episodes: usize,
56
57 pub l1_max_summaries: usize,
59
60 pub l0_age_threshold_secs: u64,
62
63 pub l1_age_threshold_secs: u64,
65
66 pub group_size: usize,
68
69 pub similarity_threshold: f32,
71
72 pub max_summary_tokens: usize,
74
75 pub reembed_summaries: bool,
77
78 pub check_interval_secs: u64,
80}
81
82impl Default for MemoryCompactionConfig {
83 fn default() -> Self {
84 Self {
85 l0_max_episodes: 1000,
86 l1_max_summaries: 100,
87 l0_age_threshold_secs: 3600, l1_age_threshold_secs: 86400 * 7, group_size: 10,
90 similarity_threshold: 0.7,
91 max_summary_tokens: 200,
92 reembed_summaries: true,
93 check_interval_secs: 300, }
95 }
96}
97
98impl MemoryCompactionConfig {
99 pub fn aggressive() -> Self {
101 Self {
102 l0_max_episodes: 100,
103 l1_max_summaries: 20,
104 l0_age_threshold_secs: 60,
105 l1_age_threshold_secs: 3600,
106 group_size: 5,
107 ..Default::default()
108 }
109 }
110
111 pub fn long_running() -> Self {
113 Self {
114 l0_max_episodes: 5000,
115 l1_max_summaries: 500,
116 l0_age_threshold_secs: 3600 * 6, l1_age_threshold_secs: 86400 * 30, group_size: 20,
119 ..Default::default()
120 }
121 }
122}
123
124#[derive(Debug, Clone)]
130pub struct Episode {
131 pub id: String,
133
134 pub timestamp: f64,
136
137 pub content: String,
139
140 pub episode_type: EpisodeType,
142
143 pub metadata: HashMap<String, String>,
145
146 pub embedding: Option<Vec<f32>>,
148
149 pub token_count: usize,
151}
152
153#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
155pub enum EpisodeType {
156 UserMessage,
158 AssistantResponse,
160 ToolCall,
162 ToolResult,
164 SystemEvent,
166 Observation,
168}
169
170#[derive(Debug, Clone)]
172pub struct Summary {
173 pub id: String,
175
176 pub content: String,
178
179 pub source_episode_ids: Vec<String>,
181
182 pub time_range: (f64, f64),
184
185 pub embedding: Option<Vec<f32>>,
187
188 pub token_count: usize,
190
191 pub created_at: f64,
193
194 pub topics: Vec<String>,
196}
197
198#[derive(Debug, Clone)]
200pub struct Abstraction {
201 pub id: String,
203
204 pub content: String,
206
207 pub source_summary_ids: Vec<String>,
209
210 pub time_range: (f64, f64),
212
213 pub embedding: Option<Vec<f32>>,
215
216 pub token_count: usize,
218
219 pub created_at: f64,
221
222 pub insights: Vec<String>,
224}
225
226pub trait Summarizer: Send + Sync {
232 fn summarize_episodes(&self, episodes: &[Episode]) -> Result<String, CompactionError>;
234
235 fn abstract_summaries(&self, summaries: &[Summary]) -> Result<String, CompactionError>;
237
238 fn extract_topics(&self, content: &str) -> Vec<String>;
240}
241
242pub struct ExtractiveSummarizer {
244 pub max_sentences: usize,
246
247 pub include_timestamps: bool,
249}
250
251impl Default for ExtractiveSummarizer {
252 fn default() -> Self {
253 Self {
254 max_sentences: 5,
255 include_timestamps: true,
256 }
257 }
258}
259
260impl Summarizer for ExtractiveSummarizer {
261 fn summarize_episodes(&self, episodes: &[Episode]) -> Result<String, CompactionError> {
262 if episodes.is_empty() {
263 return Ok(String::new());
264 }
265
266 let mut summary_parts = Vec::new();
267
268 let first_ts = episodes.first().map(|e| e.timestamp).unwrap_or(0.0);
270 let last_ts = episodes.last().map(|e| e.timestamp).unwrap_or(0.0);
271
272 if self.include_timestamps {
273 summary_parts.push(format!(
274 "[{} episodes over {:.0} seconds]",
275 episodes.len(),
276 last_ts - first_ts
277 ));
278 }
279
280 let mut by_type: HashMap<EpisodeType, Vec<&Episode>> = HashMap::new();
282 for episode in episodes {
283 by_type.entry(episode.episode_type).or_default().push(episode);
284 }
285
286 for (ep_type, eps) in by_type {
288 let type_name = match ep_type {
289 EpisodeType::UserMessage => "User messages",
290 EpisodeType::AssistantResponse => "Responses",
291 EpisodeType::ToolCall => "Tool calls",
292 EpisodeType::ToolResult => "Tool results",
293 EpisodeType::SystemEvent => "Events",
294 EpisodeType::Observation => "Observations",
295 };
296
297 let sentences: Vec<String> = eps
299 .iter()
300 .take(self.max_sentences)
301 .filter_map(|e| e.content.split('.').next().map(|s| s.trim().to_string()))
302 .filter(|s| !s.is_empty())
303 .collect();
304
305 if !sentences.is_empty() {
306 summary_parts.push(format!("{}: {}", type_name, sentences.join("; ")));
307 }
308 }
309
310 Ok(summary_parts.join("\n"))
311 }
312
313 fn abstract_summaries(&self, summaries: &[Summary]) -> Result<String, CompactionError> {
314 if summaries.is_empty() {
315 return Ok(String::new());
316 }
317
318 let mut abstraction_parts = Vec::new();
319
320 let first_ts = summaries.iter().map(|s| s.time_range.0).fold(f64::MAX, f64::min);
322 let last_ts = summaries.iter().map(|s| s.time_range.1).fold(f64::MIN, f64::max);
323
324 abstraction_parts.push(format!(
325 "[{} summaries, {:.1} hours span]",
326 summaries.len(),
327 (last_ts - first_ts) / 3600.0
328 ));
329
330 let all_topics: Vec<&str> = summaries
332 .iter()
333 .flat_map(|s| s.topics.iter().map(|t| t.as_str()))
334 .collect();
335
336 if !all_topics.is_empty() {
337 let unique_topics: Vec<_> = all_topics.iter()
338 .cloned()
339 .collect::<std::collections::HashSet<_>>()
340 .into_iter()
341 .take(10)
342 .collect();
343 abstraction_parts.push(format!("Topics: {}", unique_topics.join(", ")));
344 }
345
346 let key_points: Vec<String> = summaries
348 .iter()
349 .take(5)
350 .filter_map(|s| s.content.lines().next().map(|l| l.to_string()))
351 .collect();
352
353 if !key_points.is_empty() {
354 abstraction_parts.push(format!("Key points:\n- {}", key_points.join("\n- ")));
355 }
356
357 Ok(abstraction_parts.join("\n"))
358 }
359
360 fn extract_topics(&self, content: &str) -> Vec<String> {
361 let stopwords = ["the", "a", "an", "is", "are", "was", "were", "to", "from", "in", "on", "at", "for", "and", "or"];
363
364 let words: Vec<&str> = content
365 .split_whitespace()
366 .filter(|w| w.len() > 3)
367 .filter(|w| !stopwords.contains(&w.to_lowercase().as_str()))
368 .collect();
369
370 let mut freq: HashMap<String, usize> = HashMap::new();
372 for word in words {
373 let normalized = word.to_lowercase().trim_matches(|c: char| !c.is_alphanumeric()).to_string();
374 if normalized.len() > 3 {
375 *freq.entry(normalized).or_insert(0) += 1;
376 }
377 }
378
379 let mut sorted: Vec<_> = freq.into_iter().collect();
381 sorted.sort_by(|a, b| b.1.cmp(&a.1));
382
383 sorted.into_iter().take(5).map(|(w, _)| w).collect()
384 }
385}
386
387pub struct HierarchicalMemory<S: Summarizer> {
393 config: MemoryCompactionConfig,
395
396 l0_episodes: RwLock<VecDeque<Episode>>,
398
399 l1_summaries: RwLock<VecDeque<Summary>>,
401
402 l2_abstractions: RwLock<VecDeque<Abstraction>>,
404
405 summarizer: Arc<S>,
407
408 stats: RwLock<CompactionStats>,
410
411 next_id: std::sync::atomic::AtomicU64,
413}
414
415#[derive(Debug, Clone, Default)]
417pub struct CompactionStats {
418 pub total_episodes: usize,
420
421 pub total_summaries: usize,
423
424 pub total_abstractions: usize,
426
427 pub episodes_compacted: usize,
429
430 pub summaries_compacted: usize,
432
433 pub last_compaction: Option<f64>,
435
436 pub token_savings: usize,
438}
439
440impl<S: Summarizer> HierarchicalMemory<S> {
441 pub fn new(config: MemoryCompactionConfig, summarizer: Arc<S>) -> Self {
443 Self {
444 config,
445 l0_episodes: RwLock::new(VecDeque::new()),
446 l1_summaries: RwLock::new(VecDeque::new()),
447 l2_abstractions: RwLock::new(VecDeque::new()),
448 summarizer,
449 stats: RwLock::new(CompactionStats::default()),
450 next_id: std::sync::atomic::AtomicU64::new(1),
451 }
452 }
453
454 fn next_id(&self) -> String {
456 let id = self.next_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst);
457 format!("mem_{}", id)
458 }
459
460 pub fn add_episode(&self, content: String, episode_type: EpisodeType) -> String {
462 let id = self.next_id();
463 let timestamp = SystemTime::now()
464 .duration_since(UNIX_EPOCH)
465 .unwrap_or_default()
466 .as_secs_f64();
467
468 let token_count = content.len() / 4; let episode = Episode {
471 id: id.clone(),
472 timestamp,
473 content,
474 episode_type,
475 metadata: HashMap::new(),
476 embedding: None,
477 token_count,
478 };
479
480 {
481 let mut l0 = self.l0_episodes.write().unwrap();
482 l0.push_back(episode);
483 }
484
485 {
486 let mut stats = self.stats.write().unwrap();
487 stats.total_episodes += 1;
488 }
489
490 id
491 }
492
493 pub fn add_episode_with_embedding(
495 &self,
496 content: String,
497 episode_type: EpisodeType,
498 embedding: Vec<f32>,
499 ) -> String {
500 let id = self.next_id();
501 let timestamp = SystemTime::now()
502 .duration_since(UNIX_EPOCH)
503 .unwrap_or_default()
504 .as_secs_f64();
505
506 let token_count = content.len() / 4;
507
508 let episode = Episode {
509 id: id.clone(),
510 timestamp,
511 content,
512 episode_type,
513 metadata: HashMap::new(),
514 embedding: Some(embedding),
515 token_count,
516 };
517
518 {
519 let mut l0 = self.l0_episodes.write().unwrap();
520 l0.push_back(episode);
521 }
522
523 {
524 let mut stats = self.stats.write().unwrap();
525 stats.total_episodes += 1;
526 }
527
528 id
529 }
530
531 pub fn maybe_compact(&self) -> Result<bool, CompactionError> {
533 let needs_l0 = {
534 let l0 = self.l0_episodes.read().unwrap();
535 l0.len() >= self.config.l0_max_episodes
536 };
537
538 let needs_l1 = {
539 let l1 = self.l1_summaries.read().unwrap();
540 l1.len() >= self.config.l1_max_summaries
541 };
542
543 if needs_l0 || needs_l1 {
544 self.run_compaction()?;
545 return Ok(true);
546 }
547
548 Ok(false)
549 }
550
551 pub fn run_compaction(&self) -> Result<(), CompactionError> {
553 self.compact_l0_to_l1()?;
555
556 self.compact_l1_to_l2()?;
558
559 {
561 let mut stats = self.stats.write().unwrap();
562 stats.last_compaction = Some(
563 SystemTime::now()
564 .duration_since(UNIX_EPOCH)
565 .unwrap_or_default()
566 .as_secs_f64()
567 );
568 }
569
570 Ok(())
571 }
572
573 fn compact_l0_to_l1(&self) -> Result<(), CompactionError> {
575 let now = SystemTime::now()
576 .duration_since(UNIX_EPOCH)
577 .unwrap_or_default()
578 .as_secs_f64();
579
580 let age_threshold = now - self.config.l0_age_threshold_secs as f64;
581
582 let to_compact: Vec<Episode> = {
584 let l0 = self.l0_episodes.read().unwrap();
585 l0.iter()
586 .filter(|e| e.timestamp < age_threshold)
587 .cloned()
588 .collect()
589 };
590
591 if to_compact.is_empty() {
592 return Ok(());
593 }
594
595 let groups = self.group_episodes(&to_compact);
597
598 for group in groups {
600 if group.is_empty() {
601 continue;
602 }
603
604 let content = self.summarizer.summarize_episodes(&group)?;
605 let topics = self.summarizer.extract_topics(&content);
606
607 let first_ts = group.iter().map(|e| e.timestamp).fold(f64::MAX, f64::min);
608 let last_ts = group.iter().map(|e| e.timestamp).fold(f64::MIN, f64::max);
609
610 let episode_ids: Vec<String> = group.iter().map(|e| e.id.clone()).collect();
611 let original_tokens: usize = group.iter().map(|e| e.token_count).sum();
612 let summary_tokens = content.len() / 4;
613
614 let summary = Summary {
615 id: self.next_id(),
616 content,
617 source_episode_ids: episode_ids,
618 time_range: (first_ts, last_ts),
619 embedding: None, token_count: summary_tokens,
621 created_at: now,
622 topics,
623 };
624
625 {
627 let mut l1 = self.l1_summaries.write().unwrap();
628 l1.push_back(summary);
629 }
630
631 {
633 let mut stats = self.stats.write().unwrap();
634 stats.total_summaries += 1;
635 stats.episodes_compacted += group.len();
636 stats.token_savings += original_tokens.saturating_sub(summary_tokens);
637 }
638 }
639
640 {
642 let mut l0 = self.l0_episodes.write().unwrap();
643 l0.retain(|e| e.timestamp >= age_threshold);
644 }
645
646 Ok(())
647 }
648
649 fn compact_l1_to_l2(&self) -> Result<(), CompactionError> {
651 let now = SystemTime::now()
652 .duration_since(UNIX_EPOCH)
653 .unwrap_or_default()
654 .as_secs_f64();
655
656 let age_threshold = now - self.config.l1_age_threshold_secs as f64;
657
658 let to_compact: Vec<Summary> = {
660 let l1 = self.l1_summaries.read().unwrap();
661 l1.iter()
662 .filter(|s| s.created_at < age_threshold)
663 .cloned()
664 .collect()
665 };
666
667 if to_compact.len() < self.config.group_size {
668 return Ok(());
669 }
670
671 let groups = self.group_summaries(&to_compact);
673
674 for group in groups {
675 if group.is_empty() {
676 continue;
677 }
678
679 let content = self.summarizer.abstract_summaries(&group)?;
680
681 let first_ts = group.iter().map(|s| s.time_range.0).fold(f64::MAX, f64::min);
682 let last_ts = group.iter().map(|s| s.time_range.1).fold(f64::MIN, f64::max);
683
684 let summary_ids: Vec<String> = group.iter().map(|s| s.id.clone()).collect();
685 let original_tokens: usize = group.iter().map(|s| s.token_count).sum();
686 let abstraction_tokens = content.len() / 4;
687
688 let insights: Vec<String> = group
690 .iter()
691 .flat_map(|s| s.topics.clone())
692 .collect::<std::collections::HashSet<_>>()
693 .into_iter()
694 .take(5)
695 .collect();
696
697 let abstraction = Abstraction {
698 id: self.next_id(),
699 content,
700 source_summary_ids: summary_ids,
701 time_range: (first_ts, last_ts),
702 embedding: None,
703 token_count: abstraction_tokens,
704 created_at: now,
705 insights,
706 };
707
708 {
710 let mut l2 = self.l2_abstractions.write().unwrap();
711 l2.push_back(abstraction);
712 }
713
714 {
716 let mut stats = self.stats.write().unwrap();
717 stats.total_abstractions += 1;
718 stats.summaries_compacted += group.len();
719 stats.token_savings += original_tokens.saturating_sub(abstraction_tokens);
720 }
721 }
722
723 {
725 let mut l1 = self.l1_summaries.write().unwrap();
726 l1.retain(|s| s.created_at >= age_threshold);
727 }
728
729 Ok(())
730 }
731
732 fn group_episodes(&self, episodes: &[Episode]) -> Vec<Vec<Episode>> {
734 episodes
736 .chunks(self.config.group_size)
737 .map(|chunk| chunk.to_vec())
738 .collect()
739 }
740
741 fn group_summaries(&self, summaries: &[Summary]) -> Vec<Vec<Summary>> {
743 summaries
744 .chunks(self.config.group_size)
745 .map(|chunk| chunk.to_vec())
746 .collect()
747 }
748
749 pub fn total_tokens(&self) -> usize {
751 let l0: usize = self.l0_episodes.read().unwrap().iter().map(|e| e.token_count).sum();
752 let l1: usize = self.l1_summaries.read().unwrap().iter().map(|s| s.token_count).sum();
753 let l2: usize = self.l2_abstractions.read().unwrap().iter().map(|a| a.token_count).sum();
754
755 l0 + l1 + l2
756 }
757
758 pub fn get_context(&self, max_tokens: usize) -> Vec<MemoryEntry> {
760 let mut entries = Vec::new();
761 let mut tokens_used = 0;
762
763 let l0 = self.l0_episodes.read().unwrap();
765 for episode in l0.iter().rev() {
766 if tokens_used + episode.token_count > max_tokens {
767 break;
768 }
769 entries.push(MemoryEntry::Episode(episode.clone()));
770 tokens_used += episode.token_count;
771 }
772
773 let l1 = self.l1_summaries.read().unwrap();
775 for summary in l1.iter().rev() {
776 if tokens_used + summary.token_count > max_tokens {
777 break;
778 }
779 entries.push(MemoryEntry::Summary(summary.clone()));
780 tokens_used += summary.token_count;
781 }
782
783 let l2 = self.l2_abstractions.read().unwrap();
785 for abstraction in l2.iter().rev() {
786 if tokens_used + abstraction.token_count > max_tokens {
787 break;
788 }
789 entries.push(MemoryEntry::Abstraction(abstraction.clone()));
790 tokens_used += abstraction.token_count;
791 }
792
793 entries
794 }
795
796 pub fn stats(&self) -> CompactionStats {
798 self.stats.read().unwrap().clone()
799 }
800
801 pub fn tier_counts(&self) -> (usize, usize, usize) {
803 let l0 = self.l0_episodes.read().unwrap().len();
804 let l1 = self.l1_summaries.read().unwrap().len();
805 let l2 = self.l2_abstractions.read().unwrap().len();
806 (l0, l1, l2)
807 }
808}
809
810#[derive(Debug, Clone)]
812pub enum MemoryEntry {
813 Episode(Episode),
814 Summary(Summary),
815 Abstraction(Abstraction),
816}
817
818impl MemoryEntry {
819 pub fn content(&self) -> &str {
821 match self {
822 Self::Episode(e) => &e.content,
823 Self::Summary(s) => &s.content,
824 Self::Abstraction(a) => &a.content,
825 }
826 }
827
828 pub fn token_count(&self) -> usize {
830 match self {
831 Self::Episode(e) => e.token_count,
832 Self::Summary(s) => s.token_count,
833 Self::Abstraction(a) => a.token_count,
834 }
835 }
836
837 pub fn tier(&self) -> usize {
839 match self {
840 Self::Episode(_) => 0,
841 Self::Summary(_) => 1,
842 Self::Abstraction(_) => 2,
843 }
844 }
845}
846
847#[derive(Debug, Clone)]
853pub enum CompactionError {
854 SummarizationFailed(String),
856 EmbeddingFailed(String),
858 StorageError(String),
860}
861
862impl std::fmt::Display for CompactionError {
863 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
864 match self {
865 Self::SummarizationFailed(msg) => write!(f, "Summarization failed: {}", msg),
866 Self::EmbeddingFailed(msg) => write!(f, "Embedding failed: {}", msg),
867 Self::StorageError(msg) => write!(f, "Storage error: {}", msg),
868 }
869 }
870}
871
872impl std::error::Error for CompactionError {}
873
874pub fn create_hierarchical_memory() -> HierarchicalMemory<ExtractiveSummarizer> {
880 HierarchicalMemory::new(
881 MemoryCompactionConfig::default(),
882 Arc::new(ExtractiveSummarizer::default()),
883 )
884}
885
886pub fn create_test_memory() -> HierarchicalMemory<ExtractiveSummarizer> {
888 HierarchicalMemory::new(
889 MemoryCompactionConfig::aggressive(),
890 Arc::new(ExtractiveSummarizer::default()),
891 )
892}
893
894#[cfg(test)]
899mod tests {
900 use super::*;
901
902 #[test]
903 fn test_add_episode() {
904 let memory = create_test_memory();
905
906 let id = memory.add_episode(
907 "User asked about weather".to_string(),
908 EpisodeType::UserMessage,
909 );
910
911 assert!(id.starts_with("mem_"));
912
913 let (l0, l1, l2) = memory.tier_counts();
914 assert_eq!(l0, 1);
915 assert_eq!(l1, 0);
916 assert_eq!(l2, 0);
917 }
918
919 #[test]
920 fn test_extractive_summarizer() {
921 let summarizer = ExtractiveSummarizer::default();
922
923 let episodes = vec![
924 Episode {
925 id: "1".to_string(),
926 timestamp: 0.0,
927 content: "User asked about the weather forecast.".to_string(),
928 episode_type: EpisodeType::UserMessage,
929 metadata: HashMap::new(),
930 embedding: None,
931 token_count: 10,
932 },
933 Episode {
934 id: "2".to_string(),
935 timestamp: 1.0,
936 content: "Assistant provided weather information for NYC.".to_string(),
937 episode_type: EpisodeType::AssistantResponse,
938 metadata: HashMap::new(),
939 embedding: None,
940 token_count: 12,
941 },
942 ];
943
944 let summary = summarizer.summarize_episodes(&episodes).unwrap();
945
946 assert!(!summary.is_empty());
947 assert!(summary.contains("episodes") || summary.contains("User") || summary.contains("Responses"));
948 }
949
950 #[test]
951 fn test_topic_extraction() {
952 let summarizer = ExtractiveSummarizer::default();
953
954 let content = "The weather forecast shows sunny conditions with temperatures around 75 degrees. Tomorrow expects rain and thunderstorms across the region.";
955
956 let topics = summarizer.extract_topics(content);
957
958 assert!(!topics.is_empty());
959 }
961
962 #[test]
963 fn test_memory_context_retrieval() {
964 let memory = create_test_memory();
965
966 for i in 0..5 {
968 memory.add_episode(
969 format!("Episode {} content here with some text.", i),
970 EpisodeType::UserMessage,
971 );
972 }
973
974 let context = memory.get_context(1000);
975
976 assert!(!context.is_empty());
977
978 for entry in &context {
980 assert_eq!(entry.tier(), 0);
981 }
982 }
983
984 #[test]
985 fn test_token_tracking() {
986 let memory = create_test_memory();
987
988 memory.add_episode(
989 "Short message".to_string(),
990 EpisodeType::UserMessage,
991 );
992
993 memory.add_episode(
994 "A much longer message with more content that should have more tokens estimated".to_string(),
995 EpisodeType::AssistantResponse,
996 );
997
998 let total = memory.total_tokens();
999 assert!(total > 0);
1000 }
1001
1002 #[test]
1003 fn test_stats_tracking() {
1004 let memory = create_test_memory();
1005
1006 for _ in 0..10 {
1007 memory.add_episode("Test episode".to_string(), EpisodeType::UserMessage);
1008 }
1009
1010 let stats = memory.stats();
1011 assert_eq!(stats.total_episodes, 10);
1012 }
1013}