winter_utils/serde/
byte_reader.rs

1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6#[cfg(feature = "std")]
7use alloc::string::ToString;
8use alloc::{string::String, vec::Vec};
9#[cfg(feature = "std")]
10use core::cell::{Ref, RefCell};
11#[cfg(feature = "std")]
12use std::io::BufRead;
13
14use super::{Deserializable, DeserializationError};
15
16// BYTE READER TRAIT
17// ================================================================================================
18
19/// Defines how primitive values are to be read from `Self`.
20///
21/// Whenever data is read from the reader using any of the `read_*` functions, the reader advances
22/// to the next unread byte. If the error occurs, the reader is not rolled back to the state prior
23/// to calling any of the function.
24pub trait ByteReader {
25    // REQUIRED METHODS
26    // --------------------------------------------------------------------------------------------
27
28    /// Returns a single byte read from `self`.
29    ///
30    /// # Errors
31    /// Returns a [DeserializationError] error the reader is at EOF.
32    fn read_u8(&mut self) -> Result<u8, DeserializationError>;
33
34    /// Returns the next byte to be read from `self` without advancing the reader to the next byte.
35    ///
36    /// # Errors
37    /// Returns a [DeserializationError] error the reader is at EOF.
38    fn peek_u8(&self) -> Result<u8, DeserializationError>;
39
40    /// Returns a slice of bytes of the specified length read from `self`.
41    ///
42    /// # Errors
43    /// Returns a [DeserializationError] if a slice of the specified length could not be read
44    /// from `self`.
45    fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError>;
46
47    /// Returns a byte array of length `N` read from `self`.
48    ///
49    /// # Errors
50    /// Returns a [DeserializationError] if an array of the specified length could not be read
51    /// from `self`.
52    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError>;
53
54    /// Checks if it is possible to read at least `num_bytes` bytes from this ByteReader
55    ///
56    /// # Errors
57    /// Returns an error if, when reading the requested number of bytes, we go beyond the
58    /// the data available in the reader.
59    fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError>;
60
61    /// Returns true if there are more bytes left to be read from `self`.
62    fn has_more_bytes(&self) -> bool;
63
64    // PROVIDED METHODS
65    // --------------------------------------------------------------------------------------------
66
67    /// Returns a boolean value read from `self` consuming 1 byte from the reader.
68    ///
69    /// # Errors
70    /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
71    fn read_bool(&mut self) -> Result<bool, DeserializationError> {
72        let byte = self.read_u8()?;
73        match byte {
74            0 => Ok(false),
75            1 => Ok(true),
76            _ => Err(DeserializationError::InvalidValue(format!("{byte} is not a boolean value"))),
77        }
78    }
79
80    /// Returns a u16 value read from `self` in little-endian byte order.
81    ///
82    /// # Errors
83    /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
84    fn read_u16(&mut self) -> Result<u16, DeserializationError> {
85        let bytes = self.read_array::<2>()?;
86        Ok(u16::from_le_bytes(bytes))
87    }
88
89    /// Returns a u32 value read from `self` in little-endian byte order.
90    ///
91    /// # Errors
92    /// Returns a [DeserializationError] if a u32 value could not be read from `self`.
93    fn read_u32(&mut self) -> Result<u32, DeserializationError> {
94        let bytes = self.read_array::<4>()?;
95        Ok(u32::from_le_bytes(bytes))
96    }
97
98    /// Returns a u64 value read from `self` in little-endian byte order.
99    ///
100    /// # Errors
101    /// Returns a [DeserializationError] if a u64 value could not be read from `self`.
102    fn read_u64(&mut self) -> Result<u64, DeserializationError> {
103        let bytes = self.read_array::<8>()?;
104        Ok(u64::from_le_bytes(bytes))
105    }
106
107    /// Returns a u128 value read from `self` in little-endian byte order.
108    ///
109    /// # Errors
110    /// Returns a [DeserializationError] if a u128 value could not be read from `self`.
111    fn read_u128(&mut self) -> Result<u128, DeserializationError> {
112        let bytes = self.read_array::<16>()?;
113        Ok(u128::from_le_bytes(bytes))
114    }
115
116    /// Returns a usize value read from `self` in [vint64](https://docs.rs/vint64/latest/vint64/)
117    /// format.
118    ///
119    /// # Errors
120    /// Returns a [DeserializationError] if:
121    /// * usize value could not be read from `self`.
122    /// * encoded value is greater than `usize` maximum value on a given platform.
123    fn read_usize(&mut self) -> Result<usize, DeserializationError> {
124        let first_byte = self.peek_u8()?;
125        let length = first_byte.trailing_zeros() as usize + 1;
126
127        let result = if length == 9 {
128            // 9-byte special case
129            self.read_u8()?;
130            let value = self.read_array::<8>()?;
131            u64::from_le_bytes(value)
132        } else {
133            let mut encoded = [0u8; 8];
134            let value = self.read_slice(length)?;
135            encoded[..length].copy_from_slice(value);
136            u64::from_le_bytes(encoded) >> length
137        };
138
139        // check if the result value is within acceptable bounds for `usize` on a given platform
140        if result > usize::MAX as u64 {
141            return Err(DeserializationError::InvalidValue(format!(
142                "Encoded value must be less than {}, but {} was provided",
143                usize::MAX,
144                result
145            )));
146        }
147
148        Ok(result as usize)
149    }
150
151    /// Returns a byte vector of the specified length read from `self`.
152    ///
153    /// # Errors
154    /// Returns a [DeserializationError] if a vector of the specified length could not be read
155    /// from `self`.
156    fn read_vec(&mut self, len: usize) -> Result<Vec<u8>, DeserializationError> {
157        let data = self.read_slice(len)?;
158        Ok(data.to_vec())
159    }
160
161    /// Returns a String of the specified length read from `self`.
162    ///
163    /// # Errors
164    /// Returns a [DeserializationError] if a String of the specified length could not be read
165    /// from `self`.
166    fn read_string(&mut self, num_bytes: usize) -> Result<String, DeserializationError> {
167        let data = self.read_vec(num_bytes)?;
168        String::from_utf8(data).map_err(|err| DeserializationError::InvalidValue(format!("{err}")))
169    }
170
171    /// Reads a deserializable value from `self`.
172    ///
173    /// # Errors
174    /// Returns a [DeserializationError] if the specified value could not be read from `self`.
175    fn read<D>(&mut self) -> Result<D, DeserializationError>
176    where
177        Self: Sized,
178        D: Deserializable,
179    {
180        D::read_from(self)
181    }
182
183    /// Reads a sequence of bytes from `self`, attempts to deserialize these bytes into a vector
184    /// with the specified number of `D` elements, and returns the result.
185    ///
186    /// # Errors
187    /// Returns a [DeserializationError] if the specified number elements could not be read from
188    /// `self`.
189    fn read_many<D>(&mut self, num_elements: usize) -> Result<Vec<D>, DeserializationError>
190    where
191        Self: Sized,
192        D: Deserializable,
193    {
194        let mut result = Vec::with_capacity(num_elements);
195        for _ in 0..num_elements {
196            let element = D::read_from(self)?;
197            result.push(element)
198        }
199        Ok(result)
200    }
201}
202
203// STANDARD LIBRARY ADAPTER
204// ================================================================================================
205
206/// An adapter of [ByteReader] to any type that implements [std::io::Read]
207///
208/// In particular, this covers things like [std::fs::File], standard input, etc.
209#[cfg(feature = "std")]
210pub struct ReadAdapter<'a> {
211    // NOTE: The [ByteReader] trait does not currently support reader implementations that require
212    // mutation during `peek_u8`, `has_more_bytes`, and `check_eor`. These (or equivalent)
213    // operations on the standard library [std::io::BufRead] trait require a mutable reference, as
214    // it may be necessary to read from the underlying input to implement them.
215    //
216    // To handle this, we wrap the underlying reader in an [RefCell], this allows us to mutate the
217    // reader if necessary during a call to one of the above-mentioned trait methods, without
218    // sacrificing safety - at the cost of enforcing Rust's borrowing semantics dynamically.
219    //
220    // This should not be a problem in practice, except in the case where `read_slice` is called,
221    // and the reference returned is from `reader` directly, rather than `buf`. If a call to one
222    // of the above-mentioned methods is made while that reference is live, and we attempt to read
223    // from `reader`, a panic will occur.
224    //
225    // Ultimately, this should be addressed by making the [ByteReader] trait align with the
226    // standard library I/O traits, so this is a temporary solution.
227    reader: RefCell<std::io::BufReader<&'a mut dyn std::io::Read>>,
228    // A temporary buffer to store chunks read from `reader` that are larger than what is required
229    // for the higher-level [ByteReader] APIs.
230    //
231    // By default we attempt to satisfy reads from `reader` directly, but that is not always
232    // possible.
233    buf: alloc::vec::Vec<u8>,
234    // The position in `buf` at which we should start reading the next byte, when `buf` is
235    // non-empty.
236    pos: usize,
237    // This is set when we attempt to read from `reader` and get an empty buffer. This indicates
238    // that once we exhaust `buf`, we have truly reached end-of-file.
239    //
240    // We will use this to more accurately handle functions like `has_more_bytes` when this is set.
241    guaranteed_eof: bool,
242}
243
244#[cfg(feature = "std")]
245impl<'a> ReadAdapter<'a> {
246    /// Create a new [ByteReader] adapter for the given implementation of [std::io::Read]
247    pub fn new(reader: &'a mut dyn std::io::Read) -> Self {
248        Self {
249            reader: RefCell::new(std::io::BufReader::with_capacity(256, reader)),
250            buf: Default::default(),
251            pos: 0,
252            guaranteed_eof: false,
253        }
254    }
255
256    /// Get the internal adapter buffer as a (possibly empty) slice of bytes
257    #[inline(always)]
258    fn buffer(&self) -> &[u8] {
259        self.buf.get(self.pos..).unwrap_or(&[])
260    }
261
262    /// Get the internal adapter buffer as a slice of bytes, or `None` if the buffer is empty
263    #[inline(always)]
264    fn non_empty_buffer(&self) -> Option<&[u8]> {
265        self.buf.get(self.pos..).filter(|b| !b.is_empty())
266    }
267
268    /// Return the current reader buffer as a (possibly empty) slice of bytes.
269    ///
270    /// This buffer being empty _does not_ mean we're at EOF, you must call
271    /// [non_empty_reader_buffer_mut] first.
272    #[inline(always)]
273    fn reader_buffer(&self) -> Ref<'_, [u8]> {
274        Ref::map(self.reader.borrow(), |r| r.buffer())
275    }
276
277    /// Return the current reader buffer, reading from the underlying reader
278    /// if the buffer is empty.
279    ///
280    /// Returns `Ok` only if the buffer is non-empty, and no errors occurred
281    /// while filling it (if filling was needed).
282    fn non_empty_reader_buffer_mut(&mut self) -> Result<&[u8], DeserializationError> {
283        use std::io::ErrorKind;
284        let buf = self.reader.get_mut().fill_buf().map_err(|e| match e.kind() {
285            ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
286            e => DeserializationError::UnknownError(e.to_string()),
287        })?;
288        if buf.is_empty() {
289            self.guaranteed_eof = true;
290            Err(DeserializationError::UnexpectedEOF)
291        } else {
292            Ok(buf)
293        }
294    }
295
296    /// Same as [non_empty_reader_buffer_mut], but with dynamically-enforced
297    /// borrow check rules so that it can be called in functions like `peek_u8`.
298    ///
299    /// This comes with overhead for the dynamic checks, so you should prefer
300    /// to call [non_empty_reader_buffer_mut] if you already have a mutable
301    /// reference to `self`
302    fn non_empty_reader_buffer(&self) -> Result<Ref<'_, [u8]>, DeserializationError> {
303        use std::io::ErrorKind;
304        let mut reader = self.reader.borrow_mut();
305        let buf = reader.fill_buf().map_err(|e| match e.kind() {
306            ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
307            e => DeserializationError::UnknownError(e.to_string()),
308        })?;
309        if buf.is_empty() {
310            Err(DeserializationError::UnexpectedEOF)
311        } else {
312            // Re-borrow immutably
313            drop(reader);
314            Ok(self.reader_buffer())
315        }
316    }
317
318    /// Returns true if there is sufficient capacity remaining in `buf` to hold `n` bytes
319    #[inline]
320    fn has_remaining_capacity(&self, n: usize) -> bool {
321        let remaining = self.buf.capacity() - self.buffer().len();
322        remaining >= n
323    }
324
325    /// Takes the next byte from the input, returning an error if the operation fails
326    fn pop(&mut self) -> Result<u8, DeserializationError> {
327        if let Some(byte) = self.non_empty_buffer().map(|b| b[0]) {
328            self.pos += 1;
329            return Ok(byte);
330        }
331        let result = self.non_empty_reader_buffer_mut().map(|b| b[0]);
332        if result.is_ok() {
333            self.reader.get_mut().consume(1);
334        } else {
335            self.guaranteed_eof = true;
336        }
337        result
338    }
339
340    /// Takes the next `N` bytes from the input as an array, returning an error if the operation
341    /// fails
342    fn read_exact<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
343        let buf = self.buffer();
344        let mut output = [0; N];
345        match buf.len() {
346            0 => {
347                let buf = self.non_empty_reader_buffer_mut()?;
348                if buf.len() < N {
349                    return Err(DeserializationError::UnexpectedEOF);
350                }
351                // SAFETY: This copy is guaranteed to be safe, as we have validated above
352                // that `buf` has at least N bytes, and `output` is defined to be exactly
353                // N bytes.
354                unsafe {
355                    core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
356                }
357                self.reader.get_mut().consume(N);
358            },
359            n if n >= N => {
360                // SAFETY: This copy is guaranteed to be safe, as we have validated above
361                // that `buf` has at least N bytes, and `output` is defined to be exactly
362                // N bytes.
363                unsafe {
364                    core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
365                }
366                self.pos += N;
367            },
368            n => {
369                // We have to fill from both the local and reader buffers
370                self.non_empty_reader_buffer_mut()?;
371                let reader_buf = self.reader_buffer();
372                match reader_buf.len() {
373                    #[cfg(debug_assertions)]
374                    0 => unreachable!("expected reader buffer to be non-empty to reach here"),
375                    #[cfg(not(debug_assertions))]
376                    // SAFETY: The call to `non_empty_reader_buffer_mut` will return an error
377                    // if `reader_buffer` is non-empty, as a result is is impossible to reach
378                    // here with a length of 0.
379                    0 => unsafe { core::hint::unreachable_unchecked() },
380                    // We got enough in one request
381                    m if m + n >= N => {
382                        let needed = N - n;
383                        let dst = output.as_mut_ptr();
384                        // SAFETY: Both copies are guaranteed to be in-bounds:
385                        //
386                        // * `output` is defined to be exactly N bytes
387                        // * `buf` is guaranteed to be < N bytes
388                        // * `reader_buf` is guaranteed to have the remaining bytes needed,
389                        // and we only copy exactly that many bytes
390                        unsafe {
391                            core::ptr::copy_nonoverlapping(self.buffer().as_ptr(), dst, n);
392                            core::ptr::copy_nonoverlapping(reader_buf.as_ptr(), dst.add(n), needed);
393                            drop(reader_buf);
394                        }
395                        self.pos += n;
396                        self.reader.get_mut().consume(needed);
397                    },
398                    // We didn't get enough, but haven't necessarily reached eof yet, so fall back
399                    // to filling `self.buf`
400                    m => {
401                        let needed = N - (m + n);
402                        drop(reader_buf);
403                        self.buffer_at_least(needed)?;
404                        debug_assert!(self.buffer().len() >= N, "expected buffer to be at least {N} bytes after call to buffer_at_least");
405                        // SAFETY: This is guaranteed to be an in-bounds copy
406                        unsafe {
407                            core::ptr::copy_nonoverlapping(
408                                self.buffer().as_ptr(),
409                                output.as_mut_ptr(),
410                                N,
411                            );
412                        }
413                        self.pos += N;
414                        return Ok(output);
415                    },
416                }
417            },
418        }
419
420        // Check if we should reset our internal buffer
421        if self.buffer().is_empty() && self.pos > 0 {
422            unsafe {
423                self.buf.set_len(0);
424            }
425        }
426
427        Ok(output)
428    }
429
430    /// Fill `self.buf` with `count` bytes
431    ///
432    /// This should only be called when we can't read from the reader directly
433    fn buffer_at_least(&mut self, mut count: usize) -> Result<(), DeserializationError> {
434        // Read until we have at least `count` bytes, or until we reach end-of-file,
435        // which ever comes first.
436        loop {
437            // If we have successfully read `count` bytes, we're done
438            if count == 0 || self.buffer().len() >= count {
439                break Ok(());
440            }
441
442            // This operation will return an error if the underlying reader hits EOF
443            self.non_empty_reader_buffer_mut()?;
444
445            // Extend `self.buf` with the bytes read from the underlying reader.
446            //
447            // NOTE: We have to re-borrow the reader buffer here, since we can't get a mutable
448            // reference to `self.buf` while holding an immutable reference to the reader buffer.
449            let reader = self.reader.get_mut();
450            let buf = reader.buffer();
451            let consumed = buf.len();
452            self.buf.extend_from_slice(buf);
453            reader.consume(consumed);
454            count = count.saturating_sub(consumed);
455        }
456    }
457}
458
459#[cfg(feature = "std")]
460impl ByteReader for ReadAdapter<'_> {
461    #[inline(always)]
462    fn read_u8(&mut self) -> Result<u8, DeserializationError> {
463        self.pop()
464    }
465
466    /// NOTE: If we happen to not have any bytes buffered yet when this is called, then we will be
467    /// forced to try and read from the underlying reader. This requires a mutable reference, which
468    /// is obtained dynamically via [RefCell].
469    ///
470    /// <div class="warning">
471    /// Callers must ensure that they do not hold any immutable references to the buffer of this
472    /// reader when calling this function so as to avoid a situation in which the dynamic borrow
473    /// check fails. Specifically, you must not be holding a reference to the result of
474    /// [Self::read_slice] when this function is called.
475    /// </div>
476    fn peek_u8(&self) -> Result<u8, DeserializationError> {
477        if let Some(byte) = self.buffer().first() {
478            return Ok(*byte);
479        }
480        self.non_empty_reader_buffer().map(|b| b[0])
481    }
482
483    fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
484        // Edge case
485        if len == 0 {
486            return Ok(&[]);
487        }
488
489        // If we have unused buffer, and the consumed portion is
490        // large enough, we will move the unused portion of the buffer
491        // to the start, freeing up bytes at the end for more reads
492        // before forcing a reallocation
493        let should_optimize_storage = self.pos >= 16 && !self.has_remaining_capacity(len);
494        if should_optimize_storage {
495            // We're going to optimize storage first
496            let buf = self.buffer();
497            let src = buf.as_ptr();
498            let count = buf.len();
499            let dst = self.buf.as_mut_ptr();
500            unsafe {
501                core::ptr::copy(src, dst, count);
502                self.buf.set_len(count);
503                self.pos = 0;
504            }
505        }
506
507        // Fill the buffer so we have at least `len` bytes available,
508        // this will return an error if we hit EOF first
509        self.buffer_at_least(len)?;
510
511        let slice = &self.buf[self.pos..(self.pos + len)];
512        self.pos += len;
513        Ok(slice)
514    }
515
516    #[inline]
517    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
518        if N == 0 {
519            return Ok([0; N]);
520        }
521        self.read_exact()
522    }
523
524    fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
525        // Do we have sufficient data in the local buffer?
526        let buffer_len = self.buffer().len();
527        if buffer_len >= num_bytes {
528            return Ok(());
529        }
530
531        // What about if we include what is in the local buffer and the reader's buffer?
532        let reader_buffer_len = self.non_empty_reader_buffer().map(|b| b.len())?;
533        let buffer_len = buffer_len + reader_buffer_len;
534        if buffer_len >= num_bytes {
535            return Ok(());
536        }
537
538        // We have no more input, thus can't fulfill a request of `num_bytes`
539        if self.guaranteed_eof {
540            return Err(DeserializationError::UnexpectedEOF);
541        }
542
543        // Because this function is read-only, we must optimistically assume we can read `num_bytes`
544        // from the input, and fail later if that does not hold. We know we're not at EOF yet, but
545        // that's all we can say without buffering more from the reader. We could make use of
546        // `buffer_at_least`, which would guarantee a correct result, but it would also impose
547        // additional restrictions on the use of this function, e.g. not using it while holding a
548        // reference returned from `read_slice`. Since it is not a memory safety violation to return
549        // an optimistic result here, it makes for a better tradeoff.
550        Ok(())
551    }
552
553    #[inline]
554    fn has_more_bytes(&self) -> bool {
555        !self.buffer().is_empty() || self.non_empty_reader_buffer().is_ok()
556    }
557}
558
559// CURSOR
560// ================================================================================================
561
562#[cfg(feature = "std")]
563macro_rules! cursor_remaining_buf {
564    ($cursor:ident) => {{
565        let buf = $cursor.get_ref().as_ref();
566        let start = $cursor.position().min(buf.len() as u64) as usize;
567        &buf[start..]
568    }};
569}
570
571#[cfg(feature = "std")]
572impl<T: AsRef<[u8]>> ByteReader for std::io::Cursor<T> {
573    fn read_u8(&mut self) -> Result<u8, DeserializationError> {
574        let buf = cursor_remaining_buf!(self);
575        if buf.is_empty() {
576            Err(DeserializationError::UnexpectedEOF)
577        } else {
578            let byte = buf[0];
579            self.set_position(self.position() + 1);
580            Ok(byte)
581        }
582    }
583
584    fn peek_u8(&self) -> Result<u8, DeserializationError> {
585        cursor_remaining_buf!(self)
586            .first()
587            .copied()
588            .ok_or(DeserializationError::UnexpectedEOF)
589    }
590
591    fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
592        let pos = self.position();
593        let size = self.get_ref().as_ref().len() as u64;
594        if size.saturating_sub(pos) < len as u64 {
595            Err(DeserializationError::UnexpectedEOF)
596        } else {
597            self.set_position(pos + len as u64);
598            let start = pos.min(size) as usize;
599            Ok(&self.get_ref().as_ref()[start..(start + len)])
600        }
601    }
602
603    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
604        self.read_slice(N).map(|bytes| {
605            let mut result = [0u8; N];
606            result.copy_from_slice(bytes);
607            result
608        })
609    }
610
611    fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
612        if cursor_remaining_buf!(self).len() >= num_bytes {
613            Ok(())
614        } else {
615            Err(DeserializationError::UnexpectedEOF)
616        }
617    }
618
619    #[inline]
620    fn has_more_bytes(&self) -> bool {
621        let pos = self.position();
622        let size = self.get_ref().as_ref().len() as u64;
623        pos < size
624    }
625}
626
627// SLICE READER
628// ================================================================================================
629
630/// Implements [ByteReader] trait for a slice of bytes.
631///
632/// NOTE: If you are building with the `std` feature, you should probably prefer [std::io::Cursor]
633/// instead. However, [SliceReader] is still useful in no-std environments until stabilization of
634/// the `core_io_borrowed_buf` feature.
635pub struct SliceReader<'a> {
636    source: &'a [u8],
637    pos: usize,
638}
639
640impl<'a> SliceReader<'a> {
641    /// Creates a new slice reader from the specified slice.
642    pub fn new(source: &'a [u8]) -> Self {
643        SliceReader { source, pos: 0 }
644    }
645}
646
647impl ByteReader for SliceReader<'_> {
648    fn read_u8(&mut self) -> Result<u8, DeserializationError> {
649        self.check_eor(1)?;
650        let result = self.source[self.pos];
651        self.pos += 1;
652        Ok(result)
653    }
654
655    fn peek_u8(&self) -> Result<u8, DeserializationError> {
656        self.check_eor(1)?;
657        Ok(self.source[self.pos])
658    }
659
660    fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
661        self.check_eor(len)?;
662        let result = &self.source[self.pos..self.pos + len];
663        self.pos += len;
664        Ok(result)
665    }
666
667    fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
668        self.check_eor(N)?;
669        let mut result = [0_u8; N];
670        result.copy_from_slice(&self.source[self.pos..self.pos + N]);
671        self.pos += N;
672        Ok(result)
673    }
674
675    fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
676        if self.pos + num_bytes > self.source.len() {
677            return Err(DeserializationError::UnexpectedEOF);
678        }
679        Ok(())
680    }
681
682    fn has_more_bytes(&self) -> bool {
683        self.pos < self.source.len()
684    }
685}
686
687#[cfg(all(test, feature = "std"))]
688mod tests {
689    use std::io::Cursor;
690
691    use super::*;
692    use crate::ByteWriter;
693
694    #[test]
695    fn read_adapter_empty() -> Result<(), DeserializationError> {
696        let mut reader = std::io::empty();
697        let mut adapter = ReadAdapter::new(&mut reader);
698        assert!(!adapter.has_more_bytes());
699        assert_eq!(adapter.check_eor(8), Err(DeserializationError::UnexpectedEOF));
700        assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
701        assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
702        assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
703        assert_eq!(adapter.read_slice(1), Err(DeserializationError::UnexpectedEOF));
704        assert_eq!(adapter.read_array(), Ok([]));
705        assert_eq!(adapter.read_array::<1>(), Err(DeserializationError::UnexpectedEOF));
706        Ok(())
707    }
708
709    #[test]
710    fn read_adapter_passthrough() -> Result<(), DeserializationError> {
711        let mut reader = std::io::repeat(0b101);
712        let mut adapter = ReadAdapter::new(&mut reader);
713        assert!(adapter.has_more_bytes());
714        assert_eq!(adapter.check_eor(8), Ok(()));
715        assert_eq!(adapter.peek_u8(), Ok(0b101));
716        assert_eq!(adapter.read_u8(), Ok(0b101));
717        assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
718        assert_eq!(adapter.read_slice(4), Ok([0b101, 0b101, 0b101, 0b101].as_slice()));
719        assert_eq!(adapter.read_array(), Ok([]));
720        assert_eq!(adapter.read_array(), Ok([0b101, 0b101]));
721        Ok(())
722    }
723
724    #[test]
725    fn read_adapter_exact() {
726        const VALUE: usize = 2048;
727        let mut reader = Cursor::new(VALUE.to_le_bytes());
728        let mut adapter = ReadAdapter::new(&mut reader);
729        assert_eq!(usize::from_le_bytes(adapter.read_array().unwrap()), VALUE);
730        assert!(!adapter.has_more_bytes());
731        assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
732        assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
733    }
734
735    #[test]
736    fn read_adapter_roundtrip() {
737        const VALUE: usize = 2048;
738
739        // Write VALUE to storage
740        let mut cursor = Cursor::new([0; core::mem::size_of::<usize>()]);
741        cursor.write_usize(VALUE);
742
743        // Read VALUE from storage
744        cursor.set_position(0);
745        let mut adapter = ReadAdapter::new(&mut cursor);
746
747        assert_eq!(adapter.read_usize(), Ok(VALUE));
748    }
749
750    #[test]
751    fn read_adapter_for_file() {
752        use std::fs::File;
753
754        use crate::ByteWriter;
755
756        let path = std::env::temp_dir().join("read_adapter_for_file.bin");
757
758        // Encode some data to a buffer, then write that buffer to a file
759        {
760            let mut buf = Vec::<u8>::with_capacity(256);
761            buf.write_bytes(b"MAGIC\0");
762            buf.write_bool(true);
763            buf.write_u32(0xbeef);
764            buf.write_usize(0xfeed);
765            buf.write_u16(0x5);
766
767            std::fs::write(&path, &buf).unwrap();
768        }
769
770        // Open the file, and try to decode the encoded items
771        let mut file = File::open(&path).unwrap();
772        let mut reader = ReadAdapter::new(&mut file);
773        assert_eq!(reader.peek_u8().unwrap(), b'M');
774        assert_eq!(reader.read_slice(6).unwrap(), b"MAGIC\0");
775        assert!(reader.read_bool().unwrap());
776        assert_eq!(reader.read_u32().unwrap(), 0xbeef);
777        assert_eq!(reader.read_usize().unwrap(), 0xfeed);
778        assert_eq!(reader.read_u16().unwrap(), 0x5);
779        assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
780    }
781
782    #[test]
783    fn read_adapter_issue_383() {
784        const STR_BYTES: &[u8] = b"just a string";
785
786        use std::fs::File;
787
788        use crate::ByteWriter;
789
790        let path = std::env::temp_dir().join("issue_383.bin");
791
792        // Encode some data to a buffer, then write that buffer to a file
793        {
794            let mut buf = vec![0u8; 1024];
795            unsafe {
796                buf.set_len(0);
797            }
798            buf.write_u128(2 * u64::MAX as u128);
799            unsafe {
800                buf.set_len(512);
801            }
802            buf.write_bytes(STR_BYTES);
803            buf.write_u32(0xbeef);
804
805            std::fs::write(&path, &buf).unwrap();
806        }
807
808        // Open the file, and try to decode the encoded items
809        let mut file = File::open(&path).unwrap();
810        let mut reader = ReadAdapter::new(&mut file);
811        assert_eq!(reader.read_u128().unwrap(), 2 * u64::MAX as u128);
812        assert_eq!(reader.buf.len(), 0);
813        assert_eq!(reader.pos, 0);
814        // Read to offset 512 (we're 16 bytes into the underlying file, i.e. offset of 496)
815        reader.read_slice(496).unwrap();
816        assert_eq!(reader.buf.len(), 496);
817        assert_eq!(reader.pos, 496);
818        // The byte string is 13 bytes, followed by 4 bytes containing the trailing u32 value.
819        // We expect that the underlying reader will buffer the remaining bytes of the file when
820        // reading STR_BYTES, so the total size of our adapter's buffer should be
821        // 496 + STR_BYTES.len() + size_of::<u32>();
822        assert_eq!(reader.read_slice(STR_BYTES.len()).unwrap(), STR_BYTES);
823        assert_eq!(reader.buf.len(), 496 + STR_BYTES.len() + core::mem::size_of::<u32>());
824        // We haven't read the u32 yet
825        assert_eq!(reader.pos, 509);
826        assert_eq!(reader.read_u32().unwrap(), 0xbeef);
827        // Now we have
828        assert_eq!(reader.pos, 513);
829        assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
830    }
831}