selene_graph/text_index/
candidate.rs1use rustc_hash::FxHashSet;
4
5use selene_core::{CancellationChecker, NodeId};
6
7use super::{QueryDocumentFrequencies, QueryPostings, TextIndex, TextPosting};
8use crate::text_search::{
9 DocumentStats, TEXT_SEARCH_CANCEL_STRIDE, TextSearchError, TextSearchHit, TextTopK, bm25_score,
10 unique_query_terms,
11};
12
13impl TextIndex {
14 #[must_use]
22 pub fn search_candidates(
23 &self,
24 query: &str,
25 candidates: &[NodeId],
26 k: usize,
27 ) -> Vec<TextSearchHit> {
28 self.search_candidates_checked(query, candidates, k, CancellationChecker::disabled())
29 .expect("disabled text-index checker cannot fail")
30 }
31
32 pub fn search_candidates_checked(
40 &self,
41 query: &str,
42 candidates: &[NodeId],
43 k: usize,
44 checker: CancellationChecker<'_>,
45 ) -> Result<Vec<TextSearchHit>, TextSearchError> {
46 checker.check()?;
47 if k == 0 || candidates.is_empty() || self.document_lengths.is_empty() {
48 return Ok(Vec::new());
49 }
50 let query_terms = unique_query_terms(query);
51 if query_terms.is_empty() {
52 return Ok(Vec::new());
53 }
54
55 let (document_frequencies, postings_by_term) = self.query_postings(&query_terms);
56 if postings_by_term.iter().all(Option::is_none) {
57 return Ok(Vec::new());
58 }
59 let candidate_set = self.indexed_candidate_set(candidates, checker)?;
60 if candidate_set.is_empty() {
61 return Ok(Vec::new());
62 }
63
64 let corpus_len = self.document_lengths.len() as f64;
65 let average_document_len = self.total_document_len as f64 / corpus_len;
66 let mut top_k = TextTopK::new(k);
67 let mut candidates_since_check = 0usize;
68 for node_id in candidate_set {
69 candidates_since_check += 1;
70 if candidates_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
71 checker.note_nodes_scanned(candidates_since_check)?;
72 candidates_since_check = 0;
73 }
74 let len = *self
75 .document_lengths
76 .get(&node_id)
77 .expect("candidate set contains only indexed documents");
78 let Some(doc) = candidate_document_stats(node_id, len, &postings_by_term) else {
79 continue;
80 };
81 let score = bm25_score(
82 &doc,
83 &document_frequencies,
84 corpus_len,
85 average_document_len,
86 );
87 if score > 0.0 {
88 top_k.push(node_id, score);
89 }
90 }
91 if candidates_since_check > 0 {
92 checker.note_nodes_scanned(candidates_since_check)?;
93 }
94 Ok(top_k.into_hits())
95 }
96
97 fn indexed_candidate_set(
98 &self,
99 candidates: &[NodeId],
100 checker: CancellationChecker<'_>,
101 ) -> Result<FxHashSet<NodeId>, TextSearchError> {
102 let mut set = FxHashSet::default();
103 set.reserve(candidates.len().min(self.document_lengths.len()));
104 let mut candidates_since_check = 0usize;
105 for &candidate in candidates {
106 candidates_since_check += 1;
107 if candidates_since_check >= TEXT_SEARCH_CANCEL_STRIDE {
108 checker.note_nodes_scanned(candidates_since_check)?;
109 candidates_since_check = 0;
110 }
111 if self.document_lengths.contains_key(&candidate) {
112 set.insert(candidate);
113 }
114 }
115 if candidates_since_check > 0 {
116 checker.note_nodes_scanned(candidates_since_check)?;
117 }
118 Ok(set)
119 }
120
121 fn query_postings<'a>(
122 &'a self,
123 query_terms: &[String],
124 ) -> (QueryDocumentFrequencies, QueryPostings<'a>) {
125 let mut document_frequencies = QueryDocumentFrequencies::with_capacity(query_terms.len());
126 let mut postings_by_term = QueryPostings::with_capacity(query_terms.len());
127 for term in query_terms {
128 match self.postings.get(term) {
129 Some(postings) => {
130 document_frequencies.push(u32::try_from(postings.len()).unwrap_or(u32::MAX));
131 postings_by_term.push(Some(postings.as_slice()));
132 }
133 None => {
134 document_frequencies.push(0);
135 postings_by_term.push(None);
136 }
137 }
138 }
139 (document_frequencies, postings_by_term)
140 }
141}
142
143fn candidate_document_stats(
144 node_id: NodeId,
145 len: u32,
146 postings_by_term: &[Option<&[TextPosting]>],
147) -> Option<DocumentStats> {
148 let mut doc = DocumentStats::zero(node_id, len, postings_by_term.len());
149 let mut matched = false;
150 for (term_index, postings) in postings_by_term.iter().enumerate() {
151 let Some(postings) = postings else {
152 continue;
153 };
154 if let Ok(index) = postings.binary_search_by_key(&node_id, |posting| posting.node_id) {
155 doc.term_counts[term_index] = postings[index].term_count;
156 matched = true;
157 }
158 }
159 matched.then_some(doc)
160}