solomka_program/
serde_varint.rs

1#![allow(clippy::integer_arithmetic)]
2use {
3    serde::{
4        de::{Error as _, SeqAccess, Visitor},
5        ser::SerializeTuple,
6        Deserializer, Serializer,
7    },
8    std::{fmt, marker::PhantomData},
9};
10
11pub trait VarInt: Sized {
12    fn visit_seq<'de, A>(seq: A) -> Result<Self, A::Error>
13    where
14        A: SeqAccess<'de>;
15
16    fn serialize<S>(self, serializer: S) -> Result<S::Ok, S::Error>
17    where
18        S: Serializer;
19}
20
21struct VarIntVisitor<T> {
22    phantom: PhantomData<T>,
23}
24
25impl<'de, T> Visitor<'de> for VarIntVisitor<T>
26where
27    T: VarInt,
28{
29    type Value = T;
30
31    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
32        formatter.write_str("a VarInt")
33    }
34
35    fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
36    where
37        A: SeqAccess<'de>,
38    {
39        T::visit_seq(seq)
40    }
41}
42
43pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
44where
45    T: Copy + VarInt,
46    S: Serializer,
47{
48    (*value).serialize(serializer)
49}
50
51pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
52where
53    D: Deserializer<'de>,
54    T: VarInt,
55{
56    deserializer.deserialize_tuple(
57        (std::mem::size_of::<T>() * 8 + 6) / 7,
58        VarIntVisitor {
59            phantom: PhantomData::default(),
60        },
61    )
62}
63
64macro_rules! impl_var_int {
65    ($type:ty) => {
66        impl VarInt for $type {
67            fn visit_seq<'de, A>(mut seq: A) -> Result<Self, A::Error>
68            where
69                A: SeqAccess<'de>,
70            {
71                let mut out = 0;
72                let mut shift = 0u32;
73                while shift < <$type>::BITS {
74                    let byte = match seq.next_element::<u8>()? {
75                        None => return Err(A::Error::custom("Invalid Sequence")),
76                        Some(byte) => byte,
77                    };
78                    out |= ((byte & 0x7F) as Self) << shift;
79                    if byte & 0x80 == 0 {
80                        // Last byte should not have been truncated when it was
81                        // shifted to the left above.
82                        if (out >> shift) as u8 != byte {
83                            return Err(A::Error::custom("Last Byte Truncated"));
84                        }
85                        // Last byte can be zero only if there was only one
86                        // byte and the output is also zero.
87                        if byte == 0u8 && (shift != 0 || out != 0) {
88                            return Err(A::Error::custom("Invalid Trailing Zeros"));
89                        }
90                        return Ok(out);
91                    }
92                    shift += 7;
93                }
94                Err(A::Error::custom("Left Shift Overflows"))
95            }
96
97            fn serialize<S>(mut self, serializer: S) -> Result<S::Ok, S::Error>
98            where
99                S: Serializer,
100            {
101                let bits = <$type>::BITS - self.leading_zeros();
102                let num_bytes = ((bits + 6) / 7).max(1) as usize;
103                let mut seq = serializer.serialize_tuple(num_bytes)?;
104                while self >= 0x80 {
105                    let byte = ((self & 0x7F) | 0x80) as u8;
106                    seq.serialize_element(&byte)?;
107                    self >>= 7;
108                }
109                seq.serialize_element(&(self as u8))?;
110                seq.end()
111            }
112        }
113    };
114}
115
116impl_var_int!(u32);
117impl_var_int!(u64);
118
119#[cfg(test)]
120mod tests {
121    use rand::Rng;
122
123    #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
124    struct Dummy {
125        #[serde(with = "super")]
126        a: u32,
127        b: u64,
128        #[serde(with = "super")]
129        c: u64,
130        d: u32,
131    }
132
133    #[test]
134    fn test_serde_varint() {
135        assert_eq!((std::mem::size_of::<u32>() * 8 + 6) / 7, 5);
136        assert_eq!((std::mem::size_of::<u64>() * 8 + 6) / 7, 10);
137        let dummy = Dummy {
138            a: 698,
139            b: 370,
140            c: 146,
141            d: 796,
142        };
143        let bytes = bincode::serialize(&dummy).unwrap();
144        assert_eq!(bytes.len(), 16);
145        let other: Dummy = bincode::deserialize(&bytes).unwrap();
146        assert_eq!(other, dummy);
147    }
148
149    #[test]
150    fn test_serde_varint_zero() {
151        let dummy = Dummy {
152            a: 0,
153            b: 0,
154            c: 0,
155            d: 0,
156        };
157        let bytes = bincode::serialize(&dummy).unwrap();
158        assert_eq!(bytes.len(), 14);
159        let other: Dummy = bincode::deserialize(&bytes).unwrap();
160        assert_eq!(other, dummy);
161    }
162
163    #[test]
164    fn test_serde_varint_max() {
165        let dummy = Dummy {
166            a: u32::MAX,
167            b: u64::MAX,
168            c: u64::MAX,
169            d: u32::MAX,
170        };
171        let bytes = bincode::serialize(&dummy).unwrap();
172        assert_eq!(bytes.len(), 27);
173        let other: Dummy = bincode::deserialize(&bytes).unwrap();
174        assert_eq!(other, dummy);
175    }
176
177    #[test]
178    fn test_serde_varint_rand() {
179        let mut rng = rand::thread_rng();
180        for _ in 0..100_000 {
181            let dummy = Dummy {
182                a: rng.gen::<u32>() >> rng.gen_range(0, u32::BITS),
183                b: rng.gen::<u64>() >> rng.gen_range(0, u64::BITS),
184                c: rng.gen::<u64>() >> rng.gen_range(0, u64::BITS),
185                d: rng.gen::<u32>() >> rng.gen_range(0, u32::BITS),
186            };
187            let bytes = bincode::serialize(&dummy).unwrap();
188            let other: Dummy = bincode::deserialize(&bytes).unwrap();
189            assert_eq!(other, dummy);
190        }
191    }
192
193    #[test]
194    fn test_serde_varint_trailing_zeros() {
195        let buffer = [0x93, 0xc2, 0xa9, 0x8d, 0x0];
196        let out = bincode::deserialize::<Dummy>(&buffer);
197        assert!(out.is_err());
198        assert_eq!(
199            format!("{:?}", out),
200            r#"Err(Custom("Invalid Trailing Zeros"))"#
201        );
202        let buffer = [0x80, 0x0];
203        let out = bincode::deserialize::<Dummy>(&buffer);
204        assert!(out.is_err());
205        assert_eq!(
206            format!("{:?}", out),
207            r#"Err(Custom("Invalid Trailing Zeros"))"#
208        );
209    }
210
211    #[test]
212    fn test_serde_varint_last_byte_truncated() {
213        let buffer = [0xe4, 0xd7, 0x88, 0xf6, 0x6f, 0xd4, 0xb9, 0x59];
214        let out = bincode::deserialize::<Dummy>(&buffer);
215        assert!(out.is_err());
216        assert_eq!(
217            format!("{:?}", out),
218            r#"Err(Custom("Last Byte Truncated"))"#
219        );
220    }
221
222    #[test]
223    fn test_serde_varint_shift_overflow() {
224        let buffer = [0x84, 0xdf, 0x96, 0xfa, 0xef];
225        let out = bincode::deserialize::<Dummy>(&buffer);
226        assert!(out.is_err());
227        assert_eq!(
228            format!("{:?}", out),
229            r#"Err(Custom("Left Shift Overflows"))"#
230        );
231    }
232
233    #[test]
234    fn test_serde_varint_short_buffer() {
235        let buffer = [0x84, 0xdf, 0x96, 0xfa];
236        let out = bincode::deserialize::<Dummy>(&buffer);
237        assert!(out.is_err());
238        assert_eq!(format!("{:?}", out), r#"Err(Io(Kind(UnexpectedEof)))"#);
239    }
240
241    #[test]
242    fn test_serde_varint_fuzz() {
243        let mut rng = rand::thread_rng();
244        let mut buffer = [0u8; 36];
245        let mut num_errors = 0;
246        for _ in 0..200_000 {
247            rng.fill(&mut buffer[..]);
248            match bincode::deserialize::<Dummy>(&buffer) {
249                Err(_) => {
250                    num_errors += 1;
251                }
252                Ok(dummy) => {
253                    let bytes = bincode::serialize(&dummy).unwrap();
254                    assert_eq!(bytes, &buffer[..bytes.len()]);
255                }
256            }
257        }
258        assert!(num_errors > 2_000);
259    }
260}