1use std::collections::HashMap;
8use std::path::Path;
9
10use crate::INDEX_VERSION;
11use crate::error::IndexError;
12use crate::index::{IndexEntry, SearchHit};
13use crate::parser::CodeChunk;
14
15#[derive(Clone, Debug)]
24struct EmbeddingPoint(Vec<f32>);
25
26impl instant_distance::Point for EmbeddingPoint {
27 fn distance(&self, other: &Self) -> f32 {
28 let dot: f32 = self.0.iter().zip(other.0.iter()).map(|(a, b)| a * b).sum();
31 1.0 - dot
32 }
33}
34
35#[derive(serde::Serialize, serde::Deserialize)]
40pub struct SeekrIndex {
41 pub version: u32,
43
44 pub vectors: HashMap<u64, Vec<f32>>,
46
47 pub inverted_index: HashMap<String, Vec<(u64, u32)>>,
49
50 pub chunks: HashMap<u64, CodeChunk>,
52
53 pub embedding_dim: usize,
55
56 pub chunk_count: usize,
58
59 #[serde(skip)]
62 hnsw: Option<instant_distance::HnswMap<EmbeddingPoint, u64>>,
63}
64
65impl std::fmt::Debug for SeekrIndex {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 f.debug_struct("SeekrIndex")
68 .field("version", &self.version)
69 .field("embedding_dim", &self.embedding_dim)
70 .field("chunk_count", &self.chunk_count)
71 .field("vectors_len", &self.vectors.len())
72 .field("hnsw", &self.hnsw.as_ref().map(|_| "Some(<HnswMap>)"))
73 .finish()
74 }
75}
76
77impl SeekrIndex {
78 pub fn new(embedding_dim: usize) -> Self {
80 Self {
81 version: INDEX_VERSION,
82 vectors: HashMap::new(),
83 inverted_index: HashMap::new(),
84 chunks: HashMap::new(),
85 embedding_dim,
86 chunk_count: 0,
87 hnsw: None,
88 }
89 }
90
91 pub fn add_entry(&mut self, entry: IndexEntry, chunk: CodeChunk) {
93 let chunk_id = entry.chunk_id;
94
95 self.vectors.insert(chunk_id, entry.embedding);
97
98 for token in &entry.text_tokens {
100 let posting_list = self.inverted_index.entry(token.clone()).or_default();
101 if let Some(existing) = posting_list.iter_mut().find(|(id, _)| *id == chunk_id) {
102 existing.1 += 1;
103 } else {
104 posting_list.push((chunk_id, 1));
105 }
106 }
107
108 self.chunks.insert(chunk_id, chunk);
110 self.chunk_count = self.chunks.len();
111 }
112
113 pub fn remove_chunk(&mut self, chunk_id: u64) {
117 self.vectors.remove(&chunk_id);
119
120 self.inverted_index.retain(|_token, posting_list| {
122 posting_list.retain(|(id, _)| *id != chunk_id);
123 !posting_list.is_empty()
124 });
125
126 self.chunks.remove(&chunk_id);
128 self.chunk_count = self.chunks.len();
129 }
130
131 pub fn remove_chunks(&mut self, chunk_ids: &[u64]) {
133 for &chunk_id in chunk_ids {
134 self.remove_chunk(chunk_id);
135 }
136 }
137
138 pub fn build_from(chunks: &[CodeChunk], embeddings: &[Vec<f32>], embedding_dim: usize) -> Self {
142 let mut index = Self::new(embedding_dim);
143
144 for (chunk, embedding) in chunks.iter().zip(embeddings.iter()) {
145 let text_tokens = tokenize_for_index(&chunk.body);
146
147 let entry = IndexEntry {
148 chunk_id: chunk.id,
149 embedding: embedding.clone(),
150 text_tokens,
151 };
152
153 index.add_entry(entry, chunk.clone());
154 }
155
156 index.rebuild_hnsw();
158
159 index
160 }
161
162 pub fn rebuild_hnsw(&mut self) {
166 if self.vectors.is_empty() {
167 self.hnsw = None;
168 return;
169 }
170
171 let mut points = Vec::with_capacity(self.vectors.len());
172 let mut values = Vec::with_capacity(self.vectors.len());
173
174 for (&chunk_id, embedding) in &self.vectors {
175 points.push(EmbeddingPoint(embedding.clone()));
176 values.push(chunk_id);
177 }
178
179 let hnsw_map = instant_distance::Builder::default().build(points, values);
180 self.hnsw = Some(hnsw_map);
181
182 tracing::debug!(chunks = self.vectors.len(), "HNSW graph built");
183 }
184
185 pub fn search_vector(
193 &self,
194 query_embedding: &[f32],
195 top_k: usize,
196 score_threshold: f32,
197 ) -> Vec<SearchHit> {
198 if let Some(ref hnsw) = self.hnsw {
199 self.search_vector_hnsw(hnsw, query_embedding, top_k, score_threshold)
201 } else {
202 self.search_vector_brute_force(query_embedding, top_k, score_threshold)
204 }
205 }
206
207 fn search_vector_hnsw(
209 &self,
210 hnsw: &instant_distance::HnswMap<EmbeddingPoint, u64>,
211 query_embedding: &[f32],
212 top_k: usize,
213 score_threshold: f32,
214 ) -> Vec<SearchHit> {
215 let query_point = EmbeddingPoint(query_embedding.to_vec());
216 let mut search = instant_distance::Search::default();
217
218 let results: Vec<SearchHit> = hnsw
219 .search(&query_point, &mut search)
220 .take(top_k)
221 .filter_map(|item| {
222 let chunk_id = *item.value;
223 let score = 1.0 - item.distance;
225 if score >= score_threshold {
226 Some(SearchHit { chunk_id, score })
227 } else {
228 None
229 }
230 })
231 .collect();
232
233 results
234 }
235
236 fn search_vector_brute_force(
240 &self,
241 query_embedding: &[f32],
242 top_k: usize,
243 score_threshold: f32,
244 ) -> Vec<SearchHit> {
245 let mut scores: Vec<(u64, f32)> = self
246 .vectors
247 .iter()
248 .map(|(&chunk_id, embedding)| {
249 let score = cosine_similarity(query_embedding, embedding);
250 (chunk_id, score)
251 })
252 .filter(|(_, score)| *score >= score_threshold)
253 .collect();
254
255 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
257
258 scores
259 .into_iter()
260 .take(top_k)
261 .map(|(chunk_id, score)| SearchHit { chunk_id, score })
262 .collect()
263 }
264
265 pub fn search_text(&self, query: &str, top_k: usize) -> Vec<SearchHit> {
269 let query_tokens = tokenize_for_index(query);
270
271 if query_tokens.is_empty() {
272 return Vec::new();
273 }
274
275 let mut scores: HashMap<u64, f32> = HashMap::new();
277
278 for token in &query_tokens {
279 if let Some(posting_list) = self.inverted_index.get(token) {
280 for &(chunk_id, frequency) in posting_list {
281 *scores.entry(chunk_id).or_default() += frequency as f32;
282 }
283 }
284 }
285
286 let num_tokens = query_tokens.len() as f32;
288 let mut results: Vec<(u64, f32)> = scores
289 .into_iter()
290 .map(|(id, score)| (id, score / num_tokens))
291 .collect();
292
293 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
294
295 results
296 .into_iter()
297 .take(top_k)
298 .map(|(chunk_id, score)| SearchHit { chunk_id, score })
299 .collect()
300 }
301
302 pub fn get_chunk(&self, chunk_id: u64) -> Option<&CodeChunk> {
304 self.chunks.get(&chunk_id)
305 }
306
307 pub fn save(&self, dir: &Path) -> Result<(), IndexError> {
311 std::fs::create_dir_all(dir)?;
312
313 let index_path = dir.join("index.bin");
314 let data =
315 bincode::serialize(self).map_err(|e| IndexError::Serialization(e.to_string()))?;
316 std::fs::write(&index_path, data)?;
317
318 let old_json_path = dir.join("index.json");
320 if old_json_path.exists() {
321 let _ = std::fs::remove_file(&old_json_path);
322 }
323
324 tracing::info!(
325 chunks = self.chunk_count,
326 path = %dir.display(),
327 "Index saved (bincode v2)"
328 );
329
330 Ok(())
331 }
332
333 pub fn load(dir: &Path) -> Result<Self, IndexError> {
338 let bin_path = dir.join("index.bin");
339 let json_path = dir.join("index.json");
340
341 let mut index: SeekrIndex = if bin_path.exists() {
342 let data = std::fs::read(&bin_path)?;
344 bincode::deserialize(&data).map_err(|e| IndexError::Serialization(e.to_string()))?
345 } else if json_path.exists() {
346 let data = std::fs::read(&json_path)?;
348 serde_json::from_slice(&data).map_err(|e| IndexError::Serialization(e.to_string()))?
349 } else {
350 return Err(IndexError::NotFound(bin_path));
351 };
352
353 if index.version != INDEX_VERSION {
355 return Err(IndexError::VersionMismatch {
356 file_version: index.version,
357 expected_version: INDEX_VERSION,
358 });
359 }
360
361 index.rebuild_hnsw();
363
364 tracing::info!(
365 chunks = index.chunk_count,
366 path = %dir.display(),
367 "Index loaded (HNSW rebuilt)"
368 );
369
370 Ok(index)
371 }
372}
373
374pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
376 if a.len() != b.len() || a.is_empty() {
377 return 0.0;
378 }
379
380 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
381 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
382 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
383
384 if norm_a == 0.0 || norm_b == 0.0 {
385 return 0.0;
386 }
387
388 dot / (norm_a * norm_b)
389}
390
391fn tokenize_for_index(text: &str) -> Vec<String> {
395 text.split(|c: char| !c.is_alphanumeric() && c != '_')
396 .map(|s| s.to_lowercase())
397 .filter(|s| s.len() >= 2)
398 .collect()
399}
400
401pub fn tokenize_for_index_pub(text: &str) -> Vec<String> {
403 tokenize_for_index(text)
404}
405
406#[cfg(test)]
407mod tests {
408 use super::*;
409 use crate::parser::ChunkKind;
410 use std::path::PathBuf;
411
412 fn make_test_chunk(id: u64, name: &str, body: &str) -> CodeChunk {
413 CodeChunk {
414 id,
415 file_path: PathBuf::from("test.rs"),
416 language: "rust".to_string(),
417 kind: ChunkKind::Function,
418 name: Some(name.to_string()),
419 signature: None,
420 doc_comment: None,
421 body: body.to_string(),
422 byte_range: 0..body.len(),
423 line_range: 0..1,
424 }
425 }
426
427 #[test]
428 fn test_cosine_similarity() {
429 let a = vec![1.0, 0.0, 0.0];
430 let b = vec![0.0, 1.0, 0.0];
431 assert!((cosine_similarity(&a, &b)).abs() < 0.01);
432
433 let c = vec![1.0, 0.0, 0.0];
434 assert!((cosine_similarity(&a, &c) - 1.0).abs() < 0.01);
435 }
436
437 #[test]
438 fn test_build_and_search_text() {
439 let chunks = vec![
440 make_test_chunk(
441 1,
442 "authenticate",
443 "fn authenticate(user: &str, password: &str) -> Result<Token, Error>",
444 ),
445 make_test_chunk(2, "calculate", "fn calculate_total(items: &[Item]) -> f64"),
446 ];
447 let embeddings = vec![vec![0.1; 8], vec![0.2; 8]];
448
449 let index = SeekrIndex::build_from(&chunks, &embeddings, 8);
450
451 assert_eq!(index.chunk_count, 2);
452
453 let results = index.search_text("authenticate user password", 10);
455 assert!(!results.is_empty());
456 assert_eq!(results[0].chunk_id, 1);
457 }
458
459 #[test]
460 fn test_build_and_search_vector() {
461 let chunks = vec![
462 make_test_chunk(1, "foo", "fn foo()"),
463 make_test_chunk(2, "bar", "fn bar()"),
464 ];
465 let embeddings = vec![vec![1.0, 0.0, 0.0], vec![0.0, 1.0, 0.0]];
466
467 let index = SeekrIndex::build_from(&chunks, &embeddings, 3);
468
469 let query = vec![0.9, 0.1, 0.0];
471 let results = index.search_vector(&query, 2, 0.0);
472 assert!(!results.is_empty());
473 assert_eq!(
474 results[0].chunk_id, 1,
475 "Should find the most similar chunk first"
476 );
477 }
478
479 #[test]
480 fn test_save_and_load() {
481 let chunks = vec![make_test_chunk(1, "test", "fn test() {}")];
482 let embeddings = vec![vec![0.5; 4]];
483 let index = SeekrIndex::build_from(&chunks, &embeddings, 4);
484
485 let dir = tempfile::tempdir().unwrap();
486 index.save(dir.path()).unwrap();
487
488 let loaded = SeekrIndex::load(dir.path()).unwrap();
489 assert_eq!(loaded.chunk_count, 1);
490 assert_eq!(loaded.version, INDEX_VERSION);
491 }
492
493 #[test]
494 fn test_tokenize_for_index() {
495 let tokens = tokenize_for_index("fn authenticate_user(username: &str) -> Result<String>");
496 assert!(tokens.contains(&"fn".to_string()));
497 assert!(tokens.contains(&"authenticate_user".to_string()));
498 assert!(tokens.contains(&"username".to_string()));
499 assert!(tokens.contains(&"result".to_string()));
500 assert!(tokens.contains(&"string".to_string()));
501 }
502}