Skip to main content

selene_graph/vector_search/
approx_turbo_quant.rs

1//! TurboQuant approximate search over explicit candidate sets.
2
3use roaring::RoaringBitmap;
4use selene_core::{CancellationChecker, DbString, VectorValue};
5
6use crate::error::GraphError;
7use crate::graph::SeleneGraph;
8
9use super::{
10    ApproximateVectorSearchOptions, VectorCandidateSet, VectorNodeSearchHit, VectorSearchError,
11    approx_batch, rerank_ann_row_candidates, turbo_quant_exact,
12};
13
14impl SeleneGraph {
15    /// Approximately rank a canonical node candidate set through a TurboQuant index.
16    ///
17    /// This is the approximate counterpart to
18    /// [`Self::score_vector_candidate_set_checked`]: callers supply the
19    /// candidate set explicitly, TurboQuant preselects up to `ef_search`
20    /// candidates within that set, and the graph layer exact-reranks the
21    /// returned rows against primary `VECTOR` values. Missing nodes, nodes
22    /// outside the registered `(label, property)` vector index, and nodes
23    /// without a vector value are skipped under the normal snapshot visibility
24    /// rules. When the search width covers every surviving indexed candidate
25    /// row, the compressed preselection pass is skipped and those rows are
26    /// exact-scored directly.
27    pub fn approximate_vector_search_candidate_set_checked(
28        &self,
29        label: &DbString,
30        property: &DbString,
31        query: &VectorValue,
32        candidates: &VectorCandidateSet,
33        options: ApproximateVectorSearchOptions,
34        checker: CancellationChecker<'_>,
35    ) -> Result<Vec<VectorNodeSearchHit>, VectorSearchError> {
36        checker.check()?;
37        if options.k == 0 || candidates.is_empty() {
38            return Ok(Vec::new());
39        }
40        let query_dimension = u32::try_from(query.dimension())
41            .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
42        let Some(index) = self
43            .vector_index_for(label, property)
44            .filter(|index| index.dimension() == query_dimension)
45        else {
46            return Err(VectorSearchError::ApproximateIndexMissing);
47        };
48        let Some(indexed_metric) = index.ann_metric() else {
49            return Err(VectorSearchError::ApproximateIndexMissing);
50        };
51        if indexed_metric != options.metric {
52            return Err(VectorSearchError::ApproximateMetricMismatch {
53                indexed: indexed_metric,
54                requested: options.metric,
55            });
56        }
57        if !index.is_turbo_quant() {
58            return Err(VectorSearchError::ApproximateIndexMissing);
59        }
60
61        let allowed_rows = self.vector_candidate_rows(candidates, index.rows(), &checker)?;
62        if allowed_rows.is_empty() {
63            return Ok(Vec::new());
64        }
65        if turbo_quant_exact::covers_rows(&allowed_rows, options) {
66            return rerank_ann_row_candidates(
67                self,
68                property,
69                query,
70                options.metric,
71                options.k,
72                turbo_quant_exact::row_hits(&allowed_rows),
73                &checker,
74            );
75        }
76        let row_hits = index
77            .turbo_quant_candidates_in_rows(query, options.k, options.ef_search, &allowed_rows)
78            .ok_or(VectorSearchError::ApproximateIndexMissing)?
79            .map_err(GraphError::from)?;
80        rerank_ann_row_candidates(
81            self,
82            property,
83            query,
84            options.metric,
85            options.k,
86            row_hits,
87            &checker,
88        )
89    }
90
91    /// Approximately rank one canonical node candidate set per query.
92    ///
93    /// Each `queries[i]` is searched only within `candidate_sets[i]` through a
94    /// matching TurboQuant index, then exact-reranked against primary vector
95    /// values. Output positions correspond to input query positions.
96    pub fn approximate_vector_search_candidate_sets_batch_checked(
97        &self,
98        label: &DbString,
99        property: &DbString,
100        queries: &[VectorValue],
101        candidate_sets: &[VectorCandidateSet],
102        options: ApproximateVectorSearchOptions,
103        checker: CancellationChecker<'_>,
104    ) -> Result<Vec<Vec<VectorNodeSearchHit>>, VectorSearchError> {
105        checker.check()?;
106        if queries.len() != candidate_sets.len() {
107            return Err(VectorSearchError::BatchLengthMismatch {
108                queries: queries.len(),
109                candidate_sets: candidate_sets.len(),
110            });
111        }
112        let Some(first_query) = queries.first() else {
113            return Ok(Vec::new());
114        };
115        if options.k == 0 {
116            return Ok(vec![Vec::new(); queries.len()]);
117        }
118        let query_dimension = u32::try_from(first_query.dimension())
119            .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
120        let Some(index) = self
121            .vector_index_for(label, property)
122            .filter(|index| index.dimension() == query_dimension)
123        else {
124            return Err(VectorSearchError::ApproximateIndexMissing);
125        };
126        let Some(indexed_metric) = index.ann_metric() else {
127            return Err(VectorSearchError::ApproximateIndexMissing);
128        };
129        if indexed_metric != options.metric {
130            return Err(VectorSearchError::ApproximateMetricMismatch {
131                indexed: indexed_metric,
132                requested: options.metric,
133            });
134        }
135        if !index.is_turbo_quant() {
136            return Err(VectorSearchError::ApproximateIndexMissing);
137        }
138        for query in queries {
139            checker.check()?;
140            let dimension = u32::try_from(query.dimension())
141                .map_err(|_| VectorSearchError::ApproximateIndexMissing)?;
142            if dimension != query_dimension {
143                return Err(VectorSearchError::ApproximateIndexMissing);
144            }
145        }
146
147        if let Some(first_candidate_set) = candidate_sets.first()
148            && candidate_sets.iter().skip(1).all(|candidate_set| {
149                approx_batch::candidate_sets_match(first_candidate_set, candidate_set)
150            })
151        {
152            let allowed_rows =
153                self.vector_candidate_rows(first_candidate_set, index.rows(), &checker)?;
154            if turbo_quant_exact::covers_rows(&allowed_rows, options) {
155                let row_batches = vec![turbo_quant_exact::row_hits(&allowed_rows); queries.len()];
156                return approx_batch::rerank_ann_row_candidate_batches(
157                    self,
158                    property,
159                    queries,
160                    options.metric,
161                    options.k,
162                    row_batches,
163                    &checker,
164                );
165            }
166            let row_batches = index
167                .turbo_quant_candidates_batch_in_shared_rows(
168                    queries,
169                    options.k,
170                    options.ef_search,
171                    &allowed_rows,
172                )
173                .ok_or(VectorSearchError::ApproximateIndexMissing)?
174                .map_err(GraphError::from)?;
175            return approx_batch::rerank_ann_row_candidate_batches(
176                self,
177                property,
178                queries,
179                options.metric,
180                options.k,
181                row_batches,
182                &checker,
183            );
184        }
185
186        let allowed_rows = candidate_sets
187            .iter()
188            .map(|candidates| self.vector_candidate_rows(candidates, index.rows(), &checker))
189            .collect::<Result<Vec<_>, _>>()?;
190        if allowed_rows
191            .iter()
192            .all(|rows| turbo_quant_exact::covers_rows(rows, options))
193        {
194            let row_batches = allowed_rows
195                .iter()
196                .map(turbo_quant_exact::row_hits)
197                .collect::<Vec<_>>();
198            return approx_batch::rerank_ann_row_candidate_batches(
199                self,
200                property,
201                queries,
202                options.metric,
203                options.k,
204                row_batches,
205                &checker,
206            );
207        }
208        let row_batches = index
209            .turbo_quant_candidates_batch_in_rows(
210                queries,
211                options.k,
212                options.ef_search,
213                &allowed_rows,
214            )
215            .ok_or(VectorSearchError::ApproximateIndexMissing)?
216            .map_err(GraphError::from)?;
217        approx_batch::rerank_ann_row_candidate_batches(
218            self,
219            property,
220            queries,
221            options.metric,
222            options.k,
223            row_batches,
224            &checker,
225        )
226    }
227
228    fn vector_candidate_rows(
229        &self,
230        candidates: &VectorCandidateSet,
231        index_rows: &RoaringBitmap,
232        checker: &CancellationChecker<'_>,
233    ) -> Result<RoaringBitmap, VectorSearchError> {
234        let mut rows = RoaringBitmap::new();
235        for node_id in candidates.as_nodes().iter().copied() {
236            checker.check()?;
237            let Some(row) = self.row_for_node_id(node_id) else {
238                continue;
239            };
240            let raw_row = row.get();
241            if self.node_store.is_alive(raw_row) && index_rows.contains(raw_row) {
242                rows.insert(raw_row);
243            }
244        }
245        Ok(rows)
246    }
247}