Skip to main content

ruvector_graph/
bm25.rs

1//! Compact BM25 keyword index over a node text property (ADR-252 P4).
2//!
3//! Provides the keyword arm of the tri-modal hybrid query (BM25 + ANN vector +
4//! graph traversal). Self-contained — an in-memory inverted index with
5//! Okapi BM25 scoring, no external search engine. Built from `(NodeId, &str)`
6//! pairs and queried for the top-k by keyword relevance.
7
8use crate::types::NodeId;
9use std::collections::HashMap;
10
11/// Okapi BM25 parameters. Defaults `k1=1.2`, `b=0.75` are the standard choices.
12#[derive(Debug, Clone, Copy)]
13pub struct Bm25Params {
14    pub k1: f32,
15    pub b: f32,
16}
17
18impl Default for Bm25Params {
19    fn default() -> Self {
20        Self { k1: 1.2, b: 0.75 }
21    }
22}
23
24/// In-memory BM25 inverted index over a single text field.
25#[derive(Debug, Clone)]
26pub struct Bm25Index {
27    params: Bm25Params,
28    /// term -> postings as (doc index, term frequency).
29    postings: HashMap<String, Vec<(u32, u32)>>,
30    doc_ids: Vec<NodeId>,
31    doc_len: Vec<u32>,
32    avgdl: f32,
33}
34
35impl Bm25Index {
36    /// Lowercase, split on non-alphanumeric. Cheap and dependency-free.
37    pub fn tokenize(text: &str) -> Vec<String> {
38        text.split(|c: char| !c.is_alphanumeric())
39            .filter(|t| !t.is_empty())
40            .map(|t| t.to_ascii_lowercase())
41            .collect()
42    }
43
44    /// Build the index from `(id, text)` pairs.
45    pub fn build<I, S>(docs: I, params: Bm25Params) -> Self
46    where
47        I: IntoIterator<Item = (NodeId, S)>,
48        S: AsRef<str>,
49    {
50        let mut postings: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
51        let mut doc_ids = Vec::new();
52        let mut doc_len = Vec::new();
53        let mut total_len: u64 = 0;
54
55        for (id, text) in docs {
56            let doc_idx = doc_ids.len() as u32;
57            let tokens = Self::tokenize(text.as_ref());
58            doc_len.push(tokens.len() as u32);
59            total_len += tokens.len() as u64;
60
61            // Term frequencies within this doc.
62            let mut tf: HashMap<String, u32> = HashMap::new();
63            for tok in tokens {
64                *tf.entry(tok).or_insert(0) += 1;
65            }
66            for (term, freq) in tf {
67                postings.entry(term).or_default().push((doc_idx, freq));
68            }
69            doc_ids.push(id);
70        }
71
72        let n = doc_ids.len().max(1) as f32;
73        let avgdl = if doc_ids.is_empty() { 0.0 } else { total_len as f32 / n };
74        Self { params, postings, doc_ids, doc_len, avgdl }
75    }
76
77    /// Number of indexed documents.
78    pub fn len(&self) -> usize {
79        self.doc_ids.len()
80    }
81    pub fn is_empty(&self) -> bool {
82        self.doc_ids.is_empty()
83    }
84
85    /// Top-`k` documents by BM25 score for `query`, descending. Only documents
86    /// with a positive score are returned.
87    pub fn search(&self, query: &str, k: usize) -> Vec<(NodeId, f32)> {
88        if self.doc_ids.is_empty() || k == 0 {
89            return Vec::new();
90        }
91        let n = self.doc_ids.len() as f32;
92        let (k1, b) = (self.params.k1, self.params.b);
93        let mut scores: HashMap<u32, f32> = HashMap::new();
94
95        // Deduplicate query terms; each contributes once via its idf.
96        let mut seen_terms = std::collections::HashSet::new();
97        for term in Self::tokenize(query) {
98            if !seen_terms.insert(term.clone()) {
99                continue;
100            }
101            let Some(postings) = self.postings.get(&term) else {
102                continue;
103            };
104            let df = postings.len() as f32;
105            // Robertson/Spärck-Jones idf with +1 to stay non-negative.
106            let idf = ((n - df + 0.5) / (df + 0.5) + 1.0).ln();
107            for &(doc_idx, freq) in postings {
108                let dl = self.doc_len[doc_idx as usize] as f32;
109                let tf = freq as f32;
110                let denom = tf + k1 * (1.0 - b + b * dl / self.avgdl.max(1e-6));
111                let contribution = idf * (tf * (k1 + 1.0)) / denom;
112                *scores.entry(doc_idx).or_insert(0.0) += contribution;
113            }
114        }
115
116        let mut ranked: Vec<(NodeId, f32)> = scores
117            .into_iter()
118            .map(|(idx, s)| (self.doc_ids[idx as usize].clone(), s))
119            .collect();
120        ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
121        ranked.truncate(k);
122        ranked
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129
130    fn corpus() -> Vec<(NodeId, &'static str)> {
131        vec![
132            ("d1".into(), "the quick brown fox jumps over the lazy dog"),
133            ("d2".into(), "machine learning models for vector search"),
134            ("d3".into(), "vector databases enable semantic search at scale"),
135            ("d4".into(), "a recipe for italian pasta with tomato sauce"),
136        ]
137    }
138
139    #[test]
140    fn ranks_relevant_docs_first() {
141        let idx = Bm25Index::build(corpus(), Bm25Params::default());
142        assert_eq!(idx.len(), 4);
143        let res = idx.search("vector search", 4);
144        assert!(!res.is_empty());
145        // d2 and d3 both mention "vector" and "search"; pasta doc must not lead.
146        assert!(res[0].0 == "d2" || res[0].0 == "d3");
147        assert!(res.iter().all(|(id, _)| id != "d4") || res.last().unwrap().0 == "d4");
148    }
149
150    #[test]
151    fn idf_downweights_common_terms() {
152        let idx = Bm25Index::build(corpus(), Bm25Params::default());
153        // "the" appears in d1 only here but is short; "pasta" is rare → strong signal.
154        let res = idx.search("pasta", 4);
155        assert_eq!(res[0].0, "d4");
156    }
157
158    #[test]
159    fn empty_query_and_index_safe() {
160        let empty = Bm25Index::build(Vec::<(NodeId, &str)>::new(), Bm25Params::default());
161        assert!(empty.search("anything", 5).is_empty());
162        let idx = Bm25Index::build(corpus(), Bm25Params::default());
163        assert!(idx.search("", 5).is_empty());
164        assert!(idx.search("zzz nonexistent", 5).is_empty());
165    }
166}