1use crate::{Chunk, ChunkId, Error, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::{HashMap, HashSet};
6
7const DEFAULT_EMBEDDING_DIM: usize = 384;
9
10pub trait SparseIndex: Send + Sync {
12 fn add(&mut self, chunk: &Chunk);
14
15 fn add_batch(&mut self, chunks: &[Chunk]);
17
18 fn search(&self, query: &str, k: usize) -> Vec<(ChunkId, f32)>;
20
21 fn remove(&mut self, chunk_id: ChunkId);
23
24 fn len(&self) -> usize;
26
27 fn is_empty(&self) -> bool {
29 self.len() == 0
30 }
31}
32
33#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct BM25Index {
36 inverted_index: HashMap<String, Vec<(ChunkId, u32)>>,
38 doc_freqs: HashMap<String, u32>,
40 doc_lengths: HashMap<ChunkId, u32>,
42 avg_doc_length: f32,
44 doc_count: u32,
46 k1: f32,
48 b: f32,
50 lowercase: bool,
52 stopwords: HashSet<String>,
54}
55
56impl Default for BM25Index {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl BM25Index {
63 #[must_use]
65 pub fn new() -> Self {
66 Self {
67 inverted_index: HashMap::new(),
68 doc_freqs: HashMap::new(),
69 doc_lengths: HashMap::new(),
70 avg_doc_length: 0.0,
71 doc_count: 0,
72 k1: 1.2,
73 b: 0.75,
74 lowercase: true,
75 stopwords: Self::default_stopwords(),
76 }
77 }
78
79 #[must_use]
81 pub fn with_params(k1: f32, b: f32) -> Self {
82 Self { k1, b, ..Self::new() }
83 }
84
85 #[must_use]
87 pub fn with_stopwords(mut self, stopwords: HashSet<String>) -> Self {
88 self.stopwords = stopwords;
89 self
90 }
91
92 fn default_stopwords() -> HashSet<String> {
93 [
94 "a", "an", "the", "is", "are", "was", "were", "be", "been", "being", "have", "has",
95 "had", "do", "does", "did", "will", "would", "could", "should", "may", "might", "must",
96 "shall", "can", "need", "dare", "ought", "used", "to", "of", "in", "for", "on", "with",
97 "at", "by", "from", "as", "into", "through", "during", "before", "after", "above",
98 "below", "between", "under", "again", "further", "then", "once", "here", "there",
99 "when", "where", "why", "how", "all", "each", "few", "more", "most", "other", "some",
100 "such", "no", "nor", "not", "only", "own", "same", "so", "than", "too", "very", "just",
101 "and", "but", "if", "or", "because", "until", "while", "this", "that", "these",
102 "those", "it", "its",
103 ]
104 .iter()
105 .map(|s| (*s).to_string())
106 .collect()
107 }
108
109 pub fn tokenize(&self, text: &str) -> Vec<String> {
111 text.split(|c: char| !c.is_alphanumeric())
112 .filter(|s| !s.is_empty())
113 .map(|s| if self.lowercase { s.to_lowercase() } else { s.to_string() })
114 .filter(|s| !self.stopwords.contains(s))
115 .filter(|s| s.len() >= 2) .collect()
117 }
118
119 fn term_frequency(&self, term: &str, chunk_id: ChunkId) -> u32 {
121 self.inverted_index
122 .get(term)
123 .and_then(|postings| postings.iter().find(|(id, _)| *id == chunk_id))
124 .map(|(_, freq)| *freq)
125 .unwrap_or(0)
126 }
127
128 fn score_term(&self, term: &str, chunk_id: ChunkId) -> f32 {
130 let tf = self.term_frequency(term, chunk_id) as f32;
131 if tf == 0.0 {
132 return 0.0;
133 }
134
135 let df = self.doc_freqs.get(term).copied().unwrap_or(0) as f32;
136 let n = self.doc_count as f32;
137 let doc_len = self.doc_lengths.get(&chunk_id).copied().unwrap_or(0) as f32;
138
139 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).max(f32::EPSILON).ln();
141
142 let tf_norm = (tf * (self.k1 + 1.0))
144 / (tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avg_doc_length));
145
146 idf * tf_norm
147 }
148
149 fn update_avg_doc_length(&mut self) {
151 if self.doc_count == 0 {
152 self.avg_doc_length = 0.0;
153 } else {
154 let total: u32 = self.doc_lengths.values().sum();
155 self.avg_doc_length = total as f32 / self.doc_count as f32;
156 }
157 }
158
159 fn get_chunks_for_term(&self, term: &str) -> Vec<ChunkId> {
161 self.inverted_index
162 .get(term)
163 .map(|postings| postings.iter().map(|(id, _)| *id).collect())
164 .unwrap_or_default()
165 }
166}
167
168impl SparseIndex for BM25Index {
169 fn add(&mut self, chunk: &Chunk) {
170 let tokens = self.tokenize(&chunk.content);
171 let doc_len = tokens.len() as u32;
172
173 self.doc_lengths.insert(chunk.id, doc_len);
175 self.doc_count += 1;
176
177 let mut term_freqs: HashMap<String, u32> = HashMap::new();
179 for token in &tokens {
180 *term_freqs.entry(token.clone()).or_insert(0) += 1;
181 }
182
183 let mut seen_terms: HashSet<String> = HashSet::new();
185 for (term, freq) in term_freqs {
186 self.inverted_index.entry(term.clone()).or_default().push((chunk.id, freq));
187
188 if seen_terms.insert(term.clone()) {
189 *self.doc_freqs.entry(term).or_insert(0) += 1;
190 }
191 }
192
193 self.update_avg_doc_length();
194 }
195
196 fn add_batch(&mut self, chunks: &[Chunk]) {
197 for chunk in chunks {
198 self.add(chunk);
199 }
200 }
201
202 fn search(&self, query: &str, k: usize) -> Vec<(ChunkId, f32)> {
203 let query_terms = self.tokenize(query);
204 if query_terms.is_empty() {
205 return Vec::new();
206 }
207
208 let mut candidates: HashSet<ChunkId> = HashSet::new();
210 for term in &query_terms {
211 for chunk_id in self.get_chunks_for_term(term) {
212 candidates.insert(chunk_id);
213 }
214 }
215
216 let mut scores: Vec<(ChunkId, f32)> = candidates
218 .into_iter()
219 .map(|chunk_id| {
220 let score: f32 =
221 query_terms.iter().map(|term| self.score_term(term, chunk_id)).sum();
222 (chunk_id, score)
223 })
224 .filter(|(_, score)| *score > 0.0)
225 .collect();
226
227 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
229 scores.truncate(k);
230 scores
231 }
232
233 fn remove(&mut self, chunk_id: ChunkId) {
234 if self.doc_lengths.remove(&chunk_id).is_some() {
236 self.doc_count = self.doc_count.saturating_sub(1);
237 }
238
239 let mut terms_to_remove: Vec<String> = Vec::new();
241 for (term, postings) in &mut self.inverted_index {
242 let original_len = postings.len();
243 postings.retain(|(id, _)| *id != chunk_id);
244
245 if postings.len() < original_len {
246 if let Some(df) = self.doc_freqs.get_mut(term) {
248 *df = df.saturating_sub(1);
249 if *df == 0 {
250 terms_to_remove.push(term.clone());
251 }
252 }
253 }
254 }
255
256 for term in terms_to_remove {
258 self.inverted_index.remove(&term);
259 self.doc_freqs.remove(&term);
260 }
261
262 self.update_avg_doc_length();
263 }
264
265 fn len(&self) -> usize {
266 self.doc_count as usize
267 }
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct VectorStoreConfig {
273 pub dimension: usize,
275 pub metric: DistanceMetric,
277 pub hnsw_m: usize,
279 pub hnsw_ef_construction: usize,
281 pub hnsw_ef_search: usize,
283}
284
285impl Default for VectorStoreConfig {
286 fn default() -> Self {
287 Self {
288 dimension: DEFAULT_EMBEDDING_DIM,
289 metric: DistanceMetric::Cosine,
290 hnsw_m: 16,
291 hnsw_ef_construction: 100,
292 hnsw_ef_search: 50,
293 }
294 }
295}
296
297#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
299pub enum DistanceMetric {
300 #[default]
302 Cosine,
303 Euclidean,
305 DotProduct,
307}
308
309#[derive(Debug, Clone)]
311pub struct VectorStore {
312 config: VectorStoreConfig,
314 vectors: HashMap<ChunkId, Vec<f32>>,
316 chunks: HashMap<ChunkId, Chunk>,
318}
319
320impl VectorStore {
321 #[must_use]
323 pub fn new(config: VectorStoreConfig) -> Self {
324 Self { config, vectors: HashMap::new(), chunks: HashMap::new() }
325 }
326
327 #[must_use]
329 pub fn with_dimension(dimension: usize) -> Self {
330 Self::new(VectorStoreConfig { dimension, ..Default::default() })
331 }
332
333 #[must_use]
335 pub fn config(&self) -> &VectorStoreConfig {
336 &self.config
337 }
338
339 pub fn insert(&mut self, chunk: Chunk) -> Result<()> {
341 let embedding = chunk
342 .embedding
343 .as_ref()
344 .ok_or_else(|| Error::InvalidConfig("chunk must have embedding".to_string()))?;
345
346 if embedding.len() != self.config.dimension {
347 return Err(Error::DimensionMismatch {
348 expected: self.config.dimension,
349 actual: embedding.len(),
350 });
351 }
352
353 self.vectors.insert(chunk.id, embedding.clone());
354 self.chunks.insert(chunk.id, chunk);
355 Ok(())
356 }
357
358 pub fn insert_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
360 for chunk in chunks {
361 self.insert(chunk)?;
362 }
363 Ok(())
364 }
365
366 pub fn search(&self, query_vector: &[f32], k: usize) -> Result<Vec<(ChunkId, f32)>> {
368 if query_vector.len() != self.config.dimension {
369 return Err(Error::DimensionMismatch {
370 expected: self.config.dimension,
371 actual: query_vector.len(),
372 });
373 }
374
375 let mut scores: Vec<(ChunkId, f32)> = self
376 .vectors
377 .iter()
378 .map(|(id, vec)| {
379 let score = match self.config.metric {
380 DistanceMetric::Cosine => cosine_similarity(query_vector, vec),
381 DistanceMetric::Euclidean => -euclidean_distance(query_vector, vec),
382 DistanceMetric::DotProduct => dot_product(query_vector, vec),
383 };
384 (*id, score)
385 })
386 .collect();
387
388 scores.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
390 scores.truncate(k);
391
392 Ok(scores)
393 }
394
395 #[must_use]
397 pub fn get(&self, chunk_id: ChunkId) -> Option<&Chunk> {
398 self.chunks.get(&chunk_id)
399 }
400
401 pub fn remove(&mut self, chunk_id: ChunkId) -> Option<Chunk> {
403 self.vectors.remove(&chunk_id);
404 self.chunks.remove(&chunk_id)
405 }
406
407 #[must_use]
409 pub fn len(&self) -> usize {
410 self.vectors.len()
411 }
412
413 #[must_use]
415 pub fn is_empty(&self) -> bool {
416 self.vectors.is_empty()
417 }
418}
419
420fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
422 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
423 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
424 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
425
426 if norm_a == 0.0 || norm_b == 0.0 {
427 0.0
428 } else {
429 dot / (norm_a * norm_b)
430 }
431}
432
433fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
434 a.iter().zip(b.iter()).map(|(x, y)| (x - y).powi(2)).sum::<f32>().sqrt()
435}
436
437fn dot_product(a: &[f32], b: &[f32]) -> f32 {
438 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
439}
440
441#[cfg(test)]
442mod tests {
443 use super::*;
444 use crate::DocumentId;
445
446 fn create_test_chunk(content: &str) -> Chunk {
447 Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
448 }
449
450 fn create_test_chunk_with_embedding(content: &str, embedding: Vec<f32>) -> Chunk {
451 let mut chunk = create_test_chunk(content);
452 chunk.set_embedding(embedding);
453 chunk
454 }
455
456 #[test]
459 fn test_bm25_index_new() {
460 let index = BM25Index::new();
461 assert_eq!(index.len(), 0);
462 assert!(index.is_empty());
463 assert!((index.k1 - 1.2).abs() < 0.01);
464 assert!((index.b - 0.75).abs() < 0.01);
465 }
466
467 #[test]
468 fn test_bm25_index_with_params() {
469 let index = BM25Index::with_params(1.5, 0.5);
470 assert!((index.k1 - 1.5).abs() < 0.01);
471 assert!((index.b - 0.5).abs() < 0.01);
472 }
473
474 #[test]
475 fn test_bm25_tokenize() {
476 let index = BM25Index::new();
477 let tokens = index.tokenize("Hello World! This is a test.");
478
479 assert!(tokens.contains(&"hello".to_string()));
480 assert!(tokens.contains(&"world".to_string()));
481 assert!(tokens.contains(&"test".to_string()));
482 assert!(!tokens.contains(&"this".to_string()));
484 assert!(!tokens.contains(&"is".to_string()));
485 assert!(!tokens.contains(&"a".to_string()));
486 }
487
488 #[test]
489 fn test_bm25_tokenize_lowercase() {
490 let index = BM25Index::new();
491 let tokens = index.tokenize("HELLO World");
492 assert!(tokens.contains(&"hello".to_string()));
493 assert!(tokens.contains(&"world".to_string()));
494 }
495
496 #[test]
497 fn test_bm25_add_chunk() {
498 let mut index = BM25Index::new();
499 let chunk = create_test_chunk("Machine learning is fascinating");
500
501 index.add(&chunk);
502
503 assert_eq!(index.len(), 1);
504 assert!(!index.is_empty());
505 assert!(index.inverted_index.contains_key("machine"));
506 assert!(index.inverted_index.contains_key("learning"));
507 }
508
509 #[test]
510 fn test_bm25_add_batch() {
511 let mut index = BM25Index::new();
512 let chunks = vec![
513 create_test_chunk("First document about AI"),
514 create_test_chunk("Second document about ML"),
515 create_test_chunk("Third document about deep learning"),
516 ];
517
518 index.add_batch(&chunks);
519
520 assert_eq!(index.len(), 3);
521 }
522
523 #[test]
524 fn test_bm25_search_basic() {
525 let mut index = BM25Index::new();
526 let chunk1 = create_test_chunk("Machine learning algorithms");
527 let chunk2 = create_test_chunk("Deep learning neural networks");
528 let chunk3 = create_test_chunk("Natural language processing");
529
530 index.add(&chunk1);
531 index.add(&chunk2);
532 index.add(&chunk3);
533
534 let results = index.search("machine learning", 10);
535
536 assert!(!results.is_empty());
537 assert!(results.iter().any(|(id, _)| *id == chunk1.id));
539 }
540
541 #[test]
542 fn test_bm25_search_empty_query() {
543 let mut index = BM25Index::new();
544 index.add(&create_test_chunk("Test document"));
545
546 let results = index.search("", 10);
547 assert!(results.is_empty());
548 }
549
550 #[test]
551 fn test_bm25_search_stopwords_only() {
552 let mut index = BM25Index::new();
553 index.add(&create_test_chunk("Test document"));
554
555 let results = index.search("the a an", 10);
556 assert!(results.is_empty());
557 }
558
559 #[test]
560 fn test_bm25_search_no_match() {
561 let mut index = BM25Index::new();
562 index.add(&create_test_chunk("Cats and dogs"));
563
564 let results = index.search("quantum physics", 10);
565 assert!(results.is_empty());
566 }
567
568 #[test]
569 fn test_bm25_search_ranking() {
570 let mut index = BM25Index::new();
571
572 let chunk1 = create_test_chunk("python programming language");
574 let chunk2 = create_test_chunk("python python python programming");
575
576 index.add(&chunk1);
577 index.add(&chunk2);
578
579 let results = index.search("python programming", 10);
580
581 assert_eq!(results.len(), 2);
582 assert_eq!(results[0].0, chunk2.id);
584 }
585
586 #[test]
587 fn test_bm25_search_top_k() {
588 let mut index = BM25Index::new();
589 for i in 0..10 {
590 index.add(&create_test_chunk(&format!("document {i} about rust")));
591 }
592
593 let results = index.search("rust", 3);
594 assert_eq!(results.len(), 3);
595 }
596
597 #[test]
598 fn test_bm25_remove() {
599 let mut index = BM25Index::new();
600 let chunk = create_test_chunk("Test document");
601 let chunk_id = chunk.id;
602
603 index.add(&chunk);
604 assert_eq!(index.len(), 1);
605
606 index.remove(chunk_id);
607 assert_eq!(index.len(), 0);
608
609 let results = index.search("test", 10);
610 assert!(results.is_empty());
611 }
612
613 #[test]
614 fn test_bm25_avg_doc_length() {
615 let mut index = BM25Index::new();
616
617 index.add(&create_test_chunk("short text")); index.add(&create_test_chunk("this is a longer piece of text about programming")); assert!(index.avg_doc_length > 0.0);
621 }
622
623 #[test]
624 fn test_bm25_idf_calculation() {
625 let mut index = BM25Index::new();
626
627 index.add(&create_test_chunk("common rare"));
629 index.add(&create_test_chunk("common word"));
630 index.add(&create_test_chunk("common term"));
631
632 let rare_results = index.search("rare", 10);
634 let common_results = index.search("common", 10);
635
636 assert!(!rare_results.is_empty());
639 assert!(!common_results.is_empty());
640 }
641
642 #[test]
645 fn test_vector_store_new() {
646 let store = VectorStore::with_dimension(384);
647 assert_eq!(store.config().dimension, 384);
648 assert!(store.is_empty());
649 }
650
651 #[test]
652 fn test_vector_store_config() {
653 let config = VectorStoreConfig {
654 dimension: 768,
655 metric: DistanceMetric::DotProduct,
656 hnsw_m: 32,
657 hnsw_ef_construction: 200,
658 hnsw_ef_search: 100,
659 };
660 let store = VectorStore::new(config.clone());
661
662 assert_eq!(store.config().dimension, 768);
663 assert_eq!(store.config().metric, DistanceMetric::DotProduct);
664 }
665
666 #[test]
667 fn test_vector_store_insert() {
668 let mut store = VectorStore::with_dimension(3);
669 let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0, 0.0]);
670
671 store.insert(chunk.clone()).unwrap();
672
673 assert_eq!(store.len(), 1);
674 assert!(!store.is_empty());
675 assert!(store.get(chunk.id).is_some());
676 }
677
678 #[test]
679 fn test_vector_store_insert_no_embedding() {
680 let mut store = VectorStore::with_dimension(3);
681 let chunk = create_test_chunk("no embedding");
682
683 let result = store.insert(chunk);
684 assert!(result.is_err());
685 }
686
687 #[test]
688 fn test_vector_store_insert_wrong_dimension() {
689 let mut store = VectorStore::with_dimension(3);
690 let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0]); let result = store.insert(chunk);
693 assert!(result.is_err());
694 match result {
695 Err(Error::DimensionMismatch { expected, actual }) => {
696 assert_eq!(expected, 3);
697 assert_eq!(actual, 2);
698 }
699 _ => panic!("Expected DimensionMismatch error"),
700 }
701 }
702
703 #[test]
704 fn test_vector_store_insert_batch() {
705 let mut store = VectorStore::with_dimension(3);
706 let chunks = vec![
707 create_test_chunk_with_embedding("a", vec![1.0, 0.0, 0.0]),
708 create_test_chunk_with_embedding("b", vec![0.0, 1.0, 0.0]),
709 create_test_chunk_with_embedding("c", vec![0.0, 0.0, 1.0]),
710 ];
711
712 store.insert_batch(chunks).unwrap();
713 assert_eq!(store.len(), 3);
714 }
715
716 #[test]
717 fn test_vector_store_search_cosine() {
718 let mut store = VectorStore::with_dimension(3);
719
720 let chunk1 = create_test_chunk_with_embedding("north", vec![1.0, 0.0, 0.0]);
721 let chunk2 = create_test_chunk_with_embedding("east", vec![0.0, 1.0, 0.0]);
722 let chunk3 = create_test_chunk_with_embedding(
723 "diagonal",
724 vec![std::f32::consts::FRAC_1_SQRT_2, std::f32::consts::FRAC_1_SQRT_2, 0.0],
725 );
726
727 let id1 = chunk1.id;
728 let id3 = chunk3.id;
729
730 store.insert(chunk1).unwrap();
731 store.insert(chunk2).unwrap();
732 store.insert(chunk3).unwrap();
733
734 let query = vec![0.9, 0.1, 0.0];
736 let results = store.search(&query, 10).unwrap();
737
738 assert_eq!(results.len(), 3);
739 assert_eq!(results[0].0, id1);
741 assert_eq!(results[1].0, id3);
743 }
744
745 #[test]
746 fn test_vector_store_search_top_k() {
747 let mut store = VectorStore::with_dimension(3);
748
749 for i in 0..10 {
750 let embedding = vec![i as f32, 0.0, 0.0];
751 store
752 .insert(create_test_chunk_with_embedding(&format!("chunk {i}"), embedding))
753 .unwrap();
754 }
755
756 let results = store.search(&[9.0, 0.0, 0.0], 3).unwrap();
757 assert_eq!(results.len(), 3);
758 }
759
760 #[test]
761 fn test_vector_store_search_wrong_dimension() {
762 let store = VectorStore::with_dimension(3);
763 let result = store.search(&[1.0, 0.0], 10);
764 assert!(result.is_err());
765 }
766
767 #[test]
768 fn test_vector_store_remove() {
769 let mut store = VectorStore::with_dimension(3);
770 let chunk = create_test_chunk_with_embedding("test", vec![1.0, 0.0, 0.0]);
771 let chunk_id = chunk.id;
772
773 store.insert(chunk).unwrap();
774 assert_eq!(store.len(), 1);
775
776 let removed = store.remove(chunk_id);
777 assert!(removed.is_some());
778 assert_eq!(store.len(), 0);
779 assert!(store.get(chunk_id).is_none());
780 }
781
782 #[test]
783 fn test_vector_store_remove_nonexistent() {
784 let mut store = VectorStore::with_dimension(3);
785 let removed = store.remove(ChunkId::new());
786 assert!(removed.is_none());
787 }
788
789 #[test]
790 fn test_distance_metric_euclidean() {
791 let config = VectorStoreConfig {
792 dimension: 2,
793 metric: DistanceMetric::Euclidean,
794 ..Default::default()
795 };
796 let mut store = VectorStore::new(config);
797
798 let chunk1 = create_test_chunk_with_embedding("origin", vec![0.0, 0.0]);
799 let chunk2 = create_test_chunk_with_embedding("near", vec![1.0, 0.0]);
800 let chunk3 = create_test_chunk_with_embedding("far", vec![10.0, 0.0]);
801
802 let id2 = chunk2.id;
803 let id1 = chunk1.id;
804
805 store.insert(chunk1).unwrap();
806 store.insert(chunk2).unwrap();
807 store.insert(chunk3).unwrap();
808
809 let results = store.search(&[0.0, 0.0], 10).unwrap();
811 assert_eq!(results[0].0, id1); assert_eq!(results[1].0, id2); }
814
815 #[test]
816 fn test_distance_metric_dot_product() {
817 let config = VectorStoreConfig {
818 dimension: 2,
819 metric: DistanceMetric::DotProduct,
820 ..Default::default()
821 };
822 let mut store = VectorStore::new(config);
823
824 let chunk1 = create_test_chunk_with_embedding("small", vec![1.0, 0.0]);
825 let chunk2 = create_test_chunk_with_embedding("large", vec![10.0, 0.0]);
826
827 let id2 = chunk2.id;
828
829 store.insert(chunk1).unwrap();
830 store.insert(chunk2).unwrap();
831
832 let results = store.search(&[1.0, 0.0], 10).unwrap();
834 assert_eq!(results[0].0, id2);
835 }
836
837 use proptest::prelude::*;
840
841 proptest! {
842 #[test]
843 fn prop_bm25_add_increases_count(content in "[a-zA-Z ]{10,100}") {
844 let mut index = BM25Index::new();
845 let initial = index.len();
846 index.add(&create_test_chunk(&content));
847 prop_assert_eq!(index.len(), initial + 1);
848 }
849
850 #[test]
851 fn prop_bm25_search_results_within_k(
852 content in prop::collection::vec("[a-zA-Z]{3,10}", 5..20),
853 k in 1usize..10
854 ) {
855 let mut index = BM25Index::new();
856 for c in &content {
857 index.add(&create_test_chunk(c));
858 }
859
860 let results = index.search("test", k);
861 prop_assert!(results.len() <= k);
862 }
863
864 #[test]
865 fn prop_bm25_scores_non_negative(
866 docs in prop::collection::vec("[a-zA-Z ]{5,50}", 3..10),
867 query in "[a-zA-Z]{3,10}"
868 ) {
869 let mut index = BM25Index::new();
870 for doc in &docs {
871 index.add(&create_test_chunk(doc));
872 }
873
874 let results = index.search(&query, 100);
875 for (_, score) in results {
876 prop_assert!(score >= 0.0);
877 }
878 }
879
880 #[test]
881 fn prop_vector_store_search_returns_stored(
882 dim in 2usize..10,
883 n_chunks in 1usize..20
884 ) {
885 let mut store = VectorStore::with_dimension(dim);
886 let mut ids = Vec::new();
887
888 for i in 0..n_chunks {
889 let mut embedding = vec![0.0f32; dim];
890 embedding[i % dim] = 1.0;
891 let chunk = create_test_chunk_with_embedding(&format!("chunk {i}"), embedding);
892 ids.push(chunk.id);
893 store.insert(chunk).unwrap();
894 }
895
896 let query = vec![1.0f32; dim];
897 let results = store.search(&query, n_chunks).unwrap();
898
899 for (id, _) in results {
901 prop_assert!(ids.contains(&id));
902 }
903 }
904 }
905}