1use bit_vec::BitVec;
2use std::ops::{BitAnd, BitOr, BitXor};
3use std::sync::atomic::{AtomicU32, Ordering};
4
5type B = u64;
6
7#[derive(PartialEq, Eq, Clone, Hash, Debug)]
11pub struct StateSet(BitVec<B>);
12
13#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
15pub struct State(pub(crate) u32);
16
17impl State {
18 pub fn nth(n: u32) -> Self {
20 Self(n)
21 }
22}
23
24thread_local! {
25 static STATE_COUNT: AtomicU32 = const { AtomicU32::new(u32::MAX) };
26}
27
28fn state_count() -> u32 {
29 STATE_COUNT.with(|count| {
30 let loaded = count.load(Ordering::Relaxed);
31 debug_assert_ne!(
32 loaded,
33 u32::MAX,
34 "all StateSet's must be constructed within StateSet::scope"
35 );
36 loaded
37 })
38}
39
40impl StateSet {
41 pub fn scope<R>(state_count: u32, scope: impl FnOnce() -> R) -> R {
43 STATE_COUNT.with(|count| {
44 #[cfg(debug_assertions)]
45 let old = count.load(Ordering::Relaxed);
46 count.store(state_count, Ordering::Relaxed);
47 let ret = scope();
48 #[cfg(debug_assertions)]
49 count.store(old, Ordering::Relaxed);
50 ret
51 })
52 }
53
54 pub fn with_states(states: &[State]) -> Self {
56 let mut ret = BitVec::<B>::default();
57 ret.grow(state_count() as usize, false);
58 for state in states {
59 ret.set(state.0 as usize, true);
60 }
61 Self(ret)
62 }
63
64 #[inline(always)]
66 pub fn len() -> u32 {
67 state_count()
68 }
69
70 pub fn all() -> Self {
72 let mut ret = BitVec::<B>::default();
73 ret.grow(state_count() as usize, true);
74 Self(ret)
75 }
76
77 #[inline(always)]
79 pub fn entropy(&self) -> u32 {
80 (self.0.count_ones() as u32).saturating_sub(1)
81 }
82
83 #[inline(always)]
85 pub fn has(&self, state: State) -> bool {
86 self.0.get(state.0 as usize).unwrap()
87 }
88
89 #[inline(always)]
91 pub fn has_any(&self, states: &Self) -> bool {
92 self.0
93 .blocks()
94 .zip(states.0.blocks())
95 .map(|(a, b)| (a & b != 0) as u32)
96 .sum::<u32>()
97 > 0
98 }
99
100 #[inline(always)]
102 pub fn remove(&mut self, state: State) {
103 self.0.set(state.0 as usize, false);
104 }
105
106 pub fn remove_all(&mut self, states: &Self) {
108 for (state, present) in states.0.iter().enumerate() {
109 if present {
110 self.0.set(state, false);
111 }
112 }
113 }
114
115 #[inline(always)]
117 pub fn add(&mut self, state: State) {
118 self.0.set(state.0 as usize, true);
119 }
120
121 pub fn add_all(&mut self, states: &Self) {
123 self.0.or(&states.0);
124 }
125
126 pub(crate) fn iter(&self) -> impl Iterator<Item = State> + '_ {
127 self.0
128 .iter()
129 .enumerate()
130 .filter(|(_, p)| *p)
131 .map(|(i, _)| State::nth(i as u32))
132 }
133
134 pub fn retain(&mut self, mut filter: impl FnMut(State) -> bool) {
136 for s in 0..StateSet::len() {
137 let s = State::nth(s);
138 if self.has(s) && !filter(s) {
139 self.remove(s);
140 }
141 }
142 }
143}
144
145impl BitOr<Self> for StateSet {
146 type Output = Self;
147
148 fn bitor(mut self, rhs: Self) -> Self::Output {
149 self.0.or(&rhs.0);
150 self
151 }
152}
153
154impl BitAnd for StateSet {
155 type Output = Self;
156
157 fn bitand(mut self, rhs: Self) -> Self::Output {
158 self.0.and(&rhs.0);
159 self
160 }
161}
162
163impl BitXor for StateSet {
164 type Output = Self;
165
166 fn bitxor(mut self, rhs: Self) -> Self::Output {
167 self.0.xor(&rhs.0);
168 self
169 }
170}
171
172impl BitOr for State {
173 type Output = StateSet;
174
175 fn bitor(self, rhs: Self) -> Self::Output {
176 StateSet::with_states(&[self, rhs])
177 }
178}
179
180impl BitOr<State> for StateSet {
181 type Output = Self;
182
183 fn bitor(mut self, rhs: State) -> Self::Output {
184 self.add(rhs);
185 self
186 }
187}