Skip to main content

tokio_websockets/upgrade/
client_request.rs

1//! A [`Codec`] to parse client HTTP Upgrade handshakes and validate them.
2use std::str::FromStr;
3
4use base64::{Engine, engine::general_purpose::STANDARD};
5use bytes::{Buf, BytesMut};
6use http::{HeaderMap, header::SET_COOKIE};
7use httparse::Request;
8use tokio_util::codec::Decoder;
9
10use crate::{sha::digest, upgrade::Error};
11
12/// A static HTTP/1.1 101 Switching Protocols response up until the
13/// `Sec-WebSocket-Accept` header value.
14const SWITCHING_PROTOCOLS_BODY: &[u8] = b"HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: ";
15
16/// Returns whether an ASCII byte slice is contained in another one, ignoring
17/// captalization.
18fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
19    if needle.is_empty() {
20        return true;
21    }
22
23    while haystack.len() >= needle.len() {
24        if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
25            return true;
26        }
27
28        haystack = &haystack[1..];
29    }
30
31    false
32}
33
34/// A client's opening handshake.
35struct ClientRequest {
36    /// The SHA-1 digest of the `Sec-WebSocket-Key` header.
37    ws_accept: [u8; 20],
38}
39
40impl ClientRequest {
41    /// Parses the client's opening handshake.
42    ///
43    /// # Errors
44    ///
45    /// This method fails when a header required for the WebSocket protocol is
46    /// missing in the handshake.
47    pub fn parse<'a, F>(header: F) -> Result<Self, Error>
48    where
49        F: Fn(&'static str) -> Option<&'a str> + 'a,
50    {
51        let find_header = |name| header(name).ok_or(super::Error::MissingHeader(name));
52
53        let check_header = |name, expected, err| {
54            let actual = find_header(name)?;
55            if actual.eq_ignore_ascii_case(expected) {
56                Ok(())
57            } else {
58                Err(err)
59            }
60        };
61
62        let check_header_contains = |name, expected: &str, err| {
63            let actual = find_header(name)?;
64            if contains_ignore_ascii_case(actual.as_bytes(), expected.as_bytes()) {
65                Ok(())
66            } else {
67                Err(err)
68            }
69        };
70
71        check_header("Upgrade", "websocket", Error::UpgradeNotWebSocket)?;
72        check_header_contains("Connection", "Upgrade", Error::ConnectionNotUpgrade)?;
73        check_header(
74            "Sec-WebSocket-Version",
75            "13",
76            Error::UnsupportedWebSocketVersion,
77        )?;
78
79        let key = find_header("Sec-WebSocket-Key")?;
80        let ws_accept = digest(key.as_bytes());
81        Ok(Self { ws_accept })
82    }
83
84    /// Returns the value that the client expects to see in the server's
85    /// `Sec-WebSocket-Accept` header.
86    #[must_use]
87    pub fn ws_accept(&self) -> String {
88        STANDARD.encode(self.ws_accept)
89    }
90}
91
92/// A codec that implements a [`Decoder`] for HTTP/1.1 upgrade requests and
93/// yields the request and a HTTP/1.1 response to reply with.
94///
95/// It does not implement an [`Encoder`].
96///
97/// [`Encoder`]: tokio_util::codec::Encoder
98pub struct Codec<'a> {
99    /// List of headers to add to the Switching Protocols response.
100    pub response_headers: &'a HeaderMap,
101}
102
103impl Decoder for Codec<'_> {
104    type Error = crate::Error;
105    type Item = (http::Request<()>, Vec<u8>);
106
107    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<Self::Item>, Self::Error> {
108        let mut headers = [httparse::EMPTY_HEADER; 64];
109        let mut request = Request::new(&mut headers);
110        let status = request.parse(src).map_err(Error::Parsing)?;
111
112        if !status.is_complete() {
113            return Ok(None);
114        }
115
116        let request_len = status.unwrap();
117
118        let mut builder = http::request::Builder::new();
119        if let Some(m) = request.method {
120            let method =
121                http::method::Method::from_bytes(m.as_bytes()).expect("httparse method is valid");
122            builder = builder.method(method);
123        }
124
125        if let Some(uri) = request.path {
126            builder = builder.uri(uri);
127        }
128
129        match request.version {
130            Some(0) => builder = builder.version(http::Version::HTTP_10),
131            Some(1) => builder = builder.version(http::Version::HTTP_11),
132            _ => Err(Error::Parsing(httparse::Error::Version))?,
133        }
134
135        let mut header_map = http::HeaderMap::with_capacity(request.headers.len());
136
137        for header in request.headers {
138            let name = http::HeaderName::from_str(header.name)
139                .map_err(|_| Error::Parsing(httparse::Error::HeaderName))?;
140            let value = http::HeaderValue::from_bytes(header.value)
141                .map_err(|_| Error::Parsing(httparse::Error::HeaderValue))?;
142
143            header_map.insert(name, value);
144        }
145
146        // You have to build the request before you can assign headers: https://github.com/hyperium/http/issues/91
147        let mut request = builder
148            .body(())
149            .expect("httparse sees the request as valid");
150        *request.headers_mut() = header_map;
151
152        let ws_accept =
153            ClientRequest::parse(|name| request.headers().get(name).and_then(|h| h.to_str().ok()))?
154                .ws_accept();
155
156        src.advance(request_len);
157
158        // Preallocate the size without extra headers
159        let mut resp = Vec::with_capacity(SWITCHING_PROTOCOLS_BODY.len() + ws_accept.len() + 4);
160
161        resp.extend_from_slice(SWITCHING_PROTOCOLS_BODY);
162        resp.extend_from_slice(ws_accept.as_bytes());
163        resp.extend_from_slice(b"\r\n");
164
165        for name in self.response_headers.keys() {
166            let values = self.response_headers.get_all(name).iter();
167
168            if name == SET_COOKIE {
169                // Set-Cookie is treated differently because if multiple values are present,
170                // multiple header entries should be used rather than one
171                for value in values {
172                    resp.extend_from_slice(name.as_str().as_bytes());
173                    resp.extend_from_slice(b": ");
174                    resp.extend_from_slice(value.as_bytes());
175                    resp.extend_from_slice(b"\r\n");
176                }
177            } else {
178                // All other header values of the same key should be concatenated with a comma
179                resp.extend_from_slice(name.as_str().as_bytes());
180                resp.extend_from_slice(b": ");
181
182                let mut values = values.peekable();
183                while let Some(value) = values.next() {
184                    resp.extend_from_slice(value.as_bytes());
185
186                    if values.peek().is_some() {
187                        resp.push(b',');
188                    }
189                }
190
191                resp.extend_from_slice(b"\r\n");
192            }
193        }
194
195        resp.extend_from_slice(b"\r\n");
196
197        Ok(Some((request, resp)))
198    }
199}