1use std::borrow::Cow;
10use std::cmp::Ordering;
11use std::collections::{BTreeSet, BinaryHeap};
12use std::time::Duration;
13
14use roaring::RoaringBitmap;
15use selene_core::{CancellationCause, CancellationChecker, DbString, NodeId, Value};
16
17use crate::error::{GraphError, GraphResult};
18use crate::graph::SeleneGraph;
19use crate::parallel_scan::{should_parallelize_scan, try_reduce_bitmap_chunks};
20use crate::shared::SharedGraph;
21use crate::store::RowIndex;
22
23pub(crate) const TEXT_SEARCH_CANCEL_STRIDE: usize = 1024;
24#[cfg(not(test))]
25const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 2048;
26#[cfg(test)]
27const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 4;
28
29#[cfg(not(test))]
30const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 16_384;
31#[cfg(test)]
32const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 8;
33const BM25_K1: f64 = 1.2;
34const BM25_B: f64 = 0.75;
35
36#[derive(Clone, Debug, PartialEq)]
38pub struct TextSearchHit {
39 pub node_id: NodeId,
41 pub score: f64,
43}
44
45#[derive(Debug, thiserror::Error)]
47pub enum TextSearchError {
48 #[error(transparent)]
50 Graph(#[from] GraphError),
51 #[error("text search cancelled")]
53 Cancelled,
54 #[error("text search timed out after {elapsed:?}")]
56 Timeout {
57 elapsed: Duration,
59 },
60}
61
62impl TextSearchError {
63 fn into_graph_error(self) -> GraphError {
64 match self {
65 Self::Graph(error) => error,
66 Self::Cancelled | Self::Timeout { .. } => GraphError::Inconsistent {
67 reason: format!("disabled text-search checker returned {self}"),
68 },
69 }
70 }
71}
72
73impl From<CancellationCause> for TextSearchError {
74 fn from(cause: CancellationCause) -> Self {
75 match cause {
76 CancellationCause::Cancelled => Self::Cancelled,
77 CancellationCause::Timeout { elapsed } => Self::Timeout { elapsed },
78 }
79 }
80}
81
82impl SeleneGraph {
83 pub fn exact_text_search_nodes(
91 &self,
92 label: &DbString,
93 property: &DbString,
94 query: &str,
95 k: usize,
96 ) -> GraphResult<Vec<TextSearchHit>> {
97 self.exact_text_search_nodes_checked(
98 label,
99 property,
100 query,
101 k,
102 CancellationChecker::disabled(),
103 )
104 .map_err(TextSearchError::into_graph_error)
105 }
106
107 pub fn exact_text_search_nodes_checked(
109 &self,
110 label: &DbString,
111 property: &DbString,
112 query: &str,
113 k: usize,
114 checker: CancellationChecker<'_>,
115 ) -> Result<Vec<TextSearchHit>, TextSearchError> {
116 checker.check()?;
117 if k == 0 {
118 return Ok(Vec::new());
119 }
120 let query_terms = unique_query_terms(query);
121 if query_terms.is_empty() {
122 return Ok(Vec::new());
123 }
124 let Some(label_rows) = self.nodes_with_label(label) else {
125 return Ok(Vec::new());
126 };
127
128 let scan = TextScan::new(self, label, property, &query_terms);
129 let chunk = if should_parallelize_text_scan(label_rows, k) {
130 exact_text_scan_parallel(scan, label_rows, checker)?
131 } else {
132 exact_text_scan_serial(scan, label_rows, checker)?
133 };
134 Ok(rank_text_docs(chunk, k))
135 }
136}
137
138impl SharedGraph {
139 pub fn exact_text_search_nodes(
141 &self,
142 label: &DbString,
143 property: &DbString,
144 query: &str,
145 k: usize,
146 ) -> GraphResult<Vec<TextSearchHit>> {
147 self.read()
148 .exact_text_search_nodes(label, property, query, k)
149 }
150
151 pub fn exact_text_search_nodes_checked(
153 &self,
154 label: &DbString,
155 property: &DbString,
156 query: &str,
157 k: usize,
158 checker: CancellationChecker<'_>,
159 ) -> Result<Vec<TextSearchHit>, TextSearchError> {
160 self.read()
161 .exact_text_search_nodes_checked(label, property, query, k, checker)
162 }
163}
164
165#[derive(Clone, Copy)]
166struct TextScan<'a> {
167 graph: &'a SeleneGraph,
168 label: &'a DbString,
169 property: &'a DbString,
170 query_terms: &'a [String],
171}
172
173impl<'a> TextScan<'a> {
174 fn new(
175 graph: &'a SeleneGraph,
176 label: &'a DbString,
177 property: &'a DbString,
178 query_terms: &'a [String],
179 ) -> Self {
180 Self {
181 graph,
182 label,
183 property,
184 query_terms,
185 }
186 }
187
188 fn document_for_row(self, raw_row: u32) -> Result<Option<DocumentStats>, TextSearchError> {
189 if !self.graph.node_store.is_alive(raw_row) {
190 return Ok(None);
191 }
192 let row = RowIndex::new(raw_row);
193 let node_id = self
194 .graph
195 .node_id_for_row(row)
196 .ok_or_else(|| GraphError::Inconsistent {
197 reason: format!(
198 "label index row {raw_row} for {} has no node id",
199 self.label.as_str()
200 ),
201 })?;
202 let properties = self
203 .graph
204 .node_store
205 .properties
206 .get(raw_row as usize)
207 .ok_or_else(|| GraphError::Inconsistent {
208 reason: format!(
209 "text search row {raw_row} for {} has no property row",
210 self.label.as_str()
211 ),
212 })?;
213 let Some(Value::String(text)) = properties.get(self.property) else {
214 return Ok(None);
215 };
216 Ok(document_stats(node_id, text.as_str(), self.query_terms))
217 }
218}
219
220#[derive(Debug)]
221struct TextScanChunk {
222 docs: Vec<DocumentStats>,
223 document_frequencies: Vec<u32>,
224 total_document_len: u64,
225}
226
227impl TextScanChunk {
228 fn empty(query_term_count: usize) -> Self {
229 Self {
230 docs: Vec::new(),
231 document_frequencies: vec![0; query_term_count],
232 total_document_len: 0,
233 }
234 }
235
236 fn push(&mut self, doc: DocumentStats) {
237 for (frequency, count) in self.document_frequencies.iter_mut().zip(&doc.term_counts) {
238 if *count > 0 {
239 *frequency = frequency.saturating_add(1);
240 }
241 }
242 self.total_document_len = self.total_document_len.saturating_add(u64::from(doc.len));
243 self.docs.push(doc);
244 }
245}
246
247fn should_parallelize_text_scan(rows: &RoaringBitmap, k: usize) -> bool {
248 should_parallelize_scan(rows.len(), k, TEXT_SEARCH_PARALLEL_MIN_ROWS)
249}
250
251fn exact_text_scan_parallel(
252 scan: TextScan<'_>,
253 rows: &RoaringBitmap,
254 checker: CancellationChecker<'_>,
255) -> Result<TextScanChunk, TextSearchError> {
256 try_reduce_bitmap_chunks(
257 rows,
258 TEXT_SEARCH_PARALLEL_CHUNK_ROWS,
259 checker,
260 || TextScanChunk::empty(scan.query_terms.len()),
261 |chunk| exact_text_scan_chunk(scan, chunk),
262 merge_text_scan_chunks,
263 )
264}
265
266fn exact_text_scan_serial(
267 scan: TextScan<'_>,
268 rows: &RoaringBitmap,
269 checker: CancellationChecker<'_>,
270) -> Result<TextScanChunk, TextSearchError> {
271 let mut chunk = TextScanChunk::empty(scan.query_terms.len());
272 let mut rows_since_check = 0usize;
273 for raw_row in rows.iter() {
274 rows_since_check += 1;
275 if rows_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
276 checker.check()?;
277 rows_since_check = 0;
278 }
279 if let Some(doc) = scan.document_for_row(raw_row)? {
280 chunk.push(doc);
281 }
282 }
283 Ok(chunk)
284}
285
286fn exact_text_scan_chunk(
287 scan: TextScan<'_>,
288 rows: &[u32],
289) -> Result<TextScanChunk, TextSearchError> {
290 let mut chunk = TextScanChunk::empty(scan.query_terms.len());
291 for &raw_row in rows {
292 if let Some(doc) = scan.document_for_row(raw_row)? {
293 chunk.push(doc);
294 }
295 }
296 Ok(chunk)
297}
298
299fn merge_text_scan_chunks(
300 mut lhs: TextScanChunk,
301 mut rhs: TextScanChunk,
302) -> Result<TextScanChunk, TextSearchError> {
303 for (lhs_frequency, rhs_frequency) in lhs
304 .document_frequencies
305 .iter_mut()
306 .zip(&rhs.document_frequencies)
307 {
308 *lhs_frequency = lhs_frequency.saturating_add(*rhs_frequency);
309 }
310 lhs.total_document_len = lhs
311 .total_document_len
312 .saturating_add(rhs.total_document_len);
313 lhs.docs.append(&mut rhs.docs);
314 Ok(lhs)
315}
316
317fn rank_text_docs(chunk: TextScanChunk, k: usize) -> Vec<TextSearchHit> {
318 if chunk.docs.is_empty() {
319 return Vec::new();
320 }
321 let corpus_len = chunk.docs.len() as f64;
322 let average_document_len = chunk.total_document_len as f64 / corpus_len;
323 let mut top_k = TextTopK::new(k);
324 for doc in chunk.docs {
325 let score = bm25_score(
326 &doc,
327 &chunk.document_frequencies,
328 corpus_len,
329 average_document_len,
330 );
331 if score > 0.0 {
332 top_k.push(doc.node_id, score);
333 }
334 }
335 top_k.into_hits()
336}
337
338#[derive(Debug)]
339pub(crate) struct DocumentStats {
340 pub(crate) node_id: NodeId,
341 len: u32,
342 pub(crate) term_counts: Vec<u32>,
343}
344
345impl DocumentStats {
346 pub(crate) fn zero(node_id: NodeId, len: u32, query_term_count: usize) -> Self {
347 Self {
348 node_id,
349 len,
350 term_counts: vec![0; query_term_count],
351 }
352 }
353}
354
355pub(crate) fn unique_query_terms(query: &str) -> Vec<String> {
356 let terms: BTreeSet<_> = tokenize_borrowed(query).map(Cow::into_owned).collect();
357 terms.into_iter().collect()
358}
359
360fn document_stats(node_id: NodeId, text: &str, query_terms: &[String]) -> Option<DocumentStats> {
361 let mut term_counts = vec![0_u32; query_terms.len()];
362 let mut len = 0_u32;
363 for token in tokenize_borrowed(text) {
364 len = len.saturating_add(1);
365 if let Ok(index) = query_terms.binary_search_by(|term| term.as_str().cmp(token.as_ref())) {
366 term_counts[index] = term_counts[index].saturating_add(1);
367 }
368 }
369 (len > 0).then_some(DocumentStats {
370 node_id,
371 len,
372 term_counts,
373 })
374}
375
376pub(crate) fn tokenize_borrowed(text: &str) -> Tokenizer<'_> {
378 Tokenizer { text, offset: 0 }
379}
380
381pub(crate) struct Tokenizer<'a> {
383 text: &'a str,
384 offset: usize,
385}
386
387impl<'a> Iterator for Tokenizer<'a> {
388 type Item = Cow<'a, str>;
389
390 fn next(&mut self) -> Option<Self::Item> {
391 let mut start = None;
392 let mut end = self.text.len();
393 let mut owned = None::<String>;
394
395 let base = self.offset;
396 for (relative_index, ch) in self.text[base..].char_indices() {
397 let index = base + relative_index;
398 if !ch.is_alphanumeric() {
399 if start.is_some() {
400 end = index;
401 self.offset = index + ch.len_utf8();
402 break;
403 }
404 self.offset = index + ch.len_utf8();
405 continue;
406 }
407
408 let start_index = *start.get_or_insert(index);
409 let mut lowercase = ch.to_lowercase();
410 let first = lowercase
411 .next()
412 .expect("char lowercase mapping yields at least one char");
413 let second = lowercase.next();
414 let unchanged = first == ch && second.is_none();
415 if let Some(buffer) = owned.as_mut() {
416 if unchanged {
417 buffer.push(ch);
418 } else {
419 buffer.push(first);
420 if let Some(second) = second {
421 buffer.push(second);
422 }
423 buffer.extend(lowercase);
424 }
425 } else if !unchanged {
426 let mut buffer = self.text[start_index..index].to_owned();
427 buffer.push(first);
428 if let Some(second) = second {
429 buffer.push(second);
430 }
431 buffer.extend(lowercase);
432 owned = Some(buffer);
433 }
434 }
435
436 let start = start?;
437 if self.offset <= start {
438 self.offset = self.text.len();
439 }
440
441 Some(match owned {
442 Some(token) => Cow::Owned(token),
443 None => Cow::Borrowed(&self.text[start..end]),
444 })
445 }
446}
447
448pub(crate) fn bm25_score(
449 doc: &DocumentStats,
450 document_frequencies: &[u32],
451 corpus_len: f64,
452 average_document_len: f64,
453) -> f64 {
454 let document_len = f64::from(doc.len);
455 doc.term_counts
456 .iter()
457 .zip(document_frequencies)
458 .filter(|(term_count, _)| **term_count > 0)
459 .map(|(term_count, document_frequency)| {
460 let term_count = f64::from(*term_count);
461 let document_frequency = f64::from(*document_frequency);
462 let idf =
463 (1.0 + (corpus_len - document_frequency + 0.5) / (document_frequency + 0.5)).ln();
464 let normalization = term_count
465 + BM25_K1 * (1.0 - BM25_B + BM25_B * document_len / average_document_len);
466 idf * (term_count * (BM25_K1 + 1.0)) / normalization
467 })
468 .sum()
469}
470
471#[derive(Debug)]
472pub(crate) struct TextTopK {
473 k: usize,
474 heap: BinaryHeap<TextHeapEntry>,
475}
476
477impl TextTopK {
478 pub(crate) fn new(k: usize) -> Self {
479 Self {
480 k,
481 heap: BinaryHeap::new(),
482 }
483 }
484
485 pub(crate) fn push(&mut self, node_id: NodeId, score: f64) {
486 debug_assert!(score.is_finite(), "BM25 scores must be finite");
487 if self.k == 0 {
488 return;
489 }
490 let entry = TextHeapEntry { score, node_id };
491 if self.heap.len() < self.k {
492 self.heap.push(entry);
493 return;
494 }
495 let Some(worst) = self.heap.peek() else {
496 return;
497 };
498 if entry.cmp(worst).is_lt() {
499 self.heap.pop();
500 self.heap.push(entry);
501 }
502 }
503
504 pub(crate) fn into_hits(self) -> Vec<TextSearchHit> {
505 let mut hits: Vec<_> = self
506 .heap
507 .into_iter()
508 .map(|entry| TextSearchHit {
509 node_id: entry.node_id,
510 score: entry.score,
511 })
512 .collect();
513 hits.sort_by(compare_hit);
514 hits
515 }
516}
517
518#[derive(Debug)]
519struct TextHeapEntry {
520 score: f64,
521 node_id: NodeId,
522}
523
524impl Eq for TextHeapEntry {}
525
526impl PartialEq for TextHeapEntry {
527 fn eq(&self, rhs: &Self) -> bool {
528 self.score.to_bits() == rhs.score.to_bits() && self.node_id == rhs.node_id
529 }
530}
531
532impl Ord for TextHeapEntry {
533 fn cmp(&self, rhs: &Self) -> Ordering {
534 rhs.score
535 .total_cmp(&self.score)
536 .then_with(|| self.node_id.cmp(&rhs.node_id))
537 }
538}
539
540impl PartialOrd for TextHeapEntry {
541 fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
542 Some(self.cmp(rhs))
543 }
544}
545
546fn compare_hit(lhs: &TextSearchHit, rhs: &TextSearchHit) -> Ordering {
547 rhs.score
548 .total_cmp(&lhs.score)
549 .then_with(|| lhs.node_id.cmp(&rhs.node_id))
550}
551
552#[cfg(test)]
553#[path = "text_search/tests.rs"]
554mod tests;