selene_graph/text_index/
candidate.rs1use rustc_hash::FxHashSet;
4
5use selene_core::{CancellationChecker, NodeId};
6
7use super::{TextIndex, TextPosting};
8use crate::text_search::{
9 DocumentStats, TEXT_SEARCH_CANCEL_STRIDE, TextSearchError, TextSearchHit, TextTopK, bm25_score,
10 unique_query_terms,
11};
12
13impl TextIndex {
14 #[must_use]
22 pub fn search_candidates(
23 &self,
24 query: &str,
25 candidates: &[NodeId],
26 k: usize,
27 ) -> Vec<TextSearchHit> {
28 self.search_candidates_checked(query, candidates, k, CancellationChecker::disabled())
29 .expect("disabled text-index checker cannot fail")
30 }
31
32 pub fn search_candidates_checked(
39 &self,
40 query: &str,
41 candidates: &[NodeId],
42 k: usize,
43 checker: CancellationChecker<'_>,
44 ) -> Result<Vec<TextSearchHit>, TextSearchError> {
45 checker.check()?;
46 if k == 0 || candidates.is_empty() || self.document_lengths.is_empty() {
47 return Ok(Vec::new());
48 }
49 let query_terms = unique_query_terms(query);
50 if query_terms.is_empty() {
51 return Ok(Vec::new());
52 }
53
54 let (document_frequencies, postings_by_term) = self.query_postings(&query_terms);
55 if postings_by_term.iter().all(Option::is_none) {
56 return Ok(Vec::new());
57 }
58 let candidate_set = self.indexed_candidate_set(candidates, checker)?;
59 if candidate_set.is_empty() {
60 return Ok(Vec::new());
61 }
62
63 let corpus_len = self.document_lengths.len() as f64;
64 let average_document_len = self.total_document_len as f64 / corpus_len;
65 let mut top_k = TextTopK::new(k);
66 let mut candidates_since_check = 0usize;
67 for node_id in candidate_set {
68 candidates_since_check += 1;
69 if candidates_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
70 checker.check()?;
71 candidates_since_check = 0;
72 }
73 let len = *self
74 .document_lengths
75 .get(&node_id)
76 .expect("candidate set contains only indexed documents");
77 let Some(doc) = candidate_document_stats(node_id, len, &postings_by_term) else {
78 continue;
79 };
80 let score = bm25_score(
81 &doc,
82 &document_frequencies,
83 corpus_len,
84 average_document_len,
85 );
86 if score > 0.0 {
87 top_k.push(node_id, score);
88 }
89 }
90 Ok(top_k.into_hits())
91 }
92
93 fn indexed_candidate_set(
94 &self,
95 candidates: &[NodeId],
96 checker: CancellationChecker<'_>,
97 ) -> Result<FxHashSet<NodeId>, TextSearchError> {
98 let mut set = FxHashSet::default();
99 let mut candidates_since_check = 0usize;
100 for &candidate in candidates {
101 candidates_since_check += 1;
102 if candidates_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
103 checker.check()?;
104 candidates_since_check = 0;
105 }
106 if self.document_lengths.contains_key(&candidate) {
107 set.insert(candidate);
108 }
109 }
110 Ok(set)
111 }
112
113 fn query_postings<'a>(
114 &'a self,
115 query_terms: &[String],
116 ) -> (Vec<u32>, Vec<Option<&'a [TextPosting]>>) {
117 let mut document_frequencies = Vec::with_capacity(query_terms.len());
118 let mut postings_by_term = Vec::with_capacity(query_terms.len());
119 for term in query_terms {
120 match self.postings.get(term) {
121 Some(postings) => {
122 document_frequencies.push(u32::try_from(postings.len()).unwrap_or(u32::MAX));
123 postings_by_term.push(Some(postings.as_slice()));
124 }
125 None => {
126 document_frequencies.push(0);
127 postings_by_term.push(None);
128 }
129 }
130 }
131 (document_frequencies, postings_by_term)
132 }
133}
134
135fn candidate_document_stats(
136 node_id: NodeId,
137 len: u32,
138 postings_by_term: &[Option<&[TextPosting]>],
139) -> Option<DocumentStats> {
140 let mut doc = DocumentStats::zero(node_id, len, postings_by_term.len());
141 let mut matched = false;
142 for (term_index, postings) in postings_by_term.iter().enumerate() {
143 let Some(postings) = postings else {
144 continue;
145 };
146 if let Ok(index) = postings.binary_search_by_key(&node_id, |posting| posting.node_id) {
147 doc.term_counts[term_index] = postings[index].term_count;
148 matched = true;
149 }
150 }
151 matched.then_some(doc)
152}