1use std::collections::HashMap;
37use std::sync::{Arc, RwLock};
38use std::sync::atomic::{AtomicU64, Ordering};
39use std::path::Path;
40use crate::{VectorIndexWrapper, VectorIndex, Vector, SearchResult, SimilarityMetric, EmbeddingFunction, PersistenceError, save_collection_to_file, load_collection_from_file};
41use crate::errors::{VectorLiteError, VectorLiteResult};
42
43pub struct VectorLiteClient {
66 collections: HashMap<String, CollectionRef>,
67 embedding_function: Arc<dyn EmbeddingFunction>,
68}
69
70pub struct Settings {}
74
75impl VectorLiteClient {
76 pub fn new(embedding_function: Box<dyn EmbeddingFunction>) -> Self {
77 Self {
78 collections: HashMap::new(),
79 embedding_function: Arc::from(embedding_function),
80 }
81 }
82
83 pub fn create_collection(&mut self, name: &str, index_type: IndexType) -> VectorLiteResult<()> {
84 if self.collections.contains_key(name) {
85 return Err(VectorLiteError::CollectionAlreadyExists { name: name.to_string() });
86 }
87
88 let dimension = self.embedding_function.dimension();
89 let index = match index_type {
90 IndexType::Flat => VectorIndexWrapper::Flat(crate::FlatIndex::new(dimension, Vec::new())),
91 IndexType::HNSW => VectorIndexWrapper::HNSW(Box::new(crate::HNSWIndex::new(dimension))),
92 };
93
94 let collection = Collection {
95 name: name.to_string(),
96 index: Arc::new(RwLock::new(index)),
97 next_id: Arc::new(AtomicU64::new(0)),
98 };
99
100 self.collections.insert(name.to_string(), Arc::new(collection));
101 Ok(())
102 }
103
104 pub fn get_collection(&self, name: &str) -> Option<&CollectionRef> {
105 self.collections.get(name)
106 }
107
108 pub fn list_collections(&self) -> Vec<String> {
109 self.collections.keys().cloned().collect()
110 }
111
112 pub fn delete_collection(&mut self, name: &str) -> VectorLiteResult<()> {
113 if self.collections.remove(name).is_some() {
114 Ok(())
115 } else {
116 Err(VectorLiteError::CollectionNotFound { name: name.to_string() })
117 }
118 }
119
120 pub fn has_collection(&self, name: &str) -> bool {
121 self.collections.contains_key(name)
122 }
123
124 pub fn add_text_to_collection(&self, collection_name: &str, text: &str, metadata: Option<serde_json::Value>) -> VectorLiteResult<u64> {
125 let collection = self.collections.get(collection_name)
126 .ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
127
128 collection.add_text_with_metadata(text, metadata, self.embedding_function.as_ref())
129 }
130
131
132 pub fn search_text_in_collection(&self, collection_name: &str, query_text: &str, k: usize, similarity_metric: SimilarityMetric) -> VectorLiteResult<Vec<SearchResult>> {
133 let collection = self.collections.get(collection_name)
134 .ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
135
136 collection.search_text(query_text, k, similarity_metric, self.embedding_function.as_ref())
137 }
138
139
140 pub fn delete_from_collection(&self, collection_name: &str, id: u64) -> VectorLiteResult<()> {
141 let collection = self.collections.get(collection_name)
142 .ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
143
144 collection.delete(id)
145 }
146
147 pub fn get_vector_from_collection(&self, collection_name: &str, id: u64) -> VectorLiteResult<Option<Vector>> {
148 let collection = self.collections.get(collection_name)
149 .ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
150
151 collection.get_vector(id)
152 }
153
154 pub fn get_collection_info(&self, collection_name: &str) -> VectorLiteResult<CollectionInfo> {
155 let collection = self.collections.get(collection_name)
156 .ok_or_else(|| VectorLiteError::CollectionNotFound { name: collection_name.to_string() })?;
157
158 collection.get_info()
159 }
160
161 pub fn add_collection(&mut self, collection: Collection) -> VectorLiteResult<()> {
163 let name = collection.name().to_string();
164 if self.collections.contains_key(&name) {
165 return Err(VectorLiteError::CollectionAlreadyExists { name });
166 }
167 self.collections.insert(name, Arc::new(collection));
168 Ok(())
169 }
170
171}
172
173#[derive(Debug, Clone, Copy)]
195pub enum IndexType {
196 Flat,
203 HNSW,
210}
211
212pub struct Collection {
222 name: String,
223 index: Arc<RwLock<VectorIndexWrapper>>,
224 next_id: Arc<AtomicU64>,
225}
226
227type CollectionRef = Arc<Collection>;
229
230#[derive(Debug, Clone, serde::Serialize)]
251pub struct CollectionInfo {
252 pub name: String,
254 pub count: usize,
256 pub is_empty: bool,
258 pub dimension: usize,
260}
261
262impl std::fmt::Debug for Collection {
263 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
264 f.debug_struct("Collection")
265 .field("name", &self.name)
266 .field("next_id", &self.next_id.load(Ordering::Relaxed))
267 .finish()
268 }
269}
270
271impl Collection {
272 pub fn new(name: String, index: VectorIndexWrapper) -> Self {
274 let next_id = match &index {
276 VectorIndexWrapper::Flat(flat_index) => {
277 flat_index.max_id()
278 .map(|max_id| max_id + 1)
279 .unwrap_or(0)
280 }
281 VectorIndexWrapper::HNSW(hnsw_index) => {
282 hnsw_index.max_id()
283 .map(|max_id| max_id + 1)
284 .unwrap_or(0)
285 }
286 };
287
288 Self {
289 name,
290 index: Arc::new(RwLock::new(index)),
291 next_id: Arc::new(AtomicU64::new(next_id)),
292 }
293 }
294
295 pub fn add_text(&self, text: &str, embedding_function: &dyn EmbeddingFunction) -> VectorLiteResult<u64> {
296 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
297
298 let embedding = embedding_function.generate_embedding(text)?;
300
301 let vector = Vector {
302 id,
303 values: embedding,
304 text: text.to_string(),
305 metadata: None
306 };
307 let vector_dimension = vector.values.len();
308 let vector_id = vector.id;
309
310 let mut index = self.index.write().map_err(|_| VectorLiteError::LockError("Failed to acquire write lock for add_text".to_string()))?;
312 index.add(vector).map_err(|e| {
313 if e.contains("dimension") {
314 VectorLiteError::DimensionMismatch {
315 expected: index.dimension(),
316 actual: vector_dimension
317 }
318 } else if e.contains("already exists") {
319 VectorLiteError::DuplicateVectorId { id: vector_id }
320 } else {
321 VectorLiteError::InternalError(e)
322 }
323 })?;
324 Ok(id)
325 }
326
327 pub fn add_text_with_metadata(&self, text: &str, metadata: Option<serde_json::Value>, embedding_function: &dyn EmbeddingFunction) -> VectorLiteResult<u64> {
328 let id = self.next_id.fetch_add(1, Ordering::Relaxed);
329
330 let embedding = embedding_function.generate_embedding(text)?;
332
333 let vector = Vector {
334 id,
335 values: embedding,
336 text: text.to_string(),
337 metadata
338 };
339 let vector_dimension = vector.values.len();
340 let vector_id = vector.id;
341
342 let mut index = self.index.write().map_err(|_| VectorLiteError::LockError("Failed to acquire write lock for add_text_with_metadata".to_string()))?;
344 index.add(vector).map_err(|e| {
345 if e.contains("dimension") {
346 VectorLiteError::DimensionMismatch {
347 expected: index.dimension(),
348 actual: vector_dimension
349 }
350 } else if e.contains("already exists") {
351 VectorLiteError::DuplicateVectorId { id: vector_id }
352 } else {
353 VectorLiteError::InternalError(e)
354 }
355 })?;
356 Ok(id)
357 }
358
359
360 pub fn delete(&self, id: u64) -> VectorLiteResult<()> {
361 let mut index = self.index.write().map_err(|_| VectorLiteError::LockError("Failed to acquire write lock for delete".to_string()))?;
362 index.delete(id).map_err(|e| {
363 if e.contains("does not exist") {
364 VectorLiteError::VectorNotFound { id }
365 } else {
366 VectorLiteError::InternalError(e)
367 }
368 })
369 }
370
371 pub fn search_text(&self, query_text: &str, k: usize, similarity_metric: SimilarityMetric, embedding_function: &dyn EmbeddingFunction) -> VectorLiteResult<Vec<SearchResult>> {
372 let query_embedding = embedding_function.generate_embedding(query_text)?;
374
375 let index = self.index.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for search_text".to_string()))?;
377 Ok(index.search(&query_embedding, k, similarity_metric))
378 }
379
380
381 pub fn get_vector(&self, id: u64) -> VectorLiteResult<Option<Vector>> {
382 let index = self.index.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for get_vector".to_string()))?;
383 Ok(index.get_vector(id).cloned())
384 }
385
386 pub fn get_info(&self) -> VectorLiteResult<CollectionInfo> {
387 let index = self.index.read().map_err(|_| VectorLiteError::LockError("Failed to acquire read lock for get_info".to_string()))?;
388 Ok(CollectionInfo {
389 name: self.name.clone(),
390 count: index.len(),
391 is_empty: index.is_empty(),
392 dimension: index.dimension(),
393 })
394 }
395
396 pub fn name(&self) -> &str {
397 &self.name
398 }
399
400 pub fn next_id(&self) -> u64 {
402 self.next_id.load(Ordering::Relaxed)
403 }
404
405 pub fn index_read(&self) -> Result<std::sync::RwLockReadGuard<'_, VectorIndexWrapper>, String> {
407 self.index.read().map_err(|_| "Failed to acquire read lock".to_string())
408 }
409
410 pub fn save_to_file(&self, path: &Path) -> Result<(), PersistenceError> {
441 save_collection_to_file(self, path)
442 }
443
444 pub fn load_from_file(path: &Path) -> Result<Self, PersistenceError> {
472 load_collection_from_file(path)
473 }
474}
475
476#[cfg(test)]
477mod tests {
478 use super::*;
479
480 struct MockEmbeddingFunction {
482 dimension: usize,
483 }
484
485 impl MockEmbeddingFunction {
486 fn new(dimension: usize) -> Self {
487 Self { dimension }
488 }
489 }
490
491 impl EmbeddingFunction for MockEmbeddingFunction {
492 fn generate_embedding(&self, _text: &str) -> crate::embeddings::Result<Vec<f64>> {
493 Ok(vec![1.0; self.dimension])
495 }
496
497 fn dimension(&self) -> usize {
498 self.dimension
499 }
500 }
501
502 #[test]
503 fn test_client_creation() {
504 let embedding_fn = MockEmbeddingFunction::new(3);
505 let client = VectorLiteClient::new(Box::new(embedding_fn));
506
507 assert!(client.collections.is_empty());
508 assert!(client.list_collections().is_empty());
509 }
510
511 #[test]
512 fn test_create_collection() {
513 let embedding_fn = MockEmbeddingFunction::new(3);
514 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
515
516 let result = client.create_collection("test_collection", IndexType::Flat);
518 assert!(result.is_ok());
519
520 assert!(client.has_collection("test_collection"));
522 assert_eq!(client.list_collections(), vec!["test_collection"]);
523 }
524
525 #[test]
526 fn test_create_duplicate_collection() {
527 let embedding_fn = MockEmbeddingFunction::new(3);
528 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
529
530 client.create_collection("test_collection", IndexType::Flat).unwrap();
532
533 let result = client.create_collection("test_collection", IndexType::Flat);
535 assert!(result.is_err());
536 assert!(matches!(result.unwrap_err(), VectorLiteError::CollectionAlreadyExists { .. }));
537 }
538
539 #[test]
540 fn test_get_collection() {
541 let embedding_fn = MockEmbeddingFunction::new(3);
542 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
543
544 client.create_collection("test_collection", IndexType::Flat).unwrap();
546
547 let collection = client.get_collection("test_collection");
549 assert!(collection.is_some());
550 assert_eq!(collection.unwrap().name(), "test_collection");
551
552 let collection = client.get_collection("non_existent");
554 assert!(collection.is_none());
555 }
556
557 #[test]
558 fn test_delete_collection() {
559 let embedding_fn = MockEmbeddingFunction::new(3);
560 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
561
562 client.create_collection("test_collection", IndexType::Flat).unwrap();
564 assert!(client.has_collection("test_collection"));
565
566 let result = client.delete_collection("test_collection");
568 assert!(result.is_ok());
569 assert!(!client.has_collection("test_collection"));
570
571 let result = client.delete_collection("non_existent");
573 assert!(result.is_err());
574 }
575
576 #[test]
577 fn test_add_text_to_collection() {
578 let embedding_fn = MockEmbeddingFunction::new(3);
579 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
580
581 client.create_collection("test_collection", IndexType::Flat).unwrap();
583
584 let result = client.add_text_to_collection("test_collection", "Hello world", None);
586 assert!(result.is_ok());
587 let id = result.unwrap();
588 assert_eq!(id, 0); let result = client.add_text_to_collection("test_collection", "Another text", None);
592 assert!(result.is_ok());
593 let id = result.unwrap();
594 assert_eq!(id, 1);
595
596 let info = client.get_collection_info("test_collection").unwrap();
598 assert_eq!(info.count, 2);
599 }
600
601 #[test]
602 fn test_add_text_to_nonexistent_collection() {
603 let embedding_fn = MockEmbeddingFunction::new(3);
604 let client = VectorLiteClient::new(Box::new(embedding_fn));
605
606 let result = client.add_text_to_collection("non_existent", "Hello world", None);
608 assert!(result.is_err());
609 assert!(matches!(result.unwrap_err(), VectorLiteError::CollectionNotFound { .. }));
610 }
611
612 #[test]
613 fn test_collection_operations() {
614 let embedding_fn = MockEmbeddingFunction::new(3);
615 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
616
617 client.create_collection("test_collection", IndexType::Flat).unwrap();
619
620 let info = client.get_collection_info("test_collection").unwrap();
622 assert!(info.is_empty);
623 assert_eq!(info.count, 0);
624 assert_eq!(info.name, "test_collection");
625
626 let id = client.add_text_to_collection("test_collection", "Hello world", None).unwrap();
628 assert_eq!(id, 0);
629
630 let info = client.get_collection_info("test_collection").unwrap();
631 assert!(!info.is_empty);
632 assert_eq!(info.count, 1);
633
634 let id = client.add_text_to_collection("test_collection", "Another text", None).unwrap();
636 assert_eq!(id, 1);
637
638 let info = client.get_collection_info("test_collection").unwrap();
639 assert_eq!(info.count, 2);
640
641 let results = client.search_text_in_collection("test_collection", "Hello", 1, SimilarityMetric::Cosine).unwrap();
643 assert_eq!(results.len(), 1);
644 assert_eq!(results[0].id, 0);
645
646 let vector = client.get_vector_from_collection("test_collection", 0).unwrap();
648 assert!(vector.is_some());
649 assert_eq!(vector.unwrap().id, 0);
650
651 client.delete_from_collection("test_collection", 0).unwrap();
653
654 let info = client.get_collection_info("test_collection").unwrap();
655 assert_eq!(info.count, 1);
656
657 let vector = client.get_vector_from_collection("test_collection", 0).unwrap();
659 assert!(vector.is_none());
660 }
661
662 #[test]
663 fn test_collection_with_hnsw_index() {
664 let embedding_fn = MockEmbeddingFunction::new(3);
665 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
666
667 client.create_collection("hnsw_collection", IndexType::HNSW).unwrap();
669
670 let id1 = client.add_text_to_collection("hnsw_collection", "First document", None).unwrap();
672 let id2 = client.add_text_to_collection("hnsw_collection", "Second document", None).unwrap();
673
674 assert_eq!(id1, 0);
675 assert_eq!(id2, 1);
676
677 let info = client.get_collection_info("hnsw_collection").unwrap();
678 assert_eq!(info.count, 2);
679
680 let results = client.search_text_in_collection("hnsw_collection", "First", 1, SimilarityMetric::Cosine).unwrap();
682 assert_eq!(results.len(), 1);
683 }
684
685
686
687 #[test]
688 fn test_collection_save_and_load() {
689 let embedding_fn = MockEmbeddingFunction::new(3);
690 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
691
692 client.create_collection("test_collection", IndexType::Flat).unwrap();
694 client.add_text_to_collection("test_collection", "Hello world", None).unwrap();
695 client.add_text_to_collection("test_collection", "Another text", None).unwrap();
696
697 let collection = client.get_collection("test_collection").unwrap();
698
699 let temp_dir = tempfile::TempDir::new().unwrap();
701 let file_path = temp_dir.path().join("test_collection.vlc");
702
703 collection.save_to_file(&file_path).unwrap();
704 assert!(file_path.exists());
705
706 let loaded_collection = Collection::load_from_file(&file_path).unwrap();
708
709 assert_eq!(loaded_collection.name(), "test_collection");
711
712 let info = loaded_collection.get_info().unwrap();
714 assert_eq!(info.count, 2);
715 assert_eq!(info.dimension, 3);
716 assert!(!info.is_empty);
717
718 let test_embedding_fn = MockEmbeddingFunction::new(3);
720 let results = loaded_collection.search_text("Hello", 2, SimilarityMetric::Cosine, &test_embedding_fn).unwrap();
721 assert_eq!(results.len(), 2);
722 }
723
724 #[test]
725 fn test_collection_save_and_load_hnsw() {
726 let embedding_fn = MockEmbeddingFunction::new(3);
727 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
728
729 client.create_collection("test_hnsw_collection", IndexType::HNSW).unwrap();
731 client.add_text_to_collection("test_hnsw_collection", "First document", None).unwrap();
732 client.add_text_to_collection("test_hnsw_collection", "Second document", None).unwrap();
733
734 let collection = client.get_collection("test_hnsw_collection").unwrap();
735
736 let info = collection.get_info().unwrap();
738 assert_eq!(info.count, 2);
739 assert_eq!(info.dimension, 3);
740
741 let test_embedding_fn = MockEmbeddingFunction::new(3);
743
744 let results = collection.search_text("First", 1, SimilarityMetric::Cosine, &test_embedding_fn).unwrap();
746 assert_eq!(results.len(), 1);
747
748 let temp_dir = tempfile::TempDir::new().unwrap();
750 let file_path = temp_dir.path().join("test_hnsw_collection.vlc");
751
752 collection.save_to_file(&file_path).unwrap();
753 assert!(file_path.exists());
754
755 let loaded_collection = Collection::load_from_file(&file_path).unwrap();
757
758 assert_eq!(loaded_collection.name(), "test_hnsw_collection");
760
761 let info = loaded_collection.get_info().unwrap();
763 assert_eq!(info.count, 2);
764 assert_eq!(info.dimension, 3);
765 assert!(!info.is_empty);
766
767 let results = loaded_collection.search_text("First", 1, SimilarityMetric::Cosine, &test_embedding_fn).unwrap();
769 assert_eq!(results.len(), 1);
770 }
771
772 #[test]
773 fn test_collection_save_nonexistent_directory() {
774 let embedding_fn = MockEmbeddingFunction::new(3);
775 let mut client = VectorLiteClient::new(Box::new(embedding_fn));
776
777 client.create_collection("test_collection", IndexType::Flat).unwrap();
778 client.add_text_to_collection("test_collection", "Hello world", None).unwrap();
779
780 let collection = client.get_collection("test_collection").unwrap();
781
782 let temp_dir = tempfile::TempDir::new().unwrap();
784 let file_path = temp_dir.path().join("nonexistent").join("test_collection.vlc");
785
786 let result = collection.save_to_file(&file_path);
787 assert!(result.is_ok());
788 assert!(file_path.exists());
789 }
790
791 #[test]
792 fn test_collection_load_nonexistent_file() {
793 let temp_dir = tempfile::TempDir::new().unwrap();
794 let file_path = temp_dir.path().join("nonexistent.vlc");
795
796 let result = Collection::load_from_file(&file_path);
797 assert!(result.is_err());
798 assert!(matches!(result.unwrap_err(), PersistenceError::Io(_)));
799 }
800
801 #[test]
802 fn test_collection_load_invalid_json() {
803 let temp_dir = tempfile::TempDir::new().unwrap();
804 let file_path = temp_dir.path().join("invalid.vlc");
805
806 std::fs::write(&file_path, "invalid json content").unwrap();
808
809 let result = Collection::load_from_file(&file_path);
810 assert!(result.is_err());
811 assert!(matches!(result.unwrap_err(), PersistenceError::Serialization(_)));
812 }
813}