rustybit_serde_bencode/
de.rs

1use core::str;
2
3use serde::{Deserialize, Deserializer};
4
5use crate::{Error, ErrorKind};
6
7pub struct BencodeDeserializer<'de> {
8    input: &'de [u8],
9    position: usize,
10}
11
12impl<'de> BencodeDeserializer<'de> {
13    pub fn from_str(input: &'de str) -> Self {
14        BencodeDeserializer {
15            input: input.as_bytes(),
16            position: 0,
17        }
18    }
19
20    pub fn from_bytes(input: &'de [u8]) -> Self {
21        BencodeDeserializer { input, position: 0 }
22    }
23
24    pub fn error(&self, kind: ErrorKind) -> Error {
25        let err: Error = kind.into();
26        err.set_position(self.position)
27    }
28
29    pub(crate) fn move_cursor(&mut self, by: usize) {
30        self.position += by;
31        self.input = &self.input[by..]
32    }
33}
34
35impl<'de> BencodeDeserializer<'de> {
36    fn parse_integer(&mut self) -> Result<i64, Error> {
37        match self.input.iter().position(|byte| *byte == b'e') {
38            Some(pos) => {
39                let integer = &self.input[..pos];
40                let parsed_integer: i64 = std::str::from_utf8(integer)
41                    .map_err(|_| self.error(ErrorKind::BadInputData("Bad integer")))?
42                    .parse()
43                    .map_err(|_| {
44                        self.error(ErrorKind::BadInputData(
45                            "Unnable to parse integer from the provided data",
46                        ))
47                    })?;
48                self.move_cursor(pos + 1);
49                Ok(parsed_integer)
50            }
51            _ => Err(self.error(ErrorKind::BadInputData("expected closing delimiter for integer"))),
52        }
53    }
54
55    fn parse_bytes(&mut self) -> Result<&'de [u8], Error> {
56        match self.input.iter().position(|byte| *byte == b':') {
57            Some(delim_pos) => {
58                let raw_bytes_len = &self.input[..delim_pos];
59                let bytes_len: usize = std::str::from_utf8(raw_bytes_len)
60                    .map_err(|_| self.error(ErrorKind::BadInputData("bytes length is not valid utf8")))?
61                    .parse()
62                    .map_err(|_| self.error(ErrorKind::BadInputData("expected valid bytes length")))?;
63                // Skip the delimiter as well
64                self.move_cursor(delim_pos + 1);
65                let raw_bytes = &self.input[..bytes_len];
66                self.move_cursor(bytes_len);
67                Ok(raw_bytes)
68            }
69            _ => Err(self.error(ErrorKind::BadInputData("expected bytes delimiter ':'"))),
70        }
71    }
72
73    fn parse_bytes_checked(&mut self) -> Result<&'de [u8], Error> {
74        match self
75            .input
76            .first()
77            .ok_or(self.error(ErrorKind::UnexpectedEof("bytes")))?
78        {
79            b'0'..=b'9' => self.parse_bytes(),
80            _ => Err(self.error(ErrorKind::BadInputData("expected bytes length"))),
81        }
82    }
83}
84
85impl<'de, 'a> serde::de::Deserializer<'de> for &'a mut BencodeDeserializer<'de> {
86    type Error = Error;
87
88    fn deserialize_bool<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
89    where
90        V: serde::de::Visitor<'de>,
91    {
92        Err(self.error(ErrorKind::Unsupported("bool")))
93    }
94
95    fn deserialize_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
96    where
97        V: serde::de::Visitor<'de>,
98    {
99        match self
100            .input
101            .first()
102            .ok_or(self.error(ErrorKind::UnexpectedEof("any bencode value")))?
103        {
104            b'd' => self.deserialize_map(visitor),
105            b'l' => self.deserialize_seq(visitor),
106            b'i' => self.deserialize_i64(visitor),
107            _ => self.deserialize_bytes(visitor),
108        }
109    }
110
111    fn deserialize_i8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
112    where
113        V: serde::de::Visitor<'de>,
114    {
115        self.deserialize_i64(visitor)
116    }
117
118    fn deserialize_i16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
119    where
120        V: serde::de::Visitor<'de>,
121    {
122        self.deserialize_i64(visitor)
123    }
124
125    fn deserialize_i32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
126    where
127        V: serde::de::Visitor<'de>,
128    {
129        self.deserialize_i64(visitor)
130    }
131
132    fn deserialize_i64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
133    where
134        V: serde::de::Visitor<'de>,
135    {
136        match self
137            .input
138            .first()
139            .ok_or(self.error(ErrorKind::UnexpectedEof("integer")))?
140        {
141            b'i' => {
142                // Skip int label
143                self.move_cursor(1);
144                visitor.visit_i64(self.parse_integer()?)
145            }
146            _ => Err(self.error(ErrorKind::BadInputData("expected integer label 'i'"))),
147        }
148    }
149
150    fn deserialize_u8<V>(self, visitor: V) -> Result<V::Value, Self::Error>
151    where
152        V: serde::de::Visitor<'de>,
153    {
154        self.deserialize_i64(visitor)
155    }
156
157    fn deserialize_u16<V>(self, visitor: V) -> Result<V::Value, Self::Error>
158    where
159        V: serde::de::Visitor<'de>,
160    {
161        self.deserialize_i64(visitor)
162    }
163
164    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value, Self::Error>
165    where
166        V: serde::de::Visitor<'de>,
167    {
168        self.deserialize_i64(visitor)
169    }
170
171    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value, Self::Error>
172    where
173        V: serde::de::Visitor<'de>,
174    {
175        self.deserialize_i64(visitor)
176    }
177
178    fn deserialize_f32<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
179    where
180        V: serde::de::Visitor<'de>,
181    {
182        Err(self.error(ErrorKind::Unsupported("f32")))
183    }
184
185    fn deserialize_f64<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
186    where
187        V: serde::de::Visitor<'de>,
188    {
189        Err(self.error(ErrorKind::Unsupported("f64")))
190    }
191
192    fn deserialize_char<V>(self, _visitor: V) -> Result<V::Value, Self::Error>
193    where
194        V: serde::de::Visitor<'de>,
195    {
196        Err(self.error(ErrorKind::Unsupported("char")))
197    }
198
199    fn deserialize_str<V>(self, visitor: V) -> Result<V::Value, Self::Error>
200    where
201        V: serde::de::Visitor<'de>,
202    {
203        let str = str::from_utf8(self.parse_bytes_checked()?)
204            .map_err(|_| self.error(ErrorKind::BadInputData("expected valid utf8 string")))?;
205        visitor.visit_borrowed_str(str)
206    }
207
208    fn deserialize_string<V>(self, visitor: V) -> Result<V::Value, Self::Error>
209    where
210        V: serde::de::Visitor<'de>,
211    {
212        self.deserialize_str(visitor)
213    }
214
215    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value, Self::Error>
216    where
217        V: serde::de::Visitor<'de>,
218    {
219        visitor.visit_borrowed_bytes(self.parse_bytes_checked()?)
220    }
221
222    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value, Self::Error>
223    where
224        V: serde::de::Visitor<'de>,
225    {
226        self.deserialize_bytes(visitor)
227    }
228
229    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value, Self::Error>
230    where
231        V: serde::de::Visitor<'de>,
232    {
233        match self
234            .input
235            .first()
236            .ok_or(self.error(ErrorKind::UnexpectedEof("null string")))?
237        {
238            b'0' => {
239                let _ = self.parse_bytes()?;
240                visitor.visit_none()
241            }
242            _ => visitor.visit_some(&mut *self),
243        }
244    }
245
246    fn deserialize_unit<V>(self, visitor: V) -> Result<V::Value, Self::Error>
247    where
248        V: serde::de::Visitor<'de>,
249    {
250        let bytes = self.parse_bytes_checked()?;
251        if !bytes.is_empty() {
252            return Err(self.error(ErrorKind::BadInputData("expected bencode string of length 0")));
253        }
254        visitor.visit_unit()
255    }
256
257    fn deserialize_unit_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value, Self::Error>
258    where
259        V: serde::de::Visitor<'de>,
260    {
261        self.deserialize_unit(visitor)
262    }
263
264    fn deserialize_newtype_struct<V>(self, _name: &'static str, visitor: V) -> Result<V::Value, Self::Error>
265    where
266        V: serde::de::Visitor<'de>,
267    {
268        visitor.visit_newtype_struct(self)
269    }
270
271    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value, Self::Error>
272    where
273        V: serde::de::Visitor<'de>,
274    {
275        match self
276            .input
277            .first()
278            .ok_or(self.error(ErrorKind::UnexpectedEof("bencode list")))?
279        {
280            b'l' => {
281                // Skip list label
282                self.move_cursor(1);
283                let value = visitor.visit_seq(BencodeAccessor { de: self });
284                match self
285                    .input
286                    .first()
287                    .ok_or(self.error(ErrorKind::UnexpectedEof("bencode list end")))?
288                {
289                    b'e' => {
290                        // Skip list end
291                        self.move_cursor(1);
292                        value
293                    }
294                    _ => Err(self.error(ErrorKind::BadInputData("expected bencode list end"))),
295                }
296            }
297            _ => Err(self.error(ErrorKind::BadInputData("expected bencode list"))),
298        }
299    }
300
301    fn deserialize_tuple<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
302    where
303        V: serde::de::Visitor<'de>,
304    {
305        self.deserialize_seq(visitor)
306    }
307
308    fn deserialize_tuple_struct<V>(self, _name: &'static str, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
309    where
310        V: serde::de::Visitor<'de>,
311    {
312        self.deserialize_seq(visitor)
313    }
314
315    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value, Self::Error>
316    where
317        V: serde::de::Visitor<'de>,
318    {
319        match self
320            .input
321            .first()
322            .ok_or(self.error(ErrorKind::UnexpectedEof("bencode dictionary")))?
323        {
324            b'd' => {
325                // Skip dict label
326                self.move_cursor(1);
327                let value = visitor.visit_map(BencodeAccessor { de: self });
328                match self
329                    .input
330                    .first()
331                    .ok_or(self.error(ErrorKind::UnexpectedEof("bencode dictionary end")))?
332                {
333                    b'e' => {
334                        // Skip dict end
335                        self.move_cursor(1);
336                        value
337                    }
338                    _ => Err(self.error(ErrorKind::BadInputData("expected bencode dictionary end"))),
339                }
340            }
341            _ => Err(self.error(ErrorKind::BadInputData("expected bencode dictionary"))),
342        }
343    }
344
345    fn deserialize_struct<V>(
346        self,
347        _name: &'static str,
348        _fields: &'static [&'static str],
349        visitor: V,
350    ) -> Result<V::Value, Self::Error>
351    where
352        V: serde::de::Visitor<'de>,
353    {
354        self.deserialize_map(visitor)
355    }
356
357    fn deserialize_enum<V>(
358        self,
359        _name: &'static str,
360        _variants: &'static [&'static str],
361        visitor: V,
362    ) -> Result<V::Value, Self::Error>
363    where
364        V: serde::de::Visitor<'de>,
365    {
366        match self
367            .input
368            .first()
369            .ok_or(self.error(ErrorKind::UnexpectedEof("bencode dictionary")))?
370        {
371            b'd' => {
372                // Skip dict label
373                self.move_cursor(1);
374                let value = visitor.visit_enum(BencodeAccessor { de: self });
375                match self
376                    .input
377                    .first()
378                    .ok_or(self.error(ErrorKind::UnexpectedEof("bencode dictionary end")))?
379                {
380                    b'e' => {
381                        // Skip dict end
382                        self.move_cursor(1);
383                        value
384                    }
385                    _ => Err(self.error(ErrorKind::BadInputData("expected bencode dictionary end"))),
386                }
387            }
388            _ => Err(self.error(ErrorKind::BadInputData("expected bencode dictionary"))),
389        }
390    }
391
392    fn deserialize_identifier<V>(self, visitor: V) -> Result<V::Value, Self::Error>
393    where
394        V: serde::de::Visitor<'de>,
395    {
396        self.deserialize_str(visitor)
397    }
398
399    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value, Self::Error>
400    where
401        V: serde::de::Visitor<'de>,
402    {
403        self.deserialize_any(visitor)
404    }
405}
406
407struct BencodeAccessor<'a, 'de> {
408    de: &'a mut BencodeDeserializer<'de>,
409}
410
411impl<'a, 'de> serde::de::SeqAccess<'de> for BencodeAccessor<'a, 'de> {
412    type Error = Error;
413
414    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>, Self::Error>
415    where
416        T: serde::de::DeserializeSeed<'de>,
417    {
418        if *self
419            .de
420            .input
421            .first()
422            .ok_or(self.de.error(ErrorKind::UnexpectedEof("next seq element or end")))?
423            == b'e'
424        {
425            return Ok(None);
426        }
427        seed.deserialize(&mut *self.de).map(Some)
428    }
429}
430
431impl<'a, 'de> serde::de::MapAccess<'de> for BencodeAccessor<'a, 'de> {
432    type Error = Error;
433
434    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
435    where
436        K: serde::de::DeserializeSeed<'de>,
437    {
438        if *self
439            .de
440            .input
441            .first()
442            .ok_or(self.de.error(ErrorKind::UnexpectedEof("next dict key or end")))?
443            == b'e'
444        {
445            return Ok(None);
446        }
447        seed.deserialize(&mut *self.de).map(Some)
448    }
449
450    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
451    where
452        V: serde::de::DeserializeSeed<'de>,
453    {
454        seed.deserialize(&mut *self.de)
455    }
456}
457
458impl<'de, 'a> serde::de::EnumAccess<'de> for BencodeAccessor<'a, 'de> {
459    type Error = Error;
460    type Variant = Self;
461
462    fn variant_seed<V>(self, seed: V) -> Result<(V::Value, Self::Variant), Self::Error>
463    where
464        V: serde::de::DeserializeSeed<'de>,
465    {
466        Ok((seed.deserialize(&mut *self.de)?, self))
467    }
468}
469
470impl<'de, 'a> serde::de::VariantAccess<'de> for BencodeAccessor<'a, 'de> {
471    type Error = Error;
472
473    fn unit_variant(self) -> Result<(), Self::Error> {
474        let bytes = self.de.parse_bytes_checked()?;
475        if !bytes.is_empty() {
476            return Err(self
477                .de
478                .error(ErrorKind::BadInputData("expected bencode string of length 0")));
479        }
480        Ok(())
481    }
482
483    fn newtype_variant_seed<T>(self, seed: T) -> Result<T::Value, Self::Error>
484    where
485        T: serde::de::DeserializeSeed<'de>,
486    {
487        seed.deserialize(&mut *self.de)
488    }
489
490    fn tuple_variant<V>(self, _len: usize, visitor: V) -> Result<V::Value, Self::Error>
491    where
492        V: serde::de::Visitor<'de>,
493    {
494        self.de.deserialize_seq(visitor)
495    }
496
497    fn struct_variant<V>(self, _fields: &'static [&'static str], visitor: V) -> Result<V::Value, Self::Error>
498    where
499        V: serde::de::Visitor<'de>,
500    {
501        self.de.deserialize_map(visitor)
502    }
503}
504
505pub fn from_bytes<'de, T: Deserialize<'de>>(input: &'de [u8]) -> Result<T, Error> {
506    let mut deserializer = BencodeDeserializer::from_bytes(input);
507    let deserialized = T::deserialize(&mut deserializer)?;
508    if !deserializer.input.is_empty() {
509        return Err(ErrorKind::Custom(format!(
510            "Trailing bytes after deserialization: {}",
511            deserializer.input.len()
512        ))
513        .into());
514    }
515    Ok(deserialized)
516}
517
518pub fn from_str<'de, T: Deserialize<'de>>(input: &'de str) -> Result<T, Error> {
519    from_bytes(input.as_bytes())
520}