webtrans_proto/
connect.rs1use std::{str::FromStr, sync::Arc};
4
5use bytes::{Buf, BufMut, BytesMut};
6use tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt};
7use url::Url;
8
9use super::{Frame, VarInt, qpack};
10use crate::io::read_incremental;
11
12use thiserror::Error;
13
14#[derive(Error, Debug, Clone)]
16pub enum ConnectError {
18 #[error("unexpected end of input")]
19 UnexpectedEnd,
21
22 #[error("qpack error")]
23 QpackError(#[from] qpack::DecodeError),
25
26 #[error("unexpected frame {0:?}")]
27 UnexpectedFrame(Frame),
29
30 #[error("invalid method")]
31 InvalidMethod,
33
34 #[error("invalid url")]
35 InvalidUrl(#[from] url::ParseError),
37
38 #[error("invalid status")]
39 InvalidStatus,
41
42 #[error("expected 200, got: {0:?}")]
43 WrongStatus(Option<http::StatusCode>),
45
46 #[error("expected connect, got: {0:?}")]
47 WrongMethod(Option<http::method::Method>),
49
50 #[error("expected https, got: {0:?}")]
51 WrongScheme(Option<String>),
53
54 #[error("expected authority header")]
55 WrongAuthority,
57
58 #[error("expected webtransport, got: {0:?}")]
59 WrongProtocol(Option<String>),
61
62 #[error("expected path header")]
63 WrongPath,
65
66 #[error("non-200 status: {0:?}")]
67 ErrorStatus(http::StatusCode),
69
70 #[error("io error: {0}")]
71 Io(Arc<std::io::Error>),
73}
74
75impl From<std::io::Error> for ConnectError {
76 fn from(err: std::io::Error) -> Self {
77 ConnectError::Io(Arc::new(err))
78 }
79}
80
81#[derive(Debug)]
82pub struct ConnectRequest {
84 pub url: Url,
86}
87
88impl ConnectRequest {
89 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
91 let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
92 if typ != Frame::HEADERS {
93 return Err(ConnectError::UnexpectedFrame(typ));
94 }
95
96 let headers = qpack::Headers::decode(&mut data)?;
99
100 let scheme = match headers.get(":scheme") {
101 Some("https") => "https",
102 Some(scheme) => Err(ConnectError::WrongScheme(Some(scheme.to_string())))?,
103 None => return Err(ConnectError::WrongScheme(None)),
104 };
105
106 let authority = headers
107 .get(":authority")
108 .ok_or(ConnectError::WrongAuthority)?;
109
110 let path_and_query = headers.get(":path").ok_or(ConnectError::WrongPath)?;
111
112 let method = headers.get(":method");
113 match method
114 .map(|method| method.try_into().map_err(|_| ConnectError::InvalidMethod))
115 .transpose()?
116 {
117 Some(http::Method::CONNECT) => (),
118 o => return Err(ConnectError::WrongMethod(o)),
119 };
120
121 let protocol = headers.get(":protocol");
122 if protocol != Some("webtransport") {
123 return Err(ConnectError::WrongProtocol(protocol.map(|s| s.to_string())));
124 }
125
126 let url = Url::parse(&format!("{scheme}://{authority}{path_and_query}"))?;
127
128 Ok(Self { url })
129 }
130
131 pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, ConnectError> {
133 read_incremental(
134 stream,
135 |cursor| Self::decode(cursor),
136 |err| matches!(err, ConnectError::UnexpectedEnd),
137 ConnectError::UnexpectedEnd,
138 )
139 .await
140 }
141
142 pub fn encode<B: BufMut>(&self, buf: &mut B) {
144 let mut headers = qpack::Headers::default();
145 headers.set(":method", "CONNECT");
146 headers.set(":scheme", self.url.scheme());
147 headers.set(":authority", self.url.authority());
148 let path_and_query = match self.url.query() {
149 Some(query) => format!("{}?{}", self.url.path(), query),
150 None => self.url.path().to_string(),
151 };
152 headers.set(":path", &path_and_query);
153 headers.set(":protocol", "webtransport");
154 encode_headers_frame(buf, &headers);
155 }
156
157 pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), ConnectError> {
159 let mut buf = BytesMut::new();
160 self.encode(&mut buf);
161 stream.write_all_buf(&mut buf).await?;
162 Ok(())
163 }
164}
165
166#[derive(Debug)]
167pub struct ConnectResponse {
169 pub status: http::status::StatusCode,
171}
172
173impl ConnectResponse {
174 pub fn decode<B: Buf>(buf: &mut B) -> Result<Self, ConnectError> {
176 let (typ, mut data) = Frame::read(buf).map_err(|_| ConnectError::UnexpectedEnd)?;
177 if typ != Frame::HEADERS {
178 return Err(ConnectError::UnexpectedFrame(typ));
179 }
180
181 let headers = qpack::Headers::decode(&mut data)?;
182
183 let status = match headers
184 .get(":status")
185 .map(|status| {
186 http::StatusCode::from_str(status).map_err(|_| ConnectError::InvalidStatus)
187 })
188 .transpose()?
189 {
190 Some(status) if status.is_success() => status,
191 o => return Err(ConnectError::WrongStatus(o)),
192 };
193
194 Ok(Self { status })
195 }
196
197 pub async fn read<S: AsyncRead + Unpin>(stream: &mut S) -> Result<Self, ConnectError> {
199 read_incremental(
200 stream,
201 |cursor| Self::decode(cursor),
202 |err| matches!(err, ConnectError::UnexpectedEnd),
203 ConnectError::UnexpectedEnd,
204 )
205 .await
206 }
207
208 pub fn encode<B: BufMut>(&self, buf: &mut B) {
210 let mut headers = qpack::Headers::default();
211 headers.set(":status", self.status.as_str());
212 headers.set("sec-webtransport-http3-draft", "draft02");
213 encode_headers_frame(buf, &headers);
214 }
215
216 pub async fn write<S: AsyncWrite + Unpin>(&self, stream: &mut S) -> Result<(), ConnectError> {
218 let mut buf = BytesMut::new();
219 self.encode(&mut buf);
220 stream.write_all_buf(&mut buf).await?;
221 Ok(())
222 }
223}
224
225fn encode_headers_frame<B: BufMut>(buf: &mut B, headers: &qpack::Headers) {
226 let mut payload = Vec::new();
228 headers.encode(&mut payload);
229
230 Frame::HEADERS.encode(buf);
231 VarInt::from_u32(payload.len() as u32).encode(buf);
232 buf.put_slice(&payload);
233}