1use std::collections::HashMap;
7use std::path::Path;
8
9use crate::error::IndexError;
10use crate::index::{IndexEntry, SearchHit};
11use crate::parser::CodeChunk;
12use crate::INDEX_VERSION;
13
14#[derive(Debug, serde::Serialize, serde::Deserialize)]
16pub struct SeekrIndex {
17 pub version: u32,
19
20 pub vectors: HashMap<u64, Vec<f32>>,
22
23 pub inverted_index: HashMap<String, Vec<(u64, u32)>>,
25
26 pub chunks: HashMap<u64, CodeChunk>,
28
29 pub embedding_dim: usize,
31
32 pub chunk_count: usize,
34}
35
36impl SeekrIndex {
37 pub fn new(embedding_dim: usize) -> Self {
39 Self {
40 version: INDEX_VERSION,
41 vectors: HashMap::new(),
42 inverted_index: HashMap::new(),
43 chunks: HashMap::new(),
44 embedding_dim,
45 chunk_count: 0,
46 }
47 }
48
49 pub fn add_entry(&mut self, entry: IndexEntry, chunk: CodeChunk) {
51 let chunk_id = entry.chunk_id;
52
53 self.vectors.insert(chunk_id, entry.embedding);
55
56 for token in &entry.text_tokens {
58 let posting_list = self.inverted_index.entry(token.clone()).or_default();
59 if let Some(existing) = posting_list.iter_mut().find(|(id, _)| *id == chunk_id) {
60 existing.1 += 1;
61 } else {
62 posting_list.push((chunk_id, 1));
63 }
64 }
65
66 self.chunks.insert(chunk_id, chunk);
68 self.chunk_count = self.chunks.len();
69 }
70
71 pub fn remove_chunk(&mut self, chunk_id: u64) {
75 self.vectors.remove(&chunk_id);
77
78 self.inverted_index.retain(|_token, posting_list| {
80 posting_list.retain(|(id, _)| *id != chunk_id);
81 !posting_list.is_empty()
82 });
83
84 self.chunks.remove(&chunk_id);
86 self.chunk_count = self.chunks.len();
87 }
88
89 pub fn remove_chunks(&mut self, chunk_ids: &[u64]) {
91 for &chunk_id in chunk_ids {
92 self.remove_chunk(chunk_id);
93 }
94 }
95
96 pub fn build_from(
98 chunks: &[CodeChunk],
99 embeddings: &[Vec<f32>],
100 embedding_dim: usize,
101 ) -> Self {
102 let mut index = Self::new(embedding_dim);
103
104 for (chunk, embedding) in chunks.iter().zip(embeddings.iter()) {
105 let text_tokens = tokenize_for_index(&chunk.body);
106
107 let entry = IndexEntry {
108 chunk_id: chunk.id,
109 embedding: embedding.clone(),
110 text_tokens,
111 };
112
113 index.add_entry(entry, chunk.clone());
114 }
115
116 index
117 }
118
119 pub fn search_vector(
123 &self,
124 query_embedding: &[f32],
125 top_k: usize,
126 score_threshold: f32,
127 ) -> Vec<SearchHit> {
128 let mut scores: Vec<(u64, f32)> = self
129 .vectors
130 .iter()
131 .map(|(&chunk_id, embedding)| {
132 let score = cosine_similarity(query_embedding, embedding);
133 (chunk_id, score)
134 })
135 .filter(|(_, score)| *score >= score_threshold)
136 .collect();
137
138 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
140
141 scores
142 .into_iter()
143 .take(top_k)
144 .map(|(chunk_id, score)| SearchHit { chunk_id, score })
145 .collect()
146 }
147
148 pub fn search_text(&self, query: &str, top_k: usize) -> Vec<SearchHit> {
152 let query_tokens = tokenize_for_index(query);
153
154 if query_tokens.is_empty() {
155 return Vec::new();
156 }
157
158 let mut scores: HashMap<u64, f32> = HashMap::new();
160
161 for token in &query_tokens {
162 if let Some(posting_list) = self.inverted_index.get(token) {
163 for &(chunk_id, frequency) in posting_list {
164 *scores.entry(chunk_id).or_default() += frequency as f32;
165 }
166 }
167 }
168
169 let num_tokens = query_tokens.len() as f32;
171 let mut results: Vec<(u64, f32)> = scores
172 .into_iter()
173 .map(|(id, score)| (id, score / num_tokens))
174 .collect();
175
176 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
177
178 results
179 .into_iter()
180 .take(top_k)
181 .map(|(chunk_id, score)| SearchHit { chunk_id, score })
182 .collect()
183 }
184
185 pub fn get_chunk(&self, chunk_id: u64) -> Option<&CodeChunk> {
187 self.chunks.get(&chunk_id)
188 }
189
190 pub fn save(&self, dir: &Path) -> Result<(), IndexError> {
192 std::fs::create_dir_all(dir)?;
193
194 let index_path = dir.join("index.json");
195 let data = serde_json::to_vec(self)
196 .map_err(|e| IndexError::Serialization(e.to_string()))?;
197 std::fs::write(&index_path, data)?;
198
199 tracing::info!(
200 chunks = self.chunk_count,
201 path = %dir.display(),
202 "Index saved"
203 );
204
205 Ok(())
206 }
207
208 pub fn load(dir: &Path) -> Result<Self, IndexError> {
210 let index_path = dir.join("index.json");
211
212 if !index_path.exists() {
213 return Err(IndexError::NotFound(index_path));
214 }
215
216 let data = std::fs::read(&index_path)?;
217 let index: SeekrIndex = serde_json::from_slice(&data)
218 .map_err(|e| IndexError::Serialization(e.to_string()))?;
219
220 if index.version != INDEX_VERSION {
222 return Err(IndexError::VersionMismatch {
223 file_version: index.version,
224 expected_version: INDEX_VERSION,
225 });
226 }
227
228 tracing::info!(
229 chunks = index.chunk_count,
230 path = %dir.display(),
231 "Index loaded"
232 );
233
234 Ok(index)
235 }
236}
237
238pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
240 if a.len() != b.len() || a.is_empty() {
241 return 0.0;
242 }
243
244 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
245 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
246 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
247
248 if norm_a == 0.0 || norm_b == 0.0 {
249 return 0.0;
250 }
251
252 dot / (norm_a * norm_b)
253}
254
255fn tokenize_for_index(text: &str) -> Vec<String> {
259 text.split(|c: char| !c.is_alphanumeric() && c != '_')
260 .map(|s| s.to_lowercase())
261 .filter(|s| s.len() >= 2)
262 .collect()
263}
264
265pub fn tokenize_for_index_pub(text: &str) -> Vec<String> {
267 tokenize_for_index(text)
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273 use crate::parser::ChunkKind;
274 use std::path::PathBuf;
275
276 fn make_test_chunk(id: u64, name: &str, body: &str) -> CodeChunk {
277 CodeChunk {
278 id,
279 file_path: PathBuf::from("test.rs"),
280 language: "rust".to_string(),
281 kind: ChunkKind::Function,
282 name: Some(name.to_string()),
283 signature: None,
284 doc_comment: None,
285 body: body.to_string(),
286 byte_range: 0..body.len(),
287 line_range: 0..1,
288 }
289 }
290
291 #[test]
292 fn test_cosine_similarity() {
293 let a = vec![1.0, 0.0, 0.0];
294 let b = vec![0.0, 1.0, 0.0];
295 assert!((cosine_similarity(&a, &b)).abs() < 0.01);
296
297 let c = vec![1.0, 0.0, 0.0];
298 assert!((cosine_similarity(&a, &c) - 1.0).abs() < 0.01);
299 }
300
301 #[test]
302 fn test_build_and_search_text() {
303 let chunks = vec![
304 make_test_chunk(1, "authenticate", "fn authenticate(user: &str, password: &str) -> Result<Token, Error>"),
305 make_test_chunk(2, "calculate", "fn calculate_total(items: &[Item]) -> f64"),
306 ];
307 let embeddings = vec![vec![0.1; 8], vec![0.2; 8]];
308
309 let index = SeekrIndex::build_from(&chunks, &embeddings, 8);
310
311 assert_eq!(index.chunk_count, 2);
312
313 let results = index.search_text("authenticate user password", 10);
315 assert!(!results.is_empty());
316 assert_eq!(results[0].chunk_id, 1);
317 }
318
319 #[test]
320 fn test_build_and_search_vector() {
321 let chunks = vec![
322 make_test_chunk(1, "foo", "fn foo()"),
323 make_test_chunk(2, "bar", "fn bar()"),
324 ];
325 let embeddings = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
326
327 let index = SeekrIndex::build_from(&chunks, &embeddings, 3);
328
329 let query = vec![0.9, 0.1, 0.0];
331 let results = index.search_vector(&query, 2, 0.0);
332 assert!(!results.is_empty());
333 assert_eq!(results[0].chunk_id, 1, "Should find the most similar chunk first");
334 }
335
336 #[test]
337 fn test_save_and_load() {
338 let chunks = vec![make_test_chunk(1, "test", "fn test() {}")];
339 let embeddings = vec![vec![0.5; 4]];
340 let index = SeekrIndex::build_from(&chunks, &embeddings, 4);
341
342 let dir = tempfile::tempdir().unwrap();
343 index.save(dir.path()).unwrap();
344
345 let loaded = SeekrIndex::load(dir.path()).unwrap();
346 assert_eq!(loaded.chunk_count, 1);
347 assert_eq!(loaded.version, INDEX_VERSION);
348 }
349
350 #[test]
351 fn test_tokenize_for_index() {
352 let tokens = tokenize_for_index("fn authenticate_user(username: &str) -> Result<String>");
353 assert!(tokens.contains(&"fn".to_string()));
354 assert!(tokens.contains(&"authenticate_user".to_string()));
355 assert!(tokens.contains(&"username".to_string()));
356 assert!(tokens.contains(&"result".to_string()));
357 assert!(tokens.contains(&"string".to_string()));
358 }
359}