Skip to main content

tantivy_bitpacker/
bitpacker.rs

1use std::io;
2use std::ops::{Range, RangeInclusive};
3
4use bitpacking::{BitPacker as ExternalBitPackerTrait, BitPacker1x};
5
6pub struct BitPacker {
7    mini_buffer: u64,
8    mini_buffer_written: usize,
9}
10
11impl Default for BitPacker {
12    fn default() -> Self {
13        BitPacker::new()
14    }
15}
16impl BitPacker {
17    pub fn new() -> BitPacker {
18        BitPacker {
19            mini_buffer: 0u64,
20            mini_buffer_written: 0,
21        }
22    }
23
24    #[inline]
25    pub fn write<TWrite: io::Write + ?Sized>(
26        &mut self,
27        val: u64,
28        num_bits: u8,
29        output: &mut TWrite,
30    ) -> io::Result<()> {
31        let num_bits = num_bits as usize;
32        if self.mini_buffer_written + num_bits > 64 {
33            self.mini_buffer |= val.wrapping_shl(self.mini_buffer_written as u32);
34            output.write_all(self.mini_buffer.to_le_bytes().as_ref())?;
35            self.mini_buffer = val.wrapping_shr((64 - self.mini_buffer_written) as u32);
36            self.mini_buffer_written = self.mini_buffer_written + num_bits - 64;
37        } else {
38            self.mini_buffer |= val << self.mini_buffer_written;
39            self.mini_buffer_written += num_bits;
40            if self.mini_buffer_written == 64 {
41                output.write_all(self.mini_buffer.to_le_bytes().as_ref())?;
42                self.mini_buffer_written = 0;
43                self.mini_buffer = 0u64;
44            }
45        }
46        Ok(())
47    }
48
49    pub fn flush<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
50        if self.mini_buffer_written > 0 {
51            let num_bytes = self.mini_buffer_written.div_ceil(8);
52            let bytes = self.mini_buffer.to_le_bytes();
53            output.write_all(&bytes[..num_bytes])?;
54            self.mini_buffer_written = 0;
55            self.mini_buffer = 0;
56        }
57        Ok(())
58    }
59
60    pub fn close<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
61        self.flush(output)?;
62        Ok(())
63    }
64}
65
66#[derive(Clone, Debug, Default, Copy)]
67pub struct BitUnpacker {
68    num_bits: usize,
69    mask: u64,
70}
71
72impl BitUnpacker {
73    /// Creates a bit unpacker, that assumes the same bitwidth for all values.
74    ///
75    /// The bitunpacker works by doing an unaligned read of 8 bytes.
76    /// For this reason, values of `num_bits` between
77    /// [57..63] are forbidden.
78    pub fn new(num_bits: u8) -> BitUnpacker {
79        assert!(num_bits <= 7 * 8 || num_bits == 64);
80        let mask: u64 = if num_bits == 64 {
81            !0u64
82        } else {
83            (1u64 << num_bits) - 1u64
84        };
85        BitUnpacker {
86            num_bits: usize::from(num_bits),
87            mask,
88        }
89    }
90
91    pub fn bit_width(&self) -> u8 {
92        self.num_bits as u8
93    }
94
95    #[inline]
96    pub fn get(&self, idx: u32, data: &[u8]) -> u64 {
97        let addr_in_bits = idx as usize * self.num_bits;
98        let addr = addr_in_bits >> 3;
99        if addr + 8 > data.len() {
100            if self.num_bits == 0 {
101                return 0;
102            }
103            let bit_shift = addr_in_bits & 7;
104            return self.get_slow_path(addr, bit_shift as u32, data);
105        }
106        let bit_shift = addr_in_bits & 7;
107        let bytes: [u8; 8] = (&data[addr..addr + 8]).try_into().unwrap();
108        let val_unshifted_unmasked: u64 = u64::from_le_bytes(bytes);
109        let val_shifted = val_unshifted_unmasked >> bit_shift;
110        val_shifted & self.mask
111    }
112
113    #[inline(never)]
114    fn get_slow_path(&self, addr: usize, bit_shift: u32, data: &[u8]) -> u64 {
115        let mut bytes: [u8; 8] = [0u8; 8];
116        let available_bytes = data.len() - addr;
117        // This function is meant to only be called if we did not have 8 bytes to load.
118        debug_assert!(available_bytes < 8);
119        bytes[..available_bytes].copy_from_slice(&data[addr..]);
120        let val_unshifted_unmasked: u64 = u64::from_le_bytes(bytes);
121        let val_shifted = val_unshifted_unmasked >> bit_shift;
122        val_shifted & self.mask
123    }
124
125    // Decodes the range of bitpacked `u32` values with idx
126    // in [start_idx, start_idx + output.len()).
127    //
128    // #Panics
129    //
130    // This methods panics if `num_bits` is > 32.
131    fn get_batch_u32s(&self, start_idx: u32, data: &[u8], output: &mut [u32]) {
132        assert!(
133            self.bit_width() <= 32,
134            "Bitwidth must be <= 32 to use this method."
135        );
136
137        let end_idx: u32 = start_idx + output.len() as u32;
138
139        // We use `usize` here to avoid overflow issues.
140        let end_bit_read = (end_idx as usize) * self.num_bits;
141        let end_byte_read = end_bit_read.div_ceil(8);
142        assert!(
143            end_byte_read <= data.len(),
144            "Requested index is out of bounds."
145        );
146
147        // Simple slow implementation of get_batch_u32s, to deal with our ramps.
148        let get_batch_ramp = |start_idx: u32, output: &mut [u32]| {
149            for (out, idx) in output.iter_mut().zip(start_idx..) {
150                *out = self.get(idx, data) as u32;
151            }
152        };
153
154        // We use an unrolled routine to decode 32 values at once.
155        // We therefore decompose our range of values to decode into three ranges:
156        // - Entrance ramp: [start_idx, fast_track_start) (up to 31 values)
157        // - Highway: [fast_track_start, fast_track_end) (a length multiple of 32s)
158        // - Exit ramp: [fast_track_end, start_idx + output.len()) (up to 31 values)
159
160        // We want the start of the fast track to start align with bytes.
161        // A sufficient condition is to start with an idx that is a multiple of 8,
162        // so highway start is the closest multiple of 8 that is >= start_idx.
163        let entrance_ramp_len: u32 = 8 - (start_idx % 8) % 8;
164
165        let highway_start: u32 = start_idx + entrance_ramp_len;
166
167        if highway_start + (BitPacker1x::BLOCK_LEN as u32) > end_idx {
168            // We don't have enough values to have even a single block of highway.
169            // Let's just supply the values the simple way.
170            get_batch_ramp(start_idx, output);
171            return;
172        }
173
174        let num_blocks: usize = (end_idx - highway_start) as usize / BitPacker1x::BLOCK_LEN;
175
176        // Entrance ramp
177        get_batch_ramp(start_idx, &mut output[..entrance_ramp_len as usize]);
178
179        // Highway
180        let mut offset = (highway_start as usize * self.num_bits) / 8;
181        let mut output_cursor = (highway_start - start_idx) as usize;
182        for _ in 0..num_blocks {
183            offset += BitPacker1x.decompress(
184                &data[offset..],
185                &mut output[output_cursor..],
186                self.num_bits as u8,
187            );
188            output_cursor += 32;
189        }
190
191        // Exit ramp
192        let highway_end: u32 = highway_start + (num_blocks * BitPacker1x::BLOCK_LEN) as u32;
193        get_batch_ramp(highway_end, &mut output[output_cursor..]);
194    }
195
196    pub fn get_ids_for_value_range(
197        &self,
198        range: RangeInclusive<u64>,
199        id_range: Range<u32>,
200        data: &[u8],
201        positions: &mut Vec<u32>,
202    ) {
203        if self.bit_width() > 32 {
204            self.get_ids_for_value_range_slow(range, id_range, data, positions)
205        } else {
206            if *range.start() > u32::MAX as u64 {
207                positions.clear();
208                return;
209            }
210            let range_u32 = (*range.start() as u32)..=(*range.end()).min(u32::MAX as u64) as u32;
211            self.get_ids_for_value_range_fast(range_u32, id_range, data, positions)
212        }
213    }
214
215    fn get_ids_for_value_range_slow(
216        &self,
217        range: RangeInclusive<u64>,
218        id_range: Range<u32>,
219        data: &[u8],
220        positions: &mut Vec<u32>,
221    ) {
222        positions.clear();
223        for i in id_range {
224            // If we cared we could make this branchless, but the slow implementation should rarely
225            // kick in.
226            let val = self.get(i, data);
227            if range.contains(&val) {
228                positions.push(i);
229            }
230        }
231    }
232
233    fn get_ids_for_value_range_fast(
234        &self,
235        value_range: RangeInclusive<u32>,
236        id_range: Range<u32>,
237        data: &[u8],
238        positions: &mut Vec<u32>,
239    ) {
240        positions.resize(id_range.len(), 0u32);
241        self.get_batch_u32s(id_range.start, data, positions);
242        crate::filter_vec::filter_vec_in_place(value_range, id_range.start, positions)
243    }
244}
245
246#[cfg(test)]
247mod test {
248    use super::{BitPacker, BitUnpacker};
249
250    fn create_bitpacker(len: usize, num_bits: u8) -> (BitUnpacker, Vec<u64>, Vec<u8>) {
251        let mut data = Vec::new();
252        let mut bitpacker = BitPacker::new();
253        let max_val: u64 = (1u64 << num_bits as u64) - 1u64;
254        let vals: Vec<u64> = (0u64..len as u64)
255            .map(|i| if max_val == 0 { 0 } else { i % max_val })
256            .collect();
257        for &val in &vals {
258            bitpacker.write(val, num_bits, &mut data).unwrap();
259        }
260        bitpacker.close(&mut data).unwrap();
261        assert_eq!(data.len(), ((num_bits as usize) * len).div_ceil(8));
262        let bitunpacker = BitUnpacker::new(num_bits);
263        (bitunpacker, vals, data)
264    }
265
266    fn test_bitpacker_util(len: usize, num_bits: u8) {
267        let (bitunpacker, vals, data) = create_bitpacker(len, num_bits);
268        for (i, val) in vals.iter().enumerate() {
269            assert_eq!(bitunpacker.get(i as u32, &data), *val);
270        }
271    }
272
273    #[test]
274    fn test_bitpacker() {
275        test_bitpacker_util(10, 3);
276        test_bitpacker_util(10, 0);
277        test_bitpacker_util(10, 1);
278        test_bitpacker_util(6, 14);
279        test_bitpacker_util(1000, 14);
280    }
281
282    use proptest::prelude::*;
283
284    fn num_bits_strategy() -> impl Strategy<Value = u8> {
285        prop_oneof!(Just(0), Just(1), 2u8..56u8, Just(56), Just(64),)
286    }
287
288    fn vals_strategy() -> impl Strategy<Value = (u8, Vec<u64>)> {
289        (num_bits_strategy(), 0usize..100usize).prop_flat_map(|(num_bits, len)| {
290            let max_val = if num_bits == 64 {
291                u64::MAX
292            } else {
293                (1u64 << num_bits as u32) - 1
294            };
295            let vals = proptest::collection::vec(0..=max_val, len);
296            vals.prop_map(move |vals| (num_bits, vals))
297        })
298    }
299
300    fn test_bitpacker_aux(num_bits: u8, vals: &[u64]) {
301        let mut buffer: Vec<u8> = Vec::new();
302        let mut bitpacker = BitPacker::new();
303        for &val in vals {
304            bitpacker.write(val, num_bits, &mut buffer).unwrap();
305        }
306        bitpacker.flush(&mut buffer).unwrap();
307        assert_eq!(buffer.len(), (vals.len() * num_bits as usize).div_ceil(8));
308        let bitunpacker = BitUnpacker::new(num_bits);
309        let max_val = if num_bits == 64 {
310            u64::MAX
311        } else {
312            (1u64 << num_bits) - 1
313        };
314        for (i, val) in vals.iter().copied().enumerate() {
315            assert!(val <= max_val);
316            assert_eq!(bitunpacker.get(i as u32, &buffer), val);
317        }
318    }
319
320    proptest::proptest! {
321        #[test]
322        fn test_bitpacker_proptest((num_bits, vals) in vals_strategy()) {
323            test_bitpacker_aux(num_bits, &vals);
324        }
325    }
326
327    #[test]
328    #[should_panic]
329    fn test_get_batch_panics_over_32_bits() {
330        let bitunpacker = BitUnpacker::new(33);
331        let mut output: [u32; 1] = [0u32];
332        bitunpacker.get_batch_u32s(0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
333    }
334
335    #[test]
336    fn test_get_batch_limit() {
337        let bitunpacker = BitUnpacker::new(1);
338        let mut output: [u32; 3] = [0u32, 0u32, 0u32];
339        bitunpacker.get_batch_u32s(8 * 4 - 3, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
340    }
341
342    #[test]
343    #[should_panic]
344    fn test_get_batch_panics_when_off_scope() {
345        let bitunpacker = BitUnpacker::new(1);
346        let mut output: [u32; 3] = [0u32, 0u32, 0u32];
347        // We are missing exactly one bit.
348        bitunpacker.get_batch_u32s(8 * 4 - 2, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
349    }
350
351    proptest::proptest! {
352        #[test]
353        fn test_get_batch_u32s_proptest(num_bits in 0u8..=32u8) {
354            let mask =
355                if num_bits == 32u8 {
356                    u32::MAX
357                } else {
358                    (1u32 << num_bits) - 1
359                };
360            let mut buffer: Vec<u8> = Vec::new();
361            let mut bitpacker = BitPacker::new();
362            for val in 0..100 {
363                bitpacker.write(val & mask as u64, num_bits, &mut buffer).unwrap();
364            }
365            bitpacker.flush(&mut buffer).unwrap();
366            let bitunpacker = BitUnpacker::new(num_bits);
367            let mut output: Vec<u32> = Vec::new();
368            for len in [0, 1, 2, 32, 33, 34, 64] {
369                for start_idx in 0u32..32u32 {
370                    output.resize(len, 0);
371                    bitunpacker.get_batch_u32s(start_idx, &buffer, &mut output);
372                    for (i, output_byte) in output.iter().enumerate() {
373                        let expected = (start_idx + i as u32) & mask;
374                        assert_eq!(*output_byte, expected);
375                    }
376                }
377            }
378        }
379    }
380}