velesdb_core/
collection.rs

1//! Collection management for `VelesDB`.
2
3use crate::distance::DistanceMetric;
4use crate::error::{Error, Result};
5use crate::index::{Bm25Index, HnswIndex, VectorIndex};
6use crate::point::{Point, SearchResult};
7use crate::storage::{LogPayloadStorage, MmapStorage, PayloadStorage, VectorStorage};
8
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12use std::sync::Arc;
13
14/// Metadata for a collection.
15#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CollectionConfig {
17    /// Name of the collection.
18    pub name: String,
19
20    /// Vector dimension.
21    pub dimension: usize,
22
23    /// Distance metric.
24    pub metric: DistanceMetric,
25
26    /// Number of points in the collection.
27    pub point_count: usize,
28}
29
30/// A collection of vectors with associated metadata.
31#[derive(Clone)]
32pub struct Collection {
33    /// Path to the collection data.
34    path: PathBuf,
35
36    /// Collection configuration.
37    config: Arc<RwLock<CollectionConfig>>,
38
39    /// Vector storage (on-disk, memory-mapped).
40    vector_storage: Arc<RwLock<MmapStorage>>,
41
42    /// Payload storage (on-disk, log-structured).
43    payload_storage: Arc<RwLock<LogPayloadStorage>>,
44
45    /// HNSW index for fast approximate nearest neighbor search.
46    index: Arc<HnswIndex>,
47
48    /// BM25 index for full-text search.
49    text_index: Arc<Bm25Index>,
50}
51
52impl Collection {
53    /// Creates a new collection at the specified path.
54    ///
55    /// # Errors
56    ///
57    /// Returns an error if the directory cannot be created or the config cannot be saved.
58    pub fn create(path: PathBuf, dimension: usize, metric: DistanceMetric) -> Result<Self> {
59        std::fs::create_dir_all(&path)?;
60
61        let name = path
62            .file_name()
63            .and_then(|n| n.to_str())
64            .unwrap_or("unknown")
65            .to_string();
66
67        let config = CollectionConfig {
68            name,
69            dimension,
70            metric,
71            point_count: 0,
72        };
73
74        // Initialize persistent storages
75        let vector_storage = Arc::new(RwLock::new(
76            MmapStorage::new(&path, dimension).map_err(Error::Io)?,
77        ));
78
79        let payload_storage = Arc::new(RwLock::new(
80            LogPayloadStorage::new(&path).map_err(Error::Io)?,
81        ));
82
83        // Create HNSW index
84        let index = Arc::new(HnswIndex::new(dimension, metric));
85
86        // Create BM25 index for full-text search
87        let text_index = Arc::new(Bm25Index::new());
88
89        let collection = Self {
90            path,
91            config: Arc::new(RwLock::new(config)),
92            vector_storage,
93            payload_storage,
94            index,
95            text_index,
96        };
97
98        collection.save_config()?;
99
100        Ok(collection)
101    }
102
103    /// Opens an existing collection from the specified path.
104    ///
105    /// # Errors
106    ///
107    /// Returns an error if the config file cannot be read or parsed.
108    pub fn open(path: PathBuf) -> Result<Self> {
109        let config_path = path.join("config.json");
110        let config_data = std::fs::read_to_string(&config_path)?;
111        let config: CollectionConfig =
112            serde_json::from_str(&config_data).map_err(|e| Error::Serialization(e.to_string()))?;
113
114        // Open persistent storages
115        let vector_storage = Arc::new(RwLock::new(
116            MmapStorage::new(&path, config.dimension).map_err(Error::Io)?,
117        ));
118
119        let payload_storage = Arc::new(RwLock::new(
120            LogPayloadStorage::new(&path).map_err(Error::Io)?,
121        ));
122
123        // Load HNSW index if it exists, otherwise create new (empty)
124        let index = if path.join("hnsw.bin").exists() {
125            Arc::new(HnswIndex::load(&path, config.dimension, config.metric).map_err(Error::Io)?)
126        } else {
127            Arc::new(HnswIndex::new(config.dimension, config.metric))
128        };
129
130        // Create and rebuild BM25 index from existing payloads
131        let text_index = Arc::new(Bm25Index::new());
132
133        // Rebuild BM25 index from persisted payloads
134        {
135            let storage = payload_storage.read();
136            let ids = storage.ids();
137            for id in ids {
138                if let Ok(Some(payload)) = storage.retrieve(id) {
139                    let text = Self::extract_text_from_payload(&payload);
140                    if !text.is_empty() {
141                        text_index.add_document(id, &text);
142                    }
143                }
144            }
145        }
146
147        Ok(Self {
148            path,
149            config: Arc::new(RwLock::new(config)),
150            vector_storage,
151            payload_storage,
152            index,
153            text_index,
154        })
155    }
156
157    /// Returns the collection configuration.
158    #[must_use]
159    pub fn config(&self) -> CollectionConfig {
160        self.config.read().clone()
161    }
162
163    /// Inserts or updates points in the collection.
164    ///
165    /// # Errors
166    ///
167    /// Returns an error if any point has a mismatched dimension.
168    pub fn upsert(&self, points: Vec<Point>) -> Result<()> {
169        let config = self.config.read();
170        let dimension = config.dimension;
171        drop(config);
172
173        // Validate dimensions first
174        for point in &points {
175            if point.dimension() != dimension {
176                return Err(Error::DimensionMismatch {
177                    expected: dimension,
178                    actual: point.dimension(),
179                });
180            }
181        }
182
183        let mut vector_storage = self.vector_storage.write();
184        let mut payload_storage = self.payload_storage.write();
185
186        for point in points {
187            // 1. Store Vector
188            vector_storage
189                .store(point.id, &point.vector)
190                .map_err(Error::Io)?;
191
192            // 2. Store Payload (if present)
193            if let Some(payload) = &point.payload {
194                payload_storage
195                    .store(point.id, payload)
196                    .map_err(Error::Io)?;
197            } else {
198                // If payload is None, check if we need to delete existing payload?
199                // For now, let's assume upsert with None doesn't clear payload unless explicit.
200                // Or consistency: Point represents full state. If None, maybe we should delete?
201                // Let's stick to: if None, do nothing (keep existing) or delete?
202                // Typically upsert replaces. Let's say if None, we delete potential existing payload to be consistent.
203                let _ = payload_storage.delete(point.id); // Ignore error if not found
204            }
205
206            // 3. Update Vector Index
207            // Note: HnswIndex.insert() skips if ID already exists (no updates supported)
208            // For true upsert semantics, we'd need to remove then re-insert
209            self.index.insert(point.id, &point.vector);
210
211            // 4. Update BM25 Text Index
212            if let Some(payload) = &point.payload {
213                let text = Self::extract_text_from_payload(payload);
214                if !text.is_empty() {
215                    self.text_index.add_document(point.id, &text);
216                }
217            } else {
218                // Remove from text index if payload was cleared
219                self.text_index.remove_document(point.id);
220            }
221        }
222
223        // Update point count
224        let mut config = self.config.write();
225        config.point_count = vector_storage.len();
226
227        // Auto-flush for durability (MVP choice: consistent but slower)
228        // In prod, this might be backgrounded or explicit.
229        vector_storage.flush().map_err(Error::Io)?;
230        payload_storage.flush().map_err(Error::Io)?;
231        self.index.save(&self.path).map_err(Error::Io)?;
232
233        Ok(())
234    }
235
236    /// Retrieves points by their IDs.
237    #[must_use]
238    pub fn get(&self, ids: &[u64]) -> Vec<Option<Point>> {
239        let vector_storage = self.vector_storage.read();
240        let payload_storage = self.payload_storage.read();
241
242        ids.iter()
243            .map(|&id| {
244                // Retrieve vector
245                let vector = vector_storage.retrieve(id).ok().flatten()?;
246
247                // Retrieve payload
248                let payload = payload_storage.retrieve(id).ok().flatten();
249
250                Some(Point {
251                    id,
252                    vector,
253                    payload,
254                })
255            })
256            .collect()
257    }
258
259    /// Deletes points by their IDs.
260    ///
261    /// # Errors
262    ///
263    /// Returns an error if storage operations fail.
264    pub fn delete(&self, ids: &[u64]) -> Result<()> {
265        let mut vector_storage = self.vector_storage.write();
266        let mut payload_storage = self.payload_storage.write();
267
268        for &id in ids {
269            vector_storage.delete(id).map_err(Error::Io)?;
270            payload_storage.delete(id).map_err(Error::Io)?;
271            self.index.remove(id);
272        }
273
274        let mut config = self.config.write();
275        config.point_count = vector_storage.len();
276
277        Ok(())
278    }
279
280    /// Searches for the k nearest neighbors of the query vector.
281    ///
282    /// Uses HNSW index for fast approximate nearest neighbor search.
283    ///
284    /// # Errors
285    ///
286    /// Returns an error if the query vector dimension doesn't match the collection.
287    pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
288        let config = self.config.read();
289
290        if query.len() != config.dimension {
291            return Err(Error::DimensionMismatch {
292                expected: config.dimension,
293                actual: query.len(),
294            });
295        }
296        drop(config);
297
298        // Use HNSW index for fast ANN search
299        let index_results = self.index.search(query, k);
300
301        let vector_storage = self.vector_storage.read();
302        let payload_storage = self.payload_storage.read();
303
304        // Map index results to SearchResult with full point data
305        let results: Vec<SearchResult> = index_results
306            .into_iter()
307            .filter_map(|(id, score)| {
308                // We need to fetch vector and payload
309                let vector = vector_storage.retrieve(id).ok().flatten()?;
310                let payload = payload_storage.retrieve(id).ok().flatten();
311
312                let point = Point {
313                    id,
314                    vector,
315                    payload,
316                };
317
318                Some(SearchResult::new(point, score))
319            })
320            .collect();
321
322        Ok(results)
323    }
324
325    /// Returns the number of points in the collection.
326    #[must_use]
327    pub fn len(&self) -> usize {
328        self.vector_storage.read().len()
329    }
330
331    /// Returns true if the collection is empty.
332    #[must_use]
333    pub fn is_empty(&self) -> bool {
334        self.vector_storage.read().is_empty()
335    }
336
337    /// Saves the collection configuration and index to disk.
338    ///
339    /// # Errors
340    ///
341    /// Returns an error if storage operations fail.
342    pub fn flush(&self) -> Result<()> {
343        self.save_config()?;
344        self.vector_storage.write().flush().map_err(Error::Io)?;
345        self.payload_storage.write().flush().map_err(Error::Io)?;
346        self.index.save(&self.path).map_err(Error::Io)?;
347        Ok(())
348    }
349
350    /// Saves the collection configuration to disk.
351    fn save_config(&self) -> Result<()> {
352        let config = self.config.read();
353        let config_path = self.path.join("config.json");
354        let config_data = serde_json::to_string_pretty(&*config)
355            .map_err(|e| Error::Serialization(e.to_string()))?;
356        std::fs::write(config_path, config_data)?;
357        Ok(())
358    }
359
360    /// Performs full-text search using BM25.
361    ///
362    /// # Arguments
363    ///
364    /// * `query` - Text query to search for
365    /// * `k` - Maximum number of results to return
366    ///
367    /// # Returns
368    ///
369    /// Vector of search results sorted by BM25 score (descending).
370    #[must_use]
371    pub fn text_search(&self, query: &str, k: usize) -> Vec<SearchResult> {
372        let bm25_results = self.text_index.search(query, k);
373
374        let vector_storage = self.vector_storage.read();
375        let payload_storage = self.payload_storage.read();
376
377        bm25_results
378            .into_iter()
379            .filter_map(|(id, score)| {
380                let vector = vector_storage.retrieve(id).ok().flatten()?;
381                let payload = payload_storage.retrieve(id).ok().flatten();
382
383                let point = Point {
384                    id,
385                    vector,
386                    payload,
387                };
388
389                Some(SearchResult::new(point, score))
390            })
391            .collect()
392    }
393
394    /// Performs hybrid search combining vector similarity and full-text search.
395    ///
396    /// Uses Reciprocal Rank Fusion (RRF) to combine results from both searches.
397    ///
398    /// # Arguments
399    ///
400    /// * `vector_query` - Query vector for similarity search
401    /// * `text_query` - Text query for BM25 search
402    /// * `k` - Maximum number of results to return
403    /// * `vector_weight` - Weight for vector results (0.0-1.0, default 0.5)
404    ///
405    /// # Errors
406    ///
407    /// Returns an error if the query vector dimension doesn't match.
408    pub fn hybrid_search(
409        &self,
410        vector_query: &[f32],
411        text_query: &str,
412        k: usize,
413        vector_weight: Option<f32>,
414    ) -> Result<Vec<SearchResult>> {
415        let config = self.config.read();
416        if vector_query.len() != config.dimension {
417            return Err(Error::DimensionMismatch {
418                expected: config.dimension,
419                actual: vector_query.len(),
420            });
421        }
422        drop(config);
423
424        let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
425        let text_weight = 1.0 - weight;
426
427        // Get vector search results (more than k to allow for fusion)
428        let vector_results = self.index.search(vector_query, k * 2);
429
430        // Get BM25 text search results
431        let text_results = self.text_index.search(text_query, k * 2);
432
433        // Perf: Apply RRF (Reciprocal Rank Fusion) with FxHashMap for faster hashing
434        // RRF score = 1 / (rank + 60) - the constant 60 is standard
435        let mut fused_scores: rustc_hash::FxHashMap<u64, f32> = rustc_hash::FxHashMap::default();
436
437        // Add vector scores with RRF
438        #[allow(clippy::cast_precision_loss)]
439        for (rank, (id, _)) in vector_results.iter().enumerate() {
440            let rrf_score = weight / (rank as f32 + 60.0);
441            *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
442        }
443
444        // Add text scores with RRF
445        #[allow(clippy::cast_precision_loss)]
446        for (rank, (id, _)) in text_results.iter().enumerate() {
447            let rrf_score = text_weight / (rank as f32 + 60.0);
448            *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
449        }
450
451        // Perf: Use partial sort for top-k instead of full sort
452        let mut scored_ids: Vec<_> = fused_scores.into_iter().collect();
453        if scored_ids.len() > k {
454            scored_ids.select_nth_unstable_by(k, |a, b| {
455                b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
456            });
457            scored_ids.truncate(k);
458            scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
459        } else {
460            scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
461        }
462
463        // Fetch full point data
464        let vector_storage = self.vector_storage.read();
465        let payload_storage = self.payload_storage.read();
466
467        let results: Vec<SearchResult> = scored_ids
468            .into_iter()
469            .filter_map(|(id, score)| {
470                let vector = vector_storage.retrieve(id).ok().flatten()?;
471                let payload = payload_storage.retrieve(id).ok().flatten();
472
473                let point = Point {
474                    id,
475                    vector,
476                    payload,
477                };
478
479                Some(SearchResult::new(point, score))
480            })
481            .collect();
482
483        Ok(results)
484    }
485
486    /// Extracts all string values from a JSON payload for text indexing.
487    fn extract_text_from_payload(payload: &serde_json::Value) -> String {
488        let mut texts = Vec::new();
489        Self::collect_strings(payload, &mut texts);
490        texts.join(" ")
491    }
492
493    /// Recursively collects all string values from a JSON value.
494    fn collect_strings(value: &serde_json::Value, texts: &mut Vec<String>) {
495        match value {
496            serde_json::Value::String(s) => texts.push(s.clone()),
497            serde_json::Value::Array(arr) => {
498                for item in arr {
499                    Self::collect_strings(item, texts);
500                }
501            }
502            serde_json::Value::Object(obj) => {
503                for v in obj.values() {
504                    Self::collect_strings(v, texts);
505                }
506            }
507            _ => {}
508        }
509    }
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515    use serde_json::json;
516    use tempfile::tempdir;
517
518    #[test]
519    fn test_collection_create() {
520        let dir = tempdir().unwrap();
521        let path = dir.path().join("test_collection");
522
523        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
524        let config = collection.config();
525
526        assert_eq!(config.dimension, 3);
527        assert_eq!(config.metric, DistanceMetric::Cosine);
528        assert_eq!(config.point_count, 0);
529    }
530
531    #[test]
532    fn test_collection_upsert_and_search() {
533        let dir = tempdir().unwrap();
534        let path = dir.path().join("test_collection");
535
536        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
537
538        let points = vec![
539            Point::without_payload(1, vec![1.0, 0.0, 0.0]),
540            Point::without_payload(2, vec![0.0, 1.0, 0.0]),
541            Point::without_payload(3, vec![0.0, 0.0, 1.0]),
542        ];
543
544        collection.upsert(points).unwrap();
545        assert_eq!(collection.len(), 3);
546
547        let query = vec![1.0, 0.0, 0.0];
548        let results = collection.search(&query, 2).unwrap();
549
550        assert_eq!(results.len(), 2);
551        assert_eq!(results[0].point.id, 1); // Most similar
552    }
553
554    #[test]
555    fn test_dimension_mismatch() {
556        let dir = tempdir().unwrap();
557        let path = dir.path().join("test_collection");
558
559        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
560
561        let points = vec![Point::without_payload(1, vec![1.0, 0.0])]; // Wrong dimension
562
563        let result = collection.upsert(points);
564        assert!(result.is_err());
565    }
566
567    #[test]
568    fn test_collection_open_existing() {
569        let dir = tempdir().unwrap();
570        let path = dir.path().join("test_collection");
571
572        // Create and populate collection
573        {
574            let collection =
575                Collection::create(path.clone(), 3, DistanceMetric::Euclidean).unwrap();
576            let points = vec![
577                Point::without_payload(1, vec![1.0, 2.0, 3.0]),
578                Point::without_payload(2, vec![4.0, 5.0, 6.0]),
579            ];
580            collection.upsert(points).unwrap();
581            collection.flush().unwrap();
582        }
583
584        // Reopen and verify
585        let collection = Collection::open(path).unwrap();
586        let config = collection.config();
587
588        assert_eq!(config.dimension, 3);
589        assert_eq!(config.metric, DistanceMetric::Euclidean);
590        assert_eq!(collection.len(), 2);
591    }
592
593    #[test]
594    fn test_collection_get_points() {
595        let dir = tempdir().unwrap();
596        let path = dir.path().join("test_collection");
597
598        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
599        let points = vec![
600            Point::without_payload(1, vec![1.0, 0.0, 0.0]),
601            Point::without_payload(2, vec![0.0, 1.0, 0.0]),
602        ];
603        collection.upsert(points).unwrap();
604
605        // Get existing points
606        let retrieved = collection.get(&[1, 2, 999]);
607
608        assert!(retrieved[0].is_some());
609        assert_eq!(retrieved[0].as_ref().unwrap().id, 1);
610        assert!(retrieved[1].is_some());
611        assert_eq!(retrieved[1].as_ref().unwrap().id, 2);
612        assert!(retrieved[2].is_none()); // 999 doesn't exist
613    }
614
615    #[test]
616    fn test_collection_delete_points() {
617        let dir = tempdir().unwrap();
618        let path = dir.path().join("test_collection");
619
620        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
621        let points = vec![
622            Point::without_payload(1, vec![1.0, 0.0, 0.0]),
623            Point::without_payload(2, vec![0.0, 1.0, 0.0]),
624            Point::without_payload(3, vec![0.0, 0.0, 1.0]),
625        ];
626        collection.upsert(points).unwrap();
627        assert_eq!(collection.len(), 3);
628
629        // Delete one point
630        collection.delete(&[2]).unwrap();
631        assert_eq!(collection.len(), 2);
632
633        // Verify it's gone
634        let retrieved = collection.get(&[2]);
635        assert!(retrieved[0].is_none());
636    }
637
638    #[test]
639    fn test_collection_is_empty() {
640        let dir = tempdir().unwrap();
641        let path = dir.path().join("test_collection");
642
643        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
644        assert!(collection.is_empty());
645
646        collection
647            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
648            .unwrap();
649        assert!(!collection.is_empty());
650    }
651
652    #[test]
653    fn test_collection_with_payload() {
654        let dir = tempdir().unwrap();
655        let path = dir.path().join("test_collection");
656
657        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
658
659        let points = vec![Point::new(
660            1,
661            vec![1.0, 0.0, 0.0],
662            Some(json!({"title": "Test Document", "category": "tech"})),
663        )];
664        collection.upsert(points).unwrap();
665
666        let retrieved = collection.get(&[1]);
667        assert!(retrieved[0].is_some());
668
669        let point = retrieved[0].as_ref().unwrap();
670        assert!(point.payload.is_some());
671        assert_eq!(point.payload.as_ref().unwrap()["title"], "Test Document");
672    }
673
674    #[test]
675    fn test_collection_search_dimension_mismatch() {
676        let dir = tempdir().unwrap();
677        let path = dir.path().join("test_collection");
678
679        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
680        collection
681            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
682            .unwrap();
683
684        // Search with wrong dimension
685        let result = collection.search(&[1.0, 0.0], 5);
686        assert!(result.is_err());
687    }
688
689    #[test]
690    fn test_collection_upsert_replaces_payload() {
691        let dir = tempdir().unwrap();
692        let path = dir.path().join("test_collection");
693
694        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
695
696        // Insert with payload
697        collection
698            .upsert(vec![Point::new(
699                1,
700                vec![1.0, 0.0, 0.0],
701                Some(json!({"version": 1})),
702            )])
703            .unwrap();
704
705        // Upsert without payload (should clear it)
706        collection
707            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
708            .unwrap();
709
710        let retrieved = collection.get(&[1]);
711        let point = retrieved[0].as_ref().unwrap();
712        assert!(point.payload.is_none());
713    }
714
715    #[test]
716    fn test_collection_flush() {
717        let dir = tempdir().unwrap();
718        let path = dir.path().join("test_collection");
719
720        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
721        collection
722            .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
723            .unwrap();
724
725        // Explicit flush should succeed
726        let result = collection.flush();
727        assert!(result.is_ok());
728    }
729
730    #[test]
731    fn test_collection_euclidean_metric() {
732        let dir = tempdir().unwrap();
733        let path = dir.path().join("test_collection");
734
735        let collection = Collection::create(path, 3, DistanceMetric::Euclidean).unwrap();
736
737        let points = vec![
738            Point::without_payload(1, vec![0.0, 0.0, 0.0]),
739            Point::without_payload(2, vec![1.0, 0.0, 0.0]),
740            Point::without_payload(3, vec![10.0, 0.0, 0.0]),
741        ];
742        collection.upsert(points).unwrap();
743
744        let query = vec![0.5, 0.0, 0.0];
745        let results = collection.search(&query, 3).unwrap();
746
747        // Point 1 (0,0,0) and Point 2 (1,0,0) should be closest to query (0.5,0,0)
748        assert!(results[0].point.id == 1 || results[0].point.id == 2);
749    }
750
751    #[test]
752    fn test_collection_text_search() {
753        let dir = tempdir().unwrap();
754        let path = dir.path().join("test_collection");
755
756        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
757
758        let points = vec![
759            Point::new(
760                1,
761                vec![1.0, 0.0, 0.0],
762                Some(json!({"title": "Rust Programming", "content": "Learn Rust language"})),
763            ),
764            Point::new(
765                2,
766                vec![0.0, 1.0, 0.0],
767                Some(json!({"title": "Python Tutorial", "content": "Python is great"})),
768            ),
769            Point::new(
770                3,
771                vec![0.0, 0.0, 1.0],
772                Some(json!({"title": "Rust Performance", "content": "Rust is fast"})),
773            ),
774        ];
775        collection.upsert(points).unwrap();
776
777        // Search for "rust" - should match docs 1 and 3
778        let results = collection.text_search("rust", 10);
779        assert_eq!(results.len(), 2);
780
781        let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
782        assert!(ids.contains(&1));
783        assert!(ids.contains(&3));
784    }
785
786    #[test]
787    fn test_collection_hybrid_search() {
788        let dir = tempdir().unwrap();
789        let path = dir.path().join("test_collection");
790
791        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
792
793        let points = vec![
794            Point::new(
795                1,
796                vec![1.0, 0.0, 0.0],
797                Some(json!({"title": "Rust Programming"})),
798            ),
799            Point::new(
800                2,
801                vec![0.9, 0.1, 0.0], // Similar vector to query
802                Some(json!({"title": "Python Programming"})),
803            ),
804            Point::new(
805                3,
806                vec![0.0, 1.0, 0.0],
807                Some(json!({"title": "Rust Performance"})),
808            ),
809        ];
810        collection.upsert(points).unwrap();
811
812        // Hybrid search: vector close to [1,0,0], text "rust"
813        // Doc 1 matches both (vector + text)
814        // Doc 2 matches vector only
815        // Doc 3 matches text only
816        let query = vec![1.0, 0.0, 0.0];
817        let results = collection
818            .hybrid_search(&query, "rust", 3, Some(0.5))
819            .unwrap();
820
821        assert!(!results.is_empty());
822        // Doc 1 should rank high (matches both)
823        assert_eq!(results[0].point.id, 1);
824    }
825
826    #[test]
827    fn test_extract_text_from_payload() {
828        // Test nested payload extraction
829        let payload = json!({
830            "title": "Hello",
831            "meta": {
832                "author": "World",
833                "tags": ["rust", "fast"]
834            }
835        });
836
837        let text = Collection::extract_text_from_payload(&payload);
838        assert!(text.contains("Hello"));
839        assert!(text.contains("World"));
840        assert!(text.contains("rust"));
841        assert!(text.contains("fast"));
842    }
843
844    #[test]
845    fn test_text_search_empty_query() {
846        let dir = tempdir().unwrap();
847        let path = dir.path().join("test_collection");
848
849        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
850
851        let points = vec![Point::new(
852            1,
853            vec![1.0, 0.0, 0.0],
854            Some(json!({"content": "test document"})),
855        )];
856        collection.upsert(points).unwrap();
857
858        // Empty query should return empty results
859        let results = collection.text_search("", 10);
860        assert!(results.is_empty());
861    }
862
863    #[test]
864    fn test_text_search_no_payload() {
865        let dir = tempdir().unwrap();
866        let path = dir.path().join("test_collection");
867
868        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
869
870        // Points without payload
871        let points = vec![
872            Point::new(1, vec![1.0, 0.0, 0.0], None),
873            Point::new(2, vec![0.0, 1.0, 0.0], None),
874        ];
875        collection.upsert(points).unwrap();
876
877        // Text search should return empty (no text indexed)
878        let results = collection.text_search("test", 10);
879        assert!(results.is_empty());
880    }
881
882    #[test]
883    fn test_hybrid_search_text_weight_zero() {
884        let dir = tempdir().unwrap();
885        let path = dir.path().join("test_collection");
886
887        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
888
889        let points = vec![
890            Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Rust"}))),
891            Point::new(2, vec![0.9, 0.1, 0.0], Some(json!({"title": "Python"}))),
892        ];
893        collection.upsert(points).unwrap();
894
895        // vector_weight=1.0 means text_weight=0.0 (pure vector search)
896        let query = vec![0.9, 0.1, 0.0];
897        let results = collection
898            .hybrid_search(&query, "rust", 2, Some(1.0))
899            .unwrap();
900
901        // Doc 2 should be first (closest vector) even though "rust" matches doc 1
902        assert_eq!(results[0].point.id, 2);
903    }
904
905    #[test]
906    fn test_hybrid_search_vector_weight_zero() {
907        let dir = tempdir().unwrap();
908        let path = dir.path().join("test_collection");
909
910        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
911
912        let points = vec![
913            Point::new(
914                1,
915                vec![1.0, 0.0, 0.0],
916                Some(json!({"title": "Rust programming language"})),
917            ),
918            Point::new(
919                2,
920                vec![0.99, 0.01, 0.0], // Very close to query vector
921                Some(json!({"title": "Python programming"})),
922            ),
923        ];
924        collection.upsert(points).unwrap();
925
926        // vector_weight=0.0 means text_weight=1.0 (pure text search)
927        let query = vec![0.99, 0.01, 0.0];
928        let results = collection
929            .hybrid_search(&query, "rust", 2, Some(0.0))
930            .unwrap();
931
932        // Doc 1 should be first (matches "rust") even though doc 2 has closer vector
933        assert_eq!(results[0].point.id, 1);
934    }
935
936    #[test]
937    fn test_bm25_update_document() {
938        let dir = tempdir().unwrap();
939        let path = dir.path().join("test_collection");
940
941        let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
942
943        // Insert initial document
944        let points = vec![Point::new(
945            1,
946            vec![1.0, 0.0, 0.0],
947            Some(json!({"content": "rust programming"})),
948        )];
949        collection.upsert(points).unwrap();
950
951        // Verify it's indexed
952        let results = collection.text_search("rust", 10);
953        assert_eq!(results.len(), 1);
954
955        // Update document with different text
956        let points = vec![Point::new(
957            1,
958            vec![1.0, 0.0, 0.0],
959            Some(json!({"content": "python programming"})),
960        )];
961        collection.upsert(points).unwrap();
962
963        // Should no longer match "rust"
964        let results = collection.text_search("rust", 10);
965        assert!(results.is_empty());
966
967        // Should now match "python"
968        let results = collection.text_search("python", 10);
969        assert_eq!(results.len(), 1);
970    }
971
972    #[test]
973    fn test_bm25_large_dataset() {
974        let dir = tempdir().unwrap();
975        let path = dir.path().join("test_collection");
976
977        let collection = Collection::create(path, 4, DistanceMetric::Cosine).unwrap();
978
979        // Insert 100 documents
980        let points: Vec<Point> = (0..100)
981            .map(|i| {
982                let content = if i % 10 == 0 {
983                    format!("rust document number {i}")
984                } else {
985                    format!("other document number {i}")
986                };
987                Point::new(
988                    i,
989                    vec![0.1, 0.2, 0.3, 0.4],
990                    Some(json!({"content": content})),
991                )
992            })
993            .collect();
994        collection.upsert(points).unwrap();
995
996        // Search for "rust" - should find 10 documents (0, 10, 20, ..., 90)
997        let results = collection.text_search("rust", 100);
998        assert_eq!(results.len(), 10);
999
1000        // All results should have IDs divisible by 10
1001        for result in &results {
1002            assert_eq!(result.point.id % 10, 0);
1003        }
1004    }
1005
1006    #[test]
1007    fn test_bm25_persistence_on_reopen() {
1008        let dir = tempdir().unwrap();
1009        let path = dir.path().join("test_collection");
1010
1011        // Create collection and add documents
1012        {
1013            let collection = Collection::create(path.clone(), 4, DistanceMetric::Cosine).unwrap();
1014
1015            let points = vec![
1016                Point::new(
1017                    1,
1018                    vec![1.0, 0.0, 0.0, 0.0],
1019                    Some(json!({"content": "Rust programming language"})),
1020                ),
1021                Point::new(
1022                    2,
1023                    vec![0.0, 1.0, 0.0, 0.0],
1024                    Some(json!({"content": "Python tutorial"})),
1025                ),
1026                Point::new(
1027                    3,
1028                    vec![0.0, 0.0, 1.0, 0.0],
1029                    Some(json!({"content": "Rust is fast and safe"})),
1030                ),
1031            ];
1032            collection.upsert(points).unwrap();
1033
1034            // Verify search works before closing
1035            let results = collection.text_search("rust", 10);
1036            assert_eq!(results.len(), 2);
1037        }
1038
1039        // Reopen collection and verify BM25 index is rebuilt
1040        {
1041            let collection = Collection::open(path).unwrap();
1042
1043            // BM25 should be rebuilt from persisted payloads
1044            let results = collection.text_search("rust", 10);
1045            assert_eq!(results.len(), 2);
1046
1047            let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1048            assert!(ids.contains(&1));
1049            assert!(ids.contains(&3));
1050        }
1051    }
1052}