velesdb_core/
collection.rs

1//! Collection management for `VelesDB`.
2
3use crate::distance::DistanceMetric;
4use crate::error::{Error, Result};
5use crate::index::{Bm25Index, HnswIndex, VectorIndex};
6use crate::point::{Point, SearchResult};
7use crate::quantization::{BinaryQuantizedVector, QuantizedVector, StorageMode};
8use crate::storage::{LogPayloadStorage, MmapStorage, PayloadStorage, VectorStorage};
9
10use std::collections::HashMap;
11
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15use std::sync::Arc;
16
17/// Metadata for a collection.
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CollectionConfig {
20    /// Name of the collection.
21    pub name: String,
22
23    /// Vector dimension.
24    pub dimension: usize,
25
26    /// Distance metric.
27    pub metric: DistanceMetric,
28
29    /// Number of points in the collection.
30    pub point_count: usize,
31
32    /// Storage mode for vectors (Full, SQ8, Binary).
33    #[serde(default)]
34    pub storage_mode: StorageMode,
35}
36
37/// A collection of vectors with associated metadata.
38#[derive(Clone)]
39pub struct Collection {
40    /// Path to the collection data.
41    path: PathBuf,
42
43    /// Collection configuration.
44    config: Arc<RwLock<CollectionConfig>>,
45
46    /// Vector storage (on-disk, memory-mapped).
47    vector_storage: Arc<RwLock<MmapStorage>>,
48
49    /// Payload storage (on-disk, log-structured).
50    payload_storage: Arc<RwLock<LogPayloadStorage>>,
51
52    /// HNSW index for fast approximate nearest neighbor search.
53    index: Arc<HnswIndex>,
54
55    /// BM25 index for full-text search.
56    text_index: Arc<Bm25Index>,
57
58    /// SQ8 quantized vectors cache (for SQ8 storage mode).
59    sq8_cache: Arc<RwLock<HashMap<u64, QuantizedVector>>>,
60
61    /// Binary quantized vectors cache (for Binary storage mode).
62    binary_cache: Arc<RwLock<HashMap<u64, BinaryQuantizedVector>>>,
63}
64
65impl Collection {
66    /// Creates a new collection at the specified path.
67    ///
68    /// # Errors
69    ///
70    /// Returns an error if the directory cannot be created or the config cannot be saved.
71    pub fn create(path: PathBuf, dimension: usize, metric: DistanceMetric) -> Result<Self> {
72        Self::create_with_options(path, dimension, metric, StorageMode::default())
73    }
74
75    /// Creates a new collection with custom storage options.
76    ///
77    /// # Arguments
78    ///
79    /// * `path` - Path to the collection directory
80    /// * `dimension` - Vector dimension
81    /// * `metric` - Distance metric
82    /// * `storage_mode` - Vector storage mode (Full, SQ8, Binary)
83    ///
84    /// # Errors
85    ///
86    /// Returns an error if the directory cannot be created or the config cannot be saved.
87    pub fn create_with_options(
88        path: PathBuf,
89        dimension: usize,
90        metric: DistanceMetric,
91        storage_mode: StorageMode,
92    ) -> Result<Self> {
93        std::fs::create_dir_all(&path)?;
94
95        let name = path
96            .file_name()
97            .and_then(|n| n.to_str())
98            .unwrap_or("unknown")
99            .to_string();
100
101        let config = CollectionConfig {
102            name,
103            dimension,
104            metric,
105            point_count: 0,
106            storage_mode,
107        };
108
109        // Initialize persistent storages
110        let vector_storage = Arc::new(RwLock::new(
111            MmapStorage::new(&path, dimension).map_err(Error::Io)?,
112        ));
113
114        let payload_storage = Arc::new(RwLock::new(
115            LogPayloadStorage::new(&path).map_err(Error::Io)?,
116        ));
117
118        // Create HNSW index
119        let index = Arc::new(HnswIndex::new(dimension, metric));
120
121        // Create BM25 index for full-text search
122        let text_index = Arc::new(Bm25Index::new());
123
124        let collection = Self {
125            path,
126            config: Arc::new(RwLock::new(config)),
127            vector_storage,
128            payload_storage,
129            index,
130            text_index,
131            sq8_cache: Arc::new(RwLock::new(HashMap::new())),
132            binary_cache: Arc::new(RwLock::new(HashMap::new())),
133        };
134
135        collection.save_config()?;
136
137        Ok(collection)
138    }
139
140    /// Opens an existing collection from the specified path.
141    ///
142    /// # Errors
143    ///
144    /// Returns an error if the config file cannot be read or parsed.
145    pub fn open(path: PathBuf) -> Result<Self> {
146        let config_path = path.join("config.json");
147        let config_data = std::fs::read_to_string(&config_path)?;
148        let config: CollectionConfig =
149            serde_json::from_str(&config_data).map_err(|e| Error::Serialization(e.to_string()))?;
150
151        // Open persistent storages
152        let vector_storage = Arc::new(RwLock::new(
153            MmapStorage::new(&path, config.dimension).map_err(Error::Io)?,
154        ));
155
156        let payload_storage = Arc::new(RwLock::new(
157            LogPayloadStorage::new(&path).map_err(Error::Io)?,
158        ));
159
160        // Load HNSW index if it exists, otherwise create new (empty)
161        let index = if path.join("hnsw.bin").exists() {
162            Arc::new(HnswIndex::load(&path, config.dimension, config.metric).map_err(Error::Io)?)
163        } else {
164            Arc::new(HnswIndex::new(config.dimension, config.metric))
165        };
166
167        // Create and rebuild BM25 index from existing payloads
168        let text_index = Arc::new(Bm25Index::new());
169
170        // Rebuild BM25 index from persisted payloads
171        {
172            let storage = payload_storage.read();
173            let ids = storage.ids();
174            for id in ids {
175                if let Ok(Some(payload)) = storage.retrieve(id) {
176                    let text = Self::extract_text_from_payload(&payload);
177                    if !text.is_empty() {
178                        text_index.add_document(id, &text);
179                    }
180                }
181            }
182        }
183
184        Ok(Self {
185            path,
186            config: Arc::new(RwLock::new(config)),
187            vector_storage,
188            payload_storage,
189            index,
190            text_index,
191            sq8_cache: Arc::new(RwLock::new(HashMap::new())),
192            binary_cache: Arc::new(RwLock::new(HashMap::new())),
193        })
194    }
195
196    /// Returns the collection configuration.
197    #[must_use]
198    pub fn config(&self) -> CollectionConfig {
199        self.config.read().clone()
200    }
201
202    /// Inserts or updates points in the collection.
203    ///
204    /// Accepts any iterator of points (Vec, slice, array, etc.)
205    ///
206    /// # Errors
207    ///
208    /// Returns an error if any point has a mismatched dimension.
209    pub fn upsert(&self, points: impl IntoIterator<Item = Point>) -> Result<()> {
210        let points: Vec<Point> = points.into_iter().collect();
211        let config = self.config.read();
212        let dimension = config.dimension;
213        let storage_mode = config.storage_mode;
214        drop(config);
215
216        // Validate dimensions first
217        for point in &points {
218            if point.dimension() != dimension {
219                return Err(Error::DimensionMismatch {
220                    expected: dimension,
221                    actual: point.dimension(),
222                });
223            }
224        }
225
226        let mut vector_storage = self.vector_storage.write();
227        let mut payload_storage = self.payload_storage.write();
228
229        // Get quantized caches if needed
230        let mut sq8_cache = match storage_mode {
231            StorageMode::SQ8 => Some(self.sq8_cache.write()),
232            _ => None,
233        };
234        let mut binary_cache = match storage_mode {
235            StorageMode::Binary => Some(self.binary_cache.write()),
236            _ => None,
237        };
238
239        for point in points {
240            // 1. Store Vector
241            vector_storage
242                .store(point.id, &point.vector)
243                .map_err(Error::Io)?;
244
245            // 2. Store quantized vector based on storage_mode
246            match storage_mode {
247                StorageMode::SQ8 => {
248                    if let Some(ref mut cache) = sq8_cache {
249                        let quantized = QuantizedVector::from_f32(&point.vector);
250                        cache.insert(point.id, quantized);
251                    }
252                }
253                StorageMode::Binary => {
254                    if let Some(ref mut cache) = binary_cache {
255                        let quantized = BinaryQuantizedVector::from_f32(&point.vector);
256                        cache.insert(point.id, quantized);
257                    }
258                }
259                StorageMode::Full => {}
260            }
261
262            // 3. Store Payload (if present)
263            if let Some(payload) = &point.payload {
264                payload_storage
265                    .store(point.id, payload)
266                    .map_err(Error::Io)?;
267            } else {
268                let _ = payload_storage.delete(point.id);
269            }
270
271            // 4. Update Vector Index
272            self.index.insert(point.id, &point.vector);
273
274            // 5. Update BM25 Text Index
275            if let Some(payload) = &point.payload {
276                let text = Self::extract_text_from_payload(payload);
277                if !text.is_empty() {
278                    self.text_index.add_document(point.id, &text);
279                }
280            } else {
281                self.text_index.remove_document(point.id);
282            }
283        }
284
285        // Update point count
286        let mut config = self.config.write();
287        config.point_count = vector_storage.len();
288
289        // Auto-flush for durability
290        vector_storage.flush().map_err(Error::Io)?;
291        payload_storage.flush().map_err(Error::Io)?;
292        self.index.save(&self.path).map_err(Error::Io)?;
293
294        Ok(())
295    }
296
297    /// Bulk insert optimized for high-throughput import.
298    ///
299    /// # Performance
300    ///
301    /// This method is optimized for bulk loading:
302    /// - Uses sequential HNSW insertion (reliable, no rayon conflicts)
303    /// - Single flush at the end (not per-point)
304    /// - No HNSW index save (deferred for performance)
305    /// - ~20-30% faster than previous parallel approach on large batches (5000+)
306    /// - Benchmark: 1.5-2.1 Kvec/s on 768D vectors
307    ///
308    /// # Errors
309    ///
310    /// Returns an error if any point has a mismatched dimension.
311    pub fn upsert_bulk(&self, points: &[Point]) -> Result<usize> {
312        if points.is_empty() {
313            return Ok(0);
314        }
315
316        let config = self.config.read();
317        let dimension = config.dimension;
318        drop(config);
319
320        // Validate dimensions first
321        for point in points {
322            if point.dimension() != dimension {
323                return Err(Error::DimensionMismatch {
324                    expected: dimension,
325                    actual: point.dimension(),
326                });
327            }
328        }
329
330        // Perf: Collect vectors for parallel HNSW insertion (needed for clone anyway)
331        let vectors_for_hnsw: Vec<(u64, Vec<f32>)> =
332            points.iter().map(|p| (p.id, p.vector.clone())).collect();
333
334        // Perf: Single batch WAL write + contiguous mmap write
335        // Use references from vectors_for_hnsw to avoid double allocation
336        let vectors_for_storage: Vec<(u64, &[f32])> = vectors_for_hnsw
337            .iter()
338            .map(|(id, v)| (*id, v.as_slice()))
339            .collect();
340
341        let mut vector_storage = self.vector_storage.write();
342        vector_storage
343            .store_batch(&vectors_for_storage)
344            .map_err(Error::Io)?;
345        drop(vector_storage);
346
347        // Store payloads and update BM25 (still sequential for now)
348        let mut payload_storage = self.payload_storage.write();
349        for point in points {
350            if let Some(payload) = &point.payload {
351                payload_storage
352                    .store(point.id, payload)
353                    .map_err(Error::Io)?;
354
355                // Update BM25 text index
356                let text = Self::extract_text_from_payload(payload);
357                if !text.is_empty() {
358                    self.text_index.add_document(point.id, &text);
359                }
360            }
361        }
362        drop(payload_storage);
363
364        // Perf: Parallel HNSW insertion (CPU bound - benefits from parallelism)
365        let inserted = self.index.insert_batch_parallel(vectors_for_hnsw);
366        self.index.set_searching_mode();
367
368        // Update point count
369        let mut config = self.config.write();
370        config.point_count = self.vector_storage.read().len();
371        drop(config);
372
373        // Perf: Only flush vector/payload storage (fast mmap sync)
374        // Skip expensive HNSW index save - will be saved on collection close/explicit flush
375        // This is safe: HNSW is in-memory and rebuilt from vector storage on restart
376        self.vector_storage.write().flush().map_err(Error::Io)?;
377        self.payload_storage.write().flush().map_err(Error::Io)?;
378        // NOTE: index.save() removed - too slow for batch operations
379        // Call collection.flush() explicitly if durability is critical
380
381        Ok(inserted)
382    }
383
384    /// Retrieves points by their IDs.
385    #[must_use]
386    pub fn get(&self, ids: &[u64]) -> Vec<Option<Point>> {
387        let vector_storage = self.vector_storage.read();
388        let payload_storage = self.payload_storage.read();
389
390        ids.iter()
391            .map(|&id| {
392                // Retrieve vector
393                let vector = vector_storage.retrieve(id).ok().flatten()?;
394
395                // Retrieve payload
396                let payload = payload_storage.retrieve(id).ok().flatten();
397
398                Some(Point {
399                    id,
400                    vector,
401                    payload,
402                })
403            })
404            .collect()
405    }
406
407    /// Deletes points by their IDs.
408    ///
409    /// # Errors
410    ///
411    /// Returns an error if storage operations fail.
412    pub fn delete(&self, ids: &[u64]) -> Result<()> {
413        let mut vector_storage = self.vector_storage.write();
414        let mut payload_storage = self.payload_storage.write();
415
416        for &id in ids {
417            vector_storage.delete(id).map_err(Error::Io)?;
418            payload_storage.delete(id).map_err(Error::Io)?;
419            self.index.remove(id);
420        }
421
422        let mut config = self.config.write();
423        config.point_count = vector_storage.len();
424
425        Ok(())
426    }
427
428    /// Searches for the k nearest neighbors of the query vector.
429    ///
430    /// Uses HNSW index for fast approximate nearest neighbor search.
431    ///
432    /// # Errors
433    ///
434    /// Returns an error if the query vector dimension doesn't match the collection.
435    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
436        let config = self.config.read();
437
438        if query.len() != config.dimension {
439            return Err(Error::DimensionMismatch {
440                expected: config.dimension,
441                actual: query.len(),
442            });
443        }
444        drop(config);
445
446        // Use HNSW index for fast ANN search
447        let index_results = self.index.search(query, k);
448
449        let vector_storage = self.vector_storage.read();
450        let payload_storage = self.payload_storage.read();
451
452        // Map index results to SearchResult with full point data
453        let results: Vec<SearchResult> = index_results
454            .into_iter()
455            .filter_map(|(id, score)| {
456                // We need to fetch vector and payload
457                let vector = vector_storage.retrieve(id).ok().flatten()?;
458                let payload = payload_storage.retrieve(id).ok().flatten();
459
460                let point = Point {
461                    id,
462                    vector,
463                    payload,
464                };
465
466                Some(SearchResult::new(point, score))
467            })
468            .collect();
469
470        Ok(results)
471    }
472
473    /// Performs fast vector similarity search returning only IDs and scores.
474    ///
475    /// Perf: This is ~3-5x faster than `search()` because it skips vector/payload retrieval.
476    /// Use this when you only need IDs and scores, not full point data.
477    ///
478    /// # Arguments
479    ///
480    /// * `query` - Query vector
481    /// * `k` - Maximum number of results to return
482    ///
483    /// # Returns
484    ///
485    /// Vector of (id, score) tuples sorted by similarity.
486    ///
487    /// # Errors
488    ///
489    /// Returns an error if the query vector dimension doesn't match the collection.
490    pub fn search_ids(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
491        let config = self.config.read();
492
493        if query.len() != config.dimension {
494            return Err(Error::DimensionMismatch {
495                expected: config.dimension,
496                actual: query.len(),
497            });
498        }
499        drop(config);
500
501        // Perf: Direct HNSW search without vector/payload retrieval (Round 8)
502        Ok(self.index.search(query, k))
503    }
504
505    /// Performs batch vector similarity search in parallel using rayon.
506    ///
507    /// Perf: This is significantly faster than calling `search` in a loop
508    /// because it parallelizes across CPU cores and amortizes lock overhead.
509    ///
510    /// # Arguments
511    ///
512    /// * `queries` - Slice of query vectors
513    /// * `k` - Maximum number of results per query
514    ///
515    /// # Returns
516    ///
517    /// Vector of search results for each query, with full point data.
518    ///
519    /// # Errors
520    ///
521    /// Returns an error if any query vector dimension doesn't match the collection.
522    pub fn search_batch_parallel(
523        &self,
524        queries: &[&[f32]],
525        k: usize,
526    ) -> Result<Vec<Vec<SearchResult>>> {
527        use crate::index::SearchQuality;
528
529        let config = self.config.read();
530        let dimension = config.dimension;
531        drop(config);
532
533        // Validate all query dimensions first
534        for query in queries {
535            if query.len() != dimension {
536                return Err(Error::DimensionMismatch {
537                    expected: dimension,
538                    actual: query.len(),
539                });
540            }
541        }
542
543        // Perf: Use parallel HNSW search (P0 optimization)
544        let index_results = self
545            .index
546            .search_batch_parallel(queries, k, SearchQuality::Balanced);
547
548        // Map results to SearchResult with full point data
549        let vector_storage = self.vector_storage.read();
550        let payload_storage = self.payload_storage.read();
551
552        let results: Vec<Vec<SearchResult>> = index_results
553            .into_iter()
554            .map(|query_results: Vec<(u64, f32)>| {
555                query_results
556                    .into_iter()
557                    .filter_map(|(id, score)| {
558                        let vector = vector_storage.retrieve(id).ok().flatten()?;
559                        let payload = payload_storage.retrieve(id).ok().flatten();
560                        Some(SearchResult {
561                            point: Point {
562                                id,
563                                vector,
564                                payload,
565                            },
566                            score,
567                        })
568                    })
569                    .collect()
570            })
571            .collect();
572
573        Ok(results)
574    }
575
576    /// Returns the number of points in the collection.
577    /// Perf: Uses cached `point_count` from config instead of acquiring storage lock
578    #[must_use]
579    pub fn len(&self) -> usize {
580        self.config.read().point_count
581    }
582
583    /// Returns true if the collection is empty.
584    /// Perf: Uses cached `point_count` from config instead of acquiring storage lock
585    #[must_use]
586    pub fn is_empty(&self) -> bool {
587        self.config.read().point_count == 0
588    }
589
590    /// Saves the collection configuration and index to disk.
591    ///
592    /// # Errors
593    ///
594    /// Returns an error if storage operations fail.
595    pub fn flush(&self) -> Result<()> {
596        self.save_config()?;
597        self.vector_storage.write().flush().map_err(Error::Io)?;
598        self.payload_storage.write().flush().map_err(Error::Io)?;
599        self.index.save(&self.path).map_err(Error::Io)?;
600        Ok(())
601    }
602
603    /// Saves the collection configuration to disk.
604    fn save_config(&self) -> Result<()> {
605        let config = self.config.read();
606        let config_path = self.path.join("config.json");
607        let config_data = serde_json::to_string_pretty(&*config)
608            .map_err(|e| Error::Serialization(e.to_string()))?;
609        std::fs::write(config_path, config_data)?;
610        Ok(())
611    }
612
613    /// Performs full-text search using BM25.
614    ///
615    /// # Arguments
616    ///
617    /// * `query` - Text query to search for
618    /// * `k` - Maximum number of results to return
619    ///
620    /// # Returns
621    ///
622    /// Vector of search results sorted by BM25 score (descending).
623    #[must_use]
624    pub fn text_search(&self, query: &str, k: usize) -> Vec<SearchResult> {
625        let bm25_results = self.text_index.search(query, k);
626
627        let vector_storage = self.vector_storage.read();
628        let payload_storage = self.payload_storage.read();
629
630        bm25_results
631            .into_iter()
632            .filter_map(|(id, score)| {
633                let vector = vector_storage.retrieve(id).ok().flatten()?;
634                let payload = payload_storage.retrieve(id).ok().flatten();
635
636                let point = Point {
637                    id,
638                    vector,
639                    payload,
640                };
641
642                Some(SearchResult::new(point, score))
643            })
644            .collect()
645    }
646
647    /// Performs hybrid search combining vector similarity and full-text search.
648    ///
649    /// Uses Reciprocal Rank Fusion (RRF) to combine results from both searches.
650    ///
651    /// # Arguments
652    ///
653    /// * `vector_query` - Query vector for similarity search
654    /// * `text_query` - Text query for BM25 search
655    /// * `k` - Maximum number of results to return
656    /// * `vector_weight` - Weight for vector results (0.0-1.0, default 0.5)
657    ///
658    /// # Errors
659    ///
660    /// Returns an error if the query vector dimension doesn't match.
661    pub fn hybrid_search(
662        &self,
663        vector_query: &[f32],
664        text_query: &str,
665        k: usize,
666        vector_weight: Option<f32>,
667    ) -> Result<Vec<SearchResult>> {
668        let config = self.config.read();
669        if vector_query.len() != config.dimension {
670            return Err(Error::DimensionMismatch {
671                expected: config.dimension,
672                actual: vector_query.len(),
673            });
674        }
675        drop(config);
676
677        let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
678        let text_weight = 1.0 - weight;
679
680        // Get vector search results (more than k to allow for fusion)
681        let vector_results = self.index.search(vector_query, k * 2);
682
683        // Get BM25 text search results
684        let text_results = self.text_index.search(text_query, k * 2);
685
686        // Perf: Apply RRF (Reciprocal Rank Fusion) with FxHashMap for faster hashing
687        // RRF score = 1 / (rank + 60) - the constant 60 is standard
688        let mut fused_scores: rustc_hash::FxHashMap<u64, f32> = rustc_hash::FxHashMap::default();
689
690        // Add vector scores with RRF
691        #[allow(clippy::cast_precision_loss)]
692        for (rank, (id, _)) in vector_results.iter().enumerate() {
693            let rrf_score = weight / (rank as f32 + 60.0);
694            *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
695        }
696
697        // Add text scores with RRF
698        #[allow(clippy::cast_precision_loss)]
699        for (rank, (id, _)) in text_results.iter().enumerate() {
700            let rrf_score = text_weight / (rank as f32 + 60.0);
701            *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
702        }
703
704        // Perf: Use partial sort for top-k instead of full sort
705        let mut scored_ids: Vec<_> = fused_scores.into_iter().collect();
706        if scored_ids.len() > k {
707            scored_ids.select_nth_unstable_by(k, |a, b| {
708                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
709            });
710            scored_ids.truncate(k);
711            scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
712        } else {
713            scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
714        }
715
716        // Fetch full point data
717        let vector_storage = self.vector_storage.read();
718        let payload_storage = self.payload_storage.read();
719
720        let results: Vec<SearchResult> = scored_ids
721            .into_iter()
722            .filter_map(|(id, score)| {
723                let vector = vector_storage.retrieve(id).ok().flatten()?;
724                let payload = payload_storage.retrieve(id).ok().flatten();
725
726                let point = Point {
727                    id,
728                    vector,
729                    payload,
730                };
731
732                Some(SearchResult::new(point, score))
733            })
734            .collect();
735
736        Ok(results)
737    }
738
739    /// Extracts all string values from a JSON payload for text indexing.
740    fn extract_text_from_payload(payload: &serde_json::Value) -> String {
741        let mut texts = Vec::new();
742        Self::collect_strings(payload, &mut texts);
743        texts.join(" ")
744    }
745
746    /// Recursively collects all string values from a JSON value.
747    fn collect_strings(value: &serde_json::Value, texts: &mut Vec<String>) {
748        match value {
749            serde_json::Value::String(s) => texts.push(s.clone()),
750            serde_json::Value::Array(arr) => {
751                for item in arr {
752                    Self::collect_strings(item, texts);
753                }
754            }
755            serde_json::Value::Object(obj) => {
756                for v in obj.values() {
757                    Self::collect_strings(v, texts);
758                }
759            }
760            _ => {}
761        }
762    }
763}
764
765#[cfg(test)]
766mod tests {
767    use super::*;
768    use serde_json::json;
769    use tempfile::tempdir;
770
771    #[test]
772    fn test_collection_create() {
773        let dir = tempdir().unwrap();
774        let path = dir.path().join("test_collection");
775
776        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
777        let config = collection.config();
778
779        assert_eq!(config.dimension, 3);
780        assert_eq!(config.metric, DistanceMetric::Cosine);
781        assert_eq!(config.point_count, 0);
782    }
783
784    #[test]
785    fn test_collection_upsert_and_search() {
786        let dir = tempdir().unwrap();
787        let path = dir.path().join("test_collection");
788
789        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
790
791        let points = vec![
792            Point::without_payload(1, vec![1.0, 0.0, 0.0]),
793            Point::without_payload(2, vec![0.0, 1.0, 0.0]),
794            Point::without_payload(3, vec![0.0, 0.0, 1.0]),
795        ];
796
797        collection.upsert(points).unwrap();
798        assert_eq!(collection.len(), 3);
799
800        let query = vec![1.0, 0.0, 0.0];
801        let results = collection.search(&query, 2).unwrap();
802
803        assert_eq!(results.len(), 2);
804        assert_eq!(results[0].point.id, 1); // Most similar
805    }
806
807    #[test]
808    fn test_dimension_mismatch() {
809        let dir = tempdir().unwrap();
810        let path = dir.path().join("test_collection");
811
812        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
813
814        let points = vec![Point::without_payload(1, vec![1.0, 0.0])]; // Wrong dimension
815
816        let result = collection.upsert(points);
817        assert!(result.is_err());
818    }
819
820    #[test]
821    fn test_collection_open_existing() {
822        let dir = tempdir().unwrap();
823        let path = dir.path().join("test_collection");
824
825        // Create and populate collection
826        {
827            let collection =
828                Collection::create(path.clone(), 3, DistanceMetric::Euclidean).unwrap();
829            let points = vec![
830                Point::without_payload(1, vec![1.0, 2.0, 3.0]),
831                Point::without_payload(2, vec![4.0, 5.0, 6.0]),
832            ];
833            collection.upsert(points).unwrap();
834            collection.flush().unwrap();
835        }
836
837        // Reopen and verify
838        let collection = Collection::open(path).unwrap();
839        let config = collection.config();
840
841        assert_eq!(config.dimension, 3);
842        assert_eq!(config.metric, DistanceMetric::Euclidean);
843        assert_eq!(collection.len(), 2);
844    }
845
846    #[test]
847    fn test_collection_get_points() {
848        let dir = tempdir().unwrap();
849        let path = dir.path().join("test_collection");
850
851        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
852        let points = vec![
853            Point::without_payload(1, vec![1.0, 0.0, 0.0]),
854            Point::without_payload(2, vec![0.0, 1.0, 0.0]),
855        ];
856        collection.upsert(points).unwrap();
857
858        // Get existing points
859        let retrieved = collection.get(&[1, 2, 999]);
860
861        assert!(retrieved[0].is_some());
862        assert_eq!(retrieved[0].as_ref().unwrap().id, 1);
863        assert!(retrieved[1].is_some());
864        assert_eq!(retrieved[1].as_ref().unwrap().id, 2);
865        assert!(retrieved[2].is_none()); // 999 doesn't exist
866    }
867
868    #[test]
869    fn test_collection_delete_points() {
870        let dir = tempdir().unwrap();
871        let path = dir.path().join("test_collection");
872
873        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
874        let points = vec![
875            Point::without_payload(1, vec![1.0, 0.0, 0.0]),
876            Point::without_payload(2, vec![0.0, 1.0, 0.0]),
877            Point::without_payload(3, vec![0.0, 0.0, 1.0]),
878        ];
879        collection.upsert(points).unwrap();
880        assert_eq!(collection.len(), 3);
881
882        // Delete one point
883        collection.delete(&[2]).unwrap();
884        assert_eq!(collection.len(), 2);
885
886        // Verify it's gone
887        let retrieved = collection.get(&[2]);
888        assert!(retrieved[0].is_none());
889    }
890
891    #[test]
892    fn test_collection_is_empty() {
893        let dir = tempdir().unwrap();
894        let path = dir.path().join("test_collection");
895
896        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
897        assert!(collection.is_empty());
898
899        collection
900            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
901            .unwrap();
902        assert!(!collection.is_empty());
903    }
904
905    #[test]
906    fn test_collection_with_payload() {
907        let dir = tempdir().unwrap();
908        let path = dir.path().join("test_collection");
909
910        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
911
912        let points = vec![Point::new(
913            1,
914            vec![1.0, 0.0, 0.0],
915            Some(json!({"title": "Test Document", "category": "tech"})),
916        )];
917        collection.upsert(points).unwrap();
918
919        let retrieved = collection.get(&[1]);
920        assert!(retrieved[0].is_some());
921
922        let point = retrieved[0].as_ref().unwrap();
923        assert!(point.payload.is_some());
924        assert_eq!(point.payload.as_ref().unwrap()["title"], "Test Document");
925    }
926
927    #[test]
928    fn test_collection_search_dimension_mismatch() {
929        let dir = tempdir().unwrap();
930        let path = dir.path().join("test_collection");
931
932        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
933        collection
934            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
935            .unwrap();
936
937        // Search with wrong dimension
938        let result = collection.search(&[1.0, 0.0], 5);
939        assert!(result.is_err());
940    }
941
942    #[test]
943    fn test_collection_search_ids_fast() {
944        // Round 8: Test fast search returning only IDs and scores
945        let dir = tempdir().unwrap();
946        let path = dir.path().join("test_collection");
947
948        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
949        collection
950            .upsert(vec![
951                Point::without_payload(1, vec![1.0, 0.0, 0.0]),
952                Point::without_payload(2, vec![0.9, 0.1, 0.0]),
953                Point::without_payload(3, vec![0.0, 1.0, 0.0]),
954            ])
955            .unwrap();
956
957        // Fast search returns (id, score) tuples
958        let results = collection.search_ids(&[1.0, 0.0, 0.0], 2).unwrap();
959        assert_eq!(results.len(), 2);
960        assert_eq!(results[0].0, 1); // Best match
961        assert!(results[0].1 > results[1].1); // Scores are sorted
962    }
963
964    #[test]
965    fn test_collection_upsert_replaces_payload() {
966        let dir = tempdir().unwrap();
967        let path = dir.path().join("test_collection");
968
969        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
970
971        // Insert with payload
972        collection
973            .upsert(vec![Point::new(
974                1,
975                vec![1.0, 0.0, 0.0],
976                Some(json!({"version": 1})),
977            )])
978            .unwrap();
979
980        // Upsert without payload (should clear it)
981        collection
982            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
983            .unwrap();
984
985        let retrieved = collection.get(&[1]);
986        let point = retrieved[0].as_ref().unwrap();
987        assert!(point.payload.is_none());
988    }
989
990    #[test]
991    fn test_collection_flush() {
992        let dir = tempdir().unwrap();
993        let path = dir.path().join("test_collection");
994
995        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
996        collection
997            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
998            .unwrap();
999
1000        // Explicit flush should succeed
1001        let result = collection.flush();
1002        assert!(result.is_ok());
1003    }
1004
1005    #[test]
1006    fn test_collection_euclidean_metric() {
1007        let dir = tempdir().unwrap();
1008        let path = dir.path().join("test_collection");
1009
1010        let collection = Collection::create(path, 3, DistanceMetric::Euclidean).unwrap();
1011
1012        let points = vec![
1013            Point::without_payload(1, vec![0.0, 0.0, 0.0]),
1014            Point::without_payload(2, vec![1.0, 0.0, 0.0]),
1015            Point::without_payload(3, vec![10.0, 0.0, 0.0]),
1016        ];
1017        collection.upsert(points).unwrap();
1018
1019        let query = vec![0.5, 0.0, 0.0];
1020        let results = collection.search(&query, 3).unwrap();
1021
1022        // Point 1 (0,0,0) and Point 2 (1,0,0) should be closest to query (0.5,0,0)
1023        assert!(results[0].point.id == 1 || results[0].point.id == 2);
1024    }
1025
1026    #[test]
1027    fn test_collection_text_search() {
1028        let dir = tempdir().unwrap();
1029        let path = dir.path().join("test_collection");
1030
1031        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1032
1033        let points = vec![
1034            Point::new(
1035                1,
1036                vec![1.0, 0.0, 0.0],
1037                Some(json!({"title": "Rust Programming", "content": "Learn Rust language"})),
1038            ),
1039            Point::new(
1040                2,
1041                vec![0.0, 1.0, 0.0],
1042                Some(json!({"title": "Python Tutorial", "content": "Python is great"})),
1043            ),
1044            Point::new(
1045                3,
1046                vec![0.0, 0.0, 1.0],
1047                Some(json!({"title": "Rust Performance", "content": "Rust is fast"})),
1048            ),
1049        ];
1050        collection.upsert(points).unwrap();
1051
1052        // Search for "rust" - should match docs 1 and 3
1053        let results = collection.text_search("rust", 10);
1054        assert_eq!(results.len(), 2);
1055
1056        let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1057        assert!(ids.contains(&1));
1058        assert!(ids.contains(&3));
1059    }
1060
1061    #[test]
1062    fn test_collection_hybrid_search() {
1063        let dir = tempdir().unwrap();
1064        let path = dir.path().join("test_collection");
1065
1066        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1067
1068        let points = vec![
1069            Point::new(
1070                1,
1071                vec![1.0, 0.0, 0.0],
1072                Some(json!({"title": "Rust Programming"})),
1073            ),
1074            Point::new(
1075                2,
1076                vec![0.9, 0.1, 0.0], // Similar vector to query
1077                Some(json!({"title": "Python Programming"})),
1078            ),
1079            Point::new(
1080                3,
1081                vec![0.0, 1.0, 0.0],
1082                Some(json!({"title": "Rust Performance"})),
1083            ),
1084        ];
1085        collection.upsert(points).unwrap();
1086
1087        // Hybrid search: vector close to [1,0,0], text "rust"
1088        // Doc 1 matches both (vector + text)
1089        // Doc 2 matches vector only
1090        // Doc 3 matches text only
1091        let query = vec![1.0, 0.0, 0.0];
1092        let results = collection
1093            .hybrid_search(&query, "rust", 3, Some(0.5))
1094            .unwrap();
1095
1096        assert!(!results.is_empty());
1097        // Doc 1 should rank high (matches both)
1098        assert_eq!(results[0].point.id, 1);
1099    }
1100
1101    #[test]
1102    fn test_extract_text_from_payload() {
1103        // Test nested payload extraction
1104        let payload = json!({
1105            "title": "Hello",
1106            "meta": {
1107                "author": "World",
1108                "tags": ["rust", "fast"]
1109            }
1110        });
1111
1112        let text = Collection::extract_text_from_payload(&payload);
1113        assert!(text.contains("Hello"));
1114        assert!(text.contains("World"));
1115        assert!(text.contains("rust"));
1116        assert!(text.contains("fast"));
1117    }
1118
1119    #[test]
1120    fn test_text_search_empty_query() {
1121        let dir = tempdir().unwrap();
1122        let path = dir.path().join("test_collection");
1123
1124        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1125
1126        let points = vec![Point::new(
1127            1,
1128            vec![1.0, 0.0, 0.0],
1129            Some(json!({"content": "test document"})),
1130        )];
1131        collection.upsert(points).unwrap();
1132
1133        // Empty query should return empty results
1134        let results = collection.text_search("", 10);
1135        assert!(results.is_empty());
1136    }
1137
1138    #[test]
1139    fn test_text_search_no_payload() {
1140        let dir = tempdir().unwrap();
1141        let path = dir.path().join("test_collection");
1142
1143        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1144
1145        // Points without payload
1146        let points = vec![
1147            Point::new(1, vec![1.0, 0.0, 0.0], None),
1148            Point::new(2, vec![0.0, 1.0, 0.0], None),
1149        ];
1150        collection.upsert(points).unwrap();
1151
1152        // Text search should return empty (no text indexed)
1153        let results = collection.text_search("test", 10);
1154        assert!(results.is_empty());
1155    }
1156
1157    #[test]
1158    fn test_hybrid_search_text_weight_zero() {
1159        let dir = tempdir().unwrap();
1160        let path = dir.path().join("test_collection");
1161
1162        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1163
1164        let points = vec![
1165            Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Rust"}))),
1166            Point::new(2, vec![0.9, 0.1, 0.0], Some(json!({"title": "Python"}))),
1167        ];
1168        collection.upsert(points).unwrap();
1169
1170        // vector_weight=1.0 means text_weight=0.0 (pure vector search)
1171        let query = vec![0.9, 0.1, 0.0];
1172        let results = collection
1173            .hybrid_search(&query, "rust", 2, Some(1.0))
1174            .unwrap();
1175
1176        // Doc 2 should be first (closest vector) even though "rust" matches doc 1
1177        assert_eq!(results[0].point.id, 2);
1178    }
1179
1180    #[test]
1181    fn test_hybrid_search_vector_weight_zero() {
1182        let dir = tempdir().unwrap();
1183        let path = dir.path().join("test_collection");
1184
1185        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1186
1187        let points = vec![
1188            Point::new(
1189                1,
1190                vec![1.0, 0.0, 0.0],
1191                Some(json!({"title": "Rust programming language"})),
1192            ),
1193            Point::new(
1194                2,
1195                vec![0.99, 0.01, 0.0], // Very close to query vector
1196                Some(json!({"title": "Python programming"})),
1197            ),
1198        ];
1199        collection.upsert(points).unwrap();
1200
1201        // vector_weight=0.0 means text_weight=1.0 (pure text search)
1202        let query = vec![0.99, 0.01, 0.0];
1203        let results = collection
1204            .hybrid_search(&query, "rust", 2, Some(0.0))
1205            .unwrap();
1206
1207        // Doc 1 should be first (matches "rust") even though doc 2 has closer vector
1208        assert_eq!(results[0].point.id, 1);
1209    }
1210
1211    #[test]
1212    fn test_bm25_update_document() {
1213        let dir = tempdir().unwrap();
1214        let path = dir.path().join("test_collection");
1215
1216        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1217
1218        // Insert initial document
1219        let points = vec![Point::new(
1220            1,
1221            vec![1.0, 0.0, 0.0],
1222            Some(json!({"content": "rust programming"})),
1223        )];
1224        collection.upsert(points).unwrap();
1225
1226        // Verify it's indexed
1227        let results = collection.text_search("rust", 10);
1228        assert_eq!(results.len(), 1);
1229
1230        // Update document with different text
1231        let points = vec![Point::new(
1232            1,
1233            vec![1.0, 0.0, 0.0],
1234            Some(json!({"content": "python programming"})),
1235        )];
1236        collection.upsert(points).unwrap();
1237
1238        // Should no longer match "rust"
1239        let results = collection.text_search("rust", 10);
1240        assert!(results.is_empty());
1241
1242        // Should now match "python"
1243        let results = collection.text_search("python", 10);
1244        assert_eq!(results.len(), 1);
1245    }
1246
1247    #[test]
1248    fn test_bm25_large_dataset() {
1249        let dir = tempdir().unwrap();
1250        let path = dir.path().join("test_collection");
1251
1252        let collection = Collection::create(path, 4, DistanceMetric::Cosine).unwrap();
1253
1254        // Insert 100 documents
1255        let points: Vec<Point> = (0..100)
1256            .map(|i| {
1257                let content = if i % 10 == 0 {
1258                    format!("rust document number {i}")
1259                } else {
1260                    format!("other document number {i}")
1261                };
1262                Point::new(
1263                    i,
1264                    vec![0.1, 0.2, 0.3, 0.4],
1265                    Some(json!({"content": content})),
1266                )
1267            })
1268            .collect();
1269        collection.upsert(points).unwrap();
1270
1271        // Search for "rust" - should find 10 documents (0, 10, 20, ..., 90)
1272        let results = collection.text_search("rust", 100);
1273        assert_eq!(results.len(), 10);
1274
1275        // All results should have IDs divisible by 10
1276        for result in &results {
1277            assert_eq!(result.point.id % 10, 0);
1278        }
1279    }
1280
1281    #[test]
1282    fn test_bm25_persistence_on_reopen() {
1283        let dir = tempdir().unwrap();
1284        let path = dir.path().join("test_collection");
1285
1286        // Create collection and add documents
1287        {
1288            let collection = Collection::create(path.clone(), 4, DistanceMetric::Cosine).unwrap();
1289
1290            let points = vec![
1291                Point::new(
1292                    1,
1293                    vec![1.0, 0.0, 0.0, 0.0],
1294                    Some(json!({"content": "Rust programming language"})),
1295                ),
1296                Point::new(
1297                    2,
1298                    vec![0.0, 1.0, 0.0, 0.0],
1299                    Some(json!({"content": "Python tutorial"})),
1300                ),
1301                Point::new(
1302                    3,
1303                    vec![0.0, 0.0, 1.0, 0.0],
1304                    Some(json!({"content": "Rust is fast and safe"})),
1305                ),
1306            ];
1307            collection.upsert(points).unwrap();
1308
1309            // Verify search works before closing
1310            let results = collection.text_search("rust", 10);
1311            assert_eq!(results.len(), 2);
1312        }
1313
1314        // Reopen collection and verify BM25 index is rebuilt
1315        {
1316            let collection = Collection::open(path).unwrap();
1317
1318            // BM25 should be rebuilt from persisted payloads
1319            let results = collection.text_search("rust", 10);
1320            assert_eq!(results.len(), 2);
1321
1322            let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1323            assert!(ids.contains(&1));
1324            assert!(ids.contains(&3));
1325        }
1326    }
1327
1328    // =========================================================================
1329    // Tests for upsert_bulk (optimized bulk import)
1330    // =========================================================================
1331
1332    #[test]
1333    fn test_upsert_bulk_basic() {
1334        let dir = tempdir().unwrap();
1335        let path = dir.path().join("test_collection");
1336        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1337
1338        let points = vec![
1339            Point::new(1, vec![1.0, 0.0, 0.0], None),
1340            Point::new(2, vec![0.0, 1.0, 0.0], None),
1341            Point::new(3, vec![0.0, 0.0, 1.0], None),
1342        ];
1343
1344        let inserted = collection.upsert_bulk(&points).unwrap();
1345        assert_eq!(inserted, 3);
1346        assert_eq!(collection.len(), 3);
1347    }
1348
1349    #[test]
1350    fn test_upsert_bulk_with_payload() {
1351        let dir = tempdir().unwrap();
1352        let path = dir.path().join("test_collection");
1353        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1354
1355        let points = vec![
1356            Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Doc 1"}))),
1357            Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"title": "Doc 2"}))),
1358        ];
1359
1360        collection.upsert_bulk(&points).unwrap();
1361        let retrieved = collection.get(&[1, 2]);
1362        assert_eq!(retrieved.len(), 2);
1363        assert!(retrieved[0].as_ref().unwrap().payload.is_some());
1364    }
1365
1366    #[test]
1367    fn test_upsert_bulk_empty() {
1368        let dir = tempdir().unwrap();
1369        let path = dir.path().join("test_collection");
1370        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1371
1372        let points: Vec<Point> = vec![];
1373        let inserted = collection.upsert_bulk(&points).unwrap();
1374        assert_eq!(inserted, 0);
1375    }
1376
1377    #[test]
1378    fn test_upsert_bulk_dimension_mismatch() {
1379        let dir = tempdir().unwrap();
1380        let path = dir.path().join("test_collection");
1381        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1382
1383        let points = vec![
1384            Point::new(1, vec![1.0, 0.0, 0.0], None),
1385            Point::new(2, vec![0.0, 1.0], None), // Wrong dimension
1386        ];
1387
1388        let result = collection.upsert_bulk(&points);
1389        assert!(result.is_err());
1390    }
1391
1392    #[test]
1393    #[allow(clippy::cast_precision_loss)]
1394    fn test_upsert_bulk_large_batch() {
1395        let dir = tempdir().unwrap();
1396        let path = dir.path().join("test_collection");
1397        let collection = Collection::create(path, 64, DistanceMetric::Cosine).unwrap();
1398
1399        let points: Vec<Point> = (0_u64..500)
1400            .map(|i| {
1401                let vector: Vec<f32> = (0_u64..64)
1402                    .map(|j| ((i + j) % 100) as f32 / 100.0)
1403                    .collect();
1404                Point::new(i, vector, None)
1405            })
1406            .collect();
1407
1408        let inserted = collection.upsert_bulk(&points).unwrap();
1409        assert_eq!(inserted, 500);
1410        assert_eq!(collection.len(), 500);
1411    }
1412
1413    #[test]
1414    fn test_upsert_bulk_search_works() {
1415        let dir = tempdir().unwrap();
1416        let path = dir.path().join("test_collection");
1417        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1418
1419        // Use more distinct vectors to ensure deterministic search results
1420        let points = vec![
1421            Point::new(1, vec![1.0, 0.0, 0.0], None),
1422            Point::new(2, vec![0.0, 1.0, 0.0], None),
1423            Point::new(3, vec![0.0, 0.0, 1.0], None),
1424        ];
1425
1426        collection.upsert_bulk(&points).unwrap();
1427
1428        let query = vec![1.0, 0.0, 0.0];
1429        let results = collection.search(&query, 3).unwrap();
1430        assert!(!results.is_empty());
1431        // With distinct orthogonal vectors, id=1 should always be the top result
1432        assert_eq!(results[0].point.id, 1);
1433    }
1434
1435    #[test]
1436    fn test_upsert_bulk_bm25_indexing() {
1437        let dir = tempdir().unwrap();
1438        let path = dir.path().join("test_collection");
1439        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1440
1441        let points = vec![
1442            Point::new(
1443                1,
1444                vec![1.0, 0.0, 0.0],
1445                Some(json!({"content": "Rust lang"})),
1446            ),
1447            Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"content": "Python"}))),
1448            Point::new(
1449                3,
1450                vec![0.0, 0.0, 1.0],
1451                Some(json!({"content": "Rust fast"})),
1452            ),
1453        ];
1454
1455        collection.upsert_bulk(&points).unwrap();
1456        let results = collection.text_search("rust", 10);
1457        assert_eq!(results.len(), 2);
1458    }
1459}