1use anyhow::{anyhow, Result};
61use serde::{Deserialize, Serialize};
62use std::collections::HashMap;
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct MultiVectorDoc {
67 pub id: String,
69 pub vectors: Vec<Vec<f32>>,
71 pub metadata: serde_json::Value,
73}
74
75impl MultiVectorDoc {
76 pub fn new(id: impl Into<String>, vectors: Vec<Vec<f32>>, metadata: serde_json::Value) -> Self {
78 Self {
79 id: id.into(),
80 vectors,
81 metadata,
82 }
83 }
84
85 pub fn num_vectors(&self) -> usize {
87 self.vectors.len()
88 }
89
90 pub fn dimension(&self) -> usize {
92 self.vectors.first().map(|v| v.len()).unwrap_or(0)
93 }
94
95 pub fn validate(&self) -> Result<()> {
97 if self.vectors.is_empty() {
98 return Err(anyhow!("Document has no vectors"));
99 }
100
101 let dim = self.dimension();
102 for (i, vec) in self.vectors.iter().enumerate() {
103 if vec.len() != dim {
104 return Err(anyhow!(
105 "Vector {} has dimension {}, expected {}",
106 i,
107 vec.len(),
108 dim
109 ));
110 }
111 }
112
113 Ok(())
114 }
115}
116
117#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
119pub enum AggregationMethod {
120 MaxSim,
122 AvgSim,
124 SumSim,
126 FirstToken,
128}
129
130pub struct MultiVectorIndex {
132 dimension: usize,
134 documents: HashMap<String, MultiVectorDoc>,
136 token_index: Vec<(String, usize)>,
139 token_vectors: Vec<Vec<f32>>,
141 aggregation: AggregationMethod,
143}
144
145impl MultiVectorIndex {
146 pub fn new(dimension: usize) -> Self {
148 Self {
149 dimension,
150 documents: HashMap::new(),
151 token_index: Vec::new(),
152 token_vectors: Vec::new(),
153 aggregation: AggregationMethod::MaxSim,
154 }
155 }
156
157 pub fn with_aggregation(mut self, aggregation: AggregationMethod) -> Self {
159 self.aggregation = aggregation;
160 self
161 }
162
163 pub fn add(&mut self, doc: MultiVectorDoc) -> Result<()> {
165 doc.validate()?;
166
167 if doc.dimension() != self.dimension {
168 return Err(anyhow!(
169 "Document dimension {} doesn't match index dimension {}",
170 doc.dimension(),
171 self.dimension
172 ));
173 }
174
175 let doc_id = doc.id.clone();
176
177 for (token_idx, vector) in doc.vectors.iter().enumerate() {
179 self.token_index.push((doc_id.clone(), token_idx));
180 self.token_vectors.push(vector.clone());
181 }
182
183 self.documents.insert(doc_id, doc);
184
185 Ok(())
186 }
187
188 pub fn search(&self, query_vectors: &[Vec<f32>], k: usize) -> Result<Vec<(String, f32)>> {
190 if query_vectors.is_empty() {
191 return Err(anyhow!("Query has no vectors"));
192 }
193
194 for qv in query_vectors {
196 if qv.len() != self.dimension {
197 return Err(anyhow!(
198 "Query dimension {} doesn't match index dimension {}",
199 qv.len(),
200 self.dimension
201 ));
202 }
203 }
204
205 let mut doc_scores: HashMap<String, Vec<f32>> = HashMap::new();
207
208 for query_vec in query_vectors {
210 for (token_id, (doc_id, _token_idx)) in self.token_index.iter().enumerate() {
212 let token_vec = &self.token_vectors[token_id];
213 let sim = cosine_similarity(query_vec, token_vec);
214
215 doc_scores
216 .entry(doc_id.clone())
217 .or_insert_with(Vec::new)
218 .push(sim);
219 }
220 }
221
222 let mut results: Vec<(String, f32)> = doc_scores
224 .into_iter()
225 .map(|(doc_id, sims)| {
226 let score = match self.aggregation {
227 AggregationMethod::MaxSim => {
228 sims.iter().copied().fold(f32::NEG_INFINITY, f32::max)
229 }
230 AggregationMethod::AvgSim => sims.iter().sum::<f32>() / sims.len() as f32,
231 AggregationMethod::SumSim => sims.iter().sum(),
232 AggregationMethod::FirstToken => sims.first().copied().unwrap_or(0.0),
233 };
234 (doc_id, score)
235 })
236 .collect();
237
238 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap());
240 results.truncate(k);
241
242 Ok(results)
243 }
244
245 pub fn get(&self, doc_id: &str) -> Option<&MultiVectorDoc> {
247 self.documents.get(doc_id)
248 }
249
250 pub fn num_documents(&self) -> usize {
252 self.documents.len()
253 }
254
255 pub fn num_tokens(&self) -> usize {
257 self.token_vectors.len()
258 }
259
260 pub fn stats(&self) -> MultiVectorStats {
262 let avg_tokens_per_doc = if !self.documents.is_empty() {
263 self.num_tokens() as f32 / self.num_documents() as f32
264 } else {
265 0.0
266 };
267
268 MultiVectorStats {
269 num_documents: self.num_documents(),
270 num_tokens: self.num_tokens(),
271 dimension: self.dimension,
272 avg_tokens_per_doc,
273 aggregation: self.aggregation,
274 }
275 }
276}
277
278#[derive(Debug, Clone)]
280pub struct MultiVectorStats {
281 pub num_documents: usize,
282 pub num_tokens: usize,
283 pub dimension: usize,
284 pub avg_tokens_per_doc: f32,
285 pub aggregation: AggregationMethod,
286}
287
288fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
290 assert_eq!(a.len(), b.len());
291
292 let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
293 let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
294 let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
295
296 if norm_a == 0.0 || norm_b == 0.0 {
297 0.0
298 } else {
299 dot / (norm_a * norm_b)
300 }
301}
302
303pub mod colbert {
305 use super::*;
306
307 pub struct ColBERTQuery {
309 pub tokens: Vec<Vec<f32>>,
311 }
312
313 impl ColBERTQuery {
314 pub fn new(tokens: Vec<Vec<f32>>) -> Self {
316 Self { tokens }
317 }
318
319 pub fn score(&self, doc: &MultiVectorDoc) -> f32 {
321 if self.tokens.is_empty() || doc.vectors.is_empty() {
322 return 0.0;
323 }
324
325 let mut total_score = 0.0;
326
327 for query_token in &self.tokens {
329 let max_sim = doc
330 .vectors
331 .iter()
332 .map(|doc_token| cosine_similarity(query_token, doc_token))
333 .fold(f32::NEG_INFINITY, f32::max);
334
335 total_score += max_sim;
336 }
337
338 total_score
339 }
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
348 fn test_multi_vector_doc_creation() {
349 let doc = MultiVectorDoc::new(
350 "doc1",
351 vec![vec![1.0, 2.0], vec![3.0, 4.0]],
352 serde_json::json!({}),
353 );
354
355 assert_eq!(doc.id, "doc1");
356 assert_eq!(doc.num_vectors(), 2);
357 assert_eq!(doc.dimension(), 2);
358 }
359
360 #[test]
361 fn test_doc_validation() {
362 let valid_doc = MultiVectorDoc::new(
363 "doc1",
364 vec![vec![1.0, 2.0], vec![3.0, 4.0]],
365 serde_json::json!({}),
366 );
367 assert!(valid_doc.validate().is_ok());
368
369 let invalid_doc = MultiVectorDoc::new(
370 "doc2",
371 vec![vec![1.0, 2.0], vec![3.0, 4.0, 5.0]], serde_json::json!({}),
373 );
374 assert!(invalid_doc.validate().is_err());
375 }
376
377 #[test]
378 fn test_index_add_and_get() {
379 let mut index = MultiVectorIndex::new(2);
380
381 let doc = MultiVectorDoc::new(
382 "doc1",
383 vec![vec![1.0, 2.0], vec![3.0, 4.0]],
384 serde_json::json!({}),
385 );
386
387 assert!(index.add(doc.clone()).is_ok());
388 assert_eq!(index.num_documents(), 1);
389 assert_eq!(index.num_tokens(), 2);
390
391 let retrieved = index.get("doc1").unwrap();
392 assert_eq!(retrieved.id, "doc1");
393 }
394
395 #[test]
396 fn test_multi_vector_search_maxsim() {
397 let mut index = MultiVectorIndex::new(2).with_aggregation(AggregationMethod::MaxSim);
398
399 let doc1 = MultiVectorDoc::new(
401 "doc1",
402 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
403 serde_json::json!({}),
404 );
405 let doc2 = MultiVectorDoc::new(
406 "doc2",
407 vec![vec![0.5, 0.5], vec![0.5, 0.5]],
408 serde_json::json!({}),
409 );
410
411 index.add(doc1).unwrap();
412 index.add(doc2).unwrap();
413
414 let query = vec![vec![1.0, 0.0]];
416 let results = index.search(&query, 2).unwrap();
417
418 assert_eq!(results.len(), 2);
419 assert_eq!(results[0].0, "doc1");
421 }
422
423 #[test]
424 fn test_cosine_similarity() {
425 let a = vec![1.0, 0.0, 0.0];
426 let b = vec![1.0, 0.0, 0.0];
427 assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
428
429 let c = vec![1.0, 0.0, 0.0];
430 let d = vec![0.0, 1.0, 0.0];
431 assert!((cosine_similarity(&c, &d) - 0.0).abs() < 0.001);
432 }
433
434 #[test]
435 fn test_colbert_query() {
436 use colbert::*;
437
438 let query = ColBERTQuery::new(vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
439
440 let doc = MultiVectorDoc::new(
441 "doc1",
442 vec![vec![1.0, 0.0], vec![0.5, 0.5]],
443 serde_json::json!({}),
444 );
445
446 let score = query.score(&doc);
447 assert!(score > 0.0);
448 }
449
450 #[test]
451 fn test_index_stats() {
452 let mut index = MultiVectorIndex::new(128);
453
454 let doc1 = MultiVectorDoc::new(
455 "doc1",
456 vec![vec![0.0; 128], vec![0.1; 128]],
457 serde_json::json!({}),
458 );
459 let doc2 = MultiVectorDoc::new(
460 "doc2",
461 vec![vec![0.2; 128], vec![0.3; 128], vec![0.4; 128]],
462 serde_json::json!({}),
463 );
464
465 index.add(doc1).unwrap();
466 index.add(doc2).unwrap();
467
468 let stats = index.stats();
469 assert_eq!(stats.num_documents, 2);
470 assert_eq!(stats.num_tokens, 5); assert_eq!(stats.dimension, 128);
472 assert!((stats.avg_tokens_per_doc - 2.5).abs() < 0.01);
473 }
474
475 #[test]
476 fn test_aggregation_methods() {
477 let mut index = MultiVectorIndex::new(2);
478
479 let doc = MultiVectorDoc::new(
480 "doc1",
481 vec![vec![1.0, 0.0], vec![0.0, 1.0]],
482 serde_json::json!({}),
483 );
484 index.add(doc).unwrap();
485
486 index.aggregation = AggregationMethod::MaxSim;
488 let query = vec![vec![1.0, 0.0]];
489 let results = index.search(&query, 1).unwrap();
490 assert!(results[0].1 > 0.9); index.aggregation = AggregationMethod::AvgSim;
494 let results = index.search(&query, 1).unwrap();
495 assert!(results[0].1 > 0.0 && results[0].1 < 1.0); }
497}