Skip to main content

solana_wincode_varint/
lib.rs

1//! Wincode schemas for LEB128 variable-length integer encoding.
2#![cfg_attr(docsrs, feature(doc_cfg))]
3use {
4    std::mem::MaybeUninit,
5    wincode::{
6        config::ConfigCore,
7        io::{Reader, Writer},
8        ReadError, ReadResult, SchemaRead, SchemaWrite, WriteResult,
9    },
10};
11
12/// Wincode schema that encodes an integer using unsigned LEB128 (Little-Endian Base-128).
13///
14/// Each byte stores 7 bits of the value. The most significant bit is a continuation
15/// flag: `1` means more bytes follow, `0` marks the last byte. Produces the same
16/// bytes as `solana-serde-varint`, so wincode and bincode are wire-compatible.
17///
18/// Supported types: `u16`, `u32`, `u64`.
19///
20/// # Example
21///
22/// ```
23/// use solana_wincode_varint::Leb128Int;
24///
25/// #[derive(wincode::SchemaRead, wincode::SchemaWrite)]
26/// struct StructInts {
27///     #[wincode(with = "Leb128Int<u32>")]
28///     index: u32,
29///     #[wincode(with = "Leb128Int<u64>")]
30///     value: u64,
31/// }
32/// ```
33pub struct Leb128Int<T>(pub T);
34
35macro_rules! impl_schema_read {
36    ($type:ty) => {
37        unsafe impl<'de, C: ConfigCore> SchemaRead<'de, C> for Leb128Int<$type> {
38            type Dst = $type;
39
40            fn read(mut reader: impl Reader<'de>, dst: &mut MaybeUninit<$type>) -> ReadResult<()> {
41                let mut out: $type = 0;
42                let mut shift = 0u32;
43                while shift < <$type>::BITS {
44                    let byte = reader.take_byte()?;
45                    out |= ((byte & 0x7F) as $type) << shift;
46                    if byte & 0x80 == 0 {
47                        // Last byte should not have been truncated when it was
48                        // shifted to the left above.
49                        if (out >> shift) as u8 != byte {
50                            return Err(ReadError::Custom("Last Byte Truncated"));
51                        }
52                        // Last byte can be zero only if there was only one
53                        // byte and the output is also zero.
54                        if byte == 0u8 && (shift != 0 || out != 0) {
55                            return Err(ReadError::Custom("Invalid Trailing Zeros"));
56                        }
57                        dst.write(out);
58                        return Ok(());
59                    }
60                    shift = shift.wrapping_add(7);
61                }
62                Err(ReadError::Custom("Left Shift Overflows"))
63            }
64        }
65    };
66}
67
68macro_rules! impl_schema_write {
69    ($type:ty) => {
70        unsafe impl<C: ConfigCore> SchemaWrite<C> for Leb128Int<$type> {
71            type Src = $type;
72
73            fn size_of(src: &$type) -> WriteResult<usize> {
74                let bits = <$type>::BITS.wrapping_sub(src.leading_zeros());
75                Ok(bits.div_ceil(7).max(1) as usize)
76            }
77
78            fn write(mut writer: impl Writer, src: &$type) -> WriteResult<()> {
79                let mut value = *src;
80                while value >= 0x80 {
81                    let byte = ((value & 0x7F) | 0x80) as u8;
82                    writer.write(&[byte])?;
83                    value >>= 7;
84                }
85                Ok(writer.write(&[value as u8])?)
86            }
87        }
88    };
89}
90
91impl_schema_read!(u16);
92impl_schema_read!(u32);
93impl_schema_read!(u64);
94
95impl_schema_write!(u16);
96impl_schema_write!(u32);
97impl_schema_write!(u64);
98
99#[cfg(test)]
100mod tests {
101    use {
102        rand::Rng,
103        serde_derive::{Deserialize, Serialize},
104    };
105
106    // Max encoded size: ceil(16/7) + ceil(32/7) + ceil(64/7) = 3 + 5 + 10 = 18 bytes.
107    #[derive(
108        Debug, Eq, PartialEq, Serialize, Deserialize, wincode::SchemaRead, wincode::SchemaWrite,
109    )]
110    struct Dummy {
111        #[serde(with = "solana_serde_varint")]
112        #[wincode(with = "crate::Leb128Int<u16>")]
113        a: u16,
114        #[serde(with = "solana_serde_varint")]
115        #[wincode(with = "crate::Leb128Int<u32>")]
116        b: u32,
117        #[serde(with = "solana_serde_varint")]
118        #[wincode(with = "crate::Leb128Int<u64>")]
119        c: u64,
120    }
121
122    fn check(dummy: &Dummy) {
123        let wincode_bytes = wincode::serialize(dummy).unwrap();
124        let bincode_bytes = bincode::serialize(dummy).unwrap();
125        assert_eq!(wincode_bytes, bincode_bytes);
126        assert_eq!(
127            &wincode::deserialize::<Dummy>(&wincode_bytes).unwrap(),
128            dummy
129        );
130        assert_eq!(
131            &bincode::deserialize::<Dummy>(&bincode_bytes).unwrap(),
132            dummy
133        );
134    }
135
136    #[test]
137    fn edge_cases() {
138        let cases = [
139            Dummy { a: 0, b: 0, c: 0 },
140            Dummy { a: 1, b: 1, c: 1 },
141            Dummy {
142                a: 0x7F,
143                b: 0x7F,
144                c: 0x7F,
145            },
146            Dummy {
147                a: 0x80,
148                b: 0x80,
149                c: 0x80,
150            },
151            Dummy {
152                a: 0x3FFF,
153                b: 0x3FFF,
154                c: 0x3FFF,
155            },
156            Dummy {
157                a: 0x4000,
158                b: 0x4000,
159                c: 0x4000,
160            },
161            Dummy {
162                a: u16::MAX,
163                b: u32::MAX,
164                c: u64::MAX,
165            },
166        ];
167        for dummy in &cases {
168            check(dummy);
169        }
170    }
171
172    #[test]
173    fn random() {
174        let mut rng = rand::rng();
175        for _ in 0..100_000 {
176            check(&Dummy {
177                a: rng.random::<u16>() >> rng.random_range(0..u16::BITS),
178                b: rng.random::<u32>() >> rng.random_range(0..u32::BITS),
179                c: rng.random::<u64>() >> rng.random_range(0..u64::BITS),
180            });
181        }
182    }
183
184    #[test]
185    fn trailing_zeros() {
186        let buf = [0x80u8, 0x00];
187        let r = wincode::deserialize::<Dummy>(&buf);
188        assert!(matches!(
189            r,
190            Err(wincode::ReadError::Custom("Invalid Trailing Zeros"))
191        ));
192        assert!(bincode::deserialize::<Dummy>(&buf).is_err());
193    }
194
195    #[test]
196    fn last_byte_truncated() {
197        let buf = [0x01u8, 0xe4, 0xd7, 0x88, 0xf6, 0x6f];
198        let r = wincode::deserialize::<Dummy>(&buf);
199        assert!(matches!(
200            r,
201            Err(wincode::ReadError::Custom("Last Byte Truncated"))
202        ));
203        assert!(bincode::deserialize::<Dummy>(&buf).is_err());
204    }
205
206    #[test]
207    fn shift_overflow() {
208        let buf = [0x80u8, 0x80, 0x80];
209        let r = wincode::deserialize::<Dummy>(&buf);
210        assert!(matches!(
211            r,
212            Err(wincode::ReadError::Custom("Left Shift Overflows"))
213        ));
214        assert!(bincode::deserialize::<Dummy>(&buf).is_err());
215    }
216
217    #[test]
218    fn short_buffer() {
219        let buf = [0x80u8];
220        let r = wincode::deserialize::<Dummy>(&buf);
221        assert!(matches!(r, Err(wincode::ReadError::Io(_))));
222        assert!(bincode::deserialize::<Dummy>(&buf).is_err());
223
224        let r = wincode::deserialize::<Dummy>(&[]);
225        assert!(matches!(r, Err(wincode::ReadError::Io(_))));
226        assert!(bincode::deserialize::<Dummy>(&[]).is_err());
227    }
228}