wtransport_proto_lightyear_patch/
ids.rs1use crate::varint::VarInt;
2use std::fmt;
3use std::str::FromStr;
4
5#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
7pub struct StreamId(VarInt);
8
9impl StreamId {
10 pub const MAX: StreamId = StreamId(VarInt::MAX);
12
13 #[inline(always)]
15 pub const fn new(varint: VarInt) -> Self {
16 Self(varint)
17 }
18
19 #[inline(always)]
21 pub const fn is_bidirectional(self) -> bool {
22 self.0.into_inner() & 0x2 == 0
23 }
24
25 #[inline(always)]
27 pub const fn is_client_initiated(self) -> bool {
28 self.0.into_inner() & 0x1 == 0
29 }
30
31 #[inline(always)]
33 pub const fn is_local(self, is_server: bool) -> bool {
34 (self.0.into_inner() & 0x1) == (is_server as u64)
35 }
36
37 #[inline(always)]
39 pub const fn into_u64(self) -> u64 {
40 self.0.into_inner()
41 }
42
43 #[inline(always)]
45 pub const fn into_varint(self) -> VarInt {
46 self.0
47 }
48}
49
50impl From<StreamId> for VarInt {
51 #[inline(always)]
52 fn from(stream_id: StreamId) -> Self {
53 stream_id.0
54 }
55}
56
57impl fmt::Debug for StreamId {
58 #[inline(always)]
59 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
60 self.0.fmt(f)
61 }
62}
63
64impl fmt::Display for StreamId {
65 #[inline(always)]
66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67 self.0.fmt(f)
68 }
69}
70
71#[derive(Debug)]
73pub struct InvalidSessionId;
74
75#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
80pub struct SessionId(StreamId);
81
82impl SessionId {
83 #[inline(always)]
85 pub const fn into_u64(self) -> u64 {
86 self.0.into_u64()
87 }
88
89 #[inline(always)]
91 pub const fn into_varint(self) -> VarInt {
92 self.0.into_varint()
93 }
94
95 #[inline(always)]
97 pub const fn session_stream(self) -> StreamId {
98 self.0
99 }
100
101 pub fn try_from_session_stream(stream_id: StreamId) -> Result<Self, InvalidSessionId> {
106 if stream_id.is_bidirectional() && stream_id.is_client_initiated() {
107 Ok(Self(stream_id))
108 } else {
109 Err(InvalidSessionId)
110 }
111 }
112
113 #[inline(always)]
119 pub const unsafe fn from_session_stream_unchecked(stream_id: StreamId) -> Self {
120 debug_assert!(stream_id.is_bidirectional() && stream_id.is_client_initiated());
121 Self(stream_id)
122 }
123
124 #[inline(always)]
125 pub(crate) fn try_from_varint(varint: VarInt) -> Result<Self, InvalidSessionId> {
126 Self::try_from_session_stream(StreamId::new(varint))
127 }
128
129 #[cfg(test)]
130 pub(crate) fn maybe_invalid(varint: VarInt) -> Self {
131 Self(StreamId::new(varint))
132 }
133}
134
135impl fmt::Debug for SessionId {
136 #[inline(always)]
137 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
138 self.0.fmt(f)
139 }
140}
141
142impl fmt::Display for SessionId {
143 #[inline(always)]
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 self.0.fmt(f)
146 }
147}
148
149#[derive(Debug)]
151pub struct InvalidQStreamId;
152
153#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
155pub struct QStreamId(VarInt);
156
157impl QStreamId {
158 pub const MAX: QStreamId =
161 unsafe { Self(VarInt::from_u64_unchecked(1_152_921_504_606_846_975)) };
162
163 #[inline(always)]
165 pub const fn from_session_id(session_id: SessionId) -> Self {
166 let value = session_id.into_u64() >> 2;
167 debug_assert!(value <= Self::MAX.into_u64());
168
169 let varint = unsafe { VarInt::from_u64_unchecked(value) };
171
172 Self(varint)
173 }
174
175 #[inline(always)]
179 pub const fn into_stream_id(self) -> StreamId {
180 let varint = unsafe {
182 debug_assert!(self.0.into_inner() << 2 <= VarInt::MAX.into_inner());
183 VarInt::from_u64_unchecked(self.0.into_inner() << 2)
184 };
185
186 StreamId::new(varint)
187 }
188
189 #[inline(always)]
191 pub const fn into_session_id(self) -> SessionId {
192 let stream_id = self.into_stream_id();
193
194 unsafe {
196 debug_assert!(stream_id.is_bidirectional() && stream_id.is_client_initiated());
197 SessionId::from_session_stream_unchecked(stream_id)
198 }
199 }
200
201 #[inline(always)]
203 pub const fn into_u64(self) -> u64 {
204 self.0.into_inner()
205 }
206
207 #[inline(always)]
209 pub const fn into_varint(self) -> VarInt {
210 self.0
211 }
212
213 pub(crate) fn try_from_varint(varint: VarInt) -> Result<Self, InvalidQStreamId> {
214 if varint <= Self::MAX.into_varint() {
215 Ok(Self(varint))
216 } else {
217 Err(InvalidQStreamId)
218 }
219 }
220
221 #[cfg(test)]
222 pub(crate) fn maybe_invalid(varint: VarInt) -> QStreamId {
223 Self(varint)
224 }
225}
226
227impl fmt::Debug for QStreamId {
228 #[inline(always)]
229 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
230 self.0.fmt(f)
231 }
232}
233
234impl fmt::Display for QStreamId {
235 #[inline(always)]
236 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
237 self.0.fmt(f)
238 }
239}
240
241#[derive(Debug)]
243pub struct InvalidStatusCode;
244
245#[derive(Default, Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
247pub struct StatusCode(u16);
248
249impl StatusCode {
250 pub const MAX: Self = Self(599);
252
253 pub const MIN: Self = Self(100);
255
256 pub const OK: Self = Self(200);
258
259 pub const FORBIDDEN: Self = Self(403);
261
262 pub const NOT_FOUND: Self = Self(404);
264
265 #[inline(always)]
267 pub fn try_from_u32(value: u32) -> Result<Self, InvalidStatusCode> {
268 value.try_into()
269 }
270
271 #[inline(always)]
273 pub fn into_inner(self) -> u16 {
274 self.0
275 }
276
277 #[inline(always)]
279 pub fn is_successful(self) -> bool {
280 (200..300).contains(&self.0)
281 }
282}
283
284impl TryFrom<u8> for StatusCode {
285 type Error = InvalidStatusCode;
286
287 fn try_from(value: u8) -> Result<Self, Self::Error> {
288 if u16::from(value) >= Self::MIN.0 && u16::from(value) <= Self::MAX.0 {
289 Ok(Self(u16::from(value)))
290 } else {
291 Err(InvalidStatusCode)
292 }
293 }
294}
295
296impl TryFrom<u16> for StatusCode {
297 type Error = InvalidStatusCode;
298
299 fn try_from(value: u16) -> Result<Self, Self::Error> {
300 if (Self::MIN.0..=Self::MAX.0).contains(&value) {
301 Ok(Self(value))
302 } else {
303 Err(InvalidStatusCode)
304 }
305 }
306}
307
308impl TryFrom<u32> for StatusCode {
309 type Error = InvalidStatusCode;
310
311 fn try_from(value: u32) -> Result<Self, Self::Error> {
312 if value >= u32::from(Self::MIN.0) && value <= u32::from(Self::MAX.0) {
313 Ok(Self(value as u16))
314 } else {
315 Err(InvalidStatusCode)
316 }
317 }
318}
319
320impl TryFrom<u64> for StatusCode {
321 type Error = InvalidStatusCode;
322
323 fn try_from(value: u64) -> Result<Self, Self::Error> {
324 if value >= u64::from(Self::MIN.0) && value <= u64::from(Self::MAX.0) {
325 Ok(Self(value as u16))
326 } else {
327 Err(InvalidStatusCode)
328 }
329 }
330}
331
332impl FromStr for StatusCode {
333 type Err = InvalidStatusCode;
334
335 fn from_str(s: &str) -> Result<Self, Self::Err> {
336 Ok(Self(s.parse().map_err(|_| InvalidStatusCode)?))
337 }
338}
339
340impl fmt::Debug for StatusCode {
341 #[inline]
342 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
343 self.0.fmt(f)
344 }
345}
346
347impl fmt::Display for StatusCode {
348 #[inline]
349 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
350 self.0.fmt(f)
351 }
352}
353
354#[cfg(test)]
355mod tests {
356 use utils::stream_types;
357 use utils::StreamType;
358
359 use super::*;
360
361 #[test]
362 fn stream_properties() {
363 for (id, stream_type) in stream_types(1024) {
364 let stream_id = StreamId::new(id);
365
366 match stream_type {
367 StreamType::ClientBi => {
368 assert!(stream_id.is_bidirectional());
369 assert!(stream_id.is_client_initiated());
370 assert!(stream_id.is_local(false));
371 assert!(!stream_id.is_local(true));
372 }
373 StreamType::ServerBi => {
374 assert!(stream_id.is_bidirectional());
375 assert!(!stream_id.is_client_initiated());
376 assert!(!stream_id.is_local(false));
377 assert!(stream_id.is_local(true));
378 }
379 StreamType::ClientUni => {
380 assert!(!stream_id.is_bidirectional());
381 assert!(stream_id.is_client_initiated());
382 assert!(stream_id.is_local(false));
383 assert!(!stream_id.is_local(true));
384 }
385 StreamType::ServerUni => {
386 assert!(!stream_id.is_bidirectional());
387 assert!(!stream_id.is_client_initiated());
388 assert!(!stream_id.is_local(false));
389 assert!(stream_id.is_local(true));
390 }
391 }
392 }
393 }
394
395 #[test]
396 fn session_id() {
397 for (id, stream_type) in stream_types(1024) {
398 if let StreamType::ClientBi = stream_type {
399 assert!(SessionId::try_from_varint(id).is_ok());
400 assert!(SessionId::try_from_session_stream(StreamId::new(id)).is_ok());
401 } else {
402 assert!(SessionId::try_from_varint(id).is_err());
403 assert!(SessionId::try_from_session_stream(StreamId::new(id)).is_err());
404 }
405 }
406 }
407
408 #[test]
409 fn qstream_id() {
410 for (quarter, id) in stream_types(1024)
411 .filter(|(_id, r#type)| matches!(r#type, StreamType::ClientBi))
412 .map(|(id, _type)| id)
413 .enumerate()
414 {
415 let session_id = SessionId::try_from_varint(id).unwrap();
416 let qstream_id = QStreamId::from_session_id(session_id);
417
418 assert_eq!(qstream_id.into_stream_id(), session_id.session_stream());
419 assert_eq!(qstream_id.into_session_id(), session_id);
420 assert_eq!(qstream_id.into_u64(), quarter as u64);
421 }
422 }
423
424 mod utils {
425 use super::*;
426
427 #[derive(Copy, Clone, Debug)]
428 pub enum StreamType {
429 ClientBi,
430 ServerBi,
431 ClientUni,
432 ServerUni,
433 }
434
435 pub fn stream_types(max_id: u32) -> impl Iterator<Item = (VarInt, StreamType)> {
436 [
437 StreamType::ClientBi,
438 StreamType::ServerBi,
439 StreamType::ClientUni,
440 StreamType::ServerUni,
441 ]
442 .into_iter()
443 .cycle()
444 .enumerate()
445 .map(|(index, r#type)| (VarInt::from_u32(index as u32), r#type))
446 .take(max_id as usize)
447 }
448 }
449}