Skip to main content

selene_graph/vector_search/
score.rs

1use selene_core::{
2    CancellationChecker, CoreError, DbString, NodeId, Value, VectorMetric, VectorMetricQuery,
3    VectorTopK, VectorValue,
4};
5
6use crate::error::{GraphError, GraphResult};
7use crate::graph::SeleneGraph;
8use crate::parallel_scan::{should_parallelize_scan, try_reduce_chunks};
9
10use super::{
11    VECTOR_SEARCH_CANCEL_STRIDE, VectorCandidateSet, VectorNeighborDirection,
12    VectorNeighborSearchOptions, VectorNodeSearchHit, VectorSearchError, merge_top_k,
13    score_candidate_batch::{
14        candidate_sets_all_match, should_parallelize_candidate_batch_scoring,
15        should_parallelize_repeated_candidate_batch,
16    },
17    vector_node_hits,
18};
19
20#[cfg(not(test))]
21const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 4096;
22#[cfg(test)]
23const VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES: usize = 8;
24
25#[cfg(not(test))]
26const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 1024;
27#[cfg(test)]
28const VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES: usize = 4;
29
30impl SeleneGraph {
31    /// Score an explicit node candidate set against one query vector.
32    ///
33    /// This is the graph-retrieval rerank primitive: callers can produce
34    /// candidates from graph pattern matches, graph algorithms, or ANN indexes,
35    /// then rank only those nodes by a vector-valued property. Candidate ids are
36    /// deduplicated before scoring. Missing, deleted, and non-vector candidates
37    /// are skipped to match normal live-snapshot visibility.
38    pub fn score_vector_nodes(
39        &self,
40        property: &DbString,
41        query: &VectorValue,
42        candidates: &[NodeId],
43        metric: VectorMetric,
44        k: usize,
45    ) -> GraphResult<Vec<VectorNodeSearchHit>> {
46        self.score_vector_nodes_checked(
47            property,
48            query,
49            candidates,
50            metric,
51            k,
52            CancellationChecker::disabled(),
53        )
54        .map_err(VectorSearchError::into_graph_error)
55    }
56
57    /// Score explicit node candidates with cancellation checks.
58    ///
59    /// This preserves [`Self::score_vector_nodes`] ordering and visibility while
60    /// checking `checker` before work begins and every 1024 unique candidates.
61    pub fn score_vector_nodes_checked(
62        &self,
63        property: &DbString,
64        query: &VectorValue,
65        candidates: &[NodeId],
66        metric: VectorMetric,
67        k: usize,
68        checker: CancellationChecker<'_>,
69    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
70        checker.check()?;
71        if k == 0 || candidates.is_empty() {
72            return Ok(Vec::new());
73        }
74
75        let candidates = VectorCandidateSet::from_nodes(candidates.iter().copied());
76        self.score_vector_candidate_set_after_initial_check(
77            property,
78            query,
79            &candidates,
80            metric,
81            k,
82            checker,
83        )
84    }
85
86    /// Score one canonical node candidate set against one query vector.
87    ///
88    /// This is the zero-renormalization companion to
89    /// [`Self::score_vector_nodes`]. Callers that already hold a
90    /// [`VectorCandidateSet`] can avoid the extra sort/dedup pass while keeping
91    /// the same live-snapshot visibility, metric, and hit ordering semantics.
92    pub fn score_vector_candidate_set(
93        &self,
94        property: &DbString,
95        query: &VectorValue,
96        candidates: &VectorCandidateSet,
97        metric: VectorMetric,
98        k: usize,
99    ) -> GraphResult<Vec<VectorNodeSearchHit>> {
100        self.score_vector_candidate_set_checked(
101            property,
102            query,
103            candidates,
104            metric,
105            k,
106            CancellationChecker::disabled(),
107        )
108        .map_err(VectorSearchError::into_graph_error)
109    }
110
111    /// Score one canonical node candidate set with cancellation checks.
112    pub fn score_vector_candidate_set_checked(
113        &self,
114        property: &DbString,
115        query: &VectorValue,
116        candidates: &VectorCandidateSet,
117        metric: VectorMetric,
118        k: usize,
119        checker: CancellationChecker<'_>,
120    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
121        checker.check()?;
122        if k == 0 || candidates.is_empty() {
123            return Ok(Vec::new());
124        }
125        self.score_vector_candidate_set_after_initial_check(
126            property, query, candidates, metric, k, checker,
127        )
128    }
129
130    fn score_vector_candidate_set_after_initial_check(
131        &self,
132        property: &DbString,
133        query: &VectorValue,
134        candidates: &VectorCandidateSet,
135        metric: VectorMetric,
136        k: usize,
137        checker: CancellationChecker<'_>,
138    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
139        checker.check()?;
140
141        let scorer = metric.bind_query(query).map_err(GraphError::from)?;
142        if should_parallelize_candidate_scoring(candidates.len(), k) {
143            return self
144                .score_vector_candidate_set_parallel(property, scorer, candidates, k, checker);
145        }
146
147        self.score_vector_candidate_set_serial(property, scorer, candidates, k, checker)
148    }
149
150    pub(super) fn score_vector_candidate_set_serial(
151        &self,
152        property: &DbString,
153        scorer: VectorMetricQuery<'_>,
154        candidates: &VectorCandidateSet,
155        k: usize,
156        checker: CancellationChecker<'_>,
157    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
158        let mut top_k = VectorTopK::new(k);
159        for (offset, node_id) in candidates.as_nodes().iter().copied().enumerate() {
160            if offset % VECTOR_SEARCH_CANCEL_STRIDE == 0 {
161                checker.check()?;
162            }
163            let Some(properties) = self.node_properties(node_id) else {
164                continue;
165            };
166            let Some(Value::Vector(vector)) = properties.get(property) else {
167                continue;
168            };
169            let distance = scorer.distance(vector).map_err(GraphError::from)?;
170            top_k.push_distance(node_id, distance);
171        }
172
173        Ok(vector_node_hits(top_k))
174    }
175
176    fn score_vector_candidate_set_parallel(
177        &self,
178        property: &DbString,
179        scorer: VectorMetricQuery<'_>,
180        candidates: &VectorCandidateSet,
181        k: usize,
182        checker: CancellationChecker<'_>,
183    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
184        let top_k = try_reduce_chunks(
185            candidates.as_nodes(),
186            VECTOR_CANDIDATE_SCORE_PARALLEL_CHUNK_NODES,
187            checker,
188            || VectorTopK::new(k),
189            |chunk| self.score_vector_candidate_set_chunk(property, scorer, chunk, k),
190            merge_top_k,
191        )?;
192
193        Ok(vector_node_hits(top_k))
194    }
195
196    fn score_vector_candidate_set_chunk(
197        &self,
198        property: &DbString,
199        scorer: VectorMetricQuery<'_>,
200        candidates: &[NodeId],
201        k: usize,
202    ) -> Result<VectorTopK<NodeId>, VectorSearchError> {
203        let mut top_k = VectorTopK::new(k);
204        for node_id in candidates.iter().copied() {
205            let Some(properties) = self.node_properties(node_id) else {
206                continue;
207            };
208            let Some(Value::Vector(vector)) = properties.get(property) else {
209                continue;
210            };
211            let distance = scorer.distance(vector).map_err(GraphError::from)?;
212            top_k.push_distance(node_id, distance);
213        }
214        Ok(top_k)
215    }
216
217    /// Score one explicit candidate set for each query vector.
218    ///
219    /// The result position corresponds to the input query position. Candidate
220    /// sets are independent and follow [`Self::score_vector_nodes`] semantics:
221    /// each set is deduplicated, non-live or non-vector nodes are skipped, and
222    /// hits are ordered by distance then node id. The method rejects mismatched
223    /// query/candidate-set counts and mixed query dimensions before scoring.
224    pub fn score_vector_nodes_batch<C>(
225        &self,
226        property: &DbString,
227        queries: &[VectorValue],
228        candidate_sets: &[C],
229        metric: VectorMetric,
230        k: usize,
231    ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>>
232    where
233        C: AsRef<[NodeId]>,
234    {
235        self.score_vector_nodes_batch_checked(
236            property,
237            queries,
238            candidate_sets,
239            metric,
240            k,
241            CancellationChecker::disabled(),
242        )
243        .map_err(VectorSearchError::into_graph_error)
244    }
245
246    /// Score batched explicit node candidates with cancellation checks.
247    ///
248    /// This preserves [`Self::score_vector_nodes_batch`] ordering and
249    /// visibility while checking `checker` before batch validation and before
250    /// each query's candidate set is scored.
251    pub fn score_vector_nodes_batch_checked<C>(
252        &self,
253        property: &DbString,
254        queries: &[VectorValue],
255        candidate_sets: &[C],
256        metric: VectorMetric,
257        k: usize,
258        checker: CancellationChecker<'_>,
259    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError>
260    where
261        C: AsRef<[NodeId]>,
262    {
263        checker.check()?;
264        validate_batch_inputs(queries, candidate_sets.len())?;
265        if queries.is_empty() {
266            return Ok(Vec::new());
267        }
268        if k == 0 {
269            return Ok(vec![Vec::new(); queries.len()]);
270        }
271
272        let mut canonical_sets = Vec::with_capacity(candidate_sets.len());
273        for candidates in candidate_sets {
274            checker.check()?;
275            canonical_sets.push(VectorCandidateSet::from_nodes(
276                candidates.as_ref().iter().copied(),
277            ));
278        }
279        self.score_vector_candidate_sets_batch_checked(
280            property,
281            queries,
282            &canonical_sets,
283            metric,
284            k,
285            checker,
286        )
287    }
288
289    /// Score one canonical candidate set for each query vector.
290    ///
291    /// This is the batch companion to [`Self::score_vector_candidate_set`]. It
292    /// preserves the generic batch scoring contract while avoiding a second
293    /// normalization pass for callers that already hold canonical candidate
294    /// sets.
295    pub fn score_vector_candidate_sets_batch(
296        &self,
297        property: &DbString,
298        queries: &[VectorValue],
299        candidate_sets: &[VectorCandidateSet],
300        metric: VectorMetric,
301        k: usize,
302    ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
303        self.score_vector_candidate_sets_batch_checked(
304            property,
305            queries,
306            candidate_sets,
307            metric,
308            k,
309            CancellationChecker::disabled(),
310        )
311        .map_err(VectorSearchError::into_graph_error)
312    }
313
314    /// Score batched canonical candidate sets with cancellation checks.
315    pub fn score_vector_candidate_sets_batch_checked(
316        &self,
317        property: &DbString,
318        queries: &[VectorValue],
319        candidate_sets: &[VectorCandidateSet],
320        metric: VectorMetric,
321        k: usize,
322        checker: CancellationChecker<'_>,
323    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
324        checker.check()?;
325        validate_batch_inputs(queries, candidate_sets.len())?;
326        if queries.is_empty() {
327            return Ok(Vec::new());
328        }
329        if k == 0 {
330            return Ok(vec![Vec::new(); queries.len()]);
331        }
332
333        let should_parallelize_batch =
334            should_parallelize_candidate_batch_scoring(candidate_sets, k);
335        if let Some(candidates) = candidate_sets.first()
336            && should_parallelize_repeated_candidate_batch(queries.len(), candidates.len(), k)
337            && candidate_sets_all_match(candidate_sets)
338        {
339            return self.score_repeated_vector_candidate_set_batch_parallel(
340                property, queries, candidates, metric, k, checker,
341            );
342        }
343
344        if should_parallelize_batch {
345            return self.score_vector_candidate_sets_batch_parallel(
346                property,
347                queries,
348                candidate_sets,
349                metric,
350                k,
351                checker,
352            );
353        }
354        if candidate_sets_all_match(candidate_sets) {
355            return self.score_repeated_vector_candidate_set_batch_serial(
356                property,
357                queries,
358                &candidate_sets[0],
359                metric,
360                k,
361                checker,
362            );
363        }
364
365        self.score_vector_candidate_sets_batch_grouped_serial(
366            property,
367            queries,
368            candidate_sets,
369            metric,
370            k,
371            checker,
372        )
373    }
374
375    /// Score vector-valued neighbors reached from one anchor through `edge_label`.
376    ///
377    /// This is the one-hop graph candidate-set companion to
378    /// [`Self::score_vector_nodes`]. It derives candidates from the snapshot's
379    /// directed adjacency, then applies the same dedupe, visibility, metric, and
380    /// ordering rules as explicit candidate scoring.
381    pub fn score_vector_neighbors(
382        &self,
383        property: &DbString,
384        query: &VectorValue,
385        anchor: NodeId,
386        options: VectorNeighborSearchOptions<'_>,
387    ) -> GraphResult<Vec<VectorNodeSearchHit>> {
388        self.score_vector_neighbors_checked(
389            property,
390            query,
391            anchor,
392            options,
393            CancellationChecker::disabled(),
394        )
395        .map_err(VectorSearchError::into_graph_error)
396    }
397
398    /// Score vector-valued neighbors with cancellation checks.
399    pub fn score_vector_neighbors_checked(
400        &self,
401        property: &DbString,
402        query: &VectorValue,
403        anchor: NodeId,
404        options: VectorNeighborSearchOptions<'_>,
405        checker: CancellationChecker<'_>,
406    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
407        checker.check()?;
408        if options.k == 0 {
409            return Ok(Vec::new());
410        }
411        let candidates =
412            self.vector_neighbor_candidates(anchor, options.edge_label, options.direction);
413        self.score_vector_candidate_set_checked(
414            property,
415            query,
416            &candidates,
417            options.metric,
418            options.k,
419            checker,
420        )
421    }
422
423    /// Score one anchor's vector-valued neighbors for each query vector.
424    ///
425    /// `queries[i]` is scored against neighbors derived from `anchors[i]`.
426    /// Mismatched query/anchor counts and mixed query dimensions are rejected
427    /// before scoring.
428    pub fn score_vector_neighbors_batch(
429        &self,
430        property: &DbString,
431        queries: &[VectorValue],
432        anchors: &[NodeId],
433        options: VectorNeighborSearchOptions<'_>,
434    ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
435        self.score_vector_neighbors_batch_checked(
436            property,
437            queries,
438            anchors,
439            options,
440            CancellationChecker::disabled(),
441        )
442        .map_err(VectorSearchError::into_graph_error)
443    }
444
445    /// Score batched one-hop graph neighbors with cancellation checks.
446    pub fn score_vector_neighbors_batch_checked(
447        &self,
448        property: &DbString,
449        queries: &[VectorValue],
450        anchors: &[NodeId],
451        options: VectorNeighborSearchOptions<'_>,
452        checker: CancellationChecker<'_>,
453    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
454        checker.check()?;
455        validate_batch_inputs(queries, anchors.len())?;
456        if queries.is_empty() {
457            return Ok(Vec::new());
458        }
459        if options.k == 0 {
460            return Ok(vec![Vec::new(); queries.len()]);
461        }
462        let candidate_sets = self.vector_neighbor_candidate_sets_batch(
463            anchors,
464            options.edge_label,
465            options.direction,
466            options.k,
467            checker,
468        )?;
469        self.score_vector_candidate_sets_batch_checked(
470            property,
471            queries,
472            &candidate_sets,
473            options.metric,
474            options.k,
475            checker,
476        )
477    }
478
479    /// Expand one canonical root set per query through one graph hop, then score it.
480    ///
481    /// `queries[i]` is scored against `root_sets[i]` plus nodes reached from
482    /// those roots through `options.edge_label` in `options.direction`. This is
483    /// the batch graph-retrieval primitive for ANN-then-graph-expansion and
484    /// graph-query-then-vector-rerank workloads.
485    pub fn score_vector_expanded_candidate_sets_batch(
486        &self,
487        property: &DbString,
488        queries: &[VectorValue],
489        root_sets: &[VectorCandidateSet],
490        options: VectorNeighborSearchOptions<'_>,
491    ) -> GraphResult<Vec<Vec<VectorNodeSearchHit>>> {
492        self.score_vector_expanded_candidate_sets_batch_checked(
493            property,
494            queries,
495            root_sets,
496            options,
497            CancellationChecker::disabled(),
498        )
499        .map_err(VectorSearchError::into_graph_error)
500    }
501
502    /// Expand and score batched canonical root sets with cancellation checks.
503    pub fn score_vector_expanded_candidate_sets_batch_checked(
504        &self,
505        property: &DbString,
506        queries: &[VectorValue],
507        root_sets: &[VectorCandidateSet],
508        options: VectorNeighborSearchOptions<'_>,
509        checker: CancellationChecker<'_>,
510    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
511        checker.check()?;
512        validate_batch_inputs(queries, root_sets.len())?;
513        if queries.is_empty() {
514            return Ok(Vec::new());
515        }
516        if options.k == 0 {
517            return Ok(vec![Vec::new(); queries.len()]);
518        }
519
520        let expanded_sets = self.expand_vector_candidate_sets_batch(
521            root_sets,
522            options.edge_label,
523            options.direction,
524            options.k,
525            checker,
526        )?;
527        self.score_vector_candidate_sets_batch_checked(
528            property,
529            queries,
530            &expanded_sets,
531            options.metric,
532            options.k,
533            checker,
534        )
535    }
536
537    /// Expand one canonical root set per query through one graph hop.
538    ///
539    /// This is the reusable batch expansion primitive used by graph-expanded
540    /// vector scorers. It preserves input order, reuses duplicate root-set
541    /// expansion work within bounded batches, and checks cancellation while
542    /// deriving candidates.
543    pub fn expand_vector_candidate_sets_batch_checked(
544        &self,
545        root_sets: &[VectorCandidateSet],
546        edge_label: &DbString,
547        direction: VectorNeighborDirection,
548        k: usize,
549        checker: CancellationChecker<'_>,
550    ) -> Result<Vec<VectorCandidateSet>, VectorSearchError> {
551        self.expand_vector_candidate_sets_batch(root_sets, edge_label, direction, k, checker)
552    }
553
554    /// Return canonical vector-score candidates reached from one graph anchor.
555    ///
556    /// Candidates are filtered by edge label and direction, sorted by
557    /// [`NodeId`], and deduplicated. The returned set intentionally does not
558    /// check vector property presence or node liveness; scoring APIs apply
559    /// normal snapshot visibility when the set is consumed.
560    #[must_use]
561    pub fn vector_neighbor_candidates(
562        &self,
563        anchor: NodeId,
564        edge_label: &DbString,
565        direction: VectorNeighborDirection,
566    ) -> VectorCandidateSet {
567        let mut candidates = Vec::new();
568        if matches!(
569            direction,
570            VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
571        ) && let Some(entry) = self.outgoing_edges(anchor)
572        {
573            candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
574        }
575        if matches!(
576            direction,
577            VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
578        ) && let Some(entry) = self.incoming_edges(anchor)
579        {
580            candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
581        }
582        VectorCandidateSet::from_nodes(candidates)
583    }
584
585    /// Expand canonical candidates through one labeled graph hop.
586    ///
587    /// The returned set contains every root candidate plus neighbors reached
588    /// from those roots through `edge_label` in `direction`. This is the
589    /// production primitive behind graph-authored support/provenance expansion:
590    /// callers can build a small root set from graph queries or ANN hits, expand
591    /// through graph topology, then pass the canonical result to vector scoring.
592    #[must_use]
593    pub fn expand_vector_candidate_set(
594        &self,
595        roots: &VectorCandidateSet,
596        edge_label: &DbString,
597        direction: VectorNeighborDirection,
598    ) -> VectorCandidateSet {
599        self.expand_vector_candidate_set_checked(
600            roots,
601            edge_label,
602            direction,
603            CancellationChecker::disabled(),
604        )
605        .expect("disabled cancellation cannot fail")
606    }
607
608    /// Expand canonical candidates through one labeled graph hop with cancellation checks.
609    pub fn expand_vector_candidate_set_checked(
610        &self,
611        roots: &VectorCandidateSet,
612        edge_label: &DbString,
613        direction: VectorNeighborDirection,
614        checker: CancellationChecker<'_>,
615    ) -> Result<VectorCandidateSet, VectorSearchError> {
616        checker.check()?;
617        if roots.is_empty() {
618            return Ok(VectorCandidateSet::default());
619        }
620        let mut candidates = Vec::with_capacity(roots.len());
621        candidates.extend_from_slice(roots.as_nodes());
622        for (offset, root) in roots.as_nodes().iter().copied().enumerate() {
623            if offset % VECTOR_SEARCH_CANCEL_STRIDE == 0 {
624                checker.check()?;
625            }
626            if matches!(
627                direction,
628                VectorNeighborDirection::Outgoing | VectorNeighborDirection::Both
629            ) && let Some(entry) = self.outgoing_edges(root)
630            {
631                candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
632            }
633            if matches!(
634                direction,
635                VectorNeighborDirection::Incoming | VectorNeighborDirection::Both
636            ) && let Some(entry) = self.incoming_edges(root)
637            {
638                candidates.extend(entry.iter_label(edge_label).map(|edge| edge.neighbor));
639            }
640        }
641        Ok(VectorCandidateSet::from_nodes(candidates))
642    }
643}
644
645fn should_parallelize_candidate_scoring(candidate_count: usize, k: usize) -> bool {
646    should_parallelize_scan(
647        candidate_count as u64,
648        k,
649        VECTOR_CANDIDATE_SCORE_PARALLEL_MIN_NODES as u64,
650    )
651}
652
653fn validate_batch_inputs(
654    queries: &[VectorValue],
655    candidate_set_count: usize,
656) -> Result<(), VectorSearchError> {
657    if queries.len() != candidate_set_count {
658        return Err(VectorSearchError::BatchLengthMismatch {
659            queries: queries.len(),
660            candidate_sets: candidate_set_count,
661        });
662    }
663    let Some(first_query) = queries.first() else {
664        return Ok(());
665    };
666    let first_dimension = first_query.dimension();
667    for query in &queries[1..] {
668        if query.dimension() != first_dimension {
669            return Err(GraphError::from(CoreError::VectorDimensionMismatch {
670                lhs: first_dimension,
671                rhs: query.dimension(),
672            })
673            .into());
674        }
675    }
676    Ok(())
677}