Skip to main content

selene_graph/vector_search/
exact_batch.rs

1use roaring::RoaringBitmap;
2use selene_core::{
3    CancellationChecker, CoreError, DbString, NodeId, Value, VectorMetric, VectorMetricQuery,
4    VectorTopK, VectorValue,
5};
6
7use super::{
8    VECTOR_SEARCH_CANCEL_STRIDE, VECTOR_SEARCH_PARALLEL_CHUNK_ROWS, VectorNodeSearchHit,
9    VectorSearchError, should_parallelize_exact_scan, vector_node_hits,
10};
11use crate::error::GraphError;
12use crate::graph::SeleneGraph;
13use crate::parallel_scan::try_reduce_bitmap_chunks;
14use crate::store::RowIndex;
15
16impl SeleneGraph {
17    /// Exhaustively rank vector-valued node properties for a batch of queries.
18    ///
19    /// The output position corresponds to the input query position. This keeps
20    /// the exact single-query semantics but resolves the row set once and scans
21    /// candidates once, which is useful for agent-memory workloads that probe
22    /// several embeddings over the same `(label, property)` surface.
23    pub fn exact_vector_search_nodes_batch_checked(
24        &self,
25        label: &DbString,
26        property: &DbString,
27        queries: &[VectorValue],
28        metric: VectorMetric,
29        k: usize,
30        checker: CancellationChecker<'_>,
31    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
32        checker.check()?;
33        let Some(first_query) = queries.first() else {
34            return Ok(Vec::new());
35        };
36        let first_dimension = first_query.dimension();
37        for query in &queries[1..] {
38            if query.dimension() != first_dimension {
39                return Err(GraphError::from(CoreError::VectorDimensionMismatch {
40                    lhs: first_dimension,
41                    rhs: query.dimension(),
42                })
43                .into());
44            }
45        }
46        if k == 0 {
47            return Ok(vec![Vec::new(); queries.len()]);
48        }
49        let Some(label_rows) = self.nodes_with_label(label) else {
50            return Ok(vec![Vec::new(); queries.len()]);
51        };
52
53        let query_dimension = u32::try_from(first_dimension).ok();
54        let vector_index = query_dimension.and_then(|dimension| {
55            self.vector_index_for(label, property)
56                .filter(|index| index.dimension() == dimension)
57        });
58        let rows = vector_index
59            .as_ref()
60            .map_or(label_rows, |index| index.rows());
61        let scorers: Result<Vec<_>, GraphError> = queries
62            .iter()
63            .map(|query| metric.bind_query(query).map_err(GraphError::from))
64            .collect();
65        let scorers = scorers?;
66        if should_parallelize_exact_scan(rows, k) {
67            return self
68                .exact_vector_search_batch_parallel(label, property, &scorers, k, rows, checker);
69        }
70
71        let top_ks =
72            self.exact_vector_search_batch_serial(label, property, &scorers, k, rows, checker)?;
73        Ok(top_ks.into_iter().map(vector_node_hits).collect())
74    }
75
76    fn exact_vector_search_batch_parallel(
77        &self,
78        label: &DbString,
79        property: &DbString,
80        scorers: &[VectorMetricQuery<'_>],
81        k: usize,
82        rows: &RoaringBitmap,
83        checker: CancellationChecker<'_>,
84    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
85        let top_ks = try_reduce_bitmap_chunks(
86            rows,
87            VECTOR_SEARCH_PARALLEL_CHUNK_ROWS,
88            checker,
89            || new_batch_top_ks(scorers.len(), k),
90            |chunk| self.exact_vector_search_batch_chunk(label, property, scorers, k, chunk),
91            merge_batch_top_ks,
92        )?;
93
94        Ok(top_ks.into_iter().map(vector_node_hits).collect())
95    }
96
97    fn exact_vector_search_batch_serial(
98        &self,
99        label: &DbString,
100        property: &DbString,
101        scorers: &[VectorMetricQuery<'_>],
102        k: usize,
103        rows: &RoaringBitmap,
104        checker: CancellationChecker<'_>,
105    ) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
106        let mut top_ks = new_batch_top_ks(scorers.len(), k);
107        let mut rows_since_check = 0usize;
108        for raw_row in rows.iter() {
109            rows_since_check += 1;
110            if rows_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
111                checker.note_nodes_scanned(rows_since_check)?;
112                rows_since_check = 0;
113            }
114            self.push_batch_row(label, property, scorers, &mut top_ks, raw_row)?;
115        }
116        if rows_since_check > 0 {
117            checker.note_nodes_scanned(rows_since_check)?;
118        }
119        Ok(top_ks)
120    }
121
122    fn exact_vector_search_batch_chunk(
123        &self,
124        label: &DbString,
125        property: &DbString,
126        scorers: &[VectorMetricQuery<'_>],
127        k: usize,
128        rows: &[u32],
129    ) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
130        let mut top_ks = new_batch_top_ks(scorers.len(), k);
131        for &raw_row in rows {
132            self.push_batch_row(label, property, scorers, &mut top_ks, raw_row)?;
133        }
134        Ok(top_ks)
135    }
136
137    fn push_batch_row(
138        &self,
139        label: &DbString,
140        property: &DbString,
141        scorers: &[VectorMetricQuery<'_>],
142        top_ks: &mut [VectorTopK<NodeId>],
143        raw_row: u32,
144    ) -> Result<(), VectorSearchError> {
145        if !self.node_store.is_alive(raw_row) {
146            return Ok(());
147        }
148        let row = RowIndex::new(raw_row);
149        let node_id = self
150            .node_id_for_row(row)
151            .ok_or_else(|| GraphError::Inconsistent {
152                reason: format!(
153                    "vector search row {raw_row} for {} has no node id",
154                    label.as_str()
155                ),
156            })?;
157        let properties = self
158            .node_store
159            .properties
160            .get(raw_row as usize)
161            .ok_or_else(|| GraphError::Inconsistent {
162                reason: format!(
163                    "vector search row {raw_row} for {} has no property row",
164                    label.as_str()
165                ),
166            })?;
167        let Some(Value::Vector(vector)) = properties.get(property) else {
168            return Ok(());
169        };
170        for (scorer, top_k) in scorers.iter().zip(top_ks) {
171            let distance = scorer.distance(vector).map_err(GraphError::from)?;
172            top_k.push_distance(node_id, distance);
173        }
174        Ok(())
175    }
176}
177
178fn new_batch_top_ks(query_count: usize, k: usize) -> Vec<VectorTopK<NodeId>> {
179    (0..query_count).map(|_| VectorTopK::new(k)).collect()
180}
181
182fn merge_batch_top_ks(
183    mut lhs: Vec<VectorTopK<NodeId>>,
184    rhs: Vec<VectorTopK<NodeId>>,
185) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
186    debug_assert_eq!(lhs.len(), rhs.len());
187    for (lhs_top_k, rhs_top_k) in lhs.iter_mut().zip(rhs) {
188        for hit in rhs_top_k.into_hits() {
189            lhs_top_k.push_distance(hit.key, hit.distance);
190        }
191    }
192    Ok(lhs)
193}