wtransport_proto_lightyear_patch/
session.rs

1use crate::headers::Headers;
2use crate::ids::InvalidStatusCode;
3use crate::ids::StatusCode;
4use url::Url;
5
6/// Error when parsing URL.
7#[derive(Debug)]
8pub enum UrlParseError {
9    /// Missing host part in the URL.
10    EmptyHost,
11
12    /// Invalid international domain name.
13    IdnaError,
14
15    /// Invalid port number.
16    InvalidPort,
17
18    /// Invalid IPv4 address
19    InvalidIpv4Address,
20
21    /// Invalid IPv6 address
22    InvalidIpv6Address,
23
24    /// Invalid domain character.
25    InvalidDomainCharacter,
26
27    /// Relative URL without a base.
28    RelativeUrlWithoutBase,
29
30    /// Relative URL with a cannot-be-a-base base
31    RelativeUrlWithCannotBeABaseBase,
32
33    /// A cannot-be-a-base URL doesn’t have a host to set
34    SetHostOnCannotBeABaseUrl,
35
36    /// URLs more than 4 GB are not supported.
37    Overflow,
38
39    /// Unknown error during URL parsing.
40    Unknown,
41
42    /// WebTransport only support HTTPS method.
43    SchemeNotHttps,
44}
45
46/// Error when parsing [`Headers`].
47#[derive(Debug)]
48pub enum HeadersParseError {
49    /// Method field is missing.
50    MissingMethod,
51
52    /// Method is not 'CONNECT'.
53    MethodNotConnect,
54
55    /// Scheme field is missing.
56    MissingScheme,
57
58    /// Scheme is not 'https'.
59    SchemeNotHttps,
60
61    /// Protocol field is missing.
62    MissingProtocol,
63
64    /// Protocol is not 'webtransport'.
65    ProtocolNotWebTransport,
66
67    /// Authority field is missing.
68    MissingAuthority,
69
70    /// Path field is missing.
71    MissingPath,
72
73    /// Status field is missing.
74    MissingStatusCode,
75
76    /// The status code value is not valid.
77    InvalidStatusCode,
78}
79
80/// An error when attempting to insert a value for a reserved header.
81///
82/// It is returned as an error when trying to insert a key-value pair into
83/// [`SessionRequest`] where the key is one of the
84/// [reserved headers](SessionRequest::RESERVED_HEADERS).
85#[derive(Debug)]
86pub struct ReservedHeader;
87
88/// A CONNECT WebTransport request.
89#[derive(Debug)]
90pub struct SessionRequest(Headers);
91
92impl SessionRequest {
93    /// A collection of reserved headers used in the WebTransport protocol.
94    ///
95    /// Reserved headers have special significance in the WebTransport protocol and
96    /// cannot be used as additional headers with the [`insert`](Self::insert) method.
97    ///
98    /// The following headers are considered reserved:
99    /// - `:method`
100    /// - `:scheme`
101    /// - `:protocol`
102    /// - `:authority`
103    /// - `:path`
104    pub const RESERVED_HEADERS: &'static [&'static str] =
105        &[":method", ":scheme", ":protocol", ":authority", ":path"];
106
107    /// Parses an URL to build a Session request.
108    pub fn new<S>(url: S) -> Result<Self, UrlParseError>
109    where
110        S: AsRef<str>,
111    {
112        let url = Url::parse(url.as_ref())?;
113
114        if url.scheme() != "https" {
115            return Err(UrlParseError::SchemeNotHttps);
116        }
117
118        let path = format!(
119            "{}{}",
120            url.path(),
121            url.query().map(|s| format!("?{}", s)).unwrap_or_default()
122        );
123
124        let headers = [
125            (":method", "CONNECT"),
126            (":scheme", "https"),
127            (":protocol", "webtransport"),
128            (":authority", url.authority()),
129            (":path", &path),
130        ]
131        .into_iter()
132        .collect();
133
134        Ok(Self(headers))
135    }
136
137    /// Returns the `:authority` field of the request.
138    pub fn authority(&self) -> &str {
139        self.0
140            .get(":authority")
141            .expect("Session request must contain ':authority' field")
142    }
143
144    /// Returns the `:path` field of the request.
145    pub fn path(&self) -> &str {
146        self.0
147            .get(":path")
148            .expect("Session request must contain ':path' field")
149    }
150
151    /// Returns the `origin` field of the request if present.
152    pub fn origin(&self) -> Option<&str> {
153        self.0.get("origin")
154    }
155
156    /// Returns the `user-agent` field of the request if present.
157    pub fn user_agent(&self) -> Option<&str> {
158        self.0.get("user-agent")
159    }
160
161    /// Gets a field from the request (if present).
162    pub fn get<K>(&self, key: K) -> Option<&str>
163    where
164        K: AsRef<str>,
165    {
166        self.0.get(key)
167    }
168
169    /// Inserts a key-value pair into the header map, checking for reserved headers.
170    ///
171    /// This method inserts a key-value pair into the header map after ensuring that
172    /// the specified key is not one of the [reserved headers](Self::RESERVED_HEADERS).
173    /// If the key is reserved, the method returns an `Err(ReservedHeader)` indicating
174    /// the attempt to insert a value for a reserved header.
175    ///
176    /// If the key already exists in the header map, the corresponding value is updated with
177    /// the new value.
178    pub fn insert<K, V>(&mut self, key: K, value: V) -> Result<(), ReservedHeader>
179    where
180        K: ToString,
181        V: ToString,
182    {
183        let key = key.to_string();
184
185        if Self::RESERVED_HEADERS.iter().any(|rh| rh == &key) {
186            return Err(ReservedHeader);
187        }
188
189        self.0.insert(key, value);
190        Ok(())
191    }
192
193    /// Returns the whole headers associated with the request.
194    pub fn headers(&self) -> &Headers {
195        &self.0
196    }
197}
198
199impl TryFrom<Headers> for SessionRequest {
200    type Error = HeadersParseError;
201
202    fn try_from(headers: Headers) -> Result<Self, Self::Error> {
203        if headers
204            .get(":method")
205            .ok_or(HeadersParseError::MissingMethod)?
206            != "CONNECT"
207        {
208            return Err(HeadersParseError::MethodNotConnect);
209        }
210
211        if headers
212            .get(":scheme")
213            .ok_or(HeadersParseError::MissingScheme)?
214            != "https"
215        {
216            return Err(HeadersParseError::SchemeNotHttps);
217        }
218
219        if headers
220            .get(":protocol")
221            .ok_or(HeadersParseError::MissingProtocol)?
222            != "webtransport"
223        {
224            return Err(HeadersParseError::ProtocolNotWebTransport);
225        }
226
227        headers
228            .get(":authority")
229            .ok_or(HeadersParseError::MissingAuthority)?;
230
231        headers.get(":path").ok_or(HeadersParseError::MissingPath)?;
232
233        Ok(Self(headers))
234    }
235}
236
237impl From<url::ParseError> for UrlParseError {
238    fn from(error: url::ParseError) -> Self {
239        match error {
240            url::ParseError::EmptyHost => UrlParseError::EmptyHost,
241            url::ParseError::IdnaError => UrlParseError::IdnaError,
242            url::ParseError::InvalidPort => UrlParseError::InvalidPort,
243            url::ParseError::InvalidIpv4Address => UrlParseError::InvalidIpv4Address,
244            url::ParseError::InvalidIpv6Address => UrlParseError::InvalidIpv6Address,
245            url::ParseError::InvalidDomainCharacter => UrlParseError::InvalidDomainCharacter,
246            url::ParseError::RelativeUrlWithoutBase => UrlParseError::RelativeUrlWithoutBase,
247            url::ParseError::RelativeUrlWithCannotBeABaseBase => {
248                UrlParseError::RelativeUrlWithCannotBeABaseBase
249            }
250            url::ParseError::SetHostOnCannotBeABaseUrl => UrlParseError::SetHostOnCannotBeABaseUrl,
251            url::ParseError::Overflow => UrlParseError::Overflow,
252            _ => UrlParseError::Unknown,
253        }
254    }
255}
256
257/// A WebTransport CONNECT response.
258pub struct SessionResponse(Headers);
259
260impl SessionResponse {
261    /// Constructs from [`StatusCode`].
262    pub fn with_status_code(status_code: StatusCode) -> Self {
263        let headers = [(":status", status_code.to_string())].into_iter().collect();
264        Self(headers)
265    }
266
267    /// Constructs with [`StatusCode::OK`].
268    pub fn ok() -> Self {
269        Self::with_status_code(StatusCode::OK)
270    }
271
272    /// Constructs with [`StatusCode::FORBIDDEN`].
273    pub fn forbidden() -> Self {
274        Self::with_status_code(StatusCode::FORBIDDEN)
275    }
276
277    /// Constructs with [`StatusCode::NOT_FOUND`].
278    pub fn not_found() -> Self {
279        Self::with_status_code(StatusCode::NOT_FOUND)
280    }
281
282    /// Returns the status code.
283    pub fn code(&self) -> StatusCode {
284        self.0
285            .get(":status")
286            .expect("Status code is always present")
287            .parse()
288            .expect("Status code value must be valid")
289    }
290
291    /// Adds a header field to the response.
292    ///
293    /// If the key is already present, the value is updated.
294    pub fn add<K, V>(&mut self, key: K, value: V)
295    where
296        K: ToString,
297        V: ToString,
298    {
299        self.0.insert(key, value);
300    }
301
302    /// Returns the whole headers associated with the request.
303    pub fn headers(&self) -> &Headers {
304        &self.0
305    }
306}
307
308impl TryFrom<Headers> for SessionResponse {
309    type Error = HeadersParseError;
310
311    fn try_from(headers: Headers) -> Result<Self, Self::Error> {
312        let status_code = headers
313            .get(":status")
314            .ok_or(HeadersParseError::MissingStatusCode)?
315            .parse()
316            .map_err(|InvalidStatusCode| HeadersParseError::InvalidStatusCode)?;
317
318        Ok(Self::with_status_code(status_code))
319    }
320}
321
322#[cfg(test)]
323mod tests {
324    use super::*;
325
326    #[test]
327    fn parse_url() {
328        let request = SessionRequest::new("https://localhost:4433/foo/bar?p1=1&p2=2").unwrap();
329        assert_eq!(request.authority(), "localhost:4433");
330        assert_eq!(request.path(), "/foo/bar?p1=1&p2=2");
331        assert_eq!(request.get(":method").unwrap(), "CONNECT");
332        assert_eq!(request.get(":protocol").unwrap(), "webtransport");
333    }
334
335    #[test]
336    fn not_https() {
337        let error = SessionRequest::new("http://localhost:4433");
338        assert!(matches!(error, Err(UrlParseError::SchemeNotHttps)));
339    }
340
341    #[test]
342    fn parse_headers() {
343        assert!(SessionRequest::try_from(
344            [
345                (":method", "CONNECT"),
346                (":scheme", "https"),
347                (":protocol", "webtransport"),
348                (":authority", "localhost:4433"),
349                (":path", "/")
350            ]
351            .into_iter()
352            .collect::<Headers>()
353        )
354        .is_ok());
355    }
356
357    #[test]
358    fn parse_headers_error_method() {
359        assert!(matches!(
360            SessionRequest::try_from(
361                [
362                    (":scheme", "https"),
363                    (":protocol", "webtransport"),
364                    (":authority", "localhost:4433"),
365                    (":path", "/")
366                ]
367                .into_iter()
368                .collect::<Headers>()
369            ),
370            Err(HeadersParseError::MissingMethod),
371        ));
372
373        assert!(matches!(
374            SessionRequest::try_from(
375                [
376                    (":method", "GET"),
377                    (":scheme", "https"),
378                    (":protocol", "webtransport"),
379                    (":authority", "localhost:4433"),
380                    (":path", "/")
381                ]
382                .into_iter()
383                .collect::<Headers>()
384            ),
385            Err(HeadersParseError::MethodNotConnect),
386        ));
387    }
388
389    #[test]
390    fn parse_headers_error_scheme() {
391        assert!(matches!(
392            SessionRequest::try_from(
393                [
394                    (":method", "CONNECT"),
395                    (":protocol", "webtransport"),
396                    (":authority", "localhost:4433"),
397                    (":path", "/")
398                ]
399                .into_iter()
400                .collect::<Headers>()
401            ),
402            Err(HeadersParseError::MissingScheme),
403        ));
404
405        assert!(matches!(
406            SessionRequest::try_from(
407                [
408                    (":method", "CONNECT"),
409                    (":scheme", "http"),
410                    (":protocol", "webtransport"),
411                    (":authority", "localhost:4433"),
412                    (":path", "/")
413                ]
414                .into_iter()
415                .collect::<Headers>()
416            ),
417            Err(HeadersParseError::SchemeNotHttps),
418        ));
419    }
420
421    #[test]
422    fn insert() {
423        let mut request = SessionRequest::new("https://example.com").unwrap();
424        request.insert("version", "test").unwrap();
425        assert_eq!(request.get("version").unwrap(), "test");
426    }
427
428    #[test]
429    fn insert_reseved() {
430        let mut request = SessionRequest::new("https://example.com").unwrap();
431
432        assert!(matches!(
433            request.insert(":method", "GET"),
434            Err(ReservedHeader)
435        ));
436
437        assert!(matches!(
438            request.insert(":scheme", "ftp"),
439            Err(ReservedHeader)
440        ));
441
442        assert!(matches!(
443            request.insert(":protocol", "web"),
444            Err(ReservedHeader)
445        ));
446
447        assert!(matches!(
448            request.insert(":authority", "me"),
449            Err(ReservedHeader)
450        ));
451
452        assert!(matches!(
453            request.insert(":path", "example"),
454            Err(ReservedHeader)
455        ));
456    }
457}