velesdb_core/collection/search/
batch.rs1use crate::collection::types::Collection;
4use crate::error::{Error, Result};
5use crate::index::SearchQuality;
6use crate::point::{Point, SearchResult};
7use crate::storage::{PayloadStorage, VectorStorage};
8
9impl Collection {
10 pub fn search_batch_with_filters(
27 &self,
28 queries: &[&[f32]],
29 k: usize,
30 filters: &[Option<crate::filter::Filter>],
31 ) -> Result<Vec<Vec<SearchResult>>> {
32 if queries.len() != filters.len() {
33 return Err(Error::Config(format!(
34 "Queries count ({}) does not match filters count ({})",
35 queries.len(),
36 filters.len()
37 )));
38 }
39
40 let config = self.config.read();
41 let dimension = config.dimension;
42 drop(config);
43
44 for query in queries {
46 if query.len() != dimension {
47 return Err(Error::DimensionMismatch {
48 expected: dimension,
49 actual: query.len(),
50 });
51 }
52 }
53
54 let candidates_k = k.saturating_mul(4).max(k + 10);
56 let metric = self.config.read().metric;
57 let index_results =
58 self.index
59 .search_batch_parallel(queries, candidates_k, SearchQuality::Balanced);
60
61 let vector_storage = self.vector_storage.read();
62 let payload_storage = self.payload_storage.read();
63
64 let mut all_results = Vec::with_capacity(queries.len());
65
66 for ((query_results, filter_opt), query) in
67 index_results.into_iter().zip(filters).zip(queries)
68 {
69 let query_results = self.merge_delta(query_results, query, candidates_k, metric);
71 let mut filtered_results: Vec<SearchResult> = query_results
72 .into_iter()
73 .filter_map(|(id, score)| {
74 let payload = payload_storage.retrieve(id).ok().flatten();
75
76 if let Some(ref filter) = filter_opt {
78 if let Some(ref p) = payload {
79 if !filter.matches(p) {
80 return None;
81 }
82 } else if !filter.matches(&serde_json::Value::Null) {
83 return None;
84 }
85 }
86
87 let vector = vector_storage.retrieve(id).ok().flatten()?;
88 Some(SearchResult {
89 point: Point {
90 id,
91 vector,
92 payload,
93 sparse_vectors: None,
94 },
95 score,
96 })
97 })
98 .collect();
99
100 let higher_is_better = self.config.read().metric.higher_is_better();
102 if higher_is_better {
103 filtered_results.sort_by(|a, b| {
104 b.score
105 .partial_cmp(&a.score)
106 .unwrap_or(std::cmp::Ordering::Equal)
107 });
108 } else {
109 filtered_results.sort_by(|a, b| {
110 a.score
111 .partial_cmp(&b.score)
112 .unwrap_or(std::cmp::Ordering::Equal)
113 });
114 }
115 filtered_results.truncate(k);
116
117 all_results.push(filtered_results);
118 }
119
120 Ok(all_results)
121 }
122
123 pub fn search_batch_with_filter(
135 &self,
136 queries: &[&[f32]],
137 k: usize,
138 filter: &crate::filter::Filter,
139 ) -> Result<Vec<Vec<SearchResult>>> {
140 let filters: Vec<Option<crate::filter::Filter>> = vec![Some(filter.clone()); queries.len()];
141 self.search_batch_with_filters(queries, k, &filters)
142 }
143
144 pub fn search_batch_parallel(
161 &self,
162 queries: &[&[f32]],
163 k: usize,
164 ) -> Result<Vec<Vec<SearchResult>>> {
165 let config = self.config.read();
166 let dimension = config.dimension;
167 drop(config);
168
169 for query in queries {
171 if query.len() != dimension {
172 return Err(Error::DimensionMismatch {
173 expected: dimension,
174 actual: query.len(),
175 });
176 }
177 }
178
179 let metric = self.config.read().metric;
181 let index_results = self
182 .index
183 .search_batch_parallel(queries, k, SearchQuality::Balanced);
184
185 let vector_storage = self.vector_storage.read();
187 let payload_storage = self.payload_storage.read();
188
189 let results: Vec<Vec<SearchResult>> = index_results
190 .into_iter()
191 .zip(queries)
192 .map(|(query_results, query): (Vec<(u64, f32)>, &&[f32])| {
193 let query_results = self.merge_delta(query_results, query, k, metric);
195 query_results
196 .into_iter()
197 .filter_map(|(id, score)| {
198 let vector = vector_storage.retrieve(id).ok().flatten()?;
199 let payload = payload_storage.retrieve(id).ok().flatten();
200 Some(SearchResult {
201 point: Point {
202 id,
203 vector,
204 payload,
205 sparse_vectors: None,
206 },
207 score,
208 })
209 })
210 .collect()
211 })
212 .collect();
213
214 Ok(results)
215 }
216
217 #[allow(clippy::needless_pass_by_value)]
242 #[allow(clippy::too_many_lines)]
243 pub fn multi_query_search(
244 &self,
245 vectors: &[&[f32]],
246 top_k: usize,
247 fusion: crate::fusion::FusionStrategy,
248 filter: Option<&crate::filter::Filter>,
249 ) -> Result<Vec<SearchResult>> {
250 const MAX_VECTORS: usize = 10;
251
252 if vectors.is_empty() {
254 return Err(Error::Config(
255 "multi_query_search requires at least one vector".into(),
256 ));
257 }
258
259 if vectors.len() > MAX_VECTORS {
261 return Err(Error::Config(format!(
262 "multi_query_search supports at most {MAX_VECTORS} vectors, got {}",
263 vectors.len()
264 )));
265 }
266
267 let config = self.config.read();
269 let dimension = config.dimension;
270 drop(config);
271
272 for vector in vectors {
273 if vector.len() != dimension {
274 return Err(Error::DimensionMismatch {
275 expected: dimension,
276 actual: vector.len(),
277 });
278 }
279 }
280
281 let overfetch_k = match top_k {
283 0..=10 => top_k * 20,
284 11..=50 => top_k * 10,
285 51..=100 => top_k * 5,
286 _ => top_k * 2,
287 };
288
289 let metric = self.config.read().metric;
290
291 let batch_results =
293 self.index
294 .search_batch_parallel(vectors, overfetch_k, crate::SearchQuality::Balanced);
295
296 let batch_results: Vec<Vec<(u64, f32)>> = batch_results
298 .into_iter()
299 .zip(vectors)
300 .map(|(query_results, query)| {
301 self.merge_delta(query_results, query, overfetch_k, metric)
302 })
303 .collect();
304
305 let filtered_results: Vec<Vec<(u64, f32)>> = if let Some(f) = filter {
307 let payload_storage = self.payload_storage.read();
308 batch_results
309 .into_iter()
310 .map(|query_results| {
311 query_results
312 .into_iter()
313 .filter(|(id, _score)| {
314 if let Ok(Some(payload)) = payload_storage.retrieve(*id) {
315 f.matches(&payload)
316 } else {
317 false
318 }
319 })
320 .collect()
321 })
322 .collect()
323 } else {
324 batch_results
325 };
326
327 let fused = fusion
329 .fuse(filtered_results)
330 .map_err(|e| Error::Config(format!("Fusion error: {e}")))?;
331
332 let vector_storage = self.vector_storage.read();
334 let payload_storage = self.payload_storage.read();
335
336 let results: Vec<SearchResult> = fused
337 .into_iter()
338 .take(top_k)
339 .filter_map(|(id, score)| {
340 let vector = vector_storage.retrieve(id).ok().flatten()?;
341 let payload = payload_storage.retrieve(id).ok().flatten();
342
343 let point = Point {
344 id,
345 vector,
346 payload,
347 sparse_vectors: None,
348 };
349
350 Some(SearchResult::new(point, score))
351 })
352 .collect();
353
354 Ok(results)
355 }
356
357 #[allow(clippy::needless_pass_by_value)]
376 pub fn multi_query_search_ids(
377 &self,
378 vectors: &[&[f32]],
379 top_k: usize,
380 fusion: crate::fusion::FusionStrategy,
381 ) -> Result<Vec<(u64, f32)>> {
382 const MAX_VECTORS: usize = 10;
383
384 if vectors.is_empty() {
385 return Err(Error::Config(
386 "multi_query_search requires at least one vector".into(),
387 ));
388 }
389
390 if vectors.len() > MAX_VECTORS {
391 return Err(Error::Config(format!(
392 "multi_query_search supports at most {MAX_VECTORS} vectors, got {}",
393 vectors.len()
394 )));
395 }
396
397 let config = self.config.read();
398 let dimension = config.dimension;
399 drop(config);
400
401 for vector in vectors {
402 if vector.len() != dimension {
403 return Err(Error::DimensionMismatch {
404 expected: dimension,
405 actual: vector.len(),
406 });
407 }
408 }
409
410 let overfetch_k = match top_k {
411 0..=10 => top_k * 20,
412 11..=50 => top_k * 10,
413 51..=100 => top_k * 5,
414 _ => top_k * 2,
415 };
416
417 let metric = self.config.read().metric;
418
419 let batch_results =
420 self.index
421 .search_batch_parallel(vectors, overfetch_k, crate::SearchQuality::Balanced);
422
423 let batch_results: Vec<Vec<(u64, f32)>> = batch_results
425 .into_iter()
426 .zip(vectors)
427 .map(|(query_results, query)| {
428 self.merge_delta(query_results, query, overfetch_k, metric)
429 })
430 .collect();
431
432 let fused = fusion
433 .fuse(batch_results)
434 .map_err(|e| Error::Config(format!("Fusion error: {e}")))?;
435
436 Ok(fused.into_iter().take(top_k).collect())
437 }
438}