Skip to main content

trueno_rag/multivector/
index.rs

1//! WARP index with IVF structure
2//!
3//! This module implements the WARP index which organizes compressed token
4//! embeddings by centroid for cache-efficient search. The index supports:
5//!
6//! - Training from sample embeddings
7//! - Incremental insertion of documents
8//! - Building (compacting) for efficient search
9//! - MaxSim-based multi-vector search
10
11use crate::multivector::{
12    codec::ResidualCodec,
13    search::{CandidateScorer, CentroidSelector, ScoreMerger},
14    types::{MultiVectorEmbedding, WarpIndexConfig, WarpSearchConfig},
15};
16use crate::{Chunk, ChunkId, Result};
17use serde::{Deserialize, Serialize};
18use std::collections::HashMap;
19
20/// WARP index for efficient multi-vector retrieval.
21///
22/// The index organizes token embeddings by centroid assignment (IVF structure)
23/// for cache-efficient access during search. Each token embedding is stored as:
24/// - Centroid assignment
25/// - Quantized residual (2-4 bits per dimension)
26///
27/// # Lifecycle
28///
29/// 1. Create index with `new(config)`
30/// 2. Train codec with `train(samples)`
31/// 3. Insert documents with `insert(chunk, embedding)`
32/// 4. Build index with `build()` (compacts for efficient search)
33/// 5. Search with `search(query, config)`
34///
35/// # Memory Layout
36///
37/// After `build()`, data is organized by centroid:
38/// ```text
39/// Centroid 0: [chunk_ids...] [token_indices...] [residuals...]
40/// Centroid 1: [chunk_ids...] [token_indices...] [residuals...]
41/// ...
42/// ```
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct WarpIndex {
45    /// Index configuration
46    config: WarpIndexConfig,
47    /// Trained residual codec (None until trained)
48    codec: Option<ResidualCodec>,
49    /// Number of embeddings per centroid
50    sizes: Vec<usize>,
51    /// Cumulative offset for each centroid's data
52    offsets: Vec<usize>,
53    /// Chunk IDs, sorted by centroid assignment
54    chunk_ids: Vec<ChunkId>,
55    /// Token indices within each chunk
56    token_indices: Vec<u16>,
57    /// Packed residuals, sorted by centroid
58    residuals: Vec<u8>,
59    /// Original chunks for result retrieval
60    #[serde(skip)]
61    chunks: HashMap<ChunkId, Chunk>,
62    /// Pending embeddings before build
63    #[serde(skip)]
64    pending: Vec<(ChunkId, MultiVectorEmbedding)>,
65    /// Whether the index has been built
66    is_built: bool,
67}
68
69impl WarpIndex {
70    /// Create a new WARP index with the given configuration.
71    #[must_use]
72    pub fn new(config: WarpIndexConfig) -> Self {
73        Self {
74            config,
75            codec: None,
76            sizes: Vec::new(),
77            offsets: Vec::new(),
78            chunk_ids: Vec::new(),
79            token_indices: Vec::new(),
80            residuals: Vec::new(),
81            chunks: HashMap::new(),
82            pending: Vec::new(),
83            is_built: false,
84        }
85    }
86
87    /// Get the index configuration.
88    #[must_use]
89    pub fn config(&self) -> &WarpIndexConfig {
90        &self.config
91    }
92
93    /// Get the trained codec (if any).
94    #[must_use]
95    pub fn codec(&self) -> Option<&ResidualCodec> {
96        self.codec.as_ref()
97    }
98
99    /// Check if the codec has been trained.
100    #[must_use]
101    pub fn is_trained(&self) -> bool {
102        self.codec.is_some()
103    }
104
105    /// Check if the index has been built.
106    #[must_use]
107    pub fn is_built(&self) -> bool {
108        self.is_built
109    }
110
111    /// Get the number of indexed chunks.
112    #[must_use]
113    pub fn num_chunks(&self) -> usize {
114        self.chunks.len()
115    }
116
117    /// Get the number of indexed tokens.
118    #[must_use]
119    pub fn num_tokens(&self) -> usize {
120        self.chunk_ids.len()
121    }
122
123    /// Check if the index is empty.
124    #[must_use]
125    pub fn is_empty(&self) -> bool {
126        self.chunks.is_empty()
127    }
128
129    /// Get a chunk by ID.
130    #[must_use]
131    pub fn get_chunk(&self, id: &ChunkId) -> Option<&Chunk> {
132        self.chunks.get(id)
133    }
134
135    /// Get memory usage in bytes (approximate).
136    #[must_use]
137    pub fn memory_usage(&self) -> usize {
138        let codec_size = self
139            .codec
140            .as_ref()
141            .map(|c| {
142                c.centroids().len() * 4 // centroids
143                    + c.dim() * ((1 << c.nbits()) - 1) * 4 // cutoffs
144                    + c.dim() * (1 << c.nbits()) * 4 // weights
145            })
146            .unwrap_or(0);
147
148        let index_size = self.chunk_ids.len() * size_of::<ChunkId>()
149            + self.token_indices.len() * size_of::<u16>()
150            + self.residuals.len()
151            + self.sizes.len() * size_of::<usize>()
152            + self.offsets.len() * size_of::<usize>();
153
154        codec_size + index_size
155    }
156
157    /// Train the codec from sample embeddings.
158    ///
159    /// # Arguments
160    ///
161    /// * `samples` - Sample multi-vector embeddings for training
162    ///
163    /// # Errors
164    ///
165    /// Returns an error if:
166    /// - Not enough samples for training
167    /// - Configuration is invalid
168    pub fn train(&mut self, samples: &[MultiVectorEmbedding]) -> Result<()> {
169        // Collect all token embeddings
170        let total_tokens: usize = samples.iter().map(|s| s.num_tokens()).sum();
171        let min_samples = self.config.effective_min_training_samples();
172
173        if total_tokens < min_samples {
174            return Err(crate::Error::InvalidInput(format!(
175                "Insufficient training tokens: {total_tokens} < {min_samples} required"
176            )));
177        }
178
179        // Flatten all embeddings
180        let mut all_embeddings = Vec::with_capacity(total_tokens * self.config.token_dim);
181        for sample in samples {
182            all_embeddings.extend_from_slice(sample.as_slice());
183        }
184
185        // Train codec
186        let codec = ResidualCodec::train(
187            &all_embeddings,
188            self.config.token_dim,
189            self.config.num_centroids,
190            self.config.nbits,
191            self.config.kmeans_iterations,
192        )?;
193
194        self.codec = Some(codec);
195        Ok(())
196    }
197
198    /// Insert a chunk with its token embeddings.
199    ///
200    /// The chunk will be stored in pending state until `build()` is called.
201    ///
202    /// # Errors
203    ///
204    /// Returns an error if:
205    /// - Codec has not been trained
206    /// - Index has already been built (call `rebuild()` first)
207    pub fn insert(&mut self, chunk: Chunk, embedding: MultiVectorEmbedding) -> Result<()> {
208        if self.codec.is_none() {
209            return Err(crate::Error::InvalidInput(
210                "Codec not trained - call train() first".to_string(),
211            ));
212        }
213
214        if self.is_built {
215            return Err(crate::Error::InvalidInput(
216                "Index already built - cannot insert".to_string(),
217            ));
218        }
219
220        let chunk_id = chunk.id;
221        self.chunks.insert(chunk_id, chunk);
222        self.pending.push((chunk_id, embedding));
223
224        Ok(())
225    }
226
227    /// Build the index for efficient search.
228    ///
229    /// This compacts all pending embeddings into a centroid-organized
230    /// IVF structure optimized for cache-efficient search.
231    ///
232    /// # Errors
233    ///
234    /// Returns an error if the codec has not been trained.
235    pub fn build(&mut self) -> Result<()> {
236        let codec = self.codec.as_ref().ok_or_else(|| {
237            crate::Error::InvalidInput("Codec not trained - call train() first".to_string())
238        })?;
239
240        // Assign each token to its nearest centroid
241        let mut centroid_assignments: Vec<Vec<(ChunkId, u16, Vec<u8>)>> =
242            vec![Vec::new(); self.config.num_centroids];
243
244        for (chunk_id, embedding) in &self.pending {
245            for (token_idx, token) in embedding.tokens().enumerate() {
246                let (centroid_id, residual) = codec.compress(token);
247                centroid_assignments[centroid_id].push((*chunk_id, token_idx as u16, residual));
248            }
249        }
250
251        // Build compacted arrays
252        let bytes_per_residual = self.config.packed_residual_size();
253
254        self.sizes = centroid_assignments.iter().map(|v| v.len()).collect();
255        self.offsets = self
256            .sizes
257            .iter()
258            .scan(0, |acc, &size| {
259                let offset = *acc;
260                *acc += size;
261                Some(offset)
262            })
263            .collect();
264
265        let total_tokens: usize = self.sizes.iter().sum();
266        self.chunk_ids = Vec::with_capacity(total_tokens);
267        self.token_indices = Vec::with_capacity(total_tokens);
268        self.residuals = Vec::with_capacity(total_tokens * bytes_per_residual);
269
270        for assignments in centroid_assignments {
271            for (chunk_id, token_idx, residual) in assignments {
272                self.chunk_ids.push(chunk_id);
273                self.token_indices.push(token_idx);
274                self.residuals.extend(residual);
275            }
276        }
277
278        self.pending.clear();
279        self.is_built = true;
280
281        Ok(())
282    }
283
284    /// Clear the built index to allow new insertions.
285    ///
286    /// Chunks are preserved, but the IVF structure is cleared.
287    /// Call `build()` again after inserting new chunks.
288    pub fn clear_index(&mut self) {
289        self.sizes.clear();
290        self.offsets.clear();
291        self.chunk_ids.clear();
292        self.token_indices.clear();
293        self.residuals.clear();
294        self.is_built = false;
295    }
296
297    /// Search for relevant chunks using MaxSim scoring.
298    ///
299    /// # Arguments
300    ///
301    /// * `query` - Query multi-vector embedding
302    /// * `search_config` - Search parameters
303    ///
304    /// # Returns
305    ///
306    /// Vector of (ChunkId, score) pairs sorted by score descending.
307    ///
308    /// # Errors
309    ///
310    /// Returns an error if the index has not been built.
311    pub fn search(
312        &self,
313        query: &MultiVectorEmbedding,
314        search_config: &WarpSearchConfig,
315    ) -> Result<Vec<(ChunkId, f32)>> {
316        let codec = self
317            .codec
318            .as_ref()
319            .ok_or_else(|| crate::Error::InvalidInput("Codec not trained".to_string()))?;
320
321        if !self.is_built {
322            return Err(crate::Error::InvalidInput(
323                "Index not built - call build() first".to_string(),
324            ));
325        }
326
327        // Phase 1: Select centroids per query token
328        let selected_centroids = CentroidSelector::select(
329            query,
330            codec.centroids(),
331            self.config.token_dim,
332            search_config,
333        );
334
335        // Apply bound: limit total centroids examined
336        let mut total_centroids = 0;
337        let max_tokens = search_config.t_prime.unwrap_or(usize::MAX);
338        let bounded_centroids: Vec<Vec<(usize, f32)>> = selected_centroids
339            .into_iter()
340            .take(max_tokens)
341            .map(|centroids| {
342                let take =
343                    (search_config.bound.saturating_sub(total_centroids)).min(centroids.len());
344                total_centroids += take;
345                centroids.into_iter().take(take).collect()
346            })
347            .collect();
348
349        // Phase 2: Score candidates from selected centroids
350        let bytes_per_residual = self.config.packed_residual_size();
351
352        let token_scores: Vec<Vec<(ChunkId, u16, f32)>> = bounded_centroids
353            .into_iter()
354            .enumerate()
355            .map(|(query_token_idx, centroids)| {
356                let query_token = query.token(query_token_idx);
357
358                centroids
359                    .into_iter()
360                    .flat_map(|(centroid_id, centroid_score)| {
361                        CandidateScorer::score(
362                            query_token,
363                            centroid_id,
364                            centroid_score,
365                            codec,
366                            &self.sizes,
367                            &self.offsets,
368                            &self.chunk_ids,
369                            &self.token_indices,
370                            &self.residuals,
371                            bytes_per_residual,
372                        )
373                    })
374                    .collect()
375            })
376            .collect();
377
378        // Phase 3: Merge via MaxSim
379        Ok(ScoreMerger::merge(token_scores, search_config.k))
380    }
381
382    /// Get centroid size (number of tokens assigned).
383    #[must_use]
384    pub fn centroid_size(&self, centroid_id: usize) -> usize {
385        self.sizes.get(centroid_id).copied().unwrap_or(0)
386    }
387
388    /// Get centroid offset in the compacted arrays.
389    #[must_use]
390    pub fn centroid_offset(&self, centroid_id: usize) -> usize {
391        self.offsets.get(centroid_id).copied().unwrap_or(0)
392    }
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398    use crate::DocumentId;
399
400    fn create_test_chunk(content: &str) -> Chunk {
401        Chunk::new(DocumentId::new(), content.to_string(), 0, content.len())
402    }
403
404    fn generate_embedding(num_tokens: usize, dim: usize, seed: u64) -> MultiVectorEmbedding {
405        let mut embeddings = Vec::with_capacity(num_tokens * dim);
406        let mut rng = seed;
407
408        for _ in 0..(num_tokens * dim) {
409            rng = rng.wrapping_mul(6364136223846793005).wrapping_add(1);
410            let val = ((rng >> 33) as f32 / u32::MAX as f32) * 2.0 - 1.0;
411            embeddings.push(val);
412        }
413
414        MultiVectorEmbedding::new(embeddings, num_tokens, dim)
415    }
416
417    // ============ Basic Index Tests ============
418
419    #[test]
420    fn test_index_new() {
421        let config = WarpIndexConfig::new(2, 16, 32);
422        let index = WarpIndex::new(config);
423
424        assert!(!index.is_trained());
425        assert!(!index.is_built());
426        assert!(index.is_empty());
427    }
428
429    #[test]
430    fn test_index_config() {
431        let config = WarpIndexConfig::new(4, 32, 64);
432        let index = WarpIndex::new(config);
433
434        assert_eq!(index.config().nbits, 4);
435        assert_eq!(index.config().num_centroids, 32);
436        assert_eq!(index.config().token_dim, 64);
437    }
438
439    // ============ Training Tests ============
440
441    #[test]
442    fn test_index_train() {
443        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
444        let mut index = WarpIndex::new(config);
445
446        // Generate training samples
447        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
448
449        index.train(&samples).unwrap();
450
451        assert!(index.is_trained());
452        assert!(index.codec().is_some());
453    }
454
455    #[test]
456    fn test_index_train_insufficient_samples() {
457        let config = WarpIndexConfig::new(2, 100, 16); // 100 centroids needs 1000+ samples
458        let mut index = WarpIndex::new(config);
459
460        let samples: Vec<_> = (0..5).map(|i| generate_embedding(10, 16, i)).collect();
461
462        let result = index.train(&samples);
463        assert!(result.is_err());
464    }
465
466    // ============ Insert Tests ============
467
468    #[test]
469    fn test_index_insert() {
470        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
471        let mut index = WarpIndex::new(config);
472
473        // Train first
474        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
475        index.train(&samples).unwrap();
476
477        // Insert chunk
478        let chunk = create_test_chunk("test content");
479        let embedding = generate_embedding(5, 16, 999);
480        index.insert(chunk, embedding).unwrap();
481
482        assert_eq!(index.num_chunks(), 1);
483    }
484
485    #[test]
486    fn test_index_insert_without_training() {
487        let config = WarpIndexConfig::new(2, 8, 16);
488        let mut index = WarpIndex::new(config);
489
490        let chunk = create_test_chunk("test");
491        let embedding = generate_embedding(5, 16, 0);
492
493        let result = index.insert(chunk, embedding);
494        assert!(result.is_err());
495    }
496
497    // ============ Build Tests ============
498
499    #[test]
500    fn test_index_build() {
501        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
502        let mut index = WarpIndex::new(config);
503
504        // Train
505        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
506        index.train(&samples).unwrap();
507
508        // Insert
509        for i in 0..10 {
510            let chunk = create_test_chunk(&format!("document {}", i));
511            let embedding = generate_embedding(5, 16, 1000 + i);
512            index.insert(chunk, embedding).unwrap();
513        }
514
515        // Build
516        index.build().unwrap();
517
518        assert!(index.is_built());
519        assert_eq!(index.num_chunks(), 10);
520        assert_eq!(index.num_tokens(), 50); // 10 chunks × 5 tokens
521    }
522
523    #[test]
524    fn test_index_cannot_insert_after_build() {
525        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
526        let mut index = WarpIndex::new(config);
527
528        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
529        index.train(&samples).unwrap();
530
531        let chunk = create_test_chunk("test");
532        let embedding = generate_embedding(5, 16, 0);
533        index.insert(chunk, embedding).unwrap();
534
535        index.build().unwrap();
536
537        // Try to insert after build
538        let chunk2 = create_test_chunk("test2");
539        let embedding2 = generate_embedding(5, 16, 1);
540        let result = index.insert(chunk2, embedding2);
541
542        assert!(result.is_err());
543    }
544
545    // ============ Search Tests ============
546
547    #[test]
548    fn test_index_search() {
549        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
550        let mut index = WarpIndex::new(config);
551
552        // Train
553        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
554        index.train(&samples).unwrap();
555
556        // Insert
557        for i in 0..20 {
558            let chunk = create_test_chunk(&format!("document {}", i));
559            let embedding = generate_embedding(5, 16, 1000 + i);
560            index.insert(chunk, embedding).unwrap();
561        }
562
563        // Build
564        index.build().unwrap();
565
566        // Search
567        let query = generate_embedding(3, 16, 9999);
568        let search_config = WarpSearchConfig::with_k(5);
569        let results = index.search(&query, &search_config).unwrap();
570
571        assert!(results.len() <= 5);
572        assert!(!results.is_empty());
573
574        // Results should be sorted by score descending
575        for i in 1..results.len() {
576            assert!(results[i - 1].1 >= results[i].1);
577        }
578    }
579
580    #[test]
581    fn test_index_search_without_build() {
582        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
583        let mut index = WarpIndex::new(config);
584
585        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
586        index.train(&samples).unwrap();
587
588        let query = generate_embedding(3, 16, 0);
589        let search_config = WarpSearchConfig::with_k(5);
590        let result = index.search(&query, &search_config);
591
592        assert!(result.is_err());
593    }
594
595    // ============ Memory & Stats Tests ============
596
597    #[test]
598    fn test_index_memory_usage() {
599        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
600        let mut index = WarpIndex::new(config);
601
602        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
603        index.train(&samples).unwrap();
604
605        for i in 0..10 {
606            let chunk = create_test_chunk(&format!("doc {}", i));
607            let embedding = generate_embedding(5, 16, 1000 + i);
608            index.insert(chunk, embedding).unwrap();
609        }
610
611        index.build().unwrap();
612
613        let memory = index.memory_usage();
614        assert!(memory > 0);
615    }
616
617    #[test]
618    fn test_index_centroid_stats() {
619        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
620        let mut index = WarpIndex::new(config);
621
622        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
623        index.train(&samples).unwrap();
624
625        for i in 0..10 {
626            let chunk = create_test_chunk(&format!("doc {}", i));
627            let embedding = generate_embedding(5, 16, 1000 + i);
628            index.insert(chunk, embedding).unwrap();
629        }
630
631        index.build().unwrap();
632
633        // Total tokens across centroids should equal num_tokens
634        let total: usize = (0..8).map(|c| index.centroid_size(c)).sum();
635        assert_eq!(total, index.num_tokens());
636    }
637
638    // ============ Clear & Rebuild Tests ============
639
640    #[test]
641    fn test_index_clear_and_rebuild() {
642        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
643        let mut index = WarpIndex::new(config);
644
645        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
646        index.train(&samples).unwrap();
647
648        let chunk = create_test_chunk("test");
649        let embedding = generate_embedding(5, 16, 0);
650        index.insert(chunk, embedding).unwrap();
651        index.build().unwrap();
652
653        assert!(index.is_built());
654
655        index.clear_index();
656
657        assert!(!index.is_built());
658        assert_eq!(index.num_tokens(), 0);
659        // Chunks are preserved
660        assert_eq!(index.num_chunks(), 1);
661    }
662
663    // ============ Get Chunk Tests ============
664
665    #[test]
666    fn test_index_get_chunk() {
667        let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
668        let mut index = WarpIndex::new(config);
669
670        let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
671        index.train(&samples).unwrap();
672
673        let chunk = create_test_chunk("test content");
674        let chunk_id = chunk.id;
675        let embedding = generate_embedding(5, 16, 0);
676        index.insert(chunk, embedding).unwrap();
677
678        let retrieved = index.get_chunk(&chunk_id);
679        assert!(retrieved.is_some());
680        assert_eq!(retrieved.unwrap().content, "test content");
681    }
682
683    // ============ Property-Based Tests ============
684
685    use proptest::prelude::*;
686
687    proptest! {
688        #[test]
689        fn prop_search_returns_at_most_k(k in 1usize..20) {
690            let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
691            let mut index = WarpIndex::new(config);
692
693            let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
694            index.train(&samples).unwrap();
695
696            for i in 0..30 {
697                let chunk = create_test_chunk(&format!("doc {}", i));
698                let embedding = generate_embedding(5, 16, 1000 + i as u64);
699                index.insert(chunk, embedding).unwrap();
700            }
701
702            index.build().unwrap();
703
704            let query = generate_embedding(3, 16, 9999);
705            let search_config = WarpSearchConfig::with_k(k);
706            let results = index.search(&query, &search_config).unwrap();
707
708            prop_assert!(results.len() <= k);
709        }
710
711        #[test]
712        fn prop_search_results_sorted_descending(seed in 0u64..1000) {
713            let config = WarpIndexConfig::new(2, 8, 16).with_kmeans_iterations(3);
714            let mut index = WarpIndex::new(config);
715
716            let samples: Vec<_> = (0..100).map(|i| generate_embedding(10, 16, i)).collect();
717            index.train(&samples).unwrap();
718
719            for i in 0..20 {
720                let chunk = create_test_chunk(&format!("doc {}", i));
721                let embedding = generate_embedding(5, 16, seed + i as u64);
722                index.insert(chunk, embedding).unwrap();
723            }
724
725            index.build().unwrap();
726
727            let query = generate_embedding(3, 16, seed + 1000);
728            let search_config = WarpSearchConfig::with_k(10);
729            let results = index.search(&query, &search_config).unwrap();
730
731            for i in 1..results.len() {
732                prop_assert!(results[i - 1].1 >= results[i].1);
733            }
734        }
735    }
736}