selene_graph/vector_search/
approx_turbo_quant.rs1use roaring::RoaringBitmap;
4use selene_core::{CancellationChecker, DbString, VectorValue};
5
6use crate::error::GraphError;
7use crate::graph::SeleneGraph;
8
9use super::{
10 ApproximateVectorSearchOptions, VectorCandidateSet, VectorNodeSearchHit, VectorSearchError,
11 approx_batch, rerank_ann_row_candidates, turbo_quant_exact,
12};
13
14impl SeleneGraph {
15 pub fn approximate_vector_search_candidate_set_checked(
28 &self,
29 label: &DbString,
30 property: &DbString,
31 query: &VectorValue,
32 candidates: &VectorCandidateSet,
33 options: ApproximateVectorSearchOptions,
34 checker: CancellationChecker<'_>,
35 ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
36 checker.check()?;
37 if options.k == 0 || candidates.is_empty() {
38 return Ok(Vec::new());
39 }
40 let query_dimension = u32::try_from(query.dimension())
41 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
42 let Some(index) = self
43 .vector_index_for(label, property)
44 .filter(|index| index.dimension() == query_dimension)
45 else {
46 return Err(VectorSearchError::ApproximateIndexMissing);
47 };
48 let Some(indexed_metric) = index.ann_metric() else {
49 return Err(VectorSearchError::ApproximateIndexMissing);
50 };
51 if indexed_metric != options.metric {
52 return Err(VectorSearchError::ApproximateMetricMismatch {
53 indexed: indexed_metric,
54 requested: options.metric,
55 });
56 }
57 if !index.is_turbo_quant() {
58 return Err(VectorSearchError::ApproximateIndexMissing);
59 }
60
61 let allowed_rows = self.vector_candidate_rows(candidates, index.rows(), &checker)?;
62 if allowed_rows.is_empty() {
63 return Ok(Vec::new());
64 }
65 if turbo_quant_exact::covers_rows(&allowed_rows, options) {
66 return rerank_ann_row_candidates(
67 self,
68 property,
69 query,
70 options.metric,
71 options.k,
72 turbo_quant_exact::row_hits(&allowed_rows),
73 &checker,
74 );
75 }
76 let row_hits = index
77 .turbo_quant_candidates_in_rows(query, options.k, options.ef_search, &allowed_rows)
78 .ok_or(VectorSearchError::ApproximateIndexMissing)?
79 .map_err(GraphError::from)?;
80 rerank_ann_row_candidates(
81 self,
82 property,
83 query,
84 options.metric,
85 options.k,
86 row_hits,
87 &checker,
88 )
89 }
90
91 pub fn approximate_vector_search_candidate_sets_batch_checked(
97 &self,
98 label: &DbString,
99 property: &DbString,
100 queries: &[VectorValue],
101 candidate_sets: &[VectorCandidateSet],
102 options: ApproximateVectorSearchOptions,
103 checker: CancellationChecker<'_>,
104 ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
105 checker.check()?;
106 if queries.len() != candidate_sets.len() {
107 return Err(VectorSearchError::BatchLengthMismatch {
108 queries: queries.len(),
109 candidate_sets: candidate_sets.len(),
110 });
111 }
112 let Some(first_query) = queries.first() else {
113 return Ok(Vec::new());
114 };
115 if options.k == 0 {
116 return Ok(vec![Vec::new(); queries.len()]);
117 }
118 let query_dimension = u32::try_from(first_query.dimension())
119 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
120 let Some(index) = self
121 .vector_index_for(label, property)
122 .filter(|index| index.dimension() == query_dimension)
123 else {
124 return Err(VectorSearchError::ApproximateIndexMissing);
125 };
126 let Some(indexed_metric) = index.ann_metric() else {
127 return Err(VectorSearchError::ApproximateIndexMissing);
128 };
129 if indexed_metric != options.metric {
130 return Err(VectorSearchError::ApproximateMetricMismatch {
131 indexed: indexed_metric,
132 requested: options.metric,
133 });
134 }
135 if !index.is_turbo_quant() {
136 return Err(VectorSearchError::ApproximateIndexMissing);
137 }
138 for query in queries {
139 checker.check()?;
140 let dimension = u32::try_from(query.dimension())
141 .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
142 if dimension != query_dimension {
143 return Err(VectorSearchError::ApproximateIndexMissing);
144 }
145 }
146
147 if let Some(first_candidate_set) = candidate_sets.first()
148 && candidate_sets.iter().skip(1).all(|candidate_set| {
149 approx_batch::candidate_sets_match(first_candidate_set, candidate_set)
150 })
151 {
152 let allowed_rows =
153 self.vector_candidate_rows(first_candidate_set, index.rows(), &checker)?;
154 if turbo_quant_exact::covers_rows(&allowed_rows, options) {
155 let row_batches = vec![turbo_quant_exact::row_hits(&allowed_rows); queries.len()];
156 return approx_batch::rerank_ann_row_candidate_batches(
157 self,
158 property,
159 queries,
160 options.metric,
161 options.k,
162 row_batches,
163 &checker,
164 );
165 }
166 let row_batches = index
167 .turbo_quant_candidates_batch_in_shared_rows(
168 queries,
169 options.k,
170 options.ef_search,
171 &allowed_rows,
172 )
173 .ok_or(VectorSearchError::ApproximateIndexMissing)?
174 .map_err(GraphError::from)?;
175 return approx_batch::rerank_ann_row_candidate_batches(
176 self,
177 property,
178 queries,
179 options.metric,
180 options.k,
181 row_batches,
182 &checker,
183 );
184 }
185
186 let allowed_rows = candidate_sets
187 .iter()
188 .map(|candidates| self.vector_candidate_rows(candidates, index.rows(), &checker))
189 .collect::<Result<Vec<_>, _>>()?;
190 if allowed_rows
191 .iter()
192 .all(|rows| turbo_quant_exact::covers_rows(rows, options))
193 {
194 let row_batches = allowed_rows
195 .iter()
196 .map(turbo_quant_exact::row_hits)
197 .collect::<Vec<_>>();
198 return approx_batch::rerank_ann_row_candidate_batches(
199 self,
200 property,
201 queries,
202 options.metric,
203 options.k,
204 row_batches,
205 &checker,
206 );
207 }
208 let row_batches = index
209 .turbo_quant_candidates_batch_in_rows(
210 queries,
211 options.k,
212 options.ef_search,
213 &allowed_rows,
214 )
215 .ok_or(VectorSearchError::ApproximateIndexMissing)?
216 .map_err(GraphError::from)?;
217 approx_batch::rerank_ann_row_candidate_batches(
218 self,
219 property,
220 queries,
221 options.metric,
222 options.k,
223 row_batches,
224 &checker,
225 )
226 }
227
228 fn vector_candidate_rows(
229 &self,
230 candidates: &VectorCandidateSet,
231 index_rows: &RoaringBitmap,
232 checker: &CancellationChecker<'_>,
233 ) -> Result<RoaringBitmap, VectorSearchError> {
234 let mut rows = RoaringBitmap::new();
235 for node_id in candidates.as_nodes().iter().copied() {
236 checker.check()?;
237 let Some(row) = self.row_for_node_id(node_id) else {
238 continue;
239 };
240 let raw_row = row.get();
241 if self.node_store.is_alive(raw_row) && index_rows.contains(raw_row) {
242 rows.insert(raw_row);
243 }
244 }
245 Ok(rows)
246 }
247}