1use crate::error::MemoryError;
8use crate::search::{HybridSearchEngine, SearchConfig};
9use crate::types::{Content, Message, Role};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::path::Path;
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18pub struct WorkingMemory {
19 pub current_goal: Option<String>,
20 pub sub_tasks: Vec<String>,
21 pub scratchpad: HashMap<String, String>,
22 pub active_files: Vec<String>,
23}
24
25impl WorkingMemory {
26 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn set_goal(&mut self, goal: impl Into<String>) {
31 self.current_goal = Some(goal.into());
32 }
33
34 pub fn add_sub_task(&mut self, task: impl Into<String>) {
35 self.sub_tasks.push(task.into());
36 }
37
38 pub fn note(&mut self, key: impl Into<String>, value: impl Into<String>) {
39 self.scratchpad.insert(key.into(), value.into());
40 }
41
42 pub fn add_active_file(&mut self, path: impl Into<String>) {
43 let path = path.into();
44 if !self.active_files.contains(&path) {
45 self.active_files.push(path);
46 }
47 }
48
49 pub fn clear(&mut self) {
50 *self = Self::default();
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct ShortTermMemory {
57 messages: VecDeque<Message>,
58 window_size: usize,
59 summarized_prefix: Option<String>,
60 total_messages_seen: usize,
61 pinned: std::collections::HashSet<usize>,
63 compressed_offset: usize,
65}
66
67impl ShortTermMemory {
68 pub fn new(window_size: usize) -> Self {
69 Self {
70 messages: VecDeque::new(),
71 window_size,
72 summarized_prefix: None,
73 total_messages_seen: 0,
74 pinned: std::collections::HashSet::new(),
75 compressed_offset: 0,
76 }
77 }
78
79 pub fn add(&mut self, message: Message) {
81 self.messages.push_back(message);
82 self.total_messages_seen += 1;
83 }
84
85 pub fn to_messages(&self) -> Vec<Message> {
87 let mut result = Vec::new();
88
89 if let Some(ref summary) = self.summarized_prefix {
91 result.push(Message::system(format!(
92 "[Summary of earlier conversation]\n{}",
93 summary
94 )));
95 }
96
97 let start = if self.messages.len() > self.window_size {
100 self.messages.len() - self.window_size
101 } else {
102 0
103 };
104
105 for (i, msg) in self.messages.iter().enumerate() {
106 if i >= start || self.is_pinned(i) {
107 result.push(msg.clone());
108 }
109 }
110
111 result
112 }
113
114 pub fn needs_compression(&self) -> bool {
116 self.messages.len() >= self.window_size * 2
117 }
118
119 pub fn compress(&mut self, summary: String) -> usize {
125 if self.messages.len() <= self.window_size {
126 return 0;
127 }
128
129 let to_remove = self.messages.len() - self.window_size;
130
131 let mut preserve_indices: HashSet<usize> = HashSet::new();
133 for i in 0..to_remove {
134 let abs_idx = self.compressed_offset + i;
135 if self.pinned.contains(&abs_idx) {
136 preserve_indices.insert(i);
137 }
138 }
139
140 let mut extra_preserves: Vec<usize> = Vec::new();
142 for &i in &preserve_indices {
143 if let Some(msg) = self.messages.get(i) {
144 match &msg.content {
145 Content::ToolResult { call_id, .. } => {
147 if let Some(pair_idx) =
148 Self::find_tool_call_for_result(call_id, &self.messages, to_remove)
149 {
150 if !preserve_indices.contains(&pair_idx) {
151 extra_preserves.push(pair_idx);
152 }
153 }
154 }
155 Content::ToolCall { id, .. } => {
157 if let Some(pair_idx) =
158 Self::find_tool_result_for_call(id, &self.messages, to_remove)
159 {
160 if !preserve_indices.contains(&pair_idx) {
161 extra_preserves.push(pair_idx);
162 }
163 }
164 }
165 Content::MultiPart { parts } => {
167 for part in parts {
168 match part {
169 Content::ToolCall { id, .. } => {
170 if let Some(pair_idx) = Self::find_tool_result_for_call(
171 id,
172 &self.messages,
173 to_remove,
174 ) {
175 if !preserve_indices.contains(&pair_idx) {
176 extra_preserves.push(pair_idx);
177 }
178 }
179 }
180 Content::ToolResult { call_id, .. } => {
181 if let Some(pair_idx) = Self::find_tool_call_for_result(
182 call_id,
183 &self.messages,
184 to_remove,
185 ) {
186 if !preserve_indices.contains(&pair_idx) {
187 extra_preserves.push(pair_idx);
188 }
189 }
190 }
191 _ => {}
192 }
193 }
194 }
195 _ => {}
196 }
197 }
198 }
199 for idx in extra_preserves {
200 preserve_indices.insert(idx);
201 }
202
203 let mut preserved = Vec::new();
205 let mut removed_count = 0;
206
207 for i in 0..to_remove {
208 if preserve_indices.contains(&i) {
209 if let Some(msg) = self.messages.get(i) {
210 preserved.push(msg.clone());
211 }
212 } else {
213 removed_count += 1;
214 }
215 }
216
217 let mut surviving_pinned: Vec<usize> = Vec::new();
220 for i in to_remove..self.messages.len() {
221 let abs_idx = self.compressed_offset + i;
222 if self.pinned.contains(&abs_idx) {
223 surviving_pinned.push(i - to_remove);
224 }
225 }
226
227 self.messages.drain(..to_remove);
229 self.compressed_offset += to_remove;
230
231 let preserved_count = preserved.len();
234 if !preserved.is_empty() {
235 let mut new_messages = VecDeque::with_capacity(preserved_count + self.messages.len());
236 for msg in preserved {
237 new_messages.push_back(msg);
238 }
239 new_messages.append(&mut self.messages);
240 self.messages = new_messages;
241 }
242
243 let mut new_pinned = HashSet::new();
247 for i in 0..preserved_count {
248 new_pinned.insert(self.compressed_offset + i);
249 }
250 for pos in surviving_pinned {
251 new_pinned.insert(self.compressed_offset + preserved_count + pos);
252 }
253 self.pinned = new_pinned;
254
255 if let Some(ref existing) = self.summarized_prefix {
257 self.summarized_prefix = Some(format!("{}\n\n{}", existing, summary));
258 } else {
259 self.summarized_prefix = Some(summary);
260 }
261
262 removed_count
263 }
264
265 fn find_tool_call_for_result(
268 call_id: &str,
269 messages: &VecDeque<Message>,
270 limit: usize,
271 ) -> Option<usize> {
272 messages
273 .iter()
274 .enumerate()
275 .take(limit)
276 .find(|(_, msg)| {
277 msg.role == Role::Assistant
278 && Self::content_contains_tool_call_id(&msg.content, call_id)
279 })
280 .map(|(i, _)| i)
281 }
282
283 fn find_tool_result_for_call(
286 tool_call_id: &str,
287 messages: &VecDeque<Message>,
288 limit: usize,
289 ) -> Option<usize> {
290 messages
291 .iter()
292 .enumerate()
293 .take(limit)
294 .find(|(_, msg)| {
295 (msg.role == Role::Tool || msg.role == Role::User)
296 && Self::content_contains_tool_result_id(&msg.content, tool_call_id)
297 })
298 .map(|(i, _)| i)
299 }
300
301 fn content_contains_tool_call_id(content: &Content, target_id: &str) -> bool {
303 match content {
304 Content::ToolCall { id, .. } => id == target_id,
305 Content::MultiPart { parts } => parts
306 .iter()
307 .any(|p| Self::content_contains_tool_call_id(p, target_id)),
308 _ => false,
309 }
310 }
311
312 fn content_contains_tool_result_id(content: &Content, target_id: &str) -> bool {
314 match content {
315 Content::ToolResult { call_id, .. } => call_id == target_id,
316 Content::MultiPart { parts } => parts
317 .iter()
318 .any(|p| Self::content_contains_tool_result_id(p, target_id)),
319 _ => false,
320 }
321 }
322
323 pub fn pin(&mut self, position: usize) -> bool {
326 if position >= self.messages.len() {
327 return false;
328 }
329 let abs_idx = self.compressed_offset + position;
330 self.pinned.insert(abs_idx);
331 true
332 }
333
334 pub fn unpin(&mut self, position: usize) -> bool {
336 if position >= self.messages.len() {
337 return false;
338 }
339 let abs_idx = self.compressed_offset + position;
340 self.pinned.remove(&abs_idx)
341 }
342
343 pub fn is_pinned(&self, position: usize) -> bool {
345 let abs_idx = self.compressed_offset + position;
346 self.pinned.contains(&abs_idx)
347 }
348
349 pub fn pinned_count(&self) -> usize {
351 (0..self.messages.len())
353 .filter(|&i| self.is_pinned(i))
354 .count()
355 }
356
357 pub fn messages_to_summarize(&self) -> Vec<&Message> {
359 if self.messages.len() <= self.window_size {
360 return Vec::new();
361 }
362 let to_summarize = self.messages.len() - self.window_size;
363 self.messages.iter().take(to_summarize).collect()
364 }
365
366 pub fn len(&self) -> usize {
368 self.messages.len()
369 }
370
371 pub fn is_empty(&self) -> bool {
373 self.messages.is_empty()
374 }
375
376 pub fn total_messages_seen(&self) -> usize {
378 self.total_messages_seen
379 }
380
381 pub fn clear(&mut self) {
383 self.messages.clear();
384 self.summarized_prefix = None;
385 self.total_messages_seen = 0;
386 self.pinned.clear();
387 self.compressed_offset = 0;
388 }
389
390 pub fn messages(&self) -> &VecDeque<Message> {
392 &self.messages
393 }
394
395 pub fn summary(&self) -> Option<&str> {
397 self.summarized_prefix.as_deref()
398 }
399}
400
401#[derive(Debug, Clone, Serialize, Deserialize)]
403pub struct Fact {
404 pub id: Uuid,
405 pub content: String,
406 pub source: String,
407 pub created_at: DateTime<Utc>,
408 pub tags: Vec<String>,
409}
410
411impl Fact {
412 pub fn new(content: impl Into<String>, source: impl Into<String>) -> Self {
413 Self {
414 id: Uuid::new_v4(),
415 content: content.into(),
416 source: source.into(),
417 created_at: Utc::now(),
418 tags: Vec::new(),
419 }
420 }
421
422 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
423 self.tags = tags;
424 self
425 }
426}
427
428#[derive(Debug, Clone, Serialize, Deserialize)]
430pub struct LongTermMemory {
431 pub facts: Vec<Fact>,
432 pub preferences: HashMap<String, String>,
433 pub corrections: Vec<Correction>,
434 #[serde(default = "LongTermMemory::default_max_facts")]
436 pub max_facts: usize,
437 #[serde(default = "LongTermMemory::default_max_corrections")]
439 pub max_corrections: usize,
440}
441
442impl Default for LongTermMemory {
443 fn default() -> Self {
444 Self {
445 facts: Vec::new(),
446 preferences: HashMap::new(),
447 corrections: Vec::new(),
448 max_facts: Self::default_max_facts(),
449 max_corrections: Self::default_max_corrections(),
450 }
451 }
452}
453
454#[derive(Debug, Clone, Serialize, Deserialize)]
456pub struct Correction {
457 pub id: Uuid,
458 pub original: String,
459 pub corrected: String,
460 pub context: String,
461 pub timestamp: DateTime<Utc>,
462}
463
464impl LongTermMemory {
465 pub fn new() -> Self {
466 Self::default()
467 }
468
469 fn default_max_facts() -> usize {
470 10_000
471 }
472
473 fn default_max_corrections() -> usize {
474 1_000
475 }
476
477 pub fn add_fact(&mut self, fact: Fact) {
478 if self.facts.len() >= self.max_facts {
479 self.facts.remove(0);
480 }
481 self.facts.push(fact);
482 }
483
484 pub fn set_preference(&mut self, key: impl Into<String>, value: impl Into<String>) {
485 self.preferences.insert(key.into(), value.into());
486 }
487
488 pub fn get_preference(&self, key: &str) -> Option<&str> {
489 self.preferences.get(key).map(|s| s.as_str())
490 }
491
492 pub fn add_correction(&mut self, original: String, corrected: String, context: String) {
493 if self.corrections.len() >= self.max_corrections {
494 self.corrections.remove(0);
495 }
496 self.corrections.push(Correction {
497 id: Uuid::new_v4(),
498 original,
499 corrected,
500 context,
501 timestamp: Utc::now(),
502 });
503 }
504
505 pub fn search_facts(&self, query: &str) -> Vec<&Fact> {
507 let query_lower = query.to_lowercase();
508 self.facts
509 .iter()
510 .filter(|f| {
511 f.content.to_lowercase().contains(&query_lower)
512 || f.tags
513 .iter()
514 .any(|t| t.to_lowercase().contains(&query_lower))
515 })
516 .collect()
517 }
518}
519
520pub struct MemorySystem {
522 pub working: WorkingMemory,
523 pub short_term: ShortTermMemory,
524 pub long_term: LongTermMemory,
525 search_engine: Option<HybridSearchEngine>,
527 flusher: Option<MemoryFlusher>,
529}
530
531impl MemorySystem {
532 pub fn new(window_size: usize) -> Self {
533 Self {
534 working: WorkingMemory::new(),
535 short_term: ShortTermMemory::new(window_size),
536 long_term: LongTermMemory::new(),
537 search_engine: None,
538 flusher: None,
539 }
540 }
541
542 pub fn with_search(
544 window_size: usize,
545 search_config: SearchConfig,
546 ) -> Result<Self, crate::search::SearchError> {
547 let engine = HybridSearchEngine::open(search_config)?;
548 Ok(Self {
549 working: WorkingMemory::new(),
550 short_term: ShortTermMemory::new(window_size),
551 long_term: LongTermMemory::new(),
552 search_engine: Some(engine),
553 flusher: None,
554 })
555 }
556
557 pub fn with_flusher(mut self, config: FlushConfig) -> Self {
559 self.flusher = Some(MemoryFlusher::new(config));
560 self
561 }
562
563 pub fn context_messages(&self) -> Vec<Message> {
565 self.short_term.to_messages()
566 }
567
568 pub fn add_message(&mut self, message: Message) {
570 self.short_term.add(message);
571 if let Some(ref mut flusher) = self.flusher {
573 flusher.on_message_added();
574 }
575 }
576
577 pub fn add_fact(&mut self, fact: Fact) {
579 if let Some(ref mut engine) = self.search_engine {
580 let _ = engine.index_fact(&fact.id.to_string(), &fact.content);
581 }
582 self.long_term.add_fact(fact);
583 }
584
585 pub fn search_facts_hybrid(&self, query: &str) -> Vec<&Fact> {
587 if let Some(ref engine) = self.search_engine {
588 if let Ok(results) = engine.search(query) {
589 let ids: Vec<String> = results.iter().map(|r| r.fact_id.clone()).collect();
590 let found: Vec<&Fact> = self
591 .long_term
592 .facts
593 .iter()
594 .filter(|f| ids.contains(&f.id.to_string()))
595 .collect();
596 if !found.is_empty() {
597 return found;
598 }
599 }
600 }
601 self.long_term.search_facts(query)
603 }
604
605 pub fn check_auto_flush(&mut self) -> Result<bool, MemoryError> {
609 let mut flusher = match self.flusher.take() {
610 Some(f) => f,
611 None => return Ok(false),
612 };
613 let result = if flusher.should_flush() {
614 flusher.flush(self)?;
615 Ok(true)
616 } else {
617 Ok(false)
618 };
619 self.flusher = Some(flusher);
620 result
621 }
622
623 pub fn force_flush(&mut self) -> Result<(), MemoryError> {
625 let mut flusher = match self.flusher.take() {
626 Some(f) => f,
627 None => return Ok(()),
628 };
629 let result = flusher.force_flush(self);
630 self.flusher = Some(flusher);
631 result
632 }
633
634 pub fn flusher_is_dirty(&self) -> bool {
636 self.flusher.as_ref().is_some_and(|f| f.is_dirty())
637 }
638
639 pub fn start_new_task(&mut self, goal: impl Into<String>) {
641 self.working.clear();
642 self.working.set_goal(goal);
643 }
644
645 pub fn clear_session(&mut self) {
647 self.working.clear();
648 self.short_term.clear();
649 }
650
651 pub fn context_breakdown(&self, context_window: usize) -> ContextBreakdown {
653 let summary_chars = self.short_term.summary().map(|s| s.len()).unwrap_or(0);
654 let message_chars: usize = self
655 .short_term
656 .messages()
657 .iter()
658 .map(|m| m.content_length())
659 .sum();
660 let total_chars = summary_chars + message_chars;
661
662 let summary_tokens = summary_chars / 4;
664 let message_tokens = message_chars / 4;
665 let total_tokens = total_chars / 4;
666 let remaining_tokens = context_window.saturating_sub(total_tokens);
667
668 ContextBreakdown {
669 summary_tokens,
670 message_tokens,
671 total_tokens,
672 context_window,
673 remaining_tokens,
674 message_count: self.short_term.len(),
675 total_messages_seen: self.short_term.total_messages_seen(),
676 pinned_count: self.short_term.pinned_count(),
677 has_summary: self.short_term.summary().is_some(),
678 facts_count: self.long_term.facts.len(),
679 rules_count: 0, }
681 }
682
683 pub fn pin_message(&mut self, position: usize) -> bool {
685 self.short_term.pin(position)
686 }
687
688 pub fn unpin_message(&mut self, position: usize) -> bool {
690 self.short_term.unpin(position)
691 }
692}
693
694#[derive(Debug, Clone, Default)]
696pub struct ContextBreakdown {
697 pub summary_tokens: usize,
699 pub message_tokens: usize,
701 pub total_tokens: usize,
703 pub context_window: usize,
705 pub remaining_tokens: usize,
707 pub message_count: usize,
709 pub total_messages_seen: usize,
711 pub pinned_count: usize,
713 pub has_summary: bool,
715 pub facts_count: usize,
717 pub rules_count: usize,
719}
720
721impl ContextBreakdown {
722 pub fn usage_ratio(&self) -> f32 {
724 if self.context_window == 0 {
725 return 0.0;
726 }
727 (self.total_tokens as f32 / self.context_window as f32).clamp(0.0, 1.0)
728 }
729
730 pub fn is_warning(&self) -> bool {
733 self.usage_ratio() >= 0.7
734 }
735}
736
737#[derive(Debug, Clone, Serialize, Deserialize)]
739pub struct SessionMetadata {
740 pub id: Uuid,
741 pub created_at: DateTime<Utc>,
742 pub updated_at: DateTime<Utc>,
743 pub task_summary: Option<String>,
744}
745
746impl SessionMetadata {
747 pub fn new() -> Self {
748 let now = Utc::now();
749 Self {
750 id: Uuid::new_v4(),
751 created_at: now,
752 updated_at: now,
753 task_summary: None,
754 }
755 }
756}
757
758impl Default for SessionMetadata {
759 fn default() -> Self {
760 Self::new()
761 }
762}
763
764#[derive(Debug, Clone, Serialize, Deserialize)]
766pub struct Session {
767 pub metadata: SessionMetadata,
768 pub working: WorkingMemory,
769 pub long_term: LongTermMemory,
770 pub messages: Vec<Message>,
771 pub window_size: usize,
772}
773
774impl MemorySystem {
775 pub fn save_session(&self, path: &Path) -> Result<(), MemoryError> {
777 let session = Session {
778 metadata: SessionMetadata {
779 id: Uuid::new_v4(),
780 created_at: Utc::now(),
781 updated_at: Utc::now(),
782 task_summary: self.working.current_goal.clone(),
783 },
784 working: self.working.clone(),
785 long_term: self.long_term.clone(),
786 messages: self.short_term.messages().iter().cloned().collect(),
787 window_size: self.short_term.window_size(),
788 };
789
790 let json =
791 serde_json::to_string_pretty(&session).map_err(|e| MemoryError::PersistenceError {
792 message: format!("Failed to serialize session: {}", e),
793 })?;
794
795 if let Some(parent) = path.parent() {
796 std::fs::create_dir_all(parent).map_err(|e| MemoryError::PersistenceError {
797 message: format!("Failed to create directory: {}", e),
798 })?;
799 }
800
801 std::fs::write(path, json).map_err(|e| MemoryError::PersistenceError {
802 message: format!("Failed to write session file: {}", e),
803 })?;
804
805 Ok(())
806 }
807
808 pub fn load_session(path: &Path) -> Result<Self, MemoryError> {
810 let json = std::fs::read_to_string(path).map_err(|e| MemoryError::SessionLoadFailed {
811 message: format!("Failed to read session file: {}", e),
812 })?;
813
814 let session: Session =
815 serde_json::from_str(&json).map_err(|e| MemoryError::SessionLoadFailed {
816 message: format!("Failed to deserialize session: {}", e),
817 })?;
818
819 let mut memory = MemorySystem::new(session.window_size);
820 memory.working = session.working;
821 memory.long_term = session.long_term;
822 for msg in session.messages {
823 memory.short_term.add(msg);
824 }
825
826 Ok(memory)
827 }
828}
829
830impl ShortTermMemory {
831 pub fn window_size(&self) -> usize {
833 self.window_size
834 }
835}
836
837#[derive(Debug, Clone)]
839pub struct CompressionResult {
840 pub messages_before: usize,
841 pub messages_after: usize,
842 pub compressed_count: usize,
843}
844
845#[derive(Debug, Clone, Serialize, Deserialize)]
851pub struct FlushConfig {
852 pub enabled: bool,
854 pub interval_secs: u64,
856 pub flush_on_message_count: usize,
858 pub flush_path: Option<std::path::PathBuf>,
860}
861
862impl Default for FlushConfig {
863 fn default() -> Self {
864 Self {
865 enabled: false,
866 interval_secs: 300, flush_on_message_count: 50,
868 flush_path: None,
869 }
870 }
871}
872
873#[derive(Debug, Clone)]
875pub struct MemoryFlusher {
876 config: FlushConfig,
877 dirty: bool,
878 messages_since_flush: usize,
879 last_flush: DateTime<Utc>,
880 total_flushes: usize,
881}
882
883impl MemoryFlusher {
884 pub fn new(config: FlushConfig) -> Self {
886 Self {
887 config,
888 dirty: false,
889 messages_since_flush: 0,
890 last_flush: Utc::now(),
891 total_flushes: 0,
892 }
893 }
894
895 pub fn on_message_added(&mut self) {
897 self.dirty = true;
898 self.messages_since_flush += 1;
899 }
900
901 pub fn should_flush(&self) -> bool {
903 if !self.config.enabled || !self.dirty {
904 return false;
905 }
906
907 if self.config.flush_on_message_count > 0
909 && self.messages_since_flush >= self.config.flush_on_message_count
910 {
911 return true;
912 }
913
914 if self.config.interval_secs > 0 {
916 let elapsed = (Utc::now() - self.last_flush).num_seconds();
917 if elapsed >= self.config.interval_secs as i64 {
918 return true;
919 }
920 }
921
922 false
923 }
924
925 pub fn flush(&mut self, memory: &MemorySystem) -> Result<(), MemoryError> {
927 let path =
928 self.config
929 .flush_path
930 .as_ref()
931 .ok_or_else(|| MemoryError::PersistenceError {
932 message: "No flush path configured".to_string(),
933 })?;
934
935 memory.save_session(path)?;
936 self.mark_flushed();
937 Ok(())
938 }
939
940 pub fn force_flush(&mut self, memory: &MemorySystem) -> Result<(), MemoryError> {
942 if !self.dirty {
943 return Ok(()); }
945 self.flush(memory)
946 }
947
948 pub fn is_dirty(&self) -> bool {
950 self.dirty
951 }
952
953 pub fn messages_since_flush(&self) -> usize {
955 self.messages_since_flush
956 }
957
958 pub fn total_flushes(&self) -> usize {
960 self.total_flushes
961 }
962
963 fn mark_flushed(&mut self) {
965 self.dirty = false;
966 self.messages_since_flush = 0;
967 self.last_flush = Utc::now();
968 self.total_flushes += 1;
969 }
970}
971
972#[derive(Debug, Clone, Serialize, Deserialize)]
978pub struct BehavioralRule {
979 pub id: Uuid,
981 pub rule: String,
983 pub source_ids: Vec<Uuid>,
985 pub support_count: usize,
987 pub created_at: DateTime<Utc>,
989}
990
991#[derive(Debug, Clone, Default, Serialize, Deserialize)]
993pub struct KnowledgeStore {
994 pub rules: Vec<BehavioralRule>,
995 pub processed_correction_ids: Vec<Uuid>,
997 pub processed_fact_ids: Vec<Uuid>,
999}
1000
1001impl KnowledgeStore {
1002 pub fn new() -> Self {
1003 Self::default()
1004 }
1005
1006 pub fn load(path: &std::path::Path) -> Result<Self, MemoryError> {
1008 if !path.exists() {
1009 return Ok(Self::new());
1010 }
1011 let json = std::fs::read_to_string(path).map_err(|e| MemoryError::PersistenceError {
1012 message: format!("Failed to read knowledge store: {}", e),
1013 })?;
1014 serde_json::from_str(&json).map_err(|e| MemoryError::PersistenceError {
1015 message: format!("Failed to parse knowledge store: {}", e),
1016 })
1017 }
1018
1019 pub fn save(&self, path: &std::path::Path) -> Result<(), MemoryError> {
1021 if let Some(parent) = path.parent() {
1022 std::fs::create_dir_all(parent).map_err(|e| MemoryError::PersistenceError {
1023 message: format!("Failed to create knowledge directory: {}", e),
1024 })?;
1025 }
1026 let json =
1027 serde_json::to_string_pretty(self).map_err(|e| MemoryError::PersistenceError {
1028 message: format!("Failed to serialize knowledge store: {}", e),
1029 })?;
1030 std::fs::write(path, json).map_err(|e| MemoryError::PersistenceError {
1031 message: format!("Failed to write knowledge store: {}", e),
1032 })
1033 }
1034}
1035
1036pub struct KnowledgeDistiller {
1041 store: KnowledgeStore,
1042 max_rules: usize,
1043 min_entries: usize,
1044 store_path: Option<std::path::PathBuf>,
1045}
1046
1047impl KnowledgeDistiller {
1048 pub fn new(config: Option<&crate::config::KnowledgeConfig>) -> Self {
1051 match config {
1052 Some(cfg) if cfg.enabled => {
1053 let store = cfg
1054 .knowledge_path
1055 .as_ref()
1056 .and_then(|p| KnowledgeStore::load(p).ok())
1057 .unwrap_or_default();
1058 Self {
1059 store,
1060 max_rules: cfg.max_rules,
1061 min_entries: cfg.min_entries_for_distillation,
1062 store_path: cfg.knowledge_path.clone(),
1063 }
1064 }
1065 _ => Self {
1066 store: KnowledgeStore::new(),
1067 max_rules: 0,
1068 min_entries: usize::MAX,
1069 store_path: None,
1070 },
1071 }
1072 }
1073
1074 pub fn distill(&mut self, long_term: &LongTermMemory) {
1079 let new_corrections: Vec<&Correction> = long_term
1081 .corrections
1082 .iter()
1083 .filter(|c| !self.store.processed_correction_ids.contains(&c.id))
1084 .collect();
1085
1086 let new_facts: Vec<&Fact> = long_term
1088 .facts
1089 .iter()
1090 .filter(|f| !self.store.processed_fact_ids.contains(&f.id))
1091 .collect();
1092
1093 let total_new = new_corrections.len() + new_facts.len();
1094 if total_new < self.min_entries {
1095 return; }
1097
1098 let mut context_groups: HashMap<String, Vec<&Correction>> = HashMap::new();
1101 for correction in &new_corrections {
1102 let key = correction
1104 .context
1105 .chars()
1106 .take(50)
1107 .collect::<String>()
1108 .to_lowercase();
1109 context_groups.entry(key).or_default().push(correction);
1110 }
1111
1112 for group in context_groups.values() {
1114 if group.len() >= 2 {
1115 let corrected_patterns: Vec<&str> =
1117 group.iter().map(|c| c.corrected.as_str()).collect();
1118 let rule_text = format!(
1119 "Based on {} previous corrections: prefer {}",
1120 group.len(),
1121 corrected_patterns.join("; ")
1122 );
1123 let source_ids: Vec<Uuid> = group.iter().map(|c| c.id).collect();
1124 self.store.rules.push(BehavioralRule {
1125 id: Uuid::new_v4(),
1126 rule: rule_text,
1127 source_ids,
1128 support_count: group.len(),
1129 created_at: Utc::now(),
1130 });
1131 } else {
1132 for c in group {
1134 self.store.rules.push(BehavioralRule {
1135 id: Uuid::new_v4(),
1136 rule: format!("Instead of '{}', prefer '{}'", c.original, c.corrected),
1137 source_ids: vec![c.id],
1138 support_count: 1,
1139 created_at: Utc::now(),
1140 });
1141 }
1142 }
1143 }
1144
1145 for fact in &new_facts {
1149 let is_preference = fact.tags.iter().any(|t| t == "preference")
1150 || fact.content.starts_with("Prefer")
1151 || fact.content.starts_with("Always")
1152 || fact.content.starts_with("Never")
1153 || fact.content.starts_with("Don't")
1154 || fact.content.starts_with("Use ");
1155 if is_preference {
1156 self.store.rules.push(BehavioralRule {
1157 id: Uuid::new_v4(),
1158 rule: fact.content.clone(),
1159 source_ids: vec![fact.id],
1160 support_count: 1,
1161 created_at: Utc::now(),
1162 });
1163 }
1164 }
1165
1166 for c in &new_corrections {
1168 self.store.processed_correction_ids.push(c.id);
1169 }
1170 for f in &new_facts {
1171 self.store.processed_fact_ids.push(f.id);
1172 }
1173
1174 if self.store.rules.len() > self.max_rules {
1176 self.store
1177 .rules
1178 .sort_by(|a, b| b.support_count.cmp(&a.support_count));
1179 self.store.rules.truncate(self.max_rules);
1180 }
1181
1182 if let Some(ref path) = self.store_path {
1184 let _ = self.store.save(path);
1185 }
1186 }
1187
1188 pub fn rules_for_prompt(&self) -> String {
1192 if self.store.rules.is_empty() {
1193 return String::new();
1194 }
1195 let mut prompt = String::from(
1196 "\n\n## Learned Behavioral Rules\n\
1197 The following rules were distilled from previous sessions. Follow them:\n",
1198 );
1199 for (i, rule) in self.store.rules.iter().enumerate() {
1200 prompt.push_str(&format!("{}. {}\n", i + 1, rule.rule));
1201 }
1202 prompt
1203 }
1204
1205 pub fn rule_count(&self) -> usize {
1207 self.store.rules.len()
1208 }
1209
1210 pub fn store(&self) -> &KnowledgeStore {
1212 &self.store
1213 }
1214}
1215
1216#[cfg(test)]
1217mod tests {
1218 use super::*;
1219 use crate::types::{Content, Role};
1220
1221 #[test]
1222 fn test_working_memory_lifecycle() {
1223 let mut wm = WorkingMemory::new();
1224 assert!(wm.current_goal.is_none());
1225
1226 wm.set_goal("refactor auth module");
1227 assert_eq!(wm.current_goal.as_deref(), Some("refactor auth module"));
1228
1229 wm.add_sub_task("read current implementation");
1230 wm.add_sub_task("design new structure");
1231 assert_eq!(wm.sub_tasks.len(), 2);
1232
1233 wm.note("finding", "uses basic auth currently");
1234 assert_eq!(
1235 wm.scratchpad.get("finding").map(|s| s.as_str()),
1236 Some("uses basic auth currently")
1237 );
1238
1239 wm.add_active_file("src/auth/mod.rs");
1240 wm.add_active_file("src/auth/mod.rs"); assert_eq!(wm.active_files.len(), 1);
1242
1243 wm.clear();
1244 assert!(wm.current_goal.is_none());
1245 assert!(wm.sub_tasks.is_empty());
1246 }
1247
1248 #[test]
1249 fn test_short_term_memory_basic() {
1250 let mut stm = ShortTermMemory::new(5);
1251 assert!(stm.is_empty());
1252 assert_eq!(stm.len(), 0);
1253
1254 stm.add(Message::user("hello"));
1255 stm.add(Message::assistant("hi there"));
1256 assert_eq!(stm.len(), 2);
1257 assert_eq!(stm.total_messages_seen(), 2);
1258
1259 let messages = stm.to_messages();
1260 assert_eq!(messages.len(), 2);
1261 }
1262
1263 #[test]
1264 fn test_short_term_memory_window() {
1265 let mut stm = ShortTermMemory::new(3);
1266
1267 for i in 0..6 {
1268 stm.add(Message::user(format!("message {}", i)));
1269 }
1270
1271 assert_eq!(stm.len(), 6);
1272 let messages = stm.to_messages();
1273 assert_eq!(messages.len(), 3);
1275 assert_eq!(messages[0].content.as_text(), Some("message 3"));
1276 assert_eq!(messages[2].content.as_text(), Some("message 5"));
1277 }
1278
1279 #[test]
1280 fn test_short_term_memory_compression() {
1281 let mut stm = ShortTermMemory::new(3);
1282
1283 for i in 0..6 {
1284 stm.add(Message::user(format!("message {}", i)));
1285 }
1286
1287 assert!(stm.needs_compression());
1288
1289 let to_summarize = stm.messages_to_summarize();
1290 assert_eq!(to_summarize.len(), 3); let compressed = stm.compress("Summary of messages 0-2.".to_string());
1293 assert_eq!(compressed, 3);
1294 assert_eq!(stm.len(), 3); let messages = stm.to_messages();
1297 assert_eq!(messages.len(), 4); assert!(messages[0]
1300 .content
1301 .as_text()
1302 .unwrap()
1303 .contains("Summary of"));
1304 assert_eq!(messages[0].role, Role::System);
1305 }
1306
1307 #[test]
1308 fn test_short_term_memory_double_compression() {
1309 let mut stm = ShortTermMemory::new(2);
1310
1311 for i in 0..5 {
1313 stm.add(Message::user(format!("msg {}", i)));
1314 }
1315 stm.compress("First summary.".to_string());
1316 assert_eq!(stm.len(), 2);
1317
1318 for i in 5..8 {
1320 stm.add(Message::user(format!("msg {}", i)));
1321 }
1322 stm.compress("Second summary.".to_string());
1323 assert_eq!(stm.len(), 2);
1324
1325 let summary = stm.summary().unwrap();
1327 assert!(summary.contains("First summary."));
1328 assert!(summary.contains("Second summary."));
1329 }
1330
1331 #[test]
1332 fn test_short_term_memory_clear() {
1333 let mut stm = ShortTermMemory::new(5);
1334 stm.add(Message::user("test"));
1335 stm.compress("summary".to_string());
1336
1337 stm.clear();
1338 assert!(stm.is_empty());
1339 assert!(stm.summary().is_none());
1340 assert_eq!(stm.total_messages_seen(), 0);
1341 }
1342
1343 #[test]
1344 fn test_fact_creation() {
1345 let fact = Fact::new("Project uses JWT auth", "code analysis")
1346 .with_tags(vec!["auth".to_string(), "jwt".to_string()]);
1347 assert_eq!(fact.content, "Project uses JWT auth");
1348 assert_eq!(fact.source, "code analysis");
1349 assert_eq!(fact.tags.len(), 2);
1350 }
1351
1352 #[test]
1353 fn test_long_term_memory() {
1354 let mut ltm = LongTermMemory::new();
1355
1356 ltm.add_fact(Fact::new("Uses Rust 2021 edition", "Cargo.toml"));
1357 ltm.set_preference("code_style", "rustfmt defaults");
1358 ltm.add_correction(
1359 "wrong import".to_string(),
1360 "correct import".to_string(),
1361 "editing main.rs".to_string(),
1362 );
1363
1364 assert_eq!(ltm.facts.len(), 1);
1365 assert_eq!(ltm.get_preference("code_style"), Some("rustfmt defaults"));
1366 assert_eq!(ltm.corrections.len(), 1);
1367 }
1368
1369 #[test]
1370 fn test_long_term_memory_search() {
1371 let mut ltm = LongTermMemory::new();
1372 ltm.add_fact(Fact::new("Project uses JWT authentication", "analysis"));
1373 ltm.add_fact(
1374 Fact::new("Database is PostgreSQL", "config").with_tags(vec!["database".to_string()]),
1375 );
1376 ltm.add_fact(Fact::new("Frontend uses React", "package.json"));
1377
1378 let results = ltm.search_facts("JWT");
1379 assert_eq!(results.len(), 1);
1380 assert!(results[0].content.contains("JWT"));
1381
1382 let results = ltm.search_facts("database");
1383 assert_eq!(results.len(), 1);
1384
1385 let results = ltm.search_facts("nonexistent");
1386 assert!(results.is_empty());
1387 }
1388
1389 #[test]
1390 fn test_memory_system() {
1391 let mut mem = MemorySystem::new(5);
1392
1393 mem.start_new_task("fix bug #42");
1394 assert_eq!(mem.working.current_goal.as_deref(), Some("fix bug #42"));
1395
1396 mem.add_message(Message::user("fix the null pointer bug"));
1397 mem.add_message(Message::assistant("I'll look into that."));
1398
1399 let ctx = mem.context_messages();
1400 assert_eq!(ctx.len(), 2);
1401
1402 mem.clear_session();
1403 assert!(mem.short_term.is_empty());
1404 assert!(mem.working.current_goal.is_none());
1405 }
1406
1407 #[test]
1408 fn test_compression_no_op_when_within_window() {
1409 let mut stm = ShortTermMemory::new(10);
1410 stm.add(Message::user("hello"));
1411 stm.add(Message::assistant("hi"));
1412
1413 assert!(!stm.needs_compression());
1414 assert!(stm.messages_to_summarize().is_empty());
1415
1416 let compressed = stm.compress("should not matter".to_string());
1417 assert_eq!(compressed, 0);
1418 }
1419
1420 #[test]
1421 fn test_memory_system_new_task_preserves_long_term() {
1422 let mut mem = MemorySystem::new(5);
1423 mem.long_term.add_fact(Fact::new("important fact", "test"));
1424 mem.add_message(Message::user("task 1"));
1425
1426 mem.start_new_task("task 2");
1427 assert_eq!(mem.working.current_goal.as_deref(), Some("task 2"));
1428 assert_eq!(mem.long_term.facts.len(), 1); }
1430
1431 #[test]
1434 fn test_session_save_load_roundtrip() {
1435 let dir = tempfile::tempdir().unwrap();
1436 let session_path = dir.path().join("session.json");
1437
1438 let mut mem = MemorySystem::new(10);
1440 mem.start_new_task("fix bug #42");
1441 mem.add_message(Message::user("fix the bug"));
1442 mem.add_message(Message::assistant("Looking into it."));
1443 mem.long_term.add_fact(Fact::new("Uses Rust", "analysis"));
1444 mem.long_term.set_preference("style", "concise");
1445
1446 mem.save_session(&session_path).unwrap();
1448 assert!(session_path.exists());
1449
1450 let loaded = MemorySystem::load_session(&session_path).unwrap();
1452 assert_eq!(loaded.working.current_goal.as_deref(), Some("fix bug #42"));
1453 assert_eq!(loaded.short_term.len(), 2);
1454 assert_eq!(loaded.long_term.facts.len(), 1);
1455 assert_eq!(loaded.long_term.get_preference("style"), Some("concise"));
1456
1457 let messages = loaded.context_messages();
1459 assert_eq!(messages.len(), 2);
1460 assert_eq!(messages[0].content.as_text(), Some("fix the bug"));
1461 assert_eq!(messages[1].content.as_text(), Some("Looking into it."));
1462 }
1463
1464 #[test]
1465 fn test_session_load_missing_file() {
1466 let result = MemorySystem::load_session(Path::new("/nonexistent/session.json"));
1467 assert!(result.is_err());
1468 }
1469
1470 #[test]
1471 fn test_session_load_corrupt_json() {
1472 let dir = tempfile::tempdir().unwrap();
1473 let path = dir.path().join("bad.json");
1474 std::fs::write(&path, "not valid json").unwrap();
1475
1476 let result = MemorySystem::load_session(&path);
1477 assert!(result.is_err());
1478 }
1479
1480 #[test]
1481 fn test_session_save_creates_directories() {
1482 let dir = tempfile::tempdir().unwrap();
1483 let session_path = dir.path().join("nested").join("dir").join("session.json");
1484
1485 let mem = MemorySystem::new(5);
1486 mem.save_session(&session_path).unwrap();
1487 assert!(session_path.exists());
1488 }
1489
1490 #[test]
1491 fn test_session_metadata() {
1492 let meta = SessionMetadata::new();
1493 assert!(meta.task_summary.is_none());
1494 assert!(meta.created_at <= Utc::now());
1495
1496 let default_meta = SessionMetadata::default();
1497 assert!(default_meta.task_summary.is_none());
1498 }
1499
1500 #[test]
1501 fn test_short_term_window_size() {
1502 let stm = ShortTermMemory::new(7);
1503 assert_eq!(stm.window_size(), 7);
1504 }
1505
1506 #[test]
1509 fn test_pin_message() {
1510 let mut stm = ShortTermMemory::new(5);
1511 stm.add(Message::user("msg 0"));
1512 stm.add(Message::user("msg 1"));
1513 stm.add(Message::user("msg 2"));
1514
1515 assert!(stm.pin(1));
1516 assert!(stm.is_pinned(1));
1517 assert!(!stm.is_pinned(0));
1518 assert_eq!(stm.pinned_count(), 1);
1519 }
1520
1521 #[test]
1522 fn test_pin_out_of_bounds() {
1523 let mut stm = ShortTermMemory::new(5);
1524 stm.add(Message::user("msg 0"));
1525 assert!(!stm.pin(5)); }
1527
1528 #[test]
1529 fn test_unpin_message() {
1530 let mut stm = ShortTermMemory::new(5);
1531 stm.add(Message::user("msg 0"));
1532 stm.add(Message::user("msg 1"));
1533
1534 stm.pin(0);
1535 assert!(stm.is_pinned(0));
1536 assert!(stm.unpin(0));
1537 assert!(!stm.is_pinned(0));
1538 }
1539
1540 #[test]
1541 fn test_pinned_survives_compression() {
1542 let mut stm = ShortTermMemory::new(3);
1543 stm.add(Message::user("old 0"));
1545 stm.add(Message::user("old 1"));
1546 stm.add(Message::user("important pinned"));
1547 stm.add(Message::user("msg 3"));
1548 stm.add(Message::user("msg 4"));
1549 stm.add(Message::user("msg 5"));
1550
1551 stm.pin(2);
1553 assert!(stm.needs_compression());
1554
1555 let removed = stm.compress("Summary of old messages".to_string());
1556 assert!(removed < 3); let msgs = stm.to_messages();
1561 let has_pinned = msgs
1562 .iter()
1563 .any(|m| matches!(&m.content, Content::Text { text } if text == "important pinned"));
1564 assert!(has_pinned, "Pinned message should survive compression");
1565 }
1566
1567 #[test]
1568 fn test_clear_resets_pins() {
1569 let mut stm = ShortTermMemory::new(5);
1570 stm.add(Message::user("msg 0"));
1571 stm.pin(0);
1572 assert_eq!(stm.pinned_count(), 1);
1573
1574 stm.clear();
1575 assert_eq!(stm.pinned_count(), 0);
1576 }
1577
1578 #[test]
1581 fn test_context_breakdown() {
1582 let mut memory = MemorySystem::new(10);
1583 memory.add_message(Message::user("hello world"));
1584 memory.add_message(Message::assistant("hi there!"));
1585
1586 let ctx = memory.context_breakdown(8000);
1587 assert!(ctx.message_tokens > 0);
1588 assert_eq!(ctx.message_count, 2);
1589 assert_eq!(ctx.context_window, 8000);
1590 assert!(ctx.remaining_tokens > 0);
1591 assert!(!ctx.has_summary);
1592 assert_eq!(ctx.pinned_count, 0);
1593 }
1594
1595 #[test]
1596 fn test_context_breakdown_ratio() {
1597 let ctx = ContextBreakdown {
1598 total_tokens: 4000,
1599 context_window: 8000,
1600 ..Default::default()
1601 };
1602 assert!((ctx.usage_ratio() - 0.5).abs() < 0.01);
1603 assert!(!ctx.is_warning());
1604
1605 let ctx_high = ContextBreakdown {
1606 total_tokens: 7000,
1607 context_window: 8000,
1608 ..Default::default()
1609 };
1610 assert!(ctx_high.is_warning());
1611 }
1612
1613 #[test]
1614 fn test_pin_message_via_memory_system() {
1615 let mut memory = MemorySystem::new(10);
1616 memory.add_message(Message::user("msg 0"));
1617 memory.add_message(Message::user("msg 1"));
1618
1619 assert!(memory.pin_message(0));
1620 assert!(memory.short_term.is_pinned(0));
1621 }
1622
1623 #[test]
1626 fn test_flusher_default_config() {
1627 let config = FlushConfig::default();
1628 assert!(!config.enabled);
1629 assert_eq!(config.interval_secs, 300);
1630 assert_eq!(config.flush_on_message_count, 50);
1631 assert!(config.flush_path.is_none());
1632 }
1633
1634 #[test]
1635 fn test_flusher_not_dirty_by_default() {
1636 let flusher = MemoryFlusher::new(FlushConfig::default());
1637 assert!(!flusher.is_dirty());
1638 assert_eq!(flusher.messages_since_flush(), 0);
1639 assert_eq!(flusher.total_flushes(), 0);
1640 }
1641
1642 #[test]
1643 fn test_flusher_marks_dirty_on_message() {
1644 let mut flusher = MemoryFlusher::new(FlushConfig::default());
1645 flusher.on_message_added();
1646 assert!(flusher.is_dirty());
1647 assert_eq!(flusher.messages_since_flush(), 1);
1648 }
1649
1650 #[test]
1651 fn test_flusher_disabled_never_triggers() {
1652 let mut flusher = MemoryFlusher::new(FlushConfig {
1653 enabled: false,
1654 ..FlushConfig::default()
1655 });
1656 for _ in 0..100 {
1657 flusher.on_message_added();
1658 }
1659 assert!(!flusher.should_flush());
1660 }
1661
1662 #[test]
1663 fn test_flusher_message_count_trigger() {
1664 let mut flusher = MemoryFlusher::new(FlushConfig {
1665 enabled: true,
1666 flush_on_message_count: 5,
1667 interval_secs: 0,
1668 flush_path: None,
1669 });
1670
1671 for _ in 0..4 {
1672 flusher.on_message_added();
1673 }
1674 assert!(!flusher.should_flush());
1675
1676 flusher.on_message_added(); assert!(flusher.should_flush());
1678 }
1679
1680 #[test]
1681 fn test_flusher_not_dirty_no_trigger() {
1682 let flusher = MemoryFlusher::new(FlushConfig {
1683 enabled: true,
1684 flush_on_message_count: 1,
1685 interval_secs: 0,
1686 flush_path: None,
1687 });
1688 assert!(!flusher.should_flush());
1690 }
1691
1692 #[test]
1693 fn test_flusher_flush_resets_state() {
1694 let dir = tempfile::tempdir().unwrap();
1695 let flush_path = dir.path().join("flush.json");
1696
1697 let mut flusher = MemoryFlusher::new(FlushConfig {
1698 enabled: true,
1699 flush_on_message_count: 2,
1700 interval_secs: 0,
1701 flush_path: Some(flush_path.clone()),
1702 });
1703
1704 let mut mem = MemorySystem::new(10);
1705 mem.add_message(Message::user("test"));
1706
1707 flusher.on_message_added();
1708 flusher.on_message_added();
1709 assert!(flusher.should_flush());
1710
1711 flusher.flush(&mem).unwrap();
1712 assert!(!flusher.is_dirty());
1713 assert_eq!(flusher.messages_since_flush(), 0);
1714 assert_eq!(flusher.total_flushes(), 1);
1715 assert!(flush_path.exists());
1716 }
1717
1718 #[test]
1719 fn test_flusher_force_flush() {
1720 let dir = tempfile::tempdir().unwrap();
1721 let flush_path = dir.path().join("force.json");
1722
1723 let mut flusher = MemoryFlusher::new(FlushConfig {
1724 enabled: true,
1725 flush_on_message_count: 100,
1726 interval_secs: 0,
1727 flush_path: Some(flush_path.clone()),
1728 });
1729
1730 let mem = MemorySystem::new(10);
1731
1732 flusher.force_flush(&mem).unwrap();
1734 assert_eq!(flusher.total_flushes(), 0);
1735
1736 flusher.on_message_added();
1738 flusher.force_flush(&mem).unwrap();
1739 assert_eq!(flusher.total_flushes(), 1);
1740 assert!(!flusher.is_dirty());
1741 }
1742
1743 #[test]
1744 fn test_flusher_no_path_error() {
1745 let mut flusher = MemoryFlusher::new(FlushConfig {
1746 enabled: true,
1747 flush_on_message_count: 1,
1748 interval_secs: 0,
1749 flush_path: None,
1750 });
1751 flusher.on_message_added();
1752
1753 let mem = MemorySystem::new(10);
1754 let result = flusher.flush(&mem);
1755 assert!(result.is_err());
1756 }
1757
1758 #[test]
1759 fn test_flush_config_serialization() {
1760 let config = FlushConfig {
1761 enabled: true,
1762 interval_secs: 120,
1763 flush_on_message_count: 25,
1764 flush_path: Some(std::path::PathBuf::from("/tmp/flush.json")),
1765 };
1766 let json = serde_json::to_string(&config).unwrap();
1767 let restored: FlushConfig = serde_json::from_str(&json).unwrap();
1768 assert!(restored.enabled);
1769 assert_eq!(restored.interval_secs, 120);
1770 assert_eq!(restored.flush_on_message_count, 25);
1771 }
1772
1773 #[test]
1776 fn test_memory_system_without_search_uses_keyword_fallback() {
1777 let mut mem = MemorySystem::new(10);
1778 mem.add_fact(Fact::new("Rust uses ownership for memory safety", "docs"));
1779 mem.add_fact(Fact::new("Python uses garbage collection", "docs"));
1780
1781 let results = mem.search_facts_hybrid("ownership");
1782 assert_eq!(results.len(), 1);
1783 assert!(results[0].content.contains("ownership"));
1784 }
1785
1786 #[test]
1787 fn test_memory_system_with_search_engine() {
1788 let dir = tempfile::tempdir().unwrap();
1789 let config = SearchConfig {
1790 index_path: dir.path().join("idx"),
1791 db_path: dir.path().join("vec.db"),
1792 vector_dimensions: 64,
1793 full_text_weight: 0.5,
1794 vector_weight: 0.5,
1795 max_results: 10,
1796 };
1797 let mut mem = MemorySystem::with_search(10, config).unwrap();
1798
1799 mem.add_fact(Fact::new("Rust uses ownership model", "analysis"));
1800 mem.add_fact(Fact::new("Python garbage collector", "analysis"));
1801
1802 let results = mem.search_facts_hybrid("ownership");
1804 assert!(!results.is_empty());
1805 assert!(results.iter().any(|f| f.content.contains("ownership")));
1806 }
1807
1808 #[test]
1809 fn test_memory_system_search_empty_query() {
1810 let mut mem = MemorySystem::new(10);
1811 mem.add_fact(Fact::new("some fact", "source"));
1812 let results = mem.search_facts_hybrid("");
1813 assert!(!results.is_empty());
1817 }
1818
1819 #[test]
1820 fn test_memory_system_search_no_facts() {
1821 let mem = MemorySystem::new(10);
1822 let results = mem.search_facts_hybrid("anything");
1823 assert!(results.is_empty());
1824 }
1825
1826 #[test]
1827 fn test_add_fact_indexes_into_search_engine() {
1828 let dir = tempfile::tempdir().unwrap();
1829 let config = SearchConfig {
1830 index_path: dir.path().join("idx"),
1831 db_path: dir.path().join("vec.db"),
1832 vector_dimensions: 64,
1833 full_text_weight: 0.5,
1834 vector_weight: 0.5,
1835 max_results: 10,
1836 };
1837 let mut mem = MemorySystem::with_search(10, config).unwrap();
1838
1839 for i in 0..5 {
1841 mem.add_fact(Fact::new(format!("fact number {}", i), "test"));
1842 }
1843
1844 assert_eq!(mem.long_term.facts.len(), 5);
1846 }
1847
1848 #[test]
1851 fn test_memory_system_with_flusher() {
1852 let config = FlushConfig {
1853 enabled: true,
1854 flush_on_message_count: 5,
1855 interval_secs: 0,
1856 flush_path: None,
1857 };
1858 let mem = MemorySystem::new(10).with_flusher(config);
1859 assert!(!mem.flusher_is_dirty());
1860 }
1861
1862 #[test]
1863 fn test_memory_system_add_message_notifies_flusher() {
1864 let config = FlushConfig {
1865 enabled: true,
1866 flush_on_message_count: 5,
1867 interval_secs: 0,
1868 flush_path: None,
1869 };
1870 let mut mem = MemorySystem::new(10).with_flusher(config);
1871
1872 mem.add_message(Message::user("hello"));
1873 assert!(mem.flusher_is_dirty());
1874 }
1875
1876 #[test]
1877 fn test_memory_system_check_auto_flush_no_flusher() {
1878 let mut mem = MemorySystem::new(10);
1879 let result = mem.check_auto_flush().unwrap();
1881 assert!(!result);
1882 }
1883
1884 #[test]
1885 fn test_memory_system_check_auto_flush_triggers() {
1886 let dir = tempfile::tempdir().unwrap();
1887 let flush_path = dir.path().join("auto_flush.json");
1888
1889 let config = FlushConfig {
1890 enabled: true,
1891 flush_on_message_count: 3,
1892 interval_secs: 0,
1893 flush_path: Some(flush_path.clone()),
1894 };
1895 let mut mem = MemorySystem::new(10).with_flusher(config);
1896
1897 mem.add_message(Message::user("msg 1"));
1899 mem.add_message(Message::user("msg 2"));
1900 assert!(!mem.check_auto_flush().unwrap());
1901 assert!(!flush_path.exists());
1902
1903 mem.add_message(Message::user("msg 3"));
1905 assert!(mem.check_auto_flush().unwrap());
1906 assert!(flush_path.exists());
1907
1908 assert!(!mem.flusher_is_dirty());
1910 }
1911
1912 #[test]
1913 fn test_memory_system_force_flush() {
1914 let dir = tempfile::tempdir().unwrap();
1915 let flush_path = dir.path().join("force_flush.json");
1916
1917 let config = FlushConfig {
1918 enabled: true,
1919 flush_on_message_count: 100, interval_secs: 0,
1921 flush_path: Some(flush_path.clone()),
1922 };
1923 let mut mem = MemorySystem::new(10).with_flusher(config);
1924
1925 mem.add_message(Message::user("important data"));
1926 assert!(mem.flusher_is_dirty());
1927
1928 mem.force_flush().unwrap();
1929 assert!(!mem.flusher_is_dirty());
1930 assert!(flush_path.exists());
1931 }
1932
1933 #[test]
1934 fn test_memory_system_force_flush_no_flusher() {
1935 let mut mem = MemorySystem::new(10);
1936 mem.force_flush().unwrap();
1938 }
1939
1940 #[test]
1943 fn test_knowledge_distiller_disabled() {
1944 let distiller = KnowledgeDistiller::new(None);
1945 assert_eq!(distiller.rule_count(), 0);
1946 assert!(distiller.rules_for_prompt().is_empty());
1947 }
1948
1949 #[test]
1950 fn test_knowledge_distiller_no_data() {
1951 let config = crate::config::KnowledgeConfig::default();
1952 let mut distiller = KnowledgeDistiller::new(Some(&config));
1953 let ltm = LongTermMemory::new();
1954 distiller.distill(<m);
1955 assert_eq!(distiller.rule_count(), 0);
1956 }
1957
1958 #[test]
1959 fn test_knowledge_distiller_corrections_below_threshold() {
1960 let config = crate::config::KnowledgeConfig {
1961 min_entries_for_distillation: 5,
1962 ..Default::default()
1963 };
1964 let mut distiller = KnowledgeDistiller::new(Some(&config));
1965
1966 let mut ltm = LongTermMemory::new();
1967 ltm.add_correction(
1968 "unwrap()".into(),
1969 "? operator".into(),
1970 "error handling".into(),
1971 );
1972 ltm.add_correction("println!".into(), "tracing::info!".into(), "logging".into());
1973
1974 distiller.distill(<m);
1975 assert_eq!(distiller.rule_count(), 0);
1977 }
1978
1979 #[test]
1980 fn test_knowledge_distiller_single_corrections() {
1981 let config = crate::config::KnowledgeConfig {
1982 min_entries_for_distillation: 2,
1983 ..Default::default()
1984 };
1985 let mut distiller = KnowledgeDistiller::new(Some(&config));
1986
1987 let mut ltm = LongTermMemory::new();
1988 ltm.add_correction(
1989 "unwrap()".into(),
1990 "? operator".into(),
1991 "error handling".into(),
1992 );
1993 ltm.add_correction("println!".into(), "tracing::info!".into(), "logging".into());
1994 distiller.distill(<m);
1997 assert_eq!(distiller.rule_count(), 2);
1998
1999 let prompt = distiller.rules_for_prompt();
2000 assert!(prompt.contains("Learned Behavioral Rules"));
2001 assert!(prompt.contains("? operator"));
2002 assert!(prompt.contains("tracing::info!"));
2003 }
2004
2005 #[test]
2006 fn test_knowledge_distiller_grouped_corrections() {
2007 let config = crate::config::KnowledgeConfig {
2008 min_entries_for_distillation: 2,
2009 ..Default::default()
2010 };
2011 let mut distiller = KnowledgeDistiller::new(Some(&config));
2012
2013 let mut ltm = LongTermMemory::new();
2014 ltm.add_correction(
2016 "unwrap()".into(),
2017 "? operator".into(),
2018 "error handling in Rust code".into(),
2019 );
2020 ltm.add_correction(
2021 "expect()".into(),
2022 "map_err()?".into(),
2023 "error handling in Rust code".into(),
2024 );
2025
2026 distiller.distill(<m);
2027 assert_eq!(distiller.rule_count(), 1);
2028 let prompt = distiller.rules_for_prompt();
2029 assert!(prompt.contains("2 previous corrections"));
2030 }
2031
2032 #[test]
2033 fn test_knowledge_distiller_preference_facts() {
2034 let config = crate::config::KnowledgeConfig {
2035 min_entries_for_distillation: 1,
2036 ..Default::default()
2037 };
2038 let mut distiller = KnowledgeDistiller::new(Some(&config));
2039
2040 let mut ltm = LongTermMemory::new();
2041 ltm.add_fact(Fact::new("Prefer async/await over threads", "user"));
2042 ltm.add_fact(Fact::new("Project uses PostgreSQL", "session"));
2043
2044 distiller.distill(<m);
2045 assert_eq!(distiller.rule_count(), 1);
2047 let prompt = distiller.rules_for_prompt();
2048 assert!(prompt.contains("async/await"));
2049 }
2050
2051 #[test]
2052 fn test_knowledge_distiller_max_rules_truncation() {
2053 let config = crate::config::KnowledgeConfig {
2054 min_entries_for_distillation: 1,
2055 max_rules: 3,
2056 ..Default::default()
2057 };
2058 let mut distiller = KnowledgeDistiller::new(Some(&config));
2059
2060 let mut ltm = LongTermMemory::new();
2061 for i in 0..10 {
2062 ltm.add_correction(
2063 format!("old{}", i),
2064 format!("new{}", i),
2065 format!("context{}", i),
2066 );
2067 }
2068
2069 distiller.distill(<m);
2070 assert!(distiller.rule_count() <= 3);
2071 }
2072
2073 #[test]
2074 fn test_knowledge_distiller_idempotent() {
2075 let config = crate::config::KnowledgeConfig {
2076 min_entries_for_distillation: 1,
2077 ..Default::default()
2078 };
2079 let mut distiller = KnowledgeDistiller::new(Some(&config));
2080
2081 let mut ltm = LongTermMemory::new();
2082 ltm.add_correction("old".into(), "new".into(), "ctx".into());
2083
2084 distiller.distill(<m);
2085 let count_after_first = distiller.rule_count();
2086
2087 distiller.distill(<m);
2089 assert_eq!(distiller.rule_count(), count_after_first);
2090 }
2091
2092 #[test]
2093 fn test_knowledge_store_save_load_roundtrip() {
2094 let dir = tempfile::tempdir().unwrap();
2095 let path = dir.path().join("knowledge.json");
2096
2097 let mut store = KnowledgeStore::new();
2098 store.rules.push(BehavioralRule {
2099 id: Uuid::new_v4(),
2100 rule: "Prefer ? over unwrap".into(),
2101 source_ids: vec![Uuid::new_v4()],
2102 support_count: 3,
2103 created_at: Utc::now(),
2104 });
2105
2106 store.save(&path).unwrap();
2107 let loaded = KnowledgeStore::load(&path).unwrap();
2108 assert_eq!(loaded.rules.len(), 1);
2109 assert_eq!(loaded.rules[0].rule, "Prefer ? over unwrap");
2110 assert_eq!(loaded.rules[0].support_count, 3);
2111 }
2112
2113 #[test]
2114 fn test_knowledge_store_load_nonexistent() {
2115 let store =
2116 KnowledgeStore::load(std::path::Path::new("/nonexistent/knowledge.json")).unwrap();
2117 assert!(store.rules.is_empty());
2118 }
2119
2120 #[test]
2121 fn test_unpin_out_of_bounds_returns_false() {
2122 let mut stm = ShortTermMemory::new(100);
2123 stm.add(Message::user("hello"));
2124 stm.add(Message::assistant("hi"));
2125 stm.add(Message::user("world"));
2126
2127 assert!(stm.pin(1));
2129 assert!(stm.is_pinned(1));
2130
2131 assert!(!stm.unpin(999));
2133 assert!(!stm.unpin(3));
2134
2135 assert!(stm.is_pinned(1));
2137 }
2138
2139 #[test]
2140 fn test_unpin_at_exact_boundary() {
2141 let mut stm = ShortTermMemory::new(100);
2142 stm.add(Message::user("msg0"));
2143 stm.add(Message::user("msg1"));
2144
2145 assert!(!stm.unpin(2));
2147
2148 assert!(!stm.unpin(1)); assert!(stm.pin(1));
2153 assert!(stm.unpin(1));
2154 assert!(!stm.is_pinned(1));
2155 }
2156
2157 #[test]
2160 fn test_compress_preserves_tool_chain_pairs() {
2161 let mut stm = ShortTermMemory::new(3);
2162
2163 stm.add(Message::user("read the file")); stm.add(Message::new(
2166 Role::Assistant,
2168 Content::tool_call("call_abc", "file_read", serde_json::json!({"path": "x.rs"})),
2169 ));
2170 stm.add(Message::tool_result("call_abc", "fn main() {}", false)); stm.add(Message::assistant("Here is the content.")); stm.add(Message::user("thanks")); stm.add(Message::assistant("You're welcome.")); stm.pin(2);
2177 assert!(stm.needs_compression());
2178
2179 let removed = stm.compress("Summary of earlier conversation".to_string());
2180
2181 let msgs = stm.to_messages();
2184 let has_tool_call = msgs.iter().any(|m| {
2185 matches!(
2186 &m.content,
2187 Content::ToolCall { name, .. } if name == "file_read"
2188 )
2189 });
2190 let has_tool_result = msgs.iter().any(|m| {
2191 matches!(
2192 &m.content,
2193 Content::ToolResult { call_id, .. } if call_id == "call_abc"
2194 )
2195 });
2196
2197 assert!(
2198 has_tool_result,
2199 "Pinned tool_result should survive compression"
2200 );
2201 assert!(
2202 has_tool_call,
2203 "Paired tool_call should also survive compression"
2204 );
2205
2206 assert!(removed >= 1, "At least one message should be compressed");
2209 }
2210
2211 #[test]
2212 fn test_compress_preserves_tool_call_paired_with_result() {
2213 let mut stm = ShortTermMemory::new(3);
2214
2215 stm.add(Message::user("do something")); stm.add(Message::new(
2217 Role::Assistant,
2219 Content::tool_call("call_xyz", "shell_exec", serde_json::json!({"cmd": "ls"})),
2220 ));
2221 stm.add(Message::tool_result("call_xyz", "file1\nfile2", false)); stm.add(Message::assistant("Listed files.")); stm.add(Message::user("ok")); stm.add(Message::assistant("Done.")); stm.pin(1);
2228 assert!(stm.needs_compression());
2229
2230 stm.compress("Summary".to_string());
2231
2232 let msgs = stm.to_messages();
2233 let has_tool_call = msgs.iter().any(|m| {
2234 matches!(
2235 &m.content,
2236 Content::ToolCall { id, .. } if id == "call_xyz"
2237 )
2238 });
2239 let has_tool_result = msgs.iter().any(|m| {
2240 matches!(
2241 &m.content,
2242 Content::ToolResult { call_id, .. } if call_id == "call_xyz"
2243 )
2244 });
2245
2246 assert!(has_tool_call, "Pinned tool_call should survive compression");
2247 assert!(
2248 has_tool_result,
2249 "Paired tool_result should also survive compression"
2250 );
2251 }
2252}