#![allow(clippy::precedence, clippy::verbose_bit_mask)]
use super::util;
use crate::storage::*;
use crate::structure::bititer::BitIter;
use byteorder::{BigEndian, ByteOrder};
use bytes::{Bytes, BytesMut};
use futures::prelude::*;
use std::{convert::TryFrom, error, fmt, io};
use tokio::{
codec::{Decoder, FramedRead},
prelude::*,
};
#[derive(Clone)]
pub struct BitArray {
len: u64,
buf: Bytes,
}
#[derive(Debug, PartialEq)]
pub enum BitArrayError {
InputBufferTooSmall(usize),
UnexpectedInputBufferSize(u64, u64, u64),
}
impl BitArrayError {
fn validate_input_buf_size(input_buf_size: usize) -> Result<(), Self> {
if input_buf_size < 8 {
return Err(BitArrayError::InputBufferTooSmall(input_buf_size));
}
Ok(())
}
fn validate_len(input_buf_size: usize, len: u64) -> Result<(), Self> {
let expected_buf_size = {
let after_shifting = len >> 6 << 3;
if len & 63 == 0 {
after_shifting + 8
} else {
after_shifting + 16
}
};
let input_buf_size = u64::try_from(input_buf_size).unwrap();
if input_buf_size != expected_buf_size {
return Err(BitArrayError::UnexpectedInputBufferSize(
input_buf_size,
expected_buf_size,
len,
));
}
Ok(())
}
}
impl fmt::Display for BitArrayError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
use BitArrayError::*;
match self {
InputBufferTooSmall(input_buf_size) => {
write!(f, "expected input buffer size ({}) >= 8", input_buf_size)
}
UnexpectedInputBufferSize(input_buf_size, expected_buf_size, len) => write!(
f,
"expected input buffer size ({}) to be {} for {} bits",
input_buf_size, expected_buf_size, len
),
}
}
}
impl error::Error for BitArrayError {}
impl From<BitArrayError> for io::Error {
fn from(err: BitArrayError) -> io::Error {
io::Error::new(io::ErrorKind::InvalidData, err)
}
}
fn read_control_word(buf: &[u8], input_buf_size: usize) -> Result<u64, BitArrayError> {
let len = BigEndian::read_u64(buf);
BitArrayError::validate_len(input_buf_size, len)?;
Ok(len)
}
impl BitArray {
pub fn from_bits(mut buf: Bytes) -> Result<BitArray, BitArrayError> {
let input_buf_size = buf.len();
BitArrayError::validate_input_buf_size(input_buf_size)?;
let len = read_control_word(&buf.split_off(input_buf_size - 8), input_buf_size)?;
Ok(BitArray { buf, len })
}
pub fn bits(&self) -> &[u8] {
&self.buf
}
pub fn len(&self) -> usize {
usize::try_from(self.len).unwrap_or_else(|_| {
panic!(
"expected length ({}) to fit in {} bytes",
self.len,
std::mem::size_of::<usize>()
)
})
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
pub fn get(&self, index: usize) -> bool {
let len = self.len();
assert!(index < len, "expected index ({}) < length ({})", index, len);
let byte = self.buf[index / 8];
let mask = 0b1000_0000 >> index % 8;
byte & mask != 0
}
}
pub struct BitArrayFileBuilder<W> {
dest: W,
current: u64,
count: u64,
}
impl<W: AsyncWrite> BitArrayFileBuilder<W> {
pub fn new(dest: W) -> BitArrayFileBuilder<W> {
BitArrayFileBuilder {
dest,
current: 0,
count: 0,
}
}
pub fn push(self, bit: bool) -> impl Future<Item = BitArrayFileBuilder<W>, Error = io::Error> {
let BitArrayFileBuilder {
current,
count,
dest,
} = self;
let current = if bit {
let pos = count & 0b11_1111;
current | 0x8000_0000_0000_0000 >> pos
} else {
current
};
let count = count + 1;
if count & 0b11_1111 == 0 {
future::Either::A(util::write_u64(dest, current).map(move |dest| {
BitArrayFileBuilder {
current: 0,
count,
dest,
}
}))
} else {
future::Either::B(future::ok(BitArrayFileBuilder {
current,
count,
dest,
}))
}
}
pub fn push_all<S: Stream<Item = bool, Error = io::Error>>(
self,
stream: S,
) -> impl Future<Item = BitArrayFileBuilder<W>, Error = io::Error> {
stream.fold(self, |builder, bit| builder.push(bit))
}
fn finalize_data(self) -> impl Future<Item = W, Error = io::Error> {
let BitArrayFileBuilder {
current,
count,
dest,
} = self;
if count & 0b11_1111 == 0 {
future::Either::A(future::ok(dest))
} else {
future::Either::B(util::write_u64(dest, current))
}
}
pub fn finalize(self) -> impl Future<Item = W, Error = io::Error> {
let count = self.count;
self.finalize_data()
.and_then(move |dest| util::write_u64(dest, count))
.and_then(tokio::io::flush)
}
pub fn count(&self) -> u64 {
self.count
}
}
pub struct BitArrayBlockDecoder {
readahead: Option<u64>,
}
impl Decoder for BitArrayBlockDecoder {
type Item = u64;
type Error = io::Error;
fn decode(&mut self, bytes: &mut BytesMut) -> Result<Option<u64>, io::Error> {
if bytes.len() < 8 {
return Ok(None);
}
match self
.readahead
.replace(BigEndian::read_u64(&bytes.split_to(8)))
{
Some(word) => Ok(Some(word)),
None => self.decode(bytes),
}
}
}
pub fn bitarray_stream_blocks<R: AsyncRead>(r: R) -> FramedRead<R, BitArrayBlockDecoder> {
FramedRead::new(r, BitArrayBlockDecoder { readahead: None })
}
fn bitarray_len_from_file<F: FileLoad>(f: F) -> impl Future<Item = (F, u64), Error = io::Error> {
BitArrayError::validate_input_buf_size(f.size())
.map_or_else(|e| Err(e.into()), |_| Ok(f))
.into_future()
.and_then(|f| {
tokio::io::read_exact(f.open_read_from(f.size() - 8), [0; 8]).map(|(_, buf)| (f, buf))
})
.and_then(|(f, control_word)| {
read_control_word(&control_word, f.size())
.map_or_else(|e| Err(e.into()), |len| Ok((f, len)))
.into_future()
})
}
pub fn bitarray_stream_bits<F: FileLoad>(f: F) -> impl Stream<Item = bool, Error = io::Error> {
bitarray_len_from_file(f)
.into_stream()
.map(move |(f, len)| {
bitarray_stream_blocks(f.open_read())
.map(|block| stream::iter_ok(BitIter::new(block)))
.flatten()
.take(len)
})
.flatten()
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::memory::*;
#[test]
fn bit_array_error() {
assert_eq!(
"expected input buffer size (7) >= 8",
BitArrayError::InputBufferTooSmall(7).to_string()
);
assert_eq!(
"expected input buffer size (9) to be 8 for 0 bits",
BitArrayError::UnexpectedInputBufferSize(9, 8, 0).to_string()
);
assert_eq!(
io::Error::new(
io::ErrorKind::InvalidData,
BitArrayError::InputBufferTooSmall(7)
)
.to_string(),
io::Error::from(BitArrayError::InputBufferTooSmall(7)).to_string()
);
}
#[test]
fn validate_input_buf_size() {
let val = |buf_size| BitArrayError::validate_input_buf_size(buf_size);
let err = |buf_size| Err(BitArrayError::InputBufferTooSmall(buf_size));
assert_eq!(err(7), val(7));
assert_eq!(Ok(()), val(8));
assert_eq!(Ok(()), val(9));
assert_eq!(Ok(()), val(usize::max_value()));
}
#[test]
fn validate_len() {
let val = |buf_size, len| BitArrayError::validate_len(buf_size, len);
let err = |buf_size, expected, len| {
Err(BitArrayError::UnexpectedInputBufferSize(
buf_size, expected, len,
))
};
assert_eq!(err(0, 8, 0), val(0, 0));
assert_eq!(Ok(()), val(16, 1));
assert_eq!(Ok(()), val(16, 2));
#[cfg(target_pointer_width = "64")]
assert_eq!(
Ok(()),
val(
usize::try_from(u128::from(u64::max_value()) + 65 >> 6 << 3).unwrap(),
u64::max_value()
)
);
}
#[test]
fn decode() {
let mut decoder = BitArrayBlockDecoder { readahead: None };
let mut bytes = BytesMut::from([0u8; 8].as_ref());
assert_eq!(None, Decoder::decode(&mut decoder, &mut bytes).unwrap());
}
#[test]
pub fn empty() {
assert!(BitArray::from_bits(Bytes::from([0u8; 8].as_ref()))
.unwrap()
.is_empty());
}
#[test]
pub fn construct_and_parse_small_bitarray() {
let x = MemoryBackedStore::new();
let contents = vec![true, true, false, false, true];
BitArrayFileBuilder::new(x.open_write())
.push_all(stream::iter_ok(contents))
.and_then(|b| b.finalize())
.wait()
.unwrap();
let loaded = x.map().wait().unwrap();
let bitarray = BitArray::from_bits(loaded).unwrap();
assert_eq!(true, bitarray.get(0));
assert_eq!(true, bitarray.get(1));
assert_eq!(false, bitarray.get(2));
assert_eq!(false, bitarray.get(3));
assert_eq!(true, bitarray.get(4));
}
#[test]
pub fn construct_and_parse_large_bitarray() {
let x = MemoryBackedStore::new();
let contents = (0..).map(|n| n % 3 == 0).take(123456);
BitArrayFileBuilder::new(x.open_write())
.push_all(stream::iter_ok(contents))
.and_then(|b| b.finalize())
.wait()
.unwrap();
let loaded = x.map().wait().unwrap();
let bitarray = BitArray::from_bits(loaded).unwrap();
for i in 0..bitarray.len() {
assert_eq!(i % 3 == 0, bitarray.get(i));
}
}
#[test]
fn bitarray_len_from_file_errors() {
let store = MemoryBackedStore::new();
let _ = tokio::io::write_all(store.open_write(), [0, 0, 0]).wait();
assert_eq!(
io::Error::from(BitArrayError::InputBufferTooSmall(3)).to_string(),
bitarray_len_from_file(store)
.wait()
.err()
.unwrap()
.to_string()
);
let store = MemoryBackedStore::new();
let _ = tokio::io::write_all(store.open_write(), [0, 0, 0, 0, 0, 0, 0, 2]).wait();
assert_eq!(
io::Error::from(BitArrayError::UnexpectedInputBufferSize(8, 16, 2)).to_string(),
bitarray_len_from_file(store)
.wait()
.err()
.unwrap()
.to_string()
);
}
#[test]
pub fn stream_blocks() {
let x = MemoryBackedStore::new();
let contents = (0..).map(|n| n % 4 == 1).take(256);
BitArrayFileBuilder::new(x.open_write())
.push_all(stream::iter_ok(contents))
.and_then(|b| b.finalize())
.wait()
.unwrap();
let stream = bitarray_stream_blocks(x.open_read());
stream
.for_each(|block| Ok(assert_eq!(0x4444444444444444, block)))
.wait()
.unwrap();
}
#[test]
fn stream_bits() {
let x = MemoryBackedStore::new();
let contents: Vec<_> = (0..).map(|n| n % 4 == 1).take(123).collect();
BitArrayFileBuilder::new(x.open_write())
.push_all(stream::iter_ok(contents.clone()))
.and_then(|b| b.finalize())
.wait()
.unwrap();
let result = bitarray_stream_bits(x).collect().wait().unwrap();
assert_eq!(contents, result);
}
}