Skip to main content

sqlrite/sql/fts/
posting_list.rs

1//! In-memory inverted index for FTS — `term -> { rowid -> term_freq }`,
2//! plus per-document length cache. Wraps the [`super::tokenizer`] +
3//! [`super::bm25`] primitives into a usable index. Pure data structure;
4//! no SQL coupling.
5//!
6//! Mirrors the role of [`crate::sql::hnsw::HnswIndex`] in 7d.1: this is
7//! the in-memory state that 8b will hang off `Table` (via a future
8//! `fts_indexes: Vec<FtsIndex>` field) and that 8c will serialize into
9//! `KIND_FTS_POSTING` cells.
10//!
11//! ## Identity choices
12//!
13//! - Rowids are `i64` (matches HNSW's `node_id` and SQLRite's row-id
14//!   convention; see [`crate::sql::hnsw::HnswIndex::insert`]).
15//! - The map structure is `BTreeMap<String, BTreeMap<i64, u32>>` rather
16//!   than `HashMap` so that (1) persistence (8c) gets a deterministic
17//!   on-disk byte order for free — postings are emitted in lexicographic
18//!   term order, each posting list in ascending rowid order — and (2)
19//!   tests get stable ordering without sorting. `HashMap` is faster on a
20//!   per-op basis but the lookups in the FTS hot path are bounded by
21//!   query-term count (single digits in practice), so the BTreeMap log-N
22//!   factor is negligible.
23//!
24//! ## What it does NOT do (yet)
25//!
26//! - **No persistence.** State lives entirely in memory. 8c wires it
27//!   into the page format under cell-kind `0x06`.
28//! - **No transaction integration.** 8b is responsible for batching
29//!   updates inside a `BEGIN; ... COMMIT;` block.
30//! - **No phrase / boolean queries.** Single-token any-term match only
31//!   for the MVP per the plan's "Out of scope" section. Multi-token
32//!   queries OR the per-term hits — no AND, NOT, or positional info.
33
34use std::collections::{BTreeMap, HashMap};
35
36use super::bm25::{Bm25Params, score as bm25_score};
37use super::tokenizer::tokenize;
38
39/// In-memory inverted index. See module-level doc.
40#[derive(Debug, Default, Clone)]
41pub struct PostingList {
42    /// Term -> { rowid -> term frequency in that doc }.
43    postings: BTreeMap<String, BTreeMap<i64, u32>>,
44    /// Rowid -> document length (in tokens, post-tokenization).
45    /// Acts as the canonical "set of indexed rowids" — `len()` and
46    /// `is_empty()` derive from this.
47    doc_lengths: BTreeMap<i64, u32>,
48    /// Sum of all `doc_lengths` values; tracked incrementally to make
49    /// [`avg_doc_len`] O(1) regardless of corpus size.
50    total_tokens: u64,
51}
52
53impl PostingList {
54    /// Empty index with no postings and no documents.
55    pub fn new() -> Self {
56        Self::default()
57    }
58
59    /// Number of indexed documents.
60    pub fn len(&self) -> usize {
61        self.doc_lengths.len()
62    }
63
64    /// True iff no document has been inserted (or all have been removed).
65    pub fn is_empty(&self) -> bool {
66        self.doc_lengths.is_empty()
67    }
68
69    /// Average document length in tokens. Returns `0.0` when the index
70    /// is empty so BM25 can guard cleanly without a div-by-zero.
71    pub fn avg_doc_len(&self) -> f64 {
72        if self.doc_lengths.is_empty() {
73            0.0
74        } else {
75            self.total_tokens as f64 / self.doc_lengths.len() as f64
76        }
77    }
78
79    /// Phase 8c — emit `(rowid, doc_len)` pairs for every indexed doc,
80    /// in ascending rowid order. The pager writes these into the FTS
81    /// index's doc-lengths sidecar cell; reload feeds them back to
82    /// [`Self::from_persisted_postings`].
83    pub fn serialize_doc_lengths(&self) -> Vec<(i64, u32)> {
84        self.doc_lengths
85            .iter()
86            .map(|(id, len)| (*id, *len))
87            .collect()
88    }
89
90    /// Phase 8c — emit `(term, [(rowid, term_freq)])` triples in
91    /// lexicographic term order; per-term entries are in ascending
92    /// rowid order (the underlying `BTreeMap` already guarantees this).
93    /// One element per unique indexed term; pager writes one cell per
94    /// element.
95    pub fn serialize_postings(&self) -> Vec<(String, Vec<(i64, u32)>)> {
96        self.postings
97            .iter()
98            .map(|(term, postings)| {
99                let entries = postings.iter().map(|(id, freq)| (*id, *freq)).collect();
100                (term.clone(), entries)
101            })
102            .collect()
103    }
104
105    /// Phase 8c — rebuild a `PostingList` directly from the persisted
106    /// doc-lengths sidecar + per-term postings. No tokenization runs;
107    /// the resulting index is byte-equivalent to what was saved
108    /// (assuming the input came from `serialize_*`).
109    ///
110    /// `doc_lengths` is the full `(rowid, doc_len)` map written into
111    /// the sidecar cell. `postings` is one `(term, [(rowid, tf)])`
112    /// element per term cell.
113    pub fn from_persisted_postings<I, J>(doc_lengths: I, postings: J) -> Self
114    where
115        I: IntoIterator<Item = (i64, u32)>,
116        J: IntoIterator<Item = (String, Vec<(i64, u32)>)>,
117    {
118        let mut doc_lengths_map: BTreeMap<i64, u32> = BTreeMap::new();
119        let mut total_tokens: u64 = 0;
120        for (rowid, len) in doc_lengths {
121            doc_lengths_map.insert(rowid, len);
122            total_tokens += len as u64;
123        }
124
125        let mut postings_map: BTreeMap<String, BTreeMap<i64, u32>> = BTreeMap::new();
126        for (term, entries) in postings {
127            let inner: BTreeMap<i64, u32> = entries.into_iter().collect();
128            // An empty posting list shouldn't be persisted, but if it
129            // somehow was, drop it on load — `remove()` would have
130            // pruned the same way at runtime.
131            if !inner.is_empty() {
132                postings_map.insert(term, inner);
133            }
134        }
135
136        Self {
137            postings: postings_map,
138            doc_lengths: doc_lengths_map,
139            total_tokens,
140        }
141    }
142
143    /// Tokenize `text` and add its postings under `rowid`. If `rowid` is
144    /// already indexed, its previous postings are removed first — i.e.
145    /// `insert` is idempotent for re-indexing the same row.
146    ///
147    /// A row whose tokenization yields zero tokens is still recorded
148    /// (with `doc_len = 0` and no posting entries). This keeps `len()`
149    /// honest for "indexed but empty" rows; BM25 returns 0.0 for them.
150    pub fn insert(&mut self, rowid: i64, text: &str) {
151        if self.doc_lengths.contains_key(&rowid) {
152            self.remove(rowid);
153        }
154
155        let tokens = tokenize(text);
156        let doc_len = tokens.len() as u32;
157        self.total_tokens += doc_len as u64;
158        self.doc_lengths.insert(rowid, doc_len);
159
160        // Aggregate per-term frequency for this doc, then push into the
161        // global postings map. This avoids bumping the same posting
162        // entry repeatedly for a doc with many occurrences of one term.
163        let mut tf: HashMap<&str, u32> = HashMap::new();
164        for tok in &tokens {
165            *tf.entry(tok.as_str()).or_insert(0) += 1;
166        }
167        for (term, freq) in tf {
168            self.postings
169                .entry(term.to_string())
170                .or_default()
171                .insert(rowid, freq);
172        }
173    }
174
175    /// Remove all postings for `rowid`. No-op if `rowid` was never
176    /// inserted. Empty per-term posting lists left behind by the last
177    /// referencing row are pruned to keep the BTreeMap tight.
178    pub fn remove(&mut self, rowid: i64) {
179        let Some(doc_len) = self.doc_lengths.remove(&rowid) else {
180            return;
181        };
182        self.total_tokens -= doc_len as u64;
183
184        // Walk every term — fine because term count grows with vocab,
185        // not corpus size, and remove is rare. 8b's incremental DELETE
186        // path uses the rebuild-at-save strategy (Q7) anyway.
187        let mut empty_terms = Vec::new();
188        for (term, postings) in self.postings.iter_mut() {
189            if postings.remove(&rowid).is_some() && postings.is_empty() {
190                empty_terms.push(term.clone());
191            }
192        }
193        for term in empty_terms {
194            self.postings.remove(&term);
195        }
196    }
197
198    /// True iff `rowid` is indexed and at least one of its terms is in
199    /// the (tokenized) `query`. Powers `fts_match(col, 'q')` in 8b
200    /// without going through scoring.
201    pub fn matches(&self, rowid: i64, query: &str) -> bool {
202        if !self.doc_lengths.contains_key(&rowid) {
203            return false;
204        }
205        for term in tokenize(query) {
206            if let Some(postings) = self.postings.get(&term) {
207                if postings.contains_key(&rowid) {
208                    return true;
209                }
210            }
211        }
212        false
213    }
214
215    /// BM25 score for a single (rowid, query) pair. Returns `0.0` if
216    /// `rowid` is unknown or no query terms hit.
217    pub fn score(&self, rowid: i64, query: &str, params: &Bm25Params) -> f64 {
218        let Some(&doc_len) = self.doc_lengths.get(&rowid) else {
219            return 0.0;
220        };
221        let query_terms = tokenize(query);
222        if query_terms.is_empty() {
223            return 0.0;
224        }
225
226        let term_freq = self.term_freq_for_doc(rowid, &query_terms);
227        let n_docs_with = self.n_docs_with_for_terms(&query_terms);
228        bm25_score(
229            &query_terms,
230            &term_freq,
231            doc_len,
232            self.avg_doc_len(),
233            &n_docs_with,
234            self.doc_lengths.len() as u32,
235            params,
236        )
237    }
238
239    /// Score every doc that contains at least one query term and return
240    /// `(rowid, score)` sorted by score descending, ties broken by
241    /// rowid ascending. Powers the bulk path used by 8b's
242    /// `try_fts_probe` optimizer hook.
243    ///
244    /// Empty query → empty result. Empty index → empty result. Rows
245    /// that don't match any query term are not scored at all (they
246    /// would score 0.0 — including them just bloats the result).
247    pub fn query(&self, query: &str, params: &Bm25Params) -> Vec<(i64, f64)> {
248        let query_terms = tokenize(query);
249        if query_terms.is_empty() || self.doc_lengths.is_empty() {
250            return Vec::new();
251        }
252
253        // Collect candidate rowids: every doc that has at least one
254        // query term in its postings. BTreeMap iteration is sorted, so
255        // the candidate set comes out in ascending rowid order — handy
256        // for the tie-break below.
257        let mut candidates: BTreeMap<i64, u32> = BTreeMap::new();
258        for term in &query_terms {
259            if let Some(postings) = self.postings.get(term) {
260                for &rowid in postings.keys() {
261                    candidates.entry(rowid).or_insert(0);
262                }
263            }
264        }
265        if candidates.is_empty() {
266            return Vec::new();
267        }
268
269        let n_docs_with = self.n_docs_with_for_terms(&query_terms);
270        let avg = self.avg_doc_len();
271        let total_docs = self.doc_lengths.len() as u32;
272
273        let mut scored: Vec<(i64, f64)> = candidates
274            .into_keys()
275            .map(|rowid| {
276                let doc_len = self.doc_lengths[&rowid];
277                let tf = self.term_freq_for_doc(rowid, &query_terms);
278                let s = bm25_score(
279                    &query_terms,
280                    &tf,
281                    doc_len,
282                    avg,
283                    &n_docs_with,
284                    total_docs,
285                    params,
286                );
287                (rowid, s)
288            })
289            .collect();
290
291        // Score desc, then rowid asc on ties. f64::partial_cmp + the
292        // candidate set already being sorted ascending means we only
293        // need a stable sort_by on score.
294        scored.sort_by(|a, b| {
295            b.1.partial_cmp(&a.1)
296                .unwrap_or(std::cmp::Ordering::Equal)
297                .then_with(|| a.0.cmp(&b.0))
298        });
299        scored
300    }
301
302    fn term_freq_for_doc(&self, rowid: i64, query_terms: &[String]) -> HashMap<String, u32> {
303        let mut tf = HashMap::with_capacity(query_terms.len());
304        for term in query_terms {
305            if tf.contains_key(term) {
306                continue;
307            }
308            let freq = self
309                .postings
310                .get(term)
311                .and_then(|p| p.get(&rowid).copied())
312                .unwrap_or(0);
313            tf.insert(term.clone(), freq);
314        }
315        tf
316    }
317
318    fn n_docs_with_for_terms(&self, query_terms: &[String]) -> HashMap<String, u32> {
319        let mut n = HashMap::with_capacity(query_terms.len());
320        for term in query_terms {
321            if n.contains_key(term) {
322                continue;
323            }
324            let count = self.postings.get(term).map(|p| p.len() as u32).unwrap_or(0);
325            n.insert(term.clone(), count);
326        }
327        n
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334
335    #[test]
336    fn empty_list_is_empty() {
337        let pl = PostingList::new();
338        assert!(pl.is_empty());
339        assert_eq!(pl.len(), 0);
340        assert_eq!(pl.avg_doc_len(), 0.0);
341        assert!(pl.query("anything", &Bm25Params::default()).is_empty());
342        assert_eq!(pl.score(1, "anything", &Bm25Params::default()), 0.0);
343        assert!(!pl.matches(1, "anything"));
344    }
345
346    #[test]
347    fn empty_query_returns_empty_results() {
348        let mut pl = PostingList::new();
349        pl.insert(1, "rust embedded database");
350        assert!(pl.query("", &Bm25Params::default()).is_empty());
351        assert!(pl.query("!!!", &Bm25Params::default()).is_empty());
352        assert_eq!(pl.score(1, "", &Bm25Params::default()), 0.0);
353    }
354
355    #[test]
356    fn insert_and_query_two_docs_ranks_correctly() {
357        let mut pl = PostingList::new();
358        pl.insert(1, "rust rust embedded database");
359        pl.insert(2, "rust language");
360        let res = pl.query("rust", &Bm25Params::default());
361        assert_eq!(res.len(), 2);
362        // doc1 has tf=2 in a longer doc; doc2 has tf=1 in a shorter doc.
363        // Length normalization makes the call non-obvious — just check
364        // that the result set contains both rows in some order, with
365        // both scores positive.
366        let (id_a, s_a) = res[0];
367        let (id_b, s_b) = res[1];
368        assert!(s_a > 0.0 && s_b > 0.0);
369        assert!(s_a >= s_b);
370        assert!(
371            (id_a == 1 || id_a == 2) && (id_b == 1 || id_b == 2) && id_a != id_b,
372            "result rowids should be {{1,2}}, got ({}, {})",
373            id_a,
374            id_b
375        );
376
377        // matches() agrees on which rows hit.
378        assert!(pl.matches(1, "rust"));
379        assert!(pl.matches(2, "rust"));
380        assert!(!pl.matches(1, "python"));
381    }
382
383    #[test]
384    fn score_method_matches_bulk_query() {
385        let mut pl = PostingList::new();
386        pl.insert(10, "rust embedded database");
387        pl.insert(20, "go embedded database");
388        pl.insert(30, "python web framework");
389
390        let params = Bm25Params::default();
391        let bulk = pl.query("embedded", &params);
392        for (rowid, score) in &bulk {
393            let direct = pl.score(*rowid, "embedded", &params);
394            assert!(
395                (direct - score).abs() < f64::EPSILON * 16.0,
396                "score({}, ...) = {} vs query() reported {}",
397                rowid,
398                direct,
399                score
400            );
401        }
402        assert_eq!(pl.score(30, "embedded", &params), 0.0);
403    }
404
405    #[test]
406    fn remove_clears_doc_and_prunes_empty_terms() {
407        let mut pl = PostingList::new();
408        pl.insert(1, "rust");
409        pl.insert(2, "rust embedded");
410        assert_eq!(pl.len(), 2);
411        assert_eq!(pl.total_tokens, 3);
412        assert!(pl.postings.contains_key("rust"));
413        assert!(pl.postings.contains_key("embedded"));
414
415        pl.remove(2);
416        assert_eq!(pl.len(), 1);
417        assert_eq!(pl.total_tokens, 1);
418        // "embedded" only existed in doc 2; should be gone now.
419        assert!(!pl.postings.contains_key("embedded"));
420        assert!(pl.postings.contains_key("rust"));
421
422        pl.remove(1);
423        assert!(pl.is_empty());
424        assert!(pl.postings.is_empty());
425        assert_eq!(pl.total_tokens, 0);
426
427        // Idempotent remove.
428        pl.remove(1);
429        pl.remove(99);
430        assert!(pl.is_empty());
431    }
432
433    #[test]
434    fn reinsert_replaces_prior_postings() {
435        let mut pl = PostingList::new();
436        pl.insert(1, "rust rust rust");
437        assert_eq!(pl.postings["rust"][&1], 3);
438        assert_eq!(pl.total_tokens, 3);
439
440        pl.insert(1, "go");
441        assert_eq!(pl.len(), 1);
442        assert_eq!(pl.total_tokens, 1);
443        assert!(!pl.postings.contains_key("rust"));
444        assert_eq!(pl.postings["go"][&1], 1);
445    }
446
447    #[test]
448    fn tie_break_orders_by_rowid_ascending() {
449        // Two identical docs → identical scores → rowid ASC.
450        let mut pl = PostingList::new();
451        pl.insert(7, "alpha beta");
452        pl.insert(3, "alpha beta");
453        pl.insert(5, "alpha beta");
454        let res = pl.query("alpha", &Bm25Params::default());
455        let ids: Vec<i64> = res.iter().map(|(id, _)| *id).collect();
456        assert_eq!(ids, vec![3, 5, 7]);
457        // All three scores should be exactly equal.
458        let s = res[0].1;
459        for (_, score) in &res {
460            assert_eq!(*score, s);
461        }
462    }
463
464    #[test]
465    fn multi_term_query_unions_candidates_any_term() {
466        let mut pl = PostingList::new();
467        pl.insert(1, "rust embedded");
468        pl.insert(2, "rust web");
469        pl.insert(3, "go embedded");
470        pl.insert(4, "python web");
471        let res = pl.query("rust embedded", &Bm25Params::default());
472        let ids: std::collections::BTreeSet<i64> = res.iter().map(|(id, _)| *id).collect();
473        // Per the MVP "any-term" semantic — rowid 4 is the only one with
474        // neither term, so it must NOT appear; the other three must.
475        assert_eq!(ids, [1, 2, 3].iter().copied().collect());
476        // Doc 1 has both terms → should outrank singletons.
477        assert_eq!(res[0].0, 1);
478    }
479
480    #[test]
481    fn serialize_round_trips_through_from_persisted() {
482        // Phase 8c — the (de)serialize pair must reproduce the exact
483        // in-memory state that was saved. Emptiness, multi-term, and
484        // re-insert idempotence all need to round-trip.
485        let mut pl = PostingList::new();
486        pl.insert(1, "rust embedded database");
487        pl.insert(2, "rust web framework");
488        pl.insert(3, ""); // zero-token doc — exercises the sidecar
489        pl.insert(4, "rust rust rust embedded power");
490
491        let docs = pl.serialize_doc_lengths();
492        let postings = pl.serialize_postings();
493        let roundtripped = PostingList::from_persisted_postings(docs, postings);
494
495        assert_eq!(roundtripped.len(), pl.len(), "doc count");
496        assert_eq!(roundtripped.avg_doc_len(), pl.avg_doc_len(), "avg_doc_len");
497        // Every query result + score must match.
498        let q = pl.query("rust", &Bm25Params::default());
499        let q2 = roundtripped.query("rust", &Bm25Params::default());
500        assert_eq!(q, q2, "query results must match after round-trip");
501        // Zero-token doc 3 stays in the corpus stats so total_docs is
502        // honest, even though it'll never match a query.
503        assert!(roundtripped.matches(1, "rust"));
504        assert!(!roundtripped.matches(3, "rust"));
505    }
506
507    #[test]
508    fn synthetic_thousand_doc_corpus_top_ten_is_stable() {
509        // 1000 deterministic docs. Most are noise; only 5 contain the
510        // rare "quasar" term. Top-10 query must surface those 5 (the
511        // remaining slots score 0.0 and aren't returned at all because
512        // we filter to candidates with at least one matching term).
513        let mut pl = PostingList::new();
514        let rare_rows: [i64; 5] = [137, 248, 391, 642, 873];
515        for i in 0..1000_i64 {
516            // Pseudo-random body deterministic in `i`.
517            let words = ["alpha", "beta", "gamma", "delta", "epsilon", "zeta"];
518            let pick_a = words[((i as usize) * 7) % words.len()];
519            let pick_b = words[((i as usize) * 13 + 1) % words.len()];
520            let body = if rare_rows.contains(&i) {
521                format!("quasar {} {}", pick_a, pick_b)
522            } else {
523                format!("{} {}", pick_a, pick_b)
524            };
525            pl.insert(i, &body);
526        }
527        assert_eq!(pl.len(), 1000);
528
529        let res = pl.query("quasar", &Bm25Params::default());
530        assert_eq!(res.len(), 5, "exactly five docs should contain 'quasar'");
531        let returned: std::collections::BTreeSet<i64> = res.iter().map(|(id, _)| *id).collect();
532        let expected: std::collections::BTreeSet<i64> = rare_rows.iter().copied().collect();
533        assert_eq!(returned, expected);
534
535        // Stability: re-running the query yields identical output (no
536        // hidden HashMap order leaking through).
537        let res2 = pl.query("quasar", &Bm25Params::default());
538        assert_eq!(res, res2);
539    }
540}