1use selene_core::{
2 CancellationChecker, CoreError, DbString, NodeId, Value, VectorMetric, VectorMetricQuery,
3 VectorTopK, VectorValue,
4};
5
6use crate::error::{GraphError, GraphResult};
7use crate::graph::SeleneGraph;
8use crate::parallel_scan::{should_parallelize_scan, try_reduce_chunks};
9
10use super::{
11 VECTOR_SEARCH_CANCEL_STRIDE, VectorCandidateSet, VectorNeighborDirection,
12 VectorNeighborSearchOptions, VectorNodeSearchHit, VectorSearchError, merge_top_k,
13 score_candidate_batch::{
14 candidate_sets_all_match, should_parallelize_candidate_batch_scoring,
15 should_parallelize_repeated_candidate_batch,
16 },
17 vector_node_hits,
18};
19
20#[cfg(not(test))]
21const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 4096;
22#[cfg(test)]
23const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 8;
24
25#[cfg(not(test))]
26const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 1024;
27#[cfg(test)]
28const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 4;
29
30impl SeleneGraph {
31 pub fn score_vector_nodes(
39 &self,
40 property: &DbString,
41 query: &VectorValue,
42 candidates: &[NodeId],
43 metric: VectorMetric,
44 k: usize,
45 ) -> GraphResult<Vec<VectorNodeSearchHit>> {
46 self.score_vector_nodes_checked(
47 property,
48 query,
49 candidates,
50 metric,
51 k,
52 CancellationChecker::disabled(),
53 )
54 .map_err(VectorSearchError::into_graph_error)
55 }
56
57 pub fn score_vector_nodes_checked(
62 &self,
63 property: &DbString,
64 query: &VectorValue,
65 candidates: &[NodeId],
66 metric: VectorMetric,
67 k: usize,
68 checker: CancellationChecker<'_>,
69 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
70 checker.check()?;
71 if k == 0 || candidates.is_empty() {
72 return Ok(Vec::new());
73 }
74
75 let candidates = VectorCandidateSet::from_nodes(candidates.iter().copied());
76 self.score_vector_candidate_set_after_initial_check(
77 property,
78 query,
79 &candidates,
80 metric,
81 k,
82 checker,
83 )
84 }
85
86 pub fn score_vector_candidate_set(
93 &self,
94 property: &DbString,
95 query: &VectorValue,
96 candidates: &VectorCandidateSet,
97 metric: VectorMetric,
98 k: usize,
99 ) -> GraphResult<Vec<VectorNodeSearchHit>> {
100 self.score_vector_candidate_set_checked(
101 property,
102 query,
103 candidates,
104 metric,
105 k,
106 CancellationChecker::disabled(),
107 )
108 .map_err(VectorSearchError::into_graph_error)
109 }
110
111 pub fn score_vector_candidate_set_checked(
113 &self,
114 property: &DbString,
115 query: &VectorValue,
116 candidates: &VectorCandidateSet,
117 metric: VectorMetric,
118 k: usize,
119 checker: CancellationChecker<'_>,
120 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
121 checker.check()?;
122 if k == 0 || candidates.is_empty() {
123 return Ok(Vec::new());
124 }
125 self.score_vector_candidate_set_after_initial_check(
126 property, query, candidates, metric, k, checker,
127 )
128 }
129
130 fn score_vector_candidate_set_after_initial_check(
131 &self,
132 property: &DbString,
133 query: &VectorValue,
134 candidates: &VectorCandidateSet,
135 metric: VectorMetric,
136 k: usize,
137 checker: CancellationChecker<'_>,
138 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
139 checker.check()?;
140
141 let scorer = metric.bind_query(query).map_err(GraphError::from)?;
142 if should_parallelize_candidate_scoring(candidates.len(), k) {
143 return self
144 .score_vector_candidate_set_parallel(property, scorer, candidates, k, checker);
145 }
146
147 self.score_vector_candidate_set_serial(property, scorer, candidates, k, checker)
148 }
149
150 pub(super) fn score_vector_candidate_set_serial(
151 &self,
152 property: &DbString,
153 scorer: VectorMetricQuery<'_>,
154 candidates: &VectorCandidateSet,
155 k: usize,
156 checker: CancellationChecker<'_>,
157 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
158 let mut top_k = VectorTopK::new(k);
159 for (offset, node_id) in candidates.as_nodes().iter().copied().enumerate() {
160 if offset % VECTOR_SEARCH_CANCEL_STRIDE == 0 {
161 checker.check()?;
162 }
163 let Some(properties) = self.node_properties(node_id) else {
164 continue;
165 };
166 let Some(Value::Vector(vector)) = properties.get(property) else {
167 continue;
168 };
169 let distance = scorer.distance(vector).map_err(GraphError::from)?;
170 top_k.push_distance(node_id, distance);
171 }
172
173 Ok(vector_node_hits(top_k))
174 }
175
176 fn score_vector_candidate_set_parallel(
177 &self,
178 property: &DbString,
179 scorer: VectorMetricQuery<'_>,
180 candidates: &VectorCandidateSet,
181 k: usize,
182 checker: CancellationChecker<'_>,
183 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
184 let top_k = try_reduce_chunks(
185 candidates.as_nodes(),
186 VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES,
187 checker,
188 || VectorTopK::new(k),
189 |chunk| self.score_vector_candidate_set_chunk(property, scorer, chunk, k),
190 merge_top_k,
191 )?;
192
193 Ok(vector_node_hits(top_k))
194 }
195
196 fn score_vector_candidate_set_chunk(
197 &self,
198 property: &DbString,
199 scorer: VectorMetricQuery<'_>,
200 candidates: &[NodeId],
201 k: usize,
202 ) -> Result<VectorTopK<NodeId>, VectorSearchError> {
203 let mut top_k = VectorTopK::new(k);
204 for node_id in candidates.iter().copied() {
205 let Some(properties) = self.node_properties(node_id) else {
206 continue;
207 };
208 let Some(Value::Vector(vector)) = properties.get(property) else {
209 continue;
210 };
211 let distance = scorer.distance(vector).map_err(GraphError::from)?;
212 top_k.push_distance(node_id, distance);
213 }
214 Ok(top_k)
215 }
216
217 pub fn score_vector_nodes_batch<C>(
225 &self,
226 property: &DbString,
227 queries: &[VectorValue],
228 candidate_sets: &[C],
229 metric: VectorMetric,
230 k: usize,
231 ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>>
232 where
233 C: AsRef<[NodeId]>,
234 {
235 self.score_vector_nodes_batch_checked(
236 property,
237 queries,
238 candidate_sets,
239 metric,
240 k,
241 CancellationChecker::disabled(),
242 )
243 .map_err(VectorSearchError::into_graph_error)
244 }
245
246 pub fn score_vector_nodes_batch_checked<C>(
252 &self,
253 property: &DbString,
254 queries: &[VectorValue],
255 candidate_sets: &[C],
256 metric: VectorMetric,
257 k: usize,
258 checker: CancellationChecker<'_>,
259 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError>
260 where
261 C: AsRef<[NodeId]>,
262 {
263 checker.check()?;
264 validate_batch_inputs(queries, candidate_sets.len())?;
265 if queries.is_empty() {
266 return Ok(Vec::new());
267 }
268 if k == 0 {
269 return Ok(vec![Vec::new(); queries.len()]);
270 }
271
272 let mut canonical_sets = Vec::with_capacity(candidate_sets.len());
273 for candidates in candidate_sets {
274 checker.check()?;
275 canonical_sets.push(VectorCandidateSet::from_nodes(
276 candidates.as_ref().iter().copied(),
277 ));
278 }
279 self.score_vector_candidate_sets_batch_checked(
280 property,
281 queries,
282 &canonical_sets,
283 metric,
284 k,
285 checker,
286 )
287 }
288
289 pub fn score_vector_candidate_sets_batch(
296 &self,
297 property: &DbString,
298 queries: &[VectorValue],
299 candidate_sets: &[VectorCandidateSet],
300 metric: VectorMetric,
301 k: usize,
302 ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
303 self.score_vector_candidate_sets_batch_checked(
304 property,
305 queries,
306 candidate_sets,
307 metric,
308 k,
309 CancellationChecker::disabled(),
310 )
311 .map_err(VectorSearchError::into_graph_error)
312 }
313
314 pub fn score_vector_candidate_sets_batch_checked(
316 &self,
317 property: &DbString,
318 queries: &[VectorValue],
319 candidate_sets: &[VectorCandidateSet],
320 metric: VectorMetric,
321 k: usize,
322 checker: CancellationChecker<'_>,
323 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
324 checker.check()?;
325 validate_batch_inputs(queries, candidate_sets.len())?;
326 if queries.is_empty() {
327 return Ok(Vec::new());
328 }
329 if k == 0 {
330 return Ok(vec![Vec::new(); queries.len()]);
331 }
332
333 let should_parallelize_batch =
334 should_parallelize_candidate_batch_scoring(candidate_sets, k);
335 if let Some(candidates) = candidate_sets.first()
336 && should_parallelize_repeated_candidate_batch(queries.len(), candidates.len(), k)
337 && candidate_sets_all_match(candidate_sets)
338 {
339 return self.score_repeated_vector_candidate_set_batch_parallel(
340 property, queries, candidates, metric, k, checker,
341 );
342 }
343
344 if should_parallelize_batch {
345 return self.score_vector_candidate_sets_batch_parallel(
346 property,
347 queries,
348 candidate_sets,
349 metric,
350 k,
351 checker,
352 );
353 }
354 if candidate_sets_all_match(candidate_sets) {
355 return self.score_repeated_vector_candidate_set_batch_serial(
356 property,
357 queries,
358 &candidate_sets[0],
359 metric,
360 k,
361 checker,
362 );
363 }
364
365 self.score_vector_candidate_sets_batch_grouped_serial(
366 property,
367 queries,
368 candidate_sets,
369 metric,
370 k,
371 checker,
372 )
373 }
374
375 pub fn score_vector_neighbors(
382 &self,
383 property: &DbString,
384 query: &VectorValue,
385 anchor: NodeId,
386 options: VectorNeighborSearchOptions<'_>,
387 ) -> GraphResult<Vec<VectorNodeSearchHit>> {
388 self.score_vector_neighbors_checked(
389 property,
390 query,
391 anchor,
392 options,
393 CancellationChecker::disabled(),
394 )
395 .map_err(VectorSearchError::into_graph_error)
396 }
397
398 pub fn score_vector_neighbors_checked(
400 &self,
401 property: &DbString,
402 query: &VectorValue,
403 anchor: NodeId,
404 options: VectorNeighborSearchOptions<'_>,
405 checker: CancellationChecker<'_>,
406 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
407 checker.check()?;
408 if options.k == 0 {
409 return Ok(Vec::new());
410 }
411 let candidates =
412 self.vector_neighbor_candidates(anchor, options.edge_label, options.direction);
413 self.score_vector_candidate_set_checked(
414 property,
415 query,
416 &candidates,
417 options.metric,
418 options.k,
419 checker,
420 )
421 }
422
423 pub fn score_vector_neighbors_batch(
429 &self,
430 property: &DbString,
431 queries: &[VectorValue],
432 anchors: &[NodeId],
433 options: VectorNeighborSearchOptions<'_>,
434 ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
435 self.score_vector_neighbors_batch_checked(
436 property,
437 queries,
438 anchors,
439 options,
440 CancellationChecker::disabled(),
441 )
442 .map_err(VectorSearchError::into_graph_error)
443 }
444
445 pub fn score_vector_neighbors_batch_checked(
447 &self,
448 property: &DbString,
449 queries: &[VectorValue],
450 anchors: &[NodeId],
451 options: VectorNeighborSearchOptions<'_>,
452 checker: CancellationChecker<'_>,
453 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
454 checker.check()?;
455 validate_batch_inputs(queries, anchors.len())?;
456 if queries.is_empty() {
457 return Ok(Vec::new());
458 }
459 if options.k == 0 {
460 return Ok(vec![Vec::new(); queries.len()]);
461 }
462 let candidate_sets = self.vector_neighbor_candidate_sets_batch(
463 anchors,
464 options.edge_label,
465 options.direction,
466 options.k,
467 checker,
468 )?;
469 self.score_vector_candidate_sets_batch_checked(
470 property,
471 queries,
472 &candidate_sets,
473 options.metric,
474 options.k,
475 checker,
476 )
477 }
478
479 pub fn score_vector_expanded_candidate_sets_batch(
486 &self,
487 property: &DbString,
488 queries: &[VectorValue],
489 root_sets: &[VectorCandidateSet],
490 options: VectorNeighborSearchOptions<'_>,
491 ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
492 self.score_vector_expanded_candidate_sets_batch_checked(
493 property,
494 queries,
495 root_sets,
496 options,
497 CancellationChecker::disabled(),
498 )
499 .map_err(VectorSearchError::into_graph_error)
500 }
501
502 pub fn score_vector_expanded_candidate_sets_batch_checked(
504 &self,
505 property: &DbString,
506 queries: &[VectorValue],
507 root_sets: &[VectorCandidateSet],
508 options: VectorNeighborSearchOptions<'_>,
509 checker: CancellationChecker<'_>,
510 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
511 checker.check()?;
512 validate_batch_inputs(queries, root_sets.len())?;
513 if queries.is_empty() {
514 return Ok(Vec::new());
515 }
516 if options.k == 0 {
517 return Ok(vec![Vec::new(); queries.len()]);
518 }
519
520 let expanded_sets = self.expand_vector_candidate_sets_batch(
521 root_sets,
522 options.edge_label,
523 options.direction,
524 options.k,
525 checker,
526 )?;
527 self.score_vector_candidate_sets_batch_checked(
528 property,
529 queries,
530 &expanded_sets,
531 options.metric,
532 options.k,
533 checker,
534 )
535 }
536
537 pub fn expand_vector_candidate_sets_batch_checked(
544 &self,
545 root_sets: &[VectorCandidateSet],
546 edge_label: &DbString,
547 direction: VectorNeighborDirection,
548 k: usize,
549 checker: CancellationChecker<'_>,
550 ) -> Result<Vec<VectorCandidateSet>, VectorSearchError> {
551 self.expand_vector_candidate_sets_batch(root_sets, edge_label, direction, k, checker)
552 }
553
554 #[must_use]
561 pub fn vector_neighbor_candidates(
562 &self,
563 anchor: NodeId,
564 edge_label: &DbString,
565 direction: VectorNeighborDirection,
566 ) -> VectorCandidateSet {
567 let mut candidates = Vec::new();
568 if matches!(
569 direction,
570 VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
571 ) && let Some(entry) = self.outgoing_edges(anchor)
572 {
573 candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
574 }
575 if matches!(
576 direction,
577 VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
578 ) && let Some(entry) = self.incoming_edges(anchor)
579 {
580 candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
581 }
582 VectorCandidateSet::from_nodes(candidates)
583 }
584
585 #[must_use]
593 pub fn expand_vector_candidate_set(
594 &self,
595 roots: &VectorCandidateSet,
596 edge_label: &DbString,
597 direction: VectorNeighborDirection,
598 ) -> VectorCandidateSet {
599 self.expand_vector_candidate_set_checked(
600 roots,
601 edge_label,
602 direction,
603 CancellationChecker::disabled(),
604 )
605 .expect("disabled cancellation cannot fail")
606 }
607
608 pub fn expand_vector_candidate_set_checked(
610 &self,
611 roots: &VectorCandidateSet,
612 edge_label: &DbString,
613 direction: VectorNeighborDirection,
614 checker: CancellationChecker<'_>,
615 ) -> Result<VectorCandidateSet, VectorSearchError> {
616 checker.check()?;
617 if roots.is_empty() {
618 return Ok(VectorCandidateSet::default());
619 }
620 let mut candidates = Vec::with_capacity(roots.len());
621 candidates.extend_from_slice(roots.as_nodes());
622 for (offset, root) in roots.as_nodes().iter().copied().enumerate() {
623 if offset % VECTOR_SEARCH_CANCEL_STRIDE == 0 {
624 checker.check()?;
625 }
626 if matches!(
627 direction,
628 VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
629 ) && let Some(entry) = self.outgoing_edges(root)
630 {
631 candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
632 }
633 if matches!(
634 direction,
635 VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
636 ) && let Some(entry) = self.incoming_edges(root)
637 {
638 candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
639 }
640 }
641 Ok(VectorCandidateSet::from_nodes(candidates))
642 }
643}
644
645fn should_parallelize_candidate_scoring(candidate_count: usize, k: usize) -> bool {
646 should_parallelize_scan(
647 candidate_count as u64,
648 k,
649 VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES as u64,
650 )
651}
652
653fn validate_batch_inputs(
654 queries: &[VectorValue],
655 candidate_set_count: usize,
656) -> Result<(), VectorSearchError> {
657 if queries.len() != candidate_set_count {
658 return Err(VectorSearchError::BatchLengthMismatch {
659 queries: queries.len(),
660 candidate_sets: candidate_set_count,
661 });
662 }
663 let Some(first_query) = queries.first() else {
664 return Ok(());
665 };
666 let first_dimension = first_query.dimension();
667 for query in &queries[1..] {
668 if query.dimension() != first_dimension {
669 return Err(GraphError::from(CoreError::VectorDimensionMismatch {
670 lhs: first_dimension,
671 rhs: query.dimension(),
672 })
673 .into());
674 }
675 }
676 Ok(())
677}