Skip to main content

trueno_rag/
index.rs

1//! Indexing for RAG pipelines (BM25 sparse index and vector store)
2
3use crate::{Chunk, ChunkId, Error, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6
7/// Default embedding dimension (all-MiniLM-L6-v2 / BGE-small-en-v1.5)
8const DEFAULT_EMBEDDING_DIM: usize = 384;
9
10/// Sparse index trait for lexical retrieval
11pub trait SparseIndex: Send + Sync {
12    /// Index a chunk
13    fn add(&mut self, chunk: &Chunk);
14
15    /// Index multiple chunks
16    fn add_batch(&mut self, chunks: &[Chunk]);
17
18    /// Search for matching chunks
19    fn search(&self, query: &str, k: usize) -> Vec<(ChunkId, f32)>;
20
21    /// Remove a chunk from the index
22    fn remove(&mut self, chunk_id: ChunkId);
23
24    /// Get the number of indexed documents
25    fn len(&self) -> usize;
26
27    /// Check if the index is empty
28    fn is_empty(&self) -> bool {
29        self.len() == 0
30    }
31}
32
33/// BM25 index implementation
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BM25Index {
36    /// Inverted index: term -> [(chunk_id, term_freq)]
37    inverted_index: HashMap<String, Vec<(ChunkId, u32)>>,
38    /// Document frequencies: term -> doc count
39    doc_freqs: HashMap<String, u32>,
40    /// Document lengths: chunk_id -> length
41    doc_lengths: HashMap<ChunkId, u32>,
42    /// Average document length
43    avg_doc_length: f32,
44    /// Total document count
45    doc_count: u32,
46    /// BM25 k1 parameter (term frequency saturation)
47    k1: f32,
48    /// BM25 b parameter (length normalization)
49    b: f32,
50    /// Tokenizer settings
51    lowercase: bool,
52    /// Stopwords
53    stopwords: HashSet<String>,
54}
55
56impl Default for BM25Index {
57    fn default() -> Self {
58        Self::new()
59    }
60}
61
62impl BM25Index {
63    /// Create a new BM25 index with default parameters
64    #[must_use]
65    pub fn new() -> Self {
66        Self {
67            inverted_index: HashMap::new(),
68            doc_freqs: HashMap::new(),
69            doc_lengths: HashMap::new(),
70            avg_doc_length: 0.0,
71            doc_count: 0,
72            k1: 1.2,
73            b: 0.75,
74            lowercase: true,
75            stopwords: Self::default_stopwords(),
76        }
77    }
78
79    /// Create with custom BM25 parameters
80    #[must_use]
81    pub fn with_params(k1: f32, b: f32) -> Self {
82        Self { k1, b, ..Self::new() }
83    }
84
85    /// Set stopwords
86    #[must_use]
87    pub fn with_stopwords(mut self, stopwords: HashSet<String>) -> Self {
88        self.stopwords = stopwords;
89        self
90    }
91
92    fn default_stopwords() -> HashSet<String> {
93        [
94            "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
95            "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
96            "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
97            "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
98            "below", "between", "under", "again", "further", "then", "once", "here", "there",
99            "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
100            "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just",
101            "and", "but", "if", "or", "because", "until", "while", "this", "that", "these",
102            "those", "it", "its",
103        ]
104        .iter()
105        .map(|s| (*s).to_string())
106        .collect()
107    }
108
109    /// Tokenize text
110    pub fn tokenize(&self, text: &str) -> Vec<String> {
111        text.split(|c: char| !c.is_alphanumeric())
112            .filter(|s| !s.is_empty())
113            .map(|s| if self.lowercase { s.to_lowercase() } else { s.to_string() })
114            .filter(|s| !self.stopwords.contains(s))
115            .filter(|s| s.len() >= 2) // Filter very short tokens
116            .collect()
117    }
118
119    /// Compute term frequency in a document
120    fn term_frequency(&self, term: &str, chunk_id: ChunkId) -> u32 {
121        self.inverted_index
122            .get(term)
123            .and_then(|postings| postings.iter().find(|(id, _)| *id == chunk_id))
124            .map(|(_, freq)| *freq)
125            .unwrap_or(0)
126    }
127
128    /// Compute BM25 score for a single term
129    fn score_term(&self, term: &str, chunk_id: ChunkId) -> f32 {
130        let tf = self.term_frequency(term, chunk_id) as f32;
131        if tf == 0.0 {
132            return 0.0;
133        }
134
135        let df = self.doc_freqs.get(term).copied().unwrap_or(0) as f32;
136        let n = self.doc_count as f32;
137        let doc_len = self.doc_lengths.get(&chunk_id).copied().unwrap_or(0) as f32;
138
139        // IDF component: log((N - df + 0.5) / (df + 0.5) + 1)
140        let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).max(f32::EPSILON).ln();
141
142        // TF component with length normalization
143        let tf_norm = (tf * (self.k1 + 1.0))
144            / (tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length));
145
146        idf * tf_norm
147    }
148
149    /// Update average document length
150    fn update_avg_doc_length(&mut self) {
151        if self.doc_count == 0 {
152            self.avg_doc_length = 0.0;
153        } else {
154            let total: u32 = self.doc_lengths.values().sum();
155            self.avg_doc_length = total as f32 / self.doc_count as f32;
156        }
157    }
158
159    /// Get chunks containing a term
160    fn get_chunks_for_term(&self, term: &str) -> Vec<ChunkId> {
161        self.inverted_index
162            .get(term)
163            .map(|postings| postings.iter().map(|(id, _)| *id).collect())
164            .unwrap_or_default()
165    }
166}
167
168impl SparseIndex for BM25Index {
169    fn add(&mut self, chunk: &Chunk) {
170        let tokens = self.tokenize(&chunk.content);
171        let doc_len = tokens.len() as u32;
172
173        // Update document length
174        self.doc_lengths.insert(chunk.id, doc_len);
175        self.doc_count += 1;
176
177        // Count term frequencies
178        let mut term_freqs: HashMap<String, u32> = HashMap::new();
179        for token in &tokens {
180            *term_freqs.entry(token.clone()).or_insert(0) += 1;
181        }
182
183        // Update inverted index and document frequencies
184        let mut seen_terms: HashSet<String> = HashSet::new();
185        for (term, freq) in term_freqs {
186            self.inverted_index.entry(term.clone()).or_default().push((chunk.id, freq));
187
188            if seen_terms.insert(term.clone()) {
189                *self.doc_freqs.entry(term).or_insert(0) += 1;
190            }
191        }
192
193        self.update_avg_doc_length();
194    }
195
196    fn add_batch(&mut self, chunks: &[Chunk]) {
197        for chunk in chunks {
198            self.add(chunk);
199        }
200    }
201
202    fn search(&self, query: &str, k: usize) -> Vec<(ChunkId, f32)> {
203        let query_terms = self.tokenize(query);
204        if query_terms.is_empty() {
205            return Vec::new();
206        }
207
208        // Collect candidate documents
209        let mut candidates: HashSet<ChunkId> = HashSet::new();
210        for term in &query_terms {
211            for chunk_id in self.get_chunks_for_term(term) {
212                candidates.insert(chunk_id);
213            }
214        }
215
216        // Score candidates
217        let mut scores: Vec<(ChunkId, f32)> = candidates
218            .into_iter()
219            .map(|chunk_id| {
220                let score: f32 =
221                    query_terms.iter().map(|term| self.score_term(term, chunk_id)).sum();
222                (chunk_id, score)
223            })
224            .filter(|(_, score)| *score > 0.0)
225            .collect();
226
227        // Sort by score descending
228        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
229        scores.truncate(k);
230        scores
231    }
232
233    fn remove(&mut self, chunk_id: ChunkId) {
234        // Remove from document lengths
235        if self.doc_lengths.remove(&chunk_id).is_some() {
236            self.doc_count = self.doc_count.saturating_sub(1);
237        }
238
239        // Remove from inverted index
240        let mut terms_to_remove: Vec<String> = Vec::new();
241        for (term, postings) in &mut self.inverted_index {
242            let original_len = postings.len();
243            postings.retain(|(id, _)| *id != chunk_id);
244
245            if postings.len() < original_len {
246                // Document contained this term
247                if let Some(df) = self.doc_freqs.get_mut(term) {
248                    *df = df.saturating_sub(1);
249                    if *df == 0 {
250                        terms_to_remove.push(term.clone());
251                    }
252                }
253            }
254        }
255
256        // Clean up empty terms
257        for term in terms_to_remove {
258            self.inverted_index.remove(&term);
259            self.doc_freqs.remove(&term);
260        }
261
262        self.update_avg_doc_length();
263    }
264
265    fn len(&self) -> usize {
266        self.doc_count as usize
267    }
268}
269
270/// Vector store configuration
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct VectorStoreConfig {
273    /// Embedding dimension
274    pub dimension: usize,
275    /// Distance metric
276    pub metric: DistanceMetric,
277    /// HNSW M parameter (connections per node)
278    pub hnsw_m: usize,
279    /// HNSW ef_construction parameter
280    pub hnsw_ef_construction: usize,
281    /// HNSW ef_search parameter
282    pub hnsw_ef_search: usize,
283}
284
285impl Default for VectorStoreConfig {
286    fn default() -> Self {
287        Self {
288            dimension: DEFAULT_EMBEDDING_DIM,
289            metric: DistanceMetric::Cosine,
290            hnsw_m: 16,
291            hnsw_ef_construction: 100,
292            hnsw_ef_search: 50,
293        }
294    }
295}
296
297/// Distance metric for vector search
298#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
299pub enum DistanceMetric {
300    /// Cosine similarity
301    #[default]
302    Cosine,
303    /// Euclidean distance
304    Euclidean,
305    /// Dot product
306    DotProduct,
307}
308
309/// Vector store for dense retrieval
310#[derive(Debug, Clone)]
311pub struct VectorStore {
312    /// Configuration
313    config: VectorStoreConfig,
314    /// Stored vectors: chunk_id -> embedding
315    vectors: HashMap<ChunkId, Vec<f32>>,
316    /// Chunk content cache: chunk_id -> content
317    chunks: HashMap<ChunkId, Chunk>,
318}
319
320impl VectorStore {
321    /// Create a new vector store
322    #[must_use]
323    pub fn new(config: VectorStoreConfig) -> Self {
324        Self { config, vectors: HashMap::new(), chunks: HashMap::new() }
325    }
326
327    /// Create with default configuration
328    #[must_use]
329    pub fn with_dimension(dimension: usize) -> Self {
330        Self::new(VectorStoreConfig { dimension, ..Default::default() })
331    }
332
333    /// Get the configuration
334    #[must_use]
335    pub fn config(&self) -> &VectorStoreConfig {
336        &self.config
337    }
338
339    /// Insert a chunk with its embedding
340    pub fn insert(&mut self, chunk: Chunk) -> Result<()> {
341        let embedding = chunk
342            .embedding
343            .as_ref()
344            .ok_or_else(|| Error::InvalidConfig("chunk must have embedding".to_string()))?;
345
346        if embedding.len() != self.config.dimension {
347            return Err(Error::DimensionMismatch {
348                expected: self.config.dimension,
349                actual: embedding.len(),
350            });
351        }
352
353        self.vectors.insert(chunk.id, embedding.clone());
354        self.chunks.insert(chunk.id, chunk);
355        Ok(())
356    }
357
358    /// Insert multiple chunks
359    pub fn insert_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
360        for chunk in chunks {
361            self.insert(chunk)?;
362        }
363        Ok(())
364    }
365
366    /// Search for similar vectors
367    pub fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<(ChunkId, f32)>> {
368        if query_vector.len() != self.config.dimension {
369            return Err(Error::DimensionMismatch {
370                expected: self.config.dimension,
371                actual: query_vector.len(),
372            });
373        }
374
375        let mut scores: Vec<(ChunkId, f32)> = self
376            .vectors
377            .iter()
378            .map(|(id, vec)| {
379                let score = match self.config.metric {
380                    DistanceMetric::Cosine => cosine_similarity(query_vector, vec),
381                    DistanceMetric::Euclidean => -euclidean_distance(query_vector, vec),
382                    DistanceMetric::DotProduct => dot_product(query_vector, vec),
383                };
384                (*id, score)
385            })
386            .collect();
387
388        // Sort by score descending (higher is better)
389        scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
390        scores.truncate(k);
391
392        Ok(scores)
393    }
394
395    /// Get a chunk by ID
396    #[must_use]
397    pub fn get(&self, chunk_id: ChunkId) -> Option<&Chunk> {
398        self.chunks.get(&chunk_id)
399    }
400
401    /// Remove a chunk
402    pub fn remove(&mut self, chunk_id: ChunkId) -> Option<Chunk> {
403        self.vectors.remove(&chunk_id);
404        self.chunks.remove(&chunk_id)
405    }
406
407    /// Get the number of stored vectors
408    #[must_use]
409    pub fn len(&self) -> usize {
410        self.vectors.len()
411    }
412
413    /// Check if the store is empty
414    #[must_use]
415    pub fn is_empty(&self) -> bool {
416        self.vectors.is_empty()
417    }
418}
419
420// Helper functions from embed module
421fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
422    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
423    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
424    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
425
426    if norm_a == 0.0 || norm_b == 0.0 {
427        0.0
428    } else {
429        dot / (norm_a * norm_b)
430    }
431}
432
433fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
434    a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
435}
436
437fn dot_product(a: &[f32], b: &[f32]) -> f32 {
438    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use crate::DocumentId;
445
446    fn create_test_chunk(content: &str) -> Chunk {
447        Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
448    }
449
450    fn create_test_chunk_with_embedding(content: &str, embedding: Vec<f32>) -> Chunk {
451        let mut chunk = create_test_chunk(content);
452        chunk.set_embedding(embedding);
453        chunk
454    }
455
456    // ============ BM25Index Tests ============
457
458    #[test]
459    fn test_bm25_index_new() {
460        let index = BM25Index::new();
461        assert_eq!(index.len(), 0);
462        assert!(index.is_empty());
463        assert!((index.k1 - 1.2).abs() < 0.01);
464        assert!((index.b - 0.75).abs() < 0.01);
465    }
466
467    #[test]
468    fn test_bm25_index_with_params() {
469        let index = BM25Index::with_params(1.5, 0.5);
470        assert!((index.k1 - 1.5).abs() < 0.01);
471        assert!((index.b - 0.5).abs() < 0.01);
472    }
473
474    #[test]
475    fn test_bm25_tokenize() {
476        let index = BM25Index::new();
477        let tokens = index.tokenize("Hello World! This is a test.");
478
479        assert!(tokens.contains(&"hello".to_string()));
480        assert!(tokens.contains(&"world".to_string()));
481        assert!(tokens.contains(&"test".to_string()));
482        // Stopwords should be removed
483        assert!(!tokens.contains(&"this".to_string()));
484        assert!(!tokens.contains(&"is".to_string()));
485        assert!(!tokens.contains(&"a".to_string()));
486    }
487
488    #[test]
489    fn test_bm25_tokenize_lowercase() {
490        let index = BM25Index::new();
491        let tokens = index.tokenize("HELLO World");
492        assert!(tokens.contains(&"hello".to_string()));
493        assert!(tokens.contains(&"world".to_string()));
494    }
495
496    #[test]
497    fn test_bm25_add_chunk() {
498        let mut index = BM25Index::new();
499        let chunk = create_test_chunk("Machine learning is fascinating");
500
501        index.add(&chunk);
502
503        assert_eq!(index.len(), 1);
504        assert!(!index.is_empty());
505        assert!(index.inverted_index.contains_key("machine"));
506        assert!(index.inverted_index.contains_key("learning"));
507    }
508
509    #[test]
510    fn test_bm25_add_batch() {
511        let mut index = BM25Index::new();
512        let chunks = vec![
513            create_test_chunk("First document about AI"),
514            create_test_chunk("Second document about ML"),
515            create_test_chunk("Third document about deep learning"),
516        ];
517
518        index.add_batch(&chunks);
519
520        assert_eq!(index.len(), 3);
521    }
522
523    #[test]
524    fn test_bm25_search_basic() {
525        let mut index = BM25Index::new();
526        let chunk1 = create_test_chunk("Machine learning algorithms");
527        let chunk2 = create_test_chunk("Deep learning neural networks");
528        let chunk3 = create_test_chunk("Natural language processing");
529
530        index.add(&chunk1);
531        index.add(&chunk2);
532        index.add(&chunk3);
533
534        let results = index.search("machine learning", 10);
535
536        assert!(!results.is_empty());
537        // Chunk with "machine learning" should score highest
538        assert!(results.iter().any(|(id, _)| *id == chunk1.id));
539    }
540
541    #[test]
542    fn test_bm25_search_empty_query() {
543        let mut index = BM25Index::new();
544        index.add(&create_test_chunk("Test document"));
545
546        let results = index.search("", 10);
547        assert!(results.is_empty());
548    }
549
550    #[test]
551    fn test_bm25_search_stopwords_only() {
552        let mut index = BM25Index::new();
553        index.add(&create_test_chunk("Test document"));
554
555        let results = index.search("the a an", 10);
556        assert!(results.is_empty());
557    }
558
559    #[test]
560    fn test_bm25_search_no_match() {
561        let mut index = BM25Index::new();
562        index.add(&create_test_chunk("Cats and dogs"));
563
564        let results = index.search("quantum physics", 10);
565        assert!(results.is_empty());
566    }
567
568    #[test]
569    fn test_bm25_search_ranking() {
570        let mut index = BM25Index::new();
571
572        // Document with more term matches should rank higher
573        let chunk1 = create_test_chunk("python programming language");
574        let chunk2 = create_test_chunk("python python python programming");
575
576        index.add(&chunk1);
577        index.add(&chunk2);
578
579        let results = index.search("python programming", 10);
580
581        assert_eq!(results.len(), 2);
582        // Chunk2 should rank higher due to more "python" occurrences
583        assert_eq!(results[0].0, chunk2.id);
584    }
585
586    #[test]
587    fn test_bm25_search_top_k() {
588        let mut index = BM25Index::new();
589        for i in 0..10 {
590            index.add(&create_test_chunk(&format!("document {i} about rust")));
591        }
592
593        let results = index.search("rust", 3);
594        assert_eq!(results.len(), 3);
595    }
596
597    #[test]
598    fn test_bm25_remove() {
599        let mut index = BM25Index::new();
600        let chunk = create_test_chunk("Test document");
601        let chunk_id = chunk.id;
602
603        index.add(&chunk);
604        assert_eq!(index.len(), 1);
605
606        index.remove(chunk_id);
607        assert_eq!(index.len(), 0);
608
609        let results = index.search("test", 10);
610        assert!(results.is_empty());
611    }
612
613    #[test]
614    fn test_bm25_avg_doc_length() {
615        let mut index = BM25Index::new();
616
617        index.add(&create_test_chunk("short text")); // ~2 tokens
618        index.add(&create_test_chunk("this is a longer piece of text about programming")); // ~5 tokens
619
620        assert!(index.avg_doc_length > 0.0);
621    }
622
623    #[test]
624    fn test_bm25_idf_calculation() {
625        let mut index = BM25Index::new();
626
627        // Add documents with varying term frequencies
628        index.add(&create_test_chunk("common rare"));
629        index.add(&create_test_chunk("common word"));
630        index.add(&create_test_chunk("common term"));
631
632        // Search for rare term should give higher score
633        let rare_results = index.search("rare", 10);
634        let common_results = index.search("common", 10);
635
636        // "rare" appears in 1 doc, "common" in 3 docs
637        // IDF of "rare" should be higher
638        assert!(!rare_results.is_empty());
639        assert!(!common_results.is_empty());
640    }
641
642    // ============ VectorStore Tests ============
643
644    #[test]
645    fn test_vector_store_new() {
646        let store = VectorStore::with_dimension(384);
647        assert_eq!(store.config().dimension, 384);
648        assert!(store.is_empty());
649    }
650
651    #[test]
652    fn test_vector_store_config() {
653        let config = VectorStoreConfig {
654            dimension: 768,
655            metric: DistanceMetric::DotProduct,
656            hnsw_m: 32,
657            hnsw_ef_construction: 200,
658            hnsw_ef_search: 100,
659        };
660        let store = VectorStore::new(config.clone());
661
662        assert_eq!(store.config().dimension, 768);
663        assert_eq!(store.config().metric, DistanceMetric::DotProduct);
664    }
665
666    #[test]
667    fn test_vector_store_insert() {
668        let mut store = VectorStore::with_dimension(3);
669        let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0, 0.0]);
670
671        store.insert(chunk.clone()).unwrap();
672
673        assert_eq!(store.len(), 1);
674        assert!(!store.is_empty());
675        assert!(store.get(chunk.id).is_some());
676    }
677
678    #[test]
679    fn test_vector_store_insert_no_embedding() {
680        let mut store = VectorStore::with_dimension(3);
681        let chunk = create_test_chunk("no embedding");
682
683        let result = store.insert(chunk);
684        assert!(result.is_err());
685    }
686
687    #[test]
688    fn test_vector_store_insert_wrong_dimension() {
689        let mut store = VectorStore::with_dimension(3);
690        let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0]); // Wrong dimension
691
692        let result = store.insert(chunk);
693        assert!(result.is_err());
694        match result {
695            Err(Error::DimensionMismatch { expected, actual }) => {
696                assert_eq!(expected, 3);
697                assert_eq!(actual, 2);
698            }
699            _ => panic!("Expected DimensionMismatch error"),
700        }
701    }
702
703    #[test]
704    fn test_vector_store_insert_batch() {
705        let mut store = VectorStore::with_dimension(3);
706        let chunks = vec![
707            create_test_chunk_with_embedding("a", vec![1.0, 0.0, 0.0]),
708            create_test_chunk_with_embedding("b", vec![0.0, 1.0, 0.0]),
709            create_test_chunk_with_embedding("c", vec![0.0, 0.0, 1.0]),
710        ];
711
712        store.insert_batch(chunks).unwrap();
713        assert_eq!(store.len(), 3);
714    }
715
716    #[test]
717    fn test_vector_store_search_cosine() {
718        let mut store = VectorStore::with_dimension(3);
719
720        let chunk1 = create_test_chunk_with_embedding("north", vec![1.0, 0.0, 0.0]);
721        let chunk2 = create_test_chunk_with_embedding("east", vec![0.0, 1.0, 0.0]);
722        let chunk3 = create_test_chunk_with_embedding(
723            "diagonal",
724            vec![std::f32::consts::FRAC_1_SQRT_2, std::f32::consts::FRAC_1_SQRT_2, 0.0],
725        );
726
727        let id1 = chunk1.id;
728        let id3 = chunk3.id;
729
730        store.insert(chunk1).unwrap();
731        store.insert(chunk2).unwrap();
732        store.insert(chunk3).unwrap();
733
734        // Search for vector pointing mostly north
735        let query = vec![0.9, 0.1, 0.0];
736        let results = store.search(&query, 10).unwrap();
737
738        assert_eq!(results.len(), 3);
739        // chunk1 (north) should be most similar
740        assert_eq!(results[0].0, id1);
741        // chunk3 (diagonal) should be second
742        assert_eq!(results[1].0, id3);
743    }
744
745    #[test]
746    fn test_vector_store_search_top_k() {
747        let mut store = VectorStore::with_dimension(3);
748
749        for i in 0..10 {
750            let embedding = vec![i as f32, 0.0, 0.0];
751            store
752                .insert(create_test_chunk_with_embedding(&format!("chunk {i}"), embedding))
753                .unwrap();
754        }
755
756        let results = store.search(&[9.0, 0.0, 0.0], 3).unwrap();
757        assert_eq!(results.len(), 3);
758    }
759
760    #[test]
761    fn test_vector_store_search_wrong_dimension() {
762        let store = VectorStore::with_dimension(3);
763        let result = store.search(&[1.0, 0.0], 10);
764        assert!(result.is_err());
765    }
766
767    #[test]
768    fn test_vector_store_remove() {
769        let mut store = VectorStore::with_dimension(3);
770        let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0, 0.0]);
771        let chunk_id = chunk.id;
772
773        store.insert(chunk).unwrap();
774        assert_eq!(store.len(), 1);
775
776        let removed = store.remove(chunk_id);
777        assert!(removed.is_some());
778        assert_eq!(store.len(), 0);
779        assert!(store.get(chunk_id).is_none());
780    }
781
782    #[test]
783    fn test_vector_store_remove_nonexistent() {
784        let mut store = VectorStore::with_dimension(3);
785        let removed = store.remove(ChunkId::new());
786        assert!(removed.is_none());
787    }
788
789    #[test]
790    fn test_distance_metric_euclidean() {
791        let config = VectorStoreConfig {
792            dimension: 2,
793            metric: DistanceMetric::Euclidean,
794            ..Default::default()
795        };
796        let mut store = VectorStore::new(config);
797
798        let chunk1 = create_test_chunk_with_embedding("origin", vec![0.0, 0.0]);
799        let chunk2 = create_test_chunk_with_embedding("near", vec![1.0, 0.0]);
800        let chunk3 = create_test_chunk_with_embedding("far", vec![10.0, 0.0]);
801
802        let id2 = chunk2.id;
803        let id1 = chunk1.id;
804
805        store.insert(chunk1).unwrap();
806        store.insert(chunk2).unwrap();
807        store.insert(chunk3).unwrap();
808
809        // Search from origin - near should be closest
810        let results = store.search(&[0.0, 0.0], 10).unwrap();
811        assert_eq!(results[0].0, id1); // Exact match
812        assert_eq!(results[1].0, id2); // Nearest neighbor
813    }
814
815    #[test]
816    fn test_distance_metric_dot_product() {
817        let config = VectorStoreConfig {
818            dimension: 2,
819            metric: DistanceMetric::DotProduct,
820            ..Default::default()
821        };
822        let mut store = VectorStore::new(config);
823
824        let chunk1 = create_test_chunk_with_embedding("small", vec![1.0, 0.0]);
825        let chunk2 = create_test_chunk_with_embedding("large", vec![10.0, 0.0]);
826
827        let id2 = chunk2.id;
828
829        store.insert(chunk1).unwrap();
830        store.insert(chunk2).unwrap();
831
832        // Dot product prefers larger magnitude vectors
833        let results = store.search(&[1.0, 0.0], 10).unwrap();
834        assert_eq!(results[0].0, id2);
835    }
836
837    // ============ Property-Based Tests ============
838
839    use proptest::prelude::*;
840
841    proptest! {
842        #[test]
843        fn prop_bm25_add_increases_count(content in "[a-zA-Z ]{10,100}") {
844            let mut index = BM25Index::new();
845            let initial = index.len();
846            index.add(&create_test_chunk(&content));
847            prop_assert_eq!(index.len(), initial + 1);
848        }
849
850        #[test]
851        fn prop_bm25_search_results_within_k(
852            content in prop::collection::vec("[a-zA-Z]{3,10}", 5..20),
853            k in 1usize..10
854        ) {
855            let mut index = BM25Index::new();
856            for c in &content {
857                index.add(&create_test_chunk(c));
858            }
859
860            let results = index.search("test", k);
861            prop_assert!(results.len() <= k);
862        }
863
864        #[test]
865        fn prop_bm25_scores_non_negative(
866            docs in prop::collection::vec("[a-zA-Z ]{5,50}", 3..10),
867            query in "[a-zA-Z]{3,10}"
868        ) {
869            let mut index = BM25Index::new();
870            for doc in &docs {
871                index.add(&create_test_chunk(doc));
872            }
873
874            let results = index.search(&query, 100);
875            for (_, score) in results {
876                prop_assert!(score >= 0.0);
877            }
878        }
879
880        #[test]
881        fn prop_vector_store_search_returns_stored(
882            dim in 2usize..10,
883            n_chunks in 1usize..20
884        ) {
885            let mut store = VectorStore::with_dimension(dim);
886            let mut ids = Vec::new();
887
888            for i in 0..n_chunks {
889                let mut embedding = vec![0.0f32; dim];
890                embedding[i % dim] = 1.0;
891                let chunk = create_test_chunk_with_embedding(&format!("chunk {i}"), embedding);
892                ids.push(chunk.id);
893                store.insert(chunk).unwrap();
894            }
895
896            let query = vec![1.0f32; dim];
897            let results = store.search(&query, n_chunks).unwrap();
898
899            // All results should be from stored chunks
900            for (id, _) in results {
901                prop_assert!(ids.contains(&id));
902            }
903        }
904    }
905}