Skip to main content

sochdb_query/
bm25_filtered.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! BM25 Filter Pushdown via Posting-Set Intersection (Task 6)
19//!
20//! This module implements filter pushdown for BM25 full-text search by
21//! intersecting posting lists with the allowed set BEFORE scoring.
22//!
23//! ## Key Insight
24//!
25//! ```text
26//! Traditional BM25: score ALL docs with term → filter → top-k
27//! Filtered BM25: intersect posting(term) ∩ AllowedSet → score only intersection
28//! ```
29//!
30//! ## Optimization: Term Reordering
31//!
32//! For multi-term queries, we reorder terms by increasing document frequency (DF):
33//!
34//! ```text
35//! Query: "machine learning algorithms"
36//! DFs: machine=10000, learning=8000, algorithms=500
37//!
38//! Order: algorithms → learning → machine
39//!
40//! Step 1: posting(algorithms) ∩ AllowedSet → small_set
41//! Step 2: small_set ∩ posting(learning) → smaller_set  
42//! Step 3: smaller_set ∩ posting(machine) → final_set
43//! ```
44//!
45//! This minimizes intermediate set sizes, reducing memory and CPU.
46//!
47//! ## Cost Model
48//!
49//! Let:
50//! - N = total docs
51//! - |A| = allowed set size
52//! - df(t) = document frequency of term t
53//!
54//! Without pushdown: O(Σ df(t)) scoring operations
55//! With pushdown: O(min(|A|, min(df(t))) scoring operations
56//!
57//! When |A| << N, this is a massive win.
58
59use std::collections::HashMap;
60use std::sync::Arc;
61
62use crate::candidate_gate::AllowedSet;
63use crate::filtered_vector_search::ScoredResult;
64
65// ============================================================================
66// BM25 Parameters
67// ============================================================================
68
69/// BM25 scoring parameters
70#[derive(Debug, Clone)]
71pub struct Bm25Params {
72    /// Term frequency saturation parameter (default: 1.2)
73    pub k1: f32,
74    /// Document length normalization (default: 0.75)
75    pub b: f32,
76    /// Average document length (computed from corpus)
77    pub avgdl: f32,
78    /// Total number of documents in corpus
79    pub total_docs: u64,
80}
81
82impl Default for Bm25Params {
83    fn default() -> Self {
84        Self {
85            k1: 1.2,
86            b: 0.75,
87            avgdl: 100.0,
88            total_docs: 1_000_000,
89        }
90    }
91}
92
93impl Bm25Params {
94    /// Compute IDF for a term
95    pub fn idf(&self, doc_freq: u64) -> f32 {
96        let n = self.total_docs as f32;
97        let df = doc_freq as f32;
98        ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
99    }
100
101    /// Compute term score component
102    pub fn term_score(&self, tf: f32, doc_len: f32, idf: f32) -> f32 {
103        let numerator = tf * (self.k1 + 1.0);
104        let denominator = tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avgdl);
105        idf * numerator / denominator
106    }
107}
108
109// ============================================================================
110// Term Posting List
111// ============================================================================
112
113/// A posting list for a single term
114#[derive(Debug, Clone)]
115pub struct PostingList {
116    /// Term text
117    pub term: String,
118    /// Document IDs containing this term (sorted)
119    pub doc_ids: Vec<u64>,
120    /// Term frequencies in each document
121    pub term_freqs: Vec<u32>,
122    /// Document frequency (len of doc_ids)
123    pub doc_freq: u64,
124}
125
126impl PostingList {
127    /// Create a new posting list
128    pub fn new(term: impl Into<String>, entries: Vec<(u64, u32)>) -> Self {
129        let term = term.into();
130        let doc_freq = entries.len() as u64;
131        let mut doc_ids = Vec::with_capacity(entries.len());
132        let mut term_freqs = Vec::with_capacity(entries.len());
133
134        for (doc_id, tf) in entries {
135            doc_ids.push(doc_id);
136            term_freqs.push(tf);
137        }
138
139        Self {
140            term,
141            doc_ids,
142            term_freqs,
143            doc_freq,
144        }
145    }
146
147    /// Intersect with an allowed set, returning (doc_id, tf) pairs
148    pub fn intersect_with_allowed(&self, allowed: &AllowedSet) -> Vec<(u64, u32)> {
149        match allowed {
150            AllowedSet::All => self
151                .doc_ids
152                .iter()
153                .zip(self.term_freqs.iter())
154                .map(|(&id, &tf)| (id, tf))
155                .collect(),
156            AllowedSet::None => vec![],
157            _ => self
158                .doc_ids
159                .iter()
160                .zip(self.term_freqs.iter())
161                .filter(|&(&id, _)| allowed.contains(id))
162                .map(|(&id, &tf)| (id, tf))
163                .collect(),
164        }
165    }
166}
167
168// ============================================================================
169// Inverted Index Interface
170// ============================================================================
171
172/// Trait for accessing an inverted index
173pub trait InvertedIndex: Send + Sync {
174    /// Get posting list for a term (None if term not in vocabulary)
175    fn get_posting_list(&self, term: &str) -> Option<PostingList>;
176
177    /// Get document length (in tokens)
178    fn get_doc_length(&self, doc_id: u64) -> Option<u32>;
179
180    /// Get BM25 parameters
181    fn get_params(&self) -> &Bm25Params;
182}
183
184// ============================================================================
185// Filtered BM25 Executor
186// ============================================================================
187
188/// A BM25 executor that applies filter pushdown
189pub struct FilteredBm25Executor<I: InvertedIndex> {
190    index: Arc<I>,
191}
192
193impl<I: InvertedIndex> FilteredBm25Executor<I> {
194    /// Create a new executor
195    pub fn new(index: Arc<I>) -> Self {
196        Self { index }
197    }
198
199    /// Execute a BM25 query with filter pushdown
200    ///
201    /// # Algorithm
202    ///
203    /// 1. Tokenize query
204    /// 2. Get posting lists for each term
205    /// 3. Sort terms by DF (ascending) for early pruning
206    /// 4. Intersect posting lists with AllowedSet (in DF order)
207    /// 5. Score only docs in intersection
208    /// 6. Return top-k
209    pub fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
210        // Short-circuit if nothing allowed
211        if allowed.is_empty() {
212            return vec![];
213        }
214
215        // Tokenize (simple whitespace split for now)
216        let terms: Vec<&str> = query
217            .split_whitespace()
218            .filter(|t| t.len() >= 2) // Skip very short terms
219            .collect();
220
221        if terms.is_empty() {
222            return vec![];
223        }
224
225        // Get posting lists and sort by DF
226        let mut posting_lists: Vec<PostingList> = terms
227            .iter()
228            .filter_map(|t| self.index.get_posting_list(t))
229            .collect();
230
231        // Sort by document frequency (ascending)
232        posting_lists.sort_by_key(|pl| pl.doc_freq);
233
234        // Progressive intersection with AllowedSet
235        let candidates = self.progressive_intersection(&posting_lists, allowed);
236
237        if candidates.is_empty() {
238            return vec![];
239        }
240
241        // Score candidates
242        let params = self.index.get_params();
243        let scores = self.score_candidates(&candidates, &posting_lists, params);
244
245        // Return top-k
246        self.top_k(scores, k)
247    }
248
249    /// Progressively intersect posting lists with allowed set
250    ///
251    /// Returns map of doc_id -> term frequencies for each term
252    fn progressive_intersection(
253        &self,
254        posting_lists: &[PostingList],
255        allowed: &AllowedSet,
256    ) -> HashMap<u64, Vec<u32>> {
257        if posting_lists.is_empty() {
258            return HashMap::new();
259        }
260
261        // Start with first (smallest DF) term intersected with allowed
262        let first = &posting_lists[0];
263        let mut candidates: HashMap<u64, Vec<u32>> = first
264            .intersect_with_allowed(allowed)
265            .into_iter()
266            .map(|(id, tf)| (id, vec![tf]))
267            .collect();
268
269        // Intersect with remaining terms
270        for (_term_idx, posting_list) in posting_lists.iter().enumerate().skip(1) {
271            // Build a lookup for this term's postings
272            let term_postings: HashMap<u64, u32> = posting_list
273                .doc_ids
274                .iter()
275                .zip(posting_list.term_freqs.iter())
276                .map(|(&id, &tf)| (id, tf))
277                .collect();
278
279            // Keep only candidates that appear in this term's postings
280            candidates.retain(|doc_id, tfs| {
281                if let Some(&tf) = term_postings.get(doc_id) {
282                    tfs.push(tf);
283                    true
284                } else {
285                    false
286                }
287            });
288
289            // Early exit if no candidates remain
290            if candidates.is_empty() {
291                break;
292            }
293        }
294
295        candidates
296    }
297
298    /// Score candidates using BM25
299    fn score_candidates(
300        &self,
301        candidates: &HashMap<u64, Vec<u32>>,
302        posting_lists: &[PostingList],
303        params: &Bm25Params,
304    ) -> Vec<ScoredResult> {
305        // Precompute IDFs
306        let idfs: Vec<f32> = posting_lists
307            .iter()
308            .map(|pl| params.idf(pl.doc_freq))
309            .collect();
310
311        candidates
312            .iter()
313            .filter_map(|(&doc_id, tfs)| {
314                let doc_len = self.index.get_doc_length(doc_id)? as f32;
315
316                let score: f32 = tfs
317                    .iter()
318                    .zip(idfs.iter())
319                    .map(|(&tf, &idf)| params.term_score(tf as f32, doc_len, idf))
320                    .sum();
321
322                Some(ScoredResult::new(doc_id, score))
323            })
324            .collect()
325    }
326
327    /// Get top-k results
328    fn top_k(&self, mut scores: Vec<ScoredResult>, k: usize) -> Vec<ScoredResult> {
329        // Sort by score descending
330        scores.sort_by(|a, b| {
331            b.score
332                .partial_cmp(&a.score)
333                .unwrap_or(std::cmp::Ordering::Equal)
334        });
335        scores.truncate(k);
336        scores
337    }
338}
339
340// ============================================================================
341// Disjunctive (OR) Query Support
342// ============================================================================
343
344/// A disjunctive BM25 query (OR semantics)
345pub struct DisjunctiveBm25Executor<I: InvertedIndex> {
346    index: Arc<I>,
347}
348
349impl<I: InvertedIndex> DisjunctiveBm25Executor<I> {
350    /// Create a new executor
351    pub fn new(index: Arc<I>) -> Self {
352        Self { index }
353    }
354
355    /// Execute with OR semantics (documents match if they have ANY query term)
356    pub fn search(&self, query: &str, k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
357        if allowed.is_empty() {
358            return vec![];
359        }
360
361        let terms: Vec<&str> = query.split_whitespace().collect();
362        if terms.is_empty() {
363            return vec![];
364        }
365
366        // Get posting lists
367        let posting_lists: Vec<PostingList> = terms
368            .iter()
369            .filter_map(|t| self.index.get_posting_list(t))
370            .collect();
371
372        let params = self.index.get_params();
373
374        // Accumulate scores for all docs matching any term
375        let mut scores: HashMap<u64, f32> = HashMap::new();
376
377        for posting_list in &posting_lists {
378            let idf = params.idf(posting_list.doc_freq);
379
380            // Only score docs in allowed set
381            for (&doc_id, &tf) in posting_list
382                .doc_ids
383                .iter()
384                .zip(posting_list.term_freqs.iter())
385            {
386                if !allowed.contains(doc_id) {
387                    continue;
388                }
389
390                if let Some(doc_len) = self.index.get_doc_length(doc_id) {
391                    let term_score = params.term_score(tf as f32, doc_len as f32, idf);
392                    *scores.entry(doc_id).or_insert(0.0) += term_score;
393                }
394            }
395        }
396
397        // Convert to results and get top-k
398        let mut results: Vec<ScoredResult> = scores
399            .into_iter()
400            .map(|(id, score)| ScoredResult::new(id, score))
401            .collect();
402
403        results.sort_by(|a, b| {
404            b.score
405                .partial_cmp(&a.score)
406                .unwrap_or(std::cmp::Ordering::Equal)
407        });
408        results.truncate(k);
409        results
410    }
411}
412
413// ============================================================================
414// Phrase Query Support
415// ============================================================================
416
417/// Position information for a term in a document
418#[derive(Debug, Clone)]
419pub struct PositionalPosting {
420    /// Document ID
421    pub doc_id: u64,
422    /// Positions where term appears (sorted)
423    pub positions: Vec<u32>,
424}
425
426/// Trait for positional index access
427pub trait PositionalIndex: InvertedIndex {
428    /// Get positional posting list
429    fn get_positional_posting(&self, term: &str) -> Option<Vec<PositionalPosting>>;
430}
431
432/// Phrase query executor with filter pushdown
433pub struct FilteredPhraseExecutor<I: PositionalIndex> {
434    index: Arc<I>,
435}
436
437impl<I: PositionalIndex> FilteredPhraseExecutor<I> {
438    /// Create a new executor
439    pub fn new(index: Arc<I>) -> Self {
440        Self { index }
441    }
442
443    /// Execute a phrase query
444    ///
445    /// Documents must contain all terms in sequence.
446    pub fn search(&self, phrase: &[&str], k: usize, allowed: &AllowedSet) -> Vec<ScoredResult> {
447        if phrase.is_empty() || allowed.is_empty() {
448            return vec![];
449        }
450
451        // Get positional postings for each term
452        let mut positional_postings: Vec<Vec<PositionalPosting>> = vec![];
453        for term in phrase {
454            match self.index.get_positional_posting(term) {
455                Some(postings) => positional_postings.push(postings),
456                None => return vec![], // Term not in index → no matches
457            }
458        }
459
460        // Find documents containing all terms
461        let candidates = self.find_phrase_matches(&positional_postings, allowed);
462
463        // Score by phrase frequency
464        let params = self.index.get_params();
465        let results: Vec<ScoredResult> = candidates
466            .into_iter()
467            .filter_map(|(doc_id, phrase_freq)| {
468                let doc_len = self.index.get_doc_length(doc_id)? as f32;
469                // Use phrase frequency as TF, use min DF for IDF approximation
470                let min_df = positional_postings
471                    .iter()
472                    .map(|pp| pp.len() as u64)
473                    .min()
474                    .unwrap_or(1);
475                let idf = params.idf(min_df);
476                let score = params.term_score(phrase_freq as f32, doc_len, idf);
477                Some(ScoredResult::new(doc_id, score))
478            })
479            .collect();
480
481        let mut results = results;
482        results.sort_by(|a, b| {
483            b.score
484                .partial_cmp(&a.score)
485                .unwrap_or(std::cmp::Ordering::Equal)
486        });
487        results.truncate(k);
488        results
489    }
490
491    /// Find documents containing the phrase
492    fn find_phrase_matches(
493        &self,
494        positional_postings: &[Vec<PositionalPosting>],
495        allowed: &AllowedSet,
496    ) -> Vec<(u64, u32)> {
497        if positional_postings.is_empty() {
498            return vec![];
499        }
500
501        // Index postings by doc_id
502        let indexed: Vec<HashMap<u64, &Vec<u32>>> = positional_postings
503            .iter()
504            .map(|postings| {
505                postings
506                    .iter()
507                    .filter(|p| allowed.contains(p.doc_id))
508                    .map(|p| (p.doc_id, &p.positions))
509                    .collect()
510            })
511            .collect();
512
513        // Start with docs having first term
514        let first_docs: std::collections::HashSet<u64> = indexed[0].keys().copied().collect();
515
516        // Keep only docs that have all terms
517        let candidate_docs: Vec<u64> = first_docs
518            .into_iter()
519            .filter(|doc_id| indexed.iter().all(|idx| idx.contains_key(doc_id)))
520            .collect();
521
522        // Check phrase positions
523        let mut matches = vec![];
524
525        for doc_id in candidate_docs {
526            let mut phrase_count = 0u32;
527
528            // Get positions for first term
529            let first_positions = indexed[0].get(&doc_id).unwrap();
530
531            'outer: for &start_pos in first_positions.iter() {
532                // Check if subsequent terms appear at consecutive positions
533                for (term_idx, term_positions) in indexed.iter().enumerate().skip(1) {
534                    let expected_pos = start_pos + term_idx as u32;
535                    let positions = term_positions.get(&doc_id).unwrap();
536
537                    // Binary search for expected position
538                    if positions.binary_search(&expected_pos).is_err() {
539                        continue 'outer;
540                    }
541                }
542
543                // Found a phrase match
544                phrase_count += 1;
545            }
546
547            if phrase_count > 0 {
548                matches.push((doc_id, phrase_count));
549            }
550        }
551
552        matches
553    }
554}
555
556// ============================================================================
557// Tests
558// ============================================================================
559
560#[cfg(test)]
561mod tests {
562    use super::*;
563    use crate::candidate_gate::AllowedSet;
564
565    // Mock inverted index for testing
566    struct MockIndex {
567        postings: HashMap<String, PostingList>,
568        doc_lengths: HashMap<u64, u32>,
569        params: Bm25Params,
570    }
571
572    impl MockIndex {
573        fn new() -> Self {
574            let mut postings = HashMap::new();
575            let mut doc_lengths = HashMap::new();
576
577            // Add some test data
578            postings.insert(
579                "rust".to_string(),
580                PostingList::new("rust", vec![(1, 3), (2, 1), (3, 2), (5, 1)]),
581            );
582            postings.insert(
583                "database".to_string(),
584                PostingList::new("database", vec![(1, 1), (3, 4), (4, 1)]),
585            );
586            postings.insert(
587                "vector".to_string(),
588                PostingList::new("vector", vec![(1, 2), (2, 3), (4, 1), (5, 2)]),
589            );
590
591            // Doc lengths
592            for i in 1..=5 {
593                doc_lengths.insert(i, 100);
594            }
595
596            Self {
597                postings,
598                doc_lengths,
599                params: Bm25Params {
600                    k1: 1.2,
601                    b: 0.75,
602                    avgdl: 100.0,
603                    total_docs: 1000,
604                },
605            }
606        }
607    }
608
609    impl InvertedIndex for MockIndex {
610        fn get_posting_list(&self, term: &str) -> Option<PostingList> {
611            self.postings.get(term).cloned()
612        }
613
614        fn get_doc_length(&self, doc_id: u64) -> Option<u32> {
615            self.doc_lengths.get(&doc_id).copied()
616        }
617
618        fn get_params(&self) -> &Bm25Params {
619            &self.params
620        }
621    }
622
623    #[test]
624    fn test_conjunctive_search() {
625        let index = Arc::new(MockIndex::new());
626        let executor = FilteredBm25Executor::new(index);
627
628        // Search for "rust database"
629        // Should match docs 1 and 3 (have both terms)
630        let results = executor.search("rust database", 10, &AllowedSet::All);
631
632        assert_eq!(results.len(), 2);
633        let doc_ids: Vec<u64> = results.iter().map(|r| r.doc_id).collect();
634        assert!(doc_ids.contains(&1));
635        assert!(doc_ids.contains(&3));
636    }
637
638    #[test]
639    fn test_filter_pushdown() {
640        let index = Arc::new(MockIndex::new());
641        let executor = FilteredBm25Executor::new(index);
642
643        // Only allow doc 1
644        let allowed = AllowedSet::SortedVec(Arc::new(vec![1]));
645
646        let results = executor.search("rust database", 10, &allowed);
647
648        assert_eq!(results.len(), 1);
649        assert_eq!(results[0].doc_id, 1);
650    }
651
652    #[test]
653    fn test_empty_allowed_set() {
654        let index = Arc::new(MockIndex::new());
655        let executor = FilteredBm25Executor::new(index);
656
657        let results = executor.search("rust", 10, &AllowedSet::None);
658        assert!(results.is_empty());
659    }
660
661    #[test]
662    fn test_disjunctive_search() {
663        let index = Arc::new(MockIndex::new());
664        let executor = DisjunctiveBm25Executor::new(index);
665
666        // Search for "rust database" with OR semantics
667        // Should match docs 1, 2, 3, 4, 5 (any with either term)
668        let results = executor.search("rust database", 10, &AllowedSet::All);
669
670        // Docs 1-5 have at least one of the terms
671        assert!(results.len() >= 4);
672    }
673
674    #[test]
675    fn test_term_ordering_by_df() {
676        // This tests that progressive intersection starts with lowest DF
677        let mut pl1 = PostingList::new("rare", vec![(1, 1), (2, 1)]); // DF=2
678        let mut pl2 = PostingList::new("common", vec![(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]); // DF=5
679
680        let mut lists = vec![pl2.clone(), pl1.clone()];
681        lists.sort_by_key(|pl| pl.doc_freq);
682
683        // Should be sorted: rare (DF=2) before common (DF=5)
684        assert_eq!(lists[0].term, "rare");
685        assert_eq!(lists[1].term, "common");
686    }
687
688    #[test]
689    fn test_bm25_scoring() {
690        let params = Bm25Params::default();
691
692        // IDF for rare term (DF=10 in 1M docs) should be higher than common (DF=100K)
693        let idf_rare = params.idf(10);
694        let idf_common = params.idf(100_000);
695
696        assert!(idf_rare > idf_common);
697
698        // Score with higher TF should be higher
699        let score_tf_1 = params.term_score(1.0, 100.0, idf_rare);
700        let score_tf_5 = params.term_score(5.0, 100.0, idf_rare);
701
702        assert!(score_tf_5 > score_tf_1);
703    }
704}