vectx_core/
multivector.rs

1//! MultiVector support for ColBERT-style late interaction retrieval
2//!
3//! Implements MaxSim (Maximum Similarity) scoring as described in:
4//! <https://arxiv.org/pdf/2112.01488.pdf>
5
6use serde::{Deserialize, Serialize};
7use crate::Vector;
8
9/// Configuration for multivector comparison
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
11pub enum MultiVectorComparator {
12    /// MaxSim: For each query vector, find max similarity with any document vector, then sum
13    /// This is the ColBERT algorithm for late interaction retrieval
14    #[default]
15    MaxSim,
16}
17
18/// Configuration for multivector storage and search
19#[derive(Debug, Clone, Default, Serialize, Deserialize)]
20pub struct MultiVectorConfig {
21    pub comparator: MultiVectorComparator,
22}
23
24/// A multivector - multiple dense vectors per point (ColBERT-style)
25/// Each sub-vector typically represents a token embedding
26#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct MultiVector {
28    /// The sub-vectors (each row is a dense vector)
29    vectors: Vec<Vec<f32>>,
30    /// Dimension of each sub-vector (all must be the same)
31    dim: usize,
32}
33
34impl MultiVector {
35    /// Create a new multivector from a list of sub-vectors
36    /// All sub-vectors must have the same dimension
37    pub fn new(vectors: Vec<Vec<f32>>) -> Result<Self, &'static str> {
38        if vectors.is_empty() {
39            return Err("MultiVector cannot be empty");
40        }
41        
42        let dim = vectors[0].len();
43        if dim == 0 {
44            return Err("Sub-vectors cannot be empty");
45        }
46        
47        // Verify all vectors have the same dimension
48        if !vectors.iter().all(|v| v.len() == dim) {
49            return Err("All sub-vectors must have the same dimension");
50        }
51        
52        Ok(Self { vectors, dim })
53    }
54    
55    /// Create from a single dense vector (wraps it as a multivector with one sub-vector)
56    pub fn from_single(vector: Vec<f32>) -> Result<Self, &'static str> {
57        if vector.is_empty() {
58            return Err("Vector cannot be empty");
59        }
60        let dim = vector.len();
61        Ok(Self { vectors: vec![vector], dim })
62    }
63    
64    /// Get the dimension of each sub-vector
65    #[inline]
66    pub fn dim(&self) -> usize {
67        self.dim
68    }
69    
70    /// Get the number of sub-vectors
71    #[inline]
72    pub fn len(&self) -> usize {
73        self.vectors.len()
74    }
75    
76    /// Check if empty
77    #[inline]
78    pub fn is_empty(&self) -> bool {
79        self.vectors.is_empty()
80    }
81    
82    /// Get a reference to the sub-vectors
83    #[inline]
84    pub fn vectors(&self) -> &[Vec<f32>] {
85        &self.vectors
86    }
87    
88    /// Get the first sub-vector (useful for backwards compatibility)
89    #[inline]
90    pub fn first(&self) -> Option<&Vec<f32>> {
91        self.vectors.first()
92    }
93    
94    /// Convert to a single Vector (using first sub-vector)
95    /// Used for backwards compatibility with non-multivector operations
96    pub fn to_single_vector(&self) -> Vector {
97        Vector::new(self.vectors[0].clone())
98    }
99    
100    /// Compute MaxSim score between two multivectors
101    /// 
102    /// For each sub-vector in `self` (query), find the maximum similarity 
103    /// with any sub-vector in `other` (document), then sum all maximums.
104    /// 
105    /// This is the ColBERT scoring algorithm.
106    pub fn max_sim(&self, other: &MultiVector) -> f32 {
107        if self.dim != other.dim {
108            return 0.0;
109        }
110        
111        let mut total_score = 0.0;
112        
113        // For each query sub-vector
114        for query_vec in &self.vectors {
115            let mut max_sim = f32::NEG_INFINITY;
116            
117            // Find max similarity with any document sub-vector
118            for doc_vec in &other.vectors {
119                let sim = dot_product(query_vec, doc_vec);
120                if sim > max_sim {
121                    max_sim = sim;
122                }
123            }
124            
125            // Only add if we found a valid similarity
126            if max_sim > f32::NEG_INFINITY {
127                total_score += max_sim;
128            }
129        }
130        
131        total_score
132    }
133    
134    /// Compute MaxSim with cosine similarity (normalized dot product)
135    pub fn max_sim_cosine(&self, other: &MultiVector) -> f32 {
136        if self.dim != other.dim {
137            return 0.0;
138        }
139        
140        let mut total_score = 0.0;
141        
142        for query_vec in &self.vectors {
143            let query_norm = norm(query_vec);
144            if query_norm < f32::EPSILON {
145                continue;
146            }
147            
148            let mut max_sim = f32::NEG_INFINITY;
149            
150            for doc_vec in &other.vectors {
151                let doc_norm = norm(doc_vec);
152                if doc_norm < f32::EPSILON {
153                    continue;
154                }
155                
156                let sim = dot_product(query_vec, doc_vec) / (query_norm * doc_norm);
157                if sim > max_sim {
158                    max_sim = sim;
159                }
160            }
161            
162            if max_sim > f32::NEG_INFINITY {
163                total_score += max_sim;
164            }
165        }
166        
167        total_score
168    }
169    
170    /// Compute MaxSim with negative L2 distance (for Euclidean)
171    pub fn max_sim_l2(&self, other: &MultiVector) -> f32 {
172        if self.dim != other.dim {
173            return f32::NEG_INFINITY;
174        }
175        
176        let mut total_score = 0.0;
177        
178        for query_vec in &self.vectors {
179            let mut min_dist = f32::INFINITY;
180            
181            for doc_vec in &other.vectors {
182                let dist = l2_distance(query_vec, doc_vec);
183                if dist < min_dist {
184                    min_dist = dist;
185                }
186            }
187            
188            if min_dist < f32::INFINITY {
189                // Negative because we want higher scores for closer vectors
190                total_score -= min_dist;
191            }
192        }
193        
194        total_score
195    }
196}
197
198/// Simple dot product (can be replaced with SIMD version)
199#[inline]
200fn dot_product(a: &[f32], b: &[f32]) -> f32 {
201    a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
202}
203
204/// Vector norm
205#[inline]
206fn norm(v: &[f32]) -> f32 {
207    v.iter().map(|x| x * x).sum::<f32>().sqrt()
208}
209
210/// L2 distance
211#[inline]
212fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
213    a.iter()
214        .zip(b.iter())
215        .map(|(x, y)| (x - y) * (x - y))
216        .sum::<f32>()
217        .sqrt()
218}
219
220#[cfg(test)]
221mod tests {
222    use super::*;
223    
224    #[test]
225    fn test_multivector_creation() {
226        let mv = MultiVector::new(vec![
227            vec![1.0, 0.0, 0.0],
228            vec![0.0, 1.0, 0.0],
229        ]).unwrap();
230        
231        assert_eq!(mv.dim(), 3);
232        assert_eq!(mv.len(), 2);
233    }
234    
235    #[test]
236    fn test_max_sim_identical() {
237        let mv1 = MultiVector::new(vec![
238            vec![1.0, 0.0],
239            vec![0.0, 1.0],
240        ]).unwrap();
241        
242        let mv2 = MultiVector::new(vec![
243            vec![1.0, 0.0],
244            vec![0.0, 1.0],
245        ]).unwrap();
246        
247        // Each query vector should match perfectly with one doc vector
248        // Score = 1.0 + 1.0 = 2.0
249        let score = mv1.max_sim(&mv2);
250        assert!((score - 2.0).abs() < 1e-6);
251    }
252    
253    #[test]
254    fn test_max_sim_different() {
255        let query = MultiVector::new(vec![
256            vec![1.0, 0.0],
257        ]).unwrap();
258        
259        let doc = MultiVector::new(vec![
260            vec![0.5, 0.5],
261            vec![1.0, 0.0],
262        ]).unwrap();
263        
264        // Query has 1 vector [1,0], max sim with doc vectors is [1,0] = 1.0
265        let score = query.max_sim(&doc);
266        assert!((score - 1.0).abs() < 1e-6);
267    }
268    
269    #[test]
270    fn test_max_sim_cosine() {
271        let query = MultiVector::new(vec![
272            vec![2.0, 0.0],  // Not normalized
273        ]).unwrap();
274        
275        let doc = MultiVector::new(vec![
276            vec![1.0, 0.0],
277        ]).unwrap();
278        
279        // Cosine similarity should be 1.0 regardless of magnitude
280        let score = query.max_sim_cosine(&doc);
281        assert!((score - 1.0).abs() < 1e-6);
282    }
283}