1use crate::{Error, Point, Result, Vector, HnswIndex, BM25Index, Filter, MultiVector};
2use parking_lot::RwLock;
3use std::collections::HashMap;
4use std::sync::Arc;
5use std::sync::atomic::{AtomicBool, Ordering};
6
7#[derive(Debug, Clone)]
9pub struct CollectionConfig {
10 pub name: String,
11 pub vector_dim: usize,
12 pub distance: Distance,
13 pub use_hnsw: bool,
14 pub enable_bm25: bool,
15}
16
17impl Default for CollectionConfig {
18 fn default() -> Self {
19 Self {
20 name: String::new(),
21 vector_dim: 128,
22 distance: Distance::Cosine,
23 use_hnsw: true,
24 enable_bm25: false,
25 }
26 }
27}
28
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum Distance {
31 Cosine,
32 Euclidean,
33 Dot,
34}
35
36#[derive(Debug, Clone, PartialEq, Eq)]
38pub enum PayloadIndexType {
39 Keyword,
40 Integer,
41 Float,
42 Bool,
43 Geo,
44 Text,
45}
46
47pub struct Collection {
49 config: CollectionConfig,
50 points: Arc<RwLock<HashMap<String, Point>>>,
51 hnsw: Option<Arc<RwLock<HnswIndex>>>,
52 bm25: Option<Arc<RwLock<BM25Index>>>,
53 hnsw_built: Arc<RwLock<bool>>,
54 hnsw_rebuilding: Arc<AtomicBool>,
55 batch_mode: Arc<RwLock<bool>>,
56 pending_points: Arc<RwLock<Vec<Point>>>,
57 payload_indexes: Arc<RwLock<HashMap<String, PayloadIndexType>>>,
59 operation_counter: Arc<std::sync::atomic::AtomicU64>,
61}
62
63impl Collection {
64 pub fn new(config: CollectionConfig) -> Self {
65 let hnsw = if config.use_hnsw {
66 Some(Arc::new(RwLock::new(HnswIndex::new(16, 3))))
67 } else {
68 None
69 };
70
71 let bm25 = if config.enable_bm25 {
72 Some(Arc::new(RwLock::new(BM25Index::new())))
73 } else {
74 None
75 };
76
77 Self {
78 config,
79 points: Arc::new(RwLock::new(HashMap::new())),
80 hnsw,
81 bm25,
82 hnsw_built: Arc::new(RwLock::new(false)),
83 hnsw_rebuilding: Arc::new(AtomicBool::new(false)),
84 batch_mode: Arc::new(RwLock::new(false)),
85 pending_points: Arc::new(RwLock::new(Vec::new())),
86 payload_indexes: Arc::new(RwLock::new(HashMap::new())),
87 operation_counter: Arc::new(std::sync::atomic::AtomicU64::new(0)),
88 }
89 }
90
91 #[inline]
93 pub fn next_operation_id(&self) -> u64 {
94 self.operation_counter.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
95 }
96
97 #[inline]
98 #[must_use]
99 pub fn name(&self) -> &str {
100 &self.config.name
101 }
102
103 #[inline]
104 #[must_use]
105 pub fn vector_dim(&self) -> usize {
106 self.config.vector_dim
107 }
108
109 #[inline]
110 #[must_use]
111 pub fn distance(&self) -> Distance {
112 self.config.distance
113 }
114
115 #[inline]
116 #[must_use]
117 pub fn use_hnsw(&self) -> bool {
118 self.config.use_hnsw
119 }
120
121 #[inline]
122 #[must_use]
123 pub fn enable_bm25(&self) -> bool {
124 self.config.enable_bm25
125 }
126
127 #[inline]
128 #[must_use]
129 pub fn count(&self) -> usize {
130 self.points.read().len()
131 }
132
133 #[inline]
134 #[must_use]
135 pub fn is_empty(&self) -> bool {
136 self.points.read().is_empty()
137 }
138
139 pub fn get_all_points(&self) -> Vec<Point> {
141 self.points.read().values().cloned().collect()
142 }
143
144 pub fn upsert(&self, point: Point) -> Result<()> {
146 if self.config.vector_dim > 0 && point.vector.dim() != self.config.vector_dim {
148 return Err(Error::InvalidDimension {
149 expected: self.config.vector_dim,
150 actual: point.vector.dim(),
151 });
152 }
153
154 let id_str = point.id.to_string();
155
156 let new_version = {
158 let points = self.points.read();
159 if let Some(existing) = points.get(&id_str) {
160 existing.version + 1
161 } else {
162 0
163 }
164 };
165
166 let mut versioned_point = point;
168 versioned_point.version = new_version;
169
170 let in_batch = *self.batch_mode.read();
171 if in_batch {
172 self.points.write().insert(id_str.clone(), versioned_point.clone());
173 self.pending_points.write().push(versioned_point);
174 return Ok(());
175 }
176
177 if let Some(hnsw) = &self.hnsw {
178 let built = *self.hnsw_built.read();
179 if built {
180 let mut normalized_point = versioned_point.clone();
181 normalized_point.vector.normalize();
182
183 let mut index = hnsw.write();
184 index.insert(normalized_point);
185 }
186 }
187
188 if let Some(bm25) = &self.bm25 {
189 if let Some(payload) = &versioned_point.payload {
190 if let Some(text) = payload.get("text").and_then(|v| v.as_str()) {
191 let mut index = bm25.write();
192 index.insert_doc(&id_str, text);
193 }
194 }
195 }
196
197 self.points.write().insert(id_str, versioned_point);
198 Ok(())
199 }
200
201 pub fn start_batch(&self) {
203 *self.batch_mode.write() = true;
204 self.pending_points.write().clear();
205 }
206
207 pub fn end_batch(&self) -> Result<()> {
209 *self.batch_mode.write() = false;
210
211 if let Some(hnsw) = &self.hnsw {
212 let points = self.points.read();
213 let point_count = points.len();
214
215 const HNSW_REBUILD_THRESHOLD: usize = 10_000;
216
217 if point_count > HNSW_REBUILD_THRESHOLD && !self.hnsw_rebuilding.load(Ordering::Acquire) {
218 self.hnsw_rebuilding.store(true, Ordering::Release);
219 let points_clone: Vec<Point> = points.values().cloned().collect();
220 let hnsw_clone = hnsw.clone();
221 let built_flag = self.hnsw_built.clone();
222 let rebuilding_flag = self.hnsw_rebuilding.clone();
223
224 let job = crate::background::HnswRebuildJob::new(
225 points_clone,
226 hnsw_clone,
227 built_flag,
228 rebuilding_flag,
229 );
230 crate::background::get_background_system().submit(Box::new(job));
231 }
232 }
233
234 self.pending_points.write().clear();
235 Ok(())
236 }
237
238 pub fn batch_upsert(&self, points: Vec<Point>) -> Result<()> {
240 self.start_batch();
241 for point in points {
242 self.upsert(point)?;
243 }
244 self.end_batch()?;
245 Ok(())
246 }
247
248 pub fn batch_upsert_with_prewarm(&self, points: Vec<Point>, prewarm: bool) -> Result<()> {
250 self.batch_upsert(points)?;
251 if prewarm {
252 self.prewarm_index()?;
253 }
254 Ok(())
255 }
256
257 #[inline]
259 pub fn get(&self, id: &str) -> Option<Point> {
260 self.points.read().get(id).cloned()
261 }
262
263 pub fn delete(&self, id: &str) -> Result<bool> {
265 if let Some(hnsw) = &self.hnsw {
266 let mut index = hnsw.write();
267 index.remove(id);
268 }
269
270 if let Some(bm25) = &self.bm25 {
271 let mut index = bm25.write();
272 index.delete_doc(id);
273 }
274
275 let mut points = self.points.write();
276 Ok(points.remove(id).is_some())
277 }
278
279 pub fn set_payload(&self, id: &str, payload: serde_json::Value) -> Result<bool> {
281 let mut points = self.points.write();
282 if let Some(point) = points.get_mut(id) {
283 if let Some(existing) = &mut point.payload {
284 if let (Some(existing_obj), Some(new_obj)) = (existing.as_object_mut(), payload.as_object()) {
285 for (key, value) in new_obj {
286 existing_obj.insert(key.clone(), value.clone());
287 }
288 }
289 } else {
290 point.payload = Some(payload);
291 }
292 Ok(true)
293 } else {
294 Ok(false)
295 }
296 }
297
298 pub fn overwrite_payload(&self, id: &str, payload: serde_json::Value) -> Result<bool> {
300 let mut points = self.points.write();
301 if let Some(point) = points.get_mut(id) {
302 point.payload = Some(payload);
303 Ok(true)
304 } else {
305 Ok(false)
306 }
307 }
308
309 pub fn delete_payload_keys(&self, id: &str, keys: &[String]) -> Result<bool> {
311 let mut points = self.points.write();
312 if let Some(point) = points.get_mut(id) {
313 if let Some(payload) = &mut point.payload {
314 if let Some(obj) = payload.as_object_mut() {
315 for key in keys {
316 obj.remove(key);
317 }
318 }
319 }
320 Ok(true)
321 } else {
322 Ok(false)
323 }
324 }
325
326 pub fn clear_payload(&self, id: &str) -> Result<bool> {
328 let mut points = self.points.write();
329 if let Some(point) = points.get_mut(id) {
330 point.payload = None;
331 Ok(true)
332 } else {
333 Ok(false)
334 }
335 }
336
337 pub fn update_vector(&self, id: &str, vector: Vector) -> Result<bool> {
339 let mut points = self.points.write();
340 if let Some(point) = points.get_mut(id) {
341 point.vector = vector.clone();
342
343 if let Some(hnsw) = &self.hnsw {
345 let mut index = hnsw.write();
346 index.remove(id);
347 index.insert(point.clone());
349 }
350 Ok(true)
351 } else {
352 Ok(false)
353 }
354 }
355
356 pub fn update_multivector(&self, id: &str, multivector: Option<MultiVector>) -> Result<bool> {
358 let mut points = self.points.write();
359 if let Some(point) = points.get_mut(id) {
360 point.multivector = multivector;
361 Ok(true)
362 } else {
363 Ok(false)
364 }
365 }
366
367 pub fn delete_vector(&self, id: &str) -> Result<bool> {
369 self.delete(id)
372 }
373
374 pub fn create_payload_index(&self, field_name: &str, index_type: PayloadIndexType) -> Result<bool> {
376 let mut indexes = self.payload_indexes.write();
377 indexes.insert(field_name.to_string(), index_type);
378 Ok(true)
379 }
380
381 pub fn delete_payload_index(&self, field_name: &str) -> Result<bool> {
383 let mut indexes = self.payload_indexes.write();
384 Ok(indexes.remove(field_name).is_some())
385 }
386
387 pub fn get_payload_indexes(&self) -> HashMap<String, PayloadIndexType> {
389 self.payload_indexes.read().clone()
390 }
391
392 pub fn is_field_indexed(&self, field_name: &str) -> bool {
394 self.payload_indexes.read().contains_key(field_name)
395 }
396
397 pub fn prewarm_index(&self) -> Result<()> {
399 if let Some(hnsw) = &self.hnsw {
400 let mut built = self.hnsw_built.write();
401 if !*built {
402 let points = self.points.read();
403 if !points.is_empty() {
404 let mut index = hnsw.write();
405 *index = HnswIndex::new(16, 3);
406 for point in points.values() {
407 index.insert(point.clone());
408 }
409 *built = true;
410 }
411 }
412 }
413 Ok(())
414 }
415
416 fn brute_force_search(&self, query: &Vector, limit: usize, filter: Option<&dyn Filter>) -> Vec<(Point, f32)> {
418 use rayon::prelude::*;
419
420 let points = self.points.read();
421 let query_slice = query.as_slice();
422 let distance = self.config.distance.clone();
423
424 let point_vec: Vec<_> = points.values().collect();
426
427 let scored: Vec<(usize, f32)> = if point_vec.len() >= 10000 && filter.is_none() {
431 point_vec
433 .par_iter()
434 .enumerate()
435 .map(|(idx, point)| {
436 let score = match distance {
437 Distance::Cosine => {
438 crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
439 }
440 Distance::Euclidean => {
441 -crate::simd::l2_distance_simd(query_slice, point.vector.as_slice())
442 }
443 Distance::Dot => {
444 crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
445 }
446 };
447 (idx, score)
448 })
449 .collect()
450 } else {
451 let mut results = Vec::with_capacity(point_vec.len());
453
454 if filter.is_none() && matches!(distance, Distance::Cosine) {
455 for (idx, point) in point_vec.iter().enumerate() {
457 let score = crate::simd::dot_product_simd(query_slice, point.vector.as_slice());
458 results.push((idx, score));
459 }
460 } else {
461 for (idx, point) in point_vec.iter().enumerate() {
463 if let Some(f) = filter {
464 if !f.matches(point) {
465 continue;
466 }
467 }
468
469 let score = match distance {
470 Distance::Cosine => {
471 crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
472 }
473 Distance::Euclidean => {
474 -crate::simd::l2_distance_simd(query_slice, point.vector.as_slice())
475 }
476 Distance::Dot => {
477 crate::simd::dot_product_simd(query_slice, point.vector.as_slice())
478 }
479 };
480
481 results.push((idx, score));
482 }
483 }
484 results
485 };
486
487 let mut scored = scored;
489 if scored.len() > limit {
490 scored.select_nth_unstable_by(limit, |a, b| {
491 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
492 });
493 scored.truncate(limit);
494 }
495 scored.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
496
497 scored
499 .into_iter()
500 .map(|(idx, score)| (point_vec[idx].clone(), score))
501 .collect()
502 }
503
504 pub fn search(
507 &self,
508 query: &Vector,
509 limit: usize,
510 filter: Option<&dyn Filter>,
511 ) -> Vec<(Point, f32)> {
512 let normalized_query = query.normalized();
513 let point_count = self.points.read().len();
514
515 const BRUTE_FORCE_THRESHOLD: usize = 10000;
517 if point_count < BRUTE_FORCE_THRESHOLD {
518 return self.brute_force_search(&normalized_query, limit, filter);
519 }
520
521 if let Some(hnsw) = &self.hnsw {
522 {
524 let mut built = self.hnsw_built.write();
525 if !*built {
526 let points = self.points.read();
527 if !points.is_empty() {
528 let mut index = hnsw.write();
529 *index = HnswIndex::new(16, 3);
530 for point in points.values() {
531 index.insert(point.clone());
532 }
533 *built = true;
534 }
535 }
536 }
537
538 let mut index = hnsw.write();
540 let mut results = index.search(&normalized_query, limit, None);
541
542 if let Some(f) = filter {
543 results.retain(|(point, _)| f.matches(point));
544 }
545
546 results
547 } else {
548 let points = self.points.read();
549 let results: Vec<(Point, f32)> = points
550 .values()
551 .filter(|point| {
552 filter.map(|f| f.matches(point)).unwrap_or(true)
553 })
554 .map(|point| {
555 let score = match self.config.distance {
556 Distance::Cosine => point.vector.cosine_similarity(query),
557 Distance::Euclidean => -point.vector.l2_distance(query),
558 Distance::Dot => {
559 point.vector.as_slice()
560 .iter()
561 .zip(query.as_slice().iter())
562 .map(|(a, b)| a * b)
563 .sum()
564 }
565 };
566 (point.clone(), score)
567 })
568 .collect();
569
570 let mut sorted = results;
571 sorted.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
572 sorted.truncate(limit);
573 sorted
574 }
575 }
576
577 pub fn search_text(&self, query: &str, limit: usize) -> Vec<(String, f32)> {
579 if let Some(bm25) = &self.bm25 {
580 let index = bm25.read();
581 index.search(query, limit)
582 } else {
583 Vec::new()
584 }
585 }
586
587 pub fn search_multivector(
592 &self,
593 query: &MultiVector,
594 limit: usize,
595 filter: Option<&dyn Filter>,
596 ) -> Vec<(Point, f32)> {
597 let points = self.points.read();
598
599 let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
600
601 for point in points.values() {
602 if let Some(f) = filter {
603 if !f.matches(point) {
604 continue;
605 }
606 }
607
608 let score = if let Some(doc_mv) = &point.multivector {
610 match self.config.distance {
612 Distance::Cosine => query.max_sim_cosine(doc_mv),
613 Distance::Euclidean => query.max_sim_l2(doc_mv),
614 Distance::Dot => query.max_sim(doc_mv),
615 }
616 } else {
617 let doc_mv = MultiVector::from_single(point.vector.as_slice().to_vec())
619 .unwrap_or_else(|_| MultiVector::new(vec![vec![0.0; query.dim()]]).unwrap());
620 match self.config.distance {
621 Distance::Cosine => query.max_sim_cosine(&doc_mv),
622 Distance::Euclidean => query.max_sim_l2(&doc_mv),
623 Distance::Dot => query.max_sim(&doc_mv),
624 }
625 };
626
627 results.push((point.clone(), score));
628 }
629
630 if results.len() > limit {
632 results.select_nth_unstable_by(limit, |a, b| {
633 b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
634 });
635 results.truncate(limit);
636 }
637
638 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
639 results
640 }
641
642 pub fn iter(&self) -> Vec<Point> {
644 self.points.read().values().cloned().collect()
645 }
646
647 pub fn search_sparse(
649 &self,
650 query: &crate::point::SparseVector,
651 vector_name: &str,
652 limit: usize,
653 filter: Option<&dyn Filter>,
654 ) -> Vec<(Point, f32)> {
655 let points = self.points.read();
656
657 let mut results: Vec<(Point, f32)> = Vec::with_capacity(points.len().min(limit * 2));
658
659 for point in points.values() {
660 if let Some(f) = filter {
662 if !f.matches(point) {
663 continue;
664 }
665 }
666
667 if let Some(point_sparse) = point.sparse_vectors.get(vector_name) {
669 let score = query.dot(point_sparse);
671
672 if score > 0.0 {
674 results.push((point.clone(), score));
675 }
676 }
677 }
678
679 results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
681 results.truncate(limit);
682
683 results
684 }
685}
686