Skip to main content

sochdb_query/
candidate_gate.rs

1// SPDX-License-Identifier: AGPL-3.0-or-later
2// SochDB - LLM-Optimized Embedded Database
3// Copyright (C) 2026 Sushanth Reddy Vanagala (https://github.com/sushanthpy)
4//
5// This program is free software: you can redistribute it and/or modify
6// it under the terms of the GNU Affero General Public License as published by
7// the Free Software Foundation, either version 3 of the License, or
8// (at your option) any later version.
9//
10// This program is distributed in the hope that it will be useful,
11// but WITHOUT ANY WARRANTY; without even the implied warranty of
12// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
13// GNU Affero General Public License for more details.
14//
15// You should have received a copy of the GNU Affero General Public License
16// along with this program. If not, see <https://www.gnu.org/licenses/>.
17
18//! Unified Candidate Gate Interface (Task 4)
19//!
20//! This module defines the `AllowedSet` abstraction that every retrieval
21//! executor MUST accept. The gate guarantees:
22//!
23//! 1. **Never return a doc outside AllowedSet** - structural enforcement
24//! 2. **Apply constraints during generation** - no post-filtering
25//! 3. **Consistent semantics** across vector/BM25/hybrid/context
26//!
27//! ## The Contract
28//!
29//! Every executor receives an `AllowedSet` and must:
30//! - Check membership BEFORE including any candidate
31//! - Short-circuit if AllowedSet is empty (return empty results)
32//! - Report selectivity for query planning
33//!
34//! ## Representations
35//!
36//! `AllowedSet` supports multiple representations for efficiency:
37//!
38//! | Representation | Best For | Membership | Space |
39//! |----------------|----------|------------|-------|
40//! | Bitmap | Dense sets | O(1) | O(N/8) |
41//! | SortedVec | Sparse sets | O(log n) | O(n) |
42//! | HashSet | Random access | O(1) avg | O(n) |
43//! | All | No constraint | O(1) | O(1) |
44//!
45//! ## Selectivity
46//!
47//! Executors use selectivity `|S|/N` to choose execution strategy:
48//! - High selectivity (> 0.1): Standard search with filter
49//! - Low selectivity (< 0.01): Scan only allowed IDs
50//! - Very low (< 0.001): Consider alternative strategy
51
52use std::collections::HashSet;
53use std::fmt;
54use std::sync::Arc;
55
56// ============================================================================
57// AllowedSet - Core Abstraction
58// ============================================================================
59
60/// The unified gate for candidate filtering
61///
62/// Every executor MUST check `allowed_set.contains(doc_id)` before
63/// including any result. This is the structural enforcement of pushdown.
64#[derive(Clone)]
65pub enum AllowedSet {
66    /// All documents are allowed (no filter constraint)
67    All,
68
69    /// Bitmap representation (efficient for dense sets)
70    Bitmap(Arc<AllowedBitmap>),
71
72    /// Sorted vector (efficient for sparse sets with iteration)
73    SortedVec(Arc<Vec<u64>>),
74
75    /// Hash set (efficient for random access)
76    HashSet(Arc<HashSet<u64>>),
77
78    /// No documents allowed (empty result shortcut)
79    None,
80}
81
82impl AllowedSet {
83    /// Create an AllowedSet from a bitmap
84    pub fn from_bitmap(bitmap: AllowedBitmap) -> Self {
85        if bitmap.is_empty() {
86            Self::None
87        } else if bitmap.is_all() {
88            Self::All
89        } else {
90            Self::Bitmap(Arc::new(bitmap))
91        }
92    }
93
94    /// Create an AllowedSet from a sorted vector of doc IDs
95    pub fn from_sorted_vec(mut ids: Vec<u64>) -> Self {
96        if ids.is_empty() {
97            return Self::None;
98        }
99        ids.sort_unstable();
100        ids.dedup();
101        Self::SortedVec(Arc::new(ids))
102    }
103
104    /// Create an AllowedSet from an iterator of doc IDs
105    pub fn from_iter(ids: impl IntoIterator<Item = u64>) -> Self {
106        let set: HashSet<u64> = ids.into_iter().collect();
107        if set.is_empty() {
108            Self::None
109        } else {
110            Self::HashSet(Arc::new(set))
111        }
112    }
113
114    /// Check if a document ID is allowed
115    ///
116    /// This is the core operation that executors MUST call.
117    #[inline]
118    pub fn contains(&self, doc_id: u64) -> bool {
119        match self {
120            Self::All => true,
121            Self::Bitmap(bm) => bm.contains(doc_id),
122            Self::SortedVec(vec) => vec.binary_search(&doc_id).is_ok(),
123            Self::HashSet(set) => set.contains(&doc_id),
124            Self::None => false,
125        }
126    }
127
128    /// Check if this set is empty (no allowed documents)
129    pub fn is_empty(&self) -> bool {
130        matches!(self, Self::None)
131    }
132
133    /// Check if this set allows all documents
134    pub fn is_all(&self) -> bool {
135        matches!(self, Self::All)
136    }
137
138    /// Get the cardinality (number of allowed documents)
139    ///
140    /// Returns None for All (unknown without universe size)
141    pub fn cardinality(&self) -> Option<usize> {
142        match self {
143            Self::All => None,
144            Self::Bitmap(bm) => Some(bm.count()),
145            Self::SortedVec(vec) => Some(vec.len()),
146            Self::HashSet(set) => Some(set.len()),
147            Self::None => Some(0),
148        }
149    }
150
151    /// Compute selectivity against a universe of size N
152    ///
153    /// Returns |S| / N, the fraction of allowed documents
154    pub fn selectivity(&self, universe_size: usize) -> f64 {
155        if universe_size == 0 {
156            return 0.0;
157        }
158        match self {
159            Self::All => 1.0,
160            Self::None => 0.0,
161            other => other
162                .cardinality()
163                .map(|c| c as f64 / universe_size as f64)
164                .unwrap_or(1.0),
165        }
166    }
167
168    /// Intersect with another AllowedSet
169    pub fn intersect(&self, other: &AllowedSet) -> AllowedSet {
170        match (self, other) {
171            // Identity cases
172            (Self::All, x) | (x, Self::All) => x.clone(),
173            (Self::None, _) | (_, Self::None) => Self::None,
174
175            // Both are sets - compute intersection
176            (Self::SortedVec(a), Self::SortedVec(b)) => {
177                let result = sorted_vec_intersect(a, b);
178                Self::from_sorted_vec(result)
179            }
180            (Self::HashSet(a), Self::HashSet(b)) => {
181                let result: HashSet<_> = a.intersection(b).copied().collect();
182                if result.is_empty() {
183                    Self::None
184                } else {
185                    Self::HashSet(Arc::new(result))
186                }
187            }
188            (Self::Bitmap(a), Self::Bitmap(b)) => {
189                let result = a.intersect(b);
190                Self::from_bitmap(result)
191            }
192
193            // Mixed - convert to hash set
194            (a, b) => {
195                let set_a: HashSet<u64> = a.iter().collect();
196                let set_b: HashSet<u64> = b.iter().collect();
197                let result: HashSet<_> = set_a.intersection(&set_b).copied().collect();
198                if result.is_empty() {
199                    Self::None
200                } else {
201                    Self::HashSet(Arc::new(result))
202                }
203            }
204        }
205    }
206
207    /// Union with another AllowedSet
208    pub fn union(&self, other: &AllowedSet) -> AllowedSet {
209        match (self, other) {
210            (Self::All, _) | (_, Self::All) => Self::All,
211            (Self::None, x) | (x, Self::None) => x.clone(),
212
213            (Self::HashSet(a), Self::HashSet(b)) => {
214                let result: HashSet<_> = a.union(b).copied().collect();
215                Self::HashSet(Arc::new(result))
216            }
217
218            // Mixed - convert to hash set
219            (a, b) => {
220                let mut result: HashSet<u64> = a.iter().collect();
221                result.extend(b.iter());
222                Self::HashSet(Arc::new(result))
223            }
224        }
225    }
226
227    /// Iterate over allowed document IDs
228    ///
229    /// Note: For All, this returns an empty iterator (unknown universe)
230    pub fn iter(&self) -> AllowedSetIter<'_> {
231        match self {
232            Self::All => AllowedSetIter::Empty,
233            Self::Bitmap(bm) => AllowedSetIter::Bitmap(bm.iter()),
234            Self::SortedVec(vec) => AllowedSetIter::SortedVec(vec.iter()),
235            Self::HashSet(set) => AllowedSetIter::HashSet(set.iter()),
236            Self::None => AllowedSetIter::Empty,
237        }
238    }
239
240    /// Convert to a Vec (for small sets)
241    pub fn to_vec(&self) -> Vec<u64> {
242        self.iter().collect()
243    }
244}
245
246impl fmt::Debug for AllowedSet {
247    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
248        match self {
249            Self::All => write!(f, "AllowedSet::All"),
250            Self::None => write!(f, "AllowedSet::None"),
251            Self::Bitmap(bm) => write!(f, "AllowedSet::Bitmap(count={})", bm.count()),
252            Self::SortedVec(vec) => write!(f, "AllowedSet::SortedVec(len={})", vec.len()),
253            Self::HashSet(set) => write!(f, "AllowedSet::HashSet(len={})", set.len()),
254        }
255    }
256}
257
258impl Default for AllowedSet {
259    fn default() -> Self {
260        Self::All
261    }
262}
263
264// Helper for sorted vec intersection
265fn sorted_vec_intersect(a: &[u64], b: &[u64]) -> Vec<u64> {
266    let mut result = Vec::with_capacity(a.len().min(b.len()));
267    let mut i = 0;
268    let mut j = 0;
269
270    while i < a.len() && j < b.len() {
271        match a[i].cmp(&b[j]) {
272            std::cmp::Ordering::Less => i += 1,
273            std::cmp::Ordering::Greater => j += 1,
274            std::cmp::Ordering::Equal => {
275                result.push(a[i]);
276                i += 1;
277                j += 1;
278            }
279        }
280    }
281
282    result
283}
284
285// ============================================================================
286// AllowedSet Iterator
287// ============================================================================
288
289/// Iterator over allowed document IDs
290pub enum AllowedSetIter<'a> {
291    Empty,
292    Bitmap(BitmapIter<'a>),
293    SortedVec(std::slice::Iter<'a, u64>),
294    HashSet(std::collections::hash_set::Iter<'a, u64>),
295}
296
297impl<'a> Iterator for AllowedSetIter<'a> {
298    type Item = u64;
299
300    fn next(&mut self) -> Option<Self::Item> {
301        match self {
302            Self::Empty => None,
303            Self::Bitmap(iter) => iter.next(),
304            Self::SortedVec(iter) => iter.next().copied(),
305            Self::HashSet(iter) => iter.next().copied(),
306        }
307    }
308
309    fn size_hint(&self) -> (usize, Option<usize>) {
310        match self {
311            Self::Empty => (0, Some(0)),
312            Self::Bitmap(iter) => iter.size_hint(),
313            Self::SortedVec(iter) => iter.size_hint(),
314            Self::HashSet(iter) => iter.size_hint(),
315        }
316    }
317}
318
319// ============================================================================
320// Bitmap Implementation
321// ============================================================================
322
323/// Simple bitmap for allowed document IDs
324///
325/// This is a basic implementation. For production, consider using
326/// the `roaring` crate for compressed bitmaps.
327pub struct AllowedBitmap {
328    /// Bits stored as u64 words
329    words: Vec<u64>,
330    /// Total number of set bits (cached)
331    count: usize,
332    /// Whether this represents "all" (complement mode)
333    all: bool,
334}
335
336impl AllowedBitmap {
337    /// Create a new empty bitmap
338    pub fn new() -> Self {
339        Self {
340            words: Vec::new(),
341            count: 0,
342            all: false,
343        }
344    }
345
346    /// Create a bitmap with all bits set up to max_id
347    pub fn all(max_id: u64) -> Self {
348        let word_count = (max_id as usize / 64) + 1;
349        Self {
350            words: vec![u64::MAX; word_count],
351            count: max_id as usize + 1,
352            all: true,
353        }
354    }
355
356    /// Create a bitmap from a set of IDs
357    pub fn from_ids(ids: &[u64]) -> Self {
358        if ids.is_empty() {
359            return Self::new();
360        }
361
362        let max_id = *ids.iter().max().unwrap();
363        let word_count = (max_id as usize / 64) + 1;
364        let mut words = vec![0u64; word_count];
365
366        for &id in ids {
367            let word_idx = id as usize / 64;
368            let bit_idx = id % 64;
369            words[word_idx] |= 1 << bit_idx;
370        }
371
372        Self {
373            words,
374            count: ids.len(),
375            all: false,
376        }
377    }
378
379    /// Set a bit
380    pub fn set(&mut self, id: u64) {
381        let word_idx = id as usize / 64;
382        let bit_idx = id % 64;
383
384        // Extend if necessary
385        if word_idx >= self.words.len() {
386            self.words.resize(word_idx + 1, 0);
387        }
388
389        let old = self.words[word_idx];
390        self.words[word_idx] |= 1 << bit_idx;
391        if old != self.words[word_idx] {
392            self.count += 1;
393        }
394    }
395
396    /// Clear a bit
397    pub fn clear(&mut self, id: u64) {
398        let word_idx = id as usize / 64;
399        if word_idx >= self.words.len() {
400            return;
401        }
402
403        let bit_idx = id % 64;
404        let old = self.words[word_idx];
405        self.words[word_idx] &= !(1 << bit_idx);
406        if old != self.words[word_idx] {
407            self.count -= 1;
408        }
409    }
410
411    /// Check if a bit is set
412    #[inline]
413    pub fn contains(&self, id: u64) -> bool {
414        let word_idx = id as usize / 64;
415        if word_idx >= self.words.len() {
416            return false;
417        }
418        let bit_idx = id % 64;
419        (self.words[word_idx] & (1 << bit_idx)) != 0
420    }
421
422    /// Get the count of set bits
423    pub fn count(&self) -> usize {
424        self.count
425    }
426
427    /// Check if empty
428    pub fn is_empty(&self) -> bool {
429        self.count == 0
430    }
431
432    /// Check if all bits are set
433    pub fn is_all(&self) -> bool {
434        self.all
435    }
436
437    /// Intersect with another bitmap
438    pub fn intersect(&self, other: &AllowedBitmap) -> AllowedBitmap {
439        let min_len = self.words.len().min(other.words.len());
440        let mut words = Vec::with_capacity(min_len);
441        let mut count = 0;
442
443        for i in 0..min_len {
444            let word = self.words[i] & other.words[i];
445            count += word.count_ones() as usize;
446            words.push(word);
447        }
448
449        AllowedBitmap {
450            words,
451            count,
452            all: false,
453        }
454    }
455
456    /// Union with another bitmap
457    pub fn union(&self, other: &AllowedBitmap) -> AllowedBitmap {
458        let max_len = self.words.len().max(other.words.len());
459        let mut words = Vec::with_capacity(max_len);
460        let mut count = 0;
461
462        for i in 0..max_len {
463            let a = self.words.get(i).copied().unwrap_or(0);
464            let b = other.words.get(i).copied().unwrap_or(0);
465            let word = a | b;
466            count += word.count_ones() as usize;
467            words.push(word);
468        }
469
470        AllowedBitmap {
471            words,
472            count,
473            all: false,
474        }
475    }
476
477    /// Iterate over set bit positions
478    pub fn iter(&self) -> BitmapIter<'_> {
479        BitmapIter {
480            words: &self.words,
481            word_idx: 0,
482            bit_offset: 0,
483            remaining: self.count,
484        }
485    }
486}
487
488impl Default for AllowedBitmap {
489    fn default() -> Self {
490        Self::new()
491    }
492}
493
494/// Iterator over set bits in a bitmap
495pub struct BitmapIter<'a> {
496    words: &'a [u64],
497    word_idx: usize,
498    bit_offset: u64,
499    remaining: usize,
500}
501
502impl<'a> Iterator for BitmapIter<'a> {
503    type Item = u64;
504
505    fn next(&mut self) -> Option<Self::Item> {
506        if self.remaining == 0 {
507            return None;
508        }
509
510        while self.word_idx < self.words.len() {
511            let word = self.words[self.word_idx];
512            let masked = word >> self.bit_offset;
513
514            if masked != 0 {
515                let trailing = masked.trailing_zeros() as u64;
516                let bit_pos = self.bit_offset + trailing;
517                self.bit_offset = bit_pos + 1;
518
519                if self.bit_offset >= 64 {
520                    self.bit_offset = 0;
521                    self.word_idx += 1;
522                }
523
524                self.remaining -= 1;
525                return Some(
526                    self.word_idx as u64 * 64 + bit_pos
527                        - (if self.bit_offset == 0 { 64 } else { 0 })
528                        + (if bit_pos >= 64 { 0 } else { bit_pos }),
529                );
530            }
531
532            self.word_idx += 1;
533            self.bit_offset = 0;
534        }
535
536        None
537    }
538
539    fn size_hint(&self) -> (usize, Option<usize>) {
540        (self.remaining, Some(self.remaining))
541    }
542}
543
544// Fix the iterator - simpler implementation
545impl<'a> BitmapIter<'a> {
546    #[allow(dead_code)]
547    fn new(words: &'a [u64], count: usize) -> Self {
548        Self {
549            words,
550            word_idx: 0,
551            bit_offset: 0,
552            remaining: count,
553        }
554    }
555}
556
557// Simple correct iterator implementation
558impl AllowedBitmap {
559    /// Iterate over set bit positions (simple implementation)
560    pub fn iter_simple(&self) -> impl Iterator<Item = u64> + '_ {
561        self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
562            (0..64).filter_map(move |bit| {
563                if (word & (1 << bit)) != 0 {
564                    Some(word_idx as u64 * 64 + bit as u64)
565                } else {
566                    None
567                }
568            })
569        })
570    }
571}
572
573// ============================================================================
574// Candidate Gate Trait
575// ============================================================================
576
577/// The candidate gate trait that all executors must implement
578///
579/// This trait ensures every retrieval path respects the AllowedSet.
580pub trait CandidateGate {
581    /// The query type
582    type Query;
583
584    /// The result type  
585    type Result;
586
587    /// The error type
588    type Error;
589
590    /// Execute with a mandatory allowed set
591    ///
592    /// # Contract
593    ///
594    /// - MUST NOT return any result with doc_id not in allowed_set
595    /// - SHOULD short-circuit if allowed_set is empty
596    /// - SHOULD use selectivity to choose execution strategy
597    fn execute_with_gate(
598        &self,
599        query: &Self::Query,
600        allowed_set: &AllowedSet,
601    ) -> Result<Self::Result, Self::Error>;
602
603    /// Get the execution strategy for a given selectivity
604    fn strategy_for_selectivity(&self, selectivity: f64) -> ExecutionStrategy {
605        if selectivity >= 0.1 {
606            ExecutionStrategy::FilterDuringSearch
607        } else if selectivity >= 0.001 {
608            ExecutionStrategy::ScanAllowedIds
609        } else {
610            ExecutionStrategy::LinearScan
611        }
612    }
613}
614
615/// Execution strategy based on selectivity
616#[derive(Debug, Clone, Copy, PartialEq, Eq)]
617pub enum ExecutionStrategy {
618    /// Standard search with filter check during traversal
619    FilterDuringSearch,
620
621    /// Iterate over allowed IDs and compute distances
622    ScanAllowedIds,
623
624    /// Fall back to linear scan (very low selectivity)
625    LinearScan,
626
627    /// Refuse to execute (too expensive)
628    Reject,
629}
630
631// ============================================================================
632// Tests
633// ============================================================================
634
635#[cfg(test)]
636mod tests {
637    use super::*;
638
639    #[test]
640    fn test_allowed_set_contains() {
641        // All
642        let all = AllowedSet::All;
643        assert!(all.contains(0));
644        assert!(all.contains(1000000));
645
646        // None
647        let none = AllowedSet::None;
648        assert!(!none.contains(0));
649
650        // SortedVec
651        let vec = AllowedSet::from_sorted_vec(vec![1, 3, 5, 7, 9]);
652        assert!(vec.contains(1));
653        assert!(vec.contains(5));
654        assert!(!vec.contains(2));
655        assert!(!vec.contains(10));
656
657        // HashSet
658        let set = AllowedSet::from_iter([1, 3, 5, 7, 9]);
659        assert!(set.contains(1));
660        assert!(set.contains(5));
661        assert!(!set.contains(2));
662    }
663
664    #[test]
665    fn test_allowed_set_selectivity() {
666        let set = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
667
668        assert_eq!(set.selectivity(100), 0.05);
669        assert_eq!(set.selectivity(10), 0.5);
670
671        assert_eq!(AllowedSet::All.selectivity(100), 1.0);
672        assert_eq!(AllowedSet::None.selectivity(100), 0.0);
673    }
674
675    #[test]
676    fn test_allowed_set_intersection() {
677        let a = AllowedSet::from_sorted_vec(vec![1, 2, 3, 4, 5]);
678        let b = AllowedSet::from_sorted_vec(vec![3, 4, 5, 6, 7]);
679
680        let c = a.intersect(&b);
681        assert_eq!(c.cardinality(), Some(3));
682        assert!(c.contains(3));
683        assert!(c.contains(4));
684        assert!(c.contains(5));
685        assert!(!c.contains(1));
686        assert!(!c.contains(7));
687    }
688
689    #[test]
690    fn test_bitmap_basic() {
691        let mut bm = AllowedBitmap::new();
692        bm.set(0);
693        bm.set(5);
694        bm.set(64);
695        bm.set(100);
696
697        assert!(bm.contains(0));
698        assert!(bm.contains(5));
699        assert!(bm.contains(64));
700        assert!(bm.contains(100));
701        assert!(!bm.contains(1));
702        assert!(!bm.contains(63));
703
704        assert_eq!(bm.count(), 4);
705    }
706
707    #[test]
708    fn test_bitmap_from_ids() {
709        let ids = vec![1, 5, 10, 100, 1000];
710        let bm = AllowedBitmap::from_ids(&ids);
711
712        for &id in &ids {
713            assert!(bm.contains(id));
714        }
715        assert!(!bm.contains(0));
716        assert!(!bm.contains(50));
717    }
718
719    #[test]
720    fn test_bitmap_intersection() {
721        let a = AllowedBitmap::from_ids(&[1, 2, 3, 4, 5]);
722        let b = AllowedBitmap::from_ids(&[3, 4, 5, 6, 7]);
723
724        let c = a.intersect(&b);
725        assert_eq!(c.count(), 3);
726        assert!(c.contains(3));
727        assert!(c.contains(4));
728        assert!(c.contains(5));
729    }
730
731    #[test]
732    fn test_execution_strategy() {
733        struct DummyGate;
734        impl CandidateGate for DummyGate {
735            type Query = ();
736            type Result = ();
737            type Error = ();
738            fn execute_with_gate(&self, _: &(), _: &AllowedSet) -> Result<(), ()> {
739                Ok(())
740            }
741        }
742
743        let gate = DummyGate;
744        assert_eq!(
745            gate.strategy_for_selectivity(0.5),
746            ExecutionStrategy::FilterDuringSearch
747        );
748        assert_eq!(
749            gate.strategy_for_selectivity(0.01),
750            ExecutionStrategy::ScanAllowedIds
751        );
752        assert_eq!(
753            gate.strategy_for_selectivity(0.0001),
754            ExecutionStrategy::LinearScan
755        );
756    }
757}