1use crate::distance::DistanceMetric;
4use crate::error::{Error, Result};
5use crate::index::{HnswIndex, VectorIndex};
6use crate::point::{Point, SearchResult};
7use crate::storage::{LogPayloadStorage, MmapStorage, PayloadStorage, VectorStorage};
8
9use parking_lot::RwLock;
10use serde::{Deserialize, Serialize};
11use std::path::PathBuf;
12use std::sync::Arc;
13
14#[derive(Debug, Clone, Serialize, Deserialize)]
16pub struct CollectionConfig {
17 pub name: String,
19
20 pub dimension: usize,
22
23 pub metric: DistanceMetric,
25
26 pub point_count: usize,
28}
29
30#[derive(Clone)]
32pub struct Collection {
33 path: PathBuf,
35
36 config: Arc<RwLock<CollectionConfig>>,
38
39 vector_storage: Arc<RwLock<MmapStorage>>,
41
42 payload_storage: Arc<RwLock<LogPayloadStorage>>,
44
45 index: Arc<HnswIndex>,
47}
48
49impl Collection {
50 pub fn create(path: PathBuf, dimension: usize, metric: DistanceMetric) -> Result<Self> {
56 std::fs::create_dir_all(&path)?;
57
58 let name = path
59 .file_name()
60 .and_then(|n| n.to_str())
61 .unwrap_or("unknown")
62 .to_string();
63
64 let config = CollectionConfig {
65 name,
66 dimension,
67 metric,
68 point_count: 0,
69 };
70
71 let vector_storage = Arc::new(RwLock::new(
73 MmapStorage::new(&path, dimension).map_err(Error::Io)?,
74 ));
75
76 let payload_storage = Arc::new(RwLock::new(
77 LogPayloadStorage::new(&path).map_err(Error::Io)?,
78 ));
79
80 let index = Arc::new(HnswIndex::new(dimension, metric));
82
83 let collection = Self {
84 path,
85 config: Arc::new(RwLock::new(config)),
86 vector_storage,
87 payload_storage,
88 index,
89 };
90
91 collection.save_config()?;
92
93 Ok(collection)
94 }
95
96 pub fn open(path: PathBuf) -> Result<Self> {
102 let config_path = path.join("config.json");
103 let config_data = std::fs::read_to_string(&config_path)?;
104 let config: CollectionConfig =
105 serde_json::from_str(&config_data).map_err(|e| Error::Serialization(e.to_string()))?;
106
107 let vector_storage = Arc::new(RwLock::new(
109 MmapStorage::new(&path, config.dimension).map_err(Error::Io)?,
110 ));
111
112 let payload_storage = Arc::new(RwLock::new(
113 LogPayloadStorage::new(&path).map_err(Error::Io)?,
114 ));
115
116 let index = if path.join("hnsw.bin").exists() {
118 Arc::new(HnswIndex::load(&path, config.dimension, config.metric).map_err(Error::Io)?)
119 } else {
120 Arc::new(HnswIndex::new(config.dimension, config.metric))
121 };
122
123 Ok(Self {
124 path,
125 config: Arc::new(RwLock::new(config)),
126 vector_storage,
127 payload_storage,
128 index,
129 })
130 }
131
132 #[must_use]
134 pub fn config(&self) -> CollectionConfig {
135 self.config.read().clone()
136 }
137
138 pub fn upsert(&self, points: Vec<Point>) -> Result<()> {
144 let config = self.config.read();
145 let dimension = config.dimension;
146 drop(config);
147
148 for point in &points {
150 if point.dimension() != dimension {
151 return Err(Error::DimensionMismatch {
152 expected: dimension,
153 actual: point.dimension(),
154 });
155 }
156 }
157
158 let mut vector_storage = self.vector_storage.write();
159 let mut payload_storage = self.payload_storage.write();
160
161 for point in points {
162 vector_storage
164 .store(point.id, &point.vector)
165 .map_err(Error::Io)?;
166
167 if let Some(payload) = &point.payload {
169 payload_storage
170 .store(point.id, payload)
171 .map_err(Error::Io)?;
172 } else {
173 let _ = payload_storage.delete(point.id); }
180
181 self.index.insert(point.id, &point.vector);
185 }
186
187 let mut config = self.config.write();
189 config.point_count = vector_storage.len();
190
191 vector_storage.flush().map_err(Error::Io)?;
194 payload_storage.flush().map_err(Error::Io)?;
195 self.index.save(&self.path).map_err(Error::Io)?;
196
197 Ok(())
198 }
199
200 #[must_use]
202 pub fn get(&self, ids: &[u64]) -> Vec<Option<Point>> {
203 let vector_storage = self.vector_storage.read();
204 let payload_storage = self.payload_storage.read();
205
206 ids.iter()
207 .map(|&id| {
208 let vector = vector_storage.retrieve(id).ok().flatten()?;
210
211 let payload = payload_storage.retrieve(id).ok().flatten();
213
214 Some(Point {
215 id,
216 vector,
217 payload,
218 })
219 })
220 .collect()
221 }
222
223 pub fn delete(&self, ids: &[u64]) -> Result<()> {
229 let mut vector_storage = self.vector_storage.write();
230 let mut payload_storage = self.payload_storage.write();
231
232 for &id in ids {
233 vector_storage.delete(id).map_err(Error::Io)?;
234 payload_storage.delete(id).map_err(Error::Io)?;
235 self.index.remove(id);
236 }
237
238 let mut config = self.config.write();
239 config.point_count = vector_storage.len();
240
241 Ok(())
242 }
243
244 pub fn search(&self, query: &[f32], k: usize) -> Result<Vec<SearchResult>> {
252 let config = self.config.read();
253
254 if query.len() != config.dimension {
255 return Err(Error::DimensionMismatch {
256 expected: config.dimension,
257 actual: query.len(),
258 });
259 }
260 drop(config);
261
262 let index_results = self.index.search(query, k);
264
265 let vector_storage = self.vector_storage.read();
266 let payload_storage = self.payload_storage.read();
267
268 let results: Vec<SearchResult> = index_results
270 .into_iter()
271 .filter_map(|(id, score)| {
272 let vector = vector_storage.retrieve(id).ok().flatten()?;
274 let payload = payload_storage.retrieve(id).ok().flatten();
275
276 let point = Point {
277 id,
278 vector,
279 payload,
280 };
281
282 Some(SearchResult::new(point, score))
283 })
284 .collect();
285
286 Ok(results)
287 }
288
289 #[must_use]
291 pub fn len(&self) -> usize {
292 self.vector_storage.read().len()
293 }
294
295 #[must_use]
297 pub fn is_empty(&self) -> bool {
298 self.vector_storage.read().is_empty()
299 }
300
301 pub fn flush(&self) -> Result<()> {
307 self.save_config()?;
308 self.vector_storage.write().flush().map_err(Error::Io)?;
309 self.payload_storage.write().flush().map_err(Error::Io)?;
310 self.index.save(&self.path).map_err(Error::Io)?;
311 Ok(())
312 }
313
314 fn save_config(&self) -> Result<()> {
316 let config = self.config.read();
317 let config_path = self.path.join("config.json");
318 let config_data = serde_json::to_string_pretty(&*config)
319 .map_err(|e| Error::Serialization(e.to_string()))?;
320 std::fs::write(config_path, config_data)?;
321 Ok(())
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use serde_json::json;
329 use tempfile::tempdir;
330
331 #[test]
332 fn test_collection_create() {
333 let dir = tempdir().unwrap();
334 let path = dir.path().join("test_collection");
335
336 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
337 let config = collection.config();
338
339 assert_eq!(config.dimension, 3);
340 assert_eq!(config.metric, DistanceMetric::Cosine);
341 assert_eq!(config.point_count, 0);
342 }
343
344 #[test]
345 fn test_collection_upsert_and_search() {
346 let dir = tempdir().unwrap();
347 let path = dir.path().join("test_collection");
348
349 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
350
351 let points = vec![
352 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
353 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
354 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
355 ];
356
357 collection.upsert(points).unwrap();
358 assert_eq!(collection.len(), 3);
359
360 let query = vec![1.0, 0.0, 0.0];
361 let results = collection.search(&query, 2).unwrap();
362
363 assert_eq!(results.len(), 2);
364 assert_eq!(results[0].point.id, 1); }
366
367 #[test]
368 fn test_dimension_mismatch() {
369 let dir = tempdir().unwrap();
370 let path = dir.path().join("test_collection");
371
372 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
373
374 let points = vec![Point::without_payload(1, vec![1.0, 0.0])]; let result = collection.upsert(points);
377 assert!(result.is_err());
378 }
379
380 #[test]
381 fn test_collection_open_existing() {
382 let dir = tempdir().unwrap();
383 let path = dir.path().join("test_collection");
384
385 {
387 let collection =
388 Collection::create(path.clone(), 3, DistanceMetric::Euclidean).unwrap();
389 let points = vec![
390 Point::without_payload(1, vec![1.0, 2.0, 3.0]),
391 Point::without_payload(2, vec![4.0, 5.0, 6.0]),
392 ];
393 collection.upsert(points).unwrap();
394 collection.flush().unwrap();
395 }
396
397 let collection = Collection::open(path).unwrap();
399 let config = collection.config();
400
401 assert_eq!(config.dimension, 3);
402 assert_eq!(config.metric, DistanceMetric::Euclidean);
403 assert_eq!(collection.len(), 2);
404 }
405
406 #[test]
407 fn test_collection_get_points() {
408 let dir = tempdir().unwrap();
409 let path = dir.path().join("test_collection");
410
411 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
412 let points = vec![
413 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
414 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
415 ];
416 collection.upsert(points).unwrap();
417
418 let retrieved = collection.get(&[1, 2, 999]);
420
421 assert!(retrieved[0].is_some());
422 assert_eq!(retrieved[0].as_ref().unwrap().id, 1);
423 assert!(retrieved[1].is_some());
424 assert_eq!(retrieved[1].as_ref().unwrap().id, 2);
425 assert!(retrieved[2].is_none()); }
427
428 #[test]
429 fn test_collection_delete_points() {
430 let dir = tempdir().unwrap();
431 let path = dir.path().join("test_collection");
432
433 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
434 let points = vec![
435 Point::without_payload(1, vec![1.0, 0.0, 0.0]),
436 Point::without_payload(2, vec![0.0, 1.0, 0.0]),
437 Point::without_payload(3, vec![0.0, 0.0, 1.0]),
438 ];
439 collection.upsert(points).unwrap();
440 assert_eq!(collection.len(), 3);
441
442 collection.delete(&[2]).unwrap();
444 assert_eq!(collection.len(), 2);
445
446 let retrieved = collection.get(&[2]);
448 assert!(retrieved[0].is_none());
449 }
450
451 #[test]
452 fn test_collection_is_empty() {
453 let dir = tempdir().unwrap();
454 let path = dir.path().join("test_collection");
455
456 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
457 assert!(collection.is_empty());
458
459 collection
460 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
461 .unwrap();
462 assert!(!collection.is_empty());
463 }
464
465 #[test]
466 fn test_collection_with_payload() {
467 let dir = tempdir().unwrap();
468 let path = dir.path().join("test_collection");
469
470 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
471
472 let points = vec![Point::new(
473 1,
474 vec![1.0, 0.0, 0.0],
475 Some(json!({"title": "Test Document", "category": "tech"})),
476 )];
477 collection.upsert(points).unwrap();
478
479 let retrieved = collection.get(&[1]);
480 assert!(retrieved[0].is_some());
481
482 let point = retrieved[0].as_ref().unwrap();
483 assert!(point.payload.is_some());
484 assert_eq!(point.payload.as_ref().unwrap()["title"], "Test Document");
485 }
486
487 #[test]
488 fn test_collection_search_dimension_mismatch() {
489 let dir = tempdir().unwrap();
490 let path = dir.path().join("test_collection");
491
492 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
493 collection
494 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
495 .unwrap();
496
497 let result = collection.search(&[1.0, 0.0], 5);
499 assert!(result.is_err());
500 }
501
502 #[test]
503 fn test_collection_upsert_replaces_payload() {
504 let dir = tempdir().unwrap();
505 let path = dir.path().join("test_collection");
506
507 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
508
509 collection
511 .upsert(vec![Point::new(
512 1,
513 vec![1.0, 0.0, 0.0],
514 Some(json!({"version": 1})),
515 )])
516 .unwrap();
517
518 collection
520 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
521 .unwrap();
522
523 let retrieved = collection.get(&[1]);
524 let point = retrieved[0].as_ref().unwrap();
525 assert!(point.payload.is_none());
526 }
527
528 #[test]
529 fn test_collection_flush() {
530 let dir = tempdir().unwrap();
531 let path = dir.path().join("test_collection");
532
533 let collection = Collection::create(path, 3, DistanceMetric::Cosine).unwrap();
534 collection
535 .upsert(vec![Point::without_payload(1, vec![1.0, 0.0, 0.0])])
536 .unwrap();
537
538 let result = collection.flush();
540 assert!(result.is_ok());
541 }
542
543 #[test]
544 fn test_collection_euclidean_metric() {
545 let dir = tempdir().unwrap();
546 let path = dir.path().join("test_collection");
547
548 let collection = Collection::create(path, 3, DistanceMetric::Euclidean).unwrap();
549
550 let points = vec![
551 Point::without_payload(1, vec![0.0, 0.0, 0.0]),
552 Point::without_payload(2, vec![1.0, 0.0, 0.0]),
553 Point::without_payload(3, vec![10.0, 0.0, 0.0]),
554 ];
555 collection.upsert(points).unwrap();
556
557 let query = vec![0.5, 0.0, 0.0];
558 let results = collection.search(&query, 3).unwrap();
559
560 assert!(results[0].point.id == 1 || results[0].point.id == 2);
562 }
563}