rust_openttd_admin/packet/serde/
de.rs

1use super::error::{Error, Result};
2use byteorder::{LittleEndian, ReadBytesExt};
3use serde::de::{self, Deserialize, DeserializeSeed, SeqAccess, Visitor};
4
5struct Deserializer<'de> {
6    input: &'de [u8],
7}
8
9impl<'de> Deserializer<'de> {
10    fn from_bytes(input: &'de [u8]) -> Self {
11        Deserializer { input }
12    }
13
14    fn parse_string(&mut self) -> Result<&'de str> {
15        if let Some(end) = self.input.iter().position(|&b| b == 0) {
16            let string = &self.input[0..end];
17            self.input = &self.input[end + 1..];
18            let string = &std::str::from_utf8(&string)?;
19            Ok(string)
20        } else {
21            Err(Error::EndlessString)
22        }
23    }
24
25    fn parse_bool(&mut self) -> Result<bool> {
26        let b = self.input.read_u8()?;
27        match b {
28            1 => Ok(true),
29            0 => Ok(false),
30            _ => Err(Error::InvalidBool),
31        }
32    }
33}
34
35/// Decode a serializable type from an OpenTTD buffer. The input should be the
36/// data buffer, without the preceding length and packet type. Usually this is
37/// is used to implement [`PacketRead`].
38pub fn from_bytes<'b, T>(s: &'b [u8]) -> Result<T>
39where
40    T: Deserialize<'b>,
41{
42    let mut deserializer = Deserializer::from_bytes(s);
43    let t = T::deserialize(&mut deserializer)?;
44    if deserializer.input.is_empty() {
45        Ok(t)
46    } else {
47        Err(Error::TrailingCharacters)
48    }
49}
50
51/// A trait that provides the [`read_packet`](PacketRead#read_packet) function
52/// to a type implementing [`std::io::Read`].
53pub trait PacketRead {
54    /// Read a packet and return the packet type and data.
55    fn read_packet(&mut self) -> Result<(u8, Vec<u8>)>;
56}
57
58impl<T: std::io::Read> PacketRead for T {
59    fn read_packet(&mut self) -> Result<(u8, Vec<u8>)> {
60        let length = self.read_u16::<LittleEndian>()? as usize;
61        let packet_type = self.read_u8()?;
62        let buffer_length = length - 3;
63        let mut buffer = vec![0u8; buffer_length];
64        self.read_exact(&mut buffer)?;
65        Ok((packet_type, buffer))
66    }
67}
68
69impl<'de, 'a> de::Deserializer<'de> for &'a mut Deserializer<'de> {
70    type Error = Error;
71
72    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
73    where
74        V: Visitor<'de>,
75    {
76        // We really have no idea what the data means.
77        Err(Error::NotSupported)
78    }
79
80    fn deserialize_char<V>(self, visitor: V) -> Result<V::Value>
81    where
82        V: Visitor<'de>,
83    {
84        let string = self.parse_string()?;
85        if string.chars().count() != 1 {
86            Err(Error::InvalidChar)
87        } else {
88            visitor.visit_char(string.chars().next().unwrap())
89        }
90    }
91
92    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
93    where
94        V: Visitor<'de>,
95    {
96        visitor.visit_bool(self.parse_bool()?)
97    }
98
99    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
100    where
101        V: Visitor<'de>,
102    {
103        visitor.visit_i8(self.input.read_i8()?)
104    }
105
106    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value>
107    where
108        V: Visitor<'de>,
109    {
110        visitor.visit_u8(self.input.read_u8()?)
111    }
112
113    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value>
114    where
115        V: Visitor<'de>,
116    {
117        visitor.visit_i16(self.input.read_i16::<LittleEndian>()?)
118    }
119
120    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value>
121    where
122        V: Visitor<'de>,
123    {
124        visitor.visit_u16(self.input.read_u16::<LittleEndian>()?)
125    }
126
127    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value>
128    where
129        V: Visitor<'de>,
130    {
131        visitor.visit_i32(self.input.read_i32::<LittleEndian>()?)
132    }
133
134    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
135    where
136        V: Visitor<'de>,
137    {
138        visitor.visit_u32(self.input.read_u32::<LittleEndian>()?)
139    }
140
141    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value>
142    where
143        V: Visitor<'de>,
144    {
145        visitor.visit_i64(self.input.read_i64::<LittleEndian>()?)
146    }
147
148    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
149    where
150        V: Visitor<'de>,
151    {
152        visitor.visit_u64(self.input.read_u64::<LittleEndian>()?)
153    }
154
155    fn deserialize_f32<V>(self, visitor: V) -> Result<V::Value>
156    where
157        V: Visitor<'de>,
158    {
159        visitor.visit_f32(self.input.read_f32::<LittleEndian>()?)
160    }
161
162    fn deserialize_f64<V>(self, visitor: V) -> Result<V::Value>
163    where
164        V: Visitor<'de>,
165    {
166        visitor.visit_f64(self.input.read_f64::<LittleEndian>()?)
167    }
168
169    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value>
170    where
171        V: Visitor<'de>,
172    {
173        let string = self.parse_string()?;
174        visitor.visit_borrowed_str(string)
175    }
176
177    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value>
178    where
179        V: Visitor<'de>,
180    {
181        let string = self.parse_string()?;
182        visitor.visit_borrowed_str(string)
183    }
184
185    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
186    where
187        V: Visitor<'de>,
188    {
189        visitor.visit_seq(self)
190    }
191
192    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
193    where
194        V: Visitor<'de>,
195    {
196        visitor.visit_seq(self)
197    }
198
199    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
200    where
201        V: Visitor<'de>,
202    {
203        if self.input.is_empty() {
204            visitor.visit_none()
205        } else {
206            visitor.visit_some(self)
207        }
208    }
209
210    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value>
211    where
212        V: Visitor<'de>,
213    {
214        visitor.visit_unit()
215    }
216
217    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
218    where
219        V: Visitor<'de>,
220    {
221        visitor.visit_unit()
222    }
223
224    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value>
225    where
226        V: Visitor<'de>,
227    {
228        visitor.visit_newtype_struct(self)
229    }
230
231    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
232    where
233        V: Visitor<'de>,
234    {
235        visitor.visit_seq(self)
236    }
237
238    fn deserialize_tuple<V>(self, len: usize, visitor: V) -> Result<V::Value>
239    where
240        V: Visitor<'de>,
241    {
242        visitor.visit_seq(FixedSizeSeqAccess { de: self, len })
243    }
244
245    fn deserialize_tuple_struct<V>(
246        self,
247        _name: &'static str,
248        len: usize,
249        visitor: V,
250    ) -> Result<V::Value>
251    where
252        V: Visitor<'de>,
253    {
254        visitor.visit_seq(FixedSizeSeqAccess { de: self, len })
255    }
256
257    fn deserialize_map<V>(self, _visitor: V) -> Result<V::Value>
258    where
259        V: Visitor<'de>,
260    {
261        // Maps are not supported
262        Err(Error::NotSupported)
263    }
264
265    fn deserialize_struct<V>(
266        self,
267        _name: &'static str,
268        fields: &'static [&'static str],
269        visitor: V,
270    ) -> Result<V::Value>
271    where
272        V: Visitor<'de>,
273    {
274        visitor.visit_seq(FixedSizeSeqAccess {
275            de: self,
276            len: fields.len(),
277        })
278    }
279
280    fn deserialize_enum<V>(
281        self,
282        _name: &'static str,
283        _variants: &'static [&'static str],
284        _visitor: V,
285    ) -> Result<V::Value>
286    where
287        V: Visitor<'de>,
288    {
289        // Use custom implementation instead.
290        Err(Error::NotSupported)
291    }
292
293    fn deserialize_identifier<V>(self, _visitor: V) -> Result<V::Value>
294    where
295        V: Visitor<'de>,
296    {
297        Err(Error::NotSupported)
298    }
299
300    fn deserialize_ignored_any<V>(self, _visitor: V) -> Result<V::Value>
301    where
302        V: Visitor<'de>,
303    {
304        Err(Error::NotSupported)
305    }
306}
307
308impl<'de> SeqAccess<'de> for Deserializer<'de> {
309    type Error = Error;
310    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
311    where
312        T: DeserializeSeed<'de>,
313    {
314        let expect_next = self.parse_bool()?;
315        if expect_next {
316            seed.deserialize(self).map(Some)
317        } else {
318            Ok(None)
319        }
320    }
321}
322
323/// A SeqAccess of fixed size without delimiters.
324struct FixedSizeSeqAccess<'a, 'de: 'a> {
325    de: &'a mut Deserializer<'de>,
326    len: usize,
327}
328
329impl<'de, 'a> SeqAccess<'de> for FixedSizeSeqAccess<'a, 'de> {
330    type Error = Error;
331
332    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
333    where
334        T: DeserializeSeed<'de>,
335    {
336        if self.len == 0 {
337            Ok(None)
338        } else {
339            self.len -= 1;
340            seed.deserialize(&mut *self.de).map(Some)
341        }
342    }
343}
344
345#[cfg(test)]
346mod test {
347    use super::PacketRead;
348
349    #[test]
350    fn test_empty_packet_read() {
351        let mut empty_packet: &[u8] = &[3, 0, 10];
352        assert_eq!(empty_packet.read_packet().unwrap(), (10, Vec::new()));
353    }
354
355    use super::*;
356    use serde_derive::Deserialize;
357
358    #[test]
359    fn test_simple_struct_read() {
360        #[derive(Deserialize, Eq, PartialEq, Debug)]
361        struct SimpleStruct {
362            a: u8,
363            b: u16,
364            c: u32,
365            d: bool
366        }
367        let mut input: &[u8] = &vec![
368            11, 0, // Length
369            10, // PACKET_TYPE
370            1, // a
371            2, 0, // b
372            3, 0, 0, 0, // c
373            1 // d
374        ];
375        let simple_struct = SimpleStruct { a: 1, b: 2, c: 3, d: true};
376        let (packet_type, buffer) = input.read_packet().unwrap();
377        assert_eq!(packet_type, 10);
378        assert_eq!(from_bytes::<SimpleStruct>(&buffer).unwrap(), simple_struct);
379    }
380
381    #[test]
382    fn test_vec_ser() {
383        #[derive(Deserialize, Eq, PartialEq, Debug)]
384        struct VecStruct {
385            item: Vec<u8>
386        }
387        let mut input: &[u8] = &vec![
388            14, 0, // Length
389            0xFF, // Packet type
390            1, 0, // boolean, item
391            1, 1,
392            1, 2,
393            1, 3,
394            1, 4,
395            0 // False
396        ];
397        let vec_struct = VecStruct { item: vec![0, 1, 2, 3, 4] };
398        let (packet_type, buffer) = input.read_packet().unwrap();
399        assert_eq!(packet_type, 0xFF);
400        assert_eq!(from_bytes::<VecStruct>(&buffer).unwrap(), vec_struct);
401    }
402
403    mod option_tests {
404        use super::*;
405
406        #[derive(Deserialize, Eq, PartialEq, Debug)]
407        struct OptionStruct {
408            mandatory: u8,
409            optional: Option<u8>
410        }
411
412        #[test]
413        fn test_some_ser() {
414            let mut input: &[u8] = &vec![
415                5, 0, // Length
416                3, // PACKET_TYPE
417                10, // mandatory
418                10 // optional
419            ];
420            let some_struct = OptionStruct { mandatory: 10, optional: Some(10) };
421            let (packet_type, buffer) = input.read_packet().unwrap();
422            assert_eq!(packet_type, 3);
423            assert_eq!(from_bytes::<OptionStruct>(&buffer).unwrap(), some_struct);
424        }
425
426        #[test]
427        fn test_none_ser() {
428            let mut input: &[u8] = &vec![
429                4, 0, // Length
430                3, // PACKET_TYPE
431                10 // mandatory
432            ];
433            let some_struct = OptionStruct { mandatory: 10, optional: None };
434            let (packet_type, buffer) = input.read_packet().unwrap();
435            assert_eq!(packet_type, 3);
436            assert_eq!(from_bytes::<OptionStruct>(&buffer).unwrap(), some_struct);
437        }
438    }
439}