vexil_runtime/
bit_reader.rs1use crate::error::DecodeError;
2use crate::{MAX_BYTES_LENGTH, MAX_RECURSION_DEPTH};
3
4pub struct BitReader<'a> {
14 data: &'a [u8],
15 byte_pos: usize,
16 bit_offset: u8,
17 recursion_depth: u32,
18}
19
20impl<'a> BitReader<'a> {
21 pub fn new(data: &'a [u8]) -> Self {
23 Self {
24 data,
25 byte_pos: 0,
26 bit_offset: 0,
27 recursion_depth: 0,
28 }
29 }
30
31 pub fn read_bits(&mut self, count: u8) -> Result<u64, DecodeError> {
33 let mut result: u64 = 0;
34 for i in 0..count {
35 if self.byte_pos >= self.data.len() {
36 return Err(DecodeError::UnexpectedEof);
37 }
38 let bit = (self.data[self.byte_pos] >> self.bit_offset) & 1;
39 result |= u64::from(bit) << i;
40 self.bit_offset += 1;
41 if self.bit_offset == 8 {
42 self.byte_pos += 1;
43 self.bit_offset = 0;
44 }
45 }
46 Ok(result)
47 }
48
49 pub fn read_bool(&mut self) -> Result<bool, DecodeError> {
51 Ok(self.read_bits(1)? != 0)
52 }
53
54 pub fn flush_to_byte_boundary(&mut self) {
57 if self.bit_offset > 0 {
58 self.byte_pos += 1;
59 self.bit_offset = 0;
60 }
61 }
62
63 fn remaining(&self) -> usize {
65 self.data.len().saturating_sub(self.byte_pos)
66 }
67
68 pub fn read_u8(&mut self) -> Result<u8, DecodeError> {
70 self.flush_to_byte_boundary();
71 if self.remaining() < 1 {
72 return Err(DecodeError::UnexpectedEof);
73 }
74 let v = self.data[self.byte_pos];
75 self.byte_pos += 1;
76 Ok(v)
77 }
78
79 pub fn read_u16(&mut self) -> Result<u16, DecodeError> {
81 self.flush_to_byte_boundary();
82 if self.remaining() < 2 {
83 return Err(DecodeError::UnexpectedEof);
84 }
85 let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
86 .try_into()
87 .unwrap();
88 self.byte_pos += 2;
89 Ok(u16::from_le_bytes(bytes))
90 }
91
92 pub fn read_u32(&mut self) -> Result<u32, DecodeError> {
94 self.flush_to_byte_boundary();
95 if self.remaining() < 4 {
96 return Err(DecodeError::UnexpectedEof);
97 }
98 let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
99 .try_into()
100 .unwrap();
101 self.byte_pos += 4;
102 Ok(u32::from_le_bytes(bytes))
103 }
104
105 pub fn read_u64(&mut self) -> Result<u64, DecodeError> {
107 self.flush_to_byte_boundary();
108 if self.remaining() < 8 {
109 return Err(DecodeError::UnexpectedEof);
110 }
111 let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
112 .try_into()
113 .unwrap();
114 self.byte_pos += 8;
115 Ok(u64::from_le_bytes(bytes))
116 }
117
118 pub fn read_i8(&mut self) -> Result<i8, DecodeError> {
120 self.flush_to_byte_boundary();
121 if self.remaining() < 1 {
122 return Err(DecodeError::UnexpectedEof);
123 }
124 let bytes: [u8; 1] = [self.data[self.byte_pos]];
125 self.byte_pos += 1;
126 Ok(i8::from_le_bytes(bytes))
127 }
128
129 pub fn read_i16(&mut self) -> Result<i16, DecodeError> {
131 self.flush_to_byte_boundary();
132 if self.remaining() < 2 {
133 return Err(DecodeError::UnexpectedEof);
134 }
135 let bytes: [u8; 2] = self.data[self.byte_pos..self.byte_pos + 2]
136 .try_into()
137 .unwrap();
138 self.byte_pos += 2;
139 Ok(i16::from_le_bytes(bytes))
140 }
141
142 pub fn read_i32(&mut self) -> Result<i32, DecodeError> {
144 self.flush_to_byte_boundary();
145 if self.remaining() < 4 {
146 return Err(DecodeError::UnexpectedEof);
147 }
148 let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
149 .try_into()
150 .unwrap();
151 self.byte_pos += 4;
152 Ok(i32::from_le_bytes(bytes))
153 }
154
155 pub fn read_i64(&mut self) -> Result<i64, DecodeError> {
157 self.flush_to_byte_boundary();
158 if self.remaining() < 8 {
159 return Err(DecodeError::UnexpectedEof);
160 }
161 let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
162 .try_into()
163 .unwrap();
164 self.byte_pos += 8;
165 Ok(i64::from_le_bytes(bytes))
166 }
167
168 pub fn read_f32(&mut self) -> Result<f32, DecodeError> {
170 self.flush_to_byte_boundary();
171 if self.remaining() < 4 {
172 return Err(DecodeError::UnexpectedEof);
173 }
174 let bytes: [u8; 4] = self.data[self.byte_pos..self.byte_pos + 4]
175 .try_into()
176 .unwrap();
177 self.byte_pos += 4;
178 Ok(f32::from_le_bytes(bytes))
179 }
180
181 pub fn read_f64(&mut self) -> Result<f64, DecodeError> {
183 self.flush_to_byte_boundary();
184 if self.remaining() < 8 {
185 return Err(DecodeError::UnexpectedEof);
186 }
187 let bytes: [u8; 8] = self.data[self.byte_pos..self.byte_pos + 8]
188 .try_into()
189 .unwrap();
190 self.byte_pos += 8;
191 Ok(f64::from_le_bytes(bytes))
192 }
193
194 pub fn read_leb128(&mut self, max_bytes: u8) -> Result<u64, DecodeError> {
196 self.flush_to_byte_boundary();
197 let (value, consumed) = crate::leb128::decode(&self.data[self.byte_pos..], max_bytes)?;
198 self.byte_pos += consumed;
199 Ok(value)
200 }
201
202 pub fn read_zigzag(&mut self, _type_bits: u8, max_bytes: u8) -> Result<i64, DecodeError> {
204 let raw = self.read_leb128(max_bytes)?;
205 Ok(crate::zigzag::zigzag_decode(raw))
206 }
207
208 pub fn read_string(&mut self) -> Result<String, DecodeError> {
210 self.flush_to_byte_boundary();
211 let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
212 if len > MAX_BYTES_LENGTH {
213 return Err(DecodeError::LimitExceeded {
214 field: "string",
215 limit: MAX_BYTES_LENGTH,
216 actual: len,
217 });
218 }
219 let len = len as usize;
220 if self.remaining() < len {
221 return Err(DecodeError::UnexpectedEof);
222 }
223 let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
224 self.byte_pos += len;
225 String::from_utf8(bytes).map_err(|_| DecodeError::InvalidUtf8)
226 }
227
228 pub fn read_bytes(&mut self) -> Result<Vec<u8>, DecodeError> {
230 self.flush_to_byte_boundary();
231 let len = self.read_leb128(crate::MAX_LENGTH_PREFIX_BYTES)?;
232 if len > MAX_BYTES_LENGTH {
233 return Err(DecodeError::LimitExceeded {
234 field: "bytes",
235 limit: MAX_BYTES_LENGTH,
236 actual: len,
237 });
238 }
239 let len = len as usize;
240 if self.remaining() < len {
241 return Err(DecodeError::UnexpectedEof);
242 }
243 let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
244 self.byte_pos += len;
245 Ok(bytes)
246 }
247
248 pub fn read_raw_bytes(&mut self, len: usize) -> Result<Vec<u8>, DecodeError> {
250 self.flush_to_byte_boundary();
251 if self.remaining() < len {
252 return Err(DecodeError::UnexpectedEof);
253 }
254 let bytes = self.data[self.byte_pos..self.byte_pos + len].to_vec();
255 self.byte_pos += len;
256 Ok(bytes)
257 }
258
259 pub fn enter_recursive(&mut self) -> Result<(), DecodeError> {
261 self.recursion_depth += 1;
262 if self.recursion_depth > MAX_RECURSION_DEPTH {
263 return Err(DecodeError::RecursionLimitExceeded);
264 }
265 Ok(())
266 }
267
268 pub fn leave_recursive(&mut self) {
270 self.recursion_depth = self.recursion_depth.saturating_sub(1);
271 }
272}
273
274#[cfg(test)]
275mod tests {
276 use super::*;
277 use crate::BitWriter;
278
279 #[test]
280 fn read_single_bit() {
281 let mut r = BitReader::new(&[0x01]);
282 assert!(r.read_bool().unwrap());
283 }
284
285 #[test]
286 fn round_trip_sub_byte() {
287 let mut w = BitWriter::new();
288 w.write_bits(5, 3);
289 w.write_bits(19, 5);
290 w.write_bits(42, 6);
291 let buf = w.finish();
292 let mut r = BitReader::new(&buf);
293 assert_eq!(r.read_bits(3).unwrap(), 5);
294 assert_eq!(r.read_bits(5).unwrap(), 19);
295 assert_eq!(r.read_bits(6).unwrap(), 42);
296 }
297
298 #[test]
299 fn round_trip_u16() {
300 let mut w = BitWriter::new();
301 w.write_u16(0x1234);
302 let b = w.finish();
303 assert_eq!(BitReader::new(&b).read_u16().unwrap(), 0x1234);
304 }
305
306 #[test]
307 fn round_trip_i32_neg() {
308 let mut w = BitWriter::new();
309 w.write_i32(-42);
310 let b = w.finish();
311 assert_eq!(BitReader::new(&b).read_i32().unwrap(), -42);
312 }
313
314 #[test]
315 fn round_trip_f32() {
316 let mut w = BitWriter::new();
317 w.write_f32(std::f32::consts::PI);
318 let b = w.finish();
319 assert_eq!(BitReader::new(&b).read_f32().unwrap(), std::f32::consts::PI);
320 }
321
322 #[test]
323 fn round_trip_f64_nan() {
324 let mut w = BitWriter::new();
325 w.write_f64(f64::NAN);
326 let b = w.finish();
327 let v = BitReader::new(&b).read_f64().unwrap();
328 assert!(v.is_nan());
329 assert_eq!(v.to_bits(), 0x7FF8000000000000);
330 }
331
332 #[test]
333 fn round_trip_string() {
334 let mut w = BitWriter::new();
335 w.write_string("hello");
336 let b = w.finish();
337 assert_eq!(BitReader::new(&b).read_string().unwrap(), "hello");
338 }
339
340 #[test]
341 fn round_trip_leb128() {
342 let mut w = BitWriter::new();
343 w.write_leb128(300);
344 let b = w.finish();
345 assert_eq!(BitReader::new(&b).read_leb128(4).unwrap(), 300);
346 }
347
348 #[test]
349 fn round_trip_zigzag() {
350 let mut w = BitWriter::new();
351 w.write_zigzag(-42, 64);
352 let b = w.finish();
353 assert_eq!(BitReader::new(&b).read_zigzag(64, 10).unwrap(), -42);
354 }
355
356 #[test]
357 fn unexpected_eof() {
358 assert_eq!(
359 BitReader::new(&[]).read_u8().unwrap_err(),
360 DecodeError::UnexpectedEof
361 );
362 }
363
364 #[test]
365 fn invalid_utf8() {
366 let mut w = BitWriter::new();
367 w.write_leb128(2);
368 w.write_raw_bytes(&[0xFF, 0xFE]);
369 let b = w.finish();
370 assert_eq!(
371 BitReader::new(&b).read_string().unwrap_err(),
372 DecodeError::InvalidUtf8
373 );
374 }
375
376 #[test]
377 fn recursion_depth_limit() {
378 let mut r = BitReader::new(&[]);
379 for _ in 0..64 {
380 r.enter_recursive().unwrap();
381 }
382 assert_eq!(
383 r.enter_recursive().unwrap_err(),
384 DecodeError::RecursionLimitExceeded
385 );
386 }
387
388 #[test]
389 fn recursion_depth_leave() {
390 let mut r = BitReader::new(&[]);
391 for _ in 0..64 {
392 r.enter_recursive().unwrap();
393 }
394 r.leave_recursive();
395 r.enter_recursive().unwrap();
396 }
397
398 #[test]
399 fn trailing_bytes_not_rejected() {
400 let data = [0x2a, 0x00, 0x00, 0x00, 0x63, 0x00];
403 let mut r = BitReader::new(&data);
404 let x = r.read_u32().unwrap();
405 assert_eq!(x, 42);
406 r.flush_to_byte_boundary();
407 }
410
411 #[test]
412 fn flush_reader() {
413 let mut w = BitWriter::new();
414 w.write_bits(0b101, 3);
415 w.flush_to_byte_boundary();
416 w.write_u8(0xAB);
417 let b = w.finish();
418 let mut r = BitReader::new(&b);
419 assert_eq!(r.read_bits(3).unwrap(), 0b101);
420 r.flush_to_byte_boundary();
421 assert_eq!(r.read_u8().unwrap(), 0xAB);
422 }
423}