1use std::io;
2use std::ops::{Range, RangeInclusive};
3
4use bitpacking::{BitPacker as ExternalBitPackerTrait, BitPacker1x};
5
6pub struct BitPacker {
7 mini_buffer: u64,
8 mini_buffer_written: usize,
9}
10
11impl Default for BitPacker {
12 fn default() -> Self {
13 BitPacker::new()
14 }
15}
16impl BitPacker {
17 pub fn new() -> BitPacker {
18 BitPacker {
19 mini_buffer: 0u64,
20 mini_buffer_written: 0,
21 }
22 }
23
24 #[inline]
25 pub fn write<TWrite: io::Write + ?Sized>(
26 &mut self,
27 val: u64,
28 num_bits: u8,
29 output: &mut TWrite,
30 ) -> io::Result<()> {
31 let num_bits = num_bits as usize;
32 if self.mini_buffer_written + num_bits > 64 {
33 self.mini_buffer |= val.wrapping_shl(self.mini_buffer_written as u32);
34 output.write_all(self.mini_buffer.to_le_bytes().as_ref())?;
35 self.mini_buffer = val.wrapping_shr((64 - self.mini_buffer_written) as u32);
36 self.mini_buffer_written = self.mini_buffer_written + num_bits - 64;
37 } else {
38 self.mini_buffer |= val << self.mini_buffer_written;
39 self.mini_buffer_written += num_bits;
40 if self.mini_buffer_written == 64 {
41 output.write_all(self.mini_buffer.to_le_bytes().as_ref())?;
42 self.mini_buffer_written = 0;
43 self.mini_buffer = 0u64;
44 }
45 }
46 Ok(())
47 }
48
49 pub fn flush<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
50 if self.mini_buffer_written > 0 {
51 let num_bytes = self.mini_buffer_written.div_ceil(8);
52 let bytes = self.mini_buffer.to_le_bytes();
53 output.write_all(&bytes[..num_bytes])?;
54 self.mini_buffer_written = 0;
55 self.mini_buffer = 0;
56 }
57 Ok(())
58 }
59
60 pub fn close<TWrite: io::Write + ?Sized>(&mut self, output: &mut TWrite) -> io::Result<()> {
61 self.flush(output)?;
62 Ok(())
63 }
64}
65
66#[derive(Clone, Debug, Default, Copy)]
67pub struct BitUnpacker {
68 num_bits: usize,
69 mask: u64,
70}
71
72impl BitUnpacker {
73 pub fn new(num_bits: u8) -> BitUnpacker {
79 assert!(num_bits <= 7 * 8 || num_bits == 64);
80 let mask: u64 = if num_bits == 64 {
81 !0u64
82 } else {
83 (1u64 << num_bits) - 1u64
84 };
85 BitUnpacker {
86 num_bits: usize::from(num_bits),
87 mask,
88 }
89 }
90
91 pub fn bit_width(&self) -> u8 {
92 self.num_bits as u8
93 }
94
95 #[inline]
96 pub fn get(&self, idx: u32, data: &[u8]) -> u64 {
97 let addr_in_bits = idx as usize * self.num_bits;
98 let addr = addr_in_bits >> 3;
99 if addr + 8 > data.len() {
100 if self.num_bits == 0 {
101 return 0;
102 }
103 let bit_shift = addr_in_bits & 7;
104 return self.get_slow_path(addr, bit_shift as u32, data);
105 }
106 let bit_shift = addr_in_bits & 7;
107 let bytes: [u8; 8] = (&data[addr..addr + 8]).try_into().unwrap();
108 let val_unshifted_unmasked: u64 = u64::from_le_bytes(bytes);
109 let val_shifted = val_unshifted_unmasked >> bit_shift;
110 val_shifted & self.mask
111 }
112
113 #[inline(never)]
114 fn get_slow_path(&self, addr: usize, bit_shift: u32, data: &[u8]) -> u64 {
115 let mut bytes: [u8; 8] = [0u8; 8];
116 let available_bytes = data.len() - addr;
117 debug_assert!(available_bytes < 8);
119 bytes[..available_bytes].copy_from_slice(&data[addr..]);
120 let val_unshifted_unmasked: u64 = u64::from_le_bytes(bytes);
121 let val_shifted = val_unshifted_unmasked >> bit_shift;
122 val_shifted & self.mask
123 }
124
125 fn get_batch_u32s(&self, start_idx: u32, data: &[u8], output: &mut [u32]) {
132 assert!(
133 self.bit_width() <= 32,
134 "Bitwidth must be <= 32 to use this method."
135 );
136
137 let end_idx: u32 = start_idx + output.len() as u32;
138
139 let end_bit_read = (end_idx as usize) * self.num_bits;
141 let end_byte_read = end_bit_read.div_ceil(8);
142 assert!(
143 end_byte_read <= data.len(),
144 "Requested index is out of bounds."
145 );
146
147 let get_batch_ramp = |start_idx: u32, output: &mut [u32]| {
149 for (out, idx) in output.iter_mut().zip(start_idx..) {
150 *out = self.get(idx, data) as u32;
151 }
152 };
153
154 let entrance_ramp_len: u32 = 8 - (start_idx % 8) % 8;
164
165 let highway_start: u32 = start_idx + entrance_ramp_len;
166
167 if highway_start + (BitPacker1x::BLOCK_LEN as u32) > end_idx {
168 get_batch_ramp(start_idx, output);
171 return;
172 }
173
174 let num_blocks: usize = (end_idx - highway_start) as usize / BitPacker1x::BLOCK_LEN;
175
176 get_batch_ramp(start_idx, &mut output[..entrance_ramp_len as usize]);
178
179 let mut offset = (highway_start as usize * self.num_bits) / 8;
181 let mut output_cursor = (highway_start - start_idx) as usize;
182 for _ in 0..num_blocks {
183 offset += BitPacker1x.decompress(
184 &data[offset..],
185 &mut output[output_cursor..],
186 self.num_bits as u8,
187 );
188 output_cursor += 32;
189 }
190
191 let highway_end: u32 = highway_start + (num_blocks * BitPacker1x::BLOCK_LEN) as u32;
193 get_batch_ramp(highway_end, &mut output[output_cursor..]);
194 }
195
196 pub fn get_ids_for_value_range(
197 &self,
198 range: RangeInclusive<u64>,
199 id_range: Range<u32>,
200 data: &[u8],
201 positions: &mut Vec<u32>,
202 ) {
203 if self.bit_width() > 32 {
204 self.get_ids_for_value_range_slow(range, id_range, data, positions)
205 } else {
206 if *range.start() > u32::MAX as u64 {
207 positions.clear();
208 return;
209 }
210 let range_u32 = (*range.start() as u32)..=(*range.end()).min(u32::MAX as u64) as u32;
211 self.get_ids_for_value_range_fast(range_u32, id_range, data, positions)
212 }
213 }
214
215 fn get_ids_for_value_range_slow(
216 &self,
217 range: RangeInclusive<u64>,
218 id_range: Range<u32>,
219 data: &[u8],
220 positions: &mut Vec<u32>,
221 ) {
222 positions.clear();
223 for i in id_range {
224 let val = self.get(i, data);
227 if range.contains(&val) {
228 positions.push(i);
229 }
230 }
231 }
232
233 fn get_ids_for_value_range_fast(
234 &self,
235 value_range: RangeInclusive<u32>,
236 id_range: Range<u32>,
237 data: &[u8],
238 positions: &mut Vec<u32>,
239 ) {
240 positions.resize(id_range.len(), 0u32);
241 self.get_batch_u32s(id_range.start, data, positions);
242 crate::filter_vec::filter_vec_in_place(value_range, id_range.start, positions)
243 }
244}
245
246#[cfg(test)]
247mod test {
248 use super::{BitPacker, BitUnpacker};
249
250 fn create_bitpacker(len: usize, num_bits: u8) -> (BitUnpacker, Vec<u64>, Vec<u8>) {
251 let mut data = Vec::new();
252 let mut bitpacker = BitPacker::new();
253 let max_val: u64 = (1u64 << num_bits as u64) - 1u64;
254 let vals: Vec<u64> = (0u64..len as u64)
255 .map(|i| if max_val == 0 { 0 } else { i % max_val })
256 .collect();
257 for &val in &vals {
258 bitpacker.write(val, num_bits, &mut data).unwrap();
259 }
260 bitpacker.close(&mut data).unwrap();
261 assert_eq!(data.len(), ((num_bits as usize) * len).div_ceil(8));
262 let bitunpacker = BitUnpacker::new(num_bits);
263 (bitunpacker, vals, data)
264 }
265
266 fn test_bitpacker_util(len: usize, num_bits: u8) {
267 let (bitunpacker, vals, data) = create_bitpacker(len, num_bits);
268 for (i, val) in vals.iter().enumerate() {
269 assert_eq!(bitunpacker.get(i as u32, &data), *val);
270 }
271 }
272
273 #[test]
274 fn test_bitpacker() {
275 test_bitpacker_util(10, 3);
276 test_bitpacker_util(10, 0);
277 test_bitpacker_util(10, 1);
278 test_bitpacker_util(6, 14);
279 test_bitpacker_util(1000, 14);
280 }
281
282 use proptest::prelude::*;
283
284 fn num_bits_strategy() -> impl Strategy<Value = u8> {
285 prop_oneof!(Just(0), Just(1), 2u8..56u8, Just(56), Just(64),)
286 }
287
288 fn vals_strategy() -> impl Strategy<Value = (u8, Vec<u64>)> {
289 (num_bits_strategy(), 0usize..100usize).prop_flat_map(|(num_bits, len)| {
290 let max_val = if num_bits == 64 {
291 u64::MAX
292 } else {
293 (1u64 << num_bits as u32) - 1
294 };
295 let vals = proptest::collection::vec(0..=max_val, len);
296 vals.prop_map(move |vals| (num_bits, vals))
297 })
298 }
299
300 fn test_bitpacker_aux(num_bits: u8, vals: &[u64]) {
301 let mut buffer: Vec<u8> = Vec::new();
302 let mut bitpacker = BitPacker::new();
303 for &val in vals {
304 bitpacker.write(val, num_bits, &mut buffer).unwrap();
305 }
306 bitpacker.flush(&mut buffer).unwrap();
307 assert_eq!(buffer.len(), (vals.len() * num_bits as usize).div_ceil(8));
308 let bitunpacker = BitUnpacker::new(num_bits);
309 let max_val = if num_bits == 64 {
310 u64::MAX
311 } else {
312 (1u64 << num_bits) - 1
313 };
314 for (i, val) in vals.iter().copied().enumerate() {
315 assert!(val <= max_val);
316 assert_eq!(bitunpacker.get(i as u32, &buffer), val);
317 }
318 }
319
320 proptest::proptest! {
321 #[test]
322 fn test_bitpacker_proptest((num_bits, vals) in vals_strategy()) {
323 test_bitpacker_aux(num_bits, &vals);
324 }
325 }
326
327 #[test]
328 #[should_panic]
329 fn test_get_batch_panics_over_32_bits() {
330 let bitunpacker = BitUnpacker::new(33);
331 let mut output: [u32; 1] = [0u32];
332 bitunpacker.get_batch_u32s(0, &[0, 0, 0, 0, 0, 0, 0, 0], &mut output[..]);
333 }
334
335 #[test]
336 fn test_get_batch_limit() {
337 let bitunpacker = BitUnpacker::new(1);
338 let mut output: [u32; 3] = [0u32, 0u32, 0u32];
339 bitunpacker.get_batch_u32s(8 * 4 - 3, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
340 }
341
342 #[test]
343 #[should_panic]
344 fn test_get_batch_panics_when_off_scope() {
345 let bitunpacker = BitUnpacker::new(1);
346 let mut output: [u32; 3] = [0u32, 0u32, 0u32];
347 bitunpacker.get_batch_u32s(8 * 4 - 2, &[0u8, 0u8, 0u8, 0u8], &mut output[..]);
349 }
350
351 proptest::proptest! {
352 #[test]
353 fn test_get_batch_u32s_proptest(num_bits in 0u8..=32u8) {
354 let mask =
355 if num_bits == 32u8 {
356 u32::MAX
357 } else {
358 (1u32 << num_bits) - 1
359 };
360 let mut buffer: Vec<u8> = Vec::new();
361 let mut bitpacker = BitPacker::new();
362 for val in 0..100 {
363 bitpacker.write(val & mask as u64, num_bits, &mut buffer).unwrap();
364 }
365 bitpacker.flush(&mut buffer).unwrap();
366 let bitunpacker = BitUnpacker::new(num_bits);
367 let mut output: Vec<u32> = Vec::new();
368 for len in [0, 1, 2, 32, 33, 34, 64] {
369 for start_idx in 0u32..32u32 {
370 output.resize(len, 0);
371 bitunpacker.get_batch_u32s(start_idx, &buffer, &mut output);
372 for (i, output_byte) in output.iter().enumerate() {
373 let expected = (start_idx + i as u32) & mask;
374 assert_eq!(*output_byte, expected);
375 }
376 }
377 }
378 }
379 }
380}