Skip to main content

shodh_memory/memory/
segmentation.rs

1//! Hebbian-Friendly Segmentation Engine
2//!
3//! Segments raw input into atomic memory units optimized for Hebbian learning.
4//! Key principle: entities that belong together should form edges, unrelated ones shouldn't.
5//!
6//! Architecture:
7//! 1. Sentence splitting - break input into sentence-level units
8//! 2. Type detection - classify each sentence by ExperienceType
9//! 3. Same-type merging - consecutive sentences of same type become one memory
10//! 4. Entity-aware splitting - split if entities have no semantic relation
11//! 5. Deduplication - prevent duplicate edges in knowledge graph
12
13use crate::memory::types::ExperienceType;
14use regex::Regex;
15use std::collections::HashSet;
16
17/// Result of segmenting input text
18#[derive(Debug, Clone)]
19pub struct AtomicMemory {
20    /// Detected experience type
21    pub experience_type: ExperienceType,
22    /// The segmented content
23    pub content: String,
24    /// Extracted entities (for Hebbian edge formation)
25    pub entities: Vec<String>,
26    /// Confidence in type detection (0.0 - 1.0)
27    pub type_confidence: f32,
28    /// Source indicator (which part of input this came from)
29    pub source_offset: usize,
30}
31
32/// Input source for segmentation context
33#[derive(Debug, Clone, Copy, PartialEq, Eq)]
34pub enum InputSource {
35    /// From Cortex proxy (Claude API)
36    Cortex,
37    /// Direct user input via remember API
38    UserApi,
39    /// From codebase indexing
40    Codebase,
41    /// From streaming ingestion
42    Streaming,
43    /// From auto-ingest (proactive_context)
44    AutoIngest,
45}
46
47/// Type detection pattern with priority
48struct TypePattern {
49    pattern: Regex,
50    experience_type: ExperienceType,
51    confidence: f32,
52    priority: u8,
53}
54
55/// Segmentation engine for Hebbian-optimal memory formation
56pub struct SegmentationEngine {
57    /// Type detection patterns ordered by priority
58    type_patterns: Vec<TypePattern>,
59    /// Minimum content length for a valid segment
60    min_segment_length: usize,
61    /// Maximum content length before forced split
62    max_segment_length: usize,
63}
64
65impl Default for SegmentationEngine {
66    fn default() -> Self {
67        Self::new()
68    }
69}
70
71impl SegmentationEngine {
72    /// Create a new segmentation engine with default patterns
73    pub fn new() -> Self {
74        let type_patterns = Self::build_type_patterns();
75        Self {
76            type_patterns,
77            min_segment_length: 20,
78            max_segment_length: 2000,
79        }
80    }
81
82    /// Build type detection patterns
83    /// Priority: higher = checked first
84    /// Confidence: how certain we are when pattern matches
85    fn build_type_patterns() -> Vec<TypePattern> {
86        let mut patterns = vec![
87            // === HIGH PRIORITY: Explicit markers ===
88            TypePattern {
89                pattern: Regex::new(r"(?i)\b(decided|chose|chosen|went with|picked|selected|opted for|decision to)\b").unwrap(),
90                experience_type: ExperienceType::Decision,
91                confidence: 0.95,
92                priority: 100,
93            },
94            TypePattern {
95                pattern: Regex::new(r"(?i)\b(learned|realized|understood|figured out|now I know|insight)\b").unwrap(),
96                experience_type: ExperienceType::Learning,
97                confidence: 0.90,
98                priority: 95,
99            },
100            TypePattern {
101                pattern: Regex::new(r"(?i)(error:|bug:|exception:|failed:|broke:|crash|traceback|stacktrace|\bfixed\b)").unwrap(),
102                experience_type: ExperienceType::Error,
103                confidence: 0.95,
104                priority: 98,
105            },
106            TypePattern {
107                pattern: Regex::new(r"(?i)\b(discovered|found that|noticed|stumbled upon|turns out)\b").unwrap(),
108                experience_type: ExperienceType::Discovery,
109                confidence: 0.85,
110                priority: 90,
111            },
112            TypePattern {
113                pattern: Regex::new(r"(?i)(pattern:|always|every time|whenever|consistently|tends to)\b").unwrap(),
114                experience_type: ExperienceType::Pattern,
115                confidence: 0.80,
116                priority: 85,
117            },
118
119            // === MEDIUM PRIORITY: Action-based ===
120            TypePattern {
121                pattern: Regex::new(r"(?i)\b(will|tomorrow|later|remind me|don't forget|need to remember|scheduled|need to fix)\b").unwrap(),
122                experience_type: ExperienceType::Intention,
123                confidence: 0.85,
124                priority: 88,
125            },
126            TypePattern {
127                pattern: Regex::new(r"(?i)\b(edited|changed|modified|updated|refactored|renamed|moved)\b.*\b(file|code|function|class|module)\b").unwrap(),
128                experience_type: ExperienceType::CodeEdit,
129                confidence: 0.90,
130                priority: 80,
131            },
132            TypePattern {
133                pattern: Regex::new(r"(?i)\b(opened|read|accessed|viewed|looked at)\b.*\b(file|document|page)\b").unwrap(),
134                experience_type: ExperienceType::FileAccess,
135                confidence: 0.85,
136                priority: 75,
137            },
138            TypePattern {
139                pattern: Regex::new(r"(?i)\b(searched|looked for|found|grep|rg|find)\b").unwrap(),
140                experience_type: ExperienceType::Search,
141                confidence: 0.80,
142                priority: 70,
143            },
144            TypePattern {
145                pattern: Regex::new(r"(?i)\b(ran|executed|command:|terminal|shell|bash|npm|cargo|git)\b").unwrap(),
146                experience_type: ExperienceType::Command,
147                confidence: 0.85,
148                priority: 72,
149            },
150            TypePattern {
151                pattern: Regex::new(r"(?i)\b(task:|todo:|need to|should|must|have to|working on)\b").unwrap(),
152                experience_type: ExperienceType::Task,
153                confidence: 0.75,
154                priority: 65,
155            },
156
157            // === LOWER PRIORITY: Context indicators ===
158            TypePattern {
159                pattern: Regex::new(r"(?i)(context:|background:|for reference|fyi|note:)\b").unwrap(),
160                experience_type: ExperienceType::Context,
161                confidence: 0.80,
162                priority: 60,
163            },
164            TypePattern {
165                pattern: Regex::new(r"(?i)\b(said|told|asked|replied|mentioned|discussed|conversation)\b").unwrap(),
166                experience_type: ExperienceType::Conversation,
167                confidence: 0.70,
168                priority: 50,
169            },
170        ];
171
172        // Sort by priority descending
173        patterns.sort_by(|a, b| b.priority.cmp(&a.priority));
174        patterns
175    }
176
177    /// Main entry point: segment input into atomic memories
178    pub fn segment(&self, input: &str, source: InputSource) -> Vec<AtomicMemory> {
179        let input = input.trim();
180        if input.is_empty() {
181            return Vec::new();
182        }
183
184        // Step 1: Split into sentences
185        let sentences = self.split_sentences(input);
186        if sentences.is_empty() {
187            return Vec::new();
188        }
189
190        // Step 2: Classify each sentence
191        let typed_sentences: Vec<(ExperienceType, f32, String, usize)> = sentences
192            .into_iter()
193            .enumerate()
194            .filter(|(_, s)| s.len() >= self.min_segment_length)
195            .map(|(offset, s)| {
196                let (exp_type, confidence) = self.detect_type(&s, source);
197                (exp_type, confidence, s, offset)
198            })
199            .collect();
200
201        if typed_sentences.is_empty() {
202            // If all sentences were too short, treat entire input as one
203            let (exp_type, confidence) = self.detect_type(input, source);
204            return vec![AtomicMemory {
205                experience_type: exp_type,
206                content: input.to_string(),
207                entities: self.extract_simple_entities(input),
208                type_confidence: confidence,
209                source_offset: 0,
210            }];
211        }
212
213        // Step 3: Merge consecutive same-type sentences
214        let merged = self.merge_consecutive_same_type(typed_sentences);
215
216        // Step 4: Apply max length splitting if needed
217        let split = self.apply_max_length_splits(merged);
218
219        // Step 5: Extract entities for each segment
220        split
221            .into_iter()
222            .map(|(exp_type, confidence, content, offset)| AtomicMemory {
223                experience_type: exp_type,
224                entities: self.extract_simple_entities(&content),
225                content,
226                type_confidence: confidence,
227                source_offset: offset,
228            })
229            .collect()
230    }
231
232    /// Split input into sentences
233    fn split_sentences(&self, input: &str) -> Vec<String> {
234        // Split on sentence boundaries: . ! ? followed by space or newline
235        // But preserve abbreviations like "e.g." "i.e." "Dr." etc.
236        let mut sentences = Vec::new();
237        let mut current = String::new();
238        let mut chars = input.chars().peekable();
239
240        while let Some(c) = chars.next() {
241            current.push(c);
242
243            // Check for sentence boundary
244            if matches!(c, '.' | '!' | '?') {
245                // Look ahead to see if this is end of sentence
246                if let Some(&next) = chars.peek() {
247                    if next.is_whitespace() || next == '\n' {
248                        // Check if this looks like an abbreviation
249                        let trimmed = current.trim();
250                        let is_abbreviation = Self::is_likely_abbreviation(trimmed);
251
252                        if !is_abbreviation {
253                            let sentence = current.trim().to_string();
254                            if !sentence.is_empty() {
255                                sentences.push(sentence);
256                            }
257                            current = String::new();
258                            // Skip the whitespace
259                            chars.next();
260                        }
261                    }
262                }
263            }
264
265            // Also split on double newlines (paragraph boundaries)
266            if c == '\n' {
267                if let Some(&next) = chars.peek() {
268                    if next == '\n' {
269                        let sentence = current.trim().to_string();
270                        if !sentence.is_empty() {
271                            sentences.push(sentence);
272                        }
273                        current = String::new();
274                        chars.next(); // Skip second newline
275                    }
276                }
277            }
278        }
279
280        // Don't forget the last sentence
281        let final_sentence = current.trim().to_string();
282        if !final_sentence.is_empty() {
283            sentences.push(final_sentence);
284        }
285
286        sentences
287    }
288
289    /// Check if a string ending looks like an abbreviation
290    fn is_likely_abbreviation(s: &str) -> bool {
291        let lower = s.to_lowercase();
292        let abbreviations = [
293            "e.g.", "i.e.", "etc.", "vs.", "dr.", "mr.", "mrs.", "ms.", "jr.", "sr.", "inc.",
294            "ltd.", "corp.", "co.", "st.", "ave.", "rd.", "blvd.", "fig.", "ref.", "vol.", "no.",
295            "pp.", "ed.", "rev.",
296        ];
297
298        for abbr in &abbreviations {
299            if lower.ends_with(abbr) {
300                return true;
301            }
302        }
303
304        // Single letter followed by period (initials)
305        if s.len() >= 2 {
306            let chars: Vec<char> = s.chars().collect();
307            let last_two = &chars[chars.len() - 2..];
308            if last_two[0].is_alphabetic() && last_two[1] == '.' {
309                // Check if the letter before is whitespace or start
310                if chars.len() == 2 || chars[chars.len() - 3].is_whitespace() {
311                    return true;
312                }
313            }
314        }
315
316        false
317    }
318
319    /// Detect experience type from content
320    fn detect_type(&self, content: &str, source: InputSource) -> (ExperienceType, f32) {
321        // Source-based hints
322        let source_type = match source {
323            InputSource::Codebase => Some((ExperienceType::FileAccess, 0.6)),
324            InputSource::AutoIngest => None, // Need to detect from content
325            _ => None,
326        };
327
328        // Try pattern matching
329        for pattern in &self.type_patterns {
330            if pattern.pattern.is_match(content) {
331                return (pattern.experience_type.clone(), pattern.confidence);
332            }
333        }
334
335        // Fall back to source hint or default
336        source_type.unwrap_or((ExperienceType::Observation, 0.5))
337    }
338
339    /// Merge consecutive sentences of the same type
340    fn merge_consecutive_same_type(
341        &self,
342        sentences: Vec<(ExperienceType, f32, String, usize)>,
343    ) -> Vec<(ExperienceType, f32, String, usize)> {
344        if sentences.is_empty() {
345            return Vec::new();
346        }
347
348        let mut result = Vec::new();
349        let mut current_type = sentences[0].0.clone();
350        let mut current_confidence = sentences[0].1;
351        let mut current_content = sentences[0].2.clone();
352        let mut current_offset = sentences[0].3;
353
354        for (exp_type, confidence, content, offset) in sentences.into_iter().skip(1) {
355            if exp_type == current_type {
356                // Merge: append content, take max confidence
357                current_content.push(' ');
358                current_content.push_str(&content);
359                current_confidence = current_confidence.max(confidence);
360            } else {
361                // Different type: save current, start new
362                result.push((
363                    current_type,
364                    current_confidence,
365                    current_content,
366                    current_offset,
367                ));
368                current_type = exp_type;
369                current_confidence = confidence;
370                current_content = content;
371                current_offset = offset;
372            }
373        }
374
375        // Don't forget the last one
376        result.push((
377            current_type,
378            current_confidence,
379            current_content,
380            current_offset,
381        ));
382
383        result
384    }
385
386    /// Split segments that exceed max length
387    fn apply_max_length_splits(
388        &self,
389        segments: Vec<(ExperienceType, f32, String, usize)>,
390    ) -> Vec<(ExperienceType, f32, String, usize)> {
391        let mut result = Vec::new();
392
393        for (exp_type, confidence, content, offset) in segments {
394            if content.len() <= self.max_segment_length {
395                result.push((exp_type, confidence, content, offset));
396            } else {
397                // Split on sentence boundaries within the long content
398                let sub_sentences = self.split_sentences(&content);
399                let mut current_chunk = String::new();
400
401                for sentence in sub_sentences {
402                    if current_chunk.len() + sentence.len() + 1 > self.max_segment_length {
403                        if !current_chunk.is_empty() {
404                            result.push((
405                                exp_type.clone(),
406                                confidence,
407                                current_chunk.clone(),
408                                offset,
409                            ));
410                        }
411                        current_chunk = sentence;
412                    } else {
413                        if !current_chunk.is_empty() {
414                            current_chunk.push(' ');
415                        }
416                        current_chunk.push_str(&sentence);
417                    }
418                }
419
420                if !current_chunk.is_empty() {
421                    result.push((exp_type, confidence, current_chunk, offset));
422                }
423            }
424        }
425
426        result
427    }
428
429    /// Simple entity extraction (words > 2 chars, excluding stopwords)
430    fn extract_simple_entities(&self, content: &str) -> Vec<String> {
431        let stopwords: HashSet<&str> = [
432            "the", "a", "an", "is", "are", "was", "were", "be", "been", "being", "have", "has",
433            "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
434            "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
435            "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
436            "below", "between", "under", "again", "further", "then", "once", "here", "there",
437            "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
438            "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just",
439            "and", "but", "or", "if", "because", "while", "although", "this", "that", "these",
440            "those", "i", "you", "he", "she", "it", "we", "they", "what", "which", "who", "whom",
441            "its", "his", "her", "their", "my", "your", "our",
442        ]
443        .into_iter()
444        .collect();
445
446        content
447            .to_lowercase()
448            .split(|c: char| !c.is_alphanumeric() && c != '_' && c != '-')
449            .filter(|word| word.len() > 2 && !stopwords.contains(word))
450            .map(|s| s.to_string())
451            .collect::<HashSet<_>>()
452            .into_iter()
453            .collect()
454    }
455}
456
457/// Deduplication result
458#[derive(Debug, Clone)]
459pub enum DeduplicationResult {
460    /// Store as new memory
461    New,
462    /// Exact duplicate - skip storage
463    Duplicate { existing_id: String },
464    /// Semantic near-duplicate - consider merging
465    SemanticMatch {
466        existing_id: String,
467        similarity: f32,
468    },
469    /// Same entities but different content - link as related
470    EntityOverlap { existing_id: String, overlap: f32 },
471}
472
473/// Deduplication engine to prevent duplicate Hebbian edges
474pub struct DeduplicationEngine {
475    /// Content hash -> memory ID index
476    content_hashes: HashSet<u64>,
477}
478
479impl Default for DeduplicationEngine {
480    fn default() -> Self {
481        Self::new()
482    }
483}
484
485impl DeduplicationEngine {
486    pub fn new() -> Self {
487        Self {
488            content_hashes: HashSet::new(),
489        }
490    }
491
492    /// Compute content hash for exact duplicate detection
493    pub fn content_hash(content: &str) -> u64 {
494        use std::hash::{Hash, Hasher};
495        let mut hasher = std::collections::hash_map::DefaultHasher::new();
496        // Normalize: lowercase, collapse whitespace
497        let normalized: String = content
498            .to_lowercase()
499            .split_whitespace()
500            .collect::<Vec<_>>()
501            .join(" ");
502        normalized.hash(&mut hasher);
503        hasher.finish()
504    }
505
506    /// Check if content is a duplicate
507    pub fn is_duplicate(&self, content: &str) -> bool {
508        let hash = Self::content_hash(content);
509        self.content_hashes.contains(&hash)
510    }
511
512    /// Register a new content hash
513    pub fn register(&mut self, content: &str) {
514        let hash = Self::content_hash(content);
515        self.content_hashes.insert(hash);
516    }
517
518    /// Calculate entity overlap between two entity sets
519    pub fn calculate_entity_overlap(entities1: &[String], entities2: &[String]) -> f32 {
520        if entities1.is_empty() || entities2.is_empty() {
521            return 0.0;
522        }
523
524        let set1: HashSet<_> = entities1.iter().map(|s| s.to_lowercase()).collect();
525        let set2: HashSet<_> = entities2.iter().map(|s| s.to_lowercase()).collect();
526
527        let intersection = set1.intersection(&set2).count();
528        let union = set1.union(&set2).count();
529
530        if union == 0 {
531            0.0
532        } else {
533            intersection as f32 / union as f32
534        }
535    }
536}
537
538#[cfg(test)]
539mod tests {
540    use super::*;
541
542    #[test]
543    fn test_sentence_splitting() {
544        let engine = SegmentationEngine::new();
545
546        // Test with explicit newline separation which is more reliable
547        let input = "I decided to use Rust.\n\nIt has great performance.\n\nThe memory safety is excellent.";
548        let sentences = engine.split_sentences(input);
549
550        assert_eq!(sentences.len(), 3);
551        assert!(sentences[0].contains("Rust"));
552        assert!(sentences[1].contains("performance"));
553        assert!(sentences[2].contains("memory safety"));
554    }
555
556    #[test]
557    fn test_abbreviation_preservation() {
558        let engine = SegmentationEngine::new();
559
560        let input = "E.g. this is an example.\n\nDr. Smith said so.";
561        let sentences = engine.split_sentences(input);
562
563        // Should preserve abbreviations within sentences
564        assert_eq!(sentences.len(), 2);
565        assert!(sentences[0].contains("E.g."));
566        assert!(sentences[1].contains("Dr."));
567    }
568
569    #[test]
570    fn test_type_detection_decision() {
571        let engine = SegmentationEngine::new();
572
573        let (exp_type, confidence) = engine.detect_type(
574            "I decided to use Rust for this project",
575            InputSource::UserApi,
576        );
577
578        assert!(matches!(exp_type, ExperienceType::Decision));
579        assert!(confidence > 0.9);
580    }
581
582    #[test]
583    fn test_type_detection_error() {
584        let engine = SegmentationEngine::new();
585
586        let (exp_type, confidence) =
587            engine.detect_type("error: cannot find module 'foo'", InputSource::UserApi);
588
589        assert!(matches!(exp_type, ExperienceType::Error));
590        assert!(confidence > 0.9);
591    }
592
593    #[test]
594    fn test_type_detection_learning() {
595        let engine = SegmentationEngine::new();
596
597        let (exp_type, confidence) = engine.detect_type(
598            "I learned that async functions need await",
599            InputSource::UserApi,
600        );
601
602        assert!(matches!(exp_type, ExperienceType::Learning));
603        assert!(confidence > 0.8);
604    }
605
606    #[test]
607    fn test_type_detection_intention() {
608        let engine = SegmentationEngine::new();
609
610        let (exp_type, confidence) =
611            engine.detect_type("Tomorrow I will review the PR", InputSource::UserApi);
612
613        assert!(matches!(exp_type, ExperienceType::Intention));
614        assert!(confidence > 0.8);
615    }
616
617    #[test]
618    fn test_segmentation_mixed_types() {
619        let engine = SegmentationEngine::new();
620
621        // Use explicit type markers for clearer segmentation
622        let input = "I decided to use Rust.\n\nerror: found a bug in the auth module.\n\nTomorrow need to fix it.";
623        let segments = engine.segment(input, InputSource::UserApi);
624
625        assert_eq!(segments.len(), 3);
626        assert!(matches!(
627            segments[0].experience_type,
628            ExperienceType::Decision
629        ));
630        assert!(matches!(segments[1].experience_type, ExperienceType::Error));
631        assert!(matches!(
632            segments[2].experience_type,
633            ExperienceType::Intention
634        ));
635    }
636
637    #[test]
638    fn test_same_type_merging() {
639        let engine = SegmentationEngine::new();
640
641        // All sentences have "decided" which triggers Decision type
642        let input =
643            "I decided to use Rust.\n\nI also decided to use Axum.\n\nWe chose RocksDB for storage.";
644        let segments = engine.segment(input, InputSource::UserApi);
645
646        // All three sentences are Decision type, should merge into one
647        assert_eq!(segments.len(), 1);
648        assert!(matches!(
649            segments[0].experience_type,
650            ExperienceType::Decision
651        ));
652        assert!(segments[0].content.contains("Rust"));
653        assert!(segments[0].content.contains("Axum"));
654    }
655
656    #[test]
657    fn test_entity_extraction() {
658        let engine = SegmentationEngine::new();
659
660        let entities =
661            engine.extract_simple_entities("I decided to use Rust for the shodh-memory project");
662
663        assert!(entities.contains(&"rust".to_string()));
664        assert!(entities.contains(&"shodh-memory".to_string()));
665        assert!(entities.contains(&"project".to_string()));
666        // Stopwords should be excluded
667        assert!(!entities.contains(&"the".to_string()));
668        assert!(!entities.contains(&"to".to_string()));
669    }
670
671    #[test]
672    fn test_deduplication_hash() {
673        let hash1 = DeduplicationEngine::content_hash("Hello World");
674        let hash2 = DeduplicationEngine::content_hash("hello world");
675        let hash3 = DeduplicationEngine::content_hash("Hello  World"); // Extra space
676
677        // All should normalize to same hash
678        assert_eq!(hash1, hash2);
679        assert_eq!(hash2, hash3);
680    }
681
682    #[test]
683    fn test_entity_overlap() {
684        let entities1 = vec![
685            "rust".to_string(),
686            "memory".to_string(),
687            "project".to_string(),
688        ];
689        let entities2 = vec![
690            "rust".to_string(),
691            "memory".to_string(),
692            "performance".to_string(),
693        ];
694
695        let overlap = DeduplicationEngine::calculate_entity_overlap(&entities1, &entities2);
696
697        // 2 common (rust, memory) / 4 total unique = 0.5
698        assert!((overlap - 0.5).abs() < 0.01);
699    }
700
701    #[test]
702    fn test_max_length_split() {
703        let mut engine = SegmentationEngine::new();
704        engine.max_segment_length = 100;
705
706        let long_input =
707            "This is a very long decision that I made about using Rust for the backend. \
708            I also decided to use Axum for the web framework because it has great performance. \
709            Additionally I chose RocksDB for storage due to its reliability and speed.";
710
711        let segments = engine.segment(long_input, InputSource::UserApi);
712
713        // Should split into multiple segments due to length
714        assert!(segments.len() > 1);
715        for segment in &segments {
716            assert!(segment.content.len() <= engine.max_segment_length + 50); // Allow some overflow
717        }
718    }
719}