Skip to main content

vexil_runtime/
bit_reader.rs

1use crate::error::DecodeError;
2use crate::{MAX_BYTES_LENGTH, MAX_RECURSION_DEPTH};
3
4/// A cursor over a byte slice that reads fields LSB-first at the bit level.
5///
6/// Created with [`BitReader::new`], consumed with `read_*` methods. Tracks
7/// a byte position and a sub-byte bit offset, plus a recursion depth counter
8/// for safely decoding recursive types.
9///
10/// Sub-byte reads pull individual bits from the current byte. Multi-byte reads
11/// (e.g. [`read_u16`](Self::read_u16)) first align to the next byte boundary,
12/// then interpret the bytes as little-endian.
13pub struct BitReader<'a> {
14    data: &'a [u8],
15    byte_pos: usize,
16    bit_offset: u8,
17    recursion_depth: u32,
18}
19
20impl<'a> BitReader<'a> {
21    /// Create a new `BitReader` over the given byte slice.
22    pub fn new(data: &'a [u8]) -> Self {
23        Self {
24            data,
25            byte_pos: 0,
26            bit_offset: 0,
27            recursion_depth: 0,
28        }
29    }
30
31    /// Read `count` bits LSB-first into a u64.
32    pub fn read_bits(&mut self, count: u8) -> Result<u64, DecodeError> {
33        let mut result: u64 = 0;
34        for i in 0..count {
35            if self.byte_pos >= self.data.len() {
36                return Err(DecodeError::UnexpectedEof);
37            }
38            let bit = (self.data[self.byte_pos] >> self.bit_offset) & 1;
39            result |= u64::from(bit) << i;
40            self.bit_offset += 1;
41            if self.bit_offset == 8 {
42                self.byte_pos += 1;
43                self.bit_offset = 0;
44            }
45        }
46        Ok(result)
47    }
48
49    /// Read a single bit as bool.
50    pub fn read_bool(&mut self) -> Result<bool, DecodeError> {
51        Ok(self.read_bits(1)? != 0)
52    }
53
54    /// Advance to the next byte boundary, discarding any remaining bits in the current byte.
55    /// Infallible.
56    pub fn flush_to_byte_boundary(&mut self) {
57        if self.bit_offset > 0 {
58            self.byte_pos += 1;
59            self.bit_offset = 0;
60        }
61    }
62
63    /// Remaining bytes from byte_pos.
64    fn remaining(&self) -> usize {
65        self.data.len().saturating_sub(self.byte_pos)
66    }
67
68    /// Read a `u8`, aligning to a byte boundary first.
69    pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
70        self.flush_to_byte_boundary();
71        if self.remaining() < 1 {
72            return Err(DecodeError::UnexpectedEof);
73        }
74        let v = self.data[self.byte_pos];
75        self.byte_pos += 1;
76        Ok(v)
77    }
78
79    /// Read a little-endian `u16`, aligning to a byte boundary first.
80    pub fn read_u16(&mut self) -> Result<u16, DecodeError> {
81        self.flush_to_byte_boundary();
82        if self.remaining() < 2 {
83            return Err(DecodeError::UnexpectedEof);
84        }
85        let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
86            .try_into()
87            .unwrap();
88        self.byte_pos += 2;
89        Ok(u16::from_le_bytes(bytes))
90    }
91
92    /// Read a little-endian `u32`, aligning to a byte boundary first.
93    pub fn read_u32(&mut self) -> Result<u32, DecodeError> {
94        self.flush_to_byte_boundary();
95        if self.remaining() < 4 {
96            return Err(DecodeError::UnexpectedEof);
97        }
98        let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
99            .try_into()
100            .unwrap();
101        self.byte_pos += 4;
102        Ok(u32::from_le_bytes(bytes))
103    }
104
105    /// Read a little-endian `u64`, aligning to a byte boundary first.
106    pub fn read_u64(&mut self) -> Result<u64, DecodeError> {
107        self.flush_to_byte_boundary();
108        if self.remaining() < 8 {
109            return Err(DecodeError::UnexpectedEof);
110        }
111        let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
112            .try_into()
113            .unwrap();
114        self.byte_pos += 8;
115        Ok(u64::from_le_bytes(bytes))
116    }
117
118    /// Read an `i8`, aligning to a byte boundary first.
119    pub fn read_i8(&mut self) -> Result<i8, DecodeError> {
120        self.flush_to_byte_boundary();
121        if self.remaining() < 1 {
122            return Err(DecodeError::UnexpectedEof);
123        }
124        let bytes: [u8; 1] = [self.data[self.byte_pos]];
125        self.byte_pos += 1;
126        Ok(i8::from_le_bytes(bytes))
127    }
128
129    /// Read a little-endian `i16`, aligning to a byte boundary first.
130    pub fn read_i16(&mut self) -> Result<i16, DecodeError> {
131        self.flush_to_byte_boundary();
132        if self.remaining() < 2 {
133            return Err(DecodeError::UnexpectedEof);
134        }
135        let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
136            .try_into()
137            .unwrap();
138        self.byte_pos += 2;
139        Ok(i16::from_le_bytes(bytes))
140    }
141
142    /// Read a little-endian `i32`, aligning to a byte boundary first.
143    pub fn read_i32(&mut self) -> Result<i32, DecodeError> {
144        self.flush_to_byte_boundary();
145        if self.remaining() < 4 {
146            return Err(DecodeError::UnexpectedEof);
147        }
148        let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
149            .try_into()
150            .unwrap();
151        self.byte_pos += 4;
152        Ok(i32::from_le_bytes(bytes))
153    }
154
155    /// Read a little-endian `i64`, aligning to a byte boundary first.
156    pub fn read_i64(&mut self) -> Result<i64, DecodeError> {
157        self.flush_to_byte_boundary();
158        if self.remaining() < 8 {
159            return Err(DecodeError::UnexpectedEof);
160        }
161        let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
162            .try_into()
163            .unwrap();
164        self.byte_pos += 8;
165        Ok(i64::from_le_bytes(bytes))
166    }
167
168    /// Read a little-endian `f32`, aligning to a byte boundary first.
169    pub fn read_f32(&mut self) -> Result<f32, DecodeError> {
170        self.flush_to_byte_boundary();
171        if self.remaining() < 4 {
172            return Err(DecodeError::UnexpectedEof);
173        }
174        let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
175            .try_into()
176            .unwrap();
177        self.byte_pos += 4;
178        Ok(f32::from_le_bytes(bytes))
179    }
180
181    /// Read a little-endian `f64`, aligning to a byte boundary first.
182    pub fn read_f64(&mut self) -> Result<f64, DecodeError> {
183        self.flush_to_byte_boundary();
184        if self.remaining() < 8 {
185            return Err(DecodeError::UnexpectedEof);
186        }
187        let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
188            .try_into()
189            .unwrap();
190        self.byte_pos += 8;
191        Ok(f64::from_le_bytes(bytes))
192    }
193
194    /// Read a LEB128-encoded u64, consuming at most `max_bytes` bytes.
195    pub fn read_leb128(&mut self, max_bytes: u8) -> Result<u64, DecodeError> {
196        self.flush_to_byte_boundary();
197        let (value, consumed) = crate::leb128::decode(&self.data[self.byte_pos..], max_bytes)?;
198        self.byte_pos += consumed;
199        Ok(value)
200    }
201
202    /// Read a ZigZag + LEB128 encoded signed integer.
203    pub fn read_zigzag(&mut self, _type_bits: u8, max_bytes: u8) -> Result<i64, DecodeError> {
204        let raw = self.read_leb128(max_bytes)?;
205        Ok(crate::zigzag::zigzag_decode(raw))
206    }
207
208    /// Read a length-prefixed UTF-8 string.
209    pub fn read_string(&mut self) -> Result<String, DecodeError> {
210        self.flush_to_byte_boundary();
211        let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
212        if len > MAX_BYTES_LENGTH {
213            return Err(DecodeError::LimitExceeded {
214                field: "string",
215                limit: MAX_BYTES_LENGTH,
216                actual: len,
217            });
218        }
219        let len = len as usize;
220        if self.remaining() < len {
221            return Err(DecodeError::UnexpectedEof);
222        }
223        let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
224        self.byte_pos += len;
225        String::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
226    }
227
228    /// Read a length-prefixed byte vector.
229    pub fn read_bytes(&mut self) -> Result<Vec<u8>, DecodeError> {
230        self.flush_to_byte_boundary();
231        let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
232        if len > MAX_BYTES_LENGTH {
233            return Err(DecodeError::LimitExceeded {
234                field: "bytes",
235                limit: MAX_BYTES_LENGTH,
236                actual: len,
237            });
238        }
239        let len = len as usize;
240        if self.remaining() < len {
241            return Err(DecodeError::UnexpectedEof);
242        }
243        let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
244        self.byte_pos += len;
245        Ok(bytes)
246    }
247
248    /// Read exactly `len` raw bytes with no length prefix.
249    pub fn read_raw_bytes(&mut self, len: usize) -> Result<Vec<u8>, DecodeError> {
250        self.flush_to_byte_boundary();
251        if self.remaining() < len {
252            return Err(DecodeError::UnexpectedEof);
253        }
254        let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
255        self.byte_pos += len;
256        Ok(bytes)
257    }
258
259    /// Increment recursion depth; return error if limit exceeded.
260    pub fn enter_recursive(&mut self) -> Result<(), DecodeError> {
261        self.recursion_depth += 1;
262        if self.recursion_depth > MAX_RECURSION_DEPTH {
263            return Err(DecodeError::RecursionLimitExceeded);
264        }
265        Ok(())
266    }
267
268    /// Decrement recursion depth.
269    pub fn leave_recursive(&mut self) {
270        self.recursion_depth = self.recursion_depth.saturating_sub(1);
271    }
272}
273
274#[cfg(test)]
275mod tests {
276    use super::*;
277    use crate::BitWriter;
278
279    #[test]
280    fn read_single_bit() {
281        let mut r = BitReader::new(&[0x01]);
282        assert!(r.read_bool().unwrap());
283    }
284
285    #[test]
286    fn round_trip_sub_byte() {
287        let mut w = BitWriter::new();
288        w.write_bits(5, 3);
289        w.write_bits(19, 5);
290        w.write_bits(42, 6);
291        let buf = w.finish();
292        let mut r = BitReader::new(&buf);
293        assert_eq!(r.read_bits(3).unwrap(), 5);
294        assert_eq!(r.read_bits(5).unwrap(), 19);
295        assert_eq!(r.read_bits(6).unwrap(), 42);
296    }
297
298    #[test]
299    fn round_trip_u16() {
300        let mut w = BitWriter::new();
301        w.write_u16(0x1234);
302        let b = w.finish();
303        assert_eq!(BitReader::new(&b).read_u16().unwrap(), 0x1234);
304    }
305
306    #[test]
307    fn round_trip_i32_neg() {
308        let mut w = BitWriter::new();
309        w.write_i32(-42);
310        let b = w.finish();
311        assert_eq!(BitReader::new(&b).read_i32().unwrap(), -42);
312    }
313
314    #[test]
315    fn round_trip_f32() {
316        let mut w = BitWriter::new();
317        w.write_f32(std::f32::consts::PI);
318        let b = w.finish();
319        assert_eq!(BitReader::new(&b).read_f32().unwrap(), std::f32::consts::PI);
320    }
321
322    #[test]
323    fn round_trip_f64_nan() {
324        let mut w = BitWriter::new();
325        w.write_f64(f64::NAN);
326        let b = w.finish();
327        let v = BitReader::new(&b).read_f64().unwrap();
328        assert!(v.is_nan());
329        assert_eq!(v.to_bits(), 0x7FF8000000000000);
330    }
331
332    #[test]
333    fn round_trip_string() {
334        let mut w = BitWriter::new();
335        w.write_string("hello");
336        let b = w.finish();
337        assert_eq!(BitReader::new(&b).read_string().unwrap(), "hello");
338    }
339
340    #[test]
341    fn round_trip_leb128() {
342        let mut w = BitWriter::new();
343        w.write_leb128(300);
344        let b = w.finish();
345        assert_eq!(BitReader::new(&b).read_leb128(4).unwrap(), 300);
346    }
347
348    #[test]
349    fn round_trip_zigzag() {
350        let mut w = BitWriter::new();
351        w.write_zigzag(-42, 64);
352        let b = w.finish();
353        assert_eq!(BitReader::new(&b).read_zigzag(64, 10).unwrap(), -42);
354    }
355
356    #[test]
357    fn unexpected_eof() {
358        assert_eq!(
359            BitReader::new(&[]).read_u8().unwrap_err(),
360            DecodeError::UnexpectedEof
361        );
362    }
363
364    #[test]
365    fn invalid_utf8() {
366        let mut w = BitWriter::new();
367        w.write_leb128(2);
368        w.write_raw_bytes(&[0xFF, 0xFE]);
369        let b = w.finish();
370        assert_eq!(
371            BitReader::new(&b).read_string().unwrap_err(),
372            DecodeError::InvalidUtf8
373        );
374    }
375
376    #[test]
377    fn recursion_depth_limit() {
378        let mut r = BitReader::new(&[]);
379        for _ in 0..64 {
380            r.enter_recursive().unwrap();
381        }
382        assert_eq!(
383            r.enter_recursive().unwrap_err(),
384            DecodeError::RecursionLimitExceeded
385        );
386    }
387
388    #[test]
389    fn recursion_depth_leave() {
390        let mut r = BitReader::new(&[]);
391        for _ in 0..64 {
392            r.enter_recursive().unwrap();
393        }
394        r.leave_recursive();
395        r.enter_recursive().unwrap();
396    }
397
398    #[test]
399    fn trailing_bytes_not_rejected() {
400        // Simulate v2-encoded message read by v1 decoder:
401        // v2 wrote u32(42) + u16(99), v1 only reads u32(42)
402        let data = [0x2a, 0x00, 0x00, 0x00, 0x63, 0x00];
403        let mut r = BitReader::new(&data);
404        let x = r.read_u32().unwrap();
405        assert_eq!(x, 42);
406        r.flush_to_byte_boundary();
407        // Remaining bytes (0x63, 0x00) must not cause error.
408        // BitReader can be dropped with unread data — no panic.
409    }
410
411    #[test]
412    fn flush_reader() {
413        let mut w = BitWriter::new();
414        w.write_bits(0b101, 3);
415        w.flush_to_byte_boundary();
416        w.write_u8(0xAB);
417        let b = w.finish();
418        let mut r = BitReader::new(&b);
419        assert_eq!(r.read_bits(3).unwrap(), 0b101);
420        r.flush_to_byte_boundary();
421        assert_eq!(r.read_u8().unwrap(), 0xAB);
422    }
423}