Skip to main content

sochdb_vector/query/
engine.rs

1//! Query engine implementation.
2
3use std::collections::HashSet;
4use std::sync::Arc;
5use std::time::Instant;
6
7use crate::config::EngineConfig;
8use crate::error::{Error, Result};
9use crate::filter::BitsetFilter;
10use crate::rotation::Rotator;
11use crate::segment::Segment;
12use crate::segment::bps::{BpsBuilder, BpsScanner};
13use crate::segment::rdf::RdfScorer;
14use crate::segment::rerank::{Reranker, quantize_query};
15use crate::types::*;
16
17/// Query engine for executing vector searches
18pub struct QueryEngine {
19    config: EngineConfig,
20    segments: Vec<Arc<Segment>>,
21    rotator: Rotator,
22}
23
24impl QueryEngine {
25    /// Create a new query engine
26    pub fn new(config: EngineConfig) -> Result<Self> {
27        config.validate()?;
28        let rotator = Rotator::new(config.dim);
29
30        Ok(Self {
31            config,
32            segments: Vec::new(),
33            rotator,
34        })
35    }
36
37    /// Add a segment to the engine
38    pub fn add_segment(&mut self, segment: Arc<Segment>) -> Result<()> {
39        if segment.dim() != self.config.dim {
40            return Err(Error::DimensionMismatch {
41                expected: self.config.dim,
42                got: segment.dim(),
43            });
44        }
45        self.segments.push(segment);
46        Ok(())
47    }
48
49    /// Load a segment from file
50    pub fn load_segment(&mut self, path: &str) -> Result<()> {
51        let segment = Segment::open(path)?;
52        self.add_segment(Arc::new(segment))
53    }
54
55    /// Execute a query
56    pub fn search(&self, query: &[f32], params: &QueryParams) -> Result<QueryResult> {
57        if query.len() != self.config.dim as usize {
58            return Err(Error::DimensionMismatch {
59                expected: self.config.dim,
60                got: query.len() as u32,
61            });
62        }
63
64        if self.segments.is_empty() {
65            return Err(Error::EmptyIndex);
66        }
67
68        let total_start = Instant::now();
69        let mut stats = QueryStats::default();
70
71        // Step 0: Rotate query
72        let rotate_start = Instant::now();
73        let rotated_query = self.rotator.rotate(query);
74        stats.time_rotate_ns = rotate_start.elapsed().as_nanos() as u64;
75
76        // Prepare filter
77        let filter = params.filter.as_ref().map(|bits| {
78            BitsetFilter::from_ids(
79                self.total_vectors(),
80                &bits.iter().map(|&id| id as VectorId).collect::<Vec<_>>(),
81            )
82        });
83
84        // Search each segment
85        let mut all_candidates: Vec<ScoredCandidate> = Vec::new();
86
87        for segment in &self.segments {
88            let segment_result = self.search_segment(
89                segment,
90                &rotated_query,
91                query,
92                params,
93                filter.as_ref(),
94                &mut stats,
95            )?;
96            all_candidates.extend(segment_result);
97        }
98
99        // Merge and select top k
100        all_candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
101        all_candidates.truncate(params.k);
102
103        stats.total_time_ns = total_start.elapsed().as_nanos() as u64;
104
105        Ok(QueryResult {
106            candidates: all_candidates,
107            stats,
108        })
109    }
110
111    /// Search within a single segment
112    fn search_segment(
113        &self,
114        segment: &Segment,
115        rotated_query: &[f32],
116        _original_query: &[f32],
117        params: &QueryParams,
118        filter: Option<&BitsetFilter>,
119        stats: &mut QueryStats,
120    ) -> Result<Vec<ScoredCandidate>> {
121        let header = segment.header();
122        let _n_vec = header.n_vec as usize;
123
124        // Compute selectivity for filter-aware widening
125        let selectivity = filter
126            .map(|f| {
127                let mut f = f.clone();
128                f.selectivity()
129            })
130            .unwrap_or(1.0);
131
132        let widening_factor = if selectivity < 1.0 && params.adaptive {
133            (1.0 / selectivity).min(4.0)
134        } else {
135            1.0
136        };
137
138        // Step 2: RDF candidate generation
139        let rdf_start = Instant::now();
140        let rdf_candidates = if header
141            .flags
142            .has(crate::segment::format::SegmentFlags::HAS_RDF)
143        {
144            let l_a_widened = ((params.l_a as f32) * widening_factor) as usize;
145            self.rdf_search(segment, rotated_query, l_a_widened)
146        } else {
147            Vec::new()
148        };
149        stats.time_rdf_ns += rdf_start.elapsed().as_nanos() as u64;
150        stats.rdf_candidates += rdf_candidates.len();
151
152        // Step 3: BPS candidate generation
153        let bps_start = Instant::now();
154        let bps_candidates = if header
155            .flags
156            .has(crate::segment::format::SegmentFlags::HAS_BPS)
157        {
158            let l_b_widened = ((params.l_b as f32) * widening_factor) as usize;
159            self.bps_search(segment, rotated_query, l_b_widened)
160        } else {
161            Vec::new()
162        };
163        stats.time_bps_ns += bps_start.elapsed().as_nanos() as u64;
164        stats.bps_candidates += bps_candidates.len();
165
166        // Step 4: Union candidates
167        let mut candidate_set: HashSet<VectorId> = HashSet::new();
168        for c in &rdf_candidates {
169            candidate_set.insert(c.id);
170        }
171        for (vid, _) in &bps_candidates {
172            candidate_set.insert(*vid);
173        }
174        stats.union_size += candidate_set.len();
175
176        // Apply filter
177        let filter_start = Instant::now();
178        let filtered_candidates: Vec<VectorId> = if let Some(f) = filter {
179            candidate_set
180                .into_iter()
181                .filter(|&id| f.contains(id) && !segment.is_tombstoned(id))
182                .collect()
183        } else {
184            candidate_set
185                .into_iter()
186                .filter(|&id| !segment.is_tombstoned(id))
187                .collect()
188        };
189        stats.time_filter_ns += filter_start.elapsed().as_nanos() as u64;
190        stats.post_filter_size += filtered_candidates.len();
191
192        // Step 5: Rerank
193        let rerank_start = Instant::now();
194        let reranked = self.rerank(segment, rotated_query, &filtered_candidates, params.r)?;
195        stats.time_rerank_ns += rerank_start.elapsed().as_nanos() as u64;
196        stats.rerank_count += reranked.len();
197
198        // Adaptive widening check
199        if params.adaptive && reranked.len() < params.k {
200            stats.widening_applied = true;
201            // Could widen and retry here
202        }
203
204        Ok(reranked)
205    }
206
207    /// RDF-based candidate generation
208    fn rdf_search(
209        &self,
210        segment: &Segment,
211        rotated_query: &[f32],
212        l_a: usize,
213    ) -> Vec<ScoredCandidate> {
214        let directory = segment.rdf_directory();
215        if directory.is_empty() {
216            return Vec::new();
217        }
218
219        let rdf_data = unsafe {
220            let ptr = segment.rdf_data_ptr();
221            let len = segment.header().file_len as usize - segment.header().off_rdf_data as usize;
222            std::slice::from_raw_parts(ptr, len.min(1024 * 1024 * 100)) // Cap at 100MB for safety
223        };
224
225        let dim_weights = segment.dim_weights();
226        let scorer = RdfScorer::new(
227            directory,
228            rdf_data,
229            dim_weights,
230            segment.header().rdf_stripe_shift,
231            segment.num_vectors(),
232        );
233
234        scorer.score(rotated_query, self.config.rdf.top_t as usize, l_a)
235    }
236
237    /// BPS-based candidate generation
238    fn bps_search(
239        &self,
240        segment: &Segment,
241        rotated_query: &[f32],
242        l_b: usize,
243    ) -> Vec<(VectorId, Distance)> {
244        let header = segment.header();
245        let bps_data = segment.bps_data();
246
247        // Compute query sketch using stored qparams when available (correct
248        // asymmetric quantization).  Fall back to legacy symmetric quantization
249        // only for segments written before qparams were persisted.
250        #[allow(deprecated)]
251        let query_sketch = if let Some(qparams) = segment.bps_qparams() {
252            BpsBuilder::compute_query_sketch_with_params(&self.config.bps, rotated_query, qparams)
253        } else {
254            BpsBuilder::compute_query_sketch(&self.config.bps, rotated_query)
255        };
256
257        let scanner = BpsScanner::new(
258            bps_data,
259            header.n_vec as usize,
260            header.num_bps_blocks() as usize,
261            header.bps_proj as usize,
262        );
263
264        scanner.top_k(&query_sketch, l_b)
265    }
266
267    /// Rerank candidates using int8 dot product
268    fn rerank(
269        &self,
270        segment: &Segment,
271        rotated_query: &[f32],
272        candidates: &[VectorId],
273        r: usize,
274    ) -> Result<Vec<ScoredCandidate>> {
275        if candidates.is_empty() {
276            return Ok(Vec::new());
277        }
278
279        let header = segment.header();
280        let i8_data = segment.i8_data();
281        let scales = segment.scales_data();
282
283        // Quantize query
284        let (query_i8, query_scale) = quantize_query(rotated_query, &self.config.rerank);
285
286        // Get outliers if available
287        let outliers = if header
288            .flags
289            .has(crate::segment::format::SegmentFlags::HAS_OUTLIERS)
290        {
291            unsafe {
292                std::slice::from_raw_parts(
293                    segment.outliers_ptr(),
294                    header.n_vec as usize * header.num_outliers as usize,
295                )
296            }
297        } else {
298            &[]
299        };
300
301        let reranker = Reranker::new(
302            i8_data,
303            &scales[..header.n_vec as usize], // One scale per vector
304            outliers,
305            header.dim as usize,
306            header.num_outliers as usize,
307        );
308
309        Ok(reranker.rerank(candidates, &query_i8, query_scale, r))
310    }
311
312    /// Optional verification with fp32
313    pub fn verify(
314        &self,
315        segment: &Segment,
316        candidates: &[ScoredCandidate],
317        query: &[f32],
318        k: usize,
319    ) -> Vec<ScoredCandidate> {
320        if let Some(fp32_data) = segment.fp32_data() {
321            let dim = segment.dim() as usize;
322            let mut verified: Vec<ScoredCandidate> = candidates
323                .iter()
324                .map(|c| {
325                    let offset = c.id as usize * dim;
326                    if offset + dim <= fp32_data.len() {
327                        let vec = &fp32_data[offset..offset + dim];
328                        let score = query.iter().zip(vec.iter()).map(|(a, b)| a * b).sum();
329                        ScoredCandidate { id: c.id, score }
330                    } else {
331                        *c
332                    }
333                })
334                .collect();
335
336            verified.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
337            verified.truncate(k);
338            verified
339        } else {
340            candidates.to_vec()
341        }
342    }
343
344    /// Get total vector count across all segments
345    pub fn total_vectors(&self) -> u32 {
346        self.segments.iter().map(|s| s.num_vectors()).sum()
347    }
348
349    /// Get config
350    pub fn config(&self) -> &EngineConfig {
351        &self.config
352    }
353
354    /// Get segments
355    pub fn segments(&self) -> &[Arc<Segment>] {
356        &self.segments
357    }
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363    use crate::segment::SegmentWriter;
364    use tempfile::NamedTempFile;
365
366    fn create_test_index() -> (NamedTempFile, EngineConfig) {
367        let config = EngineConfig::with_dim(64);
368        let mut writer = SegmentWriter::new(config.clone()).unwrap();
369
370        // Add vectors with known patterns
371        for i in 0..1000 {
372            let mut vec = vec![0.0f32; 64];
373            // Create a distinctive pattern
374            vec[i % 64] = 1.0;
375            vec[(i + 1) % 64] = 0.5;
376            writer.add(&vec).unwrap();
377        }
378
379        let file = NamedTempFile::new().unwrap();
380        writer.build(file.path()).unwrap();
381
382        (file, config)
383    }
384
385    #[test]
386    fn test_query_engine_basic() {
387        let (file, config) = create_test_index();
388
389        let mut engine = QueryEngine::new(config).unwrap();
390        engine.load_segment(file.path().to_str().unwrap()).unwrap();
391
392        // Query similar to vector 0
393        let mut query = vec![0.0f32; 64];
394        query[0] = 1.0;
395        query[1] = 0.5;
396
397        let params = QueryParams {
398            k: 10,
399            l_a: 100,
400            l_b: 200,
401            r: 50,
402            adaptive: false,
403            filter: None,
404        };
405
406        let result = engine.search(&query, &params).unwrap();
407
408        assert!(!result.candidates.is_empty());
409        println!("Query stats: {}", result.stats);
410    }
411
412    #[test]
413    fn test_query_with_filter() {
414        let (file, config) = create_test_index();
415
416        let mut engine = QueryEngine::new(config).unwrap();
417        engine.load_segment(file.path().to_str().unwrap()).unwrap();
418
419        let mut query = vec![0.0f32; 64];
420        query[0] = 1.0;
421
422        // Filter to only even IDs
423        let filter: Vec<u64> = (0..500).map(|i| i * 2).collect();
424
425        let params = QueryParams {
426            k: 10,
427            l_a: 100,
428            l_b: 200,
429            r: 50,
430            adaptive: false,
431            filter: Some(filter),
432        };
433
434        let result = engine.search(&query, &params).unwrap();
435
436        // All results should have even IDs
437        for c in &result.candidates {
438            assert!(c.id % 2 == 0, "Expected even ID, got {}", c.id);
439        }
440    }
441}