1use std::collections::{HashMap, HashSet};
2use std::path::Path;
3use std::fs::File;
4use std::io::{BufWriter, BufReader};
5use serde::{Serialize, Deserialize};
6use anyhow::Result;
7use regex::Regex;
8
9#[derive(Debug, Serialize, Deserialize)]
11pub struct TrigramIndex {
12 trigrams: HashMap<String, HashSet<u32>>,
14 total_trigrams: usize,
16}
17
18impl TrigramIndex {
19 pub fn new() -> Self {
21 Self {
22 trigrams: HashMap::new(),
23 total_trigrams: 0,
24 }
25 }
26
27 pub fn add_file_content(&mut self, file_handle: u32, content: &str) {
29 let trigrams = extract_trigrams(content);
30
31 for trigram in trigrams {
32 self.trigrams
33 .entry(trigram)
34 .or_insert_with(HashSet::new)
35 .insert(file_handle);
36 }
37
38 self.total_trigrams += 1;
39 }
40
41 pub fn save_to_file<P: AsRef<Path>>(&self, path: P) -> Result<()> {
43 let file = File::create(path)?;
44 let writer = BufWriter::new(file);
45 serde_json::to_writer_pretty(writer, self)?;
46 Ok(())
47 }
48
49 pub fn load_from_file<P: AsRef<Path>>(path: P) -> Result<Self> {
51 let file = File::open(path)?;
52 let reader = BufReader::new(file);
53 let index: Self = serde_json::from_reader(reader)?;
54 Ok(index)
55 }
56
57 pub fn extract_regex_trigrams(&self, regex_pattern: &str) -> Result<Vec<String>> {
59 let _regex = Regex::new(regex_pattern)?;
61
62 let literals = extract_literal_parts(regex_pattern);
66
67 let mut trigrams = HashSet::new();
68 for literal in literals {
69 if literal.len() >= 3 {
70 trigrams.extend(extract_trigrams(&literal));
71 }
72 }
73
74 Ok(trigrams.into_iter().collect())
75 }
76
77 pub fn get_regex_candidates(&self, regex_pattern: &str) -> Result<HashSet<u32>> {
79 let required_trigrams = self.extract_regex_trigrams(regex_pattern)?;
80
81 if required_trigrams.is_empty() {
82 let all_handles: HashSet<u32> = self.trigrams
85 .values()
86 .flat_map(|handles| handles.iter())
87 .cloned()
88 .collect();
89 return Ok(all_handles);
90 }
91
92 let mut candidates: Option<HashSet<u32>> = None;
94
95 for trigram in required_trigrams {
96 if let Some(handles) = self.trigrams.get(&trigram) {
97 match candidates {
98 None => candidates = Some(handles.clone()),
99 Some(ref mut current) => {
100 current.retain(|h| handles.contains(h));
101 }
102 }
103 } else {
104 return Ok(HashSet::new());
106 }
107 }
108
109 Ok(candidates.unwrap_or_default())
110 }
111
112 pub fn stats(&self) -> TrigramStats {
114 let total_file_references: usize = self.trigrams
115 .values()
116 .map(|handles| handles.len())
117 .sum();
118
119 TrigramStats {
120 unique_trigrams: self.trigrams.len(),
121 total_file_references,
122 avg_files_per_trigram: if self.trigrams.is_empty() {
123 0.0
124 } else {
125 total_file_references as f64 / self.trigrams.len() as f64
126 },
127 }
128 }
129}
130
131#[derive(Debug)]
132pub struct TrigramStats {
133 pub unique_trigrams: usize,
134 pub total_file_references: usize,
135 pub avg_files_per_trigram: f64,
136}
137
138fn extract_trigrams(text: &str) -> HashSet<String> {
140 let mut trigrams = HashSet::new();
141
142 let normalized = text.to_lowercase();
144 let chars: Vec<char> = normalized.chars().collect();
145
146 for window in chars.windows(3) {
148 let trigram: String = window.iter().collect();
149 if trigram.chars().all(|c| c.is_alphanumeric() || c == '_') {
151 trigrams.insert(trigram);
152 }
153 }
154
155 trigrams
156}
157
158fn extract_literal_parts(pattern: &str) -> Vec<String> {
161 let mut literals = Vec::new();
162 let mut current_literal = String::new();
163 let chars: Vec<char> = pattern.chars().collect();
164
165 let mut i = 0;
166 while i < chars.len() {
167 match chars[i] {
168 '.' | '*' | '+' | '?' | '^' | '$' | '|' | '(' | ')' | '[' | ']' | '{' | '}' => {
170 if !current_literal.is_empty() {
171 literals.push(current_literal.clone());
172 current_literal.clear();
173 }
174 i += 1;
175 }
176 '\\' => {
178 if i + 1 < chars.len() {
179 current_literal.push(chars[i + 1]);
181 i += 2;
182 } else {
183 i += 1;
184 }
185 }
186 c => {
188 current_literal.push(c);
189 i += 1;
190 }
191 }
192 }
193
194 if !current_literal.is_empty() {
195 literals.push(current_literal);
196 }
197
198 literals
199}
200
201#[cfg(test)]
202mod tests {
203 use super::*;
204
205 #[test]
206 fn test_extract_trigrams() {
207 let text = "hello world";
208 let trigrams = extract_trigrams(text);
209
210 assert!(trigrams.contains("hel"));
211 assert!(trigrams.contains("ell"));
212 assert!(trigrams.contains("llo"));
213 assert!(trigrams.contains("wor"));
214 assert!(trigrams.contains("orl"));
215 assert!(trigrams.contains("rld"));
216 }
217
218 #[test]
219 fn test_extract_literal_parts() {
220 let patterns = vec![
221 ("hello", vec!["hello"]),
222 ("hello.*world", vec!["hello", "world"]),
223 ("fn\\s+\\w+", vec!["fn"]), ("(test|demo)", vec!["test", "demo"]),
225 ];
226
227 for (pattern, expected) in patterns {
228 let literals = extract_literal_parts(pattern);
229 assert_eq!(literals, expected, "Failed for pattern: {}", pattern);
230 }
231 }
232
233 #[test]
234 fn test_trigram_index() {
235 let mut index = TrigramIndex::new();
236
237 index.add_file_content(1, "hello world");
238 index.add_file_content(2, "world peace");
239
240 let candidates = index.trigrams.get("wor").unwrap();
242 assert!(candidates.contains(&1));
243 assert!(candidates.contains(&2));
244
245 let candidates = index.trigrams.get("hel").unwrap();
247 assert!(candidates.contains(&1));
248 assert!(!candidates.contains(&2));
249 }
250}