1#![allow(clippy::integer_arithmetic)]
2use {
3 serde::{
4 de::{Error as _, SeqAccess, Visitor},
5 ser::SerializeTuple,
6 Deserializer, Serializer,
7 },
8 std::{fmt, marker::PhantomData},
9};
10
11pub trait VarInt: Sized {
12 fn visit_seq<'de, A>(seq: A) -> Result<Self, A::Error>
13 where
14 A: SeqAccess<'de>;
15
16 fn serialize<S>(self, serializer: S) -> Result<S::Ok, S::Error>
17 where
18 S: Serializer;
19}
20
21struct VarIntVisitor<T> {
22 phantom: PhantomData<T>,
23}
24
25impl<'de, T> Visitor<'de> for VarIntVisitor<T>
26where
27 T: VarInt,
28{
29 type Value = T;
30
31 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
32 formatter.write_str("a VarInt")
33 }
34
35 fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
36 where
37 A: SeqAccess<'de>,
38 {
39 T::visit_seq(seq)
40 }
41}
42
43pub fn serialize<S, T>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
44where
45 T: Copy + VarInt,
46 S: Serializer,
47{
48 (*value).serialize(serializer)
49}
50
51pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
52where
53 D: Deserializer<'de>,
54 T: VarInt,
55{
56 deserializer.deserialize_tuple(
57 (std::mem::size_of::<T>() * 8 + 6) / 7,
58 VarIntVisitor {
59 phantom: PhantomData::default(),
60 },
61 )
62}
63
64macro_rules! impl_var_int {
65 ($type:ty) => {
66 impl VarInt for $type {
67 fn visit_seq<'de, A>(mut seq: A) -> Result<Self, A::Error>
68 where
69 A: SeqAccess<'de>,
70 {
71 let mut out = 0;
72 let mut shift = 0u32;
73 while shift < <$type>::BITS {
74 let byte = match seq.next_element::<u8>()? {
75 None => return Err(A::Error::custom("Invalid Sequence")),
76 Some(byte) => byte,
77 };
78 out |= ((byte & 0x7F) as Self) << shift;
79 if byte & 0x80 == 0 {
80 if (out >> shift) as u8 != byte {
83 return Err(A::Error::custom("Last Byte Truncated"));
84 }
85 if byte == 0u8 && (shift != 0 || out != 0) {
88 return Err(A::Error::custom("Invalid Trailing Zeros"));
89 }
90 return Ok(out);
91 }
92 shift += 7;
93 }
94 Err(A::Error::custom("Left Shift Overflows"))
95 }
96
97 fn serialize<S>(mut self, serializer: S) -> Result<S::Ok, S::Error>
98 where
99 S: Serializer,
100 {
101 let bits = <$type>::BITS - self.leading_zeros();
102 let num_bytes = ((bits + 6) / 7).max(1) as usize;
103 let mut seq = serializer.serialize_tuple(num_bytes)?;
104 while self >= 0x80 {
105 let byte = ((self & 0x7F) | 0x80) as u8;
106 seq.serialize_element(&byte)?;
107 self >>= 7;
108 }
109 seq.serialize_element(&(self as u8))?;
110 seq.end()
111 }
112 }
113 };
114}
115
116impl_var_int!(u32);
117impl_var_int!(u64);
118
119#[cfg(test)]
120mod tests {
121 use rand::Rng;
122
123 #[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
124 struct Dummy {
125 #[serde(with = "super")]
126 a: u32,
127 b: u64,
128 #[serde(with = "super")]
129 c: u64,
130 d: u32,
131 }
132
133 #[test]
134 fn test_serde_varint() {
135 assert_eq!((std::mem::size_of::<u32>() * 8 + 6) / 7, 5);
136 assert_eq!((std::mem::size_of::<u64>() * 8 + 6) / 7, 10);
137 let dummy = Dummy {
138 a: 698,
139 b: 370,
140 c: 146,
141 d: 796,
142 };
143 let bytes = bincode::serialize(&dummy).unwrap();
144 assert_eq!(bytes.len(), 16);
145 let other: Dummy = bincode::deserialize(&bytes).unwrap();
146 assert_eq!(other, dummy);
147 }
148
149 #[test]
150 fn test_serde_varint_zero() {
151 let dummy = Dummy {
152 a: 0,
153 b: 0,
154 c: 0,
155 d: 0,
156 };
157 let bytes = bincode::serialize(&dummy).unwrap();
158 assert_eq!(bytes.len(), 14);
159 let other: Dummy = bincode::deserialize(&bytes).unwrap();
160 assert_eq!(other, dummy);
161 }
162
163 #[test]
164 fn test_serde_varint_max() {
165 let dummy = Dummy {
166 a: u32::MAX,
167 b: u64::MAX,
168 c: u64::MAX,
169 d: u32::MAX,
170 };
171 let bytes = bincode::serialize(&dummy).unwrap();
172 assert_eq!(bytes.len(), 27);
173 let other: Dummy = bincode::deserialize(&bytes).unwrap();
174 assert_eq!(other, dummy);
175 }
176
177 #[test]
178 fn test_serde_varint_rand() {
179 let mut rng = rand::thread_rng();
180 for _ in 0..100_000 {
181 let dummy = Dummy {
182 a: rng.gen::<u32>() >> rng.gen_range(0, u32::BITS),
183 b: rng.gen::<u64>() >> rng.gen_range(0, u64::BITS),
184 c: rng.gen::<u64>() >> rng.gen_range(0, u64::BITS),
185 d: rng.gen::<u32>() >> rng.gen_range(0, u32::BITS),
186 };
187 let bytes = bincode::serialize(&dummy).unwrap();
188 let other: Dummy = bincode::deserialize(&bytes).unwrap();
189 assert_eq!(other, dummy);
190 }
191 }
192
193 #[test]
194 fn test_serde_varint_trailing_zeros() {
195 let buffer = [0x93, 0xc2, 0xa9, 0x8d, 0x0];
196 let out = bincode::deserialize::<Dummy>(&buffer);
197 assert!(out.is_err());
198 assert_eq!(
199 format!("{:?}", out),
200 r#"Err(Custom("Invalid Trailing Zeros"))"#
201 );
202 let buffer = [0x80, 0x0];
203 let out = bincode::deserialize::<Dummy>(&buffer);
204 assert!(out.is_err());
205 assert_eq!(
206 format!("{:?}", out),
207 r#"Err(Custom("Invalid Trailing Zeros"))"#
208 );
209 }
210
211 #[test]
212 fn test_serde_varint_last_byte_truncated() {
213 let buffer = [0xe4, 0xd7, 0x88, 0xf6, 0x6f, 0xd4, 0xb9, 0x59];
214 let out = bincode::deserialize::<Dummy>(&buffer);
215 assert!(out.is_err());
216 assert_eq!(
217 format!("{:?}", out),
218 r#"Err(Custom("Last Byte Truncated"))"#
219 );
220 }
221
222 #[test]
223 fn test_serde_varint_shift_overflow() {
224 let buffer = [0x84, 0xdf, 0x96, 0xfa, 0xef];
225 let out = bincode::deserialize::<Dummy>(&buffer);
226 assert!(out.is_err());
227 assert_eq!(
228 format!("{:?}", out),
229 r#"Err(Custom("Left Shift Overflows"))"#
230 );
231 }
232
233 #[test]
234 fn test_serde_varint_short_buffer() {
235 let buffer = [0x84, 0xdf, 0x96, 0xfa];
236 let out = bincode::deserialize::<Dummy>(&buffer);
237 assert!(out.is_err());
238 assert_eq!(format!("{:?}", out), r#"Err(Io(Kind(UnexpectedEof)))"#);
239 }
240
241 #[test]
242 fn test_serde_varint_fuzz() {
243 let mut rng = rand::thread_rng();
244 let mut buffer = [0u8; 36];
245 let mut num_errors = 0;
246 for _ in 0..200_000 {
247 rng.fill(&mut buffer[..]);
248 match bincode::deserialize::<Dummy>(&buffer) {
249 Err(_) => {
250 num_errors += 1;
251 }
252 Ok(dummy) => {
253 let bytes = bincode::serialize(&dummy).unwrap();
254 assert_eq!(bytes, &buffer[..bytes.len()]);
255 }
256 }
257 }
258 assert!(num_errors > 2_000);
259 }
260}