Skip to main content

rustolio_utils/
bit_vec.rs

1//
2// SPDX-License-Identifier: MPL-2.0
3//
4// Copyright (c) 2026 Tobias Binnewies. All rights reserved.
5//
6// This Source Code Form is subject to the terms of the Mozilla Public
7// License, v. 2.0. If a copy of the MPL was not distributed with this
8// file, You can obtain one at http://mozilla.org/MPL/2.0/.
9//
10
11use std::ops::{BitAnd, BitOr, BitXor, Index, Not};
12
13use crate::prelude::*;
14
15#[derive(Debug, Clone, PartialEq, Eq, Encode, Decode)]
16pub struct BitVector {
17    data: Box<[u8]>,
18    len: usize, // length in bits
19}
20
21impl BitVector {
22    /// Creates a new BitVector with the specified number of bits, all initialized to `false`
23    pub fn new(len: usize) -> Self {
24        let bytes_needed = len.div_ceil(8); // Ceiling division
25        BitVector {
26            data: vec![0; bytes_needed].into_boxed_slice(),
27            len,
28        }
29    }
30
31    /// Creates a BitVector from a binary string (e.g., "101010")
32    pub fn from_binary_str(s: &str) -> crate::Result<Self> {
33        let mut bv = BitVector::new(s.len());
34        for (i, c) in s.chars().enumerate() {
35            match c {
36                '0' => bv.set(i, false),
37                '1' => bv.set(i, true),
38                _ => return Err(crate::Error::new("Invalid character in binary string")),
39            }
40        }
41        Ok(bv)
42    }
43
44    /// Returns the length in bits
45    #[inline]
46    pub fn len(&self) -> usize {
47        self.len
48    }
49
50    /// Returns the number of bytes used for storage
51    pub fn bytes_len(&self) -> usize {
52        self.data.len()
53    }
54
55    /// Gets the value of the bit at the specified index
56    /// Panics if index is out of bounds
57    pub fn get(&self, index: usize) -> bool {
58        assert!(index < self.len, "Index out of bounds");
59        let byte_index = index / 8;
60        let bit_index = 7 - (index % 8); // Using big-endian bit order within bytes
61        (self.data[byte_index] >> bit_index) & 1 == 1
62    }
63
64    /// Sets the value of the bit at the specified index
65    /// Panics if index is out of bounds
66    pub fn set(&mut self, index: usize, value: bool) {
67        assert!(index < self.len, "Index out of bounds");
68        let byte_index = index / 8;
69        let bit_index = 7 - (index % 8); // Using big-endian bit order within bytes
70
71        if value {
72            // Set the bit to 1
73            self.data[byte_index] |= 1 << bit_index;
74        } else {
75            // Set the bit to 0
76            self.data[byte_index] &= !(1 << bit_index);
77        }
78    }
79
80    /// Toggles (flips) the bit at the specified index
81    pub fn toggle(&mut self, index: usize) {
82        assert!(index < self.len, "Index out of bounds");
83        let byte_index = index / 8;
84        let bit_index = 7 - (index % 8);
85        self.data[byte_index] ^= 1 << bit_index;
86    }
87
88    /// Sets all bits to `false`
89    pub fn clear(&mut self) {
90        for byte in &mut self.data {
91            *byte = 0;
92        }
93    }
94
95    /// Sets all bits to `true`
96    pub fn set_all(&mut self) {
97        for byte in &mut self.data {
98            *byte = 0xFF;
99        }
100    }
101
102    /// Counts the number of bits set to `true`
103    pub fn count_ones(&self) -> usize {
104        if self.len == 0 {
105            return 0;
106        }
107        let last_bytes_used = self.len % 8;
108
109        let last_bytes_ones = if last_bytes_used == 0 {
110            self.data.last().unwrap().count_ones()
111        } else {
112            (self.data.last().unwrap() >> (8 - last_bytes_used)).count_ones()
113        };
114
115        last_bytes_ones as usize
116            + self
117                .data
118                .iter()
119                .rev()
120                .skip(1)
121                .map(|&byte| byte.count_ones() as usize)
122                .sum::<usize>()
123    }
124
125    /// Counts the number of bits set to `false`
126    pub fn count_zeros(&self) -> usize {
127        self.len - self.count_ones()
128    }
129
130    /// Returns true if all bits are `true`
131    pub fn all(&self) -> bool {
132        self.count_ones() == self.len
133    }
134
135    /// Returns true if any bit is `true`
136    pub fn any(&self) -> bool {
137        !self.none()
138    }
139
140    /// Returns true if all bits are `false`
141    pub fn none(&self) -> bool {
142        !self.data.iter().any(|&byte| byte != 0)
143    }
144
145    /// Returns the underlying byte vector
146    pub fn as_bytes(&self) -> &[u8] {
147        &self.data
148    }
149
150    /// Returns an iterator over the bits
151    pub fn iter<'a>(&'a self) -> BitVectorIter<'a> {
152        BitVectorIter { bv: self, index: 0 }
153    }
154
155    /// Performs a bitwise AND with another bitvector
156    /// Panics if bitvectors have different lengths
157    pub fn and(&mut self, other: &BitVector) {
158        assert_eq!(self.len, other.len, "Bitvectors must have the same length");
159        for i in 0..self.data.len() {
160            self.data[i] &= other.data[i];
161        }
162    }
163
164    /// Performs a bitwise OR with another bitvector
165    /// Panics if bitvectors have different lengths
166    pub fn or(&mut self, other: &BitVector) {
167        assert_eq!(self.len, other.len, "Bitvectors must have the same length");
168        for i in 0..self.data.len() {
169            self.data[i] |= other.data[i];
170        }
171    }
172
173    /// Performs a bitwise XOR with another bitvector
174    /// Panics if bitvectors have different lengths
175    pub fn xor(&mut self, other: &BitVector) {
176        assert_eq!(self.len, other.len, "Bitvectors must have the same length");
177        for i in 0..self.data.len() {
178            self.data[i] ^= other.data[i];
179        }
180    }
181
182    /// Returns the bitvector as a binary string
183    pub fn to_binary_string(&self) -> String {
184        self.iter().map(|b| if b { '1' } else { '0' }).collect()
185    }
186}
187
188// Implement indexing with [] operator
189impl Index<usize> for BitVector {
190    type Output = bool;
191
192    fn index(&self, index: usize) -> &Self::Output {
193        if self.get(index) {
194            return &true;
195        }
196        &false
197    }
198}
199
200// Iterator for BitVector
201pub struct BitVectorIter<'a> {
202    bv: &'a BitVector,
203    index: usize,
204}
205
206impl<'a> Iterator for BitVectorIter<'a> {
207    type Item = bool;
208
209    fn next(&mut self) -> Option<Self::Item> {
210        if self.index < self.bv.len() {
211            let result = self.bv.get(self.index);
212            self.index += 1;
213            Some(result)
214        } else {
215            None
216        }
217    }
218
219    fn size_hint(&self) -> (usize, Option<usize>) {
220        let remaining = self.bv.len() - self.index;
221        (remaining, Some(remaining))
222    }
223}
224
225impl<'a> ExactSizeIterator for BitVectorIter<'a> {}
226
227// Implement bitwise operators
228impl BitAnd for &BitVector {
229    type Output = BitVector;
230
231    fn bitand(self, other: Self) -> BitVector {
232        assert_eq!(self.len, other.len, "Bitvectors must have the same length");
233        let mut result = self.clone();
234        result.and(other);
235        result
236    }
237}
238
239impl BitOr for &BitVector {
240    type Output = BitVector;
241
242    fn bitor(self, other: Self) -> BitVector {
243        assert_eq!(self.len, other.len, "Bitvectors must have the same length");
244        let mut result = self.clone();
245        result.or(other);
246        result
247    }
248}
249
250impl BitXor for &BitVector {
251    type Output = BitVector;
252
253    fn bitxor(self, other: Self) -> BitVector {
254        assert_eq!(self.len, other.len, "Bitvectors must have the same length");
255        let mut result = self.clone();
256        result.xor(other);
257        result
258    }
259}
260
261impl Not for &BitVector {
262    type Output = BitVector;
263
264    fn not(self) -> BitVector {
265        let mut result = self.clone();
266        for byte in &mut result.data {
267            *byte = !*byte;
268        }
269        result
270    }
271}
272
273#[cfg(test)]
274mod tests {
275    use super::*;
276
277    #[test]
278    fn test_new() {
279        let bv = BitVector::new(10);
280        assert_eq!(bv.len(), 10);
281        assert_eq!(bv.bytes_len(), 2); // 10 bits = 2 bytes
282        assert!(bv.none()); // All bits should be false
283    }
284
285    #[test]
286    fn test_from_str() {
287        let bv = BitVector::from_binary_str("1010101011001100").unwrap();
288        assert_eq!(bv.len(), 16);
289        assert_eq!(bv.bytes_len(), 2);
290        assert_eq!(bv.get(0), true);
291        assert_eq!(bv[0], true);
292        assert_eq!(bv.get(1), false);
293        assert_eq!(bv.get(7), false);
294        assert_eq!(bv.get(8), true);
295        assert_eq!(bv[15], false);
296    }
297
298    #[test]
299    fn test_get_set() {
300        let mut bv = BitVector::new(20);
301
302        // Initially all false
303        for i in 0..20 {
304            assert_eq!(bv.get(i), false);
305        }
306
307        // Set some bits
308        bv.set(0, true);
309        bv.set(5, true);
310        bv.set(19, true);
311
312        // Check set bits
313        assert_eq!(bv.get(0), true);
314        assert_eq!(bv.get(5), true);
315        assert_eq!(bv.get(19), true);
316
317        // Check unset bits
318        assert_eq!(bv.get(1), false);
319        assert_eq!(bv.get(18), false);
320
321        // Overwrite a bit
322        bv.set(5, false);
323        assert_eq!(bv.get(5), false);
324    }
325
326    #[test]
327    #[should_panic(expected = "Index out of bounds")]
328    fn test_get_out_of_bounds() {
329        let bv = BitVector::new(10);
330        bv.get(10);
331    }
332
333    #[test]
334    #[should_panic(expected = "Index out of bounds")]
335    fn test_set_out_of_bounds() {
336        let mut bv = BitVector::new(10);
337        bv.set(10, true);
338    }
339
340    #[test]
341    fn test_toggle() {
342        let mut bv = BitVector::new(8);
343        bv.set(3, true);
344
345        bv.toggle(3);
346        assert_eq!(bv.get(3), false);
347
348        bv.toggle(3);
349        assert_eq!(bv.get(3), true);
350
351        bv.toggle(0);
352        assert_eq!(bv.get(0), true);
353    }
354
355    #[test]
356    fn test_clear() {
357        let mut bv = BitVector::new(8);
358        assert!(bv.none());
359
360        bv.toggle(1);
361        assert!(bv.any());
362
363        bv.clear();
364        assert!(bv.none());
365        assert_eq!(bv.count_ones(), 0);
366    }
367
368    #[test]
369    fn test_count_ones_zeros() {
370        let bv = BitVector::from_binary_str("10101010").unwrap();
371        assert_eq!(bv.count_ones(), 4);
372        assert_eq!(bv.count_zeros(), 4);
373
374        let bv2 = BitVector::new(100);
375        assert_eq!(bv2.count_ones(), 0);
376        assert_eq!(bv2.count_zeros(), 100);
377    }
378
379    #[test]
380    fn test_all_any_none() {
381        let mut bv = BitVector::new(8);
382        assert!(bv.none());
383        assert!(!bv.any());
384        assert!(!bv.all());
385
386        bv.set(3, true);
387        assert!(!bv.none());
388        assert!(bv.any());
389        assert!(!bv.all());
390
391        bv.set_all();
392        assert!(!bv.none());
393        assert!(bv.any());
394        assert!(bv.all());
395    }
396
397    #[test]
398    fn test_iter() {
399        let bv = BitVector::from_binary_str("1010").unwrap();
400        let bits: Vec<bool> = bv.iter().collect();
401        assert_eq!(bits, vec![true, false, true, false]);
402
403        // Test with exact size iterator
404        let mut iter = bv.iter();
405        assert_eq!(iter.len(), 4);
406        iter.next();
407        assert_eq!(iter.len(), 3);
408    }
409
410    #[test]
411    fn test_bitwise_operations() {
412        let bv1 = BitVector::from_binary_str("1100").unwrap();
413        let bv2 = BitVector::from_binary_str("1010").unwrap();
414
415        // AND
416        let and = &bv1 & &bv2;
417        assert_eq!(and.to_binary_string(), "1000");
418
419        // OR
420        let or = &bv1 | &bv2;
421        assert_eq!(or.to_binary_string(), "1110");
422
423        // XOR
424        let xor = &bv1 ^ &bv2;
425        assert_eq!(xor.to_binary_string(), "0110");
426
427        // NOT
428        let not = !&bv1;
429        assert_eq!(not.to_binary_string(), "0011");
430
431        // In-place operations
432        let mut bv3 = bv1.clone();
433        bv3.and(&bv2);
434        assert_eq!(bv3.to_binary_string(), "1000");
435    }
436
437    #[test]
438    #[should_panic(expected = "Bitvectors must have the same length")]
439    fn test_bitwise_length_mismatch() {
440        let bv1 = BitVector::new(4);
441        let bv2 = BitVector::new(5);
442        let _ = &bv1 & &bv2;
443    }
444
445    #[test]
446    fn test_to_binary_string() {
447        let bv = BitVector::from_binary_str("10101100").unwrap();
448        assert_eq!(bv.to_binary_string(), "10101100");
449
450        let bv2 = BitVector::new(3);
451        assert_eq!(bv2.to_binary_string(), "000");
452    }
453
454    #[test]
455    fn test_partial_byte_handling() {
456        // Test with 12 bits (1.5 bytes)
457        let mut bv = BitVector::new(12);
458        bv.set_all();
459        assert_eq!(bv.count_ones(), 12);
460
461        // The last 4 bits of the second byte should be 0
462        let bytes = bv.as_bytes();
463        assert_eq!(bytes.len(), 2);
464        assert_eq!(bytes[1], 0xFF);
465    }
466
467    #[test]
468    fn test_clone() {
469        let bv1 = BitVector::from_binary_str("101010").unwrap();
470        let bv2 = bv1.clone();
471        assert_eq!(bv1, bv2);
472
473        // Modify clone and ensure original unchanged
474        let mut bv3 = bv2;
475        bv3.set(0, false);
476        assert_ne!(bv1, bv3);
477    }
478
479    #[test]
480    fn test_edge_cases() {
481        // Empty bitvector
482        let bv = BitVector::new(0);
483        assert_eq!(bv.len(), 0);
484        assert_eq!(bv.bytes_len(), 0);
485        assert!(bv.none());
486        assert!(!bv.any());
487        assert!(bv.all()); // Vacuous truth
488
489        // Single bit
490        let mut bv2 = BitVector::new(1);
491        bv2.set(0, true);
492        assert_eq!(bv2.len(), 1);
493        assert_eq!(bv2.bytes_len(), 1);
494        assert!(bv2.any());
495        assert!(bv2.all());
496
497        // Exactly 8 bits (1 full byte)
498        let bv3 = BitVector::new(8);
499        assert_eq!(bv3.bytes_len(), 1);
500
501        // Exactly 9 bits (needs 2 bytes)
502        let bv4 = BitVector::new(9);
503        assert_eq!(bv4.bytes_len(), 2);
504    }
505
506    #[test]
507    fn test_bitwise_not_with_partial_bytes() {
508        let bv = BitVector::from_binary_str("111100").unwrap(); // 6 bits
509        let not_bv = !&bv;
510
511        // Should be 0011 followed by 00 (for the last 2 bits of the byte)
512        // But only first 6 bits matter
513        assert_eq!(&not_bv.to_binary_string()[0..6], "000011");
514    }
515}