tdb_succinct/
bitarray.rs

1#![allow(clippy::precedence, clippy::verbose_bit_mask)]
2
3//! Code for reading, writing, and using bit arrays.
4//!
5//! A bit array is a contiguous sequence of N bits contained in L words. By choosing L as the
6//! minimal number of words required for N bits, the sequence is compressed and yet aligned on a
7//! word boundary.
8//!
9//! # Notes
10//!
11//! * All words are stored in a standard big-endian encoding.
12//! * The maximum number of bits is 2^64-1.
13//!
14//! # Naming
15//!
16//! Because of the ambiguity of the English language and the possibility to confuse the meanings of
17//! the words used to describe aspects of this code, we try to use the following definitions
18//! consistently throughout:
19//!
20//! * buffer: a contiguous sequence of bytes
21//!
22//! * size: the number of bytes in a buffer
23//!
24//! * word: a 64-bit contiguous sequence aligned on 8-byte boundaries starting at the beginning of
25//!     the input buffer
26//!
27//! * index: the logical address of a bit in the data buffer.
28//!
29//! * length: the number of usable bits in the bit array
30
31use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
32
33use super::util;
34use crate::bititer::BitIter;
35use crate::storage::{FileLoad, SyncableFile};
36use byteorder::{BigEndian, ByteOrder};
37use bytes::{Buf, BufMut, Bytes, BytesMut};
38use futures::io;
39use futures::stream::{Stream, StreamExt, TryStreamExt};
40use std::{convert::TryFrom, error, fmt};
41use tokio_util::codec::{Decoder, FramedRead};
42
43/// A thread-safe, reference-counted, compressed bit sequence.
44///
45/// A `BitArray` is a wrapper around a [`Bytes`] that provides a view of the underlying data as a
46/// compressed sequence of bits.
47///
48/// [`Bytes`]: ../../../bytes/struct.Bytes.html
49///
50/// As with other types in [`structures`], a `BitArray` is created from an existing buffer, rather
51/// than constructed from parts. The buffer may be read from a file or other source and may be very
52/// large. A `BitArray` preserves the buffer to save memory but provides a simple abstraction of
53/// being a vector of `bool`s.
54///
55/// [`structures`]: ../index.html
56#[derive(Clone)]
57pub struct BitArray {
58    /// Number of usable bits in the array.
59    len: u64,
60
61    /// Shared reference to the buffer containing the sequence of bits.
62    ///
63    /// The buffer does not contain the control word.
64    buf: Bytes,
65}
66
67/// An error that occurred during a bit array operation.
68#[derive(Debug, PartialEq)]
69pub enum BitArrayError {
70    InputBufferTooSmall(usize),
71    UnexpectedInputBufferSize(u64, u64, u64),
72}
73
74impl BitArrayError {
75    /// Validate the input buffer size.
76    ///
77    /// It must have at least the control word.
78    fn validate_input_buf_size(input_buf_size: usize) -> Result<(), Self> {
79        if input_buf_size < 8 {
80            return Err(BitArrayError::InputBufferTooSmall(input_buf_size));
81        }
82        Ok(())
83    }
84
85    /// Validate the length.
86    ///
87    /// The input buffer size should be the appropriate multiple of 8 to include the number of bits
88    /// plus the control word.
89    fn validate_len(input_buf_size: usize, len: u64) -> Result<(), Self> {
90        // Calculate the expected input buffer size. This includes the control word.
91        let expected_buf_size = {
92            // The following steps are necessary to avoid overflow. If we add first and shift
93            // second, the addition might result in a value greater than `u64::max_value()`.
94            // Therefore, we right-shift first to produce a value that cannot overflow, check how
95            // much we need to add, and add it.
96            let after_shifting = len >> 6 << 3;
97            if len & 63 == 0 {
98                // The number of bits fit evenly into 64-bit words. Add only the control word.
99                after_shifting + 8
100            } else {
101                // The number of bits do not fit evenly into 64-bit words. Add a word for the
102                // leftovers plus the control word.
103                after_shifting + 16
104            }
105        };
106        let input_buf_size = u64::try_from(input_buf_size).unwrap();
107
108        if input_buf_size != expected_buf_size {
109            return Err(BitArrayError::UnexpectedInputBufferSize(
110                input_buf_size,
111                expected_buf_size,
112                len,
113            ));
114        }
115
116        Ok(())
117    }
118}
119
120impl fmt::Display for BitArrayError {
121    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
122        use BitArrayError::*;
123        match self {
124            InputBufferTooSmall(input_buf_size) => {
125                write!(f, "expected input buffer size ({}) >= 8", input_buf_size)
126            }
127            UnexpectedInputBufferSize(input_buf_size, expected_buf_size, len) => write!(
128                f,
129                "expected input buffer size ({}) to be {} for {} bits",
130                input_buf_size, expected_buf_size, len
131            ),
132        }
133    }
134}
135
136impl error::Error for BitArrayError {}
137
138impl From<BitArrayError> for io::Error {
139    fn from(err: BitArrayError) -> io::Error {
140        io::Error::new(io::ErrorKind::InvalidData, err)
141    }
142}
143
144/// Read the length from the control word buffer. `buf` must start at the first word after the data
145/// buffer. `input_buf_size` is used for validation.
146fn read_control_word(buf: &[u8], input_buf_size: usize) -> Result<u64, BitArrayError> {
147    let len = BigEndian::read_u64(buf);
148    BitArrayError::validate_len(input_buf_size, len)?;
149    Ok(len)
150}
151
152impl BitArray {
153    /// Construct a `BitArray` by parsing a `Bytes` buffer.
154    pub fn from_bits(mut buf: Bytes) -> Result<BitArray, BitArrayError> {
155        let input_buf_size = buf.len();
156        BitArrayError::validate_input_buf_size(input_buf_size)?;
157
158        let len = read_control_word(&buf.split_off(input_buf_size - 8), input_buf_size)?;
159
160        Ok(BitArray { buf, len })
161    }
162
163    /// Returns a reference to the buffer slice.
164    pub fn bits(&self) -> &[u8] {
165        &self.buf
166    }
167
168    /// Returns the number of usable bits in the bit array.
169    pub fn len(&self) -> usize {
170        usize::try_from(self.len).unwrap_or_else(|_| {
171            panic!(
172                "expected length ({}) to fit in {} bytes",
173                self.len,
174                std::mem::size_of::<usize>()
175            )
176        })
177    }
178
179    /// Returns `true` if there are no usable bits.
180    pub fn is_empty(&self) -> bool {
181        self.len == 0
182    }
183
184    /// Reads the data buffer and returns the logical value of the bit at the bit `index`.
185    ///
186    /// Panics if `index` is >= the length of the bit array.
187    pub fn get(&self, index: usize) -> bool {
188        let len = self.len();
189        debug_assert!(index < len, "expected index ({}) < length ({})", index, len);
190
191        let byte = self.buf[index / 8];
192        let mask = 0b1000_0000 >> index % 8;
193
194        byte & mask != 0
195    }
196
197    pub fn iter(&self) -> impl Iterator<Item = bool> {
198        let bits = self.clone();
199        (0..bits.len()).map(move |index| bits.get(index))
200    }
201}
202
203pub struct BitArrayBufBuilder<B> {
204    /// Destination of the bit array data.
205    dest: B,
206    /// Storage for the next word to be written.
207    current: u64,
208    /// Number of bits written to the buffer
209    count: u64,
210}
211
212impl<B: BufMut> BitArrayBufBuilder<B> {
213    pub fn new(dest: B) -> BitArrayBufBuilder<B> {
214        BitArrayBufBuilder {
215            dest,
216            current: 0,
217            count: 0,
218        }
219    }
220
221    pub fn push(&mut self, bit: bool) {
222        // Set the bit in the current word.
223        if bit {
224            // Determine the position of the bit to be set from `count`.
225            let pos = self.count & 0b11_1111;
226            self.current |= 0x8000_0000_0000_0000 >> pos;
227        }
228
229        // Advance the bit count.
230        self.count += 1;
231
232        // Check if the new `count` has reached a word boundary.
233        if self.count & 0b11_1111 == 0 {
234            // We have filled `current`, so write it to the destination.
235            self.dest.put_u64(self.current);
236            self.current = 0;
237        }
238    }
239
240    pub fn push_all<I: Iterator<Item = bool>>(&mut self, mut iter: I) {
241        while let Some(bit) = iter.next() {
242            self.push(bit);
243        }
244    }
245
246    fn finalize_data(&mut self) {
247        if self.count & 0b11_1111 != 0 {
248            self.dest.put_u64(self.current);
249        }
250    }
251
252    pub fn finalize(mut self) -> B {
253        let count = self.count;
254        // Write the final data word.
255        self.finalize_data();
256        // Write the control word.
257        self.dest.put_u64(count);
258
259        self.dest
260    }
261
262    pub fn count(&self) -> u64 {
263        self.count
264    }
265}
266
267pub struct BitArrayFileBuilder<W> {
268    /// Destination of the bit array data.
269    dest: W,
270    /// Storage for the next word to be written.
271    current: u64,
272    /// Number of bits written to the buffer
273    count: u64,
274}
275
276impl<W: SyncableFile> BitArrayFileBuilder<W> {
277    pub fn new(dest: W) -> BitArrayFileBuilder<W> {
278        BitArrayFileBuilder {
279            dest,
280            current: 0,
281            count: 0,
282        }
283    }
284
285    pub async fn push(&mut self, bit: bool) -> io::Result<()> {
286        // Set the bit in the current word.
287        if bit {
288            // Determine the position of the bit to be set from `count`.
289            let pos = self.count & 0b11_1111;
290            self.current |= 0x8000_0000_0000_0000 >> pos;
291        }
292
293        // Advance the bit count.
294        self.count += 1;
295
296        // Check if the new `count` has reached a word boundary.
297        if self.count & 0b11_1111 == 0 {
298            // We have filled `current`, so write it to the destination.
299            util::write_u64(&mut self.dest, self.current).await?;
300            self.current = 0;
301        }
302
303        Ok(())
304    }
305
306    pub async fn push_all<S: Stream<Item = io::Result<bool>> + Unpin>(
307        &mut self,
308        mut stream: S,
309    ) -> io::Result<()> {
310        while let Some(bit) = stream.next().await {
311            let bit = bit?;
312            self.push(bit).await?;
313        }
314
315        Ok(())
316    }
317
318    async fn finalize_data(&mut self) -> io::Result<()> {
319        if self.count & 0b11_1111 != 0 {
320            util::write_u64(&mut self.dest, self.current).await?;
321        }
322
323        Ok(())
324    }
325
326    pub async fn finalize(mut self) -> io::Result<()> {
327        let count = self.count;
328        // Write the final data word.
329        self.finalize_data().await?;
330        // Write the control word.
331        util::write_u64(&mut self.dest, count).await?;
332        // Flush the `dest`.
333        self.dest.flush().await?;
334        self.dest.sync_all().await?;
335
336        Ok(())
337    }
338
339    pub fn count(&self) -> u64 {
340        self.count
341    }
342}
343
344pub struct BitArrayBlockDecoder {
345    /// The next word, if it exists, to return.
346    ///
347    /// This is used to make sure that `decode` always returns one word behind the current word, so
348    /// that when we reach the end, we don't return the last word, which is the control word.
349    readahead: Option<u64>,
350}
351
352impl Decoder for BitArrayBlockDecoder {
353    type Item = u64;
354    type Error = io::Error;
355
356    /// Decode the next block of the bit array.
357    fn decode(&mut self, bytes: &mut BytesMut) -> Result<Option<u64>, io::Error> {
358        Ok(decode_next_bitarray_block(bytes, &mut self.readahead))
359    }
360}
361
362fn decode_next_bitarray_block<B: Buf>(bytes: &mut B, readahead: &mut Option<u64>) -> Option<u64> {
363    // If there isn't a full word available in the buffer, stop.
364    if bytes.remaining() < 8 {
365        return None;
366    }
367
368    // Read the next word. If `readahead` was `Some`, return that value; otherwise,
369    // recurse to read a second word and then return the first word.
370    //
371    // This trick means that we don't return the last word in the buffer, which is the control
372    // word. The consequence is that we read an extra word at the beginning of the decoding
373    // process.
374    match readahead.replace(bytes.get_u64()) {
375        Some(word) => Some(word),
376        None => decode_next_bitarray_block(bytes, readahead),
377    }
378}
379
380pub fn bitarray_stream_blocks<R: AsyncRead + Unpin>(r: R) -> FramedRead<R, BitArrayBlockDecoder> {
381    FramedRead::new(r, BitArrayBlockDecoder { readahead: None })
382}
383
384pub fn bitarray_iter_blocks<B: Buf>(b: B) -> BitArrayBlockIterator<B> {
385    BitArrayBlockIterator {
386        buf: b,
387        readahead: None,
388    }
389}
390
391pub struct BitArrayBlockIterator<B: Buf> {
392    buf: B,
393    readahead: Option<u64>,
394}
395
396impl<B: Buf> Iterator for BitArrayBlockIterator<B> {
397    type Item = u64;
398    fn next(&mut self) -> Option<u64> {
399        decode_next_bitarray_block(&mut self.buf, &mut self.readahead)
400    }
401}
402
403/// Read the length (number of bits) from a `FileLoad`.
404pub async fn bitarray_len_from_file<F: FileLoad>(f: F) -> io::Result<u64> {
405    BitArrayError::validate_input_buf_size(f.size().await?)?;
406    let mut control_word = vec![0; 8];
407    f.open_read_from(f.size().await? - 8)
408        .await?
409        .read_exact(&mut control_word)
410        .await?;
411    Ok(read_control_word(&control_word, f.size().await?)?)
412}
413
414pub async fn bitarray_stream_bits<F: FileLoad>(
415    f: F,
416) -> io::Result<impl Stream<Item = io::Result<bool>> + Unpin> {
417    // Read the length.
418    let len = bitarray_len_from_file(f.clone()).await?;
419
420    // Read the words into a `Stream`.
421    Ok(bitarray_stream_blocks(f.open_read().await?)
422        // For each word, read the bits into a `Stream`.
423        .map_ok(|block| util::stream_iter_ok(BitIter::new(block)))
424        // Turn the `Stream` of bit `Stream`s into a bit `Stream`.
425        .try_flatten()
426        .into_stream()
427        // Cut the `Stream` off after the length of bits is reached.
428        .take(len as usize))
429}
430
431#[cfg(test)]
432mod tests {
433    use crate::storage::memory::MemoryBackedStore;
434    use crate::storage::FileStore;
435
436    use super::*;
437    use futures::executor::block_on;
438    use futures::future;
439
440    #[test]
441    fn bit_array_error() {
442        // Display
443        assert_eq!(
444            "expected input buffer size (7) >= 8",
445            BitArrayError::InputBufferTooSmall(7).to_string()
446        );
447        assert_eq!(
448            "expected input buffer size (9) to be 8 for 0 bits",
449            BitArrayError::UnexpectedInputBufferSize(9, 8, 0).to_string()
450        );
451
452        // From<BitArrayError> for io::Error
453        assert_eq!(
454            io::Error::new(
455                io::ErrorKind::InvalidData,
456                BitArrayError::InputBufferTooSmall(7)
457            )
458            .to_string(),
459            io::Error::from(BitArrayError::InputBufferTooSmall(7)).to_string()
460        );
461    }
462
463    #[test]
464    fn validate_input_buf_size() {
465        let val = |buf_size| BitArrayError::validate_input_buf_size(buf_size);
466        let err = |buf_size| Err(BitArrayError::InputBufferTooSmall(buf_size));
467        assert_eq!(err(7), val(7));
468        assert_eq!(Ok(()), val(8));
469        assert_eq!(Ok(()), val(9));
470        assert_eq!(Ok(()), val(usize::max_value()));
471    }
472
473    #[test]
474    fn validate_len() {
475        let val = |buf_size, len| BitArrayError::validate_len(buf_size, len);
476        let err = |buf_size, expected, len| {
477            Err(BitArrayError::UnexpectedInputBufferSize(
478                buf_size, expected, len,
479            ))
480        };
481
482        assert_eq!(err(0, 8, 0), val(0, 0));
483        assert_eq!(Ok(()), val(16, 1));
484        assert_eq!(Ok(()), val(16, 2));
485
486        #[cfg(target_pointer_width = "64")]
487        assert_eq!(
488            Ok(()),
489            val(
490                usize::try_from(u128::from(u64::max_value()) + 65 >> 6 << 3).unwrap(),
491                u64::max_value()
492            )
493        );
494    }
495
496    #[test]
497    fn decode() {
498        let mut decoder = BitArrayBlockDecoder { readahead: None };
499        let mut bytes = BytesMut::from([0u8; 8].as_ref());
500        assert_eq!(None, Decoder::decode(&mut decoder, &mut bytes).unwrap());
501    }
502
503    #[test]
504    fn empty() {
505        assert!(BitArray::from_bits(Bytes::from([0u8; 8].as_ref()))
506            .unwrap()
507            .is_empty());
508    }
509
510    #[tokio::test]
511    async fn construct_and_parse_small_bitarray() {
512        let x = MemoryBackedStore::new();
513        let contents = vec![true, true, false, false, true];
514
515        let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
516        block_on(async {
517            builder.push_all(util::stream_iter_ok(contents)).await?;
518            builder.finalize().await?;
519
520            Ok::<_, io::Error>(())
521        })
522        .unwrap();
523
524        let loaded = block_on(x.map()).unwrap();
525
526        let bitarray = BitArray::from_bits(loaded).unwrap();
527
528        assert_eq!(true, bitarray.get(0));
529        assert_eq!(true, bitarray.get(1));
530        assert_eq!(false, bitarray.get(2));
531        assert_eq!(false, bitarray.get(3));
532        assert_eq!(true, bitarray.get(4));
533    }
534
535    #[tokio::test]
536    async fn construct_and_parse_large_bitarray() {
537        let x = MemoryBackedStore::new();
538        let contents = (0..).map(|n| n % 3 == 0).take(123456);
539
540        let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
541        block_on(async {
542            builder.push_all(util::stream_iter_ok(contents)).await?;
543            builder.finalize().await?;
544
545            Ok::<_, io::Error>(())
546        })
547        .unwrap();
548
549        let loaded = block_on(x.map()).unwrap();
550
551        let bitarray = BitArray::from_bits(loaded).unwrap();
552
553        for i in 0..bitarray.len() {
554            assert_eq!(i % 3 == 0, bitarray.get(i));
555        }
556    }
557
558    #[tokio::test]
559    async fn bitarray_len_from_file_errors() {
560        let store = MemoryBackedStore::new();
561        let mut writer = store.open_write().await.unwrap();
562        writer.write_all(&[0, 0, 0]).await.unwrap();
563        writer.sync_all().await.unwrap();
564        assert_eq!(
565            io::Error::from(BitArrayError::InputBufferTooSmall(3)).to_string(),
566            block_on(bitarray_len_from_file(store))
567                .err()
568                .unwrap()
569                .to_string()
570        );
571
572        let store = MemoryBackedStore::new();
573        let mut writer = store.open_write().await.unwrap();
574        writer.write_all(&[0, 0, 0, 0, 0, 0, 0, 2]).await.unwrap();
575        writer.sync_all().await.unwrap();
576        assert_eq!(
577            io::Error::from(BitArrayError::UnexpectedInputBufferSize(8, 16, 2)).to_string(),
578            block_on(bitarray_len_from_file(store))
579                .err()
580                .unwrap()
581                .to_string()
582        );
583    }
584
585    #[tokio::test]
586    async fn stream_blocks() {
587        let x = MemoryBackedStore::new();
588        let contents: Vec<bool> = (0..).map(|n| n % 4 == 1).take(256).collect();
589
590        let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
591        builder
592            .push_all(util::stream_iter_ok(contents))
593            .await
594            .unwrap();
595        builder.finalize().await.unwrap();
596
597        let stream = bitarray_stream_blocks(x.open_read().await.unwrap());
598
599        stream
600            .try_for_each(|block| future::ok(assert_eq!(0x4444444444444444, block)))
601            .await
602            .unwrap();
603    }
604
605    #[tokio::test]
606    async fn stream_bits() {
607        let x = MemoryBackedStore::new();
608        let contents: Vec<_> = (0..).map(|n| n % 4 == 1).take(123).collect();
609
610        let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
611        block_on(async {
612            builder
613                .push_all(util::stream_iter_ok(contents.clone()))
614                .await?;
615            builder.finalize().await?;
616
617            Ok::<_, io::Error>(())
618        })
619        .unwrap();
620
621        let result: Vec<_> =
622            block_on(bitarray_stream_bits(x).await.unwrap().try_collect()).unwrap();
623
624        assert_eq!(contents, result);
625    }
626}