Skip to main content

virtual_frame/
bitmask.rs

1//! Packed bitmask — one bit per row, 64-bit words.
2//!
3//! The bitmask is the foundation of virtual-frame's zero-copy filtering.
4//! A million-row filter costs 122 KB of bitmask memory. The original data
5//! stays untouched. Chained filters AND their bitmasks together — still
6//! no data copied, still 122 KB.
7
8/// Packed bitmask over N rows. One bit per row, stored in 64-bit words.
9///
10/// Iteration uses `trailing_zeros()` for O(popcount) scanning —
11/// entirely-zero words are skipped in O(1).
12#[derive(Debug, Clone)]
13pub struct BitMask {
14    pub(crate) words: Vec<u64>,
15    pub(crate) nrows: usize,
16}
17
18impl BitMask {
19    /// All rows set (all-true mask).
20    pub fn all_true(nrows: usize) -> Self {
21        let nwords = nwords_for(nrows);
22        let mut words = vec![u64::MAX; nwords];
23        // Zero tail bits beyond nrows for determinism
24        if nrows % 64 != 0 && nwords > 0 {
25            let tail = nrows % 64;
26            words[nwords - 1] = (1u64 << tail) - 1;
27        }
28        BitMask { words, nrows }
29    }
30
31    /// No rows set (all-false mask).
32    pub fn all_false(nrows: usize) -> Self {
33        let nwords = nwords_for(nrows);
34        BitMask {
35            words: vec![0u64; nwords],
36            nrows,
37        }
38    }
39
40    /// Construct from a bool slice, one entry per row.
41    pub fn from_bools(bools: &[bool]) -> Self {
42        let nrows = bools.len();
43        let nwords = nwords_for(nrows);
44        let mut words = vec![0u64; nwords];
45        for (i, &b) in bools.iter().enumerate() {
46            if b {
47                words[i / 64] |= 1u64 << (i % 64);
48            }
49        }
50        BitMask { words, nrows }
51    }
52
53    /// Get bit at row `i`.
54    #[inline]
55    pub fn get(&self, i: usize) -> bool {
56        debug_assert!(i < self.nrows);
57        (self.words[i / 64] >> (i % 64)) & 1 == 1
58    }
59
60    /// Set bit at row `i`.
61    #[inline]
62    pub fn set(&mut self, i: usize) {
63        debug_assert!(i < self.nrows);
64        self.words[i / 64] |= 1u64 << (i % 64);
65    }
66
67    /// Clear bit at row `i`.
68    #[inline]
69    pub fn clear(&mut self, i: usize) {
70        debug_assert!(i < self.nrows);
71        self.words[i / 64] &= !(1u64 << (i % 64));
72    }
73
74    /// Number of set bits (visible rows).
75    pub fn count_ones(&self) -> usize {
76        self.words.iter().map(|w| w.count_ones() as usize).sum()
77    }
78
79    /// AND two masks together (chained filter semantics).
80    ///
81    /// Panics if nrows differs — this is a programming error (same base data).
82    pub fn and(&self, other: &BitMask) -> BitMask {
83        assert_eq!(
84            self.nrows, other.nrows,
85            "BitMask::and: nrows mismatch ({} vs {})",
86            self.nrows, other.nrows
87        );
88        let words = self
89            .words
90            .iter()
91            .zip(other.words.iter())
92            .map(|(a, b)| a & b)
93            .collect();
94        BitMask {
95            words,
96            nrows: self.nrows,
97        }
98    }
99
100    /// OR two masks together.
101    pub fn or(&self, other: &BitMask) -> BitMask {
102        assert_eq!(self.nrows, other.nrows);
103        let words = self
104            .words
105            .iter()
106            .zip(other.words.iter())
107            .map(|(a, b)| a | b)
108            .collect();
109        BitMask {
110            words,
111            nrows: self.nrows,
112        }
113    }
114
115    /// Iterate over set row indices in ascending order (deterministic).
116    ///
117    /// Uses word-level bit scanning with `trailing_zeros()` for O(popcount)
118    /// iteration. Entirely-zero words are skipped in O(1).
119    pub fn iter_set(&self) -> impl Iterator<Item = usize> + '_ {
120        self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
121            let base = word_idx * 64;
122            let mut w = word;
123            std::iter::from_fn(move || {
124                if w == 0 {
125                    return None;
126                }
127                let bit = w.trailing_zeros() as usize;
128                w &= w - 1; // clear lowest set bit
129                Some(base + bit)
130            })
131        })
132    }
133
134    /// Total rows this mask covers.
135    pub fn nrows(&self) -> usize {
136        self.nrows
137    }
138
139    /// Number of backing u64 words.
140    pub fn nwords(&self) -> usize {
141        self.words.len()
142    }
143
144    /// Memory size in bytes.
145    pub fn size_bytes(&self) -> usize {
146        self.words.len() * 8
147    }
148}
149
150/// Compute the number of u64 words needed for `nrows` bits.
151#[inline]
152pub fn nwords_for(nrows: usize) -> usize {
153    (nrows + 63) / 64
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_all_true() {
162        let mask = BitMask::all_true(100);
163        assert_eq!(mask.count_ones(), 100);
164        for i in 0..100 {
165            assert!(mask.get(i));
166        }
167    }
168
169    #[test]
170    fn test_all_false() {
171        let mask = BitMask::all_false(100);
172        assert_eq!(mask.count_ones(), 0);
173    }
174
175    #[test]
176    fn test_iter_set() {
177        let mask = BitMask::from_bools(&[true, false, true, false, true]);
178        let indices: Vec<usize> = mask.iter_set().collect();
179        assert_eq!(indices, vec![0, 2, 4]);
180    }
181
182    #[test]
183    fn test_and() {
184        let a = BitMask::from_bools(&[true, true, false, false]);
185        let b = BitMask::from_bools(&[true, false, true, false]);
186        let c = a.and(&b);
187        let indices: Vec<usize> = c.iter_set().collect();
188        assert_eq!(indices, vec![0]);
189    }
190
191    #[test]
192    fn test_iter_set_word_boundary() {
193        // Test across 64-bit word boundary
194        let mut bools = vec![false; 128];
195        bools[0] = true;
196        bools[63] = true;
197        bools[64] = true;
198        bools[127] = true;
199        let mask = BitMask::from_bools(&bools);
200        let indices: Vec<usize> = mask.iter_set().collect();
201        assert_eq!(indices, vec![0, 63, 64, 127]);
202    }
203
204    #[test]
205    fn test_size_bytes() {
206        let mask = BitMask::all_true(1_000_000);
207        // 1M rows = ceil(1M/64) = 15,625 words = 125,000 bytes ≈ 122 KB
208        assert_eq!(mask.size_bytes(), 125_000);
209    }
210}