summavy_bitpacker/
bitpacker.rs

1use std::convert::TryInto;
2use std::io;
3
4pub struct BitPacker {
5    mini_buffer: u64,
6    mini_buffer_written: usize,
7}
8impl Default for BitPacker {
9    fn default() -> Self {
10        BitPacker::new()
11    }
12}
13impl BitPacker {
14    pub fn new() -> BitPacker {
15        BitPacker {
16            mini_buffer: 0u64,
17            mini_buffer_written: 0,
18        }
19    }
20
21    #[inline]
22    pub fn write<TWrite: io::Write>(
23        &mut self,
24        val: u64,
25        num_bits: u8,
26        output: &mut TWrite,
27    ) -> io::Result<()> {
28        let num_bits = num_bits as usize;
29        if self.mini_buffer_written + num_bits > 64 {
30            self.mini_buffer |= val.wrapping_shl(self.mini_buffer_written as u32);
31            output.write_all(self.mini_buffer.to_le_bytes().as_ref())?;
32            self.mini_buffer = val.wrapping_shr((64 - self.mini_buffer_written) as u32);
33            self.mini_buffer_written = self.mini_buffer_written + num_bits - 64;
34        } else {
35            self.mini_buffer |= val << self.mini_buffer_written;
36            self.mini_buffer_written += num_bits;
37            if self.mini_buffer_written == 64 {
38                output.write_all(self.mini_buffer.to_le_bytes().as_ref())?;
39                self.mini_buffer_written = 0;
40                self.mini_buffer = 0u64;
41            }
42        }
43        Ok(())
44    }
45
46    pub fn flush<TWrite: io::Write>(&mut self, output: &mut TWrite) -> io::Result<()> {
47        if self.mini_buffer_written > 0 {
48            let num_bytes = (self.mini_buffer_written + 7) / 8;
49            let bytes = self.mini_buffer.to_le_bytes();
50            output.write_all(&bytes[..num_bytes])?;
51            self.mini_buffer_written = 0;
52            self.mini_buffer = 0;
53        }
54        Ok(())
55    }
56
57    pub fn close<TWrite: io::Write>(&mut self, output: &mut TWrite) -> io::Result<()> {
58        self.flush(output)?;
59        // Padding the write file to simplify reads.
60        output.write_all(&[0u8; 7])?;
61        Ok(())
62    }
63}
64
65#[derive(Clone, Debug, Default)]
66pub struct BitUnpacker {
67    num_bits: u64,
68    mask: u64,
69}
70
71impl BitUnpacker {
72    pub fn new(num_bits: u8) -> BitUnpacker {
73        let mask: u64 = if num_bits == 64 {
74            !0u64
75        } else {
76            (1u64 << num_bits) - 1u64
77        };
78        BitUnpacker {
79            num_bits: u64::from(num_bits),
80            mask,
81        }
82    }
83
84    pub fn bit_width(&self) -> u8 {
85        self.num_bits as u8
86    }
87
88    #[inline]
89    pub fn get(&self, idx: u32, data: &[u8]) -> u64 {
90        if self.num_bits == 0 {
91            return 0u64;
92        }
93        let addr_in_bits = idx * self.num_bits as u32;
94        let addr = (addr_in_bits >> 3) as usize;
95        let bit_shift = addr_in_bits & 7;
96        debug_assert!(
97            addr + 8 <= data.len(),
98            "The fast field field should have been padded with 7 bytes."
99        );
100        let bytes: [u8; 8] = (&data[addr..addr + 8]).try_into().unwrap();
101        let val_unshifted_unmasked: u64 = u64::from_le_bytes(bytes);
102        let val_shifted = val_unshifted_unmasked >> bit_shift;
103        val_shifted & self.mask
104    }
105}
106
107#[cfg(test)]
108mod test {
109    use super::{BitPacker, BitUnpacker};
110
111    fn create_fastfield_bitpacker(len: usize, num_bits: u8) -> (BitUnpacker, Vec<u64>, Vec<u8>) {
112        let mut data = Vec::new();
113        let mut bitpacker = BitPacker::new();
114        let max_val: u64 = (1u64 << num_bits as u64) - 1u64;
115        let vals: Vec<u64> = (0u64..len as u64)
116            .map(|i| if max_val == 0 { 0 } else { i % max_val })
117            .collect();
118        for &val in &vals {
119            bitpacker.write(val, num_bits, &mut data).unwrap();
120        }
121        bitpacker.close(&mut data).unwrap();
122        assert_eq!(data.len(), ((num_bits as usize) * len + 7) / 8 + 7);
123        let bitunpacker = BitUnpacker::new(num_bits);
124        (bitunpacker, vals, data)
125    }
126
127    fn test_bitpacker_util(len: usize, num_bits: u8) {
128        let (bitunpacker, vals, data) = create_fastfield_bitpacker(len, num_bits);
129        for (i, val) in vals.iter().enumerate() {
130            assert_eq!(bitunpacker.get(i as u32, &data), *val);
131        }
132    }
133
134    #[test]
135    fn test_bitpacker() {
136        test_bitpacker_util(10, 3);
137        test_bitpacker_util(10, 0);
138        test_bitpacker_util(10, 1);
139        test_bitpacker_util(6, 14);
140        test_bitpacker_util(1000, 14);
141    }
142}