Skip to main content

sochdb_vector/
filter.rs

1//! Filter support with bitsets and filter-aware widening.
2
3use crate::types::*;
4
5/// Bitset filter for vector IDs
6#[derive(Debug, Clone)]
7pub struct BitsetFilter {
8    /// Bitset words (64 bits each)
9    bits: Vec<u64>,
10    /// Number of vectors covered
11    n_vec: u32,
12    /// Cached population count
13    popcount: Option<u32>,
14}
15
16impl BitsetFilter {
17    /// Create an empty filter (all vectors excluded)
18    pub fn new_empty(n_vec: u32) -> Self {
19        let num_words = (n_vec as usize + 63) / 64;
20        Self {
21            bits: vec![0; num_words],
22            n_vec,
23            popcount: Some(0),
24        }
25    }
26
27    /// Create a full filter (all vectors included)
28    pub fn new_full(n_vec: u32) -> Self {
29        let num_words = (n_vec as usize + 63) / 64;
30        let mut bits = vec![u64::MAX; num_words];
31
32        // Clear bits beyond n_vec
33        let remainder = n_vec as usize % 64;
34        if remainder > 0 && !bits.is_empty() {
35            bits[num_words - 1] = (1u64 << remainder) - 1;
36        }
37
38        Self {
39            bits,
40            n_vec,
41            popcount: Some(n_vec),
42        }
43    }
44
45    /// Create from a list of included vector IDs
46    pub fn from_ids(n_vec: u32, ids: &[VectorId]) -> Self {
47        let mut filter = Self::new_empty(n_vec);
48        for &id in ids {
49            filter.set(id);
50        }
51        filter.popcount = None; // Invalidate cache
52        filter
53    }
54
55    /// Set a bit (include vector)
56    #[inline]
57    pub fn set(&mut self, id: VectorId) {
58        if id < self.n_vec {
59            let word_idx = id as usize / 64;
60            let bit_idx = id as usize % 64;
61            self.bits[word_idx] |= 1u64 << bit_idx;
62            self.popcount = None;
63        }
64    }
65
66    /// Clear a bit (exclude vector)
67    #[inline]
68    pub fn clear(&mut self, id: VectorId) {
69        if id < self.n_vec {
70            let word_idx = id as usize / 64;
71            let bit_idx = id as usize % 64;
72            self.bits[word_idx] &= !(1u64 << bit_idx);
73            self.popcount = None;
74        }
75    }
76
77    /// Check if a vector is included
78    #[inline]
79    pub fn contains(&self, id: VectorId) -> bool {
80        if id >= self.n_vec {
81            return false;
82        }
83        let word_idx = id as usize / 64;
84        let bit_idx = id as usize % 64;
85        (self.bits[word_idx] & (1u64 << bit_idx)) != 0
86    }
87
88    /// Count of included vectors
89    pub fn count(&mut self) -> u32 {
90        if let Some(c) = self.popcount {
91            return c;
92        }
93        let c = self.bits.iter().map(|w| w.count_ones()).sum();
94        self.popcount = Some(c);
95        c
96    }
97
98    /// Get selectivity (fraction of vectors included)
99    pub fn selectivity(&mut self) -> f32 {
100        if self.n_vec == 0 {
101            return 0.0;
102        }
103        self.count() as f32 / self.n_vec as f32
104    }
105
106    /// AND with another filter
107    pub fn and(&self, other: &BitsetFilter) -> BitsetFilter {
108        assert_eq!(self.n_vec, other.n_vec);
109        let bits: Vec<u64> = self
110            .bits
111            .iter()
112            .zip(other.bits.iter())
113            .map(|(&a, &b)| a & b)
114            .collect();
115        BitsetFilter {
116            bits,
117            n_vec: self.n_vec,
118            popcount: None,
119        }
120    }
121
122    /// OR with another filter
123    pub fn or(&self, other: &BitsetFilter) -> BitsetFilter {
124        assert_eq!(self.n_vec, other.n_vec);
125        let bits: Vec<u64> = self
126            .bits
127            .iter()
128            .zip(other.bits.iter())
129            .map(|(&a, &b)| a | b)
130            .collect();
131        BitsetFilter {
132            bits,
133            n_vec: self.n_vec,
134            popcount: None,
135        }
136    }
137
138    /// NOT (invert filter)
139    pub fn not(&self) -> BitsetFilter {
140        let num_words = self.bits.len();
141        let mut bits: Vec<u64> = self.bits.iter().map(|&w| !w).collect();
142
143        // Clear bits beyond n_vec
144        let remainder = self.n_vec as usize % 64;
145        if remainder > 0 && !bits.is_empty() {
146            bits[num_words - 1] &= (1u64 << remainder) - 1;
147        }
148
149        BitsetFilter {
150            bits,
151            n_vec: self.n_vec,
152            popcount: None,
153        }
154    }
155
156    /// Apply filter to candidates, returning only included ones
157    pub fn filter_candidates(&self, candidates: &[ScoredCandidate]) -> Vec<ScoredCandidate> {
158        candidates
159            .iter()
160            .filter(|c| self.contains(c.id))
161            .copied()
162            .collect()
163    }
164
165    /// Get raw bits
166    pub fn bits(&self) -> &[u64] {
167        &self.bits
168    }
169
170    /// Get number of vectors
171    pub fn n_vec(&self) -> u32 {
172        self.n_vec
173    }
174}
175
176/// Compute filter-aware widening factor
177/// If selectivity is `s`, we need ~1/s more candidates before filtering
178pub fn compute_widening_factor(selectivity: f32, max_factor: f32) -> f32 {
179    if selectivity <= 0.0 {
180        return max_factor;
181    }
182    (1.0 / selectivity).min(max_factor)
183}
184
185/// Apply filter-aware widening to candidate count
186pub fn widen_for_filter(base_count: usize, selectivity: f32, max_factor: f32) -> usize {
187    let factor = compute_widening_factor(selectivity, max_factor);
188    ((base_count as f32) * factor).ceil() as usize
189}
190
191/// Combine a tombstone bitset with a filter
192pub fn apply_tombstones(filter: &BitsetFilter, tombstones: &[u64]) -> BitsetFilter {
193    let mut result = filter.clone();
194    for (i, &tombstone_word) in tombstones.iter().enumerate() {
195        if i < result.bits.len() {
196            // Clear bits that are tombstoned
197            result.bits[i] &= !tombstone_word;
198        }
199    }
200    result.popcount = None;
201    result
202}
203
204#[cfg(test)]
205mod tests {
206    use super::*;
207
208    #[test]
209    fn test_bitset_basic() {
210        let mut filter = BitsetFilter::new_empty(100);
211
212        assert!(!filter.contains(0));
213        filter.set(0);
214        assert!(filter.contains(0));
215
216        filter.set(50);
217        filter.set(99);
218        assert_eq!(filter.count(), 3);
219
220        filter.clear(50);
221        assert!(!filter.contains(50));
222        assert_eq!(filter.count(), 2);
223    }
224
225    #[test]
226    fn test_bitset_full() {
227        let mut filter = BitsetFilter::new_full(100);
228        assert!(filter.contains(0));
229        assert!(filter.contains(99));
230        assert!(!filter.contains(100));
231        assert_eq!(filter.count(), 100);
232    }
233
234    #[test]
235    fn test_bitset_and() {
236        let mut a = BitsetFilter::new_empty(100);
237        a.set(0);
238        a.set(1);
239        a.set(2);
240
241        let mut b = BitsetFilter::new_empty(100);
242        b.set(1);
243        b.set(2);
244        b.set(3);
245
246        let mut c = a.and(&b);
247        assert!(!c.contains(0));
248        assert!(c.contains(1));
249        assert!(c.contains(2));
250        assert!(!c.contains(3));
251        assert_eq!(c.count(), 2);
252    }
253
254    #[test]
255    fn test_bitset_or() {
256        let mut a = BitsetFilter::new_empty(100);
257        a.set(0);
258        a.set(1);
259
260        let mut b = BitsetFilter::new_empty(100);
261        b.set(1);
262        b.set(2);
263
264        let mut c = a.or(&b);
265        assert!(c.contains(0));
266        assert!(c.contains(1));
267        assert!(c.contains(2));
268        assert!(!c.contains(3));
269        assert_eq!(c.count(), 3);
270    }
271
272    #[test]
273    fn test_selectivity() {
274        let mut filter = BitsetFilter::new_empty(1000);
275        for i in 0..100 {
276            filter.set(i);
277        }
278
279        let selectivity = filter.selectivity();
280        assert!((selectivity - 0.1).abs() < 0.001);
281    }
282
283    #[test]
284    fn test_widening_factor() {
285        // 10% selectivity -> 10x widening (capped)
286        assert!((compute_widening_factor(0.1, 20.0) - 10.0).abs() < 0.001);
287
288        // 1% selectivity -> capped at max
289        assert!((compute_widening_factor(0.01, 20.0) - 20.0).abs() < 0.001);
290
291        // 100% selectivity -> no widening
292        assert!((compute_widening_factor(1.0, 20.0) - 1.0).abs() < 0.001);
293    }
294
295    #[test]
296    fn test_filter_candidates() {
297        let mut filter = BitsetFilter::new_empty(100);
298        filter.set(1);
299        filter.set(3);
300        filter.set(5);
301
302        let candidates = vec![
303            ScoredCandidate { id: 0, score: 1.0 },
304            ScoredCandidate { id: 1, score: 2.0 },
305            ScoredCandidate { id: 2, score: 3.0 },
306            ScoredCandidate { id: 3, score: 4.0 },
307        ];
308
309        let filtered = filter.filter_candidates(&candidates);
310        assert_eq!(filtered.len(), 2);
311        assert_eq!(filtered[0].id, 1);
312        assert_eq!(filtered[1].id, 3);
313    }
314}