wasm_msgpack/ext/
mod.rs

1#[cfg(feature = "timestamp")]
2pub mod timestamp;
3
4use crate::encode::{Binary, Error, SerializeIntoSlice};
5#[allow(unused_imports)]
6use crate::marker::Marker;
7#[allow(unused_imports)]
8use byteorder::{BigEndian, ByteOrder};
9use core::{convert::TryInto, fmt::Display, marker::PhantomData};
10use serde::{ser::SerializeStruct, Deserialize, Serialize};
11
12#[repr(transparent)]
13#[derive(Copy, Clone, PartialEq, Eq, PartialOrd, Ord, Serialize, Deserialize)]
14#[serde(transparent)]
15struct ExtType(i8);
16
17impl core::fmt::Debug for ExtType {
18    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
19        f.debug_tuple("ExtType").field(&self.0).finish()
20    }
21}
22
23impl Display for ExtType {
24    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
25        if self.0 == -1 {
26            f.write_str("-1 (Timestamp)")
27        } else {
28            write!(f, "{}", self.0)
29        }
30    }
31}
32
33#[derive(PartialEq, Eq)]
34pub struct Ext<'a> {
35    typ: ExtType,
36    data: Binary<'a>,
37}
38
39impl<'a> core::fmt::Debug for Ext<'a> {
40    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
41        f.debug_struct("Ext").field("typ", &self.typ).field("data", &self.data).finish()
42    }
43}
44
45impl<'a> Ext<'a> {
46    pub const fn new_from_binary(typ: i8, data: Binary<'a>) -> Self {
47        Ext { typ: ExtType(typ), data }
48    }
49    pub const fn new(typ: i8, data: &'a [u8]) -> Self {
50        Ext {
51            typ: ExtType(typ),
52            data: Binary::new(data),
53        }
54    }
55    #[inline(always)]
56    pub const fn get_type(&self) -> i8 {
57        self.typ.0
58    }
59    #[inline(always)]
60    pub const fn get_data(&self) -> &Binary<'a> {
61        &self.data
62    }
63}
64
65#[inline]
66pub(crate) fn get_ext_start(data_len: usize) -> Result<(Marker, usize), Error> {
67    let (marker, header_len) = match data_len {
68        #[cfg(feature = "fixext")]
69        1 | 2 | 4 | 8 | 16 => {
70            let header_len = 2;
71            let marker = match data_len {
72                1 => Marker::FixExt1,
73                2 => Marker::FixExt2,
74                4 => Marker::FixExt4,
75                8 => Marker::FixExt8,
76                16 => Marker::FixExt16,
77                _ => unreachable!(),
78            };
79            (marker, header_len)
80        }
81        #[cfg(feature = "ext8")]
82        0..=0xff => (Marker::Ext8, 3),
83        #[cfg(feature = "ext16")]
84        0x100..=0xffff => (Marker::Ext16, 4),
85        #[cfg(feature = "ext32")]
86        0x1_0000..=0xffff_ffff => (Marker::Ext32, 6),
87        _ => return Err(Error::OutOfBounds),
88    };
89    Ok((marker, header_len))
90}
91
92pub(crate) fn read_ext_len<B: zerocopy::ByteSlice>(buf: B) -> Result<(usize, usize), crate::decode::Error> {
93    if buf.len() < 2 {
94        return Err(crate::decode::Error::EndOfBuffer(Marker::Ext8));
95    }
96    let marker: Marker = buf[0].try_into().unwrap();
97    let (header_len, data_len) = match marker {
98        #[cfg(feature = "fixext")]
99        Marker::FixExt1 => (2, 1),
100        #[cfg(feature = "fixext")]
101        Marker::FixExt2 => (2, 2),
102        #[cfg(feature = "fixext")]
103        Marker::FixExt4 => (2, 4),
104        #[cfg(feature = "fixext")]
105        Marker::FixExt8 => (2, 8),
106        #[cfg(feature = "fixext")]
107        Marker::FixExt16 => (2, 16),
108        #[cfg(feature = "ext8")]
109        Marker::Ext8 => (3, buf[1] as usize),
110        #[cfg(feature = "ext16")]
111        Marker::Ext16 => {
112            if buf.len() < 4 {
113                return Err(crate::decode::Error::EndOfBuffer(Marker::Ext16));
114            }
115            (4, BigEndian::read_u16(&buf[1..3]) as usize)
116        }
117        #[cfg(feature = "ext32")]
118        Marker::Ext32 => {
119            if buf.len() < 6 {
120                return Err(crate::decode::Error::EndOfBuffer(Marker::Ext32));
121            }
122            (6, BigEndian::read_u32(&buf[1..5]) as usize)
123        }
124        _ => return Err(crate::decode::Error::InvalidType),
125    };
126    // let typ = buf[header_len - 1] as i8;
127    if buf.len() >= header_len + data_len {
128        Ok((header_len, data_len))
129    } else {
130        Err(crate::decode::Error::EndOfBuffer(Marker::Ext8))
131    }
132}
133
134pub fn serialize_ext(value: &Ext<'_>, buf: &mut [u8]) -> Result<usize, Error> {
135    let typ = value.get_type();
136    let data = value.get_data();
137
138    let (marker, header_len) = get_ext_start(data.len())?;
139    if buf.len() < data.len() + header_len {
140        return Err(Error::EndOfBuffer);
141    }
142    buf[0] = marker.to_u8();
143    if header_len > 2 {
144        #[cfg(all(feature = "ext8", not(any(feature = "ext16", feature = "ext32"))))]
145        {
146            buf[1] = data.len() as u8;
147        }
148        #[cfg(any(feature = "ext16", feature = "ext32"))]
149        {
150            BigEndian::write_uint(&mut buf[1..], data.len() as u64, header_len - 2);
151        }
152    }
153    buf[header_len - 1] = typ as u8;
154    buf[header_len..data.len() + header_len].clone_from_slice(data);
155    Ok(data.len() + header_len)
156}
157
158pub fn try_deserialize_ext(buf: &[u8]) -> Result<Ext<'_>, crate::decode::Error> {
159    if buf.len() < 3 {
160        return Err(crate::decode::Error::EndOfBuffer(Marker::Ext8));
161    }
162    let (header_len, data_len) = read_ext_len(buf)?;
163    if buf.len() < header_len + data_len {
164        return Err(crate::decode::Error::EndOfBuffer(Marker::Ext8));
165    }
166    let typ = buf[header_len - 1] as i8;
167    return Ok(Ext::new(typ, &buf[header_len..header_len + data_len]));
168}
169
170impl<'a> SerializeIntoSlice for &Ext<'a> {
171    fn write_into_slice(&self, buf: &mut [u8]) -> Result<usize, Error> {
172        serialize_ext(self, buf)
173    }
174}
175
176pub(crate) const TYPE_NAME: &str = "$Ext";
177pub(crate) const FIELD_TYPE_NAME: &str = "type";
178pub(crate) const FIELD_DATA_NAME: &str = "data";
179
180#[cfg(feature = "serde")]
181impl<'a> ::serde::ser::Serialize for Ext<'a> {
182    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
183    where
184        S: serde::Serializer,
185    {
186        let mut s = serializer.serialize_struct(TYPE_NAME, 2)?;
187        s.serialize_field(FIELD_TYPE_NAME, &self.typ)?;
188        s.serialize_field(FIELD_DATA_NAME, &self.data)?;
189        s.end()
190    }
191}
192
193#[cfg(feature = "serde")]
194impl<'de: 'a, 'a> ::serde::de::Deserialize<'de> for Ext<'a> {
195    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
196    where
197        D: ::serde::de::Deserializer<'de>,
198    {
199        struct ExtVisitor<'a>(PhantomData<&'a ()>);
200
201        impl<'de: 'a, 'a> ::serde::de::Visitor<'de> for ExtVisitor<'a> {
202            type Value = Ext<'a>;
203            fn expecting(&self, formatter: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
204                formatter.write_str("a MsgPack ext data")
205            }
206            fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
207            where
208                A: serde::de::SeqAccess<'de>,
209            {
210                // This will be called from the MsgPack deserializer
211                // if it detects an ext tag or an array tag
212                let typ: Option<ExtType> = seq.next_element()?;
213                let data: Option<Binary> = seq.next_element()?;
214                match (typ, data) {
215                    (Some(typ), Some(data)) => Ok(Ext::new_from_binary(typ.0, data)),
216                    (Some(_), None) => Err(::serde::de::Error::custom("ext data not found")),
217                    _ => Err(::serde::de::Error::custom("ext type field not found")),
218                }
219            }
220            fn visit_map<V>(self, mut map: V) -> Result<Self::Value, V::Error>
221            where
222                V: ::serde::de::MapAccess<'de>,
223            {
224                // This will probably be called from other deserializers
225                // or if the MsgPack deserializer sees a map tag
226
227                enum Field {
228                    // Seconds,
229                    // Nanoseconds,
230                    Type,
231                    Data,
232                }
233
234                impl<'de> ::serde::de::Deserialize<'de> for Field {
235                    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
236                    where
237                        D: serde::Deserializer<'de>,
238                    {
239                        struct FieldVisitor;
240                        impl<'de> ::serde::de::Visitor<'de> for FieldVisitor {
241                            type Value = Field;
242                            fn expecting(&self, formatter: &mut core::fmt::Formatter) -> core::fmt::Result {
243                                // formatter.write_str("`seconds`, `secs`, `s`, `nanoseconds`, `nanos` or `ns`")
244                                formatter.write_str("`type` or `data`")
245                            }
246                            fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
247                            where
248                                E: serde::de::Error,
249                            {
250                                match v {
251                                    // "seconds" | "secs" | "s" => Ok(Field::Seconds),
252                                    // "nanoseconds" | "nanos" | "ns" => Ok(Field::Nanoseconds),
253                                    "type" => Ok(Field::Type),
254                                    "data" => Ok(Field::Data),
255                                    _ => Err(::serde::de::Error::unknown_field(
256                                        v,
257                                        &["type", "data"],
258                                        // &["seconds", "secs", "s", "nanoseconds", "nanos", "ns"],
259                                    )),
260                                }
261                            }
262                        }
263
264                        deserializer.deserialize_identifier(FieldVisitor)
265                    }
266                }
267
268                // if let Ok(typ) = map.next_key::<Field>() {
269                //     let typ = match typ {
270                //         Some(v) => v,
271                //         None => return Err(::serde::de::Error::custom("ext type not found")),
272                //     };
273                //     let v: Binary = map.next_value()?;
274                //     Ok(Ext::new(typ.0, &v))
275                // } else {
276                let mut typ = None;
277                let mut data = None;
278                // let mut seconds = None;
279                // let mut nanoseconds = None;
280
281                loop {
282                    match map.next_key::<Field>() {
283                        Ok(Some(Field::Type)) => typ = Some(map.next_value::<ExtType>()?),
284                        Ok(Some(Field::Data)) => data = Some(map.next_value::<Binary>()?),
285                        Ok(None) => break, // no more fields
286                        Err(_e) => {
287                            // Error, could be an unknown field name
288                            // println!("{:?}", e);
289                            map.next_value()?;
290                        }
291                    }
292                }
293
294                match (typ, data) {
295                    (Some(typ), Some(data)) => Ok(Ext::new_from_binary(typ.0, data)),
296                    (Some(_), None) => Err(::serde::de::Error::custom("ext data not found")),
297                    _ => Err(::serde::de::Error::custom("ext type field not found")),
298                }
299                // }
300            }
301        }
302
303        static FIELDS: [&str; 2] = [FIELD_TYPE_NAME, FIELD_DATA_NAME];
304        deserializer.deserialize_struct(TYPE_NAME, &FIELDS, ExtVisitor(PhantomData))
305    }
306}