1use crate::types::NodeId;
9use std::collections::HashMap;
10
11#[derive(Debug, Clone, Copy)]
13pub struct Bm25Params {
14 pub k1: f32,
15 pub b: f32,
16}
17
18impl Default for Bm25Params {
19 fn default() -> Self {
20 Self { k1: 1.2, b: 0.75 }
21 }
22}
23
24#[derive(Debug, Clone)]
26pub struct Bm25Index {
27 params: Bm25Params,
28 postings: HashMap<String, Vec<(u32, u32)>>,
30 doc_ids: Vec<NodeId>,
31 doc_len: Vec<u32>,
32 avgdl: f32,
33}
34
35impl Bm25Index {
36 pub fn tokenize(text: &str) -> Vec<String> {
38 text.split(|c: char| !c.is_alphanumeric())
39 .filter(|t| !t.is_empty())
40 .map(|t| t.to_ascii_lowercase())
41 .collect()
42 }
43
44 pub fn build<I, S>(docs: I, params: Bm25Params) -> Self
46 where
47 I: IntoIterator<Item = (NodeId, S)>,
48 S: AsRef<str>,
49 {
50 let mut postings: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
51 let mut doc_ids = Vec::new();
52 let mut doc_len = Vec::new();
53 let mut total_len: u64 = 0;
54
55 for (id, text) in docs {
56 let doc_idx = doc_ids.len() as u32;
57 let tokens = Self::tokenize(text.as_ref());
58 doc_len.push(tokens.len() as u32);
59 total_len += tokens.len() as u64;
60
61 let mut tf: HashMap<String, u32> = HashMap::new();
63 for tok in tokens {
64 *tf.entry(tok).or_insert(0) += 1;
65 }
66 for (term, freq) in tf {
67 postings.entry(term).or_default().push((doc_idx, freq));
68 }
69 doc_ids.push(id);
70 }
71
72 let n = doc_ids.len().max(1) as f32;
73 let avgdl = if doc_ids.is_empty() { 0.0 } else { total_len as f32 / n };
74 Self { params, postings, doc_ids, doc_len, avgdl }
75 }
76
77 pub fn len(&self) -> usize {
79 self.doc_ids.len()
80 }
81 pub fn is_empty(&self) -> bool {
82 self.doc_ids.is_empty()
83 }
84
85 pub fn search(&self, query: &str, k: usize) -> Vec<(NodeId, f32)> {
88 if self.doc_ids.is_empty() || k == 0 {
89 return Vec::new();
90 }
91 let n = self.doc_ids.len() as f32;
92 let (k1, b) = (self.params.k1, self.params.b);
93 let mut scores: HashMap<u32, f32> = HashMap::new();
94
95 let mut seen_terms = std::collections::HashSet::new();
97 for term in Self::tokenize(query) {
98 if !seen_terms.insert(term.clone()) {
99 continue;
100 }
101 let Some(postings) = self.postings.get(&term) else {
102 continue;
103 };
104 let df = postings.len() as f32;
105 let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
107 for &(doc_idx, freq) in postings {
108 let dl = self.doc_len[doc_idx as usize] as f32;
109 let tf = freq as f32;
110 let denom = tf + k1 * (1.0 - b + b * dl / self.avgdl.max(1e-6));
111 let contribution = idf * (tf * (k1 + 1.0)) / denom;
112 *scores.entry(doc_idx).or_insert(0.0) += contribution;
113 }
114 }
115
116 let mut ranked: Vec<(NodeId, f32)> = scores
117 .into_iter()
118 .map(|(idx, s)| (self.doc_ids[idx as usize].clone(), s))
119 .collect();
120 ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
121 ranked.truncate(k);
122 ranked
123 }
124}
125
126#[cfg(test)]
127mod tests {
128 use super::*;
129
130 fn corpus() -> Vec<(NodeId, &'static str)> {
131 vec![
132 ("d1".into(), "the quick brown fox jumps over the lazy dog"),
133 ("d2".into(), "machine learning models for vector search"),
134 ("d3".into(), "vector databases enable semantic search at scale"),
135 ("d4".into(), "a recipe for italian pasta with tomato sauce"),
136 ]
137 }
138
139 #[test]
140 fn ranks_relevant_docs_first() {
141 let idx = Bm25Index::build(corpus(), Bm25Params::default());
142 assert_eq!(idx.len(), 4);
143 let res = idx.search("vector search", 4);
144 assert!(!res.is_empty());
145 assert!(res[0].0 == "d2" || res[0].0 == "d3");
147 assert!(res.iter().all(|(id, _)| id != "d4") || res.last().unwrap().0 == "d4");
148 }
149
150 #[test]
151 fn idf_downweights_common_terms() {
152 let idx = Bm25Index::build(corpus(), Bm25Params::default());
153 let res = idx.search("pasta", 4);
155 assert_eq!(res[0].0, "d4");
156 }
157
158 #[test]
159 fn empty_query_and_index_safe() {
160 let empty = Bm25Index::build(Vec::<(NodeId, &str)>::new(), Bm25Params::default());
161 assert!(empty.search("anything", 5).is_empty());
162 let idx = Bm25Index::build(corpus(), Bm25Params::default());
163 assert!(idx.search("", 5).is_empty());
164 assert!(idx.search("zzz nonexistent", 5).is_empty());
165 }
166}