xmlity_quick_xml/
de.rs

1use std::ops::Deref;
2
3use quick_xml::{
4    events::{attributes::Attribute, BytesStart, Event},
5    name::QName as QuickName,
6    NsReader,
7};
8
9use xmlity::{
10    de::{self, Error as _, Unexpected, Visitor},
11    Deserialize, ExpandedName, LocalName, QName, XmlNamespace,
12};
13
14use crate::{xml_namespace_from_resolve_result, HasQuickXmlAlternative, OwnedQuickName};
15
16use super::Error;
17
18pub fn from_str<'a, T>(s: &'a str) -> Result<T, Error>
19where
20    T: Deserialize<'a>,
21{
22    let mut deserializer = Deserializer::from(s.as_bytes());
23    T::deserialize(&mut deserializer)
24}
25
26pub enum Peeked<'a> {
27    None,
28    Text,
29    CData,
30    Element {
31        name: QName<'a>,
32        namespace: Option<XmlNamespace<'a>>,
33    },
34}
35
36#[derive(Debug, Clone)]
37pub struct Deserializer<'i> {
38    reader: NsReader<&'i [u8]>,
39    current_depth: i16,
40    peeked_event: Option<Event<'i>>,
41}
42
43impl<'i> From<NsReader<&'i [u8]>> for Deserializer<'i> {
44    fn from(reader: NsReader<&'i [u8]>) -> Self {
45        Self::new(reader)
46    }
47}
48
49impl<'i> From<&'i [u8]> for Deserializer<'i> {
50    fn from(buffer: &'i [u8]) -> Self {
51        Self::new(NsReader::from_reader(buffer))
52    }
53}
54
55impl<'i> Deserializer<'i> {
56    pub fn new(reader: NsReader<&'i [u8]>) -> Self {
57        Self {
58            reader,
59            current_depth: 0,
60            peeked_event: None,
61        }
62    }
63
64    fn read_event(&mut self) -> Result<Option<Event<'i>>, Error> {
65        while let Ok(event) = self.reader.read_event() {
66            match event {
67                Event::Eof => return Ok(None),
68                Event::Text(text) if text.clone().into_inner().trim_ascii().is_empty() => {
69                    continue;
70                }
71                event => return Ok(Some(event)),
72            }
73        }
74
75        Ok(None)
76    }
77
78    fn read_until_element_end(&mut self, name: &QuickName, depth: i16) -> Result<(), Error> {
79        while let Some(event) = self.peek_event() {
80            let correct_name = match event {
81                Event::End(ref e) if e.name() == *name => true,
82                Event::Eof => return Err(Error::Unexpected(Unexpected::Eof)),
83                _ => false,
84            };
85
86            if correct_name && self.current_depth == depth {
87                return Ok(());
88            }
89
90            self.next_event();
91        }
92
93        Err(Error::Unexpected(de::Unexpected::Eof))
94    }
95
96    pub fn peek_event(&mut self) -> Option<&Event<'i>> {
97        if self.peeked_event.is_some() {
98            return self.peeked_event.as_ref();
99        }
100
101        self.peeked_event = self.read_event().ok().flatten();
102        self.peeked_event.as_ref()
103    }
104
105    pub fn next_event(&mut self) -> Option<Event<'i>> {
106        let event = if self.peeked_event.is_some() {
107            self.peeked_event.take()
108        } else {
109            self.read_event().ok().flatten()
110        };
111
112        if matches!(event, Some(Event::End(_))) {
113            self.current_depth -= 1;
114        }
115        if matches!(event, Some(Event::Start(_))) {
116            self.current_depth += 1;
117        }
118
119        event
120    }
121
122    pub fn create_sub_seq_access<'p>(&'p mut self) -> SubSeqAccess<'p, 'i> {
123        SubSeqAccess::Filled {
124            current: Some(self.clone()),
125            parent: self,
126        }
127    }
128
129    pub fn try_deserialize<T, E>(
130        &mut self,
131        closure: impl for<'a> FnOnce(&'a mut Deserializer<'i>) -> Result<T, E>,
132    ) -> Result<T, E> {
133        let mut sub_deserializer = self.clone();
134        let res = closure(&mut sub_deserializer);
135
136        if res.is_ok() {
137            *self = sub_deserializer;
138        }
139        res
140    }
141
142    pub fn expand_name<'a>(&self, qname: QuickName<'a>) -> ExpandedName<'a> {
143        let (resolve_result, _) = self.reader.resolve(qname, false);
144        let namespace = xml_namespace_from_resolve_result(resolve_result).map(|ns| ns.into_owned());
145
146        ExpandedName::new(LocalName::from_quick_xml(qname.local_name()), namespace)
147    }
148
149    pub fn resolve_bytes_start<'a>(&self, bytes_start: &'a BytesStart<'a>) -> ExpandedName<'a> {
150        self.expand_name(bytes_start.name())
151    }
152
153    pub fn resolve_attribute<'a>(&self, attribute: &'a Attribute<'a>) -> ExpandedName<'a> {
154        self.expand_name(attribute.key)
155    }
156}
157
158pub struct ElementAccess<'a, 'r> {
159    deserializer: Option<&'a mut Deserializer<'r>>,
160    attribute_index: usize,
161    bytes_start: BytesStart<'r>,
162    start_depth: i16,
163    empty: bool,
164}
165
166impl Drop for ElementAccess<'_, '_> {
167    fn drop(&mut self) {
168        self.try_end().ok();
169    }
170}
171
172impl<'r> ElementAccess<'_, 'r> {
173    fn deserializer(&self) -> &Deserializer<'r> {
174        self.deserializer
175            .as_ref()
176            .expect("Should not be called after ElementAccess has been consumed")
177    }
178
179    fn try_end(&mut self) -> Result<(), Error> {
180        if self.empty {
181            return Ok(());
182        }
183
184        if let Some(deserializer) = self.deserializer.as_mut() {
185            deserializer.read_until_element_end(&self.bytes_start.name(), self.start_depth)?;
186        }
187        Ok(())
188    }
189}
190
191pub struct AttributeAccess<'a> {
192    name: ExpandedName<'a>,
193    value: String,
194}
195
196impl<'a> de::AttributeAccess<'a> for AttributeAccess<'a> {
197    type Error = Error;
198
199    fn name(&self) -> ExpandedName<'_> {
200        self.name.clone()
201    }
202
203    fn value(&self) -> &str {
204        self.value.as_str()
205    }
206}
207
208struct EmptySeqAccess;
209
210impl<'de> de::SeqAccess<'de> for EmptySeqAccess {
211    type Error = Error;
212    type SubAccess<'s>
213        = EmptySeqAccess
214    where
215        Self: 's;
216
217    fn next_element_seq<T>(&mut self) -> Result<Option<T>, Self::Error>
218    where
219        T: Deserialize<'de>,
220    {
221        Ok(None)
222    }
223
224    fn next_element<T>(&mut self) -> Result<Option<T>, Self::Error>
225    where
226        T: Deserialize<'de>,
227    {
228        Ok(None)
229    }
230
231    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
232        Ok(EmptySeqAccess)
233    }
234}
235
236struct AttributeDeserializer<'a> {
237    name: ExpandedName<'a>,
238    value: String,
239}
240
241impl<'a> xmlity::Deserializer<'a> for AttributeDeserializer<'a> {
242    type Error = Error;
243
244    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
245    where
246        V: Visitor<'a>,
247    {
248        visitor.visit_attribute(AttributeAccess {
249            name: self.name,
250            value: self.value,
251        })
252    }
253
254    fn deserialize_seq<V>(self, _: V) -> Result<V::Value, Self::Error>
255    where
256        V: Visitor<'a>,
257    {
258        Err(Self::Error::Unexpected(de::Unexpected::Seq))
259    }
260}
261
262pub struct SubAttributesAccess<'a, 'r> {
263    deserializer: &'a Deserializer<'r>,
264    bytes_start: &'a BytesStart<'r>,
265    attribute_index: usize,
266    write_attribute_to: &'a mut usize,
267}
268
269impl Drop for SubAttributesAccess<'_, '_> {
270    fn drop(&mut self) {
271        *self.write_attribute_to = self.attribute_index;
272    }
273}
274
275fn next_attribute<'a, 'de, T: Deserialize<'de>>(
276    deserializer: &'a Deserializer<'_>,
277    bytes_start: &'a BytesStart<'_>,
278    attribute_index: &'a mut usize,
279) -> Result<Option<T>, Error> {
280    let (attribute, key) = loop {
281        let Some(attribute) = bytes_start.attributes().nth(*attribute_index) else {
282            return Ok(None);
283        };
284
285        let attribute = attribute?;
286
287        let key = deserializer.resolve_attribute(&attribute).into_owned();
288
289        const XMLNS_NAMESPACE: XmlNamespace<'static> =
290            XmlNamespace::new_dangerous("http://www.w3.org/2000/xmlns/");
291
292        if key.namespace() == Some(&XMLNS_NAMESPACE) {
293            *attribute_index += 1;
294            continue;
295        }
296
297        break (attribute, key);
298    };
299
300    let value = String::from_utf8(attribute.value.into_owned())
301        .expect("attribute value should be valid utf8");
302
303    let deserializer = AttributeDeserializer { name: key, value };
304
305    let res = T::deserialize(deserializer)?;
306
307    // Only increment the index if the deserialization was successful
308    *attribute_index += 1;
309
310    Ok(Some(res))
311}
312
313impl<'de> de::AttributesAccess<'de> for SubAttributesAccess<'_, 'de> {
314    type Error = Error;
315
316    type SubAccess<'a>
317        = SubAttributesAccess<'a, 'de>
318    where
319        Self: 'a;
320
321    fn next_attribute<T>(&mut self) -> Result<Option<T>, Self::Error>
322    where
323        T: Deserialize<'de>,
324    {
325        next_attribute(
326            self.deserializer,
327            self.bytes_start,
328            &mut self.attribute_index,
329        )
330    }
331
332    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
333        Ok(Self::SubAccess {
334            deserializer: self.deserializer,
335            bytes_start: self.bytes_start,
336            attribute_index: self.attribute_index,
337            write_attribute_to: self.write_attribute_to,
338        })
339    }
340}
341
342impl<'de> de::AttributesAccess<'de> for ElementAccess<'_, 'de> {
343    type Error = Error;
344
345    type SubAccess<'a>
346        = SubAttributesAccess<'a, 'de>
347    where
348        Self: 'a;
349
350    fn next_attribute<T>(&mut self) -> Result<Option<T>, Self::Error>
351    where
352        T: Deserialize<'de>,
353    {
354        next_attribute(
355            self.deserializer
356                .as_ref()
357                .expect("deserializer should be set"),
358            &self.bytes_start,
359            &mut self.attribute_index,
360        )
361    }
362
363    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
364        Ok(Self::SubAccess {
365            bytes_start: &self.bytes_start,
366            attribute_index: self.attribute_index,
367            write_attribute_to: &mut self.attribute_index,
368            deserializer: self
369                .deserializer
370                .as_ref()
371                .expect("Should not be called after ElementAccess has been consumed"),
372        })
373    }
374}
375
376impl<'a, 'de> de::ElementAccess<'de> for ElementAccess<'a, 'de> {
377    type ChildrenAccess = ChildrenAccess<'a, 'de>;
378
379    fn name(&self) -> ExpandedName<'_> {
380        self.deserializer().resolve_bytes_start(&self.bytes_start)
381    }
382
383    fn children(mut self) -> Result<Self::ChildrenAccess, Self::Error> {
384        Ok(if self.empty {
385            ChildrenAccess::Empty
386        } else {
387            let deserializer = self
388                .deserializer
389                .take()
390                .expect("Should not be called after ElementAccess has been consumed");
391
392            ChildrenAccess::Filled {
393                expected_end: QName::from_quick_xml(self.bytes_start.name()).into_owned(),
394                start_depth: self.start_depth,
395                deserializer,
396            }
397        })
398    }
399}
400
401pub enum ChildrenAccess<'a, 'r> {
402    Filled {
403        expected_end: QName<'static>,
404        deserializer: &'a mut Deserializer<'r>,
405        start_depth: i16,
406    },
407    Empty,
408}
409
410impl Drop for ChildrenAccess<'_, '_> {
411    fn drop(&mut self) {
412        let ChildrenAccess::Filled {
413            expected_end,
414            deserializer,
415            start_depth,
416        } = self
417        else {
418            return;
419        };
420
421        deserializer
422            .read_until_element_end(&OwnedQuickName::new(expected_end).as_ref(), *start_depth)
423            .unwrap();
424    }
425}
426
427impl<'r> de::SeqAccess<'r> for ChildrenAccess<'_, 'r> {
428    type Error = Error;
429
430    type SubAccess<'s>
431        = SubSeqAccess<'s, 'r>
432    where
433        Self: 's;
434
435    fn next_element<T>(&mut self) -> Result<Option<T>, Self::Error>
436    where
437        T: Deserialize<'r>,
438    {
439        let ChildrenAccess::Filled {
440            expected_end,
441            deserializer,
442            start_depth,
443        } = self
444        else {
445            return Ok(None);
446        };
447
448        if deserializer.peek_event().is_none() {
449            return Ok(None);
450        }
451
452        let current_depth = deserializer.current_depth;
453
454        if let Some(Event::End(bytes_end)) = deserializer.peek_event() {
455            if OwnedQuickName::new(expected_end).as_ref() != bytes_end.name()
456                && current_depth == *start_depth
457            {
458                return Err(Error::custom(format!(
459                    "Expected end of element {}, found end of element {}",
460                    expected_end,
461                    QName::from_quick_xml(bytes_end.name())
462                )));
463            }
464
465            return Ok(None);
466        }
467
468        deserializer
469            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize(deserializer))
470            .map(Some)
471    }
472
473    fn next_element_seq<T>(&mut self) -> Result<Option<T>, Self::Error>
474    where
475        T: Deserialize<'r>,
476    {
477        let ChildrenAccess::Filled {
478            expected_end,
479            deserializer,
480            start_depth,
481        } = self
482        else {
483            return Ok(None);
484        };
485
486        if deserializer.peek_event().is_none() {
487            return Ok(None);
488        }
489
490        let current_depth = deserializer.current_depth;
491
492        if let Some(Event::End(bytes_end)) = deserializer.peek_event() {
493            if OwnedQuickName::new(expected_end).as_ref() != bytes_end.name()
494                && current_depth == *start_depth
495            {
496                return Err(Error::custom(format!(
497                    "Expected end of element {}, found end of element {}",
498                    expected_end,
499                    QName::from_quick_xml(bytes_end.name())
500                )));
501            }
502
503            return Ok(None);
504        }
505
506        deserializer
507            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize_seq(deserializer))
508            .map(Some)
509    }
510
511    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
512        let ChildrenAccess::Filled { deserializer, .. } = self else {
513            return Ok(SubSeqAccess::Empty);
514        };
515
516        Ok(deserializer.create_sub_seq_access())
517    }
518}
519
520pub struct SeqAccess<'a, 'r> {
521    deserializer: &'a mut Deserializer<'r>,
522}
523
524#[allow(clippy::large_enum_variant)]
525pub enum SubSeqAccess<'p, 'r> {
526    Filled {
527        current: Option<Deserializer<'r>>,
528        parent: &'p mut Deserializer<'r>,
529    },
530    Empty,
531}
532
533impl Drop for SubSeqAccess<'_, '_> {
534    fn drop(&mut self) {
535        if let SubSeqAccess::Filled { current, parent } = self {
536            **parent = current.take().expect("SubSeqAccess dropped twice");
537        }
538    }
539}
540
541impl<'r> de::SeqAccess<'r> for SubSeqAccess<'_, 'r> {
542    type Error = Error;
543
544    type SubAccess<'s>
545        = SubSeqAccess<'s, 'r>
546    where
547        Self: 's;
548
549    fn next_element_seq<T>(&mut self) -> Result<Option<T>, Self::Error>
550    where
551        T: Deserialize<'r>,
552    {
553        let Self::Filled { current, .. } = self else {
554            return Ok(None);
555        };
556
557        let deserializer = current.as_mut().expect("SubSeqAccess used after drop");
558
559        if deserializer.peek_event().is_none() {
560            return Ok(None);
561        }
562
563        deserializer
564            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize_seq(deserializer))
565            .map(Some)
566    }
567
568    fn next_element<T>(&mut self) -> Result<Option<T>, Self::Error>
569    where
570        T: Deserialize<'r>,
571    {
572        let Self::Filled { current, .. } = self else {
573            return Ok(None);
574        };
575
576        let deserializer = current.as_mut().expect("SubSeqAccess used after drop");
577
578        if deserializer.peek_event().is_none() {
579            return Ok(None);
580        }
581
582        deserializer
583            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize(deserializer))
584            .map(Some)
585    }
586
587    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
588        let Self::Filled { current, .. } = self else {
589            return Ok(SubSeqAccess::Empty);
590        };
591
592        Ok(current
593            .as_mut()
594            .expect("SubSeqAccess used after drop")
595            .create_sub_seq_access())
596    }
597}
598
599impl<'r> de::SeqAccess<'r> for SeqAccess<'_, 'r> {
600    type Error = Error;
601
602    type SubAccess<'s>
603        = SubSeqAccess<'s, 'r>
604    where
605        Self: 's;
606
607    fn next_element_seq<T>(&mut self) -> Result<Option<T>, Self::Error>
608    where
609        T: Deserialize<'r>,
610    {
611        if self.deserializer.peek_event().is_none() {
612            return Ok(None);
613        }
614
615        self.deserializer
616            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize_seq(deserializer))
617            .map(Some)
618    }
619
620    fn next_element<T>(&mut self) -> Result<Option<T>, Self::Error>
621    where
622        T: Deserialize<'r>,
623    {
624        if self.deserializer.peek_event().is_none() {
625            return Ok(None);
626        }
627
628        self.deserializer
629            .try_deserialize(|deserializer| Deserialize::<'r>::deserialize(deserializer))
630            .map(Some)
631    }
632
633    fn sub_access(&mut self) -> Result<Self::SubAccess<'_>, Self::Error> {
634        Ok(SubSeqAccess::Filled {
635            current: Some(self.deserializer.clone()),
636            parent: self.deserializer,
637        })
638    }
639}
640
641impl<'r> xmlity::Deserializer<'r> for &mut Deserializer<'r> {
642    type Error = Error;
643
644    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
645    where
646        V: de::Visitor<'r>,
647    {
648        let event = self.next_event().ok_or_else(|| Error::custom("EOF"))?;
649
650        match event {
651            Event::Start(bytes_start) => {
652                let element_name = OwnedQuickName(bytes_start.name().0.to_owned());
653
654                let value = Visitor::visit_element(
655                    visitor,
656                    ElementAccess {
657                        bytes_start,
658                        start_depth: self.current_depth,
659                        deserializer: Some(self),
660                        empty: false,
661                        attribute_index: 0,
662                    },
663                )?;
664
665                let end_event = self.next_event().ok_or_else(|| Error::custom("EOF"))?;
666
667                let success = if let Event::End(bytes_end) = &end_event {
668                    bytes_end.name() == element_name.as_ref()
669                } else {
670                    false
671                };
672
673                if success {
674                    Ok(value)
675                } else {
676                    Err(Error::custom("No matching end element"))
677                }
678            }
679            Event::End(_bytes_end) => Err(Error::custom("Unexpected end element")),
680            Event::Empty(bytes_start) => visitor.visit_element(ElementAccess {
681                bytes_start: bytes_start.into_owned().clone(),
682                start_depth: self.current_depth,
683                deserializer: Some(self),
684                empty: true,
685                attribute_index: 0,
686            }),
687            Event::Text(bytes_text) => visitor.visit_text(bytes_text.deref()),
688            Event::CData(bytes_cdata) => visitor.visit_cdata(bytes_cdata.deref()),
689            Event::Comment(bytes_text) => visitor.visit_comment(bytes_text.deref()),
690            Event::Decl(bytes_decl) => visitor.visit_decl(
691                bytes_decl.version()?,
692                match bytes_decl.encoding() {
693                    Some(Ok(encoding)) => Some(encoding),
694                    Some(Err(err)) => return Err(Error::QuickXml(err.into())),
695                    None => None,
696                },
697                match bytes_decl.standalone() {
698                    Some(Ok(standalone)) => Some(standalone),
699                    Some(Err(err)) => return Err(Error::QuickXml(err.into())),
700                    None => None,
701                },
702            ),
703            Event::PI(bytes_pi) => visitor.visit_pi(bytes_pi.deref()),
704            Event::DocType(bytes_text) => visitor.visit_doctype(bytes_text.deref()),
705            Event::Eof => Err(Error::custom("Unexpected EOF")),
706        }
707    }
708
709    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
710    where
711        V: de::Visitor<'r>,
712    {
713        visitor.visit_seq(SeqAccess { deserializer: self })
714    }
715}
716
717impl<'r> xmlity::Deserializer<'r> for Deserializer<'r> {
718    type Error = Error;
719
720    fn deserialize_any<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
721    where
722        V: de::Visitor<'r>,
723    {
724        (&mut self).deserialize_any(visitor)
725    }
726
727    fn deserialize_seq<V>(mut self, visitor: V) -> Result<V::Value, Self::Error>
728    where
729        V: de::Visitor<'r>,
730    {
731        (&mut self).deserialize_seq(visitor)
732    }
733}