vlqencoding/
lib.rs

1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8//! VLQ (Variable-length quantity) encoding.
9
10use std::io;
11use std::io::Read;
12use std::io::Write;
13
14pub trait VLQEncode<T> {
15    /// Encode an integer to a VLQ byte array and write it directly to a stream.
16    ///
17    /// # Examples
18    ///
19    /// ```
20    /// use vlqencoding::VLQEncode;
21    /// let mut v = vec![];
22    ///
23    /// let x = 120u8;
24    /// v.write_vlq(x)
25    ///     .expect("writing an encoded u8 to a vec should work");
26    /// assert_eq!(v, vec![120]);
27    ///
28    /// let x = 22742734291u64;
29    /// v.write_vlq(x)
30    ///     .expect("writing an encoded u64 to a vec should work");
31    ///
32    /// assert_eq!(v, vec![120, 211, 171, 202, 220, 84]);
33    /// ```
34    ///
35    /// Signed integers are encoded via zig-zag:
36    ///
37    /// ```
38    /// use vlqencoding::VLQEncode;
39    /// let mut v = vec![];
40    ///
41    /// let x = -3i8;
42    /// v.write_vlq(x)
43    ///     .expect("writing an encoded i8 to a vec should work");
44    /// assert_eq!(v, vec![5]);
45    ///
46    /// let x = 1000i16;
47    /// v.write_vlq(x)
48    ///     .expect("writing an encoded i16 to a vec should work");
49    /// assert_eq!(v, vec![5, 208, 15]);
50    /// ```
51    fn write_vlq(&mut self, value: T) -> io::Result<()>;
52}
53
54pub trait VLQDecode<T> {
55    /// Read a VLQ byte array from stream and decode it to an integer.
56    ///
57    /// # Examples
58    ///
59    /// ```
60    /// use std::io::Cursor;
61    /// use std::io::ErrorKind;
62    /// use std::io::Seek;
63    /// use std::io::SeekFrom;
64    ///
65    /// use vlqencoding::VLQDecode;
66    ///
67    /// let mut c = Cursor::new(vec![120u8, 211, 171, 202, 220, 84]);
68    ///
69    /// let x: Result<u8, _> = c.read_vlq();
70    /// assert_eq!(x.unwrap(), 120u8);
71    ///
72    /// let x: Result<u16, _> = c.read_vlq();
73    /// assert_eq!(x.unwrap_err().kind(), ErrorKind::InvalidData);
74    ///
75    /// c.seek(SeekFrom::Start(1)).expect("seek should work");
76    /// let x: Result<u64, _> = c.read_vlq();
77    /// assert_eq!(x.unwrap(), 22742734291u64);
78    /// ```
79    ///
80    /// Signed integers are decoded via zig-zag:
81    ///
82    /// ```
83    /// use std::io::Cursor;
84    /// use std::io::ErrorKind;
85    /// use std::io::Seek;
86    /// use std::io::SeekFrom;
87    ///
88    /// use vlqencoding::VLQDecode;
89    ///
90    /// let mut c = Cursor::new(vec![5u8, 208, 15]);
91    ///
92    /// let x: Result<i8, _> = c.read_vlq();
93    /// assert_eq!(x.unwrap(), -3i8);
94    ///
95    /// let x: Result<i8, _> = c.read_vlq();
96    /// assert_eq!(x.unwrap_err().kind(), ErrorKind::InvalidData);
97    ///
98    /// c.seek(SeekFrom::Start(1)).expect("seek should work");
99    /// let x: Result<i32, _> = c.read_vlq();
100    /// assert_eq!(x.unwrap(), 1000i32);
101    /// ```
102    fn read_vlq(&mut self) -> io::Result<T>;
103}
104
105pub trait VLQDecodeAt<T> {
106    /// Read a VLQ byte array from the given offset and decode it to an integer.
107    ///
108    /// Returns `Ok((decoded_integer, bytes_read))` on success.
109    ///
110    /// This is similar to `VLQDecode::read_vlq`. It's for immutable `AsRef<[u8]>` instead of
111    /// a mutable `io::Read` object.
112    ///
113    /// # Examples
114    ///
115    /// ```
116    /// use std::io::ErrorKind;
117    ///
118    /// use vlqencoding::VLQDecodeAt;
119    ///
120    /// let c = &[120u8, 211, 171, 202, 220, 84, 255];
121    ///
122    /// let x: Result<(u8, _), _> = c.read_vlq_at(0);
123    /// assert_eq!(x.unwrap(), (120u8, 1));
124    ///
125    /// let x: Result<(u64, _), _> = c.read_vlq_at(1);
126    /// assert_eq!(x.unwrap(), (22742734291u64, 5));
127    ///
128    /// let x: Result<(u64, _), _> = c.read_vlq_at(6);
129    /// assert_eq!(x.unwrap_err().kind(), ::std::io::ErrorKind::InvalidData);
130    ///
131    /// let x: Result<(u64, _), _> = c.read_vlq_at(7);
132    /// assert_eq!(x.unwrap_err().kind(), ::std::io::ErrorKind::InvalidData);
133    /// ```
134    fn read_vlq_at(&self, offset: usize) -> io::Result<(T, usize)>;
135}
136
137macro_rules! impl_unsigned_primitive {
138    ($T: ident) => {
139        impl<W: Write + ?Sized> VLQEncode<$T> for W {
140            fn write_vlq(&mut self, value: $T) -> io::Result<()> {
141                let mut buf = [0u8];
142                let mut value = value;
143                loop {
144                    let mut byte = (value & 127) as u8;
145                    let next = value >> 7;
146                    if next != 0 {
147                        byte |= 128;
148                    }
149                    buf[0] = byte;
150                    self.write_all(&buf)?;
151                    value = next;
152                    if value == 0 {
153                        break;
154                    }
155                }
156                Ok(())
157            }
158        }
159
160        impl<R: Read + ?Sized> VLQDecode<$T> for R {
161            fn read_vlq(&mut self) -> io::Result<$T> {
162                let mut buf = [0u8];
163                let mut value = 0 as $T;
164                let mut base = 1 as $T;
165                let base_multiplier = (1 << 7) as $T;
166                loop {
167                    self.read_exact(&mut buf)?;
168                    let byte = buf[0];
169                    value = ($T::from(byte & 127))
170                        .checked_mul(base)
171                        .and_then(|v| v.checked_add(value))
172                        .ok_or(io::ErrorKind::InvalidData)?;
173                    if byte & 128 == 0 {
174                        break;
175                    }
176                    base = base
177                        .checked_mul(base_multiplier)
178                        .ok_or(io::ErrorKind::InvalidData)?;
179                }
180                Ok(value)
181            }
182        }
183
184        impl<R: AsRef<[u8]>> VLQDecodeAt<$T> for R {
185            fn read_vlq_at(&self, offset: usize) -> io::Result<($T, usize)> {
186                let buf = self.as_ref();
187                let mut size = 0;
188                let mut value = 0 as $T;
189                let mut base = 1 as $T;
190                let base_multiplier = (1 << 7) as $T;
191                loop {
192                    if let Some(byte) = buf.get(offset + size) {
193                        size += 1;
194                        value = ($T::from(byte & 127))
195                            .checked_mul(base)
196                            .and_then(|v| v.checked_add(value))
197                            .ok_or(io::ErrorKind::InvalidData)?;
198                        if byte & 128 == 0 {
199                            break;
200                        }
201                        base = base
202                            .checked_mul(base_multiplier)
203                            .ok_or(io::ErrorKind::InvalidData)?;
204                    } else {
205                        return Err(io::ErrorKind::InvalidData.into());
206                    }
207                }
208                Ok((value, size))
209            }
210        }
211    };
212}
213
214impl_unsigned_primitive!(usize);
215impl_unsigned_primitive!(u64);
216impl_unsigned_primitive!(u32);
217impl_unsigned_primitive!(u16);
218impl_unsigned_primitive!(u8);
219
220macro_rules! impl_signed_primitive {
221    ($T: ty, $U: ty) => {
222        impl<W: Write + ?Sized> VLQEncode<$T> for W {
223            fn write_vlq(&mut self, v: $T) -> io::Result<()> {
224                self.write_vlq(((v << 1) ^ (v >> (<$U>::BITS - 1))) as $U)
225            }
226        }
227
228        impl<R: Read + ?Sized> VLQDecode<$T> for R {
229            fn read_vlq(&mut self) -> io::Result<$T> {
230                (self.read_vlq() as Result<$U, _>).map(|n| ((n >> 1) as $T) ^ -((n & 1) as $T))
231            }
232        }
233
234        impl<R: AsRef<[u8]>> VLQDecodeAt<$T> for R {
235            fn read_vlq_at(&self, offset: usize) -> io::Result<($T, usize)> {
236                (self.read_vlq_at(offset) as Result<($U, _), _>)
237                    .map(|(n, s)| (((n >> 1) as $T) ^ -((n & 1) as $T), s))
238            }
239        }
240    };
241}
242
243impl_signed_primitive!(isize, usize);
244impl_signed_primitive!(i64, u64);
245impl_signed_primitive!(i32, u32);
246impl_signed_primitive!(i16, u16);
247impl_signed_primitive!(i8, u8);
248
249#[cfg(test)]
250mod tests {
251    use std::io;
252    use std::io::Cursor;
253    use std::io::Seek;
254    use std::io::SeekFrom;
255
256    use quickcheck::quickcheck;
257
258    use super::*;
259
260    macro_rules! check_round_trip {
261        ($N: expr) => {{
262            let mut v = vec![];
263            let mut x = $N;
264            v.write_vlq(x).expect("write");
265
266            // `z` and `y` below are helpful for the compiler to figure out the return type of
267            // `read_vlq_at`, and `read_vlq`.
268            #[allow(unused_assignments)]
269            let mut z = x;
270            let t = v.read_vlq_at(0).unwrap();
271            z = t.0;
272
273            let mut c = Cursor::new(v);
274            let y = x;
275            x = c.read_vlq().unwrap();
276            x == y && y == z && t.1 == c.position() as usize
277        }};
278    }
279
280    #[test]
281    fn test_round_trip_manual() {
282        for i in (0..64)
283            .flat_map(|b| vec![1u64 << b, (1 << b) + 1, (1 << b) - 1].into_iter())
284            .chain(vec![0xb3a73ce2ff2, 0xab54a98ceb1f0ad2].into_iter())
285            .flat_map(|i| vec![i, !i].into_iter())
286        {
287            assert!(check_round_trip!(i as i8));
288            assert!(check_round_trip!(i as i16));
289            assert!(check_round_trip!(i as i32));
290            assert!(check_round_trip!(i as i64));
291            assert!(check_round_trip!(i as isize));
292            assert!(check_round_trip!(i as u8));
293            assert!(check_round_trip!(i as u16));
294            assert!(check_round_trip!(i as u32));
295            assert!(check_round_trip!(i));
296            assert!(check_round_trip!(i as usize));
297        }
298    }
299
300    #[test]
301    fn test_read_errors() {
302        let mut c = Cursor::new(vec![]);
303        assert_eq!(
304            (c.read_vlq() as io::Result<u64>).unwrap_err().kind(),
305            io::ErrorKind::UnexpectedEof
306        );
307
308        let mut c = Cursor::new(vec![255, 129]);
309        assert_eq!(
310            (c.read_vlq() as io::Result<u64>).unwrap_err().kind(),
311            io::ErrorKind::UnexpectedEof
312        );
313
314        c.seek(SeekFrom::Start(0)).unwrap();
315        assert_eq!(
316            (c.read_vlq() as io::Result<u8>).unwrap_err().kind(),
317            io::ErrorKind::InvalidData
318        );
319    }
320
321    #[test]
322    fn test_zig_zag() {
323        let mut c = Cursor::new(vec![]);
324        for (i, u) in [
325            (0, 0),
326            (-1, 1),
327            (1, 2),
328            (-2, 3),
329            (-127, 253),
330            (127, 254),
331            (-128i8, 255u8),
332        ] {
333            c.seek(SeekFrom::Start(0)).expect("seek");
334            c.write_vlq(i).expect("write");
335            c.seek(SeekFrom::Start(0)).expect("seek");
336            let x: u8 = c.read_vlq().unwrap();
337            assert_eq!(x, u);
338        }
339    }
340
341    quickcheck! {
342        fn test_round_trip_u64_quickcheck(x: u64) -> bool {
343            check_round_trip!(x)
344        }
345
346        fn test_round_trip_i64_quickcheck(x: i64) -> bool {
347            check_round_trip!(x)
348        }
349
350        fn test_round_trip_u8_quickcheck(x: u8) -> bool {
351            check_round_trip!(x)
352        }
353
354        fn test_round_trip_i8_quickcheck(x: i8) -> bool {
355            check_round_trip!(x)
356        }
357    }
358}