1use crate::distance::DistanceMetric;
4use crate::error::{Error, Result};
5use crate::index::{Bm25Index, HnswIndex, VectorIndex};
6use crate::point::{Point, SearchResult};
7use crate::storage::{LogPayloadStorage, MmapStorage, PayloadStorage, VectorStorage};
8
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CollectionConfig {
17 pub name: String,
19
20 pub dimension: usize,
22
23 pub metric: DistanceMetric,
25
26 pub point_count: usize,
28}
29
30#[derive(Clone)]
32pub struct Collection {
33 path: PathBuf,
35
36 config: Arc<RwLock<CollectionConfig>>,
38
39 vector_storage: Arc<RwLock<MmapStorage>>,
41
42 payload_storage: Arc<RwLock<LogPayloadStorage>>,
44
45 index: Arc<HnswIndex>,
47
48 text_index: Arc<Bm25Index>,
50}
51
52impl Collection {
53 pub fn create(path: PathBuf, dimension: usize, metric: DistanceMetric) -> Result<Self> {
59 std::fs::create_dir_all(&path)?;
60
61 let name = path
62 .file_name()
63 .and_then(|n| n.to_str())
64 .unwrap_or("unknown")
65 .to_string();
66
67 let config = CollectionConfig {
68 name,
69 dimension,
70 metric,
71 point_count: 0,
72 };
73
74 let vector_storage = Arc::new(RwLock::new(
76 MmapStorage::new(&path, dimension).map_err(Error::Io)?,
77 ));
78
79 let payload_storage = Arc::new(RwLock::new(
80 LogPayloadStorage::new(&path).map_err(Error::Io)?,
81 ));
82
83 let index = Arc::new(HnswIndex::new(dimension, metric));
85
86 let text_index = Arc::new(Bm25Index::new());
88
89 let collection = Self {
90 path,
91 config: Arc::new(RwLock::new(config)),
92 vector_storage,
93 payload_storage,
94 index,
95 text_index,
96 };
97
98 collection.save_config()?;
99
100 Ok(collection)
101 }
102
103 pub fn open(path: PathBuf) -> Result<Self> {
109 let config_path = path.join("config.json");
110 let config_data = std::fs::read_to_string(&config_path)?;
111 let config: CollectionConfig =
112 serde_json::from_str(&config_data).map_err(|e| Error::Serialization(e.to_string()))?;
113
114 let vector_storage = Arc::new(RwLock::new(
116 MmapStorage::new(&path, config.dimension).map_err(Error::Io)?,
117 ));
118
119 let payload_storage = Arc::new(RwLock::new(
120 LogPayloadStorage::new(&path).map_err(Error::Io)?,
121 ));
122
123 let index = if path.join("hnsw.bin").exists() {
125 Arc::new(HnswIndex::load(&path, config.dimension, config.metric).map_err(Error::Io)?)
126 } else {
127 Arc::new(HnswIndex::new(config.dimension, config.metric))
128 };
129
130 let text_index = Arc::new(Bm25Index::new());
132
133 {
135 let storage = payload_storage.read();
136 let ids = storage.ids();
137 for id in ids {
138 if let Ok(Some(payload)) = storage.retrieve(id) {
139 let text = Self::extract_text_from_payload(&payload);
140 if !text.is_empty() {
141 text_index.add_document(id, &text);
142 }
143 }
144 }
145 }
146
147 Ok(Self {
148 path,
149 config: Arc::new(RwLock::new(config)),
150 vector_storage,
151 payload_storage,
152 index,
153 text_index,
154 })
155 }
156
157 #[must_use]
159 pub fn config(&self) -> CollectionConfig {
160 self.config.read().clone()
161 }
162
163 pub fn upsert(&self, points: Vec<Point>) -> Result<()> {
169 let config = self.config.read();
170 let dimension = config.dimension;
171 drop(config);
172
173 for point in &points {
175 if point.dimension() != dimension {
176 return Err(Error::DimensionMismatch {
177 expected: dimension,
178 actual: point.dimension(),
179 });
180 }
181 }
182
183 let mut vector_storage = self.vector_storage.write();
184 let mut payload_storage = self.payload_storage.write();
185
186 for point in points {
187 vector_storage
189 .store(point.id, &point.vector)
190 .map_err(Error::Io)?;
191
192 if let Some(payload) = &point.payload {
194 payload_storage
195 .store(point.id, payload)
196 .map_err(Error::Io)?;
197 } else {
198 let _ = payload_storage.delete(point.id); }
205
206 self.index.insert(point.id, &point.vector);
210
211 if let Some(payload) = &point.payload {
213 let text = Self::extract_text_from_payload(payload);
214 if !text.is_empty() {
215 self.text_index.add_document(point.id, &text);
216 }
217 } else {
218 self.text_index.remove_document(point.id);
220 }
221 }
222
223 let mut config = self.config.write();
225 config.point_count = vector_storage.len();
226
227 vector_storage.flush().map_err(Error::Io)?;
230 payload_storage.flush().map_err(Error::Io)?;
231 self.index.save(&self.path).map_err(Error::Io)?;
232
233 Ok(())
234 }
235
236 pub fn upsert_bulk(&self, points: &[Point]) -> Result<usize> {
249 if points.is_empty() {
250 return Ok(0);
251 }
252
253 let config = self.config.read();
254 let dimension = config.dimension;
255 drop(config);
256
257 for point in points {
259 if point.dimension() != dimension {
260 return Err(Error::DimensionMismatch {
261 expected: dimension,
262 actual: point.dimension(),
263 });
264 }
265 }
266
267 let vectors_for_hnsw: Vec<(u64, Vec<f32>)> =
269 points.iter().map(|p| (p.id, p.vector.clone())).collect();
270
271 let vectors_for_storage: Vec<(u64, &[f32])> = vectors_for_hnsw
274 .iter()
275 .map(|(id, v)| (*id, v.as_slice()))
276 .collect();
277
278 let mut vector_storage = self.vector_storage.write();
279 vector_storage
280 .store_batch(&vectors_for_storage)
281 .map_err(Error::Io)?;
282 drop(vector_storage);
283
284 let mut payload_storage = self.payload_storage.write();
286 for point in points {
287 if let Some(payload) = &point.payload {
288 payload_storage
289 .store(point.id, payload)
290 .map_err(Error::Io)?;
291
292 let text = Self::extract_text_from_payload(payload);
294 if !text.is_empty() {
295 self.text_index.add_document(point.id, &text);
296 }
297 }
298 }
299 drop(payload_storage);
300
301 let inserted = self.index.insert_batch_parallel(vectors_for_hnsw);
303 self.index.set_searching_mode();
304
305 let mut config = self.config.write();
307 config.point_count = self.vector_storage.read().len();
308 drop(config);
309
310 self.vector_storage.write().flush().map_err(Error::Io)?;
312 self.payload_storage.write().flush().map_err(Error::Io)?;
313 self.index.save(&self.path).map_err(Error::Io)?;
314
315 Ok(inserted)
316 }
317
318 #[must_use]
320 pub fn get(&self, ids: &[u64]) -> Vec<Option<Point>> {
321 let vector_storage = self.vector_storage.read();
322 let payload_storage = self.payload_storage.read();
323
324 ids.iter()
325 .map(|&id| {
326 let vector = vector_storage.retrieve(id).ok().flatten()?;
328
329 let payload = payload_storage.retrieve(id).ok().flatten();
331
332 Some(Point {
333 id,
334 vector,
335 payload,
336 })
337 })
338 .collect()
339 }
340
341 pub fn delete(&self, ids: &[u64]) -> Result<()> {
347 let mut vector_storage = self.vector_storage.write();
348 let mut payload_storage = self.payload_storage.write();
349
350 for &id in ids {
351 vector_storage.delete(id).map_err(Error::Io)?;
352 payload_storage.delete(id).map_err(Error::Io)?;
353 self.index.remove(id);
354 }
355
356 let mut config = self.config.write();
357 config.point_count = vector_storage.len();
358
359 Ok(())
360 }
361
362 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
370 let config = self.config.read();
371
372 if query.len() != config.dimension {
373 return Err(Error::DimensionMismatch {
374 expected: config.dimension,
375 actual: query.len(),
376 });
377 }
378 drop(config);
379
380 let index_results = self.index.search(query, k);
382
383 let vector_storage = self.vector_storage.read();
384 let payload_storage = self.payload_storage.read();
385
386 let results: Vec<SearchResult> = index_results
388 .into_iter()
389 .filter_map(|(id, score)| {
390 let vector = vector_storage.retrieve(id).ok().flatten()?;
392 let payload = payload_storage.retrieve(id).ok().flatten();
393
394 let point = Point {
395 id,
396 vector,
397 payload,
398 };
399
400 Some(SearchResult::new(point, score))
401 })
402 .collect();
403
404 Ok(results)
405 }
406
407 #[must_use]
409 pub fn len(&self) -> usize {
410 self.vector_storage.read().len()
411 }
412
413 #[must_use]
415 pub fn is_empty(&self) -> bool {
416 self.vector_storage.read().is_empty()
417 }
418
419 pub fn flush(&self) -> Result<()> {
425 self.save_config()?;
426 self.vector_storage.write().flush().map_err(Error::Io)?;
427 self.payload_storage.write().flush().map_err(Error::Io)?;
428 self.index.save(&self.path).map_err(Error::Io)?;
429 Ok(())
430 }
431
432 fn save_config(&self) -> Result<()> {
434 let config = self.config.read();
435 let config_path = self.path.join("config.json");
436 let config_data = serde_json::to_string_pretty(&*config)
437 .map_err(|e| Error::Serialization(e.to_string()))?;
438 std::fs::write(config_path, config_data)?;
439 Ok(())
440 }
441
442 #[must_use]
453 pub fn text_search(&self, query: &str, k: usize) -> Vec<SearchResult> {
454 let bm25_results = self.text_index.search(query, k);
455
456 let vector_storage = self.vector_storage.read();
457 let payload_storage = self.payload_storage.read();
458
459 bm25_results
460 .into_iter()
461 .filter_map(|(id, score)| {
462 let vector = vector_storage.retrieve(id).ok().flatten()?;
463 let payload = payload_storage.retrieve(id).ok().flatten();
464
465 let point = Point {
466 id,
467 vector,
468 payload,
469 };
470
471 Some(SearchResult::new(point, score))
472 })
473 .collect()
474 }
475
476 pub fn hybrid_search(
491 &self,
492 vector_query: &[f32],
493 text_query: &str,
494 k: usize,
495 vector_weight: Option<f32>,
496 ) -> Result<Vec<SearchResult>> {
497 let config = self.config.read();
498 if vector_query.len() != config.dimension {
499 return Err(Error::DimensionMismatch {
500 expected: config.dimension,
501 actual: vector_query.len(),
502 });
503 }
504 drop(config);
505
506 let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
507 let text_weight = 1.0 - weight;
508
509 let vector_results = self.index.search(vector_query, k * 2);
511
512 let text_results = self.text_index.search(text_query, k * 2);
514
515 let mut fused_scores: rustc_hash::FxHashMap<u64, f32> = rustc_hash::FxHashMap::default();
518
519 #[allow(clippy::cast_precision_loss)]
521 for (rank, (id, _)) in vector_results.iter().enumerate() {
522 let rrf_score = weight / (rank as f32 + 60.0);
523 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
524 }
525
526 #[allow(clippy::cast_precision_loss)]
528 for (rank, (id, _)) in text_results.iter().enumerate() {
529 let rrf_score = text_weight / (rank as f32 + 60.0);
530 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
531 }
532
533 let mut scored_ids: Vec<_> = fused_scores.into_iter().collect();
535 if scored_ids.len() > k {
536 scored_ids.select_nth_unstable_by(k, |a, b| {
537 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
538 });
539 scored_ids.truncate(k);
540 scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
541 } else {
542 scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
543 }
544
545 let vector_storage = self.vector_storage.read();
547 let payload_storage = self.payload_storage.read();
548
549 let results: Vec<SearchResult> = scored_ids
550 .into_iter()
551 .filter_map(|(id, score)| {
552 let vector = vector_storage.retrieve(id).ok().flatten()?;
553 let payload = payload_storage.retrieve(id).ok().flatten();
554
555 let point = Point {
556 id,
557 vector,
558 payload,
559 };
560
561 Some(SearchResult::new(point, score))
562 })
563 .collect();
564
565 Ok(results)
566 }
567
568 fn extract_text_from_payload(payload: &serde_json::Value) -> String {
570 let mut texts = Vec::new();
571 Self::collect_strings(payload, &mut texts);
572 texts.join(" ")
573 }
574
575 fn collect_strings(value: &serde_json::Value, texts: &mut Vec<String>) {
577 match value {
578 serde_json::Value::String(s) => texts.push(s.clone()),
579 serde_json::Value::Array(arr) => {
580 for item in arr {
581 Self::collect_strings(item, texts);
582 }
583 }
584 serde_json::Value::Object(obj) => {
585 for v in obj.values() {
586 Self::collect_strings(v, texts);
587 }
588 }
589 _ => {}
590 }
591 }
592}
593
594#[cfg(test)]
595mod tests {
596 use super::*;
597 use serde_json::json;
598 use tempfile::tempdir;
599
600 #[test]
601 fn test_collection_create() {
602 let dir = tempdir().unwrap();
603 let path = dir.path().join("test_collection");
604
605 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
606 let config = collection.config();
607
608 assert_eq!(config.dimension, 3);
609 assert_eq!(config.metric, DistanceMetric::Cosine);
610 assert_eq!(config.point_count, 0);
611 }
612
613 #[test]
614 fn test_collection_upsert_and_search() {
615 let dir = tempdir().unwrap();
616 let path = dir.path().join("test_collection");
617
618 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
619
620 let points = vec![
621 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
622 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
623 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
624 ];
625
626 collection.upsert(points).unwrap();
627 assert_eq!(collection.len(), 3);
628
629 let query = vec![1.0, 0.0, 0.0];
630 let results = collection.search(&query, 2).unwrap();
631
632 assert_eq!(results.len(), 2);
633 assert_eq!(results[0].point.id, 1); }
635
636 #[test]
637 fn test_dimension_mismatch() {
638 let dir = tempdir().unwrap();
639 let path = dir.path().join("test_collection");
640
641 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
642
643 let points = vec![Point::without_payload(1, vec![1.0, 0.0])]; let result = collection.upsert(points);
646 assert!(result.is_err());
647 }
648
649 #[test]
650 fn test_collection_open_existing() {
651 let dir = tempdir().unwrap();
652 let path = dir.path().join("test_collection");
653
654 {
656 let collection =
657 Collection::create(path.clone(), 3, DistanceMetric::Euclidean).unwrap();
658 let points = vec![
659 Point::without_payload(1, vec![1.0, 2.0, 3.0]),
660 Point::without_payload(2, vec![4.0, 5.0, 6.0]),
661 ];
662 collection.upsert(points).unwrap();
663 collection.flush().unwrap();
664 }
665
666 let collection = Collection::open(path).unwrap();
668 let config = collection.config();
669
670 assert_eq!(config.dimension, 3);
671 assert_eq!(config.metric, DistanceMetric::Euclidean);
672 assert_eq!(collection.len(), 2);
673 }
674
675 #[test]
676 fn test_collection_get_points() {
677 let dir = tempdir().unwrap();
678 let path = dir.path().join("test_collection");
679
680 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
681 let points = vec![
682 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
683 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
684 ];
685 collection.upsert(points).unwrap();
686
687 let retrieved = collection.get(&[1, 2, 999]);
689
690 assert!(retrieved[0].is_some());
691 assert_eq!(retrieved[0].as_ref().unwrap().id, 1);
692 assert!(retrieved[1].is_some());
693 assert_eq!(retrieved[1].as_ref().unwrap().id, 2);
694 assert!(retrieved[2].is_none()); }
696
697 #[test]
698 fn test_collection_delete_points() {
699 let dir = tempdir().unwrap();
700 let path = dir.path().join("test_collection");
701
702 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
703 let points = vec![
704 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
705 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
706 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
707 ];
708 collection.upsert(points).unwrap();
709 assert_eq!(collection.len(), 3);
710
711 collection.delete(&[2]).unwrap();
713 assert_eq!(collection.len(), 2);
714
715 let retrieved = collection.get(&[2]);
717 assert!(retrieved[0].is_none());
718 }
719
720 #[test]
721 fn test_collection_is_empty() {
722 let dir = tempdir().unwrap();
723 let path = dir.path().join("test_collection");
724
725 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
726 assert!(collection.is_empty());
727
728 collection
729 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
730 .unwrap();
731 assert!(!collection.is_empty());
732 }
733
734 #[test]
735 fn test_collection_with_payload() {
736 let dir = tempdir().unwrap();
737 let path = dir.path().join("test_collection");
738
739 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
740
741 let points = vec![Point::new(
742 1,
743 vec![1.0, 0.0, 0.0],
744 Some(json!({"title": "Test Document", "category": "tech"})),
745 )];
746 collection.upsert(points).unwrap();
747
748 let retrieved = collection.get(&[1]);
749 assert!(retrieved[0].is_some());
750
751 let point = retrieved[0].as_ref().unwrap();
752 assert!(point.payload.is_some());
753 assert_eq!(point.payload.as_ref().unwrap()["title"], "Test Document");
754 }
755
756 #[test]
757 fn test_collection_search_dimension_mismatch() {
758 let dir = tempdir().unwrap();
759 let path = dir.path().join("test_collection");
760
761 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
762 collection
763 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
764 .unwrap();
765
766 let result = collection.search(&[1.0, 0.0], 5);
768 assert!(result.is_err());
769 }
770
771 #[test]
772 fn test_collection_upsert_replaces_payload() {
773 let dir = tempdir().unwrap();
774 let path = dir.path().join("test_collection");
775
776 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
777
778 collection
780 .upsert(vec![Point::new(
781 1,
782 vec![1.0, 0.0, 0.0],
783 Some(json!({"version": 1})),
784 )])
785 .unwrap();
786
787 collection
789 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
790 .unwrap();
791
792 let retrieved = collection.get(&[1]);
793 let point = retrieved[0].as_ref().unwrap();
794 assert!(point.payload.is_none());
795 }
796
797 #[test]
798 fn test_collection_flush() {
799 let dir = tempdir().unwrap();
800 let path = dir.path().join("test_collection");
801
802 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
803 collection
804 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
805 .unwrap();
806
807 let result = collection.flush();
809 assert!(result.is_ok());
810 }
811
812 #[test]
813 fn test_collection_euclidean_metric() {
814 let dir = tempdir().unwrap();
815 let path = dir.path().join("test_collection");
816
817 let collection = Collection::create(path, 3, DistanceMetric::Euclidean).unwrap();
818
819 let points = vec![
820 Point::without_payload(1, vec![0.0, 0.0, 0.0]),
821 Point::without_payload(2, vec![1.0, 0.0, 0.0]),
822 Point::without_payload(3, vec![10.0, 0.0, 0.0]),
823 ];
824 collection.upsert(points).unwrap();
825
826 let query = vec![0.5, 0.0, 0.0];
827 let results = collection.search(&query, 3).unwrap();
828
829 assert!(results[0].point.id == 1 || results[0].point.id == 2);
831 }
832
833 #[test]
834 fn test_collection_text_search() {
835 let dir = tempdir().unwrap();
836 let path = dir.path().join("test_collection");
837
838 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
839
840 let points = vec![
841 Point::new(
842 1,
843 vec![1.0, 0.0, 0.0],
844 Some(json!({"title": "Rust Programming", "content": "Learn Rust language"})),
845 ),
846 Point::new(
847 2,
848 vec![0.0, 1.0, 0.0],
849 Some(json!({"title": "Python Tutorial", "content": "Python is great"})),
850 ),
851 Point::new(
852 3,
853 vec![0.0, 0.0, 1.0],
854 Some(json!({"title": "Rust Performance", "content": "Rust is fast"})),
855 ),
856 ];
857 collection.upsert(points).unwrap();
858
859 let results = collection.text_search("rust", 10);
861 assert_eq!(results.len(), 2);
862
863 let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
864 assert!(ids.contains(&1));
865 assert!(ids.contains(&3));
866 }
867
868 #[test]
869 fn test_collection_hybrid_search() {
870 let dir = tempdir().unwrap();
871 let path = dir.path().join("test_collection");
872
873 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
874
875 let points = vec![
876 Point::new(
877 1,
878 vec![1.0, 0.0, 0.0],
879 Some(json!({"title": "Rust Programming"})),
880 ),
881 Point::new(
882 2,
883 vec![0.9, 0.1, 0.0], Some(json!({"title": "Python Programming"})),
885 ),
886 Point::new(
887 3,
888 vec![0.0, 1.0, 0.0],
889 Some(json!({"title": "Rust Performance"})),
890 ),
891 ];
892 collection.upsert(points).unwrap();
893
894 let query = vec![1.0, 0.0, 0.0];
899 let results = collection
900 .hybrid_search(&query, "rust", 3, Some(0.5))
901 .unwrap();
902
903 assert!(!results.is_empty());
904 assert_eq!(results[0].point.id, 1);
906 }
907
908 #[test]
909 fn test_extract_text_from_payload() {
910 let payload = json!({
912 "title": "Hello",
913 "meta": {
914 "author": "World",
915 "tags": ["rust", "fast"]
916 }
917 });
918
919 let text = Collection::extract_text_from_payload(&payload);
920 assert!(text.contains("Hello"));
921 assert!(text.contains("World"));
922 assert!(text.contains("rust"));
923 assert!(text.contains("fast"));
924 }
925
926 #[test]
927 fn test_text_search_empty_query() {
928 let dir = tempdir().unwrap();
929 let path = dir.path().join("test_collection");
930
931 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
932
933 let points = vec![Point::new(
934 1,
935 vec![1.0, 0.0, 0.0],
936 Some(json!({"content": "test document"})),
937 )];
938 collection.upsert(points).unwrap();
939
940 let results = collection.text_search("", 10);
942 assert!(results.is_empty());
943 }
944
945 #[test]
946 fn test_text_search_no_payload() {
947 let dir = tempdir().unwrap();
948 let path = dir.path().join("test_collection");
949
950 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
951
952 let points = vec![
954 Point::new(1, vec![1.0, 0.0, 0.0], None),
955 Point::new(2, vec![0.0, 1.0, 0.0], None),
956 ];
957 collection.upsert(points).unwrap();
958
959 let results = collection.text_search("test", 10);
961 assert!(results.is_empty());
962 }
963
964 #[test]
965 fn test_hybrid_search_text_weight_zero() {
966 let dir = tempdir().unwrap();
967 let path = dir.path().join("test_collection");
968
969 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
970
971 let points = vec![
972 Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Rust"}))),
973 Point::new(2, vec![0.9, 0.1, 0.0], Some(json!({"title": "Python"}))),
974 ];
975 collection.upsert(points).unwrap();
976
977 let query = vec![0.9, 0.1, 0.0];
979 let results = collection
980 .hybrid_search(&query, "rust", 2, Some(1.0))
981 .unwrap();
982
983 assert_eq!(results[0].point.id, 2);
985 }
986
987 #[test]
988 fn test_hybrid_search_vector_weight_zero() {
989 let dir = tempdir().unwrap();
990 let path = dir.path().join("test_collection");
991
992 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
993
994 let points = vec![
995 Point::new(
996 1,
997 vec![1.0, 0.0, 0.0],
998 Some(json!({"title": "Rust programming language"})),
999 ),
1000 Point::new(
1001 2,
1002 vec![0.99, 0.01, 0.0], Some(json!({"title": "Python programming"})),
1004 ),
1005 ];
1006 collection.upsert(points).unwrap();
1007
1008 let query = vec![0.99, 0.01, 0.0];
1010 let results = collection
1011 .hybrid_search(&query, "rust", 2, Some(0.0))
1012 .unwrap();
1013
1014 assert_eq!(results[0].point.id, 1);
1016 }
1017
1018 #[test]
1019 fn test_bm25_update_document() {
1020 let dir = tempdir().unwrap();
1021 let path = dir.path().join("test_collection");
1022
1023 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1024
1025 let points = vec![Point::new(
1027 1,
1028 vec![1.0, 0.0, 0.0],
1029 Some(json!({"content": "rust programming"})),
1030 )];
1031 collection.upsert(points).unwrap();
1032
1033 let results = collection.text_search("rust", 10);
1035 assert_eq!(results.len(), 1);
1036
1037 let points = vec![Point::new(
1039 1,
1040 vec![1.0, 0.0, 0.0],
1041 Some(json!({"content": "python programming"})),
1042 )];
1043 collection.upsert(points).unwrap();
1044
1045 let results = collection.text_search("rust", 10);
1047 assert!(results.is_empty());
1048
1049 let results = collection.text_search("python", 10);
1051 assert_eq!(results.len(), 1);
1052 }
1053
1054 #[test]
1055 fn test_bm25_large_dataset() {
1056 let dir = tempdir().unwrap();
1057 let path = dir.path().join("test_collection");
1058
1059 let collection = Collection::create(path, 4, DistanceMetric::Cosine).unwrap();
1060
1061 let points: Vec<Point> = (0..100)
1063 .map(|i| {
1064 let content = if i % 10 == 0 {
1065 format!("rust document number {i}")
1066 } else {
1067 format!("other document number {i}")
1068 };
1069 Point::new(
1070 i,
1071 vec![0.1, 0.2, 0.3, 0.4],
1072 Some(json!({"content": content})),
1073 )
1074 })
1075 .collect();
1076 collection.upsert(points).unwrap();
1077
1078 let results = collection.text_search("rust", 100);
1080 assert_eq!(results.len(), 10);
1081
1082 for result in &results {
1084 assert_eq!(result.point.id % 10, 0);
1085 }
1086 }
1087
1088 #[test]
1089 fn test_bm25_persistence_on_reopen() {
1090 let dir = tempdir().unwrap();
1091 let path = dir.path().join("test_collection");
1092
1093 {
1095 let collection = Collection::create(path.clone(), 4, DistanceMetric::Cosine).unwrap();
1096
1097 let points = vec![
1098 Point::new(
1099 1,
1100 vec![1.0, 0.0, 0.0, 0.0],
1101 Some(json!({"content": "Rust programming language"})),
1102 ),
1103 Point::new(
1104 2,
1105 vec![0.0, 1.0, 0.0, 0.0],
1106 Some(json!({"content": "Python tutorial"})),
1107 ),
1108 Point::new(
1109 3,
1110 vec![0.0, 0.0, 1.0, 0.0],
1111 Some(json!({"content": "Rust is fast and safe"})),
1112 ),
1113 ];
1114 collection.upsert(points).unwrap();
1115
1116 let results = collection.text_search("rust", 10);
1118 assert_eq!(results.len(), 2);
1119 }
1120
1121 {
1123 let collection = Collection::open(path).unwrap();
1124
1125 let results = collection.text_search("rust", 10);
1127 assert_eq!(results.len(), 2);
1128
1129 let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1130 assert!(ids.contains(&1));
1131 assert!(ids.contains(&3));
1132 }
1133 }
1134
1135 #[test]
1140 fn test_upsert_bulk_basic() {
1141 let dir = tempdir().unwrap();
1142 let path = dir.path().join("test_collection");
1143 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1144
1145 let points = vec![
1146 Point::new(1, vec![1.0, 0.0, 0.0], None),
1147 Point::new(2, vec![0.0, 1.0, 0.0], None),
1148 Point::new(3, vec![0.0, 0.0, 1.0], None),
1149 ];
1150
1151 let inserted = collection.upsert_bulk(&points).unwrap();
1152 assert_eq!(inserted, 3);
1153 assert_eq!(collection.len(), 3);
1154 }
1155
1156 #[test]
1157 fn test_upsert_bulk_with_payload() {
1158 let dir = tempdir().unwrap();
1159 let path = dir.path().join("test_collection");
1160 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1161
1162 let points = vec![
1163 Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Doc 1"}))),
1164 Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"title": "Doc 2"}))),
1165 ];
1166
1167 collection.upsert_bulk(&points).unwrap();
1168 let retrieved = collection.get(&[1, 2]);
1169 assert_eq!(retrieved.len(), 2);
1170 assert!(retrieved[0].as_ref().unwrap().payload.is_some());
1171 }
1172
1173 #[test]
1174 fn test_upsert_bulk_empty() {
1175 let dir = tempdir().unwrap();
1176 let path = dir.path().join("test_collection");
1177 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1178
1179 let points: Vec<Point> = vec![];
1180 let inserted = collection.upsert_bulk(&points).unwrap();
1181 assert_eq!(inserted, 0);
1182 }
1183
1184 #[test]
1185 fn test_upsert_bulk_dimension_mismatch() {
1186 let dir = tempdir().unwrap();
1187 let path = dir.path().join("test_collection");
1188 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1189
1190 let points = vec![
1191 Point::new(1, vec![1.0, 0.0, 0.0], None),
1192 Point::new(2, vec![0.0, 1.0], None), ];
1194
1195 let result = collection.upsert_bulk(&points);
1196 assert!(result.is_err());
1197 }
1198
1199 #[test]
1200 #[allow(clippy::cast_precision_loss)]
1201 fn test_upsert_bulk_large_batch() {
1202 let dir = tempdir().unwrap();
1203 let path = dir.path().join("test_collection");
1204 let collection = Collection::create(path, 64, DistanceMetric::Cosine).unwrap();
1205
1206 let points: Vec<Point> = (0_u64..500)
1207 .map(|i| {
1208 let vector: Vec<f32> = (0_u64..64)
1209 .map(|j| ((i + j) % 100) as f32 / 100.0)
1210 .collect();
1211 Point::new(i, vector, None)
1212 })
1213 .collect();
1214
1215 let inserted = collection.upsert_bulk(&points).unwrap();
1216 assert_eq!(inserted, 500);
1217 assert_eq!(collection.len(), 500);
1218 }
1219
1220 #[test]
1221 fn test_upsert_bulk_search_works() {
1222 let dir = tempdir().unwrap();
1223 let path = dir.path().join("test_collection");
1224 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1225
1226 let points = vec![
1227 Point::new(1, vec![1.0, 0.0, 0.0], None),
1228 Point::new(2, vec![0.9, 0.1, 0.0], None),
1229 Point::new(3, vec![0.0, 1.0, 0.0], None),
1230 ];
1231
1232 collection.upsert_bulk(&points).unwrap();
1233
1234 let query = vec![1.0, 0.0, 0.0];
1235 let results = collection.search(&query, 3).unwrap();
1236 assert!(!results.is_empty());
1237 assert_eq!(results[0].point.id, 1);
1238 }
1239
1240 #[test]
1241 fn test_upsert_bulk_bm25_indexing() {
1242 let dir = tempdir().unwrap();
1243 let path = dir.path().join("test_collection");
1244 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
1245
1246 let points = vec![
1247 Point::new(
1248 1,
1249 vec![1.0, 0.0, 0.0],
1250 Some(json!({"content": "Rust lang"})),
1251 ),
1252 Point::new(2, vec![0.0, 1.0, 0.0], Some(json!({"content": "Python"}))),
1253 Point::new(
1254 3,
1255 vec![0.0, 0.0, 1.0],
1256 Some(json!({"content": "Rust fast"})),
1257 ),
1258 ];
1259
1260 collection.upsert_bulk(&points).unwrap();
1261 let results = collection.text_search("rust", 10);
1262 assert_eq!(results.len(), 2);
1263 }
1264}