Skip to main content

trueno_rag/
retrieve.rs

1//! Retrieval module for RAG pipelines
2
3use crate::{
4    embed::Embedder,
5    fusion::FusionStrategy,
6    index::{BM25Index, SparseIndex, VectorStore},
7    Chunk, ChunkId, Result,
8};
9use serde::{Deserialize, Serialize};
10
11/// Result of a retrieval operation
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct RetrievalResult {
14    /// The retrieved chunk
15    pub chunk: Chunk,
16    /// Dense retrieval score (if applicable)
17    pub dense_score: Option<f32>,
18    /// Sparse retrieval score (if applicable)
19    pub sparse_score: Option<f32>,
20    /// Multi-vector retrieval score (if applicable, ColBERT-style MaxSim)
21    #[cfg(feature = "multivector")]
22    pub multivector_score: Option<f32>,
23    /// Fused score (if hybrid retrieval)
24    pub fused_score: Option<f32>,
25    /// Reranking score (if reranking applied)
26    pub rerank_score: Option<f32>,
27}
28
29impl RetrievalResult {
30    /// Create a new retrieval result from a chunk
31    #[must_use]
32    pub fn new(chunk: Chunk) -> Self {
33        Self {
34            chunk,
35            dense_score: None,
36            sparse_score: None,
37            #[cfg(feature = "multivector")]
38            multivector_score: None,
39            fused_score: None,
40            rerank_score: None,
41        }
42    }
43
44    /// Set the dense score
45    #[must_use]
46    pub fn with_dense_score(mut self, score: f32) -> Self {
47        self.dense_score = Some(score);
48        self
49    }
50
51    /// Set the sparse score
52    #[must_use]
53    pub fn with_sparse_score(mut self, score: f32) -> Self {
54        self.sparse_score = Some(score);
55        self
56    }
57
58    /// Set the fused score
59    #[must_use]
60    pub fn with_fused_score(mut self, score: f32) -> Self {
61        self.fused_score = Some(score);
62        self
63    }
64
65    /// Set the rerank score
66    #[must_use]
67    pub fn with_rerank_score(mut self, score: f32) -> Self {
68        self.rerank_score = Some(score);
69        self
70    }
71
72    /// Set the multi-vector (ColBERT-style) score
73    #[cfg(feature = "multivector")]
74    #[must_use]
75    pub fn with_multivector_score(mut self, score: f32) -> Self {
76        self.multivector_score = Some(score);
77        self
78    }
79
80    /// Get the best available score (rerank > fused > multivector > dense > sparse)
81    #[must_use]
82    pub fn best_score(&self) -> f32 {
83        self.rerank_score
84            .or(self.fused_score)
85            .or(self.dense_score)
86            .or(self.sparse_score)
87            .unwrap_or(0.0)
88    }
89
90    /// Get the best available score including multi-vector score
91    #[cfg(feature = "multivector")]
92    #[must_use]
93    pub fn best_score_with_multivector(&self) -> f32 {
94        self.rerank_score
95            .or(self.fused_score)
96            .or(self.multivector_score)
97            .or(self.dense_score)
98            .or(self.sparse_score)
99            .unwrap_or(0.0)
100    }
101}
102
103/// Configuration for hybrid retrieval
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct HybridRetrieverConfig {
106    /// Number of candidates to retrieve from each source
107    pub candidates_per_source: usize,
108    /// Fusion strategy
109    pub fusion: FusionStrategy,
110    /// Whether to use dense retrieval
111    pub use_dense: bool,
112    /// Whether to use sparse retrieval
113    pub use_sparse: bool,
114}
115
116impl Default for HybridRetrieverConfig {
117    fn default() -> Self {
118        Self {
119            candidates_per_source: 50,
120            fusion: FusionStrategy::default(),
121            use_dense: true,
122            use_sparse: true,
123        }
124    }
125}
126
127/// Hybrid retriever combining dense and sparse retrieval
128pub struct HybridRetriever<E: Embedder> {
129    /// Dense vector store
130    dense: VectorStore,
131    /// Sparse BM25 index
132    sparse: BM25Index,
133    /// Embedder for query embedding
134    embedder: E,
135    /// Configuration
136    config: HybridRetrieverConfig,
137}
138
139impl<E: Embedder> HybridRetriever<E> {
140    /// Create a new hybrid retriever
141    #[must_use]
142    pub fn new(dense: VectorStore, sparse: BM25Index, embedder: E) -> Self {
143        Self { dense, sparse, embedder, config: HybridRetrieverConfig::default() }
144    }
145
146    /// Set the configuration
147    #[must_use]
148    pub fn with_config(mut self, config: HybridRetrieverConfig) -> Self {
149        self.config = config;
150        self
151    }
152
153    /// Get the dense store
154    #[must_use]
155    pub fn dense_store(&self) -> &VectorStore {
156        &self.dense
157    }
158
159    /// Get the dense store mutably
160    pub fn dense_store_mut(&mut self) -> &mut VectorStore {
161        &mut self.dense
162    }
163
164    /// Get the sparse index
165    #[must_use]
166    pub fn sparse_index(&self) -> &BM25Index {
167        &self.sparse
168    }
169
170    /// Get the sparse index mutably
171    pub fn sparse_index_mut(&mut self) -> &mut BM25Index {
172        &mut self.sparse
173    }
174
175    /// Index a chunk (adds to both dense and sparse indices)
176    pub fn index(&mut self, chunk: Chunk) -> Result<()> {
177        // Add to sparse index
178        self.sparse.add(&chunk);
179
180        // Add to dense index (requires embedding)
181        self.dense.insert(chunk)?;
182
183        Ok(())
184    }
185
186    /// Index multiple chunks
187    pub fn index_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
188        for chunk in chunks {
189            self.index(chunk)?;
190        }
191        Ok(())
192    }
193
194    /// Retrieve relevant chunks for a query
195    pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
196        let candidates = self.config.candidates_per_source;
197
198        // Dense retrieval
199        let dense_results = if self.config.use_dense {
200            let query_embedding = self.embedder.embed_query(query)?;
201            self.dense.search(&query_embedding, candidates)?
202        } else {
203            Vec::new()
204        };
205
206        // Sparse retrieval
207        let sparse_results =
208            if self.config.use_sparse { self.sparse.search(query, candidates) } else { Vec::new() };
209
210        // Fuse results
211        let fused = self.config.fusion.fuse(&dense_results, &sparse_results);
212
213        // Build score maps for lookup
214        let dense_scores: std::collections::HashMap<ChunkId, f32> =
215            dense_results.into_iter().collect();
216        let sparse_scores: std::collections::HashMap<ChunkId, f32> =
217            sparse_results.into_iter().collect();
218
219        // Build retrieval results
220        let mut results = Vec::with_capacity(k.min(fused.len()));
221        for (chunk_id, fused_score) in fused.into_iter().take(k) {
222            if let Some(chunk) = self.dense.get(chunk_id) {
223                let mut result = RetrievalResult::new(chunk.clone()).with_fused_score(fused_score);
224
225                if let Some(&score) = dense_scores.get(&chunk_id) {
226                    result = result.with_dense_score(score);
227                }
228                if let Some(&score) = sparse_scores.get(&chunk_id) {
229                    result = result.with_sparse_score(score);
230                }
231
232                results.push(result);
233            }
234        }
235
236        Ok(results)
237    }
238
239    /// Retrieve using only dense (vector) search
240    pub fn retrieve_dense(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
241        let query_embedding = self.embedder.embed_query(query)?;
242        let results = self.dense.search(&query_embedding, k)?;
243
244        let mut retrieval_results = Vec::with_capacity(results.len());
245        for (chunk_id, score) in results {
246            if let Some(chunk) = self.dense.get(chunk_id) {
247                retrieval_results.push(RetrievalResult::new(chunk.clone()).with_dense_score(score));
248            }
249        }
250
251        Ok(retrieval_results)
252    }
253
254    /// Retrieve using only sparse (BM25) search
255    pub fn retrieve_sparse(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
256        let results = self.sparse.search(query, k);
257
258        let mut retrieval_results = Vec::with_capacity(results.len());
259        for (chunk_id, score) in results {
260            if let Some(chunk) = self.dense.get(chunk_id) {
261                retrieval_results
262                    .push(RetrievalResult::new(chunk.clone()).with_sparse_score(score));
263            }
264        }
265
266        Ok(retrieval_results)
267    }
268
269    /// Get the number of indexed chunks
270    #[must_use]
271    pub fn len(&self) -> usize {
272        self.dense.len()
273    }
274
275    /// Check if the retriever is empty
276    #[must_use]
277    pub fn is_empty(&self) -> bool {
278        self.dense.is_empty()
279    }
280}
281
282/// Dense-only retriever (simpler API for vector-only search)
283pub struct DenseRetriever<E: Embedder> {
284    store: VectorStore,
285    embedder: E,
286}
287
288impl<E: Embedder> DenseRetriever<E> {
289    /// Create a new dense retriever
290    #[must_use]
291    pub fn new(store: VectorStore, embedder: E) -> Self {
292        Self { store, embedder }
293    }
294
295    /// Index a chunk
296    pub fn index(&mut self, chunk: Chunk) -> Result<()> {
297        self.store.insert(chunk)
298    }
299
300    /// Retrieve relevant chunks
301    pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
302        let query_embedding = self.embedder.embed_query(query)?;
303        let results = self.store.search(&query_embedding, k)?;
304
305        let mut retrieval_results = Vec::with_capacity(results.len());
306        for (chunk_id, score) in results {
307            if let Some(chunk) = self.store.get(chunk_id) {
308                retrieval_results.push(RetrievalResult::new(chunk.clone()).with_dense_score(score));
309            }
310        }
311
312        Ok(retrieval_results)
313    }
314}
315
316/// Sparse-only retriever (BM25)
317pub struct SparseRetriever {
318    index: BM25Index,
319    chunks: std::collections::HashMap<ChunkId, Chunk>,
320}
321
322impl SparseRetriever {
323    /// Create a new sparse retriever
324    #[must_use]
325    pub fn new() -> Self {
326        Self { index: BM25Index::new(), chunks: std::collections::HashMap::new() }
327    }
328
329    /// Index a chunk
330    pub fn index(&mut self, chunk: Chunk) {
331        self.index.add(&chunk);
332        self.chunks.insert(chunk.id, chunk);
333    }
334
335    /// Retrieve relevant chunks
336    #[must_use]
337    pub fn retrieve(&self, query: &str, k: usize) -> Vec<RetrievalResult> {
338        let results = self.index.search(query, k);
339
340        results
341            .into_iter()
342            .filter_map(|(chunk_id, score)| {
343                self.chunks
344                    .get(&chunk_id)
345                    .map(|chunk| RetrievalResult::new(chunk.clone()).with_sparse_score(score))
346            })
347            .collect()
348    }
349}
350
351impl Default for SparseRetriever {
352    fn default() -> Self {
353        Self::new()
354    }
355}
356
357// ============ Multi-Vector Retriever (WARP) ============
358
359/// Multi-vector retriever using WARP index for ColBERT-style late interaction.
360///
361/// This retriever uses token-level embeddings and MaxSim scoring for fine-grained
362/// semantic matching. Unlike single-vector dense retrieval, multi-vector approaches
363/// represent documents and queries as multiple token embeddings.
364///
365/// # Example
366///
367/// ```ignore
368/// use trueno_rag::multivector::{
369///     WarpIndexConfig, WarpSearchConfig,
370///     MockMultiVectorEmbedder, MultiVectorRetriever,
371/// };
372///
373/// let config = WarpIndexConfig::new(2, 256, 128);
374/// let embedder = MockMultiVectorEmbedder::new(128, 512);
375/// let mut retriever = MultiVectorRetriever::new(config, embedder);
376///
377/// // Train on sample documents
378/// retriever.train(&sample_chunks)?;
379///
380/// // Index documents
381/// for chunk in chunks {
382///     retriever.index(chunk)?;
383/// }
384/// retriever.build()?;
385///
386/// // Search
387/// let results = retriever.retrieve("What is machine learning?", 10)?;
388/// ```
389#[cfg(feature = "multivector")]
390pub struct MultiVectorRetriever<E: crate::multivector::MultiVectorEmbedder> {
391    /// WARP index for compressed multi-vector storage and search
392    index: crate::multivector::WarpIndex,
393    /// Multi-vector embedder for token-level embeddings
394    embedder: E,
395    /// Search configuration
396    search_config: crate::multivector::WarpSearchConfig,
397}
398
399#[cfg(feature = "multivector")]
400impl<E: crate::multivector::MultiVectorEmbedder> MultiVectorRetriever<E> {
401    /// Create a new multi-vector retriever with the given configuration and embedder.
402    ///
403    /// # Arguments
404    ///
405    /// * `config` - WARP index configuration (nbits, num_centroids, token_dim)
406    /// * `embedder` - Multi-vector embedder for generating token embeddings
407    #[must_use]
408    pub fn new(config: crate::multivector::WarpIndexConfig, embedder: E) -> Self {
409        Self {
410            index: crate::multivector::WarpIndex::new(config),
411            embedder,
412            search_config: crate::multivector::WarpSearchConfig::default(),
413        }
414    }
415
416    /// Set the search configuration.
417    #[must_use]
418    pub fn with_search_config(mut self, config: crate::multivector::WarpSearchConfig) -> Self {
419        self.search_config = config;
420        self
421    }
422
423    /// Train the WARP index on sample chunks.
424    ///
425    /// This builds the residual quantization codec by learning centroids from
426    /// the provided sample embeddings. Should be called before indexing.
427    ///
428    /// # Arguments
429    ///
430    /// * `sample_chunks` - Representative chunks for training the codec
431    pub fn train(&mut self, sample_chunks: &[Chunk]) -> Result<()> {
432        let texts: Vec<&str> = sample_chunks.iter().map(|c| c.content.as_str()).collect();
433        let embeddings = self.embedder.embed_tokens_batch(&texts)?;
434        self.index.train(&embeddings)?;
435        Ok(())
436    }
437
438    /// Index a single chunk.
439    ///
440    /// The chunk is embedded and compressed using the trained codec.
441    /// Call `train()` before indexing.
442    pub fn index(&mut self, chunk: Chunk) -> Result<()> {
443        let embedding = self.embedder.embed_tokens(&chunk.content)?;
444        self.index.insert(chunk, embedding)?;
445        Ok(())
446    }
447
448    /// Index multiple chunks.
449    pub fn index_batch(&mut self, chunks: Vec<Chunk>) -> Result<()> {
450        for chunk in chunks {
451            self.index(chunk)?;
452        }
453        Ok(())
454    }
455
456    /// Build the index for efficient search.
457    ///
458    /// This compacts the index by organizing embeddings by centroid (IVF structure).
459    /// Call after all chunks have been indexed.
460    pub fn build(&mut self) -> Result<()> {
461        self.index.build()
462    }
463
464    /// Retrieve relevant chunks for a query using multi-vector MaxSim scoring.
465    ///
466    /// # Arguments
467    ///
468    /// * `query` - Query text
469    /// * `k` - Number of results to return
470    pub fn retrieve(&self, query: &str, k: usize) -> Result<Vec<RetrievalResult>> {
471        let query_embedding = self.embedder.embed_tokens(query)?;
472        let search_config = crate::multivector::WarpSearchConfig::with_k(k)
473            .nprobe(self.search_config.nprobe)
474            .bound(self.search_config.bound)
475            .centroid_score_threshold(self.search_config.centroid_score_threshold);
476        let results = self.index.search(&query_embedding, &search_config)?;
477
478        let mut retrieval_results = Vec::with_capacity(results.len());
479        for (chunk_id, score) in results {
480            if let Some(chunk) = self.index.get_chunk(&chunk_id) {
481                retrieval_results
482                    .push(RetrievalResult::new(chunk.clone()).with_multivector_score(score));
483            }
484        }
485
486        Ok(retrieval_results)
487    }
488
489    /// Get the number of indexed chunks.
490    #[must_use]
491    pub fn len(&self) -> usize {
492        self.index.num_chunks()
493    }
494
495    /// Check if the retriever is empty.
496    #[must_use]
497    pub fn is_empty(&self) -> bool {
498        self.len() == 0
499    }
500
501    /// Get the underlying WARP index.
502    #[must_use]
503    pub fn warp_index(&self) -> &crate::multivector::WarpIndex {
504        &self.index
505    }
506
507    /// Get the embedder.
508    #[must_use]
509    pub fn embedder(&self) -> &E {
510        &self.embedder
511    }
512
513    /// Get memory usage statistics.
514    #[must_use]
515    pub fn memory_usage(&self) -> usize {
516        self.index.memory_usage()
517    }
518}
519
520#[cfg(test)]
521mod tests {
522    use super::*;
523    use crate::{embed::MockEmbedder, DocumentId};
524
525    fn create_test_chunk(content: &str, embedding: Vec<f32>) -> Chunk {
526        let mut chunk = Chunk::new(DocumentId::new(), content.to_string(), 0, content.len());
527        chunk.set_embedding(embedding);
528        chunk
529    }
530
531    // ============ RetrievalResult Tests ============
532
533    #[test]
534    fn test_retrieval_result_new() {
535        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
536        let result = RetrievalResult::new(chunk);
537
538        assert!(result.dense_score.is_none());
539        assert!(result.sparse_score.is_none());
540        assert!(result.fused_score.is_none());
541        assert!(result.rerank_score.is_none());
542    }
543
544    #[test]
545    fn test_retrieval_result_with_scores() {
546        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
547        let result = RetrievalResult::new(chunk)
548            .with_dense_score(0.9)
549            .with_sparse_score(0.8)
550            .with_fused_score(0.85)
551            .with_rerank_score(0.95);
552
553        assert_eq!(result.dense_score, Some(0.9));
554        assert_eq!(result.sparse_score, Some(0.8));
555        assert_eq!(result.fused_score, Some(0.85));
556        assert_eq!(result.rerank_score, Some(0.95));
557    }
558
559    #[test]
560    fn test_retrieval_result_best_score_priority() {
561        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
562
563        // Rerank takes priority
564        let result =
565            RetrievalResult::new(chunk.clone()).with_dense_score(0.5).with_rerank_score(0.9);
566        assert!((result.best_score() - 0.9).abs() < 0.001);
567
568        // Fused takes priority over dense/sparse
569        let result =
570            RetrievalResult::new(chunk.clone()).with_dense_score(0.5).with_fused_score(0.7);
571        assert!((result.best_score() - 0.7).abs() < 0.001);
572
573        // Dense used when nothing else available
574        let result = RetrievalResult::new(chunk).with_dense_score(0.5);
575        assert!((result.best_score() - 0.5).abs() < 0.001);
576    }
577
578    #[test]
579    fn test_retrieval_result_best_score_default() {
580        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
581        let result = RetrievalResult::new(chunk);
582        assert!((result.best_score() - 0.0).abs() < 0.001);
583    }
584
585    // ============ HybridRetrieverConfig Tests ============
586
587    #[test]
588    fn test_hybrid_config_default() {
589        let config = HybridRetrieverConfig::default();
590        assert_eq!(config.candidates_per_source, 50);
591        assert!(config.use_dense);
592        assert!(config.use_sparse);
593    }
594
595    // ============ HybridRetriever Tests ============
596
597    #[test]
598    fn test_hybrid_retriever_new() {
599        let embedder = MockEmbedder::new(64);
600        let dense = VectorStore::with_dimension(64);
601        let sparse = BM25Index::new();
602
603        let retriever = HybridRetriever::new(dense, sparse, embedder);
604        assert!(retriever.is_empty());
605    }
606
607    #[test]
608    fn test_hybrid_retriever_index() {
609        let embedder = MockEmbedder::new(64);
610        let dense = VectorStore::with_dimension(64);
611        let sparse = BM25Index::new();
612
613        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
614
615        let chunk = create_test_chunk("machine learning is great", vec![0.0; 64]);
616        retriever.index(chunk).unwrap();
617
618        assert_eq!(retriever.len(), 1);
619    }
620
621    #[test]
622    fn test_hybrid_retriever_index_batch() {
623        let embedder = MockEmbedder::new(64);
624        let dense = VectorStore::with_dimension(64);
625        let sparse = BM25Index::new();
626
627        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
628
629        let chunks = vec![
630            create_test_chunk("first document", vec![1.0; 64]),
631            create_test_chunk("second document", vec![0.5; 64]),
632        ];
633        retriever.index_batch(chunks).unwrap();
634
635        assert_eq!(retriever.len(), 2);
636    }
637
638    #[test]
639    fn test_hybrid_retriever_retrieve() {
640        let embedder = MockEmbedder::new(3);
641        let dense = VectorStore::with_dimension(3);
642        let sparse = BM25Index::new();
643
644        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
645
646        // Index some chunks
647        retriever
648            .index(create_test_chunk("machine learning algorithms", vec![1.0, 0.0, 0.0]))
649            .unwrap();
650        retriever
651            .index(create_test_chunk("deep learning neural networks", vec![0.9, 0.1, 0.0]))
652            .unwrap();
653        retriever.index(create_test_chunk("cooking recipes", vec![0.0, 0.0, 1.0])).unwrap();
654
655        let results = retriever.retrieve("machine learning", 2).unwrap();
656
657        assert!(!results.is_empty());
658        assert!(results.len() <= 2);
659    }
660
661    #[test]
662    fn test_hybrid_retriever_retrieve_dense_only() {
663        let embedder = MockEmbedder::new(3);
664        let dense = VectorStore::with_dimension(3);
665        let sparse = BM25Index::new();
666
667        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
668
669        retriever.index(create_test_chunk("test doc", vec![1.0, 0.0, 0.0])).unwrap();
670
671        let results = retriever.retrieve_dense("test", 10).unwrap();
672        assert!(!results.is_empty());
673        assert!(results[0].dense_score.is_some());
674        assert!(results[0].sparse_score.is_none());
675    }
676
677    #[test]
678    fn test_hybrid_retriever_retrieve_sparse_only() {
679        let embedder = MockEmbedder::new(3);
680        let dense = VectorStore::with_dimension(3);
681        let sparse = BM25Index::new();
682
683        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
684
685        retriever.index(create_test_chunk("machine learning test", vec![1.0, 0.0, 0.0])).unwrap();
686
687        let results = retriever.retrieve_sparse("machine", 10).unwrap();
688        assert!(!results.is_empty());
689        assert!(results[0].sparse_score.is_some());
690        assert!(results[0].dense_score.is_none());
691    }
692
693    #[test]
694    fn test_hybrid_retriever_config() {
695        let embedder = MockEmbedder::new(3);
696        let dense = VectorStore::with_dimension(3);
697        let sparse = BM25Index::new();
698
699        let config = HybridRetrieverConfig {
700            candidates_per_source: 100,
701            fusion: FusionStrategy::Linear { dense_weight: 0.7 },
702            use_dense: true,
703            use_sparse: true,
704        };
705
706        let retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
707
708        assert_eq!(retriever.config.candidates_per_source, 100);
709    }
710
711    // ============ DenseRetriever Tests ============
712
713    #[test]
714    fn test_dense_retriever() {
715        let embedder = MockEmbedder::new(3);
716        let store = VectorStore::with_dimension(3);
717        let mut retriever = DenseRetriever::new(store, embedder);
718
719        retriever.index(create_test_chunk("test document", vec![1.0, 0.0, 0.0])).unwrap();
720
721        let results = retriever.retrieve("test", 10).unwrap();
722        assert_eq!(results.len(), 1);
723        assert!(results[0].dense_score.is_some());
724    }
725
726    // ============ SparseRetriever Tests ============
727
728    #[test]
729    fn test_sparse_retriever_new() {
730        let retriever = SparseRetriever::new();
731        let results = retriever.retrieve("test", 10);
732        assert!(results.is_empty());
733    }
734
735    #[test]
736    fn test_sparse_retriever_index() {
737        let mut retriever = SparseRetriever::new();
738        let chunk = Chunk::new(DocumentId::new(), "machine learning test".to_string(), 0, 20);
739
740        retriever.index(chunk);
741        let results = retriever.retrieve("machine", 10);
742
743        assert_eq!(results.len(), 1);
744        assert!(results[0].sparse_score.is_some());
745    }
746
747    #[test]
748    fn test_sparse_retriever_multiple() {
749        let mut retriever = SparseRetriever::new();
750
751        retriever.index(Chunk::new(
752            DocumentId::new(),
753            "rust programming language".to_string(),
754            0,
755            24,
756        ));
757        retriever.index(Chunk::new(
758            DocumentId::new(),
759            "python programming language".to_string(),
760            0,
761            26,
762        ));
763
764        let results = retriever.retrieve("programming", 10);
765        assert_eq!(results.len(), 2);
766    }
767
768    // ============ Additional Coverage Tests ============
769
770    #[test]
771    fn test_hybrid_retriever_store_accessors() {
772        let embedder = MockEmbedder::new(64);
773        let dense = VectorStore::with_dimension(64);
774        let sparse = BM25Index::new();
775
776        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
777
778        // Test immutable accessors
779        let _dense_store = retriever.dense_store();
780        let _sparse_index = retriever.sparse_index();
781
782        // Test mutable accessors
783        let dense_mut = retriever.dense_store_mut();
784        assert!(dense_mut.is_empty());
785
786        let sparse_mut = retriever.sparse_index_mut();
787        let _ = sparse_mut; // Just verify it compiles and works
788    }
789
790    #[test]
791    fn test_hybrid_retriever_is_empty() {
792        let embedder = MockEmbedder::new(64);
793        let dense = VectorStore::with_dimension(64);
794        let sparse = BM25Index::new();
795
796        let mut retriever = HybridRetriever::new(dense, sparse, embedder);
797        assert!(retriever.is_empty());
798
799        retriever.index(create_test_chunk("test", vec![0.0; 64])).unwrap();
800        assert!(!retriever.is_empty());
801    }
802
803    #[test]
804    fn test_sparse_retriever_default() {
805        let retriever = SparseRetriever::default();
806        let results = retriever.retrieve("test", 10);
807        assert!(results.is_empty());
808    }
809
810    #[test]
811    fn test_retrieval_result_best_score_sparse_fallback() {
812        let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
813
814        // Only sparse score available
815        let result = RetrievalResult::new(chunk).with_sparse_score(0.75);
816        assert!((result.best_score() - 0.75).abs() < 0.001);
817    }
818
819    #[test]
820    fn test_hybrid_retriever_with_dense_disabled() {
821        let embedder = MockEmbedder::new(3);
822        let dense = VectorStore::with_dimension(3);
823        let sparse = BM25Index::new();
824
825        let config = HybridRetrieverConfig {
826            candidates_per_source: 50,
827            fusion: FusionStrategy::default(),
828            use_dense: false,
829            use_sparse: true,
830        };
831
832        let mut retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
833
834        retriever.index(create_test_chunk("machine learning test", vec![1.0, 0.0, 0.0])).unwrap();
835
836        // Should still work, using only sparse
837        let results = retriever.retrieve("machine", 10).unwrap();
838        // Results depend on sparse-only fusion
839        assert!(results.len() <= 10);
840    }
841
842    #[test]
843    fn test_hybrid_retriever_with_sparse_disabled() {
844        let embedder = MockEmbedder::new(3);
845        let dense = VectorStore::with_dimension(3);
846        let sparse = BM25Index::new();
847
848        let config = HybridRetrieverConfig {
849            candidates_per_source: 50,
850            fusion: FusionStrategy::default(),
851            use_dense: true,
852            use_sparse: false,
853        };
854
855        let mut retriever = HybridRetriever::new(dense, sparse, embedder).with_config(config);
856
857        retriever.index(create_test_chunk("test content", vec![1.0, 0.0, 0.0])).unwrap();
858
859        // Should still work, using only dense
860        let results = retriever.retrieve("test", 10).unwrap();
861        assert!(results.len() <= 10);
862    }
863
864    #[test]
865    fn test_hybrid_retriever_config_serialization() {
866        let config = HybridRetrieverConfig {
867            candidates_per_source: 100,
868            fusion: FusionStrategy::RRF { k: 60.0 },
869            use_dense: true,
870            use_sparse: false,
871        };
872
873        let json = serde_json::to_string(&config).unwrap();
874        let deserialized: HybridRetrieverConfig = serde_json::from_str(&json).unwrap();
875
876        assert_eq!(config.candidates_per_source, deserialized.candidates_per_source);
877        assert_eq!(config.use_dense, deserialized.use_dense);
878        assert_eq!(config.use_sparse, deserialized.use_sparse);
879    }
880
881    #[test]
882    fn test_retrieval_result_serialization() {
883        let chunk = Chunk::new(DocumentId::new(), "test content".to_string(), 0, 12);
884        let result = RetrievalResult::new(chunk)
885            .with_dense_score(0.9)
886            .with_sparse_score(0.8)
887            .with_fused_score(0.85)
888            .with_rerank_score(0.95);
889
890        let json = serde_json::to_string(&result).unwrap();
891        let deserialized: RetrievalResult = serde_json::from_str(&json).unwrap();
892
893        assert_eq!(result.dense_score, deserialized.dense_score);
894        assert_eq!(result.sparse_score, deserialized.sparse_score);
895        assert_eq!(result.fused_score, deserialized.fused_score);
896        assert_eq!(result.rerank_score, deserialized.rerank_score);
897    }
898
899    // ============ Property-Based Tests ============
900
901    use proptest::prelude::*;
902
903    proptest! {
904        #[test]
905        fn prop_retrieval_result_scores_preserved(
906            dense in 0.0f32..1.0,
907            sparse in 0.0f32..1.0,
908            fused in 0.0f32..1.0
909        ) {
910            let chunk = Chunk::new(DocumentId::new(), "test".to_string(), 0, 4);
911            let result = RetrievalResult::new(chunk)
912                .with_dense_score(dense)
913                .with_sparse_score(sparse)
914                .with_fused_score(fused);
915
916            prop_assert!((result.dense_score.unwrap() - dense).abs() < 0.0001);
917            prop_assert!((result.sparse_score.unwrap() - sparse).abs() < 0.0001);
918            prop_assert!((result.fused_score.unwrap() - fused).abs() < 0.0001);
919        }
920
921        #[test]
922        fn prop_hybrid_retriever_respects_k(k in 1usize..10) {
923            let embedder = MockEmbedder::new(3);
924            let dense = VectorStore::with_dimension(3);
925            let sparse = BM25Index::new();
926
927            let mut retriever = HybridRetriever::new(dense, sparse, embedder);
928
929            // Add more chunks than k
930            for i in 0..20 {
931                let mut emb = vec![0.0; 3];
932                emb[i % 3] = 1.0;
933                retriever.index(create_test_chunk(
934                    &format!("document number {i} about testing"),
935                    emb,
936                )).unwrap();
937            }
938
939            let results = retriever.retrieve("testing", k).unwrap();
940            prop_assert!(results.len() <= k);
941        }
942    }
943}