siftdb_core/
trigram.rs

1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3use std::fs::File;
4use std::io::{BufWriter, BufReader};
5use serde::{Serialize, Deserialize};
6use anyhow::Result;
7use regex::Regex;
8
9/// Trigram index for fast regex preprocessing and filtering
10#[derive(Debug, Serialize, Deserialize)]
11pub struct TrigramIndex {
12    /// Maps trigram -> set of file handles that contain it
13    trigrams: HashMap<String, HashSet<u32>>,
14    /// Total number of trigrams indexed
15    total_trigrams: usize,
16}
17
18impl TrigramIndex {
19    /// Create new empty trigram index
20    pub fn new() -> Self {
21        Self {
22            trigrams: HashMap::new(),
23            total_trigrams: 0,
24        }
25    }
26
27    /// Add content from a file to the trigram index
28    pub fn add_file_content(&mut self, file_handle: u32, content: &str) {
29        let trigrams = extract_trigrams(content);
30        
31        for trigram in trigrams {
32            self.trigrams
33                .entry(trigram)
34                .or_insert_with(HashSet::new)
35                .insert(file_handle);
36        }
37        
38        self.total_trigrams += 1;
39    }
40
41    /// Save trigram index to file
42    pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
43        let file = File::create(path)?;
44        let writer = BufWriter::new(file);
45        serde_json::to_writer_pretty(writer, self)?;
46        Ok(())
47    }
48
49    /// Load trigram index from file
50    pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
51        let file = File::open(path)?;
52        let reader = BufReader::new(file);
53        let index: Self = serde_json::from_reader(reader)?;
54        Ok(index)
55    }
56
57    /// Extract trigrams from a regex pattern for fast filtering
58    pub fn extract_regex_trigrams(&self, regex_pattern: &str) -> Result<Vec<String>> {
59        // Try to compile the regex to validate it
60        let _regex = Regex::new(regex_pattern)?;
61        
62        // For now, extract literal trigrams from the pattern
63        // This is a simplified approach - a full implementation would
64        // analyze the regex AST to find guaranteed literal sequences
65        let literals = extract_literal_parts(regex_pattern);
66        
67        let mut trigrams = HashSet::new();
68        for literal in literals {
69            if literal.len() >= 3 {
70                trigrams.extend(extract_trigrams(&literal));
71            }
72        }
73        
74        Ok(trigrams.into_iter().collect())
75    }
76
77    /// Get candidate file handles that might match a regex based on trigrams
78    pub fn get_regex_candidates(&self, regex_pattern: &str) -> Result<HashSet<u32>> {
79        let required_trigrams = self.extract_regex_trigrams(regex_pattern)?;
80        
81        if required_trigrams.is_empty() {
82            // If no trigrams can be extracted, we need to check all files
83            // Return all file handles from all trigrams
84            let all_handles: HashSet<u32> = self.trigrams
85                .values()
86                .flat_map(|handles| handles.iter())
87                .cloned()
88                .collect();
89            return Ok(all_handles);
90        }
91
92        // Find intersection of files containing all required trigrams
93        let mut candidates: Option<HashSet<u32>> = None;
94        
95        for trigram in required_trigrams {
96            if let Some(handles) = self.trigrams.get(&trigram) {
97                match candidates {
98                    None => candidates = Some(handles.clone()),
99                    Some(ref mut current) => {
100                        current.retain(|h| handles.contains(h));
101                    }
102                }
103            } else {
104                // If any required trigram is missing, no matches possible
105                return Ok(HashSet::new());
106            }
107        }
108        
109        Ok(candidates.unwrap_or_default())
110    }
111
112    /// Get statistics about the trigram index
113    pub fn stats(&self) -> TrigramStats {
114        let total_file_references: usize = self.trigrams
115            .values()
116            .map(|handles| handles.len())
117            .sum();
118
119        TrigramStats {
120            unique_trigrams: self.trigrams.len(),
121            total_file_references,
122            avg_files_per_trigram: if self.trigrams.is_empty() {
123                0.0
124            } else {
125                total_file_references as f64 / self.trigrams.len() as f64
126            },
127        }
128    }
129}
130
131#[derive(Debug)]
132pub struct TrigramStats {
133    pub unique_trigrams: usize,
134    pub total_file_references: usize,
135    pub avg_files_per_trigram: f64,
136}
137
138/// Extract all trigrams (3-character sequences) from text
139fn extract_trigrams(text: &str) -> HashSet<String> {
140    let mut trigrams = HashSet::new();
141    
142    // Normalize to lowercase for case-insensitive matching
143    let normalized = text.to_lowercase();
144    let chars: Vec<char> = normalized.chars().collect();
145    
146    // Extract 3-character windows
147    for window in chars.windows(3) {
148        let trigram: String = window.iter().collect();
149        // Only include alphanumeric trigrams to reduce index size
150        if trigram.chars().all(|c| c.is_alphanumeric() || c == '_') {
151            trigrams.insert(trigram);
152        }
153    }
154    
155    trigrams
156}
157
158/// Extract literal parts from a regex pattern
159/// This is a simplified implementation - a full version would parse the regex AST
160fn extract_literal_parts(pattern: &str) -> Vec<String> {
161    let mut literals = Vec::new();
162    let mut current_literal = String::new();
163    let chars: Vec<char> = pattern.chars().collect();
164    
165    let mut i = 0;
166    while i < chars.len() {
167        match chars[i] {
168            // Regex metacharacters that end a literal sequence
169            '.' | '*' | '+' | '?' | '^' | '$' | '|' | '(' | ')' | '[' | ']' | '{' | '}' => {
170                if !current_literal.is_empty() {
171                    literals.push(current_literal.clone());
172                    current_literal.clear();
173                }
174                i += 1;
175            }
176            // Escape sequences
177            '\\' => {
178                if i + 1 < chars.len() {
179                    // Add the escaped character as literal
180                    current_literal.push(chars[i + 1]);
181                    i += 2;
182                } else {
183                    i += 1;
184                }
185            }
186            // Regular characters
187            c => {
188                current_literal.push(c);
189                i += 1;
190            }
191        }
192    }
193    
194    if !current_literal.is_empty() {
195        literals.push(current_literal);
196    }
197    
198    literals
199}
200
201#[cfg(test)]
202mod tests {
203    use super::*;
204
205    #[test]
206    fn test_extract_trigrams() {
207        let text = "hello world";
208        let trigrams = extract_trigrams(text);
209        
210        assert!(trigrams.contains("hel"));
211        assert!(trigrams.contains("ell"));
212        assert!(trigrams.contains("llo"));
213        assert!(trigrams.contains("wor"));
214        assert!(trigrams.contains("orl"));
215        assert!(trigrams.contains("rld"));
216    }
217
218    #[test]
219    fn test_extract_literal_parts() {
220        let patterns = vec![
221            ("hello", vec!["hello"]),
222            ("hello.*world", vec!["hello", "world"]),
223            ("fn\\s+\\w+", vec!["fn"]),  // \\s+ and \\w+ are not literals
224            ("(test|demo)", vec!["test", "demo"]),
225        ];
226        
227        for (pattern, expected) in patterns {
228            let literals = extract_literal_parts(pattern);
229            assert_eq!(literals, expected, "Failed for pattern: {}", pattern);
230        }
231    }
232
233    #[test]
234    fn test_trigram_index() {
235        let mut index = TrigramIndex::new();
236        
237        index.add_file_content(1, "hello world");
238        index.add_file_content(2, "world peace");
239        
240        // Both files should contain "wor" trigram
241        let candidates = index.trigrams.get("wor").unwrap();
242        assert!(candidates.contains(&1));
243        assert!(candidates.contains(&2));
244        
245        // Only file 1 should contain "hel" trigram
246        let candidates = index.trigrams.get("hel").unwrap();
247        assert!(candidates.contains(&1));
248        assert!(!candidates.contains(&2));
249    }
250}