1#![allow(clippy::precedence, clippy::verbose_bit_mask)]
2
3use tokio::io::{AsyncRead, AsyncReadExt, AsyncWriteExt};
32
33use super::util;
34use crate::bititer::BitIter;
35use crate::storage::{FileLoad, SyncableFile};
36use byteorder::{BigEndian, ByteOrder};
37use bytes::{Buf, BufMut, Bytes, BytesMut};
38use futures::io;
39use futures::stream::{Stream, StreamExt, TryStreamExt};
40use std::{convert::TryFrom, error, fmt};
41use tokio_util::codec::{Decoder, FramedRead};
42
43#[derive(Clone)]
57pub struct BitArray {
58 len: u64,
60
61 buf: Bytes,
65}
66
67#[derive(Debug, PartialEq)]
69pub enum BitArrayError {
70 InputBufferTooSmall(usize),
71 UnexpectedInputBufferSize(u64, u64, u64),
72}
73
74impl BitArrayError {
75 fn validate_input_buf_size(input_buf_size: usize) -> Result<(), Self> {
79 if input_buf_size < 8 {
80 return Err(BitArrayError::InputBufferTooSmall(input_buf_size));
81 }
82 Ok(())
83 }
84
85 fn validate_len(input_buf_size: usize, len: u64) -> Result<(), Self> {
90 let expected_buf_size = {
92 let after_shifting = len >> 6 << 3;
97 if len & 63 == 0 {
98 after_shifting + 8
100 } else {
101 after_shifting + 16
104 }
105 };
106 let input_buf_size = u64::try_from(input_buf_size).unwrap();
107
108 if input_buf_size != expected_buf_size {
109 return Err(BitArrayError::UnexpectedInputBufferSize(
110 input_buf_size,
111 expected_buf_size,
112 len,
113 ));
114 }
115
116 Ok(())
117 }
118}
119
120impl fmt::Display for BitArrayError {
121 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
122 use BitArrayError::*;
123 match self {
124 InputBufferTooSmall(input_buf_size) => {
125 write!(f, "expected input buffer size ({}) >= 8", input_buf_size)
126 }
127 UnexpectedInputBufferSize(input_buf_size, expected_buf_size, len) => write!(
128 f,
129 "expected input buffer size ({}) to be {} for {} bits",
130 input_buf_size, expected_buf_size, len
131 ),
132 }
133 }
134}
135
136impl error::Error for BitArrayError {}
137
138impl From<BitArrayError> for io::Error {
139 fn from(err: BitArrayError) -> io::Error {
140 io::Error::new(io::ErrorKind::InvalidData, err)
141 }
142}
143
144fn read_control_word(buf: &[u8], input_buf_size: usize) -> Result<u64, BitArrayError> {
147 let len = BigEndian::read_u64(buf);
148 BitArrayError::validate_len(input_buf_size, len)?;
149 Ok(len)
150}
151
152impl BitArray {
153 pub fn from_bits(mut buf: Bytes) -> Result<BitArray, BitArrayError> {
155 let input_buf_size = buf.len();
156 BitArrayError::validate_input_buf_size(input_buf_size)?;
157
158 let len = read_control_word(&buf.split_off(input_buf_size - 8), input_buf_size)?;
159
160 Ok(BitArray { buf, len })
161 }
162
163 pub fn bits(&self) -> &[u8] {
165 &self.buf
166 }
167
168 pub fn len(&self) -> usize {
170 usize::try_from(self.len).unwrap_or_else(|_| {
171 panic!(
172 "expected length ({}) to fit in {} bytes",
173 self.len,
174 std::mem::size_of::<usize>()
175 )
176 })
177 }
178
179 pub fn is_empty(&self) -> bool {
181 self.len == 0
182 }
183
184 pub fn get(&self, index: usize) -> bool {
188 let len = self.len();
189 debug_assert!(index < len, "expected index ({}) < length ({})", index, len);
190
191 let byte = self.buf[index / 8];
192 let mask = 0b1000_0000 >> index % 8;
193
194 byte & mask != 0
195 }
196
197 pub fn iter(&self) -> impl Iterator<Item = bool> {
198 let bits = self.clone();
199 (0..bits.len()).map(move |index| bits.get(index))
200 }
201}
202
203pub struct BitArrayBufBuilder<B> {
204 dest: B,
206 current: u64,
208 count: u64,
210}
211
212impl<B: BufMut> BitArrayBufBuilder<B> {
213 pub fn new(dest: B) -> BitArrayBufBuilder<B> {
214 BitArrayBufBuilder {
215 dest,
216 current: 0,
217 count: 0,
218 }
219 }
220
221 pub fn push(&mut self, bit: bool) {
222 if bit {
224 let pos = self.count & 0b11_1111;
226 self.current |= 0x8000_0000_0000_0000 >> pos;
227 }
228
229 self.count += 1;
231
232 if self.count & 0b11_1111 == 0 {
234 self.dest.put_u64(self.current);
236 self.current = 0;
237 }
238 }
239
240 pub fn push_all<I: Iterator<Item = bool>>(&mut self, mut iter: I) {
241 while let Some(bit) = iter.next() {
242 self.push(bit);
243 }
244 }
245
246 fn finalize_data(&mut self) {
247 if self.count & 0b11_1111 != 0 {
248 self.dest.put_u64(self.current);
249 }
250 }
251
252 pub fn finalize(mut self) -> B {
253 let count = self.count;
254 self.finalize_data();
256 self.dest.put_u64(count);
258
259 self.dest
260 }
261
262 pub fn count(&self) -> u64 {
263 self.count
264 }
265}
266
267pub struct BitArrayFileBuilder<W> {
268 dest: W,
270 current: u64,
272 count: u64,
274}
275
276impl<W: SyncableFile> BitArrayFileBuilder<W> {
277 pub fn new(dest: W) -> BitArrayFileBuilder<W> {
278 BitArrayFileBuilder {
279 dest,
280 current: 0,
281 count: 0,
282 }
283 }
284
285 pub async fn push(&mut self, bit: bool) -> io::Result<()> {
286 if bit {
288 let pos = self.count & 0b11_1111;
290 self.current |= 0x8000_0000_0000_0000 >> pos;
291 }
292
293 self.count += 1;
295
296 if self.count & 0b11_1111 == 0 {
298 util::write_u64(&mut self.dest, self.current).await?;
300 self.current = 0;
301 }
302
303 Ok(())
304 }
305
306 pub async fn push_all<S: Stream<Item = io::Result<bool>> + Unpin>(
307 &mut self,
308 mut stream: S,
309 ) -> io::Result<()> {
310 while let Some(bit) = stream.next().await {
311 let bit = bit?;
312 self.push(bit).await?;
313 }
314
315 Ok(())
316 }
317
318 async fn finalize_data(&mut self) -> io::Result<()> {
319 if self.count & 0b11_1111 != 0 {
320 util::write_u64(&mut self.dest, self.current).await?;
321 }
322
323 Ok(())
324 }
325
326 pub async fn finalize(mut self) -> io::Result<()> {
327 let count = self.count;
328 self.finalize_data().await?;
330 util::write_u64(&mut self.dest, count).await?;
332 self.dest.flush().await?;
334 self.dest.sync_all().await?;
335
336 Ok(())
337 }
338
339 pub fn count(&self) -> u64 {
340 self.count
341 }
342}
343
344pub struct BitArrayBlockDecoder {
345 readahead: Option<u64>,
350}
351
352impl Decoder for BitArrayBlockDecoder {
353 type Item = u64;
354 type Error = io::Error;
355
356 fn decode(&mut self, bytes: &mut BytesMut) -> Result<Option<u64>, io::Error> {
358 Ok(decode_next_bitarray_block(bytes, &mut self.readahead))
359 }
360}
361
362fn decode_next_bitarray_block<B: Buf>(bytes: &mut B, readahead: &mut Option<u64>) -> Option<u64> {
363 if bytes.remaining() < 8 {
365 return None;
366 }
367
368 match readahead.replace(bytes.get_u64()) {
375 Some(word) => Some(word),
376 None => decode_next_bitarray_block(bytes, readahead),
377 }
378}
379
380pub fn bitarray_stream_blocks<R: AsyncRead + Unpin>(r: R) -> FramedRead<R, BitArrayBlockDecoder> {
381 FramedRead::new(r, BitArrayBlockDecoder { readahead: None })
382}
383
384pub fn bitarray_iter_blocks<B: Buf>(b: B) -> BitArrayBlockIterator<B> {
385 BitArrayBlockIterator {
386 buf: b,
387 readahead: None,
388 }
389}
390
391pub struct BitArrayBlockIterator<B: Buf> {
392 buf: B,
393 readahead: Option<u64>,
394}
395
396impl<B: Buf> Iterator for BitArrayBlockIterator<B> {
397 type Item = u64;
398 fn next(&mut self) -> Option<u64> {
399 decode_next_bitarray_block(&mut self.buf, &mut self.readahead)
400 }
401}
402
403pub async fn bitarray_len_from_file<F: FileLoad>(f: F) -> io::Result<u64> {
405 BitArrayError::validate_input_buf_size(f.size().await?)?;
406 let mut control_word = vec![0; 8];
407 f.open_read_from(f.size().await? - 8)
408 .await?
409 .read_exact(&mut control_word)
410 .await?;
411 Ok(read_control_word(&control_word, f.size().await?)?)
412}
413
414pub async fn bitarray_stream_bits<F: FileLoad>(
415 f: F,
416) -> io::Result<impl Stream<Item = io::Result<bool>> + Unpin> {
417 let len = bitarray_len_from_file(f.clone()).await?;
419
420 Ok(bitarray_stream_blocks(f.open_read().await?)
422 .map_ok(|block| util::stream_iter_ok(BitIter::new(block)))
424 .try_flatten()
426 .into_stream()
427 .take(len as usize))
429}
430
431#[cfg(test)]
432mod tests {
433 use crate::storage::memory::MemoryBackedStore;
434 use crate::storage::FileStore;
435
436 use super::*;
437 use futures::executor::block_on;
438 use futures::future;
439
440 #[test]
441 fn bit_array_error() {
442 assert_eq!(
444 "expected input buffer size (7) >= 8",
445 BitArrayError::InputBufferTooSmall(7).to_string()
446 );
447 assert_eq!(
448 "expected input buffer size (9) to be 8 for 0 bits",
449 BitArrayError::UnexpectedInputBufferSize(9, 8, 0).to_string()
450 );
451
452 assert_eq!(
454 io::Error::new(
455 io::ErrorKind::InvalidData,
456 BitArrayError::InputBufferTooSmall(7)
457 )
458 .to_string(),
459 io::Error::from(BitArrayError::InputBufferTooSmall(7)).to_string()
460 );
461 }
462
463 #[test]
464 fn validate_input_buf_size() {
465 let val = |buf_size| BitArrayError::validate_input_buf_size(buf_size);
466 let err = |buf_size| Err(BitArrayError::InputBufferTooSmall(buf_size));
467 assert_eq!(err(7), val(7));
468 assert_eq!(Ok(()), val(8));
469 assert_eq!(Ok(()), val(9));
470 assert_eq!(Ok(()), val(usize::max_value()));
471 }
472
473 #[test]
474 fn validate_len() {
475 let val = |buf_size, len| BitArrayError::validate_len(buf_size, len);
476 let err = |buf_size, expected, len| {
477 Err(BitArrayError::UnexpectedInputBufferSize(
478 buf_size, expected, len,
479 ))
480 };
481
482 assert_eq!(err(0, 8, 0), val(0, 0));
483 assert_eq!(Ok(()), val(16, 1));
484 assert_eq!(Ok(()), val(16, 2));
485
486 #[cfg(target_pointer_width = "64")]
487 assert_eq!(
488 Ok(()),
489 val(
490 usize::try_from(u128::from(u64::max_value()) + 65 >> 6 << 3).unwrap(),
491 u64::max_value()
492 )
493 );
494 }
495
496 #[test]
497 fn decode() {
498 let mut decoder = BitArrayBlockDecoder { readahead: None };
499 let mut bytes = BytesMut::from([0u8; 8].as_ref());
500 assert_eq!(None, Decoder::decode(&mut decoder, &mut bytes).unwrap());
501 }
502
503 #[test]
504 fn empty() {
505 assert!(BitArray::from_bits(Bytes::from([0u8; 8].as_ref()))
506 .unwrap()
507 .is_empty());
508 }
509
510 #[tokio::test]
511 async fn construct_and_parse_small_bitarray() {
512 let x = MemoryBackedStore::new();
513 let contents = vec![true, true, false, false, true];
514
515 let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
516 block_on(async {
517 builder.push_all(util::stream_iter_ok(contents)).await?;
518 builder.finalize().await?;
519
520 Ok::<_, io::Error>(())
521 })
522 .unwrap();
523
524 let loaded = block_on(x.map()).unwrap();
525
526 let bitarray = BitArray::from_bits(loaded).unwrap();
527
528 assert_eq!(true, bitarray.get(0));
529 assert_eq!(true, bitarray.get(1));
530 assert_eq!(false, bitarray.get(2));
531 assert_eq!(false, bitarray.get(3));
532 assert_eq!(true, bitarray.get(4));
533 }
534
535 #[tokio::test]
536 async fn construct_and_parse_large_bitarray() {
537 let x = MemoryBackedStore::new();
538 let contents = (0..).map(|n| n % 3 == 0).take(123456);
539
540 let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
541 block_on(async {
542 builder.push_all(util::stream_iter_ok(contents)).await?;
543 builder.finalize().await?;
544
545 Ok::<_, io::Error>(())
546 })
547 .unwrap();
548
549 let loaded = block_on(x.map()).unwrap();
550
551 let bitarray = BitArray::from_bits(loaded).unwrap();
552
553 for i in 0..bitarray.len() {
554 assert_eq!(i % 3 == 0, bitarray.get(i));
555 }
556 }
557
558 #[tokio::test]
559 async fn bitarray_len_from_file_errors() {
560 let store = MemoryBackedStore::new();
561 let mut writer = store.open_write().await.unwrap();
562 writer.write_all(&[0, 0, 0]).await.unwrap();
563 writer.sync_all().await.unwrap();
564 assert_eq!(
565 io::Error::from(BitArrayError::InputBufferTooSmall(3)).to_string(),
566 block_on(bitarray_len_from_file(store))
567 .err()
568 .unwrap()
569 .to_string()
570 );
571
572 let store = MemoryBackedStore::new();
573 let mut writer = store.open_write().await.unwrap();
574 writer.write_all(&[0, 0, 0, 0, 0, 0, 0, 2]).await.unwrap();
575 writer.sync_all().await.unwrap();
576 assert_eq!(
577 io::Error::from(BitArrayError::UnexpectedInputBufferSize(8, 16, 2)).to_string(),
578 block_on(bitarray_len_from_file(store))
579 .err()
580 .unwrap()
581 .to_string()
582 );
583 }
584
585 #[tokio::test]
586 async fn stream_blocks() {
587 let x = MemoryBackedStore::new();
588 let contents: Vec<bool> = (0..).map(|n| n % 4 == 1).take(256).collect();
589
590 let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
591 builder
592 .push_all(util::stream_iter_ok(contents))
593 .await
594 .unwrap();
595 builder.finalize().await.unwrap();
596
597 let stream = bitarray_stream_blocks(x.open_read().await.unwrap());
598
599 stream
600 .try_for_each(|block| future::ok(assert_eq!(0x4444444444444444, block)))
601 .await
602 .unwrap();
603 }
604
605 #[tokio::test]
606 async fn stream_bits() {
607 let x = MemoryBackedStore::new();
608 let contents: Vec<_> = (0..).map(|n| n % 4 == 1).take(123).collect();
609
610 let mut builder = BitArrayFileBuilder::new(x.open_write().await.unwrap());
611 block_on(async {
612 builder
613 .push_all(util::stream_iter_ok(contents.clone()))
614 .await?;
615 builder.finalize().await?;
616
617 Ok::<_, io::Error>(())
618 })
619 .unwrap();
620
621 let result: Vec<_> =
622 block_on(bitarray_stream_bits(x).await.unwrap().try_collect()).unwrap();
623
624 assert_eq!(contents, result);
625 }
626}