1use std::collections::HashMap;
43
44#[derive(Debug, Clone, Copy)]
50pub struct BM25Config {
51 pub k1: f32,
55
56 pub b: f32,
61
62 pub min_idf: f32,
64}
65
66impl Default for BM25Config {
67 fn default() -> Self {
68 Self {
69 k1: 1.2,
70 b: 0.75,
71 min_idf: 0.0,
72 }
73 }
74}
75
76impl BM25Config {
77 pub fn lucene() -> Self {
79 Self {
80 k1: 1.2,
81 b: 0.75,
82 min_idf: 0.0,
83 }
84 }
85
86 pub fn elasticsearch() -> Self {
88 Self {
89 k1: 1.2,
90 b: 0.75,
91 min_idf: 0.0,
92 }
93 }
94
95 pub fn short_queries() -> Self {
97 Self {
98 k1: 1.5,
99 b: 0.5, min_idf: 0.0,
101 }
102 }
103}
104
105pub struct BM25Scorer {
111 config: BM25Config,
113
114 num_docs: usize,
116
117 total_len: usize,
120
121 doc_freqs: HashMap<String, usize>,
123}
124
125impl BM25Scorer {
126 pub fn new(config: BM25Config) -> Self {
128 Self {
129 config,
130 num_docs: 0,
131 total_len: 0,
132 doc_freqs: HashMap::new(),
133 }
134 }
135
136 pub fn build<I, D, T>(documents: I, config: BM25Config) -> Self
138 where
139 I: IntoIterator<Item = D>,
140 D: IntoIterator<Item = T>,
141 T: AsRef<str>,
142 {
143 let mut scorer = Self::new(config);
144 let mut total_len = 0usize;
145 let mut num_docs = 0usize;
146 let mut doc_freqs: HashMap<String, usize> = HashMap::new();
147
148 for doc in documents {
149 num_docs += 1;
150 let mut seen_terms: std::collections::HashSet<String> =
151 std::collections::HashSet::new();
152 let mut doc_len = 0usize;
153
154 for token in doc {
155 let term = token.as_ref().to_lowercase();
156 if !term.is_empty() {
157 seen_terms.insert(term);
158 doc_len += 1;
159 }
160 }
161
162 total_len += doc_len;
163
164 for term in seen_terms {
165 *doc_freqs.entry(term).or_insert(0) += 1;
166 }
167 }
168
169 scorer.num_docs = num_docs;
170 scorer.total_len = total_len;
171 scorer.doc_freqs = doc_freqs;
172
173 scorer
174 }
175
176 #[inline]
181 pub fn avg_doc_len(&self) -> f32 {
182 if self.num_docs > 0 {
183 self.total_len as f32 / self.num_docs as f32
184 } else {
185 0.0
186 }
187 }
188
189 #[inline]
191 pub fn config(&self) -> BM25Config {
192 self.config
193 }
194
195 #[inline]
201 fn compute_idf(&self, df: usize, n: usize) -> f32 {
202 let n = n as f32;
203 let df = df as f32;
204 ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
205 }
206
207 pub fn idf(&self, term: &str) -> f32 {
213 let df = self
214 .doc_freqs
215 .get(&term.to_lowercase())
216 .copied()
217 .unwrap_or(0);
218 let idf = self.compute_idf(df, self.num_docs);
219 if idf < self.config.min_idf { 0.0 } else { idf }
220 }
221
222 pub fn score<I, T>(&self, query_terms: I, doc_terms: &[T], doc_len: usize) -> f32
224 where
225 I: IntoIterator<Item = T>,
226 T: AsRef<str> + std::hash::Hash + Eq,
227 {
228 let mut tf: HashMap<&str, usize> = HashMap::new();
230 for term in doc_terms {
231 *tf.entry(term.as_ref()).or_insert(0) += 1;
232 }
233
234 let k1 = self.config.k1;
235 let b = self.config.b;
236 let dl = doc_len as f32;
237 let avgdl = self.avg_doc_len();
238
239 let mut score = 0.0f32;
240
241 for query_term in query_terms {
242 let term = query_term.as_ref().to_lowercase();
243 let term_str = term.as_str();
244
245 let term_tf = *tf.get(term_str).unwrap_or(&0) as f32;
247 if term_tf == 0.0 {
248 continue;
249 }
250
251 let idf = self.idf(&term);
253
254 let numerator = term_tf * (k1 + 1.0);
256 let denominator = term_tf + k1 * (1.0 - b + b * dl / avgdl);
257
258 score += idf * numerator / denominator;
259 }
260
261 score
262 }
263
264 #[inline]
266 pub fn score_with_tf(
267 &self,
268 query_terms: &[String],
269 doc_tf: &HashMap<String, usize>,
270 doc_len: usize,
271 ) -> f32 {
272 self.score_tf_lookup(query_terms, doc_len, |term| {
273 *doc_tf.get(term).unwrap_or(&0) as f32
274 })
275 }
276
277 #[inline]
284 pub fn score_with_tf_u32(
285 &self,
286 query_terms: &[String],
287 doc_tf: &HashMap<String, u32>,
288 doc_len: usize,
289 ) -> f32 {
290 self.score_tf_lookup(query_terms, doc_len, |term| {
291 *doc_tf.get(term).unwrap_or(&0) as f32
292 })
293 }
294
295 #[inline]
298 fn score_tf_lookup<F>(&self, query_terms: &[String], doc_len: usize, mut tf_of: F) -> f32
299 where
300 F: FnMut(&str) -> f32,
301 {
302 let k1 = self.config.k1;
303 let b = self.config.b;
304 let dl = doc_len as f32;
305 let avgdl = self.avg_doc_len();
306
307 let mut score = 0.0f32;
308
309 for term in query_terms {
310 let term_tf = tf_of(term);
311 if term_tf == 0.0 {
312 continue;
313 }
314
315 let idf = self.idf(term);
316 let numerator = term_tf * (k1 + 1.0);
317 let denominator = term_tf + k1 * (1.0 - b + b * dl / avgdl);
318
319 score += idf * numerator / denominator;
320 }
321
322 score
323 }
324
325 pub fn add_document<I, T>(&mut self, tokens: I)
327 where
328 I: IntoIterator<Item = T>,
329 T: AsRef<str>,
330 {
331 let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
332 let mut doc_len = 0usize;
333
334 for token in tokens {
335 let term = token.as_ref().to_lowercase();
336 if !term.is_empty() {
337 seen.insert(term);
338 doc_len += 1;
339 }
340 }
341
342 self.num_docs += 1;
345 self.total_len += doc_len;
346
347 for term in seen {
350 *self.doc_freqs.entry(term).or_insert(0) += 1;
351 }
352 }
353
354 pub fn remove_document<'a, I>(&mut self, unique_terms: I, doc_len: usize)
361 where
362 I: IntoIterator<Item = &'a str>,
363 {
364 if self.num_docs == 0 {
365 return;
366 }
367 self.num_docs -= 1;
368 self.total_len = self.total_len.saturating_sub(doc_len);
369
370 for term in unique_terms {
371 let term = term.to_lowercase();
372 if let Some(df) = self.doc_freqs.get_mut(&term) {
373 *df -= 1;
374 if *df == 0 {
375 self.doc_freqs.remove(&term);
376 }
377 }
378 }
379 }
380
381 pub fn stats(&self) -> BM25Stats {
383 BM25Stats {
384 num_docs: self.num_docs,
385 avg_doc_len: self.avg_doc_len(),
386 vocab_size: self.doc_freqs.len(),
387 }
388 }
389}
390
391#[derive(Debug, Clone)]
393pub struct BM25Stats {
394 pub num_docs: usize,
395 pub avg_doc_len: f32,
396 pub vocab_size: usize,
397}
398
399pub fn tokenize(text: &str) -> Vec<String> {
408 text.split_whitespace()
409 .map(|s| s.to_lowercase())
410 .filter(|s| !s.is_empty())
411 .collect()
412}
413
414pub fn tokenize_minimal(text: &str) -> Vec<String> {
416 text.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
417 .map(|s| s.to_lowercase())
418 .filter(|s| !s.is_empty() && s.len() > 1) .collect()
420}
421
422pub fn tokenize_query(text: &str) -> Vec<String> {
424 let mut tokens = Vec::new();
425 for part in text.split_whitespace() {
426 let lower = part.to_lowercase();
427 tokens.push(lower);
428 }
429 tokens
430}
431
432#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
441 fn test_bm25_basic() {
442 let docs = vec![
443 vec!["hello", "world"],
444 vec!["hello", "there"],
445 vec!["goodbye", "world"],
446 ];
447
448 let scorer = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
449
450 assert_eq!(scorer.num_docs, 3);
451 assert!((scorer.avg_doc_len() - 2.0).abs() < 0.001);
452 }
453
454 #[test]
455 fn test_bm25_idf() {
456 let docs = vec![
457 vec!["common", "common", "rare"],
458 vec!["common", "other"],
459 vec!["common", "another"],
460 ];
461
462 let scorer = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
463
464 let idf_common = scorer.idf("common");
466 let idf_rare = scorer.idf("rare");
467
468 assert!(idf_rare > idf_common);
470 }
471
472 #[test]
473 fn test_bm25_scoring() {
474 let docs = vec![
475 vec!["the", "quick", "brown", "fox"],
476 vec!["the", "lazy", "dog"],
477 vec!["quick", "quick", "quick"], ];
479
480 let scorer = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
481
482 let score = scorer.score(vec!["quick"], &["quick", "quick", "quick"], 3);
484
485 assert!(score > 0.0);
486
487 let score1 = scorer.score(vec!["quick"], &["the", "quick", "brown", "fox"], 4);
489
490 assert!(score > score1);
492 }
493
494 #[test]
495 fn test_tokenize() {
496 let text = "Hello, World! This is a test.";
497 let tokens = tokenize(text);
498
499 assert_eq!(tokens, vec!["hello,", "world!", "this", "is", "a", "test."]);
500 }
501
502 #[test]
503 fn test_tokenize_minimal() {
504 let text = "Hello, World! This is a test.";
505 let tokens = tokenize_minimal(text);
506
507 assert!(tokens.contains(&"hello".to_string()));
509 assert!(tokens.contains(&"world".to_string()));
510 assert!(!tokens.contains(&"a".to_string())); }
512
513 #[test]
514 fn test_add_document() {
515 let mut scorer = BM25Scorer::new(BM25Config::default());
516
517 scorer.add_document(vec!["hello", "world"]);
518 assert_eq!(scorer.num_docs, 1);
519
520 scorer.add_document(vec!["hello", "there", "friend"]);
521 assert_eq!(scorer.num_docs, 2);
522
523 assert!((scorer.avg_doc_len() - 2.5).abs() < 0.001);
525 }
526 #[test]
527 fn test_build_equals_incremental() {
528 let docs: Vec<Vec<&str>> = vec![
533 vec!["the", "quick", "brown", "fox"],
534 vec!["the", "lazy", "dog", "sleeps"],
535 vec!["quick", "quick", "brown", "dog"],
536 vec!["the", "fox", "and", "the", "dog"],
537 ];
538
539 let batch = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
540
541 let mut incremental = BM25Scorer::new(BM25Config::default());
542 for d in &docs {
543 incremental.add_document(d.iter().copied());
544 }
545
546 assert_eq!(batch.num_docs, incremental.num_docs);
548 assert_eq!(batch.total_len, incremental.total_len);
549 assert_eq!(
550 batch.avg_doc_len().to_bits(),
551 incremental.avg_doc_len().to_bits()
552 );
553
554 for term in [
556 "the", "quick", "brown", "fox", "lazy", "dog", "sleeps", "and",
557 ] {
558 assert_eq!(
559 batch.idf(term).to_bits(),
560 incremental.idf(term).to_bits(),
561 "IDF mismatch for term {term:?}"
562 );
563 }
564
565 let doc = ["quick", "quick", "brown", "dog"];
567 assert_eq!(
568 batch.score(vec!["quick", "dog"], &doc, doc.len()).to_bits(),
569 incremental
570 .score(vec!["quick", "dog"], &doc, doc.len())
571 .to_bits(),
572 );
573 }
574}