Skip to main content

zrip_core/bitstream/
reader.rs

1#![forbid(unsafe_code)]
2
3use crate::error::DecompressError;
4
5pub struct BitReader<'a> {
6    data: &'a [u8],
7    pos: usize,
8    bit_pos: u8,
9}
10
11impl<'a> BitReader<'a> {
12    pub fn new(data: &'a [u8]) -> Self {
13        Self {
14            data,
15            pos: 0,
16            bit_pos: 0,
17        }
18    }
19
20    #[inline]
21    pub fn bits_consumed(&self) -> usize {
22        self.pos * 8 + self.bit_pos as usize
23    }
24
25    #[inline]
26    pub fn bits_remaining(&self) -> usize {
27        self.data.len() * 8 - self.bits_consumed()
28    }
29
30    #[inline]
31    pub fn is_empty(&self) -> bool {
32        self.bits_remaining() == 0
33    }
34
35    #[inline]
36    pub fn bytes_consumed(&self) -> usize {
37        if self.bit_pos == 0 {
38            self.pos
39        } else {
40            self.pos + 1
41        }
42    }
43
44    pub fn read_bits(&mut self, n: u8) -> Result<u32, DecompressError> {
45        debug_assert!(n <= 25);
46        if n == 0 {
47            return Ok(0);
48        }
49        if self.bits_remaining() < n as usize {
50            return Err(DecompressError::InputExhausted);
51        }
52
53        let mut result = 0u32;
54        let mut bits_left = n;
55        let mut bit_offset = 0u8;
56
57        while bits_left > 0 {
58            let avail = 8 - self.bit_pos;
59            let take = bits_left.min(avail);
60            let byte = self.data[self.pos] as u32;
61            let mask = (1u32 << take) - 1;
62            let bits = (byte >> self.bit_pos) & mask;
63            result |= bits << bit_offset;
64
65            bit_offset += take;
66            bits_left -= take;
67            self.bit_pos += take;
68            if self.bit_pos == 8 {
69                self.bit_pos = 0;
70                self.pos += 1;
71            }
72        }
73
74        Ok(result)
75    }
76
77    pub fn read_bits_u16(&mut self, n: u8) -> Result<u16, DecompressError> {
78        self.read_bits(n).map(|v| v as u16)
79    }
80
81    pub fn peek_bits(&self, n: u8) -> Result<u32, DecompressError> {
82        let mut copy = Self {
83            data: self.data,
84            pos: self.pos,
85            bit_pos: self.bit_pos,
86        };
87        copy.read_bits(n)
88    }
89
90    pub fn align_to_byte(&mut self) {
91        if self.bit_pos != 0 {
92            self.bit_pos = 0;
93            self.pos += 1;
94        }
95    }
96
97    pub fn remaining_bytes(&self) -> &'a [u8] {
98        if self.bit_pos == 0 {
99            &self.data[self.pos..]
100        } else {
101            &self.data[self.pos + 1..]
102        }
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109
110    #[test]
111    fn read_single_bits() {
112        let data = [0b1011_0100_u8];
113        let mut r = BitReader::new(&data);
114        assert_eq!(r.read_bits(1).unwrap(), 0);
115        assert_eq!(r.read_bits(1).unwrap(), 0);
116        assert_eq!(r.read_bits(1).unwrap(), 1);
117        assert_eq!(r.read_bits(1).unwrap(), 0);
118        assert_eq!(r.read_bits(1).unwrap(), 1);
119        assert_eq!(r.read_bits(1).unwrap(), 1);
120        assert_eq!(r.read_bits(1).unwrap(), 0);
121        assert_eq!(r.read_bits(1).unwrap(), 1);
122        assert!(r.is_empty());
123    }
124
125    #[test]
126    fn read_multi_bit() {
127        let data = [0xFF, 0x01];
128        let mut r = BitReader::new(&data);
129        assert_eq!(r.read_bits(8).unwrap(), 0xFF);
130        assert_eq!(r.read_bits(8).unwrap(), 0x01);
131    }
132
133    #[test]
134    fn read_cross_byte() {
135        let data = [0b1101_0110, 0b1011_0001];
136        let mut r = BitReader::new(&data);
137        assert_eq!(r.read_bits(4).unwrap(), 0b0110);
138        assert_eq!(r.read_bits(8).unwrap(), 0b0001_1101);
139        assert_eq!(r.read_bits(4).unwrap(), 0b1011);
140    }
141
142    #[test]
143    fn read_zero_bits() {
144        let data = [0xFF];
145        let mut r = BitReader::new(&data);
146        assert_eq!(r.read_bits(0).unwrap(), 0);
147        assert_eq!(r.bits_consumed(), 0);
148    }
149
150    #[test]
151    fn exhaustion() {
152        let data = [0xFF];
153        let mut r = BitReader::new(&data);
154        assert_eq!(r.read_bits(8).unwrap(), 0xFF);
155        assert!(r.read_bits(1).is_err());
156    }
157}