Skip to main content

velesdb_core/collection/search/
batch.rs

1//! Batch and multi-query search methods for Collection.
2
3use crate::collection::types::Collection;
4use crate::error::{Error, Result};
5use crate::index::SearchQuality;
6use crate::point::{Point, SearchResult};
7use crate::storage::{PayloadStorage, VectorStorage};
8
9impl Collection {
10    /// Performs batch search for multiple query vectors in parallel with metadata filtering.
11    /// Supports a different filter for each query in the batch.
12    ///
13    /// # Arguments
14    ///
15    /// * `queries` - List of query vector slices
16    /// * `k` - Maximum number of results per query
17    /// * `filters` - List of optional filters (must match queries length)
18    ///
19    /// # Returns
20    ///
21    /// Vector of search results for each query, matching its respective filter.
22    ///
23    /// # Errors
24    ///
25    /// Returns an error if queries and filters have different lengths or dimension mismatch.
26    pub fn search_batch_with_filters(
27        &self,
28        queries: &[&[f32]],
29        k: usize,
30        filters: &[Option<crate::filter::Filter>],
31    ) -> Result<Vec<Vec<SearchResult>>> {
32        if queries.len() != filters.len() {
33            return Err(Error::Config(format!(
34                "Queries count ({}) does not match filters count ({})",
35                queries.len(),
36                filters.len()
37            )));
38        }
39
40        let config = self.config.read();
41        let dimension = config.dimension;
42        drop(config);
43
44        // Validate all query dimensions
45        for query in queries {
46            if query.len() != dimension {
47                return Err(Error::DimensionMismatch {
48                    expected: dimension,
49                    actual: query.len(),
50                });
51            }
52        }
53
54        // We need to retrieve more candidates for post-filtering
55        let candidates_k = k.saturating_mul(4).max(k + 10);
56        let metric = self.config.read().metric;
57        let index_results =
58            self.index
59                .search_batch_parallel(queries, candidates_k, SearchQuality::Balanced);
60
61        let vector_storage = self.vector_storage.read();
62        let payload_storage = self.payload_storage.read();
63
64        let mut all_results = Vec::with_capacity(queries.len());
65
66        for ((query_results, filter_opt), query) in
67            index_results.into_iter().zip(filters).zip(queries)
68        {
69            // Merge with delta buffer before filtering
70            let query_results = self.merge_delta(query_results, query, candidates_k, metric);
71            let mut filtered_results: Vec<SearchResult> = query_results
72                .into_iter()
73                .filter_map(|(id, score)| {
74                    let payload = payload_storage.retrieve(id).ok().flatten();
75
76                    // Apply filter if present
77                    if let Some(ref filter) = filter_opt {
78                        if let Some(ref p) = payload {
79                            if !filter.matches(p) {
80                                return None;
81                            }
82                        } else if !filter.matches(&serde_json::Value::Null) {
83                            return None;
84                        }
85                    }
86
87                    let vector = vector_storage.retrieve(id).ok().flatten()?;
88                    Some(SearchResult {
89                        point: Point {
90                            id,
91                            vector,
92                            payload,
93                            sparse_vectors: None,
94                        },
95                        score,
96                    })
97                })
98                .collect();
99
100            // Sort and truncate to k
101            let higher_is_better = self.config.read().metric.higher_is_better();
102            if higher_is_better {
103                filtered_results.sort_by(|a, b| {
104                    b.score
105                        .partial_cmp(&a.score)
106                        .unwrap_or(std::cmp::Ordering::Equal)
107                });
108            } else {
109                filtered_results.sort_by(|a, b| {
110                    a.score
111                        .partial_cmp(&b.score)
112                        .unwrap_or(std::cmp::Ordering::Equal)
113                });
114            }
115            filtered_results.truncate(k);
116
117            all_results.push(filtered_results);
118        }
119
120        Ok(all_results)
121    }
122
123    /// Performs batch search for multiple query vectors in parallel with a single metadata filter.
124    ///
125    /// # Arguments
126    ///
127    /// * `queries` - List of query vector slices
128    /// * `k` - Maximum number of results per query
129    /// * `filter` - Metadata filter to apply to all results
130    ///
131    /// # Errors
132    ///
133    /// Returns an error if any query has incorrect dimension.
134    pub fn search_batch_with_filter(
135        &self,
136        queries: &[&[f32]],
137        k: usize,
138        filter: &crate::filter::Filter,
139    ) -> Result<Vec<Vec<SearchResult>>> {
140        let filters: Vec<Option<crate::filter::Filter>> = vec![Some(filter.clone()); queries.len()];
141        self.search_batch_with_filters(queries, k, &filters)
142    }
143
144    /// Performs batch search for multiple query vectors in parallel.
145    ///
146    /// This method is optimized for high throughput using parallel index traversal.
147    ///
148    /// # Arguments
149    ///
150    /// * `queries` - List of query vector slices
151    /// * `k` - Maximum number of results per query
152    ///
153    /// # Returns
154    ///
155    /// Vector of search results for each query, with full point data.
156    ///
157    /// # Errors
158    ///
159    /// Returns an error if any query vector dimension doesn't match the collection.
160    pub fn search_batch_parallel(
161        &self,
162        queries: &[&[f32]],
163        k: usize,
164    ) -> Result<Vec<Vec<SearchResult>>> {
165        let config = self.config.read();
166        let dimension = config.dimension;
167        drop(config);
168
169        // Validate all query dimensions first
170        for query in queries {
171            if query.len() != dimension {
172                return Err(Error::DimensionMismatch {
173                    expected: dimension,
174                    actual: query.len(),
175                });
176            }
177        }
178
179        // Perf: Use parallel HNSW search (P0 optimization)
180        let metric = self.config.read().metric;
181        let index_results = self
182            .index
183            .search_batch_parallel(queries, k, SearchQuality::Balanced);
184
185        // Map results to SearchResult with full point data
186        let vector_storage = self.vector_storage.read();
187        let payload_storage = self.payload_storage.read();
188
189        let results: Vec<Vec<SearchResult>> = index_results
190            .into_iter()
191            .zip(queries)
192            .map(|(query_results, query): (Vec<(u64, f32)>, &&[f32])| {
193                // Merge with delta buffer per query
194                let query_results = self.merge_delta(query_results, query, k, metric);
195                query_results
196                    .into_iter()
197                    .filter_map(|(id, score)| {
198                        let vector = vector_storage.retrieve(id).ok().flatten()?;
199                        let payload = payload_storage.retrieve(id).ok().flatten();
200                        Some(SearchResult {
201                            point: Point {
202                                id,
203                                vector,
204                                payload,
205                                sparse_vectors: None,
206                            },
207                            score,
208                        })
209                    })
210                    .collect()
211            })
212            .collect();
213
214        Ok(results)
215    }
216
217    /// Performs multi-query search with result fusion.
218    ///
219    /// This method executes parallel searches for multiple query vectors and fuses
220    /// the results using the specified fusion strategy. Ideal for Multiple Query
221    /// Generation (MQG) pipelines where multiple reformulations of a user query
222    /// are searched simultaneously.
223    ///
224    /// # Arguments
225    ///
226    /// * `vectors` - Slice of query vectors (all must have same dimension)
227    /// * `top_k` - Maximum number of results to return after fusion
228    /// * `fusion` - Strategy for combining results (Average, Maximum, RRF, Weighted)
229    /// * `filter` - Optional metadata filter to apply to all queries
230    ///
231    /// # Returns
232    ///
233    /// Vector of `SearchResult` sorted by fused score descending.
234    ///
235    /// # Errors
236    ///
237    /// Returns an error if:
238    /// - `vectors` is empty
239    /// - Any vector has incorrect dimension
240    /// - More than 10 vectors are provided (configurable limit)
241    #[allow(clippy::needless_pass_by_value)]
242    #[allow(clippy::too_many_lines)]
243    pub fn multi_query_search(
244        &self,
245        vectors: &[&[f32]],
246        top_k: usize,
247        fusion: crate::fusion::FusionStrategy,
248        filter: Option<&crate::filter::Filter>,
249    ) -> Result<Vec<SearchResult>> {
250        const MAX_VECTORS: usize = 10;
251
252        // Validation: non-empty
253        if vectors.is_empty() {
254            return Err(Error::Config(
255                "multi_query_search requires at least one vector".into(),
256            ));
257        }
258
259        // Validation: max vectors limit
260        if vectors.len() > MAX_VECTORS {
261            return Err(Error::Config(format!(
262                "multi_query_search supports at most {MAX_VECTORS} vectors, got {}",
263                vectors.len()
264            )));
265        }
266
267        // Validation: dimension consistency
268        let config = self.config.read();
269        let dimension = config.dimension;
270        drop(config);
271
272        for vector in vectors {
273            if vector.len() != dimension {
274                return Err(Error::DimensionMismatch {
275                    expected: dimension,
276                    actual: vector.len(),
277                });
278            }
279        }
280
281        // Calculate overfetch factor for better fusion quality
282        let overfetch_k = match top_k {
283            0..=10 => top_k * 20,
284            11..=50 => top_k * 10,
285            51..=100 => top_k * 5,
286            _ => top_k * 2,
287        };
288
289        let metric = self.config.read().metric;
290
291        // Execute parallel batch search
292        let batch_results =
293            self.index
294                .search_batch_parallel(vectors, overfetch_k, crate::SearchQuality::Balanced);
295
296        // Merge with delta buffer per query before fusion (C-2: was bypassed).
297        let batch_results: Vec<Vec<(u64, f32)>> = batch_results
298            .into_iter()
299            .zip(vectors)
300            .map(|(query_results, query)| {
301                self.merge_delta(query_results, query, overfetch_k, metric)
302            })
303            .collect();
304
305        // Apply filter if present (pre-fusion filtering)
306        let filtered_results: Vec<Vec<(u64, f32)>> = if let Some(f) = filter {
307            let payload_storage = self.payload_storage.read();
308            batch_results
309                .into_iter()
310                .map(|query_results| {
311                    query_results
312                        .into_iter()
313                        .filter(|(id, _score)| {
314                            if let Ok(Some(payload)) = payload_storage.retrieve(*id) {
315                                f.matches(&payload)
316                            } else {
317                                false
318                            }
319                        })
320                        .collect()
321                })
322                .collect()
323        } else {
324            batch_results
325        };
326
327        // Fuse results using the specified strategy
328        let fused = fusion
329            .fuse(filtered_results)
330            .map_err(|e| Error::Config(format!("Fusion error: {e}")))?;
331
332        // Fetch full point data for top_k results
333        let vector_storage = self.vector_storage.read();
334        let payload_storage = self.payload_storage.read();
335
336        let results: Vec<SearchResult> = fused
337            .into_iter()
338            .take(top_k)
339            .filter_map(|(id, score)| {
340                let vector = vector_storage.retrieve(id).ok().flatten()?;
341                let payload = payload_storage.retrieve(id).ok().flatten();
342
343                let point = Point {
344                    id,
345                    vector,
346                    payload,
347                    sparse_vectors: None,
348                };
349
350                Some(SearchResult::new(point, score))
351            })
352            .collect();
353
354        Ok(results)
355    }
356
357    /// Performs multi-query search returning only IDs and fused scores.
358    ///
359    /// This is a faster variant of `multi_query_search` that skips fetching
360    /// vector and payload data. Use when you only need document IDs.
361    ///
362    /// # Arguments
363    ///
364    /// * `vectors` - Slice of query vectors
365    /// * `top_k` - Maximum number of results
366    /// * `fusion` - Fusion strategy
367    ///
368    /// # Returns
369    ///
370    /// Vector of `(id, fused_score)` tuples sorted by score descending.
371    ///
372    /// # Errors
373    ///
374    /// Returns an error if vectors is empty, exceeds max limit, or has dimension mismatch.
375    #[allow(clippy::needless_pass_by_value)]
376    pub fn multi_query_search_ids(
377        &self,
378        vectors: &[&[f32]],
379        top_k: usize,
380        fusion: crate::fusion::FusionStrategy,
381    ) -> Result<Vec<(u64, f32)>> {
382        const MAX_VECTORS: usize = 10;
383
384        if vectors.is_empty() {
385            return Err(Error::Config(
386                "multi_query_search requires at least one vector".into(),
387            ));
388        }
389
390        if vectors.len() > MAX_VECTORS {
391            return Err(Error::Config(format!(
392                "multi_query_search supports at most {MAX_VECTORS} vectors, got {}",
393                vectors.len()
394            )));
395        }
396
397        let config = self.config.read();
398        let dimension = config.dimension;
399        drop(config);
400
401        for vector in vectors {
402            if vector.len() != dimension {
403                return Err(Error::DimensionMismatch {
404                    expected: dimension,
405                    actual: vector.len(),
406                });
407            }
408        }
409
410        let overfetch_k = match top_k {
411            0..=10 => top_k * 20,
412            11..=50 => top_k * 10,
413            51..=100 => top_k * 5,
414            _ => top_k * 2,
415        };
416
417        let metric = self.config.read().metric;
418
419        let batch_results =
420            self.index
421                .search_batch_parallel(vectors, overfetch_k, crate::SearchQuality::Balanced);
422
423        // Merge with delta buffer per query before fusion (C-2: was bypassed).
424        let batch_results: Vec<Vec<(u64, f32)>> = batch_results
425            .into_iter()
426            .zip(vectors)
427            .map(|(query_results, query)| {
428                self.merge_delta(query_results, query, overfetch_k, metric)
429            })
430            .collect();
431
432        let fused = fusion
433            .fuse(batch_results)
434            .map_err(|e| Error::Config(format!("Fusion error: {e}")))?;
435
436        Ok(fused.into_iter().take(top_k).collect())
437    }
438}