velesdb_core/collection/search/batch.rs
1//! Batch and multi-query search methods for Collection.
2
3use crate::collection::types::Collection;
4use crate::error::{Error, Result};
5use crate::index::SearchQuality;
6use crate::point::{Point, SearchResult};
7use crate::storage::{PayloadStorage, VectorStorage};
8
9impl Collection {
10 /// Performs batch search for multiple query vectors in parallel with metadata filtering.
11 /// Supports a different filter for each query in the batch.
12 ///
13 /// # Arguments
14 ///
15 /// * `queries` - List of query vector slices
16 /// * `k` - Maximum number of results per query
17 /// * `filters` - List of optional filters (must match queries length)
18 ///
19 /// # Returns
20 ///
21 /// Vector of search results for each query, matching its respective filter.
22 ///
23 /// # Errors
24 ///
25 /// Returns an error if queries and filters have different lengths or dimension mismatch.
26 pub fn search_batch_with_filters(
27 &self,
28 queries: &[&[f32]],
29 k: usize,
30 filters: &[Option<crate::filter::Filter>],
31 ) -> Result<Vec<Vec<SearchResult>>> {
32 if queries.len() != filters.len() {
33 return Err(Error::Config(format!(
34 "Queries count ({}) does not match filters count ({})",
35 queries.len(),
36 filters.len()
37 )));
38 }
39
40 let config = self.config.read();
41 let dimension = config.dimension;
42 drop(config);
43
44 // Validate all query dimensions
45 for query in queries {
46 if query.len() != dimension {
47 return Err(Error::DimensionMismatch {
48 expected: dimension,
49 actual: query.len(),
50 });
51 }
52 }
53
54 // We need to retrieve more candidates for post-filtering
55 let candidates_k = k.saturating_mul(4).max(k + 10);
56 let index_results =
57 self.index
58 .search_batch_parallel(queries, candidates_k, SearchQuality::Balanced);
59
60 let vector_storage = self.vector_storage.read();
61 let payload_storage = self.payload_storage.read();
62
63 let mut all_results = Vec::with_capacity(queries.len());
64
65 for (query_results, filter_opt) in index_results.into_iter().zip(filters) {
66 let mut filtered_results: Vec<SearchResult> = query_results
67 .into_iter()
68 .filter_map(|(id, score)| {
69 let payload = payload_storage.retrieve(id).ok().flatten();
70
71 // Apply filter if present
72 if let Some(ref filter) = filter_opt {
73 if let Some(ref p) = payload {
74 if !filter.matches(p) {
75 return None;
76 }
77 } else if !filter.matches(&serde_json::Value::Null) {
78 return None;
79 }
80 }
81
82 let vector = vector_storage.retrieve(id).ok().flatten()?;
83 Some(SearchResult {
84 point: Point {
85 id,
86 vector,
87 payload,
88 },
89 score,
90 })
91 })
92 .collect();
93
94 // Sort and truncate to k
95 let higher_is_better = self.config.read().metric.higher_is_better();
96 if higher_is_better {
97 filtered_results.sort_by(|a, b| {
98 b.score
99 .partial_cmp(&a.score)
100 .unwrap_or(std::cmp::Ordering::Equal)
101 });
102 } else {
103 filtered_results.sort_by(|a, b| {
104 a.score
105 .partial_cmp(&b.score)
106 .unwrap_or(std::cmp::Ordering::Equal)
107 });
108 }
109 filtered_results.truncate(k);
110
111 all_results.push(filtered_results);
112 }
113
114 Ok(all_results)
115 }
116
117 /// Performs batch search for multiple query vectors in parallel with a single metadata filter.
118 ///
119 /// # Arguments
120 ///
121 /// * `queries` - List of query vector slices
122 /// * `k` - Maximum number of results per query
123 /// * `filter` - Metadata filter to apply to all results
124 ///
125 /// # Errors
126 ///
127 /// Returns an error if any query has incorrect dimension.
128 pub fn search_batch_with_filter(
129 &self,
130 queries: &[&[f32]],
131 k: usize,
132 filter: &crate::filter::Filter,
133 ) -> Result<Vec<Vec<SearchResult>>> {
134 let filters: Vec<Option<crate::filter::Filter>> = vec![Some(filter.clone()); queries.len()];
135 self.search_batch_with_filters(queries, k, &filters)
136 }
137
138 /// Performs batch search for multiple query vectors in parallel.
139 ///
140 /// This method is optimized for high throughput using parallel index traversal.
141 ///
142 /// # Arguments
143 ///
144 /// * `queries` - List of query vector slices
145 /// * `k` - Maximum number of results per query
146 ///
147 /// # Returns
148 ///
149 /// Vector of search results for each query, with full point data.
150 ///
151 /// # Errors
152 ///
153 /// Returns an error if any query vector dimension doesn't match the collection.
154 pub fn search_batch_parallel(
155 &self,
156 queries: &[&[f32]],
157 k: usize,
158 ) -> Result<Vec<Vec<SearchResult>>> {
159 let config = self.config.read();
160 let dimension = config.dimension;
161 drop(config);
162
163 // Validate all query dimensions first
164 for query in queries {
165 if query.len() != dimension {
166 return Err(Error::DimensionMismatch {
167 expected: dimension,
168 actual: query.len(),
169 });
170 }
171 }
172
173 // Perf: Use parallel HNSW search (P0 optimization)
174 let index_results = self
175 .index
176 .search_batch_parallel(queries, k, SearchQuality::Balanced);
177
178 // Map results to SearchResult with full point data
179 let vector_storage = self.vector_storage.read();
180 let payload_storage = self.payload_storage.read();
181
182 let results: Vec<Vec<SearchResult>> = index_results
183 .into_iter()
184 .map(|query_results: Vec<(u64, f32)>| {
185 query_results
186 .into_iter()
187 .filter_map(|(id, score)| {
188 let vector = vector_storage.retrieve(id).ok().flatten()?;
189 let payload = payload_storage.retrieve(id).ok().flatten();
190 Some(SearchResult {
191 point: Point {
192 id,
193 vector,
194 payload,
195 },
196 score,
197 })
198 })
199 .collect()
200 })
201 .collect();
202
203 Ok(results)
204 }
205
206 /// Performs multi-query search with result fusion.
207 ///
208 /// This method executes parallel searches for multiple query vectors and fuses
209 /// the results using the specified fusion strategy. Ideal for Multiple Query
210 /// Generation (MQG) pipelines where multiple reformulations of a user query
211 /// are searched simultaneously.
212 ///
213 /// # Arguments
214 ///
215 /// * `vectors` - Slice of query vectors (all must have same dimension)
216 /// * `top_k` - Maximum number of results to return after fusion
217 /// * `fusion` - Strategy for combining results (Average, Maximum, RRF, Weighted)
218 /// * `filter` - Optional metadata filter to apply to all queries
219 ///
220 /// # Returns
221 ///
222 /// Vector of `SearchResult` sorted by fused score descending.
223 ///
224 /// # Errors
225 ///
226 /// Returns an error if:
227 /// - `vectors` is empty
228 /// - Any vector has incorrect dimension
229 /// - More than 10 vectors are provided (configurable limit)
230 #[allow(clippy::needless_pass_by_value)]
231 pub fn multi_query_search(
232 &self,
233 vectors: &[&[f32]],
234 top_k: usize,
235 fusion: crate::fusion::FusionStrategy,
236 filter: Option<&crate::filter::Filter>,
237 ) -> Result<Vec<SearchResult>> {
238 const MAX_VECTORS: usize = 10;
239
240 // Validation: non-empty
241 if vectors.is_empty() {
242 return Err(Error::Config(
243 "multi_query_search requires at least one vector".into(),
244 ));
245 }
246
247 // Validation: max vectors limit
248 if vectors.len() > MAX_VECTORS {
249 return Err(Error::Config(format!(
250 "multi_query_search supports at most {MAX_VECTORS} vectors, got {}",
251 vectors.len()
252 )));
253 }
254
255 // Validation: dimension consistency
256 let config = self.config.read();
257 let dimension = config.dimension;
258 drop(config);
259
260 for vector in vectors {
261 if vector.len() != dimension {
262 return Err(Error::DimensionMismatch {
263 expected: dimension,
264 actual: vector.len(),
265 });
266 }
267 }
268
269 // Calculate overfetch factor for better fusion quality
270 let overfetch_k = match top_k {
271 0..=10 => top_k * 20,
272 11..=50 => top_k * 10,
273 51..=100 => top_k * 5,
274 _ => top_k * 2,
275 };
276
277 // Execute parallel batch search
278 let batch_results =
279 self.index
280 .search_batch_parallel(vectors, overfetch_k, crate::SearchQuality::Balanced);
281
282 // Apply filter if present (pre-fusion filtering)
283 let filtered_results: Vec<Vec<(u64, f32)>> = if let Some(f) = filter {
284 let payload_storage = self.payload_storage.read();
285 batch_results
286 .into_iter()
287 .map(|query_results| {
288 query_results
289 .into_iter()
290 .filter(|(id, _score)| {
291 if let Ok(Some(payload)) = payload_storage.retrieve(*id) {
292 f.matches(&payload)
293 } else {
294 false
295 }
296 })
297 .collect()
298 })
299 .collect()
300 } else {
301 batch_results
302 };
303
304 // Fuse results using the specified strategy
305 let fused = fusion
306 .fuse(filtered_results)
307 .map_err(|e| Error::Config(format!("Fusion error: {e}")))?;
308
309 // Fetch full point data for top_k results
310 let vector_storage = self.vector_storage.read();
311 let payload_storage = self.payload_storage.read();
312
313 let results: Vec<SearchResult> = fused
314 .into_iter()
315 .take(top_k)
316 .filter_map(|(id, score)| {
317 let vector = vector_storage.retrieve(id).ok().flatten()?;
318 let payload = payload_storage.retrieve(id).ok().flatten();
319
320 let point = Point {
321 id,
322 vector,
323 payload,
324 };
325
326 Some(SearchResult::new(point, score))
327 })
328 .collect();
329
330 Ok(results)
331 }
332
333 /// Performs multi-query search returning only IDs and fused scores.
334 ///
335 /// This is a faster variant of `multi_query_search` that skips fetching
336 /// vector and payload data. Use when you only need document IDs.
337 ///
338 /// # Arguments
339 ///
340 /// * `vectors` - Slice of query vectors
341 /// * `top_k` - Maximum number of results
342 /// * `fusion` - Fusion strategy
343 ///
344 /// # Returns
345 ///
346 /// Vector of `(id, fused_score)` tuples sorted by score descending.
347 ///
348 /// # Errors
349 ///
350 /// Returns an error if vectors is empty, exceeds max limit, or has dimension mismatch.
351 #[allow(clippy::needless_pass_by_value)]
352 pub fn multi_query_search_ids(
353 &self,
354 vectors: &[&[f32]],
355 top_k: usize,
356 fusion: crate::fusion::FusionStrategy,
357 ) -> Result<Vec<(u64, f32)>> {
358 const MAX_VECTORS: usize = 10;
359
360 if vectors.is_empty() {
361 return Err(Error::Config(
362 "multi_query_search requires at least one vector".into(),
363 ));
364 }
365
366 if vectors.len() > MAX_VECTORS {
367 return Err(Error::Config(format!(
368 "multi_query_search supports at most {MAX_VECTORS} vectors, got {}",
369 vectors.len()
370 )));
371 }
372
373 let config = self.config.read();
374 let dimension = config.dimension;
375 drop(config);
376
377 for vector in vectors {
378 if vector.len() != dimension {
379 return Err(Error::DimensionMismatch {
380 expected: dimension,
381 actual: vector.len(),
382 });
383 }
384 }
385
386 let overfetch_k = match top_k {
387 0..=10 => top_k * 20,
388 11..=50 => top_k * 10,
389 51..=100 => top_k * 5,
390 _ => top_k * 2,
391 };
392
393 let batch_results =
394 self.index
395 .search_batch_parallel(vectors, overfetch_k, crate::SearchQuality::Balanced);
396
397 let fused = fusion
398 .fuse(batch_results)
399 .map_err(|e| Error::Config(format!("Fusion error: {e}")))?;
400
401 Ok(fused.into_iter().take(top_k).collect())
402 }
403}