1#![feature(generic_const_exprs)]
2
3use std::ops::{Add, Sub};
4
5const USIZE_BITS: usize = usize::BITS as usize;
7
8#[derive(Debug)]
10pub enum StackBitSetError {
11 IndexOutOfBounds,
12}
13
14pub const fn usize_count(n: usize) -> usize {
16 (n / USIZE_BITS) + if n % USIZE_BITS == 0 { 0 } else { 1 }
17}
18
19pub const fn const_min(a: usize, b: usize) -> usize {
20 if a < b {
21 a
22 } else {
23 b
24 }
25}
26
27#[derive(Clone, Copy, Debug)]
43pub struct StackBitSet<const N: usize>
44where
45 [(); usize_count(N)]: Sized,
46{
47 data: [usize; usize_count(N)],
48}
49
50impl<const N: usize> Default for StackBitSet<N>
51where
52 [(); usize_count(N)]: Sized,
53{
54 fn default() -> Self {
55 Self::new()
56 }
57}
58
59pub struct StackBitSetIterator<'a, const N: usize>
60where
61 [(); usize_count(N)]: Sized,
62{
63 index: usize,
64 limit: usize,
65 bitset: &'a StackBitSet<N>,
66}
67
68impl<'a, const N: usize> StackBitSetIterator<'a, N>
69where
70 [(); usize_count(N)]: Sized,
71{
72 pub fn new(bitset: &'a StackBitSet<N>) -> Self {
73 Self::new_limit(bitset, N)
74 }
75
76 pub fn new_limit(bitset: &'a StackBitSet<N>, limit: usize) -> Self {
77 Self {
78 index: 0,
79 limit,
80 bitset,
81 }
82 }
83}
84
85impl<'a, const N: usize> Iterator for StackBitSetIterator<'a, N>
86where
87 [(); usize_count(N)]: Sized,
88{
89 type Item = usize;
90
91 fn next(&mut self) -> Option<Self::Item> {
92 for i in self.index..const_min(N, self.limit) {
93 if self.bitset.get(i).unwrap() {
94 self.index = i + 1;
95 return Some(i);
96 }
97 }
98 None
99 }
100}
101
102impl<const N: usize> StackBitSet<N>
103where
104 [(); usize_count(N)]: Sized,
105{
106 pub fn new() -> Self {
108 StackBitSet {
109 data: [0usize; usize_count(N)],
110 }
111 }
112
113 pub fn iter(&self) -> StackBitSetIterator<N> {
114 StackBitSetIterator::new(self)
115 }
116
117 pub fn iter_limit(&self, limit: usize) -> StackBitSetIterator<N> {
118 StackBitSetIterator::new_limit(self, limit)
119 }
120
121 pub fn get(&self, idx: usize) -> Result<bool, StackBitSetError> {
123 if let Some(chunk) = self.data.get(idx / USIZE_BITS).filter(|_| idx < N) {
124 Ok(chunk & (1 << (idx % USIZE_BITS)) != 0)
125 } else {
126 Err(StackBitSetError::IndexOutOfBounds)
127 }
128 }
129
130 pub fn set(&mut self, idx: usize) -> Result<(), StackBitSetError> {
132 if let Some(chunk) = self.data.get_mut(idx / USIZE_BITS).filter(|_| idx < N) {
133 *chunk |= 1 << (idx % USIZE_BITS);
134 Ok(())
135 } else {
136 Err(StackBitSetError::IndexOutOfBounds)
137 }
138 }
139
140 pub fn reset(&mut self, idx: usize) -> Result<(), StackBitSetError> {
142 if let Some(chunk) = self.data.get_mut(idx / USIZE_BITS).filter(|_| idx < N) {
143 *chunk &= !(1 << (idx % USIZE_BITS));
144 Ok(())
145 } else {
146 Err(StackBitSetError::IndexOutOfBounds)
147 }
148 }
149}
150
151impl<const N: usize> StackBitSet<N>
152where
153 [(); usize_count(N)]: Sized,
154{
155 pub fn union<const M: usize>(&self, other: &StackBitSet<M>) -> StackBitSet<{ const_min(N, M) }>
156 where
157 [(); usize_count(M)]: Sized,
158 [(); usize_count(const_min(N, M))]: Sized,
159 {
160 let mut res = StackBitSet::new();
161 for i in self.iter_limit(M).chain(other.iter_limit(N)) {
162 res.set(i).unwrap();
163 }
164 res
165 }
166 pub fn intersection<const M: usize>(
167 &self,
168 other: &StackBitSet<M>,
169 ) -> StackBitSet<{ const_min(N, M) }>
170 where
171 [(); usize_count(M)]: Sized,
172 [(); usize_count(const_min(N, M))]: Sized,
173 {
174 let mut res = StackBitSet::new();
175 for i in self.iter_limit(M) {
176 if other.get(i).unwrap() {
177 res.set(i).unwrap();
178 }
179 }
180 res
181 }
182 pub fn difference<const M: usize>(
183 &self,
184 other: &StackBitSet<M>,
185 ) -> StackBitSet<{ const_min(N, M) }>
186 where
187 [(); usize_count(M)]: Sized,
188 [(); usize_count(const_min(N, M))]: Sized,
189 {
190 let mut res = StackBitSet::new();
191 for i in 0..(const_min(N, M)) {
192 if self.get(i).unwrap() {
193 res.set(i).unwrap();
194 }
195 if other.get(i).unwrap() {
196 res.reset(i).unwrap();
197 }
198 }
199 res
200 }
201 pub fn complement(&self) -> StackBitSet<N> {
202 let mut res = StackBitSet::new();
203 for i in 0..N {
204 if !self.get(i).unwrap() {
205 res.set(i).unwrap();
206 }
207 }
208 res
209 }
210 pub fn is_subset<const M: usize>(&self, other: &StackBitSet<M>) -> bool
211 where
212 [(); usize_count(M)]: Sized,
213 {
214 for i in 0..N {
215 if (i < M && (!other.get(i).unwrap() && self.get(i).unwrap()))
216 || (i >= M && self.get(i).unwrap())
217 {
218 return false;
219 }
220 }
221 !self.is_equal(other)
222 }
223 pub fn is_equal<const M: usize>(&self, other: &StackBitSet<M>) -> bool
224 where
225 [(); usize_count(M)]: Sized,
226 {
227 for i in 0..(N + M - const_min(N, M)) {
228 if i < N && i < M && (other.get(i).unwrap() ^ self.get(i).unwrap()) {
229 println!("1");
230 return false;
231 } else if i >= M && i < N && self.get(i).unwrap() {
232 println!("2");
233 return false;
234 } else if i >= N && i < M && other.get(i).unwrap() {
235 println!("3");
236 return false;
237 }
238 }
239 true
240 }
241 pub fn is_superset<const M: usize>(&self, other: &StackBitSet<M>) -> bool
242 where
243 [(); usize_count(M)]: Sized,
244 {
245 !self.is_equal(other) && !self.is_subset(other)
246 }
247}
248
249impl<const N: usize, const M: usize> Add<&StackBitSet<M>> for StackBitSet<N>
250where
251 [(); usize_count(N)]: Sized,
252 [(); usize_count(M)]: Sized,
253 [(); usize_count(const_min(N, M))]: Sized,
254{
255 type Output = StackBitSet<{ const_min(N, M) }>;
256
257 fn add(self, other: &StackBitSet<M>) -> Self::Output {
258 self.union(other)
259 }
260}
261
262macro_rules! add_impl {
263 ($($t:ty)*) => ($(
264
265 impl<const N: usize> Add<$t> for StackBitSet<N>
266where
267 [(); usize_count(N)]: Sized,
268{
269 type Output = StackBitSet<N>;
270
271 fn add(mut self, other: $t) -> StackBitSet<N> {
272 self.set(other as usize).unwrap();
273 self
274 }
275}
276 )*)
277}
278
279add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 }
280
281macro_rules! sub_impl {
282 ($($t:ty)*) => ($(
283
284 impl<const N: usize> Sub<$t> for StackBitSet<N>
285where
286 [(); usize_count(N)]: Sized,
287{
288 type Output = StackBitSet<N>;
289
290 fn sub(mut self, other: $t) -> StackBitSet<N> {
291 self.reset(other as usize).unwrap();
292 self
293 }
294}
295 )*)
296}
297
298sub_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 }
299
300#[cfg(test)]
301mod tests {
302 use crate::StackBitSet;
303 #[test]
304 fn bitset_create() {
305 let _a: StackBitSet<42> = StackBitSet::new();
306 }
307
308 #[test]
309 fn set_reset_bit() {
310 let mut a: StackBitSet<42> = StackBitSet::new();
311 assert!(!a.get(12).unwrap());
312 a.set(12).unwrap();
313 assert!(a.get(12).unwrap());
314 a.reset(12).unwrap();
315 assert!(!a.get(12).unwrap());
316 }
317
318 #[test]
319 fn equality() {
320 let mut a: StackBitSet<42> = StackBitSet::new();
321 let mut b: StackBitSet<69> = StackBitSet::new();
322 assert!(a.is_equal(&b));
323 a.set(12).unwrap();
324 assert!(!a.is_equal(&b));
325 b.set(12).unwrap();
326 assert!(a.is_equal(&b));
327 }
328
329 #[test]
330 fn union() {
331 let mut a: StackBitSet<42> = StackBitSet::new();
332 let mut b: StackBitSet<69> = StackBitSet::new();
333 a.set(12).unwrap();
334 b.set(29).unwrap();
335 let mut c: StackBitSet<37> = StackBitSet::new();
336 c.set(12).unwrap();
337 c.set(29).unwrap();
338 assert!(c.is_equal(&(a.union(&b))));
339 assert!(a.is_subset(&c));
340 assert!(b.is_subset(&c));
341 let d: StackBitSet<93> = StackBitSet::new();
342 assert!((c.intersection(&a)).intersection(&b).is_equal(&d));
343 }
344
345 #[test]
346 fn subset() {
347 let mut a: StackBitSet<42> = StackBitSet::new();
348 let mut b: StackBitSet<69> = StackBitSet::new();
349 a.set(12).unwrap();
350 b.set(12).unwrap();
351 b.set(29).unwrap();
352
353 assert!(a.is_subset(&b));
354 assert!(!b.is_subset(&a));
355 assert!(b.is_superset(&a));
356 assert!(!b.is_equal(&a));
357 }
358}