selene_graph/vector_search/
exact_batch.rs1use roaring::RoaringBitmap;
2use selene_core::{
3 CancellationChecker, CoreError, DbString, NodeId, Value, VectorMetric, VectorMetricQuery,
4 VectorTopK, VectorValue,
5};
6
7use super::{
8 VECTOR_SEARCH_CANCEL_STRIDE, VECTOR_SEARCH_PARALLEL_CHUNK_ROWS, VectorNodeSearchHit,
9 VectorSearchError, should_parallelize_exact_scan, vector_node_hits,
10};
11use crate::error::GraphError;
12use crate::graph::SeleneGraph;
13use crate::parallel_scan::try_reduce_bitmap_chunks;
14use crate::store::RowIndex;
15
16impl SeleneGraph {
17 pub fn exact_vector_search_nodes_batch_checked(
24 &self,
25 label: &DbString,
26 property: &DbString,
27 queries: &[VectorValue],
28 metric: VectorMetric,
29 k: usize,
30 checker: CancellationChecker<'_>,
31 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
32 checker.check()?;
33 let Some(first_query) = queries.first() else {
34 return Ok(Vec::new());
35 };
36 let first_dimension = first_query.dimension();
37 for query in &queries[1..] {
38 if query.dimension() != first_dimension {
39 return Err(GraphError::from(CoreError::VectorDimensionMismatch {
40 lhs: first_dimension,
41 rhs: query.dimension(),
42 })
43 .into());
44 }
45 }
46 if k == 0 {
47 return Ok(vec![Vec::new(); queries.len()]);
48 }
49 let Some(label_rows) = self.nodes_with_label(label) else {
50 return Ok(vec![Vec::new(); queries.len()]);
51 };
52
53 let query_dimension = u32::try_from(first_dimension).ok();
54 let vector_index = query_dimension.and_then(|dimension| {
55 self.vector_index_for(label, property)
56 .filter(|index| index.dimension() == dimension)
57 });
58 let rows = vector_index
59 .as_ref()
60 .map_or(label_rows, |index| index.rows());
61 let scorers: Result<Vec<_>, GraphError> = queries
62 .iter()
63 .map(|query| metric.bind_query(query).map_err(GraphError::from))
64 .collect();
65 let scorers = scorers?;
66 if should_parallelize_exact_scan(rows, k) {
67 return self
68 .exact_vector_search_batch_parallel(label, property, &scorers, k, rows, checker);
69 }
70
71 let top_ks =
72 self.exact_vector_search_batch_serial(label, property, &scorers, k, rows, checker)?;
73 Ok(top_ks.into_iter().map(vector_node_hits).collect())
74 }
75
76 fn exact_vector_search_batch_parallel(
77 &self,
78 label: &DbString,
79 property: &DbString,
80 scorers: &[VectorMetricQuery<'_>],
81 k: usize,
82 rows: &RoaringBitmap,
83 checker: CancellationChecker<'_>,
84 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
85 let top_ks = try_reduce_bitmap_chunks(
86 rows,
87 VECTOR_SEARCH_PARALLEL_CHUNK_ROWS,
88 checker,
89 || new_batch_top_ks(scorers.len(), k),
90 |chunk| self.exact_vector_search_batch_chunk(label, property, scorers, k, chunk),
91 merge_batch_top_ks,
92 )?;
93
94 Ok(top_ks.into_iter().map(vector_node_hits).collect())
95 }
96
97 fn exact_vector_search_batch_serial(
98 &self,
99 label: &DbString,
100 property: &DbString,
101 scorers: &[VectorMetricQuery<'_>],
102 k: usize,
103 rows: &RoaringBitmap,
104 checker: CancellationChecker<'_>,
105 ) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
106 let mut top_ks = new_batch_top_ks(scorers.len(), k);
107 let mut rows_since_check = 0usize;
108 for raw_row in rows.iter() {
109 rows_since_check += 1;
110 if rows_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
111 checker.note_nodes_scanned(rows_since_check)?;
112 rows_since_check = 0;
113 }
114 self.push_batch_row(label, property, scorers, &mut top_ks, raw_row)?;
115 }
116 if rows_since_check > 0 {
117 checker.note_nodes_scanned(rows_since_check)?;
118 }
119 Ok(top_ks)
120 }
121
122 fn exact_vector_search_batch_chunk(
123 &self,
124 label: &DbString,
125 property: &DbString,
126 scorers: &[VectorMetricQuery<'_>],
127 k: usize,
128 rows: &[u32],
129 ) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
130 let mut top_ks = new_batch_top_ks(scorers.len(), k);
131 for &raw_row in rows {
132 self.push_batch_row(label, property, scorers, &mut top_ks, raw_row)?;
133 }
134 Ok(top_ks)
135 }
136
137 fn push_batch_row(
138 &self,
139 label: &DbString,
140 property: &DbString,
141 scorers: &[VectorMetricQuery<'_>],
142 top_ks: &mut [VectorTopK<NodeId>],
143 raw_row: u32,
144 ) -> Result<(), VectorSearchError> {
145 if !self.node_store.is_alive(raw_row) {
146 return Ok(());
147 }
148 let row = RowIndex::new(raw_row);
149 let node_id = self
150 .node_id_for_row(row)
151 .ok_or_else(|| GraphError::Inconsistent {
152 reason: format!(
153 "vector search row {raw_row} for {} has no node id",
154 label.as_str()
155 ),
156 })?;
157 let properties = self
158 .node_store
159 .properties
160 .get(raw_row as usize)
161 .ok_or_else(|| GraphError::Inconsistent {
162 reason: format!(
163 "vector search row {raw_row} for {} has no property row",
164 label.as_str()
165 ),
166 })?;
167 let Some(Value::Vector(vector)) = properties.get(property) else {
168 return Ok(());
169 };
170 for (scorer, top_k) in scorers.iter().zip(top_ks) {
171 let distance = scorer.distance(vector).map_err(GraphError::from)?;
172 top_k.push_distance(node_id, distance);
173 }
174 Ok(())
175 }
176}
177
178fn new_batch_top_ks(query_count: usize, k: usize) -> Vec<VectorTopK<NodeId>> {
179 (0..query_count).map(|_| VectorTopK::new(k)).collect()
180}
181
182fn merge_batch_top_ks(
183 mut lhs: Vec<VectorTopK<NodeId>>,
184 rhs: Vec<VectorTopK<NodeId>>,
185) -> Result<Vec<VectorTopK<NodeId>>, VectorSearchError> {
186 debug_assert_eq!(lhs.len(), rhs.len());
187 for (lhs_top_k, rhs_top_k) in lhs.iter_mut().zip(rhs) {
188 for hit in rhs_top_k.into_hits() {
189 lhs_top_k.push_distance(hit.key, hit.distance);
190 }
191 }
192 Ok(lhs)
193}