wormhole_io/
read_write.rs

1use std::{io, marker::PhantomData};
2
3pub trait Readable {
4    fn read<R>(reader: &mut R) -> io::Result<Self>
5    where
6        Self: Sized,
7        R: io::Read;
8}
9
10pub trait Writeable {
11    fn write<W>(&self, writer: &mut W) -> io::Result<()>
12    where
13        W: io::Write;
14}
15
16impl Readable for u8 {
17    fn read<R>(reader: &mut R) -> io::Result<Self>
18    where
19        R: io::Read,
20    {
21        let mut buf = [0u8; 1];
22        reader.read_exact(&mut buf)?;
23        Ok(buf[0])
24    }
25}
26
27impl Writeable for u8 {
28    fn write<W>(&self, writer: &mut W) -> io::Result<()>
29    where
30        W: io::Write,
31    {
32        writer.write_all(&[*self])
33    }
34}
35
36impl Readable for bool {
37    fn read<R>(reader: &mut R) -> io::Result<Self>
38    where
39        R: io::Read,
40    {
41        match u8::read(reader)? {
42            0 => Ok(false),
43            1 => Ok(true),
44            _ => Err(io::Error::new(
45                io::ErrorKind::InvalidData,
46                "invalid bool value",
47            )),
48        }
49    }
50}
51
52impl Writeable for bool {
53    fn write<W>(&self, writer: &mut W) -> io::Result<()>
54    where
55        W: io::Write,
56    {
57        writer.write_all(&[u8::from(*self)])
58    }
59}
60
61macro_rules! impl_for_int {
62    ($type:ty) => {
63        impl Readable for $type {
64            fn read<R>(reader: &mut R) -> io::Result<Self>
65            where
66                R: io::Read,
67            {
68                let mut buf = [0u8; std::mem::size_of::<$type>()];
69                reader.read_exact(&mut buf)?;
70                Ok(Self::from_be_bytes(buf))
71            }
72        }
73
74        impl Writeable for $type {
75            fn write<W>(&self, writer: &mut W) -> io::Result<()>
76            where
77                W: io::Write,
78            {
79                writer.write_all(&self.to_be_bytes())
80            }
81        }
82    };
83}
84
85impl_for_int!(u16);
86impl_for_int!(u32);
87impl_for_int!(u64);
88impl_for_int!(u128);
89
90impl_for_int!(i8);
91impl_for_int!(i16);
92impl_for_int!(i32);
93impl_for_int!(i64);
94impl_for_int!(i128);
95
96impl<const N: usize> Readable for [u8; N] {
97    fn read<R>(reader: &mut R) -> io::Result<Self>
98    where
99        Self: Sized,
100        R: io::Read,
101    {
102        let mut buf = [0u8; N];
103        reader.read_exact(&mut buf)?;
104        Ok(buf)
105    }
106}
107
108impl<const N: usize> Writeable for [u8; N] {
109    fn write<W>(&self, writer: &mut W) -> io::Result<()>
110    where
111        W: io::Write,
112    {
113        writer.write_all(self)
114    }
115}
116
117macro_rules! impl_for_int_array {
118    ($type:ty) => {
119        impl<const N: usize> Readable for [$type; N] {
120            fn read<R>(reader: &mut R) -> io::Result<Self>
121            where
122                R: io::Read,
123            {
124                let mut buf = [Default::default(); N];
125                for i in 0..N {
126                    buf[i] = <$type>::read(reader)?;
127                }
128                Ok(buf)
129            }
130        }
131
132        impl<const N: usize> Writeable for [$type; N] {
133            fn write<W>(&self, writer: &mut W) -> io::Result<()>
134            where
135                W: io::Write,
136            {
137                for i in 0..N {
138                    self[i].write(writer)?;
139                }
140                Ok(())
141            }
142        }
143    };
144}
145
146impl_for_int_array!(u16);
147impl_for_int_array!(u32);
148impl_for_int_array!(u64);
149impl_for_int_array!(u128);
150
151impl_for_int_array!(i8);
152impl_for_int_array!(i16);
153impl_for_int_array!(i32);
154impl_for_int_array!(i64);
155impl_for_int_array!(i128);
156
157impl<T> Readable for Option<T>
158where
159    T: Readable,
160{
161    fn read<R>(reader: &mut R) -> io::Result<Self>
162    where
163        Self: Sized,
164        R: io::Read,
165    {
166        match bool::read(reader)? {
167            true => Ok(Some(T::read(reader)?)),
168            false => Ok(None),
169        }
170    }
171}
172
173impl<T> Writeable for Option<T>
174where
175    T: Writeable,
176{
177    fn write<W>(&self, writer: &mut W) -> io::Result<()>
178    where
179        W: io::Write,
180    {
181        match self {
182            Some(value) => {
183                true.write(writer)?;
184                value.write(writer)
185            }
186            None => false.write(writer),
187        }
188    }
189}
190
191/// Wrapper for `Vec<u8>`. Encoding is similar to Borsh, where the length is encoded as u32 (but in
192/// this case, it's big endian).
193#[derive(Debug, Clone, PartialEq, Eq, Default)]
194pub struct WriteableBytes<L>
195where
196    u32: From<L>,
197    L: Sized + Readable + Writeable + TryFrom<usize>,
198{
199    phantom: PhantomData<L>,
200    inner: Vec<u8>,
201}
202
203impl<L> WriteableBytes<L>
204where
205    u32: From<L>,
206    L: Sized + Readable + Writeable + TryFrom<usize>,
207{
208    pub fn new(inner: Vec<u8>) -> Self {
209        Self {
210            phantom: PhantomData,
211            inner,
212        }
213    }
214
215    pub fn try_encoded_len(&self) -> io::Result<L> {
216        match L::try_from(self.inner.len()) {
217            Ok(len) => Ok(len),
218            Err(_) => Err(io::Error::new(
219                io::ErrorKind::InvalidData,
220                "L overflow when converting from usize",
221            )),
222        }
223    }
224
225    pub fn written_size(&self) -> usize {
226        std::mem::size_of::<L>() + self.inner.len()
227    }
228}
229
230impl<L> TryFrom<Vec<u8>> for WriteableBytes<L>
231where
232    u32: From<L>,
233    L: Sized + Readable + Writeable + TryFrom<usize>,
234{
235    type Error = <L as TryFrom<usize>>::Error;
236
237    fn try_from(vec: Vec<u8>) -> Result<Self, Self::Error> {
238        match L::try_from(vec.len()) {
239            Ok(_) => Ok(Self {
240                phantom: PhantomData,
241                inner: vec,
242            }),
243            Err(e) => Err(e),
244        }
245    }
246}
247
248impl<L> From<WriteableBytes<L>> for Vec<u8>
249where
250    u32: From<L>,
251    L: Sized + Readable + Writeable + TryFrom<usize>,
252{
253    fn from(bytes: WriteableBytes<L>) -> Self {
254        bytes.inner
255    }
256}
257
258impl<L> std::ops::Deref for WriteableBytes<L>
259where
260    L: Sized + Readable + Writeable,
261    u32: From<L>,
262    L: TryFrom<usize>,
263{
264    type Target = Vec<u8>;
265
266    fn deref(&self) -> &Self::Target {
267        &self.inner
268    }
269}
270
271impl<L> std::ops::DerefMut for WriteableBytes<L>
272where
273    u32: From<L>,
274    L: Sized + Readable + Writeable + TryFrom<usize>,
275{
276    fn deref_mut(&mut self) -> &mut Self::Target {
277        &mut self.inner
278    }
279}
280
281impl Readable for WriteableBytes<u8> {
282    fn read<R>(reader: &mut R) -> io::Result<Self>
283    where
284        Self: Sized,
285        R: io::Read,
286    {
287        let len = u8::read(reader)?;
288        let mut inner: Vec<u8> = vec![0u8; len.into()];
289        reader.read_exact(&mut inner)?;
290        Ok(Self {
291            phantom: PhantomData,
292            inner,
293        })
294    }
295}
296
297impl Readable for WriteableBytes<u16> {
298    fn read<R>(reader: &mut R) -> io::Result<Self>
299    where
300        Self: Sized,
301        R: io::Read,
302    {
303        let len = u16::read(reader)?;
304        let mut inner = vec![0u8; len.into()];
305        reader.read_exact(&mut inner)?;
306        Ok(Self {
307            phantom: PhantomData,
308            inner,
309        })
310    }
311}
312
313impl Readable for WriteableBytes<u32> {
314    fn read<R>(reader: &mut R) -> io::Result<Self>
315    where
316        Self: Sized,
317        R: io::Read,
318    {
319        let len = u32::read(reader)?;
320        match len.try_into() {
321            Ok(len) => {
322                let mut inner = vec![0u8; len];
323                reader.read_exact(&mut inner)?;
324                Ok(Self {
325                    phantom: PhantomData,
326                    inner,
327                })
328            }
329            Err(_) => Err(io::Error::new(
330                io::ErrorKind::InvalidData,
331                "u32 overflow when converting to usize",
332            )),
333        }
334    }
335}
336
337impl<L> Writeable for WriteableBytes<L>
338where
339    u32: From<L>,
340    L: Sized + Readable + Writeable + TryFrom<usize>,
341{
342    fn write<W>(&self, writer: &mut W) -> io::Result<()>
343    where
344        W: io::Write,
345    {
346        match self.try_encoded_len() {
347            Ok(len) => {
348                len.write(writer)?;
349                writer.write_all(&self.inner)
350            }
351            Err(e) => Err(e),
352        }
353    }
354}
355
356#[cfg(feature = "alloy")]
357impl<const N: usize> Readable for alloy_primitives::FixedBytes<N> {
358    fn read<R>(reader: &mut R) -> io::Result<Self>
359    where
360        Self: Sized,
361        R: io::Read,
362    {
363        <[u8; N]>::read(reader).map(Self)
364    }
365}
366
367#[cfg(feature = "alloy")]
368impl<const N: usize> Writeable for alloy_primitives::FixedBytes<N> {
369    fn write<W>(&self, writer: &mut W) -> io::Result<()>
370    where
371        W: io::Write,
372    {
373        self.0.write(writer)
374    }
375}
376
377#[cfg(feature = "alloy")]
378impl<const BITS: usize, const LIMBS: usize> Readable for alloy_primitives::Uint<BITS, LIMBS> {
379    fn read<R>(reader: &mut R) -> io::Result<Self>
380    where
381        Self: Sized,
382        R: io::Read,
383    {
384        let mut buf = alloy_primitives::Uint::<BITS, LIMBS>::default().to_be_bytes_vec();
385        reader.read_exact(buf.as_mut_slice())?;
386
387        Ok(alloy_primitives::Uint::try_from_be_slice(buf.as_slice()).unwrap())
388    }
389}
390
391#[cfg(feature = "alloy")]
392impl<const BITS: usize, const LIMBS: usize> Writeable for alloy_primitives::Uint<BITS, LIMBS> {
393    fn write<W>(&self, writer: &mut W) -> io::Result<()>
394    where
395        W: io::Write,
396    {
397        writer.write_all(self.to_be_bytes_vec().as_slice())
398    }
399}
400
401#[cfg(feature = "alloy")]
402impl Readable for alloy_primitives::Address {
403    fn read<R>(reader: &mut R) -> io::Result<Self>
404    where
405        Self: Sized,
406        R: io::Read,
407    {
408        alloy_primitives::FixedBytes::<20>::read(reader).map(Self)
409    }
410}
411
412#[cfg(feature = "alloy")]
413impl Writeable for alloy_primitives::Address {
414    fn write<W>(&self, writer: &mut W) -> io::Result<()>
415    where
416        W: io::Write,
417    {
418        self.0.write(writer)
419    }
420}
421
422#[cfg(test)]
423pub mod test {
424    use super::*;
425    use hex_literal::hex;
426
427    #[test]
428    fn u8_read_write() {
429        const EXPECTED_SIZE: usize = 1;
430
431        let value = 69u8;
432
433        let mut encoded = Vec::<u8>::with_capacity(EXPECTED_SIZE);
434        let mut writer = std::io::Cursor::new(&mut encoded);
435        value.write(&mut writer).unwrap();
436
437        let expected = hex!("45");
438        assert_eq!(encoded, expected);
439    }
440
441    #[test]
442    fn u64_read_write() {
443        const EXPECTED_SIZE: usize = 8;
444
445        let value = 69u64;
446        let mut encoded = Vec::<u8>::with_capacity(EXPECTED_SIZE);
447        let mut writer = std::io::Cursor::new(&mut encoded);
448        value.write(&mut writer).unwrap();
449
450        let expected = hex!("0000000000000045");
451        assert_eq!(encoded, expected);
452    }
453
454    #[test]
455    fn u8_array_read_write() {
456        let data = [1, 2, 8, 16, 32, 64, 69u8];
457
458        let mut encoded = Vec::<u8>::with_capacity(data.len());
459        let mut writer = std::io::Cursor::new(&mut encoded);
460        data.write(&mut writer).unwrap();
461
462        let expected = hex!("01020810204045");
463        assert_eq!(encoded, expected);
464    }
465
466    #[test]
467    fn u64_array_read_write() {
468        let data = [1, 2, 8, 16, 32, 64, 69u64];
469        const EXPECTED_SIZE: usize = 56;
470
471        let mut encoded = Vec::<u8>::with_capacity(EXPECTED_SIZE);
472        let mut writer = std::io::Cursor::new(&mut encoded);
473        data.write(&mut writer).unwrap();
474
475        let expected = hex!("0000000000000001000000000000000200000000000000080000000000000010000000000000002000000000000000400000000000000045");
476        assert_eq!(encoded, expected);
477    }
478
479    #[test]
480    fn variable_bytes_read_write_u8() {
481        let data = b"All your base are belong to us.";
482        let bytes = WriteableBytes::<u8>::new(data.to_vec());
483
484        let mut encoded = Vec::<u8>::with_capacity(1 + data.len());
485        let mut writer = std::io::Cursor::new(&mut encoded);
486        bytes.write(&mut writer).unwrap();
487
488        let expected = hex!("1f416c6c20796f75722062617365206172652062656c6f6e6720746f2075732e");
489        assert_eq!(encoded, expected);
490    }
491
492    #[test]
493    fn variable_bytes_read_write_u16() {
494        let data = b"All your base are belong to us.";
495        let bytes = WriteableBytes::<u16>::new(data.to_vec());
496
497        let mut encoded = Vec::<u8>::with_capacity(2 + data.len());
498        let mut writer = std::io::Cursor::new(&mut encoded);
499        bytes.write(&mut writer).unwrap();
500
501        let expected = hex!("001f416c6c20796f75722062617365206172652062656c6f6e6720746f2075732e");
502        assert_eq!(encoded, expected);
503    }
504
505    #[test]
506    fn variable_bytes_read_write_u32() {
507        let data = b"All your base are belong to us.";
508        let bytes = WriteableBytes::<u32>::new(data.to_vec());
509
510        let mut encoded = Vec::<u8>::with_capacity(4 + data.len());
511        let mut writer = std::io::Cursor::new(&mut encoded);
512        bytes.write(&mut writer).unwrap();
513
514        let expected =
515            hex!("0000001f416c6c20796f75722062617365206172652062656c6f6e6720746f2075732e");
516        assert_eq!(encoded, expected);
517    }
518
519    #[test]
520    fn mem_take() {
521        let data = b"All your base are belong to us.";
522        let mut bytes = WriteableBytes::<u16>::new(data.to_vec());
523
524        let taken = std::mem::take(&mut bytes);
525        assert_eq!(taken.as_slice(), data);
526    }
527
528    #[test]
529    fn option_some() {
530        let value = Some(69u64);
531
532        let mut encoded = Vec::<u8>::with_capacity(1 + 8);
533        let mut writer = std::io::Cursor::new(&mut encoded);
534        value.write(&mut writer).unwrap();
535
536        let expected = hex!("010000000000000045");
537        assert_eq!(encoded, expected);
538    }
539
540    #[test]
541    fn option_none() {
542        let value: Option<[u8; 64]> = None;
543
544        let mut encoded = Vec::<u8>::with_capacity(1);
545        let mut writer = std::io::Cursor::new(&mut encoded);
546        value.write(&mut writer).unwrap();
547
548        let expected = hex!("00");
549        assert_eq!(encoded, expected);
550    }
551}