safer_bytes/
safe_buf.rs

1//! Extension traits for extracting custom objects from a [`bytes::Buf`]
2
3use crate::{error, FromBuf};
4use bytes::{Buf, Bytes};
5use paste::paste;
6
7macro_rules! get_primitive_checked_be {
8    ($t:ty, $width:literal) => {
9        paste! {
10            #[doc = "This method wraps [`Buf::get_" $t "`] with a bounds check to ensure there are enough bytes remaining, without panicking."]
11            fn [<try_get_ $t>](&mut self) -> std::result::Result<$t, error::Truncated> {
12                if self.remaining() >= $width {
13                    Ok(self.[<get_ $t>]())
14                } else {
15                    Err(error::Truncated)
16                }
17            }
18        }
19    };
20}
21
22macro_rules! get_primitive_checked_le {
23    ($t:ty, $width:literal) => {
24        paste! {
25            #[doc = "This method wraps [`Buf::get_" $t "_le`] with a bounds check to ensure there are enough bytes remaining, without panicking."]
26            fn [<try_get_ $t _le>](&mut self) -> std::result::Result<$t, error::Truncated> {
27                if self.remaining() >= $width {
28                    Ok(self.[<get_ $t _le>]())
29                } else {
30                    Err(error::Truncated)
31                }
32            }
33        }
34    };
35}
36
37/// Extension trait for [`bytes::Buf`]
38pub trait SafeBuf: Buf {
39    /// Take a given number of bytes from the buffer, with a check to ensure
40    /// there are enough remaining
41    ///
42    /// # Errors
43    ///
44    /// This method will return an error if the number of bytes remaining in the
45    /// buffer is insufficent
46    fn try_copy_to_bytes(&mut self, len: usize) -> std::result::Result<Bytes, error::Truncated> {
47        if self.remaining() < len {
48            Err(error::Truncated)
49        } else {
50            Ok(self.copy_to_bytes(len))
51        }
52    }
53
54    /// Take a given number of bytes from the buffer and write to a slice, with
55    /// a check to ensure there are enough remaining
56    ///
57    /// # Errors
58    ///
59    /// This method will return an error if the number of bytes remaining in the
60    /// buffer is insufficent
61    fn try_copy_to_slice(&mut self, dst: &mut [u8]) -> std::result::Result<(), error::Truncated> {
62        if self.remaining() < dst.len() {
63            Err(error::Truncated)
64        } else {
65            self.copy_to_slice(dst);
66            Ok(())
67        }
68    }
69
70    /// Read a custom object from a buffer
71    ///
72    /// # Errors
73    ///
74    /// This method will return an error if the number of bytes remaining in the
75    /// buffer is insufficent, or if the type cannot be parsed from the bytes.
76    fn extract<T>(&mut self) -> crate::Result<T>
77    where
78        T: FromBuf,
79    {
80        T::from_buf(self)
81    }
82
83    /// Check whether this reader is exhausted (out of bytes).
84    ///
85    /// # Errors
86    ///
87    /// this method will return [`error::ExtraneousBytes`] if there are bytes
88    /// left in the buffer.
89    fn should_be_exhausted(&self) -> std::result::Result<(), error::ExtraneousBytes> {
90        if self.has_remaining() {
91            Err(error::ExtraneousBytes)
92        } else {
93            Ok(())
94        }
95    }
96
97    get_primitive_checked_be!(u8, 1);
98    get_primitive_checked_be!(i8, 1);
99
100    get_primitive_checked_be!(u16, 2);
101    get_primitive_checked_be!(i16, 2);
102    get_primitive_checked_be!(u32, 4);
103    get_primitive_checked_be!(i32, 4);
104    get_primitive_checked_be!(u64, 8);
105    get_primitive_checked_be!(i64, 8);
106    get_primitive_checked_be!(u128, 16);
107    get_primitive_checked_be!(i128, 16);
108
109    get_primitive_checked_le!(u16, 2);
110    get_primitive_checked_le!(i16, 2);
111    get_primitive_checked_le!(u32, 4);
112    get_primitive_checked_le!(i32, 4);
113    get_primitive_checked_le!(u64, 8);
114    get_primitive_checked_le!(i64, 8);
115    get_primitive_checked_le!(u128, 16);
116    get_primitive_checked_le!(i128, 16);
117}
118
119impl<T> SafeBuf for T where T: Buf {}
120
121#[cfg(test)]
122mod tests {
123    use bytes::BytesMut;
124    use paste::paste;
125
126    use super::SafeBuf;
127    use crate::BufMut;
128
129    #[test]
130    fn try_copy_to_bytes() {
131        let mut bytes = BytesMut::new();
132        bytes.extend_from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
133
134        assert!(bytes.try_copy_to_bytes(4).is_ok());
135        assert!(bytes.try_copy_to_bytes(4).is_ok());
136        assert!(bytes.try_copy_to_bytes(4).is_err());
137    }
138
139    #[test]
140    fn try_copy_to_slice() {
141        let mut bytes = BytesMut::new();
142        bytes.extend_from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
143
144        let dst = &mut [0_u8; 4];
145
146        assert!(bytes.try_copy_to_slice(dst).is_ok());
147        assert!(bytes.try_copy_to_slice(dst).is_ok());
148        assert!(bytes.try_copy_to_slice(dst).is_err());
149    }
150
151    macro_rules! round_trip {
152        ($t:ty) => {
153            paste! {
154                #[test]
155                fn [<round_trip_ $t>]() {
156                    let mut buffer = BytesMut::new();
157                    let input = 17;
158
159                    buffer.[<put_ $t>](input);
160                    let output = buffer.[<try_get_ $t>]().unwrap();
161
162                    assert!(buffer.[<try_get_ $t>]().is_err());
163                    assert_eq!(input, output);
164                    assert!(buffer.is_empty());
165                }
166            }
167        };
168    }
169
170    round_trip!(u8);
171    round_trip!(i8);
172    round_trip!(u16);
173    round_trip!(i16);
174    round_trip!(u32);
175    round_trip!(i32);
176    round_trip!(u64);
177    round_trip!(i64);
178}