Skip to main content

vecstore/
multi_vector.rs

1//! Multi-Vector Document Storage (ColBERT-style)
2//!
3//! This module supports documents with multiple embeddings per document,
4//! enabling late interaction models like ColBERT.
5//!
6//! ## Key Concepts
7//!
8//! - **Token-level embeddings**: Each token gets its own embedding
9//! - **MaxSim**: Relevance score = max similarity across all token pairs
10//! - **Late interaction**: Similarity computed at query time, not indexing time
11//!
12//! ## Architecture
13//!
14//! ```text
15//! Document: "machine learning"
16//!     │
17//!     ▼
18//! ┌─────────┬─────────┐
19//! │ machine │ learning│
20//! └────┬────┴────┬────┘
21//!      │         │
22//!   embed()   embed()
23//!      │         │
24//!      ▼         ▼
25//!   [0.1,…]  [0.2,…]
26//!
27//! Query: "deep learning"
28//!   MaxSim = max(sim(query, machine), sim(query, learning))
29//! ```
30//!
31//! ## Example
32//!
33//! ```no_run
34//! use vecstore::multi_vector::{MultiVectorDoc, MultiVectorIndex, MaxSimAggregation};
35//!
36//! # fn main() -> anyhow::Result<()> {
37//! let mut index = MultiVectorIndex::new(128); // 128-dim embeddings
38//!
39//! // Add document with multiple token embeddings
40//! let doc = MultiVectorDoc::new(
41//!     "doc1",
42//!     vec![
43//!         vec![0.1; 128],  // "machine" embedding
44//!         vec![0.2; 128],  // "learning" embedding
45//!     ],
46//!     serde_json::json!({"title": "ML Guide"}),
47//! );
48//!
49//! index.add(doc)?;
50//!
51//! // Query with MaxSim aggregation
52//! let query_tokens = vec![vec![0.15; 128]];
53//! let results = index.search(&query_tokens, 10)?;
54//!
55//! println!("Found {} results", results.len());
56//! # Ok(())
57//! # }
58//! ```
59
60use anyhow::{anyhow, Result};
61use serde::{Deserialize, Serialize};
62use std::collections::HashMap;
63
64/// Multi-vector document
65#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct MultiVectorDoc {
67    /// Document ID
68    pub id: String,
69    /// Multiple embeddings (one per token/chunk)
70    pub vectors: Vec<Vec<f32>>,
71    /// Metadata
72    pub metadata: serde_json::Value,
73}
74
75impl MultiVectorDoc {
76    /// Create a new multi-vector document
77    pub fn new(id: impl Into<String>, vectors: Vec<Vec<f32>>, metadata: serde_json::Value) -> Self {
78        Self {
79            id: id.into(),
80            vectors,
81            metadata,
82        }
83    }
84
85    /// Get number of vectors
86    pub fn num_vectors(&self) -> usize {
87        self.vectors.len()
88    }
89
90    /// Get vector dimension
91    pub fn dimension(&self) -> usize {
92        self.vectors.first().map(|v| v.len()).unwrap_or(0)
93    }
94
95    /// Validate that all vectors have the same dimension
96    pub fn validate(&self) -> Result<()> {
97        if self.vectors.is_empty() {
98            return Err(anyhow!("Document has no vectors"));
99        }
100
101        let dim = self.dimension();
102        for (i, vec) in self.vectors.iter().enumerate() {
103            if vec.len() != dim {
104                return Err(anyhow!(
105                    "Vector {} has dimension {}, expected {}",
106                    i,
107                    vec.len(),
108                    dim
109                ));
110            }
111        }
112
113        Ok(())
114    }
115}
116
117/// Aggregation method for multi-vector scores
118#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119pub enum AggregationMethod {
120    /// Maximum similarity (ColBERT)
121    MaxSim,
122    /// Average similarity
123    AvgSim,
124    /// Sum of similarities
125    SumSim,
126    /// First token only
127    FirstToken,
128}
129
130/// Multi-vector index
131pub struct MultiVectorIndex {
132    /// Expected vector dimension
133    dimension: usize,
134    /// Documents indexed by ID
135    documents: HashMap<String, MultiVectorDoc>,
136    /// Flattened token index for fast retrieval
137    /// Maps flat token ID -> (doc_id, token_index)
138    token_index: Vec<(String, usize)>,
139    /// All token vectors (flattened)
140    token_vectors: Vec<Vec<f32>>,
141    /// Aggregation method
142    aggregation: AggregationMethod,
143}
144
145impl MultiVectorIndex {
146    /// Create a new multi-vector index
147    pub fn new(dimension: usize) -> Self {
148        Self {
149            dimension,
150            documents: HashMap::new(),
151            token_index: Vec::new(),
152            token_vectors: Vec::new(),
153            aggregation: AggregationMethod::MaxSim,
154        }
155    }
156
157    /// Set aggregation method
158    pub fn with_aggregation(mut self, aggregation: AggregationMethod) -> Self {
159        self.aggregation = aggregation;
160        self
161    }
162
163    /// Add a document
164    pub fn add(&mut self, doc: MultiVectorDoc) -> Result<()> {
165        doc.validate()?;
166
167        if doc.dimension() != self.dimension {
168            return Err(anyhow!(
169                "Document dimension {} doesn't match index dimension {}",
170                doc.dimension(),
171                self.dimension
172            ));
173        }
174
175        let doc_id = doc.id.clone();
176
177        // Add all token vectors to flat index
178        for (token_idx, vector) in doc.vectors.iter().enumerate() {
179            self.token_index.push((doc_id.clone(), token_idx));
180            self.token_vectors.push(vector.clone());
181        }
182
183        self.documents.insert(doc_id, doc);
184
185        Ok(())
186    }
187
188    /// Search using multi-vector query
189    pub fn search(&self, query_vectors: &[Vec<f32>], k: usize) -> Result<Vec<(String, f32)>> {
190        if query_vectors.is_empty() {
191            return Err(anyhow!("Query has no vectors"));
192        }
193
194        // Validate query dimensions
195        for qv in query_vectors {
196            if qv.len() != self.dimension {
197                return Err(anyhow!(
198                    "Query dimension {} doesn't match index dimension {}",
199                    qv.len(),
200                    self.dimension
201                ));
202            }
203        }
204
205        // Compute scores for each document
206        let mut doc_scores: HashMap<String, Vec<f32>> = HashMap::new();
207
208        // For each query vector
209        for query_vec in query_vectors {
210            // Compute similarity with all document tokens
211            for (token_id, (doc_id, _token_idx)) in self.token_index.iter().enumerate() {
212                let token_vec = &self.token_vectors[token_id];
213                let sim = cosine_similarity(query_vec, token_vec);
214
215                doc_scores
216                    .entry(doc_id.clone())
217                    .or_insert_with(Vec::new)
218                    .push(sim);
219            }
220        }
221
222        // Aggregate scores per document
223        let mut results: Vec<(String, f32)> = doc_scores
224            .into_iter()
225            .map(|(doc_id, sims)| {
226                let score = match self.aggregation {
227                    AggregationMethod::MaxSim => {
228                        sims.iter().copied().fold(f32::NEG_INFINITY, f32::max)
229                    }
230                    AggregationMethod::AvgSim => sims.iter().sum::<f32>() / sims.len() as f32,
231                    AggregationMethod::SumSim => sims.iter().sum(),
232                    AggregationMethod::FirstToken => sims.first().copied().unwrap_or(0.0),
233                };
234                (doc_id, score)
235            })
236            .collect();
237
238        // Sort by score descending
239        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
240        results.truncate(k);
241
242        Ok(results)
243    }
244
245    /// Get a document by ID
246    pub fn get(&self, doc_id: &str) -> Option<&MultiVectorDoc> {
247        self.documents.get(doc_id)
248    }
249
250    /// Get number of documents
251    pub fn num_documents(&self) -> usize {
252        self.documents.len()
253    }
254
255    /// Get total number of token vectors
256    pub fn num_tokens(&self) -> usize {
257        self.token_vectors.len()
258    }
259
260    /// Get index statistics
261    pub fn stats(&self) -> MultiVectorStats {
262        let avg_tokens_per_doc = if !self.documents.is_empty() {
263            self.num_tokens() as f32 / self.num_documents() as f32
264        } else {
265            0.0
266        };
267
268        MultiVectorStats {
269            num_documents: self.num_documents(),
270            num_tokens: self.num_tokens(),
271            dimension: self.dimension,
272            avg_tokens_per_doc,
273            aggregation: self.aggregation,
274        }
275    }
276}
277
278/// Index statistics
279#[derive(Debug, Clone)]
280pub struct MultiVectorStats {
281    pub num_documents: usize,
282    pub num_tokens: usize,
283    pub dimension: usize,
284    pub avg_tokens_per_doc: f32,
285    pub aggregation: AggregationMethod,
286}
287
288/// Compute cosine similarity between two vectors
289fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
290    assert_eq!(a.len(), b.len());
291
292    let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
293    let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
294    let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
295
296    if norm_a == 0.0 || norm_b == 0.0 {
297        0.0
298    } else {
299        dot / (norm_a * norm_b)
300    }
301}
302
303/// ColBERT-specific utilities
304pub mod colbert {
305    use super::*;
306
307    /// ColBERT query encoder (wraps multi-vector with MaxSim)
308    pub struct ColBERTQuery {
309        /// Query token embeddings
310        pub tokens: Vec<Vec<f32>>,
311    }
312
313    impl ColBERTQuery {
314        /// Create a new ColBERT query
315        pub fn new(tokens: Vec<Vec<f32>>) -> Self {
316            Self { tokens }
317        }
318
319        /// Compute MaxSim score against a document
320        pub fn score(&self, doc: &MultiVectorDoc) -> f32 {
321            if self.tokens.is_empty() || doc.vectors.is_empty() {
322                return 0.0;
323            }
324
325            let mut total_score = 0.0;
326
327            // For each query token, find max similarity with any doc token
328            for query_token in &self.tokens {
329                let max_sim = doc
330                    .vectors
331                    .iter()
332                    .map(|doc_token| cosine_similarity(query_token, doc_token))
333                    .fold(f32::NEG_INFINITY, f32::max);
334
335                total_score += max_sim;
336            }
337
338            total_score
339        }
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    #[test]
348    fn test_multi_vector_doc_creation() {
349        let doc = MultiVectorDoc::new(
350            "doc1",
351            vec![vec![1.0, 2.0], vec![3.0, 4.0]],
352            serde_json::json!({}),
353        );
354
355        assert_eq!(doc.id, "doc1");
356        assert_eq!(doc.num_vectors(), 2);
357        assert_eq!(doc.dimension(), 2);
358    }
359
360    #[test]
361    fn test_doc_validation() {
362        let valid_doc = MultiVectorDoc::new(
363            "doc1",
364            vec![vec![1.0, 2.0], vec![3.0, 4.0]],
365            serde_json::json!({}),
366        );
367        assert!(valid_doc.validate().is_ok());
368
369        let invalid_doc = MultiVectorDoc::new(
370            "doc2",
371            vec![vec![1.0, 2.0], vec![3.0, 4.0, 5.0]], // Different dimensions
372            serde_json::json!({}),
373        );
374        assert!(invalid_doc.validate().is_err());
375    }
376
377    #[test]
378    fn test_index_add_and_get() {
379        let mut index = MultiVectorIndex::new(2);
380
381        let doc = MultiVectorDoc::new(
382            "doc1",
383            vec![vec![1.0, 2.0], vec![3.0, 4.0]],
384            serde_json::json!({}),
385        );
386
387        assert!(index.add(doc.clone()).is_ok());
388        assert_eq!(index.num_documents(), 1);
389        assert_eq!(index.num_tokens(), 2);
390
391        let retrieved = index.get("doc1").unwrap();
392        assert_eq!(retrieved.id, "doc1");
393    }
394
395    #[test]
396    fn test_multi_vector_search_maxsim() {
397        let mut index = MultiVectorIndex::new(2).with_aggregation(AggregationMethod::MaxSim);
398
399        // Add documents
400        let doc1 = MultiVectorDoc::new(
401            "doc1",
402            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
403            serde_json::json!({}),
404        );
405        let doc2 = MultiVectorDoc::new(
406            "doc2",
407            vec![vec![0.5, 0.5], vec![0.5, 0.5]],
408            serde_json::json!({}),
409        );
410
411        index.add(doc1).unwrap();
412        index.add(doc2).unwrap();
413
414        // Query
415        let query = vec![vec![1.0, 0.0]];
416        let results = index.search(&query, 2).unwrap();
417
418        assert_eq!(results.len(), 2);
419        // doc1 should rank higher (exact match with first token)
420        assert_eq!(results[0].0, "doc1");
421    }
422
423    #[test]
424    fn test_cosine_similarity() {
425        let a = vec![1.0, 0.0, 0.0];
426        let b = vec![1.0, 0.0, 0.0];
427        assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
428
429        let c = vec![1.0, 0.0, 0.0];
430        let d = vec![0.0, 1.0, 0.0];
431        assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
432    }
433
434    #[test]
435    fn test_colbert_query() {
436        use colbert::*;
437
438        let query = ColBERTQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
439
440        let doc = MultiVectorDoc::new(
441            "doc1",
442            vec![vec![1.0, 0.0], vec![0.5, 0.5]],
443            serde_json::json!({}),
444        );
445
446        let score = query.score(&doc);
447        assert!(score > 0.0);
448    }
449
450    #[test]
451    fn test_index_stats() {
452        let mut index = MultiVectorIndex::new(128);
453
454        let doc1 = MultiVectorDoc::new(
455            "doc1",
456            vec![vec![0.0; 128], vec![0.1; 128]],
457            serde_json::json!({}),
458        );
459        let doc2 = MultiVectorDoc::new(
460            "doc2",
461            vec![vec![0.2; 128], vec![0.3; 128], vec![0.4; 128]],
462            serde_json::json!({}),
463        );
464
465        index.add(doc1).unwrap();
466        index.add(doc2).unwrap();
467
468        let stats = index.stats();
469        assert_eq!(stats.num_documents, 2);
470        assert_eq!(stats.num_tokens, 5); // 2 + 3
471        assert_eq!(stats.dimension, 128);
472        assert!((stats.avg_tokens_per_doc - 2.5).abs() < 0.01);
473    }
474
475    #[test]
476    fn test_aggregation_methods() {
477        let mut index = MultiVectorIndex::new(2);
478
479        let doc = MultiVectorDoc::new(
480            "doc1",
481            vec![vec![1.0, 0.0], vec![0.0, 1.0]],
482            serde_json::json!({}),
483        );
484        index.add(doc).unwrap();
485
486        // Test MaxSim
487        index.aggregation = AggregationMethod::MaxSim;
488        let query = vec![vec![1.0, 0.0]];
489        let results = index.search(&query, 1).unwrap();
490        assert!(results[0].1 > 0.9); // Should be close to 1.0
491
492        // Test AvgSim
493        index.aggregation = AggregationMethod::AvgSim;
494        let results = index.search(&query, 1).unwrap();
495        assert!(results[0].1 > 0.0 && results[0].1 < 1.0); // Average of 1.0 and 0.0
496    }
497}