Skip to main content

sochdb_vector/
bm25.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! BM25 Scoring for Lexical Search (Task 4)
19//!
20//! This module implements BM25 (Best Matching 25) scoring for keyword search.
21//! BM25 is the standard ranking function for lexical retrieval, balancing:
22//! - Term frequency (TF): How often a term appears in a document
23//! - Inverse document frequency (IDF): How rare a term is across all documents
24//! - Document length normalization: Penalizing very long documents
25//!
26//! ## BM25 Formula
27//!
28//! ```text
29//! score(q, d) = Σ IDF(t) * (TF(t,d) * (k1 + 1)) / (TF(t,d) + k1 * (1 - b + b * |d|/avgdl))
30//! ```
31//!
32//! Where:
33//! - `TF(t,d)` = term frequency of term t in document d
34//! - `IDF(t)` = log((N - df(t) + 0.5) / (df(t) + 0.5) + 1)
35//! - `N` = total number of documents
36//! - `df(t)` = number of documents containing term t
37//! - `|d|` = length of document d
38//! - `avgdl` = average document length
39//! - `k1` = term frequency saturation parameter (typically 1.2)
40//! - `b` = length normalization parameter (typically 0.75)
41
42use std::collections::HashMap;
43
44// ============================================================================
45// BM25 Configuration
46// ============================================================================
47
48/// BM25 scoring parameters
49#[derive(Debug, Clone, Copy)]
50pub struct BM25Config {
51    /// Term frequency saturation parameter (k1)
52    /// Higher values give more weight to term frequency
53    /// Typical range: 1.2 - 2.0
54    pub k1: f32,
55
56    /// Length normalization parameter (b)
57    /// 0.0 = no length normalization
58    /// 1.0 = full length normalization
59    /// Typical value: 0.75
60    pub b: f32,
61
62    /// Minimum IDF to filter out very common terms
63    pub min_idf: f32,
64}
65
66impl Default for BM25Config {
67    fn default() -> Self {
68        Self {
69            k1: 1.2,
70            b: 0.75,
71            min_idf: 0.0,
72        }
73    }
74}
75
76impl BM25Config {
77    /// Lucene-style BM25 parameters
78    pub fn lucene() -> Self {
79        Self {
80            k1: 1.2,
81            b: 0.75,
82            min_idf: 0.0,
83        }
84    }
85
86    /// Elasticsearch-style parameters
87    pub fn elasticsearch() -> Self {
88        Self {
89            k1: 1.2,
90            b: 0.75,
91            min_idf: 0.0,
92        }
93    }
94
95    /// Parameters optimized for short queries
96    pub fn short_queries() -> Self {
97        Self {
98            k1: 1.5,
99            b: 0.5, // Less length normalization
100            min_idf: 0.0,
101        }
102    }
103}
104
105// ============================================================================
106// BM25 Scorer
107// ============================================================================
108
109/// BM25 scorer for a document collection
110pub struct BM25Scorer {
111    /// Configuration
112    config: BM25Config,
113
114    /// Total number of documents
115    num_docs: usize,
116
117    /// Total token length across all documents. The average document length is
118    /// derived from this on read so it can never drift out of sync.
119    total_len: usize,
120
121    /// Document frequency for each term
122    doc_freqs: HashMap<String, usize>,
123}
124
125impl BM25Scorer {
126    /// Create a new BM25 scorer
127    pub fn new(config: BM25Config) -> Self {
128        Self {
129            config,
130            num_docs: 0,
131            total_len: 0,
132            doc_freqs: HashMap::new(),
133        }
134    }
135
136    /// Build the scorer from a collection of documents
137    pub fn build<I, D, T>(documents: I, config: BM25Config) -> Self
138    where
139        I: IntoIterator<Item = D>,
140        D: IntoIterator<Item = T>,
141        T: AsRef<str>,
142    {
143        let mut scorer = Self::new(config);
144        let mut total_len = 0usize;
145        let mut num_docs = 0usize;
146        let mut doc_freqs: HashMap<String, usize> = HashMap::new();
147
148        for doc in documents {
149            num_docs += 1;
150            let mut seen_terms: std::collections::HashSet<String> =
151                std::collections::HashSet::new();
152            let mut doc_len = 0usize;
153
154            for token in doc {
155                let term = token.as_ref().to_lowercase();
156                if !term.is_empty() {
157                    seen_terms.insert(term);
158                    doc_len += 1;
159                }
160            }
161
162            total_len += doc_len;
163
164            for term in seen_terms {
165                *doc_freqs.entry(term).or_insert(0) += 1;
166            }
167        }
168
169        scorer.num_docs = num_docs;
170        scorer.total_len = total_len;
171        scorer.doc_freqs = doc_freqs;
172
173        scorer
174    }
175
176    /// Average document length, derived from running totals.
177    ///
178    /// Computed on read so it can never drift out of sync with the corpus
179    /// (avgdl changes for every term on every insert and delete).
180    #[inline]
181    pub fn avg_doc_len(&self) -> f32 {
182        if self.num_docs > 0 {
183            self.total_len as f32 / self.num_docs as f32
184        } else {
185            0.0
186        }
187    }
188
189    /// The scoring configuration this scorer was built with.
190    #[inline]
191    pub fn config(&self) -> BM25Config {
192        self.config
193    }
194
195    /// Compute IDF from a document frequency and corpus size.
196    ///
197    /// Single source of truth for IDF, used by every scoring path (batch and
198    /// incremental). The `+ 1` floor (Robertson-Sparck Jones with smoothing)
199    /// keeps IDF strictly positive even for terms in more than half the corpus.
200    #[inline]
201    fn compute_idf(&self, df: usize, n: usize) -> f32 {
202        let n = n as f32;
203        let df = df as f32;
204        ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
205    }
206
207    /// Get IDF for a term.
208    ///
209    /// Computed lazily from the current `(df, N)` so it is always consistent
210    /// with the live corpus; unknown terms use `df = 0` (maximum IDF). Terms
211    /// whose IDF falls below `min_idf` contribute nothing.
212    pub fn idf(&self, term: &str) -> f32 {
213        let df = self
214            .doc_freqs
215            .get(&term.to_lowercase())
216            .copied()
217            .unwrap_or(0);
218        let idf = self.compute_idf(df, self.num_docs);
219        if idf < self.config.min_idf { 0.0 } else { idf }
220    }
221
222    /// Score a document for a query
223    pub fn score<I, T>(&self, query_terms: I, doc_terms: &[T], doc_len: usize) -> f32
224    where
225        I: IntoIterator<Item = T>,
226        T: AsRef<str> + std::hash::Hash + Eq,
227    {
228        // Build term frequency map for document
229        let mut tf: HashMap<&str, usize> = HashMap::new();
230        for term in doc_terms {
231            *tf.entry(term.as_ref()).or_insert(0) += 1;
232        }
233
234        let k1 = self.config.k1;
235        let b = self.config.b;
236        let dl = doc_len as f32;
237        let avgdl = self.avg_doc_len();
238
239        let mut score = 0.0f32;
240
241        for query_term in query_terms {
242            let term = query_term.as_ref().to_lowercase();
243            let term_str = term.as_str();
244
245            // Get TF for this term in the document
246            let term_tf = *tf.get(term_str).unwrap_or(&0) as f32;
247            if term_tf == 0.0 {
248                continue;
249            }
250
251            // Get IDF
252            let idf = self.idf(&term);
253
254            // BM25 scoring formula
255            let numerator = term_tf * (k1 + 1.0);
256            let denominator = term_tf + k1 * (1.0 - b + b * dl / avgdl);
257
258            score += idf * numerator / denominator;
259        }
260
261        score
262    }
263
264    /// Score a document given precomputed term frequencies
265    #[inline]
266    pub fn score_with_tf(
267        &self,
268        query_terms: &[String],
269        doc_tf: &HashMap<String, usize>,
270        doc_len: usize,
271    ) -> f32 {
272        self.score_tf_lookup(query_terms, doc_len, |term| {
273            *doc_tf.get(term).unwrap_or(&0) as f32
274        })
275    }
276
277    /// Score a document directly from a `u32`-valued term-frequency map.
278    ///
279    /// Identical math to [`score_with_tf`](Self::score_with_tf) but lets callers
280    /// whose postings already store `u32` frequencies (the inverted index) score
281    /// without cloning the whole `term_freqs` map into a `usize`-valued copy on
282    /// every query.
283    #[inline]
284    pub fn score_with_tf_u32(
285        &self,
286        query_terms: &[String],
287        doc_tf: &HashMap<String, u32>,
288        doc_len: usize,
289    ) -> f32 {
290        self.score_tf_lookup(query_terms, doc_len, |term| {
291            *doc_tf.get(term).unwrap_or(&0) as f32
292        })
293    }
294
295    /// Shared BM25 scoring core: sums the per-term contribution using a caller
296    /// supplied term-frequency lookup, so the formula has exactly one definition.
297    #[inline]
298    fn score_tf_lookup<F>(&self, query_terms: &[String], doc_len: usize, mut tf_of: F) -> f32
299    where
300        F: FnMut(&str) -> f32,
301    {
302        let k1 = self.config.k1;
303        let b = self.config.b;
304        let dl = doc_len as f32;
305        let avgdl = self.avg_doc_len();
306
307        let mut score = 0.0f32;
308
309        for term in query_terms {
310            let term_tf = tf_of(term);
311            if term_tf == 0.0 {
312                continue;
313            }
314
315            let idf = self.idf(term);
316            let numerator = term_tf * (k1 + 1.0);
317            let denominator = term_tf + k1 * (1.0 - b + b * dl / avgdl);
318
319            score += idf * numerator / denominator;
320        }
321
322        score
323    }
324
325    /// Update stats when adding a document
326    pub fn add_document<I, T>(&mut self, tokens: I)
327    where
328        I: IntoIterator<Item = T>,
329        T: AsRef<str>,
330    {
331        let mut seen: std::collections::HashSet<String> = std::collections::HashSet::new();
332        let mut doc_len = 0usize;
333
334        for token in tokens {
335            let term = token.as_ref().to_lowercase();
336            if !term.is_empty() {
337                seen.insert(term);
338                doc_len += 1;
339            }
340        }
341
342        // Update running totals (average document length is derived from these
343        // on read, so it never goes stale).
344        self.num_docs += 1;
345        self.total_len += doc_len;
346
347        // Update document frequencies. IDF is computed lazily at query time from
348        // (df, N), so there is no cache to keep in sync here.
349        for term in seen {
350            *self.doc_freqs.entry(term).or_insert(0) += 1;
351        }
352    }
353
354    /// Update stats when removing a document.
355    ///
356    /// Inverse of [`add_document`](Self::add_document): pass the document's
357    /// unique terms and its token length. Keeps `num_docs`, `total_len`, and
358    /// `doc_freqs` consistent so IDF and avgdl never drift under deletion, and
359    /// drops terms whose document frequency reaches zero (no vocabulary leak).
360    pub fn remove_document<'a, I>(&mut self, unique_terms: I, doc_len: usize)
361    where
362        I: IntoIterator<Item = &'a str>,
363    {
364        if self.num_docs == 0 {
365            return;
366        }
367        self.num_docs -= 1;
368        self.total_len = self.total_len.saturating_sub(doc_len);
369
370        for term in unique_terms {
371            let term = term.to_lowercase();
372            if let Some(df) = self.doc_freqs.get_mut(&term) {
373                *df -= 1;
374                if *df == 0 {
375                    self.doc_freqs.remove(&term);
376                }
377            }
378        }
379    }
380
381    /// Get statistics
382    pub fn stats(&self) -> BM25Stats {
383        BM25Stats {
384            num_docs: self.num_docs,
385            avg_doc_len: self.avg_doc_len(),
386            vocab_size: self.doc_freqs.len(),
387        }
388    }
389}
390
391/// BM25 scorer statistics
392#[derive(Debug, Clone)]
393pub struct BM25Stats {
394    pub num_docs: usize,
395    pub avg_doc_len: f32,
396    pub vocab_size: usize,
397}
398
399// ============================================================================
400// Simple Tokenizer
401// ============================================================================
402
403/// Simple whitespace + lowercase tokenizer
404///
405/// For MVP, this is sufficient. Can be replaced with more sophisticated
406/// tokenizers (stemming, stopwords, etc.) later.
407pub fn tokenize(text: &str) -> Vec<String> {
408    text.split_whitespace()
409        .map(|s| s.to_lowercase())
410        .filter(|s| !s.is_empty())
411        .collect()
412}
413
414/// Tokenize with minimal normalization
415pub fn tokenize_minimal(text: &str) -> Vec<String> {
416    text.split(|c: char| c.is_whitespace() || c.is_ascii_punctuation())
417        .map(|s| s.to_lowercase())
418        .filter(|s| !s.is_empty() && s.len() > 1) // Filter single chars
419        .collect()
420}
421
422/// Tokenize query (keeps original for exact matching, adds lowercase)
423pub fn tokenize_query(text: &str) -> Vec<String> {
424    let mut tokens = Vec::new();
425    for part in text.split_whitespace() {
426        let lower = part.to_lowercase();
427        tokens.push(lower);
428    }
429    tokens
430}
431
432// ============================================================================
433// Tests
434// ============================================================================
435
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    #[test]
441    fn test_bm25_basic() {
442        let docs = vec![
443            vec!["hello", "world"],
444            vec!["hello", "there"],
445            vec!["goodbye", "world"],
446        ];
447
448        let scorer = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
449
450        assert_eq!(scorer.num_docs, 3);
451        assert!((scorer.avg_doc_len() - 2.0).abs() < 0.001);
452    }
453
454    #[test]
455    fn test_bm25_idf() {
456        let docs = vec![
457            vec!["common", "common", "rare"],
458            vec!["common", "other"],
459            vec!["common", "another"],
460        ];
461
462        let scorer = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
463
464        // "common" appears in all 3 docs, "rare" in only 1
465        let idf_common = scorer.idf("common");
466        let idf_rare = scorer.idf("rare");
467
468        // Rare terms should have higher IDF
469        assert!(idf_rare > idf_common);
470    }
471
472    #[test]
473    fn test_bm25_scoring() {
474        let docs = vec![
475            vec!["the", "quick", "brown", "fox"],
476            vec!["the", "lazy", "dog"],
477            vec!["quick", "quick", "quick"], // High TF for "quick"
478        ];
479
480        let scorer = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
481
482        // Score doc 3 for "quick"
483        let score = scorer.score(vec!["quick"], &["quick", "quick", "quick"], 3);
484
485        assert!(score > 0.0);
486
487        // Score doc 1 for "quick"
488        let score1 = scorer.score(vec!["quick"], &["the", "quick", "brown", "fox"], 4);
489
490        // Doc 3 should score higher (more occurrences of "quick")
491        assert!(score > score1);
492    }
493
494    #[test]
495    fn test_tokenize() {
496        let text = "Hello, World! This is a test.";
497        let tokens = tokenize(text);
498
499        assert_eq!(tokens, vec!["hello,", "world!", "this", "is", "a", "test."]);
500    }
501
502    #[test]
503    fn test_tokenize_minimal() {
504        let text = "Hello, World! This is a test.";
505        let tokens = tokenize_minimal(text);
506
507        // Single chars and punctuation filtered
508        assert!(tokens.contains(&"hello".to_string()));
509        assert!(tokens.contains(&"world".to_string()));
510        assert!(!tokens.contains(&"a".to_string())); // Single char
511    }
512
513    #[test]
514    fn test_add_document() {
515        let mut scorer = BM25Scorer::new(BM25Config::default());
516
517        scorer.add_document(vec!["hello", "world"]);
518        assert_eq!(scorer.num_docs, 1);
519
520        scorer.add_document(vec!["hello", "there", "friend"]);
521        assert_eq!(scorer.num_docs, 2);
522
523        // Average should be (2 + 3) / 2 = 2.5
524        assert!((scorer.avg_doc_len() - 2.5).abs() < 0.001);
525    }
526    #[test]
527    fn test_build_equals_incremental() {
528        // A corpus built in batch must score identically to the same corpus
529        // built one document at a time: scoring is a pure function of corpus
530        // content, not of insertion order. Equality is exact (bit-for-bit)
531        // because both paths derive IDF/avgdl from identical integer totals.
532        let docs: Vec<Vec<&str>> = vec![
533            vec!["the", "quick", "brown", "fox"],
534            vec!["the", "lazy", "dog", "sleeps"],
535            vec!["quick", "quick", "brown", "dog"],
536            vec!["the", "fox", "and", "the", "dog"],
537        ];
538
539        let batch = BM25Scorer::build(docs.iter().map(|d| d.iter()), BM25Config::default());
540
541        let mut incremental = BM25Scorer::new(BM25Config::default());
542        for d in &docs {
543            incremental.add_document(d.iter().copied());
544        }
545
546        // Corpus-level stats are identical.
547        assert_eq!(batch.num_docs, incremental.num_docs);
548        assert_eq!(batch.total_len, incremental.total_len);
549        assert_eq!(
550            batch.avg_doc_len().to_bits(),
551            incremental.avg_doc_len().to_bits()
552        );
553
554        // IDF is identical for every term in the vocabulary.
555        for term in [
556            "the", "quick", "brown", "fox", "lazy", "dog", "sleeps", "and",
557        ] {
558            assert_eq!(
559                batch.idf(term).to_bits(),
560                incremental.idf(term).to_bits(),
561                "IDF mismatch for term {term:?}"
562            );
563        }
564
565        // And full BM25 scores match.
566        let doc = ["quick", "quick", "brown", "dog"];
567        assert_eq!(
568            batch.score(vec!["quick", "dog"], &doc, doc.len()).to_bits(),
569            incremental
570                .score(vec!["quick", "dog"], &doc, doc.len())
571                .to_bits(),
572        );
573    }
574}