Skip to main content

rustant_core/
memory.rs

1//! Three-tier memory system for the Rustant agent.
2//!
3//! - **Working Memory**: Current task state and scratch data (single task lifetime).
4//! - **Short-Term Memory**: Sliding window of recent conversation with summarization.
5//! - **Long-Term Memory**: Persistent facts and preferences across sessions.
6
7use crate::error::MemoryError;
8use crate::search::{HybridSearchEngine, SearchConfig};
9use crate::types::{Content, Message, Role};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::path::Path;
14use uuid::Uuid;
15
16/// Working memory for the currently executing task.
17#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18pub struct WorkingMemory {
19    pub current_goal: Option<String>,
20    pub sub_tasks: Vec<String>,
21    pub scratchpad: HashMap<String, String>,
22    pub active_files: Vec<String>,
23}
24
25impl WorkingMemory {
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn set_goal(&mut self, goal: impl Into<String>) {
31        self.current_goal = Some(goal.into());
32    }
33
34    pub fn add_sub_task(&mut self, task: impl Into<String>) {
35        self.sub_tasks.push(task.into());
36    }
37
38    pub fn note(&mut self, key: impl Into<String>, value: impl Into<String>) {
39        self.scratchpad.insert(key.into(), value.into());
40    }
41
42    pub fn add_active_file(&mut self, path: impl Into<String>) {
43        let path = path.into();
44        if !self.active_files.contains(&path) {
45            self.active_files.push(path);
46        }
47    }
48
49    pub fn clear(&mut self) {
50        *self = Self::default();
51    }
52}
53
54/// Short-term memory: sliding window of recent messages with summarization support.
55#[derive(Debug, Clone)]
56pub struct ShortTermMemory {
57    messages: VecDeque<Message>,
58    window_size: usize,
59    summarized_prefix: Option<String>,
60    total_messages_seen: usize,
61    /// Set of absolute message indices that are pinned (survive compression).
62    pinned: std::collections::HashSet<usize>,
63    /// Running count of messages removed by compression (for index mapping).
64    compressed_offset: usize,
65}
66
67impl ShortTermMemory {
68    pub fn new(window_size: usize) -> Self {
69        Self {
70            messages: VecDeque::new(),
71            window_size,
72            summarized_prefix: None,
73            total_messages_seen: 0,
74            pinned: std::collections::HashSet::new(),
75            compressed_offset: 0,
76        }
77    }
78
79    /// Add a message to short-term memory.
80    pub fn add(&mut self, message: Message) {
81        self.messages.push_back(message);
82        self.total_messages_seen += 1;
83    }
84
85    /// Get all messages that should be sent to the LLM, including any summary.
86    pub fn to_messages(&self) -> Vec<Message> {
87        let mut result = Vec::new();
88
89        // Include summary of older context if available
90        if let Some(ref summary) = self.summarized_prefix {
91            result.push(Message::system(format!(
92                "[Summary of earlier conversation]\n{}",
93                summary
94            )));
95        }
96
97        // Include recent messages within the window.
98        // Pinned messages that fall before the window start are still included.
99        let start = if self.messages.len() > self.window_size {
100            self.messages.len() - self.window_size
101        } else {
102            0
103        };
104
105        for (i, msg) in self.messages.iter().enumerate() {
106            if i >= start || self.is_pinned(i) {
107                result.push(msg.clone());
108            }
109        }
110
111        result
112    }
113
114    /// Check whether compression is needed based on message count.
115    pub fn needs_compression(&self) -> bool {
116        self.messages.len() >= self.window_size * 2
117    }
118
119    /// Compress older messages by replacing them with a summary.
120    /// Pinned messages are preserved and moved to the front of the window.
121    /// When a pinned message is a tool_result, its corresponding tool_call
122    /// message is also preserved (and vice versa) to maintain valid sequences.
123    /// Returns the number of messages that were compressed.
124    pub fn compress(&mut self, summary: String) -> usize {
125        if self.messages.len() <= self.window_size {
126            return 0;
127        }
128
129        let to_remove = self.messages.len() - self.window_size;
130
131        // First pass: find indices that are pinned
132        let mut preserve_indices: HashSet<usize> = HashSet::new();
133        for i in 0..to_remove {
134            let abs_idx = self.compressed_offset + i;
135            if self.pinned.contains(&abs_idx) {
136                preserve_indices.insert(i);
137            }
138        }
139
140        // Second pass: for each preserved message, also preserve its tool pair
141        let mut extra_preserves: Vec<usize> = Vec::new();
142        for &i in &preserve_indices {
143            if let Some(msg) = self.messages.get(i) {
144                match &msg.content {
145                    // If pinned message is a tool_result, find its tool_call
146                    Content::ToolResult { call_id, .. } => {
147                        if let Some(pair_idx) =
148                            Self::find_tool_call_for_result(call_id, &self.messages, to_remove)
149                        {
150                            if !preserve_indices.contains(&pair_idx) {
151                                extra_preserves.push(pair_idx);
152                            }
153                        }
154                    }
155                    // If pinned message is a tool_call, find its tool_result
156                    Content::ToolCall { id, .. } => {
157                        if let Some(pair_idx) =
158                            Self::find_tool_result_for_call(id, &self.messages, to_remove)
159                        {
160                            if !preserve_indices.contains(&pair_idx) {
161                                extra_preserves.push(pair_idx);
162                            }
163                        }
164                    }
165                    // MultiPart: check both tool_calls and tool_results
166                    Content::MultiPart { parts } => {
167                        for part in parts {
168                            match part {
169                                Content::ToolCall { id, .. } => {
170                                    if let Some(pair_idx) = Self::find_tool_result_for_call(
171                                        id,
172                                        &self.messages,
173                                        to_remove,
174                                    ) {
175                                        if !preserve_indices.contains(&pair_idx) {
176                                            extra_preserves.push(pair_idx);
177                                        }
178                                    }
179                                }
180                                Content::ToolResult { call_id, .. } => {
181                                    if let Some(pair_idx) = Self::find_tool_call_for_result(
182                                        call_id,
183                                        &self.messages,
184                                        to_remove,
185                                    ) {
186                                        if !preserve_indices.contains(&pair_idx) {
187                                            extra_preserves.push(pair_idx);
188                                        }
189                                    }
190                                }
191                                _ => {}
192                            }
193                        }
194                    }
195                    _ => {}
196                }
197            }
198        }
199        for idx in extra_preserves {
200            preserve_indices.insert(idx);
201        }
202
203        // Collect preserved messages in order, count removed
204        let mut preserved = Vec::new();
205        let mut removed_count = 0;
206
207        for i in 0..to_remove {
208            if preserve_indices.contains(&i) {
209                if let Some(msg) = self.messages.get(i) {
210                    preserved.push(msg.clone());
211                }
212            } else {
213                removed_count += 1;
214            }
215        }
216
217        // Collect absolute indices that are pinned and remain in the window
218        // (i.e. messages beyond to_remove that were pinned).
219        let mut surviving_pinned: Vec<usize> = Vec::new();
220        for i in to_remove..self.messages.len() {
221            let abs_idx = self.compressed_offset + i;
222            if self.pinned.contains(&abs_idx) {
223                surviving_pinned.push(i - to_remove);
224            }
225        }
226
227        // Remove the to_remove oldest messages
228        self.messages.drain(..to_remove);
229        self.compressed_offset += to_remove;
230
231        // Re-insert pinned messages at the front of the window using batch operation.
232        // This avoids O(p^2) sequential insert() calls by building a new VecDeque.
233        let preserved_count = preserved.len();
234        if !preserved.is_empty() {
235            let mut new_messages = VecDeque::with_capacity(preserved_count + self.messages.len());
236            for msg in preserved {
237                new_messages.push_back(msg);
238            }
239            new_messages.append(&mut self.messages);
240            self.messages = new_messages;
241        }
242
243        // Rebuild the pinned set with correct absolute indices.
244        // Preserved messages are at positions 0..preserved_count.
245        // Surviving pinned messages shifted by preserved_count.
246        let mut new_pinned = HashSet::new();
247        for i in 0..preserved_count {
248            new_pinned.insert(self.compressed_offset + i);
249        }
250        for pos in surviving_pinned {
251            new_pinned.insert(self.compressed_offset + preserved_count + pos);
252        }
253        self.pinned = new_pinned;
254
255        // Merge with existing summary
256        if let Some(ref existing) = self.summarized_prefix {
257            self.summarized_prefix = Some(format!("{}\n\n{}", existing, summary));
258        } else {
259            self.summarized_prefix = Some(summary);
260        }
261
262        removed_count
263    }
264
265    /// Find the index (within `range 0..limit`) of an assistant message containing
266    /// a tool_call with the given ID, paired with a tool_result's `call_id`.
267    fn find_tool_call_for_result(
268        call_id: &str,
269        messages: &VecDeque<Message>,
270        limit: usize,
271    ) -> Option<usize> {
272        messages
273            .iter()
274            .enumerate()
275            .take(limit)
276            .find(|(_, msg)| {
277                msg.role == Role::Assistant
278                    && Self::content_contains_tool_call_id(&msg.content, call_id)
279            })
280            .map(|(i, _)| i)
281    }
282
283    /// Find the index (within `range 0..limit`) of a tool/user message containing
284    /// a tool_result with the given `call_id`, paired with a tool_call's ID.
285    fn find_tool_result_for_call(
286        tool_call_id: &str,
287        messages: &VecDeque<Message>,
288        limit: usize,
289    ) -> Option<usize> {
290        messages
291            .iter()
292            .enumerate()
293            .take(limit)
294            .find(|(_, msg)| {
295                (msg.role == Role::Tool || msg.role == Role::User)
296                    && Self::content_contains_tool_result_id(&msg.content, tool_call_id)
297            })
298            .map(|(i, _)| i)
299    }
300
301    /// Check if a Content contains a tool_call with the given ID.
302    fn content_contains_tool_call_id(content: &Content, target_id: &str) -> bool {
303        match content {
304            Content::ToolCall { id, .. } => id == target_id,
305            Content::MultiPart { parts } => parts
306                .iter()
307                .any(|p| Self::content_contains_tool_call_id(p, target_id)),
308            _ => false,
309        }
310    }
311
312    /// Check if a Content contains a tool_result with the given call_id.
313    fn content_contains_tool_result_id(content: &Content, target_id: &str) -> bool {
314        match content {
315            Content::ToolResult { call_id, .. } => call_id == target_id,
316            Content::MultiPart { parts } => parts
317                .iter()
318                .any(|p| Self::content_contains_tool_result_id(p, target_id)),
319            _ => false,
320        }
321    }
322
323    /// Pin a message by its position in the current message list (0-based).
324    /// Pinned messages survive compression.
325    pub fn pin(&mut self, position: usize) -> bool {
326        if position >= self.messages.len() {
327            return false;
328        }
329        let abs_idx = self.compressed_offset + position;
330        self.pinned.insert(abs_idx);
331        true
332    }
333
334    /// Unpin a message by its position in the current message list.
335    pub fn unpin(&mut self, position: usize) -> bool {
336        if position >= self.messages.len() {
337            return false;
338        }
339        let abs_idx = self.compressed_offset + position;
340        self.pinned.remove(&abs_idx)
341    }
342
343    /// Check if a message at the given position is pinned.
344    pub fn is_pinned(&self, position: usize) -> bool {
345        let abs_idx = self.compressed_offset + position;
346        self.pinned.contains(&abs_idx)
347    }
348
349    /// Number of currently pinned messages.
350    pub fn pinned_count(&self) -> usize {
351        // Count pinned messages that are still in the current window
352        (0..self.messages.len())
353            .filter(|&i| self.is_pinned(i))
354            .count()
355    }
356
357    /// Get messages that should be summarized (older than window).
358    pub fn messages_to_summarize(&self) -> Vec<&Message> {
359        if self.messages.len() <= self.window_size {
360            return Vec::new();
361        }
362        let to_summarize = self.messages.len() - self.window_size;
363        self.messages.iter().take(to_summarize).collect()
364    }
365
366    /// Get the total number of messages currently held.
367    pub fn len(&self) -> usize {
368        self.messages.len()
369    }
370
371    /// Check if memory is empty.
372    pub fn is_empty(&self) -> bool {
373        self.messages.is_empty()
374    }
375
376    /// Get total messages seen across the session.
377    pub fn total_messages_seen(&self) -> usize {
378        self.total_messages_seen
379    }
380
381    /// Clear all messages and summary.
382    pub fn clear(&mut self) {
383        self.messages.clear();
384        self.summarized_prefix = None;
385        self.total_messages_seen = 0;
386        self.pinned.clear();
387        self.compressed_offset = 0;
388    }
389
390    /// Get a reference to all current messages.
391    pub fn messages(&self) -> &VecDeque<Message> {
392        &self.messages
393    }
394
395    /// Get the summary prefix, if any.
396    pub fn summary(&self) -> Option<&str> {
397        self.summarized_prefix.as_deref()
398    }
399}
400
401/// A fact extracted from conversation for long-term storage.
402#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct Fact {
404    pub id: Uuid,
405    pub content: String,
406    pub source: String,
407    pub created_at: DateTime<Utc>,
408    pub tags: Vec<String>,
409}
410
411impl Fact {
412    pub fn new(content: impl Into<String>, source: impl Into<String>) -> Self {
413        Self {
414            id: Uuid::new_v4(),
415            content: content.into(),
416            source: source.into(),
417            created_at: Utc::now(),
418            tags: Vec::new(),
419        }
420    }
421
422    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
423        self.tags = tags;
424        self
425    }
426}
427
428/// Long-term memory persisted across sessions.
429#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct LongTermMemory {
431    pub facts: Vec<Fact>,
432    pub preferences: HashMap<String, String>,
433    pub corrections: Vec<Correction>,
434    /// Maximum number of facts to retain. When exceeded, the oldest fact is evicted.
435    #[serde(default = "LongTermMemory::default_max_facts")]
436    pub max_facts: usize,
437    /// Maximum number of corrections to retain. When exceeded, the oldest correction is evicted.
438    #[serde(default = "LongTermMemory::default_max_corrections")]
439    pub max_corrections: usize,
440}
441
442impl Default for LongTermMemory {
443    fn default() -> Self {
444        Self {
445            facts: Vec::new(),
446            preferences: HashMap::new(),
447            corrections: Vec::new(),
448            max_facts: Self::default_max_facts(),
449            max_corrections: Self::default_max_corrections(),
450        }
451    }
452}
453
454/// A correction recorded from user feedback.
455#[derive(Debug, Clone, Serialize, Deserialize)]
456pub struct Correction {
457    pub id: Uuid,
458    pub original: String,
459    pub corrected: String,
460    pub context: String,
461    pub timestamp: DateTime<Utc>,
462}
463
464impl LongTermMemory {
465    pub fn new() -> Self {
466        Self::default()
467    }
468
469    fn default_max_facts() -> usize {
470        10_000
471    }
472
473    fn default_max_corrections() -> usize {
474        1_000
475    }
476
477    pub fn add_fact(&mut self, fact: Fact) {
478        if self.facts.len() >= self.max_facts {
479            self.facts.remove(0);
480        }
481        self.facts.push(fact);
482    }
483
484    pub fn set_preference(&mut self, key: impl Into<String>, value: impl Into<String>) {
485        self.preferences.insert(key.into(), value.into());
486    }
487
488    pub fn get_preference(&self, key: &str) -> Option<&str> {
489        self.preferences.get(key).map(|s| s.as_str())
490    }
491
492    pub fn add_correction(&mut self, original: String, corrected: String, context: String) {
493        if self.corrections.len() >= self.max_corrections {
494            self.corrections.remove(0);
495        }
496        self.corrections.push(Correction {
497            id: Uuid::new_v4(),
498            original,
499            corrected,
500            context,
501            timestamp: Utc::now(),
502        });
503    }
504
505    /// Search facts by keyword.
506    pub fn search_facts(&self, query: &str) -> Vec<&Fact> {
507        let query_lower = query.to_lowercase();
508        self.facts
509            .iter()
510            .filter(|f| {
511                f.content.to_lowercase().contains(&query_lower)
512                    || f.tags
513                        .iter()
514                        .any(|t| t.to_lowercase().contains(&query_lower))
515            })
516            .collect()
517    }
518}
519
520/// The unified memory system combining all three tiers.
521pub struct MemorySystem {
522    pub working: WorkingMemory,
523    pub short_term: ShortTermMemory,
524    pub long_term: LongTermMemory,
525    /// Optional hybrid search engine for fact retrieval.
526    search_engine: Option<HybridSearchEngine>,
527    /// Optional automatic flusher for periodic persistence.
528    flusher: Option<MemoryFlusher>,
529}
530
531impl MemorySystem {
532    pub fn new(window_size: usize) -> Self {
533        Self {
534            working: WorkingMemory::new(),
535            short_term: ShortTermMemory::new(window_size),
536            long_term: LongTermMemory::new(),
537            search_engine: None,
538            flusher: None,
539        }
540    }
541
542    /// Create a memory system with hybrid search enabled.
543    pub fn with_search(
544        window_size: usize,
545        search_config: SearchConfig,
546    ) -> Result<Self, crate::search::SearchError> {
547        let engine = HybridSearchEngine::open(search_config)?;
548        Ok(Self {
549            working: WorkingMemory::new(),
550            short_term: ShortTermMemory::new(window_size),
551            long_term: LongTermMemory::new(),
552            search_engine: Some(engine),
553            flusher: None,
554        })
555    }
556
557    /// Attach an automatic flusher to this memory system (builder pattern).
558    pub fn with_flusher(mut self, config: FlushConfig) -> Self {
559        self.flusher = Some(MemoryFlusher::new(config));
560        self
561    }
562
563    /// Get all messages for the LLM context.
564    pub fn context_messages(&self) -> Vec<Message> {
565        self.short_term.to_messages()
566    }
567
568    /// Add a message to the conversation.
569    pub fn add_message(&mut self, message: Message) {
570        self.short_term.add(message);
571        // Notify flusher
572        if let Some(ref mut flusher) = self.flusher {
573            flusher.on_message_added();
574        }
575    }
576
577    /// Add a fact to long-term memory, also indexing it in the search engine.
578    pub fn add_fact(&mut self, fact: Fact) {
579        if let Some(ref mut engine) = self.search_engine {
580            let _ = engine.index_fact(&fact.id.to_string(), &fact.content);
581        }
582        self.long_term.add_fact(fact);
583    }
584
585    /// Search facts using the hybrid engine (falls back to keyword search).
586    pub fn search_facts_hybrid(&self, query: &str) -> Vec<&Fact> {
587        if let Some(ref engine) = self.search_engine {
588            if let Ok(results) = engine.search(query) {
589                let ids: Vec<String> = results.iter().map(|r| r.fact_id.clone()).collect();
590                let found: Vec<&Fact> = self
591                    .long_term
592                    .facts
593                    .iter()
594                    .filter(|f| ids.contains(&f.id.to_string()))
595                    .collect();
596                if !found.is_empty() {
597                    return found;
598                }
599            }
600        }
601        // Fallback to keyword search
602        self.long_term.search_facts(query)
603    }
604
605    /// Check if auto-flush should happen, and flush if needed.
606    ///
607    /// Uses the `Option::take()` pattern to avoid borrow conflicts.
608    pub fn check_auto_flush(&mut self) -> Result<bool, MemoryError> {
609        let mut flusher = match self.flusher.take() {
610            Some(f) => f,
611            None => return Ok(false),
612        };
613        let result = if flusher.should_flush() {
614            flusher.flush(self)?;
615            Ok(true)
616        } else {
617            Ok(false)
618        };
619        self.flusher = Some(flusher);
620        result
621    }
622
623    /// Force a flush regardless of triggers.
624    pub fn force_flush(&mut self) -> Result<(), MemoryError> {
625        let mut flusher = match self.flusher.take() {
626            Some(f) => f,
627            None => return Ok(()),
628        };
629        let result = flusher.force_flush(self);
630        self.flusher = Some(flusher);
631        result
632    }
633
634    /// Whether the flusher has unflushed data.
635    pub fn flusher_is_dirty(&self) -> bool {
636        self.flusher.as_ref().is_some_and(|f| f.is_dirty())
637    }
638
639    /// Reset working memory for a new task.
640    pub fn start_new_task(&mut self, goal: impl Into<String>) {
641        self.working.clear();
642        self.working.set_goal(goal);
643    }
644
645    /// Clear everything except long-term memory.
646    pub fn clear_session(&mut self) {
647        self.working.clear();
648        self.short_term.clear();
649    }
650
651    /// Get a breakdown of context usage for the UI.
652    pub fn context_breakdown(&self, context_window: usize) -> ContextBreakdown {
653        let summary_chars = self.short_term.summary().map(|s| s.len()).unwrap_or(0);
654        let message_chars: usize = self
655            .short_term
656            .messages()
657            .iter()
658            .map(|m| m.content_length())
659            .sum();
660        let total_chars = summary_chars + message_chars;
661
662        // Rough token estimate: ~4 chars per token
663        let summary_tokens = summary_chars / 4;
664        let message_tokens = message_chars / 4;
665        let total_tokens = total_chars / 4;
666        let remaining_tokens = context_window.saturating_sub(total_tokens);
667
668        ContextBreakdown {
669            summary_tokens,
670            message_tokens,
671            total_tokens,
672            context_window,
673            remaining_tokens,
674            message_count: self.short_term.len(),
675            total_messages_seen: self.short_term.total_messages_seen(),
676            pinned_count: self.short_term.pinned_count(),
677            has_summary: self.short_term.summary().is_some(),
678            facts_count: self.long_term.facts.len(),
679            rules_count: 0, // Populated separately if knowledge distiller is available
680        }
681    }
682
683    /// Pin a message in short-term memory by position.
684    pub fn pin_message(&mut self, position: usize) -> bool {
685        self.short_term.pin(position)
686    }
687
688    /// Unpin a message in short-term memory by position.
689    pub fn unpin_message(&mut self, position: usize) -> bool {
690        self.short_term.unpin(position)
691    }
692}
693
694/// Breakdown of context window usage for the UI.
695#[derive(Debug, Clone, Default)]
696pub struct ContextBreakdown {
697    /// Estimated tokens used by the summarized prefix.
698    pub summary_tokens: usize,
699    /// Estimated tokens used by active messages.
700    pub message_tokens: usize,
701    /// Total estimated tokens in use.
702    pub total_tokens: usize,
703    /// Total context window size (from config).
704    pub context_window: usize,
705    /// Remaining available tokens.
706    pub remaining_tokens: usize,
707    /// Number of messages currently in the window.
708    pub message_count: usize,
709    /// Total messages seen in the session.
710    pub total_messages_seen: usize,
711    /// Number of pinned messages.
712    pub pinned_count: usize,
713    /// Whether a summary prefix exists.
714    pub has_summary: bool,
715    /// Number of facts in long-term memory.
716    pub facts_count: usize,
717    /// Number of active behavioral rules.
718    pub rules_count: usize,
719}
720
721impl ContextBreakdown {
722    /// Context usage as a ratio (0.0 to 1.0).
723    pub fn usage_ratio(&self) -> f32 {
724        if self.context_window == 0 {
725            return 0.0;
726        }
727        (self.total_tokens as f32 / self.context_window as f32).clamp(0.0, 1.0)
728    }
729
730    /// Whether context is at warning level (>= 70%).
731    /// Aligned with agent's `ContextHealthEvent::Warning` threshold.
732    pub fn is_warning(&self) -> bool {
733        self.usage_ratio() >= 0.7
734    }
735}
736
737/// Metadata about a saved session.
738#[derive(Debug, Clone, Serialize, Deserialize)]
739pub struct SessionMetadata {
740    pub id: Uuid,
741    pub created_at: DateTime<Utc>,
742    pub updated_at: DateTime<Utc>,
743    pub task_summary: Option<String>,
744}
745
746impl SessionMetadata {
747    pub fn new() -> Self {
748        let now = Utc::now();
749        Self {
750            id: Uuid::new_v4(),
751            created_at: now,
752            updated_at: now,
753            task_summary: None,
754        }
755    }
756}
757
758impl Default for SessionMetadata {
759    fn default() -> Self {
760        Self::new()
761    }
762}
763
764/// A persistable session containing the memory state.
765#[derive(Debug, Clone, Serialize, Deserialize)]
766pub struct Session {
767    pub metadata: SessionMetadata,
768    pub working: WorkingMemory,
769    pub long_term: LongTermMemory,
770    pub messages: Vec<Message>,
771    pub window_size: usize,
772}
773
774impl MemorySystem {
775    /// Save the current memory state to a JSON file.
776    pub fn save_session(&self, path: &Path) -> Result<(), MemoryError> {
777        let session = Session {
778            metadata: SessionMetadata {
779                id: Uuid::new_v4(),
780                created_at: Utc::now(),
781                updated_at: Utc::now(),
782                task_summary: self.working.current_goal.clone(),
783            },
784            working: self.working.clone(),
785            long_term: self.long_term.clone(),
786            messages: self.short_term.messages().iter().cloned().collect(),
787            window_size: self.short_term.window_size(),
788        };
789
790        let json =
791            serde_json::to_string_pretty(&session).map_err(|e| MemoryError::PersistenceError {
792                message: format!("Failed to serialize session: {}", e),
793            })?;
794
795        if let Some(parent) = path.parent() {
796            std::fs::create_dir_all(parent).map_err(|e| MemoryError::PersistenceError {
797                message: format!("Failed to create directory: {}", e),
798            })?;
799        }
800
801        std::fs::write(path, json).map_err(|e| MemoryError::PersistenceError {
802            message: format!("Failed to write session file: {}", e),
803        })?;
804
805        Ok(())
806    }
807
808    /// Load a session from a JSON file.
809    pub fn load_session(path: &Path) -> Result<Self, MemoryError> {
810        let json = std::fs::read_to_string(path).map_err(|e| MemoryError::SessionLoadFailed {
811            message: format!("Failed to read session file: {}", e),
812        })?;
813
814        let session: Session =
815            serde_json::from_str(&json).map_err(|e| MemoryError::SessionLoadFailed {
816                message: format!("Failed to deserialize session: {}", e),
817            })?;
818
819        let mut memory = MemorySystem::new(session.window_size);
820        memory.working = session.working;
821        memory.long_term = session.long_term;
822        for msg in session.messages {
823            memory.short_term.add(msg);
824        }
825
826        Ok(memory)
827    }
828}
829
830impl ShortTermMemory {
831    /// Get the window size.
832    pub fn window_size(&self) -> usize {
833        self.window_size
834    }
835}
836
837/// Result of a context compression operation.
838#[derive(Debug, Clone)]
839pub struct CompressionResult {
840    pub messages_before: usize,
841    pub messages_after: usize,
842    pub compressed_count: usize,
843}
844
845// ---------------------------------------------------------------------------
846// Memory Flusher
847// ---------------------------------------------------------------------------
848
849/// Configuration for automatic memory flushing.
850#[derive(Debug, Clone, Serialize, Deserialize)]
851pub struct FlushConfig {
852    /// Whether automatic flushing is enabled.
853    pub enabled: bool,
854    /// Flush interval in seconds (0 = disabled).
855    pub interval_secs: u64,
856    /// Number of messages that triggers an auto-flush (0 = disabled).
857    pub flush_on_message_count: usize,
858    /// Path where flushed data is written.
859    pub flush_path: Option<std::path::PathBuf>,
860}
861
862impl Default for FlushConfig {
863    fn default() -> Self {
864        Self {
865            enabled: false,
866            interval_secs: 300, // 5 minutes
867            flush_on_message_count: 50,
868            flush_path: None,
869        }
870    }
871}
872
873/// Tracks dirty state and triggers for automatic memory persistence.
874#[derive(Debug, Clone)]
875pub struct MemoryFlusher {
876    config: FlushConfig,
877    dirty: bool,
878    messages_since_flush: usize,
879    last_flush: DateTime<Utc>,
880    total_flushes: usize,
881}
882
883impl MemoryFlusher {
884    /// Create a new flusher with the given configuration.
885    pub fn new(config: FlushConfig) -> Self {
886        Self {
887            config,
888            dirty: false,
889            messages_since_flush: 0,
890            last_flush: Utc::now(),
891            total_flushes: 0,
892        }
893    }
894
895    /// Notify the flusher that a message was added.
896    pub fn on_message_added(&mut self) {
897        self.dirty = true;
898        self.messages_since_flush += 1;
899    }
900
901    /// Check whether a flush should happen based on the configured triggers.
902    pub fn should_flush(&self) -> bool {
903        if !self.config.enabled || !self.dirty {
904            return false;
905        }
906
907        // Message-count trigger
908        if self.config.flush_on_message_count > 0
909            && self.messages_since_flush >= self.config.flush_on_message_count
910        {
911            return true;
912        }
913
914        // Time-based trigger
915        if self.config.interval_secs > 0 {
916            let elapsed = (Utc::now() - self.last_flush).num_seconds();
917            if elapsed >= self.config.interval_secs as i64 {
918                return true;
919            }
920        }
921
922        false
923    }
924
925    /// Perform a flush of the memory system to disk.
926    pub fn flush(&mut self, memory: &MemorySystem) -> Result<(), MemoryError> {
927        let path =
928            self.config
929                .flush_path
930                .as_ref()
931                .ok_or_else(|| MemoryError::PersistenceError {
932                    message: "No flush path configured".to_string(),
933                })?;
934
935        memory.save_session(path)?;
936        self.mark_flushed();
937        Ok(())
938    }
939
940    /// Force a flush regardless of triggers.
941    pub fn force_flush(&mut self, memory: &MemorySystem) -> Result<(), MemoryError> {
942        if !self.dirty {
943            return Ok(()); // nothing to flush
944        }
945        self.flush(memory)
946    }
947
948    /// Whether there is unflushed data.
949    pub fn is_dirty(&self) -> bool {
950        self.dirty
951    }
952
953    /// Messages added since the last flush.
954    pub fn messages_since_flush(&self) -> usize {
955        self.messages_since_flush
956    }
957
958    /// Total number of flushes performed.
959    pub fn total_flushes(&self) -> usize {
960        self.total_flushes
961    }
962
963    /// Mark the flush as completed (reset counters).
964    fn mark_flushed(&mut self) {
965        self.dirty = false;
966        self.messages_since_flush = 0;
967        self.last_flush = Utc::now();
968        self.total_flushes += 1;
969    }
970}
971
972// ---------------------------------------------------------------------------
973// Knowledge Distiller — Cross-Session Learning
974// ---------------------------------------------------------------------------
975
976/// A distilled behavioral rule generated from accumulated corrections and facts.
977#[derive(Debug, Clone, Serialize, Deserialize)]
978pub struct BehavioralRule {
979    /// Unique identifier.
980    pub id: Uuid,
981    /// Human-readable rule description (injected into system prompt).
982    pub rule: String,
983    /// Source entries (correction/fact IDs) that contributed to this rule.
984    pub source_ids: Vec<Uuid>,
985    /// How many source entries support this rule (higher = more confidence).
986    pub support_count: usize,
987    /// When this rule was distilled.
988    pub created_at: DateTime<Utc>,
989}
990
991/// Persistent knowledge store containing distilled behavioral rules.
992#[derive(Debug, Clone, Default, Serialize, Deserialize)]
993pub struct KnowledgeStore {
994    pub rules: Vec<BehavioralRule>,
995    /// Correction IDs already processed by the distiller.
996    pub processed_correction_ids: Vec<Uuid>,
997    /// Fact IDs already processed by the distiller.
998    pub processed_fact_ids: Vec<Uuid>,
999}
1000
1001impl KnowledgeStore {
1002    pub fn new() -> Self {
1003        Self::default()
1004    }
1005
1006    /// Load a knowledge store from a JSON file.
1007    pub fn load(path: &std::path::Path) -> Result<Self, MemoryError> {
1008        if !path.exists() {
1009            return Ok(Self::new());
1010        }
1011        let json = std::fs::read_to_string(path).map_err(|e| MemoryError::PersistenceError {
1012            message: format!("Failed to read knowledge store: {}", e),
1013        })?;
1014        serde_json::from_str(&json).map_err(|e| MemoryError::PersistenceError {
1015            message: format!("Failed to parse knowledge store: {}", e),
1016        })
1017    }
1018
1019    /// Save the knowledge store to a JSON file.
1020    pub fn save(&self, path: &std::path::Path) -> Result<(), MemoryError> {
1021        if let Some(parent) = path.parent() {
1022            std::fs::create_dir_all(parent).map_err(|e| MemoryError::PersistenceError {
1023                message: format!("Failed to create knowledge directory: {}", e),
1024            })?;
1025        }
1026        let json =
1027            serde_json::to_string_pretty(self).map_err(|e| MemoryError::PersistenceError {
1028                message: format!("Failed to serialize knowledge store: {}", e),
1029            })?;
1030        std::fs::write(path, json).map_err(|e| MemoryError::PersistenceError {
1031            message: format!("Failed to write knowledge store: {}", e),
1032        })
1033    }
1034}
1035
1036/// The `KnowledgeDistiller` processes accumulated corrections and facts from
1037/// `LongTermMemory` to generate compressed behavioral rules. These rules
1038/// are injected into the system prompt to influence future agent behavior,
1039/// creating a cross-session learning loop.
1040pub struct KnowledgeDistiller {
1041    store: KnowledgeStore,
1042    max_rules: usize,
1043    min_entries: usize,
1044    store_path: Option<std::path::PathBuf>,
1045}
1046
1047impl KnowledgeDistiller {
1048    /// Create a new distiller from config. If `config` is None, creates a
1049    /// disabled distiller that returns no rules.
1050    pub fn new(config: Option<&crate::config::KnowledgeConfig>) -> Self {
1051        match config {
1052            Some(cfg) if cfg.enabled => {
1053                let store = cfg
1054                    .knowledge_path
1055                    .as_ref()
1056                    .and_then(|p| KnowledgeStore::load(p).ok())
1057                    .unwrap_or_default();
1058                Self {
1059                    store,
1060                    max_rules: cfg.max_rules,
1061                    min_entries: cfg.min_entries_for_distillation,
1062                    store_path: cfg.knowledge_path.clone(),
1063                }
1064            }
1065            _ => Self {
1066                store: KnowledgeStore::new(),
1067                max_rules: 0,
1068                min_entries: usize::MAX,
1069                store_path: None,
1070            },
1071        }
1072    }
1073
1074    /// Run the distillation process over long-term memory.
1075    ///
1076    /// Groups corrections by context patterns and generates compressed rules.
1077    /// Only processes entries not yet seen by the distiller.
1078    pub fn distill(&mut self, long_term: &LongTermMemory) {
1079        // Collect new (unprocessed) corrections
1080        let new_corrections: Vec<&Correction> = long_term
1081            .corrections
1082            .iter()
1083            .filter(|c| !self.store.processed_correction_ids.contains(&c.id))
1084            .collect();
1085
1086        // Collect new (unprocessed) facts
1087        let new_facts: Vec<&Fact> = long_term
1088            .facts
1089            .iter()
1090            .filter(|f| !self.store.processed_fact_ids.contains(&f.id))
1091            .collect();
1092
1093        let total_new = new_corrections.len() + new_facts.len();
1094        if total_new < self.min_entries {
1095            return; // Not enough new data to distill
1096        }
1097
1098        // --- Distill corrections into rules ---
1099        // Group corrections by common patterns in their context field.
1100        let mut context_groups: HashMap<String, Vec<&Correction>> = HashMap::new();
1101        for correction in &new_corrections {
1102            // Normalize the context to a group key (first 50 chars, lowercased)
1103            let key = correction
1104                .context
1105                .chars()
1106                .take(50)
1107                .collect::<String>()
1108                .to_lowercase();
1109            context_groups.entry(key).or_default().push(correction);
1110        }
1111
1112        // Generate a rule per group that has enough entries
1113        for group in context_groups.values() {
1114            if group.len() >= 2 {
1115                // Multiple corrections with similar context → a behavioral rule
1116                let corrected_patterns: Vec<&str> =
1117                    group.iter().map(|c| c.corrected.as_str()).collect();
1118                let rule_text = format!(
1119                    "Based on {} previous corrections: prefer {}",
1120                    group.len(),
1121                    corrected_patterns.join("; ")
1122                );
1123                let source_ids: Vec<Uuid> = group.iter().map(|c| c.id).collect();
1124                self.store.rules.push(BehavioralRule {
1125                    id: Uuid::new_v4(),
1126                    rule: rule_text,
1127                    source_ids,
1128                    support_count: group.len(),
1129                    created_at: Utc::now(),
1130                });
1131            } else {
1132                // Single correction → direct rule
1133                for c in group {
1134                    self.store.rules.push(BehavioralRule {
1135                        id: Uuid::new_v4(),
1136                        rule: format!("Instead of '{}', prefer '{}'", c.original, c.corrected),
1137                        source_ids: vec![c.id],
1138                        support_count: 1,
1139                        created_at: Utc::now(),
1140                    });
1141                }
1142            }
1143        }
1144
1145        // --- Distill preferences from facts ---
1146        // Facts tagged with "preference" or containing directive language get
1147        // turned into rules directly.
1148        for fact in &new_facts {
1149            let is_preference = fact.tags.iter().any(|t| t == "preference")
1150                || fact.content.starts_with("Prefer")
1151                || fact.content.starts_with("Always")
1152                || fact.content.starts_with("Never")
1153                || fact.content.starts_with("Don't")
1154                || fact.content.starts_with("Use ");
1155            if is_preference {
1156                self.store.rules.push(BehavioralRule {
1157                    id: Uuid::new_v4(),
1158                    rule: fact.content.clone(),
1159                    source_ids: vec![fact.id],
1160                    support_count: 1,
1161                    created_at: Utc::now(),
1162                });
1163            }
1164        }
1165
1166        // Mark all new entries as processed
1167        for c in &new_corrections {
1168            self.store.processed_correction_ids.push(c.id);
1169        }
1170        for f in &new_facts {
1171            self.store.processed_fact_ids.push(f.id);
1172        }
1173
1174        // Trim rules to max_rules, keeping highest support_count
1175        if self.store.rules.len() > self.max_rules {
1176            self.store
1177                .rules
1178                .sort_by(|a, b| b.support_count.cmp(&a.support_count));
1179            self.store.rules.truncate(self.max_rules);
1180        }
1181
1182        // Persist if a store path is configured
1183        if let Some(ref path) = self.store_path {
1184            let _ = self.store.save(path);
1185        }
1186    }
1187
1188    /// Get the current distilled rules formatted for system prompt injection.
1189    ///
1190    /// Returns an empty string if no rules exist.
1191    pub fn rules_for_prompt(&self) -> String {
1192        if self.store.rules.is_empty() {
1193            return String::new();
1194        }
1195        let mut prompt = String::from(
1196            "\n\n## Learned Behavioral Rules\n\
1197             The following rules were distilled from previous sessions. Follow them:\n",
1198        );
1199        for (i, rule) in self.store.rules.iter().enumerate() {
1200            prompt.push_str(&format!("{}. {}\n", i + 1, rule.rule));
1201        }
1202        prompt
1203    }
1204
1205    /// Get the number of distilled rules.
1206    pub fn rule_count(&self) -> usize {
1207        self.store.rules.len()
1208    }
1209
1210    /// Get the knowledge store reference (for diagnostics/REPL).
1211    pub fn store(&self) -> &KnowledgeStore {
1212        &self.store
1213    }
1214}
1215
1216#[cfg(test)]
1217mod tests {
1218    use super::*;
1219    use crate::types::{Content, Role};
1220
1221    #[test]
1222    fn test_working_memory_lifecycle() {
1223        let mut wm = WorkingMemory::new();
1224        assert!(wm.current_goal.is_none());
1225
1226        wm.set_goal("refactor auth module");
1227        assert_eq!(wm.current_goal.as_deref(), Some("refactor auth module"));
1228
1229        wm.add_sub_task("read current implementation");
1230        wm.add_sub_task("design new structure");
1231        assert_eq!(wm.sub_tasks.len(), 2);
1232
1233        wm.note("finding", "uses basic auth currently");
1234        assert_eq!(
1235            wm.scratchpad.get("finding").map(|s| s.as_str()),
1236            Some("uses basic auth currently")
1237        );
1238
1239        wm.add_active_file("src/auth/mod.rs");
1240        wm.add_active_file("src/auth/mod.rs"); // duplicate
1241        assert_eq!(wm.active_files.len(), 1);
1242
1243        wm.clear();
1244        assert!(wm.current_goal.is_none());
1245        assert!(wm.sub_tasks.is_empty());
1246    }
1247
1248    #[test]
1249    fn test_short_term_memory_basic() {
1250        let mut stm = ShortTermMemory::new(5);
1251        assert!(stm.is_empty());
1252        assert_eq!(stm.len(), 0);
1253
1254        stm.add(Message::user("hello"));
1255        stm.add(Message::assistant("hi there"));
1256        assert_eq!(stm.len(), 2);
1257        assert_eq!(stm.total_messages_seen(), 2);
1258
1259        let messages = stm.to_messages();
1260        assert_eq!(messages.len(), 2);
1261    }
1262
1263    #[test]
1264    fn test_short_term_memory_window() {
1265        let mut stm = ShortTermMemory::new(3);
1266
1267        for i in 0..6 {
1268            stm.add(Message::user(format!("message {}", i)));
1269        }
1270
1271        assert_eq!(stm.len(), 6);
1272        let messages = stm.to_messages();
1273        // Should only include last 3 messages (window size)
1274        assert_eq!(messages.len(), 3);
1275        assert_eq!(messages[0].content.as_text(), Some("message 3"));
1276        assert_eq!(messages[2].content.as_text(), Some("message 5"));
1277    }
1278
1279    #[test]
1280    fn test_short_term_memory_compression() {
1281        let mut stm = ShortTermMemory::new(3);
1282
1283        for i in 0..6 {
1284            stm.add(Message::user(format!("message {}", i)));
1285        }
1286
1287        assert!(stm.needs_compression());
1288
1289        let to_summarize = stm.messages_to_summarize();
1290        assert_eq!(to_summarize.len(), 3); // messages 0, 1, 2
1291
1292        let compressed = stm.compress("Summary of messages 0-2.".to_string());
1293        assert_eq!(compressed, 3);
1294        assert_eq!(stm.len(), 3); // only window remains
1295
1296        let messages = stm.to_messages();
1297        // First message should be the summary
1298        assert_eq!(messages.len(), 4); // 1 summary + 3 recent
1299        assert!(messages[0]
1300            .content
1301            .as_text()
1302            .unwrap()
1303            .contains("Summary of"));
1304        assert_eq!(messages[0].role, Role::System);
1305    }
1306
1307    #[test]
1308    fn test_short_term_memory_double_compression() {
1309        let mut stm = ShortTermMemory::new(2);
1310
1311        // First batch
1312        for i in 0..5 {
1313            stm.add(Message::user(format!("msg {}", i)));
1314        }
1315        stm.compress("First summary.".to_string());
1316        assert_eq!(stm.len(), 2);
1317
1318        // Second batch
1319        for i in 5..8 {
1320            stm.add(Message::user(format!("msg {}", i)));
1321        }
1322        stm.compress("Second summary.".to_string());
1323        assert_eq!(stm.len(), 2);
1324
1325        // Summary should merge
1326        let summary = stm.summary().unwrap();
1327        assert!(summary.contains("First summary."));
1328        assert!(summary.contains("Second summary."));
1329    }
1330
1331    #[test]
1332    fn test_short_term_memory_clear() {
1333        let mut stm = ShortTermMemory::new(5);
1334        stm.add(Message::user("test"));
1335        stm.compress("summary".to_string());
1336
1337        stm.clear();
1338        assert!(stm.is_empty());
1339        assert!(stm.summary().is_none());
1340        assert_eq!(stm.total_messages_seen(), 0);
1341    }
1342
1343    #[test]
1344    fn test_fact_creation() {
1345        let fact = Fact::new("Project uses JWT auth", "code analysis")
1346            .with_tags(vec!["auth".to_string(), "jwt".to_string()]);
1347        assert_eq!(fact.content, "Project uses JWT auth");
1348        assert_eq!(fact.source, "code analysis");
1349        assert_eq!(fact.tags.len(), 2);
1350    }
1351
1352    #[test]
1353    fn test_long_term_memory() {
1354        let mut ltm = LongTermMemory::new();
1355
1356        ltm.add_fact(Fact::new("Uses Rust 2021 edition", "Cargo.toml"));
1357        ltm.set_preference("code_style", "rustfmt defaults");
1358        ltm.add_correction(
1359            "wrong import".to_string(),
1360            "correct import".to_string(),
1361            "editing main.rs".to_string(),
1362        );
1363
1364        assert_eq!(ltm.facts.len(), 1);
1365        assert_eq!(ltm.get_preference("code_style"), Some("rustfmt defaults"));
1366        assert_eq!(ltm.corrections.len(), 1);
1367    }
1368
1369    #[test]
1370    fn test_long_term_memory_search() {
1371        let mut ltm = LongTermMemory::new();
1372        ltm.add_fact(Fact::new("Project uses JWT authentication", "analysis"));
1373        ltm.add_fact(
1374            Fact::new("Database is PostgreSQL", "config").with_tags(vec!["database".to_string()]),
1375        );
1376        ltm.add_fact(Fact::new("Frontend uses React", "package.json"));
1377
1378        let results = ltm.search_facts("JWT");
1379        assert_eq!(results.len(), 1);
1380        assert!(results[0].content.contains("JWT"));
1381
1382        let results = ltm.search_facts("database");
1383        assert_eq!(results.len(), 1);
1384
1385        let results = ltm.search_facts("nonexistent");
1386        assert!(results.is_empty());
1387    }
1388
1389    #[test]
1390    fn test_memory_system() {
1391        let mut mem = MemorySystem::new(5);
1392
1393        mem.start_new_task("fix bug #42");
1394        assert_eq!(mem.working.current_goal.as_deref(), Some("fix bug #42"));
1395
1396        mem.add_message(Message::user("fix the null pointer bug"));
1397        mem.add_message(Message::assistant("I'll look into that."));
1398
1399        let ctx = mem.context_messages();
1400        assert_eq!(ctx.len(), 2);
1401
1402        mem.clear_session();
1403        assert!(mem.short_term.is_empty());
1404        assert!(mem.working.current_goal.is_none());
1405    }
1406
1407    #[test]
1408    fn test_compression_no_op_when_within_window() {
1409        let mut stm = ShortTermMemory::new(10);
1410        stm.add(Message::user("hello"));
1411        stm.add(Message::assistant("hi"));
1412
1413        assert!(!stm.needs_compression());
1414        assert!(stm.messages_to_summarize().is_empty());
1415
1416        let compressed = stm.compress("should not matter".to_string());
1417        assert_eq!(compressed, 0);
1418    }
1419
1420    #[test]
1421    fn test_memory_system_new_task_preserves_long_term() {
1422        let mut mem = MemorySystem::new(5);
1423        mem.long_term.add_fact(Fact::new("important fact", "test"));
1424        mem.add_message(Message::user("task 1"));
1425
1426        mem.start_new_task("task 2");
1427        assert_eq!(mem.working.current_goal.as_deref(), Some("task 2"));
1428        assert_eq!(mem.long_term.facts.len(), 1); // preserved
1429    }
1430
1431    // --- Session persistence tests ---
1432
1433    #[test]
1434    fn test_session_save_load_roundtrip() {
1435        let dir = tempfile::tempdir().unwrap();
1436        let session_path = dir.path().join("session.json");
1437
1438        // Build a memory system with data
1439        let mut mem = MemorySystem::new(10);
1440        mem.start_new_task("fix bug #42");
1441        mem.add_message(Message::user("fix the bug"));
1442        mem.add_message(Message::assistant("Looking into it."));
1443        mem.long_term.add_fact(Fact::new("Uses Rust", "analysis"));
1444        mem.long_term.set_preference("style", "concise");
1445
1446        // Save
1447        mem.save_session(&session_path).unwrap();
1448        assert!(session_path.exists());
1449
1450        // Load
1451        let loaded = MemorySystem::load_session(&session_path).unwrap();
1452        assert_eq!(loaded.working.current_goal.as_deref(), Some("fix bug #42"));
1453        assert_eq!(loaded.short_term.len(), 2);
1454        assert_eq!(loaded.long_term.facts.len(), 1);
1455        assert_eq!(loaded.long_term.get_preference("style"), Some("concise"));
1456
1457        // Verify message content
1458        let messages = loaded.context_messages();
1459        assert_eq!(messages.len(), 2);
1460        assert_eq!(messages[0].content.as_text(), Some("fix the bug"));
1461        assert_eq!(messages[1].content.as_text(), Some("Looking into it."));
1462    }
1463
1464    #[test]
1465    fn test_session_load_missing_file() {
1466        let result = MemorySystem::load_session(Path::new("/nonexistent/session.json"));
1467        assert!(result.is_err());
1468    }
1469
1470    #[test]
1471    fn test_session_load_corrupt_json() {
1472        let dir = tempfile::tempdir().unwrap();
1473        let path = dir.path().join("bad.json");
1474        std::fs::write(&path, "not valid json").unwrap();
1475
1476        let result = MemorySystem::load_session(&path);
1477        assert!(result.is_err());
1478    }
1479
1480    #[test]
1481    fn test_session_save_creates_directories() {
1482        let dir = tempfile::tempdir().unwrap();
1483        let session_path = dir.path().join("nested").join("dir").join("session.json");
1484
1485        let mem = MemorySystem::new(5);
1486        mem.save_session(&session_path).unwrap();
1487        assert!(session_path.exists());
1488    }
1489
1490    #[test]
1491    fn test_session_metadata() {
1492        let meta = SessionMetadata::new();
1493        assert!(meta.task_summary.is_none());
1494        assert!(meta.created_at <= Utc::now());
1495
1496        let default_meta = SessionMetadata::default();
1497        assert!(default_meta.task_summary.is_none());
1498    }
1499
1500    #[test]
1501    fn test_short_term_window_size() {
1502        let stm = ShortTermMemory::new(7);
1503        assert_eq!(stm.window_size(), 7);
1504    }
1505
1506    // --- Pinning tests ---
1507
1508    #[test]
1509    fn test_pin_message() {
1510        let mut stm = ShortTermMemory::new(5);
1511        stm.add(Message::user("msg 0"));
1512        stm.add(Message::user("msg 1"));
1513        stm.add(Message::user("msg 2"));
1514
1515        assert!(stm.pin(1));
1516        assert!(stm.is_pinned(1));
1517        assert!(!stm.is_pinned(0));
1518        assert_eq!(stm.pinned_count(), 1);
1519    }
1520
1521    #[test]
1522    fn test_pin_out_of_bounds() {
1523        let mut stm = ShortTermMemory::new(5);
1524        stm.add(Message::user("msg 0"));
1525        assert!(!stm.pin(5)); // out of bounds
1526    }
1527
1528    #[test]
1529    fn test_unpin_message() {
1530        let mut stm = ShortTermMemory::new(5);
1531        stm.add(Message::user("msg 0"));
1532        stm.add(Message::user("msg 1"));
1533
1534        stm.pin(0);
1535        assert!(stm.is_pinned(0));
1536        assert!(stm.unpin(0));
1537        assert!(!stm.is_pinned(0));
1538    }
1539
1540    #[test]
1541    fn test_pinned_survives_compression() {
1542        let mut stm = ShortTermMemory::new(3);
1543        // Add enough messages to trigger compression
1544        stm.add(Message::user("old 0"));
1545        stm.add(Message::user("old 1"));
1546        stm.add(Message::user("important pinned"));
1547        stm.add(Message::user("msg 3"));
1548        stm.add(Message::user("msg 4"));
1549        stm.add(Message::user("msg 5"));
1550
1551        // Pin the third message (index 2)
1552        stm.pin(2);
1553        assert!(stm.needs_compression());
1554
1555        let removed = stm.compress("Summary of old messages".to_string());
1556        // The pinned message should be preserved
1557        assert!(removed < 3); // Not all 3 were removed because one was pinned
1558
1559        // The pinned message should still be in the window
1560        let msgs = stm.to_messages();
1561        let has_pinned = msgs
1562            .iter()
1563            .any(|m| matches!(&m.content, Content::Text { text } if text == "important pinned"));
1564        assert!(has_pinned, "Pinned message should survive compression");
1565    }
1566
1567    #[test]
1568    fn test_clear_resets_pins() {
1569        let mut stm = ShortTermMemory::new(5);
1570        stm.add(Message::user("msg 0"));
1571        stm.pin(0);
1572        assert_eq!(stm.pinned_count(), 1);
1573
1574        stm.clear();
1575        assert_eq!(stm.pinned_count(), 0);
1576    }
1577
1578    // --- Context Breakdown tests ---
1579
1580    #[test]
1581    fn test_context_breakdown() {
1582        let mut memory = MemorySystem::new(10);
1583        memory.add_message(Message::user("hello world"));
1584        memory.add_message(Message::assistant("hi there!"));
1585
1586        let ctx = memory.context_breakdown(8000);
1587        assert!(ctx.message_tokens > 0);
1588        assert_eq!(ctx.message_count, 2);
1589        assert_eq!(ctx.context_window, 8000);
1590        assert!(ctx.remaining_tokens > 0);
1591        assert!(!ctx.has_summary);
1592        assert_eq!(ctx.pinned_count, 0);
1593    }
1594
1595    #[test]
1596    fn test_context_breakdown_ratio() {
1597        let ctx = ContextBreakdown {
1598            total_tokens: 4000,
1599            context_window: 8000,
1600            ..Default::default()
1601        };
1602        assert!((ctx.usage_ratio() - 0.5).abs() < 0.01);
1603        assert!(!ctx.is_warning());
1604
1605        let ctx_high = ContextBreakdown {
1606            total_tokens: 7000,
1607            context_window: 8000,
1608            ..Default::default()
1609        };
1610        assert!(ctx_high.is_warning());
1611    }
1612
1613    #[test]
1614    fn test_pin_message_via_memory_system() {
1615        let mut memory = MemorySystem::new(10);
1616        memory.add_message(Message::user("msg 0"));
1617        memory.add_message(Message::user("msg 1"));
1618
1619        assert!(memory.pin_message(0));
1620        assert!(memory.short_term.is_pinned(0));
1621    }
1622
1623    // --- Memory Flusher tests ---
1624
1625    #[test]
1626    fn test_flusher_default_config() {
1627        let config = FlushConfig::default();
1628        assert!(!config.enabled);
1629        assert_eq!(config.interval_secs, 300);
1630        assert_eq!(config.flush_on_message_count, 50);
1631        assert!(config.flush_path.is_none());
1632    }
1633
1634    #[test]
1635    fn test_flusher_not_dirty_by_default() {
1636        let flusher = MemoryFlusher::new(FlushConfig::default());
1637        assert!(!flusher.is_dirty());
1638        assert_eq!(flusher.messages_since_flush(), 0);
1639        assert_eq!(flusher.total_flushes(), 0);
1640    }
1641
1642    #[test]
1643    fn test_flusher_marks_dirty_on_message() {
1644        let mut flusher = MemoryFlusher::new(FlushConfig::default());
1645        flusher.on_message_added();
1646        assert!(flusher.is_dirty());
1647        assert_eq!(flusher.messages_since_flush(), 1);
1648    }
1649
1650    #[test]
1651    fn test_flusher_disabled_never_triggers() {
1652        let mut flusher = MemoryFlusher::new(FlushConfig {
1653            enabled: false,
1654            ..FlushConfig::default()
1655        });
1656        for _ in 0..100 {
1657            flusher.on_message_added();
1658        }
1659        assert!(!flusher.should_flush());
1660    }
1661
1662    #[test]
1663    fn test_flusher_message_count_trigger() {
1664        let mut flusher = MemoryFlusher::new(FlushConfig {
1665            enabled: true,
1666            flush_on_message_count: 5,
1667            interval_secs: 0,
1668            flush_path: None,
1669        });
1670
1671        for _ in 0..4 {
1672            flusher.on_message_added();
1673        }
1674        assert!(!flusher.should_flush());
1675
1676        flusher.on_message_added(); // 5th message
1677        assert!(flusher.should_flush());
1678    }
1679
1680    #[test]
1681    fn test_flusher_not_dirty_no_trigger() {
1682        let flusher = MemoryFlusher::new(FlushConfig {
1683            enabled: true,
1684            flush_on_message_count: 1,
1685            interval_secs: 0,
1686            flush_path: None,
1687        });
1688        // Not dirty, so should_flush is false even though threshold is 1
1689        assert!(!flusher.should_flush());
1690    }
1691
1692    #[test]
1693    fn test_flusher_flush_resets_state() {
1694        let dir = tempfile::tempdir().unwrap();
1695        let flush_path = dir.path().join("flush.json");
1696
1697        let mut flusher = MemoryFlusher::new(FlushConfig {
1698            enabled: true,
1699            flush_on_message_count: 2,
1700            interval_secs: 0,
1701            flush_path: Some(flush_path.clone()),
1702        });
1703
1704        let mut mem = MemorySystem::new(10);
1705        mem.add_message(Message::user("test"));
1706
1707        flusher.on_message_added();
1708        flusher.on_message_added();
1709        assert!(flusher.should_flush());
1710
1711        flusher.flush(&mem).unwrap();
1712        assert!(!flusher.is_dirty());
1713        assert_eq!(flusher.messages_since_flush(), 0);
1714        assert_eq!(flusher.total_flushes(), 1);
1715        assert!(flush_path.exists());
1716    }
1717
1718    #[test]
1719    fn test_flusher_force_flush() {
1720        let dir = tempfile::tempdir().unwrap();
1721        let flush_path = dir.path().join("force.json");
1722
1723        let mut flusher = MemoryFlusher::new(FlushConfig {
1724            enabled: true,
1725            flush_on_message_count: 100,
1726            interval_secs: 0,
1727            flush_path: Some(flush_path.clone()),
1728        });
1729
1730        let mem = MemorySystem::new(10);
1731
1732        // Not dirty - force_flush is a no-op
1733        flusher.force_flush(&mem).unwrap();
1734        assert_eq!(flusher.total_flushes(), 0);
1735
1736        // Make dirty, then force
1737        flusher.on_message_added();
1738        flusher.force_flush(&mem).unwrap();
1739        assert_eq!(flusher.total_flushes(), 1);
1740        assert!(!flusher.is_dirty());
1741    }
1742
1743    #[test]
1744    fn test_flusher_no_path_error() {
1745        let mut flusher = MemoryFlusher::new(FlushConfig {
1746            enabled: true,
1747            flush_on_message_count: 1,
1748            interval_secs: 0,
1749            flush_path: None,
1750        });
1751        flusher.on_message_added();
1752
1753        let mem = MemorySystem::new(10);
1754        let result = flusher.flush(&mem);
1755        assert!(result.is_err());
1756    }
1757
1758    #[test]
1759    fn test_flush_config_serialization() {
1760        let config = FlushConfig {
1761            enabled: true,
1762            interval_secs: 120,
1763            flush_on_message_count: 25,
1764            flush_path: Some(std::path::PathBuf::from("/tmp/flush.json")),
1765        };
1766        let json = serde_json::to_string(&config).unwrap();
1767        let restored: FlushConfig = serde_json::from_str(&json).unwrap();
1768        assert!(restored.enabled);
1769        assert_eq!(restored.interval_secs, 120);
1770        assert_eq!(restored.flush_on_message_count, 25);
1771    }
1772
1773    // --- A4: HybridSearchEngine → MemorySystem integration tests ---
1774
1775    #[test]
1776    fn test_memory_system_without_search_uses_keyword_fallback() {
1777        let mut mem = MemorySystem::new(10);
1778        mem.add_fact(Fact::new("Rust uses ownership for memory safety", "docs"));
1779        mem.add_fact(Fact::new("Python uses garbage collection", "docs"));
1780
1781        let results = mem.search_facts_hybrid("ownership");
1782        assert_eq!(results.len(), 1);
1783        assert!(results[0].content.contains("ownership"));
1784    }
1785
1786    #[test]
1787    fn test_memory_system_with_search_engine() {
1788        let dir = tempfile::tempdir().unwrap();
1789        let config = SearchConfig {
1790            index_path: dir.path().join("idx"),
1791            db_path: dir.path().join("vec.db"),
1792            vector_dimensions: 64,
1793            full_text_weight: 0.5,
1794            vector_weight: 0.5,
1795            max_results: 10,
1796        };
1797        let mut mem = MemorySystem::with_search(10, config).unwrap();
1798
1799        mem.add_fact(Fact::new("Rust uses ownership model", "analysis"));
1800        mem.add_fact(Fact::new("Python garbage collector", "analysis"));
1801
1802        // The hybrid engine should find results (or fall back to keyword)
1803        let results = mem.search_facts_hybrid("ownership");
1804        assert!(!results.is_empty());
1805        assert!(results.iter().any(|f| f.content.contains("ownership")));
1806    }
1807
1808    #[test]
1809    fn test_memory_system_search_empty_query() {
1810        let mut mem = MemorySystem::new(10);
1811        mem.add_fact(Fact::new("some fact", "source"));
1812        let results = mem.search_facts_hybrid("");
1813        // Empty query falls back to keyword search, which matches nothing
1814        // (no content matches empty substring in the keywords path)
1815        // Actually, empty string is contained in everything
1816        assert!(!results.is_empty());
1817    }
1818
1819    #[test]
1820    fn test_memory_system_search_no_facts() {
1821        let mem = MemorySystem::new(10);
1822        let results = mem.search_facts_hybrid("anything");
1823        assert!(results.is_empty());
1824    }
1825
1826    #[test]
1827    fn test_add_fact_indexes_into_search_engine() {
1828        let dir = tempfile::tempdir().unwrap();
1829        let config = SearchConfig {
1830            index_path: dir.path().join("idx"),
1831            db_path: dir.path().join("vec.db"),
1832            vector_dimensions: 64,
1833            full_text_weight: 0.5,
1834            vector_weight: 0.5,
1835            max_results: 10,
1836        };
1837        let mut mem = MemorySystem::with_search(10, config).unwrap();
1838
1839        // Add multiple facts
1840        for i in 0..5 {
1841            mem.add_fact(Fact::new(format!("fact number {}", i), "test"));
1842        }
1843
1844        // All 5 should be in long-term memory
1845        assert_eq!(mem.long_term.facts.len(), 5);
1846    }
1847
1848    // --- A5: MemoryFlusher → MemorySystem integration tests ---
1849
1850    #[test]
1851    fn test_memory_system_with_flusher() {
1852        let config = FlushConfig {
1853            enabled: true,
1854            flush_on_message_count: 5,
1855            interval_secs: 0,
1856            flush_path: None,
1857        };
1858        let mem = MemorySystem::new(10).with_flusher(config);
1859        assert!(!mem.flusher_is_dirty());
1860    }
1861
1862    #[test]
1863    fn test_memory_system_add_message_notifies_flusher() {
1864        let config = FlushConfig {
1865            enabled: true,
1866            flush_on_message_count: 5,
1867            interval_secs: 0,
1868            flush_path: None,
1869        };
1870        let mut mem = MemorySystem::new(10).with_flusher(config);
1871
1872        mem.add_message(Message::user("hello"));
1873        assert!(mem.flusher_is_dirty());
1874    }
1875
1876    #[test]
1877    fn test_memory_system_check_auto_flush_no_flusher() {
1878        let mut mem = MemorySystem::new(10);
1879        // No flusher attached — should be a no-op returning Ok(false)
1880        let result = mem.check_auto_flush().unwrap();
1881        assert!(!result);
1882    }
1883
1884    #[test]
1885    fn test_memory_system_check_auto_flush_triggers() {
1886        let dir = tempfile::tempdir().unwrap();
1887        let flush_path = dir.path().join("auto_flush.json");
1888
1889        let config = FlushConfig {
1890            enabled: true,
1891            flush_on_message_count: 3,
1892            interval_secs: 0,
1893            flush_path: Some(flush_path.clone()),
1894        };
1895        let mut mem = MemorySystem::new(10).with_flusher(config);
1896
1897        // Add messages below threshold
1898        mem.add_message(Message::user("msg 1"));
1899        mem.add_message(Message::user("msg 2"));
1900        assert!(!mem.check_auto_flush().unwrap());
1901        assert!(!flush_path.exists());
1902
1903        // Hit the threshold
1904        mem.add_message(Message::user("msg 3"));
1905        assert!(mem.check_auto_flush().unwrap());
1906        assert!(flush_path.exists());
1907
1908        // After flush, flusher should not be dirty
1909        assert!(!mem.flusher_is_dirty());
1910    }
1911
1912    #[test]
1913    fn test_memory_system_force_flush() {
1914        let dir = tempfile::tempdir().unwrap();
1915        let flush_path = dir.path().join("force_flush.json");
1916
1917        let config = FlushConfig {
1918            enabled: true,
1919            flush_on_message_count: 100, // high threshold
1920            interval_secs: 0,
1921            flush_path: Some(flush_path.clone()),
1922        };
1923        let mut mem = MemorySystem::new(10).with_flusher(config);
1924
1925        mem.add_message(Message::user("important data"));
1926        assert!(mem.flusher_is_dirty());
1927
1928        mem.force_flush().unwrap();
1929        assert!(!mem.flusher_is_dirty());
1930        assert!(flush_path.exists());
1931    }
1932
1933    #[test]
1934    fn test_memory_system_force_flush_no_flusher() {
1935        let mut mem = MemorySystem::new(10);
1936        // Should be a no-op, not an error
1937        mem.force_flush().unwrap();
1938    }
1939
1940    // --- Knowledge Distiller Tests ---
1941
1942    #[test]
1943    fn test_knowledge_distiller_disabled() {
1944        let distiller = KnowledgeDistiller::new(None);
1945        assert_eq!(distiller.rule_count(), 0);
1946        assert!(distiller.rules_for_prompt().is_empty());
1947    }
1948
1949    #[test]
1950    fn test_knowledge_distiller_no_data() {
1951        let config = crate::config::KnowledgeConfig::default();
1952        let mut distiller = KnowledgeDistiller::new(Some(&config));
1953        let ltm = LongTermMemory::new();
1954        distiller.distill(&ltm);
1955        assert_eq!(distiller.rule_count(), 0);
1956    }
1957
1958    #[test]
1959    fn test_knowledge_distiller_corrections_below_threshold() {
1960        let config = crate::config::KnowledgeConfig {
1961            min_entries_for_distillation: 5,
1962            ..Default::default()
1963        };
1964        let mut distiller = KnowledgeDistiller::new(Some(&config));
1965
1966        let mut ltm = LongTermMemory::new();
1967        ltm.add_correction(
1968            "unwrap()".into(),
1969            "? operator".into(),
1970            "error handling".into(),
1971        );
1972        ltm.add_correction("println!".into(), "tracing::info!".into(), "logging".into());
1973
1974        distiller.distill(&ltm);
1975        // Only 2 entries, threshold is 5
1976        assert_eq!(distiller.rule_count(), 0);
1977    }
1978
1979    #[test]
1980    fn test_knowledge_distiller_single_corrections() {
1981        let config = crate::config::KnowledgeConfig {
1982            min_entries_for_distillation: 2,
1983            ..Default::default()
1984        };
1985        let mut distiller = KnowledgeDistiller::new(Some(&config));
1986
1987        let mut ltm = LongTermMemory::new();
1988        ltm.add_correction(
1989            "unwrap()".into(),
1990            "? operator".into(),
1991            "error handling".into(),
1992        );
1993        ltm.add_correction("println!".into(), "tracing::info!".into(), "logging".into());
1994        // Both have different contexts → separate rules
1995
1996        distiller.distill(&ltm);
1997        assert_eq!(distiller.rule_count(), 2);
1998
1999        let prompt = distiller.rules_for_prompt();
2000        assert!(prompt.contains("Learned Behavioral Rules"));
2001        assert!(prompt.contains("? operator"));
2002        assert!(prompt.contains("tracing::info!"));
2003    }
2004
2005    #[test]
2006    fn test_knowledge_distiller_grouped_corrections() {
2007        let config = crate::config::KnowledgeConfig {
2008            min_entries_for_distillation: 2,
2009            ..Default::default()
2010        };
2011        let mut distiller = KnowledgeDistiller::new(Some(&config));
2012
2013        let mut ltm = LongTermMemory::new();
2014        // Two corrections with same context prefix → should group
2015        ltm.add_correction(
2016            "unwrap()".into(),
2017            "? operator".into(),
2018            "error handling in Rust code".into(),
2019        );
2020        ltm.add_correction(
2021            "expect()".into(),
2022            "map_err()?".into(),
2023            "error handling in Rust code".into(),
2024        );
2025
2026        distiller.distill(&ltm);
2027        assert_eq!(distiller.rule_count(), 1);
2028        let prompt = distiller.rules_for_prompt();
2029        assert!(prompt.contains("2 previous corrections"));
2030    }
2031
2032    #[test]
2033    fn test_knowledge_distiller_preference_facts() {
2034        let config = crate::config::KnowledgeConfig {
2035            min_entries_for_distillation: 1,
2036            ..Default::default()
2037        };
2038        let mut distiller = KnowledgeDistiller::new(Some(&config));
2039
2040        let mut ltm = LongTermMemory::new();
2041        ltm.add_fact(Fact::new("Prefer async/await over threads", "user"));
2042        ltm.add_fact(Fact::new("Project uses PostgreSQL", "session"));
2043
2044        distiller.distill(&ltm);
2045        // Only the "Prefer..." fact becomes a rule
2046        assert_eq!(distiller.rule_count(), 1);
2047        let prompt = distiller.rules_for_prompt();
2048        assert!(prompt.contains("async/await"));
2049    }
2050
2051    #[test]
2052    fn test_knowledge_distiller_max_rules_truncation() {
2053        let config = crate::config::KnowledgeConfig {
2054            min_entries_for_distillation: 1,
2055            max_rules: 3,
2056            ..Default::default()
2057        };
2058        let mut distiller = KnowledgeDistiller::new(Some(&config));
2059
2060        let mut ltm = LongTermMemory::new();
2061        for i in 0..10 {
2062            ltm.add_correction(
2063                format!("old{}", i),
2064                format!("new{}", i),
2065                format!("context{}", i),
2066            );
2067        }
2068
2069        distiller.distill(&ltm);
2070        assert!(distiller.rule_count() <= 3);
2071    }
2072
2073    #[test]
2074    fn test_knowledge_distiller_idempotent() {
2075        let config = crate::config::KnowledgeConfig {
2076            min_entries_for_distillation: 1,
2077            ..Default::default()
2078        };
2079        let mut distiller = KnowledgeDistiller::new(Some(&config));
2080
2081        let mut ltm = LongTermMemory::new();
2082        ltm.add_correction("old".into(), "new".into(), "ctx".into());
2083
2084        distiller.distill(&ltm);
2085        let count_after_first = distiller.rule_count();
2086
2087        // Distill again with same data — should not add duplicate rules
2088        distiller.distill(&ltm);
2089        assert_eq!(distiller.rule_count(), count_after_first);
2090    }
2091
2092    #[test]
2093    fn test_knowledge_store_save_load_roundtrip() {
2094        let dir = tempfile::tempdir().unwrap();
2095        let path = dir.path().join("knowledge.json");
2096
2097        let mut store = KnowledgeStore::new();
2098        store.rules.push(BehavioralRule {
2099            id: Uuid::new_v4(),
2100            rule: "Prefer ? over unwrap".into(),
2101            source_ids: vec![Uuid::new_v4()],
2102            support_count: 3,
2103            created_at: Utc::now(),
2104        });
2105
2106        store.save(&path).unwrap();
2107        let loaded = KnowledgeStore::load(&path).unwrap();
2108        assert_eq!(loaded.rules.len(), 1);
2109        assert_eq!(loaded.rules[0].rule, "Prefer ? over unwrap");
2110        assert_eq!(loaded.rules[0].support_count, 3);
2111    }
2112
2113    #[test]
2114    fn test_knowledge_store_load_nonexistent() {
2115        let store =
2116            KnowledgeStore::load(std::path::Path::new("/nonexistent/knowledge.json")).unwrap();
2117        assert!(store.rules.is_empty());
2118    }
2119
2120    #[test]
2121    fn test_unpin_out_of_bounds_returns_false() {
2122        let mut stm = ShortTermMemory::new(100);
2123        stm.add(Message::user("hello"));
2124        stm.add(Message::assistant("hi"));
2125        stm.add(Message::user("world"));
2126
2127        // Pin a valid message
2128        assert!(stm.pin(1));
2129        assert!(stm.is_pinned(1));
2130
2131        // Unpin out-of-bounds should return false
2132        assert!(!stm.unpin(999));
2133        assert!(!stm.unpin(3));
2134
2135        // Original pin should still be intact
2136        assert!(stm.is_pinned(1));
2137    }
2138
2139    #[test]
2140    fn test_unpin_at_exact_boundary() {
2141        let mut stm = ShortTermMemory::new(100);
2142        stm.add(Message::user("msg0"));
2143        stm.add(Message::user("msg1"));
2144
2145        // Unpin at exactly len (2) should fail
2146        assert!(!stm.unpin(2));
2147
2148        // Unpin at len-1 (1) should succeed even if not pinned (returns false from remove)
2149        assert!(!stm.unpin(1)); // not pinned, so remove returns false
2150
2151        // Pin and then unpin at boundary
2152        assert!(stm.pin(1));
2153        assert!(stm.unpin(1));
2154        assert!(!stm.is_pinned(1));
2155    }
2156
2157    // --- Tool chain pair preservation tests ---
2158
2159    #[test]
2160    fn test_compress_preserves_tool_chain_pairs() {
2161        let mut stm = ShortTermMemory::new(3);
2162
2163        // Build a conversation with tool_call → tool_result pair in the compression zone
2164        stm.add(Message::user("read the file")); // idx 0
2165        stm.add(Message::new(
2166            // idx 1 — tool_call
2167            Role::Assistant,
2168            Content::tool_call("call_abc", "file_read", serde_json::json!({"path": "x.rs"})),
2169        ));
2170        stm.add(Message::tool_result("call_abc", "fn main() {}", false)); // idx 2 — tool_result
2171        stm.add(Message::assistant("Here is the content.")); // idx 3
2172        stm.add(Message::user("thanks")); // idx 4
2173        stm.add(Message::assistant("You're welcome.")); // idx 5
2174
2175        // Pin the tool_result at index 2
2176        stm.pin(2);
2177        assert!(stm.needs_compression());
2178
2179        let removed = stm.compress("Summary of earlier conversation".to_string());
2180
2181        // The tool_result (pinned) should be preserved, AND its paired
2182        // tool_call should ALSO be preserved
2183        let msgs = stm.to_messages();
2184        let has_tool_call = msgs.iter().any(|m| {
2185            matches!(
2186                &m.content,
2187                Content::ToolCall { name, .. } if name == "file_read"
2188            )
2189        });
2190        let has_tool_result = msgs.iter().any(|m| {
2191            matches!(
2192                &m.content,
2193                Content::ToolResult { call_id, .. } if call_id == "call_abc"
2194            )
2195        });
2196
2197        assert!(
2198            has_tool_result,
2199            "Pinned tool_result should survive compression"
2200        );
2201        assert!(
2202            has_tool_call,
2203            "Paired tool_call should also survive compression"
2204        );
2205
2206        // The user message at idx 0 should have been compressed away
2207        // (it was not pinned and not a tool pair)
2208        assert!(removed >= 1, "At least one message should be compressed");
2209    }
2210
2211    #[test]
2212    fn test_compress_preserves_tool_call_paired_with_result() {
2213        let mut stm = ShortTermMemory::new(3);
2214
2215        stm.add(Message::user("do something")); // idx 0
2216        stm.add(Message::new(
2217            // idx 1 — tool_call
2218            Role::Assistant,
2219            Content::tool_call("call_xyz", "shell_exec", serde_json::json!({"cmd": "ls"})),
2220        ));
2221        stm.add(Message::tool_result("call_xyz", "file1\nfile2", false)); // idx 2
2222        stm.add(Message::assistant("Listed files.")); // idx 3
2223        stm.add(Message::user("ok")); // idx 4
2224        stm.add(Message::assistant("Done.")); // idx 5
2225
2226        // Pin the tool_call at index 1
2227        stm.pin(1);
2228        assert!(stm.needs_compression());
2229
2230        stm.compress("Summary".to_string());
2231
2232        let msgs = stm.to_messages();
2233        let has_tool_call = msgs.iter().any(|m| {
2234            matches!(
2235                &m.content,
2236                Content::ToolCall { id, .. } if id == "call_xyz"
2237            )
2238        });
2239        let has_tool_result = msgs.iter().any(|m| {
2240            matches!(
2241                &m.content,
2242                Content::ToolResult { call_id, .. } if call_id == "call_xyz"
2243            )
2244        });
2245
2246        assert!(has_tool_call, "Pinned tool_call should survive compression");
2247        assert!(
2248            has_tool_result,
2249            "Paired tool_result should also survive compression"
2250        );
2251    }
2252}