1use std::cmp::Ordering;
4
5use roaring::RoaringBitmap;
6use selene_core::{
7 CancellationChecker, DbString, NodeId, Value, VectorMetric, VectorMetricQuery, VectorTopK,
8 VectorValue,
9};
10
11use crate::error::{GraphError, GraphResult};
12use crate::graph::SeleneGraph;
13use crate::parallel_scan::{should_parallelize_scan, try_reduce_bitmap_chunks};
14#[cfg(test)]
15use crate::shared::SharedGraph;
16use crate::store::RowIndex;
17use crate::vector_index::VectorIndexSearchHit;
18#[path = "vector_search/types.rs"]
19mod types;
20pub use types::{
21 ApproximateVectorExpansionOptions, ApproximateVectorSearchOptions, VectorCandidateSet,
22 VectorNeighborDirection, VectorNeighborSearchOptions, VectorNodeSearchHit, VectorSearchError,
23};
24#[path = "vector_search/approx_batch.rs"]
25mod approx_batch;
26#[path = "vector_search/approx_filter.rs"]
27mod approx_filter;
28#[path = "vector_search/approx_turbo_quant.rs"]
29mod approx_turbo_quant;
30#[path = "vector_search/exact_batch.rs"]
31mod exact_batch;
32#[path = "vector_search/shared_wrappers.rs"]
33mod shared_wrappers;
34#[path = "vector_search/turbo_quant_exact.rs"]
35mod turbo_quant_exact;
36
37const VECTOR_SEARCH_CANCEL_STRIDE: usize = 1024;
38const VECTOR_SEARCH_PARALLEL_CHUNK_ROWS: usize = 2048;
39
40#[cfg(not(test))]
41const VECTOR_SEARCH_PARALLEL_MIN_ROWS: u64 = 16_384;
42#[cfg(test)]
43const VECTOR_SEARCH_PARALLEL_MIN_ROWS: u64 = 8;
44
45impl SeleneGraph {
46 pub fn exact_vector_search_nodes(
55 &self,
56 label: &DbString,
57 property: &DbString,
58 query: &VectorValue,
59 metric: VectorMetric,
60 k: usize,
61 ) -> GraphResult<Vec<VectorNodeSearchHit>> {
62 self.exact_vector_search_nodes_checked(
63 label,
64 property,
65 query,
66 metric,
67 k,
68 CancellationChecker::disabled(),
69 )
70 .map_err(VectorSearchError::into_graph_error)
71 }
72
73 pub fn exact_vector_search_nodes_checked(
81 &self,
82 label: &DbString,
83 property: &DbString,
84 query: &VectorValue,
85 metric: VectorMetric,
86 k: usize,
87 checker: CancellationChecker<'_>,
88 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
89 checker.check()?;
90 if k == 0 {
91 return Ok(Vec::new());
92 }
93 let Some(label_rows) = self.nodes_with_label(label) else {
94 return Ok(Vec::new());
95 };
96 let query_dimension = u32::try_from(query.dimension()).ok();
97 let vector_index = query_dimension.and_then(|dimension| {
98 self.vector_index_for(label, property)
99 .filter(|index| index.dimension() == dimension)
100 });
101 let rows = vector_index
102 .as_ref()
103 .map_or(label_rows, |index| index.rows());
104 let scorer = metric.bind_query(query).map_err(GraphError::from)?;
105 if should_parallelize_exact_scan(rows, k) {
106 return self.exact_vector_search_parallel(label, property, scorer, k, rows, checker);
107 }
108
109 let mut top_k = VectorTopK::new(k);
110 let mut rows_since_check = 0usize;
111 for raw_row in rows.iter() {
112 rows_since_check += 1;
113 if rows_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
114 checker.note_nodes_scanned(rows_since_check)?;
115 rows_since_check = 0;
116 }
117 if !self.node_store.is_alive(raw_row) {
118 continue;
119 }
120 let row = RowIndex::new(raw_row);
121 let node_id = self
122 .node_id_for_row(row)
123 .ok_or_else(|| GraphError::Inconsistent {
124 reason: format!(
125 "label index row {raw_row} for {} has no node id",
126 label.as_str()
127 ),
128 })?;
129 let properties = self
130 .node_store
131 .properties
132 .get(raw_row as usize)
133 .ok_or_else(|| GraphError::Inconsistent {
134 reason: format!(
135 "label index row {raw_row} for {} has no property row",
136 label.as_str()
137 ),
138 })?;
139 let Some(Value::Vector(vector)) = properties.get(property) else {
140 continue;
141 };
142 let distance = scorer.distance(vector).map_err(GraphError::from)?;
143 top_k.push_distance(node_id, distance);
144 }
145 if rows_since_check > 0 {
146 checker.note_nodes_scanned(rows_since_check)?;
147 }
148
149 Ok(top_k
150 .into_hits()
151 .into_iter()
152 .map(|hit| VectorNodeSearchHit {
153 node_id: hit.key,
154 distance: hit.distance,
155 })
156 .collect())
157 }
158
159 fn exact_vector_search_parallel(
160 &self,
161 label: &DbString,
162 property: &DbString,
163 scorer: VectorMetricQuery<'_>,
164 k: usize,
165 rows: &RoaringBitmap,
166 checker: CancellationChecker<'_>,
167 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
168 let top_k = try_reduce_bitmap_chunks(
169 rows,
170 VECTOR_SEARCH_PARALLEL_CHUNK_ROWS,
171 checker,
172 || VectorTopK::new(k),
173 |chunk| self.exact_vector_search_chunk(label, property, scorer, k, chunk),
174 merge_top_k,
175 )?;
176
177 Ok(vector_node_hits(top_k))
178 }
179
180 fn exact_vector_search_chunk(
181 &self,
182 label: &DbString,
183 property: &DbString,
184 scorer: VectorMetricQuery<'_>,
185 k: usize,
186 rows: &[u32],
187 ) -> Result<VectorTopK<NodeId>, VectorSearchError> {
188 let mut top_k = VectorTopK::new(k);
189 for &raw_row in rows {
190 if !self.node_store.is_alive(raw_row) {
191 continue;
192 }
193 let row = RowIndex::new(raw_row);
194 let node_id = self
195 .node_id_for_row(row)
196 .ok_or_else(|| GraphError::Inconsistent {
197 reason: format!(
198 "vector search row {raw_row} for {} has no node id",
199 label.as_str()
200 ),
201 })?;
202 let properties = self
203 .node_store
204 .properties
205 .get(raw_row as usize)
206 .ok_or_else(|| GraphError::Inconsistent {
207 reason: format!(
208 "vector search row {raw_row} for {} has no property row",
209 label.as_str()
210 ),
211 })?;
212 let Some(Value::Vector(vector)) = properties.get(property) else {
213 continue;
214 };
215 let distance = scorer.distance(vector).map_err(GraphError::from)?;
216 top_k.push_distance(node_id, distance);
217 }
218 Ok(top_k)
219 }
220
221 pub fn approximate_vector_search_nodes_checked(
231 &self,
232 label: &DbString,
233 property: &DbString,
234 query: &VectorValue,
235 options: ApproximateVectorSearchOptions,
236 checker: CancellationChecker<'_>,
237 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
238 checker.check()?;
239 let query_dimension = u32::try_from(query.dimension())
240 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
241 let Some(index) = self
242 .vector_index_for(label, property)
243 .filter(|index| index.dimension() == query_dimension)
244 else {
245 return Err(VectorSearchError::ApproximateIndexMissing);
246 };
247 let Some(indexed_metric) = index.ann_metric() else {
248 return Err(VectorSearchError::ApproximateIndexMissing);
249 };
250 if indexed_metric != options.metric {
251 return Err(VectorSearchError::ApproximateMetricMismatch {
252 indexed: indexed_metric,
253 requested: options.metric,
254 });
255 }
256 if index.is_turbo_quant() {
257 if turbo_quant_exact::covers_rows(index.rows(), options) {
258 return self.exact_vector_search_nodes_checked(
259 label,
260 property,
261 query,
262 options.metric,
263 options.k,
264 checker,
265 );
266 }
267 let row_hits = index
268 .turbo_quant_candidates(query, options.k, options.ef_search)
269 .ok_or(VectorSearchError::ApproximateIndexMissing)?
270 .map_err(GraphError::from)?;
271 return rerank_ann_row_candidates(
272 self,
273 property,
274 query,
275 options.metric,
276 options.k,
277 row_hits,
278 &checker,
279 );
280 }
281 let row_hits = index
282 .ann_search(query, options.k, options.ef_search)
283 .ok_or(VectorSearchError::ApproximateIndexMissing)?
284 .map_err(GraphError::from)?;
285
286 ann_row_hits_to_node_hits(self, label, row_hits, &checker)
287 }
288
289 pub fn approximate_vector_search_nodes_batch_checked(
298 &self,
299 label: &DbString,
300 property: &DbString,
301 queries: &[VectorValue],
302 options: ApproximateVectorSearchOptions,
303 checker: CancellationChecker<'_>,
304 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
305 checker.check()?;
306 let Some(first_query) = queries.first() else {
307 return Ok(Vec::new());
308 };
309 let query_dimension = u32::try_from(first_query.dimension())
310 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
311 let Some(index) = self
312 .vector_index_for(label, property)
313 .filter(|index| index.dimension() == query_dimension)
314 else {
315 return Err(VectorSearchError::ApproximateIndexMissing);
316 };
317 let Some(indexed_metric) = index.ann_metric() else {
318 return Err(VectorSearchError::ApproximateIndexMissing);
319 };
320 if indexed_metric != options.metric {
321 return Err(VectorSearchError::ApproximateMetricMismatch {
322 indexed: indexed_metric,
323 requested: options.metric,
324 });
325 }
326
327 if index.is_turbo_quant() {
328 for query in queries {
329 checker.check()?;
330 let dimension = u32::try_from(query.dimension())
331 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
332 if dimension != query_dimension {
333 return Err(VectorSearchError::ApproximateIndexMissing);
334 }
335 }
336 if turbo_quant_exact::covers_rows(index.rows(), options) {
337 return self.exact_vector_search_nodes_batch_checked(
338 label,
339 property,
340 queries,
341 options.metric,
342 options.k,
343 checker,
344 );
345 }
346 if !index.turbo_quant_prefers_fused_batch(queries.len()) {
347 let mut batch_hits = Vec::with_capacity(queries.len());
348 for query in queries {
349 checker.check()?;
350 let row_hits = index
351 .turbo_quant_candidates(query, options.k, options.ef_search)
352 .ok_or(VectorSearchError::ApproximateIndexMissing)?
353 .map_err(GraphError::from)?;
354 batch_hits.push(rerank_ann_row_candidates(
355 self,
356 property,
357 query,
358 options.metric,
359 options.k,
360 row_hits,
361 &checker,
362 )?);
363 }
364 return Ok(batch_hits);
365 }
366 let row_batches = index
367 .turbo_quant_candidates_batch(queries, options.k, options.ef_search)
368 .ok_or(VectorSearchError::ApproximateIndexMissing)?
369 .map_err(GraphError::from)?;
370 let mut batch_hits = Vec::with_capacity(queries.len());
371 for (query, row_hits) in queries.iter().zip(row_batches) {
372 batch_hits.push(rerank_ann_row_candidates(
373 self,
374 property,
375 query,
376 options.metric,
377 options.k,
378 row_hits,
379 &checker,
380 )?);
381 }
382 return Ok(batch_hits);
383 }
384
385 approx_batch::ann_index_batch_search(
386 self,
387 label,
388 &index,
389 queries,
390 options,
391 query_dimension,
392 checker,
393 )
394 }
395
396 pub fn approximate_vector_search_expanded_candidates_checked(
406 &self,
407 label: &DbString,
408 property: &DbString,
409 query: &VectorValue,
410 options: ApproximateVectorExpansionOptions<'_>,
411 checker: CancellationChecker<'_>,
412 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
413 checker.check()?;
414 let root_hits = self.approximate_vector_search_nodes_checked(
415 label,
416 property,
417 query,
418 ApproximateVectorSearchOptions::new(options.metric, options.root_k, options.ef_search),
419 checker,
420 )?;
421 if options.k == 0 || root_hits.is_empty() {
422 return Ok(Vec::new());
423 }
424
425 let roots = VectorCandidateSet::from_search_hits(&root_hits);
426 let expanded = self.expand_vector_candidate_set_checked(
427 &roots,
428 options.edge_label,
429 options.direction,
430 checker,
431 )?;
432 self.score_vector_candidate_set_checked(
433 property,
434 query,
435 &expanded,
436 options.metric,
437 options.k,
438 checker,
439 )
440 }
441
442 pub fn approximate_vector_search_expanded_candidates_batch_checked(
450 &self,
451 label: &DbString,
452 property: &DbString,
453 queries: &[VectorValue],
454 options: ApproximateVectorExpansionOptions<'_>,
455 checker: CancellationChecker<'_>,
456 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
457 checker.check()?;
458 let root_hits = self.approximate_vector_search_nodes_batch_checked(
459 label,
460 property,
461 queries,
462 ApproximateVectorSearchOptions::new(options.metric, options.root_k, options.ef_search),
463 checker,
464 )?;
465 if options.k == 0 {
466 return Ok(vec![Vec::new(); queries.len()]);
467 }
468
469 let root_sets = root_hits
470 .iter()
471 .map(VectorCandidateSet::from_search_hits)
472 .collect::<Vec<_>>();
473 self.score_vector_expanded_candidate_sets_batch_checked(
474 property,
475 queries,
476 &root_sets,
477 VectorNeighborSearchOptions::new(
478 options.edge_label,
479 options.direction,
480 options.metric,
481 options.k,
482 ),
483 checker,
484 )
485 }
486}
487
488fn should_parallelize_exact_scan(rows: &RoaringBitmap, k: usize) -> bool {
489 should_parallelize_scan(rows.len(), k, VECTOR_SEARCH_PARALLEL_MIN_ROWS)
490}
491
492fn merge_top_k(
493 mut lhs: VectorTopK<NodeId>,
494 rhs: VectorTopK<NodeId>,
495) -> Result<VectorTopK<NodeId>, VectorSearchError> {
496 for hit in rhs.into_hits() {
497 lhs.push_distance(hit.key, hit.distance);
498 }
499 Ok(lhs)
500}
501
502fn vector_node_hits(top_k: VectorTopK<NodeId>) -> Vec<VectorNodeSearchHit> {
503 top_k
504 .into_hits()
505 .into_iter()
506 .map(|hit| VectorNodeSearchHit {
507 node_id: hit.key,
508 distance: hit.distance,
509 })
510 .collect()
511}
512
513fn ann_row_hits_to_node_hits(
514 graph: &SeleneGraph,
515 label: &DbString,
516 row_hits: Vec<VectorIndexSearchHit>,
517 checker: &CancellationChecker<'_>,
518) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
519 let mut hits = Vec::with_capacity(row_hits.len());
520 let mut needs_sort = false;
521 let mut rows_since_check = 0usize;
522 for hit in row_hits {
523 rows_since_check += 1;
524 if rows_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
525 checker.note_nodes_scanned(rows_since_check)?;
526 rows_since_check = 0;
527 }
528 if !graph.node_store.is_alive(hit.row) {
529 continue;
530 }
531 let row = RowIndex::new(hit.row);
532 let node_id = graph
533 .node_id_for_row(row)
534 .ok_or_else(|| GraphError::Inconsistent {
535 reason: format!(
536 "ANN vector index row {} for {} has no node id",
537 hit.row,
538 label.as_str()
539 ),
540 })?;
541 let node_hit = VectorNodeSearchHit {
542 node_id,
543 distance: hit.distance,
544 };
545 needs_sort |= hits
546 .last()
547 .is_some_and(|previous| compare_node_search_hit(previous, &node_hit).is_gt());
548 hits.push(node_hit);
549 }
550 if rows_since_check > 0 {
551 checker.note_nodes_scanned(rows_since_check)?;
552 }
553 if needs_sort {
554 hits.sort_by(compare_node_search_hit);
555 }
556 Ok(hits)
557}
558
559fn rerank_ann_row_candidates(
560 graph: &SeleneGraph,
561 property: &DbString,
562 query: &VectorValue,
563 metric: VectorMetric,
564 k: usize,
565 row_hits: Vec<VectorIndexSearchHit>,
566 checker: &CancellationChecker<'_>,
567) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
568 let scorer = metric.bind_query(query).map_err(GraphError::from)?;
569 let mut top_k = VectorTopK::new(k);
570 let mut rows_since_check = 0usize;
571 for hit in row_hits {
572 rows_since_check += 1;
573 if rows_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
574 checker.note_nodes_scanned(rows_since_check)?;
575 rows_since_check = 0;
576 }
577 if !graph.node_store.is_alive(hit.row) {
578 continue;
579 }
580 let row = RowIndex::new(hit.row);
581 let node_id = graph
582 .node_id_for_row(row)
583 .ok_or_else(|| GraphError::Inconsistent {
584 reason: format!("ANN vector candidate row {} has no node id", hit.row),
585 })?;
586 let properties = graph
587 .node_store
588 .properties
589 .get(hit.row as usize)
590 .ok_or_else(|| GraphError::Inconsistent {
591 reason: format!("ANN vector candidate row {} has no property row", hit.row),
592 })?;
593 let Some(Value::Vector(vector)) = properties.get(property) else {
594 continue;
595 };
596 let distance = scorer.distance(vector).map_err(GraphError::from)?;
597 top_k.push_distance(node_id, distance);
598 }
599 if rows_since_check > 0 {
600 checker.note_nodes_scanned(rows_since_check)?;
601 }
602 Ok(vector_node_hits(top_k))
603}
604
605fn compare_node_search_hit(lhs: &VectorNodeSearchHit, rhs: &VectorNodeSearchHit) -> Ordering {
606 lhs.distance
607 .total_cmp(&rhs.distance)
608 .then_with(|| lhs.node_id.cmp(&rhs.node_id))
609}
610
611#[cfg(test)]
612#[path = "vector_search/ann_conversion_tests.rs"]
613mod ann_conversion_tests;
614#[cfg(test)]
615#[path = "vector_search/ann_expansion_tests.rs"]
616mod ann_expansion_tests;
617#[cfg(test)]
618#[path = "vector_search/batch_tests.rs"]
619mod batch_tests;
620#[cfg(test)]
621#[path = "vector_search/recall_tests.rs"]
622mod recall_tests;
623#[path = "vector_search/score.rs"]
624mod score;
625#[path = "vector_search/score_candidate_batch.rs"]
626mod score_candidate_batch;
627#[path = "vector_search/score_expanded_batch.rs"]
628mod score_expanded_batch;
629#[path = "vector_search/score_neighbor_batch.rs"]
630mod score_neighbor_batch;
631#[path = "vector_search/score_shared.rs"]
632mod score_shared;
633#[cfg(test)]
634#[path = "vector_search/score_tests.rs"]
635mod score_tests;
636#[cfg(test)]
637#[path = "vector_search/tests.rs"]
638mod tests;