simple_wfc/
state.rs

1use bit_vec::BitVec;
2use std::ops::{BitAnd, BitOr, BitXor};
3use std::sync::atomic::{AtomicU32, Ordering};
4
5type B = u64;
6
7/// A superposition of multiple [State]'s.
8///
9/// You must use [Self::scope] to set the total number of states.
10#[derive(PartialEq, Eq, Clone, Hash, Debug)]
11pub struct StateSet(BitVec<B>);
12
13/// One possible state at a location.
14#[derive(PartialEq, Eq, Copy, Clone, Hash, Debug)]
15pub struct State(pub(crate) u32);
16
17impl State {
18    /// Creates the `n`th unique state (0-indexed).
19    pub fn nth(n: u32) -> Self {
20        Self(n)
21    }
22}
23
24thread_local! {
25    static STATE_COUNT: AtomicU32 = const { AtomicU32::new(u32::MAX) };
26}
27
28fn state_count() -> u32 {
29    STATE_COUNT.with(|count| {
30        let loaded = count.load(Ordering::Relaxed);
31        debug_assert_ne!(
32            loaded,
33            u32::MAX,
34            "all StateSet's must be constructed within StateSet::scope"
35        );
36        loaded
37    })
38}
39
40impl StateSet {
41    /// All [`StateSet`]'s created in `scope` will have `state_count` states.
42    pub fn scope<R>(state_count: u32, scope: impl FnOnce() -> R) -> R {
43        STATE_COUNT.with(|count| {
44            #[cfg(debug_assertions)]
45            let old = count.load(Ordering::Relaxed);
46            count.store(state_count, Ordering::Relaxed);
47            let ret = scope();
48            #[cfg(debug_assertions)]
49            count.store(old, Ordering::Relaxed);
50            ret
51        })
52    }
53
54    /// Creates a superposition of `states`.
55    pub fn with_states(states: &[State]) -> Self {
56        let mut ret = BitVec::<B>::default();
57        ret.grow(state_count() as usize, false);
58        for state in states {
59            ret.set(state.0 as usize, true);
60        }
61        Self(ret)
62    }
63
64    /// The total number of states, as defined by [Self::scope].
65    #[inline(always)]
66    pub fn len() -> u32 {
67        state_count()
68    }
69
70    /// Superposition of all states.
71    pub fn all() -> Self {
72        let mut ret = BitVec::<B>::default();
73        ret.grow(state_count() as usize, true);
74        Self(ret)
75    }
76
77    /// Total number of possible states, minus 1.
78    #[inline(always)]
79    pub fn entropy(&self) -> u32 {
80        (self.0.count_ones() as u32).saturating_sub(1)
81    }
82
83    /// Is `state` within the superposition?
84    #[inline(always)]
85    pub fn has(&self, state: State) -> bool {
86        self.0.get(state.0 as usize).unwrap()
87    }
88
89    /// Are any of `states` within the superposition?
90    #[inline(always)]
91    pub fn has_any(&self, states: &Self) -> bool {
92        self.0
93            .blocks()
94            .zip(states.0.blocks())
95            .map(|(a, b)| (a & b != 0) as u32)
96            .sum::<u32>()
97            > 0
98    }
99
100    /// Remove `state` from the superposition.
101    #[inline(always)]
102    pub fn remove(&mut self, state: State) {
103        self.0.set(state.0 as usize, false);
104    }
105
106    /// Remove all `states` from the superposition.
107    pub fn remove_all(&mut self, states: &Self) {
108        for (state, present) in states.0.iter().enumerate() {
109            if present {
110                self.0.set(state, false);
111            }
112        }
113    }
114
115    /// Add `state` to the superposition.
116    #[inline(always)]
117    pub fn add(&mut self, state: State) {
118        self.0.set(state.0 as usize, true);
119    }
120
121    /// Add all `states` to the superposition.
122    pub fn add_all(&mut self, states: &Self) {
123        self.0.or(&states.0);
124    }
125
126    pub(crate) fn iter(&self) -> impl Iterator<Item = State> + '_ {
127        self.0
128            .iter()
129            .enumerate()
130            .filter(|(_, p)| *p)
131            .map(|(i, _)| State::nth(i as u32))
132    }
133
134    /// Filter states in place.
135    pub fn retain(&mut self, mut filter: impl FnMut(State) -> bool) {
136        for s in 0..StateSet::len() {
137            let s = State::nth(s);
138            if self.has(s) && !filter(s) {
139                self.remove(s);
140            }
141        }
142    }
143}
144
145impl BitOr<Self> for StateSet {
146    type Output = Self;
147
148    fn bitor(mut self, rhs: Self) -> Self::Output {
149        self.0.or(&rhs.0);
150        self
151    }
152}
153
154impl BitAnd for StateSet {
155    type Output = Self;
156
157    fn bitand(mut self, rhs: Self) -> Self::Output {
158        self.0.and(&rhs.0);
159        self
160    }
161}
162
163impl BitXor for StateSet {
164    type Output = Self;
165
166    fn bitxor(mut self, rhs: Self) -> Self::Output {
167        self.0.xor(&rhs.0);
168        self
169    }
170}
171
172impl BitOr for State {
173    type Output = StateSet;
174
175    fn bitor(self, rhs: Self) -> Self::Output {
176        StateSet::with_states(&[self, rhs])
177    }
178}
179
180impl BitOr<State> for StateSet {
181    type Output = Self;
182
183    fn bitor(mut self, rhs: State) -> Self::Output {
184        self.add(rhs);
185        self
186    }
187}