solana_wincode_varint/
lib.rs1#![cfg_attr(docsrs, feature(doc_cfg))]
3use {
4 std::mem::MaybeUninit,
5 wincode::{
6 config::ConfigCore,
7 io::{Reader, Writer},
8 ReadError, ReadResult, SchemaRead, SchemaWrite, WriteResult,
9 },
10};
11
12pub struct Leb128Int<T>(pub T);
34
35macro_rules! impl_schema_read {
36 ($type:ty) => {
37 unsafe impl<'de, C: ConfigCore> SchemaRead<'de, C> for Leb128Int<$type> {
38 type Dst = $type;
39
40 fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit<$type>) -> ReadResult<()> {
41 let mut out: $type = 0;
42 let mut shift = 0u32;
43 while shift < <$type>::BITS {
44 let byte = reader.take_byte()?;
45 out |= ((byte & 0x7F) as $type) << shift;
46 if byte & 0x80 == 0 {
47 if (out >> shift) as u8 != byte {
50 return Err(ReadError::Custom("Last Byte Truncated"));
51 }
52 if byte == 0u8 && (shift != 0 || out != 0) {
55 return Err(ReadError::Custom("Invalid Trailing Zeros"));
56 }
57 dst.write(out);
58 return Ok(());
59 }
60 shift = shift.wrapping_add(7);
61 }
62 Err(ReadError::Custom("Left Shift Overflows"))
63 }
64 }
65 };
66}
67
68macro_rules! impl_schema_write {
69 ($type:ty) => {
70 unsafe impl<C: ConfigCore> SchemaWrite<C> for Leb128Int<$type> {
71 type Src = $type;
72
73 fn size_of(src: &$type) -> WriteResult<usize> {
74 let bits = <$type>::BITS.wrapping_sub(src.leading_zeros());
75 Ok(bits.div_ceil(7).max(1) as usize)
76 }
77
78 fn write(mut writer: impl Writer, src: &$type) -> WriteResult<()> {
79 let mut value = *src;
80 while value >= 0x80 {
81 let byte = ((value & 0x7F) | 0x80) as u8;
82 writer.write(&[byte])?;
83 value >>= 7;
84 }
85 Ok(writer.write(&[value as u8])?)
86 }
87 }
88 };
89}
90
91impl_schema_read!(u16);
92impl_schema_read!(u32);
93impl_schema_read!(u64);
94
95impl_schema_write!(u16);
96impl_schema_write!(u32);
97impl_schema_write!(u64);
98
99#[cfg(test)]
100mod tests {
101 use {
102 rand::Rng,
103 serde_derive::{Deserialize, Serialize},
104 };
105
106 #[derive(
108 Debug, Eq, PartialEq, Serialize, Deserialize, wincode::SchemaRead, wincode::SchemaWrite,
109 )]
110 struct Dummy {
111 #[serde(with = "solana_serde_varint")]
112 #[wincode(with = "crate::Leb128Int<u16>")]
113 a: u16,
114 #[serde(with = "solana_serde_varint")]
115 #[wincode(with = "crate::Leb128Int<u32>")]
116 b: u32,
117 #[serde(with = "solana_serde_varint")]
118 #[wincode(with = "crate::Leb128Int<u64>")]
119 c: u64,
120 }
121
122 fn check(dummy: &Dummy) {
123 let wincode_bytes = wincode::serialize(dummy).unwrap();
124 let bincode_bytes = bincode::serialize(dummy).unwrap();
125 assert_eq!(wincode_bytes, bincode_bytes);
126 assert_eq!(
127 &wincode::deserialize::<Dummy>(&wincode_bytes).unwrap(),
128 dummy
129 );
130 assert_eq!(
131 &bincode::deserialize::<Dummy>(&bincode_bytes).unwrap(),
132 dummy
133 );
134 }
135
136 #[test]
137 fn edge_cases() {
138 let cases = [
139 Dummy { a: 0, b: 0, c: 0 },
140 Dummy { a: 1, b: 1, c: 1 },
141 Dummy {
142 a: 0x7F,
143 b: 0x7F,
144 c: 0x7F,
145 },
146 Dummy {
147 a: 0x80,
148 b: 0x80,
149 c: 0x80,
150 },
151 Dummy {
152 a: 0x3FFF,
153 b: 0x3FFF,
154 c: 0x3FFF,
155 },
156 Dummy {
157 a: 0x4000,
158 b: 0x4000,
159 c: 0x4000,
160 },
161 Dummy {
162 a: u16::MAX,
163 b: u32::MAX,
164 c: u64::MAX,
165 },
166 ];
167 for dummy in &cases {
168 check(dummy);
169 }
170 }
171
172 #[test]
173 fn random() {
174 let mut rng = rand::rng();
175 for _ in 0..100_000 {
176 check(&Dummy {
177 a: rng.random::<u16>() >> rng.random_range(0..u16::BITS),
178 b: rng.random::<u32>() >> rng.random_range(0..u32::BITS),
179 c: rng.random::<u64>() >> rng.random_range(0..u64::BITS),
180 });
181 }
182 }
183
184 #[test]
185 fn trailing_zeros() {
186 let buf = [0x80u8, 0x00];
187 let r = wincode::deserialize::<Dummy>(&buf);
188 assert!(matches!(
189 r,
190 Err(wincode::ReadError::Custom("Invalid Trailing Zeros"))
191 ));
192 assert!(bincode::deserialize::<Dummy>(&buf).is_err());
193 }
194
195 #[test]
196 fn last_byte_truncated() {
197 let buf = [0x01u8, 0xe4, 0xd7, 0x88, 0xf6, 0x6f];
198 let r = wincode::deserialize::<Dummy>(&buf);
199 assert!(matches!(
200 r,
201 Err(wincode::ReadError::Custom("Last Byte Truncated"))
202 ));
203 assert!(bincode::deserialize::<Dummy>(&buf).is_err());
204 }
205
206 #[test]
207 fn shift_overflow() {
208 let buf = [0x80u8, 0x80, 0x80];
209 let r = wincode::deserialize::<Dummy>(&buf);
210 assert!(matches!(
211 r,
212 Err(wincode::ReadError::Custom("Left Shift Overflows"))
213 ));
214 assert!(bincode::deserialize::<Dummy>(&buf).is_err());
215 }
216
217 #[test]
218 fn short_buffer() {
219 let buf = [0x80u8];
220 let r = wincode::deserialize::<Dummy>(&buf);
221 assert!(matches!(r, Err(wincode::ReadError::Io(_))));
222 assert!(bincode::deserialize::<Dummy>(&buf).is_err());
223
224 let r = wincode::deserialize::<Dummy>(&[]);
225 assert!(matches!(r, Err(wincode::ReadError::Io(_))));
226 assert!(bincode::deserialize::<Dummy>(&[]).is_err());
227 }
228}