1use std::cmp::Ordering;
4
5use roaring::RoaringBitmap;
6use selene_core::{
7 CancellationChecker, DbString, NodeId, Value, VectorMetric, VectorMetricQuery, VectorTopK,
8 VectorValue,
9};
10
11use crate::error::{GraphError, GraphResult};
12use crate::graph::SeleneGraph;
13use crate::parallel_scan::{should_parallelize_scan, try_reduce_bitmap_chunks};
14#[cfg(test)]
15use crate::shared::SharedGraph;
16use crate::store::RowIndex;
17use crate::vector_index::VectorIndexSearchHit;
18#[path = "vector_search/types.rs"]
19mod types;
20pub use types::{
21 ApproximateVectorExpansionOptions, ApproximateVectorSearchOptions, VectorCandidateSet,
22 VectorNeighborDirection, VectorNeighborSearchOptions, VectorNodeSearchHit, VectorSearchError,
23};
24#[path = "vector_search/approx_batch.rs"]
25mod approx_batch;
26#[path = "vector_search/approx_turbo_quant.rs"]
27mod approx_turbo_quant;
28#[path = "vector_search/exact_batch.rs"]
29mod exact_batch;
30#[path = "vector_search/shared_wrappers.rs"]
31mod shared_wrappers;
32#[path = "vector_search/turbo_quant_exact.rs"]
33mod turbo_quant_exact;
34
35const VECTOR_SEARCH_CANCEL_STRIDE: usize = 1024;
36const VECTOR_SEARCH_PARALLEL_CHUNK_ROWS: usize = 2048;
37
38#[cfg(not(test))]
39const VECTOR_SEARCH_PARALLEL_MIN_ROWS: u64 = 16_384;
40#[cfg(test)]
41const VECTOR_SEARCH_PARALLEL_MIN_ROWS: u64 = 8;
42
43impl SeleneGraph {
44 pub fn exact_vector_search_nodes(
53 &self,
54 label: &DbString,
55 property: &DbString,
56 query: &VectorValue,
57 metric: VectorMetric,
58 k: usize,
59 ) -> GraphResult<Vec<VectorNodeSearchHit>> {
60 self.exact_vector_search_nodes_checked(
61 label,
62 property,
63 query,
64 metric,
65 k,
66 CancellationChecker::disabled(),
67 )
68 .map_err(VectorSearchError::into_graph_error)
69 }
70
71 pub fn exact_vector_search_nodes_checked(
79 &self,
80 label: &DbString,
81 property: &DbString,
82 query: &VectorValue,
83 metric: VectorMetric,
84 k: usize,
85 checker: CancellationChecker<'_>,
86 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
87 checker.check()?;
88 if k == 0 {
89 return Ok(Vec::new());
90 }
91 let Some(label_rows) = self.nodes_with_label(label) else {
92 return Ok(Vec::new());
93 };
94 let query_dimension = u32::try_from(query.dimension()).ok();
95 let vector_index = query_dimension.and_then(|dimension| {
96 self.vector_index_for(label, property)
97 .filter(|index| index.dimension() == dimension)
98 });
99 let rows = vector_index
100 .as_ref()
101 .map_or(label_rows, |index| index.rows());
102 let scorer = metric.bind_query(query).map_err(GraphError::from)?;
103 if should_parallelize_exact_scan(rows, k) {
104 return self.exact_vector_search_parallel(label, property, scorer, k, rows, checker);
105 }
106
107 let mut top_k = VectorTopK::new(k);
108 let mut rows_since_check = 0usize;
109 for raw_row in rows.iter() {
110 rows_since_check += 1;
111 if rows_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
112 checker.check()?;
113 rows_since_check = 0;
114 }
115 if !self.node_store.is_alive(raw_row) {
116 continue;
117 }
118 let row = RowIndex::new(raw_row);
119 let node_id = self
120 .node_id_for_row(row)
121 .ok_or_else(|| GraphError::Inconsistent {
122 reason: format!(
123 "label index row {raw_row} for {} has no node id",
124 label.as_str()
125 ),
126 })?;
127 let properties = self
128 .node_store
129 .properties
130 .get(raw_row as usize)
131 .ok_or_else(|| GraphError::Inconsistent {
132 reason: format!(
133 "label index row {raw_row} for {} has no property row",
134 label.as_str()
135 ),
136 })?;
137 let Some(Value::Vector(vector)) = properties.get(property) else {
138 continue;
139 };
140 let distance = scorer.distance(vector).map_err(GraphError::from)?;
141 top_k.push_distance(node_id, distance);
142 }
143
144 Ok(top_k
145 .into_hits()
146 .into_iter()
147 .map(|hit| VectorNodeSearchHit {
148 node_id: hit.key,
149 distance: hit.distance,
150 })
151 .collect())
152 }
153
154 fn exact_vector_search_parallel(
155 &self,
156 label: &DbString,
157 property: &DbString,
158 scorer: VectorMetricQuery<'_>,
159 k: usize,
160 rows: &RoaringBitmap,
161 checker: CancellationChecker<'_>,
162 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
163 let top_k = try_reduce_bitmap_chunks(
164 rows,
165 VECTOR_SEARCH_PARALLEL_CHUNK_ROWS,
166 checker,
167 || VectorTopK::new(k),
168 |chunk| self.exact_vector_search_chunk(label, property, scorer, k, chunk),
169 merge_top_k,
170 )?;
171
172 Ok(vector_node_hits(top_k))
173 }
174
175 fn exact_vector_search_chunk(
176 &self,
177 label: &DbString,
178 property: &DbString,
179 scorer: VectorMetricQuery<'_>,
180 k: usize,
181 rows: &[u32],
182 ) -> Result<VectorTopK<NodeId>, VectorSearchError> {
183 let mut top_k = VectorTopK::new(k);
184 for &raw_row in rows {
185 if !self.node_store.is_alive(raw_row) {
186 continue;
187 }
188 let row = RowIndex::new(raw_row);
189 let node_id = self
190 .node_id_for_row(row)
191 .ok_or_else(|| GraphError::Inconsistent {
192 reason: format!(
193 "vector search row {raw_row} for {} has no node id",
194 label.as_str()
195 ),
196 })?;
197 let properties = self
198 .node_store
199 .properties
200 .get(raw_row as usize)
201 .ok_or_else(|| GraphError::Inconsistent {
202 reason: format!(
203 "vector search row {raw_row} for {} has no property row",
204 label.as_str()
205 ),
206 })?;
207 let Some(Value::Vector(vector)) = properties.get(property) else {
208 continue;
209 };
210 let distance = scorer.distance(vector).map_err(GraphError::from)?;
211 top_k.push_distance(node_id, distance);
212 }
213 Ok(top_k)
214 }
215
216 pub fn approximate_vector_search_nodes_checked(
226 &self,
227 label: &DbString,
228 property: &DbString,
229 query: &VectorValue,
230 options: ApproximateVectorSearchOptions,
231 checker: CancellationChecker<'_>,
232 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
233 checker.check()?;
234 let query_dimension = u32::try_from(query.dimension())
235 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
236 let Some(index) = self
237 .vector_index_for(label, property)
238 .filter(|index| index.dimension() == query_dimension)
239 else {
240 return Err(VectorSearchError::ApproximateIndexMissing);
241 };
242 let Some(indexed_metric) = index.ann_metric() else {
243 return Err(VectorSearchError::ApproximateIndexMissing);
244 };
245 if indexed_metric != options.metric {
246 return Err(VectorSearchError::ApproximateMetricMismatch {
247 indexed: indexed_metric,
248 requested: options.metric,
249 });
250 }
251 if index.is_turbo_quant() {
252 if turbo_quant_exact::covers_rows(index.rows(), options) {
253 return self.exact_vector_search_nodes_checked(
254 label,
255 property,
256 query,
257 options.metric,
258 options.k,
259 checker,
260 );
261 }
262 let row_hits = index
263 .turbo_quant_candidates(query, options.k, options.ef_search)
264 .ok_or(VectorSearchError::ApproximateIndexMissing)?
265 .map_err(GraphError::from)?;
266 return rerank_ann_row_candidates(
267 self,
268 property,
269 query,
270 options.metric,
271 options.k,
272 row_hits,
273 &checker,
274 );
275 }
276 let row_hits = index
277 .ann_search(query, options.k, options.ef_search)
278 .ok_or(VectorSearchError::ApproximateIndexMissing)?
279 .map_err(GraphError::from)?;
280
281 ann_row_hits_to_node_hits(self, label, row_hits, &checker)
282 }
283
284 pub fn approximate_vector_search_nodes_batch_checked(
293 &self,
294 label: &DbString,
295 property: &DbString,
296 queries: &[VectorValue],
297 options: ApproximateVectorSearchOptions,
298 checker: CancellationChecker<'_>,
299 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
300 checker.check()?;
301 let Some(first_query) = queries.first() else {
302 return Ok(Vec::new());
303 };
304 let query_dimension = u32::try_from(first_query.dimension())
305 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
306 let Some(index) = self
307 .vector_index_for(label, property)
308 .filter(|index| index.dimension() == query_dimension)
309 else {
310 return Err(VectorSearchError::ApproximateIndexMissing);
311 };
312 let Some(indexed_metric) = index.ann_metric() else {
313 return Err(VectorSearchError::ApproximateIndexMissing);
314 };
315 if indexed_metric != options.metric {
316 return Err(VectorSearchError::ApproximateMetricMismatch {
317 indexed: indexed_metric,
318 requested: options.metric,
319 });
320 }
321
322 if index.is_turbo_quant() {
323 for query in queries {
324 checker.check()?;
325 let dimension = u32::try_from(query.dimension())
326 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
327 if dimension != query_dimension {
328 return Err(VectorSearchError::ApproximateIndexMissing);
329 }
330 }
331 if turbo_quant_exact::covers_rows(index.rows(), options) {
332 return self.exact_vector_search_nodes_batch_checked(
333 label,
334 property,
335 queries,
336 options.metric,
337 options.k,
338 checker,
339 );
340 }
341 if !index.turbo_quant_prefers_fused_batch(queries.len()) {
342 let mut batch_hits = Vec::with_capacity(queries.len());
343 for query in queries {
344 checker.check()?;
345 let row_hits = index
346 .turbo_quant_candidates(query, options.k, options.ef_search)
347 .ok_or(VectorSearchError::ApproximateIndexMissing)?
348 .map_err(GraphError::from)?;
349 batch_hits.push(rerank_ann_row_candidates(
350 self,
351 property,
352 query,
353 options.metric,
354 options.k,
355 row_hits,
356 &checker,
357 )?);
358 }
359 return Ok(batch_hits);
360 }
361 let row_batches = index
362 .turbo_quant_candidates_batch(queries, options.k, options.ef_search)
363 .ok_or(VectorSearchError::ApproximateIndexMissing)?
364 .map_err(GraphError::from)?;
365 let mut batch_hits = Vec::with_capacity(queries.len());
366 for (query, row_hits) in queries.iter().zip(row_batches) {
367 batch_hits.push(rerank_ann_row_candidates(
368 self,
369 property,
370 query,
371 options.metric,
372 options.k,
373 row_hits,
374 &checker,
375 )?);
376 }
377 return Ok(batch_hits);
378 }
379
380 approx_batch::ann_index_batch_search(
381 self,
382 label,
383 &index,
384 queries,
385 options,
386 query_dimension,
387 checker,
388 )
389 }
390
391 pub fn approximate_vector_search_expanded_candidates_checked(
401 &self,
402 label: &DbString,
403 property: &DbString,
404 query: &VectorValue,
405 options: ApproximateVectorExpansionOptions<'_>,
406 checker: CancellationChecker<'_>,
407 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
408 checker.check()?;
409 let root_hits = self.approximate_vector_search_nodes_checked(
410 label,
411 property,
412 query,
413 ApproximateVectorSearchOptions::new(options.metric, options.root_k, options.ef_search),
414 checker,
415 )?;
416 if options.k == 0 || root_hits.is_empty() {
417 return Ok(Vec::new());
418 }
419
420 let roots = VectorCandidateSet::from_search_hits(&root_hits);
421 let expanded = self.expand_vector_candidate_set_checked(
422 &roots,
423 options.edge_label,
424 options.direction,
425 checker,
426 )?;
427 self.score_vector_candidate_set_checked(
428 property,
429 query,
430 &expanded,
431 options.metric,
432 options.k,
433 checker,
434 )
435 }
436
437 pub fn approximate_vector_search_expanded_candidates_batch_checked(
445 &self,
446 label: &DbString,
447 property: &DbString,
448 queries: &[VectorValue],
449 options: ApproximateVectorExpansionOptions<'_>,
450 checker: CancellationChecker<'_>,
451 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
452 checker.check()?;
453 let root_hits = self.approximate_vector_search_nodes_batch_checked(
454 label,
455 property,
456 queries,
457 ApproximateVectorSearchOptions::new(options.metric, options.root_k, options.ef_search),
458 checker,
459 )?;
460 if options.k == 0 {
461 return Ok(vec![Vec::new(); queries.len()]);
462 }
463
464 let root_sets = root_hits
465 .iter()
466 .map(VectorCandidateSet::from_search_hits)
467 .collect::<Vec<_>>();
468 self.score_vector_expanded_candidate_sets_batch_checked(
469 property,
470 queries,
471 &root_sets,
472 VectorNeighborSearchOptions::new(
473 options.edge_label,
474 options.direction,
475 options.metric,
476 options.k,
477 ),
478 checker,
479 )
480 }
481}
482
483fn should_parallelize_exact_scan(rows: &RoaringBitmap, k: usize) -> bool {
484 should_parallelize_scan(rows.len(), k, VECTOR_SEARCH_PARALLEL_MIN_ROWS)
485}
486
487fn merge_top_k(
488 mut lhs: VectorTopK<NodeId>,
489 rhs: VectorTopK<NodeId>,
490) -> Result<VectorTopK<NodeId>, VectorSearchError> {
491 for hit in rhs.into_hits() {
492 lhs.push_distance(hit.key, hit.distance);
493 }
494 Ok(lhs)
495}
496
497fn vector_node_hits(top_k: VectorTopK<NodeId>) -> Vec<VectorNodeSearchHit> {
498 top_k
499 .into_hits()
500 .into_iter()
501 .map(|hit| VectorNodeSearchHit {
502 node_id: hit.key,
503 distance: hit.distance,
504 })
505 .collect()
506}
507
508fn ann_row_hits_to_node_hits(
509 graph: &SeleneGraph,
510 label: &DbString,
511 row_hits: Vec<VectorIndexSearchHit>,
512 checker: &CancellationChecker<'_>,
513) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
514 let mut hits = Vec::with_capacity(row_hits.len());
515 let mut needs_sort = false;
516 for hit in row_hits {
517 checker.check()?;
518 if !graph.node_store.is_alive(hit.row) {
519 continue;
520 }
521 let row = RowIndex::new(hit.row);
522 let node_id = graph
523 .node_id_for_row(row)
524 .ok_or_else(|| GraphError::Inconsistent {
525 reason: format!(
526 "ANN vector index row {} for {} has no node id",
527 hit.row,
528 label.as_str()
529 ),
530 })?;
531 let node_hit = VectorNodeSearchHit {
532 node_id,
533 distance: hit.distance,
534 };
535 needs_sort |= hits
536 .last()
537 .is_some_and(|previous| compare_node_search_hit(previous, &node_hit).is_gt());
538 hits.push(node_hit);
539 }
540 if needs_sort {
541 hits.sort_by(compare_node_search_hit);
542 }
543 Ok(hits)
544}
545
546fn rerank_ann_row_candidates(
547 graph: &SeleneGraph,
548 property: &DbString,
549 query: &VectorValue,
550 metric: VectorMetric,
551 k: usize,
552 row_hits: Vec<VectorIndexSearchHit>,
553 checker: &CancellationChecker<'_>,
554) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
555 let scorer = metric.bind_query(query).map_err(GraphError::from)?;
556 let mut top_k = VectorTopK::new(k);
557 for hit in row_hits {
558 checker.check()?;
559 if !graph.node_store.is_alive(hit.row) {
560 continue;
561 }
562 let row = RowIndex::new(hit.row);
563 let node_id = graph
564 .node_id_for_row(row)
565 .ok_or_else(|| GraphError::Inconsistent {
566 reason: format!("ANN vector candidate row {} has no node id", hit.row),
567 })?;
568 let properties = graph
569 .node_store
570 .properties
571 .get(hit.row as usize)
572 .ok_or_else(|| GraphError::Inconsistent {
573 reason: format!("ANN vector candidate row {} has no property row", hit.row),
574 })?;
575 let Some(Value::Vector(vector)) = properties.get(property) else {
576 continue;
577 };
578 let distance = scorer.distance(vector).map_err(GraphError::from)?;
579 top_k.push_distance(node_id, distance);
580 }
581 Ok(vector_node_hits(top_k))
582}
583
584fn compare_node_search_hit(lhs: &VectorNodeSearchHit, rhs: &VectorNodeSearchHit) -> Ordering {
585 lhs.distance
586 .total_cmp(&rhs.distance)
587 .then_with(|| lhs.node_id.cmp(&rhs.node_id))
588}
589
590#[cfg(test)]
591#[path = "vector_search/ann_conversion_tests.rs"]
592mod ann_conversion_tests;
593#[cfg(test)]
594#[path = "vector_search/ann_expansion_tests.rs"]
595mod ann_expansion_tests;
596#[cfg(test)]
597#[path = "vector_search/batch_tests.rs"]
598mod batch_tests;
599#[cfg(test)]
600#[path = "vector_search/recall_tests.rs"]
601mod recall_tests;
602#[path = "vector_search/score.rs"]
603mod score;
604#[path = "vector_search/score_candidate_batch.rs"]
605mod score_candidate_batch;
606#[path = "vector_search/score_expanded_batch.rs"]
607mod score_expanded_batch;
608#[path = "vector_search/score_neighbor_batch.rs"]
609mod score_neighbor_batch;
610#[path = "vector_search/score_shared.rs"]
611mod score_shared;
612#[cfg(test)]
613#[path = "vector_search/score_tests.rs"]
614mod score_tests;
615#[cfg(test)]
616#[path = "vector_search/tests.rs"]
617mod tests;