zerocopy_bitslice/
lib.rs

1#![cfg_attr(not(feature = "std"), no_std)]
2
3use shadow_nft_common::array_from_fn;
4
5/// LenType can be any type 32-bit or lower.
6type LenType = u32;
7
8const LEN_SIZE: usize = ::core::mem::size_of::<LenType>();
9const BITS_PER_BYTE: usize = 8;
10
11/// A zero copy container that helps read from and write to the bits in a `&[u8]`.
12pub struct ZeroCopyBitSlice<'a> {
13    /// The bytes that contain the bits
14    bit_bytes: &'a mut [u8],
15    /// The number of bits contained in the bytes (possibly not equal to )
16    bit_len: LenType,
17}
18
19/// A zero copy container that helps read from and write to the bits in a `&[u8]`.
20pub struct ZeroCopyBitSliceRead<'a> {
21    /// The bytes that contain the bits
22    bit_bytes: &'a [u8],
23    /// The number of bits contained in the bytes (possibly not equal to )
24    bit_len: LenType,
25}
26
27impl<'a> ZeroCopyBitSlice<'a> {
28    /// Initializes a `ZeroCopyBitSlice` with `num_bits` bits within the buffer provided.
29    /// Requires space for the bit length in addition to bytes for all bits.
30    ///
31    /// We also update the position of the reference to just after the end of the bitslice.
32    ///
33    /// Panics if buffer is not large enough.
34    pub fn intialize_in<'o>(
35        num_bits: LenType,
36        bytes: &'o mut &'a mut [u8],
37    ) -> ZeroCopyBitSlice<'a> {
38        // Calculate number of contiguous bytes required to store this number of bits.
39        // Because the `LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
40        // nonzero size) div_ceil can be done with unchecked math
41        let byte_len = (num_bits as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
42
43        // Calculate required buffer size
44        let required_size = byte_len + LEN_SIZE;
45        if bytes.len() < required_size {
46            panic!("buffer provided for initialization is not large enough")
47        }
48
49        // Write size to buffer
50        bytes[0..LEN_SIZE].copy_from_slice(num_bits.to_le_bytes().as_ref());
51
52        // Split input bytes and update pointer
53        let (bit_bytes, rest) = unsafe {
54            ::core::mem::transmute::<(&'o mut [u8], &'o mut [u8]), (&'a mut [u8], &'a mut [u8])>(
55                bytes.split_at_mut(LEN_SIZE + byte_len),
56            )
57        };
58        *bytes = rest;
59
60        ZeroCopyBitSlice {
61            bit_bytes: &mut bit_bytes[LEN_SIZE..LEN_SIZE + byte_len],
62            bit_len: num_bits,
63        }
64    }
65
66    /// Given a pointer to a `&mut [u8]` with a **previously initialized** `ZeropCopyBitSlice<'a>,
67    /// perform zero copy deserialization on the byte array containing the bits.
68    pub fn from_bytes(bytes: &'a mut [u8]) -> ZeroCopyBitSlice<'a> {
69        // Get lengths
70        // 1) Read length (in bits!) as LenType from a little endian format
71        // 2) Calculate number of contiguous bytes required to store this number of bits.
72        //    Because the LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
73        //    nonzero size) div_ceil can be done with unchecked math
74        let bit_len: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| bytes[i]));
75        let byte_len = (bit_len as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
76
77        // Then get rest
78        let bit_bytes: &mut [u8] = &mut bytes[LEN_SIZE..LEN_SIZE + byte_len];
79
80        ZeroCopyBitSlice { bit_bytes, bit_len }
81    }
82
83    /// Given a pointer to a `&mut [u8]` with a **previously initialized** `ZeropCopyBitSlice<'a>,
84    /// perform zero copy deserialization on the byte array containing the bits.
85    pub fn from_bytes_update<'o>(bytes: &'o mut &'a mut [u8]) -> ZeroCopyBitSlice<'a> {
86        // Get lengths
87        // 1) Read length (in bits!) as LenType from a little endian format
88        // 2) Calculate number of contiguous bytes required to store this number of bits.
89        //    Because the LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
90        //    nonzero size) div_ceil can be done with unchecked math
91        let bit_len: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| bytes[i]));
92        let byte_len = (bit_len as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
93
94        // Then get bit_bytes, rest of bytes, and update pointer
95        let (bit_bytes, rest) = unsafe {
96            ::core::mem::transmute::<(&'o mut [u8], &'o mut [u8]), (&'a mut [u8], &'a mut [u8])>(
97                bytes[LEN_SIZE..].split_at_mut(byte_len),
98            )
99        };
100        *bytes = rest;
101
102        ZeroCopyBitSlice { bit_bytes, bit_len }
103    }
104
105    /// The number of bits contained within this bit slice
106    pub fn bit_len(&self) -> LenType {
107        self.bit_len
108    }
109
110    /// The number bytes used to store these bits
111    pub fn byte_len(&self) -> usize {
112        self.bit_bytes.len()
113    }
114
115    /// Sets the value of the bit at position `idx`.
116    ///
117    /// Panics if out of bounds.
118    pub fn set(&mut self, idx: LenType, value: bool) {
119        // Check bounds
120        if idx >= self.bit_len {
121            panic!("bit index is out of bounds");
122        }
123
124        // Calculate bit location.
125        let byte_idx: usize = idx as usize / BITS_PER_BYTE;
126        let rel_bit_idx: u8 = idx as u8 % BITS_PER_BYTE as u8;
127
128        if value {
129            // If value is true, we set the bit at rel_bit_idx to 1.
130            self.bit_bytes[byte_idx] |= 1 << rel_bit_idx;
131        } else {
132            // If value is false, we set the bit at rel_bit_idx to 0.
133            self.bit_bytes[byte_idx] &= !(1 << rel_bit_idx);
134        }
135    }
136
137    /// Retrieves the value of the bit at position `idx`.
138    ///
139    /// Panics if out of bounds.
140    pub fn get(&self, idx: LenType) -> bool {
141        // Check bounds
142        if idx as LenType >= self.bit_len {
143            panic!("bit index is out of bounds");
144        }
145
146        // Calculate bit location
147        let byte_idx: usize = (idx as usize) / BITS_PER_BYTE;
148        let rel_bit_idx: u8 = (idx as u8) % (BITS_PER_BYTE as u8);
149
150        // Check if the bit at rel_bit_idx is set. This operation masks all other bits and checks if the result is non-zero.
151        (self.bit_bytes[byte_idx] & (1 << rel_bit_idx)) != 0
152    }
153
154    /// Chooses a random zero bit from within the bitslice, setting the bit to 1 and returning the bit's index.
155    ///
156    /// Returns `None` if all bits are set to 1.
157    #[cfg(feature = "choose-random-zero")]
158    pub fn choose_random_zero(&mut self, seed: impl AsRef<[u8]>) -> Option<LenType> {
159        use sha2::{Digest, Sha256};
160
161        // First calculate the number of zeros and handle num_zeros = 0 case
162        let num_zeros: LenType = self.num_zeros();
163        if num_zeros == 0 {
164            return None;
165        }
166
167        // Then calculate a hash-based pseudorandom seed
168        let seed = {
169            let mut hasher = Sha256::new();
170            hasher.update(seed.as_ref());
171            hasher.finalize()
172        };
173
174        // Use the first `LEN_SIZE` bytes from the hash to construct an index that is maybe out of bounds,
175        // the bring it in within bounds. Note this is the index among zeros not the global bit index.
176        let maybe_oob_index: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| seed[i]));
177        let zero_index = maybe_oob_index % num_zeros;
178
179        // Finally, find global index of zero bit and set that bit to 1
180        find_and_set_nth_zero_bit(self.bit_bytes, zero_index, self.bit_len)
181    }
182
183    /// Calculates the number of zero bits in the bitslice.
184    pub fn num_zeros(&self) -> LenType {
185        // Handle zero length case
186        if self.byte_len() == 0 {
187            return 0;
188        }
189
190        // First count the zeros in all complete bytes (all bytes which fully occupy 8 bits,
191        // which is either all of them when bit_len is divisible by 8 or all but the last if not)
192        let num_complete_bytes = (self.bit_len as usize)
193            .checked_div(BITS_PER_BYTE)
194            .expect("already handled zero byte_len case");
195
196        // Initialize accumulator variable
197        let mut zeros = 0;
198
199        // Count all zeros in complete bytes
200        for byte in &self.bit_bytes[0..num_complete_bytes] {
201            zeros += byte.count_zeros();
202        }
203
204        // Add last byte if necessary
205        if num_complete_bytes < self.byte_len() {
206            // Get the last byte
207            let last_byte = *self
208                .bit_bytes
209                .last()
210                .expect("already handled zero byte_len case");
211
212            // If we are in this branch, it is because bit_len is not divisible by 8.
213            // So this arithmetic should not overflow
214            let up_to = (self.bit_len() % 8) as u8 - 1;
215
216            // Count zeros up to the `up_to` bit (zero index)
217            zeros += count_zero_bits_in_byte_up_to_bit(last_byte, up_to);
218        }
219
220        zeros
221    }
222
223    /// Calculates the requires number of bytes to initialize a `ZeroCopyBitSlice`, e.g. via
224    /// `initialize_in`.
225    pub const fn required_space(num_bits: LenType) -> usize {
226        // Because the LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
227        // nonzero size) div_ceil can be done with unchecked math
228        LEN_SIZE + (num_bits as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE
229    }
230}
231
232/// Given a slice of bytes, we find the `nth` zero bit (zero indexed; 0 is first zero bit). Upon
233/// finding it, we set the bit to 1 and return its global bit index.
234///
235/// Instead of iterating through individual bits to find a particular zero, it is possible to
236/// count the number of zero bits one byte at a time. So, to find the `nth` bit, we take this chunk
237/// approach until we reach a byte where the cumulative number of zero bits exceeds the target bit.
238/// Then, we switch over to a bit by bit approach, find the nth zero bit, set it to 1, and return the
239/// global bit position.
240fn find_and_set_nth_zero_bit(bytes: &mut [u8], nth: LenType, bit_len: LenType) -> Option<LenType> {
241    // Initialize zero bit counter
242    let mut zero_count: LenType = 0;
243
244    for (&byte, byte_idx) in bytes.iter().zip(0..) {
245        // Calculate the number of zeros in this byte
246        let byte_zero_count = byte.count_zeros() as LenType;
247
248        // If the cumulative number of zeros exceeds the target zero bit, switch to bit-by-bit approach
249        if zero_count + byte_zero_count > nth {
250            // Iterate through each bit in the byte
251            for rel_bit_idx in 0..8 {
252                // Handle case where the number of bits is not divisible by 8 and return early
253                if byte_idx * 8 + rel_bit_idx == bit_len {
254                    return None;
255                }
256
257                // Check for zero bit
258                if (byte & (1 << rel_bit_idx)) == 0 {
259                    // Check if zero bit is nth zero bit
260                    if zero_count == nth {
261                        // If so, set bit to 1 and return global bit index
262                        bytes[byte_idx as usize] |= 1 << rel_bit_idx;
263                        return Some(byte_idx * 8 + rel_bit_idx);
264                    }
265
266                    // Increment if we find zero bit
267                    zero_count += 1;
268                }
269            }
270        } else {
271            // Batch increment
272            zero_count += byte_zero_count;
273        }
274    }
275
276    None
277}
278impl<'a> ZeroCopyBitSliceRead<'a> {
279    /// Given a pointer to a `&mut [u8]` with a **previously initialized** `ZeropCopyBitSlice<'a>,
280    /// perform zero copy deserialization on the byte array containing the bits.
281    pub unsafe fn from_bytes(bytes: &'a [u8]) -> ZeroCopyBitSliceRead<'a> {
282        // Get lengths
283        // 1) Read length (in bits!) as LenType from a little endian format
284        // 2) Calculate number of contiguous bytes required to store this number of bits.
285        //    Because the LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
286        //    nonzero size) div_ceil can be done with unchecked math
287        let bit_len: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| bytes[i]));
288        let byte_len = (bit_len as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
289
290        // Then get rest
291        let bit_bytes: &[u8] = &bytes[LEN_SIZE..LEN_SIZE + byte_len];
292
293        ZeroCopyBitSliceRead { bit_bytes, bit_len }
294    }
295
296    /// Calculates the number of zero bits in the bitslice.
297    pub fn num_zeros(&self) -> LenType {
298        // Handle zero length case
299        if self.byte_len() == 0 {
300            return 0;
301        }
302
303        // First count the zeros in all complete bytes (all bytes which fully occupy 8 bits,
304        // which is either all of them when bit_len is divisible by 8 or all but the last if not)
305        let num_complete_bytes = (self.bit_len as usize)
306            .checked_div(BITS_PER_BYTE)
307            .expect("already handled zero byte_len case");
308
309        // Initialize accumulator variable
310        let mut zeros = 0;
311
312        // Count all zeros in complete bytes
313        for byte in &self.bit_bytes[0..num_complete_bytes] {
314            zeros += byte.count_zeros();
315        }
316
317        // Add last byte if necessary
318        if num_complete_bytes < self.byte_len() {
319            // Get the last byte
320            let last_byte = *self
321                .bit_bytes
322                .last()
323                .expect("already handled zero byte_len case");
324
325            // If we are in this branch, it is because bit_len is not divisible by 8.
326            // So this arithmetic should not overflow
327            let up_to = (self.bit_len() % 8) as u8 - 1;
328
329            // Count zeros up to the `up_to` bit (zero index)
330            zeros += count_zero_bits_in_byte_up_to_bit(last_byte, up_to);
331        }
332
333        zeros
334    }
335
336    /// The number of bits contained within this bit slice
337    pub fn bit_len(&self) -> LenType {
338        self.bit_len
339    }
340
341    /// The number bytes used to store these bits
342    pub fn byte_len(&self) -> usize {
343        self.bit_bytes.len()
344    }
345}
346
347#[test]
348fn test_find_and_set_bit() {
349    // Some bytes to play with.
350    // In this example, the third zero bit (nth = 2) is the 4th bit (index = 3).
351    // AFTER having set that one to 1, the 5th zero bit (nth = 4) is the 7th bit (index = 6);
352    // AFTER these two are set to 1, the 9th zero bit (nth = 8) is the 15th bit (index = 14);
353    //
354    // Note: Recall that the bits in an individual byte are read from right to left.
355    let mut bytes = [0b_0000_0010, 0b_1010_1010];
356    let expected_ = [0b_0100_1010, 0b_1110_1010];
357    let exp_idxs = [3, 6, 14];
358
359    let idx1 = find_and_set_nth_zero_bit(&mut bytes, 2, 16).unwrap();
360    let idx2 = find_and_set_nth_zero_bit(&mut bytes, 4, 16).unwrap();
361    // Save the state for some later tests
362    let later_test = bytes;
363    let idx3 = find_and_set_nth_zero_bit(&mut bytes, 8, 16).unwrap();
364
365    // Check for correct indices and bytes state
366    assert_eq!(idx1, exp_idxs[0]);
367    assert_eq!(idx2, exp_idxs[1]);
368    assert_eq!(idx3, exp_idxs[2]);
369    assert_eq!(bytes, expected_);
370
371    // Here we test for cases where the number of bits is not divisible by 8.
372    //
373    // If we had set bit_len to 15, this last operation should still be in bounds,
374    // but a bit_len of 14 should results in an out-of-bounds access -> None
375    //
376    // We reuse the state of the bytes as frozen before idx3 for these two tests
377    let idx3_2 = find_and_set_nth_zero_bit(&mut later_test.clone(), 8, 15).unwrap();
378    assert_eq!(idx3, idx3_2);
379    assert!(find_and_set_nth_zero_bit(&mut later_test.clone(), 8, 14).is_none());
380}
381
382#[inline(always)]
383fn count_zero_bits_in_byte_up_to_bit(byte: u8, up_to_bit: u8) -> LenType {
384    (!byte << (7 - up_to_bit)).count_ones()
385}
386
387#[test]
388fn test_count_zero_up_to() {
389    // Sanity check, not full correctness check
390    for up_to_bit in 0..=7 {
391        for byte in 0..=u8::MAX {
392            let count = count_zero_bits_in_byte_up_to_bit(byte, up_to_bit);
393            assert!(
394                count <= up_to_bit as u32 + 1,
395                "got {count} <= {up_to_bit} for {byte:08b} up to {up_to_bit}"
396            );
397        }
398    }
399
400    // Check a few cases
401    let few_cases = [0b_0000_0100, 0b_0010_0010, 0b_1010_1010, 0b_1111_0111];
402    // Up to index 3 (first 4)
403    let expected_3 = [3, 3, 2, 1];
404    // up to index 5 (first 6)
405    let expected_5 = [5, 4, 3, 1];
406    for i in 0..4 {
407        assert_eq!(
408            count_zero_bits_in_byte_up_to_bit(few_cases[i], 3),
409            expected_3[i]
410        );
411        assert_eq!(
412            count_zero_bits_in_byte_up_to_bit(few_cases[i], 5),
413            expected_5[i]
414        );
415    }
416}
417
418#[test]
419fn test_deserialization() {
420    // For this test, must be a number in 9..16 because of the two flag bytes
421    const BIT_LEN: LenType = 9;
422    const BYTE_LEN: usize = (BIT_LEN as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
423
424    // There exists some underlying allocation of bytes with:
425    // 1) two dummy bytes
426    let [d1, d2] = [7, 3];
427    // 2) a LenType length
428    let [l1, l2, l3, l4] = LenType::to_le_bytes(BIT_LEN);
429    // 3) bits living in bytes within their corresponding length
430    let [flags1, flags2] = [0b10101010, 0b10000000];
431    // 4) another two dummy bytes
432    let [d3, d4] = [4, 5];
433    let mut bytes: Vec<u8> = vec![d1, d2, l1, l2, l3, l4, flags1, flags2, d3, d4];
434
435    // Construct it w/ bytes at correct location
436    let zcbs = ZeroCopyBitSlice::from_bytes(&mut bytes[2..]);
437
438    // Check for correct lengths and content
439    assert_eq!(zcbs.bit_len(), BIT_LEN);
440    assert_eq!(zcbs.byte_len(), BYTE_LEN);
441    assert_eq!(zcbs.bit_bytes, &[flags1, flags2]);
442}
443
444#[test]
445fn test_initialization() {
446    // For this test, must be a number in 9..16 because of the two flag bytes
447    const BIT_LEN: LenType = 9;
448    const BYTE_LEN: usize = (BIT_LEN as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
449
450    // Initialize buffer
451    let mut buffer = [0; BYTE_LEN + LEN_SIZE];
452    let mut buf_slice = buffer.as_mut_slice();
453
454    // Construct it w/ bytes at correct location
455    let zcbs = ZeroCopyBitSlice::intialize_in(BIT_LEN, &mut buf_slice);
456    assert_eq!(zcbs.bit_len(), BIT_LEN);
457    assert_eq!(zcbs.byte_len(), BYTE_LEN);
458    drop(zcbs);
459
460    // Check buf_slice reference was updated properly
461    assert_eq!(buf_slice.len(), 0);
462
463    // Check corrent bit_len was written to buffer
464    let bit_len: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| buffer[i]));
465    assert_eq!(bit_len, BIT_LEN);
466}
467
468#[test]
469fn test_reads_and_writes() {
470    // Get LenType length
471    const BIT_LEN: LenType = 9;
472    const BYTE_LEN: usize = (BIT_LEN as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
473
474    // Construct bit slice
475    let mut zcbs = ZeroCopyBitSlice {
476        bit_bytes: &mut [0; BYTE_LEN],
477        bit_len: BIT_LEN,
478    };
479
480    // All bits should be false
481    for bit_idx in 0..BIT_LEN {
482        assert!(!zcbs.get(bit_idx));
483    }
484
485    // Set every other bit to true
486    for bit_idx in 0..BIT_LEN {
487        zcbs.set(bit_idx, bit_idx % 2 == 0);
488    }
489
490    // Check whether bits updated properly
491    for bit_idx in 0..BIT_LEN {
492        assert_eq!(zcbs.get(bit_idx), bit_idx % 2 == 0);
493    }
494
495    // Flip all bits
496    for bit_idx in 0..BIT_LEN {
497        zcbs.set(bit_idx, bit_idx % 2 != 0);
498    }
499
500    // Check whether bits updated properly
501    for bit_idx in 0..BIT_LEN {
502        assert_eq!(zcbs.get(bit_idx), bit_idx % 2 != 0);
503    }
504}
505
506#[test]
507#[cfg(feature = "choose-random-zero")]
508fn test_get_random() {
509    const BIT_LEN: LenType = 9;
510    const BYTE_LEN: usize = (BIT_LEN as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
511
512    // In this test we generate random sequences.
513    // We cannot straightforwardly predict the deterministic randoms sequences.
514    //
515    // So, we probe correctness with the following checks:
516    // 1) We check that >3/4 of the trials are not equal to the sorted sequence
517    // 2) The sequences must contain all elements, so we compare with sorted array
518
519    // The expected sorted results for all seeds
520    let expected_sorted = array_from_fn::from_fn(|i| i as LenType);
521    let mut unequal_counter = 0;
522
523    const TRIALS: u16 = 128;
524    for trial in 0..TRIALS {
525        // For every seed initialize a new zcbs
526        let mut zcbs = ZeroCopyBitSlice {
527            bit_bytes: &mut [0; BYTE_LEN],
528            bit_len: BIT_LEN,
529        };
530
531        // Turn the trial into a `&[u8]` seed
532        let seed = trial.to_le_bytes();
533
534        // Get the random sequence
535        let mut sequence: [LenType; BIT_LEN as usize] = array_from_fn::from_fn(|i| {
536            // Check that the number of zeros decrements properly
537            assert_eq!(zcbs.num_zeros(), BIT_LEN - i as LenType);
538
539            // Get next element in random sequence
540            zcbs.choose_random_zero(seed).unwrap()
541        });
542
543        // Increment in unequal for Check #1
544        if sequence != expected_sorted {
545            unequal_counter += 1;
546        }
547
548        // Do check #2
549        sequence.sort();
550        assert_eq!(sequence, expected_sorted);
551    }
552
553    // Do check #1
554    assert!(unequal_counter > 3 * TRIALS / 4);
555}