Skip to main content

rustant_core/
memory.rs

1//! Three-tier memory system for the Rustant agent.
2//!
3//! - **Working Memory**: Current task state and scratch data (single task lifetime).
4//! - **Short-Term Memory**: Sliding window of recent conversation with summarization.
5//! - **Long-Term Memory**: Persistent facts and preferences across sessions.
6
7use crate::error::MemoryError;
8use crate::search::{HybridSearchEngine, SearchConfig};
9use crate::types::{Content, Message, Role};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::path::Path;
14use uuid::Uuid;
15
16/// Working memory for the currently executing task.
17#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18pub struct WorkingMemory {
19    pub current_goal: Option<String>,
20    pub sub_tasks: Vec<String>,
21    pub scratchpad: HashMap<String, String>,
22    pub active_files: Vec<String>,
23}
24
25impl WorkingMemory {
26    pub fn new() -> Self {
27        Self::default()
28    }
29
30    pub fn set_goal(&mut self, goal: impl Into<String>) {
31        self.current_goal = Some(goal.into());
32    }
33
34    pub fn add_sub_task(&mut self, task: impl Into<String>) {
35        self.sub_tasks.push(task.into());
36    }
37
38    pub fn note(&mut self, key: impl Into<String>, value: impl Into<String>) {
39        self.scratchpad.insert(key.into(), value.into());
40    }
41
42    pub fn add_active_file(&mut self, path: impl Into<String>) {
43        let path = path.into();
44        if !self.active_files.contains(&path) {
45            self.active_files.push(path);
46        }
47    }
48
49    pub fn clear(&mut self) {
50        *self = Self::default();
51    }
52}
53
54/// Short-term memory: sliding window of recent messages with summarization support.
55#[derive(Debug, Clone)]
56pub struct ShortTermMemory {
57    messages: VecDeque<Message>,
58    window_size: usize,
59    summarized_prefix: Option<String>,
60    total_messages_seen: usize,
61    /// Set of absolute message indices that are pinned (survive compression).
62    pinned: std::collections::HashSet<usize>,
63    /// Running count of messages removed by compression (for index mapping).
64    compressed_offset: usize,
65}
66
67impl ShortTermMemory {
68    pub fn new(window_size: usize) -> Self {
69        Self {
70            messages: VecDeque::new(),
71            window_size,
72            summarized_prefix: None,
73            total_messages_seen: 0,
74            pinned: std::collections::HashSet::new(),
75            compressed_offset: 0,
76        }
77    }
78
79    /// Add a message to short-term memory.
80    pub fn add(&mut self, message: Message) {
81        self.messages.push_back(message);
82        self.total_messages_seen += 1;
83    }
84
85    /// Get all messages that should be sent to the LLM, including any summary.
86    pub fn to_messages(&self) -> Vec<Message> {
87        let mut result = Vec::new();
88
89        // Include summary of older context if available
90        if let Some(ref summary) = self.summarized_prefix {
91            result.push(Message::system(format!(
92                "[Summary of earlier conversation]\n{}",
93                summary
94            )));
95        }
96
97        // Include recent messages within the window.
98        // Pinned messages that fall before the window start are still included.
99        let start = if self.messages.len() > self.window_size {
100            self.messages.len() - self.window_size
101        } else {
102            0
103        };
104
105        for (i, msg) in self.messages.iter().enumerate() {
106            if i >= start || self.is_pinned(i) {
107                result.push(msg.clone());
108            }
109        }
110
111        result
112    }
113
114    /// Check whether compression is needed based on message count.
115    pub fn needs_compression(&self) -> bool {
116        self.messages.len() >= self.window_size * 2
117    }
118
119    /// Compress older messages by replacing them with a summary.
120    /// Pinned messages are preserved and moved to the front of the window.
121    /// When a pinned message is a tool_result, its corresponding tool_call
122    /// message is also preserved (and vice versa) to maintain valid sequences.
123    /// Returns the number of messages that were compressed.
124    pub fn compress(&mut self, summary: String) -> usize {
125        if self.messages.len() <= self.window_size {
126            return 0;
127        }
128
129        let to_remove = self.messages.len() - self.window_size;
130
131        // First pass: find indices that are pinned
132        let mut preserve_indices: HashSet<usize> = HashSet::new();
133        for i in 0..to_remove {
134            let abs_idx = self.compressed_offset + i;
135            if self.pinned.contains(&abs_idx) {
136                preserve_indices.insert(i);
137            }
138        }
139
140        // Second pass: for each preserved message, also preserve its tool pair
141        let mut extra_preserves: Vec<usize> = Vec::new();
142        for &i in &preserve_indices {
143            if let Some(msg) = self.messages.get(i) {
144                match &msg.content {
145                    // If pinned message is a tool_result, find its tool_call
146                    Content::ToolResult { call_id, .. } => {
147                        if let Some(pair_idx) =
148                            Self::find_tool_call_for_result(call_id, &self.messages, to_remove)
149                            && !preserve_indices.contains(&pair_idx)
150                        {
151                            extra_preserves.push(pair_idx);
152                        }
153                    }
154                    // If pinned message is a tool_call, find its tool_result
155                    Content::ToolCall { id, .. } => {
156                        if let Some(pair_idx) =
157                            Self::find_tool_result_for_call(id, &self.messages, to_remove)
158                            && !preserve_indices.contains(&pair_idx)
159                        {
160                            extra_preserves.push(pair_idx);
161                        }
162                    }
163                    // MultiPart: check both tool_calls and tool_results
164                    Content::MultiPart { parts } => {
165                        for part in parts {
166                            match part {
167                                Content::ToolCall { id, .. } => {
168                                    if let Some(pair_idx) = Self::find_tool_result_for_call(
169                                        id,
170                                        &self.messages,
171                                        to_remove,
172                                    ) && !preserve_indices.contains(&pair_idx)
173                                    {
174                                        extra_preserves.push(pair_idx);
175                                    }
176                                }
177                                Content::ToolResult { call_id, .. } => {
178                                    if let Some(pair_idx) = Self::find_tool_call_for_result(
179                                        call_id,
180                                        &self.messages,
181                                        to_remove,
182                                    ) && !preserve_indices.contains(&pair_idx)
183                                    {
184                                        extra_preserves.push(pair_idx);
185                                    }
186                                }
187                                _ => {}
188                            }
189                        }
190                    }
191                    _ => {}
192                }
193            }
194        }
195        for idx in extra_preserves {
196            preserve_indices.insert(idx);
197        }
198
199        // Collect preserved messages in order, count removed
200        let mut preserved = Vec::new();
201        let mut removed_count = 0;
202
203        for i in 0..to_remove {
204            if preserve_indices.contains(&i) {
205                if let Some(msg) = self.messages.get(i) {
206                    preserved.push(msg.clone());
207                }
208            } else {
209                removed_count += 1;
210            }
211        }
212
213        // Collect absolute indices that are pinned and remain in the window
214        // (i.e. messages beyond to_remove that were pinned).
215        let mut surviving_pinned: Vec<usize> = Vec::new();
216        for i in to_remove..self.messages.len() {
217            let abs_idx = self.compressed_offset + i;
218            if self.pinned.contains(&abs_idx) {
219                surviving_pinned.push(i - to_remove);
220            }
221        }
222
223        // Remove the to_remove oldest messages
224        self.messages.drain(..to_remove);
225        self.compressed_offset += to_remove;
226
227        // Re-insert pinned messages at the front of the window using batch operation.
228        // This avoids O(p^2) sequential insert() calls by building a new VecDeque.
229        let preserved_count = preserved.len();
230        if !preserved.is_empty() {
231            let mut new_messages = VecDeque::with_capacity(preserved_count + self.messages.len());
232            for msg in preserved {
233                new_messages.push_back(msg);
234            }
235            new_messages.append(&mut self.messages);
236            self.messages = new_messages;
237        }
238
239        // Rebuild the pinned set with correct absolute indices.
240        // Preserved messages are at positions 0..preserved_count.
241        // Surviving pinned messages shifted by preserved_count.
242        let mut new_pinned = HashSet::new();
243        for i in 0..preserved_count {
244            new_pinned.insert(self.compressed_offset + i);
245        }
246        for pos in surviving_pinned {
247            new_pinned.insert(self.compressed_offset + preserved_count + pos);
248        }
249        self.pinned = new_pinned;
250
251        // Merge with existing summary
252        if let Some(ref existing) = self.summarized_prefix {
253            self.summarized_prefix = Some(format!("{}\n\n{}", existing, summary));
254        } else {
255            self.summarized_prefix = Some(summary);
256        }
257
258        removed_count
259    }
260
261    /// Find the index (within `range 0..limit`) of an assistant message containing
262    /// a tool_call with the given ID, paired with a tool_result's `call_id`.
263    fn find_tool_call_for_result(
264        call_id: &str,
265        messages: &VecDeque<Message>,
266        limit: usize,
267    ) -> Option<usize> {
268        messages
269            .iter()
270            .enumerate()
271            .take(limit)
272            .find(|(_, msg)| {
273                msg.role == Role::Assistant
274                    && Self::content_contains_tool_call_id(&msg.content, call_id)
275            })
276            .map(|(i, _)| i)
277    }
278
279    /// Find the index (within `range 0..limit`) of a tool/user message containing
280    /// a tool_result with the given `call_id`, paired with a tool_call's ID.
281    fn find_tool_result_for_call(
282        tool_call_id: &str,
283        messages: &VecDeque<Message>,
284        limit: usize,
285    ) -> Option<usize> {
286        messages
287            .iter()
288            .enumerate()
289            .take(limit)
290            .find(|(_, msg)| {
291                (msg.role == Role::Tool || msg.role == Role::User)
292                    && Self::content_contains_tool_result_id(&msg.content, tool_call_id)
293            })
294            .map(|(i, _)| i)
295    }
296
297    /// Check if a Content contains a tool_call with the given ID.
298    fn content_contains_tool_call_id(content: &Content, target_id: &str) -> bool {
299        match content {
300            Content::ToolCall { id, .. } => id == target_id,
301            Content::MultiPart { parts } => parts
302                .iter()
303                .any(|p| Self::content_contains_tool_call_id(p, target_id)),
304            _ => false,
305        }
306    }
307
308    /// Check if a Content contains a tool_result with the given call_id.
309    fn content_contains_tool_result_id(content: &Content, target_id: &str) -> bool {
310        match content {
311            Content::ToolResult { call_id, .. } => call_id == target_id,
312            Content::MultiPart { parts } => parts
313                .iter()
314                .any(|p| Self::content_contains_tool_result_id(p, target_id)),
315            _ => false,
316        }
317    }
318
319    /// Pin a message by its position in the current message list (0-based).
320    /// Pinned messages survive compression.
321    pub fn pin(&mut self, position: usize) -> bool {
322        if position >= self.messages.len() {
323            return false;
324        }
325        let abs_idx = self.compressed_offset + position;
326        self.pinned.insert(abs_idx);
327        true
328    }
329
330    /// Unpin a message by its position in the current message list.
331    pub fn unpin(&mut self, position: usize) -> bool {
332        if position >= self.messages.len() {
333            return false;
334        }
335        let abs_idx = self.compressed_offset + position;
336        self.pinned.remove(&abs_idx)
337    }
338
339    /// Check if a message at the given position is pinned.
340    pub fn is_pinned(&self, position: usize) -> bool {
341        let abs_idx = self.compressed_offset + position;
342        self.pinned.contains(&abs_idx)
343    }
344
345    /// Number of currently pinned messages.
346    pub fn pinned_count(&self) -> usize {
347        // Count pinned messages that are still in the current window
348        (0..self.messages.len())
349            .filter(|&i| self.is_pinned(i))
350            .count()
351    }
352
353    /// Get messages that should be summarized (older than window).
354    pub fn messages_to_summarize(&self) -> Vec<&Message> {
355        if self.messages.len() <= self.window_size {
356            return Vec::new();
357        }
358        let to_summarize = self.messages.len() - self.window_size;
359        self.messages.iter().take(to_summarize).collect()
360    }
361
362    /// Get the total number of messages currently held.
363    pub fn len(&self) -> usize {
364        self.messages.len()
365    }
366
367    /// Check if memory is empty.
368    pub fn is_empty(&self) -> bool {
369        self.messages.is_empty()
370    }
371
372    /// Get total messages seen across the session.
373    pub fn total_messages_seen(&self) -> usize {
374        self.total_messages_seen
375    }
376
377    /// Clear all messages and summary.
378    pub fn clear(&mut self) {
379        self.messages.clear();
380        self.summarized_prefix = None;
381        self.total_messages_seen = 0;
382        self.pinned.clear();
383        self.compressed_offset = 0;
384    }
385
386    /// Get a reference to all current messages.
387    pub fn messages(&self) -> &VecDeque<Message> {
388        &self.messages
389    }
390
391    /// Get the summary prefix, if any.
392    pub fn summary(&self) -> Option<&str> {
393        self.summarized_prefix.as_deref()
394    }
395}
396
397/// A fact extracted from conversation for long-term storage.
398#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct Fact {
400    pub id: Uuid,
401    pub content: String,
402    pub source: String,
403    pub created_at: DateTime<Utc>,
404    pub tags: Vec<String>,
405}
406
407impl Fact {
408    pub fn new(content: impl Into<String>, source: impl Into<String>) -> Self {
409        Self {
410            id: Uuid::new_v4(),
411            content: content.into(),
412            source: source.into(),
413            created_at: Utc::now(),
414            tags: Vec::new(),
415        }
416    }
417
418    pub fn with_tags(mut self, tags: Vec<String>) -> Self {
419        self.tags = tags;
420        self
421    }
422}
423
424/// Long-term memory persisted across sessions.
425#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct LongTermMemory {
427    pub facts: Vec<Fact>,
428    pub preferences: HashMap<String, String>,
429    pub corrections: Vec<Correction>,
430    /// Maximum number of facts to retain. When exceeded, the oldest fact is evicted.
431    #[serde(default = "LongTermMemory::default_max_facts")]
432    pub max_facts: usize,
433    /// Maximum number of corrections to retain. When exceeded, the oldest correction is evicted.
434    #[serde(default = "LongTermMemory::default_max_corrections")]
435    pub max_corrections: usize,
436}
437
438impl Default for LongTermMemory {
439    fn default() -> Self {
440        Self {
441            facts: Vec::new(),
442            preferences: HashMap::new(),
443            corrections: Vec::new(),
444            max_facts: Self::default_max_facts(),
445            max_corrections: Self::default_max_corrections(),
446        }
447    }
448}
449
450/// A correction recorded from user feedback.
451#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct Correction {
453    pub id: Uuid,
454    pub original: String,
455    pub corrected: String,
456    pub context: String,
457    pub timestamp: DateTime<Utc>,
458}
459
460impl LongTermMemory {
461    pub fn new() -> Self {
462        Self::default()
463    }
464
465    fn default_max_facts() -> usize {
466        10_000
467    }
468
469    fn default_max_corrections() -> usize {
470        1_000
471    }
472
473    pub fn add_fact(&mut self, fact: Fact) {
474        if self.facts.len() >= self.max_facts {
475            self.facts.remove(0);
476        }
477        self.facts.push(fact);
478    }
479
480    pub fn set_preference(&mut self, key: impl Into<String>, value: impl Into<String>) {
481        self.preferences.insert(key.into(), value.into());
482    }
483
484    pub fn get_preference(&self, key: &str) -> Option<&str> {
485        self.preferences.get(key).map(|s| s.as_str())
486    }
487
488    pub fn add_correction(&mut self, original: String, corrected: String, context: String) {
489        if self.corrections.len() >= self.max_corrections {
490            self.corrections.remove(0);
491        }
492        self.corrections.push(Correction {
493            id: Uuid::new_v4(),
494            original,
495            corrected,
496            context,
497            timestamp: Utc::now(),
498        });
499    }
500
501    /// Search facts by keyword.
502    pub fn search_facts(&self, query: &str) -> Vec<&Fact> {
503        let query_lower = query.to_lowercase();
504        self.facts
505            .iter()
506            .filter(|f| {
507                f.content.to_lowercase().contains(&query_lower)
508                    || f.tags
509                        .iter()
510                        .any(|t| t.to_lowercase().contains(&query_lower))
511            })
512            .collect()
513    }
514}
515
516/// The unified memory system combining all three tiers.
517pub struct MemorySystem {
518    pub working: WorkingMemory,
519    pub short_term: ShortTermMemory,
520    pub long_term: LongTermMemory,
521    /// Optional hybrid search engine for fact retrieval.
522    search_engine: Option<HybridSearchEngine>,
523    /// Optional automatic flusher for periodic persistence.
524    flusher: Option<MemoryFlusher>,
525}
526
527impl MemorySystem {
528    pub fn new(window_size: usize) -> Self {
529        Self {
530            working: WorkingMemory::new(),
531            short_term: ShortTermMemory::new(window_size),
532            long_term: LongTermMemory::new(),
533            search_engine: None,
534            flusher: None,
535        }
536    }
537
538    /// Create a memory system with hybrid search enabled.
539    pub fn with_search(
540        window_size: usize,
541        search_config: SearchConfig,
542    ) -> Result<Self, crate::search::SearchError> {
543        let engine = HybridSearchEngine::open(search_config)?;
544        Ok(Self {
545            working: WorkingMemory::new(),
546            short_term: ShortTermMemory::new(window_size),
547            long_term: LongTermMemory::new(),
548            search_engine: Some(engine),
549            flusher: None,
550        })
551    }
552
553    /// Attach an automatic flusher to this memory system (builder pattern).
554    pub fn with_flusher(mut self, config: FlushConfig) -> Self {
555        self.flusher = Some(MemoryFlusher::new(config));
556        self
557    }
558
559    /// Get all messages for the LLM context.
560    pub fn context_messages(&self) -> Vec<Message> {
561        self.short_term.to_messages()
562    }
563
564    /// Add a message to the conversation.
565    pub fn add_message(&mut self, message: Message) {
566        self.short_term.add(message);
567        // Notify flusher
568        if let Some(ref mut flusher) = self.flusher {
569            flusher.on_message_added();
570        }
571    }
572
573    /// Add a fact to long-term memory, also indexing it in the search engine.
574    pub fn add_fact(&mut self, fact: Fact) {
575        if let Some(ref mut engine) = self.search_engine {
576            let _ = engine.index_fact(&fact.id.to_string(), &fact.content);
577        }
578        self.long_term.add_fact(fact);
579    }
580
581    /// Search facts using the hybrid engine (falls back to keyword search).
582    pub fn search_facts_hybrid(&self, query: &str) -> Vec<&Fact> {
583        if let Some(ref engine) = self.search_engine
584            && let Ok(results) = engine.search(query)
585        {
586            let ids: Vec<String> = results.iter().map(|r| r.fact_id.clone()).collect();
587            let found: Vec<&Fact> = self
588                .long_term
589                .facts
590                .iter()
591                .filter(|f| ids.contains(&f.id.to_string()))
592                .collect();
593            if !found.is_empty() {
594                return found;
595            }
596        }
597        // Fallback to keyword search
598        self.long_term.search_facts(query)
599    }
600
601    /// Check if auto-flush should happen, and flush if needed.
602    ///
603    /// Uses the `Option::take()` pattern to avoid borrow conflicts.
604    pub fn check_auto_flush(&mut self) -> Result<bool, MemoryError> {
605        let mut flusher = match self.flusher.take() {
606            Some(f) => f,
607            None => return Ok(false),
608        };
609        let result = if flusher.should_flush() {
610            flusher.flush(self)?;
611            Ok(true)
612        } else {
613            Ok(false)
614        };
615        self.flusher = Some(flusher);
616        result
617    }
618
619    /// Force a flush regardless of triggers.
620    pub fn force_flush(&mut self) -> Result<(), MemoryError> {
621        let mut flusher = match self.flusher.take() {
622            Some(f) => f,
623            None => return Ok(()),
624        };
625        let result = flusher.force_flush(self);
626        self.flusher = Some(flusher);
627        result
628    }
629
630    /// Whether the flusher has unflushed data.
631    pub fn flusher_is_dirty(&self) -> bool {
632        self.flusher.as_ref().is_some_and(|f| f.is_dirty())
633    }
634
635    /// Reset working memory for a new task.
636    pub fn start_new_task(&mut self, goal: impl Into<String>) {
637        self.working.clear();
638        self.working.set_goal(goal);
639    }
640
641    /// Clear everything except long-term memory.
642    pub fn clear_session(&mut self) {
643        self.working.clear();
644        self.short_term.clear();
645    }
646
647    /// Get a breakdown of context usage for the UI.
648    pub fn context_breakdown(&self, context_window: usize) -> ContextBreakdown {
649        let summary_chars = self.short_term.summary().map(|s| s.len()).unwrap_or(0);
650        let message_chars: usize = self
651            .short_term
652            .messages()
653            .iter()
654            .map(|m| m.content_length())
655            .sum();
656        let total_chars = summary_chars + message_chars;
657
658        // Rough token estimate: ~4 chars per token
659        let summary_tokens = summary_chars / 4;
660        let message_tokens = message_chars / 4;
661        let total_tokens = total_chars / 4;
662        let remaining_tokens = context_window.saturating_sub(total_tokens);
663
664        ContextBreakdown {
665            summary_tokens,
666            message_tokens,
667            total_tokens,
668            context_window,
669            remaining_tokens,
670            message_count: self.short_term.len(),
671            total_messages_seen: self.short_term.total_messages_seen(),
672            pinned_count: self.short_term.pinned_count(),
673            has_summary: self.short_term.summary().is_some(),
674            facts_count: self.long_term.facts.len(),
675            rules_count: 0, // Populated separately if knowledge distiller is available
676        }
677    }
678
679    /// Pin a message in short-term memory by position.
680    pub fn pin_message(&mut self, position: usize) -> bool {
681        self.short_term.pin(position)
682    }
683
684    /// Unpin a message in short-term memory by position.
685    pub fn unpin_message(&mut self, position: usize) -> bool {
686        self.short_term.unpin(position)
687    }
688}
689
690/// Breakdown of context window usage for the UI.
691#[derive(Debug, Clone, Default)]
692pub struct ContextBreakdown {
693    /// Estimated tokens used by the summarized prefix.
694    pub summary_tokens: usize,
695    /// Estimated tokens used by active messages.
696    pub message_tokens: usize,
697    /// Total estimated tokens in use.
698    pub total_tokens: usize,
699    /// Total context window size (from config).
700    pub context_window: usize,
701    /// Remaining available tokens.
702    pub remaining_tokens: usize,
703    /// Number of messages currently in the window.
704    pub message_count: usize,
705    /// Total messages seen in the session.
706    pub total_messages_seen: usize,
707    /// Number of pinned messages.
708    pub pinned_count: usize,
709    /// Whether a summary prefix exists.
710    pub has_summary: bool,
711    /// Number of facts in long-term memory.
712    pub facts_count: usize,
713    /// Number of active behavioral rules.
714    pub rules_count: usize,
715}
716
717impl ContextBreakdown {
718    /// Context usage as a ratio (0.0 to 1.0).
719    pub fn usage_ratio(&self) -> f32 {
720        if self.context_window == 0 {
721            return 0.0;
722        }
723        (self.total_tokens as f32 / self.context_window as f32).clamp(0.0, 1.0)
724    }
725
726    /// Whether context is at warning level (>= 70%).
727    /// Aligned with agent's `ContextHealthEvent::Warning` threshold.
728    pub fn is_warning(&self) -> bool {
729        self.usage_ratio() >= 0.7
730    }
731}
732
733/// Metadata about a saved session.
734#[derive(Debug, Clone, Serialize, Deserialize)]
735pub struct SessionMetadata {
736    pub id: Uuid,
737    pub created_at: DateTime<Utc>,
738    pub updated_at: DateTime<Utc>,
739    pub task_summary: Option<String>,
740}
741
742impl SessionMetadata {
743    pub fn new() -> Self {
744        let now = Utc::now();
745        Self {
746            id: Uuid::new_v4(),
747            created_at: now,
748            updated_at: now,
749            task_summary: None,
750        }
751    }
752}
753
754impl Default for SessionMetadata {
755    fn default() -> Self {
756        Self::new()
757    }
758}
759
760/// A persistable session containing the memory state.
761#[derive(Debug, Clone, Serialize, Deserialize)]
762pub struct Session {
763    pub metadata: SessionMetadata,
764    pub working: WorkingMemory,
765    pub long_term: LongTermMemory,
766    pub messages: Vec<Message>,
767    pub window_size: usize,
768}
769
770impl MemorySystem {
771    /// Save the current memory state to a JSON file.
772    pub fn save_session(&self, path: &Path) -> Result<(), MemoryError> {
773        let session = Session {
774            metadata: SessionMetadata {
775                id: Uuid::new_v4(),
776                created_at: Utc::now(),
777                updated_at: Utc::now(),
778                task_summary: self.working.current_goal.clone(),
779            },
780            working: self.working.clone(),
781            long_term: self.long_term.clone(),
782            messages: self.short_term.messages().iter().cloned().collect(),
783            window_size: self.short_term.window_size(),
784        };
785
786        let json =
787            serde_json::to_string_pretty(&session).map_err(|e| MemoryError::PersistenceError {
788                message: format!("Failed to serialize session: {}", e),
789            })?;
790
791        if let Some(parent) = path.parent() {
792            std::fs::create_dir_all(parent).map_err(|e| MemoryError::PersistenceError {
793                message: format!("Failed to create directory: {}", e),
794            })?;
795        }
796
797        std::fs::write(path, json).map_err(|e| MemoryError::PersistenceError {
798            message: format!("Failed to write session file: {}", e),
799        })?;
800
801        Ok(())
802    }
803
804    /// Load a session from a JSON file.
805    pub fn load_session(path: &Path) -> Result<Self, MemoryError> {
806        let json = std::fs::read_to_string(path).map_err(|e| MemoryError::SessionLoadFailed {
807            message: format!("Failed to read session file: {}", e),
808        })?;
809
810        let session: Session =
811            serde_json::from_str(&json).map_err(|e| MemoryError::SessionLoadFailed {
812                message: format!("Failed to deserialize session: {}", e),
813            })?;
814
815        let mut memory = MemorySystem::new(session.window_size);
816        memory.working = session.working;
817        memory.long_term = session.long_term;
818        for msg in session.messages {
819            memory.short_term.add(msg);
820        }
821
822        Ok(memory)
823    }
824}
825
826impl ShortTermMemory {
827    /// Get the window size.
828    pub fn window_size(&self) -> usize {
829        self.window_size
830    }
831}
832
833/// Result of a context compression operation.
834#[derive(Debug, Clone)]
835pub struct CompressionResult {
836    pub messages_before: usize,
837    pub messages_after: usize,
838    pub compressed_count: usize,
839}
840
841// ---------------------------------------------------------------------------
842// Memory Flusher
843// ---------------------------------------------------------------------------
844
845/// Configuration for automatic memory flushing.
846#[derive(Debug, Clone, Serialize, Deserialize)]
847pub struct FlushConfig {
848    /// Whether automatic flushing is enabled.
849    pub enabled: bool,
850    /// Flush interval in seconds (0 = disabled).
851    pub interval_secs: u64,
852    /// Number of messages that triggers an auto-flush (0 = disabled).
853    pub flush_on_message_count: usize,
854    /// Path where flushed data is written.
855    pub flush_path: Option<std::path::PathBuf>,
856}
857
858impl Default for FlushConfig {
859    fn default() -> Self {
860        Self {
861            enabled: false,
862            interval_secs: 300, // 5 minutes
863            flush_on_message_count: 50,
864            flush_path: None,
865        }
866    }
867}
868
869/// Tracks dirty state and triggers for automatic memory persistence.
870#[derive(Debug, Clone)]
871pub struct MemoryFlusher {
872    config: FlushConfig,
873    dirty: bool,
874    messages_since_flush: usize,
875    last_flush: DateTime<Utc>,
876    total_flushes: usize,
877}
878
879impl MemoryFlusher {
880    /// Create a new flusher with the given configuration.
881    pub fn new(config: FlushConfig) -> Self {
882        Self {
883            config,
884            dirty: false,
885            messages_since_flush: 0,
886            last_flush: Utc::now(),
887            total_flushes: 0,
888        }
889    }
890
891    /// Notify the flusher that a message was added.
892    pub fn on_message_added(&mut self) {
893        self.dirty = true;
894        self.messages_since_flush += 1;
895    }
896
897    /// Check whether a flush should happen based on the configured triggers.
898    pub fn should_flush(&self) -> bool {
899        if !self.config.enabled || !self.dirty {
900            return false;
901        }
902
903        // Message-count trigger
904        if self.config.flush_on_message_count > 0
905            && self.messages_since_flush >= self.config.flush_on_message_count
906        {
907            return true;
908        }
909
910        // Time-based trigger
911        if self.config.interval_secs > 0 {
912            let elapsed = (Utc::now() - self.last_flush).num_seconds();
913            if elapsed >= self.config.interval_secs as i64 {
914                return true;
915            }
916        }
917
918        false
919    }
920
921    /// Perform a flush of the memory system to disk.
922    pub fn flush(&mut self, memory: &MemorySystem) -> Result<(), MemoryError> {
923        let path =
924            self.config
925                .flush_path
926                .as_ref()
927                .ok_or_else(|| MemoryError::PersistenceError {
928                    message: "No flush path configured".to_string(),
929                })?;
930
931        memory.save_session(path)?;
932        self.mark_flushed();
933        Ok(())
934    }
935
936    /// Force a flush regardless of triggers.
937    pub fn force_flush(&mut self, memory: &MemorySystem) -> Result<(), MemoryError> {
938        if !self.dirty {
939            return Ok(()); // nothing to flush
940        }
941        self.flush(memory)
942    }
943
944    /// Whether there is unflushed data.
945    pub fn is_dirty(&self) -> bool {
946        self.dirty
947    }
948
949    /// Messages added since the last flush.
950    pub fn messages_since_flush(&self) -> usize {
951        self.messages_since_flush
952    }
953
954    /// Total number of flushes performed.
955    pub fn total_flushes(&self) -> usize {
956        self.total_flushes
957    }
958
959    /// Mark the flush as completed (reset counters).
960    fn mark_flushed(&mut self) {
961        self.dirty = false;
962        self.messages_since_flush = 0;
963        self.last_flush = Utc::now();
964        self.total_flushes += 1;
965    }
966}
967
968// ---------------------------------------------------------------------------
969// Knowledge Distiller — Cross-Session Learning
970// ---------------------------------------------------------------------------
971
972/// A distilled behavioral rule generated from accumulated corrections and facts.
973#[derive(Debug, Clone, Serialize, Deserialize)]
974pub struct BehavioralRule {
975    /// Unique identifier.
976    pub id: Uuid,
977    /// Human-readable rule description (injected into system prompt).
978    pub rule: String,
979    /// Source entries (correction/fact IDs) that contributed to this rule.
980    pub source_ids: Vec<Uuid>,
981    /// How many source entries support this rule (higher = more confidence).
982    pub support_count: usize,
983    /// When this rule was distilled.
984    pub created_at: DateTime<Utc>,
985}
986
987/// Persistent knowledge store containing distilled behavioral rules.
988#[derive(Debug, Clone, Default, Serialize, Deserialize)]
989pub struct KnowledgeStore {
990    pub rules: Vec<BehavioralRule>,
991    /// Correction IDs already processed by the distiller.
992    pub processed_correction_ids: Vec<Uuid>,
993    /// Fact IDs already processed by the distiller.
994    pub processed_fact_ids: Vec<Uuid>,
995}
996
997impl KnowledgeStore {
998    pub fn new() -> Self {
999        Self::default()
1000    }
1001
1002    /// Load a knowledge store from a JSON file.
1003    pub fn load(path: &std::path::Path) -> Result<Self, MemoryError> {
1004        if !path.exists() {
1005            return Ok(Self::new());
1006        }
1007        let json = std::fs::read_to_string(path).map_err(|e| MemoryError::PersistenceError {
1008            message: format!("Failed to read knowledge store: {}", e),
1009        })?;
1010        serde_json::from_str(&json).map_err(|e| MemoryError::PersistenceError {
1011            message: format!("Failed to parse knowledge store: {}", e),
1012        })
1013    }
1014
1015    /// Save the knowledge store to a JSON file.
1016    pub fn save(&self, path: &std::path::Path) -> Result<(), MemoryError> {
1017        if let Some(parent) = path.parent() {
1018            std::fs::create_dir_all(parent).map_err(|e| MemoryError::PersistenceError {
1019                message: format!("Failed to create knowledge directory: {}", e),
1020            })?;
1021        }
1022        let json =
1023            serde_json::to_string_pretty(self).map_err(|e| MemoryError::PersistenceError {
1024                message: format!("Failed to serialize knowledge store: {}", e),
1025            })?;
1026        std::fs::write(path, json).map_err(|e| MemoryError::PersistenceError {
1027            message: format!("Failed to write knowledge store: {}", e),
1028        })
1029    }
1030}
1031
1032/// The `KnowledgeDistiller` processes accumulated corrections and facts from
1033/// `LongTermMemory` to generate compressed behavioral rules. These rules
1034/// are injected into the system prompt to influence future agent behavior,
1035/// creating a cross-session learning loop.
1036pub struct KnowledgeDistiller {
1037    store: KnowledgeStore,
1038    max_rules: usize,
1039    min_entries: usize,
1040    store_path: Option<std::path::PathBuf>,
1041}
1042
1043impl KnowledgeDistiller {
1044    /// Create a new distiller from config. If `config` is None, creates a
1045    /// disabled distiller that returns no rules.
1046    pub fn new(config: Option<&crate::config::KnowledgeConfig>) -> Self {
1047        match config {
1048            Some(cfg) if cfg.enabled => {
1049                let store = cfg
1050                    .knowledge_path
1051                    .as_ref()
1052                    .and_then(|p| KnowledgeStore::load(p).ok())
1053                    .unwrap_or_default();
1054                Self {
1055                    store,
1056                    max_rules: cfg.max_rules,
1057                    min_entries: cfg.min_entries_for_distillation,
1058                    store_path: cfg.knowledge_path.clone(),
1059                }
1060            }
1061            _ => Self {
1062                store: KnowledgeStore::new(),
1063                max_rules: 0,
1064                min_entries: usize::MAX,
1065                store_path: None,
1066            },
1067        }
1068    }
1069
1070    /// Run the distillation process over long-term memory.
1071    ///
1072    /// Groups corrections by context patterns and generates compressed rules.
1073    /// Only processes entries not yet seen by the distiller.
1074    pub fn distill(&mut self, long_term: &LongTermMemory) {
1075        // Collect new (unprocessed) corrections
1076        let new_corrections: Vec<&Correction> = long_term
1077            .corrections
1078            .iter()
1079            .filter(|c| !self.store.processed_correction_ids.contains(&c.id))
1080            .collect();
1081
1082        // Collect new (unprocessed) facts
1083        let new_facts: Vec<&Fact> = long_term
1084            .facts
1085            .iter()
1086            .filter(|f| !self.store.processed_fact_ids.contains(&f.id))
1087            .collect();
1088
1089        let total_new = new_corrections.len() + new_facts.len();
1090        if total_new < self.min_entries {
1091            return; // Not enough new data to distill
1092        }
1093
1094        // --- Distill corrections into rules ---
1095        // Group corrections by common patterns in their context field.
1096        let mut context_groups: HashMap<String, Vec<&Correction>> = HashMap::new();
1097        for correction in &new_corrections {
1098            // Normalize the context to a group key (first 50 chars, lowercased)
1099            let key = correction
1100                .context
1101                .chars()
1102                .take(50)
1103                .collect::<String>()
1104                .to_lowercase();
1105            context_groups.entry(key).or_default().push(correction);
1106        }
1107
1108        // Generate a rule per group that has enough entries
1109        for group in context_groups.values() {
1110            if group.len() >= 2 {
1111                // Multiple corrections with similar context → a behavioral rule
1112                let corrected_patterns: Vec<&str> =
1113                    group.iter().map(|c| c.corrected.as_str()).collect();
1114                let rule_text = format!(
1115                    "Based on {} previous corrections: prefer {}",
1116                    group.len(),
1117                    corrected_patterns.join("; ")
1118                );
1119                let source_ids: Vec<Uuid> = group.iter().map(|c| c.id).collect();
1120                self.store.rules.push(BehavioralRule {
1121                    id: Uuid::new_v4(),
1122                    rule: rule_text,
1123                    source_ids,
1124                    support_count: group.len(),
1125                    created_at: Utc::now(),
1126                });
1127            } else {
1128                // Single correction → direct rule
1129                for c in group {
1130                    self.store.rules.push(BehavioralRule {
1131                        id: Uuid::new_v4(),
1132                        rule: format!("Instead of '{}', prefer '{}'", c.original, c.corrected),
1133                        source_ids: vec![c.id],
1134                        support_count: 1,
1135                        created_at: Utc::now(),
1136                    });
1137                }
1138            }
1139        }
1140
1141        // --- Distill preferences from facts ---
1142        // Facts tagged with "preference" or containing directive language get
1143        // turned into rules directly.
1144        for fact in &new_facts {
1145            let is_preference = fact.tags.iter().any(|t| t == "preference")
1146                || fact.content.starts_with("Prefer")
1147                || fact.content.starts_with("Always")
1148                || fact.content.starts_with("Never")
1149                || fact.content.starts_with("Don't")
1150                || fact.content.starts_with("Use ");
1151            if is_preference {
1152                self.store.rules.push(BehavioralRule {
1153                    id: Uuid::new_v4(),
1154                    rule: fact.content.clone(),
1155                    source_ids: vec![fact.id],
1156                    support_count: 1,
1157                    created_at: Utc::now(),
1158                });
1159            }
1160        }
1161
1162        // Mark all new entries as processed
1163        for c in &new_corrections {
1164            self.store.processed_correction_ids.push(c.id);
1165        }
1166        for f in &new_facts {
1167            self.store.processed_fact_ids.push(f.id);
1168        }
1169
1170        // Trim rules to max_rules, keeping highest support_count
1171        if self.store.rules.len() > self.max_rules {
1172            self.store
1173                .rules
1174                .sort_by(|a, b| b.support_count.cmp(&a.support_count));
1175            self.store.rules.truncate(self.max_rules);
1176        }
1177
1178        // Persist if a store path is configured
1179        if let Some(ref path) = self.store_path {
1180            let _ = self.store.save(path);
1181        }
1182    }
1183
1184    /// Get the current distilled rules formatted for system prompt injection.
1185    ///
1186    /// Returns an empty string if no rules exist.
1187    pub fn rules_for_prompt(&self) -> String {
1188        if self.store.rules.is_empty() {
1189            return String::new();
1190        }
1191        let mut prompt = String::from(
1192            "\n\n## Learned Behavioral Rules\n\
1193             The following rules were distilled from previous sessions. Follow them:\n",
1194        );
1195        for (i, rule) in self.store.rules.iter().enumerate() {
1196            prompt.push_str(&format!("{}. {}\n", i + 1, rule.rule));
1197        }
1198        prompt
1199    }
1200
1201    /// Get the number of distilled rules.
1202    pub fn rule_count(&self) -> usize {
1203        self.store.rules.len()
1204    }
1205
1206    /// Get the knowledge store reference (for diagnostics/REPL).
1207    pub fn store(&self) -> &KnowledgeStore {
1208        &self.store
1209    }
1210}
1211
1212#[cfg(test)]
1213mod tests {
1214    use super::*;
1215    use crate::types::{Content, Role};
1216
1217    #[test]
1218    fn test_working_memory_lifecycle() {
1219        let mut wm = WorkingMemory::new();
1220        assert!(wm.current_goal.is_none());
1221
1222        wm.set_goal("refactor auth module");
1223        assert_eq!(wm.current_goal.as_deref(), Some("refactor auth module"));
1224
1225        wm.add_sub_task("read current implementation");
1226        wm.add_sub_task("design new structure");
1227        assert_eq!(wm.sub_tasks.len(), 2);
1228
1229        wm.note("finding", "uses basic auth currently");
1230        assert_eq!(
1231            wm.scratchpad.get("finding").map(|s| s.as_str()),
1232            Some("uses basic auth currently")
1233        );
1234
1235        wm.add_active_file("src/auth/mod.rs");
1236        wm.add_active_file("src/auth/mod.rs"); // duplicate
1237        assert_eq!(wm.active_files.len(), 1);
1238
1239        wm.clear();
1240        assert!(wm.current_goal.is_none());
1241        assert!(wm.sub_tasks.is_empty());
1242    }
1243
1244    #[test]
1245    fn test_short_term_memory_basic() {
1246        let mut stm = ShortTermMemory::new(5);
1247        assert!(stm.is_empty());
1248        assert_eq!(stm.len(), 0);
1249
1250        stm.add(Message::user("hello"));
1251        stm.add(Message::assistant("hi there"));
1252        assert_eq!(stm.len(), 2);
1253        assert_eq!(stm.total_messages_seen(), 2);
1254
1255        let messages = stm.to_messages();
1256        assert_eq!(messages.len(), 2);
1257    }
1258
1259    #[test]
1260    fn test_short_term_memory_window() {
1261        let mut stm = ShortTermMemory::new(3);
1262
1263        for i in 0..6 {
1264            stm.add(Message::user(format!("message {}", i)));
1265        }
1266
1267        assert_eq!(stm.len(), 6);
1268        let messages = stm.to_messages();
1269        // Should only include last 3 messages (window size)
1270        assert_eq!(messages.len(), 3);
1271        assert_eq!(messages[0].content.as_text(), Some("message 3"));
1272        assert_eq!(messages[2].content.as_text(), Some("message 5"));
1273    }
1274
1275    #[test]
1276    fn test_short_term_memory_compression() {
1277        let mut stm = ShortTermMemory::new(3);
1278
1279        for i in 0..6 {
1280            stm.add(Message::user(format!("message {}", i)));
1281        }
1282
1283        assert!(stm.needs_compression());
1284
1285        let to_summarize = stm.messages_to_summarize();
1286        assert_eq!(to_summarize.len(), 3); // messages 0, 1, 2
1287
1288        let compressed = stm.compress("Summary of messages 0-2.".to_string());
1289        assert_eq!(compressed, 3);
1290        assert_eq!(stm.len(), 3); // only window remains
1291
1292        let messages = stm.to_messages();
1293        // First message should be the summary
1294        assert_eq!(messages.len(), 4); // 1 summary + 3 recent
1295        assert!(
1296            messages[0]
1297                .content
1298                .as_text()
1299                .unwrap()
1300                .contains("Summary of")
1301        );
1302        assert_eq!(messages[0].role, Role::System);
1303    }
1304
1305    #[test]
1306    fn test_short_term_memory_double_compression() {
1307        let mut stm = ShortTermMemory::new(2);
1308
1309        // First batch
1310        for i in 0..5 {
1311            stm.add(Message::user(format!("msg {}", i)));
1312        }
1313        stm.compress("First summary.".to_string());
1314        assert_eq!(stm.len(), 2);
1315
1316        // Second batch
1317        for i in 5..8 {
1318            stm.add(Message::user(format!("msg {}", i)));
1319        }
1320        stm.compress("Second summary.".to_string());
1321        assert_eq!(stm.len(), 2);
1322
1323        // Summary should merge
1324        let summary = stm.summary().unwrap();
1325        assert!(summary.contains("First summary."));
1326        assert!(summary.contains("Second summary."));
1327    }
1328
1329    #[test]
1330    fn test_short_term_memory_clear() {
1331        let mut stm = ShortTermMemory::new(5);
1332        stm.add(Message::user("test"));
1333        stm.compress("summary".to_string());
1334
1335        stm.clear();
1336        assert!(stm.is_empty());
1337        assert!(stm.summary().is_none());
1338        assert_eq!(stm.total_messages_seen(), 0);
1339    }
1340
1341    #[test]
1342    fn test_fact_creation() {
1343        let fact = Fact::new("Project uses JWT auth", "code analysis")
1344            .with_tags(vec!["auth".to_string(), "jwt".to_string()]);
1345        assert_eq!(fact.content, "Project uses JWT auth");
1346        assert_eq!(fact.source, "code analysis");
1347        assert_eq!(fact.tags.len(), 2);
1348    }
1349
1350    #[test]
1351    fn test_long_term_memory() {
1352        let mut ltm = LongTermMemory::new();
1353
1354        ltm.add_fact(Fact::new("Uses Rust 2021 edition", "Cargo.toml"));
1355        ltm.set_preference("code_style", "rustfmt defaults");
1356        ltm.add_correction(
1357            "wrong import".to_string(),
1358            "correct import".to_string(),
1359            "editing main.rs".to_string(),
1360        );
1361
1362        assert_eq!(ltm.facts.len(), 1);
1363        assert_eq!(ltm.get_preference("code_style"), Some("rustfmt defaults"));
1364        assert_eq!(ltm.corrections.len(), 1);
1365    }
1366
1367    #[test]
1368    fn test_long_term_memory_search() {
1369        let mut ltm = LongTermMemory::new();
1370        ltm.add_fact(Fact::new("Project uses JWT authentication", "analysis"));
1371        ltm.add_fact(
1372            Fact::new("Database is PostgreSQL", "config").with_tags(vec!["database".to_string()]),
1373        );
1374        ltm.add_fact(Fact::new("Frontend uses React", "package.json"));
1375
1376        let results = ltm.search_facts("JWT");
1377        assert_eq!(results.len(), 1);
1378        assert!(results[0].content.contains("JWT"));
1379
1380        let results = ltm.search_facts("database");
1381        assert_eq!(results.len(), 1);
1382
1383        let results = ltm.search_facts("nonexistent");
1384        assert!(results.is_empty());
1385    }
1386
1387    #[test]
1388    fn test_memory_system() {
1389        let mut mem = MemorySystem::new(5);
1390
1391        mem.start_new_task("fix bug #42");
1392        assert_eq!(mem.working.current_goal.as_deref(), Some("fix bug #42"));
1393
1394        mem.add_message(Message::user("fix the null pointer bug"));
1395        mem.add_message(Message::assistant("I'll look into that."));
1396
1397        let ctx = mem.context_messages();
1398        assert_eq!(ctx.len(), 2);
1399
1400        mem.clear_session();
1401        assert!(mem.short_term.is_empty());
1402        assert!(mem.working.current_goal.is_none());
1403    }
1404
1405    #[test]
1406    fn test_compression_no_op_when_within_window() {
1407        let mut stm = ShortTermMemory::new(10);
1408        stm.add(Message::user("hello"));
1409        stm.add(Message::assistant("hi"));
1410
1411        assert!(!stm.needs_compression());
1412        assert!(stm.messages_to_summarize().is_empty());
1413
1414        let compressed = stm.compress("should not matter".to_string());
1415        assert_eq!(compressed, 0);
1416    }
1417
1418    #[test]
1419    fn test_memory_system_new_task_preserves_long_term() {
1420        let mut mem = MemorySystem::new(5);
1421        mem.long_term.add_fact(Fact::new("important fact", "test"));
1422        mem.add_message(Message::user("task 1"));
1423
1424        mem.start_new_task("task 2");
1425        assert_eq!(mem.working.current_goal.as_deref(), Some("task 2"));
1426        assert_eq!(mem.long_term.facts.len(), 1); // preserved
1427    }
1428
1429    // --- Session persistence tests ---
1430
1431    #[test]
1432    fn test_session_save_load_roundtrip() {
1433        let dir = tempfile::tempdir().unwrap();
1434        let session_path = dir.path().join("session.json");
1435
1436        // Build a memory system with data
1437        let mut mem = MemorySystem::new(10);
1438        mem.start_new_task("fix bug #42");
1439        mem.add_message(Message::user("fix the bug"));
1440        mem.add_message(Message::assistant("Looking into it."));
1441        mem.long_term.add_fact(Fact::new("Uses Rust", "analysis"));
1442        mem.long_term.set_preference("style", "concise");
1443
1444        // Save
1445        mem.save_session(&session_path).unwrap();
1446        assert!(session_path.exists());
1447
1448        // Load
1449        let loaded = MemorySystem::load_session(&session_path).unwrap();
1450        assert_eq!(loaded.working.current_goal.as_deref(), Some("fix bug #42"));
1451        assert_eq!(loaded.short_term.len(), 2);
1452        assert_eq!(loaded.long_term.facts.len(), 1);
1453        assert_eq!(loaded.long_term.get_preference("style"), Some("concise"));
1454
1455        // Verify message content
1456        let messages = loaded.context_messages();
1457        assert_eq!(messages.len(), 2);
1458        assert_eq!(messages[0].content.as_text(), Some("fix the bug"));
1459        assert_eq!(messages[1].content.as_text(), Some("Looking into it."));
1460    }
1461
1462    #[test]
1463    fn test_session_load_missing_file() {
1464        let result = MemorySystem::load_session(Path::new("/nonexistent/session.json"));
1465        assert!(result.is_err());
1466    }
1467
1468    #[test]
1469    fn test_session_load_corrupt_json() {
1470        let dir = tempfile::tempdir().unwrap();
1471        let path = dir.path().join("bad.json");
1472        std::fs::write(&path, "not valid json").unwrap();
1473
1474        let result = MemorySystem::load_session(&path);
1475        assert!(result.is_err());
1476    }
1477
1478    #[test]
1479    fn test_session_save_creates_directories() {
1480        let dir = tempfile::tempdir().unwrap();
1481        let session_path = dir.path().join("nested").join("dir").join("session.json");
1482
1483        let mem = MemorySystem::new(5);
1484        mem.save_session(&session_path).unwrap();
1485        assert!(session_path.exists());
1486    }
1487
1488    #[test]
1489    fn test_session_metadata() {
1490        let meta = SessionMetadata::new();
1491        assert!(meta.task_summary.is_none());
1492        assert!(meta.created_at <= Utc::now());
1493
1494        let default_meta = SessionMetadata::default();
1495        assert!(default_meta.task_summary.is_none());
1496    }
1497
1498    #[test]
1499    fn test_short_term_window_size() {
1500        let stm = ShortTermMemory::new(7);
1501        assert_eq!(stm.window_size(), 7);
1502    }
1503
1504    // --- Pinning tests ---
1505
1506    #[test]
1507    fn test_pin_message() {
1508        let mut stm = ShortTermMemory::new(5);
1509        stm.add(Message::user("msg 0"));
1510        stm.add(Message::user("msg 1"));
1511        stm.add(Message::user("msg 2"));
1512
1513        assert!(stm.pin(1));
1514        assert!(stm.is_pinned(1));
1515        assert!(!stm.is_pinned(0));
1516        assert_eq!(stm.pinned_count(), 1);
1517    }
1518
1519    #[test]
1520    fn test_pin_out_of_bounds() {
1521        let mut stm = ShortTermMemory::new(5);
1522        stm.add(Message::user("msg 0"));
1523        assert!(!stm.pin(5)); // out of bounds
1524    }
1525
1526    #[test]
1527    fn test_unpin_message() {
1528        let mut stm = ShortTermMemory::new(5);
1529        stm.add(Message::user("msg 0"));
1530        stm.add(Message::user("msg 1"));
1531
1532        stm.pin(0);
1533        assert!(stm.is_pinned(0));
1534        assert!(stm.unpin(0));
1535        assert!(!stm.is_pinned(0));
1536    }
1537
1538    #[test]
1539    fn test_pinned_survives_compression() {
1540        let mut stm = ShortTermMemory::new(3);
1541        // Add enough messages to trigger compression
1542        stm.add(Message::user("old 0"));
1543        stm.add(Message::user("old 1"));
1544        stm.add(Message::user("important pinned"));
1545        stm.add(Message::user("msg 3"));
1546        stm.add(Message::user("msg 4"));
1547        stm.add(Message::user("msg 5"));
1548
1549        // Pin the third message (index 2)
1550        stm.pin(2);
1551        assert!(stm.needs_compression());
1552
1553        let removed = stm.compress("Summary of old messages".to_string());
1554        // The pinned message should be preserved
1555        assert!(removed < 3); // Not all 3 were removed because one was pinned
1556
1557        // The pinned message should still be in the window
1558        let msgs = stm.to_messages();
1559        let has_pinned = msgs
1560            .iter()
1561            .any(|m| matches!(&m.content, Content::Text { text } if text == "important pinned"));
1562        assert!(has_pinned, "Pinned message should survive compression");
1563    }
1564
1565    #[test]
1566    fn test_clear_resets_pins() {
1567        let mut stm = ShortTermMemory::new(5);
1568        stm.add(Message::user("msg 0"));
1569        stm.pin(0);
1570        assert_eq!(stm.pinned_count(), 1);
1571
1572        stm.clear();
1573        assert_eq!(stm.pinned_count(), 0);
1574    }
1575
1576    // --- Context Breakdown tests ---
1577
1578    #[test]
1579    fn test_context_breakdown() {
1580        let mut memory = MemorySystem::new(10);
1581        memory.add_message(Message::user("hello world"));
1582        memory.add_message(Message::assistant("hi there!"));
1583
1584        let ctx = memory.context_breakdown(8000);
1585        assert!(ctx.message_tokens > 0);
1586        assert_eq!(ctx.message_count, 2);
1587        assert_eq!(ctx.context_window, 8000);
1588        assert!(ctx.remaining_tokens > 0);
1589        assert!(!ctx.has_summary);
1590        assert_eq!(ctx.pinned_count, 0);
1591    }
1592
1593    #[test]
1594    fn test_context_breakdown_ratio() {
1595        let ctx = ContextBreakdown {
1596            total_tokens: 4000,
1597            context_window: 8000,
1598            ..Default::default()
1599        };
1600        assert!((ctx.usage_ratio() - 0.5).abs() < 0.01);
1601        assert!(!ctx.is_warning());
1602
1603        let ctx_high = ContextBreakdown {
1604            total_tokens: 7000,
1605            context_window: 8000,
1606            ..Default::default()
1607        };
1608        assert!(ctx_high.is_warning());
1609    }
1610
1611    #[test]
1612    fn test_pin_message_via_memory_system() {
1613        let mut memory = MemorySystem::new(10);
1614        memory.add_message(Message::user("msg 0"));
1615        memory.add_message(Message::user("msg 1"));
1616
1617        assert!(memory.pin_message(0));
1618        assert!(memory.short_term.is_pinned(0));
1619    }
1620
1621    // --- Memory Flusher tests ---
1622
1623    #[test]
1624    fn test_flusher_default_config() {
1625        let config = FlushConfig::default();
1626        assert!(!config.enabled);
1627        assert_eq!(config.interval_secs, 300);
1628        assert_eq!(config.flush_on_message_count, 50);
1629        assert!(config.flush_path.is_none());
1630    }
1631
1632    #[test]
1633    fn test_flusher_not_dirty_by_default() {
1634        let flusher = MemoryFlusher::new(FlushConfig::default());
1635        assert!(!flusher.is_dirty());
1636        assert_eq!(flusher.messages_since_flush(), 0);
1637        assert_eq!(flusher.total_flushes(), 0);
1638    }
1639
1640    #[test]
1641    fn test_flusher_marks_dirty_on_message() {
1642        let mut flusher = MemoryFlusher::new(FlushConfig::default());
1643        flusher.on_message_added();
1644        assert!(flusher.is_dirty());
1645        assert_eq!(flusher.messages_since_flush(), 1);
1646    }
1647
1648    #[test]
1649    fn test_flusher_disabled_never_triggers() {
1650        let mut flusher = MemoryFlusher::new(FlushConfig {
1651            enabled: false,
1652            ..FlushConfig::default()
1653        });
1654        for _ in 0..100 {
1655            flusher.on_message_added();
1656        }
1657        assert!(!flusher.should_flush());
1658    }
1659
1660    #[test]
1661    fn test_flusher_message_count_trigger() {
1662        let mut flusher = MemoryFlusher::new(FlushConfig {
1663            enabled: true,
1664            flush_on_message_count: 5,
1665            interval_secs: 0,
1666            flush_path: None,
1667        });
1668
1669        for _ in 0..4 {
1670            flusher.on_message_added();
1671        }
1672        assert!(!flusher.should_flush());
1673
1674        flusher.on_message_added(); // 5th message
1675        assert!(flusher.should_flush());
1676    }
1677
1678    #[test]
1679    fn test_flusher_not_dirty_no_trigger() {
1680        let flusher = MemoryFlusher::new(FlushConfig {
1681            enabled: true,
1682            flush_on_message_count: 1,
1683            interval_secs: 0,
1684            flush_path: None,
1685        });
1686        // Not dirty, so should_flush is false even though threshold is 1
1687        assert!(!flusher.should_flush());
1688    }
1689
1690    #[test]
1691    fn test_flusher_flush_resets_state() {
1692        let dir = tempfile::tempdir().unwrap();
1693        let flush_path = dir.path().join("flush.json");
1694
1695        let mut flusher = MemoryFlusher::new(FlushConfig {
1696            enabled: true,
1697            flush_on_message_count: 2,
1698            interval_secs: 0,
1699            flush_path: Some(flush_path.clone()),
1700        });
1701
1702        let mut mem = MemorySystem::new(10);
1703        mem.add_message(Message::user("test"));
1704
1705        flusher.on_message_added();
1706        flusher.on_message_added();
1707        assert!(flusher.should_flush());
1708
1709        flusher.flush(&mem).unwrap();
1710        assert!(!flusher.is_dirty());
1711        assert_eq!(flusher.messages_since_flush(), 0);
1712        assert_eq!(flusher.total_flushes(), 1);
1713        assert!(flush_path.exists());
1714    }
1715
1716    #[test]
1717    fn test_flusher_force_flush() {
1718        let dir = tempfile::tempdir().unwrap();
1719        let flush_path = dir.path().join("force.json");
1720
1721        let mut flusher = MemoryFlusher::new(FlushConfig {
1722            enabled: true,
1723            flush_on_message_count: 100,
1724            interval_secs: 0,
1725            flush_path: Some(flush_path.clone()),
1726        });
1727
1728        let mem = MemorySystem::new(10);
1729
1730        // Not dirty - force_flush is a no-op
1731        flusher.force_flush(&mem).unwrap();
1732        assert_eq!(flusher.total_flushes(), 0);
1733
1734        // Make dirty, then force
1735        flusher.on_message_added();
1736        flusher.force_flush(&mem).unwrap();
1737        assert_eq!(flusher.total_flushes(), 1);
1738        assert!(!flusher.is_dirty());
1739    }
1740
1741    #[test]
1742    fn test_flusher_no_path_error() {
1743        let mut flusher = MemoryFlusher::new(FlushConfig {
1744            enabled: true,
1745            flush_on_message_count: 1,
1746            interval_secs: 0,
1747            flush_path: None,
1748        });
1749        flusher.on_message_added();
1750
1751        let mem = MemorySystem::new(10);
1752        let result = flusher.flush(&mem);
1753        assert!(result.is_err());
1754    }
1755
1756    #[test]
1757    fn test_flush_config_serialization() {
1758        let config = FlushConfig {
1759            enabled: true,
1760            interval_secs: 120,
1761            flush_on_message_count: 25,
1762            flush_path: Some(std::path::PathBuf::from("/tmp/flush.json")),
1763        };
1764        let json = serde_json::to_string(&config).unwrap();
1765        let restored: FlushConfig = serde_json::from_str(&json).unwrap();
1766        assert!(restored.enabled);
1767        assert_eq!(restored.interval_secs, 120);
1768        assert_eq!(restored.flush_on_message_count, 25);
1769    }
1770
1771    // --- A4: HybridSearchEngine → MemorySystem integration tests ---
1772
1773    #[test]
1774    fn test_memory_system_without_search_uses_keyword_fallback() {
1775        let mut mem = MemorySystem::new(10);
1776        mem.add_fact(Fact::new("Rust uses ownership for memory safety", "docs"));
1777        mem.add_fact(Fact::new("Python uses garbage collection", "docs"));
1778
1779        let results = mem.search_facts_hybrid("ownership");
1780        assert_eq!(results.len(), 1);
1781        assert!(results[0].content.contains("ownership"));
1782    }
1783
1784    #[test]
1785    fn test_memory_system_with_search_engine() {
1786        let dir = tempfile::tempdir().unwrap();
1787        let config = SearchConfig {
1788            index_path: dir.path().join("idx"),
1789            db_path: dir.path().join("vec.db"),
1790            vector_dimensions: 64,
1791            full_text_weight: 0.5,
1792            vector_weight: 0.5,
1793            max_results: 10,
1794        };
1795        let mut mem = MemorySystem::with_search(10, config).unwrap();
1796
1797        mem.add_fact(Fact::new("Rust uses ownership model", "analysis"));
1798        mem.add_fact(Fact::new("Python garbage collector", "analysis"));
1799
1800        // The hybrid engine should find results (or fall back to keyword)
1801        let results = mem.search_facts_hybrid("ownership");
1802        assert!(!results.is_empty());
1803        assert!(results.iter().any(|f| f.content.contains("ownership")));
1804    }
1805
1806    #[test]
1807    fn test_memory_system_search_empty_query() {
1808        let mut mem = MemorySystem::new(10);
1809        mem.add_fact(Fact::new("some fact", "source"));
1810        let results = mem.search_facts_hybrid("");
1811        // Empty query falls back to keyword search, which matches nothing
1812        // (no content matches empty substring in the keywords path)
1813        // Actually, empty string is contained in everything
1814        assert!(!results.is_empty());
1815    }
1816
1817    #[test]
1818    fn test_memory_system_search_no_facts() {
1819        let mem = MemorySystem::new(10);
1820        let results = mem.search_facts_hybrid("anything");
1821        assert!(results.is_empty());
1822    }
1823
1824    #[test]
1825    fn test_add_fact_indexes_into_search_engine() {
1826        let dir = tempfile::tempdir().unwrap();
1827        let config = SearchConfig {
1828            index_path: dir.path().join("idx"),
1829            db_path: dir.path().join("vec.db"),
1830            vector_dimensions: 64,
1831            full_text_weight: 0.5,
1832            vector_weight: 0.5,
1833            max_results: 10,
1834        };
1835        let mut mem = MemorySystem::with_search(10, config).unwrap();
1836
1837        // Add multiple facts
1838        for i in 0..5 {
1839            mem.add_fact(Fact::new(format!("fact number {}", i), "test"));
1840        }
1841
1842        // All 5 should be in long-term memory
1843        assert_eq!(mem.long_term.facts.len(), 5);
1844    }
1845
1846    // --- A5: MemoryFlusher → MemorySystem integration tests ---
1847
1848    #[test]
1849    fn test_memory_system_with_flusher() {
1850        let config = FlushConfig {
1851            enabled: true,
1852            flush_on_message_count: 5,
1853            interval_secs: 0,
1854            flush_path: None,
1855        };
1856        let mem = MemorySystem::new(10).with_flusher(config);
1857        assert!(!mem.flusher_is_dirty());
1858    }
1859
1860    #[test]
1861    fn test_memory_system_add_message_notifies_flusher() {
1862        let config = FlushConfig {
1863            enabled: true,
1864            flush_on_message_count: 5,
1865            interval_secs: 0,
1866            flush_path: None,
1867        };
1868        let mut mem = MemorySystem::new(10).with_flusher(config);
1869
1870        mem.add_message(Message::user("hello"));
1871        assert!(mem.flusher_is_dirty());
1872    }
1873
1874    #[test]
1875    fn test_memory_system_check_auto_flush_no_flusher() {
1876        let mut mem = MemorySystem::new(10);
1877        // No flusher attached — should be a no-op returning Ok(false)
1878        let result = mem.check_auto_flush().unwrap();
1879        assert!(!result);
1880    }
1881
1882    #[test]
1883    fn test_memory_system_check_auto_flush_triggers() {
1884        let dir = tempfile::tempdir().unwrap();
1885        let flush_path = dir.path().join("auto_flush.json");
1886
1887        let config = FlushConfig {
1888            enabled: true,
1889            flush_on_message_count: 3,
1890            interval_secs: 0,
1891            flush_path: Some(flush_path.clone()),
1892        };
1893        let mut mem = MemorySystem::new(10).with_flusher(config);
1894
1895        // Add messages below threshold
1896        mem.add_message(Message::user("msg 1"));
1897        mem.add_message(Message::user("msg 2"));
1898        assert!(!mem.check_auto_flush().unwrap());
1899        assert!(!flush_path.exists());
1900
1901        // Hit the threshold
1902        mem.add_message(Message::user("msg 3"));
1903        assert!(mem.check_auto_flush().unwrap());
1904        assert!(flush_path.exists());
1905
1906        // After flush, flusher should not be dirty
1907        assert!(!mem.flusher_is_dirty());
1908    }
1909
1910    #[test]
1911    fn test_memory_system_force_flush() {
1912        let dir = tempfile::tempdir().unwrap();
1913        let flush_path = dir.path().join("force_flush.json");
1914
1915        let config = FlushConfig {
1916            enabled: true,
1917            flush_on_message_count: 100, // high threshold
1918            interval_secs: 0,
1919            flush_path: Some(flush_path.clone()),
1920        };
1921        let mut mem = MemorySystem::new(10).with_flusher(config);
1922
1923        mem.add_message(Message::user("important data"));
1924        assert!(mem.flusher_is_dirty());
1925
1926        mem.force_flush().unwrap();
1927        assert!(!mem.flusher_is_dirty());
1928        assert!(flush_path.exists());
1929    }
1930
1931    #[test]
1932    fn test_memory_system_force_flush_no_flusher() {
1933        let mut mem = MemorySystem::new(10);
1934        // Should be a no-op, not an error
1935        mem.force_flush().unwrap();
1936    }
1937
1938    // --- Knowledge Distiller Tests ---
1939
1940    #[test]
1941    fn test_knowledge_distiller_disabled() {
1942        let distiller = KnowledgeDistiller::new(None);
1943        assert_eq!(distiller.rule_count(), 0);
1944        assert!(distiller.rules_for_prompt().is_empty());
1945    }
1946
1947    #[test]
1948    fn test_knowledge_distiller_no_data() {
1949        let config = crate::config::KnowledgeConfig::default();
1950        let mut distiller = KnowledgeDistiller::new(Some(&config));
1951        let ltm = LongTermMemory::new();
1952        distiller.distill(&ltm);
1953        assert_eq!(distiller.rule_count(), 0);
1954    }
1955
1956    #[test]
1957    fn test_knowledge_distiller_corrections_below_threshold() {
1958        let config = crate::config::KnowledgeConfig {
1959            min_entries_for_distillation: 5,
1960            ..Default::default()
1961        };
1962        let mut distiller = KnowledgeDistiller::new(Some(&config));
1963
1964        let mut ltm = LongTermMemory::new();
1965        ltm.add_correction(
1966            "unwrap()".into(),
1967            "? operator".into(),
1968            "error handling".into(),
1969        );
1970        ltm.add_correction("println!".into(), "tracing::info!".into(), "logging".into());
1971
1972        distiller.distill(&ltm);
1973        // Only 2 entries, threshold is 5
1974        assert_eq!(distiller.rule_count(), 0);
1975    }
1976
1977    #[test]
1978    fn test_knowledge_distiller_single_corrections() {
1979        let config = crate::config::KnowledgeConfig {
1980            min_entries_for_distillation: 2,
1981            ..Default::default()
1982        };
1983        let mut distiller = KnowledgeDistiller::new(Some(&config));
1984
1985        let mut ltm = LongTermMemory::new();
1986        ltm.add_correction(
1987            "unwrap()".into(),
1988            "? operator".into(),
1989            "error handling".into(),
1990        );
1991        ltm.add_correction("println!".into(), "tracing::info!".into(), "logging".into());
1992        // Both have different contexts → separate rules
1993
1994        distiller.distill(&ltm);
1995        assert_eq!(distiller.rule_count(), 2);
1996
1997        let prompt = distiller.rules_for_prompt();
1998        assert!(prompt.contains("Learned Behavioral Rules"));
1999        assert!(prompt.contains("? operator"));
2000        assert!(prompt.contains("tracing::info!"));
2001    }
2002
2003    #[test]
2004    fn test_knowledge_distiller_grouped_corrections() {
2005        let config = crate::config::KnowledgeConfig {
2006            min_entries_for_distillation: 2,
2007            ..Default::default()
2008        };
2009        let mut distiller = KnowledgeDistiller::new(Some(&config));
2010
2011        let mut ltm = LongTermMemory::new();
2012        // Two corrections with same context prefix → should group
2013        ltm.add_correction(
2014            "unwrap()".into(),
2015            "? operator".into(),
2016            "error handling in Rust code".into(),
2017        );
2018        ltm.add_correction(
2019            "expect()".into(),
2020            "map_err()?".into(),
2021            "error handling in Rust code".into(),
2022        );
2023
2024        distiller.distill(&ltm);
2025        assert_eq!(distiller.rule_count(), 1);
2026        let prompt = distiller.rules_for_prompt();
2027        assert!(prompt.contains("2 previous corrections"));
2028    }
2029
2030    #[test]
2031    fn test_knowledge_distiller_preference_facts() {
2032        let config = crate::config::KnowledgeConfig {
2033            min_entries_for_distillation: 1,
2034            ..Default::default()
2035        };
2036        let mut distiller = KnowledgeDistiller::new(Some(&config));
2037
2038        let mut ltm = LongTermMemory::new();
2039        ltm.add_fact(Fact::new("Prefer async/await over threads", "user"));
2040        ltm.add_fact(Fact::new("Project uses PostgreSQL", "session"));
2041
2042        distiller.distill(&ltm);
2043        // Only the "Prefer..." fact becomes a rule
2044        assert_eq!(distiller.rule_count(), 1);
2045        let prompt = distiller.rules_for_prompt();
2046        assert!(prompt.contains("async/await"));
2047    }
2048
2049    #[test]
2050    fn test_knowledge_distiller_max_rules_truncation() {
2051        let config = crate::config::KnowledgeConfig {
2052            min_entries_for_distillation: 1,
2053            max_rules: 3,
2054            ..Default::default()
2055        };
2056        let mut distiller = KnowledgeDistiller::new(Some(&config));
2057
2058        let mut ltm = LongTermMemory::new();
2059        for i in 0..10 {
2060            ltm.add_correction(
2061                format!("old{}", i),
2062                format!("new{}", i),
2063                format!("context{}", i),
2064            );
2065        }
2066
2067        distiller.distill(&ltm);
2068        assert!(distiller.rule_count() <= 3);
2069    }
2070
2071    #[test]
2072    fn test_knowledge_distiller_idempotent() {
2073        let config = crate::config::KnowledgeConfig {
2074            min_entries_for_distillation: 1,
2075            ..Default::default()
2076        };
2077        let mut distiller = KnowledgeDistiller::new(Some(&config));
2078
2079        let mut ltm = LongTermMemory::new();
2080        ltm.add_correction("old".into(), "new".into(), "ctx".into());
2081
2082        distiller.distill(&ltm);
2083        let count_after_first = distiller.rule_count();
2084
2085        // Distill again with same data — should not add duplicate rules
2086        distiller.distill(&ltm);
2087        assert_eq!(distiller.rule_count(), count_after_first);
2088    }
2089
2090    #[test]
2091    fn test_knowledge_store_save_load_roundtrip() {
2092        let dir = tempfile::tempdir().unwrap();
2093        let path = dir.path().join("knowledge.json");
2094
2095        let mut store = KnowledgeStore::new();
2096        store.rules.push(BehavioralRule {
2097            id: Uuid::new_v4(),
2098            rule: "Prefer ? over unwrap".into(),
2099            source_ids: vec![Uuid::new_v4()],
2100            support_count: 3,
2101            created_at: Utc::now(),
2102        });
2103
2104        store.save(&path).unwrap();
2105        let loaded = KnowledgeStore::load(&path).unwrap();
2106        assert_eq!(loaded.rules.len(), 1);
2107        assert_eq!(loaded.rules[0].rule, "Prefer ? over unwrap");
2108        assert_eq!(loaded.rules[0].support_count, 3);
2109    }
2110
2111    #[test]
2112    fn test_knowledge_store_load_nonexistent() {
2113        let store =
2114            KnowledgeStore::load(std::path::Path::new("/nonexistent/knowledge.json")).unwrap();
2115        assert!(store.rules.is_empty());
2116    }
2117
2118    #[test]
2119    fn test_unpin_out_of_bounds_returns_false() {
2120        let mut stm = ShortTermMemory::new(100);
2121        stm.add(Message::user("hello"));
2122        stm.add(Message::assistant("hi"));
2123        stm.add(Message::user("world"));
2124
2125        // Pin a valid message
2126        assert!(stm.pin(1));
2127        assert!(stm.is_pinned(1));
2128
2129        // Unpin out-of-bounds should return false
2130        assert!(!stm.unpin(999));
2131        assert!(!stm.unpin(3));
2132
2133        // Original pin should still be intact
2134        assert!(stm.is_pinned(1));
2135    }
2136
2137    #[test]
2138    fn test_unpin_at_exact_boundary() {
2139        let mut stm = ShortTermMemory::new(100);
2140        stm.add(Message::user("msg0"));
2141        stm.add(Message::user("msg1"));
2142
2143        // Unpin at exactly len (2) should fail
2144        assert!(!stm.unpin(2));
2145
2146        // Unpin at len-1 (1) should succeed even if not pinned (returns false from remove)
2147        assert!(!stm.unpin(1)); // not pinned, so remove returns false
2148
2149        // Pin and then unpin at boundary
2150        assert!(stm.pin(1));
2151        assert!(stm.unpin(1));
2152        assert!(!stm.is_pinned(1));
2153    }
2154
2155    // --- Tool chain pair preservation tests ---
2156
2157    #[test]
2158    fn test_compress_preserves_tool_chain_pairs() {
2159        let mut stm = ShortTermMemory::new(3);
2160
2161        // Build a conversation with tool_call → tool_result pair in the compression zone
2162        stm.add(Message::user("read the file")); // idx 0
2163        stm.add(Message::new(
2164            // idx 1 — tool_call
2165            Role::Assistant,
2166            Content::tool_call("call_abc", "file_read", serde_json::json!({"path": "x.rs"})),
2167        ));
2168        stm.add(Message::tool_result("call_abc", "fn main() {}", false)); // idx 2 — tool_result
2169        stm.add(Message::assistant("Here is the content.")); // idx 3
2170        stm.add(Message::user("thanks")); // idx 4
2171        stm.add(Message::assistant("You're welcome.")); // idx 5
2172
2173        // Pin the tool_result at index 2
2174        stm.pin(2);
2175        assert!(stm.needs_compression());
2176
2177        let removed = stm.compress("Summary of earlier conversation".to_string());
2178
2179        // The tool_result (pinned) should be preserved, AND its paired
2180        // tool_call should ALSO be preserved
2181        let msgs = stm.to_messages();
2182        let has_tool_call = msgs.iter().any(|m| {
2183            matches!(
2184                &m.content,
2185                Content::ToolCall { name, .. } if name == "file_read"
2186            )
2187        });
2188        let has_tool_result = msgs.iter().any(|m| {
2189            matches!(
2190                &m.content,
2191                Content::ToolResult { call_id, .. } if call_id == "call_abc"
2192            )
2193        });
2194
2195        assert!(
2196            has_tool_result,
2197            "Pinned tool_result should survive compression"
2198        );
2199        assert!(
2200            has_tool_call,
2201            "Paired tool_call should also survive compression"
2202        );
2203
2204        // The user message at idx 0 should have been compressed away
2205        // (it was not pinned and not a tool pair)
2206        assert!(removed >= 1, "At least one message should be compressed");
2207    }
2208
2209    #[test]
2210    fn test_compress_preserves_tool_call_paired_with_result() {
2211        let mut stm = ShortTermMemory::new(3);
2212
2213        stm.add(Message::user("do something")); // idx 0
2214        stm.add(Message::new(
2215            // idx 1 — tool_call
2216            Role::Assistant,
2217            Content::tool_call("call_xyz", "shell_exec", serde_json::json!({"cmd": "ls"})),
2218        ));
2219        stm.add(Message::tool_result("call_xyz", "file1\nfile2", false)); // idx 2
2220        stm.add(Message::assistant("Listed files.")); // idx 3
2221        stm.add(Message::user("ok")); // idx 4
2222        stm.add(Message::assistant("Done.")); // idx 5
2223
2224        // Pin the tool_call at index 1
2225        stm.pin(1);
2226        assert!(stm.needs_compression());
2227
2228        stm.compress("Summary".to_string());
2229
2230        let msgs = stm.to_messages();
2231        let has_tool_call = msgs.iter().any(|m| {
2232            matches!(
2233                &m.content,
2234                Content::ToolCall { id, .. } if id == "call_xyz"
2235            )
2236        });
2237        let has_tool_result = msgs.iter().any(|m| {
2238            matches!(
2239                &m.content,
2240                Content::ToolResult { call_id, .. } if call_id == "call_xyz"
2241            )
2242        });
2243
2244        assert!(has_tool_call, "Pinned tool_call should survive compression");
2245        assert!(
2246            has_tool_result,
2247            "Paired tool_result should also survive compression"
2248        );
2249    }
2250}