tinycbor/collections/
fixed.rs

1//! Fixed length collections and structures.
2use core::mem::MaybeUninit;
3
4use crate::{CborLen, Decode, Decoder, Encode};
5
6/// An error that can occur when decoding fixed length structures and collections.
7#[derive(#[automatically_derived]
impl<E: ::core::fmt::Debug> ::core::fmt::Debug for Error<E> {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        match self {
            Error::Missing => ::core::fmt::Formatter::write_str(f, "Missing"),
            Error::Surplus => ::core::fmt::Formatter::write_str(f, "Surplus"),
            Error::Collection(__self_0) =>
                ::core::fmt::Formatter::debug_tuple_field1_finish(f,
                    "Collection", &__self_0),
        }
    }
}Debug, #[automatically_derived]
impl<E: ::core::clone::Clone> ::core::clone::Clone for Error<E> {
    #[inline]
    fn clone(&self) -> Error<E> {
        match self {
            Error::Missing => Error::Missing,
            Error::Surplus => Error::Surplus,
            Error::Collection(__self_0) =>
                Error::Collection(::core::clone::Clone::clone(__self_0)),
        }
    }
}Clone, #[automatically_derived]
impl<E: ::core::marker::Copy> ::core::marker::Copy for Error<E> { }Copy, #[automatically_derived]
impl<E: ::core::cmp::PartialEq> ::core::cmp::PartialEq for Error<E> {
    #[inline]
    fn eq(&self, other: &Error<E>) -> bool {
        let __self_discr = ::core::intrinsics::discriminant_value(self);
        let __arg1_discr = ::core::intrinsics::discriminant_value(other);
        __self_discr == __arg1_discr &&
            match (self, other) {
                (Error::Collection(__self_0), Error::Collection(__arg1_0)) =>
                    __self_0 == __arg1_0,
                _ => true,
            }
    }
}PartialEq, #[automatically_derived]
impl<E: ::core::cmp::Eq> ::core::cmp::Eq for Error<E> {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) -> () {
        let _: ::core::cmp::AssertParamIsEq<super::Error<E>>;
    }
}Eq, #[automatically_derived]
impl<E: ::core::hash::Hash> ::core::hash::Hash for Error<E> {
    #[inline]
    fn hash<__H: ::core::hash::Hasher>(&self, state: &mut __H) -> () {
        let __self_discr = ::core::intrinsics::discriminant_value(self);
        ::core::hash::Hash::hash(&__self_discr, state);
        match self {
            Error::Collection(__self_0) =>
                ::core::hash::Hash::hash(__self_0, state),
            _ => {}
        }
    }
}Hash)]
8pub enum Error<E> {
9    /// Not enough elements.
10    Missing,
11    /// Unexpected surplus elements.
12    Surplus,
13    /// Either the header or an element caused an error.
14    Collection(super::Error<E>),
15}
16
17impl<E> Error<E> {
18    /// Map a function on the inner error.
19    pub fn map<O>(self, f: impl FnOnce(E) -> O) -> Error<O> {
20        match self {
21            Error::Missing => Error::Missing,
22            Error::Surplus => Error::Surplus,
23            Error::Collection(e) => Error::Collection(e.map(f)),
24        }
25    }
26}
27
28impl<E: core::fmt::Display> core::fmt::Display for Error<E> {
29    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
30        match self {
31            Error::Missing => f.write_fmt(format_args!("missing elements"))write!(f, "missing elements"),
32            Error::Surplus => f.write_fmt(format_args!("too many elements"))write!(f, "too many elements"),
33            Error::Collection(e) => f.write_fmt(format_args!("{0}", e))write!(f, "{e}"),
34        }
35    }
36}
37
38impl<E> From<super::Error<E>> for Error<E> {
39    fn from(e: super::Error<E>) -> Self {
40        Error::Collection(e)
41    }
42}
43
44impl<E: core::error::Error + 'static> core::error::Error for Error<E> {
45    fn source(&self) -> Option<&(dyn core::error::Error + 'static)> {
46        match self {
47            Error::Missing => None,
48            Error::Surplus => None,
49            Error::Collection(e) => Some(e),
50        }
51    }
52}
53
54// Guard to prevent memory leaks in the case of a panic during decoding. This is not
55// strictly nessessary as leaks are allowed in the Rust memory safety model, but this is
56// nice to have if a dependent decides to catch unwinding panics. Our library won't cause a
57// memory leak in that case.
58struct Guard<T, const N: usize> {
59    data: [MaybeUninit<T>; N],
60    initialized: usize,
61}
62
63impl<T, const N: usize> Guard<T, N> {
64    /// Safety: Caller must ensure that all elements up to `count` are initialized.
65    unsafe fn assume_init(mut self) -> [T; N] {
66        let data = core::mem::replace(&mut self.data, [const { MaybeUninit::uninit() }; N]);
67        // Don't drop the guard anymore, becuase it contains uninitialized elements.
68        let _ = core::mem::ManuallyDrop::new(self);
69
70        // Safety: Caller has ensured that all elements are initialized.
71        unsafe { data.as_ptr().cast::<[T; N]>().read() }
72    }
73}
74
75impl<T, const N: usize> Drop for Guard<T, N> {
76    fn drop(&mut self) {
77        for i in 0..self.initialized {
78            // Safety: We only drop initialized elements.
79            unsafe { self.data[i].assume_init_drop() };
80        }
81    }
82}
83
84impl<'a, T, const N: usize> Decode<'a> for [T; N]
85where
86    T: Decode<'a>,
87{
88    type Error = Error<T::Error>;
89
90    fn decode(d: &mut Decoder<'a>) -> Result<Self, Self::Error> {
91        let mut visitor = d.array_visitor().map_err(super::Error::Malformed)?;
92        let mut guard = Guard {
93            data: [const { MaybeUninit::uninit() }; N],
94            initialized: 0,
95        };
96
97        for elem in &mut guard.data {
98            elem.write(
99                visitor
100                    .visit::<T>()
101                    .ok_or(Error::Missing)?
102                    .map_err(super::Error::Element)?,
103            );
104            guard.initialized += 1;
105        }
106        // Safety: All elements have been initialized.
107        let array = unsafe { guard.assume_init() };
108
109        if visitor.remaining() != Some(0) {
110            return Err(Error::Surplus);
111        }
112
113        Ok(array)
114    }
115}
116
117impl<T: Encode, const N: usize> Encode for [T; N] {
118    fn encode<W: embedded_io::Write>(&self, e: &mut crate::Encoder<W>) -> Result<(), W::Error> {
119        e.array(N)?;
120        for item in self {
121            item.encode(e)?;
122        }
123        Ok(())
124    }
125}
126
127impl<T: CborLen, const N: usize> CborLen for [T; N] {
128    fn cbor_len(&self) -> usize {
129        N.cbor_len() + self.iter().map(|x| x.cbor_len()).sum::<usize>()
130    }
131}
132
133// Map encoding
134
135impl<'a, K, V, const N: usize> Decode<'a> for [(K, V); N]
136where
137    K: Decode<'a>,
138    V: Decode<'a>,
139{
140    type Error = Error<super::map::Error<K::Error, V::Error>>;
141
142    fn decode(d: &mut Decoder<'a>) -> Result<Self, Self::Error> {
143        let mut visitor = d.map_visitor().map_err(super::Error::Malformed)?;
144        let mut guard = Guard {
145            data: [const { MaybeUninit::uninit() }; N],
146            initialized: 0,
147        };
148
149        for elem in &mut guard.data {
150            let v = visitor
151                .visit()
152                .ok_or(Error::Missing)?
153                .map_err(super::Error::Element)?;
154            elem.write(v);
155            guard.initialized += 1;
156        }
157        // Safety: All elements have been initialized.
158        let array = unsafe { guard.assume_init() };
159
160        if visitor.remaining() != Some(0) {
161            return Err(Error::Surplus);
162        }
163        Ok(array)
164    }
165}
166
167impl<K: Encode, V: Encode, const N: usize> Encode for [(K, V); N] {
168    fn encode<W: embedded_io::Write>(&self, e: &mut crate::Encoder<W>) -> Result<(), W::Error> {
169        e.map(N)?;
170        for (k, v) in self {
171            k.encode(e)?;
172            v.encode(e)?;
173        }
174        Ok(())
175    }
176}
177
178impl<K: CborLen, V: CborLen, const N: usize> CborLen for [(K, V); N] {
179    fn cbor_len(&self) -> usize {
180        N.cbor_len()
181            + self
182                .iter()
183                .map(|(k, v)| k.cbor_len() + v.cbor_len())
184                .sum::<usize>()
185    }
186}
187
188#[cfg(test)]
189mod tests {
190    use crate::{Decode, Decoder, test};
191
192    const EMPTY_ARRAY: &[u8] = &[0x80];
193
194    #[test]
195    fn empty() {
196        assert!(test::<[isize; 0]>([], EMPTY_ARRAY).unwrap());
197        assert!(test::<[i32; 0]>([], EMPTY_ARRAY).unwrap());
198    }
199
200    #[test]
201    fn small() {
202        assert!(test([42u16], &[0x81, 0x18, 0x2a]).unwrap());
203        assert!(test([true], &[0x81, 0xf5]).unwrap());
204        assert!(test([-1i32], &[0x81, 0x20]).unwrap());
205
206        assert!(test([1usize, 2usize], &[0x82, 0x01, 0x02]).unwrap());
207        assert!(test([true, false], &[0x82, 0xf5, 0xf4]).unwrap());
208
209        assert!(test(["a", "b", "c"], &[0x83, 0x61, 0x61, 0x61, 0x62, 0x61, 0x63]).unwrap());
210    }
211
212    #[test]
213    fn nested() {
214        assert!(
215            test(
216                [[1u64, 2], [3, 4]],
217                &[0x82, 0x82, 0x01, 0x02, 0x82, 0x03, 0x04]
218            )
219            .unwrap()
220        );
221
222        assert!(
223            test(
224                [[[1u64, 2], [3, 4]], [[5, 6], [7, 8]]],
225                &[
226                    0x82, 0x82, 0x82, 0x01, 0x02, 0x82, 0x03, 0x04, 0x82, 0x82, 0x05, 0x06, 0x82,
227                    0x07, 0x08
228                ]
229            )
230            .unwrap()
231        );
232    }
233
234    #[test]
235    fn missing() {
236        use super::Error;
237        let cbor = &[0x82, 0x01, 0x02];
238        let result = <[u16; 3]>::decode(&mut Decoder(cbor));
239        assert!(matches!(result, Err(Error::Missing)));
240    }
241
242    #[test]
243    fn surplus() {
244        use super::Error;
245        let cbor = &[0x83, 0x01, 0x02, 0x03];
246        let result = <[u16; 2]>::decode(&mut Decoder(cbor));
247        assert!(matches!(result, Err(Error::Surplus)));
248    }
249
250    #[test]
251    fn long() {
252        let arr: [u32; 25] = core::array::from_fn(|i| i as u32);
253        let mut cbor = vec![0x98, 25];
254        for i in 0..25 {
255            if i < 24 {
256                cbor.push(i as u8);
257            } else {
258                cbor.push(0x18);
259                cbor.push(i as u8);
260            }
261        }
262
263        assert!(test(arr, &cbor).unwrap());
264    }
265
266    #[test]
267    fn map() {
268        assert!(
269            test(
270                [("a", 1u16), ("b", 2u16)],
271                &[0xA2, 0x61, 0x61, 0x01, 0x61, 0x62, 0x02]
272            )
273            .unwrap()
274        );
275    }
276}