1use bytes::Bytes;
2use serde::de::IntoDeserializer;
3use std::io::{BufRead, BufReader, Read};
4
5use crate::utils::trim_bytes;
6
7use super::error::{Error, Result};
8
9pub fn from_reader<'de, R: Read, T: serde::de::Deserialize<'de>>(reader: R) -> Result<T> {
11 let mut de =
12 serde_xml_rs::Deserializer::new_from_reader(reader).non_contiguous_seq_elements(true);
13 T::deserialize(&mut de).map_err(Into::into)
14}
15
16pub fn from_str<'de, T: serde::de::Deserialize<'de>>(s: &'de str) -> Result<T> {
18 from_reader(s.as_bytes())
19}
20
21pub fn from_string<'de, T: serde::de::Deserialize<'de>>(s: String) -> Result<T> {
23 from_reader(s.as_bytes())
24}
25
26pub fn from_bytes<'de, T: serde::de::Deserialize<'de>>(s: &'de Bytes) -> Result<T> {
28 from_reader(s.as_ref())
29}
30
31macro_rules! deserialize_type {
32 ($deserialize:ident, $visit:ident) => {
33 fn $deserialize<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
34 where
35 V: serde::de::Visitor<'de>,
36 {
37 let tag = self.top_tag()?;
38 let result = visitor.$visit(tag.content().parse()?);
39 if result.is_ok() {
40 self.close_tag()?;
41 }
42 result
43 }
44 };
45}
46
47macro_rules! custom_error {
48 ($info:expr) => {
49 Err(Error::Custom {
50 field: $info.to_owned(),
51 })
52 };
53}
54
55#[derive(Debug, Clone, PartialEq)]
56enum EventType {
57 Statement,
58 EmptyTag,
59 Tag,
60 TagClose,
61 Comment,
62}
63
64#[derive(Debug, Clone)]
65struct Event {
66 pub type_: EventType,
67 pub value: Vec<u8>,
68 pub content: Vec<u8>,
69}
70
71impl Event {
72 pub fn new(type_: EventType, value: Vec<u8>, content: Vec<u8>) -> Self {
73 Self {
74 type_,
75 value,
76 content,
77 }
78 }
79
80 #[inline]
81 fn is_tag(&self) -> bool {
82 self.type_ == EventType::Tag
83 }
84
85 #[inline]
86 fn tag<'d>(&'d self) -> std::borrow::Cow<'d, str> {
87 String::from_utf8_lossy(&self.value)
88 }
89
90 #[inline]
91 fn content(&self) -> std::borrow::Cow<'_, str> {
92 String::from_utf8_lossy(&self.content)
93 }
94}
95
96struct Deserializer<R: Read> {
97 source: BufReader<R>,
98 tags: Vec<Event>,
99 next_tag_cache: Option<Event>,
100 init: bool,
101}
102
103impl<R: Read> Deserializer<R> {
104 pub fn new(r: R) -> Self {
105 Self {
106 source: BufReader::new(r),
107 tags: vec![],
108 next_tag_cache: None,
109 init: false,
110 }
111 }
112
113 #[inline]
114 fn top_tag(&self) -> Result<&Event> {
115 if let Some(tag) = self.tags.last() {
116 Ok(tag)
117 } else {
118 custom_error!("error tag")
119 }
120 }
121
122 fn next_tag(&mut self) -> Result<Event> {
123 let tag = self.next_tag_cache.take();
124 if let Some(tag) = tag {
125 return Ok(tag);
126 }
127 loop {
128 let event = self.next_event()?;
129 match event.type_ {
130 EventType::EmptyTag | EventType::TagClose | EventType::Tag => return Ok(event),
131 _ => continue,
132 }
133 }
134 }
135
136 fn next_tag_ref(&mut self) -> Result<&Event> {
137 if self.next_tag_cache.is_none() {
138 self.next_tag_cache = Some(self.next_tag()?);
139 }
140 Ok(unsafe { self.next_tag_cache.as_ref().unwrap_unchecked() })
141 }
142
143 fn close_tag(&mut self) -> Result<()> {
144 let next_tag = self.next_tag()?;
145 let top_tag = self.top_tag()?;
146 if !next_tag.is_tag() {
147 if top_tag.value == next_tag.value {
148 self.tags.pop();
149 } else {
150 return Err(Error::UnexpectedToken {
151 token: top_tag.tag().to_string(),
152 found: next_tag.tag().to_string(),
153 });
154 }
155 } else {
156 self.tags.push(next_tag);
157 self.close_tag()?;
158 self.close_tag()?;
159 }
160 Ok(())
161 }
162
163 fn next_event(&mut self) -> Result<Event> {
164 if !self.init {
165 let mut buf = vec![];
166 self.source.read_until(b'<', &mut buf)?;
167 self.init = true;
168 }
169 let mut buf = vec![];
170 self.source.read_until(b'>', &mut buf)?;
171
172 if buf.len() == 0 {
173 return custom_error!("Incorrect XML syntax");
174 }
175
176 let data = if buf.ends_with(b"/>") {
177 (EventType::EmptyTag, &buf[..buf.len() - 2])
178 } else if buf.starts_with(b"/") {
179 (EventType::TagClose, &buf[1..buf.len() - 1])
180 } else if buf.starts_with(b"!--") {
181 (EventType::Comment, &buf[..buf.len() - 1])
182 } else if buf.starts_with(b"?xml") {
183 (EventType::Statement, &buf[..buf.len() - 1])
184 } else {
185 let mut i = 0;
186 for b in buf.iter() {
187 i += 1;
188 if *b == b' ' {
189 break;
190 }
191 }
192 (EventType::Tag, &buf[..i - 1])
193 };
194 let mut content = vec![];
195 self.source.read_until(b'<', &mut content)?;
196 let i = if content.len() > 1 {
197 content.len() - 1
198 } else {
199 0
200 };
201 let content = trim_bytes(&content[..i]).to_owned();
202 let event = Event::new(data.0, data.1.to_owned(), content);
203 return Ok(event);
204 }
205}
206
207impl<'de, 'a, R: Read> serde::de::Deserializer<'de> for &'a mut Deserializer<R> {
208 type Error = Error;
209
210 fn deserialize_any<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
211 where
212 V: serde::de::Visitor<'de>,
213 {
214 self.close_tag()?;
215 visitor.visit_unit()
216 }
217
218 deserialize_type!(deserialize_bool, visit_bool);
219 deserialize_type!(deserialize_i8, visit_i8);
220 deserialize_type!(deserialize_i16, visit_i16);
221 deserialize_type!(deserialize_i32, visit_i32);
222 deserialize_type!(deserialize_i64, visit_i64);
223 deserialize_type!(deserialize_u8, visit_u8);
224 deserialize_type!(deserialize_u16, visit_u16);
225 deserialize_type!(deserialize_u32, visit_u32);
226 deserialize_type!(deserialize_u64, visit_u64);
227 deserialize_type!(deserialize_f32, visit_f32);
228 deserialize_type!(deserialize_f64, visit_f64);
229 deserialize_type!(deserialize_string, visit_string);
230
231 fn deserialize_str<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
232 where
233 V: serde::de::Visitor<'de>,
234 {
235 let tag = self.top_tag()?;
236 let result = visitor.visit_str(&tag.content());
237 if result.is_ok() {
238 self.close_tag()?;
239 }
240 result
241 }
242
243 fn deserialize_bytes<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
244 where
245 V: serde::de::Visitor<'de>,
246 {
247 let tag = self.top_tag()?;
248 let result = visitor.visit_bytes(&tag.content);
249 if result.is_ok() {
250 self.close_tag()?;
251 }
252 result
253 }
254
255 fn deserialize_byte_buf<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
256 where
257 V: serde::de::Visitor<'de>,
258 {
259 let tag = self.top_tag()?;
260 let result = visitor.visit_byte_buf(tag.content.clone());
261 if result.is_ok() {
262 self.close_tag()?;
263 }
264 result
265 }
266
267 serde::forward_to_deserialize_any! {
268 char
269 map
270 unit
271 unit_struct
272 newtype_struct
273 tuple
274 tuple_struct
275 identifier
276 }
277
278 fn deserialize_option<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
279 where
280 V: serde::de::Visitor<'de>,
281 {
282 visitor.visit_some(self)
283 }
284
285 fn deserialize_seq<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
286 where
287 V: serde::de::Visitor<'de>,
288 {
289 let tag = self.top_tag()?.clone();
290 let s = SeqAccess {
291 de: self,
292 tag,
293 is_over: false,
294 };
295 visitor.visit_seq(s)
296 }
297
298 fn deserialize_struct<V>(
299 self,
300 name: &'static str,
301 _: &'static [&'static str],
302 visitor: V,
303 ) -> std::result::Result<V::Value, Self::Error>
304 where
305 V: serde::de::Visitor<'de>,
306 {
307 if !self.init {
308 loop {
309 let event = self.next_tag()?;
310 if event.type_ == EventType::Tag && event.value == name.as_bytes() {
311 self.tags.push(event);
312 break;
313 }
314 }
315 }
316 let map_value = visitor.visit_map(self)?;
317 Ok(map_value)
318 }
319
320 fn deserialize_enum<V>(
321 self,
322 _: &'static str,
323 _: &'static [&'static str],
324 visitor: V,
325 ) -> std::result::Result<V::Value, Self::Error>
326 where
327 V: serde::de::Visitor<'de>,
328 {
329 let tag = self.top_tag()?;
330 let result = visitor.visit_enum(tag.content().into_deserializer());
331 if result.is_ok() {
332 self.close_tag()?;
333 }
334 result
335 }
336
337 fn deserialize_ignored_any<V>(self, visitor: V) -> std::result::Result<V::Value, Self::Error>
338 where
339 V: serde::de::Visitor<'de>,
340 {
341 self.close_tag()?;
342 visitor.visit_unit()
343 }
344}
345
346impl<'de, 'a, R: Read> serde::de::MapAccess<'de> for &'a mut Deserializer<R> {
347 type Error = Error;
348
349 fn next_key_seed<K>(&mut self, seed: K) -> std::result::Result<Option<K::Value>, Self::Error>
350 where
351 K: serde::de::DeserializeSeed<'de>,
352 {
353 loop {
354 let event = self.next_tag_ref()?;
355 if event.is_tag() {
356 let event = self.next_tag()?;
357 let cs = event.clone();
358 self.tags.push(event);
359 return seed.deserialize(cs.tag().into_deserializer()).map(Some);
360 } else {
361 self.close_tag()?;
362 return Ok(None);
363 }
364 }
365 }
366
367 fn next_value_seed<V>(&mut self, seed: V) -> std::result::Result<V::Value, Self::Error>
368 where
369 V: serde::de::DeserializeSeed<'de>,
370 {
371 seed.deserialize(&mut **self)
372 }
373}
374
375struct SeqAccess<'a, R: Read> {
376 de: &'a mut Deserializer<R>,
377 tag: Event,
378 is_over: bool,
379}
380
381impl<'de, 'a, R: Read> serde::de::SeqAccess<'de> for SeqAccess<'a, R> {
382 type Error = Error;
383
384 fn next_element_seed<T>(
385 &mut self,
386 seed: T,
387 ) -> std::result::Result<Option<T::Value>, Self::Error>
388 where
389 T: serde::de::DeserializeSeed<'de>,
390 {
391 if self.is_over {
392 return Ok(None);
393 };
394 let result = seed.deserialize(&mut *self.de).map(Some);
395 let next_tag = self.de.next_tag_ref()?;
396 if next_tag.is_tag() && next_tag.value == self.tag.value {
397 let next_tag = self.de.next_tag()?;
398 self.de.tags.push(next_tag);
399 self.is_over = false;
400 } else {
401 self.is_over = true;
402 }
403 result
404 }
405}