sec_http3/proto/
stream.rs1use bytes::{Buf, BufMut};
2use std::{
3 convert::TryFrom,
4 fmt::{self, Display},
5 ops::Add,
6};
7
8use crate::webtransport::SessionId;
9
10use super::{
11 coding::{BufExt, BufMutExt, Decode, Encode, UnexpectedEnd},
12 varint::VarInt,
13};
14
15#[derive(Debug, PartialEq, Eq, Clone)]
16pub struct StreamType(u64);
17
18macro_rules! stream_types {
19 {$($name:ident = $val:expr,)*} => {
20 impl StreamType {
21 $(pub const $name: StreamType = StreamType($val);)*
22 }
23 }
24}
25
26stream_types! {
27 CONTROL = 0x00,
28 PUSH = 0x01,
29 ENCODER = 0x02,
30 DECODER = 0x03,
31 WEBTRANSPORT_BIDI = 0x41,
32 WEBTRANSPORT_UNI = 0x54,
33}
34
35impl StreamType {
36 pub const MAX_ENCODED_SIZE: usize = VarInt::MAX_SIZE;
37
38 pub fn value(&self) -> u64 {
39 self.0
40 }
41 pub fn grease() -> Self {
44 StreamType(fastrand::u64(0..0x210842108421083) * 0x1f + 0x21)
45 }
46}
47
48impl Decode for StreamType {
49 fn decode<B: Buf>(buf: &mut B) -> Result<Self, UnexpectedEnd> {
50 Ok(StreamType(buf.get_var()?))
51 }
52}
53
54impl Encode for StreamType {
55 fn encode<W: BufMut>(&self, buf: &mut W) {
56 buf.write_var(self.0);
57 }
58}
59
60impl fmt::Display for StreamType {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 match self {
63 &StreamType::CONTROL => write!(f, "Control"),
64 &StreamType::ENCODER => write!(f, "Encoder"),
65 &StreamType::DECODER => write!(f, "Decoder"),
66 &StreamType::WEBTRANSPORT_UNI => write!(f, "WebTransportUni"),
67 x => write!(f, "StreamType({})", x.0),
68 }
69 }
70}
71
72#[derive(Debug, Copy, Clone, Eq, PartialEq, Ord, PartialOrd, Hash)]
74pub struct StreamId(#[cfg(not(test))] u64, #[cfg(test)] pub(crate) u64);
75
76impl fmt::Display for StreamId {
77 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
78 let initiator = match self.initiator() {
79 Side::Client => "client",
80 Side::Server => "server",
81 };
82 let dir = match self.dir() {
83 Dir::Uni => "uni",
84 Dir::Bi => "bi",
85 };
86 write!(
87 f,
88 "{} {}directional stream {}",
89 initiator,
90 dir,
91 self.index()
92 )
93 }
94}
95
96impl StreamId {
97 pub(crate) const FIRST_REQUEST: Self = Self::new(0, Dir::Bi, Side::Client);
98
99 pub fn is_request(&self) -> bool {
101 self.dir() == Dir::Bi && self.initiator() == Side::Client
102 }
103
104 pub fn is_push(&self) -> bool {
106 self.dir() == Dir::Uni && self.initiator() == Side::Server
107 }
108
109 pub(crate) fn initiator(self) -> Side {
111 if self.0 & 0x1 == 0 {
112 Side::Client
113 } else {
114 Side::Server
115 }
116 }
117
118 const fn new(index: u64, dir: Dir, initiator: Side) -> Self {
120 StreamId((index) << 2 | (dir as u64) << 1 | initiator as u64)
121 }
122
123 pub fn index(self) -> u64 {
125 self.0 >> 2
126 }
127
128 fn dir(self) -> Dir {
130 if self.0 & 0x2 == 0 {
131 Dir::Bi
132 } else {
133 Dir::Uni
134 }
135 }
136
137 pub(crate) fn into_inner(self) -> u64 {
138 self.0
139 }
140}
141
142impl TryFrom<u64> for StreamId {
143 type Error = InvalidStreamId;
144 fn try_from(v: u64) -> Result<Self, Self::Error> {
145 if v > VarInt::MAX.0 {
146 return Err(InvalidStreamId(v));
147 }
148 Ok(Self(v))
149 }
150}
151
152impl From<VarInt> for StreamId {
153 fn from(v: VarInt) -> Self {
154 Self(v.0)
155 }
156}
157
158impl From<StreamId> for VarInt {
159 fn from(v: StreamId) -> Self {
160 Self(v.0)
161 }
162}
163
164#[derive(Debug, PartialEq)]
166pub struct InvalidStreamId(pub(crate) u64);
167
168impl Display for InvalidStreamId {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 write!(f, "invalid stream id: {:x}", self.0)
171 }
172}
173
174impl Encode for StreamId {
175 fn encode<B: bytes::BufMut>(&self, buf: &mut B) {
176 VarInt::from_u64(self.0).unwrap().encode(buf);
177 }
178}
179
180impl Add<usize> for StreamId {
181 type Output = StreamId;
182
183 #[allow(clippy::suspicious_arithmetic_impl)]
184 fn add(self, rhs: usize) -> Self::Output {
185 let index = u64::min(
186 u64::saturating_add(self.index(), rhs as u64),
187 VarInt::MAX.0 >> 2,
188 );
189 Self::new(index, self.dir(), self.initiator())
190 }
191}
192
193impl From<SessionId> for StreamId {
194 fn from(value: SessionId) -> Self {
195 Self(value.into_inner())
196 }
197}
198
199#[derive(Debug, Copy, Clone, Eq, PartialEq)]
200pub enum Side {
201 Client = 0,
203 Server = 1,
205}
206
207#[derive(Debug, Copy, Clone, Eq, PartialEq)]
209enum Dir {
210 Bi = 0,
212 Uni = 1,
214}