Skip to main content

trueno_rag/multivector/
mod.rs

1//! Multi-vector retrieval with WARP algorithm
2//!
3//! This module provides ColBERT-style multi-vector retrieval using the WARP
4//! (Weighted Approximate Residual Product) algorithm. Unlike single-vector
5//! dense retrieval, multi-vector approaches represent each document and query
6//! as multiple token embeddings, enabling fine-grained "late interaction" scoring.
7//!
8//! # Overview
9//!
10//! The WARP algorithm provides memory-efficient multi-vector search by:
11//!
12//! 1. **Residual Quantization** - Compress token embeddings from 32-bit floats
13//!    to 2-4 bits per dimension using centroid-based encoding
14//! 2. **IVF Indexing** - Organize embeddings by centroid for cache-efficient access
15//! 3. **Deferred Decompression** - Score directly from compressed representations
16//!
17//! # Key Components
18//!
19//! - [`MultiVectorEmbedding`] - A document/query represented as multiple token embeddings
20//! - [`WarpIndex`] - The main index structure with train/insert/build/search methods
21//! - [`WarpIndexConfig`] - Configuration for index construction
22//! - [`WarpSearchConfig`] - Configuration for search parameters
23//! - [`ResidualCodec`] - Compression codec for token embeddings
24//! - [`MultiVectorEmbedder`] - Trait for token-level embedding models
25//! - [`MultiVectorRetriever`] - High-level retriever combining embedder and index
26//!
27//! # Quick Start
28//!
29//! ```ignore
30//! use trueno_rag::multivector::{
31//!     WarpIndex, WarpIndexConfig, WarpSearchConfig,
32//!     MockMultiVectorEmbedder, MultiVectorEmbedder,
33//!     MultiVectorRetriever,
34//! };
35//!
36//! // Create retriever with mock embedder
37//! let config = WarpIndexConfig::new(2, 256, 128);
38//! let embedder = MockMultiVectorEmbedder::new(128, 512);
39//! let mut retriever = MultiVectorRetriever::new(config, embedder);
40//!
41//! // Train on sample documents
42//! retriever.train(&sample_chunks)?;
43//!
44//! // Index documents
45//! for chunk in chunks {
46//!     retriever.index(chunk)?;
47//! }
48//! retriever.build()?;
49//!
50//! // Search
51//! let results = retriever.retrieve("What is machine learning?", 10)?;
52//! ```
53//!
54//! # Theory: MaxSim Scoring
55//!
56//! ColBERT uses MaxSim scoring which computes, for query Q with tokens {q₁...qₘ}
57//! and document D with tokens {d₁...dₙ}:
58//!
59//! ```text
60//! MaxSim(Q, D) = Σᵢ maxⱼ(qᵢ · dⱼ)
61//! ```
62//!
63//! For each query token, find the maximum similarity with any document token,
64//! then sum across query tokens. This captures soft alignment without explicit
65//! matching.
66//!
67//! # Feature Flag
68//!
69//! This module is only available with the `multivector` feature:
70//!
71//! ```toml
72//! [dependencies]
73//! trueno-rag = { version = "0.1", features = ["multivector"] }
74//! ```
75//!
76//! # References
77//!
78//! - Khattab & Zaharia (2020). "ColBERT: Efficient and Effective Passage Search
79//!   via Contextualized Late Interaction over BERT." SIGIR 2020.
80//! - Santhanam et al. (2022). "ColBERTv2: Effective and Efficient Retrieval via
81//!   Lightweight Late Interaction." NAACL 2022.
82
83pub mod codec;
84pub mod embedder;
85#[cfg(test)]
86pub mod falsification;
87pub mod index;
88pub mod search;
89pub mod types;
90
91// Re-export main types at module level
92pub use codec::{ResidualCodec, ResidualCodecBuilder};
93pub use embedder::{MockMultiVectorEmbedder, MultiVectorEmbedder};
94pub use index::WarpIndex;
95pub use search::{exact_maxsim, CandidateScorer, CentroidSelector, ScoreMerger};
96pub use types::{MultiVectorEmbedding, WarpIndexConfig, WarpSearchConfig};
97
98// Re-export retriever (defined in retrieve.rs but part of this feature)
99// Note: MultiVectorRetriever is in retrieve.rs, not here
100
101#[cfg(test)]
102mod tests {
103    use super::*;
104    use crate::{Chunk, DocumentId};
105
106    /// Integration test: full pipeline from training to search
107    #[test]
108    fn test_full_pipeline() {
109        // 1. Create embedder
110        let embedder = MockMultiVectorEmbedder::new(32, 128);
111
112        // 2. Configure index - use fewer centroids to match training data size
113        // (need at least 10 tokens per centroid for training)
114        let config = WarpIndexConfig::new(2, 4, 32).with_kmeans_iterations(5);
115
116        // 3. Create index
117        let mut index = WarpIndex::new(config);
118
119        // 4. Generate training data with enough tokens
120        let training_texts = [
121            "machine learning algorithms are powerful tools for data science",
122            "deep neural networks have revolutionized computer vision tasks",
123            "natural language processing enables machines to understand text",
124            "computer vision systems detect objects in images and video",
125            "reinforcement learning agents learn through trial and error",
126            "transformer architectures power modern large language models",
127            "attention mechanisms allow models to focus on relevant inputs",
128            "gradient descent optimization updates neural network parameters",
129        ];
130
131        let training_embeddings: Vec<_> =
132            training_texts.iter().map(|t| embedder.embed_tokens(t).unwrap()).collect();
133
134        // 5. Train codec
135        index.train(&training_embeddings).unwrap();
136
137        // 6. Insert documents
138        for text in training_texts.iter() {
139            let chunk = Chunk::new(DocumentId::new(), text.to_string(), 0, text.len());
140            let embedding = embedder.embed_tokens(text).unwrap();
141            index.insert(chunk, embedding).unwrap();
142        }
143
144        // 7. Build index
145        index.build().unwrap();
146
147        // 8. Search
148        let query_text = "neural network learning";
149        let query_embedding = embedder.embed_tokens(query_text).unwrap();
150        let search_config = WarpSearchConfig::with_k(3);
151
152        let results = index.search(&query_embedding, &search_config).unwrap();
153
154        // Verify results
155        assert!(!results.is_empty());
156        assert!(results.len() <= 3);
157
158        // Results should be sorted by score descending
159        for i in 1..results.len() {
160            assert!(results[i - 1].1 >= results[i].1);
161        }
162
163        // Can retrieve chunks by ID
164        for (chunk_id, _score) in &results {
165            let chunk = index.get_chunk(chunk_id);
166            assert!(chunk.is_some());
167        }
168    }
169
170    /// Test exact MaxSim matches expected values
171    #[test]
172    fn test_exact_maxsim_calculation() {
173        // Query: 2 tokens
174        let query = MultiVectorEmbedding::new(
175            vec![
176                1.0, 0.0, 0.0, 0.0, // q1
177                0.0, 1.0, 0.0, 0.0, // q2
178            ],
179            2,
180            4,
181        );
182
183        // Doc: 3 tokens
184        let doc = MultiVectorEmbedding::new(
185            vec![
186                0.5, 0.5, 0.0, 0.0, // d1: q1·d1=0.5, q2·d1=0.5
187                1.0, 0.0, 0.0, 0.0, // d2: q1·d2=1.0, q2·d2=0.0
188                0.0, 0.0, 1.0, 0.0, // d3: q1·d3=0.0, q2·d3=0.0
189            ],
190            3,
191            4,
192        );
193
194        let score = exact_maxsim(&query, &doc);
195
196        // MaxSim = max(0.5, 1.0, 0.0) + max(0.5, 0.0, 0.0) = 1.0 + 0.5 = 1.5
197        assert!((score - 1.5).abs() < 1e-6);
198    }
199
200    /// Test that compression preserves relative ordering
201    #[test]
202    fn test_compression_preserves_ordering() {
203        let embedder = MockMultiVectorEmbedder::new(32, 128);
204
205        // Create documents with varying relevance
206        let query = embedder.embed_tokens("machine learning").unwrap();
207        let doc_relevant = embedder.embed_tokens("machine learning algorithms").unwrap();
208        let doc_partial = embedder.embed_tokens("learning systems").unwrap();
209        let doc_irrelevant = embedder.embed_tokens("cooking recipes").unwrap();
210
211        // Exact scores
212        let exact_relevant = exact_maxsim(&query, &doc_relevant);
213        let _exact_partial = exact_maxsim(&query, &doc_partial);
214        let exact_irrelevant = exact_maxsim(&query, &doc_irrelevant);
215
216        // Verify relative ordering makes sense
217        // (relevant should score higher than irrelevant)
218        assert!(
219            exact_relevant > exact_irrelevant,
220            "Relevant doc should score higher: {} vs {}",
221            exact_relevant,
222            exact_irrelevant
223        );
224    }
225
226    /// Test search with various nprobe settings
227    #[test]
228    fn test_search_nprobe_variations() {
229        let embedder = MockMultiVectorEmbedder::new(16, 64);
230        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
231        let mut index = WarpIndex::new(config);
232
233        // Train and build
234        let texts: Vec<String> = (0..50).map(|i| format!("document number {}", i)).collect();
235        let embeddings: Vec<_> = texts.iter().map(|t| embedder.embed_tokens(t).unwrap()).collect();
236        index.train(&embeddings).unwrap();
237
238        for (i, text) in texts.iter().enumerate() {
239            let chunk = Chunk::new(DocumentId::new(), text.clone(), 0, text.len());
240            index.insert(chunk, embeddings[i].clone()).unwrap();
241        }
242        index.build().unwrap();
243
244        let query = embedder.embed_tokens("document number").unwrap();
245
246        // Test with different nprobe values
247        for nprobe in [1, 2, 4, 8] {
248            let config = WarpSearchConfig::with_k(5).nprobe(nprobe);
249            let results = index.search(&query, &config).unwrap();
250
251            assert!(results.len() <= 5, "nprobe={}: got {} results", nprobe, results.len());
252        }
253    }
254
255    /// Test memory usage is reasonable
256    #[test]
257    fn test_memory_efficiency() {
258        let embedder = MockMultiVectorEmbedder::new(128, 512);
259        // Use fewer centroids - need 10 tokens per centroid for training
260        let config = WarpIndexConfig::new(2, 8, 128).with_kmeans_iterations(5);
261        let mut index = WarpIndex::new(config);
262
263        // Train with more tokens per document (8 centroids * 10 = 80 tokens needed)
264        let texts: Vec<String> = (0..50)
265            .map(|i| {
266                format!("document number {} contains important information about topic {}", i, i)
267            })
268            .collect();
269        let embeddings: Vec<_> = texts.iter().map(|t| embedder.embed_tokens(t).unwrap()).collect();
270        index.train(&embeddings).unwrap();
271
272        // Insert
273        for (i, text) in texts.iter().enumerate() {
274            let chunk = Chunk::new(DocumentId::new(), text.clone(), 0, text.len());
275            index.insert(chunk, embeddings[i].clone()).unwrap();
276        }
277        index.build().unwrap();
278
279        let memory = index.memory_usage();
280        let num_tokens = index.num_tokens();
281
282        // With 2-bit compression: 128 dims × 2 bits = 32 bytes per token
283        // Plus overhead for chunk_ids, token_indices, etc.
284        let theoretical_min = num_tokens * 32;
285        let overhead_factor = 3.0; // Allow 3× overhead for metadata
286
287        assert!(
288            memory < (theoretical_min as f64 * overhead_factor) as usize,
289            "Memory {} too high for {} tokens (theoretical min {})",
290            memory,
291            num_tokens,
292            theoretical_min
293        );
294    }
295}