Skip to main content

selene_graph/text_index/
candidate.rs

1//! Candidate-scoped BM25 scoring for maintained text indexes.
2
3use rustc_hash::FxHashSet;
4
5use selene_core::{CancellationChecker, NodeId};
6
7use super::{TextIndex, TextPosting};
8use crate::text_search::{
9    DocumentStats, TEXT_SEARCH_CANCEL_STRIDE, TextSearchError, TextSearchHit, TextTopK, bm25_score,
10    unique_query_terms,
11};
12
13impl TextIndex {
14    /// Rank explicit node candidates for `query` using this index's BM25 corpus stats.
15    ///
16    /// Candidate ids are deduplicated and missing/non-indexed ids are ignored.
17    /// Term document frequencies and average document length remain global to
18    /// the maintained index, so this returns the same ordering as a full
19    /// [`Self::search`] followed by candidate filtering when enough global hits
20    /// are materialized.
21    #[must_use]
22    pub fn search_candidates(
23        &self,
24        query: &str,
25        candidates: &[NodeId],
26        k: usize,
27    ) -> Vec<TextSearchHit> {
28        self.search_candidates_checked(query, candidates, k, CancellationChecker::disabled())
29            .expect("disabled text-index checker cannot fail")
30    }
31
32    /// Rank explicit node candidates for `query` with cooperative cancellation checks.
33    ///
34    /// # Errors
35    ///
36    /// Returns [`TextSearchError::Cancelled`] or [`TextSearchError::Timeout`] when
37    /// the supplied checker trips while deduplicating or scoring candidates.
38    pub fn search_candidates_checked(
39        &self,
40        query: &str,
41        candidates: &[NodeId],
42        k: usize,
43        checker: CancellationChecker<'_>,
44    ) -> Result<Vec<TextSearchHit>, TextSearchError> {
45        checker.check()?;
46        if k == 0 || candidates.is_empty() || self.document_lengths.is_empty() {
47            return Ok(Vec::new());
48        }
49        let query_terms = unique_query_terms(query);
50        if query_terms.is_empty() {
51            return Ok(Vec::new());
52        }
53
54        let (document_frequencies, postings_by_term) = self.query_postings(&query_terms);
55        if postings_by_term.iter().all(Option::is_none) {
56            return Ok(Vec::new());
57        }
58        let candidate_set = self.indexed_candidate_set(candidates, checker)?;
59        if candidate_set.is_empty() {
60            return Ok(Vec::new());
61        }
62
63        let corpus_len = self.document_lengths.len() as f64;
64        let average_document_len = self.total_document_len as f64 / corpus_len;
65        let mut top_k = TextTopK::new(k);
66        let mut candidates_since_check = 0usize;
67        for node_id in candidate_set {
68            candidates_since_check += 1;
69            if candidates_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
70                checker.check()?;
71                candidates_since_check = 0;
72            }
73            let len = *self
74                .document_lengths
75                .get(&node_id)
76                .expect("candidate set contains only indexed documents");
77            let Some(doc) = candidate_document_stats(node_id, len, &postings_by_term) else {
78                continue;
79            };
80            let score = bm25_score(
81                &doc,
82                &document_frequencies,
83                corpus_len,
84                average_document_len,
85            );
86            if score > 0.0 {
87                top_k.push(node_id, score);
88            }
89        }
90        Ok(top_k.into_hits())
91    }
92
93    fn indexed_candidate_set(
94        &self,
95        candidates: &[NodeId],
96        checker: CancellationChecker<'_>,
97    ) -> Result<FxHashSet<NodeId>, TextSearchError> {
98        let mut set = FxHashSet::default();
99        let mut candidates_since_check = 0usize;
100        for &candidate in candidates {
101            candidates_since_check += 1;
102            if candidates_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
103                checker.check()?;
104                candidates_since_check = 0;
105            }
106            if self.document_lengths.contains_key(&candidate) {
107                set.insert(candidate);
108            }
109        }
110        Ok(set)
111    }
112
113    fn query_postings<'a>(
114        &'a self,
115        query_terms: &[String],
116    ) -> (Vec<u32>, Vec<Option<&'a [TextPosting]>>) {
117        let mut document_frequencies = Vec::with_capacity(query_terms.len());
118        let mut postings_by_term = Vec::with_capacity(query_terms.len());
119        for term in query_terms {
120            match self.postings.get(term) {
121                Some(postings) => {
122                    document_frequencies.push(u32::try_from(postings.len()).unwrap_or(u32::MAX));
123                    postings_by_term.push(Some(postings.as_slice()));
124                }
125                None => {
126                    document_frequencies.push(0);
127                    postings_by_term.push(None);
128                }
129            }
130        }
131        (document_frequencies, postings_by_term)
132    }
133}
134
135fn candidate_document_stats(
136    node_id: NodeId,
137    len: u32,
138    postings_by_term: &[Option<&[TextPosting]>],
139) -> Option<DocumentStats> {
140    let mut doc = DocumentStats::zero(node_id, len, postings_by_term.len());
141    let mut matched = false;
142    for (term_index, postings) in postings_by_term.iter().enumerate() {
143        let Some(postings) = postings else {
144            continue;
145        };
146        if let Ok(index) = postings.binary_search_by_key(&node_id, |posting| posting.node_id) {
147            doc.term_counts[term_index] = postings[index].term_count;
148            matched = true;
149        }
150    }
151    matched.then_some(doc)
152}