sochdb_query/
bm25_filtered.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! BM25 Filter Pushdown via Posting-Set Intersection (Task 6)
16//!
17//! This module implements filter pushdown for BM25 full-text search by
18//! intersecting posting lists with the allowed set BEFORE scoring.
19//!
20//! ## Key Insight
21//!
22//! ```text
23//! Traditional BM25: score ALL docs with term → filter → top-k
24//! Filtered BM25: intersect posting(term) ∩ AllowedSet → score only intersection
25//! ```
26//!
27//! ## Optimization: Term Reordering
28//!
29//! For multi-term queries, we reorder terms by increasing document frequency (DF):
30//!
31//! ```text
32//! Query: "machine learning algorithms"
33//! DFs: machine=10000, learning=8000, algorithms=500
34//! 
35//! Order: algorithms → learning → machine
36//! 
37//! Step 1: posting(algorithms) ∩ AllowedSet → small_set
38//! Step 2: small_set ∩ posting(learning) → smaller_set  
39//! Step 3: smaller_set ∩ posting(machine) → final_set
40//! ```
41//!
42//! This minimizes intermediate set sizes, reducing memory and CPU.
43//!
44//! ## Cost Model
45//!
46//! Let:
47//! - N = total docs
48//! - |A| = allowed set size
49//! - df(t) = document frequency of term t
50//!
51//! Without pushdown: O(Σ df(t)) scoring operations
52//! With pushdown: O(min(|A|, min(df(t))) scoring operations
53//!
54//! When |A| << N, this is a massive win.
55
56use std::collections::HashMap;
57use std::sync::Arc;
58
59use crate::candidate_gate::AllowedSet;
60use crate::filtered_vector_search::ScoredResult;
61
62// ============================================================================
63// BM25 Parameters
64// ============================================================================
65
66/// BM25 scoring parameters
67#[derive(Debug, Clone)]
68pub struct Bm25Params {
69    /// Term frequency saturation parameter (default: 1.2)
70    pub k1: f32,
71    /// Document length normalization (default: 0.75)
72    pub b: f32,
73    /// Average document length (computed from corpus)
74    pub avgdl: f32,
75    /// Total number of documents in corpus
76    pub total_docs: u64,
77}
78
79impl Default for Bm25Params {
80    fn default() -> Self {
81        Self {
82            k1: 1.2,
83            b: 0.75,
84            avgdl: 100.0,
85            total_docs: 1_000_000,
86        }
87    }
88}
89
90impl Bm25Params {
91    /// Compute IDF for a term
92    pub fn idf(&self, doc_freq: u64) -> f32 {
93        let n = self.total_docs as f32;
94        let df = doc_freq as f32;
95        ((n - df + 0.5) / (df + 0.5) + 1.0).ln()
96    }
97    
98    /// Compute term score component
99    pub fn term_score(&self, tf: f32, doc_len: f32, idf: f32) -> f32 {
100        let numerator = tf * (self.k1 + 1.0);
101        let denominator = tf + self.k1 * (1.0 - self.b + self.b * doc_len / self.avgdl);
102        idf * numerator / denominator
103    }
104}
105
106// ============================================================================
107// Term Posting List
108// ============================================================================
109
110/// A posting list for a single term
111#[derive(Debug, Clone)]
112pub struct PostingList {
113    /// Term text
114    pub term: String,
115    /// Document IDs containing this term (sorted)
116    pub doc_ids: Vec<u64>,
117    /// Term frequencies in each document
118    pub term_freqs: Vec<u32>,
119    /// Document frequency (len of doc_ids)
120    pub doc_freq: u64,
121}
122
123impl PostingList {
124    /// Create a new posting list
125    pub fn new(term: impl Into<String>, entries: Vec<(u64, u32)>) -> Self {
126        let term = term.into();
127        let doc_freq = entries.len() as u64;
128        let mut doc_ids = Vec::with_capacity(entries.len());
129        let mut term_freqs = Vec::with_capacity(entries.len());
130        
131        for (doc_id, tf) in entries {
132            doc_ids.push(doc_id);
133            term_freqs.push(tf);
134        }
135        
136        Self {
137            term,
138            doc_ids,
139            term_freqs,
140            doc_freq,
141        }
142    }
143    
144    /// Intersect with an allowed set, returning (doc_id, tf) pairs
145    pub fn intersect_with_allowed(&self, allowed: &AllowedSet) -> Vec<(u64, u32)> {
146        match allowed {
147            AllowedSet::All => {
148                self.doc_ids.iter()
149                    .zip(self.term_freqs.iter())
150                    .map(|(&id, &tf)| (id, tf))
151                    .collect()
152            }
153            AllowedSet::None => vec![],
154            _ => {
155                self.doc_ids.iter()
156                    .zip(self.term_freqs.iter())
157                    .filter(|&(&id, _)| allowed.contains(id))
158                    .map(|(&id, &tf)| (id, tf))
159                    .collect()
160            }
161        }
162    }
163}
164
165// ============================================================================
166// Inverted Index Interface
167// ============================================================================
168
169/// Trait for accessing an inverted index
170pub trait InvertedIndex: Send + Sync {
171    /// Get posting list for a term (None if term not in vocabulary)
172    fn get_posting_list(&self, term: &str) -> Option<PostingList>;
173    
174    /// Get document length (in tokens)
175    fn get_doc_length(&self, doc_id: u64) -> Option<u32>;
176    
177    /// Get BM25 parameters
178    fn get_params(&self) -> &Bm25Params;
179}
180
181// ============================================================================
182// Filtered BM25 Executor
183// ============================================================================
184
185/// A BM25 executor that applies filter pushdown
186pub struct FilteredBm25Executor<I: InvertedIndex> {
187    index: Arc<I>,
188}
189
190impl<I: InvertedIndex> FilteredBm25Executor<I> {
191    /// Create a new executor
192    pub fn new(index: Arc<I>) -> Self {
193        Self { index }
194    }
195    
196    /// Execute a BM25 query with filter pushdown
197    ///
198    /// # Algorithm
199    ///
200    /// 1. Tokenize query
201    /// 2. Get posting lists for each term
202    /// 3. Sort terms by DF (ascending) for early pruning
203    /// 4. Intersect posting lists with AllowedSet (in DF order)
204    /// 5. Score only docs in intersection
205    /// 6. Return top-k
206    pub fn search(
207        &self,
208        query: &str,
209        k: usize,
210        allowed: &AllowedSet,
211    ) -> Vec<ScoredResult> {
212        // Short-circuit if nothing allowed
213        if allowed.is_empty() {
214            return vec![];
215        }
216        
217        // Tokenize (simple whitespace split for now)
218        let terms: Vec<&str> = query
219            .split_whitespace()
220            .filter(|t| t.len() >= 2) // Skip very short terms
221            .collect();
222        
223        if terms.is_empty() {
224            return vec![];
225        }
226        
227        // Get posting lists and sort by DF
228        let mut posting_lists: Vec<PostingList> = terms
229            .iter()
230            .filter_map(|t| self.index.get_posting_list(t))
231            .collect();
232        
233        // Sort by document frequency (ascending)
234        posting_lists.sort_by_key(|pl| pl.doc_freq);
235        
236        // Progressive intersection with AllowedSet
237        let candidates = self.progressive_intersection(&posting_lists, allowed);
238        
239        if candidates.is_empty() {
240            return vec![];
241        }
242        
243        // Score candidates
244        let params = self.index.get_params();
245        let scores = self.score_candidates(&candidates, &posting_lists, params);
246        
247        // Return top-k
248        self.top_k(scores, k)
249    }
250    
251    /// Progressively intersect posting lists with allowed set
252    ///
253    /// Returns map of doc_id -> term frequencies for each term
254    fn progressive_intersection(
255        &self,
256        posting_lists: &[PostingList],
257        allowed: &AllowedSet,
258    ) -> HashMap<u64, Vec<u32>> {
259        if posting_lists.is_empty() {
260            return HashMap::new();
261        }
262        
263        // Start with first (smallest DF) term intersected with allowed
264        let first = &posting_lists[0];
265        let mut candidates: HashMap<u64, Vec<u32>> = first
266            .intersect_with_allowed(allowed)
267            .into_iter()
268            .map(|(id, tf)| (id, vec![tf]))
269            .collect();
270        
271        // Intersect with remaining terms
272        for (_term_idx, posting_list) in posting_lists.iter().enumerate().skip(1) {
273            // Build a lookup for this term's postings
274            let term_postings: HashMap<u64, u32> = posting_list
275                .doc_ids.iter()
276                .zip(posting_list.term_freqs.iter())
277                .map(|(&id, &tf)| (id, tf))
278                .collect();
279            
280            // Keep only candidates that appear in this term's postings
281            candidates.retain(|doc_id, tfs| {
282                if let Some(&tf) = term_postings.get(doc_id) {
283                    tfs.push(tf);
284                    true
285                } else {
286                    false
287                }
288            });
289            
290            // Early exit if no candidates remain
291            if candidates.is_empty() {
292                break;
293            }
294        }
295        
296        candidates
297    }
298    
299    /// Score candidates using BM25
300    fn score_candidates(
301        &self,
302        candidates: &HashMap<u64, Vec<u32>>,
303        posting_lists: &[PostingList],
304        params: &Bm25Params,
305    ) -> Vec<ScoredResult> {
306        // Precompute IDFs
307        let idfs: Vec<f32> = posting_lists
308            .iter()
309            .map(|pl| params.idf(pl.doc_freq))
310            .collect();
311        
312        candidates
313            .iter()
314            .filter_map(|(&doc_id, tfs)| {
315                let doc_len = self.index.get_doc_length(doc_id)? as f32;
316                
317                let score: f32 = tfs.iter()
318                    .zip(idfs.iter())
319                    .map(|(&tf, &idf)| params.term_score(tf as f32, doc_len, idf))
320                    .sum();
321                
322                Some(ScoredResult::new(doc_id, score))
323            })
324            .collect()
325    }
326    
327    /// Get top-k results
328    fn top_k(&self, mut scores: Vec<ScoredResult>, k: usize) -> Vec<ScoredResult> {
329        // Sort by score descending
330        scores.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
331        scores.truncate(k);
332        scores
333    }
334}
335
336// ============================================================================
337// Disjunctive (OR) Query Support
338// ============================================================================
339
340/// A disjunctive BM25 query (OR semantics)
341pub struct DisjunctiveBm25Executor<I: InvertedIndex> {
342    index: Arc<I>,
343}
344
345impl<I: InvertedIndex> DisjunctiveBm25Executor<I> {
346    /// Create a new executor
347    pub fn new(index: Arc<I>) -> Self {
348        Self { index }
349    }
350    
351    /// Execute with OR semantics (documents match if they have ANY query term)
352    pub fn search(
353        &self,
354        query: &str,
355        k: usize,
356        allowed: &AllowedSet,
357    ) -> Vec<ScoredResult> {
358        if allowed.is_empty() {
359            return vec![];
360        }
361        
362        let terms: Vec<&str> = query.split_whitespace().collect();
363        if terms.is_empty() {
364            return vec![];
365        }
366        
367        // Get posting lists
368        let posting_lists: Vec<PostingList> = terms
369            .iter()
370            .filter_map(|t| self.index.get_posting_list(t))
371            .collect();
372        
373        let params = self.index.get_params();
374        
375        // Accumulate scores for all docs matching any term
376        let mut scores: HashMap<u64, f32> = HashMap::new();
377        
378        for posting_list in &posting_lists {
379            let idf = params.idf(posting_list.doc_freq);
380            
381            // Only score docs in allowed set
382            for (&doc_id, &tf) in posting_list.doc_ids.iter().zip(posting_list.term_freqs.iter()) {
383                if !allowed.contains(doc_id) {
384                    continue;
385                }
386                
387                if let Some(doc_len) = self.index.get_doc_length(doc_id) {
388                    let term_score = params.term_score(tf as f32, doc_len as f32, idf);
389                    *scores.entry(doc_id).or_insert(0.0) += term_score;
390                }
391            }
392        }
393        
394        // Convert to results and get top-k
395        let mut results: Vec<ScoredResult> = scores
396            .into_iter()
397            .map(|(id, score)| ScoredResult::new(id, score))
398            .collect();
399        
400        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
401        results.truncate(k);
402        results
403    }
404}
405
406// ============================================================================
407// Phrase Query Support
408// ============================================================================
409
410/// Position information for a term in a document
411#[derive(Debug, Clone)]
412pub struct PositionalPosting {
413    /// Document ID
414    pub doc_id: u64,
415    /// Positions where term appears (sorted)
416    pub positions: Vec<u32>,
417}
418
419/// Trait for positional index access
420pub trait PositionalIndex: InvertedIndex {
421    /// Get positional posting list
422    fn get_positional_posting(&self, term: &str) -> Option<Vec<PositionalPosting>>;
423}
424
425/// Phrase query executor with filter pushdown
426pub struct FilteredPhraseExecutor<I: PositionalIndex> {
427    index: Arc<I>,
428}
429
430impl<I: PositionalIndex> FilteredPhraseExecutor<I> {
431    /// Create a new executor
432    pub fn new(index: Arc<I>) -> Self {
433        Self { index }
434    }
435    
436    /// Execute a phrase query
437    ///
438    /// Documents must contain all terms in sequence.
439    pub fn search(
440        &self,
441        phrase: &[&str],
442        k: usize,
443        allowed: &AllowedSet,
444    ) -> Vec<ScoredResult> {
445        if phrase.is_empty() || allowed.is_empty() {
446            return vec![];
447        }
448        
449        // Get positional postings for each term
450        let mut positional_postings: Vec<Vec<PositionalPosting>> = vec![];
451        for term in phrase {
452            match self.index.get_positional_posting(term) {
453                Some(postings) => positional_postings.push(postings),
454                None => return vec![], // Term not in index → no matches
455            }
456        }
457        
458        // Find documents containing all terms
459        let candidates = self.find_phrase_matches(&positional_postings, allowed);
460        
461        // Score by phrase frequency
462        let params = self.index.get_params();
463        let results: Vec<ScoredResult> = candidates
464            .into_iter()
465            .filter_map(|(doc_id, phrase_freq)| {
466                let doc_len = self.index.get_doc_length(doc_id)? as f32;
467                // Use phrase frequency as TF, use min DF for IDF approximation
468                let min_df = positional_postings.iter()
469                    .map(|pp| pp.len() as u64)
470                    .min()
471                    .unwrap_or(1);
472                let idf = params.idf(min_df);
473                let score = params.term_score(phrase_freq as f32, doc_len, idf);
474                Some(ScoredResult::new(doc_id, score))
475            })
476            .collect();
477        
478        let mut results = results;
479        results.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
480        results.truncate(k);
481        results
482    }
483    
484    /// Find documents containing the phrase
485    fn find_phrase_matches(
486        &self,
487        positional_postings: &[Vec<PositionalPosting>],
488        allowed: &AllowedSet,
489    ) -> Vec<(u64, u32)> {
490        if positional_postings.is_empty() {
491            return vec![];
492        }
493        
494        // Index postings by doc_id
495        let indexed: Vec<HashMap<u64, &Vec<u32>>> = positional_postings
496            .iter()
497            .map(|postings| {
498                postings.iter()
499                    .filter(|p| allowed.contains(p.doc_id))
500                    .map(|p| (p.doc_id, &p.positions))
501                    .collect()
502            })
503            .collect();
504        
505        // Start with docs having first term
506        let first_docs: std::collections::HashSet<u64> = indexed[0].keys().copied().collect();
507        
508        // Keep only docs that have all terms
509        let candidate_docs: Vec<u64> = first_docs
510            .into_iter()
511            .filter(|doc_id| indexed.iter().all(|idx| idx.contains_key(doc_id)))
512            .collect();
513        
514        // Check phrase positions
515        let mut matches = vec![];
516        
517        for doc_id in candidate_docs {
518            let mut phrase_count = 0u32;
519            
520            // Get positions for first term
521            let first_positions = indexed[0].get(&doc_id).unwrap();
522            
523            'outer: for &start_pos in first_positions.iter() {
524                // Check if subsequent terms appear at consecutive positions
525                for (term_idx, term_positions) in indexed.iter().enumerate().skip(1) {
526                    let expected_pos = start_pos + term_idx as u32;
527                    let positions = term_positions.get(&doc_id).unwrap();
528                    
529                    // Binary search for expected position
530                    if positions.binary_search(&expected_pos).is_err() {
531                        continue 'outer;
532                    }
533                }
534                
535                // Found a phrase match
536                phrase_count += 1;
537            }
538            
539            if phrase_count > 0 {
540                matches.push((doc_id, phrase_count));
541            }
542        }
543        
544        matches
545    }
546}
547
548// ============================================================================
549// Tests
550// ============================================================================
551
552#[cfg(test)]
553mod tests {
554    use super::*;
555    use crate::candidate_gate::AllowedSet;
556    
557    // Mock inverted index for testing
558    struct MockIndex {
559        postings: HashMap<String, PostingList>,
560        doc_lengths: HashMap<u64, u32>,
561        params: Bm25Params,
562    }
563    
564    impl MockIndex {
565        fn new() -> Self {
566            let mut postings = HashMap::new();
567            let mut doc_lengths = HashMap::new();
568            
569            // Add some test data
570            postings.insert("rust".to_string(), PostingList::new("rust", vec![
571                (1, 3), (2, 1), (3, 2), (5, 1),
572            ]));
573            postings.insert("database".to_string(), PostingList::new("database", vec![
574                (1, 1), (3, 4), (4, 1),
575            ]));
576            postings.insert("vector".to_string(), PostingList::new("vector", vec![
577                (1, 2), (2, 3), (4, 1), (5, 2),
578            ]));
579            
580            // Doc lengths
581            for i in 1..=5 {
582                doc_lengths.insert(i, 100);
583            }
584            
585            Self {
586                postings,
587                doc_lengths,
588                params: Bm25Params {
589                    k1: 1.2,
590                    b: 0.75,
591                    avgdl: 100.0,
592                    total_docs: 1000,
593                },
594            }
595        }
596    }
597    
598    impl InvertedIndex for MockIndex {
599        fn get_posting_list(&self, term: &str) -> Option<PostingList> {
600            self.postings.get(term).cloned()
601        }
602        
603        fn get_doc_length(&self, doc_id: u64) -> Option<u32> {
604            self.doc_lengths.get(&doc_id).copied()
605        }
606        
607        fn get_params(&self) -> &Bm25Params {
608            &self.params
609        }
610    }
611    
612    #[test]
613    fn test_conjunctive_search() {
614        let index = Arc::new(MockIndex::new());
615        let executor = FilteredBm25Executor::new(index);
616        
617        // Search for "rust database"
618        // Should match docs 1 and 3 (have both terms)
619        let results = executor.search("rust database", 10, &AllowedSet::All);
620        
621        assert_eq!(results.len(), 2);
622        let doc_ids: Vec<u64> = results.iter().map(|r| r.doc_id).collect();
623        assert!(doc_ids.contains(&1));
624        assert!(doc_ids.contains(&3));
625    }
626    
627    #[test]
628    fn test_filter_pushdown() {
629        let index = Arc::new(MockIndex::new());
630        let executor = FilteredBm25Executor::new(index);
631        
632        // Only allow doc 1
633        let allowed = AllowedSet::SortedVec(Arc::new(vec![1]));
634        
635        let results = executor.search("rust database", 10, &allowed);
636        
637        assert_eq!(results.len(), 1);
638        assert_eq!(results[0].doc_id, 1);
639    }
640    
641    #[test]
642    fn test_empty_allowed_set() {
643        let index = Arc::new(MockIndex::new());
644        let executor = FilteredBm25Executor::new(index);
645        
646        let results = executor.search("rust", 10, &AllowedSet::None);
647        assert!(results.is_empty());
648    }
649    
650    #[test]
651    fn test_disjunctive_search() {
652        let index = Arc::new(MockIndex::new());
653        let executor = DisjunctiveBm25Executor::new(index);
654        
655        // Search for "rust database" with OR semantics
656        // Should match docs 1, 2, 3, 4, 5 (any with either term)
657        let results = executor.search("rust database", 10, &AllowedSet::All);
658        
659        // Docs 1-5 have at least one of the terms
660        assert!(results.len() >= 4);
661    }
662    
663    #[test]
664    fn test_term_ordering_by_df() {
665        // This tests that progressive intersection starts with lowest DF
666        let mut pl1 = PostingList::new("rare", vec![(1, 1), (2, 1)]); // DF=2
667        let mut pl2 = PostingList::new("common", vec![(1, 1), (2, 1), (3, 1), (4, 1), (5, 1)]); // DF=5
668        
669        let mut lists = vec![pl2.clone(), pl1.clone()];
670        lists.sort_by_key(|pl| pl.doc_freq);
671        
672        // Should be sorted: rare (DF=2) before common (DF=5)
673        assert_eq!(lists[0].term, "rare");
674        assert_eq!(lists[1].term, "common");
675    }
676    
677    #[test]
678    fn test_bm25_scoring() {
679        let params = Bm25Params::default();
680        
681        // IDF for rare term (DF=10 in 1M docs) should be higher than common (DF=100K)
682        let idf_rare = params.idf(10);
683        let idf_common = params.idf(100_000);
684        
685        assert!(idf_rare > idf_common);
686        
687        // Score with higher TF should be higher
688        let score_tf_1 = params.term_score(1.0, 100.0, idf_rare);
689        let score_tf_5 = params.term_score(5.0, 100.0, idf_rare);
690        
691        assert!(score_tf_5 > score_tf_1);
692    }
693}