1use std::fmt::{self, Debug, Formatter};
2
3use thin_vec::ThinVec;
4
5const BITS: usize = usize::BITS as usize;
7
8#[derive(Clone, PartialEq, Hash)]
17pub struct BitSet(ThinVec<usize>);
18
19impl BitSet {
20    pub fn new() -> Self {
22        Self(ThinVec::new())
23    }
24
25    pub fn insert(&mut self, value: usize) {
27        let chunk = value / BITS;
28        let within = value % BITS;
29        if chunk >= self.0.len() {
30            self.0.resize(chunk + 1, 0);
31        }
32        self.0[chunk] |= 1 << within;
33    }
34
35    pub fn contains(&self, value: usize) -> bool {
37        let chunk = value / BITS;
38        let within = value % BITS;
39        let Some(bits) = self.0.get(chunk) else { return false };
40        (bits & (1 << within)) != 0
41    }
42}
43
44impl Default for BitSet {
45    fn default() -> Self {
46        Self::new()
47    }
48}
49
50impl Debug for BitSet {
51    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
52        let mut list = f.debug_list();
53        let chunks = self.0.len();
54        for v in 0..chunks * BITS {
55            if self.contains(v) {
56                list.entry(&v);
57            }
58        }
59        list.finish()
60    }
61}
62
63#[derive(Clone, PartialEq, Hash)]
67pub struct SmallBitSet {
68    low: usize,
70    hi: BitSet,
72}
73
74impl SmallBitSet {
75    pub fn new() -> Self {
77        Self { low: 0, hi: BitSet::new() }
78    }
79
80    pub fn insert(&mut self, value: usize) {
82        if value < BITS {
83            self.low |= 1 << value;
84        } else {
85            self.hi.insert(value - BITS);
86        }
87    }
88
89    pub fn contains(&self, value: usize) -> bool {
91        if value < BITS {
92            (self.low & (1 << value)) != 0
93        } else {
94            self.hi.contains(value - BITS)
95        }
96    }
97}
98
99impl Default for SmallBitSet {
100    fn default() -> Self {
101        Self::new()
102    }
103}
104
105impl Debug for SmallBitSet {
106    fn fmt(&self, f: &mut Formatter) -> fmt::Result {
107        let mut list = f.debug_list();
108        let chunks = 1 + self.hi.0.len();
109        for v in 0..chunks * BITS {
110            if self.contains(v) {
111                list.entry(&v);
112            }
113        }
114        list.finish()
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_bitset() {
124        let mut set = SmallBitSet::new();
125        assert!(!set.contains(0));
126        assert!(!set.contains(5));
127        set.insert(0);
128        set.insert(1);
129        set.insert(5);
130        set.insert(64);
131        set.insert(105);
132        set.insert(208);
133        assert!(set.contains(0));
134        assert!(set.contains(1));
135        assert!(!set.contains(2));
136        assert!(set.contains(5));
137        assert!(!set.contains(63));
138        assert!(set.contains(64));
139        assert!(!set.contains(65));
140        assert!(!set.contains(104));
141        assert!(set.contains(105));
142        assert!(!set.contains(106));
143        assert!(set.contains(208));
144        assert!(!set.contains(209));
145        assert_eq!(format!("{set:?}"), "[0, 1, 5, 64, 105, 208]");
146    }
147}