Skip to main content

sochdb_vector/
multi_vector.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Multi-Vector Documents with Stable Aggregation Semantics (Task 5)
19//!
20//! This module enables documents to have multiple vectors (e.g., for chunks/paragraphs)
21//! with deterministic aggregation during search.
22//!
23//! ## Design
24//!
25//! ```text
26//! Document (doc_id=123)
27//! ├── Chunk 0 → Vector 0 (internal_id=1000)
28//! ├── Chunk 1 → Vector 1 (internal_id=1001)
29//! ├── Chunk 2 → Vector 2 (internal_id=1002)
30//! └── Chunk 3 → Vector 3 (internal_id=1003)
31//!
32//! Search: query → [1001, 1003, 1002] (internal IDs with scores)
33//!        → Aggregate by doc_id → doc_123: max(score(1001), score(1003), score(1002))
34//! ```
35//!
36//! ## Aggregation Methods
37//!
38//! - **Max**: Use the best-matching chunk's score (ColBERT-like late interaction)
39//! - **Mean**: Average all chunk scores (good for comprehensive coverage)
40//! - **First**: Use the first chunk's score (for ordered content)
41//!
42//! ## API
43//!
44//! ```ignore
45//! // Insert multi-vector document
46//! collection.insert_multi(
47//!     doc_id="doc_123",
48//!     vectors=[v1, v2, v3, v4],
49//!     metadata={...},
50//! )
51//!
52//! // Search with aggregation
53//! collection.search(
54//!     query,
55//!     aggregate="max",  // max|mean|first
56//! )
57//! ```
58
59use std::collections::HashMap;
60use std::sync::Arc;
61
62use parking_lot::RwLock;
63
64// ============================================================================
65// Types
66// ============================================================================
67
68/// Document ID (user-provided, stable identifier)
69pub type DocId = String;
70
71/// Internal vector ID (storage-assigned)
72pub type InternalId = u64;
73
74/// Chunk/part index within a document
75pub type ChunkIndex = u32;
76
77// ============================================================================
78// Aggregation
79// ============================================================================
80
81/// Aggregation method for multi-vector documents
82#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
83pub enum AggregationMethod {
84    /// Use the maximum score across all chunks (recommended for most use cases)
85    /// This is equivalent to "did any chunk match well?"
86    #[default]
87    Max,
88
89    /// Use the average score across all chunks
90    /// Good for measuring overall document relevance
91    Mean,
92
93    /// Use the score of the first chunk
94    /// Good for documents where order matters (e.g., abstracts first)
95    First,
96
97    /// Use the score of the last chunk
98    Last,
99
100    /// Sum of all chunk scores (for sparse-like behavior)
101    Sum,
102}
103
104impl AggregationMethod {
105    /// Parse from string
106    pub fn from_str(s: &str) -> Option<Self> {
107        match s.to_lowercase().as_str() {
108            "max" => Some(Self::Max),
109            "mean" | "avg" | "average" => Some(Self::Mean),
110            "first" => Some(Self::First),
111            "last" => Some(Self::Last),
112            "sum" => Some(Self::Sum),
113            _ => None,
114        }
115    }
116}
117
118/// Aggregate scores for a document
119#[derive(Debug, Clone)]
120pub struct DocumentScore {
121    /// Document ID
122    pub doc_id: DocId,
123
124    /// Aggregated score
125    pub score: f32,
126
127    /// Best matching chunk index (for max aggregation)
128    pub best_chunk: Option<ChunkIndex>,
129
130    /// Number of chunks that matched
131    pub matched_chunks: usize,
132
133    /// All chunk scores (optional, for debugging)
134    pub chunk_scores: Option<Vec<(ChunkIndex, f32)>>,
135}
136
137impl DocumentScore {
138    /// Create from chunk scores with aggregation
139    pub fn aggregate(
140        doc_id: DocId,
141        chunk_scores: Vec<(ChunkIndex, f32)>,
142        method: AggregationMethod,
143        keep_details: bool,
144    ) -> Self {
145        if chunk_scores.is_empty() {
146            return Self {
147                doc_id,
148                score: 0.0,
149                best_chunk: None,
150                matched_chunks: 0,
151                chunk_scores: if keep_details { Some(Vec::new()) } else { None },
152            };
153        }
154
155        let matched_chunks = chunk_scores.len();
156
157        let (score, best_chunk) = match method {
158            AggregationMethod::Max => {
159                let (_idx, &(chunk, score)) = chunk_scores
160                    .iter()
161                    .enumerate()
162                    .max_by(|(_, a), (_, b)| a.1.partial_cmp(&b.1).unwrap())
163                    .unwrap();
164                (score, Some(chunk))
165            }
166            AggregationMethod::Mean => {
167                let sum: f32 = chunk_scores.iter().map(|(_, s)| s).sum();
168                (sum / chunk_scores.len() as f32, None)
169            }
170            AggregationMethod::First => {
171                let (chunk, score) = chunk_scores
172                    .iter()
173                    .min_by_key(|(idx, _)| *idx)
174                    .copied()
175                    .unwrap();
176                (score, Some(chunk))
177            }
178            AggregationMethod::Last => {
179                let (chunk, score) = chunk_scores
180                    .iter()
181                    .max_by_key(|(idx, _)| *idx)
182                    .copied()
183                    .unwrap();
184                (score, Some(chunk))
185            }
186            AggregationMethod::Sum => {
187                let sum: f32 = chunk_scores.iter().map(|(_, s)| s).sum();
188                (sum, None)
189            }
190        };
191
192        Self {
193            doc_id,
194            score,
195            best_chunk,
196            matched_chunks,
197            chunk_scores: if keep_details {
198                Some(chunk_scores)
199            } else {
200                None
201            },
202        }
203    }
204}
205
206// ============================================================================
207// Multi-Vector Index Mapping
208// ============================================================================
209
210/// Mapping from internal vector IDs to document IDs and chunk indices
211#[derive(Debug, Clone)]
212pub struct MultiVectorMapping {
213    /// Map from internal ID to (doc_id, chunk_index)
214    internal_to_doc: HashMap<InternalId, (DocId, ChunkIndex)>,
215
216    /// Map from doc_id to list of internal IDs (ordered by chunk index)
217    doc_to_internal: HashMap<DocId, Vec<InternalId>>,
218
219    /// Next internal ID to assign
220    next_internal_id: InternalId,
221}
222
223impl MultiVectorMapping {
224    /// Create a new empty mapping
225    pub fn new() -> Self {
226        Self {
227            internal_to_doc: HashMap::new(),
228            doc_to_internal: HashMap::new(),
229            next_internal_id: 0,
230        }
231    }
232
233    /// Insert a multi-vector document, returning the internal IDs
234    pub fn insert_document(&mut self, doc_id: DocId, num_chunks: usize) -> Vec<InternalId> {
235        // Remove existing if present
236        self.remove_document(&doc_id);
237
238        let mut internal_ids = Vec::with_capacity(num_chunks);
239
240        for chunk_idx in 0..num_chunks {
241            let internal_id = self.next_internal_id;
242            self.next_internal_id += 1;
243
244            self.internal_to_doc
245                .insert(internal_id, (doc_id.clone(), chunk_idx as ChunkIndex));
246            internal_ids.push(internal_id);
247        }
248
249        self.doc_to_internal.insert(doc_id, internal_ids.clone());
250
251        internal_ids
252    }
253
254    /// Remove a document and its vectors
255    pub fn remove_document(&mut self, doc_id: &str) -> Option<Vec<InternalId>> {
256        if let Some(internal_ids) = self.doc_to_internal.remove(doc_id) {
257            for id in &internal_ids {
258                self.internal_to_doc.remove(id);
259            }
260            Some(internal_ids)
261        } else {
262            None
263        }
264    }
265
266    /// Lookup document ID and chunk index for an internal ID
267    #[inline]
268    pub fn get_doc(&self, internal_id: InternalId) -> Option<(&DocId, ChunkIndex)> {
269        self.internal_to_doc.get(&internal_id).map(|(d, c)| (d, *c))
270    }
271
272    /// Get all internal IDs for a document
273    pub fn get_internal_ids(&self, doc_id: &str) -> Option<&[InternalId]> {
274        self.doc_to_internal.get(doc_id).map(|v| v.as_slice())
275    }
276
277    /// Check if a document exists
278    pub fn has_document(&self, doc_id: &str) -> bool {
279        self.doc_to_internal.contains_key(doc_id)
280    }
281
282    /// Get the number of documents
283    pub fn num_documents(&self) -> usize {
284        self.doc_to_internal.len()
285    }
286
287    /// Get the total number of vectors
288    pub fn num_vectors(&self) -> usize {
289        self.internal_to_doc.len()
290    }
291}
292
293impl Default for MultiVectorMapping {
294    fn default() -> Self {
295        Self::new()
296    }
297}
298
299// ============================================================================
300// Multi-Vector Aggregator
301// ============================================================================
302
303/// Aggregates search results from vector level to document level
304pub struct MultiVectorAggregator {
305    /// Mapping from internal IDs to documents
306    mapping: Arc<RwLock<MultiVectorMapping>>,
307
308    /// Default aggregation method
309    default_method: AggregationMethod,
310}
311
312impl MultiVectorAggregator {
313    /// Create a new aggregator
314    pub fn new(mapping: Arc<RwLock<MultiVectorMapping>>) -> Self {
315        Self {
316            mapping,
317            default_method: AggregationMethod::Max,
318        }
319    }
320
321    /// Set the default aggregation method
322    pub fn with_default_method(mut self, method: AggregationMethod) -> Self {
323        self.default_method = method;
324        self
325    }
326
327    /// Aggregate vector search results to document results
328    ///
329    /// Input: Vec<(internal_id, score)> from vector search
330    /// Output: Vec<DocumentScore> sorted by aggregated score
331    pub fn aggregate(
332        &self,
333        vector_results: &[(InternalId, f32)],
334        method: Option<AggregationMethod>,
335        limit: usize,
336    ) -> Vec<DocumentScore> {
337        let method = method.unwrap_or(self.default_method);
338        let mapping = self.mapping.read();
339
340        // Group by document
341        let mut doc_chunks: HashMap<&DocId, Vec<(ChunkIndex, f32)>> = HashMap::new();
342
343        for &(internal_id, score) in vector_results {
344            if let Some((doc_id, chunk_idx)) = mapping.get_doc(internal_id) {
345                doc_chunks
346                    .entry(doc_id)
347                    .or_default()
348                    .push((chunk_idx, score));
349            }
350        }
351
352        // Aggregate each document
353        let mut results: Vec<DocumentScore> = doc_chunks
354            .into_iter()
355            .map(|(doc_id, chunks)| DocumentScore::aggregate(doc_id.clone(), chunks, method, false))
356            .collect();
357
358        // Sort by score descending
359        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
360
361        // Limit
362        results.truncate(limit);
363
364        results
365    }
366
367    /// Aggregate with detailed chunk information
368    pub fn aggregate_detailed(
369        &self,
370        vector_results: &[(InternalId, f32)],
371        method: Option<AggregationMethod>,
372        limit: usize,
373    ) -> Vec<DocumentScore> {
374        let method = method.unwrap_or(self.default_method);
375        let mapping = self.mapping.read();
376
377        // Group by document
378        let mut doc_chunks: HashMap<&DocId, Vec<(ChunkIndex, f32)>> = HashMap::new();
379
380        for &(internal_id, score) in vector_results {
381            if let Some((doc_id, chunk_idx)) = mapping.get_doc(internal_id) {
382                doc_chunks
383                    .entry(doc_id)
384                    .or_default()
385                    .push((chunk_idx, score));
386            }
387        }
388
389        // Aggregate each document with details
390        let mut results: Vec<DocumentScore> = doc_chunks
391            .into_iter()
392            .map(|(doc_id, chunks)| DocumentScore::aggregate(doc_id.clone(), chunks, method, true))
393            .collect();
394
395        // Sort by score descending
396        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
397
398        // Limit
399        results.truncate(limit);
400
401        results
402    }
403}
404
405// ============================================================================
406// Multi-Vector Collection (High-Level API)
407// ============================================================================
408
409/// Configuration for multi-vector storage
410#[derive(Debug, Clone)]
411pub struct MultiVectorConfig {
412    /// Maximum chunks per document
413    pub max_chunks_per_doc: usize,
414
415    /// Default aggregation method
416    pub default_aggregation: AggregationMethod,
417
418    /// Over-fetch factor for ensuring enough unique documents
419    pub overfetch_factor: f32,
420}
421
422impl Default for MultiVectorConfig {
423    fn default() -> Self {
424        Self {
425            max_chunks_per_doc: 1000,
426            default_aggregation: AggregationMethod::Max,
427            overfetch_factor: 2.0,
428        }
429    }
430}
431
432/// Multi-vector document for insertion
433#[derive(Debug, Clone)]
434pub struct MultiVectorDocument {
435    /// Document ID (stable, user-provided)
436    pub id: DocId,
437
438    /// Vectors for each chunk
439    pub vectors: Vec<Vec<f32>>,
440
441    /// Optional: text content for each chunk (for hybrid search)
442    pub chunks_text: Option<Vec<String>>,
443
444    /// Document-level metadata
445    pub metadata: HashMap<String, serde_json::Value>,
446}
447
448impl MultiVectorDocument {
449    /// Create a new multi-vector document
450    pub fn new(id: impl Into<DocId>, vectors: Vec<Vec<f32>>) -> Self {
451        Self {
452            id: id.into(),
453            vectors,
454            chunks_text: None,
455            metadata: HashMap::new(),
456        }
457    }
458
459    /// Add chunk text content
460    pub fn with_text(mut self, chunks: Vec<String>) -> Self {
461        self.chunks_text = Some(chunks);
462        self
463    }
464
465    /// Add metadata
466    pub fn with_metadata(
467        mut self,
468        key: impl Into<String>,
469        value: impl Into<serde_json::Value>,
470    ) -> Self {
471        self.metadata.insert(key.into(), value.into());
472        self
473    }
474
475    /// Number of chunks
476    pub fn num_chunks(&self) -> usize {
477        self.vectors.len()
478    }
479
480    /// Validate the document
481    pub fn validate(&self, expected_dim: usize) -> Result<(), MultiVectorError> {
482        if self.vectors.is_empty() {
483            return Err(MultiVectorError::NoVectors);
484        }
485
486        for (i, v) in self.vectors.iter().enumerate() {
487            if v.len() != expected_dim {
488                return Err(MultiVectorError::DimensionMismatch {
489                    chunk: i,
490                    expected: expected_dim,
491                    actual: v.len(),
492                });
493            }
494        }
495
496        if let Some(ref texts) = self.chunks_text {
497            if texts.len() != self.vectors.len() {
498                return Err(MultiVectorError::ChunkCountMismatch {
499                    vectors: self.vectors.len(),
500                    texts: texts.len(),
501                });
502            }
503        }
504
505        Ok(())
506    }
507}
508
509/// Errors for multi-vector operations
510#[derive(Debug, thiserror::Error)]
511pub enum MultiVectorError {
512    #[error("document must have at least one vector")]
513    NoVectors,
514
515    #[error("dimension mismatch in chunk {chunk}: expected {expected}, got {actual}")]
516    DimensionMismatch {
517        chunk: usize,
518        expected: usize,
519        actual: usize,
520    },
521
522    #[error("chunk count mismatch: {vectors} vectors but {texts} texts")]
523    ChunkCountMismatch { vectors: usize, texts: usize },
524
525    #[error("too many chunks: {count} exceeds limit of {limit}")]
526    TooManyChunks { count: usize, limit: usize },
527
528    #[error("document not found: {0}")]
529    NotFound(DocId),
530}
531
532// ============================================================================
533// Tests
534// ============================================================================
535
536#[cfg(test)]
537mod tests {
538    use super::*;
539
540    #[test]
541    fn test_aggregation_max() {
542        let chunks = vec![(0, 0.5), (1, 0.9), (2, 0.3)];
543
544        let result =
545            DocumentScore::aggregate("doc1".to_string(), chunks, AggregationMethod::Max, false);
546
547        assert_eq!(result.score, 0.9);
548        assert_eq!(result.best_chunk, Some(1));
549        assert_eq!(result.matched_chunks, 3);
550    }
551
552    #[test]
553    fn test_aggregation_mean() {
554        let chunks = vec![(0, 0.6), (1, 0.9), (2, 0.3)];
555
556        let result =
557            DocumentScore::aggregate("doc1".to_string(), chunks, AggregationMethod::Mean, false);
558
559        assert!((result.score - 0.6).abs() < 0.001); // (0.6 + 0.9 + 0.3) / 3 = 0.6
560    }
561
562    #[test]
563    fn test_aggregation_first() {
564        let chunks = vec![(2, 0.3), (0, 0.5), (1, 0.9)];
565
566        let result =
567            DocumentScore::aggregate("doc1".to_string(), chunks, AggregationMethod::First, false);
568
569        assert_eq!(result.score, 0.5); // Chunk 0 has score 0.5
570        assert_eq!(result.best_chunk, Some(0));
571    }
572
573    #[test]
574    fn test_mapping_insert() {
575        let mut mapping = MultiVectorMapping::new();
576
577        let ids = mapping.insert_document("doc1".to_string(), 3);
578        assert_eq!(ids.len(), 3);
579
580        // Check reverse lookup
581        for (i, &id) in ids.iter().enumerate() {
582            let (doc_id, chunk) = mapping.get_doc(id).unwrap();
583            assert_eq!(doc_id, "doc1");
584            assert_eq!(chunk as usize, i);
585        }
586    }
587
588    #[test]
589    fn test_mapping_remove() {
590        let mut mapping = MultiVectorMapping::new();
591
592        let ids = mapping.insert_document("doc1".to_string(), 3);
593
594        let removed = mapping.remove_document("doc1").unwrap();
595        assert_eq!(removed, ids);
596
597        // Should not be found
598        assert!(mapping.get_doc(ids[0]).is_none());
599        assert!(!mapping.has_document("doc1"));
600    }
601
602    #[test]
603    fn test_aggregator() {
604        let mapping = Arc::new(RwLock::new(MultiVectorMapping::new()));
605
606        // Insert two documents
607        {
608            let mut m = mapping.write();
609            m.insert_document("doc1".to_string(), 3); // IDs 0, 1, 2
610            m.insert_document("doc2".to_string(), 2); // IDs 3, 4
611        }
612
613        let aggregator = MultiVectorAggregator::new(mapping);
614
615        // Simulate search results
616        let vector_results = vec![
617            (1, 0.95), // doc1, chunk 1
618            (3, 0.90), // doc2, chunk 0
619            (0, 0.85), // doc1, chunk 0
620            (4, 0.80), // doc2, chunk 1
621        ];
622
623        let doc_results = aggregator.aggregate(&vector_results, Some(AggregationMethod::Max), 10);
624
625        assert_eq!(doc_results.len(), 2);
626        assert_eq!(doc_results[0].doc_id, "doc1");
627        assert_eq!(doc_results[0].score, 0.95);
628        assert_eq!(doc_results[1].doc_id, "doc2");
629        assert_eq!(doc_results[1].score, 0.90);
630    }
631
632    #[test]
633    fn test_multi_vector_document() {
634        let doc = MultiVectorDocument::new("doc1", vec![vec![0.1, 0.2, 0.3], vec![0.4, 0.5, 0.6]])
635            .with_text(vec!["chunk 1".to_string(), "chunk 2".to_string()])
636            .with_metadata("author", serde_json::json!("Alice"));
637
638        assert_eq!(doc.num_chunks(), 2);
639        assert!(doc.validate(3).is_ok());
640        assert!(doc.validate(4).is_err()); // Wrong dimension
641    }
642}