wtransport_proto_lightyear_patch/
session.rs1use crate::headers::Headers;
2use crate::ids::InvalidStatusCode;
3use crate::ids::StatusCode;
4use url::Url;
5
6#[derive(Debug)]
8pub enum UrlParseError {
9 EmptyHost,
11
12 IdnaError,
14
15 InvalidPort,
17
18 InvalidIpv4Address,
20
21 InvalidIpv6Address,
23
24 InvalidDomainCharacter,
26
27 RelativeUrlWithoutBase,
29
30 RelativeUrlWithCannotBeABaseBase,
32
33 SetHostOnCannotBeABaseUrl,
35
36 Overflow,
38
39 Unknown,
41
42 SchemeNotHttps,
44}
45
46#[derive(Debug)]
48pub enum HeadersParseError {
49 MissingMethod,
51
52 MethodNotConnect,
54
55 MissingScheme,
57
58 SchemeNotHttps,
60
61 MissingProtocol,
63
64 ProtocolNotWebTransport,
66
67 MissingAuthority,
69
70 MissingPath,
72
73 MissingStatusCode,
75
76 InvalidStatusCode,
78}
79
80#[derive(Debug)]
86pub struct ReservedHeader;
87
88#[derive(Debug)]
90pub struct SessionRequest(Headers);
91
92impl SessionRequest {
93 pub const RESERVED_HEADERS: &'static [&'static str] =
105 &[":method", ":scheme", ":protocol", ":authority", ":path"];
106
107 pub fn new<S>(url: S) -> Result<Self, UrlParseError>
109 where
110 S: AsRef<str>,
111 {
112 let url = Url::parse(url.as_ref())?;
113
114 if url.scheme() != "https" {
115 return Err(UrlParseError::SchemeNotHttps);
116 }
117
118 let path = format!(
119 "{}{}",
120 url.path(),
121 url.query().map(|s| format!("?{}", s)).unwrap_or_default()
122 );
123
124 let headers = [
125 (":method", "CONNECT"),
126 (":scheme", "https"),
127 (":protocol", "webtransport"),
128 (":authority", url.authority()),
129 (":path", &path),
130 ]
131 .into_iter()
132 .collect();
133
134 Ok(Self(headers))
135 }
136
137 pub fn authority(&self) -> &str {
139 self.0
140 .get(":authority")
141 .expect("Session request must contain ':authority' field")
142 }
143
144 pub fn path(&self) -> &str {
146 self.0
147 .get(":path")
148 .expect("Session request must contain ':path' field")
149 }
150
151 pub fn origin(&self) -> Option<&str> {
153 self.0.get("origin")
154 }
155
156 pub fn user_agent(&self) -> Option<&str> {
158 self.0.get("user-agent")
159 }
160
161 pub fn get<K>(&self, key: K) -> Option<&str>
163 where
164 K: AsRef<str>,
165 {
166 self.0.get(key)
167 }
168
169 pub fn insert<K, V>(&mut self, key: K, value: V) -> Result<(), ReservedHeader>
179 where
180 K: ToString,
181 V: ToString,
182 {
183 let key = key.to_string();
184
185 if Self::RESERVED_HEADERS.iter().any(|rh| rh == &key) {
186 return Err(ReservedHeader);
187 }
188
189 self.0.insert(key, value);
190 Ok(())
191 }
192
193 pub fn headers(&self) -> &Headers {
195 &self.0
196 }
197}
198
199impl TryFrom<Headers> for SessionRequest {
200 type Error = HeadersParseError;
201
202 fn try_from(headers: Headers) -> Result<Self, Self::Error> {
203 if headers
204 .get(":method")
205 .ok_or(HeadersParseError::MissingMethod)?
206 != "CONNECT"
207 {
208 return Err(HeadersParseError::MethodNotConnect);
209 }
210
211 if headers
212 .get(":scheme")
213 .ok_or(HeadersParseError::MissingScheme)?
214 != "https"
215 {
216 return Err(HeadersParseError::SchemeNotHttps);
217 }
218
219 if headers
220 .get(":protocol")
221 .ok_or(HeadersParseError::MissingProtocol)?
222 != "webtransport"
223 {
224 return Err(HeadersParseError::ProtocolNotWebTransport);
225 }
226
227 headers
228 .get(":authority")
229 .ok_or(HeadersParseError::MissingAuthority)?;
230
231 headers.get(":path").ok_or(HeadersParseError::MissingPath)?;
232
233 Ok(Self(headers))
234 }
235}
236
237impl From<url::ParseError> for UrlParseError {
238 fn from(error: url::ParseError) -> Self {
239 match error {
240 url::ParseError::EmptyHost => UrlParseError::EmptyHost,
241 url::ParseError::IdnaError => UrlParseError::IdnaError,
242 url::ParseError::InvalidPort => UrlParseError::InvalidPort,
243 url::ParseError::InvalidIpv4Address => UrlParseError::InvalidIpv4Address,
244 url::ParseError::InvalidIpv6Address => UrlParseError::InvalidIpv6Address,
245 url::ParseError::InvalidDomainCharacter => UrlParseError::InvalidDomainCharacter,
246 url::ParseError::RelativeUrlWithoutBase => UrlParseError::RelativeUrlWithoutBase,
247 url::ParseError::RelativeUrlWithCannotBeABaseBase => {
248 UrlParseError::RelativeUrlWithCannotBeABaseBase
249 }
250 url::ParseError::SetHostOnCannotBeABaseUrl => UrlParseError::SetHostOnCannotBeABaseUrl,
251 url::ParseError::Overflow => UrlParseError::Overflow,
252 _ => UrlParseError::Unknown,
253 }
254 }
255}
256
257pub struct SessionResponse(Headers);
259
260impl SessionResponse {
261 pub fn with_status_code(status_code: StatusCode) -> Self {
263 let headers = [(":status", status_code.to_string())].into_iter().collect();
264 Self(headers)
265 }
266
267 pub fn ok() -> Self {
269 Self::with_status_code(StatusCode::OK)
270 }
271
272 pub fn forbidden() -> Self {
274 Self::with_status_code(StatusCode::FORBIDDEN)
275 }
276
277 pub fn not_found() -> Self {
279 Self::with_status_code(StatusCode::NOT_FOUND)
280 }
281
282 pub fn code(&self) -> StatusCode {
284 self.0
285 .get(":status")
286 .expect("Status code is always present")
287 .parse()
288 .expect("Status code value must be valid")
289 }
290
291 pub fn add<K, V>(&mut self, key: K, value: V)
295 where
296 K: ToString,
297 V: ToString,
298 {
299 self.0.insert(key, value);
300 }
301
302 pub fn headers(&self) -> &Headers {
304 &self.0
305 }
306}
307
308impl TryFrom<Headers> for SessionResponse {
309 type Error = HeadersParseError;
310
311 fn try_from(headers: Headers) -> Result<Self, Self::Error> {
312 let status_code = headers
313 .get(":status")
314 .ok_or(HeadersParseError::MissingStatusCode)?
315 .parse()
316 .map_err(|InvalidStatusCode| HeadersParseError::InvalidStatusCode)?;
317
318 Ok(Self::with_status_code(status_code))
319 }
320}
321
322#[cfg(test)]
323mod tests {
324 use super::*;
325
326 #[test]
327 fn parse_url() {
328 let request = SessionRequest::new("https://localhost:4433/foo/bar?p1=1&p2=2").unwrap();
329 assert_eq!(request.authority(), "localhost:4433");
330 assert_eq!(request.path(), "/foo/bar?p1=1&p2=2");
331 assert_eq!(request.get(":method").unwrap(), "CONNECT");
332 assert_eq!(request.get(":protocol").unwrap(), "webtransport");
333 }
334
335 #[test]
336 fn not_https() {
337 let error = SessionRequest::new("http://localhost:4433");
338 assert!(matches!(error, Err(UrlParseError::SchemeNotHttps)));
339 }
340
341 #[test]
342 fn parse_headers() {
343 assert!(SessionRequest::try_from(
344 [
345 (":method", "CONNECT"),
346 (":scheme", "https"),
347 (":protocol", "webtransport"),
348 (":authority", "localhost:4433"),
349 (":path", "/")
350 ]
351 .into_iter()
352 .collect::<Headers>()
353 )
354 .is_ok());
355 }
356
357 #[test]
358 fn parse_headers_error_method() {
359 assert!(matches!(
360 SessionRequest::try_from(
361 [
362 (":scheme", "https"),
363 (":protocol", "webtransport"),
364 (":authority", "localhost:4433"),
365 (":path", "/")
366 ]
367 .into_iter()
368 .collect::<Headers>()
369 ),
370 Err(HeadersParseError::MissingMethod),
371 ));
372
373 assert!(matches!(
374 SessionRequest::try_from(
375 [
376 (":method", "GET"),
377 (":scheme", "https"),
378 (":protocol", "webtransport"),
379 (":authority", "localhost:4433"),
380 (":path", "/")
381 ]
382 .into_iter()
383 .collect::<Headers>()
384 ),
385 Err(HeadersParseError::MethodNotConnect),
386 ));
387 }
388
389 #[test]
390 fn parse_headers_error_scheme() {
391 assert!(matches!(
392 SessionRequest::try_from(
393 [
394 (":method", "CONNECT"),
395 (":protocol", "webtransport"),
396 (":authority", "localhost:4433"),
397 (":path", "/")
398 ]
399 .into_iter()
400 .collect::<Headers>()
401 ),
402 Err(HeadersParseError::MissingScheme),
403 ));
404
405 assert!(matches!(
406 SessionRequest::try_from(
407 [
408 (":method", "CONNECT"),
409 (":scheme", "http"),
410 (":protocol", "webtransport"),
411 (":authority", "localhost:4433"),
412 (":path", "/")
413 ]
414 .into_iter()
415 .collect::<Headers>()
416 ),
417 Err(HeadersParseError::SchemeNotHttps),
418 ));
419 }
420
421 #[test]
422 fn insert() {
423 let mut request = SessionRequest::new("https://example.com").unwrap();
424 request.insert("version", "test").unwrap();
425 assert_eq!(request.get("version").unwrap(), "test");
426 }
427
428 #[test]
429 fn insert_reseved() {
430 let mut request = SessionRequest::new("https://example.com").unwrap();
431
432 assert!(matches!(
433 request.insert(":method", "GET"),
434 Err(ReservedHeader)
435 ));
436
437 assert!(matches!(
438 request.insert(":scheme", "ftp"),
439 Err(ReservedHeader)
440 ));
441
442 assert!(matches!(
443 request.insert(":protocol", "web"),
444 Err(ReservedHeader)
445 ));
446
447 assert!(matches!(
448 request.insert(":authority", "me"),
449 Err(ReservedHeader)
450 ));
451
452 assert!(matches!(
453 request.insert(":path", "example"),
454 Err(ReservedHeader)
455 ));
456 }
457}