1use crate::error::MemoryError;
8use crate::search::{HybridSearchEngine, SearchConfig};
9use crate::types::{Content, Message, Role};
10use chrono::{DateTime, Utc};
11use serde::{Deserialize, Serialize};
12use std::collections::{HashMap, HashSet, VecDeque};
13use std::path::Path;
14use uuid::Uuid;
15
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
18pub struct WorkingMemory {
19 pub current_goal: Option<String>,
20 pub sub_tasks: Vec<String>,
21 pub scratchpad: HashMap<String, String>,
22 pub active_files: Vec<String>,
23}
24
25impl WorkingMemory {
26 pub fn new() -> Self {
27 Self::default()
28 }
29
30 pub fn set_goal(&mut self, goal: impl Into<String>) {
31 self.current_goal = Some(goal.into());
32 }
33
34 pub fn add_sub_task(&mut self, task: impl Into<String>) {
35 self.sub_tasks.push(task.into());
36 }
37
38 pub fn note(&mut self, key: impl Into<String>, value: impl Into<String>) {
39 self.scratchpad.insert(key.into(), value.into());
40 }
41
42 pub fn add_active_file(&mut self, path: impl Into<String>) {
43 let path = path.into();
44 if !self.active_files.contains(&path) {
45 self.active_files.push(path);
46 }
47 }
48
49 pub fn clear(&mut self) {
50 *self = Self::default();
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct ShortTermMemory {
57 messages: VecDeque<Message>,
58 window_size: usize,
59 summarized_prefix: Option<String>,
60 total_messages_seen: usize,
61 pinned: std::collections::HashSet<usize>,
63 compressed_offset: usize,
65}
66
67impl ShortTermMemory {
68 pub fn new(window_size: usize) -> Self {
69 Self {
70 messages: VecDeque::new(),
71 window_size,
72 summarized_prefix: None,
73 total_messages_seen: 0,
74 pinned: std::collections::HashSet::new(),
75 compressed_offset: 0,
76 }
77 }
78
79 pub fn add(&mut self, message: Message) {
81 self.messages.push_back(message);
82 self.total_messages_seen += 1;
83 }
84
85 pub fn to_messages(&self) -> Vec<Message> {
87 let mut result = Vec::new();
88
89 if let Some(ref summary) = self.summarized_prefix {
91 result.push(Message::system(format!(
92 "[Summary of earlier conversation]\n{}",
93 summary
94 )));
95 }
96
97 let start = if self.messages.len() > self.window_size {
100 self.messages.len() - self.window_size
101 } else {
102 0
103 };
104
105 for (i, msg) in self.messages.iter().enumerate() {
106 if i >= start || self.is_pinned(i) {
107 result.push(msg.clone());
108 }
109 }
110
111 result
112 }
113
114 pub fn needs_compression(&self) -> bool {
116 self.messages.len() >= self.window_size * 2
117 }
118
119 pub fn compress(&mut self, summary: String) -> usize {
125 if self.messages.len() <= self.window_size {
126 return 0;
127 }
128
129 let to_remove = self.messages.len() - self.window_size;
130
131 let mut preserve_indices: HashSet<usize> = HashSet::new();
133 for i in 0..to_remove {
134 let abs_idx = self.compressed_offset + i;
135 if self.pinned.contains(&abs_idx) {
136 preserve_indices.insert(i);
137 }
138 }
139
140 let mut extra_preserves: Vec<usize> = Vec::new();
142 for &i in &preserve_indices {
143 if let Some(msg) = self.messages.get(i) {
144 match &msg.content {
145 Content::ToolResult { call_id, .. } => {
147 if let Some(pair_idx) =
148 Self::find_tool_call_for_result(call_id, &self.messages, to_remove)
149 && !preserve_indices.contains(&pair_idx)
150 {
151 extra_preserves.push(pair_idx);
152 }
153 }
154 Content::ToolCall { id, .. } => {
156 if let Some(pair_idx) =
157 Self::find_tool_result_for_call(id, &self.messages, to_remove)
158 && !preserve_indices.contains(&pair_idx)
159 {
160 extra_preserves.push(pair_idx);
161 }
162 }
163 Content::MultiPart { parts } => {
165 for part in parts {
166 match part {
167 Content::ToolCall { id, .. } => {
168 if let Some(pair_idx) = Self::find_tool_result_for_call(
169 id,
170 &self.messages,
171 to_remove,
172 ) && !preserve_indices.contains(&pair_idx)
173 {
174 extra_preserves.push(pair_idx);
175 }
176 }
177 Content::ToolResult { call_id, .. } => {
178 if let Some(pair_idx) = Self::find_tool_call_for_result(
179 call_id,
180 &self.messages,
181 to_remove,
182 ) && !preserve_indices.contains(&pair_idx)
183 {
184 extra_preserves.push(pair_idx);
185 }
186 }
187 _ => {}
188 }
189 }
190 }
191 _ => {}
192 }
193 }
194 }
195 for idx in extra_preserves {
196 preserve_indices.insert(idx);
197 }
198
199 let mut preserved = Vec::new();
201 let mut removed_count = 0;
202
203 for i in 0..to_remove {
204 if preserve_indices.contains(&i) {
205 if let Some(msg) = self.messages.get(i) {
206 preserved.push(msg.clone());
207 }
208 } else {
209 removed_count += 1;
210 }
211 }
212
213 let mut surviving_pinned: Vec<usize> = Vec::new();
216 for i in to_remove..self.messages.len() {
217 let abs_idx = self.compressed_offset + i;
218 if self.pinned.contains(&abs_idx) {
219 surviving_pinned.push(i - to_remove);
220 }
221 }
222
223 self.messages.drain(..to_remove);
225 self.compressed_offset += to_remove;
226
227 let preserved_count = preserved.len();
230 if !preserved.is_empty() {
231 let mut new_messages = VecDeque::with_capacity(preserved_count + self.messages.len());
232 for msg in preserved {
233 new_messages.push_back(msg);
234 }
235 new_messages.append(&mut self.messages);
236 self.messages = new_messages;
237 }
238
239 let mut new_pinned = HashSet::new();
243 for i in 0..preserved_count {
244 new_pinned.insert(self.compressed_offset + i);
245 }
246 for pos in surviving_pinned {
247 new_pinned.insert(self.compressed_offset + preserved_count + pos);
248 }
249 self.pinned = new_pinned;
250
251 if let Some(ref existing) = self.summarized_prefix {
253 self.summarized_prefix = Some(format!("{}\n\n{}", existing, summary));
254 } else {
255 self.summarized_prefix = Some(summary);
256 }
257
258 removed_count
259 }
260
261 fn find_tool_call_for_result(
264 call_id: &str,
265 messages: &VecDeque<Message>,
266 limit: usize,
267 ) -> Option<usize> {
268 messages
269 .iter()
270 .enumerate()
271 .take(limit)
272 .find(|(_, msg)| {
273 msg.role == Role::Assistant
274 && Self::content_contains_tool_call_id(&msg.content, call_id)
275 })
276 .map(|(i, _)| i)
277 }
278
279 fn find_tool_result_for_call(
282 tool_call_id: &str,
283 messages: &VecDeque<Message>,
284 limit: usize,
285 ) -> Option<usize> {
286 messages
287 .iter()
288 .enumerate()
289 .take(limit)
290 .find(|(_, msg)| {
291 (msg.role == Role::Tool || msg.role == Role::User)
292 && Self::content_contains_tool_result_id(&msg.content, tool_call_id)
293 })
294 .map(|(i, _)| i)
295 }
296
297 fn content_contains_tool_call_id(content: &Content, target_id: &str) -> bool {
299 match content {
300 Content::ToolCall { id, .. } => id == target_id,
301 Content::MultiPart { parts } => parts
302 .iter()
303 .any(|p| Self::content_contains_tool_call_id(p, target_id)),
304 _ => false,
305 }
306 }
307
308 fn content_contains_tool_result_id(content: &Content, target_id: &str) -> bool {
310 match content {
311 Content::ToolResult { call_id, .. } => call_id == target_id,
312 Content::MultiPart { parts } => parts
313 .iter()
314 .any(|p| Self::content_contains_tool_result_id(p, target_id)),
315 _ => false,
316 }
317 }
318
319 pub fn pin(&mut self, position: usize) -> bool {
322 if position >= self.messages.len() {
323 return false;
324 }
325 let abs_idx = self.compressed_offset + position;
326 self.pinned.insert(abs_idx);
327 true
328 }
329
330 pub fn unpin(&mut self, position: usize) -> bool {
332 if position >= self.messages.len() {
333 return false;
334 }
335 let abs_idx = self.compressed_offset + position;
336 self.pinned.remove(&abs_idx)
337 }
338
339 pub fn is_pinned(&self, position: usize) -> bool {
341 let abs_idx = self.compressed_offset + position;
342 self.pinned.contains(&abs_idx)
343 }
344
345 pub fn pinned_count(&self) -> usize {
347 (0..self.messages.len())
349 .filter(|&i| self.is_pinned(i))
350 .count()
351 }
352
353 pub fn messages_to_summarize(&self) -> Vec<&Message> {
355 if self.messages.len() <= self.window_size {
356 return Vec::new();
357 }
358 let to_summarize = self.messages.len() - self.window_size;
359 self.messages.iter().take(to_summarize).collect()
360 }
361
362 pub fn len(&self) -> usize {
364 self.messages.len()
365 }
366
367 pub fn is_empty(&self) -> bool {
369 self.messages.is_empty()
370 }
371
372 pub fn total_messages_seen(&self) -> usize {
374 self.total_messages_seen
375 }
376
377 pub fn clear(&mut self) {
379 self.messages.clear();
380 self.summarized_prefix = None;
381 self.total_messages_seen = 0;
382 self.pinned.clear();
383 self.compressed_offset = 0;
384 }
385
386 pub fn messages(&self) -> &VecDeque<Message> {
388 &self.messages
389 }
390
391 pub fn summary(&self) -> Option<&str> {
393 self.summarized_prefix.as_deref()
394 }
395}
396
397#[derive(Debug, Clone, Serialize, Deserialize)]
399pub struct Fact {
400 pub id: Uuid,
401 pub content: String,
402 pub source: String,
403 pub created_at: DateTime<Utc>,
404 pub tags: Vec<String>,
405}
406
407impl Fact {
408 pub fn new(content: impl Into<String>, source: impl Into<String>) -> Self {
409 Self {
410 id: Uuid::new_v4(),
411 content: content.into(),
412 source: source.into(),
413 created_at: Utc::now(),
414 tags: Vec::new(),
415 }
416 }
417
418 pub fn with_tags(mut self, tags: Vec<String>) -> Self {
419 self.tags = tags;
420 self
421 }
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
426pub struct LongTermMemory {
427 pub facts: Vec<Fact>,
428 pub preferences: HashMap<String, String>,
429 pub corrections: Vec<Correction>,
430 #[serde(default = "LongTermMemory::default_max_facts")]
432 pub max_facts: usize,
433 #[serde(default = "LongTermMemory::default_max_corrections")]
435 pub max_corrections: usize,
436}
437
438impl Default for LongTermMemory {
439 fn default() -> Self {
440 Self {
441 facts: Vec::new(),
442 preferences: HashMap::new(),
443 corrections: Vec::new(),
444 max_facts: Self::default_max_facts(),
445 max_corrections: Self::default_max_corrections(),
446 }
447 }
448}
449
450#[derive(Debug, Clone, Serialize, Deserialize)]
452pub struct Correction {
453 pub id: Uuid,
454 pub original: String,
455 pub corrected: String,
456 pub context: String,
457 pub timestamp: DateTime<Utc>,
458}
459
460impl LongTermMemory {
461 pub fn new() -> Self {
462 Self::default()
463 }
464
465 fn default_max_facts() -> usize {
466 10_000
467 }
468
469 fn default_max_corrections() -> usize {
470 1_000
471 }
472
473 pub fn add_fact(&mut self, fact: Fact) {
474 if self.facts.len() >= self.max_facts {
475 self.facts.remove(0);
476 }
477 self.facts.push(fact);
478 }
479
480 pub fn set_preference(&mut self, key: impl Into<String>, value: impl Into<String>) {
481 self.preferences.insert(key.into(), value.into());
482 }
483
484 pub fn get_preference(&self, key: &str) -> Option<&str> {
485 self.preferences.get(key).map(|s| s.as_str())
486 }
487
488 pub fn add_correction(&mut self, original: String, corrected: String, context: String) {
489 if self.corrections.len() >= self.max_corrections {
490 self.corrections.remove(0);
491 }
492 self.corrections.push(Correction {
493 id: Uuid::new_v4(),
494 original,
495 corrected,
496 context,
497 timestamp: Utc::now(),
498 });
499 }
500
501 pub fn search_facts(&self, query: &str) -> Vec<&Fact> {
503 let query_lower = query.to_lowercase();
504 self.facts
505 .iter()
506 .filter(|f| {
507 f.content.to_lowercase().contains(&query_lower)
508 || f.tags
509 .iter()
510 .any(|t| t.to_lowercase().contains(&query_lower))
511 })
512 .collect()
513 }
514}
515
516pub struct MemorySystem {
518 pub working: WorkingMemory,
519 pub short_term: ShortTermMemory,
520 pub long_term: LongTermMemory,
521 search_engine: Option<HybridSearchEngine>,
523 flusher: Option<MemoryFlusher>,
525}
526
527impl MemorySystem {
528 pub fn new(window_size: usize) -> Self {
529 Self {
530 working: WorkingMemory::new(),
531 short_term: ShortTermMemory::new(window_size),
532 long_term: LongTermMemory::new(),
533 search_engine: None,
534 flusher: None,
535 }
536 }
537
538 pub fn with_search(
540 window_size: usize,
541 search_config: SearchConfig,
542 ) -> Result<Self, crate::search::SearchError> {
543 let engine = HybridSearchEngine::open(search_config)?;
544 Ok(Self {
545 working: WorkingMemory::new(),
546 short_term: ShortTermMemory::new(window_size),
547 long_term: LongTermMemory::new(),
548 search_engine: Some(engine),
549 flusher: None,
550 })
551 }
552
553 pub fn with_flusher(mut self, config: FlushConfig) -> Self {
555 self.flusher = Some(MemoryFlusher::new(config));
556 self
557 }
558
559 pub fn context_messages(&self) -> Vec<Message> {
561 self.short_term.to_messages()
562 }
563
564 pub fn add_message(&mut self, message: Message) {
566 self.short_term.add(message);
567 if let Some(ref mut flusher) = self.flusher {
569 flusher.on_message_added();
570 }
571 }
572
573 pub fn add_fact(&mut self, fact: Fact) {
575 if let Some(ref mut engine) = self.search_engine {
576 let _ = engine.index_fact(&fact.id.to_string(), &fact.content);
577 }
578 self.long_term.add_fact(fact);
579 }
580
581 pub fn search_facts_hybrid(&self, query: &str) -> Vec<&Fact> {
583 if let Some(ref engine) = self.search_engine
584 && let Ok(results) = engine.search(query)
585 {
586 let ids: Vec<String> = results.iter().map(|r| r.fact_id.clone()).collect();
587 let found: Vec<&Fact> = self
588 .long_term
589 .facts
590 .iter()
591 .filter(|f| ids.contains(&f.id.to_string()))
592 .collect();
593 if !found.is_empty() {
594 return found;
595 }
596 }
597 self.long_term.search_facts(query)
599 }
600
601 pub fn check_auto_flush(&mut self) -> Result<bool, MemoryError> {
605 let mut flusher = match self.flusher.take() {
606 Some(f) => f,
607 None => return Ok(false),
608 };
609 let result = if flusher.should_flush() {
610 flusher.flush(self)?;
611 Ok(true)
612 } else {
613 Ok(false)
614 };
615 self.flusher = Some(flusher);
616 result
617 }
618
619 pub fn force_flush(&mut self) -> Result<(), MemoryError> {
621 let mut flusher = match self.flusher.take() {
622 Some(f) => f,
623 None => return Ok(()),
624 };
625 let result = flusher.force_flush(self);
626 self.flusher = Some(flusher);
627 result
628 }
629
630 pub fn flusher_is_dirty(&self) -> bool {
632 self.flusher.as_ref().is_some_and(|f| f.is_dirty())
633 }
634
635 pub fn start_new_task(&mut self, goal: impl Into<String>) {
637 self.working.clear();
638 self.working.set_goal(goal);
639 }
640
641 pub fn clear_session(&mut self) {
643 self.working.clear();
644 self.short_term.clear();
645 }
646
647 pub fn context_breakdown(&self, context_window: usize) -> ContextBreakdown {
649 let summary_chars = self.short_term.summary().map(|s| s.len()).unwrap_or(0);
650 let message_chars: usize = self
651 .short_term
652 .messages()
653 .iter()
654 .map(|m| m.content_length())
655 .sum();
656 let total_chars = summary_chars + message_chars;
657
658 let summary_tokens = summary_chars / 4;
660 let message_tokens = message_chars / 4;
661 let total_tokens = total_chars / 4;
662 let remaining_tokens = context_window.saturating_sub(total_tokens);
663
664 ContextBreakdown {
665 summary_tokens,
666 message_tokens,
667 total_tokens,
668 context_window,
669 remaining_tokens,
670 message_count: self.short_term.len(),
671 total_messages_seen: self.short_term.total_messages_seen(),
672 pinned_count: self.short_term.pinned_count(),
673 has_summary: self.short_term.summary().is_some(),
674 facts_count: self.long_term.facts.len(),
675 rules_count: 0, }
677 }
678
679 pub fn pin_message(&mut self, position: usize) -> bool {
681 self.short_term.pin(position)
682 }
683
684 pub fn unpin_message(&mut self, position: usize) -> bool {
686 self.short_term.unpin(position)
687 }
688}
689
690#[derive(Debug, Clone, Default)]
692pub struct ContextBreakdown {
693 pub summary_tokens: usize,
695 pub message_tokens: usize,
697 pub total_tokens: usize,
699 pub context_window: usize,
701 pub remaining_tokens: usize,
703 pub message_count: usize,
705 pub total_messages_seen: usize,
707 pub pinned_count: usize,
709 pub has_summary: bool,
711 pub facts_count: usize,
713 pub rules_count: usize,
715}
716
717impl ContextBreakdown {
718 pub fn usage_ratio(&self) -> f32 {
720 if self.context_window == 0 {
721 return 0.0;
722 }
723 (self.total_tokens as f32 / self.context_window as f32).clamp(0.0, 1.0)
724 }
725
726 pub fn is_warning(&self) -> bool {
729 self.usage_ratio() >= 0.7
730 }
731}
732
733#[derive(Debug, Clone, Serialize, Deserialize)]
735pub struct SessionMetadata {
736 pub id: Uuid,
737 pub created_at: DateTime<Utc>,
738 pub updated_at: DateTime<Utc>,
739 pub task_summary: Option<String>,
740}
741
742impl SessionMetadata {
743 pub fn new() -> Self {
744 let now = Utc::now();
745 Self {
746 id: Uuid::new_v4(),
747 created_at: now,
748 updated_at: now,
749 task_summary: None,
750 }
751 }
752}
753
754impl Default for SessionMetadata {
755 fn default() -> Self {
756 Self::new()
757 }
758}
759
760#[derive(Debug, Clone, Serialize, Deserialize)]
762pub struct Session {
763 pub metadata: SessionMetadata,
764 pub working: WorkingMemory,
765 pub long_term: LongTermMemory,
766 pub messages: Vec<Message>,
767 pub window_size: usize,
768}
769
770impl MemorySystem {
771 pub fn save_session(&self, path: &Path) -> Result<(), MemoryError> {
773 let session = Session {
774 metadata: SessionMetadata {
775 id: Uuid::new_v4(),
776 created_at: Utc::now(),
777 updated_at: Utc::now(),
778 task_summary: self.working.current_goal.clone(),
779 },
780 working: self.working.clone(),
781 long_term: self.long_term.clone(),
782 messages: self.short_term.messages().iter().cloned().collect(),
783 window_size: self.short_term.window_size(),
784 };
785
786 let json =
787 serde_json::to_string_pretty(&session).map_err(|e| MemoryError::PersistenceError {
788 message: format!("Failed to serialize session: {}", e),
789 })?;
790
791 if let Some(parent) = path.parent() {
792 std::fs::create_dir_all(parent).map_err(|e| MemoryError::PersistenceError {
793 message: format!("Failed to create directory: {}", e),
794 })?;
795 }
796
797 std::fs::write(path, json).map_err(|e| MemoryError::PersistenceError {
798 message: format!("Failed to write session file: {}", e),
799 })?;
800
801 Ok(())
802 }
803
804 pub fn load_session(path: &Path) -> Result<Self, MemoryError> {
806 let json = std::fs::read_to_string(path).map_err(|e| MemoryError::SessionLoadFailed {
807 message: format!("Failed to read session file: {}", e),
808 })?;
809
810 let session: Session =
811 serde_json::from_str(&json).map_err(|e| MemoryError::SessionLoadFailed {
812 message: format!("Failed to deserialize session: {}", e),
813 })?;
814
815 let mut memory = MemorySystem::new(session.window_size);
816 memory.working = session.working;
817 memory.long_term = session.long_term;
818 for msg in session.messages {
819 memory.short_term.add(msg);
820 }
821
822 Ok(memory)
823 }
824}
825
826impl ShortTermMemory {
827 pub fn window_size(&self) -> usize {
829 self.window_size
830 }
831}
832
833#[derive(Debug, Clone)]
835pub struct CompressionResult {
836 pub messages_before: usize,
837 pub messages_after: usize,
838 pub compressed_count: usize,
839}
840
841#[derive(Debug, Clone, Serialize, Deserialize)]
847pub struct FlushConfig {
848 pub enabled: bool,
850 pub interval_secs: u64,
852 pub flush_on_message_count: usize,
854 pub flush_path: Option<std::path::PathBuf>,
856}
857
858impl Default for FlushConfig {
859 fn default() -> Self {
860 Self {
861 enabled: false,
862 interval_secs: 300, flush_on_message_count: 50,
864 flush_path: None,
865 }
866 }
867}
868
869#[derive(Debug, Clone)]
871pub struct MemoryFlusher {
872 config: FlushConfig,
873 dirty: bool,
874 messages_since_flush: usize,
875 last_flush: DateTime<Utc>,
876 total_flushes: usize,
877}
878
879impl MemoryFlusher {
880 pub fn new(config: FlushConfig) -> Self {
882 Self {
883 config,
884 dirty: false,
885 messages_since_flush: 0,
886 last_flush: Utc::now(),
887 total_flushes: 0,
888 }
889 }
890
891 pub fn on_message_added(&mut self) {
893 self.dirty = true;
894 self.messages_since_flush += 1;
895 }
896
897 pub fn should_flush(&self) -> bool {
899 if !self.config.enabled || !self.dirty {
900 return false;
901 }
902
903 if self.config.flush_on_message_count > 0
905 && self.messages_since_flush >= self.config.flush_on_message_count
906 {
907 return true;
908 }
909
910 if self.config.interval_secs > 0 {
912 let elapsed = (Utc::now() - self.last_flush).num_seconds();
913 if elapsed >= self.config.interval_secs as i64 {
914 return true;
915 }
916 }
917
918 false
919 }
920
921 pub fn flush(&mut self, memory: &MemorySystem) -> Result<(), MemoryError> {
923 let path =
924 self.config
925 .flush_path
926 .as_ref()
927 .ok_or_else(|| MemoryError::PersistenceError {
928 message: "No flush path configured".to_string(),
929 })?;
930
931 memory.save_session(path)?;
932 self.mark_flushed();
933 Ok(())
934 }
935
936 pub fn force_flush(&mut self, memory: &MemorySystem) -> Result<(), MemoryError> {
938 if !self.dirty {
939 return Ok(()); }
941 self.flush(memory)
942 }
943
944 pub fn is_dirty(&self) -> bool {
946 self.dirty
947 }
948
949 pub fn messages_since_flush(&self) -> usize {
951 self.messages_since_flush
952 }
953
954 pub fn total_flushes(&self) -> usize {
956 self.total_flushes
957 }
958
959 fn mark_flushed(&mut self) {
961 self.dirty = false;
962 self.messages_since_flush = 0;
963 self.last_flush = Utc::now();
964 self.total_flushes += 1;
965 }
966}
967
968#[derive(Debug, Clone, Serialize, Deserialize)]
974pub struct BehavioralRule {
975 pub id: Uuid,
977 pub rule: String,
979 pub source_ids: Vec<Uuid>,
981 pub support_count: usize,
983 pub created_at: DateTime<Utc>,
985}
986
987#[derive(Debug, Clone, Default, Serialize, Deserialize)]
989pub struct KnowledgeStore {
990 pub rules: Vec<BehavioralRule>,
991 pub processed_correction_ids: Vec<Uuid>,
993 pub processed_fact_ids: Vec<Uuid>,
995}
996
997impl KnowledgeStore {
998 pub fn new() -> Self {
999 Self::default()
1000 }
1001
1002 pub fn load(path: &std::path::Path) -> Result<Self, MemoryError> {
1004 if !path.exists() {
1005 return Ok(Self::new());
1006 }
1007 let json = std::fs::read_to_string(path).map_err(|e| MemoryError::PersistenceError {
1008 message: format!("Failed to read knowledge store: {}", e),
1009 })?;
1010 serde_json::from_str(&json).map_err(|e| MemoryError::PersistenceError {
1011 message: format!("Failed to parse knowledge store: {}", e),
1012 })
1013 }
1014
1015 pub fn save(&self, path: &std::path::Path) -> Result<(), MemoryError> {
1017 if let Some(parent) = path.parent() {
1018 std::fs::create_dir_all(parent).map_err(|e| MemoryError::PersistenceError {
1019 message: format!("Failed to create knowledge directory: {}", e),
1020 })?;
1021 }
1022 let json =
1023 serde_json::to_string_pretty(self).map_err(|e| MemoryError::PersistenceError {
1024 message: format!("Failed to serialize knowledge store: {}", e),
1025 })?;
1026 std::fs::write(path, json).map_err(|e| MemoryError::PersistenceError {
1027 message: format!("Failed to write knowledge store: {}", e),
1028 })
1029 }
1030}
1031
1032pub struct KnowledgeDistiller {
1037 store: KnowledgeStore,
1038 max_rules: usize,
1039 min_entries: usize,
1040 store_path: Option<std::path::PathBuf>,
1041}
1042
1043impl KnowledgeDistiller {
1044 pub fn new(config: Option<&crate::config::KnowledgeConfig>) -> Self {
1047 match config {
1048 Some(cfg) if cfg.enabled => {
1049 let store = cfg
1050 .knowledge_path
1051 .as_ref()
1052 .and_then(|p| KnowledgeStore::load(p).ok())
1053 .unwrap_or_default();
1054 Self {
1055 store,
1056 max_rules: cfg.max_rules,
1057 min_entries: cfg.min_entries_for_distillation,
1058 store_path: cfg.knowledge_path.clone(),
1059 }
1060 }
1061 _ => Self {
1062 store: KnowledgeStore::new(),
1063 max_rules: 0,
1064 min_entries: usize::MAX,
1065 store_path: None,
1066 },
1067 }
1068 }
1069
1070 pub fn distill(&mut self, long_term: &LongTermMemory) {
1075 let new_corrections: Vec<&Correction> = long_term
1077 .corrections
1078 .iter()
1079 .filter(|c| !self.store.processed_correction_ids.contains(&c.id))
1080 .collect();
1081
1082 let new_facts: Vec<&Fact> = long_term
1084 .facts
1085 .iter()
1086 .filter(|f| !self.store.processed_fact_ids.contains(&f.id))
1087 .collect();
1088
1089 let total_new = new_corrections.len() + new_facts.len();
1090 if total_new < self.min_entries {
1091 return; }
1093
1094 let mut context_groups: HashMap<String, Vec<&Correction>> = HashMap::new();
1097 for correction in &new_corrections {
1098 let key = correction
1100 .context
1101 .chars()
1102 .take(50)
1103 .collect::<String>()
1104 .to_lowercase();
1105 context_groups.entry(key).or_default().push(correction);
1106 }
1107
1108 for group in context_groups.values() {
1110 if group.len() >= 2 {
1111 let corrected_patterns: Vec<&str> =
1113 group.iter().map(|c| c.corrected.as_str()).collect();
1114 let rule_text = format!(
1115 "Based on {} previous corrections: prefer {}",
1116 group.len(),
1117 corrected_patterns.join("; ")
1118 );
1119 let source_ids: Vec<Uuid> = group.iter().map(|c| c.id).collect();
1120 self.store.rules.push(BehavioralRule {
1121 id: Uuid::new_v4(),
1122 rule: rule_text,
1123 source_ids,
1124 support_count: group.len(),
1125 created_at: Utc::now(),
1126 });
1127 } else {
1128 for c in group {
1130 self.store.rules.push(BehavioralRule {
1131 id: Uuid::new_v4(),
1132 rule: format!("Instead of '{}', prefer '{}'", c.original, c.corrected),
1133 source_ids: vec![c.id],
1134 support_count: 1,
1135 created_at: Utc::now(),
1136 });
1137 }
1138 }
1139 }
1140
1141 for fact in &new_facts {
1145 let is_preference = fact.tags.iter().any(|t| t == "preference")
1146 || fact.content.starts_with("Prefer")
1147 || fact.content.starts_with("Always")
1148 || fact.content.starts_with("Never")
1149 || fact.content.starts_with("Don't")
1150 || fact.content.starts_with("Use ");
1151 if is_preference {
1152 self.store.rules.push(BehavioralRule {
1153 id: Uuid::new_v4(),
1154 rule: fact.content.clone(),
1155 source_ids: vec![fact.id],
1156 support_count: 1,
1157 created_at: Utc::now(),
1158 });
1159 }
1160 }
1161
1162 for c in &new_corrections {
1164 self.store.processed_correction_ids.push(c.id);
1165 }
1166 for f in &new_facts {
1167 self.store.processed_fact_ids.push(f.id);
1168 }
1169
1170 if self.store.rules.len() > self.max_rules {
1172 self.store
1173 .rules
1174 .sort_by(|a, b| b.support_count.cmp(&a.support_count));
1175 self.store.rules.truncate(self.max_rules);
1176 }
1177
1178 if let Some(ref path) = self.store_path {
1180 let _ = self.store.save(path);
1181 }
1182 }
1183
1184 pub fn rules_for_prompt(&self) -> String {
1188 if self.store.rules.is_empty() {
1189 return String::new();
1190 }
1191 let mut prompt = String::from(
1192 "\n\n## Learned Behavioral Rules\n\
1193 The following rules were distilled from previous sessions. Follow them:\n",
1194 );
1195 for (i, rule) in self.store.rules.iter().enumerate() {
1196 prompt.push_str(&format!("{}. {}\n", i + 1, rule.rule));
1197 }
1198 prompt
1199 }
1200
1201 pub fn rule_count(&self) -> usize {
1203 self.store.rules.len()
1204 }
1205
1206 pub fn store(&self) -> &KnowledgeStore {
1208 &self.store
1209 }
1210}
1211
1212#[cfg(test)]
1213mod tests {
1214 use super::*;
1215 use crate::types::{Content, Role};
1216
1217 #[test]
1218 fn test_working_memory_lifecycle() {
1219 let mut wm = WorkingMemory::new();
1220 assert!(wm.current_goal.is_none());
1221
1222 wm.set_goal("refactor auth module");
1223 assert_eq!(wm.current_goal.as_deref(), Some("refactor auth module"));
1224
1225 wm.add_sub_task("read current implementation");
1226 wm.add_sub_task("design new structure");
1227 assert_eq!(wm.sub_tasks.len(), 2);
1228
1229 wm.note("finding", "uses basic auth currently");
1230 assert_eq!(
1231 wm.scratchpad.get("finding").map(|s| s.as_str()),
1232 Some("uses basic auth currently")
1233 );
1234
1235 wm.add_active_file("src/auth/mod.rs");
1236 wm.add_active_file("src/auth/mod.rs"); assert_eq!(wm.active_files.len(), 1);
1238
1239 wm.clear();
1240 assert!(wm.current_goal.is_none());
1241 assert!(wm.sub_tasks.is_empty());
1242 }
1243
1244 #[test]
1245 fn test_short_term_memory_basic() {
1246 let mut stm = ShortTermMemory::new(5);
1247 assert!(stm.is_empty());
1248 assert_eq!(stm.len(), 0);
1249
1250 stm.add(Message::user("hello"));
1251 stm.add(Message::assistant("hi there"));
1252 assert_eq!(stm.len(), 2);
1253 assert_eq!(stm.total_messages_seen(), 2);
1254
1255 let messages = stm.to_messages();
1256 assert_eq!(messages.len(), 2);
1257 }
1258
1259 #[test]
1260 fn test_short_term_memory_window() {
1261 let mut stm = ShortTermMemory::new(3);
1262
1263 for i in 0..6 {
1264 stm.add(Message::user(format!("message {}", i)));
1265 }
1266
1267 assert_eq!(stm.len(), 6);
1268 let messages = stm.to_messages();
1269 assert_eq!(messages.len(), 3);
1271 assert_eq!(messages[0].content.as_text(), Some("message 3"));
1272 assert_eq!(messages[2].content.as_text(), Some("message 5"));
1273 }
1274
1275 #[test]
1276 fn test_short_term_memory_compression() {
1277 let mut stm = ShortTermMemory::new(3);
1278
1279 for i in 0..6 {
1280 stm.add(Message::user(format!("message {}", i)));
1281 }
1282
1283 assert!(stm.needs_compression());
1284
1285 let to_summarize = stm.messages_to_summarize();
1286 assert_eq!(to_summarize.len(), 3); let compressed = stm.compress("Summary of messages 0-2.".to_string());
1289 assert_eq!(compressed, 3);
1290 assert_eq!(stm.len(), 3); let messages = stm.to_messages();
1293 assert_eq!(messages.len(), 4); assert!(
1296 messages[0]
1297 .content
1298 .as_text()
1299 .unwrap()
1300 .contains("Summary of")
1301 );
1302 assert_eq!(messages[0].role, Role::System);
1303 }
1304
1305 #[test]
1306 fn test_short_term_memory_double_compression() {
1307 let mut stm = ShortTermMemory::new(2);
1308
1309 for i in 0..5 {
1311 stm.add(Message::user(format!("msg {}", i)));
1312 }
1313 stm.compress("First summary.".to_string());
1314 assert_eq!(stm.len(), 2);
1315
1316 for i in 5..8 {
1318 stm.add(Message::user(format!("msg {}", i)));
1319 }
1320 stm.compress("Second summary.".to_string());
1321 assert_eq!(stm.len(), 2);
1322
1323 let summary = stm.summary().unwrap();
1325 assert!(summary.contains("First summary."));
1326 assert!(summary.contains("Second summary."));
1327 }
1328
1329 #[test]
1330 fn test_short_term_memory_clear() {
1331 let mut stm = ShortTermMemory::new(5);
1332 stm.add(Message::user("test"));
1333 stm.compress("summary".to_string());
1334
1335 stm.clear();
1336 assert!(stm.is_empty());
1337 assert!(stm.summary().is_none());
1338 assert_eq!(stm.total_messages_seen(), 0);
1339 }
1340
1341 #[test]
1342 fn test_fact_creation() {
1343 let fact = Fact::new("Project uses JWT auth", "code analysis")
1344 .with_tags(vec!["auth".to_string(), "jwt".to_string()]);
1345 assert_eq!(fact.content, "Project uses JWT auth");
1346 assert_eq!(fact.source, "code analysis");
1347 assert_eq!(fact.tags.len(), 2);
1348 }
1349
1350 #[test]
1351 fn test_long_term_memory() {
1352 let mut ltm = LongTermMemory::new();
1353
1354 ltm.add_fact(Fact::new("Uses Rust 2021 edition", "Cargo.toml"));
1355 ltm.set_preference("code_style", "rustfmt defaults");
1356 ltm.add_correction(
1357 "wrong import".to_string(),
1358 "correct import".to_string(),
1359 "editing main.rs".to_string(),
1360 );
1361
1362 assert_eq!(ltm.facts.len(), 1);
1363 assert_eq!(ltm.get_preference("code_style"), Some("rustfmt defaults"));
1364 assert_eq!(ltm.corrections.len(), 1);
1365 }
1366
1367 #[test]
1368 fn test_long_term_memory_search() {
1369 let mut ltm = LongTermMemory::new();
1370 ltm.add_fact(Fact::new("Project uses JWT authentication", "analysis"));
1371 ltm.add_fact(
1372 Fact::new("Database is PostgreSQL", "config").with_tags(vec!["database".to_string()]),
1373 );
1374 ltm.add_fact(Fact::new("Frontend uses React", "package.json"));
1375
1376 let results = ltm.search_facts("JWT");
1377 assert_eq!(results.len(), 1);
1378 assert!(results[0].content.contains("JWT"));
1379
1380 let results = ltm.search_facts("database");
1381 assert_eq!(results.len(), 1);
1382
1383 let results = ltm.search_facts("nonexistent");
1384 assert!(results.is_empty());
1385 }
1386
1387 #[test]
1388 fn test_memory_system() {
1389 let mut mem = MemorySystem::new(5);
1390
1391 mem.start_new_task("fix bug #42");
1392 assert_eq!(mem.working.current_goal.as_deref(), Some("fix bug #42"));
1393
1394 mem.add_message(Message::user("fix the null pointer bug"));
1395 mem.add_message(Message::assistant("I'll look into that."));
1396
1397 let ctx = mem.context_messages();
1398 assert_eq!(ctx.len(), 2);
1399
1400 mem.clear_session();
1401 assert!(mem.short_term.is_empty());
1402 assert!(mem.working.current_goal.is_none());
1403 }
1404
1405 #[test]
1406 fn test_compression_no_op_when_within_window() {
1407 let mut stm = ShortTermMemory::new(10);
1408 stm.add(Message::user("hello"));
1409 stm.add(Message::assistant("hi"));
1410
1411 assert!(!stm.needs_compression());
1412 assert!(stm.messages_to_summarize().is_empty());
1413
1414 let compressed = stm.compress("should not matter".to_string());
1415 assert_eq!(compressed, 0);
1416 }
1417
1418 #[test]
1419 fn test_memory_system_new_task_preserves_long_term() {
1420 let mut mem = MemorySystem::new(5);
1421 mem.long_term.add_fact(Fact::new("important fact", "test"));
1422 mem.add_message(Message::user("task 1"));
1423
1424 mem.start_new_task("task 2");
1425 assert_eq!(mem.working.current_goal.as_deref(), Some("task 2"));
1426 assert_eq!(mem.long_term.facts.len(), 1); }
1428
1429 #[test]
1432 fn test_session_save_load_roundtrip() {
1433 let dir = tempfile::tempdir().unwrap();
1434 let session_path = dir.path().join("session.json");
1435
1436 let mut mem = MemorySystem::new(10);
1438 mem.start_new_task("fix bug #42");
1439 mem.add_message(Message::user("fix the bug"));
1440 mem.add_message(Message::assistant("Looking into it."));
1441 mem.long_term.add_fact(Fact::new("Uses Rust", "analysis"));
1442 mem.long_term.set_preference("style", "concise");
1443
1444 mem.save_session(&session_path).unwrap();
1446 assert!(session_path.exists());
1447
1448 let loaded = MemorySystem::load_session(&session_path).unwrap();
1450 assert_eq!(loaded.working.current_goal.as_deref(), Some("fix bug #42"));
1451 assert_eq!(loaded.short_term.len(), 2);
1452 assert_eq!(loaded.long_term.facts.len(), 1);
1453 assert_eq!(loaded.long_term.get_preference("style"), Some("concise"));
1454
1455 let messages = loaded.context_messages();
1457 assert_eq!(messages.len(), 2);
1458 assert_eq!(messages[0].content.as_text(), Some("fix the bug"));
1459 assert_eq!(messages[1].content.as_text(), Some("Looking into it."));
1460 }
1461
1462 #[test]
1463 fn test_session_load_missing_file() {
1464 let result = MemorySystem::load_session(Path::new("/nonexistent/session.json"));
1465 assert!(result.is_err());
1466 }
1467
1468 #[test]
1469 fn test_session_load_corrupt_json() {
1470 let dir = tempfile::tempdir().unwrap();
1471 let path = dir.path().join("bad.json");
1472 std::fs::write(&path, "not valid json").unwrap();
1473
1474 let result = MemorySystem::load_session(&path);
1475 assert!(result.is_err());
1476 }
1477
1478 #[test]
1479 fn test_session_save_creates_directories() {
1480 let dir = tempfile::tempdir().unwrap();
1481 let session_path = dir.path().join("nested").join("dir").join("session.json");
1482
1483 let mem = MemorySystem::new(5);
1484 mem.save_session(&session_path).unwrap();
1485 assert!(session_path.exists());
1486 }
1487
1488 #[test]
1489 fn test_session_metadata() {
1490 let meta = SessionMetadata::new();
1491 assert!(meta.task_summary.is_none());
1492 assert!(meta.created_at <= Utc::now());
1493
1494 let default_meta = SessionMetadata::default();
1495 assert!(default_meta.task_summary.is_none());
1496 }
1497
1498 #[test]
1499 fn test_short_term_window_size() {
1500 let stm = ShortTermMemory::new(7);
1501 assert_eq!(stm.window_size(), 7);
1502 }
1503
1504 #[test]
1507 fn test_pin_message() {
1508 let mut stm = ShortTermMemory::new(5);
1509 stm.add(Message::user("msg 0"));
1510 stm.add(Message::user("msg 1"));
1511 stm.add(Message::user("msg 2"));
1512
1513 assert!(stm.pin(1));
1514 assert!(stm.is_pinned(1));
1515 assert!(!stm.is_pinned(0));
1516 assert_eq!(stm.pinned_count(), 1);
1517 }
1518
1519 #[test]
1520 fn test_pin_out_of_bounds() {
1521 let mut stm = ShortTermMemory::new(5);
1522 stm.add(Message::user("msg 0"));
1523 assert!(!stm.pin(5)); }
1525
1526 #[test]
1527 fn test_unpin_message() {
1528 let mut stm = ShortTermMemory::new(5);
1529 stm.add(Message::user("msg 0"));
1530 stm.add(Message::user("msg 1"));
1531
1532 stm.pin(0);
1533 assert!(stm.is_pinned(0));
1534 assert!(stm.unpin(0));
1535 assert!(!stm.is_pinned(0));
1536 }
1537
1538 #[test]
1539 fn test_pinned_survives_compression() {
1540 let mut stm = ShortTermMemory::new(3);
1541 stm.add(Message::user("old 0"));
1543 stm.add(Message::user("old 1"));
1544 stm.add(Message::user("important pinned"));
1545 stm.add(Message::user("msg 3"));
1546 stm.add(Message::user("msg 4"));
1547 stm.add(Message::user("msg 5"));
1548
1549 stm.pin(2);
1551 assert!(stm.needs_compression());
1552
1553 let removed = stm.compress("Summary of old messages".to_string());
1554 assert!(removed < 3); let msgs = stm.to_messages();
1559 let has_pinned = msgs
1560 .iter()
1561 .any(|m| matches!(&m.content, Content::Text { text } if text == "important pinned"));
1562 assert!(has_pinned, "Pinned message should survive compression");
1563 }
1564
1565 #[test]
1566 fn test_clear_resets_pins() {
1567 let mut stm = ShortTermMemory::new(5);
1568 stm.add(Message::user("msg 0"));
1569 stm.pin(0);
1570 assert_eq!(stm.pinned_count(), 1);
1571
1572 stm.clear();
1573 assert_eq!(stm.pinned_count(), 0);
1574 }
1575
1576 #[test]
1579 fn test_context_breakdown() {
1580 let mut memory = MemorySystem::new(10);
1581 memory.add_message(Message::user("hello world"));
1582 memory.add_message(Message::assistant("hi there!"));
1583
1584 let ctx = memory.context_breakdown(8000);
1585 assert!(ctx.message_tokens > 0);
1586 assert_eq!(ctx.message_count, 2);
1587 assert_eq!(ctx.context_window, 8000);
1588 assert!(ctx.remaining_tokens > 0);
1589 assert!(!ctx.has_summary);
1590 assert_eq!(ctx.pinned_count, 0);
1591 }
1592
1593 #[test]
1594 fn test_context_breakdown_ratio() {
1595 let ctx = ContextBreakdown {
1596 total_tokens: 4000,
1597 context_window: 8000,
1598 ..Default::default()
1599 };
1600 assert!((ctx.usage_ratio() - 0.5).abs() < 0.01);
1601 assert!(!ctx.is_warning());
1602
1603 let ctx_high = ContextBreakdown {
1604 total_tokens: 7000,
1605 context_window: 8000,
1606 ..Default::default()
1607 };
1608 assert!(ctx_high.is_warning());
1609 }
1610
1611 #[test]
1612 fn test_pin_message_via_memory_system() {
1613 let mut memory = MemorySystem::new(10);
1614 memory.add_message(Message::user("msg 0"));
1615 memory.add_message(Message::user("msg 1"));
1616
1617 assert!(memory.pin_message(0));
1618 assert!(memory.short_term.is_pinned(0));
1619 }
1620
1621 #[test]
1624 fn test_flusher_default_config() {
1625 let config = FlushConfig::default();
1626 assert!(!config.enabled);
1627 assert_eq!(config.interval_secs, 300);
1628 assert_eq!(config.flush_on_message_count, 50);
1629 assert!(config.flush_path.is_none());
1630 }
1631
1632 #[test]
1633 fn test_flusher_not_dirty_by_default() {
1634 let flusher = MemoryFlusher::new(FlushConfig::default());
1635 assert!(!flusher.is_dirty());
1636 assert_eq!(flusher.messages_since_flush(), 0);
1637 assert_eq!(flusher.total_flushes(), 0);
1638 }
1639
1640 #[test]
1641 fn test_flusher_marks_dirty_on_message() {
1642 let mut flusher = MemoryFlusher::new(FlushConfig::default());
1643 flusher.on_message_added();
1644 assert!(flusher.is_dirty());
1645 assert_eq!(flusher.messages_since_flush(), 1);
1646 }
1647
1648 #[test]
1649 fn test_flusher_disabled_never_triggers() {
1650 let mut flusher = MemoryFlusher::new(FlushConfig {
1651 enabled: false,
1652 ..FlushConfig::default()
1653 });
1654 for _ in 0..100 {
1655 flusher.on_message_added();
1656 }
1657 assert!(!flusher.should_flush());
1658 }
1659
1660 #[test]
1661 fn test_flusher_message_count_trigger() {
1662 let mut flusher = MemoryFlusher::new(FlushConfig {
1663 enabled: true,
1664 flush_on_message_count: 5,
1665 interval_secs: 0,
1666 flush_path: None,
1667 });
1668
1669 for _ in 0..4 {
1670 flusher.on_message_added();
1671 }
1672 assert!(!flusher.should_flush());
1673
1674 flusher.on_message_added(); assert!(flusher.should_flush());
1676 }
1677
1678 #[test]
1679 fn test_flusher_not_dirty_no_trigger() {
1680 let flusher = MemoryFlusher::new(FlushConfig {
1681 enabled: true,
1682 flush_on_message_count: 1,
1683 interval_secs: 0,
1684 flush_path: None,
1685 });
1686 assert!(!flusher.should_flush());
1688 }
1689
1690 #[test]
1691 fn test_flusher_flush_resets_state() {
1692 let dir = tempfile::tempdir().unwrap();
1693 let flush_path = dir.path().join("flush.json");
1694
1695 let mut flusher = MemoryFlusher::new(FlushConfig {
1696 enabled: true,
1697 flush_on_message_count: 2,
1698 interval_secs: 0,
1699 flush_path: Some(flush_path.clone()),
1700 });
1701
1702 let mut mem = MemorySystem::new(10);
1703 mem.add_message(Message::user("test"));
1704
1705 flusher.on_message_added();
1706 flusher.on_message_added();
1707 assert!(flusher.should_flush());
1708
1709 flusher.flush(&mem).unwrap();
1710 assert!(!flusher.is_dirty());
1711 assert_eq!(flusher.messages_since_flush(), 0);
1712 assert_eq!(flusher.total_flushes(), 1);
1713 assert!(flush_path.exists());
1714 }
1715
1716 #[test]
1717 fn test_flusher_force_flush() {
1718 let dir = tempfile::tempdir().unwrap();
1719 let flush_path = dir.path().join("force.json");
1720
1721 let mut flusher = MemoryFlusher::new(FlushConfig {
1722 enabled: true,
1723 flush_on_message_count: 100,
1724 interval_secs: 0,
1725 flush_path: Some(flush_path.clone()),
1726 });
1727
1728 let mem = MemorySystem::new(10);
1729
1730 flusher.force_flush(&mem).unwrap();
1732 assert_eq!(flusher.total_flushes(), 0);
1733
1734 flusher.on_message_added();
1736 flusher.force_flush(&mem).unwrap();
1737 assert_eq!(flusher.total_flushes(), 1);
1738 assert!(!flusher.is_dirty());
1739 }
1740
1741 #[test]
1742 fn test_flusher_no_path_error() {
1743 let mut flusher = MemoryFlusher::new(FlushConfig {
1744 enabled: true,
1745 flush_on_message_count: 1,
1746 interval_secs: 0,
1747 flush_path: None,
1748 });
1749 flusher.on_message_added();
1750
1751 let mem = MemorySystem::new(10);
1752 let result = flusher.flush(&mem);
1753 assert!(result.is_err());
1754 }
1755
1756 #[test]
1757 fn test_flush_config_serialization() {
1758 let config = FlushConfig {
1759 enabled: true,
1760 interval_secs: 120,
1761 flush_on_message_count: 25,
1762 flush_path: Some(std::path::PathBuf::from("/tmp/flush.json")),
1763 };
1764 let json = serde_json::to_string(&config).unwrap();
1765 let restored: FlushConfig = serde_json::from_str(&json).unwrap();
1766 assert!(restored.enabled);
1767 assert_eq!(restored.interval_secs, 120);
1768 assert_eq!(restored.flush_on_message_count, 25);
1769 }
1770
1771 #[test]
1774 fn test_memory_system_without_search_uses_keyword_fallback() {
1775 let mut mem = MemorySystem::new(10);
1776 mem.add_fact(Fact::new("Rust uses ownership for memory safety", "docs"));
1777 mem.add_fact(Fact::new("Python uses garbage collection", "docs"));
1778
1779 let results = mem.search_facts_hybrid("ownership");
1780 assert_eq!(results.len(), 1);
1781 assert!(results[0].content.contains("ownership"));
1782 }
1783
1784 #[test]
1785 fn test_memory_system_with_search_engine() {
1786 let dir = tempfile::tempdir().unwrap();
1787 let config = SearchConfig {
1788 index_path: dir.path().join("idx"),
1789 db_path: dir.path().join("vec.db"),
1790 vector_dimensions: 64,
1791 full_text_weight: 0.5,
1792 vector_weight: 0.5,
1793 max_results: 10,
1794 };
1795 let mut mem = MemorySystem::with_search(10, config).unwrap();
1796
1797 mem.add_fact(Fact::new("Rust uses ownership model", "analysis"));
1798 mem.add_fact(Fact::new("Python garbage collector", "analysis"));
1799
1800 let results = mem.search_facts_hybrid("ownership");
1802 assert!(!results.is_empty());
1803 assert!(results.iter().any(|f| f.content.contains("ownership")));
1804 }
1805
1806 #[test]
1807 fn test_memory_system_search_empty_query() {
1808 let mut mem = MemorySystem::new(10);
1809 mem.add_fact(Fact::new("some fact", "source"));
1810 let results = mem.search_facts_hybrid("");
1811 assert!(!results.is_empty());
1815 }
1816
1817 #[test]
1818 fn test_memory_system_search_no_facts() {
1819 let mem = MemorySystem::new(10);
1820 let results = mem.search_facts_hybrid("anything");
1821 assert!(results.is_empty());
1822 }
1823
1824 #[test]
1825 fn test_add_fact_indexes_into_search_engine() {
1826 let dir = tempfile::tempdir().unwrap();
1827 let config = SearchConfig {
1828 index_path: dir.path().join("idx"),
1829 db_path: dir.path().join("vec.db"),
1830 vector_dimensions: 64,
1831 full_text_weight: 0.5,
1832 vector_weight: 0.5,
1833 max_results: 10,
1834 };
1835 let mut mem = MemorySystem::with_search(10, config).unwrap();
1836
1837 for i in 0..5 {
1839 mem.add_fact(Fact::new(format!("fact number {}", i), "test"));
1840 }
1841
1842 assert_eq!(mem.long_term.facts.len(), 5);
1844 }
1845
1846 #[test]
1849 fn test_memory_system_with_flusher() {
1850 let config = FlushConfig {
1851 enabled: true,
1852 flush_on_message_count: 5,
1853 interval_secs: 0,
1854 flush_path: None,
1855 };
1856 let mem = MemorySystem::new(10).with_flusher(config);
1857 assert!(!mem.flusher_is_dirty());
1858 }
1859
1860 #[test]
1861 fn test_memory_system_add_message_notifies_flusher() {
1862 let config = FlushConfig {
1863 enabled: true,
1864 flush_on_message_count: 5,
1865 interval_secs: 0,
1866 flush_path: None,
1867 };
1868 let mut mem = MemorySystem::new(10).with_flusher(config);
1869
1870 mem.add_message(Message::user("hello"));
1871 assert!(mem.flusher_is_dirty());
1872 }
1873
1874 #[test]
1875 fn test_memory_system_check_auto_flush_no_flusher() {
1876 let mut mem = MemorySystem::new(10);
1877 let result = mem.check_auto_flush().unwrap();
1879 assert!(!result);
1880 }
1881
1882 #[test]
1883 fn test_memory_system_check_auto_flush_triggers() {
1884 let dir = tempfile::tempdir().unwrap();
1885 let flush_path = dir.path().join("auto_flush.json");
1886
1887 let config = FlushConfig {
1888 enabled: true,
1889 flush_on_message_count: 3,
1890 interval_secs: 0,
1891 flush_path: Some(flush_path.clone()),
1892 };
1893 let mut mem = MemorySystem::new(10).with_flusher(config);
1894
1895 mem.add_message(Message::user("msg 1"));
1897 mem.add_message(Message::user("msg 2"));
1898 assert!(!mem.check_auto_flush().unwrap());
1899 assert!(!flush_path.exists());
1900
1901 mem.add_message(Message::user("msg 3"));
1903 assert!(mem.check_auto_flush().unwrap());
1904 assert!(flush_path.exists());
1905
1906 assert!(!mem.flusher_is_dirty());
1908 }
1909
1910 #[test]
1911 fn test_memory_system_force_flush() {
1912 let dir = tempfile::tempdir().unwrap();
1913 let flush_path = dir.path().join("force_flush.json");
1914
1915 let config = FlushConfig {
1916 enabled: true,
1917 flush_on_message_count: 100, interval_secs: 0,
1919 flush_path: Some(flush_path.clone()),
1920 };
1921 let mut mem = MemorySystem::new(10).with_flusher(config);
1922
1923 mem.add_message(Message::user("important data"));
1924 assert!(mem.flusher_is_dirty());
1925
1926 mem.force_flush().unwrap();
1927 assert!(!mem.flusher_is_dirty());
1928 assert!(flush_path.exists());
1929 }
1930
1931 #[test]
1932 fn test_memory_system_force_flush_no_flusher() {
1933 let mut mem = MemorySystem::new(10);
1934 mem.force_flush().unwrap();
1936 }
1937
1938 #[test]
1941 fn test_knowledge_distiller_disabled() {
1942 let distiller = KnowledgeDistiller::new(None);
1943 assert_eq!(distiller.rule_count(), 0);
1944 assert!(distiller.rules_for_prompt().is_empty());
1945 }
1946
1947 #[test]
1948 fn test_knowledge_distiller_no_data() {
1949 let config = crate::config::KnowledgeConfig::default();
1950 let mut distiller = KnowledgeDistiller::new(Some(&config));
1951 let ltm = LongTermMemory::new();
1952 distiller.distill(<m);
1953 assert_eq!(distiller.rule_count(), 0);
1954 }
1955
1956 #[test]
1957 fn test_knowledge_distiller_corrections_below_threshold() {
1958 let config = crate::config::KnowledgeConfig {
1959 min_entries_for_distillation: 5,
1960 ..Default::default()
1961 };
1962 let mut distiller = KnowledgeDistiller::new(Some(&config));
1963
1964 let mut ltm = LongTermMemory::new();
1965 ltm.add_correction(
1966 "unwrap()".into(),
1967 "? operator".into(),
1968 "error handling".into(),
1969 );
1970 ltm.add_correction("println!".into(), "tracing::info!".into(), "logging".into());
1971
1972 distiller.distill(<m);
1973 assert_eq!(distiller.rule_count(), 0);
1975 }
1976
1977 #[test]
1978 fn test_knowledge_distiller_single_corrections() {
1979 let config = crate::config::KnowledgeConfig {
1980 min_entries_for_distillation: 2,
1981 ..Default::default()
1982 };
1983 let mut distiller = KnowledgeDistiller::new(Some(&config));
1984
1985 let mut ltm = LongTermMemory::new();
1986 ltm.add_correction(
1987 "unwrap()".into(),
1988 "? operator".into(),
1989 "error handling".into(),
1990 );
1991 ltm.add_correction("println!".into(), "tracing::info!".into(), "logging".into());
1992 distiller.distill(<m);
1995 assert_eq!(distiller.rule_count(), 2);
1996
1997 let prompt = distiller.rules_for_prompt();
1998 assert!(prompt.contains("Learned Behavioral Rules"));
1999 assert!(prompt.contains("? operator"));
2000 assert!(prompt.contains("tracing::info!"));
2001 }
2002
2003 #[test]
2004 fn test_knowledge_distiller_grouped_corrections() {
2005 let config = crate::config::KnowledgeConfig {
2006 min_entries_for_distillation: 2,
2007 ..Default::default()
2008 };
2009 let mut distiller = KnowledgeDistiller::new(Some(&config));
2010
2011 let mut ltm = LongTermMemory::new();
2012 ltm.add_correction(
2014 "unwrap()".into(),
2015 "? operator".into(),
2016 "error handling in Rust code".into(),
2017 );
2018 ltm.add_correction(
2019 "expect()".into(),
2020 "map_err()?".into(),
2021 "error handling in Rust code".into(),
2022 );
2023
2024 distiller.distill(<m);
2025 assert_eq!(distiller.rule_count(), 1);
2026 let prompt = distiller.rules_for_prompt();
2027 assert!(prompt.contains("2 previous corrections"));
2028 }
2029
2030 #[test]
2031 fn test_knowledge_distiller_preference_facts() {
2032 let config = crate::config::KnowledgeConfig {
2033 min_entries_for_distillation: 1,
2034 ..Default::default()
2035 };
2036 let mut distiller = KnowledgeDistiller::new(Some(&config));
2037
2038 let mut ltm = LongTermMemory::new();
2039 ltm.add_fact(Fact::new("Prefer async/await over threads", "user"));
2040 ltm.add_fact(Fact::new("Project uses PostgreSQL", "session"));
2041
2042 distiller.distill(<m);
2043 assert_eq!(distiller.rule_count(), 1);
2045 let prompt = distiller.rules_for_prompt();
2046 assert!(prompt.contains("async/await"));
2047 }
2048
2049 #[test]
2050 fn test_knowledge_distiller_max_rules_truncation() {
2051 let config = crate::config::KnowledgeConfig {
2052 min_entries_for_distillation: 1,
2053 max_rules: 3,
2054 ..Default::default()
2055 };
2056 let mut distiller = KnowledgeDistiller::new(Some(&config));
2057
2058 let mut ltm = LongTermMemory::new();
2059 for i in 0..10 {
2060 ltm.add_correction(
2061 format!("old{}", i),
2062 format!("new{}", i),
2063 format!("context{}", i),
2064 );
2065 }
2066
2067 distiller.distill(<m);
2068 assert!(distiller.rule_count() <= 3);
2069 }
2070
2071 #[test]
2072 fn test_knowledge_distiller_idempotent() {
2073 let config = crate::config::KnowledgeConfig {
2074 min_entries_for_distillation: 1,
2075 ..Default::default()
2076 };
2077 let mut distiller = KnowledgeDistiller::new(Some(&config));
2078
2079 let mut ltm = LongTermMemory::new();
2080 ltm.add_correction("old".into(), "new".into(), "ctx".into());
2081
2082 distiller.distill(<m);
2083 let count_after_first = distiller.rule_count();
2084
2085 distiller.distill(<m);
2087 assert_eq!(distiller.rule_count(), count_after_first);
2088 }
2089
2090 #[test]
2091 fn test_knowledge_store_save_load_roundtrip() {
2092 let dir = tempfile::tempdir().unwrap();
2093 let path = dir.path().join("knowledge.json");
2094
2095 let mut store = KnowledgeStore::new();
2096 store.rules.push(BehavioralRule {
2097 id: Uuid::new_v4(),
2098 rule: "Prefer ? over unwrap".into(),
2099 source_ids: vec![Uuid::new_v4()],
2100 support_count: 3,
2101 created_at: Utc::now(),
2102 });
2103
2104 store.save(&path).unwrap();
2105 let loaded = KnowledgeStore::load(&path).unwrap();
2106 assert_eq!(loaded.rules.len(), 1);
2107 assert_eq!(loaded.rules[0].rule, "Prefer ? over unwrap");
2108 assert_eq!(loaded.rules[0].support_count, 3);
2109 }
2110
2111 #[test]
2112 fn test_knowledge_store_load_nonexistent() {
2113 let store =
2114 KnowledgeStore::load(std::path::Path::new("/nonexistent/knowledge.json")).unwrap();
2115 assert!(store.rules.is_empty());
2116 }
2117
2118 #[test]
2119 fn test_unpin_out_of_bounds_returns_false() {
2120 let mut stm = ShortTermMemory::new(100);
2121 stm.add(Message::user("hello"));
2122 stm.add(Message::assistant("hi"));
2123 stm.add(Message::user("world"));
2124
2125 assert!(stm.pin(1));
2127 assert!(stm.is_pinned(1));
2128
2129 assert!(!stm.unpin(999));
2131 assert!(!stm.unpin(3));
2132
2133 assert!(stm.is_pinned(1));
2135 }
2136
2137 #[test]
2138 fn test_unpin_at_exact_boundary() {
2139 let mut stm = ShortTermMemory::new(100);
2140 stm.add(Message::user("msg0"));
2141 stm.add(Message::user("msg1"));
2142
2143 assert!(!stm.unpin(2));
2145
2146 assert!(!stm.unpin(1)); assert!(stm.pin(1));
2151 assert!(stm.unpin(1));
2152 assert!(!stm.is_pinned(1));
2153 }
2154
2155 #[test]
2158 fn test_compress_preserves_tool_chain_pairs() {
2159 let mut stm = ShortTermMemory::new(3);
2160
2161 stm.add(Message::user("read the file")); stm.add(Message::new(
2164 Role::Assistant,
2166 Content::tool_call("call_abc", "file_read", serde_json::json!({"path": "x.rs"})),
2167 ));
2168 stm.add(Message::tool_result("call_abc", "fn main() {}", false)); stm.add(Message::assistant("Here is the content.")); stm.add(Message::user("thanks")); stm.add(Message::assistant("You're welcome.")); stm.pin(2);
2175 assert!(stm.needs_compression());
2176
2177 let removed = stm.compress("Summary of earlier conversation".to_string());
2178
2179 let msgs = stm.to_messages();
2182 let has_tool_call = msgs.iter().any(|m| {
2183 matches!(
2184 &m.content,
2185 Content::ToolCall { name, .. } if name == "file_read"
2186 )
2187 });
2188 let has_tool_result = msgs.iter().any(|m| {
2189 matches!(
2190 &m.content,
2191 Content::ToolResult { call_id, .. } if call_id == "call_abc"
2192 )
2193 });
2194
2195 assert!(
2196 has_tool_result,
2197 "Pinned tool_result should survive compression"
2198 );
2199 assert!(
2200 has_tool_call,
2201 "Paired tool_call should also survive compression"
2202 );
2203
2204 assert!(removed >= 1, "At least one message should be compressed");
2207 }
2208
2209 #[test]
2210 fn test_compress_preserves_tool_call_paired_with_result() {
2211 let mut stm = ShortTermMemory::new(3);
2212
2213 stm.add(Message::user("do something")); stm.add(Message::new(
2215 Role::Assistant,
2217 Content::tool_call("call_xyz", "shell_exec", serde_json::json!({"cmd": "ls"})),
2218 ));
2219 stm.add(Message::tool_result("call_xyz", "file1\nfile2", false)); stm.add(Message::assistant("Listed files.")); stm.add(Message::user("ok")); stm.add(Message::assistant("Done.")); stm.pin(1);
2226 assert!(stm.needs_compression());
2227
2228 stm.compress("Summary".to_string());
2229
2230 let msgs = stm.to_messages();
2231 let has_tool_call = msgs.iter().any(|m| {
2232 matches!(
2233 &m.content,
2234 Content::ToolCall { id, .. } if id == "call_xyz"
2235 )
2236 });
2237 let has_tool_result = msgs.iter().any(|m| {
2238 matches!(
2239 &m.content,
2240 Content::ToolResult { call_id, .. } if call_id == "call_xyz"
2241 )
2242 });
2243
2244 assert!(has_tool_call, "Pinned tool_call should survive compression");
2245 assert!(
2246 has_tool_result,
2247 "Paired tool_result should also survive compression"
2248 );
2249 }
2250}