snarkvm_utilities/
bytes.rs

1// Copyright (c) 2019-2025 Provable Inc.
2// This file is part of the snarkVM library.
3
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at:
7
8// http://www.apache.org/licenses/LICENSE-2.0
9
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use std::{
17    fmt,
18    io::{Read, Result as IoResult, Write},
19    marker::PhantomData,
20    net::{IpAddr, Ipv4Addr, Ipv6Addr, SocketAddr},
21};
22
23use serde::{
24    Deserializer,
25    Serializer,
26    de::{self, Error, SeqAccess, Visitor},
27    ser::{self, SerializeTuple},
28};
29use smol_str::SmolStr;
30
31use crate::error;
32
33/// Takes as input a sequence of structs, and converts them to a series of little-endian bytes.
34/// All traits that implement `ToBytes` can be automatically converted to bytes in this manner.
35#[macro_export]
36macro_rules! to_bytes_le {
37    ($($x:expr),*) => ({
38        let mut buffer = vec![];
39        buffer.reserve(64);
40        {$crate::push_bytes_to_vec!(buffer, $($x),*)}.map(|_| buffer)
41    });
42}
43
44#[macro_export]
45macro_rules! push_bytes_to_vec {
46    ($buffer:expr, $y:expr, $($x:expr),*) => ({
47        {ToBytes::write_le(&$y, &mut $buffer)}.and({$crate::push_bytes_to_vec!($buffer, $($x),*)})
48    });
49
50    ($buffer:expr, $x:expr) => ({
51        ToBytes::write_le(&$x, &mut $buffer)
52    })
53}
54
55pub trait ToBytes {
56    /// Writes `self` into `writer` as little-endian bytes.
57    fn write_le<W: Write>(&self, writer: W) -> IoResult<()>
58    where
59        Self: Sized;
60
61    /// Returns `self` as a byte array in little-endian order.
62    fn to_bytes_le(&self) -> anyhow::Result<Vec<u8>>
63    where
64        Self: Sized,
65    {
66        Ok(to_bytes_le![self]?)
67    }
68}
69
70pub trait FromBytes {
71    /// Reads `Self` from `reader` as little-endian bytes.
72    fn read_le<R: Read>(reader: R) -> IoResult<Self>
73    where
74        Self: Sized;
75
76    /// Returns `Self` from a byte array in little-endian order.
77    fn from_bytes_le(bytes: &[u8]) -> anyhow::Result<Self>
78    where
79        Self: Sized,
80    {
81        Ok(Self::read_le(bytes)?)
82    }
83}
84
85pub struct ToBytesSerializer<T: ToBytes>(PhantomData<T>);
86
87impl<T: ToBytes> ToBytesSerializer<T> {
88    /// Serializes a static-sized object as a byte array (without length encoding).
89    pub fn serialize<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
90        let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
91        let mut tuple = serializer.serialize_tuple(bytes.len())?;
92        for byte in &bytes {
93            tuple.serialize_element(byte)?;
94        }
95        tuple.end()
96    }
97
98    /// Serializes a dynamically-sized object as a byte array with length encoding.
99    pub fn serialize_with_size_encoding<S: Serializer>(object: &T, serializer: S) -> Result<S::Ok, S::Error> {
100        let bytes = object.to_bytes_le().map_err(ser::Error::custom)?;
101        serializer.serialize_bytes(&bytes)
102    }
103}
104
105pub struct FromBytesDeserializer<T: FromBytes>(PhantomData<T>);
106
107impl<'de, T: FromBytes> FromBytesDeserializer<T> {
108    /// Deserializes a static-sized byte array (without length encoding).
109    ///
110    /// This method fails if `deserializer` is given an insufficient `size`.
111    pub fn deserialize<D: Deserializer<'de>>(deserializer: D, name: &str, size: usize) -> Result<T, D::Error> {
112        let mut buffer = Vec::with_capacity(size);
113        deserializer.deserialize_tuple(size, FromBytesVisitor::new(&mut buffer, name))?;
114        FromBytes::read_le(&*buffer).map_err(de::Error::custom)
115    }
116
117    /// Deserializes a static-sized byte array, with a u8 length encoding at the start.
118    pub fn deserialize_with_u8<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
119        deserializer.deserialize_tuple(1usize << 8usize, FromBytesWithU8Visitor::<T>::new(name))
120    }
121
122    /// Deserializes a static-sized byte array, with a u16 length encoding at the start.
123    pub fn deserialize_with_u16<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
124        deserializer.deserialize_tuple(1usize << 16usize, FromBytesWithU16Visitor::<T>::new(name))
125    }
126
127    /// Deserializes a dynamically-sized byte array.
128    pub fn deserialize_with_size_encoding<D: Deserializer<'de>>(deserializer: D, name: &str) -> Result<T, D::Error> {
129        let mut buffer = Vec::with_capacity(32);
130        deserializer.deserialize_bytes(FromBytesVisitor::new(&mut buffer, name))?;
131        FromBytes::read_le(&*buffer).map_err(de::Error::custom)
132    }
133
134    /// Attempts to deserialize a byte array (without length encoding).
135    ///
136    /// This method does *not* fail if `deserializer` is given an insufficient `size`,
137    /// however this method fails if `FromBytes` fails to read the value of `T`.
138    pub fn deserialize_extended<D: Deserializer<'de>>(
139        deserializer: D,
140        name: &str,
141        size_a: usize,
142        size_b: usize,
143    ) -> Result<T, D::Error> {
144        // Order the given sizes from smallest to largest.
145        let (size_a, size_b) = match size_a < size_b {
146            true => (size_a, size_b),
147            false => (size_b, size_a),
148        };
149
150        // Ensure 'size_b' is within bounds.
151        if size_b > i32::MAX as usize {
152            return Err(D::Error::custom(format!("size_b ({size_b}) exceeds maximum")));
153        }
154
155        // Reserve a new `Vec` with the larger size capacity.
156        let mut buffer = Vec::with_capacity(size_b);
157
158        // Attempt to deserialize on the larger size, to load up to the maximum buffer size.
159        match deserializer.deserialize_tuple(size_b, FromBytesVisitor::new(&mut buffer, name)) {
160            // Deserialized a full buffer, attempt to read up to `size_b`.
161            Ok(()) => FromBytes::read_le(&buffer[..size_b]).map_err(de::Error::custom),
162            // Deserialized a partial buffer, attempt to read up to `size_a`, if exactly `size_a` was read.
163            Err(error) => match buffer.len() == size_a {
164                true => FromBytes::read_le(&buffer[..size_a]).map_err(de::Error::custom),
165                false => Err(error),
166            },
167        }
168    }
169}
170
171pub struct FromBytesVisitor<'a>(&'a mut Vec<u8>, SmolStr);
172
173impl<'a> FromBytesVisitor<'a> {
174    /// Initializes a new `FromBytesVisitor` with the given `buffer` and `name`.
175    pub fn new(buffer: &'a mut Vec<u8>, name: &str) -> Self {
176        Self(buffer, SmolStr::new(name))
177    }
178}
179
180impl<'de> Visitor<'de> for FromBytesVisitor<'_> {
181    type Value = ();
182
183    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
184        formatter.write_str(&format!("a valid {} ", self.1))
185    }
186
187    fn visit_borrowed_bytes<E: serde::de::Error>(self, bytes: &'de [u8]) -> Result<Self::Value, E> {
188        self.0.extend_from_slice(bytes);
189        Ok(())
190    }
191
192    fn visit_seq<S: SeqAccess<'de>>(self, mut seq: S) -> Result<Self::Value, S::Error> {
193        while let Some(byte) = seq.next_element()? {
194            self.0.push(byte);
195        }
196        Ok(())
197    }
198}
199
200struct FromBytesWithU8Visitor<T: FromBytes>(String, PhantomData<T>);
201
202impl<T: FromBytes> FromBytesWithU8Visitor<T> {
203    /// Initializes a new `FromBytesWithU8Visitor` with the given `name`.
204    pub fn new(name: &str) -> Self {
205        Self(name.to_string(), PhantomData)
206    }
207}
208
209impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU8Visitor<T> {
210    type Value = T;
211
212    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
213        formatter.write_str(&format!("a valid {} ", self.0))
214    }
215
216    fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
217        // Read the size of the object.
218        let length: u8 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
219
220        // Initialize the vector with the correct length.
221        let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 1);
222        // Push the length into the vector.
223        bytes.push(length);
224        // Read the bytes.
225        for i in 0..length {
226            // Push the next byte into the vector.
227            bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 1, &self))?);
228        }
229        // Deserialize the vector.
230        FromBytes::read_le(&*bytes).map_err(de::Error::custom)
231    }
232}
233
234struct FromBytesWithU16Visitor<T: FromBytes>(String, PhantomData<T>);
235
236impl<T: FromBytes> FromBytesWithU16Visitor<T> {
237    /// Initializes a new `FromBytesWithU16Visitor` with the given `name`.
238    pub fn new(name: &str) -> Self {
239        Self(name.to_string(), PhantomData)
240    }
241}
242
243impl<'de, T: FromBytes> Visitor<'de> for FromBytesWithU16Visitor<T> {
244    type Value = T;
245
246    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
247        formatter.write_str(&format!("a valid {} ", self.0))
248    }
249
250    fn visit_seq<V: SeqAccess<'de>>(self, mut seq: V) -> Result<Self::Value, V::Error> {
251        // Read the size of the object.
252        let length: u16 = seq.next_element()?.ok_or_else(|| Error::invalid_length(0, &self))?;
253
254        // Initialize the vector with the correct length.
255        let mut bytes: Vec<u8> = Vec::with_capacity((length as usize) + 2);
256        // Push the length into the vector.
257        bytes.extend(length.to_le_bytes());
258        // Read the bytes.
259        for i in 0..length {
260            // Push the next byte into the vector.
261            bytes.push(seq.next_element()?.ok_or_else(|| Error::invalid_length((i as usize) + 2, &self))?);
262        }
263        // Deserialize the vector.
264        FromBytes::read_le(&*bytes).map_err(de::Error::custom)
265    }
266}
267
268impl ToBytes for () {
269    #[inline]
270    fn write_le<W: Write>(&self, _writer: W) -> IoResult<()> {
271        Ok(())
272    }
273}
274
275impl FromBytes for () {
276    #[inline]
277    fn read_le<R: Read>(_bytes: R) -> IoResult<Self> {
278        Ok(())
279    }
280}
281
282impl ToBytes for bool {
283    #[inline]
284    fn write_le<W: Write>(&self, writer: W) -> IoResult<()> {
285        u8::write_le(&(*self as u8), writer)
286    }
287}
288
289impl FromBytes for bool {
290    #[inline]
291    fn read_le<R: Read>(reader: R) -> IoResult<Self> {
292        match u8::read_le(reader) {
293            Ok(0) => Ok(false),
294            Ok(1) => Ok(true),
295            Ok(_) => Err(error("FromBytes::read failed")),
296            Err(err) => Err(err),
297        }
298    }
299}
300
301impl ToBytes for SocketAddr {
302    #[inline]
303    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
304        // Write the IP address.
305        match self.ip() {
306            IpAddr::V4(ipv4) => {
307                0u8.write_le(&mut writer)?;
308                u32::from(ipv4).write_le(&mut writer)?;
309            }
310            IpAddr::V6(ipv6) => {
311                1u8.write_le(&mut writer)?;
312                u128::from(ipv6).write_le(&mut writer)?;
313            }
314        }
315        // Write the port.
316        self.port().write_le(&mut writer)?;
317        Ok(())
318    }
319}
320
321impl FromBytes for SocketAddr {
322    #[inline]
323    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
324        // Read the IP address.
325        let ip = match u8::read_le(&mut reader)? {
326            0 => IpAddr::V4(Ipv4Addr::from(u32::read_le(&mut reader)?)),
327            1 => IpAddr::V6(Ipv6Addr::from(u128::read_le(&mut reader)?)),
328            _ => return Err(error("Invalid IP address")),
329        };
330        // Read the port.
331        let port = u16::read_le(&mut reader)?;
332        Ok(SocketAddr::new(ip, port))
333    }
334}
335
336macro_rules! impl_bytes_for_integer {
337    ($int:ty) => {
338        impl ToBytes for $int {
339            #[inline]
340            fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
341                writer.write_all(&self.to_le_bytes())
342            }
343        }
344
345        impl FromBytes for $int {
346            #[inline]
347            fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
348                let mut bytes = [0u8; core::mem::size_of::<$int>()];
349                reader.read_exact(&mut bytes)?;
350                Ok(<$int>::from_le_bytes(bytes))
351            }
352        }
353    };
354}
355
356impl_bytes_for_integer!(u8);
357impl_bytes_for_integer!(u16);
358impl_bytes_for_integer!(u32);
359impl_bytes_for_integer!(u64);
360impl_bytes_for_integer!(u128);
361
362impl_bytes_for_integer!(i8);
363impl_bytes_for_integer!(i16);
364impl_bytes_for_integer!(i32);
365impl_bytes_for_integer!(i64);
366impl_bytes_for_integer!(i128);
367
368impl<const N: usize> ToBytes for [u8; N] {
369    #[inline]
370    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
371        writer.write_all(self)
372    }
373}
374
375impl<const N: usize> FromBytes for [u8; N] {
376    #[inline]
377    fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
378        let mut arr = [0u8; N];
379        reader.read_exact(&mut arr)?;
380        Ok(arr)
381    }
382}
383
384macro_rules! impl_bytes_for_integer_array {
385    ($int:ty) => {
386        impl<const N: usize> ToBytes for [$int; N] {
387            #[inline]
388            fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
389                for num in self {
390                    writer.write_all(&num.to_le_bytes())?;
391                }
392                Ok(())
393            }
394        }
395
396        impl<const N: usize> FromBytes for [$int; N] {
397            #[inline]
398            fn read_le<R: Read>(mut reader: R) -> IoResult<Self> {
399                let mut res: [$int; N] = [0; N];
400                for num in res.iter_mut() {
401                    let mut bytes = [0u8; core::mem::size_of::<$int>()];
402                    reader.read_exact(&mut bytes)?;
403                    *num = <$int>::from_le_bytes(bytes);
404                }
405                Ok(res)
406            }
407        }
408    };
409}
410
411// u8 has a dedicated, faster implementation above
412impl_bytes_for_integer_array!(u16);
413impl_bytes_for_integer_array!(u32);
414impl_bytes_for_integer_array!(u64);
415
416impl<L: ToBytes, R: ToBytes> ToBytes for (L, R) {
417    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
418        self.0.write_le(&mut writer)?;
419        self.1.write_le(&mut writer)?;
420        Ok(())
421    }
422}
423
424impl<L: FromBytes, R: FromBytes> FromBytes for (L, R) {
425    #[inline]
426    fn read_le<Reader: Read>(mut reader: Reader) -> IoResult<Self> {
427        let left: L = FromBytes::read_le(&mut reader)?;
428        let right: R = FromBytes::read_le(&mut reader)?;
429        Ok((left, right))
430    }
431}
432
433impl<T: ToBytes> ToBytes for Vec<T> {
434    #[inline]
435    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
436        for item in self {
437            item.write_le(&mut writer)?;
438        }
439        Ok(())
440    }
441}
442
443impl<'a, T: 'a + ToBytes> ToBytes for &'a [T] {
444    #[inline]
445    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
446        for item in *self {
447            item.write_le(&mut writer)?;
448        }
449        Ok(())
450    }
451}
452
453impl<'a, T: 'a + ToBytes> ToBytes for &'a T {
454    #[inline]
455    fn write_le<W: Write>(&self, mut writer: W) -> IoResult<()> {
456        (*self).write_le(&mut writer)
457    }
458}
459
460#[inline]
461pub fn bits_from_bytes_le(bytes: &[u8]) -> impl DoubleEndedIterator<Item = bool> + '_ {
462    bytes.iter().flat_map(|byte| (0..8).map(move |i| (*byte >> i) & 1 == 1))
463}
464
465#[inline]
466pub fn bytes_from_bits_le(bits: &[bool]) -> Vec<u8> {
467    let desired_size = if bits.len() % 8 == 0 { bits.len() / 8 } else { bits.len() / 8 + 1 };
468
469    let mut bytes = Vec::with_capacity(desired_size);
470    for bits in bits.chunks(8) {
471        let mut result = 0u8;
472        for (i, bit) in bits.iter().enumerate() {
473            let bit_value = *bit as u8;
474            result += bit_value << i as u8;
475        }
476
477        bytes.push(result);
478    }
479
480    bytes
481}
482
483/// A wrapper around a `Write` instance that limits the number of bytes that can be written.
484pub struct LimitedWriter<W: Write> {
485    writer: W,
486    limit: usize,
487    remaining: usize,
488}
489
490impl<W: Write> LimitedWriter<W> {
491    pub fn new(writer: W, limit: usize) -> Self {
492        Self { writer, limit, remaining: limit }
493    }
494}
495
496impl<W: Write> Write for LimitedWriter<W> {
497    fn write(&mut self, buf: &[u8]) -> IoResult<usize> {
498        if self.remaining == 0 && !buf.is_empty() {
499            return Err(std::io::Error::other(format!("Byte limit exceeded: {}", self.limit)));
500        }
501
502        let max_write = std::cmp::min(buf.len(), self.remaining);
503        match self.writer.write(&buf[..max_write]) {
504            Ok(n) => {
505                self.remaining -= n;
506                Ok(n)
507            }
508            Err(e) => Err(e),
509        }
510    }
511
512    fn flush(&mut self) -> IoResult<()> {
513        self.writer.flush()
514    }
515}
516
517#[cfg(test)]
518mod test {
519    use super::*;
520    use crate::TestRng;
521
522    use rand::Rng;
523
524    const ITERATIONS: usize = 1000;
525
526    #[test]
527    fn test_macro_empty() {
528        let array: Vec<u8> = vec![];
529        let bytes_a: Vec<u8> = to_bytes_le![array].unwrap();
530        assert_eq!(&array, &bytes_a);
531        assert_eq!(0, bytes_a.len());
532
533        let bytes_b: Vec<u8> = array.to_bytes_le().unwrap();
534        assert_eq!(&array, &bytes_b);
535        assert_eq!(0, bytes_b.len());
536    }
537
538    #[test]
539    fn test_macro() {
540        let array1 = [1u8; 32];
541        let array2 = [2u8; 16];
542        let array3 = [3u8; 8];
543        let bytes = to_bytes_le![array1, array2, array3].unwrap();
544        assert_eq!(bytes.len(), 56);
545
546        let mut actual_bytes = Vec::new();
547        actual_bytes.extend_from_slice(&array1);
548        actual_bytes.extend_from_slice(&array2);
549        actual_bytes.extend_from_slice(&array3);
550        assert_eq!(bytes, actual_bytes);
551    }
552
553    #[test]
554    fn test_bits_from_bytes_le() {
555        assert_eq!(bits_from_bytes_le(&[204, 76]).collect::<Vec<bool>>(), [
556            false, false, true, true, false, false, true, true, // 204
557            false, false, true, true, false, false, true, false, // 76
558        ]);
559    }
560
561    #[test]
562    fn test_bytes_from_bits_le() {
563        let bits = [
564            false, false, true, true, false, false, true, true, // 204
565            false, false, true, true, false, false, true, false, // 76
566        ];
567        assert_eq!(bytes_from_bits_le(&bits), [204, 76]);
568    }
569
570    #[test]
571    fn test_from_bits_le_to_bytes_le_roundtrip() {
572        let mut rng = TestRng::default();
573
574        for _ in 0..ITERATIONS {
575            let given_bytes: [u8; 32] = rng.r#gen();
576
577            let bits = bits_from_bytes_le(&given_bytes).collect::<Vec<_>>();
578            let recovered_bytes = bytes_from_bits_le(&bits);
579
580            assert_eq!(given_bytes.to_vec(), recovered_bytes);
581        }
582    }
583
584    #[test]
585    fn test_socketaddr_bytes() {
586        fn random_ipv4_address(rng: &mut TestRng) -> Ipv4Addr {
587            Ipv4Addr::new(rng.r#gen(), rng.r#gen(), rng.r#gen(), rng.r#gen())
588        }
589
590        fn random_ipv6_address(rng: &mut TestRng) -> Ipv6Addr {
591            Ipv6Addr::new(
592                rng.r#gen(),
593                rng.r#gen(),
594                rng.r#gen(),
595                rng.r#gen(),
596                rng.r#gen(),
597                rng.r#gen(),
598                rng.r#gen(),
599                rng.r#gen(),
600            )
601        }
602
603        fn random_port(rng: &mut TestRng) -> u16 {
604            rng.gen_range(1025..=65535) // excluding well-known ports
605        }
606
607        let rng = &mut TestRng::default();
608
609        for _ in 0..1_000_000 {
610            let ipv4 = SocketAddr::new(IpAddr::V4(random_ipv4_address(rng)), random_port(rng));
611            let bytes = ipv4.to_bytes_le().unwrap();
612            let ipv4_2 = SocketAddr::read_le(&bytes[..]).unwrap();
613            assert_eq!(ipv4, ipv4_2);
614
615            let ipv6 = SocketAddr::new(IpAddr::V6(random_ipv6_address(rng)), random_port(rng));
616            let bytes = ipv6.to_bytes_le().unwrap();
617            let ipv6_2 = SocketAddr::read_le(&bytes[..]).unwrap();
618            assert_eq!(ipv6, ipv6_2);
619        }
620    }
621}