Skip to main content

selene_graph/vector_search/
exact_batch.rs

1use roaring::RoaringBitmap;
2use selene_core::{
3    CancellationChecker, CoreError, DbString, NodeId, Value, VectorMetric, VectorMetricQuery,
4    VectorTopK, VectorValue,
5};
6
7use super::{
8    VECTOR_SEARCH_CANCEL_STRIDE, VECTOR_SEARCH_PARALLEL_CHUNK_ROWS, VectorNodeSearchHit,
9    VectorSearchError, should_parallelize_exact_scan, vector_node_hits,
10};
11use crate::error::GraphError;
12use crate::graph::SeleneGraph;
13use crate::parallel_scan::try_reduce_bitmap_chunks;
14use crate::store::RowIndex;
15
16impl SeleneGraph {
17    /// Exhaustively rank vector-valued node properties for a batch of queries.
18    ///
19    /// The output position corresponds to the input query position. This keeps
20    /// the exact single-query semantics but resolves the row set once and scans
21    /// candidates once, which is useful for agent-memory workloads that probe
22    /// several embeddings over the same `(label, property)` surface.
23    pub fn exact_vector_search_nodes_batch_checked(
24        &self,
25        label: &DbString,
26        property: &DbString,
27        queries: &[VectorValue],
28        metric: VectorMetric,
29        k: usize,
30        checker: CancellationChecker<'_>,
31    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
32        checker.check()?;
33        let Some(first_query) = queries.first() else {
34            return Ok(Vec::new());
35        };
36        let first_dimension = first_query.dimension();
37        for query in &queries[1..] {
38            if query.dimension() != first_dimension {
39                return Err(GraphError::from(CoreError::VectorDimensionMismatch {
40                    lhs: first_dimension,
41                    rhs: query.dimension(),
42                })
43                .into());
44            }
45        }
46        if k == 0 {
47            return Ok(vec![Vec::new(); queries.len()]);
48        }
49        let Some(label_rows) = self.nodes_with_label(label) else {
50            return Ok(vec![Vec::new(); queries.len()]);
51        };
52
53        let query_dimension = u32::try_from(first_dimension).ok();
54        let vector_index = query_dimension.and_then(|dimension| {
55            self.vector_index_for(label, property)
56                .filter(|index| index.dimension() == dimension)
57        });
58        let rows = vector_index
59            .as_ref()
60            .map_or(label_rows, |index| index.rows());
61        let scorers: Result<Vec<_>, GraphError> = queries
62            .iter()
63            .map(|query| metric.bind_query(query).map_err(GraphError::from))
64            .collect();
65        let scorers = scorers?;
66        if should_parallelize_exact_scan(rows, k) {
67            return self
68                .exact_vector_search_batch_parallel(label, property, &scorers, k, rows, checker);
69        }
70
71        let top_ks =
72            self.exact_vector_search_batch_serial(label, property, &scorers, k, rows, checker)?;
73        Ok(top_ks.into_iter().map(vector_node_hits).collect())
74    }
75
76    fn exact_vector_search_batch_parallel(
77        &self,
78        label: &DbString,
79        property: &DbString,
80        scorers: &[VectorMetricQuery<'_>],
81        k: usize,
82        rows: &RoaringBitmap,
83        checker: CancellationChecker<'_>,
84    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
85        let top_ks = try_reduce_bitmap_chunks(
86            rows,
87            VECTOR_SEARCH_PARALLEL_CHUNK_ROWS,
88            checker,
89            || new_batch_top_ks(scorers.len(), k),
90            |chunk| self.exact_vector_search_batch_chunk(label, property, scorers, k, chunk),
91            merge_batch_top_ks,
92        )?;
93
94        Ok(top_ks.into_iter().map(vector_node_hits).collect())
95    }
96
97    fn exact_vector_search_batch_serial(
98        &self,
99        label: &DbString,
100        property: &DbString,
101        scorers: &[VectorMetricQuery<'_>],
102        k: usize,
103        rows: &RoaringBitmap,
104        checker: CancellationChecker<'_>,
105    ) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
106        let mut top_ks = new_batch_top_ks(scorers.len(), k);
107        let mut rows_since_check = 0usize;
108        for raw_row in rows.iter() {
109            rows_since_check += 1;
110            if rows_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
111                checker.check()?;
112                rows_since_check = 0;
113            }
114            self.push_batch_row(label, property, scorers, &mut top_ks, raw_row)?;
115        }
116        Ok(top_ks)
117    }
118
119    fn exact_vector_search_batch_chunk(
120        &self,
121        label: &DbString,
122        property: &DbString,
123        scorers: &[VectorMetricQuery<'_>],
124        k: usize,
125        rows: &[u32],
126    ) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
127        let mut top_ks = new_batch_top_ks(scorers.len(), k);
128        for &raw_row in rows {
129            self.push_batch_row(label, property, scorers, &mut top_ks, raw_row)?;
130        }
131        Ok(top_ks)
132    }
133
134    fn push_batch_row(
135        &self,
136        label: &DbString,
137        property: &DbString,
138        scorers: &[VectorMetricQuery<'_>],
139        top_ks: &mut [VectorTopK<NodeId>],
140        raw_row: u32,
141    ) -> Result<(), VectorSearchError> {
142        if !self.node_store.is_alive(raw_row) {
143            return Ok(());
144        }
145        let row = RowIndex::new(raw_row);
146        let node_id = self
147            .node_id_for_row(row)
148            .ok_or_else(|| GraphError::Inconsistent {
149                reason: format!(
150                    "vector search row {raw_row} for {} has no node id",
151                    label.as_str()
152                ),
153            })?;
154        let properties = self
155            .node_store
156            .properties
157            .get(raw_row as usize)
158            .ok_or_else(|| GraphError::Inconsistent {
159                reason: format!(
160                    "vector search row {raw_row} for {} has no property row",
161                    label.as_str()
162                ),
163            })?;
164        let Some(Value::Vector(vector)) = properties.get(property) else {
165            return Ok(());
166        };
167        for (scorer, top_k) in scorers.iter().zip(top_ks) {
168            let distance = scorer.distance(vector).map_err(GraphError::from)?;
169            top_k.push_distance(node_id, distance);
170        }
171        Ok(())
172    }
173}
174
175fn new_batch_top_ks(query_count: usize, k: usize) -> Vec<VectorTopK<NodeId>> {
176    (0..query_count).map(|_| VectorTopK::new(k)).collect()
177}
178
179fn merge_batch_top_ks(
180    mut lhs: Vec<VectorTopK<NodeId>>,
181    rhs: Vec<VectorTopK<NodeId>>,
182) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
183    debug_assert_eq!(lhs.len(), rhs.len());
184    for (lhs_top_k, rhs_top_k) in lhs.iter_mut().zip(rhs) {
185        for hit in rhs_top_k.into_hits() {
186            lhs_top_k.push_distance(hit.key, hit.distance);
187        }
188    }
189    Ok(lhs)
190}