velesdb_core/collection/search/
text.rs1use super::OrderedFloat;
4use crate::collection::types::Collection;
5use crate::error::{Error, Result};
6use crate::point::{Point, SearchResult};
7use crate::storage::{PayloadStorage, VectorStorage};
8
9impl Collection {
10 #[must_use]
21 pub fn text_search(&self, query: &str, k: usize) -> Vec<SearchResult> {
22 let bm25_results = self.text_index.search(query, k);
23
24 let vector_storage = self.vector_storage.read();
25 let payload_storage = self.payload_storage.read();
26
27 bm25_results
28 .into_iter()
29 .filter_map(|(id, score)| {
30 let vector = vector_storage.retrieve(id).ok().flatten()?;
31 let payload = payload_storage.retrieve(id).ok().flatten();
32
33 let point = Point {
34 id,
35 vector,
36 payload,
37 };
38
39 Some(SearchResult::new(point, score))
40 })
41 .collect()
42 }
43
44 #[must_use]
56 pub fn text_search_with_filter(
57 &self,
58 query: &str,
59 k: usize,
60 filter: &crate::filter::Filter,
61 ) -> Vec<SearchResult> {
62 let candidates_k = k.saturating_mul(4).max(k + 10);
64 let bm25_results = self.text_index.search(query, candidates_k);
65
66 let vector_storage = self.vector_storage.read();
67 let payload_storage = self.payload_storage.read();
68
69 bm25_results
70 .into_iter()
71 .filter_map(|(id, score)| {
72 let vector = vector_storage.retrieve(id).ok().flatten()?;
73 let payload = payload_storage.retrieve(id).ok().flatten();
74
75 let payload_ref = payload.as_ref()?;
77 if !filter.matches(payload_ref) {
78 return None;
79 }
80
81 let point = Point {
82 id,
83 vector,
84 payload,
85 };
86
87 Some(SearchResult::new(point, score))
88 })
89 .take(k)
90 .collect()
91 }
92
93 pub fn hybrid_search(
114 &self,
115 vector_query: &[f32],
116 text_query: &str,
117 k: usize,
118 vector_weight: Option<f32>,
119 ) -> Result<Vec<SearchResult>> {
120 use crate::index::VectorIndex;
121 use std::cmp::Reverse;
122 use std::collections::BinaryHeap;
123
124 let config = self.config.read();
125 if vector_query.len() != config.dimension {
126 return Err(Error::DimensionMismatch {
127 expected: config.dimension,
128 actual: vector_query.len(),
129 });
130 }
131 drop(config);
132
133 let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
134 let text_weight = 1.0 - weight;
135
136 let vector_results = self.index.search(vector_query, k * 2);
138
139 let text_results = self.text_index.search(text_query, k * 2);
141
142 let mut fused_scores: rustc_hash::FxHashMap<u64, f32> =
145 rustc_hash::FxHashMap::with_capacity_and_hasher(
146 vector_results.len() + text_results.len(),
147 rustc_hash::FxBuildHasher,
148 );
149
150 #[allow(clippy::cast_precision_loss)]
152 for (rank, (id, _)) in vector_results.iter().enumerate() {
153 let rrf_score = weight / (rank as f32 + 60.0);
154 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
155 }
156
157 #[allow(clippy::cast_precision_loss)]
159 for (rank, (id, _)) in text_results.iter().enumerate() {
160 let rrf_score = text_weight / (rank as f32 + 60.0);
161 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
162 }
163
164 let mut top_k: BinaryHeap<Reverse<(OrderedFloat, u64)>> = BinaryHeap::with_capacity(k + 1);
167
168 for (id, score) in fused_scores {
169 top_k.push(Reverse((OrderedFloat(score), id)));
170 if top_k.len() > k {
171 top_k.pop(); }
173 }
174
175 let mut scored_ids: Vec<(u64, f32)> = top_k
177 .into_iter()
178 .map(|Reverse((OrderedFloat(s), id))| (id, s))
179 .collect();
180 scored_ids.sort_by(|a, b| b.1.total_cmp(&a.1));
181
182 let vector_storage = self.vector_storage.read();
184 let payload_storage = self.payload_storage.read();
185
186 let results: Vec<SearchResult> = scored_ids
187 .into_iter()
188 .filter_map(|(id, score)| {
189 let vector = vector_storage.retrieve(id).ok().flatten()?;
190 let payload = payload_storage.retrieve(id).ok().flatten();
191
192 let point = Point {
193 id,
194 vector,
195 payload,
196 };
197
198 Some(SearchResult::new(point, score))
199 })
200 .collect();
201
202 Ok(results)
203 }
204
205 pub fn hybrid_search_with_filter(
222 &self,
223 vector_query: &[f32],
224 text_query: &str,
225 k: usize,
226 vector_weight: Option<f32>,
227 filter: &crate::filter::Filter,
228 ) -> Result<Vec<SearchResult>> {
229 use crate::index::VectorIndex;
230
231 let config = self.config.read();
232 if vector_query.len() != config.dimension {
233 return Err(Error::DimensionMismatch {
234 expected: config.dimension,
235 actual: vector_query.len(),
236 });
237 }
238 drop(config);
239
240 let weight = vector_weight.unwrap_or(0.5).clamp(0.0, 1.0);
241 let text_weight = 1.0 - weight;
242
243 let candidates_k = k.saturating_mul(4).max(k + 10);
245
246 let vector_results = self.index.search(vector_query, candidates_k);
248
249 let text_results = self.text_index.search(text_query, candidates_k);
251
252 let mut fused_scores: rustc_hash::FxHashMap<u64, f32> = rustc_hash::FxHashMap::default();
254
255 #[allow(clippy::cast_precision_loss)]
256 for (rank, (id, _)) in vector_results.iter().enumerate() {
257 let rrf_score = weight / (rank as f32 + 60.0);
258 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
259 }
260
261 #[allow(clippy::cast_precision_loss)]
262 for (rank, (id, _)) in text_results.iter().enumerate() {
263 let rrf_score = text_weight / (rank as f32 + 60.0);
264 *fused_scores.entry(*id).or_insert(0.0) += rrf_score;
265 }
266
267 let mut scored_ids: Vec<_> = fused_scores.into_iter().collect();
269 scored_ids.sort_by(|a, b| b.1.total_cmp(&a.1));
270
271 let vector_storage = self.vector_storage.read();
273 let payload_storage = self.payload_storage.read();
274
275 let results: Vec<SearchResult> = scored_ids
276 .into_iter()
277 .filter_map(|(id, score)| {
278 let vector = vector_storage.retrieve(id).ok().flatten()?;
279 let payload = payload_storage.retrieve(id).ok().flatten();
280
281 let payload_ref = payload.as_ref()?;
283 if !filter.matches(payload_ref) {
284 return None;
285 }
286
287 let point = Point {
288 id,
289 vector,
290 payload,
291 };
292
293 Some(SearchResult::new(point, score))
294 })
295 .take(k)
296 .collect();
297
298 Ok(results)
299 }
300}