zune_core/bytestream/
reader.rs

1/*
2 * Copyright (c) 2023.
3 *
4 * This software is free software;
5 *
6 * You can redistribute it or modify it under terms of the MIT, Apache License or Zlib license
7 */
8
9use core::cmp::min;
10
11use crate::bytestream::traits::ZReaderTrait;
12
13const ERROR_MSG: &str = "No more bytes";
14
15/// An encapsulation of a byte stream reader
16///
17/// This provides an interface similar to [std::io::Cursor] but
18/// it provides fine grained options for reading different integer data types from
19/// the underlying buffer.
20///
21/// There are two variants mainly error and non error variants,
22/// the error variants are useful for cases where you need bytes
23/// from the underlying stream, and cannot do with zero result.
24/// the non error variants are useful when you may have proved data already exists
25/// eg by using [`has`] method or you are okay with returning zero if the underlying
26/// buffer has been completely read.
27///
28/// [std::io::Cursor]: https://doc.rust-lang.org/std/io/struct.Cursor.html
29/// [`has`]: Self::has
30pub struct ZByteReader<T: ZReaderTrait> {
31    /// Data stream
32    stream:   T,
33    position: usize
34}
35
36enum Mode {
37    // Big endian
38    BE,
39    // Little Endian
40    LE
41}
42#[cfg(feature = "std")]
43impl<T: ZReaderTrait> std::io::Read for ZByteReader<T> {
44    fn read(&mut self, buf: &mut [u8]) -> std::io::Result<usize> {
45        Ok(self.read(buf).unwrap())
46    }
47}
48
49impl<T: ZReaderTrait> ZByteReader<T> {
50    /// Create a new instance of the byte stream
51    ///
52    /// Bytes will be read from the start of `buf`.
53    ///
54    /// `buf` is expected to live as long as this and
55    /// all references to it live
56    ///
57    /// # Returns
58    /// A byte reader which will pull bits from bye
59    pub const fn new(buf: T) -> ZByteReader<T> {
60        ZByteReader {
61            stream:   buf,
62            position: 0
63        }
64    }
65    /// Destroy this reader returning
66    /// the underlying source of the bytes
67    /// from which we were decoding
68    pub fn consume(self) -> T {
69        self.stream
70    }
71    /// Skip `num` bytes ahead of the stream.
72    ///
73    /// This bumps up the internal cursor wit a wrapping addition
74    /// The bytes between current position and `num` will be skipped
75    ///
76    /// # Arguments
77    /// `num`: How many bytes to skip
78    ///
79    /// # Note
80    /// This does not consider length of the buffer, so skipping more bytes
81    /// than possible and then reading bytes will return an error if using error variants
82    /// or zero if using non-error variants
83    ///
84    /// # Example
85    /// ```
86    /// use zune_core::bytestream::ZByteReader;
87    /// let zero_to_hundred:Vec<u8> = (0..100).collect();
88    /// let mut stream = ZByteReader::new(&zero_to_hundred);
89    /// // skip 37 bytes
90    /// stream.skip(37);
91    ///
92    /// assert_eq!(stream.get_u8(),37);
93    /// ```
94    ///
95    /// See [`rewind`](ZByteReader::rewind) for moving the internal cursor back
96    pub fn skip(&mut self, num: usize) {
97        // Can this overflow ??
98        self.position = self.position.wrapping_add(num);
99    }
100    /// Undo a buffer read by moving the position pointer `num`
101    /// bytes behind.
102    ///
103    /// This operation will saturate at zero
104    pub fn rewind(&mut self, num: usize) {
105        self.position = self.position.saturating_sub(num);
106    }
107
108    /// Return whether the underlying buffer
109    /// has `num` bytes available for reading
110    ///
111    /// # Example
112    ///
113    /// ```
114    /// use zune_core::bytestream::ZByteReader;
115    /// let data = [0_u8;120];
116    /// let reader = ZByteReader::new(data.as_slice());
117    /// assert!(reader.has(3));
118    /// assert!(!reader.has(121));
119    /// ```
120    #[inline]
121    pub fn has(&self, num: usize) -> bool {
122        self.position.saturating_add(num) <= self.stream.get_len()
123    }
124    /// Get number of bytes available in the stream
125    #[inline]
126    pub fn get_bytes_left(&self) -> usize {
127        // Must be saturating to prevent underflow
128        self.stream.get_len().saturating_sub(self.position)
129    }
130    /// Get length of the underlying buffer.
131    ///
132    /// To get the number of bytes left in the buffer,
133    /// use [remaining] method
134    ///
135    /// [remaining]: Self::remaining
136    #[inline]
137    pub fn len(&self) -> usize {
138        self.stream.get_len()
139    }
140    /// Return true if the underlying buffer stream is empty
141    #[inline]
142    pub fn is_empty(&self) -> bool {
143        self.stream.get_len() == 0
144    }
145    /// Get current position of the buffer.
146    #[inline]
147    pub const fn get_position(&self) -> usize {
148        self.position
149    }
150    /// Return true whether or not we read to the end of the
151    /// buffer and have no more bytes left.
152    #[inline]
153    pub fn eof(&self) -> bool {
154        self.position >= self.len()
155    }
156    /// Get number of bytes unread inside this
157    /// stream.
158    ///
159    /// To get the length of the underlying stream,
160    /// use [len] method
161    ///
162    /// [len]: Self::len()
163    #[inline]
164    pub fn remaining(&self) -> usize {
165        self.stream.get_len().saturating_sub(self.position)
166    }
167    /// Get a part of the bytestream as a reference.
168    ///
169    /// This increments the position to point past the bytestream
170    /// if position+num is in bounds
171    pub fn get(&mut self, num: usize) -> Result<&[u8], &'static str> {
172        match self.stream.get_slice(self.position..self.position + num) {
173            Some(bytes) => {
174                self.position += num;
175                Ok(bytes)
176            }
177            None => Err(ERROR_MSG)
178        }
179    }
180    /// Look ahead position bytes and return a reference
181    /// to num_bytes from that position, or an error if the
182    /// peek would be out of bounds.
183    ///
184    /// This doesn't increment the position, bytes would have to be discarded
185    /// at a later point.
186    #[inline]
187    pub fn peek_at(&self, position: usize, num_bytes: usize) -> Result<&[u8], &'static str> {
188        let start = self.position + position;
189        let end = self.position + position + num_bytes;
190
191        match self.stream.get_slice(start..end) {
192            Some(bytes) => Ok(bytes),
193            None => Err(ERROR_MSG)
194        }
195    }
196    /// Get a fixed amount of bytes or return an error if we cant
197    /// satisfy the read
198    ///
199    /// This should be combined with [`has`] since if there are no
200    /// more bytes you get an error.
201    ///
202    /// But it's useful for cases where you expect bytes but they are not present
203    ///
204    /// For the zero  variant see, [`get_fixed_bytes_or_zero`]
205    ///
206    /// # Example
207    /// ```rust
208    /// use zune_core::bytestream::ZByteReader;
209    /// let mut stream = ZByteReader::new([0x0,0x5,0x3,0x2].as_slice());
210    /// let first_bytes = stream.get_fixed_bytes_or_err::<10>(); // not enough bytes
211    /// assert!(first_bytes.is_err());
212    /// ```
213    ///
214    /// [`has`]:Self::has
215    /// [`get_fixed_bytes_or_zero`]: Self::get_fixed_bytes_or_zero
216    #[inline]
217    pub fn get_fixed_bytes_or_err<const N: usize>(&mut self) -> Result<[u8; N], &'static str> {
218        let mut byte_store: [u8; N] = [0; N];
219
220        match self.stream.get_slice(self.position..self.position + N) {
221            Some(bytes) => {
222                self.position += N;
223                byte_store.copy_from_slice(bytes);
224
225                Ok(byte_store)
226            }
227            None => Err(ERROR_MSG)
228        }
229    }
230
231    /// Get a fixed amount of bytes or return a zero array size
232    /// if we can't satisfy the read
233    ///
234    /// This should be combined with [`has`] since if there are no
235    /// more bytes you get a zero initialized array
236    ///
237    /// For the error variant see, [`get_fixed_bytes_or_err`]
238    ///
239    /// # Example
240    /// ```rust
241    /// use zune_core::bytestream::ZByteReader;
242    /// let mut stream = ZByteReader::new([0x0,0x5,0x3,0x2].as_slice());
243    /// let first_bytes = stream.get_fixed_bytes_or_zero::<2>();
244    /// assert_eq!(first_bytes,[0x0,0x5]);
245    /// ```
246    ///
247    /// [`has`]:Self::has
248    /// [`get_fixed_bytes_or_err`]: Self::get_fixed_bytes_or_err
249    #[inline]
250    pub fn get_fixed_bytes_or_zero<const N: usize>(&mut self) -> [u8; N] {
251        let mut byte_store: [u8; N] = [0; N];
252
253        match self.stream.get_slice(self.position..self.position + N) {
254            Some(bytes) => {
255                self.position += N;
256                byte_store.copy_from_slice(bytes);
257
258                byte_store
259            }
260            None => byte_store
261        }
262    }
263    #[inline]
264    /// Skip bytes until a condition becomes false or the stream runs out of bytes
265    ///
266    /// # Example
267    ///
268    /// ```rust
269    /// use zune_core::bytestream::ZByteReader;
270    /// let mut stream = ZByteReader::new([0;10].as_slice());
271    /// stream.skip_until_false(|x| x.is_ascii()) // skip until we meet a non ascii character
272    /// ```
273    pub fn skip_until_false<F: Fn(u8) -> bool>(&mut self, func: F) {
274        // iterate until we have no more bytes
275        while !self.eof() {
276            // get a byte from stream
277            let byte = self.get_u8();
278
279            if !(func)(byte) {
280                // function returned false meaning we stop skipping
281                self.rewind(1);
282                break;
283            }
284        }
285    }
286    /// Return the remaining unread bytes in this byte reader
287    pub fn remaining_bytes(&self) -> &[u8] {
288        debug_assert!(self.position <= self.len());
289        self.stream.get_slice(self.position..self.len()).unwrap()
290    }
291
292    pub fn read(&mut self, buf: &mut [u8]) -> Result<usize, &'static str> {
293        let buf_length = buf.len();
294        let start = self.position;
295        let end = min(self.len(), self.position + buf_length);
296        let diff = end - start;
297
298        buf[0..diff].copy_from_slice(self.stream.get_slice(start..end).unwrap());
299
300        self.skip(diff);
301
302        Ok(diff)
303    }
304
305    /// Read enough bytes to fill in
306    pub fn read_exact(&mut self, buf: &mut [u8]) -> Result<(), &'static str> {
307        let size = self.read(buf)?;
308
309        if size != buf.len() {
310            return Err("Could not read into the whole buffer");
311        }
312        Ok(())
313    }
314
315    /// Set the cursor position
316    ///
317    /// After this, all reads will proceed from the position as an anchor
318    /// point
319    pub fn set_position(&mut self, position: usize) {
320        self.position = position;
321    }
322}
323
324macro_rules! get_single_type {
325    ($name:tt,$name2:tt,$name3:tt,$name4:tt,$name5:tt,$name6:tt,$int_type:tt) => {
326        impl<T:ZReaderTrait> ZByteReader<T>
327        {
328            #[inline(always)]
329            fn $name(&mut self, mode: Mode) -> $int_type
330            {
331                const SIZE_OF_VAL: usize = core::mem::size_of::<$int_type>();
332
333                let mut space = [0; SIZE_OF_VAL];
334
335                match self.stream.get_slice(self.position..self.position + SIZE_OF_VAL)
336                {
337                    Some(position) =>
338                    {
339                        space.copy_from_slice(position);
340                        self.position += SIZE_OF_VAL;
341
342                        match mode
343                        {
344                            Mode::LE => $int_type::from_le_bytes(space),
345                            Mode::BE => $int_type::from_be_bytes(space),
346                        }
347                    }
348                    None => 0,
349                }
350            }
351
352            #[inline(always)]
353            fn $name2(&mut self, mode: Mode) -> Result<$int_type, &'static str>
354            {
355                const SIZE_OF_VAL: usize = core::mem::size_of::<$int_type>();
356
357                let mut space = [0; SIZE_OF_VAL];
358
359                match self.stream.get_slice(self.position..self.position + SIZE_OF_VAL)
360                {
361                    Some(position) =>
362                    {
363                        space.copy_from_slice(position);
364                        self.position += SIZE_OF_VAL;
365
366                        match mode
367                        {
368                            Mode::LE => Ok($int_type::from_le_bytes(space)),
369                            Mode::BE => Ok($int_type::from_be_bytes(space)),
370                        }
371                    }
372                    None => Err(ERROR_MSG),
373                }
374            }
375            #[doc=concat!("Read ",stringify!($int_type)," as a big endian integer")]
376            #[doc=concat!("Returning an error if the underlying buffer cannot support a ",stringify!($int_type)," read.")]
377            #[inline]
378            pub fn $name3(&mut self) -> Result<$int_type, &'static str>
379            {
380                self.$name2(Mode::BE)
381            }
382
383            #[doc=concat!("Read ",stringify!($int_type)," as a little endian integer")]
384            #[doc=concat!("Returning an error if the underlying buffer cannot support a ",stringify!($int_type)," read.")]
385            #[inline]
386            pub fn $name4(&mut self) -> Result<$int_type, &'static str>
387            {
388                self.$name2(Mode::LE)
389            }
390            #[doc=concat!("Read ",stringify!($int_type)," as a big endian integer")]
391            #[doc=concat!("Returning 0 if the underlying  buffer does not have enough bytes for a ",stringify!($int_type)," read.")]
392            #[inline(always)]
393            pub fn $name5(&mut self) -> $int_type
394            {
395                self.$name(Mode::BE)
396            }
397            #[doc=concat!("Read ",stringify!($int_type)," as a little endian integer")]
398            #[doc=concat!("Returning 0 if the underlying buffer does not have enough bytes for a ",stringify!($int_type)," read.")]
399            #[inline(always)]
400            pub fn $name6(&mut self) -> $int_type
401            {
402                self.$name(Mode::LE)
403            }
404        }
405    };
406}
407// U8 implementation
408// The benefit of our own unrolled u8 impl instead of macros is that this is sometimes used in some
409// impls and is called multiple times, e.g jpeg during huffman decoding.
410// we can make some functions leaner like get_u8 is branchless
411impl<T> ZByteReader<T>
412where
413    T: ZReaderTrait
414{
415    /// Retrieve a byte from the underlying stream
416    /// returning 0 if there are no more bytes available
417    ///
418    /// This means 0 might indicate a bit or an end of stream, but
419    /// this is useful for some scenarios where one needs a byte.
420    ///
421    /// For the panicking one, see [`get_u8_err`]
422    ///
423    /// [`get_u8_err`]: Self::get_u8_err
424    #[inline(always)]
425    pub fn get_u8(&mut self) -> u8 {
426        let byte = *self.stream.get_byte(self.position).unwrap_or(&0);
427
428        self.position += usize::from(self.position < self.len());
429        byte
430    }
431
432    /// Retrieve a byte from the underlying stream
433    /// returning an error if there are no more bytes available
434    ///
435    /// For the non panicking one, see [`get_u8`]
436    ///
437    /// [`get_u8`]: Self::get_u8
438    #[inline(always)]
439    pub fn get_u8_err(&mut self) -> Result<u8, &'static str> {
440        match self.stream.get_byte(self.position) {
441            Some(byte) => {
442                self.position += 1;
443                Ok(*byte)
444            }
445            None => Err(ERROR_MSG)
446        }
447    }
448}
449
450// u16,u32,u64 -> macros
451get_single_type!(
452    get_u16_inner_or_default,
453    get_u16_inner_or_die,
454    get_u16_be_err,
455    get_u16_le_err,
456    get_u16_be,
457    get_u16_le,
458    u16
459);
460get_single_type!(
461    get_u32_inner_or_default,
462    get_u32_inner_or_die,
463    get_u32_be_err,
464    get_u32_le_err,
465    get_u32_be,
466    get_u32_le,
467    u32
468);
469get_single_type!(
470    get_u64_inner_or_default,
471    get_u64_inner_or_die,
472    get_u64_be_err,
473    get_u64_le_err,
474    get_u64_be,
475    get_u64_le,
476    u64
477);