Skip to main content

selene_graph/vector_search/
score.rs

1use selene_core::{
2    CancellationChecker, CoreError, DbString, NodeId, Value, VectorMetric, VectorMetricQuery,
3    VectorTopK, VectorValue,
4};
5
6use crate::error::{GraphError, GraphResult};
7use crate::graph::SeleneGraph;
8use crate::parallel_scan::{should_parallelize_scan, try_reduce_chunks};
9
10use super::{
11    VECTOR_SEARCH_CANCEL_STRIDE, VectorCandidateSet, VectorNeighborDirection,
12    VectorNeighborSearchOptions, VectorNodeSearchHit, VectorSearchError, merge_top_k,
13    score_candidate_batch::{
14        candidate_sets_all_match, should_parallelize_candidate_batch_scoring,
15        should_parallelize_repeated_candidate_batch,
16    },
17    vector_node_hits,
18};
19
20#[cfg(not(test))]
21const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 4096;
22#[cfg(test)]
23const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 8;
24
25#[cfg(not(test))]
26const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 1024;
27#[cfg(test)]
28const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 4;
29
30impl SeleneGraph {
31    /// Score an explicit node candidate set against one query vector.
32    ///
33    /// This is the graph-retrieval rerank primitive: callers can produce
34    /// candidates from graph pattern matches, graph algorithms, or ANN indexes,
35    /// then rank only those nodes by a vector-valued property. Candidate ids are
36    /// deduplicated before scoring. Missing, deleted, and non-vector candidates
37    /// are skipped to match normal live-snapshot visibility.
38    pub fn score_vector_nodes(
39        &self,
40        property: &DbString,
41        query: &VectorValue,
42        candidates: &[NodeId],
43        metric: VectorMetric,
44        k: usize,
45    ) -> GraphResult<Vec<VectorNodeSearchHit>> {
46        self.score_vector_nodes_checked(
47            property,
48            query,
49            candidates,
50            metric,
51            k,
52            CancellationChecker::disabled(),
53        )
54        .map_err(VectorSearchError::into_graph_error)
55    }
56
57    /// Score explicit node candidates with cancellation checks.
58    ///
59    /// This preserves [`Self::score_vector_nodes`] ordering and visibility while
60    /// checking `checker` before work begins and every 1024 unique candidates.
61    pub fn score_vector_nodes_checked(
62        &self,
63        property: &DbString,
64        query: &VectorValue,
65        candidates: &[NodeId],
66        metric: VectorMetric,
67        k: usize,
68        checker: CancellationChecker<'_>,
69    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
70        checker.check()?;
71        if k == 0 || candidates.is_empty() {
72            return Ok(Vec::new());
73        }
74
75        let candidates = VectorCandidateSet::from_nodes(candidates.iter().copied());
76        self.score_vector_candidate_set_after_initial_check(
77            property,
78            query,
79            &candidates,
80            metric,
81            k,
82            checker,
83        )
84    }
85
86    /// Score one canonical node candidate set against one query vector.
87    ///
88    /// This is the zero-renormalization companion to
89    /// [`Self::score_vector_nodes`]. Callers that already hold a
90    /// [`VectorCandidateSet`] can avoid the extra sort/dedup pass while keeping
91    /// the same live-snapshot visibility, metric, and hit ordering semantics.
92    pub fn score_vector_candidate_set(
93        &self,
94        property: &DbString,
95        query: &VectorValue,
96        candidates: &VectorCandidateSet,
97        metric: VectorMetric,
98        k: usize,
99    ) -> GraphResult<Vec<VectorNodeSearchHit>> {
100        self.score_vector_candidate_set_checked(
101            property,
102            query,
103            candidates,
104            metric,
105            k,
106            CancellationChecker::disabled(),
107        )
108        .map_err(VectorSearchError::into_graph_error)
109    }
110
111    /// Score one canonical node candidate set with cancellation checks.
112    pub fn score_vector_candidate_set_checked(
113        &self,
114        property: &DbString,
115        query: &VectorValue,
116        candidates: &VectorCandidateSet,
117        metric: VectorMetric,
118        k: usize,
119        checker: CancellationChecker<'_>,
120    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
121        checker.check()?;
122        if k == 0 || candidates.is_empty() {
123            return Ok(Vec::new());
124        }
125        self.score_vector_candidate_set_after_initial_check(
126            property, query, candidates, metric, k, checker,
127        )
128    }
129
130    fn score_vector_candidate_set_after_initial_check(
131        &self,
132        property: &DbString,
133        query: &VectorValue,
134        candidates: &VectorCandidateSet,
135        metric: VectorMetric,
136        k: usize,
137        checker: CancellationChecker<'_>,
138    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
139        checker.check()?;
140
141        let scorer = metric.bind_query(query).map_err(GraphError::from)?;
142        if should_parallelize_candidate_scoring(candidates.len(), k) {
143            return self
144                .score_vector_candidate_set_parallel(property, scorer, candidates, k, checker);
145        }
146
147        self.score_vector_candidate_set_serial(property, scorer, candidates, k, checker)
148    }
149
150    pub(super) fn score_vector_candidate_set_serial(
151        &self,
152        property: &DbString,
153        scorer: VectorMetricQuery<'_>,
154        candidates: &VectorCandidateSet,
155        k: usize,
156        checker: CancellationChecker<'_>,
157    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
158        let mut top_k = VectorTopK::new(k);
159        let mut candidates_since_check = 0usize;
160        for node_id in candidates.as_nodes().iter().copied() {
161            candidates_since_check += 1;
162            if candidates_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
163                checker.note_nodes_scanned(candidates_since_check)?;
164                candidates_since_check = 0;
165            }
166            let Some(properties) = self.node_properties(node_id) else {
167                continue;
168            };
169            let Some(Value::Vector(vector)) = properties.get(property) else {
170                continue;
171            };
172            let distance = scorer.distance(vector).map_err(GraphError::from)?;
173            top_k.push_distance(node_id, distance);
174        }
175        if candidates_since_check > 0 {
176            checker.note_nodes_scanned(candidates_since_check)?;
177        }
178
179        Ok(vector_node_hits(top_k))
180    }
181
182    fn score_vector_candidate_set_parallel(
183        &self,
184        property: &DbString,
185        scorer: VectorMetricQuery<'_>,
186        candidates: &VectorCandidateSet,
187        k: usize,
188        checker: CancellationChecker<'_>,
189    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
190        let top_k = try_reduce_chunks(
191            candidates.as_nodes(),
192            VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES,
193            checker,
194            || VectorTopK::new(k),
195            |chunk| self.score_vector_candidate_set_chunk(property, scorer, chunk, k),
196            merge_top_k,
197        )?;
198
199        Ok(vector_node_hits(top_k))
200    }
201
202    fn score_vector_candidate_set_chunk(
203        &self,
204        property: &DbString,
205        scorer: VectorMetricQuery<'_>,
206        candidates: &[NodeId],
207        k: usize,
208    ) -> Result<VectorTopK<NodeId>, VectorSearchError> {
209        let mut top_k = VectorTopK::new(k);
210        for node_id in candidates.iter().copied() {
211            let Some(properties) = self.node_properties(node_id) else {
212                continue;
213            };
214            let Some(Value::Vector(vector)) = properties.get(property) else {
215                continue;
216            };
217            let distance = scorer.distance(vector).map_err(GraphError::from)?;
218            top_k.push_distance(node_id, distance);
219        }
220        Ok(top_k)
221    }
222
223    /// Score one explicit candidate set for each query vector.
224    ///
225    /// The result position corresponds to the input query position. Candidate
226    /// sets are independent and follow [`Self::score_vector_nodes`] semantics:
227    /// each set is deduplicated, non-live or non-vector nodes are skipped, and
228    /// hits are ordered by distance then node id. The method rejects mismatched
229    /// query/candidate-set counts and mixed query dimensions before scoring.
230    pub fn score_vector_nodes_batch<C>(
231        &self,
232        property: &DbString,
233        queries: &[VectorValue],
234        candidate_sets: &[C],
235        metric: VectorMetric,
236        k: usize,
237    ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>>
238    where
239        C: AsRef<[NodeId]>,
240    {
241        self.score_vector_nodes_batch_checked(
242            property,
243            queries,
244            candidate_sets,
245            metric,
246            k,
247            CancellationChecker::disabled(),
248        )
249        .map_err(VectorSearchError::into_graph_error)
250    }
251
252    /// Score batched explicit node candidates with cancellation checks.
253    ///
254    /// This preserves [`Self::score_vector_nodes_batch`] ordering and
255    /// visibility while checking `checker` before batch validation and before
256    /// each query's candidate set is scored.
257    pub fn score_vector_nodes_batch_checked<C>(
258        &self,
259        property: &DbString,
260        queries: &[VectorValue],
261        candidate_sets: &[C],
262        metric: VectorMetric,
263        k: usize,
264        checker: CancellationChecker<'_>,
265    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError>
266    where
267        C: AsRef<[NodeId]>,
268    {
269        checker.check()?;
270        validate_batch_inputs(queries, candidate_sets.len())?;
271        if queries.is_empty() {
272            return Ok(Vec::new());
273        }
274        if k == 0 {
275            return Ok(vec![Vec::new(); queries.len()]);
276        }
277
278        let mut canonical_sets = Vec::with_capacity(candidate_sets.len());
279        for candidates in candidate_sets {
280            checker.check()?;
281            canonical_sets.push(VectorCandidateSet::from_nodes(
282                candidates.as_ref().iter().copied(),
283            ));
284        }
285        self.score_vector_candidate_sets_batch_checked(
286            property,
287            queries,
288            &canonical_sets,
289            metric,
290            k,
291            checker,
292        )
293    }
294
295    /// Score one canonical candidate set for each query vector.
296    ///
297    /// This is the batch companion to [`Self::score_vector_candidate_set`]. It
298    /// preserves the generic batch scoring contract while avoiding a second
299    /// normalization pass for callers that already hold canonical candidate
300    /// sets.
301    pub fn score_vector_candidate_sets_batch(
302        &self,
303        property: &DbString,
304        queries: &[VectorValue],
305        candidate_sets: &[VectorCandidateSet],
306        metric: VectorMetric,
307        k: usize,
308    ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
309        self.score_vector_candidate_sets_batch_checked(
310            property,
311            queries,
312            candidate_sets,
313            metric,
314            k,
315            CancellationChecker::disabled(),
316        )
317        .map_err(VectorSearchError::into_graph_error)
318    }
319
320    /// Score batched canonical candidate sets with cancellation checks.
321    pub fn score_vector_candidate_sets_batch_checked(
322        &self,
323        property: &DbString,
324        queries: &[VectorValue],
325        candidate_sets: &[VectorCandidateSet],
326        metric: VectorMetric,
327        k: usize,
328        checker: CancellationChecker<'_>,
329    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
330        checker.check()?;
331        validate_batch_inputs(queries, candidate_sets.len())?;
332        if queries.is_empty() {
333            return Ok(Vec::new());
334        }
335        if k == 0 {
336            return Ok(vec![Vec::new(); queries.len()]);
337        }
338
339        let should_parallelize_batch =
340            should_parallelize_candidate_batch_scoring(candidate_sets, k);
341        if let Some(candidates) = candidate_sets.first()
342            && should_parallelize_repeated_candidate_batch(queries.len(), candidates.len(), k)
343            && candidate_sets_all_match(candidate_sets)
344        {
345            return self.score_repeated_vector_candidate_set_batch_parallel(
346                property, queries, candidates, metric, k, checker,
347            );
348        }
349
350        if should_parallelize_batch {
351            return self.score_vector_candidate_sets_batch_parallel(
352                property,
353                queries,
354                candidate_sets,
355                metric,
356                k,
357                checker,
358            );
359        }
360        if candidate_sets_all_match(candidate_sets) {
361            return self.score_repeated_vector_candidate_set_batch_serial(
362                property,
363                queries,
364                &candidate_sets[0],
365                metric,
366                k,
367                checker,
368            );
369        }
370
371        self.score_vector_candidate_sets_batch_grouped_serial(
372            property,
373            queries,
374            candidate_sets,
375            metric,
376            k,
377            checker,
378        )
379    }
380
381    /// Score vector-valued neighbors reached from one anchor through `edge_label`.
382    ///
383    /// This is the one-hop graph candidate-set companion to
384    /// [`Self::score_vector_nodes`]. It derives candidates from the snapshot's
385    /// directed adjacency, then applies the same dedupe, visibility, metric, and
386    /// ordering rules as explicit candidate scoring.
387    pub fn score_vector_neighbors(
388        &self,
389        property: &DbString,
390        query: &VectorValue,
391        anchor: NodeId,
392        options: VectorNeighborSearchOptions<'_>,
393    ) -> GraphResult<Vec<VectorNodeSearchHit>> {
394        self.score_vector_neighbors_checked(
395            property,
396            query,
397            anchor,
398            options,
399            CancellationChecker::disabled(),
400        )
401        .map_err(VectorSearchError::into_graph_error)
402    }
403
404    /// Score vector-valued neighbors with cancellation checks.
405    pub fn score_vector_neighbors_checked(
406        &self,
407        property: &DbString,
408        query: &VectorValue,
409        anchor: NodeId,
410        options: VectorNeighborSearchOptions<'_>,
411        checker: CancellationChecker<'_>,
412    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
413        checker.check()?;
414        if options.k == 0 {
415            return Ok(Vec::new());
416        }
417        let candidates =
418            self.vector_neighbor_candidates(anchor, options.edge_label, options.direction);
419        self.score_vector_candidate_set_checked(
420            property,
421            query,
422            &candidates,
423            options.metric,
424            options.k,
425            checker,
426        )
427    }
428
429    /// Score one anchor's vector-valued neighbors for each query vector.
430    ///
431    /// `queries[i]` is scored against neighbors derived from `anchors[i]`.
432    /// Mismatched query/anchor counts and mixed query dimensions are rejected
433    /// before scoring.
434    pub fn score_vector_neighbors_batch(
435        &self,
436        property: &DbString,
437        queries: &[VectorValue],
438        anchors: &[NodeId],
439        options: VectorNeighborSearchOptions<'_>,
440    ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
441        self.score_vector_neighbors_batch_checked(
442            property,
443            queries,
444            anchors,
445            options,
446            CancellationChecker::disabled(),
447        )
448        .map_err(VectorSearchError::into_graph_error)
449    }
450
451    /// Score batched one-hop graph neighbors with cancellation checks.
452    pub fn score_vector_neighbors_batch_checked(
453        &self,
454        property: &DbString,
455        queries: &[VectorValue],
456        anchors: &[NodeId],
457        options: VectorNeighborSearchOptions<'_>,
458        checker: CancellationChecker<'_>,
459    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
460        checker.check()?;
461        validate_batch_inputs(queries, anchors.len())?;
462        if queries.is_empty() {
463            return Ok(Vec::new());
464        }
465        if options.k == 0 {
466            return Ok(vec![Vec::new(); queries.len()]);
467        }
468        let candidate_sets = self.vector_neighbor_candidate_sets_batch(
469            anchors,
470            options.edge_label,
471            options.direction,
472            options.k,
473            checker,
474        )?;
475        self.score_vector_candidate_sets_batch_checked(
476            property,
477            queries,
478            &candidate_sets,
479            options.metric,
480            options.k,
481            checker,
482        )
483    }
484
485    /// Expand one canonical root set per query through one graph hop, then score it.
486    ///
487    /// `queries[i]` is scored against `root_sets[i]` plus nodes reached from
488    /// those roots through `options.edge_label` in `options.direction`. This is
489    /// the batch graph-retrieval primitive for ANN-then-graph-expansion and
490    /// graph-query-then-vector-rerank workloads.
491    pub fn score_vector_expanded_candidate_sets_batch(
492        &self,
493        property: &DbString,
494        queries: &[VectorValue],
495        root_sets: &[VectorCandidateSet],
496        options: VectorNeighborSearchOptions<'_>,
497    ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
498        self.score_vector_expanded_candidate_sets_batch_checked(
499            property,
500            queries,
501            root_sets,
502            options,
503            CancellationChecker::disabled(),
504        )
505        .map_err(VectorSearchError::into_graph_error)
506    }
507
508    /// Expand and score batched canonical root sets with cancellation checks.
509    pub fn score_vector_expanded_candidate_sets_batch_checked(
510        &self,
511        property: &DbString,
512        queries: &[VectorValue],
513        root_sets: &[VectorCandidateSet],
514        options: VectorNeighborSearchOptions<'_>,
515        checker: CancellationChecker<'_>,
516    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
517        checker.check()?;
518        validate_batch_inputs(queries, root_sets.len())?;
519        if queries.is_empty() {
520            return Ok(Vec::new());
521        }
522        if options.k == 0 {
523            return Ok(vec![Vec::new(); queries.len()]);
524        }
525
526        let expanded_sets = self.expand_vector_candidate_sets_batch(
527            root_sets,
528            options.edge_label,
529            options.direction,
530            options.k,
531            checker,
532        )?;
533        self.score_vector_candidate_sets_batch_checked(
534            property,
535            queries,
536            &expanded_sets,
537            options.metric,
538            options.k,
539            checker,
540        )
541    }
542
543    /// Expand one canonical root set per query through one graph hop.
544    ///
545    /// This is the reusable batch expansion primitive used by graph-expanded
546    /// vector scorers. It preserves input order, reuses duplicate root-set
547    /// expansion work within bounded batches, and checks cancellation while
548    /// deriving candidates.
549    pub fn expand_vector_candidate_sets_batch_checked(
550        &self,
551        root_sets: &[VectorCandidateSet],
552        edge_label: &DbString,
553        direction: VectorNeighborDirection,
554        k: usize,
555        checker: CancellationChecker<'_>,
556    ) -> Result<Vec<VectorCandidateSet>, VectorSearchError> {
557        self.expand_vector_candidate_sets_batch(root_sets, edge_label, direction, k, checker)
558    }
559
560    /// Return canonical vector-score candidates reached from one graph anchor.
561    ///
562    /// Candidates are filtered by edge label and direction, sorted by
563    /// [`NodeId`], and deduplicated. The returned set intentionally does not
564    /// check vector property presence or node liveness; scoring APIs apply
565    /// normal snapshot visibility when the set is consumed.
566    #[must_use]
567    pub fn vector_neighbor_candidates(
568        &self,
569        anchor: NodeId,
570        edge_label: &DbString,
571        direction: VectorNeighborDirection,
572    ) -> VectorCandidateSet {
573        let mut candidates = Vec::new();
574        if matches!(
575            direction,
576            VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
577        ) && let Some(entry) = self.outgoing_edges(anchor)
578        {
579            candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
580        }
581        if matches!(
582            direction,
583            VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
584        ) && let Some(entry) = self.incoming_edges(anchor)
585        {
586            candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
587        }
588        VectorCandidateSet::from_nodes(candidates)
589    }
590
591    /// Expand canonical candidates through one labeled graph hop.
592    ///
593    /// The returned set contains every root candidate plus neighbors reached
594    /// from those roots through `edge_label` in `direction`. This is the
595    /// production primitive behind graph-authored support/provenance expansion:
596    /// callers can build a small root set from graph queries or ANN hits, expand
597    /// through graph topology, then pass the canonical result to vector scoring.
598    #[must_use]
599    pub fn expand_vector_candidate_set(
600        &self,
601        roots: &VectorCandidateSet,
602        edge_label: &DbString,
603        direction: VectorNeighborDirection,
604    ) -> VectorCandidateSet {
605        self.expand_vector_candidate_set_checked(
606            roots,
607            edge_label,
608            direction,
609            CancellationChecker::disabled(),
610        )
611        .expect("disabled cancellation cannot fail")
612    }
613
614    /// Expand canonical candidates through one labeled graph hop with cancellation checks.
615    pub fn expand_vector_candidate_set_checked(
616        &self,
617        roots: &VectorCandidateSet,
618        edge_label: &DbString,
619        direction: VectorNeighborDirection,
620        checker: CancellationChecker<'_>,
621    ) -> Result<VectorCandidateSet, VectorSearchError> {
622        checker.check()?;
623        if roots.is_empty() {
624            return Ok(VectorCandidateSet::default());
625        }
626        let mut candidates = Vec::with_capacity(roots.len());
627        candidates.extend_from_slice(roots.as_nodes());
628        let mut roots_since_check = 0usize;
629        for root in roots.as_nodes().iter().copied() {
630            roots_since_check += 1;
631            if roots_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
632                checker.note_nodes_scanned(roots_since_check)?;
633                roots_since_check = 0;
634            }
635            if matches!(
636                direction,
637                VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
638            ) && let Some(entry) = self.outgoing_edges(root)
639            {
640                candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
641            }
642            if matches!(
643                direction,
644                VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
645            ) && let Some(entry) = self.incoming_edges(root)
646            {
647                candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
648            }
649        }
650        if roots_since_check > 0 {
651            checker.note_nodes_scanned(roots_since_check)?;
652        }
653        Ok(VectorCandidateSet::from_nodes(candidates))
654    }
655}
656
657fn should_parallelize_candidate_scoring(candidate_count: usize, k: usize) -> bool {
658    should_parallelize_scan(
659        candidate_count as u64,
660        k,
661        VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES as u64,
662    )
663}
664
665fn validate_batch_inputs(
666    queries: &[VectorValue],
667    candidate_set_count: usize,
668) -> Result<(), VectorSearchError> {
669    if queries.len() != candidate_set_count {
670        return Err(VectorSearchError::BatchLengthMismatch {
671            queries: queries.len(),
672            candidate_sets: candidate_set_count,
673        });
674    }
675    let Some(first_query) = queries.first() else {
676        return Ok(());
677    };
678    let first_dimension = first_query.dimension();
679    for query in &queries[1..] {
680        if query.dimension() != first_dimension {
681            return Err(GraphError::from(CoreError::VectorDimensionMismatch {
682                lhs: first_dimension,
683                rhs: query.dimension(),
684            })
685            .into());
686        }
687    }
688    Ok(())
689}