Skip to main content

trueno_rag/multivector/
types.rs

1//! Core data structures for multi-vector retrieval
2//!
3//! This module defines the fundamental types used in WARP-based multi-vector
4//! retrieval, including embeddings, index configuration, and search parameters.
5
6use serde::{Deserialize, Serialize};
7
8/// A document or query represented as multiple token embeddings.
9///
10/// In ColBERT-style retrieval, each document and query is represented not by
11/// a single embedding vector, but by multiple vectors—one per token. This
12/// enables fine-grained "late interaction" scoring via MaxSim.
13///
14/// # Memory Layout
15///
16/// Embeddings are stored in a flattened contiguous array for cache efficiency:
17/// `[token_0_dim_0, token_0_dim_1, ..., token_1_dim_0, token_1_dim_1, ...]`
18///
19/// # Example
20///
21/// ```
22/// use trueno_rag::multivector::MultiVectorEmbedding;
23///
24/// // Create a 3-token embedding with 128 dimensions per token
25/// let embeddings = vec![0.0f32; 3 * 128];
26/// let mv = MultiVectorEmbedding::new(embeddings, 3, 128);
27///
28/// assert_eq!(mv.num_tokens(), 3);
29/// assert_eq!(mv.dim(), 128);
30/// assert_eq!(mv.token(0).len(), 128);
31/// ```
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct MultiVectorEmbedding {
34    /// Flattened embeddings: [num_tokens * dim]
35    embeddings: Vec<f32>,
36    /// Number of token embeddings
37    num_tokens: usize,
38    /// Dimension per token embedding
39    dim: usize,
40}
41
42impl MultiVectorEmbedding {
43    /// Create a new multi-vector embedding.
44    ///
45    /// # Panics
46    ///
47    /// Panics if `embeddings.len() != num_tokens * dim`.
48    #[must_use]
49    pub fn new(embeddings: Vec<f32>, num_tokens: usize, dim: usize) -> Self {
50        assert_eq!(
51            embeddings.len(),
52            num_tokens * dim,
53            "Embedding size mismatch: expected {} ({}×{}), got {}",
54            num_tokens * dim,
55            num_tokens,
56            dim,
57            embeddings.len()
58        );
59        Self { embeddings, num_tokens, dim }
60    }
61
62    /// Create from a vector of token embeddings.
63    #[must_use]
64    pub fn from_tokens(tokens: &[Vec<f32>]) -> Self {
65        if tokens.is_empty() {
66            return Self { embeddings: Vec::new(), num_tokens: 0, dim: 0 };
67        }
68
69        let dim = tokens[0].len();
70        let num_tokens = tokens.len();
71        let mut embeddings = Vec::with_capacity(num_tokens * dim);
72
73        for token in tokens {
74            assert_eq!(token.len(), dim, "All tokens must have the same dimension");
75            embeddings.extend_from_slice(token);
76        }
77
78        Self { embeddings, num_tokens, dim }
79    }
80
81    /// Get the number of token embeddings.
82    #[must_use]
83    pub fn num_tokens(&self) -> usize {
84        self.num_tokens
85    }
86
87    /// Get the dimension of each token embedding.
88    #[must_use]
89    pub fn dim(&self) -> usize {
90        self.dim
91    }
92
93    /// Get the i-th token embedding as a slice.
94    ///
95    /// # Panics
96    ///
97    /// Panics if `i >= num_tokens`.
98    #[must_use]
99    pub fn token(&self, i: usize) -> &[f32] {
100        assert!(i < self.num_tokens, "Token index out of bounds");
101        let start = i * self.dim;
102        &self.embeddings[start..start + self.dim]
103    }
104
105    /// Iterate over token embeddings.
106    pub fn tokens(&self) -> impl Iterator<Item = &[f32]> {
107        self.embeddings.chunks_exact(self.dim)
108    }
109
110    /// Get the raw flattened embeddings.
111    #[must_use]
112    pub fn as_slice(&self) -> &[f32] {
113        &self.embeddings
114    }
115
116    /// Get the raw flattened embeddings mutably.
117    pub fn as_mut_slice(&mut self) -> &mut [f32] {
118        &mut self.embeddings
119    }
120
121    /// Memory size in bytes (uncompressed).
122    #[must_use]
123    pub fn size_bytes(&self) -> usize {
124        self.embeddings.len() * size_of::<f32>()
125    }
126
127    /// Check if the embedding is empty (no tokens).
128    #[must_use]
129    pub fn is_empty(&self) -> bool {
130        self.num_tokens == 0
131    }
132}
133
134/// Configuration for WARP index construction.
135///
136/// These parameters control the compression quality and index structure.
137/// The default values provide a good balance of memory efficiency and
138/// retrieval quality for most use cases.
139///
140/// # Parameter Guidance
141///
142/// | Corpus Size  | nbits | num_centroids |
143/// |--------------|-------|---------------|
144/// | < 100K docs  | 4     | 256           |
145/// | 100K - 1M    | 2     | 1024          |
146/// | > 1M docs    | 2     | 4096          |
147#[derive(Debug, Clone, Serialize, Deserialize)]
148pub struct WarpIndexConfig {
149    /// Bits per dimension for residual quantization (2 or 4).
150    ///
151    /// - 2-bit: 16× compression, ~3-5% MRR loss
152    /// - 4-bit: 8× compression, ~1-2% MRR loss
153    pub nbits: u8,
154
155    /// Number of centroids for IVF clustering.
156    ///
157    /// More centroids provide finer-grained partitioning but require
158    /// more memory for centroid storage. Typical values: 256-4096.
159    pub num_centroids: usize,
160
161    /// Token embedding dimension (e.g., 128 for ColBERT).
162    pub token_dim: usize,
163
164    /// Minimum training samples for codec training.
165    ///
166    /// Should be at least 10 × num_centroids for stable clustering.
167    /// If None, defaults to 10 × num_centroids.
168    pub min_training_samples: Option<usize>,
169
170    /// K-means iterations for centroid training.
171    pub kmeans_iterations: usize,
172}
173
174impl Default for WarpIndexConfig {
175    fn default() -> Self {
176        Self {
177            nbits: 2,
178            num_centroids: 1024,
179            token_dim: 128,
180            min_training_samples: None,
181            kmeans_iterations: 20,
182        }
183    }
184}
185
186impl WarpIndexConfig {
187    /// Create a new configuration with the specified parameters.
188    #[must_use]
189    pub fn new(nbits: u8, num_centroids: usize, token_dim: usize) -> Self {
190        Self { nbits, num_centroids, token_dim, ..Default::default() }
191    }
192
193    /// Set the minimum training samples.
194    #[must_use]
195    pub fn with_min_training_samples(mut self, samples: usize) -> Self {
196        self.min_training_samples = Some(samples);
197        self
198    }
199
200    /// Set the k-means iterations.
201    #[must_use]
202    pub fn with_kmeans_iterations(mut self, iterations: usize) -> Self {
203        self.kmeans_iterations = iterations;
204        self
205    }
206
207    /// Get the effective minimum training samples.
208    #[must_use]
209    pub fn effective_min_training_samples(&self) -> usize {
210        self.min_training_samples.unwrap_or(10 * self.num_centroids)
211    }
212
213    /// Calculate packed residual size in bytes.
214    #[must_use]
215    pub fn packed_residual_size(&self) -> usize {
216        (self.token_dim * self.nbits as usize + 7) / 8
217    }
218
219    /// Validate the configuration.
220    pub fn validate(&self) -> Result<(), &'static str> {
221        if self.nbits != 2 && self.nbits != 4 {
222            return Err("nbits must be 2 or 4");
223        }
224        if self.num_centroids == 0 {
225            return Err("num_centroids must be > 0");
226        }
227        if self.token_dim == 0 {
228            return Err("token_dim must be > 0");
229        }
230        if self.kmeans_iterations == 0 {
231            return Err("kmeans_iterations must be > 0");
232        }
233        Ok(())
234    }
235}
236
237/// Configuration for WARP search.
238///
239/// These parameters control the trade-off between search speed and
240/// recall quality. The defaults are tuned for high recall (>95%).
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct WarpSearchConfig {
243    /// Number of results to return.
244    pub k: usize,
245
246    /// Centroids to probe per query token.
247    ///
248    /// Higher values increase recall but also latency.
249    /// Default: 4 (provides ~95% recall on most datasets).
250    pub nprobe: u32,
251
252    /// Maximum total centroids examined across all tokens.
253    ///
254    /// Acts as an upper bound on computation. Default: 128.
255    pub bound: usize,
256
257    /// Early termination: skip tokens after this many.
258    ///
259    /// For very long queries, processing all tokens may be wasteful.
260    /// Setting this limits which tokens contribute to scoring.
261    pub t_prime: Option<usize>,
262
263    /// Skip tokens with centroid score below threshold.
264    ///
265    /// Tokens that don't match any centroid well are unlikely to
266    /// contribute meaningful scores. Default: 0.4.
267    pub centroid_score_threshold: f32,
268}
269
270impl Default for WarpSearchConfig {
271    fn default() -> Self {
272        Self { k: 10, nprobe: 4, bound: 128, t_prime: None, centroid_score_threshold: 0.4 }
273    }
274}
275
276impl WarpSearchConfig {
277    /// Create a search config with the specified k.
278    #[must_use]
279    pub fn with_k(k: usize) -> Self {
280        Self { k, ..Default::default() }
281    }
282
283    /// Set nprobe (centroids per token).
284    #[must_use]
285    pub fn nprobe(mut self, nprobe: u32) -> Self {
286        self.nprobe = nprobe;
287        self
288    }
289
290    /// Set the centroid bound.
291    #[must_use]
292    pub fn bound(mut self, bound: usize) -> Self {
293        self.bound = bound;
294        self
295    }
296
297    /// Set early termination threshold.
298    #[must_use]
299    pub fn t_prime(mut self, t_prime: usize) -> Self {
300        self.t_prime = Some(t_prime);
301        self
302    }
303
304    /// Set centroid score threshold.
305    #[must_use]
306    pub fn centroid_score_threshold(mut self, threshold: f32) -> Self {
307        self.centroid_score_threshold = threshold;
308        self
309    }
310}
311
312#[cfg(test)]
313mod tests {
314    use super::*;
315
316    // ============ MultiVectorEmbedding Tests ============
317
318    #[test]
319    fn test_multivector_new() {
320        let embeddings = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
321        let mv = MultiVectorEmbedding::new(embeddings, 2, 3);
322
323        assert_eq!(mv.num_tokens(), 2);
324        assert_eq!(mv.dim(), 3);
325        assert_eq!(mv.token(0), &[1.0, 2.0, 3.0]);
326        assert_eq!(mv.token(1), &[4.0, 5.0, 6.0]);
327    }
328
329    #[test]
330    #[should_panic(expected = "Embedding size mismatch")]
331    fn test_multivector_size_mismatch() {
332        let embeddings = vec![1.0, 2.0, 3.0];
333        let _ = MultiVectorEmbedding::new(embeddings, 2, 3); // Should panic
334    }
335
336    #[test]
337    fn test_multivector_from_tokens() {
338        let tokens = vec![vec![1.0, 2.0], vec![3.0, 4.0], vec![5.0, 6.0]];
339        let mv = MultiVectorEmbedding::from_tokens(&tokens);
340
341        assert_eq!(mv.num_tokens(), 3);
342        assert_eq!(mv.dim(), 2);
343    }
344
345    #[test]
346    fn test_multivector_from_tokens_empty() {
347        let tokens: Vec<Vec<f32>> = vec![];
348        let mv = MultiVectorEmbedding::from_tokens(&tokens);
349
350        assert_eq!(mv.num_tokens(), 0);
351        assert!(mv.is_empty());
352    }
353
354    #[test]
355    fn test_multivector_tokens_iterator() {
356        let embeddings = vec![1.0, 2.0, 3.0, 4.0];
357        let mv = MultiVectorEmbedding::new(embeddings, 2, 2);
358
359        let tokens: Vec<&[f32]> = mv.tokens().collect();
360        assert_eq!(tokens.len(), 2);
361        assert_eq!(tokens[0], &[1.0, 2.0]);
362        assert_eq!(tokens[1], &[3.0, 4.0]);
363    }
364
365    #[test]
366    fn test_multivector_size_bytes() {
367        let embeddings = vec![0.0; 100];
368        let mv = MultiVectorEmbedding::new(embeddings, 10, 10);
369
370        assert_eq!(mv.size_bytes(), 100 * 4); // 100 f32s × 4 bytes
371    }
372
373    #[test]
374    fn test_multivector_as_slice() {
375        let embeddings = vec![1.0, 2.0, 3.0];
376        let mv = MultiVectorEmbedding::new(embeddings.clone(), 1, 3);
377
378        assert_eq!(mv.as_slice(), &[1.0, 2.0, 3.0]);
379    }
380
381    #[test]
382    fn test_multivector_serialization() {
383        let mv = MultiVectorEmbedding::new(vec![1.0, 2.0, 3.0, 4.0], 2, 2);
384        let json = serde_json::to_string(&mv).unwrap();
385        let deserialized: MultiVectorEmbedding = serde_json::from_str(&json).unwrap();
386
387        assert_eq!(mv.num_tokens(), deserialized.num_tokens());
388        assert_eq!(mv.dim(), deserialized.dim());
389        assert_eq!(mv.as_slice(), deserialized.as_slice());
390    }
391
392    // ============ WarpIndexConfig Tests ============
393
394    #[test]
395    fn test_index_config_default() {
396        let config = WarpIndexConfig::default();
397
398        assert_eq!(config.nbits, 2);
399        assert_eq!(config.num_centroids, 1024);
400        assert_eq!(config.token_dim, 128);
401        assert_eq!(config.kmeans_iterations, 20);
402    }
403
404    #[test]
405    fn test_index_config_new() {
406        let config = WarpIndexConfig::new(4, 256, 64);
407
408        assert_eq!(config.nbits, 4);
409        assert_eq!(config.num_centroids, 256);
410        assert_eq!(config.token_dim, 64);
411    }
412
413    #[test]
414    fn test_index_config_builders() {
415        let config = WarpIndexConfig::new(2, 512, 128)
416            .with_min_training_samples(5000)
417            .with_kmeans_iterations(30);
418
419        assert_eq!(config.min_training_samples, Some(5000));
420        assert_eq!(config.kmeans_iterations, 30);
421    }
422
423    #[test]
424    fn test_index_config_effective_min_samples() {
425        let config = WarpIndexConfig::new(2, 100, 128);
426        assert_eq!(config.effective_min_training_samples(), 1000); // 10 × 100
427
428        let config = config.with_min_training_samples(500);
429        assert_eq!(config.effective_min_training_samples(), 500);
430    }
431
432    #[test]
433    fn test_index_config_packed_size() {
434        // 128 dims × 2 bits = 256 bits = 32 bytes
435        let config = WarpIndexConfig::new(2, 1024, 128);
436        assert_eq!(config.packed_residual_size(), 32);
437
438        // 128 dims × 4 bits = 512 bits = 64 bytes
439        let config = WarpIndexConfig::new(4, 1024, 128);
440        assert_eq!(config.packed_residual_size(), 64);
441    }
442
443    #[test]
444    fn test_index_config_validate() {
445        let config = WarpIndexConfig::default();
446        assert!(config.validate().is_ok());
447
448        let bad_nbits = WarpIndexConfig { nbits: 3, ..Default::default() };
449        assert!(bad_nbits.validate().is_err());
450
451        let bad_centroids = WarpIndexConfig { num_centroids: 0, ..Default::default() };
452        assert!(bad_centroids.validate().is_err());
453    }
454
455    #[test]
456    fn test_index_config_serialization() {
457        let config = WarpIndexConfig::new(4, 512, 64);
458        let json = serde_json::to_string(&config).unwrap();
459        let deserialized: WarpIndexConfig = serde_json::from_str(&json).unwrap();
460
461        assert_eq!(config.nbits, deserialized.nbits);
462        assert_eq!(config.num_centroids, deserialized.num_centroids);
463        assert_eq!(config.token_dim, deserialized.token_dim);
464    }
465
466    // ============ WarpSearchConfig Tests ============
467
468    #[test]
469    fn test_search_config_default() {
470        let config = WarpSearchConfig::default();
471
472        assert_eq!(config.k, 10);
473        assert_eq!(config.nprobe, 4);
474        assert_eq!(config.bound, 128);
475        assert!(config.t_prime.is_none());
476        assert!((config.centroid_score_threshold - 0.4).abs() < 0.001);
477    }
478
479    #[test]
480    fn test_search_config_with_k() {
481        let config = WarpSearchConfig::with_k(20);
482        assert_eq!(config.k, 20);
483    }
484
485    #[test]
486    fn test_search_config_builders() {
487        let config = WarpSearchConfig::with_k(5)
488            .nprobe(8)
489            .bound(256)
490            .t_prime(10)
491            .centroid_score_threshold(0.5);
492
493        assert_eq!(config.k, 5);
494        assert_eq!(config.nprobe, 8);
495        assert_eq!(config.bound, 256);
496        assert_eq!(config.t_prime, Some(10));
497        assert!((config.centroid_score_threshold - 0.5).abs() < 0.001);
498    }
499
500    #[test]
501    fn test_search_config_serialization() {
502        let config = WarpSearchConfig::with_k(15).nprobe(6);
503        let json = serde_json::to_string(&config).unwrap();
504        let deserialized: WarpSearchConfig = serde_json::from_str(&json).unwrap();
505
506        assert_eq!(config.k, deserialized.k);
507        assert_eq!(config.nprobe, deserialized.nprobe);
508    }
509
510    // ============ Property-Based Tests ============
511
512    use proptest::prelude::*;
513
514    proptest! {
515        #[test]
516        fn prop_multivector_tokens_count_matches(
517            num_tokens in 1usize..20,
518            dim in 1usize..64
519        ) {
520            let embeddings = vec![0.0f32; num_tokens * dim];
521            let mv = MultiVectorEmbedding::new(embeddings, num_tokens, dim);
522
523            prop_assert_eq!(mv.num_tokens(), num_tokens);
524            prop_assert_eq!(mv.dim(), dim);
525            prop_assert_eq!(mv.tokens().count(), num_tokens);
526        }
527
528        #[test]
529        fn prop_multivector_token_slices_correct_size(
530            num_tokens in 1usize..10,
531            dim in 1usize..32
532        ) {
533            let embeddings = vec![0.0f32; num_tokens * dim];
534            let mv = MultiVectorEmbedding::new(embeddings, num_tokens, dim);
535
536            for i in 0..num_tokens {
537                prop_assert_eq!(mv.token(i).len(), dim);
538            }
539        }
540
541        #[test]
542        fn prop_index_config_packed_size_formula(
543            nbits in prop::sample::select(vec![2u8, 4]),
544            dim in 1usize..256
545        ) {
546            let config = WarpIndexConfig::new(nbits, 1024, dim);
547            let expected = (dim * nbits as usize + 7) / 8;
548            prop_assert_eq!(config.packed_residual_size(), expected);
549        }
550    }
551}