rustfs_rsc/xml/
de.rs

1use bytes::Bytes;
2use serde::de::IntoDeserializer;
3use std::io::{BufRead, BufReader, Read};
4
5use crate::utils::trim_bytes;
6
7use super::error::{Error, Result};
8
9/// A convenience method for deserialize some object from a reader.
10pub fn from_reader<'de, R: Read, T: serde::de::Deserialize<'de>>(reader: R) -> Result<T> {
11    let mut de =
12        serde_xml_rs::Deserializer::new_from_reader(reader).non_contiguous_seq_elements(true);
13    T::deserialize(&mut de).map_err(Into::into)
14}
15
16/// A convenience method for deserialize some object from a str.
17pub fn from_str<'de, T: serde::de::Deserialize<'de>>(s: &'de str) -> Result<T> {
18    from_reader(s.as_bytes())
19}
20
21/// A convenience method for deserialize some object from a string.
22pub fn from_string<'de, T: serde::de::Deserialize<'de>>(s: String) -> Result<T> {
23    from_reader(s.as_bytes())
24}
25
26/// A convenience method for deserialize some object from a [Bytes].
27pub fn from_bytes<'de, T: serde::de::Deserialize<'de>>(s: &'de Bytes) -> Result<T> {
28    from_reader(s.as_ref())
29}
30
31macro_rules! deserialize_type {
32    ($deserialize:ident, $visit:ident) => {
33        fn $deserialize<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
34        where
35            V: serde::de::Visitor<'de>,
36        {
37            let tag = self.top_tag()?;
38            let result = visitor.$visit(tag.content().parse()?);
39            if result.is_ok() {
40                self.close_tag()?;
41            }
42            result
43        }
44    };
45}
46
47macro_rules! custom_error {
48    ($info:expr) => {
49        Err(Error::Custom {
50            field: $info.to_owned(),
51        })
52    };
53}
54
55#[derive(Debug, Clone, PartialEq)]
56enum EventType {
57    Statement,
58    EmptyTag,
59    Tag,
60    TagClose,
61    Comment,
62}
63
64#[derive(Debug, Clone)]
65struct Event {
66    pub type_: EventType,
67    pub value: Vec<u8>,
68    pub content: Vec<u8>,
69}
70
71impl Event {
72    pub fn new(type_: EventType, value: Vec<u8>, content: Vec<u8>) -> Self {
73        Self {
74            type_,
75            value,
76            content,
77        }
78    }
79
80    #[inline]
81    fn is_tag(&self) -> bool {
82        self.type_ == EventType::Tag
83    }
84
85    #[inline]
86    fn tag<'d>(&'d self) -> std::borrow::Cow<'d, str> {
87        String::from_utf8_lossy(&self.value)
88    }
89
90    #[inline]
91    fn content(&self) -> std::borrow::Cow<'_, str> {
92        String::from_utf8_lossy(&self.content)
93    }
94}
95
96struct Deserializer<R: Read> {
97    source: BufReader<R>,
98    tags: Vec<Event>,
99    next_tag_cache: Option<Event>,
100    init: bool,
101}
102
103impl<R: Read> Deserializer<R> {
104    pub fn new(r: R) -> Self {
105        Self {
106            source: BufReader::new(r),
107            tags: vec![],
108            next_tag_cache: None,
109            init: false,
110        }
111    }
112
113    #[inline]
114    fn top_tag(&self) -> Result<&Event> {
115        if let Some(tag) = self.tags.last() {
116            Ok(tag)
117        } else {
118            custom_error!("error tag")
119        }
120    }
121
122    fn next_tag(&mut self) -> Result<Event> {
123        let tag = self.next_tag_cache.take();
124        if let Some(tag) = tag {
125            return Ok(tag);
126        }
127        loop {
128            let event = self.next_event()?;
129            match event.type_ {
130                EventType::EmptyTag | EventType::TagClose | EventType::Tag => return Ok(event),
131                _ => continue,
132            }
133        }
134    }
135
136    fn next_tag_ref(&mut self) -> Result<&Event> {
137        if self.next_tag_cache.is_none() {
138            self.next_tag_cache = Some(self.next_tag()?);
139        }
140        Ok(unsafe { self.next_tag_cache.as_ref().unwrap_unchecked() })
141    }
142
143    fn close_tag(&mut self) -> Result<()> {
144        let next_tag = self.next_tag()?;
145        let top_tag = self.top_tag()?;
146        if !next_tag.is_tag() {
147            if top_tag.value == next_tag.value {
148                self.tags.pop();
149            } else {
150                return Err(Error::UnexpectedToken {
151                    token: top_tag.tag().to_string(),
152                    found: next_tag.tag().to_string(),
153                });
154            }
155        } else {
156            self.tags.push(next_tag);
157            self.close_tag()?;
158            self.close_tag()?;
159        }
160        Ok(())
161    }
162
163    fn next_event(&mut self) -> Result<Event> {
164        if !self.init {
165            let mut buf = vec![];
166            self.source.read_until(b'<', &mut buf)?;
167            self.init = true;
168        }
169        let mut buf = vec![];
170        self.source.read_until(b'>', &mut buf)?;
171
172        if buf.len() == 0 {
173            return custom_error!("Incorrect XML syntax");
174        }
175
176        let data = if buf.ends_with(b"/>") {
177            (EventType::EmptyTag, &buf[..buf.len() - 2])
178        } else if buf.starts_with(b"/") {
179            (EventType::TagClose, &buf[1..buf.len() - 1])
180        } else if buf.starts_with(b"!--") {
181            (EventType::Comment, &buf[..buf.len() - 1])
182        } else if buf.starts_with(b"?xml") {
183            (EventType::Statement, &buf[..buf.len() - 1])
184        } else {
185            let mut i = 0;
186            for b in buf.iter() {
187                i += 1;
188                if *b == b' ' {
189                    break;
190                }
191            }
192            (EventType::Tag, &buf[..i - 1])
193        };
194        let mut content = vec![];
195        self.source.read_until(b'<', &mut content)?;
196        let i = if content.len() > 1 {
197            content.len() - 1
198        } else {
199            0
200        };
201        let content = trim_bytes(&content[..i]).to_owned();
202        let event = Event::new(data.0, data.1.to_owned(), content);
203        return Ok(event);
204    }
205}
206
207impl<'de, 'a, R: Read> serde::de::Deserializer<'de> for &'a mut Deserializer<R> {
208    type Error = Error;
209
210    fn deserialize_any<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
211    where
212        V: serde::de::Visitor<'de>,
213    {
214        self.close_tag()?;
215        visitor.visit_unit()
216    }
217
218    deserialize_type!(deserialize_bool, visit_bool);
219    deserialize_type!(deserialize_i8, visit_i8);
220    deserialize_type!(deserialize_i16, visit_i16);
221    deserialize_type!(deserialize_i32, visit_i32);
222    deserialize_type!(deserialize_i64, visit_i64);
223    deserialize_type!(deserialize_u8, visit_u8);
224    deserialize_type!(deserialize_u16, visit_u16);
225    deserialize_type!(deserialize_u32, visit_u32);
226    deserialize_type!(deserialize_u64, visit_u64);
227    deserialize_type!(deserialize_f32, visit_f32);
228    deserialize_type!(deserialize_f64, visit_f64);
229    deserialize_type!(deserialize_string, visit_string);
230
231    fn deserialize_str<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
232    where
233        V: serde::de::Visitor<'de>,
234    {
235        let tag = self.top_tag()?;
236        let result = visitor.visit_str(&tag.content());
237        if result.is_ok() {
238            self.close_tag()?;
239        }
240        result
241    }
242
243    fn deserialize_bytes<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
244    where
245        V: serde::de::Visitor<'de>,
246    {
247        let tag = self.top_tag()?;
248        let result = visitor.visit_bytes(&tag.content);
249        if result.is_ok() {
250            self.close_tag()?;
251        }
252        result
253    }
254
255    fn deserialize_byte_buf<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
256    where
257        V: serde::de::Visitor<'de>,
258    {
259        let tag = self.top_tag()?;
260        let result = visitor.visit_byte_buf(tag.content.clone());
261        if result.is_ok() {
262            self.close_tag()?;
263        }
264        result
265    }
266
267    serde::forward_to_deserialize_any! {
268        char
269        map
270        unit
271        unit_struct
272        newtype_struct
273        tuple
274        tuple_struct
275        identifier
276    }
277
278    fn deserialize_option<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
279    where
280        V: serde::de::Visitor<'de>,
281    {
282        visitor.visit_some(self)
283    }
284
285    fn deserialize_seq<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
286    where
287        V: serde::de::Visitor<'de>,
288    {
289        let tag = self.top_tag()?.clone();
290        let s = SeqAccess {
291            de: self,
292            tag,
293            is_over: false,
294        };
295        visitor.visit_seq(s)
296    }
297
298    fn deserialize_struct<V>(
299        self,
300        name: &'static str,
301        _: &'static [&'static str],
302        visitor: V,
303    ) -> std::result::Result<V::Value, Self::Error>
304    where
305        V: serde::de::Visitor<'de>,
306    {
307        if !self.init {
308            loop {
309                let event = self.next_tag()?;
310                if event.type_ == EventType::Tag && event.value == name.as_bytes() {
311                    self.tags.push(event);
312                    break;
313                }
314            }
315        }
316        let map_value = visitor.visit_map(self)?;
317        Ok(map_value)
318    }
319
320    fn deserialize_enum<V>(
321        self,
322        _: &'static str,
323        _: &'static [&'static str],
324        visitor: V,
325    ) -> std::result::Result<V::Value, Self::Error>
326    where
327        V: serde::de::Visitor<'de>,
328    {
329        let tag = self.top_tag()?;
330        let result = visitor.visit_enum(tag.content().into_deserializer());
331        if result.is_ok() {
332            self.close_tag()?;
333        }
334        result
335    }
336
337    fn deserialize_ignored_any<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
338    where
339        V: serde::de::Visitor<'de>,
340    {
341        self.close_tag()?;
342        visitor.visit_unit()
343    }
344}
345
346impl<'de, 'a, R: Read> serde::de::MapAccess<'de> for &'a mut Deserializer<R> {
347    type Error = Error;
348
349    fn next_key_seed<K>(&mut self, seed: K) -> std::result::Result<Option<K::Value>, Self::Error>
350    where
351        K: serde::de::DeserializeSeed<'de>,
352    {
353        loop {
354            let event = self.next_tag_ref()?;
355            if event.is_tag() {
356                let event = self.next_tag()?;
357                let cs = event.clone();
358                self.tags.push(event);
359                return seed.deserialize(cs.tag().into_deserializer()).map(Some);
360            } else {
361                self.close_tag()?;
362                return Ok(None);
363            }
364        }
365    }
366
367    fn next_value_seed<V>(&mut self, seed: V) -> std::result::Result<V::Value, Self::Error>
368    where
369        V: serde::de::DeserializeSeed<'de>,
370    {
371        seed.deserialize(&mut **self)
372    }
373}
374
375struct SeqAccess<'a, R: Read> {
376    de: &'a mut Deserializer<R>,
377    tag: Event,
378    is_over: bool,
379}
380
381impl<'de, 'a, R: Read> serde::de::SeqAccess<'de> for SeqAccess<'a, R> {
382    type Error = Error;
383
384    fn next_element_seed<T>(
385        &mut self,
386        seed: T,
387    ) -> std::result::Result<Option<T::Value>, Self::Error>
388    where
389        T: serde::de::DeserializeSeed<'de>,
390    {
391        if self.is_over {
392            return Ok(None);
393        };
394        let result = seed.deserialize(&mut *self.de).map(Some);
395        let next_tag = self.de.next_tag_ref()?;
396        if next_tag.is_tag() && next_tag.value == self.tag.value {
397            let next_tag = self.de.next_tag()?;
398            self.de.tags.push(next_tag);
399            self.is_over = false;
400        } else {
401            self.is_over = true;
402        }
403        result
404    }
405}