1use super::error::{Error, Result};
2use byteorder::{LittleEndian, ReadBytesExt};
3use serde::de::{self, Deserialize, DeserializeSeed, SeqAccess, Visitor};
4
5struct Deserializer<'de> {
6 input: &'de [u8],
7}
8
9impl<'de> Deserializer<'de> {
10 fn from_bytes(input: &'de [u8]) -> Self {
11 Deserializer { input }
12 }
13
14 fn parse_string(&mut self) -> Result<&'de str> {
15 if let Some(end) = self.input.iter().position(|&b| b == 0) {
16 let string = &self.input[0..end];
17 self.input = &self.input[end + 1..];
18 let string = &std::str::from_utf8(&string)?;
19 Ok(string)
20 } else {
21 Err(Error::EndlessString)
22 }
23 }
24
25 fn parse_bool(&mut self) -> Result<bool> {
26 let b = self.input.read_u8()?;
27 match b {
28 1 => Ok(true),
29 0 => Ok(false),
30 _ => Err(Error::InvalidBool),
31 }
32 }
33}
34
35pub fn from_bytes<'b, T>(s: &'b [u8]) -> Result<T>
39where
40 T: Deserialize<'b>,
41{
42 let mut deserializer = Deserializer::from_bytes(s);
43 let t = T::deserialize(&mut deserializer)?;
44 if deserializer.input.is_empty() {
45 Ok(t)
46 } else {
47 Err(Error::TrailingCharacters)
48 }
49}
50
51pub trait PacketRead {
54 fn read_packet(&mut self) -> Result<(u8, Vec<u8>)>;
56}
57
58impl<T: std::io::Read> PacketRead for T {
59 fn read_packet(&mut self) -> Result<(u8, Vec<u8>)> {
60 let length = self.read_u16::<LittleEndian>()? as usize;
61 let packet_type = self.read_u8()?;
62 let buffer_length = length - 3;
63 let mut buffer = vec![0u8; buffer_length];
64 self.read_exact(&mut buffer)?;
65 Ok((packet_type, buffer))
66 }
67}
68
69impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
70 type Error = Error;
71
72 fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
73 where
74 V: Visitor<'de>,
75 {
76 Err(Error::NotSupported)
78 }
79
80 fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
81 where
82 V: Visitor<'de>,
83 {
84 let string = self.parse_string()?;
85 if string.chars().count() != 1 {
86 Err(Error::InvalidChar)
87 } else {
88 visitor.visit_char(string.chars().next().unwrap())
89 }
90 }
91
92 fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
93 where
94 V: Visitor<'de>,
95 {
96 visitor.visit_bool(self.parse_bool()?)
97 }
98
99 fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
100 where
101 V: Visitor<'de>,
102 {
103 visitor.visit_i8(self.input.read_i8()?)
104 }
105
106 fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
107 where
108 V: Visitor<'de>,
109 {
110 visitor.visit_u8(self.input.read_u8()?)
111 }
112
113 fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
114 where
115 V: Visitor<'de>,
116 {
117 visitor.visit_i16(self.input.read_i16::<LittleEndian>()?)
118 }
119
120 fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
121 where
122 V: Visitor<'de>,
123 {
124 visitor.visit_u16(self.input.read_u16::<LittleEndian>()?)
125 }
126
127 fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
128 where
129 V: Visitor<'de>,
130 {
131 visitor.visit_i32(self.input.read_i32::<LittleEndian>()?)
132 }
133
134 fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
135 where
136 V: Visitor<'de>,
137 {
138 visitor.visit_u32(self.input.read_u32::<LittleEndian>()?)
139 }
140
141 fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
142 where
143 V: Visitor<'de>,
144 {
145 visitor.visit_i64(self.input.read_i64::<LittleEndian>()?)
146 }
147
148 fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
149 where
150 V: Visitor<'de>,
151 {
152 visitor.visit_u64(self.input.read_u64::<LittleEndian>()?)
153 }
154
155 fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
156 where
157 V: Visitor<'de>,
158 {
159 visitor.visit_f32(self.input.read_f32::<LittleEndian>()?)
160 }
161
162 fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
163 where
164 V: Visitor<'de>,
165 {
166 visitor.visit_f64(self.input.read_f64::<LittleEndian>()?)
167 }
168
169 fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
170 where
171 V: Visitor<'de>,
172 {
173 let string = self.parse_string()?;
174 visitor.visit_borrowed_str(string)
175 }
176
177 fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
178 where
179 V: Visitor<'de>,
180 {
181 let string = self.parse_string()?;
182 visitor.visit_borrowed_str(string)
183 }
184
185 fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
186 where
187 V: Visitor<'de>,
188 {
189 visitor.visit_seq(self)
190 }
191
192 fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
193 where
194 V: Visitor<'de>,
195 {
196 visitor.visit_seq(self)
197 }
198
199 fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
200 where
201 V: Visitor<'de>,
202 {
203 if self.input.is_empty() {
204 visitor.visit_none()
205 } else {
206 visitor.visit_some(self)
207 }
208 }
209
210 fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
211 where
212 V: Visitor<'de>,
213 {
214 visitor.visit_unit()
215 }
216
217 fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
218 where
219 V: Visitor<'de>,
220 {
221 visitor.visit_unit()
222 }
223
224 fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
225 where
226 V: Visitor<'de>,
227 {
228 visitor.visit_newtype_struct(self)
229 }
230
231 fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
232 where
233 V: Visitor<'de>,
234 {
235 visitor.visit_seq(self)
236 }
237
238 fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
239 where
240 V: Visitor<'de>,
241 {
242 visitor.visit_seq(FixedSizeSeqAccess { de: self, len })
243 }
244
245 fn deserialize_tuple_struct<V>(
246 self,
247 _name: &'static str,
248 len: usize,
249 visitor: V,
250 ) -> Result<V::Value>
251 where
252 V: Visitor<'de>,
253 {
254 visitor.visit_seq(FixedSizeSeqAccess { de: self, len })
255 }
256
257 fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
258 where
259 V: Visitor<'de>,
260 {
261 Err(Error::NotSupported)
263 }
264
265 fn deserialize_struct<V>(
266 self,
267 _name: &'static str,
268 fields: &'static [&'static str],
269 visitor: V,
270 ) -> Result<V::Value>
271 where
272 V: Visitor<'de>,
273 {
274 visitor.visit_seq(FixedSizeSeqAccess {
275 de: self,
276 len: fields.len(),
277 })
278 }
279
280 fn deserialize_enum<V>(
281 self,
282 _name: &'static str,
283 _variants: &'static [&'static str],
284 _visitor: V,
285 ) -> Result<V::Value>
286 where
287 V: Visitor<'de>,
288 {
289 Err(Error::NotSupported)
291 }
292
293 fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
294 where
295 V: Visitor<'de>,
296 {
297 Err(Error::NotSupported)
298 }
299
300 fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
301 where
302 V: Visitor<'de>,
303 {
304 Err(Error::NotSupported)
305 }
306}
307
308impl<'de> SeqAccess<'de> for Deserializer<'de> {
309 type Error = Error;
310 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
311 where
312 T: DeserializeSeed<'de>,
313 {
314 let expect_next = self.parse_bool()?;
315 if expect_next {
316 seed.deserialize(self).map(Some)
317 } else {
318 Ok(None)
319 }
320 }
321}
322
323struct FixedSizeSeqAccess<'a, 'de: 'a> {
325 de: &'a mut Deserializer<'de>,
326 len: usize,
327}
328
329impl<'de, 'a> SeqAccess<'de> for FixedSizeSeqAccess<'a, 'de> {
330 type Error = Error;
331
332 fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
333 where
334 T: DeserializeSeed<'de>,
335 {
336 if self.len == 0 {
337 Ok(None)
338 } else {
339 self.len -= 1;
340 seed.deserialize(&mut *self.de).map(Some)
341 }
342 }
343}
344
345#[cfg(test)]
346mod test {
347 use super::PacketRead;
348
349 #[test]
350 fn test_empty_packet_read() {
351 let mut empty_packet: &[u8] = &[3, 0, 10];
352 assert_eq!(empty_packet.read_packet().unwrap(), (10, Vec::new()));
353 }
354
355 use super::*;
356 use serde_derive::Deserialize;
357
358 #[test]
359 fn test_simple_struct_read() {
360 #[derive(Deserialize, Eq, PartialEq, Debug)]
361 struct SimpleStruct {
362 a: u8,
363 b: u16,
364 c: u32,
365 d: bool
366 }
367 let mut input: &[u8] = &vec![
368 11, 0, 10, 1, 2, 0, 3, 0, 0, 0, 1 ];
375 let simple_struct = SimpleStruct { a: 1, b: 2, c: 3, d: true};
376 let (packet_type, buffer) = input.read_packet().unwrap();
377 assert_eq!(packet_type, 10);
378 assert_eq!(from_bytes::<SimpleStruct>(&buffer).unwrap(), simple_struct);
379 }
380
381 #[test]
382 fn test_vec_ser() {
383 #[derive(Deserialize, Eq, PartialEq, Debug)]
384 struct VecStruct {
385 item: Vec<u8>
386 }
387 let mut input: &[u8] = &vec![
388 14, 0, 0xFF, 1, 0, 1, 1,
392 1, 2,
393 1, 3,
394 1, 4,
395 0 ];
397 let vec_struct = VecStruct { item: vec![0, 1, 2, 3, 4] };
398 let (packet_type, buffer) = input.read_packet().unwrap();
399 assert_eq!(packet_type, 0xFF);
400 assert_eq!(from_bytes::<VecStruct>(&buffer).unwrap(), vec_struct);
401 }
402
403 mod option_tests {
404 use super::*;
405
406 #[derive(Deserialize, Eq, PartialEq, Debug)]
407 struct OptionStruct {
408 mandatory: u8,
409 optional: Option<u8>
410 }
411
412 #[test]
413 fn test_some_ser() {
414 let mut input: &[u8] = &vec![
415 5, 0, 3, 10, 10 ];
420 let some_struct = OptionStruct { mandatory: 10, optional: Some(10) };
421 let (packet_type, buffer) = input.read_packet().unwrap();
422 assert_eq!(packet_type, 3);
423 assert_eq!(from_bytes::<OptionStruct>(&buffer).unwrap(), some_struct);
424 }
425
426 #[test]
427 fn test_none_ser() {
428 let mut input: &[u8] = &vec![
429 4, 0, 3, 10 ];
433 let some_struct = OptionStruct { mandatory: 10, optional: None };
434 let (packet_type, buffer) = input.read_packet().unwrap();
435 assert_eq!(packet_type, 3);
436 assert_eq!(from_bytes::<OptionStruct>(&buffer).unwrap(), some_struct);
437 }
438 }
439}