velesdb_core/
collection.rs

1//! Collection management for `VelesDB`.
2
3use crate::distance::DistanceMetric;
4use crate::error::{Error, Result};
5use crate::index::{Bm25Index, HnswIndex, VectorIndex};
6use crate::point::{Point, SearchResult};
7use crate::storage::{LogPayloadStorage, MmapStorage, PayloadStorage, VectorStorage};
8
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12use std::sync::Arc;
13
14/// Metadata for a collection.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CollectionConfig {
17    /// Name of the collection.
18    pub name: String,
19
20    /// Vector dimension.
21    pub dimension: usize,
22
23    /// Distance metric.
24    pub metric: DistanceMetric,
25
26    /// Number of points in the collection.
27    pub point_count: usize,
28}
29
30/// A collection of vectors with associated metadata.
31#[derive(Clone)]
32pub struct Collection {
33    /// Path to the collection data.
34    path: PathBuf,
35
36    /// Collection configuration.
37    config: Arc<RwLock<CollectionConfig>>,
38
39    /// Vector storage (on-disk, memory-mapped).
40    vector_storage: Arc<RwLock<MmapStorage>>,
41
42    /// Payload storage (on-disk, log-structured).
43    payload_storage: Arc<RwLock<LogPayloadStorage>>,
44
45    /// HNSW index for fast approximate nearest neighbor search.
46    index: Arc<HnswIndex>,
47
48    /// BM25 index for full-text search.
49    text_index: Arc<Bm25Index>,
50}
51
52impl Collection {
53    /// Creates a new collection at the specified path.
54    ///
55    /// # Errors
56    ///
57    /// Returns an error if the directory cannot be created or the config cannot be saved.
58    pub fn create(path: PathBuf, dimension: usize, metric: DistanceMetric) -> Result<Self> {
59        std::fs::create_dir_all(&path)?;
60
61        let name = path
62            .file_name()
63            .and_then(|n| n.to_str())
64            .unwrap_or("unknown")
65            .to_string();
66
67        let config = CollectionConfig {
68            name,
69            dimension,
70            metric,
71            point_count: 0,
72        };
73
74        // Initialize persistent storages
75        let vector_storage = Arc::new(RwLock::new(
76            MmapStorage::new(&path, dimension).map_err(Error::Io)?,
77        ));
78
79        let payload_storage = Arc::new(RwLock::new(
80            LogPayloadStorage::new(&path).map_err(Error::Io)?,
81        ));
82
83        // Create HNSW index
84        let index = Arc::new(HnswIndex::new(dimension, metric));
85
86        // Create BM25 index for full-text search
87        let text_index = Arc::new(Bm25Index::new());
88
89        let collection = Self {
90            path,
91            config: Arc::new(RwLock::new(config)),
92            vector_storage,
93            payload_storage,
94            index,
95            text_index,
96        };
97
98        collection.save_config()?;
99
100        Ok(collection)
101    }
102
103    /// Opens an existing collection from the specified path.
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the config file cannot be read or parsed.
108    pub fn open(path: PathBuf) -> Result<Self> {
109        let config_path = path.join("config.json");
110        let config_data = std::fs::read_to_string(&config_path)?;
111        let config: CollectionConfig =
112            serde_json::from_str(&config_data).map_err(|e| Error::Serialization(e.to_string()))?;
113
114        // Open persistent storages
115        let vector_storage = Arc::new(RwLock::new(
116            MmapStorage::new(&path, config.dimension).map_err(Error::Io)?,
117        ));
118
119        let payload_storage = Arc::new(RwLock::new(
120            LogPayloadStorage::new(&path).map_err(Error::Io)?,
121        ));
122
123        // Load HNSW index if it exists, otherwise create new (empty)
124        let index = if path.join("hnsw.bin").exists() {
125            Arc::new(HnswIndex::load(&path, config.dimension, config.metric).map_err(Error::Io)?)
126        } else {
127            Arc::new(HnswIndex::new(config.dimension, config.metric))
128        };
129
130        // Create and rebuild BM25 index from existing payloads
131        let text_index = Arc::new(Bm25Index::new());
132
133        // Rebuild BM25 index from persisted payloads
134        {
135            let storage = payload_storage.read();
136            let ids = storage.ids();
137            for id in ids {
138                if let Ok(Some(payload)) = storage.retrieve(id) {
139                    let text = Self::extract_text_from_payload(&payload);
140                    if !text.is_empty() {
141                        text_index.add_document(id, &text);
142                    }
143                }
144            }
145        }
146
147        Ok(Self {
148            path,
149            config: Arc::new(RwLock::new(config)),
150            vector_storage,
151            payload_storage,
152            index,
153            text_index,
154        })
155    }
156
157    /// Returns the collection configuration.
158    #[must_use]
159    pub fn config(&self) -> CollectionConfig {
160        self.config.read().clone()
161    }
162
163    /// Inserts or updates points in the collection.
164    ///
165    /// # Errors
166    ///
167    /// Returns an error if any point has a mismatched dimension.
168    pub fn upsert(&self, points: Vec<Point>) -> Result<()> {
169        let config = self.config.read();
170        let dimension = config.dimension;
171        drop(config);
172
173        // Validate dimensions first
174        for point in &points {
175            if point.dimension() != dimension {
176                return Err(Error::DimensionMismatch {
177                    expected: dimension,
178                    actual: point.dimension(),
179                });
180            }
181        }
182
183        let mut vector_storage = self.vector_storage.write();
184        let mut payload_storage = self.payload_storage.write();
185
186        for point in points {
187            // 1. Store Vector
188            vector_storage
189                .store(point.id, &point.vector)
190                .map_err(Error::Io)?;
191
192            // 2. Store Payload (if present)
193            if let Some(payload) = &point.payload {
194                payload_storage
195                    .store(point.id, payload)
196                    .map_err(Error::Io)?;
197            } else {
198                // If payload is None, check if we need to delete existing payload?
199                // For now, let's assume upsert with None doesn't clear payload unless explicit.
200                // Or consistency: Point represents full state. If None, maybe we should delete?
201                // Let's stick to: if None, do nothing (keep existing) or delete?
202                // Typically upsert replaces. Let's say if None, we delete potential existing payload to be consistent.
203                let _ = payload_storage.delete(point.id); // Ignore error if not found
204            }
205
206            // 3. Update Vector Index
207            // Note: HnswIndex.insert() skips if ID already exists (no updates supported)
208            // For true upsert semantics, we'd need to remove then re-insert
209            self.index.insert(point.id, &point.vector);
210
211            // 4. Update BM25 Text Index
212            if let Some(payload) = &point.payload {
213                let text = Self::extract_text_from_payload(payload);
214                if !text.is_empty() {
215                    self.text_index.add_document(point.id, &text);
216                }
217            } else {
218                // Remove from text index if payload was cleared
219                self.text_index.remove_document(point.id);
220            }
221        }
222
223        // Update point count
224        let mut config = self.config.write();
225        config.point_count = vector_storage.len();
226
227        // Auto-flush for durability (MVP choice: consistent but slower)
228        // In prod, this might be backgrounded or explicit.
229        vector_storage.flush().map_err(Error::Io)?;
230        payload_storage.flush().map_err(Error::Io)?;
231        self.index.save(&self.path).map_err(Error::Io)?;
232
233        Ok(())
234    }
235
236    /// Bulk insert optimized for high-throughput import.
237    ///
238    /// # Performance
239    ///
240    /// This method is optimized for bulk loading:
241    /// - Uses parallel HNSW insertion (rayon)
242    /// - Single flush at the end (not per-point)
243    /// - ~2-3x faster than regular `upsert()` for large batches
244    ///
245    /// # Errors
246    ///
247    /// Returns an error if any point has a mismatched dimension.
248    pub fn upsert_bulk(&self, points: &[Point]) -> Result<usize> {
249        if points.is_empty() {
250            return Ok(0);
251        }
252
253        let config = self.config.read();
254        let dimension = config.dimension;
255        drop(config);
256
257        // Validate dimensions first
258        for point in points {
259            if point.dimension() != dimension {
260                return Err(Error::DimensionMismatch {
261                    expected: dimension,
262                    actual: point.dimension(),
263                });
264            }
265        }
266
267        // Perf: Collect vectors for parallel HNSW insertion (needed for clone anyway)
268        let vectors_for_hnsw: Vec<(u64, Vec<f32>)> =
269            points.iter().map(|p| (p.id, p.vector.clone())).collect();
270
271        // Perf: Single batch WAL write + contiguous mmap write
272        // Use references from vectors_for_hnsw to avoid double allocation
273        let vectors_for_storage: Vec<(u64, &[f32])> = vectors_for_hnsw
274            .iter()
275            .map(|(id, v)| (*id, v.as_slice()))
276            .collect();
277
278        let mut vector_storage = self.vector_storage.write();
279        vector_storage
280            .store_batch(&vectors_for_storage)
281            .map_err(Error::Io)?;
282        drop(vector_storage);
283
284        // Store payloads and update BM25 (still sequential for now)
285        let mut payload_storage = self.payload_storage.write();
286        for point in points {
287            if let Some(payload) = &point.payload {
288                payload_storage
289                    .store(point.id, payload)
290                    .map_err(Error::Io)?;
291
292                // Update BM25 text index
293                let text = Self::extract_text_from_payload(payload);
294                if !text.is_empty() {
295                    self.text_index.add_document(point.id, &text);
296                }
297            }
298        }
299        drop(payload_storage);
300
301        // Perf: Parallel HNSW insertion (CPU bound - benefits from parallelism)
302        let inserted = self.index.insert_batch_parallel(vectors_for_hnsw);
303        self.index.set_searching_mode();
304
305        // Update point count
306        let mut config = self.config.write();
307        config.point_count = self.vector_storage.read().len();
308        drop(config);
309
310        // Single flush at the end (not per-point)
311        self.vector_storage.write().flush().map_err(Error::Io)?;
312        self.payload_storage.write().flush().map_err(Error::Io)?;
313        self.index.save(&self.path).map_err(Error::Io)?;
314
315        Ok(inserted)
316    }
317
318    /// Retrieves points by their IDs.
319    #[must_use]
320    pub fn get(&self, ids: &[u64]) -> Vec<Option<Point>> {
321        let vector_storage = self.vector_storage.read();
322        let payload_storage = self.payload_storage.read();
323
324        ids.iter()
325            .map(|&id| {
326                // Retrieve vector
327                let vector = vector_storage.retrieve(id).ok().flatten()?;
328
329                // Retrieve payload
330                let payload = payload_storage.retrieve(id).ok().flatten();
331
332                Some(Point {
333                    id,
334                    vector,
335                    payload,
336                })
337            })
338            .collect()
339    }
340
341    /// Deletes points by their IDs.
342    ///
343    /// # Errors
344    ///
345    /// Returns an error if storage operations fail.
346    pub fn delete(&self, ids: &[u64]) -> Result<()> {
347        let mut vector_storage = self.vector_storage.write();
348        let mut payload_storage = self.payload_storage.write();
349
350        for &id in ids {
351            vector_storage.delete(id).map_err(Error::Io)?;
352            payload_storage.delete(id).map_err(Error::Io)?;
353            self.index.remove(id);
354        }
355
356        let mut config = self.config.write();
357        config.point_count = vector_storage.len();
358
359        Ok(())
360    }
361
362    /// Searches for the k nearest neighbors of the query vector.
363    ///
364    /// Uses HNSW index for fast approximate nearest neighbor search.
365    ///
366    /// # Errors
367    ///
368    /// Returns an error if the query vector dimension doesn't match the collection.
369    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
370        let config = self.config.read();
371
372        if query.len() != config.dimension {
373            return Err(Error::DimensionMismatch {
374                expected: config.dimension,
375                actual: query.len(),
376            });
377        }
378        drop(config);
379
380        // Use HNSW index for fast ANN search
381        let index_results = self.index.search(query, k);
382
383        let vector_storage = self.vector_storage.read();
384        let payload_storage = self.payload_storage.read();
385
386        // Map index results to SearchResult with full point data
387        let results: Vec<SearchResult> = index_results
388            .into_iter()
389            .filter_map(|(id, score)| {
390                // We need to fetch vector and payload
391                let vector = vector_storage.retrieve(id).ok().flatten()?;
392                let payload = payload_storage.retrieve(id).ok().flatten();
393
394                let point = Point {
395                    id,
396                    vector,
397                    payload,
398                };
399
400                Some(SearchResult::new(point, score))
401            })
402            .collect();
403
404        Ok(results)
405    }
406
407    /// Returns the number of points in the collection.
408    #[must_use]
409    pub fn len(&self) -> usize {
410        self.vector_storage.read().len()
411    }
412
413    /// Returns true if the collection is empty.
414    #[must_use]
415    pub fn is_empty(&self) -> bool {
416        self.vector_storage.read().is_empty()
417    }
418
419    /// Saves the collection configuration and index to disk.
420    ///
421    /// # Errors
422    ///
423    /// Returns an error if storage operations fail.
424    pub fn flush(&self) -> Result<()> {
425        self.save_config()?;
426        self.vector_storage.write().flush().map_err(Error::Io)?;
427        self.payload_storage.write().flush().map_err(Error::Io)?;
428        self.index.save(&self.path).map_err(Error::Io)?;
429        Ok(())
430    }
431
432    /// Saves the collection configuration to disk.
433    fn save_config(&self) -> Result<()> {
434        let config = self.config.read();
435        let config_path = self.path.join("config.json");
436        let config_data = serde_json::to_string_pretty(&*config)
437            .map_err(|e| Error::Serialization(e.to_string()))?;
438        std::fs::write(config_path, config_data)?;
439        Ok(())
440    }
441
442    /// Performs full-text search using BM25.
443    ///
444    /// # Arguments
445    ///
446    /// * `query` - Text query to search for
447    /// * `k` - Maximum number of results to return
448    ///
449    /// # Returns
450    ///
451    /// Vector of search results sorted by BM25 score (descending).
452    #[must_use]
453    pub fn text_search(&self, query: &str, k: usize) -> Vec<SearchResult> {
454        let bm25_results = self.text_index.search(query, k);
455
456        let vector_storage = self.vector_storage.read();
457        let payload_storage = self.payload_storage.read();
458
459        bm25_results
460            .into_iter()
461            .filter_map(|(id, score)| {
462                let vector = vector_storage.retrieve(id).ok().flatten()?;
463                let payload = payload_storage.retrieve(id).ok().flatten();
464
465                let point = Point {
466                    id,
467                    vector,
468                    payload,
469                };
470
471                Some(SearchResult::new(point, score))
472            })
473            .collect()
474    }
475
476    /// Performs hybrid search combining vector similarity and full-text search.
477    ///
478    /// Uses Reciprocal Rank Fusion (RRF) to combine results from both searches.
479    ///
480    /// # Arguments
481    ///
482    /// * `vector_query` - Query vector for similarity search
483    /// * `text_query` - Text query for BM25 search
484    /// * `k` - Maximum number of results to return
485    /// * `vector_weight` - Weight for vector results (0.0-1.0, default 0.5)
486    ///
487    /// # Errors
488    ///
489    /// Returns an error if the query vector dimension doesn't match.
490    pub fn hybrid_search(
491        &self,
492        vector_query: &[f32],
493        text_query: &str,
494        k: usize,
495        vector_weight: Option<f32>,
496    ) -> Result<Vec<SearchResult>> {
497        let config = self.config.read();
498        if vector_query.len() != config.dimension {
499            return Err(Error::DimensionMismatch {
500                expected: config.dimension,
501                actual: vector_query.len(),
502            });
503        }
504        drop(config);
505
506        let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
507        let text_weight = 1.0 - weight;
508
509        // Get vector search results (more than k to allow for fusion)
510        let vector_results = self.index.search(vector_query, k * 2);
511
512        // Get BM25 text search results
513        let text_results = self.text_index.search(text_query, k * 2);
514
515        // Perf: Apply RRF (Reciprocal Rank Fusion) with FxHashMap for faster hashing
516        // RRF score = 1 / (rank + 60) - the constant 60 is standard
517        let mut fused_scores: rustc_hash::FxHashMap<u64, f32> = rustc_hash::FxHashMap::default();
518
519        // Add vector scores with RRF
520        #[allow(clippy::cast_precision_loss)]
521        for (rank, (id, _)) in vector_results.iter().enumerate() {
522            let rrf_score = weight / (rank as f32 + 60.0);
523            *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
524        }
525
526        // Add text scores with RRF
527        #[allow(clippy::cast_precision_loss)]
528        for (rank, (id, _)) in text_results.iter().enumerate() {
529            let rrf_score = text_weight / (rank as f32 + 60.0);
530            *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
531        }
532
533        // Perf: Use partial sort for top-k instead of full sort
534        let mut scored_ids: Vec<_> = fused_scores.into_iter().collect();
535        if scored_ids.len() > k {
536            scored_ids.select_nth_unstable_by(k, |a, b| {
537                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
538            });
539            scored_ids.truncate(k);
540            scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
541        } else {
542            scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
543        }
544
545        // Fetch full point data
546        let vector_storage = self.vector_storage.read();
547        let payload_storage = self.payload_storage.read();
548
549        let results: Vec<SearchResult> = scored_ids
550            .into_iter()
551            .filter_map(|(id, score)| {
552                let vector = vector_storage.retrieve(id).ok().flatten()?;
553                let payload = payload_storage.retrieve(id).ok().flatten();
554
555                let point = Point {
556                    id,
557                    vector,
558                    payload,
559                };
560
561                Some(SearchResult::new(point, score))
562            })
563            .collect();
564
565        Ok(results)
566    }
567
568    /// Extracts all string values from a JSON payload for text indexing.
569    fn extract_text_from_payload(payload: &serde_json::Value) -> String {
570        let mut texts = Vec::new();
571        Self::collect_strings(payload, &mut texts);
572        texts.join(" ")
573    }
574
575    /// Recursively collects all string values from a JSON value.
576    fn collect_strings(value: &serde_json::Value, texts: &mut Vec<String>) {
577        match value {
578            serde_json::Value::String(s) => texts.push(s.clone()),
579            serde_json::Value::Array(arr) => {
580                for item in arr {
581                    Self::collect_strings(item, texts);
582                }
583            }
584            serde_json::Value::Object(obj) => {
585                for v in obj.values() {
586                    Self::collect_strings(v, texts);
587                }
588            }
589            _ => {}
590        }
591    }
592}
593
594#[cfg(test)]
595mod tests {
596    use super::*;
597    use serde_json::json;
598    use tempfile::tempdir;
599
600    #[test]
601    fn test_collection_create() {
602        let dir = tempdir().unwrap();
603        let path = dir.path().join("test_collection");
604
605        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
606        let config = collection.config();
607
608        assert_eq!(config.dimension, 3);
609        assert_eq!(config.metric, DistanceMetric::Cosine);
610        assert_eq!(config.point_count, 0);
611    }
612
613    #[test]
614    fn test_collection_upsert_and_search() {
615        let dir = tempdir().unwrap();
616        let path = dir.path().join("test_collection");
617
618        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
619
620        let points = vec![
621            Point::without_payload(1, vec![1.0, 0.0, 0.0]),
622            Point::without_payload(2, vec![0.0, 1.0, 0.0]),
623            Point::without_payload(3, vec![0.0, 0.0, 1.0]),
624        ];
625
626        collection.upsert(points).unwrap();
627        assert_eq!(collection.len(), 3);
628
629        let query = vec![1.0, 0.0, 0.0];
630        let results = collection.search(&query, 2).unwrap();
631
632        assert_eq!(results.len(), 2);
633        assert_eq!(results[0].point.id, 1); // Most similar
634    }
635
636    #[test]
637    fn test_dimension_mismatch() {
638        let dir = tempdir().unwrap();
639        let path = dir.path().join("test_collection");
640
641        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
642
643        let points = vec![Point::without_payload(1, vec![1.0, 0.0])]; // Wrong dimension
644
645        let result = collection.upsert(points);
646        assert!(result.is_err());
647    }
648
649    #[test]
650    fn test_collection_open_existing() {
651        let dir = tempdir().unwrap();
652        let path = dir.path().join("test_collection");
653
654        // Create and populate collection
655        {
656            let collection =
657                Collection::create(path.clone(), 3, DistanceMetric::Euclidean).unwrap();
658            let points = vec![
659                Point::without_payload(1, vec![1.0, 2.0, 3.0]),
660                Point::without_payload(2, vec![4.0, 5.0, 6.0]),
661            ];
662            collection.upsert(points).unwrap();
663            collection.flush().unwrap();
664        }
665
666        // Reopen and verify
667        let collection = Collection::open(path).unwrap();
668        let config = collection.config();
669
670        assert_eq!(config.dimension, 3);
671        assert_eq!(config.metric, DistanceMetric::Euclidean);
672        assert_eq!(collection.len(), 2);
673    }
674
675    #[test]
676    fn test_collection_get_points() {
677        let dir = tempdir().unwrap();
678        let path = dir.path().join("test_collection");
679
680        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
681        let points = vec![
682            Point::without_payload(1, vec![1.0, 0.0, 0.0]),
683            Point::without_payload(2, vec![0.0, 1.0, 0.0]),
684        ];
685        collection.upsert(points).unwrap();
686
687        // Get existing points
688        let retrieved = collection.get(&[1, 2, 999]);
689
690        assert!(retrieved[0].is_some());
691        assert_eq!(retrieved[0].as_ref().unwrap().id, 1);
692        assert!(retrieved[1].is_some());
693        assert_eq!(retrieved[1].as_ref().unwrap().id, 2);
694        assert!(retrieved[2].is_none()); // 999 doesn't exist
695    }
696
697    #[test]
698    fn test_collection_delete_points() {
699        let dir = tempdir().unwrap();
700        let path = dir.path().join("test_collection");
701
702        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
703        let points = vec![
704            Point::without_payload(1, vec![1.0, 0.0, 0.0]),
705            Point::without_payload(2, vec![0.0, 1.0, 0.0]),
706            Point::without_payload(3, vec![0.0, 0.0, 1.0]),
707        ];
708        collection.upsert(points).unwrap();
709        assert_eq!(collection.len(), 3);
710
711        // Delete one point
712        collection.delete(&[2]).unwrap();
713        assert_eq!(collection.len(), 2);
714
715        // Verify it's gone
716        let retrieved = collection.get(&[2]);
717        assert!(retrieved[0].is_none());
718    }
719
720    #[test]
721    fn test_collection_is_empty() {
722        let dir = tempdir().unwrap();
723        let path = dir.path().join("test_collection");
724
725        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
726        assert!(collection.is_empty());
727
728        collection
729            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
730            .unwrap();
731        assert!(!collection.is_empty());
732    }
733
734    #[test]
735    fn test_collection_with_payload() {
736        let dir = tempdir().unwrap();
737        let path = dir.path().join("test_collection");
738
739        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
740
741        let points = vec![Point::new(
742            1,
743            vec![1.0, 0.0, 0.0],
744            Some(json!({"title": "Test Document", "category": "tech"})),
745        )];
746        collection.upsert(points).unwrap();
747
748        let retrieved = collection.get(&[1]);
749        assert!(retrieved[0].is_some());
750
751        let point = retrieved[0].as_ref().unwrap();
752        assert!(point.payload.is_some());
753        assert_eq!(point.payload.as_ref().unwrap()["title"], "Test Document");
754    }
755
756    #[test]
757    fn test_collection_search_dimension_mismatch() {
758        let dir = tempdir().unwrap();
759        let path = dir.path().join("test_collection");
760
761        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
762        collection
763            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
764            .unwrap();
765
766        // Search with wrong dimension
767        let result = collection.search(&[1.0, 0.0], 5);
768        assert!(result.is_err());
769    }
770
771    #[test]
772    fn test_collection_upsert_replaces_payload() {
773        let dir = tempdir().unwrap();
774        let path = dir.path().join("test_collection");
775
776        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
777
778        // Insert with payload
779        collection
780            .upsert(vec![Point::new(
781                1,
782                vec![1.0, 0.0, 0.0],
783                Some(json!({"version": 1})),
784            )])
785            .unwrap();
786
787        // Upsert without payload (should clear it)
788        collection
789            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
790            .unwrap();
791
792        let retrieved = collection.get(&[1]);
793        let point = retrieved[0].as_ref().unwrap();
794        assert!(point.payload.is_none());
795    }
796
797    #[test]
798    fn test_collection_flush() {
799        let dir = tempdir().unwrap();
800        let path = dir.path().join("test_collection");
801
802        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
803        collection
804            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
805            .unwrap();
806
807        // Explicit flush should succeed
808        let result = collection.flush();
809        assert!(result.is_ok());
810    }
811
812    #[test]
813    fn test_collection_euclidean_metric() {
814        let dir = tempdir().unwrap();
815        let path = dir.path().join("test_collection");
816
817        let collection = Collection::create(path, 3, DistanceMetric::Euclidean).unwrap();
818
819        let points = vec![
820            Point::without_payload(1, vec![0.0, 0.0, 0.0]),
821            Point::without_payload(2, vec![1.0, 0.0, 0.0]),
822            Point::without_payload(3, vec![10.0, 0.0, 0.0]),
823        ];
824        collection.upsert(points).unwrap();
825
826        let query = vec![0.5, 0.0, 0.0];
827        let results = collection.search(&query, 3).unwrap();
828
829        // Point 1 (0,0,0) and Point 2 (1,0,0) should be closest to query (0.5,0,0)
830        assert!(results[0].point.id == 1 || results[0].point.id == 2);
831    }
832
833    #[test]
834    fn test_collection_text_search() {
835        let dir = tempdir().unwrap();
836        let path = dir.path().join("test_collection");
837
838        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
839
840        let points = vec![
841            Point::new(
842                1,
843                vec![1.0, 0.0, 0.0],
844                Some(json!({"title": "Rust Programming", "content": "Learn Rust language"})),
845            ),
846            Point::new(
847                2,
848                vec![0.0, 1.0, 0.0],
849                Some(json!({"title": "Python Tutorial", "content": "Python is great"})),
850            ),
851            Point::new(
852                3,
853                vec![0.0, 0.0, 1.0],
854                Some(json!({"title": "Rust Performance", "content": "Rust is fast"})),
855            ),
856        ];
857        collection.upsert(points).unwrap();
858
859        // Search for "rust" - should match docs 1 and 3
860        let results = collection.text_search("rust", 10);
861        assert_eq!(results.len(), 2);
862
863        let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
864        assert!(ids.contains(&1));
865        assert!(ids.contains(&3));
866    }
867
868    #[test]
869    fn test_collection_hybrid_search() {
870        let dir = tempdir().unwrap();
871        let path = dir.path().join("test_collection");
872
873        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
874
875        let points = vec![
876            Point::new(
877                1,
878                vec![1.0, 0.0, 0.0],
879                Some(json!({"title": "Rust Programming"})),
880            ),
881            Point::new(
882                2,
883                vec![0.9, 0.1, 0.0], // Similar vector to query
884                Some(json!({"title": "Python Programming"})),
885            ),
886            Point::new(
887                3,
888                vec![0.0, 1.0, 0.0],
889                Some(json!({"title": "Rust Performance"})),
890            ),
891        ];
892        collection.upsert(points).unwrap();
893
894        // Hybrid search: vector close to [1,0,0], text "rust"
895        // Doc 1 matches both (vector + text)
896        // Doc 2 matches vector only
897        // Doc 3 matches text only
898        let query = vec![1.0, 0.0, 0.0];
899        let results = collection
900            .hybrid_search(&query, "rust", 3, Some(0.5))
901            .unwrap();
902
903        assert!(!results.is_empty());
904        // Doc 1 should rank high (matches both)
905        assert_eq!(results[0].point.id, 1);
906    }
907
908    #[test]
909    fn test_extract_text_from_payload() {
910        // Test nested payload extraction
911        let payload = json!({
912            "title": "Hello",
913            "meta": {
914                "author": "World",
915                "tags": ["rust", "fast"]
916            }
917        });
918
919        let text = Collection::extract_text_from_payload(&payload);
920        assert!(text.contains("Hello"));
921        assert!(text.contains("World"));
922        assert!(text.contains("rust"));
923        assert!(text.contains("fast"));
924    }
925
926    #[test]
927    fn test_text_search_empty_query() {
928        let dir = tempdir().unwrap();
929        let path = dir.path().join("test_collection");
930
931        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
932
933        let points = vec![Point::new(
934            1,
935            vec![1.0, 0.0, 0.0],
936            Some(json!({"content": "test document"})),
937        )];
938        collection.upsert(points).unwrap();
939
940        // Empty query should return empty results
941        let results = collection.text_search("", 10);
942        assert!(results.is_empty());
943    }
944
945    #[test]
946    fn test_text_search_no_payload() {
947        let dir = tempdir().unwrap();
948        let path = dir.path().join("test_collection");
949
950        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
951
952        // Points without payload
953        let points = vec![
954            Point::new(1, vec![1.0, 0.0, 0.0], None),
955            Point::new(2, vec![0.0, 1.0, 0.0], None),
956        ];
957        collection.upsert(points).unwrap();
958
959        // Text search should return empty (no text indexed)
960        let results = collection.text_search("test", 10);
961        assert!(results.is_empty());
962    }
963
964    #[test]
965    fn test_hybrid_search_text_weight_zero() {
966        let dir = tempdir().unwrap();
967        let path = dir.path().join("test_collection");
968
969        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
970
971        let points = vec![
972            Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Rust"}))),
973            Point::new(2, vec![0.9, 0.1, 0.0], Some(json!({"title": "Python"}))),
974        ];
975        collection.upsert(points).unwrap();
976
977        // vector_weight=1.0 means text_weight=0.0 (pure vector search)
978        let query = vec![0.9, 0.1, 0.0];
979        let results = collection
980            .hybrid_search(&query, "rust", 2, Some(1.0))
981            .unwrap();
982
983        // Doc 2 should be first (closest vector) even though "rust" matches doc 1
984        assert_eq!(results[0].point.id, 2);
985    }
986
987    #[test]
988    fn test_hybrid_search_vector_weight_zero() {
989        let dir = tempdir().unwrap();
990        let path = dir.path().join("test_collection");
991
992        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
993
994        let points = vec![
995            Point::new(
996                1,
997                vec![1.0, 0.0, 0.0],
998                Some(json!({"title": "Rust programming language"})),
999            ),
1000            Point::new(
1001                2,
1002                vec![0.99, 0.01, 0.0], // Very close to query vector
1003                Some(json!({"title": "Python programming"})),
1004            ),
1005        ];
1006        collection.upsert(points).unwrap();
1007
1008        // vector_weight=0.0 means text_weight=1.0 (pure text search)
1009        let query = vec![0.99, 0.01, 0.0];
1010        let results = collection
1011            .hybrid_search(&query, "rust", 2, Some(0.0))
1012            .unwrap();
1013
1014        // Doc 1 should be first (matches "rust") even though doc 2 has closer vector
1015        assert_eq!(results[0].point.id, 1);
1016    }
1017
1018    #[test]
1019    fn test_bm25_update_document() {
1020        let dir = tempdir().unwrap();
1021        let path = dir.path().join("test_collection");
1022
1023        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1024
1025        // Insert initial document
1026        let points = vec![Point::new(
1027            1,
1028            vec![1.0, 0.0, 0.0],
1029            Some(json!({"content": "rust programming"})),
1030        )];
1031        collection.upsert(points).unwrap();
1032
1033        // Verify it's indexed
1034        let results = collection.text_search("rust", 10);
1035        assert_eq!(results.len(), 1);
1036
1037        // Update document with different text
1038        let points = vec![Point::new(
1039            1,
1040            vec![1.0, 0.0, 0.0],
1041            Some(json!({"content": "python programming"})),
1042        )];
1043        collection.upsert(points).unwrap();
1044
1045        // Should no longer match "rust"
1046        let results = collection.text_search("rust", 10);
1047        assert!(results.is_empty());
1048
1049        // Should now match "python"
1050        let results = collection.text_search("python", 10);
1051        assert_eq!(results.len(), 1);
1052    }
1053
1054    #[test]
1055    fn test_bm25_large_dataset() {
1056        let dir = tempdir().unwrap();
1057        let path = dir.path().join("test_collection");
1058
1059        let collection = Collection::create(path, 4, DistanceMetric::Cosine).unwrap();
1060
1061        // Insert 100 documents
1062        let points: Vec<Point> = (0..100)
1063            .map(|i| {
1064                let content = if i % 10 == 0 {
1065                    format!("rust document number {i}")
1066                } else {
1067                    format!("other document number {i}")
1068                };
1069                Point::new(
1070                    i,
1071                    vec![0.1, 0.2, 0.3, 0.4],
1072                    Some(json!({"content": content})),
1073                )
1074            })
1075            .collect();
1076        collection.upsert(points).unwrap();
1077
1078        // Search for "rust" - should find 10 documents (0, 10, 20, ..., 90)
1079        let results = collection.text_search("rust", 100);
1080        assert_eq!(results.len(), 10);
1081
1082        // All results should have IDs divisible by 10
1083        for result in &results {
1084            assert_eq!(result.point.id % 10, 0);
1085        }
1086    }
1087
1088    #[test]
1089    fn test_bm25_persistence_on_reopen() {
1090        let dir = tempdir().unwrap();
1091        let path = dir.path().join("test_collection");
1092
1093        // Create collection and add documents
1094        {
1095            let collection = Collection::create(path.clone(), 4, DistanceMetric::Cosine).unwrap();
1096
1097            let points = vec![
1098                Point::new(
1099                    1,
1100                    vec![1.0, 0.0, 0.0, 0.0],
1101                    Some(json!({"content": "Rust programming language"})),
1102                ),
1103                Point::new(
1104                    2,
1105                    vec![0.0, 1.0, 0.0, 0.0],
1106                    Some(json!({"content": "Python tutorial"})),
1107                ),
1108                Point::new(
1109                    3,
1110                    vec![0.0, 0.0, 1.0, 0.0],
1111                    Some(json!({"content": "Rust is fast and safe"})),
1112                ),
1113            ];
1114            collection.upsert(points).unwrap();
1115
1116            // Verify search works before closing
1117            let results = collection.text_search("rust", 10);
1118            assert_eq!(results.len(), 2);
1119        }
1120
1121        // Reopen collection and verify BM25 index is rebuilt
1122        {
1123            let collection = Collection::open(path).unwrap();
1124
1125            // BM25 should be rebuilt from persisted payloads
1126            let results = collection.text_search("rust", 10);
1127            assert_eq!(results.len(), 2);
1128
1129            let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1130            assert!(ids.contains(&1));
1131            assert!(ids.contains(&3));
1132        }
1133    }
1134
1135    // =========================================================================
1136    // Tests for upsert_bulk (optimized bulk import)
1137    // =========================================================================
1138
1139    #[test]
1140    fn test_upsert_bulk_basic() {
1141        let dir = tempdir().unwrap();
1142        let path = dir.path().join("test_collection");
1143        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1144
1145        let points = vec![
1146            Point::new(1, vec![1.0, 0.0, 0.0], None),
1147            Point::new(2, vec![0.0, 1.0, 0.0], None),
1148            Point::new(3, vec![0.0, 0.0, 1.0], None),
1149        ];
1150
1151        let inserted = collection.upsert_bulk(&points).unwrap();
1152        assert_eq!(inserted, 3);
1153        assert_eq!(collection.len(), 3);
1154    }
1155
1156    #[test]
1157    fn test_upsert_bulk_with_payload() {
1158        let dir = tempdir().unwrap();
1159        let path = dir.path().join("test_collection");
1160        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1161
1162        let points = vec![
1163            Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Doc 1"}))),
1164            Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"title": "Doc 2"}))),
1165        ];
1166
1167        collection.upsert_bulk(&points).unwrap();
1168        let retrieved = collection.get(&[1, 2]);
1169        assert_eq!(retrieved.len(), 2);
1170        assert!(retrieved[0].as_ref().unwrap().payload.is_some());
1171    }
1172
1173    #[test]
1174    fn test_upsert_bulk_empty() {
1175        let dir = tempdir().unwrap();
1176        let path = dir.path().join("test_collection");
1177        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1178
1179        let points: Vec<Point> = vec![];
1180        let inserted = collection.upsert_bulk(&points).unwrap();
1181        assert_eq!(inserted, 0);
1182    }
1183
1184    #[test]
1185    fn test_upsert_bulk_dimension_mismatch() {
1186        let dir = tempdir().unwrap();
1187        let path = dir.path().join("test_collection");
1188        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1189
1190        let points = vec![
1191            Point::new(1, vec![1.0, 0.0, 0.0], None),
1192            Point::new(2, vec![0.0, 1.0], None), // Wrong dimension
1193        ];
1194
1195        let result = collection.upsert_bulk(&points);
1196        assert!(result.is_err());
1197    }
1198
1199    #[test]
1200    #[allow(clippy::cast_precision_loss)]
1201    fn test_upsert_bulk_large_batch() {
1202        let dir = tempdir().unwrap();
1203        let path = dir.path().join("test_collection");
1204        let collection = Collection::create(path, 64, DistanceMetric::Cosine).unwrap();
1205
1206        let points: Vec<Point> = (0_u64..500)
1207            .map(|i| {
1208                let vector: Vec<f32> = (0_u64..64)
1209                    .map(|j| ((i + j) % 100) as f32 / 100.0)
1210                    .collect();
1211                Point::new(i, vector, None)
1212            })
1213            .collect();
1214
1215        let inserted = collection.upsert_bulk(&points).unwrap();
1216        assert_eq!(inserted, 500);
1217        assert_eq!(collection.len(), 500);
1218    }
1219
1220    #[test]
1221    fn test_upsert_bulk_search_works() {
1222        let dir = tempdir().unwrap();
1223        let path = dir.path().join("test_collection");
1224        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1225
1226        let points = vec![
1227            Point::new(1, vec![1.0, 0.0, 0.0], None),
1228            Point::new(2, vec![0.9, 0.1, 0.0], None),
1229            Point::new(3, vec![0.0, 1.0, 0.0], None),
1230        ];
1231
1232        collection.upsert_bulk(&points).unwrap();
1233
1234        let query = vec![1.0, 0.0, 0.0];
1235        let results = collection.search(&query, 3).unwrap();
1236        assert!(!results.is_empty());
1237        assert_eq!(results[0].point.id, 1);
1238    }
1239
1240    #[test]
1241    fn test_upsert_bulk_bm25_indexing() {
1242        let dir = tempdir().unwrap();
1243        let path = dir.path().join("test_collection");
1244        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1245
1246        let points = vec![
1247            Point::new(
1248                1,
1249                vec![1.0, 0.0, 0.0],
1250                Some(json!({"content": "Rust lang"})),
1251            ),
1252            Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"content": "Python"}))),
1253            Point::new(
1254                3,
1255                vec![0.0, 0.0, 1.0],
1256                Some(json!({"content": "Rust fast"})),
1257            ),
1258        ];
1259
1260        collection.upsert_bulk(&points).unwrap();
1261        let results = collection.text_search("rust", 10);
1262        assert_eq!(results.len(), 2);
1263    }
1264}