1use crate::distance::DistanceMetric;
4use crate::error::{Error, Result};
5use crate::index::{Bm25Index, HnswIndex, VectorIndex};
6use crate::point::{Point, SearchResult};
7use crate::storage::{LogPayloadStorage, MmapStorage, PayloadStorage, VectorStorage};
8
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CollectionConfig {
17 pub name: String,
19
20 pub dimension: usize,
22
23 pub metric: DistanceMetric,
25
26 pub point_count: usize,
28}
29
30#[derive(Clone)]
32pub struct Collection {
33 path: PathBuf,
35
36 config: Arc<RwLock<CollectionConfig>>,
38
39 vector_storage: Arc<RwLock<MmapStorage>>,
41
42 payload_storage: Arc<RwLock<LogPayloadStorage>>,
44
45 index: Arc<HnswIndex>,
47
48 text_index: Arc<Bm25Index>,
50}
51
52impl Collection {
53 pub fn create(path: PathBuf, dimension: usize, metric: DistanceMetric) -> Result<Self> {
59 std::fs::create_dir_all(&path)?;
60
61 let name = path
62 .file_name()
63 .and_then(|n| n.to_str())
64 .unwrap_or("unknown")
65 .to_string();
66
67 let config = CollectionConfig {
68 name,
69 dimension,
70 metric,
71 point_count: 0,
72 };
73
74 let vector_storage = Arc::new(RwLock::new(
76 MmapStorage::new(&path, dimension).map_err(Error::Io)?,
77 ));
78
79 let payload_storage = Arc::new(RwLock::new(
80 LogPayloadStorage::new(&path).map_err(Error::Io)?,
81 ));
82
83 let index = Arc::new(HnswIndex::new(dimension, metric));
85
86 let text_index = Arc::new(Bm25Index::new());
88
89 let collection = Self {
90 path,
91 config: Arc::new(RwLock::new(config)),
92 vector_storage,
93 payload_storage,
94 index,
95 text_index,
96 };
97
98 collection.save_config()?;
99
100 Ok(collection)
101 }
102
103 pub fn open(path: PathBuf) -> Result<Self> {
109 let config_path = path.join("config.json");
110 let config_data = std::fs::read_to_string(&config_path)?;
111 let config: CollectionConfig =
112 serde_json::from_str(&config_data).map_err(|e| Error::Serialization(e.to_string()))?;
113
114 let vector_storage = Arc::new(RwLock::new(
116 MmapStorage::new(&path, config.dimension).map_err(Error::Io)?,
117 ));
118
119 let payload_storage = Arc::new(RwLock::new(
120 LogPayloadStorage::new(&path).map_err(Error::Io)?,
121 ));
122
123 let index = if path.join("hnsw.bin").exists() {
125 Arc::new(HnswIndex::load(&path, config.dimension, config.metric).map_err(Error::Io)?)
126 } else {
127 Arc::new(HnswIndex::new(config.dimension, config.metric))
128 };
129
130 let text_index = Arc::new(Bm25Index::new());
132
133 {
135 let storage = payload_storage.read();
136 let ids = storage.ids();
137 for id in ids {
138 if let Ok(Some(payload)) = storage.retrieve(id) {
139 let text = Self::extract_text_from_payload(&payload);
140 if !text.is_empty() {
141 text_index.add_document(id, &text);
142 }
143 }
144 }
145 }
146
147 Ok(Self {
148 path,
149 config: Arc::new(RwLock::new(config)),
150 vector_storage,
151 payload_storage,
152 index,
153 text_index,
154 })
155 }
156
157 #[must_use]
159 pub fn config(&self) -> CollectionConfig {
160 self.config.read().clone()
161 }
162
163 pub fn upsert(&self, points: Vec<Point>) -> Result<()> {
169 let config = self.config.read();
170 let dimension = config.dimension;
171 drop(config);
172
173 for point in &points {
175 if point.dimension() != dimension {
176 return Err(Error::DimensionMismatch {
177 expected: dimension,
178 actual: point.dimension(),
179 });
180 }
181 }
182
183 let mut vector_storage = self.vector_storage.write();
184 let mut payload_storage = self.payload_storage.write();
185
186 for point in points {
187 vector_storage
189 .store(point.id, &point.vector)
190 .map_err(Error::Io)?;
191
192 if let Some(payload) = &point.payload {
194 payload_storage
195 .store(point.id, payload)
196 .map_err(Error::Io)?;
197 } else {
198 let _ = payload_storage.delete(point.id); }
205
206 self.index.insert(point.id, &point.vector);
210
211 if let Some(payload) = &point.payload {
213 let text = Self::extract_text_from_payload(payload);
214 if !text.is_empty() {
215 self.text_index.add_document(point.id, &text);
216 }
217 } else {
218 self.text_index.remove_document(point.id);
220 }
221 }
222
223 let mut config = self.config.write();
225 config.point_count = vector_storage.len();
226
227 vector_storage.flush().map_err(Error::Io)?;
230 payload_storage.flush().map_err(Error::Io)?;
231 self.index.save(&self.path).map_err(Error::Io)?;
232
233 Ok(())
234 }
235
236 #[must_use]
238 pub fn get(&self, ids: &[u64]) -> Vec<Option<Point>> {
239 let vector_storage = self.vector_storage.read();
240 let payload_storage = self.payload_storage.read();
241
242 ids.iter()
243 .map(|&id| {
244 let vector = vector_storage.retrieve(id).ok().flatten()?;
246
247 let payload = payload_storage.retrieve(id).ok().flatten();
249
250 Some(Point {
251 id,
252 vector,
253 payload,
254 })
255 })
256 .collect()
257 }
258
259 pub fn delete(&self, ids: &[u64]) -> Result<()> {
265 let mut vector_storage = self.vector_storage.write();
266 let mut payload_storage = self.payload_storage.write();
267
268 for &id in ids {
269 vector_storage.delete(id).map_err(Error::Io)?;
270 payload_storage.delete(id).map_err(Error::Io)?;
271 self.index.remove(id);
272 }
273
274 let mut config = self.config.write();
275 config.point_count = vector_storage.len();
276
277 Ok(())
278 }
279
280 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
288 let config = self.config.read();
289
290 if query.len() != config.dimension {
291 return Err(Error::DimensionMismatch {
292 expected: config.dimension,
293 actual: query.len(),
294 });
295 }
296 drop(config);
297
298 let index_results = self.index.search(query, k);
300
301 let vector_storage = self.vector_storage.read();
302 let payload_storage = self.payload_storage.read();
303
304 let results: Vec<SearchResult> = index_results
306 .into_iter()
307 .filter_map(|(id, score)| {
308 let vector = vector_storage.retrieve(id).ok().flatten()?;
310 let payload = payload_storage.retrieve(id).ok().flatten();
311
312 let point = Point {
313 id,
314 vector,
315 payload,
316 };
317
318 Some(SearchResult::new(point, score))
319 })
320 .collect();
321
322 Ok(results)
323 }
324
325 #[must_use]
327 pub fn len(&self) -> usize {
328 self.vector_storage.read().len()
329 }
330
331 #[must_use]
333 pub fn is_empty(&self) -> bool {
334 self.vector_storage.read().is_empty()
335 }
336
337 pub fn flush(&self) -> Result<()> {
343 self.save_config()?;
344 self.vector_storage.write().flush().map_err(Error::Io)?;
345 self.payload_storage.write().flush().map_err(Error::Io)?;
346 self.index.save(&self.path).map_err(Error::Io)?;
347 Ok(())
348 }
349
350 fn save_config(&self) -> Result<()> {
352 let config = self.config.read();
353 let config_path = self.path.join("config.json");
354 let config_data = serde_json::to_string_pretty(&*config)
355 .map_err(|e| Error::Serialization(e.to_string()))?;
356 std::fs::write(config_path, config_data)?;
357 Ok(())
358 }
359
360 #[must_use]
371 pub fn text_search(&self, query: &str, k: usize) -> Vec<SearchResult> {
372 let bm25_results = self.text_index.search(query, k);
373
374 let vector_storage = self.vector_storage.read();
375 let payload_storage = self.payload_storage.read();
376
377 bm25_results
378 .into_iter()
379 .filter_map(|(id, score)| {
380 let vector = vector_storage.retrieve(id).ok().flatten()?;
381 let payload = payload_storage.retrieve(id).ok().flatten();
382
383 let point = Point {
384 id,
385 vector,
386 payload,
387 };
388
389 Some(SearchResult::new(point, score))
390 })
391 .collect()
392 }
393
394 pub fn hybrid_search(
409 &self,
410 vector_query: &[f32],
411 text_query: &str,
412 k: usize,
413 vector_weight: Option<f32>,
414 ) -> Result<Vec<SearchResult>> {
415 let config = self.config.read();
416 if vector_query.len() != config.dimension {
417 return Err(Error::DimensionMismatch {
418 expected: config.dimension,
419 actual: vector_query.len(),
420 });
421 }
422 drop(config);
423
424 let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
425 let text_weight = 1.0 - weight;
426
427 let vector_results = self.index.search(vector_query, k * 2);
429
430 let text_results = self.text_index.search(text_query, k * 2);
432
433 let mut fused_scores: rustc_hash::FxHashMap<u64, f32> = rustc_hash::FxHashMap::default();
436
437 #[allow(clippy::cast_precision_loss)]
439 for (rank, (id, _)) in vector_results.iter().enumerate() {
440 let rrf_score = weight / (rank as f32 + 60.0);
441 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
442 }
443
444 #[allow(clippy::cast_precision_loss)]
446 for (rank, (id, _)) in text_results.iter().enumerate() {
447 let rrf_score = text_weight / (rank as f32 + 60.0);
448 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
449 }
450
451 let mut scored_ids: Vec<_> = fused_scores.into_iter().collect();
453 if scored_ids.len() > k {
454 scored_ids.select_nth_unstable_by(k, |a, b| {
455 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
456 });
457 scored_ids.truncate(k);
458 scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
459 } else {
460 scored_ids.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
461 }
462
463 let vector_storage = self.vector_storage.read();
465 let payload_storage = self.payload_storage.read();
466
467 let results: Vec<SearchResult> = scored_ids
468 .into_iter()
469 .filter_map(|(id, score)| {
470 let vector = vector_storage.retrieve(id).ok().flatten()?;
471 let payload = payload_storage.retrieve(id).ok().flatten();
472
473 let point = Point {
474 id,
475 vector,
476 payload,
477 };
478
479 Some(SearchResult::new(point, score))
480 })
481 .collect();
482
483 Ok(results)
484 }
485
486 fn extract_text_from_payload(payload: &serde_json::Value) -> String {
488 let mut texts = Vec::new();
489 Self::collect_strings(payload, &mut texts);
490 texts.join(" ")
491 }
492
493 fn collect_strings(value: &serde_json::Value, texts: &mut Vec<String>) {
495 match value {
496 serde_json::Value::String(s) => texts.push(s.clone()),
497 serde_json::Value::Array(arr) => {
498 for item in arr {
499 Self::collect_strings(item, texts);
500 }
501 }
502 serde_json::Value::Object(obj) => {
503 for v in obj.values() {
504 Self::collect_strings(v, texts);
505 }
506 }
507 _ => {}
508 }
509 }
510}
511
512#[cfg(test)]
513mod tests {
514 use super::*;
515 use serde_json::json;
516 use tempfile::tempdir;
517
518 #[test]
519 fn test_collection_create() {
520 let dir = tempdir().unwrap();
521 let path = dir.path().join("test_collection");
522
523 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
524 let config = collection.config();
525
526 assert_eq!(config.dimension, 3);
527 assert_eq!(config.metric, DistanceMetric::Cosine);
528 assert_eq!(config.point_count, 0);
529 }
530
531 #[test]
532 fn test_collection_upsert_and_search() {
533 let dir = tempdir().unwrap();
534 let path = dir.path().join("test_collection");
535
536 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
537
538 let points = vec![
539 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
540 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
541 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
542 ];
543
544 collection.upsert(points).unwrap();
545 assert_eq!(collection.len(), 3);
546
547 let query = vec![1.0, 0.0, 0.0];
548 let results = collection.search(&query, 2).unwrap();
549
550 assert_eq!(results.len(), 2);
551 assert_eq!(results[0].point.id, 1); }
553
554 #[test]
555 fn test_dimension_mismatch() {
556 let dir = tempdir().unwrap();
557 let path = dir.path().join("test_collection");
558
559 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
560
561 let points = vec![Point::without_payload(1, vec![1.0, 0.0])]; let result = collection.upsert(points);
564 assert!(result.is_err());
565 }
566
567 #[test]
568 fn test_collection_open_existing() {
569 let dir = tempdir().unwrap();
570 let path = dir.path().join("test_collection");
571
572 {
574 let collection =
575 Collection::create(path.clone(), 3, DistanceMetric::Euclidean).unwrap();
576 let points = vec![
577 Point::without_payload(1, vec![1.0, 2.0, 3.0]),
578 Point::without_payload(2, vec![4.0, 5.0, 6.0]),
579 ];
580 collection.upsert(points).unwrap();
581 collection.flush().unwrap();
582 }
583
584 let collection = Collection::open(path).unwrap();
586 let config = collection.config();
587
588 assert_eq!(config.dimension, 3);
589 assert_eq!(config.metric, DistanceMetric::Euclidean);
590 assert_eq!(collection.len(), 2);
591 }
592
593 #[test]
594 fn test_collection_get_points() {
595 let dir = tempdir().unwrap();
596 let path = dir.path().join("test_collection");
597
598 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
599 let points = vec![
600 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
601 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
602 ];
603 collection.upsert(points).unwrap();
604
605 let retrieved = collection.get(&[1, 2, 999]);
607
608 assert!(retrieved[0].is_some());
609 assert_eq!(retrieved[0].as_ref().unwrap().id, 1);
610 assert!(retrieved[1].is_some());
611 assert_eq!(retrieved[1].as_ref().unwrap().id, 2);
612 assert!(retrieved[2].is_none()); }
614
615 #[test]
616 fn test_collection_delete_points() {
617 let dir = tempdir().unwrap();
618 let path = dir.path().join("test_collection");
619
620 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
621 let points = vec![
622 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
623 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
624 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
625 ];
626 collection.upsert(points).unwrap();
627 assert_eq!(collection.len(), 3);
628
629 collection.delete(&[2]).unwrap();
631 assert_eq!(collection.len(), 2);
632
633 let retrieved = collection.get(&[2]);
635 assert!(retrieved[0].is_none());
636 }
637
638 #[test]
639 fn test_collection_is_empty() {
640 let dir = tempdir().unwrap();
641 let path = dir.path().join("test_collection");
642
643 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
644 assert!(collection.is_empty());
645
646 collection
647 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
648 .unwrap();
649 assert!(!collection.is_empty());
650 }
651
652 #[test]
653 fn test_collection_with_payload() {
654 let dir = tempdir().unwrap();
655 let path = dir.path().join("test_collection");
656
657 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
658
659 let points = vec![Point::new(
660 1,
661 vec![1.0, 0.0, 0.0],
662 Some(json!({"title": "Test Document", "category": "tech"})),
663 )];
664 collection.upsert(points).unwrap();
665
666 let retrieved = collection.get(&[1]);
667 assert!(retrieved[0].is_some());
668
669 let point = retrieved[0].as_ref().unwrap();
670 assert!(point.payload.is_some());
671 assert_eq!(point.payload.as_ref().unwrap()["title"], "Test Document");
672 }
673
674 #[test]
675 fn test_collection_search_dimension_mismatch() {
676 let dir = tempdir().unwrap();
677 let path = dir.path().join("test_collection");
678
679 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
680 collection
681 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
682 .unwrap();
683
684 let result = collection.search(&[1.0, 0.0], 5);
686 assert!(result.is_err());
687 }
688
689 #[test]
690 fn test_collection_upsert_replaces_payload() {
691 let dir = tempdir().unwrap();
692 let path = dir.path().join("test_collection");
693
694 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
695
696 collection
698 .upsert(vec![Point::new(
699 1,
700 vec![1.0, 0.0, 0.0],
701 Some(json!({"version": 1})),
702 )])
703 .unwrap();
704
705 collection
707 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
708 .unwrap();
709
710 let retrieved = collection.get(&[1]);
711 let point = retrieved[0].as_ref().unwrap();
712 assert!(point.payload.is_none());
713 }
714
715 #[test]
716 fn test_collection_flush() {
717 let dir = tempdir().unwrap();
718 let path = dir.path().join("test_collection");
719
720 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
721 collection
722 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
723 .unwrap();
724
725 let result = collection.flush();
727 assert!(result.is_ok());
728 }
729
730 #[test]
731 fn test_collection_euclidean_metric() {
732 let dir = tempdir().unwrap();
733 let path = dir.path().join("test_collection");
734
735 let collection = Collection::create(path, 3, DistanceMetric::Euclidean).unwrap();
736
737 let points = vec![
738 Point::without_payload(1, vec![0.0, 0.0, 0.0]),
739 Point::without_payload(2, vec![1.0, 0.0, 0.0]),
740 Point::without_payload(3, vec![10.0, 0.0, 0.0]),
741 ];
742 collection.upsert(points).unwrap();
743
744 let query = vec![0.5, 0.0, 0.0];
745 let results = collection.search(&query, 3).unwrap();
746
747 assert!(results[0].point.id == 1 || results[0].point.id == 2);
749 }
750
751 #[test]
752 fn test_collection_text_search() {
753 let dir = tempdir().unwrap();
754 let path = dir.path().join("test_collection");
755
756 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
757
758 let points = vec![
759 Point::new(
760 1,
761 vec![1.0, 0.0, 0.0],
762 Some(json!({"title": "Rust Programming", "content": "Learn Rust language"})),
763 ),
764 Point::new(
765 2,
766 vec![0.0, 1.0, 0.0],
767 Some(json!({"title": "Python Tutorial", "content": "Python is great"})),
768 ),
769 Point::new(
770 3,
771 vec![0.0, 0.0, 1.0],
772 Some(json!({"title": "Rust Performance", "content": "Rust is fast"})),
773 ),
774 ];
775 collection.upsert(points).unwrap();
776
777 let results = collection.text_search("rust", 10);
779 assert_eq!(results.len(), 2);
780
781 let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
782 assert!(ids.contains(&1));
783 assert!(ids.contains(&3));
784 }
785
786 #[test]
787 fn test_collection_hybrid_search() {
788 let dir = tempdir().unwrap();
789 let path = dir.path().join("test_collection");
790
791 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
792
793 let points = vec![
794 Point::new(
795 1,
796 vec![1.0, 0.0, 0.0],
797 Some(json!({"title": "Rust Programming"})),
798 ),
799 Point::new(
800 2,
801 vec![0.9, 0.1, 0.0], Some(json!({"title": "Python Programming"})),
803 ),
804 Point::new(
805 3,
806 vec![0.0, 1.0, 0.0],
807 Some(json!({"title": "Rust Performance"})),
808 ),
809 ];
810 collection.upsert(points).unwrap();
811
812 let query = vec![1.0, 0.0, 0.0];
817 let results = collection
818 .hybrid_search(&query, "rust", 3, Some(0.5))
819 .unwrap();
820
821 assert!(!results.is_empty());
822 assert_eq!(results[0].point.id, 1);
824 }
825
826 #[test]
827 fn test_extract_text_from_payload() {
828 let payload = json!({
830 "title": "Hello",
831 "meta": {
832 "author": "World",
833 "tags": ["rust", "fast"]
834 }
835 });
836
837 let text = Collection::extract_text_from_payload(&payload);
838 assert!(text.contains("Hello"));
839 assert!(text.contains("World"));
840 assert!(text.contains("rust"));
841 assert!(text.contains("fast"));
842 }
843
844 #[test]
845 fn test_text_search_empty_query() {
846 let dir = tempdir().unwrap();
847 let path = dir.path().join("test_collection");
848
849 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
850
851 let points = vec![Point::new(
852 1,
853 vec![1.0, 0.0, 0.0],
854 Some(json!({"content": "test document"})),
855 )];
856 collection.upsert(points).unwrap();
857
858 let results = collection.text_search("", 10);
860 assert!(results.is_empty());
861 }
862
863 #[test]
864 fn test_text_search_no_payload() {
865 let dir = tempdir().unwrap();
866 let path = dir.path().join("test_collection");
867
868 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
869
870 let points = vec![
872 Point::new(1, vec![1.0, 0.0, 0.0], None),
873 Point::new(2, vec![0.0, 1.0, 0.0], None),
874 ];
875 collection.upsert(points).unwrap();
876
877 let results = collection.text_search("test", 10);
879 assert!(results.is_empty());
880 }
881
882 #[test]
883 fn test_hybrid_search_text_weight_zero() {
884 let dir = tempdir().unwrap();
885 let path = dir.path().join("test_collection");
886
887 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
888
889 let points = vec![
890 Point::new(1, vec![1.0, 0.0, 0.0], Some(json!({"title": "Rust"}))),
891 Point::new(2, vec![0.9, 0.1, 0.0], Some(json!({"title": "Python"}))),
892 ];
893 collection.upsert(points).unwrap();
894
895 let query = vec![0.9, 0.1, 0.0];
897 let results = collection
898 .hybrid_search(&query, "rust", 2, Some(1.0))
899 .unwrap();
900
901 assert_eq!(results[0].point.id, 2);
903 }
904
905 #[test]
906 fn test_hybrid_search_vector_weight_zero() {
907 let dir = tempdir().unwrap();
908 let path = dir.path().join("test_collection");
909
910 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
911
912 let points = vec![
913 Point::new(
914 1,
915 vec![1.0, 0.0, 0.0],
916 Some(json!({"title": "Rust programming language"})),
917 ),
918 Point::new(
919 2,
920 vec![0.99, 0.01, 0.0], Some(json!({"title": "Python programming"})),
922 ),
923 ];
924 collection.upsert(points).unwrap();
925
926 let query = vec![0.99, 0.01, 0.0];
928 let results = collection
929 .hybrid_search(&query, "rust", 2, Some(0.0))
930 .unwrap();
931
932 assert_eq!(results[0].point.id, 1);
934 }
935
936 #[test]
937 fn test_bm25_update_document() {
938 let dir = tempdir().unwrap();
939 let path = dir.path().join("test_collection");
940
941 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
942
943 let points = vec![Point::new(
945 1,
946 vec![1.0, 0.0, 0.0],
947 Some(json!({"content": "rust programming"})),
948 )];
949 collection.upsert(points).unwrap();
950
951 let results = collection.text_search("rust", 10);
953 assert_eq!(results.len(), 1);
954
955 let points = vec![Point::new(
957 1,
958 vec![1.0, 0.0, 0.0],
959 Some(json!({"content": "python programming"})),
960 )];
961 collection.upsert(points).unwrap();
962
963 let results = collection.text_search("rust", 10);
965 assert!(results.is_empty());
966
967 let results = collection.text_search("python", 10);
969 assert_eq!(results.len(), 1);
970 }
971
972 #[test]
973 fn test_bm25_large_dataset() {
974 let dir = tempdir().unwrap();
975 let path = dir.path().join("test_collection");
976
977 let collection = Collection::create(path, 4, DistanceMetric::Cosine).unwrap();
978
979 let points: Vec<Point> = (0..100)
981 .map(|i| {
982 let content = if i % 10 == 0 {
983 format!("rust document number {i}")
984 } else {
985 format!("other document number {i}")
986 };
987 Point::new(
988 i,
989 vec![0.1, 0.2, 0.3, 0.4],
990 Some(json!({"content": content})),
991 )
992 })
993 .collect();
994 collection.upsert(points).unwrap();
995
996 let results = collection.text_search("rust", 100);
998 assert_eq!(results.len(), 10);
999
1000 for result in &results {
1002 assert_eq!(result.point.id % 10, 0);
1003 }
1004 }
1005
1006 #[test]
1007 fn test_bm25_persistence_on_reopen() {
1008 let dir = tempdir().unwrap();
1009 let path = dir.path().join("test_collection");
1010
1011 {
1013 let collection = Collection::create(path.clone(), 4, DistanceMetric::Cosine).unwrap();
1014
1015 let points = vec![
1016 Point::new(
1017 1,
1018 vec![1.0, 0.0, 0.0, 0.0],
1019 Some(json!({"content": "Rust programming language"})),
1020 ),
1021 Point::new(
1022 2,
1023 vec![0.0, 1.0, 0.0, 0.0],
1024 Some(json!({"content": "Python tutorial"})),
1025 ),
1026 Point::new(
1027 3,
1028 vec![0.0, 0.0, 1.0, 0.0],
1029 Some(json!({"content": "Rust is fast and safe"})),
1030 ),
1031 ];
1032 collection.upsert(points).unwrap();
1033
1034 let results = collection.text_search("rust", 10);
1036 assert_eq!(results.len(), 2);
1037 }
1038
1039 {
1041 let collection = Collection::open(path).unwrap();
1042
1043 let results = collection.text_search("rust", 10);
1045 assert_eq!(results.len(), 2);
1046
1047 let ids: Vec<u64> = results.iter().map(|r| r.point.id).collect();
1048 assert!(ids.contains(&1));
1049 assert!(ids.contains(&3));
1050 }
1051 }
1052}