sml_rs/parser/
tlf.rs

1//! A Type-Length-Field is a building block for many SML data structures.
2
3use core::fmt;
4
5use crate::parser::ParseError;
6
7use super::{take_byte, SmlParse};
8
9use super::ResTy;
10
11/// Error type used when parsing a `TypeLengthField`
12#[derive(Clone, Debug, PartialEq, Eq)]
13pub enum TlfParseError {
14    /// The length field of a TLF overflowed
15    TlfLengthOverflow,
16    /// The TLF uses values reserved for future usage
17    TlfReserved,
18    /// The length field of a TLF underflowed
19    TlfLengthUnderflow,
20    /// The type field of a byte following the first TLF byte isn't set to `000`
21    TlfNextByteTypeMismatch,
22    /// The TLF's type field contains an invalid value
23    TlfInvalidTy,
24}
25
26impl From<TlfParseError> for ParseError {
27    fn from(x: TlfParseError) -> Self {
28        ParseError::InvalidTlf(x)
29    }
30}
31
32impl fmt::Display for TlfParseError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        <Self as fmt::Debug>::fmt(self, f)
35    }
36}
37
38#[cfg(feature = "std")]
39impl std::error::Error for TlfParseError {}
40
41#[derive(Debug, PartialEq, Eq, Clone)]
42pub(crate) struct TypeLengthField {
43    pub ty: Ty,
44    pub len: u32,
45}
46
47impl TypeLengthField {
48    #[allow(unused)]
49    pub(crate) fn new(ty: Ty, len: u32) -> TypeLengthField {
50        TypeLengthField { ty, len }
51    }
52}
53
54impl<'i> SmlParse<'i> for TypeLengthField {
55    fn parse(input: &[u8]) -> ResTy<Self> {
56        let (mut input, (mut has_more_bytes, ty, mut len)) = tlf_first_byte(input)?;
57        let mut tlf_len = 1;
58
59        // reserved for future usages
60        if matches!(ty, Ty::Boolean) && has_more_bytes {
61            return Err(TlfParseError::TlfReserved.into());
62        }
63
64        while has_more_bytes {
65            tlf_len += 1;
66
67            let (input_new, (has_more_bytes_new, len_new)) = tlf_next_byte(input)?;
68            input = input_new;
69            has_more_bytes = has_more_bytes_new;
70
71            len = match len.checked_shl(4) {
72                Some(l) => l,
73                None => {
74                    return Err(TlfParseError::TlfLengthOverflow.into());
75                }
76            };
77            len += len_new & 0b1111;
78        }
79
80        // For some reason, the length of the tlf is part of `len` for primitive types.
81        // Therefore, it has to be subtracted here
82        if !matches!(ty, Ty::ListOf) {
83            len = match len.checked_sub(tlf_len) {
84                Some(l) => l,
85                None => {
86                    return Err(TlfParseError::TlfLengthUnderflow.into());
87                }
88            }
89        }
90
91        Ok((input, TypeLengthField { ty, len }))
92    }
93}
94
95fn tlf_byte(input: &[u8]) -> ResTy<(bool, u8, u32)> {
96    let (input, b) = take_byte(input)?;
97    let len = b & 0x0F;
98    let ty = (b >> 4) & 0x07;
99    let has_more_bytes = (b & 0x80) != 0;
100    Ok((input, (has_more_bytes, ty, len as u32)))
101}
102
103fn tlf_first_byte(input: &[u8]) -> ResTy<(bool, Ty, u32)> {
104    let (input, (has_more_bytes, ty, len)) = tlf_byte(input)?;
105    let ty = Ty::from_byte(ty)?;
106    Ok((input, (has_more_bytes, ty, len)))
107}
108
109fn tlf_next_byte(input: &[u8]) -> ResTy<(bool, u32)> {
110    let (input, (has_more_bytes, ty, len)) = tlf_byte(input)?;
111    if ty != 0x00 {
112        return Err(TlfParseError::TlfNextByteTypeMismatch.into());
113    }
114    Ok((input, (has_more_bytes, len)))
115}
116
117#[derive(Debug, PartialEq, Eq, Clone, Copy)]
118pub(crate) enum Ty {
119    OctetString,
120    Boolean,
121    Integer,
122    Unsigned,
123    ListOf,
124}
125
126impl Ty {
127    fn from_byte(ty_num: u8) -> Result<Ty, ParseError> {
128        Ok(match ty_num {
129            0b000 => Ty::OctetString,
130            0b100 => Ty::Boolean,
131            0b101 => Ty::Integer,
132            0b110 => Ty::Unsigned,
133            0b111 => Ty::ListOf,
134            _ => {
135                return Err(TlfParseError::TlfInvalidTy.into());
136            }
137        })
138    }
139}
140
141#[cfg(test)]
142mod tests {
143    use super::*;
144
145    #[test]
146    fn different_types() {
147        let test_cases = [
148            (&[0b0000_0001], TypeLengthField::new(Ty::OctetString, 0)),
149            (&[0b0100_0001], TypeLengthField::new(Ty::Boolean, 0)),
150            (&[0b0101_0001], TypeLengthField::new(Ty::Integer, 0)),
151            (&[0b0110_0001], TypeLengthField::new(Ty::Unsigned, 0)),
152            (&[0b0111_0000], TypeLengthField::new(Ty::ListOf, 0)),
153        ];
154
155        test_cases.iter().for_each(|(input, exp)| {
156            assert_eq!(
157                &TypeLengthField::parse_complete(*input).expect("Decode error"),
158                exp
159            )
160        });
161    }
162
163    #[test]
164    fn reserved() {
165        // single-byte
166        assert!(TypeLengthField::parse(&[0b1100_0000]).is_err());
167        assert!(TypeLengthField::parse(&[0b0001_0000]).is_err());
168        assert!(TypeLengthField::parse(&[0b0010_0000]).is_err());
169        assert!(TypeLengthField::parse(&[0b0011_0000]).is_err());
170        assert!(TypeLengthField::parse(&[0b1001_0000]).is_err());
171        assert!(TypeLengthField::parse(&[0b1010_0000]).is_err());
172        assert!(TypeLengthField::parse(&[0b1011_0000]).is_err());
173
174        // multi-byte
175        assert!(TypeLengthField::parse(&[0b1000_0010, 0b0001_0000]).is_err());
176        assert!(TypeLengthField::parse(&[0b1000_0010, 0b0010_0000]).is_err());
177        assert!(TypeLengthField::parse(&[0b1000_0010, 0b0011_0000]).is_err());
178        assert!(TypeLengthField::parse(&[0b1000_0010, 0b0101_0000]).is_err());
179        assert!(TypeLengthField::parse(&[0b1000_0010, 0b0110_0000]).is_err());
180        assert!(TypeLengthField::parse(&[0b1000_0010, 0b0111_0000]).is_err());
181    }
182
183    #[test]
184    fn len_single_byte() {
185        // for primitive data types, the tlf length is part of the length field.
186        // for complex data types, it is not.
187
188        // single-byte tlf for primitive type
189        assert_eq!(
190            TypeLengthField::parse_complete(&[0b0000_0001]).expect("Decode error"),
191            TypeLengthField::new(Ty::OctetString, 0)
192        );
193        assert_eq!(
194            TypeLengthField::parse_complete(&[0b0000_1000]).expect("Decode error"),
195            TypeLengthField::new(Ty::OctetString, 7)
196        );
197        assert_eq!(
198            TypeLengthField::parse_complete(&[0b0000_1111]).expect("Decode error"),
199            TypeLengthField::new(Ty::OctetString, 14)
200        );
201        // length 0 for primitive types is an error
202        assert!(TypeLengthField::parse(&[0b0000_0000]).is_err());
203
204        // single-byte tlf for complex type
205        assert_eq!(
206            TypeLengthField::parse_complete(&[0b0111_0000]).expect("Decode error"),
207            TypeLengthField::new(Ty::ListOf, 0)
208        );
209        assert_eq!(
210            TypeLengthField::parse_complete(&[0b0111_1000]).expect("Decode error"),
211            TypeLengthField::new(Ty::ListOf, 8)
212        );
213        assert_eq!(
214            TypeLengthField::parse_complete(&[0b0111_1111]).expect("Decode error"),
215            TypeLengthField::new(Ty::ListOf, 15)
216        );
217    }
218
219    #[test]
220    fn len_multi_byte() {
221        // for primitive data types, the tlf length is part of the length field.
222        // for complex data types, it is not.
223
224        // multi-byte tlf for primitive type
225        assert_eq!(
226            TypeLengthField::parse_complete(&[0b1000_0010, 0b0000_0011]).expect("Decode error"),
227            TypeLengthField::new(Ty::OctetString, 0b0010_0011 - 2)
228        );
229        assert_eq!(
230            TypeLengthField::parse_complete(&[0b1000_0010, 0b1000_0011, 0b0000_1111])
231                .expect("Decode error"),
232            TypeLengthField::new(Ty::OctetString, 0b0010_0011_1111 - 3)
233        );
234
235        // multi-byte tlf for complex type
236        assert_eq!(
237            TypeLengthField::parse_complete(&[0b1111_0010, 0b0000_0011]).expect("Decode error"),
238            TypeLengthField::new(Ty::ListOf, 0b0010_0011)
239        );
240        assert_eq!(
241            TypeLengthField::parse_complete(&[0b1111_0010, 0b1000_0011, 0b0000_1111])
242                .expect("Decode error"),
243            TypeLengthField::new(Ty::ListOf, 0b0010_0011_1111)
244        );
245    }
246}