zerocopy_bitslice/lib.rs
1#![cfg_attr(not(feature = "std"), no_std)]
2
3use shadow_nft_common::array_from_fn;
4
5/// LenType can be any type 32-bit or lower.
6type LenType = u32;
7
8const LEN_SIZE: usize = ::core::mem::size_of::<LenType>();
9const BITS_PER_BYTE: usize = 8;
10
11/// A zero copy container that helps read from and write to the bits in a `&[u8]`.
12pub struct ZeroCopyBitSlice<'a> {
13 /// The bytes that contain the bits
14 bit_bytes: &'a mut [u8],
15 /// The number of bits contained in the bytes (possibly not equal to )
16 bit_len: LenType,
17}
18
19/// A zero copy container that helps read from and write to the bits in a `&[u8]`.
20pub struct ZeroCopyBitSliceRead<'a> {
21 /// The bytes that contain the bits
22 bit_bytes: &'a [u8],
23 /// The number of bits contained in the bytes (possibly not equal to )
24 bit_len: LenType,
25}
26
27impl<'a> ZeroCopyBitSlice<'a> {
28 /// Initializes a `ZeroCopyBitSlice` with `num_bits` bits within the buffer provided.
29 /// Requires space for the bit length in addition to bytes for all bits.
30 ///
31 /// We also update the position of the reference to just after the end of the bitslice.
32 ///
33 /// Panics if buffer is not large enough.
34 pub fn intialize_in<'o>(
35 num_bits: LenType,
36 bytes: &'o mut &'a mut [u8],
37 ) -> ZeroCopyBitSlice<'a> {
38 // Calculate number of contiguous bytes required to store this number of bits.
39 // Because the `LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
40 // nonzero size) div_ceil can be done with unchecked math
41 let byte_len = (num_bits as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
42
43 // Calculate required buffer size
44 let required_size = byte_len + LEN_SIZE;
45 if bytes.len() < required_size {
46 panic!("buffer provided for initialization is not large enough")
47 }
48
49 // Write size to buffer
50 bytes[0..LEN_SIZE].copy_from_slice(num_bits.to_le_bytes().as_ref());
51
52 // Split input bytes and update pointer
53 let (bit_bytes, rest) = unsafe {
54 ::core::mem::transmute::<(&'o mut [u8], &'o mut [u8]), (&'a mut [u8], &'a mut [u8])>(
55 bytes.split_at_mut(LEN_SIZE + byte_len),
56 )
57 };
58 *bytes = rest;
59
60 ZeroCopyBitSlice {
61 bit_bytes: &mut bit_bytes[LEN_SIZE..LEN_SIZE + byte_len],
62 bit_len: num_bits,
63 }
64 }
65
66 /// Given a pointer to a `&mut [u8]` with a **previously initialized** `ZeropCopyBitSlice<'a>,
67 /// perform zero copy deserialization on the byte array containing the bits.
68 pub fn from_bytes(bytes: &'a mut [u8]) -> ZeroCopyBitSlice<'a> {
69 // Get lengths
70 // 1) Read length (in bits!) as LenType from a little endian format
71 // 2) Calculate number of contiguous bytes required to store this number of bits.
72 // Because the LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
73 // nonzero size) div_ceil can be done with unchecked math
74 let bit_len: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| bytes[i]));
75 let byte_len = (bit_len as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
76
77 // Then get rest
78 let bit_bytes: &mut [u8] = &mut bytes[LEN_SIZE..LEN_SIZE + byte_len];
79
80 ZeroCopyBitSlice { bit_bytes, bit_len }
81 }
82
83 /// Given a pointer to a `&mut [u8]` with a **previously initialized** `ZeropCopyBitSlice<'a>,
84 /// perform zero copy deserialization on the byte array containing the bits.
85 pub fn from_bytes_update<'o>(bytes: &'o mut &'a mut [u8]) -> ZeroCopyBitSlice<'a> {
86 // Get lengths
87 // 1) Read length (in bits!) as LenType from a little endian format
88 // 2) Calculate number of contiguous bytes required to store this number of bits.
89 // Because the LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
90 // nonzero size) div_ceil can be done with unchecked math
91 let bit_len: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| bytes[i]));
92 let byte_len = (bit_len as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
93
94 // Then get bit_bytes, rest of bytes, and update pointer
95 let (bit_bytes, rest) = unsafe {
96 ::core::mem::transmute::<(&'o mut [u8], &'o mut [u8]), (&'a mut [u8], &'a mut [u8])>(
97 bytes[LEN_SIZE..].split_at_mut(byte_len),
98 )
99 };
100 *bytes = rest;
101
102 ZeroCopyBitSlice { bit_bytes, bit_len }
103 }
104
105 /// The number of bits contained within this bit slice
106 pub fn bit_len(&self) -> LenType {
107 self.bit_len
108 }
109
110 /// The number bytes used to store these bits
111 pub fn byte_len(&self) -> usize {
112 self.bit_bytes.len()
113 }
114
115 /// Sets the value of the bit at position `idx`.
116 ///
117 /// Panics if out of bounds.
118 pub fn set(&mut self, idx: LenType, value: bool) {
119 // Check bounds
120 if idx >= self.bit_len {
121 panic!("bit index is out of bounds");
122 }
123
124 // Calculate bit location.
125 let byte_idx: usize = idx as usize / BITS_PER_BYTE;
126 let rel_bit_idx: u8 = idx as u8 % BITS_PER_BYTE as u8;
127
128 if value {
129 // If value is true, we set the bit at rel_bit_idx to 1.
130 self.bit_bytes[byte_idx] |= 1 << rel_bit_idx;
131 } else {
132 // If value is false, we set the bit at rel_bit_idx to 0.
133 self.bit_bytes[byte_idx] &= !(1 << rel_bit_idx);
134 }
135 }
136
137 /// Retrieves the value of the bit at position `idx`.
138 ///
139 /// Panics if out of bounds.
140 pub fn get(&self, idx: LenType) -> bool {
141 // Check bounds
142 if idx as LenType >= self.bit_len {
143 panic!("bit index is out of bounds");
144 }
145
146 // Calculate bit location
147 let byte_idx: usize = (idx as usize) / BITS_PER_BYTE;
148 let rel_bit_idx: u8 = (idx as u8) % (BITS_PER_BYTE as u8);
149
150 // Check if the bit at rel_bit_idx is set. This operation masks all other bits and checks if the result is non-zero.
151 (self.bit_bytes[byte_idx] & (1 << rel_bit_idx)) != 0
152 }
153
154 /// Chooses a random zero bit from within the bitslice, setting the bit to 1 and returning the bit's index.
155 ///
156 /// Returns `None` if all bits are set to 1.
157 #[cfg(feature = "choose-random-zero")]
158 pub fn choose_random_zero(&mut self, seed: impl AsRef<[u8]>) -> Option<LenType> {
159 use sha2::{Digest, Sha256};
160
161 // First calculate the number of zeros and handle num_zeros = 0 case
162 let num_zeros: LenType = self.num_zeros();
163 if num_zeros == 0 {
164 return None;
165 }
166
167 // Then calculate a hash-based pseudorandom seed
168 let seed = {
169 let mut hasher = Sha256::new();
170 hasher.update(seed.as_ref());
171 hasher.finalize()
172 };
173
174 // Use the first `LEN_SIZE` bytes from the hash to construct an index that is maybe out of bounds,
175 // the bring it in within bounds. Note this is the index among zeros not the global bit index.
176 let maybe_oob_index: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| seed[i]));
177 let zero_index = maybe_oob_index % num_zeros;
178
179 // Finally, find global index of zero bit and set that bit to 1
180 find_and_set_nth_zero_bit(self.bit_bytes, zero_index, self.bit_len)
181 }
182
183 /// Calculates the number of zero bits in the bitslice.
184 pub fn num_zeros(&self) -> LenType {
185 // Handle zero length case
186 if self.byte_len() == 0 {
187 return 0;
188 }
189
190 // First count the zeros in all complete bytes (all bytes which fully occupy 8 bits,
191 // which is either all of them when bit_len is divisible by 8 or all but the last if not)
192 let num_complete_bytes = (self.bit_len as usize)
193 .checked_div(BITS_PER_BYTE)
194 .expect("already handled zero byte_len case");
195
196 // Initialize accumulator variable
197 let mut zeros = 0;
198
199 // Count all zeros in complete bytes
200 for byte in &self.bit_bytes[0..num_complete_bytes] {
201 zeros += byte.count_zeros();
202 }
203
204 // Add last byte if necessary
205 if num_complete_bytes < self.byte_len() {
206 // Get the last byte
207 let last_byte = *self
208 .bit_bytes
209 .last()
210 .expect("already handled zero byte_len case");
211
212 // If we are in this branch, it is because bit_len is not divisible by 8.
213 // So this arithmetic should not overflow
214 let up_to = (self.bit_len() % 8) as u8 - 1;
215
216 // Count zeros up to the `up_to` bit (zero index)
217 zeros += count_zero_bits_in_byte_up_to_bit(last_byte, up_to);
218 }
219
220 zeros
221 }
222
223 /// Calculates the requires number of bytes to initialize a `ZeroCopyBitSlice`, e.g. via
224 /// `initialize_in`.
225 pub const fn required_space(num_bits: LenType) -> usize {
226 // Because the LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
227 // nonzero size) div_ceil can be done with unchecked math
228 LEN_SIZE + (num_bits as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE
229 }
230}
231
232/// Given a slice of bytes, we find the `nth` zero bit (zero indexed; 0 is first zero bit). Upon
233/// finding it, we set the bit to 1 and return its global bit index.
234///
235/// Instead of iterating through individual bits to find a particular zero, it is possible to
236/// count the number of zero bits one byte at a time. So, to find the `nth` bit, we take this chunk
237/// approach until we reach a byte where the cumulative number of zero bits exceeds the target bit.
238/// Then, we switch over to a bit by bit approach, find the nth zero bit, set it to 1, and return the
239/// global bit position.
240fn find_and_set_nth_zero_bit(bytes: &mut [u8], nth: LenType, bit_len: LenType) -> Option<LenType> {
241 // Initialize zero bit counter
242 let mut zero_count: LenType = 0;
243
244 for (&byte, byte_idx) in bytes.iter().zip(0..) {
245 // Calculate the number of zeros in this byte
246 let byte_zero_count = byte.count_zeros() as LenType;
247
248 // If the cumulative number of zeros exceeds the target zero bit, switch to bit-by-bit approach
249 if zero_count + byte_zero_count > nth {
250 // Iterate through each bit in the byte
251 for rel_bit_idx in 0..8 {
252 // Handle case where the number of bits is not divisible by 8 and return early
253 if byte_idx * 8 + rel_bit_idx == bit_len {
254 return None;
255 }
256
257 // Check for zero bit
258 if (byte & (1 << rel_bit_idx)) == 0 {
259 // Check if zero bit is nth zero bit
260 if zero_count == nth {
261 // If so, set bit to 1 and return global bit index
262 bytes[byte_idx as usize] |= 1 << rel_bit_idx;
263 return Some(byte_idx * 8 + rel_bit_idx);
264 }
265
266 // Increment if we find zero bit
267 zero_count += 1;
268 }
269 }
270 } else {
271 // Batch increment
272 zero_count += byte_zero_count;
273 }
274 }
275
276 None
277}
278impl<'a> ZeroCopyBitSliceRead<'a> {
279 /// Given a pointer to a `&mut [u8]` with a **previously initialized** `ZeropCopyBitSlice<'a>,
280 /// perform zero copy deserialization on the byte array containing the bits.
281 pub unsafe fn from_bytes(bytes: &'a [u8]) -> ZeroCopyBitSliceRead<'a> {
282 // Get lengths
283 // 1) Read length (in bits!) as LenType from a little endian format
284 // 2) Calculate number of contiguous bytes required to store this number of bits.
285 // Because the LenType` is <= 32 bits (i.e. `LenType::MAX < usize::MAX` and a `usize` has
286 // nonzero size) div_ceil can be done with unchecked math
287 let bit_len: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| bytes[i]));
288 let byte_len = (bit_len as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
289
290 // Then get rest
291 let bit_bytes: &[u8] = &bytes[LEN_SIZE..LEN_SIZE + byte_len];
292
293 ZeroCopyBitSliceRead { bit_bytes, bit_len }
294 }
295
296 /// Calculates the number of zero bits in the bitslice.
297 pub fn num_zeros(&self) -> LenType {
298 // Handle zero length case
299 if self.byte_len() == 0 {
300 return 0;
301 }
302
303 // First count the zeros in all complete bytes (all bytes which fully occupy 8 bits,
304 // which is either all of them when bit_len is divisible by 8 or all but the last if not)
305 let num_complete_bytes = (self.bit_len as usize)
306 .checked_div(BITS_PER_BYTE)
307 .expect("already handled zero byte_len case");
308
309 // Initialize accumulator variable
310 let mut zeros = 0;
311
312 // Count all zeros in complete bytes
313 for byte in &self.bit_bytes[0..num_complete_bytes] {
314 zeros += byte.count_zeros();
315 }
316
317 // Add last byte if necessary
318 if num_complete_bytes < self.byte_len() {
319 // Get the last byte
320 let last_byte = *self
321 .bit_bytes
322 .last()
323 .expect("already handled zero byte_len case");
324
325 // If we are in this branch, it is because bit_len is not divisible by 8.
326 // So this arithmetic should not overflow
327 let up_to = (self.bit_len() % 8) as u8 - 1;
328
329 // Count zeros up to the `up_to` bit (zero index)
330 zeros += count_zero_bits_in_byte_up_to_bit(last_byte, up_to);
331 }
332
333 zeros
334 }
335
336 /// The number of bits contained within this bit slice
337 pub fn bit_len(&self) -> LenType {
338 self.bit_len
339 }
340
341 /// The number bytes used to store these bits
342 pub fn byte_len(&self) -> usize {
343 self.bit_bytes.len()
344 }
345}
346
347#[test]
348fn test_find_and_set_bit() {
349 // Some bytes to play with.
350 // In this example, the third zero bit (nth = 2) is the 4th bit (index = 3).
351 // AFTER having set that one to 1, the 5th zero bit (nth = 4) is the 7th bit (index = 6);
352 // AFTER these two are set to 1, the 9th zero bit (nth = 8) is the 15th bit (index = 14);
353 //
354 // Note: Recall that the bits in an individual byte are read from right to left.
355 let mut bytes = [0b_0000_0010, 0b_1010_1010];
356 let expected_ = [0b_0100_1010, 0b_1110_1010];
357 let exp_idxs = [3, 6, 14];
358
359 let idx1 = find_and_set_nth_zero_bit(&mut bytes, 2, 16).unwrap();
360 let idx2 = find_and_set_nth_zero_bit(&mut bytes, 4, 16).unwrap();
361 // Save the state for some later tests
362 let later_test = bytes;
363 let idx3 = find_and_set_nth_zero_bit(&mut bytes, 8, 16).unwrap();
364
365 // Check for correct indices and bytes state
366 assert_eq!(idx1, exp_idxs[0]);
367 assert_eq!(idx2, exp_idxs[1]);
368 assert_eq!(idx3, exp_idxs[2]);
369 assert_eq!(bytes, expected_);
370
371 // Here we test for cases where the number of bits is not divisible by 8.
372 //
373 // If we had set bit_len to 15, this last operation should still be in bounds,
374 // but a bit_len of 14 should results in an out-of-bounds access -> None
375 //
376 // We reuse the state of the bytes as frozen before idx3 for these two tests
377 let idx3_2 = find_and_set_nth_zero_bit(&mut later_test.clone(), 8, 15).unwrap();
378 assert_eq!(idx3, idx3_2);
379 assert!(find_and_set_nth_zero_bit(&mut later_test.clone(), 8, 14).is_none());
380}
381
382#[inline(always)]
383fn count_zero_bits_in_byte_up_to_bit(byte: u8, up_to_bit: u8) -> LenType {
384 (!byte << (7 - up_to_bit)).count_ones()
385}
386
387#[test]
388fn test_count_zero_up_to() {
389 // Sanity check, not full correctness check
390 for up_to_bit in 0..=7 {
391 for byte in 0..=u8::MAX {
392 let count = count_zero_bits_in_byte_up_to_bit(byte, up_to_bit);
393 assert!(
394 count <= up_to_bit as u32 + 1,
395 "got {count} <= {up_to_bit} for {byte:08b} up to {up_to_bit}"
396 );
397 }
398 }
399
400 // Check a few cases
401 let few_cases = [0b_0000_0100, 0b_0010_0010, 0b_1010_1010, 0b_1111_0111];
402 // Up to index 3 (first 4)
403 let expected_3 = [3, 3, 2, 1];
404 // up to index 5 (first 6)
405 let expected_5 = [5, 4, 3, 1];
406 for i in 0..4 {
407 assert_eq!(
408 count_zero_bits_in_byte_up_to_bit(few_cases[i], 3),
409 expected_3[i]
410 );
411 assert_eq!(
412 count_zero_bits_in_byte_up_to_bit(few_cases[i], 5),
413 expected_5[i]
414 );
415 }
416}
417
418#[test]
419fn test_deserialization() {
420 // For this test, must be a number in 9..16 because of the two flag bytes
421 const BIT_LEN: LenType = 9;
422 const BYTE_LEN: usize = (BIT_LEN as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
423
424 // There exists some underlying allocation of bytes with:
425 // 1) two dummy bytes
426 let [d1, d2] = [7, 3];
427 // 2) a LenType length
428 let [l1, l2, l3, l4] = LenType::to_le_bytes(BIT_LEN);
429 // 3) bits living in bytes within their corresponding length
430 let [flags1, flags2] = [0b10101010, 0b10000000];
431 // 4) another two dummy bytes
432 let [d3, d4] = [4, 5];
433 let mut bytes: Vec<u8> = vec![d1, d2, l1, l2, l3, l4, flags1, flags2, d3, d4];
434
435 // Construct it w/ bytes at correct location
436 let zcbs = ZeroCopyBitSlice::from_bytes(&mut bytes[2..]);
437
438 // Check for correct lengths and content
439 assert_eq!(zcbs.bit_len(), BIT_LEN);
440 assert_eq!(zcbs.byte_len(), BYTE_LEN);
441 assert_eq!(zcbs.bit_bytes, &[flags1, flags2]);
442}
443
444#[test]
445fn test_initialization() {
446 // For this test, must be a number in 9..16 because of the two flag bytes
447 const BIT_LEN: LenType = 9;
448 const BYTE_LEN: usize = (BIT_LEN as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
449
450 // Initialize buffer
451 let mut buffer = [0; BYTE_LEN + LEN_SIZE];
452 let mut buf_slice = buffer.as_mut_slice();
453
454 // Construct it w/ bytes at correct location
455 let zcbs = ZeroCopyBitSlice::intialize_in(BIT_LEN, &mut buf_slice);
456 assert_eq!(zcbs.bit_len(), BIT_LEN);
457 assert_eq!(zcbs.byte_len(), BYTE_LEN);
458 drop(zcbs);
459
460 // Check buf_slice reference was updated properly
461 assert_eq!(buf_slice.len(), 0);
462
463 // Check corrent bit_len was written to buffer
464 let bit_len: LenType = LenType::from_le_bytes(array_from_fn::from_fn(|i| buffer[i]));
465 assert_eq!(bit_len, BIT_LEN);
466}
467
468#[test]
469fn test_reads_and_writes() {
470 // Get LenType length
471 const BIT_LEN: LenType = 9;
472 const BYTE_LEN: usize = (BIT_LEN as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
473
474 // Construct bit slice
475 let mut zcbs = ZeroCopyBitSlice {
476 bit_bytes: &mut [0; BYTE_LEN],
477 bit_len: BIT_LEN,
478 };
479
480 // All bits should be false
481 for bit_idx in 0..BIT_LEN {
482 assert!(!zcbs.get(bit_idx));
483 }
484
485 // Set every other bit to true
486 for bit_idx in 0..BIT_LEN {
487 zcbs.set(bit_idx, bit_idx % 2 == 0);
488 }
489
490 // Check whether bits updated properly
491 for bit_idx in 0..BIT_LEN {
492 assert_eq!(zcbs.get(bit_idx), bit_idx % 2 == 0);
493 }
494
495 // Flip all bits
496 for bit_idx in 0..BIT_LEN {
497 zcbs.set(bit_idx, bit_idx % 2 != 0);
498 }
499
500 // Check whether bits updated properly
501 for bit_idx in 0..BIT_LEN {
502 assert_eq!(zcbs.get(bit_idx), bit_idx % 2 != 0);
503 }
504}
505
506#[test]
507#[cfg(feature = "choose-random-zero")]
508fn test_get_random() {
509 const BIT_LEN: LenType = 9;
510 const BYTE_LEN: usize = (BIT_LEN as usize + BITS_PER_BYTE - 1) / BITS_PER_BYTE;
511
512 // In this test we generate random sequences.
513 // We cannot straightforwardly predict the deterministic randoms sequences.
514 //
515 // So, we probe correctness with the following checks:
516 // 1) We check that >3/4 of the trials are not equal to the sorted sequence
517 // 2) The sequences must contain all elements, so we compare with sorted array
518
519 // The expected sorted results for all seeds
520 let expected_sorted = array_from_fn::from_fn(|i| i as LenType);
521 let mut unequal_counter = 0;
522
523 const TRIALS: u16 = 128;
524 for trial in 0..TRIALS {
525 // For every seed initialize a new zcbs
526 let mut zcbs = ZeroCopyBitSlice {
527 bit_bytes: &mut [0; BYTE_LEN],
528 bit_len: BIT_LEN,
529 };
530
531 // Turn the trial into a `&[u8]` seed
532 let seed = trial.to_le_bytes();
533
534 // Get the random sequence
535 let mut sequence: [LenType; BIT_LEN as usize] = array_from_fn::from_fn(|i| {
536 // Check that the number of zeros decrements properly
537 assert_eq!(zcbs.num_zeros(), BIT_LEN - i as LenType);
538
539 // Get next element in random sequence
540 zcbs.choose_random_zero(seed).unwrap()
541 });
542
543 // Increment in unequal for Check #1
544 if sequence != expected_sorted {
545 unequal_counter += 1;
546 }
547
548 // Do check #2
549 sequence.sort();
550 assert_eq!(sequence, expected_sorted);
551 }
552
553 // Do check #1
554 assert!(unequal_counter > 3 * TRIALS / 4);
555}