stack_bitset/
lib.rs

1#![feature(generic_const_exprs)]
2
3use std::ops::{Add, Sub};
4
5/// Number of bits in `usize`.
6const USIZE_BITS: usize = usize::BITS as usize;
7
8/// Error type of the crate
9#[derive(Debug)]
10pub enum StackBitSetError {
11    IndexOutOfBounds,
12}
13
14/// Computes the number of `usize` chunks needed for a bitset of `n` elements.
15pub const fn usize_count(n: usize) -> usize {
16    (n / USIZE_BITS) + if n % USIZE_BITS == 0 { 0 } else { 1 }
17}
18
19pub const fn const_min(a: usize, b: usize) -> usize {
20    if a < b {
21        a
22    } else {
23        b
24    }
25}
26
27/// BitSet with compile-time size. It does not require any allocation
28/// and is entirely stored on the stack.
29///
30/// The only field is an array of `usize`. Each element is stored in a bit
31///
32/// # Examples
33///
34/// ```rust
35/// use stack_bitset::StackBitSet;
36///
37/// let mut a: StackBitSet<42> = StackBitSet::new();
38/// a.set(12).unwrap();
39/// assert!(a.get(12).unwrap());
40/// ```
41///
42#[derive(Clone, Copy, Debug)]
43pub struct StackBitSet<const N: usize>
44where
45    [(); usize_count(N)]: Sized,
46{
47    data: [usize; usize_count(N)],
48}
49
50impl<const N: usize> Default for StackBitSet<N>
51where
52    [(); usize_count(N)]: Sized,
53{
54    fn default() -> Self {
55        Self::new()
56    }
57}
58
59pub struct StackBitSetIterator<'a, const N: usize>
60where
61    [(); usize_count(N)]: Sized,
62{
63    index: usize,
64    limit: usize,
65    bitset: &'a StackBitSet<N>,
66}
67
68impl<'a, const N: usize> StackBitSetIterator<'a, N>
69where
70    [(); usize_count(N)]: Sized,
71{
72    pub fn new(bitset: &'a StackBitSet<N>) -> Self {
73        Self::new_limit(bitset, N)
74    }
75
76    pub fn new_limit(bitset: &'a StackBitSet<N>, limit: usize) -> Self {
77        Self {
78            index: 0,
79            limit,
80            bitset,
81        }
82    }
83}
84
85impl<'a, const N: usize> Iterator for StackBitSetIterator<'a, N>
86where
87    [(); usize_count(N)]: Sized,
88{
89    type Item = usize;
90
91    fn next(&mut self) -> Option<Self::Item> {
92        for i in self.index..const_min(N, self.limit) {
93            if self.bitset.get(i).unwrap() {
94                self.index = i + 1;
95                return Some(i);
96            }
97        }
98        None
99    }
100}
101
102impl<const N: usize> StackBitSet<N>
103where
104    [(); usize_count(N)]: Sized,
105{
106    /// Create a new empty instance of the bitset
107    pub fn new() -> Self {
108        StackBitSet {
109            data: [0usize; usize_count(N)],
110        }
111    }
112
113    pub fn iter(&self) -> StackBitSetIterator<N> {
114        StackBitSetIterator::new(self)
115    }
116
117    pub fn iter_limit(&self, limit: usize) -> StackBitSetIterator<N> {
118        StackBitSetIterator::new_limit(self, limit)
119    }
120
121    /// Returns whether the elements at index `idx` in the bitset is set
122    pub fn get(&self, idx: usize) -> Result<bool, StackBitSetError> {
123        if let Some(chunk) = self.data.get(idx / USIZE_BITS).filter(|_| idx < N) {
124            Ok(chunk & (1 << (idx % USIZE_BITS)) != 0)
125        } else {
126            Err(StackBitSetError::IndexOutOfBounds)
127        }
128    }
129
130    /// sets the elements at index `idx` in the bitset
131    pub fn set(&mut self, idx: usize) -> Result<(), StackBitSetError> {
132        if let Some(chunk) = self.data.get_mut(idx / USIZE_BITS).filter(|_| idx < N) {
133            *chunk |= 1 << (idx % USIZE_BITS);
134            Ok(())
135        } else {
136            Err(StackBitSetError::IndexOutOfBounds)
137        }
138    }
139
140    /// Resets the element at index `idx` in the bitset
141    pub fn reset(&mut self, idx: usize) -> Result<(), StackBitSetError> {
142        if let Some(chunk) = self.data.get_mut(idx / USIZE_BITS).filter(|_| idx < N) {
143            *chunk &= !(1 << (idx % USIZE_BITS));
144            Ok(())
145        } else {
146            Err(StackBitSetError::IndexOutOfBounds)
147        }
148    }
149}
150
151impl<const N: usize> StackBitSet<N>
152where
153    [(); usize_count(N)]: Sized,
154{
155    pub fn union<const M: usize>(&self, other: &StackBitSet<M>) -> StackBitSet<{ const_min(N, M) }>
156    where
157        [(); usize_count(M)]: Sized,
158        [(); usize_count(const_min(N, M))]: Sized,
159    {
160        let mut res = StackBitSet::new();
161        for i in self.iter_limit(M).chain(other.iter_limit(N)) {
162            res.set(i).unwrap();
163        }
164        res
165    }
166    pub fn intersection<const M: usize>(
167        &self,
168        other: &StackBitSet<M>,
169    ) -> StackBitSet<{ const_min(N, M) }>
170    where
171        [(); usize_count(M)]: Sized,
172        [(); usize_count(const_min(N, M))]: Sized,
173    {
174        let mut res = StackBitSet::new();
175        for i in self.iter_limit(M) {
176            if other.get(i).unwrap() {
177                res.set(i).unwrap();
178            }
179        }
180        res
181    }
182    pub fn difference<const M: usize>(
183        &self,
184        other: &StackBitSet<M>,
185    ) -> StackBitSet<{ const_min(N, M) }>
186    where
187        [(); usize_count(M)]: Sized,
188        [(); usize_count(const_min(N, M))]: Sized,
189    {
190        let mut res = StackBitSet::new();
191        for i in 0..(const_min(N, M)) {
192            if self.get(i).unwrap() {
193                res.set(i).unwrap();
194            }
195            if other.get(i).unwrap() {
196                res.reset(i).unwrap();
197            }
198        }
199        res
200    }
201    pub fn complement(&self) -> StackBitSet<N> {
202        let mut res = StackBitSet::new();
203        for i in 0..N {
204            if !self.get(i).unwrap() {
205                res.set(i).unwrap();
206            }
207        }
208        res
209    }
210    pub fn is_subset<const M: usize>(&self, other: &StackBitSet<M>) -> bool
211    where
212        [(); usize_count(M)]: Sized,
213    {
214        for i in 0..N {
215            if (i < M && (!other.get(i).unwrap() && self.get(i).unwrap()))
216                || (i >= M && self.get(i).unwrap())
217            {
218                return false;
219            }
220        }
221        !self.is_equal(other)
222    }
223    pub fn is_equal<const M: usize>(&self, other: &StackBitSet<M>) -> bool
224    where
225        [(); usize_count(M)]: Sized,
226    {
227        for i in 0..(N + M - const_min(N, M)) {
228            if i < N && i < M && (other.get(i).unwrap() ^ self.get(i).unwrap()) {
229                println!("1");
230                return false;
231            } else if i >= M && i < N && self.get(i).unwrap() {
232                println!("2");
233                return false;
234            } else if i >= N && i < M && other.get(i).unwrap() {
235                println!("3");
236                return false;
237            }
238        }
239        true
240    }
241    pub fn is_superset<const M: usize>(&self, other: &StackBitSet<M>) -> bool
242    where
243        [(); usize_count(M)]: Sized,
244    {
245        !self.is_equal(other) && !self.is_subset(other)
246    }
247}
248
249impl<const N: usize, const M: usize> Add<&StackBitSet<M>> for StackBitSet<N>
250where
251    [(); usize_count(N)]: Sized,
252    [(); usize_count(M)]: Sized,
253    [(); usize_count(const_min(N, M))]: Sized,
254{
255    type Output = StackBitSet<{ const_min(N, M) }>;
256
257    fn add(self, other: &StackBitSet<M>) -> Self::Output {
258        self.union(other)
259    }
260}
261
262macro_rules! add_impl {
263    ($($t:ty)*) => ($(
264
265        impl<const N: usize> Add<$t> for StackBitSet<N>
266where
267    [(); usize_count(N)]: Sized,
268{
269    type Output = StackBitSet<N>;
270
271    fn add(mut self, other: $t) -> StackBitSet<N> {
272        self.set(other as usize).unwrap();
273        self
274    }
275}
276    )*)
277}
278
279add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 }
280
281macro_rules! sub_impl {
282    ($($t:ty)*) => ($(
283
284        impl<const N: usize> Sub<$t> for StackBitSet<N>
285where
286    [(); usize_count(N)]: Sized,
287{
288    type Output = StackBitSet<N>;
289
290    fn sub(mut self, other: $t) -> StackBitSet<N> {
291        self.reset(other as usize).unwrap();
292        self
293    }
294}
295    )*)
296}
297
298sub_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 }
299
300#[cfg(test)]
301mod tests {
302    use crate::StackBitSet;
303    #[test]
304    fn bitset_create() {
305        let _a: StackBitSet<42> = StackBitSet::new();
306    }
307
308    #[test]
309    fn set_reset_bit() {
310        let mut a: StackBitSet<42> = StackBitSet::new();
311        assert!(!a.get(12).unwrap());
312        a.set(12).unwrap();
313        assert!(a.get(12).unwrap());
314        a.reset(12).unwrap();
315        assert!(!a.get(12).unwrap());
316    }
317
318    #[test]
319    fn equality() {
320        let mut a: StackBitSet<42> = StackBitSet::new();
321        let mut b: StackBitSet<69> = StackBitSet::new();
322        assert!(a.is_equal(&b));
323        a.set(12).unwrap();
324        assert!(!a.is_equal(&b));
325        b.set(12).unwrap();
326        assert!(a.is_equal(&b));
327    }
328
329    #[test]
330    fn union() {
331        let mut a: StackBitSet<42> = StackBitSet::new();
332        let mut b: StackBitSet<69> = StackBitSet::new();
333        a.set(12).unwrap();
334        b.set(29).unwrap();
335        let mut c: StackBitSet<37> = StackBitSet::new();
336        c.set(12).unwrap();
337        c.set(29).unwrap();
338        assert!(c.is_equal(&(a.union(&b))));
339        assert!(a.is_subset(&c));
340        assert!(b.is_subset(&c));
341        let d: StackBitSet<93> = StackBitSet::new();
342        assert!((c.intersection(&a)).intersection(&b).is_equal(&d));
343    }
344
345    #[test]
346    fn subset() {
347        let mut a: StackBitSet<42> = StackBitSet::new();
348        let mut b: StackBitSet<69> = StackBitSet::new();
349        a.set(12).unwrap();
350        b.set(12).unwrap();
351        b.set(29).unwrap();
352
353        assert!(a.is_subset(&b));
354        assert!(!b.is_subset(&a));
355        assert!(b.is_superset(&a));
356        assert!(!b.is_equal(&a));
357    }
358}