simple_tlv/
decoder.rs

1use core::convert::TryInto;
2use crate::{Decodable, ErrorKind, Length, Result};
3
4/// SIMPLE-TLV decoder.
5#[derive(Debug)]
6pub struct Decoder<'a> {
7    /// Byte slice being decoded.
8    ///
9    /// In the event an error was previously encountered this will be set to
10    /// `None` to prevent further decoding while in a bad state.
11    bytes: Option<&'a [u8]>,
12
13    /// Position within the decoded slice.
14    position: Length,
15}
16
17impl<'a> Decoder<'a> {
18    /// Create a new decoder for the given byte slice.
19    pub fn new(bytes: &'a [u8]) -> Self {
20        Self {
21            bytes: Some(bytes),
22            position: Length::zero(),
23        }
24    }
25
26    /// Decode a value which impls the [`Decodable`] trait.
27    pub fn decode<T: Decodable<'a>>(&mut self) -> Result<T> {
28        if self.is_failed() {
29            self.error(ErrorKind::Failed)?;
30        }
31
32        T::decode(self).map_err(|e| {
33            self.bytes.take();
34            e.nested(self.position)
35        })
36    }
37
38    /// Decode a TaggedValue with tag checked to be as expected, returning the value
39    pub fn decode_tagged_value<V: Decodable<'a>>(&mut self, tag: crate::Tag) -> Result<V> {
40        let tagged: crate::TaggedSlice = self.decode()?;
41        tagged.tag().assert_eq(tag)?;
42        Self::new(tagged.as_bytes()).decode()
43    }
44
45    /// Decode a TaggedSlice with tag checked to be as expected, returning the value
46    pub fn decode_tagged_slice(&mut self, tag: crate::Tag) -> Result<&'a [u8]> {
47        let tagged: crate::TaggedSlice = self.decode()?;
48        tagged.tag().assert_eq(tag)?;
49        Ok(tagged.as_bytes())
50    }
51
52    /// Return an error with the given [`ErrorKind`], annotating it with
53    /// context about where the error occurred.
54    pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
55        self.bytes.take();
56        Err(kind.at(self.position))
57    }
58
59    /// Did the decoding operation fail due to an error?
60    pub fn is_failed(&self) -> bool {
61        self.bytes.is_none()
62    }
63
64    /// Finish decoding, returning the given value if there is no
65    /// remaining data, or an error otherwise
66    pub fn finish<T>(self, value: T) -> Result<T> {
67        if self.is_failed() {
68            Err(ErrorKind::Failed.at(self.position))
69        } else if !self.is_finished() {
70            Err(ErrorKind::TrailingData {
71                decoded: self.position,
72                remaining: self.remaining_len()?,
73            }
74            .at(self.position))
75        } else {
76            Ok(value)
77        }
78    }
79
80    /// Have we decoded all of the bytes in this [`Decoder`]?
81    ///
82    /// Returns `false` if we're not finished decoding or if a fatal error
83    /// has occurred.
84    pub fn is_finished(&self) -> bool {
85        self.remaining().map(|rem| rem.is_empty()).unwrap_or(false)
86    }
87
88    /// Decode a single byte, updating the internal cursor.
89    pub(crate) fn byte(&mut self) -> Result<u8> {
90        match self.bytes(1u8)? {
91            [byte] => Ok(*byte),
92            _ => self.error(ErrorKind::Truncated),
93        }
94    }
95
96    /// Obtain a slice of bytes of the given length from the current cursor
97    /// position, or return an error if we have insufficient data.
98    pub(crate) fn bytes(&mut self, len: impl TryInto<Length>) -> Result<&'a [u8]> {
99        if self.is_failed() {
100            self.error(ErrorKind::Failed)?;
101        }
102
103        let len = len
104            .try_into()
105            .or_else(|_| self.error(ErrorKind::Overflow))?;
106
107        let result = self
108            .remaining()?
109            .get(..len.to_usize())
110            .ok_or(ErrorKind::Truncated)?;
111
112        self.position = (self.position + len)?;
113        Ok(result)
114    }
115
116    /// Obtain the remaining bytes in this decoder from the current cursor
117    /// position.
118    fn remaining(&self) -> Result<&'a [u8]> {
119        self.bytes
120            .and_then(|b| b.get(self.position.into()..))
121            .ok_or_else(|| ErrorKind::Truncated.at(self.position))
122    }
123
124    /// Get the number of bytes still remaining in the buffer.
125    fn remaining_len(&self) -> Result<Length> {
126        self.remaining()?.len().try_into()
127    }
128}
129
130impl<'a> From<&'a [u8]> for Decoder<'a> {
131    fn from(bytes: &'a [u8]) -> Decoder<'a> {
132        Decoder::new(bytes)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use core::convert::TryFrom;
139    use crate::{Decodable, Tag, TaggedSlice};
140
141    #[test]
142    fn zero_length() {
143        let buf: &[u8] = &[0x2A, 0x00];
144        let ts = TaggedSlice::from_bytes(buf).unwrap();
145        assert_eq!(ts, TaggedSlice::from(Tag::try_from(42).unwrap(), &[]).unwrap());
146    }
147}
148// #[cfg(test)]
149// mod tests {
150//     use super::Decoder;
151//     use crate::{Decodable, ErrorKind, Length, Tag};
152
153//     #[test]
154//     fn truncated_message() {
155//         let mut decoder = Decoder::new(&[]);
156//         let err = bool::decode(&mut decoder).err().unwrap();
157//         assert_eq!(ErrorKind::Truncated, err.kind());
158//         assert_eq!(Some(Length::zero()), err.position());
159//     }
160
161//     #[test]
162//     fn invalid_field_length() {
163//         let mut decoder = Decoder::new(&[0x02, 0x01]);
164//         let err = i8::decode(&mut decoder).err().unwrap();
165//         assert_eq!(ErrorKind::Length { tag: Tag::Integer }, err.kind());
166//         assert_eq!(Some(Length::from(2u8)), err.position());
167//     }
168
169//     #[test]
170//     fn trailing_data() {
171//         let mut decoder = Decoder::new(&[0x02, 0x01, 0x2A, 0x00]);
172//         let x = decoder.decode().unwrap();
173//         assert_eq!(42i8, x);
174
175//         let err = decoder.finish(x).err().unwrap();
176//         assert_eq!(
177//             ErrorKind::TrailingData {
178//                 decoded: 3u8.into(),
179//                 remaining: 1u8.into()
180//             },
181//             err.kind()
182//         );
183//         assert_eq!(Some(Length::from(3u8)), err.position());
184//     }
185// }