simple_tlv/decoder.rs
1use core::convert::TryInto;
2use crate::{Decodable, ErrorKind, Length, Result};
3
4/// SIMPLE-TLV decoder.
5#[derive(Debug)]
6pub struct Decoder<'a> {
7 /// Byte slice being decoded.
8 ///
9 /// In the event an error was previously encountered this will be set to
10 /// `None` to prevent further decoding while in a bad state.
11 bytes: Option<&'a [u8]>,
12
13 /// Position within the decoded slice.
14 position: Length,
15}
16
17impl<'a> Decoder<'a> {
18 /// Create a new decoder for the given byte slice.
19 pub fn new(bytes: &'a [u8]) -> Self {
20 Self {
21 bytes: Some(bytes),
22 position: Length::zero(),
23 }
24 }
25
26 /// Decode a value which impls the [`Decodable`] trait.
27 pub fn decode<T: Decodable<'a>>(&mut self) -> Result<T> {
28 if self.is_failed() {
29 self.error(ErrorKind::Failed)?;
30 }
31
32 T::decode(self).map_err(|e| {
33 self.bytes.take();
34 e.nested(self.position)
35 })
36 }
37
38 /// Decode a TaggedValue with tag checked to be as expected, returning the value
39 pub fn decode_tagged_value<V: Decodable<'a>>(&mut self, tag: crate::Tag) -> Result<V> {
40 let tagged: crate::TaggedSlice = self.decode()?;
41 tagged.tag().assert_eq(tag)?;
42 Self::new(tagged.as_bytes()).decode()
43 }
44
45 /// Decode a TaggedSlice with tag checked to be as expected, returning the value
46 pub fn decode_tagged_slice(&mut self, tag: crate::Tag) -> Result<&'a [u8]> {
47 let tagged: crate::TaggedSlice = self.decode()?;
48 tagged.tag().assert_eq(tag)?;
49 Ok(tagged.as_bytes())
50 }
51
52 /// Return an error with the given [`ErrorKind`], annotating it with
53 /// context about where the error occurred.
54 pub fn error<T>(&mut self, kind: ErrorKind) -> Result<T> {
55 self.bytes.take();
56 Err(kind.at(self.position))
57 }
58
59 /// Did the decoding operation fail due to an error?
60 pub fn is_failed(&self) -> bool {
61 self.bytes.is_none()
62 }
63
64 /// Finish decoding, returning the given value if there is no
65 /// remaining data, or an error otherwise
66 pub fn finish<T>(self, value: T) -> Result<T> {
67 if self.is_failed() {
68 Err(ErrorKind::Failed.at(self.position))
69 } else if !self.is_finished() {
70 Err(ErrorKind::TrailingData {
71 decoded: self.position,
72 remaining: self.remaining_len()?,
73 }
74 .at(self.position))
75 } else {
76 Ok(value)
77 }
78 }
79
80 /// Have we decoded all of the bytes in this [`Decoder`]?
81 ///
82 /// Returns `false` if we're not finished decoding or if a fatal error
83 /// has occurred.
84 pub fn is_finished(&self) -> bool {
85 self.remaining().map(|rem| rem.is_empty()).unwrap_or(false)
86 }
87
88 /// Decode a single byte, updating the internal cursor.
89 pub(crate) fn byte(&mut self) -> Result<u8> {
90 match self.bytes(1u8)? {
91 [byte] => Ok(*byte),
92 _ => self.error(ErrorKind::Truncated),
93 }
94 }
95
96 /// Obtain a slice of bytes of the given length from the current cursor
97 /// position, or return an error if we have insufficient data.
98 pub(crate) fn bytes(&mut self, len: impl TryInto<Length>) -> Result<&'a [u8]> {
99 if self.is_failed() {
100 self.error(ErrorKind::Failed)?;
101 }
102
103 let len = len
104 .try_into()
105 .or_else(|_| self.error(ErrorKind::Overflow))?;
106
107 let result = self
108 .remaining()?
109 .get(..len.to_usize())
110 .ok_or(ErrorKind::Truncated)?;
111
112 self.position = (self.position + len)?;
113 Ok(result)
114 }
115
116 /// Obtain the remaining bytes in this decoder from the current cursor
117 /// position.
118 fn remaining(&self) -> Result<&'a [u8]> {
119 self.bytes
120 .and_then(|b| b.get(self.position.into()..))
121 .ok_or_else(|| ErrorKind::Truncated.at(self.position))
122 }
123
124 /// Get the number of bytes still remaining in the buffer.
125 fn remaining_len(&self) -> Result<Length> {
126 self.remaining()?.len().try_into()
127 }
128}
129
130impl<'a> From<&'a [u8]> for Decoder<'a> {
131 fn from(bytes: &'a [u8]) -> Decoder<'a> {
132 Decoder::new(bytes)
133 }
134}
135
136#[cfg(test)]
137mod tests {
138 use core::convert::TryFrom;
139 use crate::{Decodable, Tag, TaggedSlice};
140
141 #[test]
142 fn zero_length() {
143 let buf: &[u8] = &[0x2A, 0x00];
144 let ts = TaggedSlice::from_bytes(buf).unwrap();
145 assert_eq!(ts, TaggedSlice::from(Tag::try_from(42).unwrap(), &[]).unwrap());
146 }
147}
148// #[cfg(test)]
149// mod tests {
150// use super::Decoder;
151// use crate::{Decodable, ErrorKind, Length, Tag};
152
153// #[test]
154// fn truncated_message() {
155// let mut decoder = Decoder::new(&[]);
156// let err = bool::decode(&mut decoder).err().unwrap();
157// assert_eq!(ErrorKind::Truncated, err.kind());
158// assert_eq!(Some(Length::zero()), err.position());
159// }
160
161// #[test]
162// fn invalid_field_length() {
163// let mut decoder = Decoder::new(&[0x02, 0x01]);
164// let err = i8::decode(&mut decoder).err().unwrap();
165// assert_eq!(ErrorKind::Length { tag: Tag::Integer }, err.kind());
166// assert_eq!(Some(Length::from(2u8)), err.position());
167// }
168
169// #[test]
170// fn trailing_data() {
171// let mut decoder = Decoder::new(&[0x02, 0x01, 0x2A, 0x00]);
172// let x = decoder.decode().unwrap();
173// assert_eq!(42i8, x);
174
175// let err = decoder.finish(x).err().unwrap();
176// assert_eq!(
177// ErrorKind::TrailingData {
178// decoded: 3u8.into(),
179// remaining: 1u8.into()
180// },
181// err.kind()
182// );
183// assert_eq!(Some(Length::from(3u8)), err.position());
184// }
185// }