selene_graph/vector_search/
approx_turbo_quant.rs1use roaring::RoaringBitmap;
4use selene_core::{CancellationChecker, DbString, VectorValue};
5
6use crate::error::GraphError;
7use crate::graph::SeleneGraph;
8
9use super::{
10 ApproximateVectorSearchOptions, VECTOR_SEARCH_CANCEL_STRIDE, VectorCandidateSet,
11 VectorNodeSearchHit, VectorSearchError, approx_batch, rerank_ann_row_candidates,
12 turbo_quant_exact,
13};
14
15impl SeleneGraph {
16 pub fn approximate_vector_search_candidate_set_checked(
29 &self,
30 label: &DbString,
31 property: &DbString,
32 query: &VectorValue,
33 candidates: &VectorCandidateSet,
34 options: ApproximateVectorSearchOptions,
35 checker: CancellationChecker<'_>,
36 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
37 checker.check()?;
38 if options.k == 0 || candidates.is_empty() {
39 return Ok(Vec::new());
40 }
41 let query_dimension = u32::try_from(query.dimension())
42 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
43 let Some(index) = self
44 .vector_index_for(label, property)
45 .filter(|index| index.dimension() == query_dimension)
46 else {
47 return Err(VectorSearchError::ApproximateIndexMissing);
48 };
49 let Some(indexed_metric) = index.ann_metric() else {
50 return Err(VectorSearchError::ApproximateIndexMissing);
51 };
52 if indexed_metric != options.metric {
53 return Err(VectorSearchError::ApproximateMetricMismatch {
54 indexed: indexed_metric,
55 requested: options.metric,
56 });
57 }
58 if !index.is_turbo_quant() {
59 return Err(VectorSearchError::ApproximateIndexMissing);
60 }
61
62 let allowed_rows = self.vector_candidate_rows(candidates, index.rows(), &checker)?;
63 if allowed_rows.is_empty() {
64 return Ok(Vec::new());
65 }
66 if turbo_quant_exact::covers_rows(&allowed_rows, options) {
67 return rerank_ann_row_candidates(
68 self,
69 property,
70 query,
71 options.metric,
72 options.k,
73 turbo_quant_exact::row_hits(&allowed_rows),
74 &checker,
75 );
76 }
77 let row_hits = index
78 .turbo_quant_candidates_in_rows(query, options.k, options.ef_search, &allowed_rows)
79 .ok_or(VectorSearchError::ApproximateIndexMissing)?
80 .map_err(GraphError::from)?;
81 rerank_ann_row_candidates(
82 self,
83 property,
84 query,
85 options.metric,
86 options.k,
87 row_hits,
88 &checker,
89 )
90 }
91
92 pub fn approximate_vector_search_candidate_sets_batch_checked(
98 &self,
99 label: &DbString,
100 property: &DbString,
101 queries: &[VectorValue],
102 candidate_sets: &[VectorCandidateSet],
103 options: ApproximateVectorSearchOptions,
104 checker: CancellationChecker<'_>,
105 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
106 checker.check()?;
107 if queries.len() != candidate_sets.len() {
108 return Err(VectorSearchError::BatchLengthMismatch {
109 queries: queries.len(),
110 candidate_sets: candidate_sets.len(),
111 });
112 }
113 let Some(first_query) = queries.first() else {
114 return Ok(Vec::new());
115 };
116 if options.k == 0 {
117 return Ok(vec![Vec::new(); queries.len()]);
118 }
119 let query_dimension = u32::try_from(first_query.dimension())
120 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
121 let Some(index) = self
122 .vector_index_for(label, property)
123 .filter(|index| index.dimension() == query_dimension)
124 else {
125 return Err(VectorSearchError::ApproximateIndexMissing);
126 };
127 let Some(indexed_metric) = index.ann_metric() else {
128 return Err(VectorSearchError::ApproximateIndexMissing);
129 };
130 if indexed_metric != options.metric {
131 return Err(VectorSearchError::ApproximateMetricMismatch {
132 indexed: indexed_metric,
133 requested: options.metric,
134 });
135 }
136 if !index.is_turbo_quant() {
137 return Err(VectorSearchError::ApproximateIndexMissing);
138 }
139 for query in queries {
140 checker.check()?;
141 let dimension = u32::try_from(query.dimension())
142 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
143 if dimension != query_dimension {
144 return Err(VectorSearchError::ApproximateIndexMissing);
145 }
146 }
147
148 if let Some(first_candidate_set) = candidate_sets.first()
149 && candidate_sets.iter().skip(1).all(|candidate_set| {
150 approx_batch::candidate_sets_match(first_candidate_set, candidate_set)
151 })
152 {
153 let allowed_rows =
154 self.vector_candidate_rows(first_candidate_set, index.rows(), &checker)?;
155 if turbo_quant_exact::covers_rows(&allowed_rows, options) {
156 let row_batches = vec![turbo_quant_exact::row_hits(&allowed_rows); queries.len()];
157 return approx_batch::rerank_ann_row_candidate_batches(
158 self,
159 property,
160 queries,
161 options.metric,
162 options.k,
163 row_batches,
164 &checker,
165 );
166 }
167 let row_batches = index
168 .turbo_quant_candidates_batch_in_shared_rows(
169 queries,
170 options.k,
171 options.ef_search,
172 &allowed_rows,
173 )
174 .ok_or(VectorSearchError::ApproximateIndexMissing)?
175 .map_err(GraphError::from)?;
176 return approx_batch::rerank_ann_row_candidate_batches(
177 self,
178 property,
179 queries,
180 options.metric,
181 options.k,
182 row_batches,
183 &checker,
184 );
185 }
186
187 let allowed_rows = candidate_sets
188 .iter()
189 .map(|candidates| self.vector_candidate_rows(candidates, index.rows(), &checker))
190 .collect::<Result<Vec<_>, _>>()?;
191 if allowed_rows
192 .iter()
193 .all(|rows| turbo_quant_exact::covers_rows(rows, options))
194 {
195 let row_batches = allowed_rows
196 .iter()
197 .map(turbo_quant_exact::row_hits)
198 .collect::<Vec<_>>();
199 return approx_batch::rerank_ann_row_candidate_batches(
200 self,
201 property,
202 queries,
203 options.metric,
204 options.k,
205 row_batches,
206 &checker,
207 );
208 }
209 let row_batches = index
210 .turbo_quant_candidates_batch_in_rows(
211 queries,
212 options.k,
213 options.ef_search,
214 &allowed_rows,
215 )
216 .ok_or(VectorSearchError::ApproximateIndexMissing)?
217 .map_err(GraphError::from)?;
218 approx_batch::rerank_ann_row_candidate_batches(
219 self,
220 property,
221 queries,
222 options.metric,
223 options.k,
224 row_batches,
225 &checker,
226 )
227 }
228
229 fn vector_candidate_rows(
230 &self,
231 candidates: &VectorCandidateSet,
232 index_rows: &RoaringBitmap,
233 checker: &CancellationChecker<'_>,
234 ) -> Result<RoaringBitmap, VectorSearchError> {
235 let mut rows = RoaringBitmap::new();
236 let mut candidates_since_check = 0usize;
237 for node_id in candidates.as_nodes().iter().copied() {
238 candidates_since_check += 1;
239 if candidates_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
240 checker.note_nodes_scanned(candidates_since_check)?;
241 candidates_since_check = 0;
242 }
243 let Some(row) = self.row_for_node_id(node_id) else {
244 continue;
245 };
246 let raw_row = row.get();
247 if self.node_store.is_alive(raw_row) && index_rows.contains(raw_row) {
248 rows.insert(raw_row);
249 }
250 }
251 if candidates_since_check > 0 {
252 checker.note_nodes_scanned(candidates_since_check)?;
253 }
254 Ok(rows)
255 }
256}