1#[derive(Debug, Clone)]
13pub struct BitMask {
14 pub(crate) words: Vec<u64>,
15 pub(crate) nrows: usize,
16}
17
18impl BitMask {
19 pub fn all_true(nrows: usize) -> Self {
21 let nwords = nwords_for(nrows);
22 let mut words = vec![u64::MAX; nwords];
23 if nrows % 64 != 0 && nwords > 0 {
25 let tail = nrows % 64;
26 words[nwords - 1] = (1u64 << tail) - 1;
27 }
28 BitMask { words, nrows }
29 }
30
31 pub fn all_false(nrows: usize) -> Self {
33 let nwords = nwords_for(nrows);
34 BitMask {
35 words: vec![0u64; nwords],
36 nrows,
37 }
38 }
39
40 pub fn from_bools(bools: &[bool]) -> Self {
42 let nrows = bools.len();
43 let nwords = nwords_for(nrows);
44 let mut words = vec![0u64; nwords];
45 for (i, &b) in bools.iter().enumerate() {
46 if b {
47 words[i / 64] |= 1u64 << (i % 64);
48 }
49 }
50 BitMask { words, nrows }
51 }
52
53 #[inline]
55 pub fn get(&self, i: usize) -> bool {
56 debug_assert!(i < self.nrows);
57 (self.words[i / 64] >> (i % 64)) & 1 == 1
58 }
59
60 #[inline]
62 pub fn set(&mut self, i: usize) {
63 debug_assert!(i < self.nrows);
64 self.words[i / 64] |= 1u64 << (i % 64);
65 }
66
67 #[inline]
69 pub fn clear(&mut self, i: usize) {
70 debug_assert!(i < self.nrows);
71 self.words[i / 64] &= !(1u64 << (i % 64));
72 }
73
74 pub fn count_ones(&self) -> usize {
76 self.words.iter().map(|w| w.count_ones() as usize).sum()
77 }
78
79 pub fn and(&self, other: &BitMask) -> BitMask {
83 assert_eq!(
84 self.nrows, other.nrows,
85 "BitMask::and: nrows mismatch ({} vs {})",
86 self.nrows, other.nrows
87 );
88 let words = self
89 .words
90 .iter()
91 .zip(other.words.iter())
92 .map(|(a, b)| a & b)
93 .collect();
94 BitMask {
95 words,
96 nrows: self.nrows,
97 }
98 }
99
100 pub fn or(&self, other: &BitMask) -> BitMask {
102 assert_eq!(self.nrows, other.nrows);
103 let words = self
104 .words
105 .iter()
106 .zip(other.words.iter())
107 .map(|(a, b)| a | b)
108 .collect();
109 BitMask {
110 words,
111 nrows: self.nrows,
112 }
113 }
114
115 pub fn iter_set(&self) -> impl Iterator<Item = usize> + '_ {
120 self.words.iter().enumerate().flat_map(|(word_idx, &word)| {
121 let base = word_idx * 64;
122 let mut w = word;
123 std::iter::from_fn(move || {
124 if w == 0 {
125 return None;
126 }
127 let bit = w.trailing_zeros() as usize;
128 w &= w - 1; Some(base + bit)
130 })
131 })
132 }
133
134 pub fn nrows(&self) -> usize {
136 self.nrows
137 }
138
139 pub fn nwords(&self) -> usize {
141 self.words.len()
142 }
143
144 pub fn size_bytes(&self) -> usize {
146 self.words.len() * 8
147 }
148}
149
150#[inline]
152pub fn nwords_for(nrows: usize) -> usize {
153 (nrows + 63) / 64
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn test_all_true() {
162 let mask = BitMask::all_true(100);
163 assert_eq!(mask.count_ones(), 100);
164 for i in 0..100 {
165 assert!(mask.get(i));
166 }
167 }
168
169 #[test]
170 fn test_all_false() {
171 let mask = BitMask::all_false(100);
172 assert_eq!(mask.count_ones(), 0);
173 }
174
175 #[test]
176 fn test_iter_set() {
177 let mask = BitMask::from_bools(&[true, false, true, false, true]);
178 let indices: Vec<usize> = mask.iter_set().collect();
179 assert_eq!(indices, vec![0, 2, 4]);
180 }
181
182 #[test]
183 fn test_and() {
184 let a = BitMask::from_bools(&[true, true, false, false]);
185 let b = BitMask::from_bools(&[true, false, true, false]);
186 let c = a.and(&b);
187 let indices: Vec<usize> = c.iter_set().collect();
188 assert_eq!(indices, vec![0]);
189 }
190
191 #[test]
192 fn test_iter_set_word_boundary() {
193 let mut bools = vec![false; 128];
195 bools[0] = true;
196 bools[63] = true;
197 bools[64] = true;
198 bools[127] = true;
199 let mask = BitMask::from_bools(&bools);
200 let indices: Vec<usize> = mask.iter_set().collect();
201 assert_eq!(indices, vec![0, 63, 64, 127]);
202 }
203
204 #[test]
205 fn test_size_bytes() {
206 let mask = BitMask::all_true(1_000_000);
207 assert_eq!(mask.size_bytes(), 125_000);
209 }
210}