1use selene_core::{
2 CancellationChecker, CoreError, DbString, NodeId, Value, VectorMetric, VectorMetricQuery,
3 VectorTopK, VectorValue,
4};
5
6use crate::error::{GraphError, GraphResult};
7use crate::graph::SeleneGraph;
8use crate::parallel_scan::{should_parallelize_scan, try_reduce_chunks};
9
10use super::{
11 VECTOR_SEARCH_CANCEL_STRIDE, VectorCandidateSet, VectorNeighborDirection,
12 VectorNeighborSearchOptions, VectorNodeSearchHit, VectorSearchError, merge_top_k,
13 score_candidate_batch::{
14 candidate_sets_all_match, should_parallelize_candidate_batch_scoring,
15 should_parallelize_repeated_candidate_batch,
16 },
17 vector_node_hits,
18};
19
20#[cfg(not(test))]
21const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 4096;
22#[cfg(test)]
23const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 8;
24
25#[cfg(not(test))]
26const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 1024;
27#[cfg(test)]
28const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 4;
29
30impl SeleneGraph {
31 pub fn score_vector_nodes(
39 &self,
40 property: &DbString,
41 query: &VectorValue,
42 candidates: &[NodeId],
43 metric: VectorMetric,
44 k: usize,
45 ) -> GraphResult<Vec<VectorNodeSearchHit>> {
46 self.score_vector_nodes_checked(
47 property,
48 query,
49 candidates,
50 metric,
51 k,
52 CancellationChecker::disabled(),
53 )
54 .map_err(VectorSearchError::into_graph_error)
55 }
56
57 pub fn score_vector_nodes_checked(
62 &self,
63 property: &DbString,
64 query: &VectorValue,
65 candidates: &[NodeId],
66 metric: VectorMetric,
67 k: usize,
68 checker: CancellationChecker<'_>,
69 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
70 checker.check()?;
71 if k == 0 || candidates.is_empty() {
72 return Ok(Vec::new());
73 }
74
75 let candidates = VectorCandidateSet::from_nodes(candidates.iter().copied());
76 self.score_vector_candidate_set_after_initial_check(
77 property,
78 query,
79 &candidates,
80 metric,
81 k,
82 checker,
83 )
84 }
85
86 pub fn score_vector_candidate_set(
93 &self,
94 property: &DbString,
95 query: &VectorValue,
96 candidates: &VectorCandidateSet,
97 metric: VectorMetric,
98 k: usize,
99 ) -> GraphResult<Vec<VectorNodeSearchHit>> {
100 self.score_vector_candidate_set_checked(
101 property,
102 query,
103 candidates,
104 metric,
105 k,
106 CancellationChecker::disabled(),
107 )
108 .map_err(VectorSearchError::into_graph_error)
109 }
110
111 pub fn score_vector_candidate_set_checked(
113 &self,
114 property: &DbString,
115 query: &VectorValue,
116 candidates: &VectorCandidateSet,
117 metric: VectorMetric,
118 k: usize,
119 checker: CancellationChecker<'_>,
120 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
121 checker.check()?;
122 if k == 0 || candidates.is_empty() {
123 return Ok(Vec::new());
124 }
125 self.score_vector_candidate_set_after_initial_check(
126 property, query, candidates, metric, k, checker,
127 )
128 }
129
130 fn score_vector_candidate_set_after_initial_check(
131 &self,
132 property: &DbString,
133 query: &VectorValue,
134 candidates: &VectorCandidateSet,
135 metric: VectorMetric,
136 k: usize,
137 checker: CancellationChecker<'_>,
138 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
139 checker.check()?;
140
141 let scorer = metric.bind_query(query).map_err(GraphError::from)?;
142 if should_parallelize_candidate_scoring(candidates.len(), k) {
143 return self
144 .score_vector_candidate_set_parallel(property, scorer, candidates, k, checker);
145 }
146
147 self.score_vector_candidate_set_serial(property, scorer, candidates, k, checker)
148 }
149
150 pub(super) fn score_vector_candidate_set_serial(
151 &self,
152 property: &DbString,
153 scorer: VectorMetricQuery<'_>,
154 candidates: &VectorCandidateSet,
155 k: usize,
156 checker: CancellationChecker<'_>,
157 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
158 let mut top_k = VectorTopK::new(k);
159 let mut candidates_since_check = 0usize;
160 for node_id in candidates.as_nodes().iter().copied() {
161 candidates_since_check += 1;
162 if candidates_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
163 checker.note_nodes_scanned(candidates_since_check)?;
164 candidates_since_check = 0;
165 }
166 let Some(properties) = self.node_properties(node_id) else {
167 continue;
168 };
169 let Some(Value::Vector(vector)) = properties.get(property) else {
170 continue;
171 };
172 let distance = scorer.distance(vector).map_err(GraphError::from)?;
173 top_k.push_distance(node_id, distance);
174 }
175 if candidates_since_check > 0 {
176 checker.note_nodes_scanned(candidates_since_check)?;
177 }
178
179 Ok(vector_node_hits(top_k))
180 }
181
182 fn score_vector_candidate_set_parallel(
183 &self,
184 property: &DbString,
185 scorer: VectorMetricQuery<'_>,
186 candidates: &VectorCandidateSet,
187 k: usize,
188 checker: CancellationChecker<'_>,
189 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
190 let top_k = try_reduce_chunks(
191 candidates.as_nodes(),
192 VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES,
193 checker,
194 || VectorTopK::new(k),
195 |chunk| self.score_vector_candidate_set_chunk(property, scorer, chunk, k),
196 merge_top_k,
197 )?;
198
199 Ok(vector_node_hits(top_k))
200 }
201
202 fn score_vector_candidate_set_chunk(
203 &self,
204 property: &DbString,
205 scorer: VectorMetricQuery<'_>,
206 candidates: &[NodeId],
207 k: usize,
208 ) -> Result<VectorTopK<NodeId>, VectorSearchError> {
209 let mut top_k = VectorTopK::new(k);
210 for node_id in candidates.iter().copied() {
211 let Some(properties) = self.node_properties(node_id) else {
212 continue;
213 };
214 let Some(Value::Vector(vector)) = properties.get(property) else {
215 continue;
216 };
217 let distance = scorer.distance(vector).map_err(GraphError::from)?;
218 top_k.push_distance(node_id, distance);
219 }
220 Ok(top_k)
221 }
222
223 pub fn score_vector_nodes_batch<C>(
231 &self,
232 property: &DbString,
233 queries: &[VectorValue],
234 candidate_sets: &[C],
235 metric: VectorMetric,
236 k: usize,
237 ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>>
238 where
239 C: AsRef<[NodeId]>,
240 {
241 self.score_vector_nodes_batch_checked(
242 property,
243 queries,
244 candidate_sets,
245 metric,
246 k,
247 CancellationChecker::disabled(),
248 )
249 .map_err(VectorSearchError::into_graph_error)
250 }
251
252 pub fn score_vector_nodes_batch_checked<C>(
258 &self,
259 property: &DbString,
260 queries: &[VectorValue],
261 candidate_sets: &[C],
262 metric: VectorMetric,
263 k: usize,
264 checker: CancellationChecker<'_>,
265 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError>
266 where
267 C: AsRef<[NodeId]>,
268 {
269 checker.check()?;
270 validate_batch_inputs(queries, candidate_sets.len())?;
271 if queries.is_empty() {
272 return Ok(Vec::new());
273 }
274 if k == 0 {
275 return Ok(vec![Vec::new(); queries.len()]);
276 }
277
278 let mut canonical_sets = Vec::with_capacity(candidate_sets.len());
279 for candidates in candidate_sets {
280 checker.check()?;
281 canonical_sets.push(VectorCandidateSet::from_nodes(
282 candidates.as_ref().iter().copied(),
283 ));
284 }
285 self.score_vector_candidate_sets_batch_checked(
286 property,
287 queries,
288 &canonical_sets,
289 metric,
290 k,
291 checker,
292 )
293 }
294
295 pub fn score_vector_candidate_sets_batch(
302 &self,
303 property: &DbString,
304 queries: &[VectorValue],
305 candidate_sets: &[VectorCandidateSet],
306 metric: VectorMetric,
307 k: usize,
308 ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
309 self.score_vector_candidate_sets_batch_checked(
310 property,
311 queries,
312 candidate_sets,
313 metric,
314 k,
315 CancellationChecker::disabled(),
316 )
317 .map_err(VectorSearchError::into_graph_error)
318 }
319
320 pub fn score_vector_candidate_sets_batch_checked(
322 &self,
323 property: &DbString,
324 queries: &[VectorValue],
325 candidate_sets: &[VectorCandidateSet],
326 metric: VectorMetric,
327 k: usize,
328 checker: CancellationChecker<'_>,
329 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
330 checker.check()?;
331 validate_batch_inputs(queries, candidate_sets.len())?;
332 if queries.is_empty() {
333 return Ok(Vec::new());
334 }
335 if k == 0 {
336 return Ok(vec![Vec::new(); queries.len()]);
337 }
338
339 let should_parallelize_batch =
340 should_parallelize_candidate_batch_scoring(candidate_sets, k);
341 if let Some(candidates) = candidate_sets.first()
342 && should_parallelize_repeated_candidate_batch(queries.len(), candidates.len(), k)
343 && candidate_sets_all_match(candidate_sets)
344 {
345 return self.score_repeated_vector_candidate_set_batch_parallel(
346 property, queries, candidates, metric, k, checker,
347 );
348 }
349
350 if should_parallelize_batch {
351 return self.score_vector_candidate_sets_batch_parallel(
352 property,
353 queries,
354 candidate_sets,
355 metric,
356 k,
357 checker,
358 );
359 }
360 if candidate_sets_all_match(candidate_sets) {
361 return self.score_repeated_vector_candidate_set_batch_serial(
362 property,
363 queries,
364 &candidate_sets[0],
365 metric,
366 k,
367 checker,
368 );
369 }
370
371 self.score_vector_candidate_sets_batch_grouped_serial(
372 property,
373 queries,
374 candidate_sets,
375 metric,
376 k,
377 checker,
378 )
379 }
380
381 pub fn score_vector_neighbors(
388 &self,
389 property: &DbString,
390 query: &VectorValue,
391 anchor: NodeId,
392 options: VectorNeighborSearchOptions<'_>,
393 ) -> GraphResult<Vec<VectorNodeSearchHit>> {
394 self.score_vector_neighbors_checked(
395 property,
396 query,
397 anchor,
398 options,
399 CancellationChecker::disabled(),
400 )
401 .map_err(VectorSearchError::into_graph_error)
402 }
403
404 pub fn score_vector_neighbors_checked(
406 &self,
407 property: &DbString,
408 query: &VectorValue,
409 anchor: NodeId,
410 options: VectorNeighborSearchOptions<'_>,
411 checker: CancellationChecker<'_>,
412 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
413 checker.check()?;
414 if options.k == 0 {
415 return Ok(Vec::new());
416 }
417 let candidates =
418 self.vector_neighbor_candidates(anchor, options.edge_label, options.direction);
419 self.score_vector_candidate_set_checked(
420 property,
421 query,
422 &candidates,
423 options.metric,
424 options.k,
425 checker,
426 )
427 }
428
429 pub fn score_vector_neighbors_batch(
435 &self,
436 property: &DbString,
437 queries: &[VectorValue],
438 anchors: &[NodeId],
439 options: VectorNeighborSearchOptions<'_>,
440 ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
441 self.score_vector_neighbors_batch_checked(
442 property,
443 queries,
444 anchors,
445 options,
446 CancellationChecker::disabled(),
447 )
448 .map_err(VectorSearchError::into_graph_error)
449 }
450
451 pub fn score_vector_neighbors_batch_checked(
453 &self,
454 property: &DbString,
455 queries: &[VectorValue],
456 anchors: &[NodeId],
457 options: VectorNeighborSearchOptions<'_>,
458 checker: CancellationChecker<'_>,
459 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
460 checker.check()?;
461 validate_batch_inputs(queries, anchors.len())?;
462 if queries.is_empty() {
463 return Ok(Vec::new());
464 }
465 if options.k == 0 {
466 return Ok(vec![Vec::new(); queries.len()]);
467 }
468 let candidate_sets = self.vector_neighbor_candidate_sets_batch(
469 anchors,
470 options.edge_label,
471 options.direction,
472 options.k,
473 checker,
474 )?;
475 self.score_vector_candidate_sets_batch_checked(
476 property,
477 queries,
478 &candidate_sets,
479 options.metric,
480 options.k,
481 checker,
482 )
483 }
484
485 pub fn score_vector_expanded_candidate_sets_batch(
492 &self,
493 property: &DbString,
494 queries: &[VectorValue],
495 root_sets: &[VectorCandidateSet],
496 options: VectorNeighborSearchOptions<'_>,
497 ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
498 self.score_vector_expanded_candidate_sets_batch_checked(
499 property,
500 queries,
501 root_sets,
502 options,
503 CancellationChecker::disabled(),
504 )
505 .map_err(VectorSearchError::into_graph_error)
506 }
507
508 pub fn score_vector_expanded_candidate_sets_batch_checked(
510 &self,
511 property: &DbString,
512 queries: &[VectorValue],
513 root_sets: &[VectorCandidateSet],
514 options: VectorNeighborSearchOptions<'_>,
515 checker: CancellationChecker<'_>,
516 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
517 checker.check()?;
518 validate_batch_inputs(queries, root_sets.len())?;
519 if queries.is_empty() {
520 return Ok(Vec::new());
521 }
522 if options.k == 0 {
523 return Ok(vec![Vec::new(); queries.len()]);
524 }
525
526 let expanded_sets = self.expand_vector_candidate_sets_batch(
527 root_sets,
528 options.edge_label,
529 options.direction,
530 options.k,
531 checker,
532 )?;
533 self.score_vector_candidate_sets_batch_checked(
534 property,
535 queries,
536 &expanded_sets,
537 options.metric,
538 options.k,
539 checker,
540 )
541 }
542
543 pub fn expand_vector_candidate_sets_batch_checked(
550 &self,
551 root_sets: &[VectorCandidateSet],
552 edge_label: &DbString,
553 direction: VectorNeighborDirection,
554 k: usize,
555 checker: CancellationChecker<'_>,
556 ) -> Result<Vec<VectorCandidateSet>, VectorSearchError> {
557 self.expand_vector_candidate_sets_batch(root_sets, edge_label, direction, k, checker)
558 }
559
560 #[must_use]
567 pub fn vector_neighbor_candidates(
568 &self,
569 anchor: NodeId,
570 edge_label: &DbString,
571 direction: VectorNeighborDirection,
572 ) -> VectorCandidateSet {
573 let mut candidates = Vec::new();
574 if matches!(
575 direction,
576 VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
577 ) && let Some(entry) = self.outgoing_edges(anchor)
578 {
579 candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
580 }
581 if matches!(
582 direction,
583 VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
584 ) && let Some(entry) = self.incoming_edges(anchor)
585 {
586 candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
587 }
588 VectorCandidateSet::from_nodes(candidates)
589 }
590
591 #[must_use]
599 pub fn expand_vector_candidate_set(
600 &self,
601 roots: &VectorCandidateSet,
602 edge_label: &DbString,
603 direction: VectorNeighborDirection,
604 ) -> VectorCandidateSet {
605 self.expand_vector_candidate_set_checked(
606 roots,
607 edge_label,
608 direction,
609 CancellationChecker::disabled(),
610 )
611 .expect("disabled cancellation cannot fail")
612 }
613
614 pub fn expand_vector_candidate_set_checked(
616 &self,
617 roots: &VectorCandidateSet,
618 edge_label: &DbString,
619 direction: VectorNeighborDirection,
620 checker: CancellationChecker<'_>,
621 ) -> Result<VectorCandidateSet, VectorSearchError> {
622 checker.check()?;
623 if roots.is_empty() {
624 return Ok(VectorCandidateSet::default());
625 }
626 let mut candidates = Vec::with_capacity(roots.len());
627 candidates.extend_from_slice(roots.as_nodes());
628 let mut roots_since_check = 0usize;
629 for root in roots.as_nodes().iter().copied() {
630 roots_since_check += 1;
631 if roots_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
632 checker.note_nodes_scanned(roots_since_check)?;
633 roots_since_check = 0;
634 }
635 if matches!(
636 direction,
637 VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
638 ) && let Some(entry) = self.outgoing_edges(root)
639 {
640 candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
641 }
642 if matches!(
643 direction,
644 VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
645 ) && let Some(entry) = self.incoming_edges(root)
646 {
647 candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
648 }
649 }
650 if roots_since_check > 0 {
651 checker.note_nodes_scanned(roots_since_check)?;
652 }
653 Ok(VectorCandidateSet::from_nodes(candidates))
654 }
655}
656
657fn should_parallelize_candidate_scoring(candidate_count: usize, k: usize) -> bool {
658 should_parallelize_scan(
659 candidate_count as u64,
660 k,
661 VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES as u64,
662 )
663}
664
665fn validate_batch_inputs(
666 queries: &[VectorValue],
667 candidate_set_count: usize,
668) -> Result<(), VectorSearchError> {
669 if queries.len() != candidate_set_count {
670 return Err(VectorSearchError::BatchLengthMismatch {
671 queries: queries.len(),
672 candidate_sets: candidate_set_count,
673 });
674 }
675 let Some(first_query) = queries.first() else {
676 return Ok(());
677 };
678 let first_dimension = first_query.dimension();
679 for query in &queries[1..] {
680 if query.dimension() != first_dimension {
681 return Err(GraphError::from(CoreError::VectorDimensionMismatch {
682 lhs: first_dimension,
683 rhs: query.dimension(),
684 })
685 .into());
686 }
687 }
688 Ok(())
689}