texform_transform/rewrite/
level_set.rs1use super::rule::NormalizationLevel;
4
5#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
6pub struct NormalizationLevelSet(u8);
7
8impl NormalizationLevelSet {
9 pub const STANDARD: Self = Self(1 << 0);
10 pub const EXPAND: Self = Self(1 << 1);
11 pub const DROP: Self = Self(1 << 2);
12 pub const EQUIV: Self = Self(1 << 3);
13
14 pub const fn empty() -> Self {
15 Self(0)
16 }
17
18 pub const fn all() -> Self {
19 Self(0b1111)
20 }
21
22 pub const fn union(self, other: Self) -> Self {
23 Self(self.0 | other.0)
24 }
25
26 pub const fn intersects(self, other: Self) -> bool {
27 self.0 & other.0 != 0
28 }
29
30 pub const fn contains(self, level: NormalizationLevel) -> bool {
31 let bit = match level {
32 NormalizationLevel::Standard => 1 << 0,
33 NormalizationLevel::Expand => 1 << 1,
34 NormalizationLevel::Drop => 1 << 2,
35 NormalizationLevel::Equiv => 1 << 3,
36 };
37 self.0 & bit != 0
38 }
39
40 pub fn iter(self) -> impl Iterator<Item = NormalizationLevel> {
41 const ORDER: [NormalizationLevel; 4] = [
42 NormalizationLevel::Standard,
43 NormalizationLevel::Expand,
44 NormalizationLevel::Drop,
45 NormalizationLevel::Equiv,
46 ];
47 ORDER.into_iter().filter(move |level| self.contains(*level))
48 }
49}
50
51impl Default for NormalizationLevelSet {
52 fn default() -> Self {
53 Self::empty()
54 }
55}
56
57impl std::ops::BitOr for NormalizationLevelSet {
58 type Output = Self;
59
60 fn bitor(self, rhs: Self) -> Self {
61 self.union(rhs)
62 }
63}
64
65impl std::ops::BitOrAssign for NormalizationLevelSet {
66 fn bitor_assign(&mut self, rhs: Self) {
67 self.0 |= rhs.0;
68 }
69}
70
71impl From<NormalizationLevel> for NormalizationLevelSet {
72 fn from(level: NormalizationLevel) -> Self {
73 match level {
74 NormalizationLevel::Standard => Self::STANDARD,
75 NormalizationLevel::Expand => Self::EXPAND,
76 NormalizationLevel::Drop => Self::DROP,
77 NormalizationLevel::Equiv => Self::EQUIV,
78 }
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85
86 #[test]
87 fn union_combines_bits() {
88 let set = NormalizationLevelSet::STANDARD | NormalizationLevelSet::EXPAND;
89 assert!(set.contains(NormalizationLevel::Standard));
90 assert!(set.contains(NormalizationLevel::Expand));
91 assert!(!set.contains(NormalizationLevel::Drop));
92 }
93
94 #[test]
95 fn iter_emits_in_canonical_order() {
96 let set = NormalizationLevelSet::DROP | NormalizationLevelSet::STANDARD;
97 assert_eq!(
98 set.iter().collect::<Vec<_>>(),
99 vec![NormalizationLevel::Standard, NormalizationLevel::Drop]
100 );
101 }
102
103 #[test]
104 fn all_preset_contains_every_level() {
105 for level in [
106 NormalizationLevel::Standard,
107 NormalizationLevel::Expand,
108 NormalizationLevel::Drop,
109 NormalizationLevel::Equiv,
110 ] {
111 assert!(NormalizationLevelSet::all().contains(level));
112 }
113 }
114}