wasm_msgpack/decode/serde/
mod.rs

1use crate::marker::Marker;
2use core::fmt;
3use paste::paste;
4use serde::de::{self, Visitor};
5
6use self::{enum_::UnitVariantAccess, map::MapAccess, seq::SeqAccess};
7
8mod enum_;
9mod map;
10mod seq;
11
12use super::Error;
13
14type Result<T> = core::result::Result<T, Error>;
15
16#[cfg(test)]
17fn print_debug<T>(prefix: &str, function_name: &str, de: &Deserializer) {
18    #[cfg(not(feature = "std"))]
19    extern crate std;
20    #[cfg(not(feature = "std"))]
21    use std::println;
22    println!(
23        "{}{}<{}> ({:02x?})",
24        prefix,
25        function_name,
26        core::any::type_name::<T>(),
27        &de.slice[de.index..core::cmp::min(de.slice.len(), de.index + 10)]
28    );
29}
30
31#[cfg(test)]
32fn print_debug_value<T, V: core::fmt::Debug>(function_name: &str, de: &Deserializer, value: &V) {
33    #[cfg(not(feature = "std"))]
34    extern crate std;
35    #[cfg(not(feature = "std"))]
36    use std::println;
37    println!(
38        "{}<{}> => {:?}   ({:02x?})",
39        function_name,
40        core::any::type_name::<T>(),
41        value,
42        &de.slice[de.index..core::cmp::min(de.slice.len(), de.index + 10)]
43    );
44}
45
46#[cfg(not(test))]
47#[allow(clippy::missing_const_for_fn)]
48fn print_debug<T>(_prefix: &str, _function_name: &str, _de: &Deserializer) {}
49#[cfg(not(test))]
50#[allow(clippy::missing_const_for_fn)]
51fn print_debug_value<T, V: core::fmt::Debug>(_function_name: &str, _de: &Deserializer, _value: &V) {}
52
53pub(crate) struct Deserializer<'b> {
54    slice: &'b [u8],
55    index: usize,
56    state: State,
57}
58
59impl<'a> Deserializer<'a> {
60    pub const fn new(slice: &'a [u8]) -> Deserializer<'_> {
61        Deserializer {
62            slice,
63            index: 0,
64            state: State::Normal,
65        }
66    }
67
68    fn eat_byte(&mut self) {
69        self.index += 1;
70    }
71
72    fn peek(&mut self) -> Option<Marker> {
73        Some(Marker::from_u8(*self.slice.get(self.index)?))
74    }
75}
76
77macro_rules! deserialize_primitives {
78    ($into:ident, $($ty:ident),*) => {
79      $(paste! {
80        fn [<deserialize_ $ty>]<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value>
81        { paste! { self.[<deserialize_ $into>](visitor) } }
82       })*
83    };
84}
85
86enum State {
87    Normal,
88    Ext(usize),
89}
90
91impl<'a, 'de> de::Deserializer<'de> for &'a mut Deserializer<'de> {
92    type Error = Error;
93
94    deserialize_primitives!(i64, i16, i32);
95    deserialize_primitives!(u64, u8, u16, u32);
96    deserialize_primitives!(f64, f32);
97
98    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value>
99    where
100        V: Visitor<'de>,
101    {
102        print_debug::<V>("Deserializer::deserialize_", "i8", self);
103        let (value, len) = match self.state {
104            State::Normal => super::read_i8(&self.slice[self.index..])?,
105            // read the ext type as raw byte and not encoded as a normal i8
106            #[cfg(feature = "ext")]
107            State::Ext(_) => (self.slice[self.index] as i8, 1),
108        };
109        self.index += len;
110        print_debug_value::<i8, i8>("Deserializer::deserialize_i8", self, &value);
111        visitor.visit_i8(value)
112    }
113
114    fn deserialize_str<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
115        print_debug::<V>("Deserializer::deserialize_", "str", self);
116        let (s, len) = super::read_str(&self.slice[self.index..])?;
117        self.index += len;
118        visitor.visit_borrowed_str(s)
119    }
120
121    fn deserialize_bytes<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
122        print_debug::<V>("Deserializer::deserialize_", "bytes", self);
123        let (value, len) = match self.state {
124            State::Normal => super::read_bin(&self.slice[self.index..])?,
125            // read the ext type as raw byte and not encoded as a normal i8
126            #[cfg(feature = "ext")]
127            State::Ext(len) => {
128                self.state = State::Normal;
129                (&self.slice[self.index..self.index + len], len)
130            }
131        };
132        self.index += len;
133        visitor.visit_borrowed_bytes(value)
134    }
135
136    fn deserialize_byte_buf<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
137        print_debug::<V>("Deserializer::deserialize_", "byte_buf", self);
138        self.deserialize_bytes(visitor)
139    }
140
141    fn deserialize_option<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
142        print_debug::<V>("Deserializer::deserialize_", "option", self);
143        let marker = self.peek().ok_or(Error::EndOfBuffer(Marker::Reserved))?;
144        match marker {
145            Marker::Null => {
146                self.eat_byte();
147                visitor.visit_none()
148            }
149            _ => visitor.visit_some(self),
150        }
151    }
152
153    fn deserialize_seq<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
154        print_debug::<V>("Deserializer::deserialize_", "seq", self);
155        let (len, header_len) = crate::decode::read_array_len(&self.slice[self.index..])?;
156        self.index += header_len;
157        visitor.visit_seq(SeqAccess::new(self, len))
158    }
159
160    fn deserialize_tuple<V: Visitor<'de>>(self, _len: usize, visitor: V) -> Result<V::Value> {
161        print_debug::<V>("Deserializer::deserialize_", "tuple", self);
162        self.deserialize_seq(visitor)
163    }
164
165    fn deserialize_tuple_struct<V: Visitor<'de>>(self, _name: &'static str, _len: usize, visitor: V) -> Result<V::Value> {
166        print_debug::<V>("Deserializer::deserialize_", "tuple_struct", self);
167        self.deserialize_seq(visitor)
168    }
169
170    fn deserialize_map<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
171        print_debug::<V>("Deserializer::deserialize_", "map", self);
172        let (len, header_len) = crate::decode::read_map_len(&self.slice[self.index..])?;
173        self.index += header_len;
174        visitor.visit_map(MapAccess::new(self, len))
175    }
176
177    fn deserialize_struct<V: Visitor<'de>>(self, name: &'static str, _fields: &'static [&'static str], visitor: V) -> Result<V::Value> {
178        print_debug::<V>("Deserializer::deserialize_", "struct", self);
179        match name {
180            #[cfg(feature = "ext")]
181            crate::ext::TYPE_NAME | crate::timestamp::TYPE_NAME => {
182                if let Some(marker) = self.peek() {
183                    match marker {
184                        Marker::FixExt1
185                        | Marker::FixExt2
186                        | Marker::FixExt4
187                        | Marker::FixExt8
188                        | Marker::FixExt16
189                        | Marker::Ext8
190                        | Marker::Ext16
191                        | Marker::Ext32 => {
192                            let (header_len, data_len) = crate::ext::read_ext_len(&self.slice[self.index..])?;
193                            self.index += header_len - 1; // move forward minus 1 byte for the ext type (header_len includes the type byte)
194                            self.state = State::Ext(data_len);
195                            visitor.visit_seq(SeqAccess::new(self, 2))
196                        }
197                        _ => Err(Error::InvalidType),
198                    }
199                } else {
200                    Err(Error::EndOfBuffer(Marker::Reserved))
201                }
202            }
203            _ => self.deserialize_map(visitor),
204        }
205    }
206
207    fn deserialize_enum<V: Visitor<'de>>(self, _name: &'static str, _variants: &'static [&'static str], visitor: V) -> Result<V::Value> {
208        print_debug::<V>("Deserializer::deserialize_", "enum", self);
209        visitor.visit_enum(UnitVariantAccess::new(self))
210    }
211
212    fn deserialize_identifier<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
213        print_debug::<V>("Deserializer::deserialize_", "identifier", self);
214        let marker = self.peek().ok_or(Error::EndOfBuffer(Marker::Reserved))?;
215        #[allow(clippy::single_match)]
216        match marker {
217            Marker::FixMap(_) => {
218                let (_len, header_len) = crate::decode::read_map_len(&self.slice[self.index..])?;
219                self.index += header_len;
220            }
221            _ => {}
222        }
223        self.deserialize_str(visitor)
224    }
225
226    /// Unsupported. Can’t parse a value without knowing its expected type.
227    fn deserialize_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
228        let marker = self.peek().ok_or(Error::EndOfBuffer(Marker::Reserved))?;
229        match marker {
230            Marker::FixPos(_) => self.deserialize_u8(visitor),
231            Marker::FixMap(_) => self.deserialize_map(visitor),
232            Marker::Map16 => self.deserialize_map(visitor),
233            Marker::Map32 => self.deserialize_map(visitor),
234            Marker::FixArray(_) => self.deserialize_seq(visitor),
235            Marker::Array16 => self.deserialize_seq(visitor),
236            Marker::Array32 => self.deserialize_seq(visitor),
237            Marker::Str8 => self.deserialize_str(visitor),
238            Marker::Str16 => self.deserialize_str(visitor),
239            Marker::Str32 => self.deserialize_str(visitor),
240            Marker::Bin8 => self.deserialize_bytes(visitor),
241            Marker::Bin16 => self.deserialize_bytes(visitor),
242            Marker::Bin32 => self.deserialize_bytes(visitor),
243            Marker::FixStr(_) => self.deserialize_str(visitor),
244            Marker::F32 => self.deserialize_f32(visitor),
245            Marker::F64 => self.deserialize_f64(visitor),
246            Marker::I16 => self.deserialize_i16(visitor),
247            Marker::I32 => self.deserialize_i32(visitor),
248            Marker::I64 => self.deserialize_i64(visitor),
249            Marker::I8 => self.deserialize_i8(visitor),
250            Marker::U16 => self.deserialize_u16(visitor),
251            Marker::U32 => self.deserialize_u32(visitor),
252            Marker::U64 => self.deserialize_u64(visitor),
253            Marker::U8 => self.deserialize_u8(visitor),
254            Marker::True => self.deserialize_bool(visitor),
255            Marker::False => self.deserialize_bool(visitor),
256            _ => {
257                print_debug::<V>("Deserializer::deserialize_", "any", self);
258                let (_, n) = super::skip_any(&self.slice[self.index..])?;
259                self.index += n;
260                visitor.visit_unit()
261            }
262        }
263    }
264
265    /// Used to throw out fields that we don’t want to keep in our structs.
266    fn deserialize_ignored_any<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
267        print_debug::<V>("Deserializer::deserialize_", "ignored_any", self);
268        self.deserialize_any(visitor)
269    }
270
271    /// Unsupported. Use a more specific deserialize_* method
272    fn deserialize_unit<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
273        print_debug::<V>("Deserializer::deserialize_", "unit", self);
274        let marker = self.peek().ok_or(Error::EndOfBuffer(Marker::Reserved))?;
275        match marker {
276            Marker::Null | Marker::FixArray(0) => {
277                self.eat_byte();
278                visitor.visit_unit()
279            }
280            _ => Err(Error::InvalidType),
281        }
282    }
283
284    /// Unsupported. Use a more specific deserialize_* method
285    fn deserialize_unit_struct<V: Visitor<'de>>(self, _name: &'static str, visitor: V) -> Result<V::Value> {
286        print_debug::<V>("Deserializer::deserialize_", "unit_struct", self);
287        self.deserialize_unit(visitor)
288    }
289
290    fn deserialize_char<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
291        print_debug::<V>("Deserializer::deserialize_", "char", self);
292        //TODO Need to decide how to encode this. Probably as a str?
293        self.deserialize_str(visitor)
294    }
295
296    fn deserialize_newtype_struct<V: Visitor<'de>>(self, _name: &'static str, visitor: V) -> Result<V::Value> {
297        print_debug::<V>("Deserializer::deserialize_", "newtype_struct", self);
298        visitor.visit_newtype_struct(self)
299    }
300
301    fn deserialize_string<V: Visitor<'de>>(self, visitor: V) -> Result<V::Value> {
302        print_debug::<V>("Deserializer::deserialize_", "string", self);
303        self.deserialize_str(visitor)
304    }
305
306    fn deserialize_i64<V>(self, visitor: V) -> core::result::Result<V::Value, Self::Error>
307    where
308        V: Visitor<'de>,
309    {
310        print_debug::<V>("Deserializer::deserialize_", "i64", self);
311        let (value, len) = super::read_i64(&self.slice[self.index..])?;
312        self.index += len;
313        print_debug_value::<i64, i64>("Deserializer::deserialize_i64", self, &value);
314        visitor.visit_i64(value)
315    }
316
317    fn deserialize_u64<V>(self, visitor: V) -> core::result::Result<V::Value, Self::Error>
318    where
319        V: Visitor<'de>,
320    {
321        print_debug::<V>("Deserializer::deserialize_", "u64", self);
322        let (value, len) = super::read_u64(&self.slice[self.index..])?;
323        self.index += len;
324        print_debug_value::<u64, u64>("Deserializer::deserialize_u64", self, &value);
325        visitor.visit_u64(value)
326    }
327
328    fn deserialize_f64<V>(self, visitor: V) -> core::result::Result<V::Value, Self::Error>
329    where
330        V: Visitor<'de>,
331    {
332        print_debug::<V>("Deserializer::deserialize_", "f64", self);
333        let (value, len) = super::read_f64(&self.slice[self.index..])?;
334        self.index += len;
335        print_debug_value::<f64, f64>("Deserializer::deserialize_f64", self, &value);
336        visitor.visit_f64(value)
337    }
338
339    fn deserialize_bool<V>(self, visitor: V) -> core::result::Result<V::Value, Self::Error>
340    where
341        V: Visitor<'de>,
342    {
343        print_debug::<V>("Deserializer::deserialize_", "bool", self);
344        let (value, len) = super::read_bool(&self.slice[self.index..])?;
345        self.index += len;
346        print_debug_value::<bool, bool>("Deserializer::deserialize_bool", self, &value);
347        visitor.visit_bool(value)
348    }
349}
350
351impl ::serde::de::StdError for Error {}
352impl de::Error for Error {
353    #[cfg_attr(not(feature = "custom-error-messages"), allow(unused_variables))]
354    fn custom<T>(msg: T) -> Self
355    where
356        T: fmt::Display,
357    {
358        #[cfg(not(feature = "custom-error-messages"))]
359        {
360            Error::CustomError
361        }
362        #[cfg(all(not(feature = "std"), feature = "custom-error-messages"))]
363        {
364            use core::fmt::Write;
365
366            let mut string = heapless::String::new();
367            write!(string, "{:.512}", msg).unwrap();
368            Error::CustomErrorWithMessage(string)
369        }
370        #[cfg(all(feature = "std", feature = "custom-error-messages"))]
371        {
372            Error::CustomErrorWithMessage(msg.to_string())
373        }
374    }
375}
376
377impl fmt::Display for Error {
378    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
379        #[cfg(feature = "std")]
380        let s;
381        write!(
382            f,
383            "{}",
384            match self {
385                Error::InvalidType => "Unexpected type encountered.",
386                Error::OutOfBounds => "Index out of bounds.",
387                #[cfg(not(feature = "std"))]
388                Error::EndOfBuffer(_) => "End of buffer reached.",
389                #[cfg(feature = "std")]
390                Error::EndOfBuffer(m) => {
391                    s = format!("End of buffer reached: {}", u8::from(*m));
392                    s.as_str()
393                }
394                Error::CustomError => "Did not match deserializer's expected format.",
395                #[cfg(feature = "custom-error-messages")]
396                Error::CustomErrorWithMessage(msg) => msg.as_str(),
397                Error::NotAscii => "String contains non-ascii chars.",
398                Error::InvalidBoolean => "Invalid boolean marker.",
399                Error::InvalidBinType => "Invalid binary marker.",
400                Error::InvalidStringType => "Invalid string marker.",
401                Error::InvalidArrayType => "Invalid array marker.",
402                Error::InvalidMapType => "Invalid map marker.",
403                Error::InvalidNewTypeLength => "Invalid array length for newtype.",
404                Error::InvalidUtf8(_) => "Invalid Utf8.",
405            }
406        )
407    }
408}