1use std::collections::HashMap;
8use std::path::Path;
9
10use crate::error::IndexError;
11use crate::index::{IndexEntry, SearchHit};
12use crate::parser::CodeChunk;
13use crate::INDEX_VERSION;
14
15#[derive(Clone, Debug)]
24struct EmbeddingPoint(Vec<f32>);
25
26impl instant_distance::Point for EmbeddingPoint {
27 fn distance(&self, other: &Self) -> f32 {
28 let dot: f32 = self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum();
31 1.0 - dot
32 }
33}
34
35#[derive(serde::Serialize, serde::Deserialize)]
40pub struct SeekrIndex {
41 pub version: u32,
43
44 pub vectors: HashMap<u64, Vec<f32>>,
46
47 pub inverted_index: HashMap<String, Vec<(u64, u32)>>,
49
50 pub chunks: HashMap<u64, CodeChunk>,
52
53 pub embedding_dim: usize,
55
56 pub chunk_count: usize,
58
59 #[serde(skip)]
62 hnsw: Option<instant_distance::HnswMap<EmbeddingPoint, u64>>,
63}
64
65impl std::fmt::Debug for SeekrIndex {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("SeekrIndex")
68 .field("version", &self.version)
69 .field("embedding_dim", &self.embedding_dim)
70 .field("chunk_count", &self.chunk_count)
71 .field("vectors_len", &self.vectors.len())
72 .field("hnsw", &self.hnsw.as_ref().map(|_| "Some(<HnswMap>)"))
73 .finish()
74 }
75}
76
77impl SeekrIndex {
78 pub fn new(embedding_dim: usize) -> Self {
80 Self {
81 version: INDEX_VERSION,
82 vectors: HashMap::new(),
83 inverted_index: HashMap::new(),
84 chunks: HashMap::new(),
85 embedding_dim,
86 chunk_count: 0,
87 hnsw: None,
88 }
89 }
90
91 pub fn add_entry(&mut self, entry: IndexEntry, chunk: CodeChunk) {
93 let chunk_id = entry.chunk_id;
94
95 self.vectors.insert(chunk_id, entry.embedding);
97
98 for token in &entry.text_tokens {
100 let posting_list = self.inverted_index.entry(token.clone()).or_default();
101 if let Some(existing) = posting_list.iter_mut().find(|(id, _)| *id == chunk_id) {
102 existing.1 += 1;
103 } else {
104 posting_list.push((chunk_id, 1));
105 }
106 }
107
108 self.chunks.insert(chunk_id, chunk);
110 self.chunk_count = self.chunks.len();
111 }
112
113 pub fn remove_chunk(&mut self, chunk_id: u64) {
117 self.vectors.remove(&chunk_id);
119
120 self.inverted_index.retain(|_token, posting_list| {
122 posting_list.retain(|(id, _)| *id != chunk_id);
123 !posting_list.is_empty()
124 });
125
126 self.chunks.remove(&chunk_id);
128 self.chunk_count = self.chunks.len();
129 }
130
131 pub fn remove_chunks(&mut self, chunk_ids: &[u64]) {
133 for &chunk_id in chunk_ids {
134 self.remove_chunk(chunk_id);
135 }
136 }
137
138 pub fn build_from(
142 chunks: &[CodeChunk],
143 embeddings: &[Vec<f32>],
144 embedding_dim: usize,
145 ) -> Self {
146 let mut index = Self::new(embedding_dim);
147
148 for (chunk, embedding) in chunks.iter().zip(embeddings.iter()) {
149 let text_tokens = tokenize_for_index(&chunk.body);
150
151 let entry = IndexEntry {
152 chunk_id: chunk.id,
153 embedding: embedding.clone(),
154 text_tokens,
155 };
156
157 index.add_entry(entry, chunk.clone());
158 }
159
160 index.rebuild_hnsw();
162
163 index
164 }
165
166 pub fn rebuild_hnsw(&mut self) {
170 if self.vectors.is_empty() {
171 self.hnsw = None;
172 return;
173 }
174
175 let mut points = Vec::with_capacity(self.vectors.len());
176 let mut values = Vec::with_capacity(self.vectors.len());
177
178 for (&chunk_id, embedding) in &self.vectors {
179 points.push(EmbeddingPoint(embedding.clone()));
180 values.push(chunk_id);
181 }
182
183 let hnsw_map = instant_distance::Builder::default().build(points, values);
184 self.hnsw = Some(hnsw_map);
185
186 tracing::debug!(
187 chunks = self.vectors.len(),
188 "HNSW graph built"
189 );
190 }
191
192 pub fn search_vector(
200 &self,
201 query_embedding: &[f32],
202 top_k: usize,
203 score_threshold: f32,
204 ) -> Vec<SearchHit> {
205 if let Some(ref hnsw) = self.hnsw {
206 self.search_vector_hnsw(hnsw, query_embedding, top_k, score_threshold)
208 } else {
209 self.search_vector_brute_force(query_embedding, top_k, score_threshold)
211 }
212 }
213
214 fn search_vector_hnsw(
216 &self,
217 hnsw: &instant_distance::HnswMap<EmbeddingPoint, u64>,
218 query_embedding: &[f32],
219 top_k: usize,
220 score_threshold: f32,
221 ) -> Vec<SearchHit> {
222 let query_point = EmbeddingPoint(query_embedding.to_vec());
223 let mut search = instant_distance::Search::default();
224
225 let results: Vec<SearchHit> = hnsw
226 .search(&query_point, &mut search)
227 .take(top_k)
228 .filter_map(|item| {
229 let chunk_id = *item.value;
230 let score = 1.0 - item.distance;
232 if score >= score_threshold {
233 Some(SearchHit { chunk_id, score })
234 } else {
235 None
236 }
237 })
238 .collect();
239
240 results
241 }
242
243 fn search_vector_brute_force(
247 &self,
248 query_embedding: &[f32],
249 top_k: usize,
250 score_threshold: f32,
251 ) -> Vec<SearchHit> {
252 let mut scores: Vec<(u64, f32)> = self
253 .vectors
254 .iter()
255 .map(|(&chunk_id, embedding)| {
256 let score = cosine_similarity(query_embedding, embedding);
257 (chunk_id, score)
258 })
259 .filter(|(_, score)| *score >= score_threshold)
260 .collect();
261
262 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
264
265 scores
266 .into_iter()
267 .take(top_k)
268 .map(|(chunk_id, score)| SearchHit { chunk_id, score })
269 .collect()
270 }
271
272 pub fn search_text(&self, query: &str, top_k: usize) -> Vec<SearchHit> {
276 let query_tokens = tokenize_for_index(query);
277
278 if query_tokens.is_empty() {
279 return Vec::new();
280 }
281
282 let mut scores: HashMap<u64, f32> = HashMap::new();
284
285 for token in &query_tokens {
286 if let Some(posting_list) = self.inverted_index.get(token) {
287 for &(chunk_id, frequency) in posting_list {
288 *scores.entry(chunk_id).or_default() += frequency as f32;
289 }
290 }
291 }
292
293 let num_tokens = query_tokens.len() as f32;
295 let mut results: Vec<(u64, f32)> = scores
296 .into_iter()
297 .map(|(id, score)| (id, score / num_tokens))
298 .collect();
299
300 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
301
302 results
303 .into_iter()
304 .take(top_k)
305 .map(|(chunk_id, score)| SearchHit { chunk_id, score })
306 .collect()
307 }
308
309 pub fn get_chunk(&self, chunk_id: u64) -> Option<&CodeChunk> {
311 self.chunks.get(&chunk_id)
312 }
313
314 pub fn save(&self, dir: &Path) -> Result<(), IndexError> {
318 std::fs::create_dir_all(dir)?;
319
320 let index_path = dir.join("index.bin");
321 let data = bincode::serialize(self)
322 .map_err(|e| IndexError::Serialization(e.to_string()))?;
323 std::fs::write(&index_path, data)?;
324
325 let old_json_path = dir.join("index.json");
327 if old_json_path.exists() {
328 let _ = std::fs::remove_file(&old_json_path);
329 }
330
331 tracing::info!(
332 chunks = self.chunk_count,
333 path = %dir.display(),
334 "Index saved (bincode v2)"
335 );
336
337 Ok(())
338 }
339
340 pub fn load(dir: &Path) -> Result<Self, IndexError> {
345 let bin_path = dir.join("index.bin");
346 let json_path = dir.join("index.json");
347
348 let mut index: SeekrIndex = if bin_path.exists() {
349 let data = std::fs::read(&bin_path)?;
351 bincode::deserialize(&data)
352 .map_err(|e| IndexError::Serialization(e.to_string()))?
353 } else if json_path.exists() {
354 let data = std::fs::read(&json_path)?;
356 serde_json::from_slice(&data)
357 .map_err(|e| IndexError::Serialization(e.to_string()))?
358 } else {
359 return Err(IndexError::NotFound(bin_path));
360 };
361
362 if index.version != INDEX_VERSION {
364 return Err(IndexError::VersionMismatch {
365 file_version: index.version,
366 expected_version: INDEX_VERSION,
367 });
368 }
369
370 index.rebuild_hnsw();
372
373 tracing::info!(
374 chunks = index.chunk_count,
375 path = %dir.display(),
376 "Index loaded (HNSW rebuilt)"
377 );
378
379 Ok(index)
380 }
381}
382
383pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
385 if a.len() != b.len() || a.is_empty() {
386 return 0.0;
387 }
388
389 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
390 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
391 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
392
393 if norm_a == 0.0 || norm_b == 0.0 {
394 return 0.0;
395 }
396
397 dot / (norm_a * norm_b)
398}
399
400fn tokenize_for_index(text: &str) -> Vec<String> {
404 text.split(|c: char| !c.is_alphanumeric() && c != '_')
405 .map(|s| s.to_lowercase())
406 .filter(|s| s.len() >= 2)
407 .collect()
408}
409
410pub fn tokenize_for_index_pub(text: &str) -> Vec<String> {
412 tokenize_for_index(text)
413}
414
415#[cfg(test)]
416mod tests {
417 use super::*;
418 use crate::parser::ChunkKind;
419 use std::path::PathBuf;
420
421 fn make_test_chunk(id: u64, name: &str, body: &str) -> CodeChunk {
422 CodeChunk {
423 id,
424 file_path: PathBuf::from("test.rs"),
425 language: "rust".to_string(),
426 kind: ChunkKind::Function,
427 name: Some(name.to_string()),
428 signature: None,
429 doc_comment: None,
430 body: body.to_string(),
431 byte_range: 0..body.len(),
432 line_range: 0..1,
433 }
434 }
435
436 #[test]
437 fn test_cosine_similarity() {
438 let a = vec![1.0, 0.0, 0.0];
439 let b = vec![0.0, 1.0, 0.0];
440 assert!((cosine_similarity(&a, &b)).abs() < 0.01);
441
442 let c = vec![1.0, 0.0, 0.0];
443 assert!((cosine_similarity(&a, &c) - 1.0).abs() < 0.01);
444 }
445
446 #[test]
447 fn test_build_and_search_text() {
448 let chunks = vec![
449 make_test_chunk(1, "authenticate", "fn authenticate(user: &str, password: &str) -> Result<Token, Error>"),
450 make_test_chunk(2, "calculate", "fn calculate_total(items: &[Item]) -> f64"),
451 ];
452 let embeddings = vec![vec![0.1; 8], vec![0.2; 8]];
453
454 let index = SeekrIndex::build_from(&chunks, &embeddings, 8);
455
456 assert_eq!(index.chunk_count, 2);
457
458 let results = index.search_text("authenticate user password", 10);
460 assert!(!results.is_empty());
461 assert_eq!(results[0].chunk_id, 1);
462 }
463
464 #[test]
465 fn test_build_and_search_vector() {
466 let chunks = vec![
467 make_test_chunk(1, "foo", "fn foo()"),
468 make_test_chunk(2, "bar", "fn bar()"),
469 ];
470 let embeddings = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
471
472 let index = SeekrIndex::build_from(&chunks, &embeddings, 3);
473
474 let query = vec![0.9, 0.1, 0.0];
476 let results = index.search_vector(&query, 2, 0.0);
477 assert!(!results.is_empty());
478 assert_eq!(results[0].chunk_id, 1, "Should find the most similar chunk first");
479 }
480
481 #[test]
482 fn test_save_and_load() {
483 let chunks = vec![make_test_chunk(1, "test", "fn test() {}")];
484 let embeddings = vec![vec![0.5; 4]];
485 let index = SeekrIndex::build_from(&chunks, &embeddings, 4);
486
487 let dir = tempfile::tempdir().unwrap();
488 index.save(dir.path()).unwrap();
489
490 let loaded = SeekrIndex::load(dir.path()).unwrap();
491 assert_eq!(loaded.chunk_count, 1);
492 assert_eq!(loaded.version, INDEX_VERSION);
493 }
494
495 #[test]
496 fn test_tokenize_for_index() {
497 let tokens = tokenize_for_index("fn authenticate_user(username: &str) -> Result<String>");
498 assert!(tokens.contains(&"fn".to_string()));
499 assert!(tokens.contains(&"authenticate_user".to_string()));
500 assert!(tokens.contains(&"username".to_string()));
501 assert!(tokens.contains(&"result".to_string()));
502 assert!(tokens.contains(&"string".to_string()));
503 }
504}