Skip to main content

zrip_core/bitstream/
reader_reverse.rs

1use crate::error::DecompressError;
2
3/// Reverse bitstream reader using C zstd's bitsConsumed model.
4///
5/// Instead of tracking `bits_available` (decrement on consume, increment on refill),
6/// this tracks `bits_consumed` (increment on consume, reset on reload). Peek uses
7/// a double-shift: `(container << consumed) >> (64 - n)` = 2 ops vs the old model's
8/// 3 ops (shift + mask + subtract).
9pub struct ReverseBitReader<'a> {
10    pub data: &'a [u8],
11    pub container: u64,
12    pub bits_consumed: u32,
13    pub ptr: usize,
14    pub limit_ptr: usize,
15}
16
17impl<'a> ReverseBitReader<'a> {
18    pub fn new(data: &'a [u8]) -> Result<Self, DecompressError> {
19        if data.is_empty() {
20            return Err(DecompressError::InputExhausted);
21        }
22
23        let last_byte = *data.last().unwrap();
24        if last_byte == 0 {
25            return Err(DecompressError::CorruptSequences);
26        }
27
28        let initial_consumed = last_byte.leading_zeros() + 1;
29
30        let ptr = if data.len() >= 8 { data.len() - 8 } else { 0 };
31
32        let container = if data.len() >= 8 {
33            unsafe { u64::from_le((data.as_ptr().add(ptr) as *const u64).read_unaligned()) }
34        } else {
35            let mut val = 0u64;
36            for (i, &b) in data.iter().enumerate() {
37                val |= (b as u64) << (i * 8);
38            }
39            val
40        };
41
42        let bits_consumed = if data.len() >= 8 {
43            initial_consumed
44        } else {
45            64 - (data.len() as u32) * 8 + initial_consumed
46        };
47
48        let limit_ptr = if data.len() >= 8 { 8 } else { 0 };
49
50        Ok(Self {
51            data,
52            container,
53            bits_consumed,
54            ptr,
55            limit_ptr,
56        })
57    }
58
59    #[inline(always)]
60    pub fn refill(&mut self) {
61        if self.bits_consumed <= 7 || self.ptr == 0 {
62            return;
63        }
64        let byte_shift = (self.bits_consumed >> 3) as usize;
65        let actual_shift = byte_shift.min(self.ptr);
66        self.ptr -= actual_shift;
67        self.bits_consumed -= (actual_shift as u32) * 8;
68        if self.ptr + 8 <= self.data.len() {
69            self.container = unsafe {
70                u64::from_le((self.data.as_ptr().add(self.ptr) as *const u64).read_unaligned())
71            };
72        } else {
73            let mut val = 0u64;
74            let avail = self.data.len() - self.ptr;
75            for i in 0..avail {
76                val |= (unsafe { *self.data.get_unchecked(self.ptr + i) } as u64) << (i * 8);
77            }
78            self.container = val;
79        }
80    }
81
82    #[inline]
83    pub fn read_bits(&mut self, n: u8) -> Result<u32, DecompressError> {
84        debug_assert!(n <= 32);
85        if n == 0 {
86            return Ok(0);
87        }
88        self.refill();
89        let avail = 64u32.saturating_sub(self.bits_consumed);
90        if (n as u32) > avail {
91            return Err(DecompressError::InputExhausted);
92        }
93        let result = ((self.container << self.bits_consumed) >> (64 - n as u32)) as u32;
94        self.bits_consumed += n as u32;
95        Ok(result)
96    }
97
98    #[inline]
99    pub fn read_bits_unchecked(&mut self, n: u8) -> u32 {
100        debug_assert!(n <= 32);
101        if n == 0 {
102            return 0;
103        }
104        self.refill();
105        debug_assert!((n as u32) <= 64u32.saturating_sub(self.bits_consumed));
106        let result = ((self.container << self.bits_consumed) >> (64 - n as u32)) as u32;
107        self.bits_consumed += n as u32;
108        result
109    }
110
111    #[inline(always)]
112    pub fn consume_bits(&mut self, n: u8) {
113        debug_assert!((n as u32) <= 64u32.saturating_sub(self.bits_consumed));
114        self.bits_consumed += n as u32;
115        self.refill();
116    }
117
118    #[inline(always)]
119    pub fn read_bits_fast(&mut self, n: u8) -> u32 {
120        debug_assert!((n as u32) <= 64u32.saturating_sub(self.bits_consumed));
121        if n == 0 {
122            return 0;
123        }
124        let result = ((self.container << self.bits_consumed) >> (64 - n as u32)) as u32;
125        self.bits_consumed += n as u32;
126        result
127    }
128
129    #[inline(always)]
130    pub fn read_bits_branchless(&mut self, n: u8) -> u32 {
131        debug_assert!(n <= 32);
132        let result = ((self.container << (self.bits_consumed & 63)) >> 1 >> (63 - n as u32)) as u32;
133        self.bits_consumed += n as u32;
134        result
135    }
136
137    #[inline(always)]
138    pub fn refill_fast(&mut self) {
139        debug_assert!(self.ptr >= self.limit_ptr);
140        let byte_shift = (self.bits_consumed >> 3) as usize;
141        self.ptr -= byte_shift;
142        self.bits_consumed -= (byte_shift as u32) * 8;
143        self.container = unsafe {
144            u64::from_le((self.data.as_ptr().add(self.ptr) as *const u64).read_unaligned())
145        };
146    }
147
148    #[inline]
149    pub fn peek_bits(&self, n: u8) -> u32 {
150        debug_assert!(n <= 32);
151        debug_assert!((n as u32) <= 64u32.saturating_sub(self.bits_consumed));
152        if n == 0 {
153            return 0;
154        }
155        ((self.container << self.bits_consumed) >> (64 - n as u32)) as u32
156    }
157
158    #[inline]
159    pub fn bits_remaining(&self) -> usize {
160        64usize.saturating_sub(self.bits_consumed as usize) + self.ptr * 8
161    }
162
163    #[inline]
164    pub fn is_empty(&self) -> bool {
165        self.bits_consumed >= 64 && self.ptr == 0
166    }
167}
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[test]
174    fn empty_input() {
175        assert!(ReverseBitReader::new(&[]).is_err());
176    }
177
178    #[test]
179    fn zero_last_byte() {
180        assert!(ReverseBitReader::new(&[0x00]).is_err());
181    }
182
183    #[test]
184    fn sentinel_only_no_data() {
185        let data = [0b00000001];
186        let r = ReverseBitReader::new(&data).unwrap();
187        assert_eq!(r.bits_remaining(), 0);
188    }
189
190    #[test]
191    fn roundtrip_with_forward_writer() {
192        use crate::bitstream::writer::BitWriter;
193
194        let mut w = BitWriter::new();
195        w.write_bits(0b101, 3);
196        w.write_bits(0b11001010, 8);
197        w.write_bits(0b1, 1);
198        w.close_reverse_stream();
199        let bytes = w.into_bytes();
200
201        let mut r = ReverseBitReader::new(&bytes).unwrap();
202        assert_eq!(r.read_bits(1).unwrap(), 0b1);
203        assert_eq!(r.read_bits(8).unwrap(), 0b11001010);
204        assert_eq!(r.read_bits(3).unwrap(), 0b101);
205        assert_eq!(r.bits_remaining(), 0);
206    }
207
208    #[test]
209    fn single_byte_with_data() {
210        let data = [0b00001101];
211        let mut r = ReverseBitReader::new(&data).unwrap();
212        assert_eq!(r.read_bits(3).unwrap(), 0b101);
213        assert_eq!(r.bits_remaining(), 0);
214    }
215
216    #[test]
217    fn multi_byte_stream() {
218        use crate::bitstream::writer::BitWriter;
219
220        let mut w = BitWriter::new();
221        w.write_bits(0xFF, 8);
222        w.write_bits(0x3, 2);
223        w.close_reverse_stream();
224        let bytes = w.into_bytes();
225
226        let mut r = ReverseBitReader::new(&bytes).unwrap();
227        assert_eq!(r.read_bits(2).unwrap(), 0x3);
228        assert_eq!(r.read_bits(8).unwrap(), 0xFF);
229    }
230}