s2n_quic_core/stream/
id.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// SPDX-License-Identifier: Apache-2.0
3
4//! Types and utilities around the QUIC Stream identifier
5
6use crate::{endpoint, stream::StreamType, varint::VarInt};
7#[cfg(any(test, feature = "generator"))]
8use bolero_generator::prelude::*;
9
10/// The ID of a stream.
11///
12/// A stream ID is a 62-bit integer (0 to 2^62-1) that is unique for all streams
13/// on a connection.
14#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Copy, Clone, Hash)]
15#[cfg_attr(any(feature = "generator", test), derive(TypeGenerator))]
16pub struct StreamId(VarInt);
17
18// Stream IDs can be converted into `VarInt` and `u64`
19
20impl From<StreamId> for VarInt {
21    fn from(id: StreamId) -> Self {
22        id.0
23    }
24}
25
26impl From<StreamId> for u64 {
27    fn from(id: StreamId) -> Self {
28        id.0.as_u64()
29    }
30}
31
32impl StreamId {
33    /// Creates a Stream ID from a [`VarInt`].
34    ///
35    /// This is always a safe conversion, since Stream IDs and [`VarInt`]s
36    /// share the same range.
37    #[inline]
38    pub const fn from_varint(id: VarInt) -> StreamId {
39        StreamId(id)
40    }
41
42    /// Converts the stream id into a [`VarInt`]
43    #[inline]
44    pub const fn as_varint(self) -> VarInt {
45        self.0
46    }
47
48    /// Returns the initial Stream ID for a given stream type.
49    ///
50    /// E.g. the initial Stream ID for a server initiated unidirectional Stream
51    /// is Stream ID `3`.
52    ///
53    /// Example:
54    ///
55    /// ```
56    /// # use s2n_quic_core::{endpoint, stream::{StreamId, StreamType}};
57    /// let stream_id = StreamId::initial(endpoint::Type::Server, StreamType::Unidirectional);
58    /// // Initial server initiated unidirectional Stream ID is 3
59    /// assert_eq!(3u64, stream_id.as_varint().as_u64());
60    /// ```
61    #[inline]
62    pub fn initial(initiator: endpoint::Type, stream_type: StreamType) -> StreamId {
63        //= https://www.rfc-editor.org/rfc/rfc9000#section-2.1
64        //# The two least significant bits from a stream ID therefore identify a
65        //# stream as one of four types, as summarized in Table 1.
66        //#
67        //#        +======+==================================+
68        //#        | Bits | Stream Type                      |
69        //#        +======+==================================+
70        //#        | 0x00 | Client-Initiated, Bidirectional  |
71        //#        +------+----------------------------------+
72        //#        | 0x01 | Server-Initiated, Bidirectional  |
73        //#        +------+----------------------------------+
74        //#        | 0x02 | Client-Initiated, Unidirectional |
75        //#        +------+----------------------------------+
76        //#        | 0x03 | Server-Initiated, Unidirectional |
77        //#        +------+----------------------------------+
78
79        match (
80            stream_type == StreamType::Bidirectional,
81            initiator == endpoint::Type::Client,
82        ) {
83            (true, true) => StreamId(VarInt::from_u32(0)),
84            (true, false) => StreamId(VarInt::from_u32(1)),
85            (false, true) => StreamId(VarInt::from_u32(2)),
86            (false, false) => StreamId(VarInt::from_u32(3)),
87        }
88    }
89
90    /// Returns the n-th `StreamId` for a certain type of `Stream`.
91    ///
92    /// The 0th `StreamId` thereby represents the `StreamId` which is returned
93    /// by the [`Self::initial`] method. All further `StreamId`s of a certain type
94    /// will be spaced apart by 4.
95    ///
96    /// nth() will return `None` if the resulting `StreamId` would not be valid.
97    #[inline]
98    pub fn nth(initiator: endpoint::Type, stream_type: StreamType, n: u64) -> Option<StreamId> {
99        let initial = Self::initial(initiator, stream_type);
100        // We calculate as much as possible with u64, to reduce the number of
101        // overflow checks for the maximum Stream ID to the last operation
102        let id = VarInt::new(n.checked_mul(4)?.checked_add(initial.into())?).ok()?;
103        Some(StreamId(id))
104    }
105
106    /// Returns the next [`StreamId`] which is of the same type the one referred
107    /// to. E.g. if the method is called on a Stream ID for an unidirectional
108    /// client initiated stream, the Stream ID of the next unidirectional client
109    /// initiated stream will be returned.
110    ///
111    /// Returns `None` if the next Stream ID would not be valid, due to being out
112    /// of bounds.
113    ///
114    /// Example:
115    ///
116    /// ```
117    /// # use s2n_quic_core::{endpoint, stream::{StreamId, StreamType}};
118    /// let stream_id = StreamId::initial(endpoint::Type::Client, StreamType::Unidirectional);
119    /// // Initial client initiated unidirectional Stream ID is 2
120    /// assert_eq!(2u64, stream_id.as_varint().as_u64());
121    /// // Get the next client initiated Stream ID
122    /// let next_stream_id = stream_id.next_of_type();
123    /// assert_eq!(6u64, next_stream_id.expect("Next Stream ID is valid").as_varint().as_u64());
124    /// ```
125    #[inline]
126    pub fn next_of_type(self) -> Option<StreamId> {
127        // Stream IDs increase in steps of 4, since the 2 least significant bytes
128        // are used to indicate the stream type
129        self.0
130            .checked_add(VarInt::from_u32(4))
131            .map(StreamId::from_varint)
132    }
133
134    /// Returns whether the client or server initiated the Stream
135    #[inline]
136    pub fn initiator(self) -> endpoint::Type {
137        //= https://www.rfc-editor.org/rfc/rfc9000#section-2.1
138        //# The least significant bit (0x1) of the stream ID identifies the
139        //# initiator of the stream.  Client-initiated streams have even-numbered
140        //# stream IDs (with the bit set to 0)
141        if Into::<u64>::into(self.0) & 0x01u64 == 0 {
142            endpoint::Type::Client
143        } else {
144            endpoint::Type::Server
145        }
146    }
147
148    /// Returns whether the Stream is unidirectional or bidirectional.
149    #[inline]
150    pub fn stream_type(self) -> StreamType {
151        //= https://www.rfc-editor.org/rfc/rfc9000#section-2.1
152        //# The second least significant bit (0x2) of the stream ID distinguishes
153        //# between bidirectional streams (with the bit set to 0) and
154        //# unidirectional streams (with the bit set to 1).
155        if Into::<u64>::into(self.0) & 0x02 == 0 {
156            StreamType::Bidirectional
157        } else {
158            StreamType::Unidirectional
159        }
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use crate::varint::MAX_VARINT_VALUE;
167
168    #[test]
169    fn initial_stream_ids() {
170        for stream_type in &[StreamType::Bidirectional, StreamType::Unidirectional] {
171            for initiator in &[endpoint::Type::Client, endpoint::Type::Server] {
172                let id = StreamId::initial(*initiator, *stream_type);
173                assert_eq!(*stream_type, id.stream_type());
174                assert_eq!(*initiator, id.initiator());
175            }
176        }
177    }
178
179    #[test]
180    fn stream_id_overflow() {
181        // Check that the highest possible Stream ID works
182        let max_stream_id_varint = VarInt::new((1 << 62) - 1).unwrap();
183        let _max_stream_id = StreamId::from_varint(max_stream_id_varint);
184
185        let max_increaseable_stream_id_varint = max_stream_id_varint - 4;
186        let max_inreasable_stream_id = StreamId::from_varint(max_increaseable_stream_id_varint);
187        assert!(max_inreasable_stream_id.next_of_type().is_some());
188
189        // Check all the variants where the base ID is still valid but the
190        // increment is no longer.
191        for increment in 1..5 {
192            let id_varint = max_increaseable_stream_id_varint + increment;
193            let stream_id = StreamId::from_varint(id_varint);
194            assert!(stream_id.next_of_type().is_none());
195        }
196    }
197
198    #[test]
199    fn nth_stream_id() {
200        for stream_type in &[StreamType::Bidirectional, StreamType::Unidirectional] {
201            for initiator in &[endpoint::Type::Client, endpoint::Type::Server] {
202                // The first StreamId is the initial one
203                let first = StreamId::nth(*initiator, *stream_type, 0).unwrap();
204                assert_eq!(StreamId::initial(*initiator, *stream_type), first);
205
206                for n in 1..10 {
207                    let nth = StreamId::nth(*initiator, *stream_type, n).unwrap();
208                    assert_eq!(VarInt::from_u32(n as u32 * 4), nth.0 - first.0);
209                }
210            }
211        }
212    }
213
214    #[test]
215    fn invalid_nth_stream_id() {
216        for stream_type in &[StreamType::Bidirectional, StreamType::Unidirectional] {
217            for initiator in &[endpoint::Type::Client, endpoint::Type::Server] {
218                assert_eq!(
219                    None,
220                    StreamId::nth(
221                        *initiator,
222                        *stream_type,
223                        Into::<u64>::into(MAX_VARINT_VALUE / 2)
224                    )
225                );
226            }
227        }
228    }
229}