Skip to main content

synapse_primitives/
flags.rs

1//! Bitfield flag packing for efficient boolean storage
2//!
3//! This module provides utilities to pack multiple boolean flags into a single
4//! integer, reducing memory usage and protocol overhead.
5//!
6//! # Examples
7//!
8//! ```
9//! use synapse_primitives::flags::Flags64;
10//!
11//! let mut flags = Flags64::new();
12//! flags.set(0, true);  // require_auth
13//! flags.set(1, true);  // idempotent
14//! flags.set(2, false); // allow_batch
15//!
16//! assert!(flags.get(0));
17//! assert!(flags.get(1));
18//! assert!(!flags.get(2));
19//! ```
20
21use std::fmt;
22
23/// 64-bit flag container (can hold up to 64 boolean flags)
24#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
25pub struct Flags64(u64);
26
27/// 32-bit flag container (can hold up to 32 boolean flags)
28#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
29pub struct Flags32(u32);
30
31impl Flags64 {
32    /// Maximum number of flags (0-63)
33    pub const MAX_FLAGS: u8 = 64;
34
35    /// Create empty flags (all false)
36    pub const fn new() -> Self {
37        Self(0)
38    }
39
40    /// Create from raw u64 value
41    pub const fn from_raw(value: u64) -> Self {
42        Self(value)
43    }
44
45    /// Get the raw u64 value
46    pub const fn as_u64(&self) -> u64 {
47        self.0
48    }
49
50    /// Set a flag at the given bit position
51    ///
52    /// # Panics
53    ///
54    /// Panics if `bit` >= 64
55    pub fn set(&mut self, bit: u8, value: bool) {
56        assert!(bit < 64, "Bit position must be < 64");
57        if value {
58            self.0 |= 1u64 << bit;
59        } else {
60            self.0 &= !(1u64 << bit);
61        }
62    }
63
64    /// Get a flag at the given bit position
65    ///
66    /// # Panics
67    ///
68    /// Panics if `bit` >= 64
69    pub const fn get(&self, bit: u8) -> bool {
70        assert!(bit < 64, "Bit position must be < 64");
71        (self.0 & (1u64 << bit)) != 0
72    }
73
74    /// Set multiple flags from an iterator of (bit, value) pairs
75    pub fn set_multiple<I>(&mut self, flags: I)
76    where
77        I: IntoIterator<Item = (u8, bool)>,
78    {
79        for (bit, value) in flags {
80            self.set(bit, value);
81        }
82    }
83
84    /// Count the number of set flags (population count)
85    pub const fn count_set(&self) -> u32 {
86        self.0.count_ones()
87    }
88
89    /// Check if any flags are set
90    pub const fn any(&self) -> bool {
91        self.0 != 0
92    }
93
94    /// Check if all flags are unset
95    pub const fn none(&self) -> bool {
96        self.0 == 0
97    }
98
99    /// Check if all flags are set
100    pub const fn all(&self) -> bool {
101        self.0 == u64::MAX
102    }
103
104    /// Clear all flags
105    pub fn clear(&mut self) {
106        self.0 = 0;
107    }
108
109    /// Merge with another flag set (bitwise OR)
110    pub fn merge(&mut self, other: Flags64) {
111        self.0 |= other.0;
112    }
113
114    /// Intersect with another flag set (bitwise AND)
115    pub fn intersect(&mut self, other: Flags64) {
116        self.0 &= other.0;
117    }
118
119    /// Check if this flag set contains all flags from another set
120    pub const fn contains(&self, other: Flags64) -> bool {
121        (self.0 & other.0) == other.0
122    }
123}
124
125impl Flags32 {
126    /// Maximum number of flags (0-31)
127    pub const MAX_FLAGS: u8 = 32;
128
129    /// Create empty flags (all false)
130    pub const fn new() -> Self {
131        Self(0)
132    }
133
134    /// Create from raw u32 value
135    pub const fn from_raw(value: u32) -> Self {
136        Self(value)
137    }
138
139    /// Get the raw u32 value
140    pub const fn as_u32(&self) -> u32 {
141        self.0
142    }
143
144    /// Set a flag at the given bit position
145    ///
146    /// # Panics
147    ///
148    /// Panics if `bit` >= 32
149    pub fn set(&mut self, bit: u8, value: bool) {
150        assert!(bit < 32, "Bit position must be < 32");
151        if value {
152            self.0 |= 1u32 << bit;
153        } else {
154            self.0 &= !(1u32 << bit);
155        }
156    }
157
158    /// Get a flag at the given bit position
159    ///
160    /// # Panics
161    ///
162    /// Panics if `bit` >= 32
163    pub const fn get(&self, bit: u8) -> bool {
164        assert!(bit < 32, "Bit position must be < 32");
165        (self.0 & (1u32 << bit)) != 0
166    }
167
168    /// Set multiple flags from an iterator
169    pub fn set_multiple<I>(&mut self, flags: I)
170    where
171        I: IntoIterator<Item = (u8, bool)>,
172    {
173        for (bit, value) in flags {
174            self.set(bit, value);
175        }
176    }
177
178    /// Count the number of set flags
179    pub const fn count_set(&self) -> u32 {
180        self.0.count_ones()
181    }
182
183    /// Check if any flags are set
184    pub const fn any(&self) -> bool {
185        self.0 != 0
186    }
187
188    /// Check if all flags are unset
189    pub const fn none(&self) -> bool {
190        self.0 == 0
191    }
192
193    /// Clear all flags
194    pub fn clear(&mut self) {
195        self.0 = 0;
196    }
197
198    /// Merge with another flag set (bitwise OR)
199    pub fn merge(&mut self, other: Flags32) {
200        self.0 |= other.0;
201    }
202
203    /// Intersect with another flag set (bitwise AND)
204    pub fn intersect(&mut self, other: Flags32) {
205        self.0 &= other.0;
206    }
207
208    /// Check if this flag set contains all flags from another set
209    pub const fn contains(&self, other: Flags32) -> bool {
210        (self.0 & other.0) == other.0
211    }
212}
213
214impl From<Flags64> for u64 {
215    fn from(f: Flags64) -> u64 {
216        f.0
217    }
218}
219
220impl From<u64> for Flags64 {
221    fn from(value: u64) -> Flags64 {
222        Flags64(value)
223    }
224}
225
226impl From<Flags32> for u32 {
227    fn from(f: Flags32) -> u32 {
228        f.0
229    }
230}
231
232impl From<u32> for Flags32 {
233    fn from(value: u32) -> Flags32 {
234        Flags32(value)
235    }
236}
237
238impl fmt::Binary for Flags64 {
239    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
240        write!(f, "{:064b}", self.0)
241    }
242}
243
244impl fmt::Binary for Flags32 {
245    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
246        write!(f, "{:032b}", self.0)
247    }
248}
249
250/// Macro to define a typed flag set with named flags
251///
252/// Note: Requires the `paste` crate for generating setter method names.
253///
254/// # Examples
255///
256/// ```ignore
257/// use synapse_primitives::define_flags;
258///
259/// define_flags! {
260///     RequestFlags: Flags64 {
261///         REQUIRE_AUTH = 0,
262///         IDEMPOTENT = 1,
263///         ALLOW_BATCH = 2,
264///         COMPRESS_RESPONSE = 3,
265///     }
266/// }
267///
268/// let mut flags = RequestFlags::new();
269/// flags.set_require_auth(true);
270/// flags.set_idempotent(true);
271///
272/// assert!(flags.require_auth());
273/// assert!(flags.idempotent());
274/// assert!(!flags.allow_batch());
275/// ```
276#[macro_export]
277macro_rules! define_flags {
278    (
279        $name:ident: $base:ty {
280            $($flag:ident = $bit:expr),* $(,)?
281        }
282    ) => {
283        pub struct $name($base);
284
285        impl $name {
286            pub const fn new() -> Self {
287                Self(<$base>::new())
288            }
289
290            pub const fn from_raw(value: impl Into<$base>) -> Self {
291                Self(value.into())
292            }
293
294            $(
295                pub const fn $flag(&self) -> bool {
296                    self.0.get($bit)
297                }
298
299                paste::paste! {
300                    pub fn [<set_ $flag>](&mut self, value: bool) {
301                        self.0.set($bit, value);
302                    }
303                }
304            )*
305
306            pub const fn as_raw(&self) -> $base {
307                self.0
308            }
309        }
310
311        impl Default for $name {
312            fn default() -> Self {
313                Self::new()
314            }
315        }
316    };
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322
323    #[test]
324    fn test_flags64_basic() {
325        let mut flags = Flags64::new();
326        assert!(flags.none());
327        assert!(!flags.any());
328
329        flags.set(0, true);
330        assert!(flags.any());
331        assert!(flags.get(0));
332        assert!(!flags.get(1));
333
334        flags.set(63, true);
335        assert!(flags.get(63));
336        assert_eq!(flags.count_set(), 2);
337    }
338
339    #[test]
340    fn test_flags32_basic() {
341        let mut flags = Flags32::new();
342        assert!(flags.none());
343
344        flags.set(0, true);
345        flags.set(31, true);
346        assert_eq!(flags.count_set(), 2);
347    }
348
349    #[test]
350    fn test_set_multiple() {
351        let mut flags = Flags64::new();
352        flags.set_multiple(vec![(0, true), (5, true), (10, true)]);
353
354        assert!(flags.get(0));
355        assert!(flags.get(5));
356        assert!(flags.get(10));
357        assert!(!flags.get(1));
358        assert_eq!(flags.count_set(), 3);
359    }
360
361    #[test]
362    fn test_merge() {
363        let mut flags1 = Flags64::new();
364        flags1.set(0, true);
365        flags1.set(1, true);
366
367        let mut flags2 = Flags64::new();
368        flags2.set(2, true);
369        flags2.set(3, true);
370
371        flags1.merge(flags2);
372        assert!(flags1.get(0));
373        assert!(flags1.get(1));
374        assert!(flags1.get(2));
375        assert!(flags1.get(3));
376        assert_eq!(flags1.count_set(), 4);
377    }
378
379    #[test]
380    fn test_intersect() {
381        let mut flags1 = Flags64::new();
382        flags1.set(0, true);
383        flags1.set(1, true);
384        flags1.set(2, true);
385
386        let mut flags2 = Flags64::new();
387        flags2.set(1, true);
388        flags2.set(2, true);
389        flags2.set(3, true);
390
391        flags1.intersect(flags2);
392        assert!(!flags1.get(0));
393        assert!(flags1.get(1));
394        assert!(flags1.get(2));
395        assert!(!flags1.get(3));
396    }
397
398    #[test]
399    fn test_contains() {
400        let mut flags1 = Flags64::new();
401        flags1.set(0, true);
402        flags1.set(1, true);
403        flags1.set(2, true);
404
405        let mut flags2 = Flags64::new();
406        flags2.set(0, true);
407        flags2.set(1, true);
408
409        assert!(flags1.contains(flags2));
410
411        flags2.set(5, true);
412        assert!(!flags1.contains(flags2));
413    }
414
415    #[test]
416    fn test_clear() {
417        let mut flags = Flags64::new();
418        flags.set(0, true);
419        flags.set(10, true);
420        assert!(flags.any());
421
422        flags.clear();
423        assert!(flags.none());
424    }
425
426    #[test]
427    fn test_from_raw() {
428        let flags = Flags64::from_raw(0b1010);
429        assert!(!flags.get(0));
430        assert!(flags.get(1));
431        assert!(!flags.get(2));
432        assert!(flags.get(3));
433    }
434}