scuffle_bytes_util/
bytes_cursor.rs

1use std::io;
2
3use bytes::Bytes;
4
5pub type BytesCursor = io::Cursor<Bytes>;
6
7/// A helper trait to implement zero copy reads on a `Cursor<Bytes>` type.
8///
9/// Allowing for zero copy reads from a `Cursor<Bytes>` type.
10pub trait BytesCursorExt {
11    /// Extracts the remaining bytes from the cursor.
12    ///
13    /// This does not do a copy of the bytes, and is O(1) time.
14    ///
15    /// This is the same as `BytesCursor::extract_bytes(self.remaining())`.
16    ///
17    /// This is equivalent if you were to read the remaining data into a new
18    /// buffer, however this is more efficient as it does not copy the
19    /// bytes.
20    fn extract_remaining(&mut self) -> Bytes;
21
22    /// Extracts bytes from the cursor.
23    ///
24    /// This does not do a copy of the bytes, and is O(1) time.
25    /// Returns an error if the size is greater than the remaining bytes.
26    ///
27    /// This is equivalent if you were to read the remaining data into a new
28    /// buffer, however this is more efficient as it does not copy the
29    /// bytes.
30    fn extract_bytes(&mut self, size: usize) -> io::Result<Bytes>;
31}
32
33fn remaining(cursor: &BytesCursor) -> usize {
34    cursor.get_ref().len().saturating_sub(cursor.position() as usize)
35}
36
37impl BytesCursorExt for BytesCursor {
38    fn extract_remaining(&mut self) -> Bytes {
39        // We don't really care if we fail here since the desired behavior is
40        // to return all bytes remaining in the cursor. If we fail its because
41        // there are not enough bytes left in the cursor to read.
42        self.extract_bytes(remaining(self)).unwrap_or_default()
43    }
44
45    fn extract_bytes(&mut self, size: usize) -> io::Result<Bytes> {
46        // If the size is zero we can just return an empty bytes slice.
47        if size == 0 {
48            return Ok(Bytes::new());
49        }
50
51        // If the size is greater than the remaining bytes we can just return an
52        // error.
53        if size > remaining(self) {
54            return Err(io::Error::new(io::ErrorKind::UnexpectedEof, "not enough bytes"));
55        }
56
57        let position = self.position() as usize;
58
59        // We slice bytes here which is a O(1) operation as it only modifies a few
60        // reference counters and does not copy the memory.
61        let slice = self.get_ref().slice(position..position + size);
62
63        // We advance the cursor because we have now "read" the bytes.
64        self.set_position((position + size) as u64);
65
66        Ok(slice)
67    }
68}
69
70#[cfg(test)]
71#[cfg_attr(all(test, coverage_nightly), coverage(off))]
72mod tests {
73    use super::*;
74
75    #[test]
76    fn test_bytes_cursor_extract_remaining() {
77        let mut cursor = io::Cursor::new(Bytes::from_static(&[1, 2, 3, 4, 5]));
78        let remaining = cursor.extract_remaining();
79        assert_eq!(remaining, Bytes::from_static(&[1, 2, 3, 4, 5]));
80    }
81
82    #[test]
83    fn test_bytes_cursor_extract_bytes() {
84        let mut cursor = io::Cursor::new(Bytes::from_static(&[1, 2, 3, 4, 5]));
85        let bytes = cursor.extract_bytes(3).unwrap();
86        assert_eq!(bytes, Bytes::from_static(&[1, 2, 3]));
87        assert_eq!(remaining(&cursor), 2);
88
89        let bytes = cursor.extract_bytes(2).unwrap();
90        assert_eq!(bytes, Bytes::from_static(&[4, 5]));
91        assert_eq!(remaining(&cursor), 0);
92
93        let bytes = cursor.extract_bytes(1).unwrap_err();
94        assert_eq!(bytes.kind(), io::ErrorKind::UnexpectedEof);
95
96        let bytes = cursor.extract_bytes(0).unwrap();
97        assert_eq!(bytes, Bytes::from_static(&[]));
98        assert_eq!(remaining(&cursor), 0);
99
100        let bytes = cursor.extract_remaining();
101        assert_eq!(bytes, Bytes::from_static(&[]));
102        assert_eq!(remaining(&cursor), 0);
103    }
104
105    #[test]
106    fn seek_out_of_bounds() {
107        let mut cursor = io::Cursor::new(Bytes::from_static(&[1, 2, 3, 4, 5]));
108        cursor.set_position(10);
109        assert_eq!(remaining(&cursor), 0);
110
111        let bytes = cursor.extract_remaining();
112        assert_eq!(bytes, Bytes::from_static(&[]));
113
114        let bytes = cursor.extract_bytes(1);
115        assert_eq!(bytes.unwrap_err().kind(), io::ErrorKind::UnexpectedEof);
116
117        let bytes = cursor.extract_bytes(0);
118        assert_eq!(bytes.unwrap(), Bytes::from_static(&[]));
119    }
120}