Skip to main content

selene_graph/
text_search.rs

1//! Exact BM25 full-text search over graph node properties.
2//!
3//! This module is the full-text correctness oracle: it scans the current graph
4//! snapshot, tokenizes string properties, computes BM25 document statistics for
5//! the requested `(label, property)` surface, and returns a deterministic
6//! top-`k` ranking. Future maintained or postings-backed text indexes should
7//! use this path as their ordering and recall reference.
8
9use std::borrow::Cow;
10use std::cmp::Ordering;
11use std::collections::{BTreeSet, BinaryHeap};
12use std::time::Duration;
13
14use roaring::RoaringBitmap;
15use selene_core::{CancellationCause, CancellationChecker, DbString, NodeId, Value};
16
17use crate::error::{GraphError, GraphResult};
18use crate::graph::SeleneGraph;
19use crate::parallel_scan::{should_parallelize_scan, try_reduce_bitmap_chunks};
20use crate::shared::SharedGraph;
21use crate::store::RowIndex;
22
23pub(crate) const TEXT_SEARCH_CANCEL_STRIDE: usize = 1024;
24#[cfg(not(test))]
25const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 2048;
26#[cfg(test)]
27const TEXT_SEARCH_PARALLEL_CHUNK_ROWS: usize = 4;
28
29#[cfg(not(test))]
30const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 16_384;
31#[cfg(test)]
32const TEXT_SEARCH_PARALLEL_MIN_ROWS: u64 = 8;
33const BM25_K1: f64 = 1.2;
34const BM25_B: f64 = 0.75;
35
36/// One BM25-ranked node hit.
37#[derive(Clone, Debug, PartialEq)]
38pub struct TextSearchHit {
39    /// Matched node id.
40    pub node_id: NodeId,
41    /// Higher-is-better BM25 score.
42    pub score: f64,
43}
44
45/// Error returned by checked text-search APIs.
46#[derive(Debug, thiserror::Error)]
47pub enum TextSearchError {
48    /// Graph storage or consistency failure.
49    #[error(transparent)]
50    Graph(#[from] GraphError),
51    /// Caller requested cooperative cancellation.
52    #[error("text search cancelled")]
53    Cancelled,
54    /// Statement deadline elapsed.
55    #[error("text search timed out after {elapsed:?}")]
56    Timeout {
57        /// Wall-clock duration since the deadline elapsed.
58        elapsed: Duration,
59    },
60}
61
62impl TextSearchError {
63    fn into_graph_error(self) -> GraphError {
64        match self {
65            Self::Graph(error) => error,
66            Self::Cancelled | Self::Timeout { .. } => GraphError::Inconsistent {
67                reason: format!("disabled text-search checker returned {self}"),
68            },
69        }
70    }
71}
72
73impl From<CancellationCause> for TextSearchError {
74    fn from(cause: CancellationCause) -> Self {
75        match cause {
76            CancellationCause::Cancelled => Self::Cancelled,
77            CancellationCause::Timeout { elapsed } => Self::Timeout { elapsed },
78        }
79    }
80}
81
82impl SeleneGraph {
83    /// Exhaustively rank string-valued node properties using BM25.
84    ///
85    /// This is the full-text correctness oracle and small-corpus path. It scans
86    /// the row bitmap for `label`, skips nodes where `property` is absent or not
87    /// a string, tokenizes documents with the built-in Unicode-aware tokenizer,
88    /// and ranks matches with Okapi BM25 (`k1 = 1.2`, `b = 0.75`). Query tokens
89    /// are deduplicated so repeated query terms do not overweight a document.
90    pub fn exact_text_search_nodes(
91        &self,
92        label: &DbString,
93        property: &DbString,
94        query: &str,
95        k: usize,
96    ) -> GraphResult<Vec<TextSearchHit>> {
97        self.exact_text_search_nodes_checked(
98            label,
99            property,
100            query,
101            k,
102            CancellationChecker::disabled(),
103        )
104        .map_err(TextSearchError::into_graph_error)
105    }
106
107    /// Exhaustively rank string-valued node properties with cancellation checks.
108    pub fn exact_text_search_nodes_checked(
109        &self,
110        label: &DbString,
111        property: &DbString,
112        query: &str,
113        k: usize,
114        checker: CancellationChecker<'_>,
115    ) -> Result<Vec<TextSearchHit>, TextSearchError> {
116        checker.check()?;
117        if k == 0 {
118            return Ok(Vec::new());
119        }
120        let query_terms = unique_query_terms(query);
121        if query_terms.is_empty() {
122            return Ok(Vec::new());
123        }
124        let Some(label_rows) = self.nodes_with_label(label) else {
125            return Ok(Vec::new());
126        };
127
128        let scan = TextScan::new(self, label, property, &query_terms);
129        let chunk = if should_parallelize_text_scan(label_rows, k) {
130            exact_text_scan_parallel(scan, label_rows, checker)?
131        } else {
132            exact_text_scan_serial(scan, label_rows, checker)?
133        };
134        Ok(rank_text_docs(chunk, k))
135    }
136}
137
138impl SharedGraph {
139    /// Exhaustively rank string-valued node properties in the current snapshot.
140    pub fn exact_text_search_nodes(
141        &self,
142        label: &DbString,
143        property: &DbString,
144        query: &str,
145        k: usize,
146    ) -> GraphResult<Vec<TextSearchHit>> {
147        self.read()
148            .exact_text_search_nodes(label, property, query, k)
149    }
150
151    /// Exhaustively rank string-valued node properties with cancellation checks.
152    pub fn exact_text_search_nodes_checked(
153        &self,
154        label: &DbString,
155        property: &DbString,
156        query: &str,
157        k: usize,
158        checker: CancellationChecker<'_>,
159    ) -> Result<Vec<TextSearchHit>, TextSearchError> {
160        self.read()
161            .exact_text_search_nodes_checked(label, property, query, k, checker)
162    }
163}
164
165#[derive(Clone, Copy)]
166struct TextScan<'a> {
167    graph: &'a SeleneGraph,
168    label: &'a DbString,
169    property: &'a DbString,
170    query_terms: &'a [String],
171}
172
173impl<'a> TextScan<'a> {
174    fn new(
175        graph: &'a SeleneGraph,
176        label: &'a DbString,
177        property: &'a DbString,
178        query_terms: &'a [String],
179    ) -> Self {
180        Self {
181            graph,
182            label,
183            property,
184            query_terms,
185        }
186    }
187
188    fn document_for_row(self, raw_row: u32) -> Result<Option<DocumentStats>, TextSearchError> {
189        if !self.graph.node_store.is_alive(raw_row) {
190            return Ok(None);
191        }
192        let row = RowIndex::new(raw_row);
193        let node_id = self
194            .graph
195            .node_id_for_row(row)
196            .ok_or_else(|| GraphError::Inconsistent {
197                reason: format!(
198                    "label index row {raw_row} for {} has no node id",
199                    self.label.as_str()
200                ),
201            })?;
202        let properties = self
203            .graph
204            .node_store
205            .properties
206            .get(raw_row as usize)
207            .ok_or_else(|| GraphError::Inconsistent {
208                reason: format!(
209                    "text search row {raw_row} for {} has no property row",
210                    self.label.as_str()
211                ),
212            })?;
213        let Some(Value::String(text)) = properties.get(self.property) else {
214            return Ok(None);
215        };
216        Ok(document_stats(node_id, text.as_str(), self.query_terms))
217    }
218}
219
220#[derive(Debug)]
221struct TextScanChunk {
222    docs: Vec<DocumentStats>,
223    document_frequencies: Vec<u32>,
224    total_document_len: u64,
225}
226
227impl TextScanChunk {
228    fn empty(query_term_count: usize) -> Self {
229        Self {
230            docs: Vec::new(),
231            document_frequencies: vec![0; query_term_count],
232            total_document_len: 0,
233        }
234    }
235
236    fn push(&mut self, doc: DocumentStats) {
237        for (frequency, count) in self.document_frequencies.iter_mut().zip(&doc.term_counts) {
238            if *count > 0 {
239                *frequency = frequency.saturating_add(1);
240            }
241        }
242        self.total_document_len = self.total_document_len.saturating_add(u64::from(doc.len));
243        self.docs.push(doc);
244    }
245}
246
247fn should_parallelize_text_scan(rows: &RoaringBitmap, k: usize) -> bool {
248    should_parallelize_scan(rows.len(), k, TEXT_SEARCH_PARALLEL_MIN_ROWS)
249}
250
251fn exact_text_scan_parallel(
252    scan: TextScan<'_>,
253    rows: &RoaringBitmap,
254    checker: CancellationChecker<'_>,
255) -> Result<TextScanChunk, TextSearchError> {
256    try_reduce_bitmap_chunks(
257        rows,
258        TEXT_SEARCH_PARALLEL_CHUNK_ROWS,
259        checker,
260        || TextScanChunk::empty(scan.query_terms.len()),
261        |chunk| exact_text_scan_chunk(scan, chunk),
262        merge_text_scan_chunks,
263    )
264}
265
266fn exact_text_scan_serial(
267    scan: TextScan<'_>,
268    rows: &RoaringBitmap,
269    checker: CancellationChecker<'_>,
270) -> Result<TextScanChunk, TextSearchError> {
271    let mut chunk = TextScanChunk::empty(scan.query_terms.len());
272    let mut rows_since_check = 0usize;
273    for raw_row in rows.iter() {
274        rows_since_check += 1;
275        if rows_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
276            checker.check()?;
277            rows_since_check = 0;
278        }
279        if let Some(doc) = scan.document_for_row(raw_row)? {
280            chunk.push(doc);
281        }
282    }
283    Ok(chunk)
284}
285
286fn exact_text_scan_chunk(
287    scan: TextScan<'_>,
288    rows: &[u32],
289) -> Result<TextScanChunk, TextSearchError> {
290    let mut chunk = TextScanChunk::empty(scan.query_terms.len());
291    for &raw_row in rows {
292        if let Some(doc) = scan.document_for_row(raw_row)? {
293            chunk.push(doc);
294        }
295    }
296    Ok(chunk)
297}
298
299fn merge_text_scan_chunks(
300    mut lhs: TextScanChunk,
301    mut rhs: TextScanChunk,
302) -> Result<TextScanChunk, TextSearchError> {
303    for (lhs_frequency, rhs_frequency) in lhs
304        .document_frequencies
305        .iter_mut()
306        .zip(&rhs.document_frequencies)
307    {
308        *lhs_frequency = lhs_frequency.saturating_add(*rhs_frequency);
309    }
310    lhs.total_document_len = lhs
311        .total_document_len
312        .saturating_add(rhs.total_document_len);
313    lhs.docs.append(&mut rhs.docs);
314    Ok(lhs)
315}
316
317fn rank_text_docs(chunk: TextScanChunk, k: usize) -> Vec<TextSearchHit> {
318    if chunk.docs.is_empty() {
319        return Vec::new();
320    }
321    let corpus_len = chunk.docs.len() as f64;
322    let average_document_len = chunk.total_document_len as f64 / corpus_len;
323    let mut top_k = TextTopK::new(k);
324    for doc in chunk.docs {
325        let score = bm25_score(
326            &doc,
327            &chunk.document_frequencies,
328            corpus_len,
329            average_document_len,
330        );
331        if score > 0.0 {
332            top_k.push(doc.node_id, score);
333        }
334    }
335    top_k.into_hits()
336}
337
338#[derive(Debug)]
339pub(crate) struct DocumentStats {
340    pub(crate) node_id: NodeId,
341    len: u32,
342    pub(crate) term_counts: Vec<u32>,
343}
344
345impl DocumentStats {
346    pub(crate) fn zero(node_id: NodeId, len: u32, query_term_count: usize) -> Self {
347        Self {
348            node_id,
349            len,
350            term_counts: vec![0; query_term_count],
351        }
352    }
353}
354
355pub(crate) fn unique_query_terms(query: &str) -> Vec<String> {
356    let terms: BTreeSet<_> = tokenize_borrowed(query).map(Cow::into_owned).collect();
357    terms.into_iter().collect()
358}
359
360fn document_stats(node_id: NodeId, text: &str, query_terms: &[String]) -> Option<DocumentStats> {
361    let mut term_counts = vec![0_u32; query_terms.len()];
362    let mut len = 0_u32;
363    for token in tokenize_borrowed(text) {
364        len = len.saturating_add(1);
365        if let Ok(index) = query_terms.binary_search_by(|term| term.as_str().cmp(token.as_ref())) {
366            term_counts[index] = term_counts[index].saturating_add(1);
367        }
368    }
369    (len > 0).then_some(DocumentStats {
370        node_id,
371        len,
372        term_counts,
373    })
374}
375
376/// Iterate lowercase alphanumeric tokens, borrowing when lowercase is unchanged.
377pub(crate) fn tokenize_borrowed(text: &str) -> Tokenizer<'_> {
378    Tokenizer { text, offset: 0 }
379}
380
381/// Borrowing tokenizer for BM25 query/document processing.
382pub(crate) struct Tokenizer<'a> {
383    text: &'a str,
384    offset: usize,
385}
386
387impl<'a> Iterator for Tokenizer<'a> {
388    type Item = Cow<'a, str>;
389
390    fn next(&mut self) -> Option<Self::Item> {
391        let mut start = None;
392        let mut end = self.text.len();
393        let mut owned = None::<String>;
394
395        let base = self.offset;
396        for (relative_index, ch) in self.text[base..].char_indices() {
397            let index = base + relative_index;
398            if !ch.is_alphanumeric() {
399                if start.is_some() {
400                    end = index;
401                    self.offset = index + ch.len_utf8();
402                    break;
403                }
404                self.offset = index + ch.len_utf8();
405                continue;
406            }
407
408            let start_index = *start.get_or_insert(index);
409            let mut lowercase = ch.to_lowercase();
410            let first = lowercase
411                .next()
412                .expect("char lowercase mapping yields at least one char");
413            let second = lowercase.next();
414            let unchanged = first == ch && second.is_none();
415            if let Some(buffer) = owned.as_mut() {
416                if unchanged {
417                    buffer.push(ch);
418                } else {
419                    buffer.push(first);
420                    if let Some(second) = second {
421                        buffer.push(second);
422                    }
423                    buffer.extend(lowercase);
424                }
425            } else if !unchanged {
426                let mut buffer = self.text[start_index..index].to_owned();
427                buffer.push(first);
428                if let Some(second) = second {
429                    buffer.push(second);
430                }
431                buffer.extend(lowercase);
432                owned = Some(buffer);
433            }
434        }
435
436        let start = start?;
437        if self.offset <= start {
438            self.offset = self.text.len();
439        }
440
441        Some(match owned {
442            Some(token) => Cow::Owned(token),
443            None => Cow::Borrowed(&self.text[start..end]),
444        })
445    }
446}
447
448pub(crate) fn bm25_score(
449    doc: &DocumentStats,
450    document_frequencies: &[u32],
451    corpus_len: f64,
452    average_document_len: f64,
453) -> f64 {
454    let document_len = f64::from(doc.len);
455    doc.term_counts
456        .iter()
457        .zip(document_frequencies)
458        .filter(|(term_count, _)| **term_count > 0)
459        .map(|(term_count, document_frequency)| {
460            let term_count = f64::from(*term_count);
461            let document_frequency = f64::from(*document_frequency);
462            let idf =
463                (1.0 + (corpus_len - document_frequency + 0.5) / (document_frequency + 0.5)).ln();
464            let normalization = term_count
465                + BM25_K1 * (1.0 - BM25_B + BM25_B * document_len / average_document_len);
466            idf * (term_count * (BM25_K1 + 1.0)) / normalization
467        })
468        .sum()
469}
470
471#[derive(Debug)]
472pub(crate) struct TextTopK {
473    k: usize,
474    heap: BinaryHeap<TextHeapEntry>,
475}
476
477impl TextTopK {
478    pub(crate) fn new(k: usize) -> Self {
479        Self {
480            k,
481            heap: BinaryHeap::new(),
482        }
483    }
484
485    pub(crate) fn push(&mut self, node_id: NodeId, score: f64) {
486        debug_assert!(score.is_finite(), "BM25 scores must be finite");
487        if self.k == 0 {
488            return;
489        }
490        let entry = TextHeapEntry { score, node_id };
491        if self.heap.len() < self.k {
492            self.heap.push(entry);
493            return;
494        }
495        let Some(worst) = self.heap.peek() else {
496            return;
497        };
498        if entry.cmp(worst).is_lt() {
499            self.heap.pop();
500            self.heap.push(entry);
501        }
502    }
503
504    pub(crate) fn into_hits(self) -> Vec<TextSearchHit> {
505        let mut hits: Vec<_> = self
506            .heap
507            .into_iter()
508            .map(|entry| TextSearchHit {
509                node_id: entry.node_id,
510                score: entry.score,
511            })
512            .collect();
513        hits.sort_by(compare_hit);
514        hits
515    }
516}
517
518#[derive(Debug)]
519struct TextHeapEntry {
520    score: f64,
521    node_id: NodeId,
522}
523
524impl Eq for TextHeapEntry {}
525
526impl PartialEq for TextHeapEntry {
527    fn eq(&self, rhs: &Self) -> bool {
528        self.score.to_bits() == rhs.score.to_bits() && self.node_id == rhs.node_id
529    }
530}
531
532impl Ord for TextHeapEntry {
533    fn cmp(&self, rhs: &Self) -> Ordering {
534        rhs.score
535            .total_cmp(&self.score)
536            .then_with(|| self.node_id.cmp(&rhs.node_id))
537    }
538}
539
540impl PartialOrd for TextHeapEntry {
541    fn partial_cmp(&self, rhs: &Self) -> Option<Ordering> {
542        Some(self.cmp(rhs))
543    }
544}
545
546fn compare_hit(lhs: &TextSearchHit, rhs: &TextSearchHit) -> Ordering {
547    rhs.score
548        .total_cmp(&lhs.score)
549        .then_with(|| lhs.node_id.cmp(&rhs.node_id))
550}
551
552#[cfg(test)]
553#[path = "text_search/tests.rs"]
554mod tests;