Skip to main content

vexil_runtime/
bit_reader.rs

1use crate::error::DecodeError;
2use crate::{MAX_BYTES_LENGTH, MAX_RECURSION_DEPTH};
3
4/// A cursor over a byte slice that reads fields LSB-first at the bit level.
5///
6/// Created with [`BitReader::new`], consumed with `read_*` methods. Tracks
7/// a byte position and a sub-byte bit offset, plus a recursion depth counter
8/// for safely decoding recursive types.
9///
10/// Sub-byte reads pull individual bits from the current byte. Multi-byte reads
11/// (e.g. [`read_u16`](Self::read_u16)) first align to the next byte boundary,
12/// then interpret the bytes as little-endian.
13pub struct BitReader<'a> {
14    data: &'a [u8],
15    byte_pos: usize,
16    bit_offset: u8,
17    recursion_depth: u32,
18}
19
20impl<'a> BitReader<'a> {
21    /// Create a new `BitReader` over the given byte slice.
22    pub fn new(data: &'a [u8]) -> Self {
23        Self {
24            data,
25            byte_pos: 0,
26            bit_offset: 0,
27            recursion_depth: 0,
28        }
29    }
30
31    /// Read `count` bits LSB-first into a u64.
32    pub fn read_bits(&mut self, count: u8) -> Result<u64, DecodeError> {
33        let mut result: u64 = 0;
34        for i in 0..count {
35            if self.byte_pos >= self.data.len() {
36                return Err(DecodeError::UnexpectedEof);
37            }
38            let bit = (self.data[self.byte_pos] >> self.bit_offset) & 1;
39            result |= u64::from(bit) << i;
40            self.bit_offset += 1;
41            if self.bit_offset == 8 {
42                self.byte_pos += 1;
43                self.bit_offset = 0;
44            }
45        }
46        Ok(result)
47    }
48
49    /// Read a single bit as bool.
50    pub fn read_bool(&mut self) -> Result<bool, DecodeError> {
51        Ok(self.read_bits(1)? != 0)
52    }
53
54    /// Advance to the next byte boundary, discarding any remaining bits in the current byte.
55    /// Infallible.
56    pub fn flush_to_byte_boundary(&mut self) {
57        if self.bit_offset > 0 {
58            self.byte_pos += 1;
59            self.bit_offset = 0;
60        }
61    }
62
63    /// Remaining bytes from byte_pos.
64    fn remaining(&self) -> usize {
65        self.data.len().saturating_sub(self.byte_pos)
66    }
67
68    /// Read a `u8`, aligning to a byte boundary first.
69    pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
70        self.flush_to_byte_boundary();
71        if self.remaining() < 1 {
72            return Err(DecodeError::UnexpectedEof);
73        }
74        let v = self.data[self.byte_pos];
75        self.byte_pos += 1;
76        Ok(v)
77    }
78
79    /// Read a little-endian `u16`, aligning to a byte boundary first.
80    pub fn read_u16(&mut self) -> Result<u16, DecodeError> {
81        self.flush_to_byte_boundary();
82        if self.remaining() < 2 {
83            return Err(DecodeError::UnexpectedEof);
84        }
85        let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
86            .try_into()
87            .map_err(|_| DecodeError::UnexpectedEof)?;
88        self.byte_pos += 2;
89        Ok(u16::from_le_bytes(bytes))
90    }
91
92    /// Read a little-endian `u32`, aligning to a byte boundary first.
93    pub fn read_u32(&mut self) -> Result<u32, DecodeError> {
94        self.flush_to_byte_boundary();
95        if self.remaining() < 4 {
96            return Err(DecodeError::UnexpectedEof);
97        }
98        let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
99            .try_into()
100            .map_err(|_| DecodeError::UnexpectedEof)?;
101        self.byte_pos += 4;
102        Ok(u32::from_le_bytes(bytes))
103    }
104
105    /// Read a little-endian `u64`, aligning to a byte boundary first.
106    pub fn read_u64(&mut self) -> Result<u64, DecodeError> {
107        self.flush_to_byte_boundary();
108        if self.remaining() < 8 {
109            return Err(DecodeError::UnexpectedEof);
110        }
111        let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
112            .try_into()
113            .map_err(|_| DecodeError::UnexpectedEof)?;
114        self.byte_pos += 8;
115        Ok(u64::from_le_bytes(bytes))
116    }
117
118    /// Read an `i8`, aligning to a byte boundary first.
119    pub fn read_i8(&mut self) -> Result<i8, DecodeError> {
120        self.flush_to_byte_boundary();
121        if self.remaining() < 1 {
122            return Err(DecodeError::UnexpectedEof);
123        }
124        let bytes: [u8; 1] = [self.data[self.byte_pos]];
125        self.byte_pos += 1;
126        Ok(i8::from_le_bytes(bytes))
127    }
128
129    /// Read a little-endian `i16`, aligning to a byte boundary first.
130    pub fn read_i16(&mut self) -> Result<i16, DecodeError> {
131        self.flush_to_byte_boundary();
132        if self.remaining() < 2 {
133            return Err(DecodeError::UnexpectedEof);
134        }
135        let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
136            .try_into()
137            .map_err(|_| DecodeError::UnexpectedEof)?;
138        self.byte_pos += 2;
139        Ok(i16::from_le_bytes(bytes))
140    }
141
142    /// Read a little-endian `i32`, aligning to a byte boundary first.
143    pub fn read_i32(&mut self) -> Result<i32, DecodeError> {
144        self.flush_to_byte_boundary();
145        if self.remaining() < 4 {
146            return Err(DecodeError::UnexpectedEof);
147        }
148        let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
149            .try_into()
150            .map_err(|_| DecodeError::UnexpectedEof)?;
151        self.byte_pos += 4;
152        Ok(i32::from_le_bytes(bytes))
153    }
154
155    /// Read a little-endian `i64`, aligning to a byte boundary first.
156    pub fn read_i64(&mut self) -> Result<i64, DecodeError> {
157        self.flush_to_byte_boundary();
158        if self.remaining() < 8 {
159            return Err(DecodeError::UnexpectedEof);
160        }
161        let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
162            .try_into()
163            .map_err(|_| DecodeError::UnexpectedEof)?;
164        self.byte_pos += 8;
165        Ok(i64::from_le_bytes(bytes))
166    }
167
168    /// Read a little-endian `f32`, aligning to a byte boundary first.
169    pub fn read_f32(&mut self) -> Result<f32, DecodeError> {
170        self.flush_to_byte_boundary();
171        if self.remaining() < 4 {
172            return Err(DecodeError::UnexpectedEof);
173        }
174        let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
175            .try_into()
176            .map_err(|_| DecodeError::UnexpectedEof)?;
177        self.byte_pos += 4;
178        Ok(f32::from_le_bytes(bytes))
179    }
180
181    /// Read a little-endian `f64`, aligning to a byte boundary first.
182    pub fn read_f64(&mut self) -> Result<f64, DecodeError> {
183        self.flush_to_byte_boundary();
184        if self.remaining() < 8 {
185            return Err(DecodeError::UnexpectedEof);
186        }
187        let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
188            .try_into()
189            .map_err(|_| DecodeError::UnexpectedEof)?;
190        self.byte_pos += 8;
191        Ok(f64::from_le_bytes(bytes))
192    }
193
194    /// Read a LEB128-encoded u64, consuming at most `max_bytes` bytes.
195    pub fn read_leb128(&mut self, max_bytes: u8) -> Result<u64, DecodeError> {
196        self.flush_to_byte_boundary();
197        let (value, consumed) = crate::leb128::decode(&self.data[self.byte_pos..], max_bytes)?;
198        self.byte_pos += consumed;
199        Ok(value)
200    }
201
202    /// Read a ZigZag + LEB128 encoded signed integer.
203    pub fn read_zigzag(&mut self, _type_bits: u8, max_bytes: u8) -> Result<i64, DecodeError> {
204        let raw = self.read_leb128(max_bytes)?;
205        Ok(crate::zigzag::zigzag_decode(raw))
206    }
207
208    /// Read a length-prefixed UTF-8 string.
209    pub fn read_string(&mut self) -> Result<String, DecodeError> {
210        self.flush_to_byte_boundary();
211        let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
212        if len > MAX_BYTES_LENGTH {
213            return Err(DecodeError::LimitExceeded {
214                field: "string",
215                limit: MAX_BYTES_LENGTH,
216                actual: len,
217            });
218        }
219        let len = len as usize;
220        if self.remaining() < len {
221            return Err(DecodeError::UnexpectedEof);
222        }
223        let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
224        self.byte_pos += len;
225        String::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
226    }
227
228    /// Read a length-prefixed byte vector.
229    pub fn read_bytes(&mut self) -> Result<Vec<u8>, DecodeError> {
230        self.flush_to_byte_boundary();
231        let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
232        if len > MAX_BYTES_LENGTH {
233            return Err(DecodeError::LimitExceeded {
234                field: "bytes",
235                limit: MAX_BYTES_LENGTH,
236                actual: len,
237            });
238        }
239        let len = len as usize;
240        if self.remaining() < len {
241            return Err(DecodeError::UnexpectedEof);
242        }
243        let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
244        self.byte_pos += len;
245        Ok(bytes)
246    }
247
248    /// Read exactly `len` raw bytes with no length prefix.
249    pub fn read_raw_bytes(&mut self, len: usize) -> Result<Vec<u8>, DecodeError> {
250        self.flush_to_byte_boundary();
251        if self.remaining() < len {
252            return Err(DecodeError::UnexpectedEof);
253        }
254        let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
255        self.byte_pos += len;
256        Ok(bytes)
257    }
258
259    /// Read all remaining bytes from the current position to the end.
260    /// Flushes to byte boundary first. Returns an empty Vec if no bytes remain.
261    pub fn read_remaining(&mut self) -> Vec<u8> {
262        self.flush_to_byte_boundary();
263        let remaining = self.data.len().saturating_sub(self.byte_pos);
264        if remaining == 0 {
265            return Vec::new();
266        }
267        let result = self.data[self.byte_pos..].to_vec();
268        self.byte_pos = self.data.len();
269        result
270    }
271
272    /// Increment recursion depth; return error if limit exceeded.
273    pub fn enter_recursive(&mut self) -> Result<(), DecodeError> {
274        self.recursion_depth += 1;
275        if self.recursion_depth > MAX_RECURSION_DEPTH {
276            return Err(DecodeError::RecursionLimitExceeded);
277        }
278        Ok(())
279    }
280
281    /// Decrement recursion depth.
282    pub fn leave_recursive(&mut self) {
283        self.recursion_depth = self.recursion_depth.saturating_sub(1);
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290    use crate::BitWriter;
291
292    #[test]
293    fn read_single_bit() {
294        let mut r = BitReader::new(&[0x01]);
295        assert!(r.read_bool().unwrap());
296    }
297
298    #[test]
299    fn round_trip_sub_byte() {
300        let mut w = BitWriter::new();
301        w.write_bits(5, 3);
302        w.write_bits(19, 5);
303        w.write_bits(42, 6);
304        let buf = w.finish();
305        let mut r = BitReader::new(&buf);
306        assert_eq!(r.read_bits(3).unwrap(), 5);
307        assert_eq!(r.read_bits(5).unwrap(), 19);
308        assert_eq!(r.read_bits(6).unwrap(), 42);
309    }
310
311    #[test]
312    fn round_trip_u16() {
313        let mut w = BitWriter::new();
314        w.write_u16(0x1234);
315        let b = w.finish();
316        assert_eq!(BitReader::new(&b).read_u16().unwrap(), 0x1234);
317    }
318
319    #[test]
320    fn round_trip_i32_neg() {
321        let mut w = BitWriter::new();
322        w.write_i32(-42);
323        let b = w.finish();
324        assert_eq!(BitReader::new(&b).read_i32().unwrap(), -42);
325    }
326
327    #[test]
328    fn round_trip_f32() {
329        let mut w = BitWriter::new();
330        w.write_f32(std::f32::consts::PI);
331        let b = w.finish();
332        assert_eq!(BitReader::new(&b).read_f32().unwrap(), std::f32::consts::PI);
333    }
334
335    #[test]
336    fn round_trip_f64_nan() {
337        let mut w = BitWriter::new();
338        w.write_f64(f64::NAN);
339        let b = w.finish();
340        let v = BitReader::new(&b).read_f64().unwrap();
341        assert!(v.is_nan());
342        assert_eq!(v.to_bits(), 0x7FF8000000000000);
343    }
344
345    #[test]
346    fn round_trip_string() {
347        let mut w = BitWriter::new();
348        w.write_string("hello");
349        let b = w.finish();
350        assert_eq!(BitReader::new(&b).read_string().unwrap(), "hello");
351    }
352
353    #[test]
354    fn round_trip_leb128() {
355        let mut w = BitWriter::new();
356        w.write_leb128(300);
357        let b = w.finish();
358        assert_eq!(BitReader::new(&b).read_leb128(4).unwrap(), 300);
359    }
360
361    #[test]
362    fn round_trip_zigzag() {
363        let mut w = BitWriter::new();
364        w.write_zigzag(-42, 64);
365        let b = w.finish();
366        assert_eq!(BitReader::new(&b).read_zigzag(64, 10).unwrap(), -42);
367    }
368
369    #[test]
370    fn unexpected_eof() {
371        assert_eq!(
372            BitReader::new(&[]).read_u8().unwrap_err(),
373            DecodeError::UnexpectedEof
374        );
375    }
376
377    #[test]
378    fn invalid_utf8() {
379        let mut w = BitWriter::new();
380        w.write_leb128(2);
381        w.write_raw_bytes(&[0xFF, 0xFE]);
382        let b = w.finish();
383        assert_eq!(
384            BitReader::new(&b).read_string().unwrap_err(),
385            DecodeError::InvalidUtf8
386        );
387    }
388
389    #[test]
390    fn recursion_depth_limit() {
391        let mut r = BitReader::new(&[]);
392        for _ in 0..64 {
393            r.enter_recursive().unwrap();
394        }
395        assert_eq!(
396            r.enter_recursive().unwrap_err(),
397            DecodeError::RecursionLimitExceeded
398        );
399    }
400
401    #[test]
402    fn recursion_depth_leave() {
403        let mut r = BitReader::new(&[]);
404        for _ in 0..64 {
405            r.enter_recursive().unwrap();
406        }
407        r.leave_recursive();
408        r.enter_recursive().unwrap();
409    }
410
411    #[test]
412    fn trailing_bytes_not_rejected() {
413        // Simulate v2-encoded message read by v1 decoder:
414        // v2 wrote u32(42) + u16(99), v1 only reads u32(42)
415        let data = [0x2a, 0x00, 0x00, 0x00, 0x63, 0x00];
416        let mut r = BitReader::new(&data);
417        let x = r.read_u32().unwrap();
418        assert_eq!(x, 42);
419        r.flush_to_byte_boundary();
420        // Remaining bytes (0x63, 0x00) must not cause error.
421        // BitReader can be dropped with unread data — no panic.
422    }
423
424    #[test]
425    fn read_remaining_after_partial_decode() {
426        let data = [0x2a, 0x00, 0x00, 0x00, 0x63, 0x00];
427        let mut r = BitReader::new(&data);
428        let _x = r.read_u32().unwrap();
429        let remaining = r.read_remaining();
430        assert_eq!(remaining, vec![0x63, 0x00]);
431    }
432
433    #[test]
434    fn read_remaining_when_fully_consumed() {
435        let data = [0x2a, 0x00, 0x00, 0x00];
436        let mut r = BitReader::new(&data);
437        let _x = r.read_u32().unwrap();
438        let remaining = r.read_remaining();
439        assert!(remaining.is_empty());
440    }
441
442    #[test]
443    fn read_remaining_from_start() {
444        let data = [0x01, 0x02, 0x03];
445        let mut r = BitReader::new(&data);
446        let remaining = r.read_remaining();
447        assert_eq!(remaining, vec![0x01, 0x02, 0x03]);
448    }
449
450    #[test]
451    fn flush_reader() {
452        let mut w = BitWriter::new();
453        w.write_bits(0b101, 3);
454        w.flush_to_byte_boundary();
455        w.write_u8(0xAB);
456        let b = w.finish();
457        let mut r = BitReader::new(&b);
458        assert_eq!(r.read_bits(3).unwrap(), 0b101);
459        r.flush_to_byte_boundary();
460        assert_eq!(r.read_u8().unwrap(), 0xAB);
461    }
462}