Skip to main content

sochdb_query/
bm25_filtered.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! BM25 Filter Pushdown via Posting-Set Intersection (Task 6)
19//!
20//! This module implements filter pushdown for BM25 full-text search by
21//! intersecting posting lists with the allowed set BEFORE scoring.
22//!
23//! ## Key Insight
24//!
25//! ```text
26//! Traditional BM25: score ALL docs with term → filter → top-k
27//! Filtered BM25: intersect posting(term) ∩ AllowedSet → score only intersection
28//! ```
29//!
30//! ## Optimization: Term Reordering
31//!
32//! For multi-term queries, we reorder terms by increasing document frequency (DF):
33//!
34//! ```text
35//! Query: "machine learning algorithms"
36//! DFs: machine=10000, learning=8000, algorithms=500
37//! 
38//! Order: algorithms → learning → machine
39//! 
40//! Step 1: posting(algorithms) ∩ AllowedSet → small_set
41//! Step 2: small_set ∩ posting(learning) → smaller_set  
42//! Step 3: smaller_set ∩ posting(machine) → final_set
43//! ```
44//!
45//! This minimizes intermediate set sizes, reducing memory and CPU.
46//!
47//! ## Cost Model
48//!
49//! Let:
50//! - N = total docs
51//! - |A| = allowed set size
52//! - df(t) = document frequency of term t
53//!
54//! Without pushdown: O(Σ df(t)) scoring operations
55//! With pushdown: O(min(|A|, min(df(t))) scoring operations
56//!
57//! When |A| << N, this is a massive win.
58
59use std::collections::HashMap;
60use std::sync::Arc;
61
62use crate::candidate_gate::AllowedSet;
63use crate::filtered_vector_search::ScoredResult;
64
65// ============================================================================
66// BM25 Parameters
67// ============================================================================
68
69/// BM25 scoring parameters
70#[derive(Debug, Clone)]
71pub struct Bm25Params {
72    /// Term frequency saturation parameter (default: 1.2)
73    pub k1: f32,
74    /// Document length normalization (default: 0.75)
75    pub b: f32,
76    /// Average document length (computed from corpus)
77    pub avgdl: f32,
78    /// Total number of documents in corpus
79    pub total_docs: u64,
80}
81
82impl Default for Bm25Params {
83    fn default() -> Self {
84        Self {
85            k1: 1.2,
86            b: 0.75,
87            avgdl: 100.0,
88            total_docs: 1_000_000,
89        }
90    }
91}
92
93impl Bm25Params {
94    /// Compute IDF for a term
95    pub fn idf(&self, doc_freq: u64) -> f32 {
96        let n = self.total_docs as f32;
97        let df = doc_freq as f32;
98        ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
99    }
100    
101    /// Compute term score component
102    pub fn term_score(&self, tf: f32, doc_len: f32, idf: f32) -> f32 {
103        let numerator = tf * (self.k1 + 1.0);
104        let denominator = tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avgdl);
105        idf * numerator / denominator
106    }
107}
108
109// ============================================================================
110// Term Posting List
111// ============================================================================
112
113/// A posting list for a single term
114#[derive(Debug, Clone)]
115pub struct PostingList {
116    /// Term text
117    pub term: String,
118    /// Document IDs containing this term (sorted)
119    pub doc_ids: Vec<u64>,
120    /// Term frequencies in each document
121    pub term_freqs: Vec<u32>,
122    /// Document frequency (len of doc_ids)
123    pub doc_freq: u64,
124}
125
126impl PostingList {
127    /// Create a new posting list
128    pub fn new(term: impl Into<String>, entries: Vec<(u64, u32)>) -> Self {
129        let term = term.into();
130        let doc_freq = entries.len() as u64;
131        let mut doc_ids = Vec::with_capacity(entries.len());
132        let mut term_freqs = Vec::with_capacity(entries.len());
133        
134        for (doc_id, tf) in entries {
135            doc_ids.push(doc_id);
136            term_freqs.push(tf);
137        }
138        
139        Self {
140            term,
141            doc_ids,
142            term_freqs,
143            doc_freq,
144        }
145    }
146    
147    /// Intersect with an allowed set, returning (doc_id, tf) pairs
148    pub fn intersect_with_allowed(&self, allowed: &AllowedSet) -> Vec<(u64, u32)> {
149        match allowed {
150            AllowedSet::All => {
151                self.doc_ids.iter()
152                    .zip(self.term_freqs.iter())
153                    .map(|(&id, &tf)| (id, tf))
154                    .collect()
155            }
156            AllowedSet::None => vec![],
157            _ => {
158                self.doc_ids.iter()
159                    .zip(self.term_freqs.iter())
160                    .filter(|&(&id, _)| allowed.contains(id))
161                    .map(|(&id, &tf)| (id, tf))
162                    .collect()
163            }
164        }
165    }
166}
167
168// ============================================================================
169// Inverted Index Interface
170// ============================================================================
171
172/// Trait for accessing an inverted index
173pub trait InvertedIndex: Send + Sync {
174    /// Get posting list for a term (None if term not in vocabulary)
175    fn get_posting_list(&self, term: &str) -> Option<PostingList>;
176    
177    /// Get document length (in tokens)
178    fn get_doc_length(&self, doc_id: u64) -> Option<u32>;
179    
180    /// Get BM25 parameters
181    fn get_params(&self) -> &Bm25Params;
182}
183
184// ============================================================================
185// Filtered BM25 Executor
186// ============================================================================
187
188/// A BM25 executor that applies filter pushdown
189pub struct FilteredBm25Executor<I: InvertedIndex> {
190    index: Arc<I>,
191}
192
193impl<I: InvertedIndex> FilteredBm25Executor<I> {
194    /// Create a new executor
195    pub fn new(index: Arc<I>) -> Self {
196        Self { index }
197    }
198    
199    /// Execute a BM25 query with filter pushdown
200    ///
201    /// # Algorithm
202    ///
203    /// 1. Tokenize query
204    /// 2. Get posting lists for each term
205    /// 3. Sort terms by DF (ascending) for early pruning
206    /// 4. Intersect posting lists with AllowedSet (in DF order)
207    /// 5. Score only docs in intersection
208    /// 6. Return top-k
209    pub fn search(
210        &self,
211        query: &str,
212        k: usize,
213        allowed: &AllowedSet,
214    ) -> Vec<ScoredResult> {
215        // Short-circuit if nothing allowed
216        if allowed.is_empty() {
217            return vec![];
218        }
219        
220        // Tokenize (simple whitespace split for now)
221        let terms: Vec<&str> = query
222            .split_whitespace()
223            .filter(|t| t.len() >= 2) // Skip very short terms
224            .collect();
225        
226        if terms.is_empty() {
227            return vec![];
228        }
229        
230        // Get posting lists and sort by DF
231        let mut posting_lists: Vec<PostingList> = terms
232            .iter()
233            .filter_map(|t| self.index.get_posting_list(t))
234            .collect();
235        
236        // Sort by document frequency (ascending)
237        posting_lists.sort_by_key(|pl| pl.doc_freq);
238        
239        // Progressive intersection with AllowedSet
240        let candidates = self.progressive_intersection(&posting_lists, allowed);
241        
242        if candidates.is_empty() {
243            return vec![];
244        }
245        
246        // Score candidates
247        let params = self.index.get_params();
248        let scores = self.score_candidates(&candidates, &posting_lists, params);
249        
250        // Return top-k
251        self.top_k(scores, k)
252    }
253    
254    /// Progressively intersect posting lists with allowed set
255    ///
256    /// Returns map of doc_id -> term frequencies for each term
257    fn progressive_intersection(
258        &self,
259        posting_lists: &[PostingList],
260        allowed: &AllowedSet,
261    ) -> HashMap<u64, Vec<u32>> {
262        if posting_lists.is_empty() {
263            return HashMap::new();
264        }
265        
266        // Start with first (smallest DF) term intersected with allowed
267        let first = &posting_lists[0];
268        let mut candidates: HashMap<u64, Vec<u32>> = first
269            .intersect_with_allowed(allowed)
270            .into_iter()
271            .map(|(id, tf)| (id, vec![tf]))
272            .collect();
273        
274        // Intersect with remaining terms
275        for (_term_idx, posting_list) in posting_lists.iter().enumerate().skip(1) {
276            // Build a lookup for this term's postings
277            let term_postings: HashMap<u64, u32> = posting_list
278                .doc_ids.iter()
279                .zip(posting_list.term_freqs.iter())
280                .map(|(&id, &tf)| (id, tf))
281                .collect();
282            
283            // Keep only candidates that appear in this term's postings
284            candidates.retain(|doc_id, tfs| {
285                if let Some(&tf) = term_postings.get(doc_id) {
286                    tfs.push(tf);
287                    true
288                } else {
289                    false
290                }
291            });
292            
293            // Early exit if no candidates remain
294            if candidates.is_empty() {
295                break;
296            }
297        }
298        
299        candidates
300    }
301    
302    /// Score candidates using BM25
303    fn score_candidates(
304        &self,
305        candidates: &HashMap<u64, Vec<u32>>,
306        posting_lists: &[PostingList],
307        params: &Bm25Params,
308    ) -> Vec<ScoredResult> {
309        // Precompute IDFs
310        let idfs: Vec<f32> = posting_lists
311            .iter()
312            .map(|pl| params.idf(pl.doc_freq))
313            .collect();
314        
315        candidates
316            .iter()
317            .filter_map(|(&doc_id, tfs)| {
318                let doc_len = self.index.get_doc_length(doc_id)? as f32;
319                
320                let score: f32 = tfs.iter()
321                    .zip(idfs.iter())
322                    .map(|(&tf, &idf)| params.term_score(tf as f32, doc_len, idf))
323                    .sum();
324                
325                Some(ScoredResult::new(doc_id, score))
326            })
327            .collect()
328    }
329    
330    /// Get top-k results
331    fn top_k(&self, mut scores: Vec<ScoredResult>, k: usize) -> Vec<ScoredResult> {
332        // Sort by score descending
333        scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
334        scores.truncate(k);
335        scores
336    }
337}
338
339// ============================================================================
340// Disjunctive (OR) Query Support
341// ============================================================================
342
343/// A disjunctive BM25 query (OR semantics)
344pub struct DisjunctiveBm25Executor<I: InvertedIndex> {
345    index: Arc<I>,
346}
347
348impl<I: InvertedIndex> DisjunctiveBm25Executor<I> {
349    /// Create a new executor
350    pub fn new(index: Arc<I>) -> Self {
351        Self { index }
352    }
353    
354    /// Execute with OR semantics (documents match if they have ANY query term)
355    pub fn search(
356        &self,
357        query: &str,
358        k: usize,
359        allowed: &AllowedSet,
360    ) -> Vec<ScoredResult> {
361        if allowed.is_empty() {
362            return vec![];
363        }
364        
365        let terms: Vec<&str> = query.split_whitespace().collect();
366        if terms.is_empty() {
367            return vec![];
368        }
369        
370        // Get posting lists
371        let posting_lists: Vec<PostingList> = terms
372            .iter()
373            .filter_map(|t| self.index.get_posting_list(t))
374            .collect();
375        
376        let params = self.index.get_params();
377        
378        // Accumulate scores for all docs matching any term
379        let mut scores: HashMap<u64, f32> = HashMap::new();
380        
381        for posting_list in &posting_lists {
382            let idf = params.idf(posting_list.doc_freq);
383            
384            // Only score docs in allowed set
385            for (&doc_id, &tf) in posting_list.doc_ids.iter().zip(posting_list.term_freqs.iter()) {
386                if !allowed.contains(doc_id) {
387                    continue;
388                }
389                
390                if let Some(doc_len) = self.index.get_doc_length(doc_id) {
391                    let term_score = params.term_score(tf as f32, doc_len as f32, idf);
392                    *scores.entry(doc_id).or_insert(0.0) += term_score;
393                }
394            }
395        }
396        
397        // Convert to results and get top-k
398        let mut results: Vec<ScoredResult> = scores
399            .into_iter()
400            .map(|(id, score)| ScoredResult::new(id, score))
401            .collect();
402        
403        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
404        results.truncate(k);
405        results
406    }
407}
408
409// ============================================================================
410// Phrase Query Support
411// ============================================================================
412
413/// Position information for a term in a document
414#[derive(Debug, Clone)]
415pub struct PositionalPosting {
416    /// Document ID
417    pub doc_id: u64,
418    /// Positions where term appears (sorted)
419    pub positions: Vec<u32>,
420}
421
422/// Trait for positional index access
423pub trait PositionalIndex: InvertedIndex {
424    /// Get positional posting list
425    fn get_positional_posting(&self, term: &str) -> Option<Vec<PositionalPosting>>;
426}
427
428/// Phrase query executor with filter pushdown
429pub struct FilteredPhraseExecutor<I: PositionalIndex> {
430    index: Arc<I>,
431}
432
433impl<I: PositionalIndex> FilteredPhraseExecutor<I> {
434    /// Create a new executor
435    pub fn new(index: Arc<I>) -> Self {
436        Self { index }
437    }
438    
439    /// Execute a phrase query
440    ///
441    /// Documents must contain all terms in sequence.
442    pub fn search(
443        &self,
444        phrase: &[&str],
445        k: usize,
446        allowed: &AllowedSet,
447    ) -> Vec<ScoredResult> {
448        if phrase.is_empty() || allowed.is_empty() {
449            return vec![];
450        }
451        
452        // Get positional postings for each term
453        let mut positional_postings: Vec<Vec<PositionalPosting>> = vec![];
454        for term in phrase {
455            match self.index.get_positional_posting(term) {
456                Some(postings) => positional_postings.push(postings),
457                None => return vec![], // Term not in index → no matches
458            }
459        }
460        
461        // Find documents containing all terms
462        let candidates = self.find_phrase_matches(&positional_postings, allowed);
463        
464        // Score by phrase frequency
465        let params = self.index.get_params();
466        let results: Vec<ScoredResult> = candidates
467            .into_iter()
468            .filter_map(|(doc_id, phrase_freq)| {
469                let doc_len = self.index.get_doc_length(doc_id)? as f32;
470                // Use phrase frequency as TF, use min DF for IDF approximation
471                let min_df = positional_postings.iter()
472                    .map(|pp| pp.len() as u64)
473                    .min()
474                    .unwrap_or(1);
475                let idf = params.idf(min_df);
476                let score = params.term_score(phrase_freq as f32, doc_len, idf);
477                Some(ScoredResult::new(doc_id, score))
478            })
479            .collect();
480        
481        let mut results = results;
482        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
483        results.truncate(k);
484        results
485    }
486    
487    /// Find documents containing the phrase
488    fn find_phrase_matches(
489        &self,
490        positional_postings: &[Vec<PositionalPosting>],
491        allowed: &AllowedSet,
492    ) -> Vec<(u64, u32)> {
493        if positional_postings.is_empty() {
494            return vec![];
495        }
496        
497        // Index postings by doc_id
498        let indexed: Vec<HashMap<u64, &Vec<u32>>> = positional_postings
499            .iter()
500            .map(|postings| {
501                postings.iter()
502                    .filter(|p| allowed.contains(p.doc_id))
503                    .map(|p| (p.doc_id, &p.positions))
504                    .collect()
505            })
506            .collect();
507        
508        // Start with docs having first term
509        let first_docs: std::collections::HashSet<u64> = indexed[0].keys().copied().collect();
510        
511        // Keep only docs that have all terms
512        let candidate_docs: Vec<u64> = first_docs
513            .into_iter()
514            .filter(|doc_id| indexed.iter().all(|idx| idx.contains_key(doc_id)))
515            .collect();
516        
517        // Check phrase positions
518        let mut matches = vec![];
519        
520        for doc_id in candidate_docs {
521            let mut phrase_count = 0u32;
522            
523            // Get positions for first term
524            let first_positions = indexed[0].get(&doc_id).unwrap();
525            
526            'outer: for &start_pos in first_positions.iter() {
527                // Check if subsequent terms appear at consecutive positions
528                for (term_idx, term_positions) in indexed.iter().enumerate().skip(1) {
529                    let expected_pos = start_pos + term_idx as u32;
530                    let positions = term_positions.get(&doc_id).unwrap();
531                    
532                    // Binary search for expected position
533                    if positions.binary_search(&expected_pos).is_err() {
534                        continue 'outer;
535                    }
536                }
537                
538                // Found a phrase match
539                phrase_count += 1;
540            }
541            
542            if phrase_count > 0 {
543                matches.push((doc_id, phrase_count));
544            }
545        }
546        
547        matches
548    }
549}
550
551// ============================================================================
552// Tests
553// ============================================================================
554
555#[cfg(test)]
556mod tests {
557    use super::*;
558    use crate::candidate_gate::AllowedSet;
559    
560    // Mock inverted index for testing
561    struct MockIndex {
562        postings: HashMap<String, PostingList>,
563        doc_lengths: HashMap<u64, u32>,
564        params: Bm25Params,
565    }
566    
567    impl MockIndex {
568        fn new() -> Self {
569            let mut postings = HashMap::new();
570            let mut doc_lengths = HashMap::new();
571            
572            // Add some test data
573            postings.insert("rust".to_string(), PostingList::new("rust", vec![
574                (1, 3), (2, 1), (3, 2), (5, 1),
575            ]));
576            postings.insert("database".to_string(), PostingList::new("database", vec![
577                (1, 1), (3, 4), (4, 1),
578            ]));
579            postings.insert("vector".to_string(), PostingList::new("vector", vec![
580                (1, 2), (2, 3), (4, 1), (5, 2),
581            ]));
582            
583            // Doc lengths
584            for i in 1..=5 {
585                doc_lengths.insert(i, 100);
586            }
587            
588            Self {
589                postings,
590                doc_lengths,
591                params: Bm25Params {
592                    k1: 1.2,
593                    b: 0.75,
594                    avgdl: 100.0,
595                    total_docs: 1000,
596                },
597            }
598        }
599    }
600    
601    impl InvertedIndex for MockIndex {
602        fn get_posting_list(&self, term: &str) -> Option<PostingList> {
603            self.postings.get(term).cloned()
604        }
605        
606        fn get_doc_length(&self, doc_id: u64) -> Option<u32> {
607            self.doc_lengths.get(&doc_id).copied()
608        }
609        
610        fn get_params(&self) -> &Bm25Params {
611            &self.params
612        }
613    }
614    
615    #[test]
616    fn test_conjunctive_search() {
617        let index = Arc::new(MockIndex::new());
618        let executor = FilteredBm25Executor::new(index);
619        
620        // Search for "rust database"
621        // Should match docs 1 and 3 (have both terms)
622        let results = executor.search("rust database", 10, &AllowedSet::All);
623        
624        assert_eq!(results.len(), 2);
625        let doc_ids: Vec<u64> = results.iter().map(|r| r.doc_id).collect();
626        assert!(doc_ids.contains(&1));
627        assert!(doc_ids.contains(&3));
628    }
629    
630    #[test]
631    fn test_filter_pushdown() {
632        let index = Arc::new(MockIndex::new());
633        let executor = FilteredBm25Executor::new(index);
634        
635        // Only allow doc 1
636        let allowed = AllowedSet::SortedVec(Arc::new(vec![1]));
637        
638        let results = executor.search("rust database", 10, &allowed);
639        
640        assert_eq!(results.len(), 1);
641        assert_eq!(results[0].doc_id, 1);
642    }
643    
644    #[test]
645    fn test_empty_allowed_set() {
646        let index = Arc::new(MockIndex::new());
647        let executor = FilteredBm25Executor::new(index);
648        
649        let results = executor.search("rust", 10, &AllowedSet::None);
650        assert!(results.is_empty());
651    }
652    
653    #[test]
654    fn test_disjunctive_search() {
655        let index = Arc::new(MockIndex::new());
656        let executor = DisjunctiveBm25Executor::new(index);
657        
658        // Search for "rust database" with OR semantics
659        // Should match docs 1, 2, 3, 4, 5 (any with either term)
660        let results = executor.search("rust database", 10, &AllowedSet::All);
661        
662        // Docs 1-5 have at least one of the terms
663        assert!(results.len() >= 4);
664    }
665    
666    #[test]
667    fn test_term_ordering_by_df() {
668        // This tests that progressive intersection starts with lowest DF
669        let mut pl1 = PostingList::new("rare", vec![(1, 1), (2, 1)]); // DF=2
670        let mut pl2 = PostingList::new("common", vec![(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]); // DF=5
671        
672        let mut lists = vec![pl2.clone(), pl1.clone()];
673        lists.sort_by_key(|pl| pl.doc_freq);
674        
675        // Should be sorted: rare (DF=2) before common (DF=5)
676        assert_eq!(lists[0].term, "rare");
677        assert_eq!(lists[1].term, "common");
678    }
679    
680    #[test]
681    fn test_bm25_scoring() {
682        let params = Bm25Params::default();
683        
684        // IDF for rare term (DF=10 in 1M docs) should be higher than common (DF=100K)
685        let idf_rare = params.idf(10);
686        let idf_common = params.idf(100_000);
687        
688        assert!(idf_rare > idf_common);
689        
690        // Score with higher TF should be higher
691        let score_tf_1 = params.term_score(1.0, 100.0, idf_rare);
692        let score_tf_5 = params.term_score(5.0, 100.0, idf_rare);
693        
694        assert!(score_tf_5 > score_tf_1);
695    }
696}