use crate::varint::VarInt;
use std::fmt;
use std::str::FromStr;
#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct StreamId(VarInt);
impl StreamId {
pub const MAX: StreamId = StreamId(VarInt::MAX);
#[inline(always)]
pub const fn new(varint: VarInt) -> Self {
Self(varint)
}
#[inline(always)]
pub const fn is_bidirectional(self) -> bool {
self.0.into_inner() & 0x2 == 0
}
#[inline(always)]
pub const fn is_client_initiated(self) -> bool {
self.0.into_inner() & 0x1 == 0
}
#[inline(always)]
pub const fn is_local(self, is_server: bool) -> bool {
(self.0.into_inner() & 0x1) == (is_server as u64)
}
#[inline(always)]
pub const fn into_u64(self) -> u64 {
self.0.into_inner()
}
#[inline(always)]
pub const fn into_varint(self) -> VarInt {
self.0
}
}
impl From<StreamId> for VarInt {
#[inline(always)]
fn from(stream_id: StreamId) -> Self {
stream_id.0
}
}
impl fmt::Debug for StreamId {
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Display for StreamId {
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, thiserror::Error)]
#[error("invalid session ID")]
pub struct InvalidSessionId;
#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct SessionId(StreamId);
impl SessionId {
#[inline(always)]
pub const fn into_u64(self) -> u64 {
self.0.into_u64()
}
#[inline(always)]
pub const fn into_varint(self) -> VarInt {
self.0.into_varint()
}
#[inline(always)]
pub const fn session_stream(self) -> StreamId {
self.0
}
pub fn try_from_session_stream(stream_id: StreamId) -> Result<Self, InvalidSessionId> {
if stream_id.is_bidirectional() && stream_id.is_client_initiated() {
Ok(Self(stream_id))
} else {
Err(InvalidSessionId)
}
}
#[inline(always)]
pub const unsafe fn from_session_stream_unchecked(stream_id: StreamId) -> Self {
debug_assert!(stream_id.is_bidirectional() && stream_id.is_client_initiated());
Self(stream_id)
}
#[inline(always)]
pub(crate) fn try_from_varint(varint: VarInt) -> Result<Self, InvalidSessionId> {
Self::try_from_session_stream(StreamId::new(varint))
}
#[cfg(test)]
pub(crate) fn maybe_invalid(varint: VarInt) -> Self {
Self(StreamId::new(varint))
}
}
impl fmt::Debug for SessionId {
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Display for SessionId {
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, thiserror::Error)]
#[error("invalid QStream ID")]
pub struct InvalidQStreamId;
#[derive(Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct QStreamId(VarInt);
impl QStreamId {
pub const MAX: QStreamId =
unsafe { Self(VarInt::from_u64_unchecked(1_152_921_504_606_846_975)) };
#[inline(always)]
pub const fn from_session_id(session_id: SessionId) -> Self {
let value = session_id.into_u64() >> 2;
debug_assert!(value <= Self::MAX.into_u64());
let varint = unsafe { VarInt::from_u64_unchecked(value) };
Self(varint)
}
#[inline(always)]
pub const fn into_stream_id(self) -> StreamId {
let varint = unsafe {
debug_assert!(self.0.into_inner() << 2 <= VarInt::MAX.into_inner());
VarInt::from_u64_unchecked(self.0.into_inner() << 2)
};
StreamId::new(varint)
}
#[inline(always)]
pub const fn into_session_id(self) -> SessionId {
let stream_id = self.into_stream_id();
unsafe {
debug_assert!(stream_id.is_bidirectional() && stream_id.is_client_initiated());
SessionId::from_session_stream_unchecked(stream_id)
}
}
#[inline(always)]
pub const fn into_u64(self) -> u64 {
self.0.into_inner()
}
#[inline(always)]
pub const fn into_varint(self) -> VarInt {
self.0
}
pub(crate) fn try_from_varint(varint: VarInt) -> Result<Self, InvalidQStreamId> {
if varint <= Self::MAX.into_varint() {
Ok(Self(varint))
} else {
Err(InvalidQStreamId)
}
}
#[cfg(test)]
pub(crate) fn maybe_invalid(varint: VarInt) -> QStreamId {
Self(varint)
}
}
impl fmt::Debug for QStreamId {
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Display for QStreamId {
#[inline(always)]
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}
#[derive(Debug, thiserror::Error)]
#[error("invalid HTTP status code")]
pub struct InvalidStatusCode;
#[derive(Default, Copy, Clone, Eq, Hash, Ord, PartialEq, PartialOrd)]
pub struct StatusCode(u16);
impl StatusCode {
pub const MAX: Self = Self(599);
pub const MIN: Self = Self(100);
pub const OK: Self = Self(200);
pub const FORBIDDEN: Self = Self(403);
pub const NOT_FOUND: Self = Self(404);
pub const TOO_MANY_REQUESTS: Self = Self(429);
#[inline(always)]
pub fn try_from_u32(value: u32) -> Result<Self, InvalidStatusCode> {
value.try_into()
}
#[inline(always)]
pub fn into_inner(self) -> u16 {
self.0
}
#[inline(always)]
pub fn is_successful(self) -> bool {
(200..300).contains(&self.0)
}
}
impl TryFrom<u8> for StatusCode {
type Error = InvalidStatusCode;
fn try_from(value: u8) -> Result<Self, Self::Error> {
if u16::from(value) >= Self::MIN.0 && u16::from(value) <= Self::MAX.0 {
Ok(Self(u16::from(value)))
} else {
Err(InvalidStatusCode)
}
}
}
impl TryFrom<u16> for StatusCode {
type Error = InvalidStatusCode;
fn try_from(value: u16) -> Result<Self, Self::Error> {
if (Self::MIN.0..=Self::MAX.0).contains(&value) {
Ok(Self(value))
} else {
Err(InvalidStatusCode)
}
}
}
impl TryFrom<u32> for StatusCode {
type Error = InvalidStatusCode;
fn try_from(value: u32) -> Result<Self, Self::Error> {
if value >= u32::from(Self::MIN.0) && value <= u32::from(Self::MAX.0) {
Ok(Self(value as u16))
} else {
Err(InvalidStatusCode)
}
}
}
impl TryFrom<u64> for StatusCode {
type Error = InvalidStatusCode;
fn try_from(value: u64) -> Result<Self, Self::Error> {
if value >= u64::from(Self::MIN.0) && value <= u64::from(Self::MAX.0) {
Ok(Self(value as u16))
} else {
Err(InvalidStatusCode)
}
}
}
impl FromStr for StatusCode {
type Err = InvalidStatusCode;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(Self(s.parse().map_err(|_| InvalidStatusCode)?))
}
}
impl fmt::Debug for StatusCode {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
impl fmt::Display for StatusCode {
#[inline]
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.0.fmt(f)
}
}
#[cfg(test)]
mod tests {
use utils::stream_types;
use utils::StreamType;
use super::*;
#[test]
fn stream_properties() {
for (id, stream_type) in stream_types(1024) {
let stream_id = StreamId::new(id);
match stream_type {
StreamType::ClientBi => {
assert!(stream_id.is_bidirectional());
assert!(stream_id.is_client_initiated());
assert!(stream_id.is_local(false));
assert!(!stream_id.is_local(true));
}
StreamType::ServerBi => {
assert!(stream_id.is_bidirectional());
assert!(!stream_id.is_client_initiated());
assert!(!stream_id.is_local(false));
assert!(stream_id.is_local(true));
}
StreamType::ClientUni => {
assert!(!stream_id.is_bidirectional());
assert!(stream_id.is_client_initiated());
assert!(stream_id.is_local(false));
assert!(!stream_id.is_local(true));
}
StreamType::ServerUni => {
assert!(!stream_id.is_bidirectional());
assert!(!stream_id.is_client_initiated());
assert!(!stream_id.is_local(false));
assert!(stream_id.is_local(true));
}
}
}
}
#[test]
fn session_id() {
for (id, stream_type) in stream_types(1024) {
if let StreamType::ClientBi = stream_type {
assert!(SessionId::try_from_varint(id).is_ok());
assert!(SessionId::try_from_session_stream(StreamId::new(id)).is_ok());
} else {
assert!(SessionId::try_from_varint(id).is_err());
assert!(SessionId::try_from_session_stream(StreamId::new(id)).is_err());
}
}
}
#[test]
fn qstream_id() {
for (quarter, id) in stream_types(1024)
.filter(|(_id, r#type)| matches!(r#type, StreamType::ClientBi))
.map(|(id, _type)| id)
.enumerate()
{
let session_id = SessionId::try_from_varint(id).unwrap();
let qstream_id = QStreamId::from_session_id(session_id);
assert_eq!(qstream_id.into_stream_id(), session_id.session_stream());
assert_eq!(qstream_id.into_session_id(), session_id);
assert_eq!(qstream_id.into_u64(), quarter as u64);
}
}
mod utils {
use super::*;
#[derive(Copy, Clone, Debug)]
pub enum StreamType {
ClientBi,
ServerBi,
ClientUni,
ServerUni,
}
pub fn stream_types(max_id: u32) -> impl Iterator<Item = (VarInt, StreamType)> {
[
StreamType::ClientBi,
StreamType::ServerBi,
StreamType::ClientUni,
StreamType::ServerUni,
]
.into_iter()
.cycle()
.enumerate()
.map(|(index, r#type)| (VarInt::from_u32(index as u32), r#type))
.take(max_id as usize)
}
}
}