websocket_codec/
upgrade.rs1use std::result;
2use std::str;
3
4use base64::display::Base64Display;
5use bytes::{Buf, BytesMut};
6use httparse::{self, Header, Response};
7use sha1::{self, Sha1};
8use tokio_util::codec::{Decoder, Encoder};
9
10use crate::{Error, Result};
11
12type Sha1Digest = [u8; sha1::DIGEST_LENGTH];
13
14fn build_ws_accept(key: &str) -> Sha1Digest {
15 let mut s = Sha1::new();
16 s.update(key.as_bytes());
17 s.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
18 s.digest().bytes()
19}
20
21fn header<'a, 'header: 'a>(headers: &'a [Header<'header>], name: &'a str) -> result::Result<&'header [u8], String> {
22 let header = headers
23 .iter()
24 .find(|header| header.name.eq_ignore_ascii_case(name))
25 .ok_or_else(|| format!("server didn't respond with {name} header", name = name))?;
26
27 Ok(header.value)
28}
29
30fn validate_server_response(expected_ws_accept: &Sha1Digest, data: &[u8]) -> Result<Option<usize>> {
31 let mut headers = [httparse::EMPTY_HEADER; 20];
32 let mut response = Response::new(&mut headers);
33 let status = response.parse(data)?;
34 if !status.is_complete() {
35 return Ok(None);
36 }
37
38 let response_len = status.unwrap();
39 let code = response.code.unwrap();
40 if code != 101 {
41 return Err(format!("server responded with HTTP error {code}", code = code).into());
42 }
43
44 let ws_accept_header = header(response.headers, "Sec-WebSocket-Accept")?;
45 let mut ws_accept = Sha1Digest::default();
46 base64::decode_config_slice(&ws_accept_header, base64::STANDARD, &mut ws_accept)?;
47 if expected_ws_accept != &ws_accept {
48 return Err(format!(
49 "server responded with incorrect Sec-WebSocket-Accept header: expected {expected}, got {actual}",
50 expected = Base64Display::with_config(expected_ws_accept, base64::STANDARD),
51 actual = Base64Display::with_config(&ws_accept, base64::STANDARD),
52 )
53 .into());
54 }
55
56 Ok(Some(response_len))
57}
58
59fn contains_ignore_ascii_case(mut haystack: &[u8], needle: &[u8]) -> bool {
60 if needle.is_empty() {
61 return true;
62 }
63
64 while haystack.len() >= needle.len() {
65 if haystack[..needle.len()].eq_ignore_ascii_case(needle) {
66 return true;
67 }
68
69 haystack = &haystack[1..];
70 }
71
72 false
73}
74
75pub struct ClientRequest {
77 ws_accept: Sha1Digest,
78}
79
80impl ClientRequest {
81 pub fn parse<'a, F>(header: F) -> Result<Self>
83 where
84 F: Fn(&'static str) -> Option<&'a str> + 'a,
85 {
86 let header = |name| header(name).ok_or_else(|| format!("client didn't provide {name} header", name = name));
87
88 let check_header = |name, expected| {
89 let actual = header(name)?;
90 if actual.eq_ignore_ascii_case(expected) {
91 Ok(())
92 } else {
93 Err(format!(
94 "client provided incorrect {name} header: expected {expected}, got {actual}",
95 name = name,
96 expected = expected,
97 actual = actual
98 ))
99 }
100 };
101
102 let check_header_contains = |name, expected: &str| {
103 let actual = header(name)?;
104 if contains_ignore_ascii_case(actual.as_bytes(), expected.as_bytes()) {
105 Ok(())
106 } else {
107 Err(format!(
108 "client provided incorrect {name} header: expected string containing {expected}, got {actual}",
109 name = name,
110 expected = expected,
111 actual = actual
112 ))
113 }
114 };
115
116 check_header("Upgrade", "websocket")?;
117 check_header_contains("Connection", "Upgrade")?;
118 check_header("Sec-WebSocket-Version", "13")?;
119
120 let key = header("Sec-WebSocket-Key")?;
121 let ws_accept = build_ws_accept(key);
122 Ok(Self { ws_accept })
123 }
124
125 pub fn ws_accept_buf(&self, s: &mut String) {
127 base64::encode_config_buf(&self.ws_accept, base64::STANDARD, s)
128 }
129
130 pub fn ws_accept(&self) -> String {
132 base64::encode_config(&self.ws_accept, base64::STANDARD)
133 }
134}
135
136pub struct UpgradeCodec {
138 ws_accept: Sha1Digest,
139}
140
141impl UpgradeCodec {
142 pub fn new(key: &str) -> Self {
146 UpgradeCodec {
147 ws_accept: build_ws_accept(key),
148 }
149 }
150}
151
152impl Decoder for UpgradeCodec {
153 type Item = ();
154 type Error = Error;
155
156 fn decode(&mut self, src: &mut BytesMut) -> Result<Option<()>> {
157 if let Some(response_len) = validate_server_response(&self.ws_accept, src)? {
158 src.advance(response_len);
159 Ok(Some(()))
160 } else {
161 Ok(None)
162 }
163 }
164}
165
166impl Encoder<()> for UpgradeCodec {
167 type Error = Error;
168
169 fn encode(&mut self, _item: (), _dst: &mut BytesMut) -> Result<()> {
170 unimplemented!()
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use crate::upgrade::contains_ignore_ascii_case;
177
178 #[test]
179 fn does_not_contain() {
180 assert!(!contains_ignore_ascii_case(b"World", b"hello"));
181 }
182
183 #[test]
184 fn contains_exact() {
185 assert!(contains_ignore_ascii_case(b"Hello", b"hello"));
186 }
187
188 #[test]
189 fn contains_substring() {
190 assert!(contains_ignore_ascii_case(b"Hello World", b"hello"));
191 }
192}