1use std::collections::HashSet;
4use std::sync::Arc;
5use std::time::Instant;
6
7use crate::config::EngineConfig;
8use crate::error::{Error, Result};
9use crate::filter::BitsetFilter;
10use crate::rotation::Rotator;
11use crate::segment::Segment;
12use crate::segment::bps::{BpsBuilder, BpsScanner};
13use crate::segment::rdf::RdfScorer;
14use crate::segment::rerank::{Reranker, quantize_query};
15use crate::types::*;
16
17pub struct QueryEngine {
19 config: EngineConfig,
20 segments: Vec<Arc<Segment>>,
21 rotator: Rotator,
22}
23
24impl QueryEngine {
25 pub fn new(config: EngineConfig) -> Result<Self> {
27 config.validate()?;
28 let rotator = Rotator::new(config.dim);
29
30 Ok(Self {
31 config,
32 segments: Vec::new(),
33 rotator,
34 })
35 }
36
37 pub fn add_segment(&mut self, segment: Arc<Segment>) -> Result<()> {
39 if segment.dim() != self.config.dim {
40 return Err(Error::DimensionMismatch {
41 expected: self.config.dim,
42 got: segment.dim(),
43 });
44 }
45 self.segments.push(segment);
46 Ok(())
47 }
48
49 pub fn load_segment(&mut self, path: &str) -> Result<()> {
51 let segment = Segment::open(path)?;
52 self.add_segment(Arc::new(segment))
53 }
54
55 pub fn search(&self, query: &[f32], params: &QueryParams) -> Result<QueryResult> {
57 if query.len() != self.config.dim as usize {
58 return Err(Error::DimensionMismatch {
59 expected: self.config.dim,
60 got: query.len() as u32,
61 });
62 }
63
64 if self.segments.is_empty() {
65 return Err(Error::EmptyIndex);
66 }
67
68 let total_start = Instant::now();
69 let mut stats = QueryStats::default();
70
71 let rotate_start = Instant::now();
73 let rotated_query = self.rotator.rotate(query);
74 stats.time_rotate_ns = rotate_start.elapsed().as_nanos() as u64;
75
76 let filter = params.filter.as_ref().map(|bits| {
78 BitsetFilter::from_ids(
79 self.total_vectors(),
80 &bits.iter().map(|&id| id as VectorId).collect::<Vec<_>>(),
81 )
82 });
83
84 let mut all_candidates: Vec<ScoredCandidate> = Vec::new();
86
87 for segment in &self.segments {
88 let segment_result = self.search_segment(
89 segment,
90 &rotated_query,
91 query,
92 params,
93 filter.as_ref(),
94 &mut stats,
95 )?;
96 all_candidates.extend(segment_result);
97 }
98
99 all_candidates.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
101 all_candidates.truncate(params.k);
102
103 stats.total_time_ns = total_start.elapsed().as_nanos() as u64;
104
105 Ok(QueryResult {
106 candidates: all_candidates,
107 stats,
108 })
109 }
110
111 fn search_segment(
113 &self,
114 segment: &Segment,
115 rotated_query: &[f32],
116 _original_query: &[f32],
117 params: &QueryParams,
118 filter: Option<&BitsetFilter>,
119 stats: &mut QueryStats,
120 ) -> Result<Vec<ScoredCandidate>> {
121 let header = segment.header();
122 let _n_vec = header.n_vec as usize;
123
124 let selectivity = filter
126 .map(|f| {
127 let mut f = f.clone();
128 f.selectivity()
129 })
130 .unwrap_or(1.0);
131
132 let widening_factor = if selectivity < 1.0 && params.adaptive {
133 (1.0 / selectivity).min(4.0)
134 } else {
135 1.0
136 };
137
138 let rdf_start = Instant::now();
140 let rdf_candidates = if header
141 .flags
142 .has(crate::segment::format::SegmentFlags::HAS_RDF)
143 {
144 let l_a_widened = ((params.l_a as f32) * widening_factor) as usize;
145 self.rdf_search(segment, rotated_query, l_a_widened)
146 } else {
147 Vec::new()
148 };
149 stats.time_rdf_ns += rdf_start.elapsed().as_nanos() as u64;
150 stats.rdf_candidates += rdf_candidates.len();
151
152 let bps_start = Instant::now();
154 let bps_candidates = if header
155 .flags
156 .has(crate::segment::format::SegmentFlags::HAS_BPS)
157 {
158 let l_b_widened = ((params.l_b as f32) * widening_factor) as usize;
159 self.bps_search(segment, rotated_query, l_b_widened)
160 } else {
161 Vec::new()
162 };
163 stats.time_bps_ns += bps_start.elapsed().as_nanos() as u64;
164 stats.bps_candidates += bps_candidates.len();
165
166 let mut candidate_set: HashSet<VectorId> = HashSet::new();
168 for c in &rdf_candidates {
169 candidate_set.insert(c.id);
170 }
171 for (vid, _) in &bps_candidates {
172 candidate_set.insert(*vid);
173 }
174 stats.union_size += candidate_set.len();
175
176 let filter_start = Instant::now();
178 let filtered_candidates: Vec<VectorId> = if let Some(f) = filter {
179 candidate_set
180 .into_iter()
181 .filter(|&id| f.contains(id) && !segment.is_tombstoned(id))
182 .collect()
183 } else {
184 candidate_set
185 .into_iter()
186 .filter(|&id| !segment.is_tombstoned(id))
187 .collect()
188 };
189 stats.time_filter_ns += filter_start.elapsed().as_nanos() as u64;
190 stats.post_filter_size += filtered_candidates.len();
191
192 let rerank_start = Instant::now();
194 let reranked = self.rerank(segment, rotated_query, &filtered_candidates, params.r)?;
195 stats.time_rerank_ns += rerank_start.elapsed().as_nanos() as u64;
196 stats.rerank_count += reranked.len();
197
198 if params.adaptive && reranked.len() < params.k {
200 stats.widening_applied = true;
201 }
203
204 Ok(reranked)
205 }
206
207 fn rdf_search(
209 &self,
210 segment: &Segment,
211 rotated_query: &[f32],
212 l_a: usize,
213 ) -> Vec<ScoredCandidate> {
214 let directory = segment.rdf_directory();
215 if directory.is_empty() {
216 return Vec::new();
217 }
218
219 let rdf_data = unsafe {
220 let ptr = segment.rdf_data_ptr();
221 let len = segment.header().file_len as usize - segment.header().off_rdf_data as usize;
222 std::slice::from_raw_parts(ptr, len.min(1024 * 1024 * 100)) };
224
225 let dim_weights = segment.dim_weights();
226 let scorer = RdfScorer::new(
227 directory,
228 rdf_data,
229 dim_weights,
230 segment.header().rdf_stripe_shift,
231 segment.num_vectors(),
232 );
233
234 scorer.score(rotated_query, self.config.rdf.top_t as usize, l_a)
235 }
236
237 fn bps_search(
239 &self,
240 segment: &Segment,
241 rotated_query: &[f32],
242 l_b: usize,
243 ) -> Vec<(VectorId, Distance)> {
244 let header = segment.header();
245 let bps_data = segment.bps_data();
246
247 #[allow(deprecated)]
251 let query_sketch = if let Some(qparams) = segment.bps_qparams() {
252 BpsBuilder::compute_query_sketch_with_params(&self.config.bps, rotated_query, qparams)
253 } else {
254 BpsBuilder::compute_query_sketch(&self.config.bps, rotated_query)
255 };
256
257 let scanner = BpsScanner::new(
258 bps_data,
259 header.n_vec as usize,
260 header.num_bps_blocks() as usize,
261 header.bps_proj as usize,
262 );
263
264 scanner.top_k(&query_sketch, l_b)
265 }
266
267 fn rerank(
269 &self,
270 segment: &Segment,
271 rotated_query: &[f32],
272 candidates: &[VectorId],
273 r: usize,
274 ) -> Result<Vec<ScoredCandidate>> {
275 if candidates.is_empty() {
276 return Ok(Vec::new());
277 }
278
279 let header = segment.header();
280 let i8_data = segment.i8_data();
281 let scales = segment.scales_data();
282
283 let (query_i8, query_scale) = quantize_query(rotated_query, &self.config.rerank);
285
286 let outliers = if header
288 .flags
289 .has(crate::segment::format::SegmentFlags::HAS_OUTLIERS)
290 {
291 unsafe {
292 std::slice::from_raw_parts(
293 segment.outliers_ptr(),
294 header.n_vec as usize * header.num_outliers as usize,
295 )
296 }
297 } else {
298 &[]
299 };
300
301 let reranker = Reranker::new(
302 i8_data,
303 &scales[..header.n_vec as usize], outliers,
305 header.dim as usize,
306 header.num_outliers as usize,
307 );
308
309 Ok(reranker.rerank(candidates, &query_i8, query_scale, r))
310 }
311
312 pub fn verify(
314 &self,
315 segment: &Segment,
316 candidates: &[ScoredCandidate],
317 query: &[f32],
318 k: usize,
319 ) -> Vec<ScoredCandidate> {
320 if let Some(fp32_data) = segment.fp32_data() {
321 let dim = segment.dim() as usize;
322 let mut verified: Vec<ScoredCandidate> = candidates
323 .iter()
324 .map(|c| {
325 let offset = c.id as usize * dim;
326 if offset + dim <= fp32_data.len() {
327 let vec = &fp32_data[offset..offset + dim];
328 let score = query.iter().zip(vec.iter()).map(|(a, b)| a * b).sum();
329 ScoredCandidate { id: c.id, score }
330 } else {
331 *c
332 }
333 })
334 .collect();
335
336 verified.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap());
337 verified.truncate(k);
338 verified
339 } else {
340 candidates.to_vec()
341 }
342 }
343
344 pub fn total_vectors(&self) -> u32 {
346 self.segments.iter().map(|s| s.num_vectors()).sum()
347 }
348
349 pub fn config(&self) -> &EngineConfig {
351 &self.config
352 }
353
354 pub fn segments(&self) -> &[Arc<Segment>] {
356 &self.segments
357 }
358}
359
360#[cfg(test)]
361mod tests {
362 use super::*;
363 use crate::segment::SegmentWriter;
364 use tempfile::NamedTempFile;
365
366 fn create_test_index() -> (NamedTempFile, EngineConfig) {
367 let config = EngineConfig::with_dim(64);
368 let mut writer = SegmentWriter::new(config.clone()).unwrap();
369
370 for i in 0..1000 {
372 let mut vec = vec![0.0f32; 64];
373 vec[i % 64] = 1.0;
375 vec[(i + 1) % 64] = 0.5;
376 writer.add(&vec).unwrap();
377 }
378
379 let file = NamedTempFile::new().unwrap();
380 writer.build(file.path()).unwrap();
381
382 (file, config)
383 }
384
385 #[test]
386 fn test_query_engine_basic() {
387 let (file, config) = create_test_index();
388
389 let mut engine = QueryEngine::new(config).unwrap();
390 engine.load_segment(file.path().to_str().unwrap()).unwrap();
391
392 let mut query = vec![0.0f32; 64];
394 query[0] = 1.0;
395 query[1] = 0.5;
396
397 let params = QueryParams {
398 k: 10,
399 l_a: 100,
400 l_b: 200,
401 r: 50,
402 adaptive: false,
403 filter: None,
404 };
405
406 let result = engine.search(&query, ¶ms).unwrap();
407
408 assert!(!result.candidates.is_empty());
409 println!("Query stats: {}", result.stats);
410 }
411
412 #[test]
413 fn test_query_with_filter() {
414 let (file, config) = create_test_index();
415
416 let mut engine = QueryEngine::new(config).unwrap();
417 engine.load_segment(file.path().to_str().unwrap()).unwrap();
418
419 let mut query = vec![0.0f32; 64];
420 query[0] = 1.0;
421
422 let filter: Vec<u64> = (0..500).map(|i| i * 2).collect();
424
425 let params = QueryParams {
426 k: 10,
427 l_a: 100,
428 l_b: 200,
429 r: 50,
430 adaptive: false,
431 filter: Some(filter),
432 };
433
434 let result = engine.search(&query, ¶ms).unwrap();
435
436 for c in &result.candidates {
438 assert!(c.id % 2 == 0, "Expected even ID, got {}", c.id);
439 }
440 }
441}