vectx_core/
collection.rs

1use crate::{Error, Point, Result, Vector, HnswIndex, BM25Index, Filter, MultiVector};
2use parking_lot::RwLock;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7/// Configuration for a collection
8#[derive(Debug, Clone)]
9pub struct CollectionConfig {
10    pub name: String,
11    pub vector_dim: usize,
12    pub distance: Distance,
13    pub use_hnsw: bool,
14    pub enable_bm25: bool,
15}
16
17impl Default for CollectionConfig {
18    fn default() -> Self {
19        Self {
20            name: String::new(),
21            vector_dim: 128,
22            distance: Distance::Cosine,
23            use_hnsw: true,
24            enable_bm25: false,
25        }
26    }
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Distance {
31    Cosine,
32    Euclidean,
33    Dot,
34}
35
36/// Payload field index type
37#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum PayloadIndexType {
39    Keyword,
40    Integer,
41    Float,
42    Bool,
43    Geo,
44    Text,
45}
46
47/// A collection of vectors with metadata
48pub struct Collection {
49    config: CollectionConfig,
50    points: Arc<RwLock<HashMap<String, Point>>>,
51    hnsw: Option<Arc<RwLock<HnswIndex>>>,
52    bm25: Option<Arc<RwLock<BM25Index>>>,
53    hnsw_built: Arc<RwLock<bool>>,
54    hnsw_rebuilding: Arc<AtomicBool>,
55    batch_mode: Arc<RwLock<bool>>,
56    pending_points: Arc<RwLock<Vec<Point>>>,
57    /// Payload field indexes
58    payload_indexes: Arc<RwLock<HashMap<String, PayloadIndexType>>>,
59    /// Operation counter for tracking write operations
60    operation_counter: Arc<std::sync::atomic::AtomicU64>,
61}
62
63impl Collection {
64    pub fn new(config: CollectionConfig) -> Self {
65        let hnsw = if config.use_hnsw {
66            Some(Arc::new(RwLock::new(HnswIndex::new(16, 3))))
67        } else {
68            None
69        };
70
71        let bm25 = if config.enable_bm25 {
72            Some(Arc::new(RwLock::new(BM25Index::new())))
73        } else {
74            None
75        };
76
77        Self {
78            config,
79            points: Arc::new(RwLock::new(HashMap::new())),
80            hnsw,
81            bm25,
82            hnsw_built: Arc::new(RwLock::new(false)),
83            hnsw_rebuilding: Arc::new(AtomicBool::new(false)),
84            batch_mode: Arc::new(RwLock::new(false)),
85            pending_points: Arc::new(RwLock::new(Vec::new())),
86            payload_indexes: Arc::new(RwLock::new(HashMap::new())),
87            operation_counter: Arc::new(std::sync::atomic::AtomicU64::new(0)),
88        }
89    }
90    
91    /// Get next operation ID (atomically increments counter)
92    #[inline]
93    pub fn next_operation_id(&self) -> u64 {
94        self.operation_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
95    }
96
97    #[inline]
98    #[must_use]
99    pub fn name(&self) -> &str {
100        &self.config.name
101    }
102
103    #[inline]
104    #[must_use]
105    pub fn vector_dim(&self) -> usize {
106        self.config.vector_dim
107    }
108
109    #[inline]
110    #[must_use]
111    pub fn distance(&self) -> Distance {
112        self.config.distance
113    }
114
115    #[inline]
116    #[must_use]
117    pub fn use_hnsw(&self) -> bool {
118        self.config.use_hnsw
119    }
120
121    #[inline]
122    #[must_use]
123    pub fn enable_bm25(&self) -> bool {
124        self.config.enable_bm25
125    }
126
127    #[inline]
128    #[must_use]
129    pub fn count(&self) -> usize {
130        self.points.read().len()
131    }
132
133    #[inline]
134    #[must_use]
135    pub fn is_empty(&self) -> bool {
136        self.points.read().is_empty()
137    }
138
139    /// Get all points in the collection
140    pub fn get_all_points(&self) -> Vec<Point> {
141        self.points.read().values().cloned().collect()
142    }
143
144    /// Insert or update a point
145    pub fn upsert(&self, point: Point) -> Result<()> {
146        // Skip dimension check for sparse-only collections (vector_dim == 0)
147        if self.config.vector_dim > 0 && point.vector.dim() != self.config.vector_dim {
148            return Err(Error::InvalidDimension {
149                expected: self.config.vector_dim,
150                actual: point.vector.dim(),
151            });
152        }
153
154        let id_str = point.id.to_string();
155        
156        // Check if point exists and get its current version
157        let new_version = {
158            let points = self.points.read();
159            if let Some(existing) = points.get(&id_str) {
160                existing.version + 1
161            } else {
162                0
163            }
164        };
165        
166        // Create point with updated version
167        let mut versioned_point = point;
168        versioned_point.version = new_version;
169        
170        let in_batch = *self.batch_mode.read();
171        if in_batch {
172            self.points.write().insert(id_str.clone(), versioned_point.clone());
173            self.pending_points.write().push(versioned_point);
174            return Ok(());
175        }
176        
177        if let Some(hnsw) = &self.hnsw {
178            let built = *self.hnsw_built.read();
179            if built {
180                let mut normalized_point = versioned_point.clone();
181                normalized_point.vector.normalize();
182                
183                let mut index = hnsw.write();
184                index.insert(normalized_point);
185            }
186        }
187
188        if let Some(bm25) = &self.bm25 {
189            if let Some(payload) = &versioned_point.payload {
190                if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
191                    let mut index = bm25.write();
192                    index.insert_doc(&id_str, text);
193                }
194            }
195        }
196
197        self.points.write().insert(id_str, versioned_point);
198        Ok(())
199    }
200
201    /// Start batch insert mode
202    pub fn start_batch(&self) {
203        *self.batch_mode.write() = true;
204        self.pending_points.write().clear();
205    }
206
207    /// End batch insert mode
208    pub fn end_batch(&self) -> Result<()> {
209        *self.batch_mode.write() = false;
210        
211        if let Some(hnsw) = &self.hnsw {
212            let points = self.points.read();
213            let point_count = points.len();
214            
215            const HNSW_REBUILD_THRESHOLD: usize = 10_000;
216            
217            if point_count > HNSW_REBUILD_THRESHOLD && !self.hnsw_rebuilding.load(Ordering::Acquire) {
218                self.hnsw_rebuilding.store(true, Ordering::Release);
219                let points_clone: Vec<Point> = points.values().cloned().collect();
220                let hnsw_clone = hnsw.clone();
221                let built_flag = self.hnsw_built.clone();
222                let rebuilding_flag = self.hnsw_rebuilding.clone();
223                
224                let job = crate::background::HnswRebuildJob::new(
225                    points_clone,
226                    hnsw_clone,
227                    built_flag,
228                    rebuilding_flag,
229                );
230                crate::background::get_background_system().submit(Box::new(job));
231            }
232        }
233        
234        self.pending_points.write().clear();
235        Ok(())
236    }
237
238    /// Batch insert multiple points
239    pub fn batch_upsert(&self, points: Vec<Point>) -> Result<()> {
240        self.start_batch();
241        for point in points {
242            self.upsert(point)?;
243        }
244        self.end_batch()?;
245        Ok(())
246    }
247
248    /// Batch insert with optional pre-warming
249    pub fn batch_upsert_with_prewarm(&self, points: Vec<Point>, prewarm: bool) -> Result<()> {
250        self.batch_upsert(points)?;
251        if prewarm {
252            self.prewarm_index()?;
253        }
254        Ok(())
255    }
256
257    /// Get a point by ID
258    #[inline]
259    pub fn get(&self, id: &str) -> Option<Point> {
260        self.points.read().get(id).cloned()
261    }
262
263    /// Delete a point by ID
264    pub fn delete(&self, id: &str) -> Result<bool> {
265        if let Some(hnsw) = &self.hnsw {
266            let mut index = hnsw.write();
267            index.remove(id);
268        }
269
270        if let Some(bm25) = &self.bm25 {
271            let mut index = bm25.write();
272            index.delete_doc(id);
273        }
274
275        let mut points = self.points.write();
276        Ok(points.remove(id).is_some())
277    }
278
279    /// Set payload values for a point (merge with existing)
280    pub fn set_payload(&self, id: &str, payload: serde_json::Value) -> Result<bool> {
281        let mut points = self.points.write();
282        if let Some(point) = points.get_mut(id) {
283            if let Some(existing) = &mut point.payload {
284                if let (Some(existing_obj), Some(new_obj)) = (existing.as_object_mut(), payload.as_object()) {
285                    for (key, value) in new_obj {
286                        existing_obj.insert(key.clone(), value.clone());
287                    }
288                }
289            } else {
290                point.payload = Some(payload);
291            }
292            Ok(true)
293        } else {
294            Ok(false)
295        }
296    }
297
298    /// Overwrite entire payload for a point
299    pub fn overwrite_payload(&self, id: &str, payload: serde_json::Value) -> Result<bool> {
300        let mut points = self.points.write();
301        if let Some(point) = points.get_mut(id) {
302            point.payload = Some(payload);
303            Ok(true)
304        } else {
305            Ok(false)
306        }
307    }
308
309    /// Delete specific payload keys from a point
310    pub fn delete_payload_keys(&self, id: &str, keys: &[String]) -> Result<bool> {
311        let mut points = self.points.write();
312        if let Some(point) = points.get_mut(id) {
313            if let Some(payload) = &mut point.payload {
314                if let Some(obj) = payload.as_object_mut() {
315                    for key in keys {
316                        obj.remove(key);
317                    }
318                }
319            }
320            Ok(true)
321        } else {
322            Ok(false)
323        }
324    }
325
326    /// Clear all payload from a point
327    pub fn clear_payload(&self, id: &str) -> Result<bool> {
328        let mut points = self.points.write();
329        if let Some(point) = points.get_mut(id) {
330            point.payload = None;
331            Ok(true)
332        } else {
333            Ok(false)
334        }
335    }
336
337    /// Update vector for a point
338    pub fn update_vector(&self, id: &str, vector: Vector) -> Result<bool> {
339        let mut points = self.points.write();
340        if let Some(point) = points.get_mut(id) {
341            point.vector = vector.clone();
342            
343            // Update HNSW index if present
344            if let Some(hnsw) = &self.hnsw {
345                let mut index = hnsw.write();
346                index.remove(id);
347                // Insert the updated point
348                index.insert(point.clone());
349            }
350            Ok(true)
351        } else {
352            Ok(false)
353        }
354    }
355
356    /// Update multivector for a point
357    pub fn update_multivector(&self, id: &str, multivector: Option<MultiVector>) -> Result<bool> {
358        let mut points = self.points.write();
359        if let Some(point) = points.get_mut(id) {
360            point.multivector = multivector;
361            Ok(true)
362        } else {
363            Ok(false)
364        }
365    }
366
367    /// Delete vector (set to empty) - for named vectors this would delete specific vector
368    pub fn delete_vector(&self, id: &str) -> Result<bool> {
369        // For now, deleting a vector means deleting the point
370        // In full implementation, named vectors would be individually deletable
371        self.delete(id)
372    }
373
374    /// Create a payload field index
375    pub fn create_payload_index(&self, field_name: &str, index_type: PayloadIndexType) -> Result<bool> {
376        let mut indexes = self.payload_indexes.write();
377        indexes.insert(field_name.to_string(), index_type);
378        Ok(true)
379    }
380
381    /// Delete a payload field index
382    pub fn delete_payload_index(&self, field_name: &str) -> Result<bool> {
383        let mut indexes = self.payload_indexes.write();
384        Ok(indexes.remove(field_name).is_some())
385    }
386
387    /// Get all payload indexes
388    pub fn get_payload_indexes(&self) -> HashMap<String, PayloadIndexType> {
389        self.payload_indexes.read().clone()
390    }
391
392    /// Check if a field is indexed
393    pub fn is_field_indexed(&self, field_name: &str) -> bool {
394        self.payload_indexes.read().contains_key(field_name)
395    }
396
397    /// Pre-warm HNSW index
398    pub fn prewarm_index(&self) -> Result<()> {
399        if let Some(hnsw) = &self.hnsw {
400            let mut built = self.hnsw_built.write();
401            if !*built {
402                let points = self.points.read();
403                if !points.is_empty() {
404                    let mut index = hnsw.write();
405                    *index = HnswIndex::new(16, 3);
406                    for point in points.values() {
407                        index.insert(point.clone());
408                    }
409                    *built = true;
410                }
411            }
412        }
413        Ok(())
414    }
415
416    /// Fast brute-force search using SIMD - optimal for small datasets
417    fn brute_force_search(&self, query: &Vector, limit: usize, filter: Option<&dyn Filter>) -> Vec<(Point, f32)> {
418        use rayon::prelude::*;
419        
420        let points = self.points.read();
421        let query_slice = query.as_slice();
422        let distance = self.config.distance.clone();
423        
424        // Collect points to a Vec for indexing
425        let point_vec: Vec<_> = points.values().collect();
426        
427        // Parallel scoring - compute scores without cloning points
428        // Only clone the final top-k results
429        // Use parallel only for 10K+ vectors (rayon has overhead)
430        let scored: Vec<(usize, f32)> = if point_vec.len() >= 10000 && filter.is_none() {
431            // Parallel path for larger datasets without filter
432            point_vec
433                .par_iter()
434                .enumerate()
435                .map(|(idx, point)| {
436                    let score = match distance {
437                        Distance::Cosine => {
438                            crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
439                        }
440                        Distance::Euclidean => {
441                            -crate::simd::l2_distance_simd(query_slice, point.vector.as_slice())
442                        }
443                        Distance::Dot => {
444                            crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
445                        }
446                    };
447                    (idx, score)
448                })
449                .collect()
450        } else {
451            // Sequential path - optimized for common case (Cosine without filter)
452            let mut results = Vec::with_capacity(point_vec.len());
453            
454            if filter.is_none() && matches!(distance, Distance::Cosine) {
455                // Hot path: Cosine without filter - avoid branching
456                for (idx, point) in point_vec.iter().enumerate() {
457                    let score = crate::simd::dot_product_simd(query_slice, point.vector.as_slice());
458                    results.push((idx, score));
459                }
460            } else {
461                // General path with filter/distance checks
462                for (idx, point) in point_vec.iter().enumerate() {
463                    if let Some(f) = filter {
464                        if !f.matches(point) {
465                            continue;
466                        }
467                    }
468                    
469                    let score = match distance {
470                        Distance::Cosine => {
471                            crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
472                        }
473                        Distance::Euclidean => {
474                            -crate::simd::l2_distance_simd(query_slice, point.vector.as_slice())
475                        }
476                        Distance::Dot => {
477                            crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
478                        }
479                    };
480                    
481                    results.push((idx, score));
482                }
483            }
484            results
485        };
486        
487        // Get top-k using partial sort
488        let mut scored = scored;
489        if scored.len() > limit {
490            scored.select_nth_unstable_by(limit, |a, b| {
491                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
492            });
493            scored.truncate(limit);
494        }
495        scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
496        
497        // Only clone the top-k points (avoiding cloning all points)
498        scored
499            .into_iter()
500            .map(|(idx, score)| (point_vec[idx].clone(), score))
501            .collect()
502    }
503
504    /// Search for similar vectors
505    /// Uses brute-force for small datasets (<1000), HNSW for larger ones
506    pub fn search(
507        &self,
508        query: &Vector,
509        limit: usize,
510        filter: Option<&dyn Filter>,
511    ) -> Vec<(Point, f32)> {
512        let normalized_query = query.normalized();
513        let point_count = self.points.read().len();
514        
515        // Use brute-force for datasets up to 10K - SIMD is very fast and avoids HNSW overhead
516        const BRUTE_FORCE_THRESHOLD: usize = 10000;
517        if point_count < BRUTE_FORCE_THRESHOLD {
518            return self.brute_force_search(&normalized_query, limit, filter);
519        }
520        
521        if let Some(hnsw) = &self.hnsw {
522            // Check if we need to build the index first
523            {
524                let mut built = self.hnsw_built.write();
525                if !*built {
526                    let points = self.points.read();
527                    if !points.is_empty() {
528                        let mut index = hnsw.write();
529                        *index = HnswIndex::new(16, 3);
530                        for point in points.values() {
531                            index.insert(point.clone());
532                        }
533                        *built = true;
534                    }
535                }
536            }
537            
538            // Use write lock for search (HNSW search is now mutable for performance)
539            let mut index = hnsw.write();
540            let mut results = index.search(&normalized_query, limit, None);
541            
542            if let Some(f) = filter {
543                results.retain(|(point, _)| f.matches(point));
544            }
545            
546            results
547        } else {
548            let points = self.points.read();
549            let results: Vec<(Point, f32)> = points
550                .values()
551                .filter(|point| {
552                    filter.map(|f| f.matches(point)).unwrap_or(true)
553                })
554                .map(|point| {
555                    let score = match self.config.distance {
556                        Distance::Cosine => point.vector.cosine_similarity(query),
557                        Distance::Euclidean => -point.vector.l2_distance(query),
558                        Distance::Dot => {
559                            point.vector.as_slice()
560                                .iter()
561                                .zip(query.as_slice().iter())
562                                .map(|(a, b)| a * b)
563                                .sum()
564                        }
565                    };
566                    (point.clone(), score)
567                })
568                .collect();
569
570            let mut sorted = results;
571            sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
572            sorted.truncate(limit);
573            sorted
574        }
575    }
576
577    /// BM25 text search
578    pub fn search_text(&self, query: &str, limit: usize) -> Vec<(String, f32)> {
579        if let Some(bm25) = &self.bm25 {
580            let index = bm25.read();
581            index.search(query, limit)
582        } else {
583            Vec::new()
584        }
585    }
586    
587    /// Search using multivector MaxSim scoring (ColBERT-style)
588    /// 
589    /// For each sub-vector in the query, finds the maximum similarity 
590    /// with any sub-vector in each document, then sums all maximums.
591    pub fn search_multivector(
592        &self,
593        query: &MultiVector,
594        limit: usize,
595        filter: Option<&dyn Filter>,
596    ) -> Vec<(Point, f32)> {
597        let points = self.points.read();
598        
599        let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
600        
601        for point in points.values() {
602            if let Some(f) = filter {
603                if !f.matches(point) {
604                    continue;
605                }
606            }
607            
608            // Calculate MaxSim score
609            let score = if let Some(doc_mv) = &point.multivector {
610                // Both query and document have multivectors - use MaxSim
611                match self.config.distance {
612                    Distance::Cosine => query.max_sim_cosine(doc_mv),
613                    Distance::Euclidean => query.max_sim_l2(doc_mv),
614                    Distance::Dot => query.max_sim(doc_mv),
615                }
616            } else {
617                // Document has single vector - wrap it as multivector
618                let doc_mv = MultiVector::from_single(point.vector.as_slice().to_vec())
619                    .unwrap_or_else(|_| MultiVector::new(vec![vec![0.0; query.dim()]]).unwrap());
620                match self.config.distance {
621                    Distance::Cosine => query.max_sim_cosine(&doc_mv),
622                    Distance::Euclidean => query.max_sim_l2(&doc_mv),
623                    Distance::Dot => query.max_sim(&doc_mv),
624                }
625            };
626            
627            results.push((point.clone(), score));
628        }
629        
630        // Sort by score descending
631        if results.len() > limit {
632            results.select_nth_unstable_by(limit, |a, b| {
633                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
634            });
635            results.truncate(limit);
636        }
637        
638        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
639        results
640    }
641
642    /// Get all points
643    pub fn iter(&self) -> Vec<Point> {
644        self.points.read().values().cloned().collect()
645    }
646    
647    /// Search using sparse vectors (dot product on matching indices)
648    pub fn search_sparse(
649        &self,
650        query: &crate::point::SparseVector,
651        vector_name: &str,
652        limit: usize,
653        filter: Option<&dyn Filter>,
654    ) -> Vec<(Point, f32)> {
655        let points = self.points.read();
656        
657        let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
658        
659        for point in points.values() {
660            // Apply filter if provided
661            if let Some(f) = filter {
662                if !f.matches(point) {
663                    continue;
664                }
665            }
666            
667            // Get the named sparse vector from the point
668            if let Some(point_sparse) = point.sparse_vectors.get(vector_name) {
669                // Calculate dot product score
670                let score = query.dot(point_sparse);
671                
672                // Only include if score > 0 (at least one matching index)
673                if score > 0.0 {
674                    results.push((point.clone(), score));
675                }
676            }
677        }
678        
679        // Sort by score descending
680        results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
681        results.truncate(limit);
682        
683        results
684    }
685}
686