1use std::collections::HashMap;
60use std::sync::Arc;
61
62use parking_lot::RwLock;
63
64pub type DocId = String;
70
71pub type InternalId = u64;
73
74pub type ChunkIndex = u32;
76
77#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
83pub enum AggregationMethod {
84 #[default]
87 Max,
88
89 Mean,
92
93 First,
96
97 Last,
99
100 Sum,
102}
103
104impl AggregationMethod {
105 pub fn from_str(s: &str) -> Option<Self> {
107 match s.to_lowercase().as_str() {
108 "max" => Some(Self::Max),
109 "mean" | "avg" | "average" => Some(Self::Mean),
110 "first" => Some(Self::First),
111 "last" => Some(Self::Last),
112 "sum" => Some(Self::Sum),
113 _ => None,
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct DocumentScore {
121 pub doc_id: DocId,
123
124 pub score: f32,
126
127 pub best_chunk: Option<ChunkIndex>,
129
130 pub matched_chunks: usize,
132
133 pub chunk_scores: Option<Vec<(ChunkIndex, f32)>>,
135}
136
137impl DocumentScore {
138 pub fn aggregate(
140 doc_id: DocId,
141 chunk_scores: Vec<(ChunkIndex, f32)>,
142 method: AggregationMethod,
143 keep_details: bool,
144 ) -> Self {
145 if chunk_scores.is_empty() {
146 return Self {
147 doc_id,
148 score: 0.0,
149 best_chunk: None,
150 matched_chunks: 0,
151 chunk_scores: if keep_details { Some(Vec::new()) } else { None },
152 };
153 }
154
155 let matched_chunks = chunk_scores.len();
156
157 let (score, best_chunk) = match method {
158 AggregationMethod::Max => {
159 let (_idx, &(chunk, score)) = chunk_scores
160 .iter()
161 .enumerate()
162 .max_by(|(_, a), (_, b)| a.1.partial_cmp(&b.1).unwrap())
163 .unwrap();
164 (score, Some(chunk))
165 }
166 AggregationMethod::Mean => {
167 let sum: f32 = chunk_scores.iter().map(|(_, s)| s).sum();
168 (sum / chunk_scores.len() as f32, None)
169 }
170 AggregationMethod::First => {
171 let (chunk, score) = chunk_scores
172 .iter()
173 .min_by_key(|(idx, _)| *idx)
174 .copied()
175 .unwrap();
176 (score, Some(chunk))
177 }
178 AggregationMethod::Last => {
179 let (chunk, score) = chunk_scores
180 .iter()
181 .max_by_key(|(idx, _)| *idx)
182 .copied()
183 .unwrap();
184 (score, Some(chunk))
185 }
186 AggregationMethod::Sum => {
187 let sum: f32 = chunk_scores.iter().map(|(_, s)| s).sum();
188 (sum, None)
189 }
190 };
191
192 Self {
193 doc_id,
194 score,
195 best_chunk,
196 matched_chunks,
197 chunk_scores: if keep_details {
198 Some(chunk_scores)
199 } else {
200 None
201 },
202 }
203 }
204}
205
206#[derive(Debug, Clone)]
212pub struct MultiVectorMapping {
213 internal_to_doc: HashMap<InternalId, (DocId, ChunkIndex)>,
215
216 doc_to_internal: HashMap<DocId, Vec<InternalId>>,
218
219 next_internal_id: InternalId,
221}
222
223impl MultiVectorMapping {
224 pub fn new() -> Self {
226 Self {
227 internal_to_doc: HashMap::new(),
228 doc_to_internal: HashMap::new(),
229 next_internal_id: 0,
230 }
231 }
232
233 pub fn insert_document(&mut self, doc_id: DocId, num_chunks: usize) -> Vec<InternalId> {
235 self.remove_document(&doc_id);
237
238 let mut internal_ids = Vec::with_capacity(num_chunks);
239
240 for chunk_idx in 0..num_chunks {
241 let internal_id = self.next_internal_id;
242 self.next_internal_id += 1;
243
244 self.internal_to_doc
245 .insert(internal_id, (doc_id.clone(), chunk_idx as ChunkIndex));
246 internal_ids.push(internal_id);
247 }
248
249 self.doc_to_internal.insert(doc_id, internal_ids.clone());
250
251 internal_ids
252 }
253
254 pub fn remove_document(&mut self, doc_id: &str) -> Option<Vec<InternalId>> {
256 if let Some(internal_ids) = self.doc_to_internal.remove(doc_id) {
257 for id in &internal_ids {
258 self.internal_to_doc.remove(id);
259 }
260 Some(internal_ids)
261 } else {
262 None
263 }
264 }
265
266 #[inline]
268 pub fn get_doc(&self, internal_id: InternalId) -> Option<(&DocId, ChunkIndex)> {
269 self.internal_to_doc.get(&internal_id).map(|(d, c)| (d, *c))
270 }
271
272 pub fn get_internal_ids(&self, doc_id: &str) -> Option<&[InternalId]> {
274 self.doc_to_internal.get(doc_id).map(|v| v.as_slice())
275 }
276
277 pub fn has_document(&self, doc_id: &str) -> bool {
279 self.doc_to_internal.contains_key(doc_id)
280 }
281
282 pub fn num_documents(&self) -> usize {
284 self.doc_to_internal.len()
285 }
286
287 pub fn num_vectors(&self) -> usize {
289 self.internal_to_doc.len()
290 }
291}
292
293impl Default for MultiVectorMapping {
294 fn default() -> Self {
295 Self::new()
296 }
297}
298
299pub struct MultiVectorAggregator {
305 mapping: Arc<RwLock<MultiVectorMapping>>,
307
308 default_method: AggregationMethod,
310}
311
312impl MultiVectorAggregator {
313 pub fn new(mapping: Arc<RwLock<MultiVectorMapping>>) -> Self {
315 Self {
316 mapping,
317 default_method: AggregationMethod::Max,
318 }
319 }
320
321 pub fn with_default_method(mut self, method: AggregationMethod) -> Self {
323 self.default_method = method;
324 self
325 }
326
327 pub fn aggregate(
332 &self,
333 vector_results: &[(InternalId, f32)],
334 method: Option<AggregationMethod>,
335 limit: usize,
336 ) -> Vec<DocumentScore> {
337 let method = method.unwrap_or(self.default_method);
338 let mapping = self.mapping.read();
339
340 let mut doc_chunks: HashMap<&DocId, Vec<(ChunkIndex, f32)>> = HashMap::new();
342
343 for &(internal_id, score) in vector_results {
344 if let Some((doc_id, chunk_idx)) = mapping.get_doc(internal_id) {
345 doc_chunks
346 .entry(doc_id)
347 .or_default()
348 .push((chunk_idx, score));
349 }
350 }
351
352 let mut results: Vec<DocumentScore> = doc_chunks
354 .into_iter()
355 .map(|(doc_id, chunks)| DocumentScore::aggregate(doc_id.clone(), chunks, method, false))
356 .collect();
357
358 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
360
361 results.truncate(limit);
363
364 results
365 }
366
367 pub fn aggregate_detailed(
369 &self,
370 vector_results: &[(InternalId, f32)],
371 method: Option<AggregationMethod>,
372 limit: usize,
373 ) -> Vec<DocumentScore> {
374 let method = method.unwrap_or(self.default_method);
375 let mapping = self.mapping.read();
376
377 let mut doc_chunks: HashMap<&DocId, Vec<(ChunkIndex, f32)>> = HashMap::new();
379
380 for &(internal_id, score) in vector_results {
381 if let Some((doc_id, chunk_idx)) = mapping.get_doc(internal_id) {
382 doc_chunks
383 .entry(doc_id)
384 .or_default()
385 .push((chunk_idx, score));
386 }
387 }
388
389 let mut results: Vec<DocumentScore> = doc_chunks
391 .into_iter()
392 .map(|(doc_id, chunks)| DocumentScore::aggregate(doc_id.clone(), chunks, method, true))
393 .collect();
394
395 results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
397
398 results.truncate(limit);
400
401 results
402 }
403}
404
405#[derive(Debug, Clone)]
411pub struct MultiVectorConfig {
412 pub max_chunks_per_doc: usize,
414
415 pub default_aggregation: AggregationMethod,
417
418 pub overfetch_factor: f32,
420}
421
422impl Default for MultiVectorConfig {
423 fn default() -> Self {
424 Self {
425 max_chunks_per_doc: 1000,
426 default_aggregation: AggregationMethod::Max,
427 overfetch_factor: 2.0,
428 }
429 }
430}
431
432#[derive(Debug, Clone)]
434pub struct MultiVectorDocument {
435 pub id: DocId,
437
438 pub vectors: Vec<Vec<f32>>,
440
441 pub chunks_text: Option<Vec<String>>,
443
444 pub metadata: HashMap<String, serde_json::Value>,
446}
447
448impl MultiVectorDocument {
449 pub fn new(id: impl Into<DocId>, vectors: Vec<Vec<f32>>) -> Self {
451 Self {
452 id: id.into(),
453 vectors,
454 chunks_text: None,
455 metadata: HashMap::new(),
456 }
457 }
458
459 pub fn with_text(mut self, chunks: Vec<String>) -> Self {
461 self.chunks_text = Some(chunks);
462 self
463 }
464
465 pub fn with_metadata(
467 mut self,
468 key: impl Into<String>,
469 value: impl Into<serde_json::Value>,
470 ) -> Self {
471 self.metadata.insert(key.into(), value.into());
472 self
473 }
474
475 pub fn num_chunks(&self) -> usize {
477 self.vectors.len()
478 }
479
480 pub fn validate(&self, expected_dim: usize) -> Result<(), MultiVectorError> {
482 if self.vectors.is_empty() {
483 return Err(MultiVectorError::NoVectors);
484 }
485
486 for (i, v) in self.vectors.iter().enumerate() {
487 if v.len() != expected_dim {
488 return Err(MultiVectorError::DimensionMismatch {
489 chunk: i,
490 expected: expected_dim,
491 actual: v.len(),
492 });
493 }
494 }
495
496 if let Some(ref texts) = self.chunks_text {
497 if texts.len() != self.vectors.len() {
498 return Err(MultiVectorError::ChunkCountMismatch {
499 vectors: self.vectors.len(),
500 texts: texts.len(),
501 });
502 }
503 }
504
505 Ok(())
506 }
507}
508
509#[derive(Debug, thiserror::Error)]
511pub enum MultiVectorError {
512 #[error("document must have at least one vector")]
513 NoVectors,
514
515 #[error("dimension mismatch in chunk {chunk}: expected {expected}, got {actual}")]
516 DimensionMismatch {
517 chunk: usize,
518 expected: usize,
519 actual: usize,
520 },
521
522 #[error("chunk count mismatch: {vectors} vectors but {texts} texts")]
523 ChunkCountMismatch { vectors: usize, texts: usize },
524
525 #[error("too many chunks: {count} exceeds limit of {limit}")]
526 TooManyChunks { count: usize, limit: usize },
527
528 #[error("document not found: {0}")]
529 NotFound(DocId),
530}
531
532#[cfg(test)]
537mod tests {
538 use super::*;
539
540 #[test]
541 fn test_aggregation_max() {
542 let chunks = vec![(0, 0.5), (1, 0.9), (2, 0.3)];
543
544 let result =
545 DocumentScore::aggregate("doc1".to_string(), chunks, AggregationMethod::Max, false);
546
547 assert_eq!(result.score, 0.9);
548 assert_eq!(result.best_chunk, Some(1));
549 assert_eq!(result.matched_chunks, 3);
550 }
551
552 #[test]
553 fn test_aggregation_mean() {
554 let chunks = vec![(0, 0.6), (1, 0.9), (2, 0.3)];
555
556 let result =
557 DocumentScore::aggregate("doc1".to_string(), chunks, AggregationMethod::Mean, false);
558
559 assert!((result.score - 0.6).abs() < 0.001); }
561
562 #[test]
563 fn test_aggregation_first() {
564 let chunks = vec![(2, 0.3), (0, 0.5), (1, 0.9)];
565
566 let result =
567 DocumentScore::aggregate("doc1".to_string(), chunks, AggregationMethod::First, false);
568
569 assert_eq!(result.score, 0.5); assert_eq!(result.best_chunk, Some(0));
571 }
572
573 #[test]
574 fn test_mapping_insert() {
575 let mut mapping = MultiVectorMapping::new();
576
577 let ids = mapping.insert_document("doc1".to_string(), 3);
578 assert_eq!(ids.len(), 3);
579
580 for (i, &id) in ids.iter().enumerate() {
582 let (doc_id, chunk) = mapping.get_doc(id).unwrap();
583 assert_eq!(doc_id, "doc1");
584 assert_eq!(chunk as usize, i);
585 }
586 }
587
588 #[test]
589 fn test_mapping_remove() {
590 let mut mapping = MultiVectorMapping::new();
591
592 let ids = mapping.insert_document("doc1".to_string(), 3);
593
594 let removed = mapping.remove_document("doc1").unwrap();
595 assert_eq!(removed, ids);
596
597 assert!(mapping.get_doc(ids[0]).is_none());
599 assert!(!mapping.has_document("doc1"));
600 }
601
602 #[test]
603 fn test_aggregator() {
604 let mapping = Arc::new(RwLock::new(MultiVectorMapping::new()));
605
606 {
608 let mut m = mapping.write();
609 m.insert_document("doc1".to_string(), 3); m.insert_document("doc2".to_string(), 2); }
612
613 let aggregator = MultiVectorAggregator::new(mapping);
614
615 let vector_results = vec![
617 (1, 0.95), (3, 0.90), (0, 0.85), (4, 0.80), ];
622
623 let doc_results = aggregator.aggregate(&vector_results, Some(AggregationMethod::Max), 10);
624
625 assert_eq!(doc_results.len(), 2);
626 assert_eq!(doc_results[0].doc_id, "doc1");
627 assert_eq!(doc_results[0].score, 0.95);
628 assert_eq!(doc_results[1].doc_id, "doc2");
629 assert_eq!(doc_results[1].score, 0.90);
630 }
631
632 #[test]
633 fn test_multi_vector_document() {
634 let doc = MultiVectorDocument::new("doc1", vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]])
635 .with_text(vec!["chunk 1".to_string(), "chunk 2".to_string()])
636 .with_metadata("author", serde_json::json!("Alice"));
637
638 assert_eq!(doc.num_chunks(), 2);
639 assert!(doc.validate(3).is_ok());
640 assert!(doc.validate(4).is_err()); }
642}