1use std::collections::{BTreeMap, HashMap};
35
36use super::bm25::{Bm25Params, score as bm25_score};
37use super::tokenizer::tokenize;
38
39#[derive(Debug, Default, Clone)]
41pub struct PostingList {
42 postings: BTreeMap<String, BTreeMap<i64, u32>>,
44 doc_lengths: BTreeMap<i64, u32>,
48 total_tokens: u64,
51}
52
53impl PostingList {
54 pub fn new() -> Self {
56 Self::default()
57 }
58
59 pub fn len(&self) -> usize {
61 self.doc_lengths.len()
62 }
63
64 pub fn is_empty(&self) -> bool {
66 self.doc_lengths.is_empty()
67 }
68
69 pub fn avg_doc_len(&self) -> f64 {
72 if self.doc_lengths.is_empty() {
73 0.0
74 } else {
75 self.total_tokens as f64 / self.doc_lengths.len() as f64
76 }
77 }
78
79 pub fn serialize_doc_lengths(&self) -> Vec<(i64, u32)> {
84 self.doc_lengths
85 .iter()
86 .map(|(id, len)| (*id, *len))
87 .collect()
88 }
89
90 pub fn serialize_postings(&self) -> Vec<(String, Vec<(i64, u32)>)> {
96 self.postings
97 .iter()
98 .map(|(term, postings)| {
99 let entries = postings.iter().map(|(id, freq)| (*id, *freq)).collect();
100 (term.clone(), entries)
101 })
102 .collect()
103 }
104
105 pub fn from_persisted_postings<I, J>(doc_lengths: I, postings: J) -> Self
114 where
115 I: IntoIterator<Item = (i64, u32)>,
116 J: IntoIterator<Item = (String, Vec<(i64, u32)>)>,
117 {
118 let mut doc_lengths_map: BTreeMap<i64, u32> = BTreeMap::new();
119 let mut total_tokens: u64 = 0;
120 for (rowid, len) in doc_lengths {
121 doc_lengths_map.insert(rowid, len);
122 total_tokens += len as u64;
123 }
124
125 let mut postings_map: BTreeMap<String, BTreeMap<i64, u32>> = BTreeMap::new();
126 for (term, entries) in postings {
127 let inner: BTreeMap<i64, u32> = entries.into_iter().collect();
128 if !inner.is_empty() {
132 postings_map.insert(term, inner);
133 }
134 }
135
136 Self {
137 postings: postings_map,
138 doc_lengths: doc_lengths_map,
139 total_tokens,
140 }
141 }
142
143 pub fn insert(&mut self, rowid: i64, text: &str) {
151 if self.doc_lengths.contains_key(&rowid) {
152 self.remove(rowid);
153 }
154
155 let tokens = tokenize(text);
156 let doc_len = tokens.len() as u32;
157 self.total_tokens += doc_len as u64;
158 self.doc_lengths.insert(rowid, doc_len);
159
160 let mut tf: HashMap<&str, u32> = HashMap::new();
164 for tok in &tokens {
165 *tf.entry(tok.as_str()).or_insert(0) += 1;
166 }
167 for (term, freq) in tf {
168 self.postings
169 .entry(term.to_string())
170 .or_default()
171 .insert(rowid, freq);
172 }
173 }
174
175 pub fn remove(&mut self, rowid: i64) {
179 let Some(doc_len) = self.doc_lengths.remove(&rowid) else {
180 return;
181 };
182 self.total_tokens -= doc_len as u64;
183
184 let mut empty_terms = Vec::new();
188 for (term, postings) in self.postings.iter_mut() {
189 if postings.remove(&rowid).is_some() && postings.is_empty() {
190 empty_terms.push(term.clone());
191 }
192 }
193 for term in empty_terms {
194 self.postings.remove(&term);
195 }
196 }
197
198 pub fn matches(&self, rowid: i64, query: &str) -> bool {
202 if !self.doc_lengths.contains_key(&rowid) {
203 return false;
204 }
205 for term in tokenize(query) {
206 if let Some(postings) = self.postings.get(&term) {
207 if postings.contains_key(&rowid) {
208 return true;
209 }
210 }
211 }
212 false
213 }
214
215 pub fn score(&self, rowid: i64, query: &str, params: &Bm25Params) -> f64 {
218 let Some(&doc_len) = self.doc_lengths.get(&rowid) else {
219 return 0.0;
220 };
221 let query_terms = tokenize(query);
222 if query_terms.is_empty() {
223 return 0.0;
224 }
225
226 let term_freq = self.term_freq_for_doc(rowid, &query_terms);
227 let n_docs_with = self.n_docs_with_for_terms(&query_terms);
228 bm25_score(
229 &query_terms,
230 &term_freq,
231 doc_len,
232 self.avg_doc_len(),
233 &n_docs_with,
234 self.doc_lengths.len() as u32,
235 params,
236 )
237 }
238
239 pub fn query(&self, query: &str, params: &Bm25Params) -> Vec<(i64, f64)> {
248 let query_terms = tokenize(query);
249 if query_terms.is_empty() || self.doc_lengths.is_empty() {
250 return Vec::new();
251 }
252
253 let mut candidates: BTreeMap<i64, u32> = BTreeMap::new();
258 for term in &query_terms {
259 if let Some(postings) = self.postings.get(term) {
260 for &rowid in postings.keys() {
261 candidates.entry(rowid).or_insert(0);
262 }
263 }
264 }
265 if candidates.is_empty() {
266 return Vec::new();
267 }
268
269 let n_docs_with = self.n_docs_with_for_terms(&query_terms);
270 let avg = self.avg_doc_len();
271 let total_docs = self.doc_lengths.len() as u32;
272
273 let mut scored: Vec<(i64, f64)> = candidates
274 .into_keys()
275 .map(|rowid| {
276 let doc_len = self.doc_lengths[&rowid];
277 let tf = self.term_freq_for_doc(rowid, &query_terms);
278 let s = bm25_score(
279 &query_terms,
280 &tf,
281 doc_len,
282 avg,
283 &n_docs_with,
284 total_docs,
285 params,
286 );
287 (rowid, s)
288 })
289 .collect();
290
291 scored.sort_by(|a, b| {
295 b.1.partial_cmp(&a.1)
296 .unwrap_or(std::cmp::Ordering::Equal)
297 .then_with(|| a.0.cmp(&b.0))
298 });
299 scored
300 }
301
302 fn term_freq_for_doc(&self, rowid: i64, query_terms: &[String]) -> HashMap<String, u32> {
303 let mut tf = HashMap::with_capacity(query_terms.len());
304 for term in query_terms {
305 if tf.contains_key(term) {
306 continue;
307 }
308 let freq = self
309 .postings
310 .get(term)
311 .and_then(|p| p.get(&rowid).copied())
312 .unwrap_or(0);
313 tf.insert(term.clone(), freq);
314 }
315 tf
316 }
317
318 fn n_docs_with_for_terms(&self, query_terms: &[String]) -> HashMap<String, u32> {
319 let mut n = HashMap::with_capacity(query_terms.len());
320 for term in query_terms {
321 if n.contains_key(term) {
322 continue;
323 }
324 let count = self.postings.get(term).map(|p| p.len() as u32).unwrap_or(0);
325 n.insert(term.clone(), count);
326 }
327 n
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334
335 #[test]
336 fn empty_list_is_empty() {
337 let pl = PostingList::new();
338 assert!(pl.is_empty());
339 assert_eq!(pl.len(), 0);
340 assert_eq!(pl.avg_doc_len(), 0.0);
341 assert!(pl.query("anything", &Bm25Params::default()).is_empty());
342 assert_eq!(pl.score(1, "anything", &Bm25Params::default()), 0.0);
343 assert!(!pl.matches(1, "anything"));
344 }
345
346 #[test]
347 fn empty_query_returns_empty_results() {
348 let mut pl = PostingList::new();
349 pl.insert(1, "rust embedded database");
350 assert!(pl.query("", &Bm25Params::default()).is_empty());
351 assert!(pl.query("!!!", &Bm25Params::default()).is_empty());
352 assert_eq!(pl.score(1, "", &Bm25Params::default()), 0.0);
353 }
354
355 #[test]
356 fn insert_and_query_two_docs_ranks_correctly() {
357 let mut pl = PostingList::new();
358 pl.insert(1, "rust rust embedded database");
359 pl.insert(2, "rust language");
360 let res = pl.query("rust", &Bm25Params::default());
361 assert_eq!(res.len(), 2);
362 let (id_a, s_a) = res[0];
367 let (id_b, s_b) = res[1];
368 assert!(s_a > 0.0 && s_b > 0.0);
369 assert!(s_a >= s_b);
370 assert!(
371 (id_a == 1 || id_a == 2) && (id_b == 1 || id_b == 2) && id_a != id_b,
372 "result rowids should be {{1,2}}, got ({}, {})",
373 id_a,
374 id_b
375 );
376
377 assert!(pl.matches(1, "rust"));
379 assert!(pl.matches(2, "rust"));
380 assert!(!pl.matches(1, "python"));
381 }
382
383 #[test]
384 fn score_method_matches_bulk_query() {
385 let mut pl = PostingList::new();
386 pl.insert(10, "rust embedded database");
387 pl.insert(20, "go embedded database");
388 pl.insert(30, "python web framework");
389
390 let params = Bm25Params::default();
391 let bulk = pl.query("embedded", ¶ms);
392 for (rowid, score) in &bulk {
393 let direct = pl.score(*rowid, "embedded", ¶ms);
394 assert!(
395 (direct - score).abs() < f64::EPSILON * 16.0,
396 "score({}, ...) = {} vs query() reported {}",
397 rowid,
398 direct,
399 score
400 );
401 }
402 assert_eq!(pl.score(30, "embedded", ¶ms), 0.0);
403 }
404
405 #[test]
406 fn remove_clears_doc_and_prunes_empty_terms() {
407 let mut pl = PostingList::new();
408 pl.insert(1, "rust");
409 pl.insert(2, "rust embedded");
410 assert_eq!(pl.len(), 2);
411 assert_eq!(pl.total_tokens, 3);
412 assert!(pl.postings.contains_key("rust"));
413 assert!(pl.postings.contains_key("embedded"));
414
415 pl.remove(2);
416 assert_eq!(pl.len(), 1);
417 assert_eq!(pl.total_tokens, 1);
418 assert!(!pl.postings.contains_key("embedded"));
420 assert!(pl.postings.contains_key("rust"));
421
422 pl.remove(1);
423 assert!(pl.is_empty());
424 assert!(pl.postings.is_empty());
425 assert_eq!(pl.total_tokens, 0);
426
427 pl.remove(1);
429 pl.remove(99);
430 assert!(pl.is_empty());
431 }
432
433 #[test]
434 fn reinsert_replaces_prior_postings() {
435 let mut pl = PostingList::new();
436 pl.insert(1, "rust rust rust");
437 assert_eq!(pl.postings["rust"][&1], 3);
438 assert_eq!(pl.total_tokens, 3);
439
440 pl.insert(1, "go");
441 assert_eq!(pl.len(), 1);
442 assert_eq!(pl.total_tokens, 1);
443 assert!(!pl.postings.contains_key("rust"));
444 assert_eq!(pl.postings["go"][&1], 1);
445 }
446
447 #[test]
448 fn tie_break_orders_by_rowid_ascending() {
449 let mut pl = PostingList::new();
451 pl.insert(7, "alpha beta");
452 pl.insert(3, "alpha beta");
453 pl.insert(5, "alpha beta");
454 let res = pl.query("alpha", &Bm25Params::default());
455 let ids: Vec<i64> = res.iter().map(|(id, _)| *id).collect();
456 assert_eq!(ids, vec![3, 5, 7]);
457 let s = res[0].1;
459 for (_, score) in &res {
460 assert_eq!(*score, s);
461 }
462 }
463
464 #[test]
465 fn multi_term_query_unions_candidates_any_term() {
466 let mut pl = PostingList::new();
467 pl.insert(1, "rust embedded");
468 pl.insert(2, "rust web");
469 pl.insert(3, "go embedded");
470 pl.insert(4, "python web");
471 let res = pl.query("rust embedded", &Bm25Params::default());
472 let ids: std::collections::BTreeSet<i64> = res.iter().map(|(id, _)| *id).collect();
473 assert_eq!(ids, [1, 2, 3].iter().copied().collect());
476 assert_eq!(res[0].0, 1);
478 }
479
480 #[test]
481 fn serialize_round_trips_through_from_persisted() {
482 let mut pl = PostingList::new();
486 pl.insert(1, "rust embedded database");
487 pl.insert(2, "rust web framework");
488 pl.insert(3, ""); pl.insert(4, "rust rust rust embedded power");
490
491 let docs = pl.serialize_doc_lengths();
492 let postings = pl.serialize_postings();
493 let roundtripped = PostingList::from_persisted_postings(docs, postings);
494
495 assert_eq!(roundtripped.len(), pl.len(), "doc count");
496 assert_eq!(roundtripped.avg_doc_len(), pl.avg_doc_len(), "avg_doc_len");
497 let q = pl.query("rust", &Bm25Params::default());
499 let q2 = roundtripped.query("rust", &Bm25Params::default());
500 assert_eq!(q, q2, "query results must match after round-trip");
501 assert!(roundtripped.matches(1, "rust"));
504 assert!(!roundtripped.matches(3, "rust"));
505 }
506
507 #[test]
508 fn synthetic_thousand_doc_corpus_top_ten_is_stable() {
509 let mut pl = PostingList::new();
514 let rare_rows: [i64; 5] = [137, 248, 391, 642, 873];
515 for i in 0..1000_i64 {
516 let words = ["alpha", "beta", "gamma", "delta", "epsilon", "zeta"];
518 let pick_a = words[((i as usize) * 7) % words.len()];
519 let pick_b = words[((i as usize) * 13 + 1) % words.len()];
520 let body = if rare_rows.contains(&i) {
521 format!("quasar {} {}", pick_a, pick_b)
522 } else {
523 format!("{} {}", pick_a, pick_b)
524 };
525 pl.insert(i, &body);
526 }
527 assert_eq!(pl.len(), 1000);
528
529 let res = pl.query("quasar", &Bm25Params::default());
530 assert_eq!(res.len(), 5, "exactly five docs should contain 'quasar'");
531 let returned: std::collections::BTreeSet<i64> = res.iter().map(|(id, _)| *id).collect();
532 let expected: std::collections::BTreeSet<i64> = rare_rows.iter().copied().collect();
533 assert_eq!(returned, expected);
534
535 let res2 = pl.query("quasar", &Bm25Params::default());
538 assert_eq!(res, res2);
539 }
540}