1use chrono::{DateTime, Duration, Utc};
2
3use super::{Error, Result};
4use std::io::{Read, Write};
5
6pub trait SiaEncodable {
7 fn encoded_length(&self) -> usize;
8 fn encode<W: Write>(&self, w: &mut W) -> Result<()>;
9}
10
11pub trait SiaDecodable: Sized {
12 fn decode<R: Read>(r: &mut R) -> Result<Self>;
13}
14
15impl SiaEncodable for u8 {
16 fn encoded_length(&self) -> usize {
17 1
18 }
19
20 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
21 w.write_all(&[*self])?;
22 Ok(())
23 }
24}
25
26impl SiaDecodable for u8 {
27 fn decode<R: Read>(r: &mut R) -> Result<Self> {
28 let mut buf = [0; 1];
29 r.read_exact(&mut buf)?;
30 Ok(buf[0])
31 }
32}
33
34impl SiaEncodable for bool {
35 fn encoded_length(&self) -> usize {
36 1
37 }
38
39 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
40 (*self as u8).encode(w)
41 }
42}
43
44impl SiaDecodable for bool {
45 fn decode<R: Read>(r: &mut R) -> Result<Self> {
46 let v = u8::decode(r)?;
47 match v {
48 0 => Ok(false),
49 1 => Ok(true),
50 _ => Err(Error::InvalidValue("requires 0 or 1".into())),
51 }
52 }
53}
54
55impl SiaEncodable for DateTime<Utc> {
56 fn encoded_length(&self) -> usize {
57 8
58 }
59
60 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
61 self.timestamp().encode(w)
62 }
63}
64
65impl SiaDecodable for DateTime<Utc> {
66 fn decode<R: Read>(r: &mut R) -> Result<Self> {
67 let timestamp = i64::decode(r)?;
68 DateTime::from_timestamp_secs(timestamp)
69 .ok_or_else(|| Error::InvalidValue(format!("invalid timestamp: {timestamp}")))
70 }
71}
72
73impl SiaEncodable for Duration {
74 fn encoded_length(&self) -> usize {
75 8
76 }
77
78 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
79 self.num_nanoseconds()
80 .ok_or_else(|| Error::InvalidValue("duration too large".into()))?
81 .encode(w)
82 }
83}
84
85impl SiaDecodable for Duration {
86 fn decode<R: Read>(r: &mut R) -> Result<Self> {
87 let ns = u64::decode(r)?;
88 if ns > i64::MAX as u64 {
89 return Err(Error::InvalidValue(format!(
90 "duration {ns} must be less than {}",
91 i64::MAX
92 )));
93 }
94 Ok(Duration::nanoseconds(ns as i64))
95 }
96}
97
98impl<T: SiaEncodable> SiaEncodable for [T] {
99 fn encoded_length(&self) -> usize {
100 let mut len = 0;
101 len += self.len().encoded_length();
102 for item in self {
103 len += item.encoded_length();
104 }
105 len
106 }
107
108 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
109 self.len().encode(w)?;
110 for item in self {
111 item.encode(w)?;
112 }
113 Ok(())
114 }
115}
116
117impl<T: SiaEncodable> SiaEncodable for Option<T> {
118 fn encoded_length(&self) -> usize {
119 1 + match self {
120 Some(v) => v.encoded_length(),
121 None => 0,
122 }
123 }
124 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
125 match self {
126 Some(v) => {
127 true.encode(w)?;
128 v.encode(w)
129 }
130 None => false.encode(w),
131 }
132 }
133}
134
135impl<T: SiaDecodable> SiaDecodable for Option<T> {
136 fn decode<R: Read>(r: &mut R) -> Result<Self> {
137 match bool::decode(r)? {
138 true => Ok(Some(T::decode(r)?)),
139 false => Ok(None),
140 }
141 }
142}
143
144macro_rules! impl_sia_numeric {
145 ($($t:ty),*) => {
146 $(
147 impl SiaEncodable for $t {
148 fn encoded_length(&self) -> usize {
149 8
150 }
151
152 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
153 w.write_all(&(*self as u64).to_le_bytes())?;
154 Ok(())
155 }
156 }
157
158 impl SiaDecodable for $t {
159 fn decode<R: Read>(r: &mut R) -> Result<Self> {
160 let mut buf = [0u8; 8];
161 r.read_exact(&mut buf)?;
162 Ok(u64::from_le_bytes(buf) as Self)
163 }
164 }
165 )*
166 }
167}
168
169impl_sia_numeric!(u16, u32, usize, i16, i32, i64, u64);
170
171impl<T> SiaEncodable for Vec<T>
172where
173 T: SiaEncodable,
174{
175 fn encoded_length(&self) -> usize {
176 let mut len = 0;
177 len += self.len().encoded_length();
178 for item in self {
179 len += item.encoded_length();
180 }
181 len
182 }
183 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
184 self.len().encode(w)?;
185 for item in self {
186 item.encode(w)?;
187 }
188 Ok(())
189 }
190}
191
192impl<T> SiaDecodable for Vec<T>
193where
194 T: SiaDecodable,
195{
196 fn decode<R: Read>(r: &mut R) -> Result<Self> {
197 let mut vec = Vec::new();
198 for _ in 0..usize::decode(r)? {
201 vec.push(T::decode(r)?);
202 }
203 Ok(vec)
204 }
205}
206
207impl SiaEncodable for String {
208 fn encoded_length(&self) -> usize {
209 self.as_bytes().encoded_length()
210 }
211
212 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
213 self.as_bytes().encode(w)
214 }
215}
216
217impl SiaDecodable for String {
218 fn decode<R: Read>(r: &mut R) -> Result<Self> {
219 let buf = Vec::<u8>::decode(r)?;
220 String::from_utf8(buf).map_err(|e| Error::InvalidValue(e.to_string()))
221 }
222}
223
224impl SiaEncodable for bytes::Bytes {
225 fn encoded_length(&self) -> usize {
226 8 + self.len()
227 }
228
229 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
230 (self.len() as u64).encode(w)?;
231 w.write_all(self)?;
232 Ok(())
233 }
234}
235
236impl SiaDecodable for bytes::Bytes {
237 fn decode<R: Read>(r: &mut R) -> Result<Self> {
238 let len = u64::decode(r)? as usize;
239 let mut buf = vec![0u8; len];
240 r.read_exact(&mut buf)?;
241 Ok(bytes::Bytes::from(buf))
242 }
243}
244
245impl<const N: usize> SiaEncodable for [u8; N] {
246 fn encoded_length(&self) -> usize {
247 N
248 }
249 fn encode<W: Write>(&self, w: &mut W) -> Result<()> {
250 w.write_all(self)?;
251 Ok(())
252 }
253}
254
255impl<const N: usize> SiaDecodable for [u8; N] {
256 fn decode<R: Read>(r: &mut R) -> Result<Self> {
257 let mut arr = [0u8; N];
258 r.read_exact(&mut arr)?;
259 Ok(arr)
260 }
261}
262
263#[cfg(test)]
264mod tests {
265 use super::*;
266
267 fn test_roundtrip<T: SiaEncodable + SiaDecodable + std::fmt::Debug + PartialEq>(
268 value: T,
269 expected_bytes: Vec<u8>,
270 ) {
271 let mut encoded_bytes = Vec::new();
272 value
273 .encode(&mut encoded_bytes)
274 .unwrap_or_else(|e| panic!("failed to encode: {e:?}"));
275
276 assert_eq!(
277 encoded_bytes, expected_bytes,
278 "encoding mismatch for {value:?}"
279 );
280
281 let mut bytes = &expected_bytes[..];
282 let decoded = T::decode(&mut bytes).unwrap_or_else(|e| panic!("failed to decode: {e:?}"));
283 assert_eq!(decoded, value, "decoding mismatch for {value:?}");
284
285 assert_eq!(bytes.len(), 0, "leftover bytes for {value:?}");
286 }
287
288 #[test]
289 fn test_numerics() {
290 test_roundtrip(1u8, vec![1]);
291 test_roundtrip(2u16, vec![2, 0, 0, 0, 0, 0, 0, 0]);
292 test_roundtrip(3u32, vec![3, 0, 0, 0, 0, 0, 0, 0]);
293 test_roundtrip(4u64, vec![4, 0, 0, 0, 0, 0, 0, 0]);
294 test_roundtrip(5usize, vec![5, 0, 0, 0, 0, 0, 0, 0]);
295 test_roundtrip(-1i16, vec![255, 255, 255, 255, 255, 255, 255, 255]);
296 test_roundtrip(-2i32, vec![254, 255, 255, 255, 255, 255, 255, 255]);
297 test_roundtrip(-3i64, vec![253, 255, 255, 255, 255, 255, 255, 255]);
298 }
299
300 #[test]
301 fn test_strings() {
302 test_roundtrip(
303 "hello".to_string(),
304 vec![
305 5, 0, 0, 0, 0, 0, 0, 0, 104, 101, 108, 108, 111, ],
308 );
309 test_roundtrip(
310 "".to_string(),
311 vec![0, 0, 0, 0, 0, 0, 0, 0], );
313 }
314
315 #[test]
316 fn test_fixed_arrays() {
317 test_roundtrip([1u8, 2u8, 3u8], vec![1, 2, 3]);
318 test_roundtrip([0u8; 4], vec![0, 0, 0, 0]);
319 }
320
321 #[test]
322 fn test_vectors() {
323 test_roundtrip(
324 vec![1u8, 2u8, 3u8],
325 vec![
326 3, 0, 0, 0, 0, 0, 0, 0, 1, 2, 3, ],
329 );
330 test_roundtrip(
331 vec![100u64, 200u64],
332 vec![
333 2, 0, 0, 0, 0, 0, 0, 0, 100, 0, 0, 0, 0, 0, 0, 0, 200, 0, 0, 0, 0, 0, 0, 0, ],
337 );
338 test_roundtrip(
339 vec!["a".to_string(), "bc".to_string()],
340 vec![
341 2, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 97, 2, 0, 0, 0, 0, 0, 0, 0, 98, 99, ],
347 );
348 }
349
350 #[test]
351 fn test_nested() {
352 test_roundtrip(
353 vec![vec![1u8, 2u8], vec![3u8, 4u8]],
354 vec![
355 2, 0, 0, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2, 0, 0, 0, 0, 0, 0, 0, 3, 4, ],
361 );
362 }
363}