Skip to main content

selene_graph/vector_search/
approx_turbo_quant.rs

1//! TurboQuant approximate search over explicit candidate sets.
2
3use roaring::RoaringBitmap;
4use selene_core::{CancellationChecker, DbString, VectorValue};
5
6use crate::error::GraphError;
7use crate::graph::SeleneGraph;
8
9use super::{
10    ApproximateVectorSearchOptions, VECTOR_SEARCH_CANCEL_STRIDE, VectorCandidateSet,
11    VectorNodeSearchHit, VectorSearchError, approx_batch, rerank_ann_row_candidates,
12    turbo_quant_exact,
13};
14
15impl SeleneGraph {
16    /// Approximately rank a canonical node candidate set through a TurboQuant index.
17    ///
18    /// This is the approximate counterpart to
19    /// [`Self::score_vector_candidate_set_checked`]: callers supply the
20    /// candidate set explicitly, TurboQuant preselects up to `ef_search`
21    /// candidates within that set, and the graph layer exact-reranks the
22    /// returned rows against primary `VECTOR` values. Missing nodes, nodes
23    /// outside the registered `(label, property)` vector index, and nodes
24    /// without a vector value are skipped under the normal snapshot visibility
25    /// rules. When the search width covers every surviving indexed candidate
26    /// row, the compressed preselection pass is skipped and those rows are
27    /// exact-scored directly.
28    pub fn approximate_vector_search_candidate_set_checked(
29        &self,
30        label: &DbString,
31        property: &DbString,
32        query: &VectorValue,
33        candidates: &VectorCandidateSet,
34        options: ApproximateVectorSearchOptions,
35        checker: CancellationChecker<'_>,
36    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
37        checker.check()?;
38        if options.k == 0 || candidates.is_empty() {
39            return Ok(Vec::new());
40        }
41        let query_dimension = u32::try_from(query.dimension())
42            .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
43        let Some(index) = self
44            .vector_index_for(label, property)
45            .filter(|index| index.dimension() == query_dimension)
46        else {
47            return Err(VectorSearchError::ApproximateIndexMissing);
48        };
49        let Some(indexed_metric) = index.ann_metric() else {
50            return Err(VectorSearchError::ApproximateIndexMissing);
51        };
52        if indexed_metric != options.metric {
53            return Err(VectorSearchError::ApproximateMetricMismatch {
54                indexed: indexed_metric,
55                requested: options.metric,
56            });
57        }
58        if !index.is_turbo_quant() {
59            return Err(VectorSearchError::ApproximateIndexMissing);
60        }
61
62        let allowed_rows = self.vector_candidate_rows(candidates, index.rows(), &checker)?;
63        if allowed_rows.is_empty() {
64            return Ok(Vec::new());
65        }
66        if turbo_quant_exact::covers_rows(&allowed_rows, options) {
67            return rerank_ann_row_candidates(
68                self,
69                property,
70                query,
71                options.metric,
72                options.k,
73                turbo_quant_exact::row_hits(&allowed_rows),
74                &checker,
75            );
76        }
77        let row_hits = index
78            .turbo_quant_candidates_in_rows(query, options.k, options.ef_search, &allowed_rows)
79            .ok_or(VectorSearchError::ApproximateIndexMissing)?
80            .map_err(GraphError::from)?;
81        rerank_ann_row_candidates(
82            self,
83            property,
84            query,
85            options.metric,
86            options.k,
87            row_hits,
88            &checker,
89        )
90    }
91
92    /// Approximately rank one canonical node candidate set per query.
93    ///
94    /// Each `queries[i]` is searched only within `candidate_sets[i]` through a
95    /// matching TurboQuant index, then exact-reranked against primary vector
96    /// values. Output positions correspond to input query positions.
97    pub fn approximate_vector_search_candidate_sets_batch_checked(
98        &self,
99        label: &DbString,
100        property: &DbString,
101        queries: &[VectorValue],
102        candidate_sets: &[VectorCandidateSet],
103        options: ApproximateVectorSearchOptions,
104        checker: CancellationChecker<'_>,
105    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
106        checker.check()?;
107        if queries.len() != candidate_sets.len() {
108            return Err(VectorSearchError::BatchLengthMismatch {
109                queries: queries.len(),
110                candidate_sets: candidate_sets.len(),
111            });
112        }
113        let Some(first_query) = queries.first() else {
114            return Ok(Vec::new());
115        };
116        if options.k == 0 {
117            return Ok(vec![Vec::new(); queries.len()]);
118        }
119        let query_dimension = u32::try_from(first_query.dimension())
120            .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
121        let Some(index) = self
122            .vector_index_for(label, property)
123            .filter(|index| index.dimension() == query_dimension)
124        else {
125            return Err(VectorSearchError::ApproximateIndexMissing);
126        };
127        let Some(indexed_metric) = index.ann_metric() else {
128            return Err(VectorSearchError::ApproximateIndexMissing);
129        };
130        if indexed_metric != options.metric {
131            return Err(VectorSearchError::ApproximateMetricMismatch {
132                indexed: indexed_metric,
133                requested: options.metric,
134            });
135        }
136        if !index.is_turbo_quant() {
137            return Err(VectorSearchError::ApproximateIndexMissing);
138        }
139        for query in queries {
140            checker.check()?;
141            let dimension = u32::try_from(query.dimension())
142                .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
143            if dimension != query_dimension {
144                return Err(VectorSearchError::ApproximateIndexMissing);
145            }
146        }
147
148        if let Some(first_candidate_set) = candidate_sets.first()
149            && candidate_sets.iter().skip(1).all(|candidate_set| {
150                approx_batch::candidate_sets_match(first_candidate_set, candidate_set)
151            })
152        {
153            let allowed_rows =
154                self.vector_candidate_rows(first_candidate_set, index.rows(), &checker)?;
155            if turbo_quant_exact::covers_rows(&allowed_rows, options) {
156                let row_batches = vec![turbo_quant_exact::row_hits(&allowed_rows); queries.len()];
157                return approx_batch::rerank_ann_row_candidate_batches(
158                    self,
159                    property,
160                    queries,
161                    options.metric,
162                    options.k,
163                    row_batches,
164                    &checker,
165                );
166            }
167            let row_batches = index
168                .turbo_quant_candidates_batch_in_shared_rows(
169                    queries,
170                    options.k,
171                    options.ef_search,
172                    &allowed_rows,
173                )
174                .ok_or(VectorSearchError::ApproximateIndexMissing)?
175                .map_err(GraphError::from)?;
176            return approx_batch::rerank_ann_row_candidate_batches(
177                self,
178                property,
179                queries,
180                options.metric,
181                options.k,
182                row_batches,
183                &checker,
184            );
185        }
186
187        let allowed_rows = candidate_sets
188            .iter()
189            .map(|candidates| self.vector_candidate_rows(candidates, index.rows(), &checker))
190            .collect::<Result<Vec<_>, _>>()?;
191        if allowed_rows
192            .iter()
193            .all(|rows| turbo_quant_exact::covers_rows(rows, options))
194        {
195            let row_batches = allowed_rows
196                .iter()
197                .map(turbo_quant_exact::row_hits)
198                .collect::<Vec<_>>();
199            return approx_batch::rerank_ann_row_candidate_batches(
200                self,
201                property,
202                queries,
203                options.metric,
204                options.k,
205                row_batches,
206                &checker,
207            );
208        }
209        let row_batches = index
210            .turbo_quant_candidates_batch_in_rows(
211                queries,
212                options.k,
213                options.ef_search,
214                &allowed_rows,
215            )
216            .ok_or(VectorSearchError::ApproximateIndexMissing)?
217            .map_err(GraphError::from)?;
218        approx_batch::rerank_ann_row_candidate_batches(
219            self,
220            property,
221            queries,
222            options.metric,
223            options.k,
224            row_batches,
225            &checker,
226        )
227    }
228
229    fn vector_candidate_rows(
230        &self,
231        candidates: &VectorCandidateSet,
232        index_rows: &RoaringBitmap,
233        checker: &CancellationChecker<'_>,
234    ) -> Result<RoaringBitmap, VectorSearchError> {
235        let mut rows = RoaringBitmap::new();
236        let mut candidates_since_check = 0usize;
237        for node_id in candidates.as_nodes().iter().copied() {
238            candidates_since_check += 1;
239            if candidates_since_check >= VECTOR_SEARCH_CANCEL_STRIDE {
240                checker.note_nodes_scanned(candidates_since_check)?;
241                candidates_since_check = 0;
242            }
243            let Some(row) = self.row_for_node_id(node_id) else {
244                continue;
245            };
246            let raw_row = row.get();
247            if self.node_store.is_alive(raw_row) && index_rows.contains(raw_row) {
248                rows.insert(raw_row);
249            }
250        }
251        if candidates_since_check > 0 {
252            checker.note_nodes_scanned(candidates_since_check)?;
253        }
254        Ok(rows)
255    }
256}