velesdb_core/collection/search/
text.rs1use super::OrderedFloat;
4use crate::collection::types::Collection;
5use crate::error::{Error, Result};
6use crate::point::{Point, SearchResult};
7use crate::storage::{PayloadStorage, VectorStorage};
8
9impl Collection {
10 #[must_use]
21 pub fn text_search(&self, query: &str, k: usize) -> Vec<SearchResult> {
22 let bm25_results = self.text_index.search(query, k);
23
24 let vector_storage = self.vector_storage.read();
25 let payload_storage = self.payload_storage.read();
26
27 bm25_results
28 .into_iter()
29 .filter_map(|(id, score)| {
30 let vector = vector_storage.retrieve(id).ok().flatten()?;
31 let payload = payload_storage.retrieve(id).ok().flatten();
32
33 let point = Point {
34 id,
35 vector,
36 payload,
37 sparse_vectors: None,
38 };
39
40 Some(SearchResult::new(point, score))
41 })
42 .collect()
43 }
44
45 #[must_use]
57 pub fn text_search_with_filter(
58 &self,
59 query: &str,
60 k: usize,
61 filter: &crate::filter::Filter,
62 ) -> Vec<SearchResult> {
63 let candidates_k = k.saturating_mul(4).max(k + 10);
65 let bm25_results = self.text_index.search(query, candidates_k);
66
67 let vector_storage = self.vector_storage.read();
68 let payload_storage = self.payload_storage.read();
69
70 bm25_results
71 .into_iter()
72 .filter_map(|(id, score)| {
73 let vector = vector_storage.retrieve(id).ok().flatten()?;
74 let payload = payload_storage.retrieve(id).ok().flatten();
75
76 let payload_ref = payload.as_ref()?;
78 if !filter.matches(payload_ref) {
79 return None;
80 }
81
82 let point = Point {
83 id,
84 vector,
85 payload,
86 sparse_vectors: None,
87 };
88
89 Some(SearchResult::new(point, score))
90 })
91 .take(k)
92 .collect()
93 }
94
95 pub fn hybrid_search(
116 &self,
117 vector_query: &[f32],
118 text_query: &str,
119 k: usize,
120 vector_weight: Option<f32>,
121 ) -> Result<Vec<SearchResult>> {
122 use crate::index::VectorIndex;
123 use std::cmp::Reverse;
124 use std::collections::BinaryHeap;
125
126 let config = self.config.read();
127 if vector_query.len() != config.dimension {
128 return Err(Error::DimensionMismatch {
129 expected: config.dimension,
130 actual: vector_query.len(),
131 });
132 }
133 let metric = config.metric;
134 drop(config);
135
136 let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
137 let text_weight = 1.0 - weight;
138
139 let overfetch_k = k * 2;
142 let raw_vector_results = self.index.search(vector_query, overfetch_k);
143 let vector_results =
144 self.merge_delta(raw_vector_results, vector_query, overfetch_k, metric);
145
146 let text_results = self.text_index.search(text_query, k * 2);
148
149 let mut fused_scores: rustc_hash::FxHashMap<u64, f32> =
152 rustc_hash::FxHashMap::with_capacity_and_hasher(
153 vector_results.len() + text_results.len(),
154 rustc_hash::FxBuildHasher,
155 );
156
157 #[allow(clippy::cast_precision_loss)]
159 for (rank, (id, _)) in vector_results.iter().enumerate() {
160 let rrf_score = weight / (rank as f32 + 60.0);
161 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
162 }
163
164 #[allow(clippy::cast_precision_loss)]
166 for (rank, (id, _)) in text_results.iter().enumerate() {
167 let rrf_score = text_weight / (rank as f32 + 60.0);
168 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
169 }
170
171 let mut top_k: BinaryHeap<Reverse<(OrderedFloat, u64)>> = BinaryHeap::with_capacity(k + 1);
174
175 for (id, score) in fused_scores {
176 top_k.push(Reverse((OrderedFloat(score), id)));
177 if top_k.len() > k {
178 top_k.pop(); }
180 }
181
182 let mut scored_ids: Vec<(u64, f32)> = top_k
184 .into_iter()
185 .map(|Reverse((OrderedFloat(s), id))| (id, s))
186 .collect();
187 scored_ids.sort_by(|a, b| b.1.total_cmp(&a.1));
188
189 let vector_storage = self.vector_storage.read();
191 let payload_storage = self.payload_storage.read();
192
193 let results: Vec<SearchResult> = scored_ids
194 .into_iter()
195 .filter_map(|(id, score)| {
196 let vector = vector_storage.retrieve(id).ok().flatten()?;
197 let payload = payload_storage.retrieve(id).ok().flatten();
198
199 let point = Point {
200 id,
201 vector,
202 payload,
203 sparse_vectors: None,
204 };
205
206 Some(SearchResult::new(point, score))
207 })
208 .collect();
209
210 Ok(results)
211 }
212
213 pub fn hybrid_search_with_filter(
230 &self,
231 vector_query: &[f32],
232 text_query: &str,
233 k: usize,
234 vector_weight: Option<f32>,
235 filter: &crate::filter::Filter,
236 ) -> Result<Vec<SearchResult>> {
237 use crate::index::VectorIndex;
238
239 let config = self.config.read();
240 if vector_query.len() != config.dimension {
241 return Err(Error::DimensionMismatch {
242 expected: config.dimension,
243 actual: vector_query.len(),
244 });
245 }
246 let metric = config.metric;
247 drop(config);
248
249 let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
250 let text_weight = 1.0 - weight;
251
252 let candidates_k = k.saturating_mul(4).max(k + 10);
255
256 let raw_vector_results = self.index.search(vector_query, candidates_k);
258 let vector_results =
259 self.merge_delta(raw_vector_results, vector_query, candidates_k, metric);
260
261 let text_results = self.text_index.search(text_query, candidates_k);
263
264 let mut fused_scores: rustc_hash::FxHashMap<u64, f32> = rustc_hash::FxHashMap::default();
266
267 #[allow(clippy::cast_precision_loss)]
268 for (rank, (id, _)) in vector_results.iter().enumerate() {
269 let rrf_score = weight / (rank as f32 + 60.0);
270 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
271 }
272
273 #[allow(clippy::cast_precision_loss)]
274 for (rank, (id, _)) in text_results.iter().enumerate() {
275 let rrf_score = text_weight / (rank as f32 + 60.0);
276 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
277 }
278
279 let mut scored_ids: Vec<_> = fused_scores.into_iter().collect();
281 scored_ids.sort_by(|a, b| b.1.total_cmp(&a.1));
282
283 let vector_storage = self.vector_storage.read();
285 let payload_storage = self.payload_storage.read();
286
287 let results: Vec<SearchResult> = scored_ids
288 .into_iter()
289 .filter_map(|(id, score)| {
290 let vector = vector_storage.retrieve(id).ok().flatten()?;
291 let payload = payload_storage.retrieve(id).ok().flatten();
292
293 let payload_ref = payload.as_ref()?;
295 if !filter.matches(payload_ref) {
296 return None;
297 }
298
299 let point = Point {
300 id,
301 vector,
302 payload,
303 sparse_vectors: None,
304 };
305
306 Some(SearchResult::new(point, score))
307 })
308 .take(k)
309 .collect();
310
311 Ok(results)
312 }
313}