websocket_codec/
upgrade.rs

1use std::result;
2use std::str;
3
4use base64::display::Base64Display;
5use bytes::{Buf, BytesMut};
6use httparse::{self, Header, Response};
7use sha1::{self, Sha1};
8use tokio_util::codec::{Decoder, Encoder};
9
10use crate::{Error, Result};
11
12type Sha1Digest = [u8; sha1::DIGEST_LENGTH];
13
14fn build_ws_accept(key: &str) -> Sha1Digest {
15    let mut s = Sha1::new();
16    s.update(key.as_bytes());
17    s.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
18    s.digest().bytes()
19}
20
21fn header<'a, 'header: 'a>(headers: &'a [Header<'header>], name: &'a str) -> result::Result<&'header [u8], String> {
22    let header = headers
23        .iter()
24        .find(|header| header.name.eq_ignore_ascii_case(name))
25        .ok_or_else(|| format!("server didn't respond with {name} header", name = name))?;
26
27    Ok(header.value)
28}
29
30fn validate_server_response(expected_ws_accept: &Sha1Digest, data: &[u8]) -> Result<Option<usize>> {
31    let mut headers = [httparse::EMPTY_HEADER; 20];
32    let mut response = Response::new(&mut headers);
33    let status = response.parse(data)?;
34    if !status.is_complete() {
35        return Ok(None);
36    }
37
38    let response_len = status.unwrap();
39    let code = response.code.unwrap();
40    if code != 101 {
41        return Err(format!("server responded with HTTP error {code}", code = code).into());
42    }
43
44    let ws_accept_header = header(response.headers, "Sec-WebSocket-Accept")?;
45    let mut ws_accept = Sha1Digest::default();
46    base64::decode_config_slice(&ws_accept_header, base64::STANDARD, &mut ws_accept)?;
47    if expected_ws_accept != &ws_accept {
48        return Err(format!(
49            "server responded with incorrect Sec-WebSocket-Accept header: expected {expected}, got {actual}",
50            expected = Base64Display::with_config(expected_ws_accept, base64::STANDARD),
51            actual = Base64Display::with_config(&ws_accept, base64::STANDARD),
52        )
53        .into());
54    }
55
56    Ok(Some(response_len))
57}
58
59fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
60    if needle.is_empty() {
61        return true;
62    }
63
64    while haystack.len() >= needle.len() {
65        if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
66            return true;
67        }
68
69        haystack = &haystack[1..];
70    }
71
72    false
73}
74
75/// A client's opening handshake.
76pub struct ClientRequest {
77    ws_accept: Sha1Digest,
78}
79
80impl ClientRequest {
81    /// Parses the client's opening handshake.
82    pub fn parse<'a, F>(header: F) -> Result<Self>
83    where
84        F: Fn(&'static str) -> Option<&'a str> + 'a,
85    {
86        let header = |name| header(name).ok_or_else(|| format!("client didn't provide {name} header", name = name));
87
88        let check_header = |name, expected| {
89            let actual = header(name)?;
90            if actual.eq_ignore_ascii_case(expected) {
91                Ok(())
92            } else {
93                Err(format!(
94                    "client provided incorrect {name} header: expected {expected}, got {actual}",
95                    name = name,
96                    expected = expected,
97                    actual = actual
98                ))
99            }
100        };
101
102        let check_header_contains = |name, expected: &str| {
103            let actual = header(name)?;
104            if contains_ignore_ascii_case(actual.as_bytes(), expected.as_bytes()) {
105                Ok(())
106            } else {
107                Err(format!(
108                    "client provided incorrect {name} header: expected string containing {expected}, got {actual}",
109                    name = name,
110                    expected = expected,
111                    actual = actual
112                ))
113            }
114        };
115
116        check_header("Upgrade", "websocket")?;
117        check_header_contains("Connection", "Upgrade")?;
118        check_header("Sec-WebSocket-Version", "13")?;
119
120        let key = header("Sec-WebSocket-Key")?;
121        let ws_accept = build_ws_accept(key);
122        Ok(Self { ws_accept })
123    }
124
125    /// Copies the value that the client expects to see in the server's `Sec-WebSocket-Accept` header into a `String`.
126    pub fn ws_accept_buf(&self, s: &mut String) {
127        base64::encode_config_buf(&self.ws_accept, base64::STANDARD, s)
128    }
129
130    /// Returns the value that the client expects to see in the server's `Sec-WebSocket-Accept` header.
131    pub fn ws_accept(&self) -> String {
132        base64::encode_config(&self.ws_accept, base64::STANDARD)
133    }
134}
135
136/// Tokio decoder for parsing the server's response to the client's HTTP `Connection: Upgrade` request.
137pub struct UpgradeCodec {
138    ws_accept: Sha1Digest,
139}
140
141impl UpgradeCodec {
142    /// Returns a new `UpgradeCodec` object.
143    ///
144    /// The `key` parameter provides the string passed to the server via the HTTP `Sec-WebSocket-Key` header.
145    pub fn new(key: &str) -> Self {
146        UpgradeCodec {
147            ws_accept: build_ws_accept(key),
148        }
149    }
150}
151
152impl Decoder for UpgradeCodec {
153    type Item = ();
154    type Error = Error;
155
156    fn decode(&mut self, src: &mut BytesMut) -> Result<Option<()>> {
157        if let Some(response_len) = validate_server_response(&self.ws_accept, src)? {
158            src.advance(response_len);
159            Ok(Some(()))
160        } else {
161            Ok(None)
162        }
163    }
164}
165
166impl Encoder<()> for UpgradeCodec {
167    type Error = Error;
168
169    fn encode(&mut self, _item: (), _dst: &mut BytesMut) -> Result<()> {
170        unimplemented!()
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use crate::upgrade::contains_ignore_ascii_case;
177
178    #[test]
179    fn does_not_contain() {
180        assert!(!contains_ignore_ascii_case(b"World", b"hello"));
181    }
182
183    #[test]
184    fn contains_exact() {
185        assert!(contains_ignore_ascii_case(b"Hello", b"hello"));
186    }
187
188    #[test]
189    fn contains_substring() {
190        assert!(contains_ignore_ascii_case(b"Hello World", b"hello"));
191    }
192}