trueno_rag/multivector/mod.rs
1//! Multi-vector retrieval with WARP algorithm
2//!
3//! This module provides ColBERT-style multi-vector retrieval using the WARP
4//! (Weighted Approximate Residual Product) algorithm. Unlike single-vector
5//! dense retrieval, multi-vector approaches represent each document and query
6//! as multiple token embeddings, enabling fine-grained "late interaction" scoring.
7//!
8//! # Overview
9//!
10//! The WARP algorithm provides memory-efficient multi-vector search by:
11//!
12//! 1. **Residual Quantization** - Compress token embeddings from 32-bit floats
13//! to 2-4 bits per dimension using centroid-based encoding
14//! 2. **IVF Indexing** - Organize embeddings by centroid for cache-efficient access
15//! 3. **Deferred Decompression** - Score directly from compressed representations
16//!
17//! # Key Components
18//!
19//! - [`MultiVectorEmbedding`] - A document/query represented as multiple token embeddings
20//! - [`WarpIndex`] - The main index structure with train/insert/build/search methods
21//! - [`WarpIndexConfig`] - Configuration for index construction
22//! - [`WarpSearchConfig`] - Configuration for search parameters
23//! - [`ResidualCodec`] - Compression codec for token embeddings
24//! - [`MultiVectorEmbedder`] - Trait for token-level embedding models
25//! - [`MultiVectorRetriever`] - High-level retriever combining embedder and index
26//!
27//! # Quick Start
28//!
29//! ```ignore
30//! use trueno_rag::multivector::{
31//! WarpIndex, WarpIndexConfig, WarpSearchConfig,
32//! MockMultiVectorEmbedder, MultiVectorEmbedder,
33//! MultiVectorRetriever,
34//! };
35//!
36//! // Create retriever with mock embedder
37//! let config = WarpIndexConfig::new(2, 256, 128);
38//! let embedder = MockMultiVectorEmbedder::new(128, 512);
39//! let mut retriever = MultiVectorRetriever::new(config, embedder);
40//!
41//! // Train on sample documents
42//! retriever.train(&sample_chunks)?;
43//!
44//! // Index documents
45//! for chunk in chunks {
46//! retriever.index(chunk)?;
47//! }
48//! retriever.build()?;
49//!
50//! // Search
51//! let results = retriever.retrieve("What is machine learning?", 10)?;
52//! ```
53//!
54//! # Theory: MaxSim Scoring
55//!
56//! ColBERT uses MaxSim scoring which computes, for query Q with tokens {q₁...qₘ}
57//! and document D with tokens {d₁...dₙ}:
58//!
59//! ```text
60//! MaxSim(Q, D) = Σᵢ maxⱼ(qᵢ · dⱼ)
61//! ```
62//!
63//! For each query token, find the maximum similarity with any document token,
64//! then sum across query tokens. This captures soft alignment without explicit
65//! matching.
66//!
67//! # Feature Flag
68//!
69//! This module is only available with the `multivector` feature:
70//!
71//! ```toml
72//! [dependencies]
73//! trueno-rag = { version = "0.1", features = ["multivector"] }
74//! ```
75//!
76//! # References
77//!
78//! - Khattab & Zaharia (2020). "ColBERT: Efficient and Effective Passage Search
79//! via Contextualized Late Interaction over BERT." SIGIR 2020.
80//! - Santhanam et al. (2022). "ColBERTv2: Effective and Efficient Retrieval via
81//! Lightweight Late Interaction." NAACL 2022.
82
83pub mod codec;
84pub mod embedder;
85#[cfg(test)]
86pub mod falsification;
87pub mod index;
88pub mod search;
89pub mod types;
90
91// Re-export main types at module level
92pub use codec::{ResidualCodec, ResidualCodecBuilder};
93pub use embedder::{MockMultiVectorEmbedder, MultiVectorEmbedder};
94pub use index::WarpIndex;
95pub use search::{exact_maxsim, CandidateScorer, CentroidSelector, ScoreMerger};
96pub use types::{MultiVectorEmbedding, WarpIndexConfig, WarpSearchConfig};
97
98// Re-export retriever (defined in retrieve.rs but part of this feature)
99// Note: MultiVectorRetriever is in retrieve.rs, not here
100
101#[cfg(test)]
102mod tests {
103 use super::*;
104 use crate::{Chunk, DocumentId};
105
106 /// Integration test: full pipeline from training to search
107 #[test]
108 fn test_full_pipeline() {
109 // 1. Create embedder
110 let embedder = MockMultiVectorEmbedder::new(32, 128);
111
112 // 2. Configure index - use fewer centroids to match training data size
113 // (need at least 10 tokens per centroid for training)
114 let config = WarpIndexConfig::new(2, 4, 32).with_kmeans_iterations(5);
115
116 // 3. Create index
117 let mut index = WarpIndex::new(config);
118
119 // 4. Generate training data with enough tokens
120 let training_texts = [
121 "machine learning algorithms are powerful tools for data science",
122 "deep neural networks have revolutionized computer vision tasks",
123 "natural language processing enables machines to understand text",
124 "computer vision systems detect objects in images and video",
125 "reinforcement learning agents learn through trial and error",
126 "transformer architectures power modern large language models",
127 "attention mechanisms allow models to focus on relevant inputs",
128 "gradient descent optimization updates neural network parameters",
129 ];
130
131 let training_embeddings: Vec<_> =
132 training_texts.iter().map(|t| embedder.embed_tokens(t).unwrap()).collect();
133
134 // 5. Train codec
135 index.train(&training_embeddings).unwrap();
136
137 // 6. Insert documents
138 for text in training_texts.iter() {
139 let chunk = Chunk::new(DocumentId::new(), text.to_string(), 0, text.len());
140 let embedding = embedder.embed_tokens(text).unwrap();
141 index.insert(chunk, embedding).unwrap();
142 }
143
144 // 7. Build index
145 index.build().unwrap();
146
147 // 8. Search
148 let query_text = "neural network learning";
149 let query_embedding = embedder.embed_tokens(query_text).unwrap();
150 let search_config = WarpSearchConfig::with_k(3);
151
152 let results = index.search(&query_embedding, &search_config).unwrap();
153
154 // Verify results
155 assert!(!results.is_empty());
156 assert!(results.len() <= 3);
157
158 // Results should be sorted by score descending
159 for i in 1..results.len() {
160 assert!(results[i - 1].1 >= results[i].1);
161 }
162
163 // Can retrieve chunks by ID
164 for (chunk_id, _score) in &results {
165 let chunk = index.get_chunk(chunk_id);
166 assert!(chunk.is_some());
167 }
168 }
169
170 /// Test exact MaxSim matches expected values
171 #[test]
172 fn test_exact_maxsim_calculation() {
173 // Query: 2 tokens
174 let query = MultiVectorEmbedding::new(
175 vec![
176 1.0, 0.0, 0.0, 0.0, // q1
177 0.0, 1.0, 0.0, 0.0, // q2
178 ],
179 2,
180 4,
181 );
182
183 // Doc: 3 tokens
184 let doc = MultiVectorEmbedding::new(
185 vec![
186 0.5, 0.5, 0.0, 0.0, // d1: q1·d1=0.5, q2·d1=0.5
187 1.0, 0.0, 0.0, 0.0, // d2: q1·d2=1.0, q2·d2=0.0
188 0.0, 0.0, 1.0, 0.0, // d3: q1·d3=0.0, q2·d3=0.0
189 ],
190 3,
191 4,
192 );
193
194 let score = exact_maxsim(&query, &doc);
195
196 // MaxSim = max(0.5, 1.0, 0.0) + max(0.5, 0.0, 0.0) = 1.0 + 0.5 = 1.5
197 assert!((score - 1.5).abs() < 1e-6);
198 }
199
200 /// Test that compression preserves relative ordering
201 #[test]
202 fn test_compression_preserves_ordering() {
203 let embedder = MockMultiVectorEmbedder::new(32, 128);
204
205 // Create documents with varying relevance
206 let query = embedder.embed_tokens("machine learning").unwrap();
207 let doc_relevant = embedder.embed_tokens("machine learning algorithms").unwrap();
208 let doc_partial = embedder.embed_tokens("learning systems").unwrap();
209 let doc_irrelevant = embedder.embed_tokens("cooking recipes").unwrap();
210
211 // Exact scores
212 let exact_relevant = exact_maxsim(&query, &doc_relevant);
213 let _exact_partial = exact_maxsim(&query, &doc_partial);
214 let exact_irrelevant = exact_maxsim(&query, &doc_irrelevant);
215
216 // Verify relative ordering makes sense
217 // (relevant should score higher than irrelevant)
218 assert!(
219 exact_relevant > exact_irrelevant,
220 "Relevant doc should score higher: {} vs {}",
221 exact_relevant,
222 exact_irrelevant
223 );
224 }
225
226 /// Test search with various nprobe settings
227 #[test]
228 fn test_search_nprobe_variations() {
229 let embedder = MockMultiVectorEmbedder::new(16, 64);
230 let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
231 let mut index = WarpIndex::new(config);
232
233 // Train and build
234 let texts: Vec<String> = (0..50).map(|i| format!("document number {}", i)).collect();
235 let embeddings: Vec<_> = texts.iter().map(|t| embedder.embed_tokens(t).unwrap()).collect();
236 index.train(&embeddings).unwrap();
237
238 for (i, text) in texts.iter().enumerate() {
239 let chunk = Chunk::new(DocumentId::new(), text.clone(), 0, text.len());
240 index.insert(chunk, embeddings[i].clone()).unwrap();
241 }
242 index.build().unwrap();
243
244 let query = embedder.embed_tokens("document number").unwrap();
245
246 // Test with different nprobe values
247 for nprobe in [1, 2, 4, 8] {
248 let config = WarpSearchConfig::with_k(5).nprobe(nprobe);
249 let results = index.search(&query, &config).unwrap();
250
251 assert!(results.len() <= 5, "nprobe={}: got {} results", nprobe, results.len());
252 }
253 }
254
255 /// Test memory usage is reasonable
256 #[test]
257 fn test_memory_efficiency() {
258 let embedder = MockMultiVectorEmbedder::new(128, 512);
259 // Use fewer centroids - need 10 tokens per centroid for training
260 let config = WarpIndexConfig::new(2, 8, 128).with_kmeans_iterations(5);
261 let mut index = WarpIndex::new(config);
262
263 // Train with more tokens per document (8 centroids * 10 = 80 tokens needed)
264 let texts: Vec<String> = (0..50)
265 .map(|i| {
266 format!("document number {} contains important information about topic {}", i, i)
267 })
268 .collect();
269 let embeddings: Vec<_> = texts.iter().map(|t| embedder.embed_tokens(t).unwrap()).collect();
270 index.train(&embeddings).unwrap();
271
272 // Insert
273 for (i, text) in texts.iter().enumerate() {
274 let chunk = Chunk::new(DocumentId::new(), text.clone(), 0, text.len());
275 index.insert(chunk, embeddings[i].clone()).unwrap();
276 }
277 index.build().unwrap();
278
279 let memory = index.memory_usage();
280 let num_tokens = index.num_tokens();
281
282 // With 2-bit compression: 128 dims × 2 bits = 32 bytes per token
283 // Plus overhead for chunk_ids, token_indices, etc.
284 let theoretical_min = num_tokens * 32;
285 let overhead_factor = 3.0; // Allow 3× overhead for metadata
286
287 assert!(
288 memory < (theoretical_min as f64 * overhead_factor) as usize,
289 "Memory {} too high for {} tokens (theoretical min {})",
290 memory,
291 num_tokens,
292 theoretical_min
293 );
294 }
295}