use std::io;
#[derive(Debug)]
#[must_use]
pub struct BitReader<T> {
data: T,
bit_pos: u8,
current_byte: u8,
}
impl<T> BitReader<T> {
pub const fn new(data: T) -> Self {
Self {
data,
bit_pos: 0,
current_byte: 0,
}
}
}
impl<T: io::Read> BitReader<T> {
pub fn read_bit(&mut self) -> io::Result<bool> {
if self.is_aligned() {
self.update_byte()?;
}
let bit = (self.current_byte >> (7 - self.bit_pos)) & 1;
self.bit_pos = (self.bit_pos + 1) % 8;
Ok(bit == 1)
}
fn update_byte(&mut self) -> io::Result<()> {
let mut buf = [0];
self.data.read_exact(&mut buf)?;
self.current_byte = buf[0];
Ok(())
}
pub fn read_bits(&mut self, count: u8) -> io::Result<u64> {
let count = count.min(64);
let mut bits = 0;
for _ in 0..count {
let bit = self.read_bit()?;
bits <<= 1;
bits |= if bit { 1 } else { 0 };
}
Ok(bits)
}
#[inline(always)]
pub fn align(&mut self) -> io::Result<()> {
self.bit_pos = 0;
Ok(())
}
}
impl<T> BitReader<T> {
#[inline(always)]
#[must_use]
pub fn into_inner(self) -> T {
self.data
}
#[inline(always)]
#[must_use]
pub const fn get_ref(&self) -> &T {
&self.data
}
#[inline(always)]
#[must_use]
pub const fn bit_pos(&self) -> u8 {
self.bit_pos
}
#[inline(always)]
#[must_use]
pub const fn is_aligned(&self) -> bool {
self.bit_pos == 0
}
}
impl<T: io::Read> io::Read for BitReader<T> {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
if self.is_aligned() {
return self.data.read(buf);
}
for byte in buf.iter_mut() {
*byte = 0;
for _ in 0..8 {
let bit = self.read_bit()?;
*byte <<= 1;
*byte |= bit as u8;
}
}
Ok(buf.len())
}
}
impl<B: AsRef<[u8]>> BitReader<std::io::Cursor<B>> {
pub const fn new_from_slice(data: B) -> Self {
Self::new(std::io::Cursor::new(data))
}
}
impl<W: io::Seek + io::Read> BitReader<W> {
pub fn bit_stream_position(&mut self) -> io::Result<u64> {
let pos = self.data.stream_position()?;
Ok(pos * 8 + if self.is_aligned() { 8 } else { self.bit_pos as u64 } - 8)
}
pub fn seek_bits(&mut self, count: i64) -> io::Result<u64> {
if count == 0 {
return self.bit_stream_position();
}
let count = self.bit_pos as i64 + count;
let bit_move = count % 8;
let mut byte_move = count / 8;
if !self.is_aligned() {
byte_move -= 1;
}
if bit_move < 0 {
byte_move -= 1;
}
let mut pos = self.data.seek(io::SeekFrom::Current(byte_move))? * 8;
self.bit_pos = ((8 + bit_move) % 8) as u8;
if !self.is_aligned() {
self.update_byte()?;
pos += self.bit_pos as u64;
}
Ok(pos)
}
}
impl<T: io::Seek + io::Read> io::Seek for BitReader<T> {
fn seek(&mut self, pos: io::SeekFrom) -> io::Result<u64> {
match pos {
io::SeekFrom::Current(offset) if !self.is_aligned() => {
Ok(self.seek_bits(offset * 8)?.div_ceil(8))
}
_ => {
self.bit_pos = 0;
self.data.seek(pos)
}
}
}
}
#[cfg(test)]
#[cfg_attr(all(test, coverage_nightly), coverage(off))]
mod tests {
use io::{Read, Seek};
use super::*;
#[test]
fn test_bit_reader() {
let binary = 0b10101010110011001111000101010101u32;
let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
for i in 0..32 {
assert_eq!(
reader.read_bit().unwrap(),
(binary & (1 << (31 - i))) != 0,
"bit {} is not correct",
i
);
}
assert!(reader.read_bit().is_err(), "there shouldnt be any bits left");
}
#[test]
fn test_bit_reader_read_bits() {
let binary = 0b10101010110011001111000101010101u32;
let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
let cases = [
(3, 0b101),
(4, 0b0101),
(3, 0b011),
(3, 0b001),
(3, 0b100),
(3, 0b111),
(5, 0b10001),
(1, 0b0),
(7, 0b1010101),
];
for (i, (count, expected)) in cases.into_iter().enumerate() {
assert_eq!(
reader.read_bits(count).ok(),
Some(expected),
"reading {} bits ({i}) are not correct",
count
);
}
assert!(reader.read_bit().is_err(), "there shouldnt be any bits left");
}
#[test]
fn test_bit_reader_align() {
let mut reader = BitReader::new_from_slice([0b10000000, 0b10000000, 0b10000000, 0b10000000, 0b10000000, 0b10000000]);
for i in 0..6 {
let pos = reader.data.stream_position().unwrap();
assert_eq!(pos, i, "stream pos");
assert_eq!(reader.bit_pos(), 0, "bit pos");
assert!(reader.read_bit().unwrap(), "bit {} is not correct", i);
reader.align().unwrap();
let pos = reader.data.stream_position().unwrap();
assert_eq!(pos, i + 1, "stream pos");
assert_eq!(reader.bit_pos(), 0, "bit pos");
}
assert!(reader.read_bit().is_err(), "there shouldnt be any bits left");
}
#[test]
fn test_bit_reader_io_read() {
let binary = 0b10101010110011001111000101010101u32;
let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
let mut buf = [0; 1];
reader.read_exact(&mut buf).unwrap();
assert_eq!(buf, [0b10101010]);
assert_eq!(reader.read_bits(1).unwrap(), 0b1);
let mut buf = [0; 1];
reader.read_exact(&mut buf).unwrap();
assert_eq!(buf, [0b10011001]);
}
#[test]
fn test_bit_reader_seek() {
let binary = 0b10101010110011001111000101010101u32;
let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
assert_eq!(reader.seek_bits(5).unwrap(), 5);
assert_eq!(reader.data.stream_position().unwrap(), 1);
assert_eq!(reader.bit_pos(), 5);
assert_eq!(reader.read_bits(1).unwrap(), 0b0);
assert_eq!(reader.bit_pos(), 6);
assert_eq!(reader.seek_bits(0).unwrap(), 6);
assert_eq!(reader.seek_bits(10).unwrap(), 16);
assert_eq!(reader.data.stream_position().unwrap(), 2);
assert_eq!(reader.bit_pos(), 0);
assert_eq!(reader.read_bits(1).unwrap(), 0b1);
assert_eq!(reader.bit_pos(), 1);
assert_eq!(reader.data.stream_position().unwrap(), 3);
assert_eq!(reader.seek_bits(-8).unwrap(), 9);
assert_eq!(reader.data.stream_position().unwrap(), 2);
assert_eq!(reader.bit_pos(), 1);
assert_eq!(reader.read_bits(1).unwrap(), 0b1);
assert_eq!(reader.bit_pos(), 2);
assert_eq!(reader.data.stream_position().unwrap(), 2);
assert_eq!(reader.seek_bits(-2).unwrap(), 8);
assert_eq!(reader.data.stream_position().unwrap(), 1);
assert_eq!(reader.bit_pos(), 0);
assert_eq!(reader.read_bits(1).unwrap(), 0b1);
assert_eq!(reader.bit_pos(), 1);
assert_eq!(reader.data.stream_position().unwrap(), 2);
}
#[test]
fn test_bit_reader_io_seek() {
let binary = 0b10101010110011001111000101010101u32;
let mut reader = BitReader::new_from_slice(binary.to_be_bytes());
assert_eq!(reader.seek(io::SeekFrom::Start(1)).unwrap(), 1);
assert_eq!(reader.bit_pos(), 0);
assert_eq!(reader.data.stream_position().unwrap(), 1);
assert_eq!(reader.read_bits(1).unwrap(), 0b1);
assert_eq!(reader.bit_pos(), 1);
assert_eq!(reader.data.stream_position().unwrap(), 2);
assert_eq!(reader.seek(io::SeekFrom::Current(1)).unwrap(), 3);
assert_eq!(reader.bit_pos(), 1);
assert_eq!(reader.data.stream_position().unwrap(), 3);
assert_eq!(reader.read_bits(1).unwrap(), 0b1);
assert_eq!(reader.bit_pos(), 2);
assert_eq!(reader.data.stream_position().unwrap(), 3);
assert_eq!(reader.seek(io::SeekFrom::Current(-1)).unwrap(), 2);
assert_eq!(reader.bit_pos(), 2);
assert_eq!(reader.data.stream_position().unwrap(), 2);
assert_eq!(reader.read_bits(1).unwrap(), 0b0);
assert_eq!(reader.bit_pos(), 3);
assert_eq!(reader.data.stream_position().unwrap(), 2);
assert_eq!(reader.seek(io::SeekFrom::End(-1)).unwrap(), 3);
assert_eq!(reader.bit_pos(), 0);
assert_eq!(reader.data.stream_position().unwrap(), 3);
assert_eq!(reader.read_bits(1).unwrap(), 0b0);
assert_eq!(reader.bit_pos(), 1);
assert_eq!(reader.data.stream_position().unwrap(), 4);
}
}