sochdb_query/
candidate_gate.rs

1// Copyright 2025 Sushanth (https://github.com/sushanthpy)
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Unified Candidate Gate Interface (Task 4)
16//!
17//! This module defines the `AllowedSet` abstraction that every retrieval
18//! executor MUST accept. The gate guarantees:
19//!
20//! 1. **Never return a doc outside AllowedSet** - structural enforcement
21//! 2. **Apply constraints during generation** - no post-filtering
22//! 3. **Consistent semantics** across vector/BM25/hybrid/context
23//!
24//! ## The Contract
25//!
26//! Every executor receives an `AllowedSet` and must:
27//! - Check membership BEFORE including any candidate
28//! - Short-circuit if AllowedSet is empty (return empty results)
29//! - Report selectivity for query planning
30//!
31//! ## Representations
32//!
33//! `AllowedSet` supports multiple representations for efficiency:
34//!
35//! | Representation | Best For | Membership | Space |
36//! |----------------|----------|------------|-------|
37//! | Bitmap | Dense sets | O(1) | O(N/8) |
38//! | SortedVec | Sparse sets | O(log n) | O(n) |
39//! | HashSet | Random access | O(1) avg | O(n) |
40//! | All | No constraint | O(1) | O(1) |
41//!
42//! ## Selectivity
43//!
44//! Executors use selectivity `|S|/N` to choose execution strategy:
45//! - High selectivity (> 0.1): Standard search with filter
46//! - Low selectivity (< 0.01): Scan only allowed IDs
47//! - Very low (< 0.001): Consider alternative strategy
48
49use std::collections::HashSet;
50use std::fmt;
51use std::sync::Arc;
52
53// ============================================================================
54// AllowedSet - Core Abstraction
55// ============================================================================
56
57/// The unified gate for candidate filtering
58///
59/// Every executor MUST check `allowed_set.contains(doc_id)` before
60/// including any result. This is the structural enforcement of pushdown.
61#[derive(Clone)]
62pub enum AllowedSet {
63    /// All documents are allowed (no filter constraint)
64    All,
65    
66    /// Bitmap representation (efficient for dense sets)
67    Bitmap(Arc<AllowedBitmap>),
68    
69    /// Sorted vector (efficient for sparse sets with iteration)
70    SortedVec(Arc<Vec<u64>>),
71    
72    /// Hash set (efficient for random access)
73    HashSet(Arc<HashSet<u64>>),
74    
75    /// No documents allowed (empty result shortcut)
76    None,
77}
78
79impl AllowedSet {
80    /// Create an AllowedSet from a bitmap
81    pub fn from_bitmap(bitmap: AllowedBitmap) -> Self {
82        if bitmap.is_empty() {
83            Self::None
84        } else if bitmap.is_all() {
85            Self::All
86        } else {
87            Self::Bitmap(Arc::new(bitmap))
88        }
89    }
90    
91    /// Create an AllowedSet from a sorted vector of doc IDs
92    pub fn from_sorted_vec(mut ids: Vec<u64>) -> Self {
93        if ids.is_empty() {
94            return Self::None;
95        }
96        ids.sort_unstable();
97        ids.dedup();
98        Self::SortedVec(Arc::new(ids))
99    }
100    
101    /// Create an AllowedSet from an iterator of doc IDs
102    pub fn from_iter(ids: impl IntoIterator<Item = u64>) -> Self {
103        let set: HashSet<u64> = ids.into_iter().collect();
104        if set.is_empty() {
105            Self::None
106        } else {
107            Self::HashSet(Arc::new(set))
108        }
109    }
110    
111    /// Check if a document ID is allowed
112    ///
113    /// This is the core operation that executors MUST call.
114    #[inline]
115    pub fn contains(&self, doc_id: u64) -> bool {
116        match self {
117            Self::All => true,
118            Self::Bitmap(bm) => bm.contains(doc_id),
119            Self::SortedVec(vec) => vec.binary_search(&doc_id).is_ok(),
120            Self::HashSet(set) => set.contains(&doc_id),
121            Self::None => false,
122        }
123    }
124    
125    /// Check if this set is empty (no allowed documents)
126    pub fn is_empty(&self) -> bool {
127        matches!(self, Self::None)
128    }
129    
130    /// Check if this set allows all documents
131    pub fn is_all(&self) -> bool {
132        matches!(self, Self::All)
133    }
134    
135    /// Get the cardinality (number of allowed documents)
136    ///
137    /// Returns None for All (unknown without universe size)
138    pub fn cardinality(&self) -> Option<usize> {
139        match self {
140            Self::All => None,
141            Self::Bitmap(bm) => Some(bm.count()),
142            Self::SortedVec(vec) => Some(vec.len()),
143            Self::HashSet(set) => Some(set.len()),
144            Self::None => Some(0),
145        }
146    }
147    
148    /// Compute selectivity against a universe of size N
149    ///
150    /// Returns |S| / N, the fraction of allowed documents
151    pub fn selectivity(&self, universe_size: usize) -> f64 {
152        if universe_size == 0 {
153            return 0.0;
154        }
155        match self {
156            Self::All => 1.0,
157            Self::None => 0.0,
158            other => {
159                other.cardinality()
160                    .map(|c| c as f64 / universe_size as f64)
161                    .unwrap_or(1.0)
162            }
163        }
164    }
165    
166    /// Intersect with another AllowedSet
167    pub fn intersect(&self, other: &AllowedSet) -> AllowedSet {
168        match (self, other) {
169            // Identity cases
170            (Self::All, x) | (x, Self::All) => x.clone(),
171            (Self::None, _) | (_, Self::None) => Self::None,
172            
173            // Both are sets - compute intersection
174            (Self::SortedVec(a), Self::SortedVec(b)) => {
175                let result = sorted_vec_intersect(a, b);
176                Self::from_sorted_vec(result)
177            }
178            (Self::HashSet(a), Self::HashSet(b)) => {
179                let result: HashSet<_> = a.intersection(b).copied().collect();
180                if result.is_empty() {
181                    Self::None
182                } else {
183                    Self::HashSet(Arc::new(result))
184                }
185            }
186            (Self::Bitmap(a), Self::Bitmap(b)) => {
187                let result = a.intersect(b);
188                Self::from_bitmap(result)
189            }
190            
191            // Mixed - convert to hash set
192            (a, b) => {
193                let set_a: HashSet<u64> = a.iter().collect();
194                let set_b: HashSet<u64> = b.iter().collect();
195                let result: HashSet<_> = set_a.intersection(&set_b).copied().collect();
196                if result.is_empty() {
197                    Self::None
198                } else {
199                    Self::HashSet(Arc::new(result))
200                }
201            }
202        }
203    }
204    
205    /// Union with another AllowedSet
206    pub fn union(&self, other: &AllowedSet) -> AllowedSet {
207        match (self, other) {
208            (Self::All, _) | (_, Self::All) => Self::All,
209            (Self::None, x) | (x, Self::None) => x.clone(),
210            
211            (Self::HashSet(a), Self::HashSet(b)) => {
212                let result: HashSet<_> = a.union(b).copied().collect();
213                Self::HashSet(Arc::new(result))
214            }
215            
216            // Mixed - convert to hash set
217            (a, b) => {
218                let mut result: HashSet<u64> = a.iter().collect();
219                result.extend(b.iter());
220                Self::HashSet(Arc::new(result))
221            }
222        }
223    }
224    
225    /// Iterate over allowed document IDs
226    ///
227    /// Note: For All, this returns an empty iterator (unknown universe)
228    pub fn iter(&self) -> AllowedSetIter<'_> {
229        match self {
230            Self::All => AllowedSetIter::Empty,
231            Self::Bitmap(bm) => AllowedSetIter::Bitmap(bm.iter()),
232            Self::SortedVec(vec) => AllowedSetIter::SortedVec(vec.iter()),
233            Self::HashSet(set) => AllowedSetIter::HashSet(set.iter()),
234            Self::None => AllowedSetIter::Empty,
235        }
236    }
237    
238    /// Convert to a Vec (for small sets)
239    pub fn to_vec(&self) -> Vec<u64> {
240        self.iter().collect()
241    }
242}
243
244impl fmt::Debug for AllowedSet {
245    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246        match self {
247            Self::All => write!(f, "AllowedSet::All"),
248            Self::None => write!(f, "AllowedSet::None"),
249            Self::Bitmap(bm) => write!(f, "AllowedSet::Bitmap(count={})", bm.count()),
250            Self::SortedVec(vec) => write!(f, "AllowedSet::SortedVec(len={})", vec.len()),
251            Self::HashSet(set) => write!(f, "AllowedSet::HashSet(len={})", set.len()),
252        }
253    }
254}
255
256impl Default for AllowedSet {
257    fn default() -> Self {
258        Self::All
259    }
260}
261
262// Helper for sorted vec intersection
263fn sorted_vec_intersect(a: &[u64], b: &[u64]) -> Vec<u64> {
264    let mut result = Vec::with_capacity(a.len().min(b.len()));
265    let mut i = 0;
266    let mut j = 0;
267    
268    while i < a.len() && j < b.len() {
269        match a[i].cmp(&b[j]) {
270            std::cmp::Ordering::Less => i += 1,
271            std::cmp::Ordering::Greater => j += 1,
272            std::cmp::Ordering::Equal => {
273                result.push(a[i]);
274                i += 1;
275                j += 1;
276            }
277        }
278    }
279    
280    result
281}
282
283// ============================================================================
284// AllowedSet Iterator
285// ============================================================================
286
287/// Iterator over allowed document IDs
288pub enum AllowedSetIter<'a> {
289    Empty,
290    Bitmap(BitmapIter<'a>),
291    SortedVec(std::slice::Iter<'a, u64>),
292    HashSet(std::collections::hash_set::Iter<'a, u64>),
293}
294
295impl<'a> Iterator for AllowedSetIter<'a> {
296    type Item = u64;
297    
298    fn next(&mut self) -> Option<Self::Item> {
299        match self {
300            Self::Empty => None,
301            Self::Bitmap(iter) => iter.next(),
302            Self::SortedVec(iter) => iter.next().copied(),
303            Self::HashSet(iter) => iter.next().copied(),
304        }
305    }
306    
307    fn size_hint(&self) -> (usize, Option<usize>) {
308        match self {
309            Self::Empty => (0, Some(0)),
310            Self::Bitmap(iter) => iter.size_hint(),
311            Self::SortedVec(iter) => iter.size_hint(),
312            Self::HashSet(iter) => iter.size_hint(),
313        }
314    }
315}
316
317// ============================================================================
318// Bitmap Implementation
319// ============================================================================
320
321/// Simple bitmap for allowed document IDs
322///
323/// This is a basic implementation. For production, consider using
324/// the `roaring` crate for compressed bitmaps.
325pub struct AllowedBitmap {
326    /// Bits stored as u64 words
327    words: Vec<u64>,
328    /// Total number of set bits (cached)
329    count: usize,
330    /// Whether this represents "all" (complement mode)
331    all: bool,
332}
333
334impl AllowedBitmap {
335    /// Create a new empty bitmap
336    pub fn new() -> Self {
337        Self {
338            words: Vec::new(),
339            count: 0,
340            all: false,
341        }
342    }
343    
344    /// Create a bitmap with all bits set up to max_id
345    pub fn all(max_id: u64) -> Self {
346        let word_count = (max_id as usize / 64) + 1;
347        Self {
348            words: vec![u64::MAX; word_count],
349            count: max_id as usize + 1,
350            all: true,
351        }
352    }
353    
354    /// Create a bitmap from a set of IDs
355    pub fn from_ids(ids: &[u64]) -> Self {
356        if ids.is_empty() {
357            return Self::new();
358        }
359        
360        let max_id = *ids.iter().max().unwrap();
361        let word_count = (max_id as usize / 64) + 1;
362        let mut words = vec![0u64; word_count];
363        
364        for &id in ids {
365            let word_idx = id as usize / 64;
366            let bit_idx = id % 64;
367            words[word_idx] |= 1 << bit_idx;
368        }
369        
370        Self {
371            words,
372            count: ids.len(),
373            all: false,
374        }
375    }
376    
377    /// Set a bit
378    pub fn set(&mut self, id: u64) {
379        let word_idx = id as usize / 64;
380        let bit_idx = id % 64;
381        
382        // Extend if necessary
383        if word_idx >= self.words.len() {
384            self.words.resize(word_idx + 1, 0);
385        }
386        
387        let old = self.words[word_idx];
388        self.words[word_idx] |= 1 << bit_idx;
389        if old != self.words[word_idx] {
390            self.count += 1;
391        }
392    }
393    
394    /// Clear a bit
395    pub fn clear(&mut self, id: u64) {
396        let word_idx = id as usize / 64;
397        if word_idx >= self.words.len() {
398            return;
399        }
400        
401        let bit_idx = id % 64;
402        let old = self.words[word_idx];
403        self.words[word_idx] &= !(1 << bit_idx);
404        if old != self.words[word_idx] {
405            self.count -= 1;
406        }
407    }
408    
409    /// Check if a bit is set
410    #[inline]
411    pub fn contains(&self, id: u64) -> bool {
412        let word_idx = id as usize / 64;
413        if word_idx >= self.words.len() {
414            return false;
415        }
416        let bit_idx = id % 64;
417        (self.words[word_idx] & (1 << bit_idx)) != 0
418    }
419    
420    /// Get the count of set bits
421    pub fn count(&self) -> usize {
422        self.count
423    }
424    
425    /// Check if empty
426    pub fn is_empty(&self) -> bool {
427        self.count == 0
428    }
429    
430    /// Check if all bits are set
431    pub fn is_all(&self) -> bool {
432        self.all
433    }
434    
435    /// Intersect with another bitmap
436    pub fn intersect(&self, other: &AllowedBitmap) -> AllowedBitmap {
437        let min_len = self.words.len().min(other.words.len());
438        let mut words = Vec::with_capacity(min_len);
439        let mut count = 0;
440        
441        for i in 0..min_len {
442            let word = self.words[i] & other.words[i];
443            count += word.count_ones() as usize;
444            words.push(word);
445        }
446        
447        AllowedBitmap {
448            words,
449            count,
450            all: false,
451        }
452    }
453    
454    /// Union with another bitmap
455    pub fn union(&self, other: &AllowedBitmap) -> AllowedBitmap {
456        let max_len = self.words.len().max(other.words.len());
457        let mut words = Vec::with_capacity(max_len);
458        let mut count = 0;
459        
460        for i in 0..max_len {
461            let a = self.words.get(i).copied().unwrap_or(0);
462            let b = other.words.get(i).copied().unwrap_or(0);
463            let word = a | b;
464            count += word.count_ones() as usize;
465            words.push(word);
466        }
467        
468        AllowedBitmap {
469            words,
470            count,
471            all: false,
472        }
473    }
474    
475    /// Iterate over set bit positions
476    pub fn iter(&self) -> BitmapIter<'_> {
477        BitmapIter {
478            words: &self.words,
479            word_idx: 0,
480            bit_offset: 0,
481            remaining: self.count,
482        }
483    }
484}
485
486impl Default for AllowedBitmap {
487    fn default() -> Self {
488        Self::new()
489    }
490}
491
492/// Iterator over set bits in a bitmap
493pub struct BitmapIter<'a> {
494    words: &'a [u64],
495    word_idx: usize,
496    bit_offset: u64,
497    remaining: usize,
498}
499
500impl<'a> Iterator for BitmapIter<'a> {
501    type Item = u64;
502    
503    fn next(&mut self) -> Option<Self::Item> {
504        if self.remaining == 0 {
505            return None;
506        }
507        
508        while self.word_idx < self.words.len() {
509            let word = self.words[self.word_idx];
510            let masked = word >> self.bit_offset;
511            
512            if masked != 0 {
513                let trailing = masked.trailing_zeros() as u64;
514                let bit_pos = self.bit_offset + trailing;
515                self.bit_offset = bit_pos + 1;
516                
517                if self.bit_offset >= 64 {
518                    self.bit_offset = 0;
519                    self.word_idx += 1;
520                }
521                
522                self.remaining -= 1;
523                return Some(self.word_idx as u64 * 64 + bit_pos - (if self.bit_offset == 0 { 64 } else { 0 }) + (if bit_pos >= 64 { 0 } else { bit_pos }));
524            }
525            
526            self.word_idx += 1;
527            self.bit_offset = 0;
528        }
529        
530        None
531    }
532    
533    fn size_hint(&self) -> (usize, Option<usize>) {
534        (self.remaining, Some(self.remaining))
535    }
536}
537
538// Fix the iterator - simpler implementation
539impl<'a> BitmapIter<'a> {
540    #[allow(dead_code)]
541    fn new(words: &'a [u64], count: usize) -> Self {
542        Self {
543            words,
544            word_idx: 0,
545            bit_offset: 0,
546            remaining: count,
547        }
548    }
549}
550
551// Simple correct iterator implementation
552impl AllowedBitmap {
553    /// Iterate over set bit positions (simple implementation)
554    pub fn iter_simple(&self) -> impl Iterator<Item = u64> + '_ {
555        self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
556            (0..64).filter_map(move |bit| {
557                if (word & (1 << bit)) != 0 {
558                    Some(word_idx as u64 * 64 + bit as u64)
559                } else {
560                    None
561                }
562            })
563        })
564    }
565}
566
567// ============================================================================
568// Candidate Gate Trait
569// ============================================================================
570
571/// The candidate gate trait that all executors must implement
572///
573/// This trait ensures every retrieval path respects the AllowedSet.
574pub trait CandidateGate {
575    /// The query type
576    type Query;
577    
578    /// The result type  
579    type Result;
580    
581    /// The error type
582    type Error;
583    
584    /// Execute with a mandatory allowed set
585    ///
586    /// # Contract
587    ///
588    /// - MUST NOT return any result with doc_id not in allowed_set
589    /// - SHOULD short-circuit if allowed_set is empty
590    /// - SHOULD use selectivity to choose execution strategy
591    fn execute_with_gate(
592        &self,
593        query: &Self::Query,
594        allowed_set: &AllowedSet,
595    ) -> Result<Self::Result, Self::Error>;
596    
597    /// Get the execution strategy for a given selectivity
598    fn strategy_for_selectivity(&self, selectivity: f64) -> ExecutionStrategy {
599        if selectivity >= 0.1 {
600            ExecutionStrategy::FilterDuringSearch
601        } else if selectivity >= 0.001 {
602            ExecutionStrategy::ScanAllowedIds
603        } else {
604            ExecutionStrategy::LinearScan
605        }
606    }
607}
608
609/// Execution strategy based on selectivity
610#[derive(Debug, Clone, Copy, PartialEq, Eq)]
611pub enum ExecutionStrategy {
612    /// Standard search with filter check during traversal
613    FilterDuringSearch,
614    
615    /// Iterate over allowed IDs and compute distances
616    ScanAllowedIds,
617    
618    /// Fall back to linear scan (very low selectivity)
619    LinearScan,
620    
621    /// Refuse to execute (too expensive)
622    Reject,
623}
624
625// ============================================================================
626// Tests
627// ============================================================================
628
629#[cfg(test)]
630mod tests {
631    use super::*;
632    
633    #[test]
634    fn test_allowed_set_contains() {
635        // All
636        let all = AllowedSet::All;
637        assert!(all.contains(0));
638        assert!(all.contains(1000000));
639        
640        // None
641        let none = AllowedSet::None;
642        assert!(!none.contains(0));
643        
644        // SortedVec
645        let vec = AllowedSet::from_sorted_vec(vec![1, 3, 5, 7, 9]);
646        assert!(vec.contains(1));
647        assert!(vec.contains(5));
648        assert!(!vec.contains(2));
649        assert!(!vec.contains(10));
650        
651        // HashSet
652        let set = AllowedSet::from_iter([1, 3, 5, 7, 9]);
653        assert!(set.contains(1));
654        assert!(set.contains(5));
655        assert!(!set.contains(2));
656    }
657    
658    #[test]
659    fn test_allowed_set_selectivity() {
660        let set = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
661        
662        assert_eq!(set.selectivity(100), 0.05);
663        assert_eq!(set.selectivity(10), 0.5);
664        
665        assert_eq!(AllowedSet::All.selectivity(100), 1.0);
666        assert_eq!(AllowedSet::None.selectivity(100), 0.0);
667    }
668    
669    #[test]
670    fn test_allowed_set_intersection() {
671        let a = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
672        let b = AllowedSet::from_sorted_vec(vec![3, 4, 5, 6, 7]);
673        
674        let c = a.intersect(&b);
675        assert_eq!(c.cardinality(), Some(3));
676        assert!(c.contains(3));
677        assert!(c.contains(4));
678        assert!(c.contains(5));
679        assert!(!c.contains(1));
680        assert!(!c.contains(7));
681    }
682    
683    #[test]
684    fn test_bitmap_basic() {
685        let mut bm = AllowedBitmap::new();
686        bm.set(0);
687        bm.set(5);
688        bm.set(64);
689        bm.set(100);
690        
691        assert!(bm.contains(0));
692        assert!(bm.contains(5));
693        assert!(bm.contains(64));
694        assert!(bm.contains(100));
695        assert!(!bm.contains(1));
696        assert!(!bm.contains(63));
697        
698        assert_eq!(bm.count(), 4);
699    }
700    
701    #[test]
702    fn test_bitmap_from_ids() {
703        let ids = vec![1, 5, 10, 100, 1000];
704        let bm = AllowedBitmap::from_ids(&ids);
705        
706        for &id in &ids {
707            assert!(bm.contains(id));
708        }
709        assert!(!bm.contains(0));
710        assert!(!bm.contains(50));
711    }
712    
713    #[test]
714    fn test_bitmap_intersection() {
715        let a = AllowedBitmap::from_ids(&[1, 2, 3, 4, 5]);
716        let b = AllowedBitmap::from_ids(&[3, 4, 5, 6, 7]);
717        
718        let c = a.intersect(&b);
719        assert_eq!(c.count(), 3);
720        assert!(c.contains(3));
721        assert!(c.contains(4));
722        assert!(c.contains(5));
723    }
724    
725    #[test]
726    fn test_execution_strategy() {
727        struct DummyGate;
728        impl CandidateGate for DummyGate {
729            type Query = ();
730            type Result = ();
731            type Error = ();
732            fn execute_with_gate(&self, _: &(), _: &AllowedSet) -> Result<(), ()> {
733                Ok(())
734            }
735        }
736        
737        let gate = DummyGate;
738        assert_eq!(gate.strategy_for_selectivity(0.5), ExecutionStrategy::FilterDuringSearch);
739        assert_eq!(gate.strategy_for_selectivity(0.01), ExecutionStrategy::ScanAllowedIds);
740        assert_eq!(gate.strategy_for_selectivity(0.0001), ExecutionStrategy::LinearScan);
741    }
742}