1use crate::distance::DistanceMetric;
4use crate::error::{Error, Result};
5use crate::index::{Bm25Index, HnswIndex, VectorIndex};
6use crate::point::{Point, SearchResult};
7use crate::quantization::{BinaryQuantizedVector, QuantizedVector, StorageMode};
8use crate::storage::{LogPayloadStorage, MmapStorage, PayloadStorage, VectorStorage};
9
10use std::collections::HashMap;
11
12use parking_lot::RwLock;
13use serde::{Deserialize, Serialize};
14use std::path::PathBuf;
15use std::sync::Arc;
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct CollectionConfig {
20 pub name: String,
22
23 pub dimension: usize,
25
26 pub metric: DistanceMetric,
28
29 pub point_count: usize,
31
32 #[serde(default)]
34 pub storage_mode: StorageMode,
35}
36
37#[derive(Clone)]
39pub struct Collection {
40 path: PathBuf,
42
43 config: Arc<RwLock<CollectionConfig>>,
45
46 vector_storage: Arc<RwLock<MmapStorage>>,
48
49 payload_storage: Arc<RwLock<LogPayloadStorage>>,
51
52 index: Arc<HnswIndex>,
54
55 text_index: Arc<Bm25Index>,
57
58 sq8_cache: Arc<RwLock<HashMap<u64, QuantizedVector>>>,
60
61 binary_cache: Arc<RwLock<HashMap<u64, BinaryQuantizedVector>>>,
63}
64
65impl Collection {
66 pub fn create(path: PathBuf, dimension: usize, metric: DistanceMetric) -> Result<Self> {
72 Self::create_with_options(path, dimension, metric, StorageMode::default())
73 }
74
75 pub fn create_with_options(
88 path: PathBuf,
89 dimension: usize,
90 metric: DistanceMetric,
91 storage_mode: StorageMode,
92 ) -> Result<Self> {
93 std::fs::create_dir_all(&path)?;
94
95 let name = path
96 .file_name()
97 .and_then(|n| n.to_str())
98 .unwrap_or("unknown")
99 .to_string();
100
101 let config = CollectionConfig {
102 name,
103 dimension,
104 metric,
105 point_count: 0,
106 storage_mode,
107 };
108
109 let vector_storage = Arc::new(RwLock::new(
111 MmapStorage::new(&path, dimension).map_err(Error::Io)?,
112 ));
113
114 let payload_storage = Arc::new(RwLock::new(
115 LogPayloadStorage::new(&path).map_err(Error::Io)?,
116 ));
117
118 let index = Arc::new(HnswIndex::new(dimension, metric));
120
121 let text_index = Arc::new(Bm25Index::new());
123
124 let collection = Self {
125 path,
126 config: Arc::new(RwLock::new(config)),
127 vector_storage,
128 payload_storage,
129 index,
130 text_index,
131 sq8_cache: Arc::new(RwLock::new(HashMap::new())),
132 binary_cache: Arc::new(RwLock::new(HashMap::new())),
133 };
134
135 collection.save_config()?;
136
137 Ok(collection)
138 }
139
140 pub fn open(path: PathBuf) -> Result<Self> {
146 let config_path = path.join("config.json");
147 let config_data = std::fs::read_to_string(&config_path)?;
148 let config: CollectionConfig =
149 serde_json::from_str(&config_data).map_err(|e| Error::Serialization(e.to_string()))?;
150
151 let vector_storage = Arc::new(RwLock::new(
153 MmapStorage::new(&path, config.dimension).map_err(Error::Io)?,
154 ));
155
156 let payload_storage = Arc::new(RwLock::new(
157 LogPayloadStorage::new(&path).map_err(Error::Io)?,
158 ));
159
160 let index = if path.join("hnsw.bin").exists() {
162 Arc::new(HnswIndex::load(&path, config.dimension, config.metric).map_err(Error::Io)?)
163 } else {
164 Arc::new(HnswIndex::new(config.dimension, config.metric))
165 };
166
167 let text_index = Arc::new(Bm25Index::new());
169
170 {
172 let storage = payload_storage.read();
173 let ids = storage.ids();
174 for id in ids {
175 if let Ok(Some(payload)) = storage.retrieve(id) {
176 let text = Self::extract_text_from_payload(&payload);
177 if !text.is_empty() {
178 text_index.add_document(id, &text);
179 }
180 }
181 }
182 }
183
184 Ok(Self {
185 path,
186 config: Arc::new(RwLock::new(config)),
187 vector_storage,
188 payload_storage,
189 index,
190 text_index,
191 sq8_cache: Arc::new(RwLock::new(HashMap::new())),
192 binary_cache: Arc::new(RwLock::new(HashMap::new())),
193 })
194 }
195
196 #[must_use]
198 pub fn config(&self) -> CollectionConfig {
199 self.config.read().clone()
200 }
201
202 pub fn upsert(&self, points: impl IntoIterator<Item = Point>) -> Result<()> {
210 let points: Vec<Point> = points.into_iter().collect();
211 let config = self.config.read();
212 let dimension = config.dimension;
213 let storage_mode = config.storage_mode;
214 drop(config);
215
216 for point in &points {
218 if point.dimension() != dimension {
219 return Err(Error::DimensionMismatch {
220 expected: dimension,
221 actual: point.dimension(),
222 });
223 }
224 }
225
226 let mut vector_storage = self.vector_storage.write();
227 let mut payload_storage = self.payload_storage.write();
228
229 let mut sq8_cache = match storage_mode {
231 StorageMode::SQ8 => Some(self.sq8_cache.write()),
232 _ => None,
233 };
234 let mut binary_cache = match storage_mode {
235 StorageMode::Binary => Some(self.binary_cache.write()),
236 _ => None,
237 };
238
239 for point in points {
240 vector_storage
242 .store(point.id, &point.vector)
243 .map_err(Error::Io)?;
244
245 match storage_mode {
247 StorageMode::SQ8 => {
248 if let Some(ref mut cache) = sq8_cache {
249 let quantized = QuantizedVector::from_f32(&point.vector);
250 cache.insert(point.id, quantized);
251 }
252 }
253 StorageMode::Binary => {
254 if let Some(ref mut cache) = binary_cache {
255 let quantized = BinaryQuantizedVector::from_f32(&point.vector);
256 cache.insert(point.id, quantized);
257 }
258 }
259 StorageMode::Full => {}
260 }
261
262 if let Some(payload) = &point.payload {
264 payload_storage
265 .store(point.id, payload)
266 .map_err(Error::Io)?;
267 } else {
268 let _ = payload_storage.delete(point.id);
269 }
270
271 self.index.insert(point.id, &point.vector);
273
274 if let Some(payload) = &point.payload {
276 let text = Self::extract_text_from_payload(payload);
277 if !text.is_empty() {
278 self.text_index.add_document(point.id, &text);
279 }
280 } else {
281 self.text_index.remove_document(point.id);
282 }
283 }
284
285 let mut config = self.config.write();
287 config.point_count = vector_storage.len();
288
289 vector_storage.flush().map_err(Error::Io)?;
291 payload_storage.flush().map_err(Error::Io)?;
292 self.index.save(&self.path).map_err(Error::Io)?;
293
294 Ok(())
295 }
296
297 pub fn upsert_bulk(&self, points: &[Point]) -> Result<usize> {
312 if points.is_empty() {
313 return Ok(0);
314 }
315
316 let config = self.config.read();
317 let dimension = config.dimension;
318 drop(config);
319
320 for point in points {
322 if point.dimension() != dimension {
323 return Err(Error::DimensionMismatch {
324 expected: dimension,
325 actual: point.dimension(),
326 });
327 }
328 }
329
330 let vectors_for_hnsw: Vec<(u64, Vec<f32>)> =
332 points.iter().map(|p| (p.id, p.vector.clone())).collect();
333
334 let vectors_for_storage: Vec<(u64, &[f32])> = vectors_for_hnsw
337 .iter()
338 .map(|(id, v)| (*id, v.as_slice()))
339 .collect();
340
341 let mut vector_storage = self.vector_storage.write();
342 vector_storage
343 .store_batch(&vectors_for_storage)
344 .map_err(Error::Io)?;
345 drop(vector_storage);
346
347 let mut payload_storage = self.payload_storage.write();
349 for point in points {
350 if let Some(payload) = &point.payload {
351 payload_storage
352 .store(point.id, payload)
353 .map_err(Error::Io)?;
354
355 let text = Self::extract_text_from_payload(payload);
357 if !text.is_empty() {
358 self.text_index.add_document(point.id, &text);
359 }
360 }
361 }
362 drop(payload_storage);
363
364 let inserted = self.index.insert_batch_parallel(vectors_for_hnsw);
366 self.index.set_searching_mode();
367
368 let mut config = self.config.write();
370 config.point_count = self.vector_storage.read().len();
371 drop(config);
372
373 self.vector_storage.write().flush().map_err(Error::Io)?;
377 self.payload_storage.write().flush().map_err(Error::Io)?;
378 Ok(inserted)
382 }
383
384 #[must_use]
386 pub fn get(&self, ids: &[u64]) -> Vec<Option<Point>> {
387 let vector_storage = self.vector_storage.read();
388 let payload_storage = self.payload_storage.read();
389
390 ids.iter()
391 .map(|&id| {
392 let vector = vector_storage.retrieve(id).ok().flatten()?;
394
395 let payload = payload_storage.retrieve(id).ok().flatten();
397
398 Some(Point {
399 id,
400 vector,
401 payload,
402 })
403 })
404 .collect()
405 }
406
407 pub fn delete(&self, ids: &[u64]) -> Result<()> {
413 let mut vector_storage = self.vector_storage.write();
414 let mut payload_storage = self.payload_storage.write();
415
416 for &id in ids {
417 vector_storage.delete(id).map_err(Error::Io)?;
418 payload_storage.delete(id).map_err(Error::Io)?;
419 self.index.remove(id);
420 }
421
422 let mut config = self.config.write();
423 config.point_count = vector_storage.len();
424
425 Ok(())
426 }
427
428 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
436 let config = self.config.read();
437
438 if query.len() != config.dimension {
439 return Err(Error::DimensionMismatch {
440 expected: config.dimension,
441 actual: query.len(),
442 });
443 }
444 drop(config);
445
446 let index_results = self.index.search(query, k);
448
449 let vector_storage = self.vector_storage.read();
450 let payload_storage = self.payload_storage.read();
451
452 let results: Vec<SearchResult> = index_results
454 .into_iter()
455 .filter_map(|(id, score)| {
456 let vector = vector_storage.retrieve(id).ok().flatten()?;
458 let payload = payload_storage.retrieve(id).ok().flatten();
459
460 let point = Point {
461 id,
462 vector,
463 payload,
464 };
465
466 Some(SearchResult::new(point, score))
467 })
468 .collect();
469
470 Ok(results)
471 }
472
473 pub fn search_with_ef(
482 &self,
483 query: &[f32],
484 k: usize,
485 ef_search: usize,
486 ) -> Result<Vec<SearchResult>> {
487 let config = self.config.read();
488
489 if query.len() != config.dimension {
490 return Err(Error::DimensionMismatch {
491 expected: config.dimension,
492 actual: query.len(),
493 });
494 }
495 drop(config);
496
497 let quality = match ef_search {
499 0..=64 => crate::SearchQuality::Fast,
500 65..=128 => crate::SearchQuality::Balanced,
501 129..=256 => crate::SearchQuality::Accurate,
502 257..=1024 => crate::SearchQuality::HighRecall,
503 _ => crate::SearchQuality::Perfect,
504 };
505
506 let index_results = self.index.search_with_quality(query, k, quality);
507
508 let vector_storage = self.vector_storage.read();
509 let payload_storage = self.payload_storage.read();
510
511 let results: Vec<SearchResult> = index_results
512 .into_iter()
513 .filter_map(|(id, score)| {
514 let vector = vector_storage.retrieve(id).ok().flatten()?;
515 let payload = payload_storage.retrieve(id).ok().flatten();
516
517 let point = Point {
518 id,
519 vector,
520 payload,
521 };
522
523 Some(SearchResult::new(point, score))
524 })
525 .collect();
526
527 Ok(results)
528 }
529
530 pub fn search_ids(&self, query: &[f32], k: usize) -> Result<Vec<(u64, f32)>> {
548 let config = self.config.read();
549
550 if query.len() != config.dimension {
551 return Err(Error::DimensionMismatch {
552 expected: config.dimension,
553 actual: query.len(),
554 });
555 }
556 drop(config);
557
558 Ok(self.index.search(query, k))
560 }
561
562 pub fn search_batch_parallel(
580 &self,
581 queries: &[&[f32]],
582 k: usize,
583 ) -> Result<Vec<Vec<SearchResult>>> {
584 use crate::index::SearchQuality;
585
586 let config = self.config.read();
587 let dimension = config.dimension;
588 drop(config);
589
590 for query in queries {
592 if query.len() != dimension {
593 return Err(Error::DimensionMismatch {
594 expected: dimension,
595 actual: query.len(),
596 });
597 }
598 }
599
600 let index_results = self
602 .index
603 .search_batch_parallel(queries, k, SearchQuality::Balanced);
604
605 let vector_storage = self.vector_storage.read();
607 let payload_storage = self.payload_storage.read();
608
609 let results: Vec<Vec<SearchResult>> = index_results
610 .into_iter()
611 .map(|query_results: Vec<(u64, f32)>| {
612 query_results
613 .into_iter()
614 .filter_map(|(id, score)| {
615 let vector = vector_storage.retrieve(id).ok().flatten()?;
616 let payload = payload_storage.retrieve(id).ok().flatten();
617 Some(SearchResult {
618 point: Point {
619 id,
620 vector,
621 payload,
622 },
623 score,
624 })
625 })
626 .collect()
627 })
628 .collect();
629
630 Ok(results)
631 }
632
633 #[must_use]
636 pub fn len(&self) -> usize {
637 self.config.read().point_count
638 }
639
640 #[must_use]
643 pub fn is_empty(&self) -> bool {
644 self.config.read().point_count == 0
645 }
646
647 pub fn flush(&self) -> Result<()> {
653 self.save_config()?;
654 self.vector_storage.write().flush().map_err(Error::Io)?;
655 self.payload_storage.write().flush().map_err(Error::Io)?;
656 self.index.save(&self.path).map_err(Error::Io)?;
657 Ok(())
658 }
659
660 fn save_config(&self) -> Result<()> {
662 let config = self.config.read();
663 let config_path = self.path.join("config.json");
664 let config_data = serde_json::to_string_pretty(&*config)
665 .map_err(|e| Error::Serialization(e.to_string()))?;
666 std::fs::write(config_path, config_data)?;
667 Ok(())
668 }
669
670 #[must_use]
681 pub fn text_search(&self, query: &str, k: usize) -> Vec<SearchResult> {
682 let bm25_results = self.text_index.search(query, k);
683
684 let vector_storage = self.vector_storage.read();
685 let payload_storage = self.payload_storage.read();
686
687 bm25_results
688 .into_iter()
689 .filter_map(|(id, score)| {
690 let vector = vector_storage.retrieve(id).ok().flatten()?;
691 let payload = payload_storage.retrieve(id).ok().flatten();
692
693 let point = Point {
694 id,
695 vector,
696 payload,
697 };
698
699 Some(SearchResult::new(point, score))
700 })
701 .collect()
702 }
703
704 pub fn hybrid_search(
719 &self,
720 vector_query: &[f32],
721 text_query: &str,
722 k: usize,
723 vector_weight: Option<f32>,
724 ) -> Result<Vec<SearchResult>> {
725 let config = self.config.read();
726 if vector_query.len() != config.dimension {
727 return Err(Error::DimensionMismatch {
728 expected: config.dimension,
729 actual: vector_query.len(),
730 });
731 }
732 drop(config);
733
734 let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
735 let text_weight = 1.0 - weight;
736
737 let vector_results = self.index.search(vector_query, k * 2);
739
740 let text_results = self.text_index.search(text_query, k * 2);
742
743 let mut fused_scores: rustc_hash::FxHashMap<u64, f32> = rustc_hash::FxHashMap::default();
746
747 #[allow(clippy::cast_precision_loss)]
749 for (rank, (id, _)) in vector_results.iter().enumerate() {
750 let rrf_score = weight / (rank as f32 + 60.0);
751 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
752 }
753
754 #[allow(clippy::cast_precision_loss)]
756 for (rank, (id, _)) in text_results.iter().enumerate() {
757 let rrf_score = text_weight / (rank as f32 + 60.0);
758 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
759 }
760
761 let mut scored_ids: Vec<_> = fused_scores.into_iter().collect();
763 if scored_ids.len() > k {
764 scored_ids.select_nth_unstable_by(k, |a, b| {
765 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
766 });
767 scored_ids.truncate(k);
768 scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
769 } else {
770 scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
771 }
772
773 let vector_storage = self.vector_storage.read();
775 let payload_storage = self.payload_storage.read();
776
777 let results: Vec<SearchResult> = scored_ids
778 .into_iter()
779 .filter_map(|(id, score)| {
780 let vector = vector_storage.retrieve(id).ok().flatten()?;
781 let payload = payload_storage.retrieve(id).ok().flatten();
782
783 let point = Point {
784 id,
785 vector,
786 payload,
787 };
788
789 Some(SearchResult::new(point, score))
790 })
791 .collect();
792
793 Ok(results)
794 }
795
796 fn extract_text_from_payload(payload: &serde_json::Value) -> String {
798 let mut texts = Vec::new();
799 Self::collect_strings(payload, &mut texts);
800 texts.join(" ")
801 }
802
803 fn collect_strings(value: &serde_json::Value, texts: &mut Vec<String>) {
805 match value {
806 serde_json::Value::String(s) => texts.push(s.clone()),
807 serde_json::Value::Array(arr) => {
808 for item in arr {
809 Self::collect_strings(item, texts);
810 }
811 }
812 serde_json::Value::Object(obj) => {
813 for v in obj.values() {
814 Self::collect_strings(v, texts);
815 }
816 }
817 _ => {}
818 }
819 }
820}
821
822#[cfg(test)]
823mod tests {
824 use super::*;
825 use serde_json::json;
826 use tempfile::tempdir;
827
828 #[test]
829 fn test_collection_create() {
830 let dir = tempdir().unwrap();
831 let path = dir.path().join("test_collection");
832
833 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
834 let config = collection.config();
835
836 assert_eq!(config.dimension, 3);
837 assert_eq!(config.metric, DistanceMetric::Cosine);
838 assert_eq!(config.point_count, 0);
839 }
840
841 #[test]
842 fn test_collection_upsert_and_search() {
843 let dir = tempdir().unwrap();
844 let path = dir.path().join("test_collection");
845
846 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
847
848 let points = vec![
849 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
850 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
851 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
852 ];
853
854 collection.upsert(points).unwrap();
855 assert_eq!(collection.len(), 3);
856
857 let query = vec![1.0, 0.0, 0.0];
858 let results = collection.search(&query, 2).unwrap();
859
860 assert_eq!(results.len(), 2);
861 assert_eq!(results[0].point.id, 1); }
863
864 #[test]
865 fn test_dimension_mismatch() {
866 let dir = tempdir().unwrap();
867 let path = dir.path().join("test_collection");
868
869 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
870
871 let points = vec![Point::without_payload(1, vec![1.0, 0.0])]; let result = collection.upsert(points);
874 assert!(result.is_err());
875 }
876
877 #[test]
878 fn test_collection_open_existing() {
879 let dir = tempdir().unwrap();
880 let path = dir.path().join("test_collection");
881
882 {
884 let collection =
885 Collection::create(path.clone(), 3, DistanceMetric::Euclidean).unwrap();
886 let points = vec![
887 Point::without_payload(1, vec![1.0, 2.0, 3.0]),
888 Point::without_payload(2, vec![4.0, 5.0, 6.0]),
889 ];
890 collection.upsert(points).unwrap();
891 collection.flush().unwrap();
892 }
893
894 let collection = Collection::open(path).unwrap();
896 let config = collection.config();
897
898 assert_eq!(config.dimension, 3);
899 assert_eq!(config.metric, DistanceMetric::Euclidean);
900 assert_eq!(collection.len(), 2);
901 }
902
903 #[test]
904 fn test_collection_get_points() {
905 let dir = tempdir().unwrap();
906 let path = dir.path().join("test_collection");
907
908 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
909 let points = vec![
910 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
911 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
912 ];
913 collection.upsert(points).unwrap();
914
915 let retrieved = collection.get(&[1, 2, 999]);
917
918 assert!(retrieved[0].is_some());
919 assert_eq!(retrieved[0].as_ref().unwrap().id, 1);
920 assert!(retrieved[1].is_some());
921 assert_eq!(retrieved[1].as_ref().unwrap().id, 2);
922 assert!(retrieved[2].is_none()); }
924
925 #[test]
926 fn test_collection_delete_points() {
927 let dir = tempdir().unwrap();
928 let path = dir.path().join("test_collection");
929
930 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
931 let points = vec![
932 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
933 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
934 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
935 ];
936 collection.upsert(points).unwrap();
937 assert_eq!(collection.len(), 3);
938
939 collection.delete(&[2]).unwrap();
941 assert_eq!(collection.len(), 2);
942
943 let retrieved = collection.get(&[2]);
945 assert!(retrieved[0].is_none());
946 }
947
948 #[test]
949 fn test_collection_is_empty() {
950 let dir = tempdir().unwrap();
951 let path = dir.path().join("test_collection");
952
953 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
954 assert!(collection.is_empty());
955
956 collection
957 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
958 .unwrap();
959 assert!(!collection.is_empty());
960 }
961
962 #[test]
963 fn test_collection_with_payload() {
964 let dir = tempdir().unwrap();
965 let path = dir.path().join("test_collection");
966
967 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
968
969 let points = vec![Point::new(
970 1,
971 vec![1.0, 0.0, 0.0],
972 Some(json!({"title": "Test Document", "category": "tech"})),
973 )];
974 collection.upsert(points).unwrap();
975
976 let retrieved = collection.get(&[1]);
977 assert!(retrieved[0].is_some());
978
979 let point = retrieved[0].as_ref().unwrap();
980 assert!(point.payload.is_some());
981 assert_eq!(point.payload.as_ref().unwrap()["title"], "Test Document");
982 }
983
984 #[test]
985 fn test_collection_search_dimension_mismatch() {
986 let dir = tempdir().unwrap();
987 let path = dir.path().join("test_collection");
988
989 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
990 collection
991 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
992 .unwrap();
993
994 let result = collection.search(&[1.0, 0.0], 5);
996 assert!(result.is_err());
997 }
998
999 #[test]
1000 fn test_collection_search_ids_fast() {
1001 let dir = tempdir().unwrap();
1003 let path = dir.path().join("test_collection");
1004
1005 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1006 collection
1007 .upsert(vec![
1008 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
1009 Point::without_payload(2, vec![0.9, 0.1, 0.0]),
1010 Point::without_payload(3, vec![0.0, 1.0, 0.0]),
1011 ])
1012 .unwrap();
1013
1014 let results = collection.search_ids(&[1.0, 0.0, 0.0], 2).unwrap();
1016 assert_eq!(results.len(), 2);
1017 assert_eq!(results[0].0, 1); assert!(results[0].1 > results[1].1); }
1020
1021 #[test]
1022 fn test_collection_upsert_replaces_payload() {
1023 let dir = tempdir().unwrap();
1024 let path = dir.path().join("test_collection");
1025
1026 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1027
1028 collection
1030 .upsert(vec![Point::new(
1031 1,
1032 vec![1.0, 0.0, 0.0],
1033 Some(json!({"version": 1})),
1034 )])
1035 .unwrap();
1036
1037 collection
1039 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
1040 .unwrap();
1041
1042 let retrieved = collection.get(&[1]);
1043 let point = retrieved[0].as_ref().unwrap();
1044 assert!(point.payload.is_none());
1045 }
1046
1047 #[test]
1048 fn test_collection_flush() {
1049 let dir = tempdir().unwrap();
1050 let path = dir.path().join("test_collection");
1051
1052 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1053 collection
1054 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
1055 .unwrap();
1056
1057 let result = collection.flush();
1059 assert!(result.is_ok());
1060 }
1061
1062 #[test]
1063 fn test_collection_euclidean_metric() {
1064 let dir = tempdir().unwrap();
1065 let path = dir.path().join("test_collection");
1066
1067 let collection = Collection::create(path, 3, DistanceMetric::Euclidean).unwrap();
1068
1069 let points = vec![
1070 Point::without_payload(1, vec![0.0, 0.0, 0.0]),
1071 Point::without_payload(2, vec![1.0, 0.0, 0.0]),
1072 Point::without_payload(3, vec![10.0, 0.0, 0.0]),
1073 ];
1074 collection.upsert(points).unwrap();
1075
1076 let query = vec![0.5, 0.0, 0.0];
1077 let results = collection.search(&query, 3).unwrap();
1078
1079 assert!(results[0].point.id == 1 || results[0].point.id == 2);
1081 }
1082
1083 #[test]
1084 fn test_collection_text_search() {
1085 let dir = tempdir().unwrap();
1086 let path = dir.path().join("test_collection");
1087
1088 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1089
1090 let points = vec![
1091 Point::new(
1092 1,
1093 vec![1.0, 0.0, 0.0],
1094 Some(json!({"title": "Rust Programming", "content": "Learn Rust language"})),
1095 ),
1096 Point::new(
1097 2,
1098 vec![0.0, 1.0, 0.0],
1099 Some(json!({"title": "Python Tutorial", "content": "Python is great"})),
1100 ),
1101 Point::new(
1102 3,
1103 vec![0.0, 0.0, 1.0],
1104 Some(json!({"title": "Rust Performance", "content": "Rust is fast"})),
1105 ),
1106 ];
1107 collection.upsert(points).unwrap();
1108
1109 let results = collection.text_search("rust", 10);
1111 assert_eq!(results.len(), 2);
1112
1113 let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1114 assert!(ids.contains(&1));
1115 assert!(ids.contains(&3));
1116 }
1117
1118 #[test]
1119 fn test_collection_hybrid_search() {
1120 let dir = tempdir().unwrap();
1121 let path = dir.path().join("test_collection");
1122
1123 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1124
1125 let points = vec![
1126 Point::new(
1127 1,
1128 vec![1.0, 0.0, 0.0],
1129 Some(json!({"title": "Rust Programming"})),
1130 ),
1131 Point::new(
1132 2,
1133 vec![0.9, 0.1, 0.0], Some(json!({"title": "Python Programming"})),
1135 ),
1136 Point::new(
1137 3,
1138 vec![0.0, 1.0, 0.0],
1139 Some(json!({"title": "Rust Performance"})),
1140 ),
1141 ];
1142 collection.upsert(points).unwrap();
1143
1144 let query = vec![1.0, 0.0, 0.0];
1149 let results = collection
1150 .hybrid_search(&query, "rust", 3, Some(0.5))
1151 .unwrap();
1152
1153 assert!(!results.is_empty());
1154 assert_eq!(results[0].point.id, 1);
1156 }
1157
1158 #[test]
1159 fn test_extract_text_from_payload() {
1160 let payload = json!({
1162 "title": "Hello",
1163 "meta": {
1164 "author": "World",
1165 "tags": ["rust", "fast"]
1166 }
1167 });
1168
1169 let text = Collection::extract_text_from_payload(&payload);
1170 assert!(text.contains("Hello"));
1171 assert!(text.contains("World"));
1172 assert!(text.contains("rust"));
1173 assert!(text.contains("fast"));
1174 }
1175
1176 #[test]
1177 fn test_text_search_empty_query() {
1178 let dir = tempdir().unwrap();
1179 let path = dir.path().join("test_collection");
1180
1181 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1182
1183 let points = vec![Point::new(
1184 1,
1185 vec![1.0, 0.0, 0.0],
1186 Some(json!({"content": "test document"})),
1187 )];
1188 collection.upsert(points).unwrap();
1189
1190 let results = collection.text_search("", 10);
1192 assert!(results.is_empty());
1193 }
1194
1195 #[test]
1196 fn test_text_search_no_payload() {
1197 let dir = tempdir().unwrap();
1198 let path = dir.path().join("test_collection");
1199
1200 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1201
1202 let points = vec![
1204 Point::new(1, vec![1.0, 0.0, 0.0], None),
1205 Point::new(2, vec![0.0, 1.0, 0.0], None),
1206 ];
1207 collection.upsert(points).unwrap();
1208
1209 let results = collection.text_search("test", 10);
1211 assert!(results.is_empty());
1212 }
1213
1214 #[test]
1215 fn test_hybrid_search_text_weight_zero() {
1216 let dir = tempdir().unwrap();
1217 let path = dir.path().join("test_collection");
1218
1219 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1220
1221 let points = vec![
1222 Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Rust"}))),
1223 Point::new(2, vec![0.9, 0.1, 0.0], Some(json!({"title": "Python"}))),
1224 ];
1225 collection.upsert(points).unwrap();
1226
1227 let query = vec![0.9, 0.1, 0.0];
1229 let results = collection
1230 .hybrid_search(&query, "rust", 2, Some(1.0))
1231 .unwrap();
1232
1233 assert_eq!(results[0].point.id, 2);
1235 }
1236
1237 #[test]
1238 fn test_hybrid_search_vector_weight_zero() {
1239 let dir = tempdir().unwrap();
1240 let path = dir.path().join("test_collection");
1241
1242 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1243
1244 let points = vec![
1245 Point::new(
1246 1,
1247 vec![1.0, 0.0, 0.0],
1248 Some(json!({"title": "Rust programming language"})),
1249 ),
1250 Point::new(
1251 2,
1252 vec![0.99, 0.01, 0.0], Some(json!({"title": "Python programming"})),
1254 ),
1255 ];
1256 collection.upsert(points).unwrap();
1257
1258 let query = vec![0.99, 0.01, 0.0];
1260 let results = collection
1261 .hybrid_search(&query, "rust", 2, Some(0.0))
1262 .unwrap();
1263
1264 assert_eq!(results[0].point.id, 1);
1266 }
1267
1268 #[test]
1269 fn test_bm25_update_document() {
1270 let dir = tempdir().unwrap();
1271 let path = dir.path().join("test_collection");
1272
1273 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1274
1275 let points = vec![Point::new(
1277 1,
1278 vec![1.0, 0.0, 0.0],
1279 Some(json!({"content": "rust programming"})),
1280 )];
1281 collection.upsert(points).unwrap();
1282
1283 let results = collection.text_search("rust", 10);
1285 assert_eq!(results.len(), 1);
1286
1287 let points = vec![Point::new(
1289 1,
1290 vec![1.0, 0.0, 0.0],
1291 Some(json!({"content": "python programming"})),
1292 )];
1293 collection.upsert(points).unwrap();
1294
1295 let results = collection.text_search("rust", 10);
1297 assert!(results.is_empty());
1298
1299 let results = collection.text_search("python", 10);
1301 assert_eq!(results.len(), 1);
1302 }
1303
1304 #[test]
1305 fn test_bm25_large_dataset() {
1306 let dir = tempdir().unwrap();
1307 let path = dir.path().join("test_collection");
1308
1309 let collection = Collection::create(path, 4, DistanceMetric::Cosine).unwrap();
1310
1311 let points: Vec<Point> = (0..100)
1313 .map(|i| {
1314 let content = if i % 10 == 0 {
1315 format!("rust document number {i}")
1316 } else {
1317 format!("other document number {i}")
1318 };
1319 Point::new(
1320 i,
1321 vec![0.1, 0.2, 0.3, 0.4],
1322 Some(json!({"content": content})),
1323 )
1324 })
1325 .collect();
1326 collection.upsert(points).unwrap();
1327
1328 let results = collection.text_search("rust", 100);
1330 assert_eq!(results.len(), 10);
1331
1332 for result in &results {
1334 assert_eq!(result.point.id % 10, 0);
1335 }
1336 }
1337
1338 #[test]
1339 fn test_bm25_persistence_on_reopen() {
1340 let dir = tempdir().unwrap();
1341 let path = dir.path().join("test_collection");
1342
1343 {
1345 let collection = Collection::create(path.clone(), 4, DistanceMetric::Cosine).unwrap();
1346
1347 let points = vec![
1348 Point::new(
1349 1,
1350 vec![1.0, 0.0, 0.0, 0.0],
1351 Some(json!({"content": "Rust programming language"})),
1352 ),
1353 Point::new(
1354 2,
1355 vec![0.0, 1.0, 0.0, 0.0],
1356 Some(json!({"content": "Python tutorial"})),
1357 ),
1358 Point::new(
1359 3,
1360 vec![0.0, 0.0, 1.0, 0.0],
1361 Some(json!({"content": "Rust is fast and safe"})),
1362 ),
1363 ];
1364 collection.upsert(points).unwrap();
1365
1366 let results = collection.text_search("rust", 10);
1368 assert_eq!(results.len(), 2);
1369 }
1370
1371 {
1373 let collection = Collection::open(path).unwrap();
1374
1375 let results = collection.text_search("rust", 10);
1377 assert_eq!(results.len(), 2);
1378
1379 let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1380 assert!(ids.contains(&1));
1381 assert!(ids.contains(&3));
1382 }
1383 }
1384
1385 #[test]
1390 fn test_upsert_bulk_basic() {
1391 let dir = tempdir().unwrap();
1392 let path = dir.path().join("test_collection");
1393 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1394
1395 let points = vec![
1396 Point::new(1, vec![1.0, 0.0, 0.0], None),
1397 Point::new(2, vec![0.0, 1.0, 0.0], None),
1398 Point::new(3, vec![0.0, 0.0, 1.0], None),
1399 ];
1400
1401 let inserted = collection.upsert_bulk(&points).unwrap();
1402 assert_eq!(inserted, 3);
1403 assert_eq!(collection.len(), 3);
1404 }
1405
1406 #[test]
1407 fn test_upsert_bulk_with_payload() {
1408 let dir = tempdir().unwrap();
1409 let path = dir.path().join("test_collection");
1410 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1411
1412 let points = vec![
1413 Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Doc 1"}))),
1414 Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"title": "Doc 2"}))),
1415 ];
1416
1417 collection.upsert_bulk(&points).unwrap();
1418 let retrieved = collection.get(&[1, 2]);
1419 assert_eq!(retrieved.len(), 2);
1420 assert!(retrieved[0].as_ref().unwrap().payload.is_some());
1421 }
1422
1423 #[test]
1424 fn test_upsert_bulk_empty() {
1425 let dir = tempdir().unwrap();
1426 let path = dir.path().join("test_collection");
1427 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1428
1429 let points: Vec<Point> = vec![];
1430 let inserted = collection.upsert_bulk(&points).unwrap();
1431 assert_eq!(inserted, 0);
1432 }
1433
1434 #[test]
1435 fn test_upsert_bulk_dimension_mismatch() {
1436 let dir = tempdir().unwrap();
1437 let path = dir.path().join("test_collection");
1438 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1439
1440 let points = vec![
1441 Point::new(1, vec![1.0, 0.0, 0.0], None),
1442 Point::new(2, vec![0.0, 1.0], None), ];
1444
1445 let result = collection.upsert_bulk(&points);
1446 assert!(result.is_err());
1447 }
1448
1449 #[test]
1450 #[allow(clippy::cast_precision_loss)]
1451 fn test_upsert_bulk_large_batch() {
1452 let dir = tempdir().unwrap();
1453 let path = dir.path().join("test_collection");
1454 let collection = Collection::create(path, 64, DistanceMetric::Cosine).unwrap();
1455
1456 let points: Vec<Point> = (0_u64..500)
1457 .map(|i| {
1458 let vector: Vec<f32> = (0_u64..64)
1459 .map(|j| ((i + j) % 100) as f32 / 100.0)
1460 .collect();
1461 Point::new(i, vector, None)
1462 })
1463 .collect();
1464
1465 let inserted = collection.upsert_bulk(&points).unwrap();
1466 assert_eq!(inserted, 500);
1467 assert_eq!(collection.len(), 500);
1468 }
1469
1470 #[test]
1471 fn test_upsert_bulk_search_works() {
1472 let dir = tempdir().unwrap();
1473 let path = dir.path().join("test_collection");
1474 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1475
1476 let points = vec![
1478 Point::new(1, vec![1.0, 0.0, 0.0], None),
1479 Point::new(2, vec![0.0, 1.0, 0.0], None),
1480 Point::new(3, vec![0.0, 0.0, 1.0], None),
1481 ];
1482
1483 collection.upsert_bulk(&points).unwrap();
1484
1485 let query = vec![1.0, 0.0, 0.0];
1486 let results = collection.search(&query, 3).unwrap();
1487 assert!(!results.is_empty());
1488 assert_eq!(results[0].point.id, 1);
1490 }
1491
1492 #[test]
1493 fn test_upsert_bulk_bm25_indexing() {
1494 let dir = tempdir().unwrap();
1495 let path = dir.path().join("test_collection");
1496 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1497
1498 let points = vec![
1499 Point::new(
1500 1,
1501 vec![1.0, 0.0, 0.0],
1502 Some(json!({"content": "Rust lang"})),
1503 ),
1504 Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"content": "Python"}))),
1505 Point::new(
1506 3,
1507 vec![0.0, 0.0, 1.0],
1508 Some(json!({"content": "Rust fast"})),
1509 ),
1510 ];
1511
1512 collection.upsert_bulk(&points).unwrap();
1513 let results = collection.text_search("rust", 10);
1514 assert_eq!(results.len(), 2);
1515 }
1516}