Skip to main content

texform_transform/rewrite/
level_set.rs

1//! Bitset of rewrite normalization levels. Const-friendly and runtime-mutable.
2
3use super::rule::NormalizationLevel;
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
6pub struct NormalizationLevelSet(u8);
7
8impl NormalizationLevelSet {
9    pub const STANDARD: Self = Self(1 << 0);
10    pub const EXPAND: Self = Self(1 << 1);
11    pub const DROP: Self = Self(1 << 2);
12    pub const EQUIV: Self = Self(1 << 3);
13
14    pub const fn empty() -> Self {
15        Self(0)
16    }
17
18    pub const fn all() -> Self {
19        Self(0b1111)
20    }
21
22    pub const fn union(self, other: Self) -> Self {
23        Self(self.0 | other.0)
24    }
25
26    pub const fn intersects(self, other: Self) -> bool {
27        self.0 & other.0 != 0
28    }
29
30    pub const fn contains(self, level: NormalizationLevel) -> bool {
31        let bit = match level {
32            NormalizationLevel::Standard => 1 << 0,
33            NormalizationLevel::Expand => 1 << 1,
34            NormalizationLevel::Drop => 1 << 2,
35            NormalizationLevel::Equiv => 1 << 3,
36        };
37        self.0 & bit != 0
38    }
39
40    pub fn iter(self) -> impl Iterator<Item = NormalizationLevel> {
41        const ORDER: [NormalizationLevel; 4] = [
42            NormalizationLevel::Standard,
43            NormalizationLevel::Expand,
44            NormalizationLevel::Drop,
45            NormalizationLevel::Equiv,
46        ];
47        ORDER.into_iter().filter(move |level| self.contains(*level))
48    }
49}
50
51impl Default for NormalizationLevelSet {
52    fn default() -> Self {
53        Self::empty()
54    }
55}
56
57impl std::ops::BitOr for NormalizationLevelSet {
58    type Output = Self;
59
60    fn bitor(self, rhs: Self) -> Self {
61        self.union(rhs)
62    }
63}
64
65impl std::ops::BitOrAssign for NormalizationLevelSet {
66    fn bitor_assign(&mut self, rhs: Self) {
67        self.0 |= rhs.0;
68    }
69}
70
71impl From<NormalizationLevel> for NormalizationLevelSet {
72    fn from(level: NormalizationLevel) -> Self {
73        match level {
74            NormalizationLevel::Standard => Self::STANDARD,
75            NormalizationLevel::Expand => Self::EXPAND,
76            NormalizationLevel::Drop => Self::DROP,
77            NormalizationLevel::Equiv => Self::EQUIV,
78        }
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85
86    #[test]
87    fn union_combines_bits() {
88        let set = NormalizationLevelSet::STANDARD | NormalizationLevelSet::EXPAND;
89        assert!(set.contains(NormalizationLevel::Standard));
90        assert!(set.contains(NormalizationLevel::Expand));
91        assert!(!set.contains(NormalizationLevel::Drop));
92    }
93
94    #[test]
95    fn iter_emits_in_canonical_order() {
96        let set = NormalizationLevelSet::DROP | NormalizationLevelSet::STANDARD;
97        assert_eq!(
98            set.iter().collect::<Vec<_>>(),
99            vec![NormalizationLevel::Standard, NormalizationLevel::Drop]
100        );
101    }
102
103    #[test]
104    fn all_preset_contains_every_level() {
105        for level in [
106            NormalizationLevel::Standard,
107            NormalizationLevel::Expand,
108            NormalizationLevel::Drop,
109            NormalizationLevel::Equiv,
110        ] {
111            assert!(NormalizationLevelSet::all().contains(level));
112        }
113    }
114}