Skip to main content

sia_core/encoding/
v2.rs

1use chrono::{DateTime, Duration, Utc};
2
3use super::{Error, Result};
4use std::io::{Read, Write};
5
6pub trait SiaEncodable {
7    fn encoded_length(&self) -> usize;
8    fn encode<W: Write>(&self, w: &mut W) -> Result<()>;
9}
10
11pub trait SiaDecodable: Sized {
12    fn decode<R: Read>(r: &mut R) -> Result<Self>;
13}
14
15impl SiaEncodable for u8 {
16    fn encoded_length(&self) -> usize {
17        1
18    }
19
20    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
21        w.write_all(&[*self])?;
22        Ok(())
23    }
24}
25
26impl SiaDecodable for u8 {
27    fn decode<R: Read>(r: &mut R) -> Result<Self> {
28        let mut buf = [0; 1];
29        r.read_exact(&mut buf)?;
30        Ok(buf[0])
31    }
32}
33
34impl SiaEncodable for bool {
35    fn encoded_length(&self) -> usize {
36        1
37    }
38
39    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
40        (*self as u8).encode(w)
41    }
42}
43
44impl SiaDecodable for bool {
45    fn decode<R: Read>(r: &mut R) -> Result<Self> {
46        let v = u8::decode(r)?;
47        match v {
48            0 => Ok(false),
49            1 => Ok(true),
50            _ => Err(Error::InvalidValue("requires 0 or 1".into())),
51        }
52    }
53}
54
55impl SiaEncodable for DateTime<Utc> {
56    fn encoded_length(&self) -> usize {
57        8
58    }
59
60    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
61        self.timestamp().encode(w)
62    }
63}
64
65impl SiaDecodable for DateTime<Utc> {
66    fn decode<R: Read>(r: &mut R) -> Result<Self> {
67        let timestamp = i64::decode(r)?;
68        DateTime::from_timestamp_secs(timestamp)
69            .ok_or_else(|| Error::InvalidValue(format!("invalid timestamp: {timestamp}")))
70    }
71}
72
73impl SiaEncodable for Duration {
74    fn encoded_length(&self) -> usize {
75        8
76    }
77
78    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
79        self.num_nanoseconds()
80            .ok_or_else(|| Error::InvalidValue("duration too large".into()))?
81            .encode(w)
82    }
83}
84
85impl SiaDecodable for Duration {
86    fn decode<R: Read>(r: &mut R) -> Result<Self> {
87        let ns = u64::decode(r)?;
88        if ns > i64::MAX as u64 {
89            return Err(Error::InvalidValue(format!(
90                "duration {ns} must be less than {}",
91                i64::MAX
92            )));
93        }
94        Ok(Duration::nanoseconds(ns as i64))
95    }
96}
97
98impl<T: SiaEncodable> SiaEncodable for [T] {
99    fn encoded_length(&self) -> usize {
100        let mut len = 0;
101        len += self.len().encoded_length();
102        for item in self {
103            len += item.encoded_length();
104        }
105        len
106    }
107
108    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
109        self.len().encode(w)?;
110        for item in self {
111            item.encode(w)?;
112        }
113        Ok(())
114    }
115}
116
117impl<T: SiaEncodable> SiaEncodable for Option<T> {
118    fn encoded_length(&self) -> usize {
119        1 + match self {
120            Some(v) => v.encoded_length(),
121            None => 0,
122        }
123    }
124    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
125        match self {
126            Some(v) => {
127                true.encode(w)?;
128                v.encode(w)
129            }
130            None => false.encode(w),
131        }
132    }
133}
134
135impl<T: SiaDecodable> SiaDecodable for Option<T> {
136    fn decode<R: Read>(r: &mut R) -> Result<Self> {
137        match bool::decode(r)? {
138            true => Ok(Some(T::decode(r)?)),
139            false => Ok(None),
140        }
141    }
142}
143
144macro_rules! impl_sia_numeric {
145    ($($t:ty),*) => {
146        $(
147            impl SiaEncodable for $t {
148                fn encoded_length(&self) -> usize {
149                    8
150                }
151
152                fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
153                    w.write_all(&(*self as u64).to_le_bytes())?;
154                    Ok(())
155                }
156            }
157
158            impl SiaDecodable for $t {
159                fn decode<R: Read>(r: &mut R) -> Result<Self> {
160                    let mut buf = [0u8; 8];
161                    r.read_exact(&mut buf)?;
162                    Ok(u64::from_le_bytes(buf) as Self)
163                }
164            }
165        )*
166    }
167}
168
169impl_sia_numeric!(u16, u32, usize, i16, i32, i64, u64);
170
171impl<T> SiaEncodable for Vec<T>
172where
173    T: SiaEncodable,
174{
175    fn encoded_length(&self) -> usize {
176        let mut len = 0;
177        len += self.len().encoded_length();
178        for item in self {
179            len += item.encoded_length();
180        }
181        len
182    }
183    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
184        self.len().encode(w)?;
185        for item in self {
186            item.encode(w)?;
187        }
188        Ok(())
189    }
190}
191
192impl<T> SiaDecodable for Vec<T>
193where
194    T: SiaDecodable,
195{
196    fn decode<R: Read>(r: &mut R) -> Result<Self> {
197        let mut vec = Vec::new();
198        // note: the vec is not pre-allocated
199        // to prevent abuse by sending a large len
200        for _ in 0..usize::decode(r)? {
201            vec.push(T::decode(r)?);
202        }
203        Ok(vec)
204    }
205}
206
207impl SiaEncodable for String {
208    fn encoded_length(&self) -> usize {
209        self.as_bytes().encoded_length()
210    }
211
212    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
213        self.as_bytes().encode(w)
214    }
215}
216
217impl SiaDecodable for String {
218    fn decode<R: Read>(r: &mut R) -> Result<Self> {
219        let buf = Vec::<u8>::decode(r)?;
220        String::from_utf8(buf).map_err(|e| Error::InvalidValue(e.to_string()))
221    }
222}
223
224impl SiaEncodable for bytes::Bytes {
225    fn encoded_length(&self) -> usize {
226        8 + self.len()
227    }
228
229    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
230        (self.len() as u64).encode(w)?;
231        w.write_all(self)?;
232        Ok(())
233    }
234}
235
236impl SiaDecodable for bytes::Bytes {
237    fn decode<R: Read>(r: &mut R) -> Result<Self> {
238        let len = u64::decode(r)? as usize;
239        let mut buf = vec![0u8; len];
240        r.read_exact(&mut buf)?;
241        Ok(bytes::Bytes::from(buf))
242    }
243}
244
245impl<const N: usize> SiaEncodable for [u8; N] {
246    fn encoded_length(&self) -> usize {
247        N
248    }
249    fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
250        w.write_all(self)?;
251        Ok(())
252    }
253}
254
255impl<const N: usize> SiaDecodable for [u8; N] {
256    fn decode<R: Read>(r: &mut R) -> Result<Self> {
257        let mut arr = [0u8; N];
258        r.read_exact(&mut arr)?;
259        Ok(arr)
260    }
261}
262
263#[cfg(test)]
264mod tests {
265    use super::*;
266
267    fn test_roundtrip<T: SiaEncodable + SiaDecodable + std::fmt::Debug + PartialEq>(
268        value: T,
269        expected_bytes: Vec<u8>,
270    ) {
271        let mut encoded_bytes = Vec::new();
272        value
273            .encode(&mut encoded_bytes)
274            .unwrap_or_else(|e| panic!("failed to encode: {e:?}"));
275
276        assert_eq!(
277            encoded_bytes, expected_bytes,
278            "encoding mismatch for {value:?}"
279        );
280
281        let mut bytes = &expected_bytes[..];
282        let decoded = T::decode(&mut bytes).unwrap_or_else(|e| panic!("failed to decode: {e:?}"));
283        assert_eq!(decoded, value, "decoding mismatch for {value:?}");
284
285        assert_eq!(bytes.len(), 0, "leftover bytes for {value:?}");
286    }
287
288    #[test]
289    fn test_numerics() {
290        test_roundtrip(1u8, vec![1]);
291        test_roundtrip(2u16, vec![2, 0, 0, 0, 0, 0, 0, 0]);
292        test_roundtrip(3u32, vec![3, 0, 0, 0, 0, 0, 0, 0]);
293        test_roundtrip(4u64, vec![4, 0, 0, 0, 0, 0, 0, 0]);
294        test_roundtrip(5usize, vec![5, 0, 0, 0, 0, 0, 0, 0]);
295        test_roundtrip(-1i16, vec![255, 255, 255, 255, 255, 255, 255, 255]);
296        test_roundtrip(-2i32, vec![254, 255, 255, 255, 255, 255, 255, 255]);
297        test_roundtrip(-3i64, vec![253, 255, 255, 255, 255, 255, 255, 255]);
298    }
299
300    #[test]
301    fn test_strings() {
302        test_roundtrip(
303            "hello".to_string(),
304            vec![
305                5, 0, 0, 0, 0, 0, 0, 0, // length prefix
306                104, 101, 108, 108, 111, // "hello"
307            ],
308        );
309        test_roundtrip(
310            "".to_string(),
311            vec![0, 0, 0, 0, 0, 0, 0, 0], // empty string length
312        );
313    }
314
315    #[test]
316    fn test_fixed_arrays() {
317        test_roundtrip([1u8, 2u8, 3u8], vec![1, 2, 3]);
318        test_roundtrip([0u8; 4], vec![0, 0, 0, 0]);
319    }
320
321    #[test]
322    fn test_vectors() {
323        test_roundtrip(
324            vec![1u8, 2u8, 3u8],
325            vec![
326                3, 0, 0, 0, 0, 0, 0, 0, // length prefix
327                1, 2, 3, // values
328            ],
329        );
330        test_roundtrip(
331            vec![100u64, 200u64],
332            vec![
333                2, 0, 0, 0, 0, 0, 0, 0, // length prefix
334                100, 0, 0, 0, 0, 0, 0, 0, // 100u64
335                200, 0, 0, 0, 0, 0, 0, 0, // 200u64
336            ],
337        );
338        test_roundtrip(
339            vec!["a".to_string(), "bc".to_string()],
340            vec![
341                2, 0, 0, 0, 0, 0, 0, 0, // vector length
342                1, 0, 0, 0, 0, 0, 0, 0,  // first string length
343                97, // "a"
344                2, 0, 0, 0, 0, 0, 0, 0, // second string length
345                98, 99, // "bc"
346            ],
347        );
348    }
349
350    #[test]
351    fn test_nested() {
352        test_roundtrip(
353            vec![vec![1u8, 2u8], vec![3u8, 4u8]],
354            vec![
355                2, 0, 0, 0, 0, 0, 0, 0, // outer vec length
356                2, 0, 0, 0, 0, 0, 0, 0, // first inner vec length
357                1, 2, // first inner vec contents
358                2, 0, 0, 0, 0, 0, 0, 0, // second inner vec length
359                3, 4, // second inner vec contents
360            ],
361        );
362    }
363}