1use crate::{error, FromBuf};
4use bytes::{Buf, Bytes};
5use paste::paste;
6
7macro_rules! get_primitive_checked_be {
8 ($t:ty, $width:literal) => {
9 paste! {
10 #[doc = "This method wraps [`Buf::get_" $t "`] with a bounds check to ensure there are enough bytes remaining, without panicking."]
11 fn [<try_get_ $t>](&mut self) -> std::result::Result<$t, error::Truncated> {
12 if self.remaining() >= $width {
13 Ok(self.[<get_ $t>]())
14 } else {
15 Err(error::Truncated)
16 }
17 }
18 }
19 };
20}
21
22macro_rules! get_primitive_checked_le {
23 ($t:ty, $width:literal) => {
24 paste! {
25 #[doc = "This method wraps [`Buf::get_" $t "_le`] with a bounds check to ensure there are enough bytes remaining, without panicking."]
26 fn [<try_get_ $t _le>](&mut self) -> std::result::Result<$t, error::Truncated> {
27 if self.remaining() >= $width {
28 Ok(self.[<get_ $t _le>]())
29 } else {
30 Err(error::Truncated)
31 }
32 }
33 }
34 };
35}
36
37pub trait SafeBuf: Buf {
39 fn try_copy_to_bytes(&mut self, len: usize) -> std::result::Result<Bytes, error::Truncated> {
47 if self.remaining() < len {
48 Err(error::Truncated)
49 } else {
50 Ok(self.copy_to_bytes(len))
51 }
52 }
53
54 fn try_copy_to_slice(&mut self, dst: &mut [u8]) -> std::result::Result<(), error::Truncated> {
62 if self.remaining() < dst.len() {
63 Err(error::Truncated)
64 } else {
65 self.copy_to_slice(dst);
66 Ok(())
67 }
68 }
69
70 fn extract<T>(&mut self) -> crate::Result<T>
77 where
78 T: FromBuf,
79 {
80 T::from_buf(self)
81 }
82
83 fn should_be_exhausted(&self) -> std::result::Result<(), error::ExtraneousBytes> {
90 if self.has_remaining() {
91 Err(error::ExtraneousBytes)
92 } else {
93 Ok(())
94 }
95 }
96
97 get_primitive_checked_be!(u8, 1);
98 get_primitive_checked_be!(i8, 1);
99
100 get_primitive_checked_be!(u16, 2);
101 get_primitive_checked_be!(i16, 2);
102 get_primitive_checked_be!(u32, 4);
103 get_primitive_checked_be!(i32, 4);
104 get_primitive_checked_be!(u64, 8);
105 get_primitive_checked_be!(i64, 8);
106 get_primitive_checked_be!(u128, 16);
107 get_primitive_checked_be!(i128, 16);
108
109 get_primitive_checked_le!(u16, 2);
110 get_primitive_checked_le!(i16, 2);
111 get_primitive_checked_le!(u32, 4);
112 get_primitive_checked_le!(i32, 4);
113 get_primitive_checked_le!(u64, 8);
114 get_primitive_checked_le!(i64, 8);
115 get_primitive_checked_le!(u128, 16);
116 get_primitive_checked_le!(i128, 16);
117}
118
119impl<T> SafeBuf for T where T: Buf {}
120
121#[cfg(test)]
122mod tests {
123 use bytes::BytesMut;
124 use paste::paste;
125
126 use super::SafeBuf;
127 use crate::BufMut;
128
129 #[test]
130 fn try_copy_to_bytes() {
131 let mut bytes = BytesMut::new();
132 bytes.extend_from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
133
134 assert!(bytes.try_copy_to_bytes(4).is_ok());
135 assert!(bytes.try_copy_to_bytes(4).is_ok());
136 assert!(bytes.try_copy_to_bytes(4).is_err());
137 }
138
139 #[test]
140 fn try_copy_to_slice() {
141 let mut bytes = BytesMut::new();
142 bytes.extend_from_slice(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
143
144 let dst = &mut [0_u8; 4];
145
146 assert!(bytes.try_copy_to_slice(dst).is_ok());
147 assert!(bytes.try_copy_to_slice(dst).is_ok());
148 assert!(bytes.try_copy_to_slice(dst).is_err());
149 }
150
151 macro_rules! round_trip {
152 ($t:ty) => {
153 paste! {
154 #[test]
155 fn [<round_trip_ $t>]() {
156 let mut buffer = BytesMut::new();
157 let input = 17;
158
159 buffer.[<put_ $t>](input);
160 let output = buffer.[<try_get_ $t>]().unwrap();
161
162 assert!(buffer.[<try_get_ $t>]().is_err());
163 assert_eq!(input, output);
164 assert!(buffer.is_empty());
165 }
166 }
167 };
168 }
169
170 round_trip!(u8);
171 round_trip!(i8);
172 round_trip!(u16);
173 round_trip!(i16);
174 round_trip!(u32);
175 round_trip!(i32);
176 round_trip!(u64);
177 round_trip!(i64);
178}