webtransport_proto/
connect.rs

1use std::str::FromStr;
2
3use bytes::{Buf, BufMut};
4use url::Url;
5
6use super::{qpack, Frame, VarInt};
7
8use thiserror::Error;
9
10// Errors that can occur during the connect request.
11#[derive(Error, Debug, Clone)]
12pub enum ConnectError {
13    #[error("unexpected end of input")]
14    UnexpectedEnd,
15
16    #[error("qpack error")]
17    QpackError(#[from] qpack::DecodeError),
18
19    #[error("unexpected frame {0:?}")]
20    UnexpectedFrame(Frame),
21
22    #[error("invalid method")]
23    InvalidMethod,
24
25    #[error("invalid url")]
26    InvalidUrl(#[from] url::ParseError),
27
28    #[error("invalid status")]
29    InvalidStatus,
30
31    #[error("expected 200, got: {0:?}")]
32    WrongStatus(Option<http::StatusCode>),
33
34    #[error("expected connect, got: {0:?}")]
35    WrongMethod(Option<http::method::Method>),
36
37    #[error("expected https, got: {0:?}")]
38    WrongScheme(Option<String>),
39
40    #[error("expected authority header")]
41    WrongAuthority,
42
43    #[error("expected webtransport, got: {0:?}")]
44    WrongProtocol(Option<String>),
45
46    #[error("expected path header")]
47    WrongPath,
48
49    #[error("non-200 status: {0:?}")]
50    ErrorStatus(http::StatusCode),
51}
52
53#[derive(Debug)]
54pub struct ConnectRequest {
55    pub url: Url,
56}
57
58impl ConnectRequest {
59    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
60        let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
61        if typ != Frame::HEADERS {
62            return Err(ConnectError::UnexpectedFrame(typ));
63        }
64
65        // We no longer return UnexpectedEnd because we know the buffer should be large enough.
66
67        let headers = qpack::Headers::decode(&mut data)?;
68
69        let scheme = match headers.get(":scheme") {
70            Some("https") => "https",
71            Some(scheme) => Err(ConnectError::WrongScheme(Some(scheme.to_string())))?,
72            None => return Err(ConnectError::WrongScheme(None)),
73        };
74
75        let authority = headers
76            .get(":authority")
77            .ok_or(ConnectError::WrongAuthority)?;
78
79        let path = headers.get(":path").ok_or(ConnectError::WrongPath)?;
80
81        let method = headers.get(":method");
82        match method
83            .map(|method| method.try_into().map_err(|_| ConnectError::InvalidMethod))
84            .transpose()?
85        {
86            Some(http::Method::CONNECT) => (),
87            o => return Err(ConnectError::WrongMethod(o)),
88        };
89
90        let protocol = headers.get(":protocol");
91        if protocol != Some("webtransport") {
92            return Err(ConnectError::WrongProtocol(protocol.map(|s| s.to_string())));
93        }
94
95        let url = Url::parse(&format!("{}://{}{}", scheme, authority, path))?;
96
97        Ok(Self { url })
98    }
99
100    pub fn encode<B: BufMut>(&self, buf: &mut B) {
101        let mut headers = qpack::Headers::default();
102        headers.set(":method", "CONNECT");
103        headers.set(":scheme", self.url.scheme());
104        headers.set(":authority", self.url.authority());
105        headers.set(":path", self.url.path());
106        headers.set(":protocol", "webtransport");
107
108        // Use a temporary buffer so we can compute the size.
109        let mut tmp = Vec::new();
110        headers.encode(&mut tmp);
111        let size = VarInt::from_u32(tmp.len() as u32);
112
113        Frame::HEADERS.encode(buf);
114        size.encode(buf);
115        buf.put_slice(&tmp);
116    }
117}
118
119#[derive(Debug)]
120pub struct ConnectResponse {
121    pub status: http::status::StatusCode,
122}
123
124impl ConnectResponse {
125    pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
126        let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
127        if typ != Frame::HEADERS {
128            return Err(ConnectError::UnexpectedFrame(typ));
129        }
130
131        let headers = qpack::Headers::decode(&mut data)?;
132
133        let status = match headers
134            .get(":status")
135            .map(|status| {
136                http::StatusCode::from_str(status).map_err(|_| ConnectError::InvalidStatus)
137            })
138            .transpose()?
139        {
140            Some(status) if status.is_success() => status,
141            o => return Err(ConnectError::WrongStatus(o)),
142        };
143
144        Ok(Self { status })
145    }
146
147    pub fn encode<B: BufMut>(&self, buf: &mut B) {
148        let mut headers = qpack::Headers::default();
149        headers.set(":status", self.status.as_str());
150        headers.set("sec-webtransport-http3-draft", "draft02");
151
152        // Use a temporary buffer so we can compute the size.
153        let mut tmp = Vec::new();
154        headers.encode(&mut tmp);
155        let size = VarInt::from_u32(tmp.len() as u32);
156
157        Frame::HEADERS.encode(buf);
158        size.encode(buf);
159        buf.put_slice(&tmp);
160    }
161}