webtransport_proto/
connect.rs1use std::str::FromStr;
2
3use bytes::{Buf, BufMut};
4use url::Url;
5
6use super::{qpack, Frame, VarInt};
7
8use thiserror::Error;
9
10#[derive(Error, Debug, Clone)]
12pub enum ConnectError {
13 #[error("unexpected end of input")]
14 UnexpectedEnd,
15
16 #[error("qpack error")]
17 QpackError(#[from] qpack::DecodeError),
18
19 #[error("unexpected frame {0:?}")]
20 UnexpectedFrame(Frame),
21
22 #[error("invalid method")]
23 InvalidMethod,
24
25 #[error("invalid url")]
26 InvalidUrl(#[from] url::ParseError),
27
28 #[error("invalid status")]
29 InvalidStatus,
30
31 #[error("expected 200, got: {0:?}")]
32 WrongStatus(Option<http::StatusCode>),
33
34 #[error("expected connect, got: {0:?}")]
35 WrongMethod(Option<http::method::Method>),
36
37 #[error("expected https, got: {0:?}")]
38 WrongScheme(Option<String>),
39
40 #[error("expected authority header")]
41 WrongAuthority,
42
43 #[error("expected webtransport, got: {0:?}")]
44 WrongProtocol(Option<String>),
45
46 #[error("expected path header")]
47 WrongPath,
48
49 #[error("non-200 status: {0:?}")]
50 ErrorStatus(http::StatusCode),
51}
52
53#[derive(Debug)]
54pub struct ConnectRequest {
55 pub url: Url,
56}
57
58impl ConnectRequest {
59 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
60 let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
61 if typ != Frame::HEADERS {
62 return Err(ConnectError::UnexpectedFrame(typ));
63 }
64
65 let headers = qpack::Headers::decode(&mut data)?;
68
69 let scheme = match headers.get(":scheme") {
70 Some("https") => "https",
71 Some(scheme) => Err(ConnectError::WrongScheme(Some(scheme.to_string())))?,
72 None => return Err(ConnectError::WrongScheme(None)),
73 };
74
75 let authority = headers
76 .get(":authority")
77 .ok_or(ConnectError::WrongAuthority)?;
78
79 let path = headers.get(":path").ok_or(ConnectError::WrongPath)?;
80
81 let method = headers.get(":method");
82 match method
83 .map(|method| method.try_into().map_err(|_| ConnectError::InvalidMethod))
84 .transpose()?
85 {
86 Some(http::Method::CONNECT) => (),
87 o => return Err(ConnectError::WrongMethod(o)),
88 };
89
90 let protocol = headers.get(":protocol");
91 if protocol != Some("webtransport") {
92 return Err(ConnectError::WrongProtocol(protocol.map(|s| s.to_string())));
93 }
94
95 let url = Url::parse(&format!("{}://{}{}", scheme, authority, path))?;
96
97 Ok(Self { url })
98 }
99
100 pub fn encode<B: BufMut>(&self, buf: &mut B) {
101 let mut headers = qpack::Headers::default();
102 headers.set(":method", "CONNECT");
103 headers.set(":scheme", self.url.scheme());
104 headers.set(":authority", self.url.authority());
105 headers.set(":path", self.url.path());
106 headers.set(":protocol", "webtransport");
107
108 let mut tmp = Vec::new();
110 headers.encode(&mut tmp);
111 let size = VarInt::from_u32(tmp.len() as u32);
112
113 Frame::HEADERS.encode(buf);
114 size.encode(buf);
115 buf.put_slice(&tmp);
116 }
117}
118
119#[derive(Debug)]
120pub struct ConnectResponse {
121 pub status: http::status::StatusCode,
122}
123
124impl ConnectResponse {
125 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
126 let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
127 if typ != Frame::HEADERS {
128 return Err(ConnectError::UnexpectedFrame(typ));
129 }
130
131 let headers = qpack::Headers::decode(&mut data)?;
132
133 let status = match headers
134 .get(":status")
135 .map(|status| {
136 http::StatusCode::from_str(status).map_err(|_| ConnectError::InvalidStatus)
137 })
138 .transpose()?
139 {
140 Some(status) if status.is_success() => status,
141 o => return Err(ConnectError::WrongStatus(o)),
142 };
143
144 Ok(Self { status })
145 }
146
147 pub fn encode<B: BufMut>(&self, buf: &mut B) {
148 let mut headers = qpack::Headers::default();
149 headers.set(":status", self.status.as_str());
150 headers.set("sec-webtransport-http3-draft", "draft02");
151
152 let mut tmp = Vec::new();
154 headers.encode(&mut tmp);
155 let size = VarInt::from_u32(tmp.len() as u32);
156
157 Frame::HEADERS.encode(buf);
158 size.encode(buf);
159 buf.put_slice(&tmp);
160 }
161}