1use crate::distance::DistanceMetric;
4use crate::error::{Error, Result};
5use crate::index::{Bm25Index, HnswIndex, VectorIndex};
6use crate::point::{Point, SearchResult};
7use crate::quantization::{BinaryQuantizedVector, QuantizedVector, StorageMode};
8use crate::storage::{LogPayloadStorage, MmapStorage, PayloadStorage, VectorStorage};
9
10use std::collections::HashMap;
11
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15use std::sync::Arc;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CollectionConfig {
20 pub name: String,
22
23 pub dimension: usize,
25
26 pub metric: DistanceMetric,
28
29 pub point_count: usize,
31
32 #[serde(default)]
34 pub storage_mode: StorageMode,
35}
36
37#[derive(Clone)]
39pub struct Collection {
40 path: PathBuf,
42
43 config: Arc<RwLock<CollectionConfig>>,
45
46 vector_storage: Arc<RwLock<MmapStorage>>,
48
49 payload_storage: Arc<RwLock<LogPayloadStorage>>,
51
52 index: Arc<HnswIndex>,
54
55 text_index: Arc<Bm25Index>,
57
58 sq8_cache: Arc<RwLock<HashMap<u64, QuantizedVector>>>,
60
61 binary_cache: Arc<RwLock<HashMap<u64, BinaryQuantizedVector>>>,
63}
64
65impl Collection {
66 pub fn create(path: PathBuf, dimension: usize, metric: DistanceMetric) -> Result<Self> {
72 Self::create_with_options(path, dimension, metric, StorageMode::default())
73 }
74
75 pub fn create_with_options(
88 path: PathBuf,
89 dimension: usize,
90 metric: DistanceMetric,
91 storage_mode: StorageMode,
92 ) -> Result<Self> {
93 std::fs::create_dir_all(&path)?;
94
95 let name = path
96 .file_name()
97 .and_then(|n| n.to_str())
98 .unwrap_or("unknown")
99 .to_string();
100
101 let config = CollectionConfig {
102 name,
103 dimension,
104 metric,
105 point_count: 0,
106 storage_mode,
107 };
108
109 let vector_storage = Arc::new(RwLock::new(
111 MmapStorage::new(&path, dimension).map_err(Error::Io)?,
112 ));
113
114 let payload_storage = Arc::new(RwLock::new(
115 LogPayloadStorage::new(&path).map_err(Error::Io)?,
116 ));
117
118 let index = Arc::new(HnswIndex::new(dimension, metric));
120
121 let text_index = Arc::new(Bm25Index::new());
123
124 let collection = Self {
125 path,
126 config: Arc::new(RwLock::new(config)),
127 vector_storage,
128 payload_storage,
129 index,
130 text_index,
131 sq8_cache: Arc::new(RwLock::new(HashMap::new())),
132 binary_cache: Arc::new(RwLock::new(HashMap::new())),
133 };
134
135 collection.save_config()?;
136
137 Ok(collection)
138 }
139
140 pub fn open(path: PathBuf) -> Result<Self> {
146 let config_path = path.join("config.json");
147 let config_data = std::fs::read_to_string(&config_path)?;
148 let config: CollectionConfig =
149 serde_json::from_str(&config_data).map_err(|e| Error::Serialization(e.to_string()))?;
150
151 let vector_storage = Arc::new(RwLock::new(
153 MmapStorage::new(&path, config.dimension).map_err(Error::Io)?,
154 ));
155
156 let payload_storage = Arc::new(RwLock::new(
157 LogPayloadStorage::new(&path).map_err(Error::Io)?,
158 ));
159
160 let index = if path.join("hnsw.bin").exists() {
162 Arc::new(HnswIndex::load(&path, config.dimension, config.metric).map_err(Error::Io)?)
163 } else {
164 Arc::new(HnswIndex::new(config.dimension, config.metric))
165 };
166
167 let text_index = Arc::new(Bm25Index::new());
169
170 {
172 let storage = payload_storage.read();
173 let ids = storage.ids();
174 for id in ids {
175 if let Ok(Some(payload)) = storage.retrieve(id) {
176 let text = Self::extract_text_from_payload(&payload);
177 if !text.is_empty() {
178 text_index.add_document(id, &text);
179 }
180 }
181 }
182 }
183
184 Ok(Self {
185 path,
186 config: Arc::new(RwLock::new(config)),
187 vector_storage,
188 payload_storage,
189 index,
190 text_index,
191 sq8_cache: Arc::new(RwLock::new(HashMap::new())),
192 binary_cache: Arc::new(RwLock::new(HashMap::new())),
193 })
194 }
195
196 #[must_use]
198 pub fn config(&self) -> CollectionConfig {
199 self.config.read().clone()
200 }
201
202 pub fn upsert(&self, points: impl IntoIterator<Item = Point>) -> Result<()> {
210 let points: Vec<Point> = points.into_iter().collect();
211 let config = self.config.read();
212 let dimension = config.dimension;
213 let storage_mode = config.storage_mode;
214 drop(config);
215
216 for point in &points {
218 if point.dimension() != dimension {
219 return Err(Error::DimensionMismatch {
220 expected: dimension,
221 actual: point.dimension(),
222 });
223 }
224 }
225
226 let mut vector_storage = self.vector_storage.write();
227 let mut payload_storage = self.payload_storage.write();
228
229 let mut sq8_cache = match storage_mode {
231 StorageMode::SQ8 => Some(self.sq8_cache.write()),
232 _ => None,
233 };
234 let mut binary_cache = match storage_mode {
235 StorageMode::Binary => Some(self.binary_cache.write()),
236 _ => None,
237 };
238
239 for point in points {
240 vector_storage
242 .store(point.id, &point.vector)
243 .map_err(Error::Io)?;
244
245 match storage_mode {
247 StorageMode::SQ8 => {
248 if let Some(ref mut cache) = sq8_cache {
249 let quantized = QuantizedVector::from_f32(&point.vector);
250 cache.insert(point.id, quantized);
251 }
252 }
253 StorageMode::Binary => {
254 if let Some(ref mut cache) = binary_cache {
255 let quantized = BinaryQuantizedVector::from_f32(&point.vector);
256 cache.insert(point.id, quantized);
257 }
258 }
259 StorageMode::Full => {}
260 }
261
262 if let Some(payload) = &point.payload {
264 payload_storage
265 .store(point.id, payload)
266 .map_err(Error::Io)?;
267 } else {
268 let _ = payload_storage.delete(point.id);
269 }
270
271 self.index.insert(point.id, &point.vector);
273
274 if let Some(payload) = &point.payload {
276 let text = Self::extract_text_from_payload(payload);
277 if !text.is_empty() {
278 self.text_index.add_document(point.id, &text);
279 }
280 } else {
281 self.text_index.remove_document(point.id);
282 }
283 }
284
285 let mut config = self.config.write();
287 config.point_count = vector_storage.len();
288
289 vector_storage.flush().map_err(Error::Io)?;
291 payload_storage.flush().map_err(Error::Io)?;
292 self.index.save(&self.path).map_err(Error::Io)?;
293
294 Ok(())
295 }
296
297 pub fn upsert_bulk(&self, points: &[Point]) -> Result<usize> {
312 if points.is_empty() {
313 return Ok(0);
314 }
315
316 let config = self.config.read();
317 let dimension = config.dimension;
318 drop(config);
319
320 for point in points {
322 if point.dimension() != dimension {
323 return Err(Error::DimensionMismatch {
324 expected: dimension,
325 actual: point.dimension(),
326 });
327 }
328 }
329
330 let vectors_for_hnsw: Vec<(u64, Vec<f32>)> =
332 points.iter().map(|p| (p.id, p.vector.clone())).collect();
333
334 let vectors_for_storage: Vec<(u64, &[f32])> = vectors_for_hnsw
337 .iter()
338 .map(|(id, v)| (*id, v.as_slice()))
339 .collect();
340
341 let mut vector_storage = self.vector_storage.write();
342 vector_storage
343 .store_batch(&vectors_for_storage)
344 .map_err(Error::Io)?;
345 drop(vector_storage);
346
347 let mut payload_storage = self.payload_storage.write();
349 for point in points {
350 if let Some(payload) = &point.payload {
351 payload_storage
352 .store(point.id, payload)
353 .map_err(Error::Io)?;
354
355 let text = Self::extract_text_from_payload(payload);
357 if !text.is_empty() {
358 self.text_index.add_document(point.id, &text);
359 }
360 }
361 }
362 drop(payload_storage);
363
364 let inserted = self.index.insert_batch_parallel(vectors_for_hnsw);
366 self.index.set_searching_mode();
367
368 let mut config = self.config.write();
370 config.point_count = self.vector_storage.read().len();
371 drop(config);
372
373 self.vector_storage.write().flush().map_err(Error::Io)?;
377 self.payload_storage.write().flush().map_err(Error::Io)?;
378 Ok(inserted)
382 }
383
384 #[must_use]
386 pub fn get(&self, ids: &[u64]) -> Vec<Option<Point>> {
387 let vector_storage = self.vector_storage.read();
388 let payload_storage = self.payload_storage.read();
389
390 ids.iter()
391 .map(|&id| {
392 let vector = vector_storage.retrieve(id).ok().flatten()?;
394
395 let payload = payload_storage.retrieve(id).ok().flatten();
397
398 Some(Point {
399 id,
400 vector,
401 payload,
402 })
403 })
404 .collect()
405 }
406
407 pub fn delete(&self, ids: &[u64]) -> Result<()> {
413 let mut vector_storage = self.vector_storage.write();
414 let mut payload_storage = self.payload_storage.write();
415
416 for &id in ids {
417 vector_storage.delete(id).map_err(Error::Io)?;
418 payload_storage.delete(id).map_err(Error::Io)?;
419 self.index.remove(id);
420 }
421
422 let mut config = self.config.write();
423 config.point_count = vector_storage.len();
424
425 Ok(())
426 }
427
428 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
436 let config = self.config.read();
437
438 if query.len() != config.dimension {
439 return Err(Error::DimensionMismatch {
440 expected: config.dimension,
441 actual: query.len(),
442 });
443 }
444 drop(config);
445
446 let index_results = self.index.search(query, k);
448
449 let vector_storage = self.vector_storage.read();
450 let payload_storage = self.payload_storage.read();
451
452 let results: Vec<SearchResult> = index_results
454 .into_iter()
455 .filter_map(|(id, score)| {
456 let vector = vector_storage.retrieve(id).ok().flatten()?;
458 let payload = payload_storage.retrieve(id).ok().flatten();
459
460 let point = Point {
461 id,
462 vector,
463 payload,
464 };
465
466 Some(SearchResult::new(point, score))
467 })
468 .collect();
469
470 Ok(results)
471 }
472
473 pub fn search_ids(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
491 let config = self.config.read();
492
493 if query.len() != config.dimension {
494 return Err(Error::DimensionMismatch {
495 expected: config.dimension,
496 actual: query.len(),
497 });
498 }
499 drop(config);
500
501 Ok(self.index.search(query, k))
503 }
504
505 pub fn search_batch_parallel(
523 &self,
524 queries: &[&[f32]],
525 k: usize,
526 ) -> Result<Vec<Vec<SearchResult>>> {
527 use crate::index::SearchQuality;
528
529 let config = self.config.read();
530 let dimension = config.dimension;
531 drop(config);
532
533 for query in queries {
535 if query.len() != dimension {
536 return Err(Error::DimensionMismatch {
537 expected: dimension,
538 actual: query.len(),
539 });
540 }
541 }
542
543 let index_results = self
545 .index
546 .search_batch_parallel(queries, k, SearchQuality::Balanced);
547
548 let vector_storage = self.vector_storage.read();
550 let payload_storage = self.payload_storage.read();
551
552 let results: Vec<Vec<SearchResult>> = index_results
553 .into_iter()
554 .map(|query_results: Vec<(u64, f32)>| {
555 query_results
556 .into_iter()
557 .filter_map(|(id, score)| {
558 let vector = vector_storage.retrieve(id).ok().flatten()?;
559 let payload = payload_storage.retrieve(id).ok().flatten();
560 Some(SearchResult {
561 point: Point {
562 id,
563 vector,
564 payload,
565 },
566 score,
567 })
568 })
569 .collect()
570 })
571 .collect();
572
573 Ok(results)
574 }
575
576 #[must_use]
579 pub fn len(&self) -> usize {
580 self.config.read().point_count
581 }
582
583 #[must_use]
586 pub fn is_empty(&self) -> bool {
587 self.config.read().point_count == 0
588 }
589
590 pub fn flush(&self) -> Result<()> {
596 self.save_config()?;
597 self.vector_storage.write().flush().map_err(Error::Io)?;
598 self.payload_storage.write().flush().map_err(Error::Io)?;
599 self.index.save(&self.path).map_err(Error::Io)?;
600 Ok(())
601 }
602
603 fn save_config(&self) -> Result<()> {
605 let config = self.config.read();
606 let config_path = self.path.join("config.json");
607 let config_data = serde_json::to_string_pretty(&*config)
608 .map_err(|e| Error::Serialization(e.to_string()))?;
609 std::fs::write(config_path, config_data)?;
610 Ok(())
611 }
612
613 #[must_use]
624 pub fn text_search(&self, query: &str, k: usize) -> Vec<SearchResult> {
625 let bm25_results = self.text_index.search(query, k);
626
627 let vector_storage = self.vector_storage.read();
628 let payload_storage = self.payload_storage.read();
629
630 bm25_results
631 .into_iter()
632 .filter_map(|(id, score)| {
633 let vector = vector_storage.retrieve(id).ok().flatten()?;
634 let payload = payload_storage.retrieve(id).ok().flatten();
635
636 let point = Point {
637 id,
638 vector,
639 payload,
640 };
641
642 Some(SearchResult::new(point, score))
643 })
644 .collect()
645 }
646
647 pub fn hybrid_search(
662 &self,
663 vector_query: &[f32],
664 text_query: &str,
665 k: usize,
666 vector_weight: Option<f32>,
667 ) -> Result<Vec<SearchResult>> {
668 let config = self.config.read();
669 if vector_query.len() != config.dimension {
670 return Err(Error::DimensionMismatch {
671 expected: config.dimension,
672 actual: vector_query.len(),
673 });
674 }
675 drop(config);
676
677 let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
678 let text_weight = 1.0 - weight;
679
680 let vector_results = self.index.search(vector_query, k * 2);
682
683 let text_results = self.text_index.search(text_query, k * 2);
685
686 let mut fused_scores: rustc_hash::FxHashMap<u64, f32> = rustc_hash::FxHashMap::default();
689
690 #[allow(clippy::cast_precision_loss)]
692 for (rank, (id, _)) in vector_results.iter().enumerate() {
693 let rrf_score = weight / (rank as f32 + 60.0);
694 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
695 }
696
697 #[allow(clippy::cast_precision_loss)]
699 for (rank, (id, _)) in text_results.iter().enumerate() {
700 let rrf_score = text_weight / (rank as f32 + 60.0);
701 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
702 }
703
704 let mut scored_ids: Vec<_> = fused_scores.into_iter().collect();
706 if scored_ids.len() > k {
707 scored_ids.select_nth_unstable_by(k, |a, b| {
708 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
709 });
710 scored_ids.truncate(k);
711 scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
712 } else {
713 scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
714 }
715
716 let vector_storage = self.vector_storage.read();
718 let payload_storage = self.payload_storage.read();
719
720 let results: Vec<SearchResult> = scored_ids
721 .into_iter()
722 .filter_map(|(id, score)| {
723 let vector = vector_storage.retrieve(id).ok().flatten()?;
724 let payload = payload_storage.retrieve(id).ok().flatten();
725
726 let point = Point {
727 id,
728 vector,
729 payload,
730 };
731
732 Some(SearchResult::new(point, score))
733 })
734 .collect();
735
736 Ok(results)
737 }
738
739 fn extract_text_from_payload(payload: &serde_json::Value) -> String {
741 let mut texts = Vec::new();
742 Self::collect_strings(payload, &mut texts);
743 texts.join(" ")
744 }
745
746 fn collect_strings(value: &serde_json::Value, texts: &mut Vec<String>) {
748 match value {
749 serde_json::Value::String(s) => texts.push(s.clone()),
750 serde_json::Value::Array(arr) => {
751 for item in arr {
752 Self::collect_strings(item, texts);
753 }
754 }
755 serde_json::Value::Object(obj) => {
756 for v in obj.values() {
757 Self::collect_strings(v, texts);
758 }
759 }
760 _ => {}
761 }
762 }
763}
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768 use serde_json::json;
769 use tempfile::tempdir;
770
771 #[test]
772 fn test_collection_create() {
773 let dir = tempdir().unwrap();
774 let path = dir.path().join("test_collection");
775
776 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
777 let config = collection.config();
778
779 assert_eq!(config.dimension, 3);
780 assert_eq!(config.metric, DistanceMetric::Cosine);
781 assert_eq!(config.point_count, 0);
782 }
783
784 #[test]
785 fn test_collection_upsert_and_search() {
786 let dir = tempdir().unwrap();
787 let path = dir.path().join("test_collection");
788
789 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
790
791 let points = vec![
792 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
793 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
794 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
795 ];
796
797 collection.upsert(points).unwrap();
798 assert_eq!(collection.len(), 3);
799
800 let query = vec![1.0, 0.0, 0.0];
801 let results = collection.search(&query, 2).unwrap();
802
803 assert_eq!(results.len(), 2);
804 assert_eq!(results[0].point.id, 1); }
806
807 #[test]
808 fn test_dimension_mismatch() {
809 let dir = tempdir().unwrap();
810 let path = dir.path().join("test_collection");
811
812 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
813
814 let points = vec![Point::without_payload(1, vec![1.0, 0.0])]; let result = collection.upsert(points);
817 assert!(result.is_err());
818 }
819
820 #[test]
821 fn test_collection_open_existing() {
822 let dir = tempdir().unwrap();
823 let path = dir.path().join("test_collection");
824
825 {
827 let collection =
828 Collection::create(path.clone(), 3, DistanceMetric::Euclidean).unwrap();
829 let points = vec![
830 Point::without_payload(1, vec![1.0, 2.0, 3.0]),
831 Point::without_payload(2, vec![4.0, 5.0, 6.0]),
832 ];
833 collection.upsert(points).unwrap();
834 collection.flush().unwrap();
835 }
836
837 let collection = Collection::open(path).unwrap();
839 let config = collection.config();
840
841 assert_eq!(config.dimension, 3);
842 assert_eq!(config.metric, DistanceMetric::Euclidean);
843 assert_eq!(collection.len(), 2);
844 }
845
846 #[test]
847 fn test_collection_get_points() {
848 let dir = tempdir().unwrap();
849 let path = dir.path().join("test_collection");
850
851 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
852 let points = vec![
853 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
854 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
855 ];
856 collection.upsert(points).unwrap();
857
858 let retrieved = collection.get(&[1, 2, 999]);
860
861 assert!(retrieved[0].is_some());
862 assert_eq!(retrieved[0].as_ref().unwrap().id, 1);
863 assert!(retrieved[1].is_some());
864 assert_eq!(retrieved[1].as_ref().unwrap().id, 2);
865 assert!(retrieved[2].is_none()); }
867
868 #[test]
869 fn test_collection_delete_points() {
870 let dir = tempdir().unwrap();
871 let path = dir.path().join("test_collection");
872
873 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
874 let points = vec![
875 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
876 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
877 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
878 ];
879 collection.upsert(points).unwrap();
880 assert_eq!(collection.len(), 3);
881
882 collection.delete(&[2]).unwrap();
884 assert_eq!(collection.len(), 2);
885
886 let retrieved = collection.get(&[2]);
888 assert!(retrieved[0].is_none());
889 }
890
891 #[test]
892 fn test_collection_is_empty() {
893 let dir = tempdir().unwrap();
894 let path = dir.path().join("test_collection");
895
896 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
897 assert!(collection.is_empty());
898
899 collection
900 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
901 .unwrap();
902 assert!(!collection.is_empty());
903 }
904
905 #[test]
906 fn test_collection_with_payload() {
907 let dir = tempdir().unwrap();
908 let path = dir.path().join("test_collection");
909
910 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
911
912 let points = vec![Point::new(
913 1,
914 vec![1.0, 0.0, 0.0],
915 Some(json!({"title": "Test Document", "category": "tech"})),
916 )];
917 collection.upsert(points).unwrap();
918
919 let retrieved = collection.get(&[1]);
920 assert!(retrieved[0].is_some());
921
922 let point = retrieved[0].as_ref().unwrap();
923 assert!(point.payload.is_some());
924 assert_eq!(point.payload.as_ref().unwrap()["title"], "Test Document");
925 }
926
927 #[test]
928 fn test_collection_search_dimension_mismatch() {
929 let dir = tempdir().unwrap();
930 let path = dir.path().join("test_collection");
931
932 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
933 collection
934 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
935 .unwrap();
936
937 let result = collection.search(&[1.0, 0.0], 5);
939 assert!(result.is_err());
940 }
941
942 #[test]
943 fn test_collection_search_ids_fast() {
944 let dir = tempdir().unwrap();
946 let path = dir.path().join("test_collection");
947
948 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
949 collection
950 .upsert(vec![
951 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
952 Point::without_payload(2, vec![0.9, 0.1, 0.0]),
953 Point::without_payload(3, vec![0.0, 1.0, 0.0]),
954 ])
955 .unwrap();
956
957 let results = collection.search_ids(&[1.0, 0.0, 0.0], 2).unwrap();
959 assert_eq!(results.len(), 2);
960 assert_eq!(results[0].0, 1); assert!(results[0].1 > results[1].1); }
963
964 #[test]
965 fn test_collection_upsert_replaces_payload() {
966 let dir = tempdir().unwrap();
967 let path = dir.path().join("test_collection");
968
969 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
970
971 collection
973 .upsert(vec![Point::new(
974 1,
975 vec![1.0, 0.0, 0.0],
976 Some(json!({"version": 1})),
977 )])
978 .unwrap();
979
980 collection
982 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
983 .unwrap();
984
985 let retrieved = collection.get(&[1]);
986 let point = retrieved[0].as_ref().unwrap();
987 assert!(point.payload.is_none());
988 }
989
990 #[test]
991 fn test_collection_flush() {
992 let dir = tempdir().unwrap();
993 let path = dir.path().join("test_collection");
994
995 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
996 collection
997 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
998 .unwrap();
999
1000 let result = collection.flush();
1002 assert!(result.is_ok());
1003 }
1004
1005 #[test]
1006 fn test_collection_euclidean_metric() {
1007 let dir = tempdir().unwrap();
1008 let path = dir.path().join("test_collection");
1009
1010 let collection = Collection::create(path, 3, DistanceMetric::Euclidean).unwrap();
1011
1012 let points = vec![
1013 Point::without_payload(1, vec![0.0, 0.0, 0.0]),
1014 Point::without_payload(2, vec![1.0, 0.0, 0.0]),
1015 Point::without_payload(3, vec![10.0, 0.0, 0.0]),
1016 ];
1017 collection.upsert(points).unwrap();
1018
1019 let query = vec![0.5, 0.0, 0.0];
1020 let results = collection.search(&query, 3).unwrap();
1021
1022 assert!(results[0].point.id == 1 || results[0].point.id == 2);
1024 }
1025
1026 #[test]
1027 fn test_collection_text_search() {
1028 let dir = tempdir().unwrap();
1029 let path = dir.path().join("test_collection");
1030
1031 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1032
1033 let points = vec![
1034 Point::new(
1035 1,
1036 vec![1.0, 0.0, 0.0],
1037 Some(json!({"title": "Rust Programming", "content": "Learn Rust language"})),
1038 ),
1039 Point::new(
1040 2,
1041 vec![0.0, 1.0, 0.0],
1042 Some(json!({"title": "Python Tutorial", "content": "Python is great"})),
1043 ),
1044 Point::new(
1045 3,
1046 vec![0.0, 0.0, 1.0],
1047 Some(json!({"title": "Rust Performance", "content": "Rust is fast"})),
1048 ),
1049 ];
1050 collection.upsert(points).unwrap();
1051
1052 let results = collection.text_search("rust", 10);
1054 assert_eq!(results.len(), 2);
1055
1056 let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1057 assert!(ids.contains(&1));
1058 assert!(ids.contains(&3));
1059 }
1060
1061 #[test]
1062 fn test_collection_hybrid_search() {
1063 let dir = tempdir().unwrap();
1064 let path = dir.path().join("test_collection");
1065
1066 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1067
1068 let points = vec![
1069 Point::new(
1070 1,
1071 vec![1.0, 0.0, 0.0],
1072 Some(json!({"title": "Rust Programming"})),
1073 ),
1074 Point::new(
1075 2,
1076 vec![0.9, 0.1, 0.0], Some(json!({"title": "Python Programming"})),
1078 ),
1079 Point::new(
1080 3,
1081 vec![0.0, 1.0, 0.0],
1082 Some(json!({"title": "Rust Performance"})),
1083 ),
1084 ];
1085 collection.upsert(points).unwrap();
1086
1087 let query = vec![1.0, 0.0, 0.0];
1092 let results = collection
1093 .hybrid_search(&query, "rust", 3, Some(0.5))
1094 .unwrap();
1095
1096 assert!(!results.is_empty());
1097 assert_eq!(results[0].point.id, 1);
1099 }
1100
1101 #[test]
1102 fn test_extract_text_from_payload() {
1103 let payload = json!({
1105 "title": "Hello",
1106 "meta": {
1107 "author": "World",
1108 "tags": ["rust", "fast"]
1109 }
1110 });
1111
1112 let text = Collection::extract_text_from_payload(&payload);
1113 assert!(text.contains("Hello"));
1114 assert!(text.contains("World"));
1115 assert!(text.contains("rust"));
1116 assert!(text.contains("fast"));
1117 }
1118
1119 #[test]
1120 fn test_text_search_empty_query() {
1121 let dir = tempdir().unwrap();
1122 let path = dir.path().join("test_collection");
1123
1124 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1125
1126 let points = vec![Point::new(
1127 1,
1128 vec![1.0, 0.0, 0.0],
1129 Some(json!({"content": "test document"})),
1130 )];
1131 collection.upsert(points).unwrap();
1132
1133 let results = collection.text_search("", 10);
1135 assert!(results.is_empty());
1136 }
1137
1138 #[test]
1139 fn test_text_search_no_payload() {
1140 let dir = tempdir().unwrap();
1141 let path = dir.path().join("test_collection");
1142
1143 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1144
1145 let points = vec![
1147 Point::new(1, vec![1.0, 0.0, 0.0], None),
1148 Point::new(2, vec![0.0, 1.0, 0.0], None),
1149 ];
1150 collection.upsert(points).unwrap();
1151
1152 let results = collection.text_search("test", 10);
1154 assert!(results.is_empty());
1155 }
1156
1157 #[test]
1158 fn test_hybrid_search_text_weight_zero() {
1159 let dir = tempdir().unwrap();
1160 let path = dir.path().join("test_collection");
1161
1162 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1163
1164 let points = vec![
1165 Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Rust"}))),
1166 Point::new(2, vec![0.9, 0.1, 0.0], Some(json!({"title": "Python"}))),
1167 ];
1168 collection.upsert(points).unwrap();
1169
1170 let query = vec![0.9, 0.1, 0.0];
1172 let results = collection
1173 .hybrid_search(&query, "rust", 2, Some(1.0))
1174 .unwrap();
1175
1176 assert_eq!(results[0].point.id, 2);
1178 }
1179
1180 #[test]
1181 fn test_hybrid_search_vector_weight_zero() {
1182 let dir = tempdir().unwrap();
1183 let path = dir.path().join("test_collection");
1184
1185 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1186
1187 let points = vec![
1188 Point::new(
1189 1,
1190 vec![1.0, 0.0, 0.0],
1191 Some(json!({"title": "Rust programming language"})),
1192 ),
1193 Point::new(
1194 2,
1195 vec![0.99, 0.01, 0.0], Some(json!({"title": "Python programming"})),
1197 ),
1198 ];
1199 collection.upsert(points).unwrap();
1200
1201 let query = vec![0.99, 0.01, 0.0];
1203 let results = collection
1204 .hybrid_search(&query, "rust", 2, Some(0.0))
1205 .unwrap();
1206
1207 assert_eq!(results[0].point.id, 1);
1209 }
1210
1211 #[test]
1212 fn test_bm25_update_document() {
1213 let dir = tempdir().unwrap();
1214 let path = dir.path().join("test_collection");
1215
1216 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1217
1218 let points = vec![Point::new(
1220 1,
1221 vec![1.0, 0.0, 0.0],
1222 Some(json!({"content": "rust programming"})),
1223 )];
1224 collection.upsert(points).unwrap();
1225
1226 let results = collection.text_search("rust", 10);
1228 assert_eq!(results.len(), 1);
1229
1230 let points = vec![Point::new(
1232 1,
1233 vec![1.0, 0.0, 0.0],
1234 Some(json!({"content": "python programming"})),
1235 )];
1236 collection.upsert(points).unwrap();
1237
1238 let results = collection.text_search("rust", 10);
1240 assert!(results.is_empty());
1241
1242 let results = collection.text_search("python", 10);
1244 assert_eq!(results.len(), 1);
1245 }
1246
1247 #[test]
1248 fn test_bm25_large_dataset() {
1249 let dir = tempdir().unwrap();
1250 let path = dir.path().join("test_collection");
1251
1252 let collection = Collection::create(path, 4, DistanceMetric::Cosine).unwrap();
1253
1254 let points: Vec<Point> = (0..100)
1256 .map(|i| {
1257 let content = if i % 10 == 0 {
1258 format!("rust document number {i}")
1259 } else {
1260 format!("other document number {i}")
1261 };
1262 Point::new(
1263 i,
1264 vec![0.1, 0.2, 0.3, 0.4],
1265 Some(json!({"content": content})),
1266 )
1267 })
1268 .collect();
1269 collection.upsert(points).unwrap();
1270
1271 let results = collection.text_search("rust", 100);
1273 assert_eq!(results.len(), 10);
1274
1275 for result in &results {
1277 assert_eq!(result.point.id % 10, 0);
1278 }
1279 }
1280
1281 #[test]
1282 fn test_bm25_persistence_on_reopen() {
1283 let dir = tempdir().unwrap();
1284 let path = dir.path().join("test_collection");
1285
1286 {
1288 let collection = Collection::create(path.clone(), 4, DistanceMetric::Cosine).unwrap();
1289
1290 let points = vec![
1291 Point::new(
1292 1,
1293 vec![1.0, 0.0, 0.0, 0.0],
1294 Some(json!({"content": "Rust programming language"})),
1295 ),
1296 Point::new(
1297 2,
1298 vec![0.0, 1.0, 0.0, 0.0],
1299 Some(json!({"content": "Python tutorial"})),
1300 ),
1301 Point::new(
1302 3,
1303 vec![0.0, 0.0, 1.0, 0.0],
1304 Some(json!({"content": "Rust is fast and safe"})),
1305 ),
1306 ];
1307 collection.upsert(points).unwrap();
1308
1309 let results = collection.text_search("rust", 10);
1311 assert_eq!(results.len(), 2);
1312 }
1313
1314 {
1316 let collection = Collection::open(path).unwrap();
1317
1318 let results = collection.text_search("rust", 10);
1320 assert_eq!(results.len(), 2);
1321
1322 let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1323 assert!(ids.contains(&1));
1324 assert!(ids.contains(&3));
1325 }
1326 }
1327
1328 #[test]
1333 fn test_upsert_bulk_basic() {
1334 let dir = tempdir().unwrap();
1335 let path = dir.path().join("test_collection");
1336 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1337
1338 let points = vec![
1339 Point::new(1, vec![1.0, 0.0, 0.0], None),
1340 Point::new(2, vec![0.0, 1.0, 0.0], None),
1341 Point::new(3, vec![0.0, 0.0, 1.0], None),
1342 ];
1343
1344 let inserted = collection.upsert_bulk(&points).unwrap();
1345 assert_eq!(inserted, 3);
1346 assert_eq!(collection.len(), 3);
1347 }
1348
1349 #[test]
1350 fn test_upsert_bulk_with_payload() {
1351 let dir = tempdir().unwrap();
1352 let path = dir.path().join("test_collection");
1353 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1354
1355 let points = vec![
1356 Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Doc 1"}))),
1357 Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"title": "Doc 2"}))),
1358 ];
1359
1360 collection.upsert_bulk(&points).unwrap();
1361 let retrieved = collection.get(&[1, 2]);
1362 assert_eq!(retrieved.len(), 2);
1363 assert!(retrieved[0].as_ref().unwrap().payload.is_some());
1364 }
1365
1366 #[test]
1367 fn test_upsert_bulk_empty() {
1368 let dir = tempdir().unwrap();
1369 let path = dir.path().join("test_collection");
1370 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1371
1372 let points: Vec<Point> = vec![];
1373 let inserted = collection.upsert_bulk(&points).unwrap();
1374 assert_eq!(inserted, 0);
1375 }
1376
1377 #[test]
1378 fn test_upsert_bulk_dimension_mismatch() {
1379 let dir = tempdir().unwrap();
1380 let path = dir.path().join("test_collection");
1381 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1382
1383 let points = vec![
1384 Point::new(1, vec![1.0, 0.0, 0.0], None),
1385 Point::new(2, vec![0.0, 1.0], None), ];
1387
1388 let result = collection.upsert_bulk(&points);
1389 assert!(result.is_err());
1390 }
1391
1392 #[test]
1393 #[allow(clippy::cast_precision_loss)]
1394 fn test_upsert_bulk_large_batch() {
1395 let dir = tempdir().unwrap();
1396 let path = dir.path().join("test_collection");
1397 let collection = Collection::create(path, 64, DistanceMetric::Cosine).unwrap();
1398
1399 let points: Vec<Point> = (0_u64..500)
1400 .map(|i| {
1401 let vector: Vec<f32> = (0_u64..64)
1402 .map(|j| ((i + j) % 100) as f32 / 100.0)
1403 .collect();
1404 Point::new(i, vector, None)
1405 })
1406 .collect();
1407
1408 let inserted = collection.upsert_bulk(&points).unwrap();
1409 assert_eq!(inserted, 500);
1410 assert_eq!(collection.len(), 500);
1411 }
1412
1413 #[test]
1414 fn test_upsert_bulk_search_works() {
1415 let dir = tempdir().unwrap();
1416 let path = dir.path().join("test_collection");
1417 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1418
1419 let points = vec![
1421 Point::new(1, vec![1.0, 0.0, 0.0], None),
1422 Point::new(2, vec![0.0, 1.0, 0.0], None),
1423 Point::new(3, vec![0.0, 0.0, 1.0], None),
1424 ];
1425
1426 collection.upsert_bulk(&points).unwrap();
1427
1428 let query = vec![1.0, 0.0, 0.0];
1429 let results = collection.search(&query, 3).unwrap();
1430 assert!(!results.is_empty());
1431 assert_eq!(results[0].point.id, 1);
1433 }
1434
1435 #[test]
1436 fn test_upsert_bulk_bm25_indexing() {
1437 let dir = tempdir().unwrap();
1438 let path = dir.path().join("test_collection");
1439 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1440
1441 let points = vec![
1442 Point::new(
1443 1,
1444 vec![1.0, 0.0, 0.0],
1445 Some(json!({"content": "Rust lang"})),
1446 ),
1447 Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"content": "Python"}))),
1448 Point::new(
1449 3,
1450 vec![0.0, 0.0, 1.0],
1451 Some(json!({"content": "Rust fast"})),
1452 ),
1453 ];
1454
1455 collection.upsert_bulk(&points).unwrap();
1456 let results = collection.text_search("rust", 10);
1457 assert_eq!(results.len(), 2);
1458 }
1459}