Skip to main content

sqlrite/sql/fts/
bm25.rs

1//! BM25 relevance scoring — the standard ranking function for keyword
2//! retrieval. Pure math; no SQL coupling.
3//!
4//! Resolves Phase 8 plan Q4 + Q5: no stemming and no stop-list. The
5//! caller is responsible for tokenizing both the query and the document
6//! (see [`super::tokenizer::tokenize`]); this module just consumes term
7//! frequencies + corpus stats and produces a score.
8//!
9//! ## Formula (Robertson/Spärck Jones BM25)
10//!
11//! For a document `d` and query `q`:
12//!
13//! ```text
14//! score(d, q) = Σ_{t ∈ q} idf(t) · (tf(t,d) · (k1 + 1)) /
15//!                                  (tf(t,d) + k1 · (1 - b + b · |d| / avgdl))
16//!
17//! idf(t) = ln(1 + (N - n(t) + 0.5) / (n(t) + 0.5))
18//! ```
19//!
20//! - `N`        = total documents in corpus
21//! - `n(t)`     = number of documents containing term `t`
22//! - `tf(t,d)`  = frequency of `t` in `d`
23//! - `|d|`      = length of `d` in tokens
24//! - `avgdl`    = average document length across the corpus
25//! - `k1`, `b`  = tuning constants (Q4 — fixed at SQLite FTS5 defaults)
26//!
27//! The `+ 1` inside the IDF log keeps the term non-negative even when
28//! `n(t) > N/2`, which would otherwise give the classic BM25 negative
29//! IDF and require clipping. This is the "BM25+" / Lucene variant.
30
31use std::collections::HashMap;
32
33/// Tuning parameters for BM25. Per Phase 8 Q4 the public surface still
34/// exposes these as a struct so we can grow per-call overrides later
35/// without breaking signatures, but the [`Bm25Params::default()`] values
36/// (`k1 = 1.5`, `b = 0.75`) are fixed for the MVP and match SQLite FTS5.
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub struct Bm25Params {
39    /// Term-frequency saturation. Higher → less aggressive saturation
40    /// (each additional occurrence keeps adding to the score). Typical
41    /// range is `[1.2, 2.0]`; SQLite FTS5 ships `1.5`.
42    pub k1: f64,
43    /// Length-normalization weight. `0.0` → no length normalization,
44    /// `1.0` → fully proportional. SQLite FTS5 ships `0.75`.
45    pub b: f64,
46}
47
48impl Default for Bm25Params {
49    fn default() -> Self {
50        Self { k1: 1.5, b: 0.75 }
51    }
52}
53
54/// Compute the BM25 score for a single (document, query) pair.
55///
56/// - `query_terms` is the pre-tokenized query. Duplicate tokens are
57///   summed naturally — if the user typed `"rust rust db"`, the `rust`
58///   contribution gets counted twice, matching the standard formulation.
59/// - `term_freq` maps each *unique* term in the document to its
60///   frequency within that document. The caller can build this from
61///   [`super::tokenizer::tokenize`] output.
62/// - `n_docs_with` is the corpus statistic — for each term, how many
63///   distinct documents contain it. Only entries for query terms are
64///   read; extra entries are ignored.
65/// - Returns `0.0` for the empty query, the empty corpus
66///   (`total_docs == 0`), or a document whose terms don't intersect the
67///   query.
68pub fn score(
69    query_terms: &[String],
70    term_freq: &HashMap<String, u32>,
71    doc_len: u32,
72    avg_doc_len: f64,
73    n_docs_with: &HashMap<String, u32>,
74    total_docs: u32,
75    params: &Bm25Params,
76) -> f64 {
77    if query_terms.is_empty() || total_docs == 0 {
78        return 0.0;
79    }
80
81    let n = total_docs as f64;
82    let dl = doc_len as f64;
83    // avgdl == 0 only if every doc is empty; guard the division.
84    let length_norm = if avg_doc_len > 0.0 {
85        params.b * (dl / avg_doc_len)
86    } else {
87        0.0
88    };
89    let denom_base = params.k1 * (1.0 - params.b + length_norm);
90
91    let mut total = 0.0;
92    for term in query_terms {
93        let tf = term_freq.get(term).copied().unwrap_or(0) as f64;
94        if tf == 0.0 {
95            continue;
96        }
97        let n_t = n_docs_with.get(term).copied().unwrap_or(0) as f64;
98        // BM25+ IDF: ln(1 + (N - n_t + 0.5) / (n_t + 0.5))
99        let idf = (1.0 + (n - n_t + 0.5) / (n_t + 0.5)).ln();
100        let numerator = tf * (params.k1 + 1.0);
101        let denominator = tf + denom_base;
102        total += idf * (numerator / denominator);
103    }
104    total
105}
106
107#[cfg(test)]
108mod tests {
109    use super::*;
110
111    fn p() -> Bm25Params {
112        Bm25Params::default()
113    }
114
115    fn tf(pairs: &[(&str, u32)]) -> HashMap<String, u32> {
116        pairs.iter().map(|(k, v)| ((*k).to_string(), *v)).collect()
117    }
118
119    #[test]
120    fn empty_query_or_corpus_returns_zero() {
121        assert_eq!(score(&[], &tf(&[]), 0, 0.0, &tf(&[]), 0, &p()), 0.0);
122        let q = vec!["rust".to_string()];
123        assert_eq!(
124            score(
125                &q,
126                &tf(&[("rust", 3)]),
127                10,
128                10.0,
129                &tf(&[("rust", 1)]),
130                0,
131                &p()
132            ),
133            0.0
134        );
135    }
136
137    #[test]
138    fn zero_term_freq_yields_zero_score() {
139        let q = vec!["rust".to_string()];
140        let s = score(
141            &q,
142            &tf(&[("python", 5)]),
143            10,
144            10.0,
145            &tf(&[("rust", 1), ("python", 1)]),
146            5,
147            &p(),
148        );
149        assert_eq!(s, 0.0);
150    }
151
152    #[test]
153    fn higher_tf_strictly_higher_score_at_fixed_length() {
154        let q = vec!["rust".to_string()];
155        let n_docs_with = tf(&[("rust", 2)]);
156        let s_low = score(&q, &tf(&[("rust", 1)]), 10, 10.0, &n_docs_with, 100, &p());
157        let s_hi = score(&q, &tf(&[("rust", 5)]), 10, 10.0, &n_docs_with, 100, &p());
158        assert!(s_hi > s_low, "tf=5 ({}) should beat tf=1 ({})", s_hi, s_low);
159    }
160
161    #[test]
162    fn longer_doc_scores_lower_at_same_tf() {
163        // Same term-frequency, longer document → length normalization
164        // (b > 0) drags the score down.
165        let q = vec!["rust".to_string()];
166        let n_docs_with = tf(&[("rust", 2)]);
167        let s_short = score(&q, &tf(&[("rust", 3)]), 10, 50.0, &n_docs_with, 100, &p());
168        let s_long = score(&q, &tf(&[("rust", 3)]), 200, 50.0, &n_docs_with, 100, &p());
169        assert!(
170            s_short > s_long,
171            "short ({}) should beat long ({}) at same tf",
172            s_short,
173            s_long
174        );
175    }
176
177    #[test]
178    fn rare_term_dominates_common_term() {
179        // "the" appears in every doc (n_t == N) → IDF ≈ 0.4 (positive but
180        // small, BM25+ doesn't go negative). "quasar" appears in 1 doc →
181        // IDF much larger. Same TF + length, the rare term wins.
182        let q_common = vec!["the".to_string()];
183        let q_rare = vec!["quasar".to_string()];
184        let n_docs_with = tf(&[("the", 1000), ("quasar", 1)]);
185        let s_common = score(
186            &q_common,
187            &tf(&[("the", 2)]),
188            20,
189            20.0,
190            &n_docs_with,
191            1000,
192            &p(),
193        );
194        let s_rare = score(
195            &q_rare,
196            &tf(&[("quasar", 2)]),
197            20,
198            20.0,
199            &n_docs_with,
200            1000,
201            &p(),
202        );
203        assert!(
204            s_rare > s_common * 5.0,
205            "rare term ({}) should dominate common term ({})",
206            s_rare,
207            s_common
208        );
209    }
210
211    #[test]
212    fn hand_computed_reference_three_doc_corpus() {
213        // 3-doc corpus, query = ["rust"]:
214        //   doc1: "rust rust db"      tf=2, len=3
215        //   doc2: "rust db lang"      tf=1, len=3
216        //   doc3: "python db tool"    tf=0, len=3
217        // n("rust") = 2, N = 3, avgdl = 3.0, k1=1.5, b=0.75
218        //
219        //   length_norm  = 0.75 * (3 / 3) = 0.75
220        //   denom_base   = 1.5 * (1 - 0.75 + 0.75) = 1.5
221        //   idf("rust")  = ln(1 + (3 - 2 + 0.5) / (2 + 0.5))
222        //                = ln(1 + 1.5/2.5) = ln(1.6) = 0.47000362924...
223        //
224        //   doc1: 0.47000362924... * (2 * 2.5) / (2 + 1.5)
225        //       = 0.47000362924... * 5 / 3.5
226        //       = 0.67143375606...
227        //   doc2: 0.47000362924... * (1 * 2.5) / (1 + 1.5)
228        //       = 0.47000362924... * 2.5 / 2.5
229        //       = 0.47000362924...
230        //   doc3: 0.0 (no rust)
231        let q = vec!["rust".to_string()];
232        let n_docs_with = tf(&[
233            ("rust", 2),
234            ("db", 3),
235            ("lang", 1),
236            ("python", 1),
237            ("tool", 1),
238        ]);
239        let avgdl = 3.0;
240        let s1 = score(
241            &q,
242            &tf(&[("rust", 2), ("db", 1)]),
243            3,
244            avgdl,
245            &n_docs_with,
246            3,
247            &p(),
248        );
249        let s2 = score(
250            &q,
251            &tf(&[("rust", 1), ("db", 1), ("lang", 1)]),
252            3,
253            avgdl,
254            &n_docs_with,
255            3,
256            &p(),
257        );
258        let s3 = score(
259            &q,
260            &tf(&[("python", 1), ("db", 1), ("tool", 1)]),
261            3,
262            avgdl,
263            &n_docs_with,
264            3,
265            &p(),
266        );
267
268        let idf = (1.0_f64 + (3.0 - 2.0 + 0.5) / (2.0 + 0.5)).ln();
269        let expected_s1 = idf * (2.0 * (1.5 + 1.0)) / (2.0 + 1.5);
270        let expected_s2 = idf * (1.0 * (1.5 + 1.0)) / (1.0 + 1.5);
271        let tol = f64::EPSILON * 16.0;
272        assert!(
273            (s1 - expected_s1).abs() < tol,
274            "doc1 score {} vs expected {}",
275            s1,
276            expected_s1
277        );
278        assert!(
279            (s2 - expected_s2).abs() < tol,
280            "doc2 score {} vs expected {}",
281            s2,
282            expected_s2
283        );
284        assert_eq!(s3, 0.0);
285        assert!(s1 > s2, "doc1 (tf=2) should outrank doc2 (tf=1)");
286    }
287
288    #[test]
289    fn duplicate_query_tokens_compound() {
290        let q_one = vec!["rust".to_string()];
291        let q_two = vec!["rust".to_string(), "rust".to_string()];
292        let n_docs_with = tf(&[("rust", 2)]);
293        let s1 = score(&q_one, &tf(&[("rust", 1)]), 5, 5.0, &n_docs_with, 10, &p());
294        let s2 = score(&q_two, &tf(&[("rust", 1)]), 5, 5.0, &n_docs_with, 10, &p());
295        assert!(
296            (s2 - 2.0 * s1).abs() < f64::EPSILON * 8.0,
297            "duplicated query token should double the score: 2*s1={}, s2={}",
298            2.0 * s1,
299            s2
300        );
301    }
302}