Skip to main content

sochdb_query/
memory_compaction.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Hierarchical Memory Compaction (Task 5)
19//!
20//! This module implements semantic memory compaction inspired by LSM-trees.
21//! It manages tiered storage where older memories are summarized to maintain
22//! bounded context while preserving semantic continuity.
23//!
24//! ## Architecture
25//!
26//! ```text
27//! L0: Raw Episodes (recent, full detail)
28//!     │
29//!     ▼ Summarization
30//! L1: Summaries (older, compressed)
31//!     │
32//!     ▼ Abstraction
33//! L2: Abstractions (oldest, highly compressed)
34//! ```
35//!
36//! ## Compaction Strategy
37//!
38//! - Episodes older than tier threshold are grouped by semantic similarity
39//! - Each group is summarized (via LLM or extractive methods)
40//! - Summaries are re-embedded for retrieval
41//! - Growth is O(log_c T) where c = compaction ratio, T = total events
42
43use std::collections::{HashMap, VecDeque};
44use std::sync::{Arc, RwLock};
45use std::time::{SystemTime, UNIX_EPOCH};
46
47// ============================================================================
48// Configuration
49// ============================================================================
50
51/// Configuration for memory compaction
52#[derive(Debug, Clone)]
53pub struct MemoryCompactionConfig {
54    /// Maximum episodes in L0 before compaction
55    pub l0_max_episodes: usize,
56
57    /// Maximum summaries in L1 before compaction
58    pub l1_max_summaries: usize,
59
60    /// Age threshold for L0 → L1 compaction (seconds)
61    pub l0_age_threshold_secs: u64,
62
63    /// Age threshold for L1 → L2 compaction (seconds)
64    pub l1_age_threshold_secs: u64,
65
66    /// Number of episodes to group for summarization
67    pub group_size: usize,
68
69    /// Similarity threshold for grouping (0.0 to 1.0)
70    pub similarity_threshold: f32,
71
72    /// Maximum tokens per summary
73    pub max_summary_tokens: usize,
74
75    /// Whether to re-embed summaries for retrieval
76    pub reembed_summaries: bool,
77
78    /// Compaction check interval (seconds)
79    pub check_interval_secs: u64,
80}
81
82impl Default for MemoryCompactionConfig {
83    fn default() -> Self {
84        Self {
85            l0_max_episodes: 1000,
86            l1_max_summaries: 100,
87            l0_age_threshold_secs: 3600,      // 1 hour
88            l1_age_threshold_secs: 86400 * 7, // 1 week
89            group_size: 10,
90            similarity_threshold: 0.7,
91            max_summary_tokens: 200,
92            reembed_summaries: true,
93            check_interval_secs: 300, // 5 minutes
94        }
95    }
96}
97
98impl MemoryCompactionConfig {
99    /// Create config for aggressive compaction (testing/demos)
100    pub fn aggressive() -> Self {
101        Self {
102            l0_max_episodes: 100,
103            l1_max_summaries: 20,
104            l0_age_threshold_secs: 60,
105            l1_age_threshold_secs: 3600,
106            group_size: 5,
107            ..Default::default()
108        }
109    }
110
111    /// Create config for long-running agents
112    pub fn long_running() -> Self {
113        Self {
114            l0_max_episodes: 5000,
115            l1_max_summaries: 500,
116            l0_age_threshold_secs: 3600 * 6,   // 6 hours
117            l1_age_threshold_secs: 86400 * 30, // 30 days
118            group_size: 20,
119            ..Default::default()
120        }
121    }
122}
123
124// ============================================================================
125// Memory Types
126// ============================================================================
127
128/// A raw episode (L0)
129#[derive(Debug, Clone)]
130pub struct Episode {
131    /// Unique identifier
132    pub id: String,
133
134    /// Timestamp (seconds since epoch)
135    pub timestamp: f64,
136
137    /// Episode content (e.g., user message, tool call)
138    pub content: String,
139
140    /// Episode type
141    pub episode_type: EpisodeType,
142
143    /// Associated metadata
144    pub metadata: HashMap<String, String>,
145
146    /// Embedding vector (for similarity grouping)
147    pub embedding: Option<Vec<f32>>,
148
149    /// Token count (estimated or exact)
150    pub token_count: usize,
151}
152
153/// Episode types
154#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
155pub enum EpisodeType {
156    /// User message
157    UserMessage,
158    /// Assistant response
159    AssistantResponse,
160    /// Tool call
161    ToolCall,
162    /// Tool result
163    ToolResult,
164    /// System event
165    SystemEvent,
166    /// Observation
167    Observation,
168}
169
170/// A summary (L1)
171#[derive(Debug, Clone)]
172pub struct Summary {
173    /// Unique identifier
174    pub id: String,
175
176    /// Summarized content
177    pub content: String,
178
179    /// IDs of episodes that were summarized
180    pub source_episode_ids: Vec<String>,
181
182    /// Time range covered
183    pub time_range: (f64, f64),
184
185    /// Summary embedding
186    pub embedding: Option<Vec<f32>>,
187
188    /// Token count
189    pub token_count: usize,
190
191    /// When this summary was created
192    pub created_at: f64,
193
194    /// Topics/themes extracted
195    pub topics: Vec<String>,
196}
197
198/// An abstraction (L2)
199#[derive(Debug, Clone)]
200pub struct Abstraction {
201    /// Unique identifier
202    pub id: String,
203
204    /// High-level abstraction content
205    pub content: String,
206
207    /// IDs of summaries that were abstracted
208    pub source_summary_ids: Vec<String>,
209
210    /// Time range covered
211    pub time_range: (f64, f64),
212
213    /// Abstraction embedding
214    pub embedding: Option<Vec<f32>>,
215
216    /// Token count
217    pub token_count: usize,
218
219    /// When this abstraction was created
220    pub created_at: f64,
221
222    /// Key insights
223    pub insights: Vec<String>,
224}
225
226// ============================================================================
227// Summarizer Trait
228// ============================================================================
229
230/// Trait for summarization backends
231pub trait Summarizer: Send + Sync {
232    /// Summarize a group of episodes into a single summary
233    fn summarize_episodes(&self, episodes: &[Episode]) -> Result<String, CompactionError>;
234
235    /// Summarize a group of summaries into an abstraction
236    fn abstract_summaries(&self, summaries: &[Summary]) -> Result<String, CompactionError>;
237
238    /// Extract topics/themes from content
239    fn extract_topics(&self, content: &str) -> Vec<String>;
240}
241
242/// Extractive summarizer (no LLM required)
243pub struct ExtractiveSummarizer {
244    /// Maximum sentences to include
245    pub max_sentences: usize,
246
247    /// Whether to include timestamps
248    pub include_timestamps: bool,
249}
250
251impl Default for ExtractiveSummarizer {
252    fn default() -> Self {
253        Self {
254            max_sentences: 5,
255            include_timestamps: true,
256        }
257    }
258}
259
260impl Summarizer for ExtractiveSummarizer {
261    fn summarize_episodes(&self, episodes: &[Episode]) -> Result<String, CompactionError> {
262        if episodes.is_empty() {
263            return Ok(String::new());
264        }
265
266        let mut summary_parts = Vec::new();
267
268        // Time range
269        let first_ts = episodes.first().map(|e| e.timestamp).unwrap_or(0.0);
270        let last_ts = episodes.last().map(|e| e.timestamp).unwrap_or(0.0);
271
272        if self.include_timestamps {
273            summary_parts.push(format!(
274                "[{} episodes over {:.0} seconds]",
275                episodes.len(),
276                last_ts - first_ts
277            ));
278        }
279
280        // Group by type and summarize
281        let mut by_type: HashMap<EpisodeType, Vec<&Episode>> = HashMap::new();
282        for episode in episodes {
283            by_type
284                .entry(episode.episode_type)
285                .or_default()
286                .push(episode);
287        }
288
289        // Extract key content from each type
290        for (ep_type, eps) in by_type {
291            let type_name = match ep_type {
292                EpisodeType::UserMessage => "User messages",
293                EpisodeType::AssistantResponse => "Responses",
294                EpisodeType::ToolCall => "Tool calls",
295                EpisodeType::ToolResult => "Tool results",
296                EpisodeType::SystemEvent => "Events",
297                EpisodeType::Observation => "Observations",
298            };
299
300            // Take first sentence from each, up to max_sentences
301            let sentences: Vec<String> = eps
302                .iter()
303                .take(self.max_sentences)
304                .filter_map(|e| e.content.split('.').next().map(|s| s.trim().to_string()))
305                .filter(|s| !s.is_empty())
306                .collect();
307
308            if !sentences.is_empty() {
309                summary_parts.push(format!("{}: {}", type_name, sentences.join("; ")));
310            }
311        }
312
313        Ok(summary_parts.join("\n"))
314    }
315
316    fn abstract_summaries(&self, summaries: &[Summary]) -> Result<String, CompactionError> {
317        if summaries.is_empty() {
318            return Ok(String::new());
319        }
320
321        let mut abstraction_parts = Vec::new();
322
323        // Time range
324        let first_ts = summaries
325            .iter()
326            .map(|s| s.time_range.0)
327            .fold(f64::MAX, f64::min);
328        let last_ts = summaries
329            .iter()
330            .map(|s| s.time_range.1)
331            .fold(f64::MIN, f64::max);
332
333        abstraction_parts.push(format!(
334            "[{} summaries, {:.1} hours span]",
335            summaries.len(),
336            (last_ts - first_ts) / 3600.0
337        ));
338
339        // Collect all topics
340        let all_topics: Vec<&str> = summaries
341            .iter()
342            .flat_map(|s| s.topics.iter().map(|t| t.as_str()))
343            .collect();
344
345        if !all_topics.is_empty() {
346            let unique_topics: Vec<_> = all_topics
347                .iter()
348                .cloned()
349                .collect::<std::collections::HashSet<_>>()
350                .into_iter()
351                .take(10)
352                .collect();
353            abstraction_parts.push(format!("Topics: {}", unique_topics.join(", ")));
354        }
355
356        // Take first line of each summary
357        let key_points: Vec<String> = summaries
358            .iter()
359            .take(5)
360            .filter_map(|s| s.content.lines().next().map(|l| l.to_string()))
361            .collect();
362
363        if !key_points.is_empty() {
364            abstraction_parts.push(format!("Key points:\n- {}", key_points.join("\n- ")));
365        }
366
367        Ok(abstraction_parts.join("\n"))
368    }
369
370    fn extract_topics(&self, content: &str) -> Vec<String> {
371        // Simple keyword extraction (would use NLP in production)
372        let stopwords = [
373            "the", "a", "an", "is", "are", "was", "were", "to", "from", "in", "on", "at", "for",
374            "and", "or",
375        ];
376
377        let words: Vec<&str> = content
378            .split_whitespace()
379            .filter(|w| w.len() > 3)
380            .filter(|w| !stopwords.contains(&w.to_lowercase().as_str()))
381            .collect();
382
383        // Count word frequencies
384        let mut freq: HashMap<String, usize> = HashMap::new();
385        for word in words {
386            let normalized = word
387                .to_lowercase()
388                .trim_matches(|c: char| !c.is_alphanumeric())
389                .to_string();
390            if normalized.len() > 3 {
391                *freq.entry(normalized).or_insert(0) += 1;
392            }
393        }
394
395        // Return top 5 by frequency
396        let mut sorted: Vec<_> = freq.into_iter().collect();
397        sorted.sort_by(|a, b| b.1.cmp(&a.1));
398
399        sorted.into_iter().take(5).map(|(w, _)| w).collect()
400    }
401}
402
403// ============================================================================
404// Memory Store
405// ============================================================================
406
407/// Hierarchical memory store with compaction
408pub struct HierarchicalMemory<S: Summarizer> {
409    /// Configuration
410    config: MemoryCompactionConfig,
411
412    /// L0: Raw episodes
413    l0_episodes: RwLock<VecDeque<Episode>>,
414
415    /// L1: Summaries
416    l1_summaries: RwLock<VecDeque<Summary>>,
417
418    /// L2: Abstractions
419    l2_abstractions: RwLock<VecDeque<Abstraction>>,
420
421    /// Summarizer backend
422    summarizer: Arc<S>,
423
424    /// Compaction statistics
425    stats: RwLock<CompactionStats>,
426
427    /// ID counter
428    next_id: std::sync::atomic::AtomicU64,
429}
430
431/// Compaction statistics
432#[derive(Debug, Clone, Default)]
433pub struct CompactionStats {
434    /// Total episodes added
435    pub total_episodes: usize,
436
437    /// Total summaries created
438    pub total_summaries: usize,
439
440    /// Total abstractions created
441    pub total_abstractions: usize,
442
443    /// Episodes compacted (removed from L0)
444    pub episodes_compacted: usize,
445
446    /// Summaries compacted (removed from L1)
447    pub summaries_compacted: usize,
448
449    /// Last compaction time
450    pub last_compaction: Option<f64>,
451
452    /// Total token savings (estimated)
453    pub token_savings: usize,
454}
455
456impl<S: Summarizer> HierarchicalMemory<S> {
457    /// Create a new hierarchical memory store
458    pub fn new(config: MemoryCompactionConfig, summarizer: Arc<S>) -> Self {
459        Self {
460            config,
461            l0_episodes: RwLock::new(VecDeque::new()),
462            l1_summaries: RwLock::new(VecDeque::new()),
463            l2_abstractions: RwLock::new(VecDeque::new()),
464            summarizer,
465            stats: RwLock::new(CompactionStats::default()),
466            next_id: std::sync::atomic::AtomicU64::new(1),
467        }
468    }
469
470    /// Generate next ID
471    fn next_id(&self) -> String {
472        let id = self
473            .next_id
474            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
475        format!("mem_{}", id)
476    }
477
478    /// Add an episode to L0
479    pub fn add_episode(&self, content: String, episode_type: EpisodeType) -> String {
480        let id = self.next_id();
481        let timestamp = SystemTime::now()
482            .duration_since(UNIX_EPOCH)
483            .unwrap_or_default()
484            .as_secs_f64();
485
486        let token_count = content.len() / 4; // Rough estimate
487
488        let episode = Episode {
489            id: id.clone(),
490            timestamp,
491            content,
492            episode_type,
493            metadata: HashMap::new(),
494            embedding: None,
495            token_count,
496        };
497
498        {
499            let mut l0 = self.l0_episodes.write().unwrap();
500            l0.push_back(episode);
501        }
502
503        {
504            let mut stats = self.stats.write().unwrap();
505            stats.total_episodes += 1;
506        }
507
508        id
509    }
510
511    /// Add episode with embedding
512    pub fn add_episode_with_embedding(
513        &self,
514        content: String,
515        episode_type: EpisodeType,
516        embedding: Vec<f32>,
517    ) -> String {
518        let id = self.next_id();
519        let timestamp = SystemTime::now()
520            .duration_since(UNIX_EPOCH)
521            .unwrap_or_default()
522            .as_secs_f64();
523
524        let token_count = content.len() / 4;
525
526        let episode = Episode {
527            id: id.clone(),
528            timestamp,
529            content,
530            episode_type,
531            metadata: HashMap::new(),
532            embedding: Some(embedding),
533            token_count,
534        };
535
536        {
537            let mut l0 = self.l0_episodes.write().unwrap();
538            l0.push_back(episode);
539        }
540
541        {
542            let mut stats = self.stats.write().unwrap();
543            stats.total_episodes += 1;
544        }
545
546        id
547    }
548
549    /// Check if compaction is needed and run if so
550    pub fn maybe_compact(&self) -> Result<bool, CompactionError> {
551        let needs_l0 = {
552            let l0 = self.l0_episodes.read().unwrap();
553            l0.len() >= self.config.l0_max_episodes
554        };
555
556        let needs_l1 = {
557            let l1 = self.l1_summaries.read().unwrap();
558            l1.len() >= self.config.l1_max_summaries
559        };
560
561        if needs_l0 || needs_l1 {
562            self.run_compaction()?;
563            return Ok(true);
564        }
565
566        Ok(false)
567    }
568
569    /// Run compaction cycle
570    pub fn run_compaction(&self) -> Result<(), CompactionError> {
571        // L0 → L1 compaction
572        self.compact_l0_to_l1()?;
573
574        // L1 → L2 compaction
575        self.compact_l1_to_l2()?;
576
577        // Update stats
578        {
579            let mut stats = self.stats.write().unwrap();
580            stats.last_compaction = Some(
581                SystemTime::now()
582                    .duration_since(UNIX_EPOCH)
583                    .unwrap_or_default()
584                    .as_secs_f64(),
585            );
586        }
587
588        Ok(())
589    }
590
591    /// Compact L0 episodes to L1 summaries
592    fn compact_l0_to_l1(&self) -> Result<(), CompactionError> {
593        let now = SystemTime::now()
594            .duration_since(UNIX_EPOCH)
595            .unwrap_or_default()
596            .as_secs_f64();
597
598        let age_threshold = now - self.config.l0_age_threshold_secs as f64;
599
600        // Collect episodes to compact
601        let to_compact: Vec<Episode> = {
602            let l0 = self.l0_episodes.read().unwrap();
603            l0.iter()
604                .filter(|e| e.timestamp < age_threshold)
605                .cloned()
606                .collect()
607        };
608
609        if to_compact.is_empty() {
610            return Ok(());
611        }
612
613        // Group episodes by similarity or time
614        let groups = self.group_episodes(&to_compact);
615
616        // Summarize each group
617        for group in groups {
618            if group.is_empty() {
619                continue;
620            }
621
622            let content = self.summarizer.summarize_episodes(&group)?;
623            let topics = self.summarizer.extract_topics(&content);
624
625            let first_ts = group.iter().map(|e| e.timestamp).fold(f64::MAX, f64::min);
626            let last_ts = group.iter().map(|e| e.timestamp).fold(f64::MIN, f64::max);
627
628            let episode_ids: Vec<String> = group.iter().map(|e| e.id.clone()).collect();
629            let original_tokens: usize = group.iter().map(|e| e.token_count).sum();
630            let summary_tokens = content.len() / 4;
631
632            let summary = Summary {
633                id: self.next_id(),
634                content,
635                source_episode_ids: episode_ids,
636                time_range: (first_ts, last_ts),
637                embedding: None, // Would be generated if reembed_summaries is true
638                token_count: summary_tokens,
639                created_at: now,
640                topics,
641            };
642
643            // Add summary to L1
644            {
645                let mut l1 = self.l1_summaries.write().unwrap();
646                l1.push_back(summary);
647            }
648
649            // Update stats
650            {
651                let mut stats = self.stats.write().unwrap();
652                stats.total_summaries += 1;
653                stats.episodes_compacted += group.len();
654                stats.token_savings += original_tokens.saturating_sub(summary_tokens);
655            }
656        }
657
658        // Remove compacted episodes from L0
659        {
660            let mut l0 = self.l0_episodes.write().unwrap();
661            l0.retain(|e| e.timestamp >= age_threshold);
662        }
663
664        Ok(())
665    }
666
667    /// Compact L1 summaries to L2 abstractions
668    fn compact_l1_to_l2(&self) -> Result<(), CompactionError> {
669        let now = SystemTime::now()
670            .duration_since(UNIX_EPOCH)
671            .unwrap_or_default()
672            .as_secs_f64();
673
674        let age_threshold = now - self.config.l1_age_threshold_secs as f64;
675
676        // Collect summaries to compact
677        let to_compact: Vec<Summary> = {
678            let l1 = self.l1_summaries.read().unwrap();
679            l1.iter()
680                .filter(|s| s.created_at < age_threshold)
681                .cloned()
682                .collect()
683        };
684
685        if to_compact.len() < self.config.group_size {
686            return Ok(());
687        }
688
689        // Group summaries
690        let groups = self.group_summaries(&to_compact);
691
692        for group in groups {
693            if group.is_empty() {
694                continue;
695            }
696
697            let content = self.summarizer.abstract_summaries(&group)?;
698
699            let first_ts = group
700                .iter()
701                .map(|s| s.time_range.0)
702                .fold(f64::MAX, f64::min);
703            let last_ts = group
704                .iter()
705                .map(|s| s.time_range.1)
706                .fold(f64::MIN, f64::max);
707
708            let summary_ids: Vec<String> = group.iter().map(|s| s.id.clone()).collect();
709            let original_tokens: usize = group.iter().map(|s| s.token_count).sum();
710            let abstraction_tokens = content.len() / 4;
711
712            // Extract insights from topics
713            let insights: Vec<String> = group
714                .iter()
715                .flat_map(|s| s.topics.clone())
716                .collect::<std::collections::HashSet<_>>()
717                .into_iter()
718                .take(5)
719                .collect();
720
721            let abstraction = Abstraction {
722                id: self.next_id(),
723                content,
724                source_summary_ids: summary_ids,
725                time_range: (first_ts, last_ts),
726                embedding: None,
727                token_count: abstraction_tokens,
728                created_at: now,
729                insights,
730            };
731
732            // Add abstraction to L2
733            {
734                let mut l2 = self.l2_abstractions.write().unwrap();
735                l2.push_back(abstraction);
736            }
737
738            // Update stats
739            {
740                let mut stats = self.stats.write().unwrap();
741                stats.total_abstractions += 1;
742                stats.summaries_compacted += group.len();
743                stats.token_savings += original_tokens.saturating_sub(abstraction_tokens);
744            }
745        }
746
747        // Remove compacted summaries from L1
748        {
749            let mut l1 = self.l1_summaries.write().unwrap();
750            l1.retain(|s| s.created_at >= age_threshold);
751        }
752
753        Ok(())
754    }
755
756    /// Group episodes by time windows (simplified)
757    fn group_episodes(&self, episodes: &[Episode]) -> Vec<Vec<Episode>> {
758        // Simple grouping by fixed size
759        episodes
760            .chunks(self.config.group_size)
761            .map(|chunk| chunk.to_vec())
762            .collect()
763    }
764
765    /// Group summaries by time windows
766    fn group_summaries(&self, summaries: &[Summary]) -> Vec<Vec<Summary>> {
767        summaries
768            .chunks(self.config.group_size)
769            .map(|chunk| chunk.to_vec())
770            .collect()
771    }
772
773    /// Get total token count across all tiers
774    pub fn total_tokens(&self) -> usize {
775        let l0: usize = self
776            .l0_episodes
777            .read()
778            .unwrap()
779            .iter()
780            .map(|e| e.token_count)
781            .sum();
782        let l1: usize = self
783            .l1_summaries
784            .read()
785            .unwrap()
786            .iter()
787            .map(|s| s.token_count)
788            .sum();
789        let l2: usize = self
790            .l2_abstractions
791            .read()
792            .unwrap()
793            .iter()
794            .map(|a| a.token_count)
795            .sum();
796
797        l0 + l1 + l2
798    }
799
800    /// Get memory for context assembly (most recent first)
801    pub fn get_context(&self, max_tokens: usize) -> Vec<MemoryEntry> {
802        let mut entries = Vec::new();
803        let mut tokens_used = 0;
804
805        // Start with L0 (most recent)
806        let l0 = self.l0_episodes.read().unwrap();
807        for episode in l0.iter().rev() {
808            if tokens_used + episode.token_count > max_tokens {
809                break;
810            }
811            entries.push(MemoryEntry::Episode(episode.clone()));
812            tokens_used += episode.token_count;
813        }
814
815        // Add L1 summaries if space
816        let l1 = self.l1_summaries.read().unwrap();
817        for summary in l1.iter().rev() {
818            if tokens_used + summary.token_count > max_tokens {
819                break;
820            }
821            entries.push(MemoryEntry::Summary(summary.clone()));
822            tokens_used += summary.token_count;
823        }
824
825        // Add L2 abstractions if space
826        let l2 = self.l2_abstractions.read().unwrap();
827        for abstraction in l2.iter().rev() {
828            if tokens_used + abstraction.token_count > max_tokens {
829                break;
830            }
831            entries.push(MemoryEntry::Abstraction(abstraction.clone()));
832            tokens_used += abstraction.token_count;
833        }
834
835        entries
836    }
837
838    /// Get statistics
839    pub fn stats(&self) -> CompactionStats {
840        self.stats.read().unwrap().clone()
841    }
842
843    /// Get tier counts
844    pub fn tier_counts(&self) -> (usize, usize, usize) {
845        let l0 = self.l0_episodes.read().unwrap().len();
846        let l1 = self.l1_summaries.read().unwrap().len();
847        let l2 = self.l2_abstractions.read().unwrap().len();
848        (l0, l1, l2)
849    }
850}
851
852/// Entry from hierarchical memory
853#[derive(Debug, Clone)]
854pub enum MemoryEntry {
855    Episode(Episode),
856    Summary(Summary),
857    Abstraction(Abstraction),
858}
859
860impl MemoryEntry {
861    /// Get content
862    pub fn content(&self) -> &str {
863        match self {
864            Self::Episode(e) => &e.content,
865            Self::Summary(s) => &s.content,
866            Self::Abstraction(a) => &a.content,
867        }
868    }
869
870    /// Get token count
871    pub fn token_count(&self) -> usize {
872        match self {
873            Self::Episode(e) => e.token_count,
874            Self::Summary(s) => s.token_count,
875            Self::Abstraction(a) => a.token_count,
876        }
877    }
878
879    /// Get tier level
880    pub fn tier(&self) -> usize {
881        match self {
882            Self::Episode(_) => 0,
883            Self::Summary(_) => 1,
884            Self::Abstraction(_) => 2,
885        }
886    }
887}
888
889// ============================================================================
890// Errors
891// ============================================================================
892
893/// Compaction error
894#[derive(Debug, Clone)]
895pub enum CompactionError {
896    /// Summarization failed
897    SummarizationFailed(String),
898    /// Embedding failed
899    EmbeddingFailed(String),
900    /// Storage error
901    StorageError(String),
902}
903
904impl std::fmt::Display for CompactionError {
905    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
906        match self {
907            Self::SummarizationFailed(msg) => write!(f, "Summarization failed: {}", msg),
908            Self::EmbeddingFailed(msg) => write!(f, "Embedding failed: {}", msg),
909            Self::StorageError(msg) => write!(f, "Storage error: {}", msg),
910        }
911    }
912}
913
914impl std::error::Error for CompactionError {}
915
916// ============================================================================
917// Convenience Functions
918// ============================================================================
919
920/// Create a hierarchical memory with extractive summarizer
921pub fn create_hierarchical_memory() -> HierarchicalMemory<ExtractiveSummarizer> {
922    HierarchicalMemory::new(
923        MemoryCompactionConfig::default(),
924        Arc::new(ExtractiveSummarizer::default()),
925    )
926}
927
928/// Create with aggressive compaction for testing
929pub fn create_test_memory() -> HierarchicalMemory<ExtractiveSummarizer> {
930    HierarchicalMemory::new(
931        MemoryCompactionConfig::aggressive(),
932        Arc::new(ExtractiveSummarizer::default()),
933    )
934}
935
936// ============================================================================
937// Tests
938// ============================================================================
939
940#[cfg(test)]
941mod tests {
942    use super::*;
943
944    #[test]
945    fn test_add_episode() {
946        let memory = create_test_memory();
947
948        let id = memory.add_episode(
949            "User asked about weather".to_string(),
950            EpisodeType::UserMessage,
951        );
952
953        assert!(id.starts_with("mem_"));
954
955        let (l0, l1, l2) = memory.tier_counts();
956        assert_eq!(l0, 1);
957        assert_eq!(l1, 0);
958        assert_eq!(l2, 0);
959    }
960
961    #[test]
962    fn test_extractive_summarizer() {
963        let summarizer = ExtractiveSummarizer::default();
964
965        let episodes = vec![
966            Episode {
967                id: "1".to_string(),
968                timestamp: 0.0,
969                content: "User asked about the weather forecast.".to_string(),
970                episode_type: EpisodeType::UserMessage,
971                metadata: HashMap::new(),
972                embedding: None,
973                token_count: 10,
974            },
975            Episode {
976                id: "2".to_string(),
977                timestamp: 1.0,
978                content: "Assistant provided weather information for NYC.".to_string(),
979                episode_type: EpisodeType::AssistantResponse,
980                metadata: HashMap::new(),
981                embedding: None,
982                token_count: 12,
983            },
984        ];
985
986        let summary = summarizer.summarize_episodes(&episodes).unwrap();
987
988        assert!(!summary.is_empty());
989        assert!(
990            summary.contains("episodes")
991                || summary.contains("User")
992                || summary.contains("Responses")
993        );
994    }
995
996    #[test]
997    fn test_topic_extraction() {
998        let summarizer = ExtractiveSummarizer::default();
999
1000        let content = "The weather forecast shows sunny conditions with temperatures around 75 degrees. Tomorrow expects rain and thunderstorms across the region.";
1001
1002        let topics = summarizer.extract_topics(content);
1003
1004        assert!(!topics.is_empty());
1005        // Should extract meaningful words like "weather", "forecast", "temperatures", etc.
1006    }
1007
1008    #[test]
1009    fn test_memory_context_retrieval() {
1010        let memory = create_test_memory();
1011
1012        // Add some episodes
1013        for i in 0..5 {
1014            memory.add_episode(
1015                format!("Episode {} content here with some text.", i),
1016                EpisodeType::UserMessage,
1017            );
1018        }
1019
1020        let context = memory.get_context(1000);
1021
1022        assert!(!context.is_empty());
1023
1024        // All should be L0 episodes
1025        for entry in &context {
1026            assert_eq!(entry.tier(), 0);
1027        }
1028    }
1029
1030    #[test]
1031    fn test_token_tracking() {
1032        let memory = create_test_memory();
1033
1034        memory.add_episode("Short message".to_string(), EpisodeType::UserMessage);
1035
1036        memory.add_episode(
1037            "A much longer message with more content that should have more tokens estimated"
1038                .to_string(),
1039            EpisodeType::AssistantResponse,
1040        );
1041
1042        let total = memory.total_tokens();
1043        assert!(total > 0);
1044    }
1045
1046    #[test]
1047    fn test_stats_tracking() {
1048        let memory = create_test_memory();
1049
1050        for _ in 0..10 {
1051            memory.add_episode("Test episode".to_string(), EpisodeType::UserMessage);
1052        }
1053
1054        let stats = memory.stats();
1055        assert_eq!(stats.total_episodes, 10);
1056    }
1057}