Skip to main content

selene_graph/text_index/
candidate.rs

1//! Candidate-scoped BM25 scoring for maintained text indexes.
2
3use rustc_hash::FxHashSet;
4
5use selene_core::{CancellationChecker, NodeId};
6
7use super::{QueryDocumentFrequencies, QueryPostings, TextIndex, TextPosting};
8use crate::text_search::{
9    DocumentStats, TEXT_SEARCH_CANCEL_STRIDE, TextSearchError, TextSearchHit, TextTopK, bm25_score,
10    unique_query_terms,
11};
12
13impl TextIndex {
14    /// Rank explicit node candidates for `query` using this index's BM25 corpus stats.
15    ///
16    /// Candidate ids are deduplicated and missing/non-indexed ids are ignored.
17    /// Term document frequencies and average document length remain global to
18    /// the maintained index, so this returns the same ordering as a full
19    /// [`Self::search`] followed by candidate filtering when enough global hits
20    /// are materialized.
21    #[must_use]
22    pub fn search_candidates(
23        &self,
24        query: &str,
25        candidates: &[NodeId],
26        k: usize,
27    ) -> Vec<TextSearchHit> {
28        self.search_candidates_checked(query, candidates, k, CancellationChecker::disabled())
29            .expect("disabled text-index checker cannot fail")
30    }
31
32    /// Rank explicit node candidates for `query` with cooperative cancellation checks.
33    ///
34    /// # Errors
35    ///
36    /// Returns [`TextSearchError::Cancelled`], [`TextSearchError::Timeout`], or
37    /// [`TextSearchError::NodeScanBudgetExceeded`] when the supplied checker
38    /// trips while deduplicating or scoring candidates.
39    pub fn search_candidates_checked(
40        &self,
41        query: &str,
42        candidates: &[NodeId],
43        k: usize,
44        checker: CancellationChecker<'_>,
45    ) -> Result<Vec<TextSearchHit>, TextSearchError> {
46        checker.check()?;
47        if k == 0 || candidates.is_empty() || self.document_lengths.is_empty() {
48            return Ok(Vec::new());
49        }
50        let query_terms = unique_query_terms(query);
51        if query_terms.is_empty() {
52            return Ok(Vec::new());
53        }
54
55        let (document_frequencies, postings_by_term) = self.query_postings(&query_terms);
56        if postings_by_term.iter().all(Option::is_none) {
57            return Ok(Vec::new());
58        }
59        let candidate_set = self.indexed_candidate_set(candidates, checker)?;
60        if candidate_set.is_empty() {
61            return Ok(Vec::new());
62        }
63
64        let corpus_len = self.document_lengths.len() as f64;
65        let average_document_len = self.total_document_len as f64 / corpus_len;
66        let mut top_k = TextTopK::new(k);
67        let mut candidates_since_check = 0usize;
68        for node_id in candidate_set {
69            candidates_since_check += 1;
70            if candidates_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
71                checker.note_nodes_scanned(candidates_since_check)?;
72                candidates_since_check = 0;
73            }
74            let len = *self
75                .document_lengths
76                .get(&node_id)
77                .expect("candidate set contains only indexed documents");
78            let Some(doc) = candidate_document_stats(node_id, len, &postings_by_term) else {
79                continue;
80            };
81            let score = bm25_score(
82                &doc,
83                &document_frequencies,
84                corpus_len,
85                average_document_len,
86            );
87            if score > 0.0 {
88                top_k.push(node_id, score);
89            }
90        }
91        if candidates_since_check > 0 {
92            checker.note_nodes_scanned(candidates_since_check)?;
93        }
94        Ok(top_k.into_hits())
95    }
96
97    fn indexed_candidate_set(
98        &self,
99        candidates: &[NodeId],
100        checker: CancellationChecker<'_>,
101    ) -> Result<FxHashSet<NodeId>, TextSearchError> {
102        let mut set = FxHashSet::default();
103        set.reserve(candidates.len().min(self.document_lengths.len()));
104        let mut candidates_since_check = 0usize;
105        for &candidate in candidates {
106            candidates_since_check += 1;
107            if candidates_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
108                checker.note_nodes_scanned(candidates_since_check)?;
109                candidates_since_check = 0;
110            }
111            if self.document_lengths.contains_key(&candidate) {
112                set.insert(candidate);
113            }
114        }
115        if candidates_since_check > 0 {
116            checker.note_nodes_scanned(candidates_since_check)?;
117        }
118        Ok(set)
119    }
120
121    fn query_postings<'a>(
122        &'a self,
123        query_terms: &[String],
124    ) -> (QueryDocumentFrequencies, QueryPostings<'a>) {
125        let mut document_frequencies = QueryDocumentFrequencies::with_capacity(query_terms.len());
126        let mut postings_by_term = QueryPostings::with_capacity(query_terms.len());
127        for term in query_terms {
128            match self.postings.get(term) {
129                Some(postings) => {
130                    document_frequencies.push(u32::try_from(postings.len()).unwrap_or(u32::MAX));
131                    postings_by_term.push(Some(postings.as_slice()));
132                }
133                None => {
134                    document_frequencies.push(0);
135                    postings_by_term.push(None);
136                }
137            }
138        }
139        (document_frequencies, postings_by_term)
140    }
141}
142
143fn candidate_document_stats(
144    node_id: NodeId,
145    len: u32,
146    postings_by_term: &[Option<&[TextPosting]>],
147) -> Option<DocumentStats> {
148    let mut doc = DocumentStats::zero(node_id, len, postings_by_term.len());
149    let mut matched = false;
150    for (term_index, postings) in postings_by_term.iter().enumerate() {
151        let Some(postings) = postings else {
152            continue;
153        };
154        if let Ok(index) = postings.binary_search_by_key(&node_id, |posting| posting.node_id) {
155            doc.term_counts[term_index] = postings[index].term_count;
156            matched = true;
157        }
158    }
159    matched.then_some(doc)
160}