Skip to main content

zrip_core/bitstream/
reader_reverse.rs

1#![forbid(unsafe_code)]
2
3use crate::bitstream::primitives;
4use crate::error::DecompressError;
5
6/// Reverse bitstream reader using C zstd's bitsConsumed model.
7///
8/// Instead of tracking `bits_available` (decrement on consume, increment on refill),
9/// this tracks `bits_consumed` (increment on consume, reset on reload). Peek uses
10/// a double-shift: `(container << consumed) >> (64 - n)` = 2 ops vs the old model's
11/// 3 ops (shift + mask + subtract).
12pub struct ReverseBitReader<'a> {
13    pub data: &'a [u8],
14    pub container: u64,
15    pub bits_consumed: u32,
16    pub ptr: usize,
17    pub limit_ptr: usize,
18}
19
20impl<'a> ReverseBitReader<'a> {
21    pub fn new(data: &'a [u8]) -> Result<Self, DecompressError> {
22        if data.is_empty() {
23            return Err(DecompressError::InputExhausted);
24        }
25
26        let last_byte = *data.last().unwrap();
27        if last_byte == 0 {
28            return Err(DecompressError::CorruptSequences);
29        }
30
31        let initial_consumed = last_byte.leading_zeros() + 1;
32
33        let ptr = if data.len() >= 8 { data.len() - 8 } else { 0 };
34
35        let container = if data.len() >= 8 {
36            primitives::read_u64_le_unaligned(data, ptr)
37        } else {
38            let mut val = 0u64;
39            for (i, &b) in data.iter().enumerate() {
40                val |= (b as u64) << (i * 8);
41            }
42            val
43        };
44
45        let bits_consumed = if data.len() >= 8 {
46            initial_consumed
47        } else {
48            64 - (data.len() as u32) * 8 + initial_consumed
49        };
50
51        let limit_ptr = if data.len() >= 8 { 8 } else { 0 };
52
53        Ok(Self {
54            data,
55            container,
56            bits_consumed,
57            ptr,
58            limit_ptr,
59        })
60    }
61
62    #[inline(always)]
63    pub fn refill(&mut self) {
64        if self.bits_consumed <= 7 || self.ptr == 0 {
65            return;
66        }
67        let byte_shift = (self.bits_consumed >> 3) as usize;
68        let actual_shift = byte_shift.min(self.ptr);
69        self.ptr -= actual_shift;
70        self.bits_consumed -= (actual_shift as u32) * 8;
71        if self.ptr + 8 <= self.data.len() {
72            self.container = primitives::read_u64_le_unaligned(self.data, self.ptr);
73        } else {
74            let mut val = 0u64;
75            let avail = self.data.len() - self.ptr;
76            for i in 0..avail {
77                val |= (primitives::get_byte_unchecked(self.data, self.ptr + i) as u64) << (i * 8);
78            }
79            self.container = val;
80        }
81    }
82
83    #[inline]
84    pub fn read_bits(&mut self, n: u8) -> Result<u32, DecompressError> {
85        debug_assert!(n <= 32);
86        if n == 0 {
87            return Ok(0);
88        }
89        self.refill();
90        let avail = 64u32.saturating_sub(self.bits_consumed);
91        if (n as u32) > avail {
92            return Err(DecompressError::InputExhausted);
93        }
94        let result = ((self.container << self.bits_consumed) >> (64 - n as u32)) as u32;
95        self.bits_consumed += n as u32;
96        Ok(result)
97    }
98
99    #[inline]
100    pub fn read_bits_unchecked(&mut self, n: u8) -> u32 {
101        debug_assert!(n <= 32);
102        if n == 0 {
103            return 0;
104        }
105        self.refill();
106        debug_assert!((n as u32) <= 64u32.saturating_sub(self.bits_consumed));
107        let result = ((self.container << self.bits_consumed) >> (64 - n as u32)) as u32;
108        self.bits_consumed += n as u32;
109        result
110    }
111
112    #[inline(always)]
113    pub fn consume_bits(&mut self, n: u8) {
114        debug_assert!((n as u32) <= 64u32.saturating_sub(self.bits_consumed));
115        self.bits_consumed += n as u32;
116        self.refill();
117    }
118
119    #[inline(always)]
120    pub fn read_bits_fast(&mut self, n: u8) -> u32 {
121        debug_assert!((n as u32) <= 64u32.saturating_sub(self.bits_consumed));
122        if n == 0 {
123            return 0;
124        }
125        let result = ((self.container << self.bits_consumed) >> (64 - n as u32)) as u32;
126        self.bits_consumed += n as u32;
127        result
128    }
129
130    #[inline(always)]
131    pub fn read_bits_branchless(&mut self, n: u8) -> u32 {
132        debug_assert!(n <= 32);
133        let result = ((self.container << (self.bits_consumed & 63)) >> 1 >> (63 - n as u32)) as u32;
134        self.bits_consumed += n as u32;
135        result
136    }
137
138    #[inline(always)]
139    pub fn refill_fast(&mut self) {
140        let byte_shift = (self.bits_consumed >> 3) as usize;
141        debug_assert!(self.ptr >= self.limit_ptr);
142        debug_assert!(byte_shift <= self.ptr);
143        debug_assert!(self.ptr - byte_shift + 8 <= self.data.len());
144        self.ptr -= byte_shift;
145        self.bits_consumed -= (byte_shift as u32) * 8;
146        self.container = primitives::read_u64_le_unaligned(self.data, self.ptr);
147    }
148
149    #[inline]
150    pub fn peek_bits(&self, n: u8) -> u32 {
151        debug_assert!(n <= 32);
152        debug_assert!((n as u32) <= 64u32.saturating_sub(self.bits_consumed));
153        if n == 0 {
154            return 0;
155        }
156        ((self.container << self.bits_consumed) >> (64 - n as u32)) as u32
157    }
158
159    #[inline]
160    pub fn bits_remaining(&self) -> usize {
161        64usize.saturating_sub(self.bits_consumed as usize) + self.ptr * 8
162    }
163
164    #[inline]
165    pub fn is_empty(&self) -> bool {
166        self.bits_consumed >= 64 && self.ptr == 0
167    }
168}
169
170#[cfg(test)]
171mod tests {
172    use super::*;
173
174    #[test]
175    fn empty_input() {
176        assert!(ReverseBitReader::new(&[]).is_err());
177    }
178
179    #[test]
180    fn zero_last_byte() {
181        assert!(ReverseBitReader::new(&[0x00]).is_err());
182    }
183
184    #[test]
185    fn sentinel_only_no_data() {
186        let data = [0b0000_0001];
187        let r = ReverseBitReader::new(&data).unwrap();
188        assert_eq!(r.bits_remaining(), 0);
189    }
190
191    #[test]
192    fn roundtrip_with_forward_writer() {
193        use crate::bitstream::writer::BitWriter;
194
195        let mut w = BitWriter::new();
196        w.write_bits(0b101, 3);
197        w.write_bits(0b1100_1010, 8);
198        w.write_bits(0b1, 1);
199        w.close_reverse_stream();
200        let bytes = w.into_bytes();
201
202        let mut r = ReverseBitReader::new(&bytes).unwrap();
203        assert_eq!(r.read_bits(1).unwrap(), 0b1);
204        assert_eq!(r.read_bits(8).unwrap(), 0b1100_1010);
205        assert_eq!(r.read_bits(3).unwrap(), 0b101);
206        assert_eq!(r.bits_remaining(), 0);
207    }
208
209    #[test]
210    fn single_byte_with_data() {
211        let data = [0b0000_1101];
212        let mut r = ReverseBitReader::new(&data).unwrap();
213        assert_eq!(r.read_bits(3).unwrap(), 0b101);
214        assert_eq!(r.bits_remaining(), 0);
215    }
216
217    #[test]
218    fn multi_byte_stream() {
219        use crate::bitstream::writer::BitWriter;
220
221        let mut w = BitWriter::new();
222        w.write_bits(0xFF, 8);
223        w.write_bits(0x3, 2);
224        w.close_reverse_stream();
225        let bytes = w.into_bytes();
226
227        let mut r = ReverseBitReader::new(&bytes).unwrap();
228        assert_eq!(r.read_bits(2).unwrap(), 0x3);
229        assert_eq!(r.read_bits(8).unwrap(), 0xFF);
230    }
231}