salvo_quinn/proto/
stream.rs

1use bytes::{Buf, BufMut};
2use std::{
3    convert::TryFrom,
4    fmt::{self, Display},
5    ops::Add,
6};
7
8use super::{
9    coding::{BufExt, BufMutExt, Decode, Encode, UnexpectedEnd},
10    varint::VarInt,
11};
12
13#[derive(Debug, PartialEq, Eq, Clone)]
14pub struct StreamType(u64);
15
16macro_rules! stream_types {
17    {$($name:ident = $val:expr,)*} => {
18        impl StreamType {
19            $(pub const $name: StreamType = StreamType($val);)*
20        }
21    }
22}
23
24stream_types! {
25    CONTROL = 0x00,
26    PUSH = 0x01,
27    ENCODER = 0x02,
28    DECODER = 0x03,
29}
30
31impl StreamType {
32    pub const MAX_ENCODED_SIZE: usize = VarInt::MAX_SIZE;
33
34    pub fn value(&self) -> u64 {
35        self.0
36    }
37    /// returns a StreamType type with random number of the 0x1f * N + 0x21
38    /// format within the range of the Varint implementation
39    pub fn grease() -> Self {
40        StreamType(fastrand::u64(0..0x210842108421083) * 0x1f + 0x21)
41    }
42}
43
44impl Decode for StreamType {
45    fn decode<B: Buf>(buf: &mut B) -> Result<Self, UnexpectedEnd> {
46        Ok(StreamType(buf.get_var()?))
47    }
48}
49
50impl Encode for StreamType {
51    fn encode<W: BufMut>(&self, buf: &mut W) {
52        buf.write_var(self.0);
53    }
54}
55
56impl fmt::Display for StreamType {
57    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
58        match self {
59            &StreamType::CONTROL => write!(f, "Control"),
60            &StreamType::ENCODER => write!(f, "Encoder"),
61            &StreamType::DECODER => write!(f, "Decoder"),
62            x => write!(f, "StreamType({})", x.0),
63        }
64    }
65}
66
67/// Identifier for a stream
68#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
69pub struct StreamId(#[cfg(not(test))] u64, #[cfg(test)] pub(crate) u64);
70
71impl fmt::Display for StreamId {
72    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
73        let initiator = match self.initiator() {
74            Side::Client => "client",
75            Side::Server => "server",
76        };
77        let dir = match self.dir() {
78            Dir::Uni => "uni",
79            Dir::Bi => "bi",
80        };
81        write!(
82            f,
83            "{} {}directional stream {}",
84            initiator,
85            dir,
86            self.index()
87        )
88    }
89}
90
91impl StreamId {
92    pub(crate) fn first_request() -> Self {
93        Self::new(0, Dir::Bi, Side::Client)
94    }
95
96    /// Is this a client-initiated request?
97    pub fn is_request(&self) -> bool {
98        self.dir() == Dir::Bi && self.initiator() == Side::Client
99    }
100
101    /// Is this a server push?
102    pub fn is_push(&self) -> bool {
103        self.dir() == Dir::Uni && self.initiator() == Side::Server
104    }
105
106    /// Which side of a connection initiated the stream
107    pub(crate) fn initiator(self) -> Side {
108        if self.0 & 0x1 == 0 {
109            Side::Client
110        } else {
111            Side::Server
112        }
113    }
114
115    /// Create a new StreamId
116    fn new(index: u64, dir: Dir, initiator: Side) -> Self {
117        StreamId((index as u64) << 2 | (dir as u64) << 1 | initiator as u64)
118    }
119
120    /// Distinguishes streams of the same initiator and directionality
121    fn index(self) -> u64 {
122        self.0 >> 2
123    }
124
125    /// Which directions data flows in
126    fn dir(self) -> Dir {
127        if self.0 & 0x2 == 0 {
128            Dir::Bi
129        } else {
130            Dir::Uni
131        }
132    }
133}
134
135impl TryFrom<u64> for StreamId {
136    type Error = InvalidStreamId;
137    fn try_from(v: u64) -> Result<Self, Self::Error> {
138        if v > VarInt::MAX.0 {
139            return Err(InvalidStreamId(v));
140        }
141        Ok(Self(v))
142    }
143}
144
145/// Invalid StreamId, for example because it's too large
146#[derive(Debug, PartialEq)]
147pub struct InvalidStreamId(u64);
148
149impl Display for InvalidStreamId {
150    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
151        write!(f, "invalid stream id: {:x}", self.0)
152    }
153}
154
155impl Encode for StreamId {
156    fn encode<B: bytes::BufMut>(&self, buf: &mut B) {
157        VarInt::from_u64(self.0).unwrap().encode(buf);
158    }
159}
160
161impl Add<usize> for StreamId {
162    type Output = StreamId;
163
164    #[allow(clippy::suspicious_arithmetic_impl)]
165    fn add(self, rhs: usize) -> Self::Output {
166        let index = u64::min(
167            u64::saturating_add(self.index(), rhs as u64),
168            VarInt::MAX.0 >> 2,
169        );
170        Self::new(index, self.dir(), self.initiator())
171    }
172}
173
174#[derive(Debug, Copy, Clone, Eq, PartialEq)]
175pub enum Side {
176    /// The initiator of a connection
177    Client = 0,
178    /// The acceptor of a connection
179    Server = 1,
180}
181
182/// Whether a stream communicates data in both directions or only from the initiator
183#[derive(Debug, Copy, Clone, Eq, PartialEq)]
184enum Dir {
185    /// Data flows in both directions
186    Bi = 0,
187    /// Data flows only from the stream's initiator
188    Uni = 1,
189}