sqlrite/sql/fts/bm25.rs
1//! BM25 relevance scoring — the standard ranking function for keyword
2//! retrieval. Pure math; no SQL coupling.
3//!
4//! Resolves Phase 8 plan Q4 + Q5: no stemming and no stop-list. The
5//! caller is responsible for tokenizing both the query and the document
6//! (see [`super::tokenizer::tokenize`]); this module just consumes term
7//! frequencies + corpus stats and produces a score.
8//!
9//! ## Formula (Robertson/Spärck Jones BM25)
10//!
11//! For a document `d` and query `q`:
12//!
13//! ```text
14//! score(d, q) = Σ_{t ∈ q} idf(t) · (tf(t,d) · (k1 + 1)) /
15//! (tf(t,d) + k1 · (1 - b + b · |d| / avgdl))
16//!
17//! idf(t) = ln(1 + (N - n(t) + 0.5) / (n(t) + 0.5))
18//! ```
19//!
20//! - `N` = total documents in corpus
21//! - `n(t)` = number of documents containing term `t`
22//! - `tf(t,d)` = frequency of `t` in `d`
23//! - `|d|` = length of `d` in tokens
24//! - `avgdl` = average document length across the corpus
25//! - `k1`, `b` = tuning constants (Q4 — fixed at SQLite FTS5 defaults)
26//!
27//! The `+ 1` inside the IDF log keeps the term non-negative even when
28//! `n(t) > N/2`, which would otherwise give the classic BM25 negative
29//! IDF and require clipping. This is the "BM25+" / Lucene variant.
30
31use std::collections::HashMap;
32
33/// Tuning parameters for BM25. Per Phase 8 Q4 the public surface still
34/// exposes these as a struct so we can grow per-call overrides later
35/// without breaking signatures, but the [`Bm25Params::default()`] values
36/// (`k1 = 1.5`, `b = 0.75`) are fixed for the MVP and match SQLite FTS5.
37#[derive(Debug, Clone, Copy, PartialEq)]
38pub struct Bm25Params {
39 /// Term-frequency saturation. Higher → less aggressive saturation
40 /// (each additional occurrence keeps adding to the score). Typical
41 /// range is `[1.2, 2.0]`; SQLite FTS5 ships `1.5`.
42 pub k1: f64,
43 /// Length-normalization weight. `0.0` → no length normalization,
44 /// `1.0` → fully proportional. SQLite FTS5 ships `0.75`.
45 pub b: f64,
46}
47
48impl Default for Bm25Params {
49 fn default() -> Self {
50 Self { k1: 1.5, b: 0.75 }
51 }
52}
53
54/// Compute the BM25 score for a single (document, query) pair.
55///
56/// - `query_terms` is the pre-tokenized query. Duplicate tokens are
57/// summed naturally — if the user typed `"rust rust db"`, the `rust`
58/// contribution gets counted twice, matching the standard formulation.
59/// - `term_freq` maps each *unique* term in the document to its
60/// frequency within that document. The caller can build this from
61/// [`super::tokenizer::tokenize`] output.
62/// - `n_docs_with` is the corpus statistic — for each term, how many
63/// distinct documents contain it. Only entries for query terms are
64/// read; extra entries are ignored.
65/// - Returns `0.0` for the empty query, the empty corpus
66/// (`total_docs == 0`), or a document whose terms don't intersect the
67/// query.
68pub fn score(
69 query_terms: &[String],
70 term_freq: &HashMap<String, u32>,
71 doc_len: u32,
72 avg_doc_len: f64,
73 n_docs_with: &HashMap<String, u32>,
74 total_docs: u32,
75 params: &Bm25Params,
76) -> f64 {
77 if query_terms.is_empty() || total_docs == 0 {
78 return 0.0;
79 }
80
81 let n = total_docs as f64;
82 let dl = doc_len as f64;
83 // avgdl == 0 only if every doc is empty; guard the division.
84 let length_norm = if avg_doc_len > 0.0 {
85 params.b * (dl / avg_doc_len)
86 } else {
87 0.0
88 };
89 let denom_base = params.k1 * (1.0 - params.b + length_norm);
90
91 let mut total = 0.0;
92 for term in query_terms {
93 let tf = term_freq.get(term).copied().unwrap_or(0) as f64;
94 if tf == 0.0 {
95 continue;
96 }
97 let n_t = n_docs_with.get(term).copied().unwrap_or(0) as f64;
98 // BM25+ IDF: ln(1 + (N - n_t + 0.5) / (n_t + 0.5))
99 let idf = (1.0 + (n - n_t + 0.5) / (n_t + 0.5)).ln();
100 let numerator = tf * (params.k1 + 1.0);
101 let denominator = tf + denom_base;
102 total += idf * (numerator / denominator);
103 }
104 total
105}
106
107#[cfg(test)]
108mod tests {
109 use super::*;
110
111 fn p() -> Bm25Params {
112 Bm25Params::default()
113 }
114
115 fn tf(pairs: &[(&str, u32)]) -> HashMap<String, u32> {
116 pairs.iter().map(|(k, v)| ((*k).to_string(), *v)).collect()
117 }
118
119 #[test]
120 fn empty_query_or_corpus_returns_zero() {
121 assert_eq!(score(&[], &tf(&[]), 0, 0.0, &tf(&[]), 0, &p()), 0.0);
122 let q = vec!["rust".to_string()];
123 assert_eq!(
124 score(
125 &q,
126 &tf(&[("rust", 3)]),
127 10,
128 10.0,
129 &tf(&[("rust", 1)]),
130 0,
131 &p()
132 ),
133 0.0
134 );
135 }
136
137 #[test]
138 fn zero_term_freq_yields_zero_score() {
139 let q = vec!["rust".to_string()];
140 let s = score(
141 &q,
142 &tf(&[("python", 5)]),
143 10,
144 10.0,
145 &tf(&[("rust", 1), ("python", 1)]),
146 5,
147 &p(),
148 );
149 assert_eq!(s, 0.0);
150 }
151
152 #[test]
153 fn higher_tf_strictly_higher_score_at_fixed_length() {
154 let q = vec!["rust".to_string()];
155 let n_docs_with = tf(&[("rust", 2)]);
156 let s_low = score(&q, &tf(&[("rust", 1)]), 10, 10.0, &n_docs_with, 100, &p());
157 let s_hi = score(&q, &tf(&[("rust", 5)]), 10, 10.0, &n_docs_with, 100, &p());
158 assert!(s_hi > s_low, "tf=5 ({}) should beat tf=1 ({})", s_hi, s_low);
159 }
160
161 #[test]
162 fn longer_doc_scores_lower_at_same_tf() {
163 // Same term-frequency, longer document → length normalization
164 // (b > 0) drags the score down.
165 let q = vec!["rust".to_string()];
166 let n_docs_with = tf(&[("rust", 2)]);
167 let s_short = score(&q, &tf(&[("rust", 3)]), 10, 50.0, &n_docs_with, 100, &p());
168 let s_long = score(&q, &tf(&[("rust", 3)]), 200, 50.0, &n_docs_with, 100, &p());
169 assert!(
170 s_short > s_long,
171 "short ({}) should beat long ({}) at same tf",
172 s_short,
173 s_long
174 );
175 }
176
177 #[test]
178 fn rare_term_dominates_common_term() {
179 // "the" appears in every doc (n_t == N) → IDF ≈ 0.4 (positive but
180 // small, BM25+ doesn't go negative). "quasar" appears in 1 doc →
181 // IDF much larger. Same TF + length, the rare term wins.
182 let q_common = vec!["the".to_string()];
183 let q_rare = vec!["quasar".to_string()];
184 let n_docs_with = tf(&[("the", 1000), ("quasar", 1)]);
185 let s_common = score(
186 &q_common,
187 &tf(&[("the", 2)]),
188 20,
189 20.0,
190 &n_docs_with,
191 1000,
192 &p(),
193 );
194 let s_rare = score(
195 &q_rare,
196 &tf(&[("quasar", 2)]),
197 20,
198 20.0,
199 &n_docs_with,
200 1000,
201 &p(),
202 );
203 assert!(
204 s_rare > s_common * 5.0,
205 "rare term ({}) should dominate common term ({})",
206 s_rare,
207 s_common
208 );
209 }
210
211 #[test]
212 fn hand_computed_reference_three_doc_corpus() {
213 // 3-doc corpus, query = ["rust"]:
214 // doc1: "rust rust db" tf=2, len=3
215 // doc2: "rust db lang" tf=1, len=3
216 // doc3: "python db tool" tf=0, len=3
217 // n("rust") = 2, N = 3, avgdl = 3.0, k1=1.5, b=0.75
218 //
219 // length_norm = 0.75 * (3 / 3) = 0.75
220 // denom_base = 1.5 * (1 - 0.75 + 0.75) = 1.5
221 // idf("rust") = ln(1 + (3 - 2 + 0.5) / (2 + 0.5))
222 // = ln(1 + 1.5/2.5) = ln(1.6) = 0.47000362924...
223 //
224 // doc1: 0.47000362924... * (2 * 2.5) / (2 + 1.5)
225 // = 0.47000362924... * 5 / 3.5
226 // = 0.67143375606...
227 // doc2: 0.47000362924... * (1 * 2.5) / (1 + 1.5)
228 // = 0.47000362924... * 2.5 / 2.5
229 // = 0.47000362924...
230 // doc3: 0.0 (no rust)
231 let q = vec!["rust".to_string()];
232 let n_docs_with = tf(&[
233 ("rust", 2),
234 ("db", 3),
235 ("lang", 1),
236 ("python", 1),
237 ("tool", 1),
238 ]);
239 let avgdl = 3.0;
240 let s1 = score(
241 &q,
242 &tf(&[("rust", 2), ("db", 1)]),
243 3,
244 avgdl,
245 &n_docs_with,
246 3,
247 &p(),
248 );
249 let s2 = score(
250 &q,
251 &tf(&[("rust", 1), ("db", 1), ("lang", 1)]),
252 3,
253 avgdl,
254 &n_docs_with,
255 3,
256 &p(),
257 );
258 let s3 = score(
259 &q,
260 &tf(&[("python", 1), ("db", 1), ("tool", 1)]),
261 3,
262 avgdl,
263 &n_docs_with,
264 3,
265 &p(),
266 );
267
268 let idf = (1.0_f64 + (3.0 - 2.0 + 0.5) / (2.0 + 0.5)).ln();
269 let expected_s1 = idf * (2.0 * (1.5 + 1.0)) / (2.0 + 1.5);
270 let expected_s2 = idf * (1.0 * (1.5 + 1.0)) / (1.0 + 1.5);
271 let tol = f64::EPSILON * 16.0;
272 assert!(
273 (s1 - expected_s1).abs() < tol,
274 "doc1 score {} vs expected {}",
275 s1,
276 expected_s1
277 );
278 assert!(
279 (s2 - expected_s2).abs() < tol,
280 "doc2 score {} vs expected {}",
281 s2,
282 expected_s2
283 );
284 assert_eq!(s3, 0.0);
285 assert!(s1 > s2, "doc1 (tf=2) should outrank doc2 (tf=1)");
286 }
287
288 #[test]
289 fn duplicate_query_tokens_compound() {
290 let q_one = vec!["rust".to_string()];
291 let q_two = vec!["rust".to_string(), "rust".to_string()];
292 let n_docs_with = tf(&[("rust", 2)]);
293 let s1 = score(&q_one, &tf(&[("rust", 1)]), 5, 5.0, &n_docs_with, 10, &p());
294 let s2 = score(&q_two, &tf(&[("rust", 1)]), 5, 5.0, &n_docs_with, 10, &p());
295 assert!(
296 (s2 - 2.0 * s1).abs() < f64::EPSILON * 8.0,
297 "duplicated query token should double the score: 2*s1={}, s2={}",
298 2.0 * s1,
299 s2
300 );
301 }
302}