winter_utils/serde/byte_reader.rs
1// Copyright (c) Facebook, Inc. and its affiliates.
2//
3// This source code is licensed under the MIT license found in the
4// LICENSE file in the root directory of this source tree.
5
6#[cfg(feature = "std")]
7use alloc::string::ToString;
8use alloc::{string::String, vec::Vec};
9#[cfg(feature = "std")]
10use core::cell::{Ref, RefCell};
11#[cfg(feature = "std")]
12use std::io::BufRead;
13
14use super::{Deserializable, DeserializationError};
15
16// BYTE READER TRAIT
17// ================================================================================================
18
19/// Defines how primitive values are to be read from `Self`.
20///
21/// Whenever data is read from the reader using any of the `read_*` functions, the reader advances
22/// to the next unread byte. If the error occurs, the reader is not rolled back to the state prior
23/// to calling any of the function.
24pub trait ByteReader {
25 // REQUIRED METHODS
26 // --------------------------------------------------------------------------------------------
27
28 /// Returns a single byte read from `self`.
29 ///
30 /// # Errors
31 /// Returns a [DeserializationError] error the reader is at EOF.
32 fn read_u8(&mut self) -> Result<u8, DeserializationError>;
33
34 /// Returns the next byte to be read from `self` without advancing the reader to the next byte.
35 ///
36 /// # Errors
37 /// Returns a [DeserializationError] error the reader is at EOF.
38 fn peek_u8(&self) -> Result<u8, DeserializationError>;
39
40 /// Returns a slice of bytes of the specified length read from `self`.
41 ///
42 /// # Errors
43 /// Returns a [DeserializationError] if a slice of the specified length could not be read
44 /// from `self`.
45 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError>;
46
47 /// Returns a byte array of length `N` read from `self`.
48 ///
49 /// # Errors
50 /// Returns a [DeserializationError] if an array of the specified length could not be read
51 /// from `self`.
52 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError>;
53
54 /// Checks if it is possible to read at least `num_bytes` bytes from this ByteReader
55 ///
56 /// # Errors
57 /// Returns an error if, when reading the requested number of bytes, we go beyond the
58 /// the data available in the reader.
59 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError>;
60
61 /// Returns true if there are more bytes left to be read from `self`.
62 fn has_more_bytes(&self) -> bool;
63
64 // PROVIDED METHODS
65 // --------------------------------------------------------------------------------------------
66
67 /// Returns a boolean value read from `self` consuming 1 byte from the reader.
68 ///
69 /// # Errors
70 /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
71 fn read_bool(&mut self) -> Result<bool, DeserializationError> {
72 let byte = self.read_u8()?;
73 match byte {
74 0 => Ok(false),
75 1 => Ok(true),
76 _ => Err(DeserializationError::InvalidValue(format!("{byte} is not a boolean value"))),
77 }
78 }
79
80 /// Returns a u16 value read from `self` in little-endian byte order.
81 ///
82 /// # Errors
83 /// Returns a [DeserializationError] if a u16 value could not be read from `self`.
84 fn read_u16(&mut self) -> Result<u16, DeserializationError> {
85 let bytes = self.read_array::<2>()?;
86 Ok(u16::from_le_bytes(bytes))
87 }
88
89 /// Returns a u32 value read from `self` in little-endian byte order.
90 ///
91 /// # Errors
92 /// Returns a [DeserializationError] if a u32 value could not be read from `self`.
93 fn read_u32(&mut self) -> Result<u32, DeserializationError> {
94 let bytes = self.read_array::<4>()?;
95 Ok(u32::from_le_bytes(bytes))
96 }
97
98 /// Returns a u64 value read from `self` in little-endian byte order.
99 ///
100 /// # Errors
101 /// Returns a [DeserializationError] if a u64 value could not be read from `self`.
102 fn read_u64(&mut self) -> Result<u64, DeserializationError> {
103 let bytes = self.read_array::<8>()?;
104 Ok(u64::from_le_bytes(bytes))
105 }
106
107 /// Returns a u128 value read from `self` in little-endian byte order.
108 ///
109 /// # Errors
110 /// Returns a [DeserializationError] if a u128 value could not be read from `self`.
111 fn read_u128(&mut self) -> Result<u128, DeserializationError> {
112 let bytes = self.read_array::<16>()?;
113 Ok(u128::from_le_bytes(bytes))
114 }
115
116 /// Returns a usize value read from `self` in [vint64](https://docs.rs/vint64/latest/vint64/)
117 /// format.
118 ///
119 /// # Errors
120 /// Returns a [DeserializationError] if:
121 /// * usize value could not be read from `self`.
122 /// * encoded value is greater than `usize` maximum value on a given platform.
123 fn read_usize(&mut self) -> Result<usize, DeserializationError> {
124 let first_byte = self.peek_u8()?;
125 let length = first_byte.trailing_zeros() as usize + 1;
126
127 let result = if length == 9 {
128 // 9-byte special case
129 self.read_u8()?;
130 let value = self.read_array::<8>()?;
131 u64::from_le_bytes(value)
132 } else {
133 let mut encoded = [0u8; 8];
134 let value = self.read_slice(length)?;
135 encoded[..length].copy_from_slice(value);
136 u64::from_le_bytes(encoded) >> length
137 };
138
139 // check if the result value is within acceptable bounds for `usize` on a given platform
140 if result > usize::MAX as u64 {
141 return Err(DeserializationError::InvalidValue(format!(
142 "Encoded value must be less than {}, but {} was provided",
143 usize::MAX,
144 result
145 )));
146 }
147
148 Ok(result as usize)
149 }
150
151 /// Returns a byte vector of the specified length read from `self`.
152 ///
153 /// # Errors
154 /// Returns a [DeserializationError] if a vector of the specified length could not be read
155 /// from `self`.
156 fn read_vec(&mut self, len: usize) -> Result<Vec<u8>, DeserializationError> {
157 let data = self.read_slice(len)?;
158 Ok(data.to_vec())
159 }
160
161 /// Returns a String of the specified length read from `self`.
162 ///
163 /// # Errors
164 /// Returns a [DeserializationError] if a String of the specified length could not be read
165 /// from `self`.
166 fn read_string(&mut self, num_bytes: usize) -> Result<String, DeserializationError> {
167 let data = self.read_vec(num_bytes)?;
168 String::from_utf8(data).map_err(|err| DeserializationError::InvalidValue(format!("{err}")))
169 }
170
171 /// Reads a deserializable value from `self`.
172 ///
173 /// # Errors
174 /// Returns a [DeserializationError] if the specified value could not be read from `self`.
175 fn read<D>(&mut self) -> Result<D, DeserializationError>
176 where
177 Self: Sized,
178 D: Deserializable,
179 {
180 D::read_from(self)
181 }
182
183 /// Reads a sequence of bytes from `self`, attempts to deserialize these bytes into a vector
184 /// with the specified number of `D` elements, and returns the result.
185 ///
186 /// # Errors
187 /// Returns a [DeserializationError] if the specified number elements could not be read from
188 /// `self`.
189 fn read_many<D>(&mut self, num_elements: usize) -> Result<Vec<D>, DeserializationError>
190 where
191 Self: Sized,
192 D: Deserializable,
193 {
194 let mut result = Vec::with_capacity(num_elements);
195 for _ in 0..num_elements {
196 let element = D::read_from(self)?;
197 result.push(element)
198 }
199 Ok(result)
200 }
201}
202
203// STANDARD LIBRARY ADAPTER
204// ================================================================================================
205
206/// An adapter of [ByteReader] to any type that implements [std::io::Read]
207///
208/// In particular, this covers things like [std::fs::File], standard input, etc.
209#[cfg(feature = "std")]
210pub struct ReadAdapter<'a> {
211 // NOTE: The [ByteReader] trait does not currently support reader implementations that require
212 // mutation during `peek_u8`, `has_more_bytes`, and `check_eor`. These (or equivalent)
213 // operations on the standard library [std::io::BufRead] trait require a mutable reference, as
214 // it may be necessary to read from the underlying input to implement them.
215 //
216 // To handle this, we wrap the underlying reader in an [RefCell], this allows us to mutate the
217 // reader if necessary during a call to one of the above-mentioned trait methods, without
218 // sacrificing safety - at the cost of enforcing Rust's borrowing semantics dynamically.
219 //
220 // This should not be a problem in practice, except in the case where `read_slice` is called,
221 // and the reference returned is from `reader` directly, rather than `buf`. If a call to one
222 // of the above-mentioned methods is made while that reference is live, and we attempt to read
223 // from `reader`, a panic will occur.
224 //
225 // Ultimately, this should be addressed by making the [ByteReader] trait align with the
226 // standard library I/O traits, so this is a temporary solution.
227 reader: RefCell<std::io::BufReader<&'a mut dyn std::io::Read>>,
228 // A temporary buffer to store chunks read from `reader` that are larger than what is required
229 // for the higher-level [ByteReader] APIs.
230 //
231 // By default we attempt to satisfy reads from `reader` directly, but that is not always
232 // possible.
233 buf: alloc::vec::Vec<u8>,
234 // The position in `buf` at which we should start reading the next byte, when `buf` is
235 // non-empty.
236 pos: usize,
237 // This is set when we attempt to read from `reader` and get an empty buffer. This indicates
238 // that once we exhaust `buf`, we have truly reached end-of-file.
239 //
240 // We will use this to more accurately handle functions like `has_more_bytes` when this is set.
241 guaranteed_eof: bool,
242}
243
244#[cfg(feature = "std")]
245impl<'a> ReadAdapter<'a> {
246 /// Create a new [ByteReader] adapter for the given implementation of [std::io::Read]
247 pub fn new(reader: &'a mut dyn std::io::Read) -> Self {
248 Self {
249 reader: RefCell::new(std::io::BufReader::with_capacity(256, reader)),
250 buf: Default::default(),
251 pos: 0,
252 guaranteed_eof: false,
253 }
254 }
255
256 /// Get the internal adapter buffer as a (possibly empty) slice of bytes
257 #[inline(always)]
258 fn buffer(&self) -> &[u8] {
259 self.buf.get(self.pos..).unwrap_or(&[])
260 }
261
262 /// Get the internal adapter buffer as a slice of bytes, or `None` if the buffer is empty
263 #[inline(always)]
264 fn non_empty_buffer(&self) -> Option<&[u8]> {
265 self.buf.get(self.pos..).filter(|b| !b.is_empty())
266 }
267
268 /// Return the current reader buffer as a (possibly empty) slice of bytes.
269 ///
270 /// This buffer being empty _does not_ mean we're at EOF, you must call
271 /// [non_empty_reader_buffer_mut] first.
272 #[inline(always)]
273 fn reader_buffer(&self) -> Ref<'_, [u8]> {
274 Ref::map(self.reader.borrow(), |r| r.buffer())
275 }
276
277 /// Return the current reader buffer, reading from the underlying reader
278 /// if the buffer is empty.
279 ///
280 /// Returns `Ok` only if the buffer is non-empty, and no errors occurred
281 /// while filling it (if filling was needed).
282 fn non_empty_reader_buffer_mut(&mut self) -> Result<&[u8], DeserializationError> {
283 use std::io::ErrorKind;
284 let buf = self.reader.get_mut().fill_buf().map_err(|e| match e.kind() {
285 ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
286 e => DeserializationError::UnknownError(e.to_string()),
287 })?;
288 if buf.is_empty() {
289 self.guaranteed_eof = true;
290 Err(DeserializationError::UnexpectedEOF)
291 } else {
292 Ok(buf)
293 }
294 }
295
296 /// Same as [non_empty_reader_buffer_mut], but with dynamically-enforced
297 /// borrow check rules so that it can be called in functions like `peek_u8`.
298 ///
299 /// This comes with overhead for the dynamic checks, so you should prefer
300 /// to call [non_empty_reader_buffer_mut] if you already have a mutable
301 /// reference to `self`
302 fn non_empty_reader_buffer(&self) -> Result<Ref<'_, [u8]>, DeserializationError> {
303 use std::io::ErrorKind;
304 let mut reader = self.reader.borrow_mut();
305 let buf = reader.fill_buf().map_err(|e| match e.kind() {
306 ErrorKind::UnexpectedEof => DeserializationError::UnexpectedEOF,
307 e => DeserializationError::UnknownError(e.to_string()),
308 })?;
309 if buf.is_empty() {
310 Err(DeserializationError::UnexpectedEOF)
311 } else {
312 // Re-borrow immutably
313 drop(reader);
314 Ok(self.reader_buffer())
315 }
316 }
317
318 /// Returns true if there is sufficient capacity remaining in `buf` to hold `n` bytes
319 #[inline]
320 fn has_remaining_capacity(&self, n: usize) -> bool {
321 let remaining = self.buf.capacity() - self.buffer().len();
322 remaining >= n
323 }
324
325 /// Takes the next byte from the input, returning an error if the operation fails
326 fn pop(&mut self) -> Result<u8, DeserializationError> {
327 if let Some(byte) = self.non_empty_buffer().map(|b| b[0]) {
328 self.pos += 1;
329 return Ok(byte);
330 }
331 let result = self.non_empty_reader_buffer_mut().map(|b| b[0]);
332 if result.is_ok() {
333 self.reader.get_mut().consume(1);
334 } else {
335 self.guaranteed_eof = true;
336 }
337 result
338 }
339
340 /// Takes the next `N` bytes from the input as an array, returning an error if the operation
341 /// fails
342 fn read_exact<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
343 let buf = self.buffer();
344 let mut output = [0; N];
345 match buf.len() {
346 0 => {
347 let buf = self.non_empty_reader_buffer_mut()?;
348 if buf.len() < N {
349 return Err(DeserializationError::UnexpectedEOF);
350 }
351 // SAFETY: This copy is guaranteed to be safe, as we have validated above
352 // that `buf` has at least N bytes, and `output` is defined to be exactly
353 // N bytes.
354 unsafe {
355 core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
356 }
357 self.reader.get_mut().consume(N);
358 },
359 n if n >= N => {
360 // SAFETY: This copy is guaranteed to be safe, as we have validated above
361 // that `buf` has at least N bytes, and `output` is defined to be exactly
362 // N bytes.
363 unsafe {
364 core::ptr::copy_nonoverlapping(buf.as_ptr(), output.as_mut_ptr(), N);
365 }
366 self.pos += N;
367 },
368 n => {
369 // We have to fill from both the local and reader buffers
370 self.non_empty_reader_buffer_mut()?;
371 let reader_buf = self.reader_buffer();
372 match reader_buf.len() {
373 #[cfg(debug_assertions)]
374 0 => unreachable!("expected reader buffer to be non-empty to reach here"),
375 #[cfg(not(debug_assertions))]
376 // SAFETY: The call to `non_empty_reader_buffer_mut` will return an error
377 // if `reader_buffer` is non-empty, as a result is is impossible to reach
378 // here with a length of 0.
379 0 => unsafe { core::hint::unreachable_unchecked() },
380 // We got enough in one request
381 m if m + n >= N => {
382 let needed = N - n;
383 let dst = output.as_mut_ptr();
384 // SAFETY: Both copies are guaranteed to be in-bounds:
385 //
386 // * `output` is defined to be exactly N bytes
387 // * `buf` is guaranteed to be < N bytes
388 // * `reader_buf` is guaranteed to have the remaining bytes needed,
389 // and we only copy exactly that many bytes
390 unsafe {
391 core::ptr::copy_nonoverlapping(self.buffer().as_ptr(), dst, n);
392 core::ptr::copy_nonoverlapping(reader_buf.as_ptr(), dst.add(n), needed);
393 drop(reader_buf);
394 }
395 self.pos += n;
396 self.reader.get_mut().consume(needed);
397 },
398 // We didn't get enough, but haven't necessarily reached eof yet, so fall back
399 // to filling `self.buf`
400 m => {
401 let needed = N - (m + n);
402 drop(reader_buf);
403 self.buffer_at_least(needed)?;
404 debug_assert!(self.buffer().len() >= N, "expected buffer to be at least {N} bytes after call to buffer_at_least");
405 // SAFETY: This is guaranteed to be an in-bounds copy
406 unsafe {
407 core::ptr::copy_nonoverlapping(
408 self.buffer().as_ptr(),
409 output.as_mut_ptr(),
410 N,
411 );
412 }
413 self.pos += N;
414 return Ok(output);
415 },
416 }
417 },
418 }
419
420 // Check if we should reset our internal buffer
421 if self.buffer().is_empty() && self.pos > 0 {
422 unsafe {
423 self.buf.set_len(0);
424 }
425 }
426
427 Ok(output)
428 }
429
430 /// Fill `self.buf` with `count` bytes
431 ///
432 /// This should only be called when we can't read from the reader directly
433 fn buffer_at_least(&mut self, mut count: usize) -> Result<(), DeserializationError> {
434 // Read until we have at least `count` bytes, or until we reach end-of-file,
435 // which ever comes first.
436 loop {
437 // If we have successfully read `count` bytes, we're done
438 if count == 0 || self.buffer().len() >= count {
439 break Ok(());
440 }
441
442 // This operation will return an error if the underlying reader hits EOF
443 self.non_empty_reader_buffer_mut()?;
444
445 // Extend `self.buf` with the bytes read from the underlying reader.
446 //
447 // NOTE: We have to re-borrow the reader buffer here, since we can't get a mutable
448 // reference to `self.buf` while holding an immutable reference to the reader buffer.
449 let reader = self.reader.get_mut();
450 let buf = reader.buffer();
451 let consumed = buf.len();
452 self.buf.extend_from_slice(buf);
453 reader.consume(consumed);
454 count = count.saturating_sub(consumed);
455 }
456 }
457}
458
459#[cfg(feature = "std")]
460impl ByteReader for ReadAdapter<'_> {
461 #[inline(always)]
462 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
463 self.pop()
464 }
465
466 /// NOTE: If we happen to not have any bytes buffered yet when this is called, then we will be
467 /// forced to try and read from the underlying reader. This requires a mutable reference, which
468 /// is obtained dynamically via [RefCell].
469 ///
470 /// <div class="warning">
471 /// Callers must ensure that they do not hold any immutable references to the buffer of this
472 /// reader when calling this function so as to avoid a situation in which the dynamic borrow
473 /// check fails. Specifically, you must not be holding a reference to the result of
474 /// [Self::read_slice] when this function is called.
475 /// </div>
476 fn peek_u8(&self) -> Result<u8, DeserializationError> {
477 if let Some(byte) = self.buffer().first() {
478 return Ok(*byte);
479 }
480 self.non_empty_reader_buffer().map(|b| b[0])
481 }
482
483 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
484 // Edge case
485 if len == 0 {
486 return Ok(&[]);
487 }
488
489 // If we have unused buffer, and the consumed portion is
490 // large enough, we will move the unused portion of the buffer
491 // to the start, freeing up bytes at the end for more reads
492 // before forcing a reallocation
493 let should_optimize_storage = self.pos >= 16 && !self.has_remaining_capacity(len);
494 if should_optimize_storage {
495 // We're going to optimize storage first
496 let buf = self.buffer();
497 let src = buf.as_ptr();
498 let count = buf.len();
499 let dst = self.buf.as_mut_ptr();
500 unsafe {
501 core::ptr::copy(src, dst, count);
502 self.buf.set_len(count);
503 self.pos = 0;
504 }
505 }
506
507 // Fill the buffer so we have at least `len` bytes available,
508 // this will return an error if we hit EOF first
509 self.buffer_at_least(len)?;
510
511 let slice = &self.buf[self.pos..(self.pos + len)];
512 self.pos += len;
513 Ok(slice)
514 }
515
516 #[inline]
517 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
518 if N == 0 {
519 return Ok([0; N]);
520 }
521 self.read_exact()
522 }
523
524 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
525 // Do we have sufficient data in the local buffer?
526 let buffer_len = self.buffer().len();
527 if buffer_len >= num_bytes {
528 return Ok(());
529 }
530
531 // What about if we include what is in the local buffer and the reader's buffer?
532 let reader_buffer_len = self.non_empty_reader_buffer().map(|b| b.len())?;
533 let buffer_len = buffer_len + reader_buffer_len;
534 if buffer_len >= num_bytes {
535 return Ok(());
536 }
537
538 // We have no more input, thus can't fulfill a request of `num_bytes`
539 if self.guaranteed_eof {
540 return Err(DeserializationError::UnexpectedEOF);
541 }
542
543 // Because this function is read-only, we must optimistically assume we can read `num_bytes`
544 // from the input, and fail later if that does not hold. We know we're not at EOF yet, but
545 // that's all we can say without buffering more from the reader. We could make use of
546 // `buffer_at_least`, which would guarantee a correct result, but it would also impose
547 // additional restrictions on the use of this function, e.g. not using it while holding a
548 // reference returned from `read_slice`. Since it is not a memory safety violation to return
549 // an optimistic result here, it makes for a better tradeoff.
550 Ok(())
551 }
552
553 #[inline]
554 fn has_more_bytes(&self) -> bool {
555 !self.buffer().is_empty() || self.non_empty_reader_buffer().is_ok()
556 }
557}
558
559// CURSOR
560// ================================================================================================
561
562#[cfg(feature = "std")]
563macro_rules! cursor_remaining_buf {
564 ($cursor:ident) => {{
565 let buf = $cursor.get_ref().as_ref();
566 let start = $cursor.position().min(buf.len() as u64) as usize;
567 &buf[start..]
568 }};
569}
570
571#[cfg(feature = "std")]
572impl<T: AsRef<[u8]>> ByteReader for std::io::Cursor<T> {
573 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
574 let buf = cursor_remaining_buf!(self);
575 if buf.is_empty() {
576 Err(DeserializationError::UnexpectedEOF)
577 } else {
578 let byte = buf[0];
579 self.set_position(self.position() + 1);
580 Ok(byte)
581 }
582 }
583
584 fn peek_u8(&self) -> Result<u8, DeserializationError> {
585 cursor_remaining_buf!(self)
586 .first()
587 .copied()
588 .ok_or(DeserializationError::UnexpectedEOF)
589 }
590
591 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
592 let pos = self.position();
593 let size = self.get_ref().as_ref().len() as u64;
594 if size.saturating_sub(pos) < len as u64 {
595 Err(DeserializationError::UnexpectedEOF)
596 } else {
597 self.set_position(pos + len as u64);
598 let start = pos.min(size) as usize;
599 Ok(&self.get_ref().as_ref()[start..(start + len)])
600 }
601 }
602
603 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
604 self.read_slice(N).map(|bytes| {
605 let mut result = [0u8; N];
606 result.copy_from_slice(bytes);
607 result
608 })
609 }
610
611 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
612 if cursor_remaining_buf!(self).len() >= num_bytes {
613 Ok(())
614 } else {
615 Err(DeserializationError::UnexpectedEOF)
616 }
617 }
618
619 #[inline]
620 fn has_more_bytes(&self) -> bool {
621 let pos = self.position();
622 let size = self.get_ref().as_ref().len() as u64;
623 pos < size
624 }
625}
626
627// SLICE READER
628// ================================================================================================
629
630/// Implements [ByteReader] trait for a slice of bytes.
631///
632/// NOTE: If you are building with the `std` feature, you should probably prefer [std::io::Cursor]
633/// instead. However, [SliceReader] is still useful in no-std environments until stabilization of
634/// the `core_io_borrowed_buf` feature.
635pub struct SliceReader<'a> {
636 source: &'a [u8],
637 pos: usize,
638}
639
640impl<'a> SliceReader<'a> {
641 /// Creates a new slice reader from the specified slice.
642 pub fn new(source: &'a [u8]) -> Self {
643 SliceReader { source, pos: 0 }
644 }
645}
646
647impl ByteReader for SliceReader<'_> {
648 fn read_u8(&mut self) -> Result<u8, DeserializationError> {
649 self.check_eor(1)?;
650 let result = self.source[self.pos];
651 self.pos += 1;
652 Ok(result)
653 }
654
655 fn peek_u8(&self) -> Result<u8, DeserializationError> {
656 self.check_eor(1)?;
657 Ok(self.source[self.pos])
658 }
659
660 fn read_slice(&mut self, len: usize) -> Result<&[u8], DeserializationError> {
661 self.check_eor(len)?;
662 let result = &self.source[self.pos..self.pos + len];
663 self.pos += len;
664 Ok(result)
665 }
666
667 fn read_array<const N: usize>(&mut self) -> Result<[u8; N], DeserializationError> {
668 self.check_eor(N)?;
669 let mut result = [0_u8; N];
670 result.copy_from_slice(&self.source[self.pos..self.pos + N]);
671 self.pos += N;
672 Ok(result)
673 }
674
675 fn check_eor(&self, num_bytes: usize) -> Result<(), DeserializationError> {
676 if self.pos + num_bytes > self.source.len() {
677 return Err(DeserializationError::UnexpectedEOF);
678 }
679 Ok(())
680 }
681
682 fn has_more_bytes(&self) -> bool {
683 self.pos < self.source.len()
684 }
685}
686
687#[cfg(all(test, feature = "std"))]
688mod tests {
689 use std::io::Cursor;
690
691 use super::*;
692 use crate::ByteWriter;
693
694 #[test]
695 fn read_adapter_empty() -> Result<(), DeserializationError> {
696 let mut reader = std::io::empty();
697 let mut adapter = ReadAdapter::new(&mut reader);
698 assert!(!adapter.has_more_bytes());
699 assert_eq!(adapter.check_eor(8), Err(DeserializationError::UnexpectedEOF));
700 assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
701 assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
702 assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
703 assert_eq!(adapter.read_slice(1), Err(DeserializationError::UnexpectedEOF));
704 assert_eq!(adapter.read_array(), Ok([]));
705 assert_eq!(adapter.read_array::<1>(), Err(DeserializationError::UnexpectedEOF));
706 Ok(())
707 }
708
709 #[test]
710 fn read_adapter_passthrough() -> Result<(), DeserializationError> {
711 let mut reader = std::io::repeat(0b101);
712 let mut adapter = ReadAdapter::new(&mut reader);
713 assert!(adapter.has_more_bytes());
714 assert_eq!(adapter.check_eor(8), Ok(()));
715 assert_eq!(adapter.peek_u8(), Ok(0b101));
716 assert_eq!(adapter.read_u8(), Ok(0b101));
717 assert_eq!(adapter.read_slice(0), Ok([].as_slice()));
718 assert_eq!(adapter.read_slice(4), Ok([0b101, 0b101, 0b101, 0b101].as_slice()));
719 assert_eq!(adapter.read_array(), Ok([]));
720 assert_eq!(adapter.read_array(), Ok([0b101, 0b101]));
721 Ok(())
722 }
723
724 #[test]
725 fn read_adapter_exact() {
726 const VALUE: usize = 2048;
727 let mut reader = Cursor::new(VALUE.to_le_bytes());
728 let mut adapter = ReadAdapter::new(&mut reader);
729 assert_eq!(usize::from_le_bytes(adapter.read_array().unwrap()), VALUE);
730 assert!(!adapter.has_more_bytes());
731 assert_eq!(adapter.peek_u8(), Err(DeserializationError::UnexpectedEOF));
732 assert_eq!(adapter.read_u8(), Err(DeserializationError::UnexpectedEOF));
733 }
734
735 #[test]
736 fn read_adapter_roundtrip() {
737 const VALUE: usize = 2048;
738
739 // Write VALUE to storage
740 let mut cursor = Cursor::new([0; core::mem::size_of::<usize>()]);
741 cursor.write_usize(VALUE);
742
743 // Read VALUE from storage
744 cursor.set_position(0);
745 let mut adapter = ReadAdapter::new(&mut cursor);
746
747 assert_eq!(adapter.read_usize(), Ok(VALUE));
748 }
749
750 #[test]
751 fn read_adapter_for_file() {
752 use std::fs::File;
753
754 use crate::ByteWriter;
755
756 let path = std::env::temp_dir().join("read_adapter_for_file.bin");
757
758 // Encode some data to a buffer, then write that buffer to a file
759 {
760 let mut buf = Vec::<u8>::with_capacity(256);
761 buf.write_bytes(b"MAGIC\0");
762 buf.write_bool(true);
763 buf.write_u32(0xbeef);
764 buf.write_usize(0xfeed);
765 buf.write_u16(0x5);
766
767 std::fs::write(&path, &buf).unwrap();
768 }
769
770 // Open the file, and try to decode the encoded items
771 let mut file = File::open(&path).unwrap();
772 let mut reader = ReadAdapter::new(&mut file);
773 assert_eq!(reader.peek_u8().unwrap(), b'M');
774 assert_eq!(reader.read_slice(6).unwrap(), b"MAGIC\0");
775 assert!(reader.read_bool().unwrap());
776 assert_eq!(reader.read_u32().unwrap(), 0xbeef);
777 assert_eq!(reader.read_usize().unwrap(), 0xfeed);
778 assert_eq!(reader.read_u16().unwrap(), 0x5);
779 assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
780 }
781
782 #[test]
783 fn read_adapter_issue_383() {
784 const STR_BYTES: &[u8] = b"just a string";
785
786 use std::fs::File;
787
788 use crate::ByteWriter;
789
790 let path = std::env::temp_dir().join("issue_383.bin");
791
792 // Encode some data to a buffer, then write that buffer to a file
793 {
794 let mut buf = vec![0u8; 1024];
795 unsafe {
796 buf.set_len(0);
797 }
798 buf.write_u128(2 * u64::MAX as u128);
799 unsafe {
800 buf.set_len(512);
801 }
802 buf.write_bytes(STR_BYTES);
803 buf.write_u32(0xbeef);
804
805 std::fs::write(&path, &buf).unwrap();
806 }
807
808 // Open the file, and try to decode the encoded items
809 let mut file = File::open(&path).unwrap();
810 let mut reader = ReadAdapter::new(&mut file);
811 assert_eq!(reader.read_u128().unwrap(), 2 * u64::MAX as u128);
812 assert_eq!(reader.buf.len(), 0);
813 assert_eq!(reader.pos, 0);
814 // Read to offset 512 (we're 16 bytes into the underlying file, i.e. offset of 496)
815 reader.read_slice(496).unwrap();
816 assert_eq!(reader.buf.len(), 496);
817 assert_eq!(reader.pos, 496);
818 // The byte string is 13 bytes, followed by 4 bytes containing the trailing u32 value.
819 // We expect that the underlying reader will buffer the remaining bytes of the file when
820 // reading STR_BYTES, so the total size of our adapter's buffer should be
821 // 496 + STR_BYTES.len() + size_of::<u32>();
822 assert_eq!(reader.read_slice(STR_BYTES.len()).unwrap(), STR_BYTES);
823 assert_eq!(reader.buf.len(), 496 + STR_BYTES.len() + core::mem::size_of::<u32>());
824 // We haven't read the u32 yet
825 assert_eq!(reader.pos, 509);
826 assert_eq!(reader.read_u32().unwrap(), 0xbeef);
827 // Now we have
828 assert_eq!(reader.pos, 513);
829 assert!(!reader.has_more_bytes(), "expected there to be no more data in the input");
830 }
831}