Skip to main content

velesdb_core/collection/search/
batch.rs

1//! Batch and multi-query search methods for Collection.
2
3use crate::collection::types::Collection;
4use crate::error::{Error, Result};
5use crate::index::SearchQuality;
6use crate::point::{Point, SearchResult};
7use crate::storage::{PayloadStorage, VectorStorage};
8
9impl Collection {
10    /// Performs batch search for multiple query vectors in parallel with metadata filtering.
11    /// Supports a different filter for each query in the batch.
12    ///
13    /// # Arguments
14    ///
15    /// * `queries` - List of query vector slices
16    /// * `k` - Maximum number of results per query
17    /// * `filters` - List of optional filters (must match queries length)
18    ///
19    /// # Returns
20    ///
21    /// Vector of search results for each query, matching its respective filter.
22    ///
23    /// # Errors
24    ///
25    /// Returns an error if queries and filters have different lengths or dimension mismatch.
26    pub fn search_batch_with_filters(
27        &self,
28        queries: &[&[f32]],
29        k: usize,
30        filters: &[Option<crate::filter::Filter>],
31    ) -> Result<Vec<Vec<SearchResult>>> {
32        if queries.len() != filters.len() {
33            return Err(Error::Config(format!(
34                "Queries count ({}) does not match filters count ({})",
35                queries.len(),
36                filters.len()
37            )));
38        }
39
40        let config = self.config.read();
41        let dimension = config.dimension;
42        drop(config);
43
44        // Validate all query dimensions
45        for query in queries {
46            if query.len() != dimension {
47                return Err(Error::DimensionMismatch {
48                    expected: dimension,
49                    actual: query.len(),
50                });
51            }
52        }
53
54        // We need to retrieve more candidates for post-filtering
55        let candidates_k = k.saturating_mul(4).max(k + 10);
56        let index_results =
57            self.index
58                .search_batch_parallel(queries, candidates_k, SearchQuality::Balanced);
59
60        let vector_storage = self.vector_storage.read();
61        let payload_storage = self.payload_storage.read();
62
63        let mut all_results = Vec::with_capacity(queries.len());
64
65        for (query_results, filter_opt) in index_results.into_iter().zip(filters) {
66            let mut filtered_results: Vec<SearchResult> = query_results
67                .into_iter()
68                .filter_map(|(id, score)| {
69                    let payload = payload_storage.retrieve(id).ok().flatten();
70
71                    // Apply filter if present
72                    if let Some(ref filter) = filter_opt {
73                        if let Some(ref p) = payload {
74                            if !filter.matches(p) {
75                                return None;
76                            }
77                        } else if !filter.matches(&serde_json::Value::Null) {
78                            return None;
79                        }
80                    }
81
82                    let vector = vector_storage.retrieve(id).ok().flatten()?;
83                    Some(SearchResult {
84                        point: Point {
85                            id,
86                            vector,
87                            payload,
88                        },
89                        score,
90                    })
91                })
92                .collect();
93
94            // Sort and truncate to k
95            let higher_is_better = self.config.read().metric.higher_is_better();
96            if higher_is_better {
97                filtered_results.sort_by(|a, b| {
98                    b.score
99                        .partial_cmp(&a.score)
100                        .unwrap_or(std::cmp::Ordering::Equal)
101                });
102            } else {
103                filtered_results.sort_by(|a, b| {
104                    a.score
105                        .partial_cmp(&b.score)
106                        .unwrap_or(std::cmp::Ordering::Equal)
107                });
108            }
109            filtered_results.truncate(k);
110
111            all_results.push(filtered_results);
112        }
113
114        Ok(all_results)
115    }
116
117    /// Performs batch search for multiple query vectors in parallel with a single metadata filter.
118    ///
119    /// # Arguments
120    ///
121    /// * `queries` - List of query vector slices
122    /// * `k` - Maximum number of results per query
123    /// * `filter` - Metadata filter to apply to all results
124    ///
125    /// # Errors
126    ///
127    /// Returns an error if any query has incorrect dimension.
128    pub fn search_batch_with_filter(
129        &self,
130        queries: &[&[f32]],
131        k: usize,
132        filter: &crate::filter::Filter,
133    ) -> Result<Vec<Vec<SearchResult>>> {
134        let filters: Vec<Option<crate::filter::Filter>> = vec![Some(filter.clone()); queries.len()];
135        self.search_batch_with_filters(queries, k, &filters)
136    }
137
138    /// Performs batch search for multiple query vectors in parallel.
139    ///
140    /// This method is optimized for high throughput using parallel index traversal.
141    ///
142    /// # Arguments
143    ///
144    /// * `queries` - List of query vector slices
145    /// * `k` - Maximum number of results per query
146    ///
147    /// # Returns
148    ///
149    /// Vector of search results for each query, with full point data.
150    ///
151    /// # Errors
152    ///
153    /// Returns an error if any query vector dimension doesn't match the collection.
154    pub fn search_batch_parallel(
155        &self,
156        queries: &[&[f32]],
157        k: usize,
158    ) -> Result<Vec<Vec<SearchResult>>> {
159        let config = self.config.read();
160        let dimension = config.dimension;
161        drop(config);
162
163        // Validate all query dimensions first
164        for query in queries {
165            if query.len() != dimension {
166                return Err(Error::DimensionMismatch {
167                    expected: dimension,
168                    actual: query.len(),
169                });
170            }
171        }
172
173        // Perf: Use parallel HNSW search (P0 optimization)
174        let index_results = self
175            .index
176            .search_batch_parallel(queries, k, SearchQuality::Balanced);
177
178        // Map results to SearchResult with full point data
179        let vector_storage = self.vector_storage.read();
180        let payload_storage = self.payload_storage.read();
181
182        let results: Vec<Vec<SearchResult>> = index_results
183            .into_iter()
184            .map(|query_results: Vec<(u64, f32)>| {
185                query_results
186                    .into_iter()
187                    .filter_map(|(id, score)| {
188                        let vector = vector_storage.retrieve(id).ok().flatten()?;
189                        let payload = payload_storage.retrieve(id).ok().flatten();
190                        Some(SearchResult {
191                            point: Point {
192                                id,
193                                vector,
194                                payload,
195                            },
196                            score,
197                        })
198                    })
199                    .collect()
200            })
201            .collect();
202
203        Ok(results)
204    }
205
206    /// Performs multi-query search with result fusion.
207    ///
208    /// This method executes parallel searches for multiple query vectors and fuses
209    /// the results using the specified fusion strategy. Ideal for Multiple Query
210    /// Generation (MQG) pipelines where multiple reformulations of a user query
211    /// are searched simultaneously.
212    ///
213    /// # Arguments
214    ///
215    /// * `vectors` - Slice of query vectors (all must have same dimension)
216    /// * `top_k` - Maximum number of results to return after fusion
217    /// * `fusion` - Strategy for combining results (Average, Maximum, RRF, Weighted)
218    /// * `filter` - Optional metadata filter to apply to all queries
219    ///
220    /// # Returns
221    ///
222    /// Vector of `SearchResult` sorted by fused score descending.
223    ///
224    /// # Errors
225    ///
226    /// Returns an error if:
227    /// - `vectors` is empty
228    /// - Any vector has incorrect dimension
229    /// - More than 10 vectors are provided (configurable limit)
230    #[allow(clippy::needless_pass_by_value)]
231    pub fn multi_query_search(
232        &self,
233        vectors: &[&[f32]],
234        top_k: usize,
235        fusion: crate::fusion::FusionStrategy,
236        filter: Option<&crate::filter::Filter>,
237    ) -> Result<Vec<SearchResult>> {
238        const MAX_VECTORS: usize = 10;
239
240        // Validation: non-empty
241        if vectors.is_empty() {
242            return Err(Error::Config(
243                "multi_query_search requires at least one vector".into(),
244            ));
245        }
246
247        // Validation: max vectors limit
248        if vectors.len() > MAX_VECTORS {
249            return Err(Error::Config(format!(
250                "multi_query_search supports at most {MAX_VECTORS} vectors, got {}",
251                vectors.len()
252            )));
253        }
254
255        // Validation: dimension consistency
256        let config = self.config.read();
257        let dimension = config.dimension;
258        drop(config);
259
260        for vector in vectors {
261            if vector.len() != dimension {
262                return Err(Error::DimensionMismatch {
263                    expected: dimension,
264                    actual: vector.len(),
265                });
266            }
267        }
268
269        // Calculate overfetch factor for better fusion quality
270        let overfetch_k = match top_k {
271            0..=10 => top_k * 20,
272            11..=50 => top_k * 10,
273            51..=100 => top_k * 5,
274            _ => top_k * 2,
275        };
276
277        // Execute parallel batch search
278        let batch_results =
279            self.index
280                .search_batch_parallel(vectors, overfetch_k, crate::SearchQuality::Balanced);
281
282        // Apply filter if present (pre-fusion filtering)
283        let filtered_results: Vec<Vec<(u64, f32)>> = if let Some(f) = filter {
284            let payload_storage = self.payload_storage.read();
285            batch_results
286                .into_iter()
287                .map(|query_results| {
288                    query_results
289                        .into_iter()
290                        .filter(|(id, _score)| {
291                            if let Ok(Some(payload)) = payload_storage.retrieve(*id) {
292                                f.matches(&payload)
293                            } else {
294                                false
295                            }
296                        })
297                        .collect()
298                })
299                .collect()
300        } else {
301            batch_results
302        };
303
304        // Fuse results using the specified strategy
305        let fused = fusion
306            .fuse(filtered_results)
307            .map_err(|e| Error::Config(format!("Fusion error: {e}")))?;
308
309        // Fetch full point data for top_k results
310        let vector_storage = self.vector_storage.read();
311        let payload_storage = self.payload_storage.read();
312
313        let results: Vec<SearchResult> = fused
314            .into_iter()
315            .take(top_k)
316            .filter_map(|(id, score)| {
317                let vector = vector_storage.retrieve(id).ok().flatten()?;
318                let payload = payload_storage.retrieve(id).ok().flatten();
319
320                let point = Point {
321                    id,
322                    vector,
323                    payload,
324                };
325
326                Some(SearchResult::new(point, score))
327            })
328            .collect();
329
330        Ok(results)
331    }
332
333    /// Performs multi-query search returning only IDs and fused scores.
334    ///
335    /// This is a faster variant of `multi_query_search` that skips fetching
336    /// vector and payload data. Use when you only need document IDs.
337    ///
338    /// # Arguments
339    ///
340    /// * `vectors` - Slice of query vectors
341    /// * `top_k` - Maximum number of results
342    /// * `fusion` - Fusion strategy
343    ///
344    /// # Returns
345    ///
346    /// Vector of `(id, fused_score)` tuples sorted by score descending.
347    ///
348    /// # Errors
349    ///
350    /// Returns an error if vectors is empty, exceeds max limit, or has dimension mismatch.
351    #[allow(clippy::needless_pass_by_value)]
352    pub fn multi_query_search_ids(
353        &self,
354        vectors: &[&[f32]],
355        top_k: usize,
356        fusion: crate::fusion::FusionStrategy,
357    ) -> Result<Vec<(u64, f32)>> {
358        const MAX_VECTORS: usize = 10;
359
360        if vectors.is_empty() {
361            return Err(Error::Config(
362                "multi_query_search requires at least one vector".into(),
363            ));
364        }
365
366        if vectors.len() > MAX_VECTORS {
367            return Err(Error::Config(format!(
368                "multi_query_search supports at most {MAX_VECTORS} vectors, got {}",
369                vectors.len()
370            )));
371        }
372
373        let config = self.config.read();
374        let dimension = config.dimension;
375        drop(config);
376
377        for vector in vectors {
378            if vector.len() != dimension {
379                return Err(Error::DimensionMismatch {
380                    expected: dimension,
381                    actual: vector.len(),
382                });
383            }
384        }
385
386        let overfetch_k = match top_k {
387            0..=10 => top_k * 20,
388            11..=50 => top_k * 10,
389            51..=100 => top_k * 5,
390            _ => top_k * 2,
391        };
392
393        let batch_results =
394            self.index
395                .search_batch_parallel(vectors, overfetch_k, crate::SearchQuality::Balanced);
396
397        let fused = fusion
398            .fuse(batch_results)
399            .map_err(|e| Error::Config(format!("Fusion error: {e}")))?;
400
401        Ok(fused.into_iter().take(top_k).collect())
402    }
403}