Skip to main content

webtrans_proto/
varint.rs

1//! QUIC variable-length integer encoding and decoding.
2
3// Based on Quinn: https://github.com/quinn-rs/quinn/tree/main/quinn-proto/src
4// Licensed under Apache-2.0 OR MIT
5
6use std::{convert::TryInto, fmt};
7
8use bytes::{Buf, BufMut};
9use thiserror::Error;
10use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
11
12/// An integer less than 2^62.
13///
14/// Values of this type are suitable for encoding as QUIC variable-length integer.
15// Rust does not currently model that the top two bits are reserved for the length tag.
16#[derive(Default, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
17pub struct VarInt(pub(crate) u64);
18
19impl VarInt {
20    /// The largest representable value.
21    pub const MAX: Self = Self((1 << 62) - 1);
22    /// The largest encoded value length.
23    pub const MAX_SIZE: usize = 8;
24
25    /// Construct a `VarInt` infallibly.
26    pub const fn from_u32(x: u32) -> Self {
27        Self(x as u64)
28    }
29
30    /// Succeeds if `x` < 2^62.
31    pub fn from_u64(x: u64) -> Result<Self, VarIntBoundsExceeded> {
32        if x <= Self::MAX.0 {
33            Ok(Self(x))
34        } else {
35            Err(VarIntBoundsExceeded)
36        }
37    }
38
39    /// Create a `VarInt` without checking the bounds.
40    ///
41    /// # Safety
42    ///
43    /// `x` must be less than 2^62.
44    pub const unsafe fn from_u64_unchecked(x: u64) -> Self {
45        Self(x)
46    }
47
48    /// Extract the integer value.
49    pub const fn into_inner(self) -> u64 {
50        self.0
51    }
52
53    /// Compute the number of bytes needed to encode this value.
54    pub fn size(self) -> usize {
55        let x = self.0;
56        if x < (1 << 6) {
57            1
58        } else if x < (1 << 14) {
59            2
60        } else if x < (1 << 30) {
61            4
62        } else if x <= Self::MAX.0 {
63            8
64        } else {
65            unreachable!("malformed VarInt");
66        }
67    }
68}
69
70impl From<VarInt> for u64 {
71    fn from(x: VarInt) -> Self {
72        x.0
73    }
74}
75
76impl From<u8> for VarInt {
77    fn from(x: u8) -> Self {
78        Self(x.into())
79    }
80}
81
82impl From<u16> for VarInt {
83    fn from(x: u16) -> Self {
84        Self(x.into())
85    }
86}
87
88impl From<u32> for VarInt {
89    fn from(x: u32) -> Self {
90        Self(x.into())
91    }
92}
93
94impl std::convert::TryFrom<u64> for VarInt {
95    type Error = VarIntBoundsExceeded;
96    /// Succeeds if `x` < 2^62.
97    fn try_from(x: u64) -> Result<Self, VarIntBoundsExceeded> {
98        Self::from_u64(x)
99    }
100}
101
102impl std::convert::TryFrom<u128> for VarInt {
103    type Error = VarIntBoundsExceeded;
104    /// Succeeds if `x` < 2^62.
105    fn try_from(x: u128) -> Result<Self, VarIntBoundsExceeded> {
106        Self::from_u64(x.try_into().map_err(|_| VarIntBoundsExceeded)?)
107    }
108}
109
110impl std::convert::TryFrom<usize> for VarInt {
111    type Error = VarIntBoundsExceeded;
112    /// Succeeds if `x` < 2^62.
113    fn try_from(x: usize) -> Result<Self, VarIntBoundsExceeded> {
114        Self::try_from(x as u64)
115    }
116}
117
118impl fmt::Debug for VarInt {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        self.0.fmt(f)
121    }
122}
123
124impl fmt::Display for VarInt {
125    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
126        self.0.fmt(f)
127    }
128}
129
130impl VarInt {
131    /// Decode a QUIC varint from an in-memory buffer.
132    pub fn decode<B: Buf>(r: &mut B) -> Result<Self, VarIntUnexpectedEnd> {
133        if !r.has_remaining() {
134            return Err(VarIntUnexpectedEnd);
135        }
136        let mut buf = [0; 8];
137        buf[0] = r.get_u8();
138        let tag = buf[0] >> 6;
139        buf[0] &= 0b0011_1111;
140        let x = match tag {
141            0b00 => u64::from(buf[0]),
142            0b01 => {
143                if r.remaining() < 1 {
144                    return Err(VarIntUnexpectedEnd);
145                }
146                r.copy_to_slice(&mut buf[1..2]);
147                u64::from(u16::from_be_bytes(buf[..2].try_into().unwrap()))
148            }
149            0b10 => {
150                if r.remaining() < 3 {
151                    return Err(VarIntUnexpectedEnd);
152                }
153                r.copy_to_slice(&mut buf[1..4]);
154                u64::from(u32::from_be_bytes(buf[..4].try_into().unwrap()))
155            }
156            0b11 => {
157                if r.remaining() < 7 {
158                    return Err(VarIntUnexpectedEnd);
159                }
160                r.copy_to_slice(&mut buf[1..8]);
161                u64::from_be_bytes(buf)
162            }
163            _ => unreachable!(),
164        };
165        Ok(Self(x))
166    }
167
168    /// Read a QUIC varint from an async stream.
169    pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, VarIntUnexpectedEnd> {
170        // Eight bytes is the maximum encoded length.
171        let mut buf = [0; 8];
172
173        // Read the first byte because it encodes the length tag.
174        stream
175            .read_exact(&mut buf[0..1])
176            .await
177            .map_err(|_| VarIntUnexpectedEnd)?;
178
179        // 0b00 = 1 byte, 0b01 = 2 bytes, 0b10 = 4 bytes, 0b11 = 8 bytes.
180        let size = 1 << (buf[0] >> 6);
181        stream
182            .read_exact(&mut buf[1..size])
183            .await
184            .map_err(|_| VarIntUnexpectedEnd)?;
185
186        // Decode directly from the stack buffer slice.
187        let mut slice = &buf[..size];
188        let v = VarInt::decode(&mut slice).expect("buffer size is derived from the varint tag");
189
190        Ok(v)
191    }
192
193    /// Encode this value as a QUIC varint into the given buffer.
194    pub fn encode<B: BufMut>(&self, w: &mut B) {
195        let x = self.0;
196        if x < (1 << 6) {
197            w.put_u8(x as u8);
198        } else if x < (1 << 14) {
199            w.put_u16((0b01 << 14) | x as u16);
200        } else if x < (1 << 30) {
201            w.put_u32((0b10 << 30) | x as u32);
202        } else if x <= Self::MAX.0 {
203            w.put_u64((0b11 << 62) | x);
204        } else {
205            unreachable!("malformed VarInt")
206        }
207    }
208
209    /// Encode and write this value as a QUIC varint to an async stream.
210    pub async fn write<S: AsyncWrite + Unpin>(
211        &self,
212        stream: &mut S,
213    ) -> Result<(), VarIntUnexpectedEnd> {
214        // Keep the temporary buffer on the stack to avoid allocation.
215        let mut buf = [0u8; 8];
216        let mut cursor: &mut [u8] = &mut buf;
217        self.encode(&mut cursor);
218        let size = 8 - cursor.len();
219
220        let mut cursor = &buf[..size];
221        stream
222            .write_all_buf(&mut cursor)
223            .await
224            .map_err(|_| VarIntUnexpectedEnd)?;
225
226        Ok(())
227    }
228}
229
230/// Error returned when constructing a `VarInt` from a value >= 2^62.
231#[derive(Debug, Copy, Clone, Eq, PartialEq, Error)]
232#[error("value too large for varint encoding")]
233pub struct VarIntBoundsExceeded;
234
235#[derive(Error, Debug, Copy, Clone, Eq, PartialEq)]
236#[error("unexpected end of buffer")]
237/// Error returned when a varint decode reaches EOF before all bytes are available.
238pub struct VarIntUnexpectedEnd;