wtransport_proto_lightyear_patch/
ids.rs

1use crate::varint::VarInt;
2use std::fmt;
3use std::str::FromStr;
4
5/// QUIC stream id.
6#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
7pub struct StreamId(VarInt);
8
9impl StreamId {
10    /// The largest stream id.
11    pub const MAX: StreamId = StreamId(VarInt::MAX);
12
13    /// New stream id.
14    #[inline(always)]
15    pub const fn new(varint: VarInt) -> Self {
16        Self(varint)
17    }
18
19    /// Checks whether a stream is bi-directional or not.
20    #[inline(always)]
21    pub const fn is_bidirectional(self) -> bool {
22        self.0.into_inner() & 0x2 == 0
23    }
24
25    /// Checks whether a stream is client-initiated or not.
26    #[inline(always)]
27    pub const fn is_client_initiated(self) -> bool {
28        self.0.into_inner() & 0x1 == 0
29    }
30
31    /// Checks whether a stream is locally initiated or not.
32    #[inline(always)]
33    pub const fn is_local(self, is_server: bool) -> bool {
34        (self.0.into_inner() & 0x1) == (is_server as u64)
35    }
36
37    /// Returns the integer value as `u64`.
38    #[inline(always)]
39    pub const fn into_u64(self) -> u64 {
40        self.0.into_inner()
41    }
42
43    /// Returns the stream id as [`VarInt`] value.
44    #[inline(always)]
45    pub const fn into_varint(self) -> VarInt {
46        self.0
47    }
48}
49
50impl From<StreamId> for VarInt {
51    #[inline(always)]
52    fn from(stream_id: StreamId) -> Self {
53        stream_id.0
54    }
55}
56
57impl fmt::Debug for StreamId {
58    #[inline(always)]
59    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60        self.0.fmt(f)
61    }
62}
63
64impl fmt::Display for StreamId {
65    #[inline(always)]
66    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67        self.0.fmt(f)
68    }
69}
70
71/// Error for invalid Session ID value.
72#[derive(Debug)]
73pub struct InvalidSessionId;
74
75/// A WebTransport session id.
76///
77/// Internally, it corresponds to a *bidirectional* *client-initiated* QUIC stream,
78/// that is, a webtransport *session stream*.
79#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
80pub struct SessionId(StreamId);
81
82impl SessionId {
83    /// Returns the integer value as `u64`.
84    #[inline(always)]
85    pub const fn into_u64(self) -> u64 {
86        self.0.into_u64()
87    }
88
89    /// Returns the session id as [`VarInt`] value.
90    #[inline(always)]
91    pub const fn into_varint(self) -> VarInt {
92        self.0.into_varint()
93    }
94
95    /// Returns the corresponding session QUIC stream.
96    #[inline(always)]
97    pub const fn session_stream(self) -> StreamId {
98        self.0
99    }
100
101    /// Tries to create a session id from its session stream.
102    ///
103    /// `stream_id` must be *bidirectional* and *client-initiated*, otherwise
104    /// an [`Err`] is returned.
105    pub fn try_from_session_stream(stream_id: StreamId) -> Result<Self, InvalidSessionId> {
106        if stream_id.is_bidirectional() && stream_id.is_client_initiated() {
107            Ok(Self(stream_id))
108        } else {
109            Err(InvalidSessionId)
110        }
111    }
112
113    /// Creates a session id without checking session stream properties.
114    ///
115    /// # Safety
116    ///
117    /// `stream_id` must be *bidirectional* and *client-initiated*.
118    #[inline(always)]
119    pub const unsafe fn from_session_stream_unchecked(stream_id: StreamId) -> Self {
120        debug_assert!(stream_id.is_bidirectional() && stream_id.is_client_initiated());
121        Self(stream_id)
122    }
123
124    #[inline(always)]
125    pub(crate) fn try_from_varint(varint: VarInt) -> Result<Self, InvalidSessionId> {
126        Self::try_from_session_stream(StreamId::new(varint))
127    }
128
129    #[cfg(test)]
130    pub(crate) fn maybe_invalid(varint: VarInt) -> Self {
131        Self(StreamId::new(varint))
132    }
133}
134
135impl fmt::Debug for SessionId {
136    #[inline(always)]
137    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138        self.0.fmt(f)
139    }
140}
141
142impl fmt::Display for SessionId {
143    #[inline(always)]
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        self.0.fmt(f)
146    }
147}
148
149/// Error for invalid Quarter Stream ID value (too large).
150#[derive(Debug)]
151pub struct InvalidQStreamId;
152
153/// HTTP3 Quarter Stream ID.
154#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
155pub struct QStreamId(VarInt);
156
157impl QStreamId {
158    /// The largest quarter stream id.
159    // SAFETY: value is less than max varint
160    pub const MAX: QStreamId =
161        unsafe { Self(VarInt::from_u64_unchecked(1_152_921_504_606_846_975)) };
162
163    /// Creates a quarter stream id from its corresponding [`SessionId`]
164    #[inline(always)]
165    pub const fn from_session_id(session_id: SessionId) -> Self {
166        let value = session_id.into_u64() >> 2;
167        debug_assert!(value <= Self::MAX.into_u64());
168
169        // SAFETY: after bitwise operation from stream id, result is surely a varint
170        let varint = unsafe { VarInt::from_u64_unchecked(value) };
171
172        Self(varint)
173    }
174
175    /// Returns its corresponding [`StreamId`].
176    ///
177    /// This is a *client-initiated* *bidirectional* stream.
178    #[inline(always)]
179    pub const fn into_stream_id(self) -> StreamId {
180        // SAFETY: Quarter Stream ID origin from a valid Stream ID
181        let varint = unsafe {
182            debug_assert!(self.0.into_inner() << 2 <= VarInt::MAX.into_inner());
183            VarInt::from_u64_unchecked(self.0.into_inner() << 2)
184        };
185
186        StreamId::new(varint)
187    }
188
189    /// Returns its corresponding [`SessionId`].
190    #[inline(always)]
191    pub const fn into_session_id(self) -> SessionId {
192        let stream_id = self.into_stream_id();
193
194        // SAFETY: corresponding stream for qstream is bidirectional and client-initiated
195        unsafe {
196            debug_assert!(stream_id.is_bidirectional() && stream_id.is_client_initiated());
197            SessionId::from_session_stream_unchecked(stream_id)
198        }
199    }
200
201    /// Returns the integer value as `u64`.
202    #[inline(always)]
203    pub const fn into_u64(self) -> u64 {
204        self.0.into_inner()
205    }
206
207    /// Returns the quarter stream id as [`VarInt`] value.
208    #[inline(always)]
209    pub const fn into_varint(self) -> VarInt {
210        self.0
211    }
212
213    pub(crate) fn try_from_varint(varint: VarInt) -> Result<Self, InvalidQStreamId> {
214        if varint <= Self::MAX.into_varint() {
215            Ok(Self(varint))
216        } else {
217            Err(InvalidQStreamId)
218        }
219    }
220
221    #[cfg(test)]
222    pub(crate) fn maybe_invalid(varint: VarInt) -> QStreamId {
223        Self(varint)
224    }
225}
226
227impl fmt::Debug for QStreamId {
228    #[inline(always)]
229    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230        self.0.fmt(f)
231    }
232}
233
234impl fmt::Display for QStreamId {
235    #[inline(always)]
236    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237        self.0.fmt(f)
238    }
239}
240
241/// Error for invalid HTTP status code.
242#[derive(Debug)]
243pub struct InvalidStatusCode;
244
245/// HTTP status code (rfc9110).
246#[derive(Default, Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
247pub struct StatusCode(u16);
248
249impl StatusCode {
250    /// The largest code.
251    pub const MAX: Self = Self(599);
252
253    /// The smallest code.
254    pub const MIN: Self = Self(100);
255
256    /// HTTP 200 OK status code.
257    pub const OK: Self = Self(200);
258
259    /// HTTP 403 Forbidden status code.
260    pub const FORBIDDEN: Self = Self(403);
261
262    /// HTTP 404 Not Found status code.
263    pub const NOT_FOUND: Self = Self(404);
264
265    /// Tries to construct from `u32`.
266    #[inline(always)]
267    pub fn try_from_u32(value: u32) -> Result<Self, InvalidStatusCode> {
268        value.try_into()
269    }
270
271    /// Extracts the integer value as `u16`.
272    #[inline(always)]
273    pub fn into_inner(self) -> u16 {
274        self.0
275    }
276
277    /// Returns true if the status code is 2xx.
278    #[inline(always)]
279    pub fn is_successful(self) -> bool {
280        (200..300).contains(&self.0)
281    }
282}
283
284impl TryFrom<u8> for StatusCode {
285    type Error = InvalidStatusCode;
286
287    fn try_from(value: u8) -> Result<Self, Self::Error> {
288        if u16::from(value) >= Self::MIN.0 && u16::from(value) <= Self::MAX.0 {
289            Ok(Self(u16::from(value)))
290        } else {
291            Err(InvalidStatusCode)
292        }
293    }
294}
295
296impl TryFrom<u16> for StatusCode {
297    type Error = InvalidStatusCode;
298
299    fn try_from(value: u16) -> Result<Self, Self::Error> {
300        if (Self::MIN.0..=Self::MAX.0).contains(&value) {
301            Ok(Self(value))
302        } else {
303            Err(InvalidStatusCode)
304        }
305    }
306}
307
308impl TryFrom<u32> for StatusCode {
309    type Error = InvalidStatusCode;
310
311    fn try_from(value: u32) -> Result<Self, Self::Error> {
312        if value >= u32::from(Self::MIN.0) && value <= u32::from(Self::MAX.0) {
313            Ok(Self(value as u16))
314        } else {
315            Err(InvalidStatusCode)
316        }
317    }
318}
319
320impl TryFrom<u64> for StatusCode {
321    type Error = InvalidStatusCode;
322
323    fn try_from(value: u64) -> Result<Self, Self::Error> {
324        if value >= u64::from(Self::MIN.0) && value <= u64::from(Self::MAX.0) {
325            Ok(Self(value as u16))
326        } else {
327            Err(InvalidStatusCode)
328        }
329    }
330}
331
332impl FromStr for StatusCode {
333    type Err = InvalidStatusCode;
334
335    fn from_str(s: &str) -> Result<Self, Self::Err> {
336        Ok(Self(s.parse().map_err(|_| InvalidStatusCode)?))
337    }
338}
339
340impl fmt::Debug for StatusCode {
341    #[inline]
342    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343        self.0.fmt(f)
344    }
345}
346
347impl fmt::Display for StatusCode {
348    #[inline]
349    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350        self.0.fmt(f)
351    }
352}
353
354#[cfg(test)]
355mod tests {
356    use utils::stream_types;
357    use utils::StreamType;
358
359    use super::*;
360
361    #[test]
362    fn stream_properties() {
363        for (id, stream_type) in stream_types(1024) {
364            let stream_id = StreamId::new(id);
365
366            match stream_type {
367                StreamType::ClientBi => {
368                    assert!(stream_id.is_bidirectional());
369                    assert!(stream_id.is_client_initiated());
370                    assert!(stream_id.is_local(false));
371                    assert!(!stream_id.is_local(true));
372                }
373                StreamType::ServerBi => {
374                    assert!(stream_id.is_bidirectional());
375                    assert!(!stream_id.is_client_initiated());
376                    assert!(!stream_id.is_local(false));
377                    assert!(stream_id.is_local(true));
378                }
379                StreamType::ClientUni => {
380                    assert!(!stream_id.is_bidirectional());
381                    assert!(stream_id.is_client_initiated());
382                    assert!(stream_id.is_local(false));
383                    assert!(!stream_id.is_local(true));
384                }
385                StreamType::ServerUni => {
386                    assert!(!stream_id.is_bidirectional());
387                    assert!(!stream_id.is_client_initiated());
388                    assert!(!stream_id.is_local(false));
389                    assert!(stream_id.is_local(true));
390                }
391            }
392        }
393    }
394
395    #[test]
396    fn session_id() {
397        for (id, stream_type) in stream_types(1024) {
398            if let StreamType::ClientBi = stream_type {
399                assert!(SessionId::try_from_varint(id).is_ok());
400                assert!(SessionId::try_from_session_stream(StreamId::new(id)).is_ok());
401            } else {
402                assert!(SessionId::try_from_varint(id).is_err());
403                assert!(SessionId::try_from_session_stream(StreamId::new(id)).is_err());
404            }
405        }
406    }
407
408    #[test]
409    fn qstream_id() {
410        for (quarter, id) in stream_types(1024)
411            .filter(|(_id, r#type)| matches!(r#type, StreamType::ClientBi))
412            .map(|(id, _type)| id)
413            .enumerate()
414        {
415            let session_id = SessionId::try_from_varint(id).unwrap();
416            let qstream_id = QStreamId::from_session_id(session_id);
417
418            assert_eq!(qstream_id.into_stream_id(), session_id.session_stream());
419            assert_eq!(qstream_id.into_session_id(), session_id);
420            assert_eq!(qstream_id.into_u64(), quarter as u64);
421        }
422    }
423
424    mod utils {
425        use super::*;
426
427        #[derive(Copy, Clone, Debug)]
428        pub enum StreamType {
429            ClientBi,
430            ServerBi,
431            ClientUni,
432            ServerUni,
433        }
434
435        pub fn stream_types(max_id: u32) -> impl Iterator<Item = (VarInt, StreamType)> {
436            [
437                StreamType::ClientBi,
438                StreamType::ServerBi,
439                StreamType::ClientUni,
440                StreamType::ServerUni,
441            ]
442            .into_iter()
443            .cycle()
444            .enumerate()
445            .map(|(index, r#type)| (VarInt::from_u32(index as u32), r#type))
446            .take(max_id as usize)
447        }
448    }
449}