vectx_core/
multivector.rs1use serde::{Deserialize, Serialize};
7use crate::Vector;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Default, Serialize, Deserialize)]
11pub enum MultiVectorComparator {
12 #[default]
15 MaxSim,
16}
17
18#[derive(Debug, Clone, Default, Serialize, Deserialize)]
20pub struct MultiVectorConfig {
21 pub comparator: MultiVectorComparator,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
27pub struct MultiVector {
28 vectors: Vec<Vec<f32>>,
30 dim: usize,
32}
33
34impl MultiVector {
35 pub fn new(vectors: Vec<Vec<f32>>) -> Result<Self, &'static str> {
38 if vectors.is_empty() {
39 return Err("MultiVector cannot be empty");
40 }
41
42 let dim = vectors[0].len();
43 if dim == 0 {
44 return Err("Sub-vectors cannot be empty");
45 }
46
47 if !vectors.iter().all(|v| v.len() == dim) {
49 return Err("All sub-vectors must have the same dimension");
50 }
51
52 Ok(Self { vectors, dim })
53 }
54
55 pub fn from_single(vector: Vec<f32>) -> Result<Self, &'static str> {
57 if vector.is_empty() {
58 return Err("Vector cannot be empty");
59 }
60 let dim = vector.len();
61 Ok(Self { vectors: vec![vector], dim })
62 }
63
64 #[inline]
66 pub fn dim(&self) -> usize {
67 self.dim
68 }
69
70 #[inline]
72 pub fn len(&self) -> usize {
73 self.vectors.len()
74 }
75
76 #[inline]
78 pub fn is_empty(&self) -> bool {
79 self.vectors.is_empty()
80 }
81
82 #[inline]
84 pub fn vectors(&self) -> &[Vec<f32>] {
85 &self.vectors
86 }
87
88 #[inline]
90 pub fn first(&self) -> Option<&Vec<f32>> {
91 self.vectors.first()
92 }
93
94 pub fn to_single_vector(&self) -> Vector {
97 Vector::new(self.vectors[0].clone())
98 }
99
100 pub fn max_sim(&self, other: &MultiVector) -> f32 {
107 if self.dim != other.dim {
108 return 0.0;
109 }
110
111 let mut total_score = 0.0;
112
113 for query_vec in &self.vectors {
115 let mut max_sim = f32::NEG_INFINITY;
116
117 for doc_vec in &other.vectors {
119 let sim = dot_product(query_vec, doc_vec);
120 if sim > max_sim {
121 max_sim = sim;
122 }
123 }
124
125 if max_sim > f32::NEG_INFINITY {
127 total_score += max_sim;
128 }
129 }
130
131 total_score
132 }
133
134 pub fn max_sim_cosine(&self, other: &MultiVector) -> f32 {
136 if self.dim != other.dim {
137 return 0.0;
138 }
139
140 let mut total_score = 0.0;
141
142 for query_vec in &self.vectors {
143 let query_norm = norm(query_vec);
144 if query_norm < f32::EPSILON {
145 continue;
146 }
147
148 let mut max_sim = f32::NEG_INFINITY;
149
150 for doc_vec in &other.vectors {
151 let doc_norm = norm(doc_vec);
152 if doc_norm < f32::EPSILON {
153 continue;
154 }
155
156 let sim = dot_product(query_vec, doc_vec) / (query_norm * doc_norm);
157 if sim > max_sim {
158 max_sim = sim;
159 }
160 }
161
162 if max_sim > f32::NEG_INFINITY {
163 total_score += max_sim;
164 }
165 }
166
167 total_score
168 }
169
170 pub fn max_sim_l2(&self, other: &MultiVector) -> f32 {
172 if self.dim != other.dim {
173 return f32::NEG_INFINITY;
174 }
175
176 let mut total_score = 0.0;
177
178 for query_vec in &self.vectors {
179 let mut min_dist = f32::INFINITY;
180
181 for doc_vec in &other.vectors {
182 let dist = l2_distance(query_vec, doc_vec);
183 if dist < min_dist {
184 min_dist = dist;
185 }
186 }
187
188 if min_dist < f32::INFINITY {
189 total_score -= min_dist;
191 }
192 }
193
194 total_score
195 }
196}
197
198#[inline]
200fn dot_product(a: &[f32], b: &[f32]) -> f32 {
201 a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
202}
203
204#[inline]
206fn norm(v: &[f32]) -> f32 {
207 v.iter().map(|x| x * x).sum::<f32>().sqrt()
208}
209
210#[inline]
212fn l2_distance(a: &[f32], b: &[f32]) -> f32 {
213 a.iter()
214 .zip(b.iter())
215 .map(|(x, y)| (x - y) * (x - y))
216 .sum::<f32>()
217 .sqrt()
218}
219
220#[cfg(test)]
221mod tests {
222 use super::*;
223
224 #[test]
225 fn test_multivector_creation() {
226 let mv = MultiVector::new(vec![
227 vec![1.0, 0.0, 0.0],
228 vec![0.0, 1.0, 0.0],
229 ]).unwrap();
230
231 assert_eq!(mv.dim(), 3);
232 assert_eq!(mv.len(), 2);
233 }
234
235 #[test]
236 fn test_max_sim_identical() {
237 let mv1 = MultiVector::new(vec![
238 vec![1.0, 0.0],
239 vec![0.0, 1.0],
240 ]).unwrap();
241
242 let mv2 = MultiVector::new(vec![
243 vec![1.0, 0.0],
244 vec![0.0, 1.0],
245 ]).unwrap();
246
247 let score = mv1.max_sim(&mv2);
250 assert!((score - 2.0).abs() < 1e-6);
251 }
252
253 #[test]
254 fn test_max_sim_different() {
255 let query = MultiVector::new(vec![
256 vec![1.0, 0.0],
257 ]).unwrap();
258
259 let doc = MultiVector::new(vec![
260 vec![0.5, 0.5],
261 vec![1.0, 0.0],
262 ]).unwrap();
263
264 let score = query.max_sim(&doc);
266 assert!((score - 1.0).abs() < 1e-6);
267 }
268
269 #[test]
270 fn test_max_sim_cosine() {
271 let query = MultiVector::new(vec![
272 vec![2.0, 0.0], ]).unwrap();
274
275 let doc = MultiVector::new(vec![
276 vec![1.0, 0.0],
277 ]).unwrap();
278
279 let score = query.max_sim_cosine(&doc);
281 assert!((score - 1.0).abs() < 1e-6);
282 }
283}