1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0

//! Types and utilities around the QUIC Stream identifier

use crate::{endpoint, stream::StreamType, varint::VarInt};

/// The ID of a stream.
///
/// A stream ID is a 62-bit integer (0 to 2^62-1) that is unique for all streams
/// on a connection.
#[derive(PartialEq, Eq, PartialOrd, Ord, Debug, Copy, Clone, Hash)]
pub struct StreamId(VarInt);

// Stream IDs can be converted into `VarInt` and `u64`

impl From<StreamId> for VarInt {
    fn from(id: StreamId) -> Self {
        id.0
    }
}

impl From<StreamId> for u64 {
    fn from(id: StreamId) -> Self {
        id.0.as_u64()
    }
}

impl StreamId {
    /// Creates a Stream ID from a [`VarInt`].
    ///
    /// This is always a safe conversion, since Stream IDs and [`VarInt`]s
    /// share the same range.
    #[inline]
    pub const fn from_varint(id: VarInt) -> StreamId {
        StreamId(id)
    }

    /// Converts the stream id into a [`VarInt`]
    #[inline]
    pub const fn as_varint(self) -> VarInt {
        self.0
    }

    /// Returns the initial Stream ID for a given stream type.
    ///
    /// E.g. the initial Stream ID for a server initiated unidirectional Stream
    /// is Stream ID `3`.
    ///
    /// Example:
    ///
    /// ```
    /// # use s2n_quic_core::{endpoint, stream::{StreamId, StreamType}};
    /// let stream_id = StreamId::initial(endpoint::Type::Server, StreamType::Unidirectional);
    /// // Initial server initiated unidirectional Stream ID is 3
    /// assert_eq!(3u64, stream_id.as_varint().as_u64());
    /// ```
    #[inline]
    pub fn initial(initator: endpoint::Type, stream_type: StreamType) -> StreamId {
        //= https://www.rfc-editor.org/rfc/rfc9000#section-2.1
        //# The two least significant bits from a stream ID therefore identify a
        //# stream as one of four types, as summarized in Table 1.
        //#
        //#        +======+==================================+
        //#        | Bits | Stream Type                      |
        //#        +======+==================================+
        //#        | 0x00 | Client-Initiated, Bidirectional  |
        //#        +------+----------------------------------+
        //#        | 0x01 | Server-Initiated, Bidirectional  |
        //#        +------+----------------------------------+
        //#        | 0x02 | Client-Initiated, Unidirectional |
        //#        +------+----------------------------------+
        //#        | 0x03 | Server-Initiated, Unidirectional |
        //#        +------+----------------------------------+

        match (
            stream_type == StreamType::Bidirectional,
            initator == endpoint::Type::Client,
        ) {
            (true, true) => StreamId(VarInt::from_u32(0)),
            (true, false) => StreamId(VarInt::from_u32(1)),
            (false, true) => StreamId(VarInt::from_u32(2)),
            (false, false) => StreamId(VarInt::from_u32(3)),
        }
    }

    /// Returns the n-th `StreamId` for a certain type of `Stream`.
    ///
    /// The 0th `StreamId` thereby represents the `StreamId` which is returned
    /// by the [`Self::initial`] method. All further `StreamId`s of a certain type
    /// will be spaced apart by 4.
    ///
    /// nth() will return `None` if the resulting `StreamId` would not be valid.
    #[inline]
    pub fn nth(initiator: endpoint::Type, stream_type: StreamType, n: u64) -> Option<StreamId> {
        let initial = Self::initial(initiator, stream_type);
        // We calculate as much as possible with u64, to reduce the number of
        // overflow checks for the maximum Stream ID to the last operation
        let id = VarInt::new(n.checked_mul(4)?.checked_add(initial.into())?).ok()?;
        Some(StreamId(id))
    }

    /// Returns the next [`StreamId`] which is of the same type the one referred
    /// to. E.g. if the method is called on a Stream ID for an unidirectional
    /// client initiated stream, the Stream ID of the next unidirectional client
    /// initiated stream will be returned.
    ///
    /// Returns `None` if the next Stream ID would not be valid, due to being out
    /// of bounds.
    ///
    /// Example:
    ///
    /// ```
    /// # use s2n_quic_core::{endpoint, stream::{StreamId, StreamType}};
    /// let stream_id = StreamId::initial(endpoint::Type::Client, StreamType::Unidirectional);
    /// // Initial client initiated unidirectional Stream ID is 2
    /// assert_eq!(2u64, stream_id.as_varint().as_u64());
    /// // Get the next client initiated Stream ID
    /// let next_stream_id = stream_id.next_of_type();
    /// assert_eq!(6u64, next_stream_id.expect("Next Stream ID is valid").as_varint().as_u64());
    /// ```
    #[inline]
    pub fn next_of_type(self) -> Option<StreamId> {
        // Stream IDs increase in steps of 4, since the 2 least significant bytes
        // are used to indicate the stream type
        self.0
            .checked_add(VarInt::from_u32(4))
            .map(StreamId::from_varint)
    }

    /// Returns whether the client or server initiated the Stream
    #[inline]
    pub fn initiator(self) -> endpoint::Type {
        //= https://www.rfc-editor.org/rfc/rfc9000#section-2.1
        //# The least significant bit (0x1) of the stream ID identifies the
        //# initiator of the stream.  Client-initiated streams have even-numbered
        //# stream IDs (with the bit set to 0)
        if Into::<u64>::into(self.0) & 0x01u64 == 0 {
            endpoint::Type::Client
        } else {
            endpoint::Type::Server
        }
    }

    /// Returns whether the Stream is unidirectional or bidirectional.
    #[inline]
    pub fn stream_type(self) -> StreamType {
        //= https://www.rfc-editor.org/rfc/rfc9000#section-2.1
        //# The second least significant bit (0x2) of the stream ID distinguishes
        //# between bidirectional streams (with the bit set to 0) and
        //# unidirectional streams (with the bit set to 1).
        if Into::<u64>::into(self.0) & 0x02 == 0 {
            StreamType::Bidirectional
        } else {
            StreamType::Unidirectional
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::varint::MAX_VARINT_VALUE;

    #[test]
    fn initial_stream_ids() {
        for stream_type in &[StreamType::Bidirectional, StreamType::Unidirectional] {
            for initiator in &[endpoint::Type::Client, endpoint::Type::Server] {
                let id = StreamId::initial(*initiator, *stream_type);
                assert_eq!(*stream_type, id.stream_type());
                assert_eq!(*initiator, id.initiator());
            }
        }
    }

    #[test]
    fn stream_id_overflow() {
        // Check that the highest possible Stream ID works
        let max_stream_id_varint = VarInt::new((1 << 62) - 1).unwrap();
        let _max_stream_id = StreamId::from_varint(max_stream_id_varint);

        let max_increaseable_stream_id_varint = max_stream_id_varint - 4;
        let max_inreasable_stream_id = StreamId::from_varint(max_increaseable_stream_id_varint);
        assert!(max_inreasable_stream_id.next_of_type().is_some());

        // Check all the variants where the base ID is still valid but the
        // increment is no longer.
        for increment in 1..5 {
            let id_varint = max_increaseable_stream_id_varint + increment;
            let stream_id = StreamId::from_varint(id_varint);
            assert!(stream_id.next_of_type().is_none());
        }
    }

    #[test]
    fn nth_stream_id() {
        for stream_type in &[StreamType::Bidirectional, StreamType::Unidirectional] {
            for initiator in &[endpoint::Type::Client, endpoint::Type::Server] {
                // The first StreamId is the initial one
                let first = StreamId::nth(*initiator, *stream_type, 0).unwrap();
                assert_eq!(StreamId::initial(*initiator, *stream_type), first);

                for n in 1..10 {
                    let nth = StreamId::nth(*initiator, *stream_type, n).unwrap();
                    assert_eq!(VarInt::from_u32(n as u32 * 4), nth.0 - first.0);
                }
            }
        }
    }

    #[test]
    fn invalid_nth_stream_id() {
        for stream_type in &[StreamType::Bidirectional, StreamType::Unidirectional] {
            for initiator in &[endpoint::Type::Client, endpoint::Type::Server] {
                assert_eq!(
                    None,
                    StreamId::nth(
                        *initiator,
                        *stream_type,
                        Into::<u64>::into(MAX_VARINT_VALUE / 2)
                    )
                );
            }
        }
    }
}