siftdb_core/
inverted_index.rs1use std::path::{Path, PathBuf};
2use std::collections::{HashMap, HashSet};
3use std::fs::File;
4use std::io::{BufWriter, BufReader};
5use fst::{IntoStreamer, Streamer, Map, MapBuilder};
6use anyhow::{Result, Context};
7use memmap2::Mmap;
8use serde::{Serialize, Deserialize};
9
10#[derive(Debug)]
12pub struct InvertedIndex {
13 term_map: Option<Map<Mmap>>,
15 posting_lists: HashMap<String, HashSet<u32>>,
17}
18
19#[derive(Serialize, Deserialize)]
20struct PostingListData {
21 posting_lists: HashMap<String, HashSet<u32>>,
22}
23
24impl InvertedIndex {
25 pub fn new() -> Self {
27 Self {
28 term_map: None,
29 posting_lists: HashMap::new(),
30 }
31 }
32
33 pub fn build_from_content<P: AsRef<Path>>(
35 file_contents: HashMap<u32, String>, output_fst_path: P,
37 output_json_path: P,
38 ) -> Result<Self> {
39 let mut posting_lists: HashMap<String, HashSet<u32>> = HashMap::new();
40
41 for (file_handle, content) in file_contents {
43 let tokens = tokenize(&content);
44 for token in tokens {
45 posting_lists
46 .entry(token.to_lowercase())
47 .or_insert_with(HashSet::new)
48 .insert(file_handle);
49 }
50 }
51
52 let mut sorted_terms: Vec<_> = posting_lists.keys().collect();
54 sorted_terms.sort();
55
56 let file = File::create(&output_fst_path)?;
57 let mut builder = MapBuilder::new(BufWriter::new(file))?;
58
59 for (i, term) in sorted_terms.iter().enumerate() {
60 let term_bytes = term.as_bytes();
61 builder.insert(term_bytes, i as u64)?;
62 }
63
64 builder.finish()?;
65
66 let data = PostingListData {
68 posting_lists: posting_lists.clone(),
69 };
70 let file = File::create(&output_json_path)?;
71 let writer = BufWriter::new(file);
72 serde_json::to_writer(writer, &data)?;
73
74 let file = File::open(&output_fst_path)?;
76 let mmap = unsafe { memmap2::Mmap::map(&file)? };
77 let term_map = Map::new(mmap)?;
78
79 Ok(Self {
80 term_map: Some(term_map),
81 posting_lists,
82 })
83 }
84
85 pub fn load_from_files<P: AsRef<Path>>(
87 fst_path: P,
88 json_path: P,
89 ) -> Result<Self> {
90 let file = File::open(&fst_path)?;
92 let mmap = unsafe { memmap2::Mmap::map(&file)? };
93 let term_map = Map::new(mmap)?;
94
95 let file = File::open(&json_path)?;
97 let reader = BufReader::new(file);
98 let data: PostingListData = serde_json::from_reader(reader)
99 .context("Failed to parse posting lists")?;
100
101 Ok(Self {
102 term_map: Some(term_map),
103 posting_lists: data.posting_lists,
104 })
105 }
106
107 pub fn find_files_with_term(&self, term: &str) -> HashSet<u32> {
109 let term_lower = term.to_lowercase();
110 self.posting_lists
111 .get(&term_lower)
112 .cloned()
113 .unwrap_or_default()
114 }
115
116 pub fn find_files_with_all_terms(&self, terms: &[&str]) -> HashSet<u32> {
118 if terms.is_empty() {
119 return HashSet::new();
120 }
121
122 let mut result = self.find_files_with_term(terms[0]);
123
124 for &term in &terms[1..] {
125 let term_files = self.find_files_with_term(term);
126 result = result.intersection(&term_files).cloned().collect();
127
128 if result.is_empty() {
129 break; }
131 }
132
133 result
134 }
135
136 pub fn find_files_with_any_terms(&self, terms: &[&str]) -> HashSet<u32> {
138 let mut result = HashSet::new();
139
140 for &term in terms {
141 let term_files = self.find_files_with_term(term);
142 result = result.union(&term_files).cloned().collect();
143 }
144
145 result
146 }
147
148 pub fn contains_term(&self, term: &str) -> bool {
150 let term_lower = term.to_lowercase();
151 self.posting_lists.contains_key(&term_lower)
152 }
153
154 pub fn term_frequency(&self, term: &str) -> usize {
156 let term_lower = term.to_lowercase();
157 self.posting_lists
158 .get(&term_lower)
159 .map(|files| files.len())
160 .unwrap_or(0)
161 }
162
163 pub fn term_count(&self) -> usize {
165 self.posting_lists.len()
166 }
167}
168
169fn tokenize(content: &str) -> Vec<String> {
171 let mut tokens = Vec::new();
172 let mut current_token = String::new();
173
174 for ch in content.chars() {
175 if ch.is_alphanumeric() || ch == '_' {
176 current_token.push(ch);
177 } else {
178 if !current_token.is_empty() {
179 tokens.push(current_token.clone());
180 current_token.clear();
181 }
182 }
183 }
184
185 if !current_token.is_empty() {
186 tokens.push(current_token);
187 }
188
189 tokens
191 .into_iter()
192 .filter(|t| t.len() >= 2 && !is_stop_word(t))
193 .collect()
194}
195
196fn is_stop_word(word: &str) -> bool {
198 matches!(word.to_lowercase().as_str(),
199 "the" | "a" | "an" | "and" | "or" | "but" | "in" | "on" | "at" | "to" | "for" | "of" | "with" | "by"
200 )
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206 use tempfile::TempDir;
207
208 #[test]
209 fn test_tokenize() {
210 let content = "fn main() { println!(\"Hello, world!\"); }";
211 let tokens = tokenize(content);
212 assert!(tokens.contains(&"fn".to_string()));
213 assert!(tokens.contains(&"main".to_string()));
214 assert!(tokens.contains(&"println".to_string()));
215 assert!(tokens.contains(&"Hello".to_string()));
216 assert!(tokens.contains(&"world".to_string()));
217 }
218
219 #[test]
220 fn test_inverted_index() -> Result<()> {
221 let temp_dir = TempDir::new()?;
222 let fst_path = temp_dir.path().join("terms.fst");
223 let json_path = temp_dir.path().join("posting_lists.json");
224
225 let mut contents = HashMap::new();
226 contents.insert(1, "fn main() { println!(\"Hello\"); }".to_string());
227 contents.insert(2, "fn test() { assert_eq!(1, 1); }".to_string());
228 contents.insert(3, "struct Point { x: i32, y: i32 }".to_string());
229
230 let index = InvertedIndex::build_from_content(contents, &fst_path, &json_path)?;
231
232 let fn_files = index.find_files_with_term("fn");
234 assert_eq!(fn_files.len(), 2);
235 assert!(fn_files.contains(&1));
236 assert!(fn_files.contains(&2));
237
238 let struct_files = index.find_files_with_term("struct");
239 assert_eq!(struct_files.len(), 1);
240 assert!(struct_files.contains(&3));
241
242 let main_fn = index.find_files_with_all_terms(&["fn", "main"]);
244 assert_eq!(main_fn.len(), 1);
245 assert!(main_fn.contains(&1));
246
247 Ok(())
248 }
249}