zerodds_websocket_bridge/
handshake.rs1use alloc::string::{String, ToString};
29use alloc::vec::Vec;
30
31pub const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
33
34pub const WEBSOCKET_VERSION: &str = "13";
36
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum HandshakeError {
40 MalformedRequest,
42 MissingKey,
44 NotWebSocketUpgrade,
46 NotUpgradeConnection,
48 UnsupportedVersion(String),
50 UnexpectedStatus(u16),
52 AcceptMismatch,
54}
55
56impl core::fmt::Display for HandshakeError {
57 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
58 match self {
59 Self::MalformedRequest => f.write_str("malformed handshake request"),
60 Self::MissingKey => f.write_str("missing Sec-WebSocket-Key"),
61 Self::NotWebSocketUpgrade => f.write_str("Upgrade header is not websocket"),
62 Self::NotUpgradeConnection => f.write_str("Connection header is not Upgrade"),
63 Self::UnsupportedVersion(v) => write!(f, "unsupported version: {v}"),
64 Self::UnexpectedStatus(s) => write!(f, "unexpected status: {s}"),
65 Self::AcceptMismatch => f.write_str("Sec-WebSocket-Accept mismatch"),
66 }
67 }
68}
69
70#[cfg(feature = "std")]
71impl std::error::Error for HandshakeError {}
72
73#[derive(Debug, Clone, PartialEq, Eq, Default)]
75pub struct ClientHandshake {
76 pub path: String,
78 pub host: String,
80 pub key: String,
82 pub protocols: Vec<String>,
84 pub extensions: Vec<String>,
86}
87
88#[derive(Debug, Clone, PartialEq, Eq, Default)]
90pub struct ServerHandshake {
91 pub status: u16,
93 pub accept: String,
95 pub protocol: Option<String>,
97 pub extensions: Vec<String>,
99}
100
101#[must_use]
103pub fn compute_accept(client_key: &str) -> String {
104 let mut concatenated = String::with_capacity(client_key.len() + WEBSOCKET_GUID.len());
105 concatenated.push_str(client_key.trim());
106 concatenated.push_str(WEBSOCKET_GUID);
107 let digest = sha1(concatenated.as_bytes());
108 base64_encode(&digest)
109}
110
111pub fn parse_client_request(input: &str) -> Result<ClientHandshake, HandshakeError> {
116 let mut lines = input.split("\r\n");
117 let request_line = lines.next().ok_or(HandshakeError::MalformedRequest)?;
118 let mut req_parts = request_line.split_whitespace();
119 let _method = req_parts.next().ok_or(HandshakeError::MalformedRequest)?;
120 let path = req_parts
121 .next()
122 .ok_or(HandshakeError::MalformedRequest)?
123 .to_string();
124
125 let mut hs = ClientHandshake {
126 path,
127 ..Default::default()
128 };
129 let mut upgrade_ok = false;
130 let mut connection_ok = false;
131 let mut version_seen = false;
132 for line in lines {
133 if line.is_empty() {
134 break;
135 }
136 let (k, v) = line
137 .split_once(':')
138 .ok_or(HandshakeError::MalformedRequest)?;
139 let k = k.trim().to_ascii_lowercase();
140 let v = v.trim();
141 match k.as_str() {
142 "host" => hs.host = v.to_string(),
143 "upgrade" => upgrade_ok = v.eq_ignore_ascii_case("websocket"),
144 "connection" => {
145 connection_ok = v
146 .split(',')
147 .any(|part| part.trim().eq_ignore_ascii_case("upgrade"));
148 }
149 "sec-websocket-key" => hs.key = v.to_string(),
150 "sec-websocket-version" => {
151 version_seen = true;
152 if v != WEBSOCKET_VERSION {
153 return Err(HandshakeError::UnsupportedVersion(v.to_string()));
154 }
155 }
156 "sec-websocket-protocol" => {
157 hs.protocols
158 .extend(v.split(',').map(|s| s.trim().to_string()));
159 }
160 "sec-websocket-extensions" => {
161 hs.extensions
162 .extend(v.split(',').map(|s| s.trim().to_string()));
163 }
164 _ => {}
165 }
166 }
167 if !upgrade_ok {
168 return Err(HandshakeError::NotWebSocketUpgrade);
169 }
170 if !connection_ok {
171 return Err(HandshakeError::NotUpgradeConnection);
172 }
173 if hs.key.is_empty() {
174 return Err(HandshakeError::MissingKey);
175 }
176 if !version_seen {
177 return Err(HandshakeError::UnsupportedVersion(String::new()));
178 }
179 Ok(hs)
180}
181
182#[must_use]
184pub fn build_server_response(req: &ClientHandshake) -> ServerHandshake {
185 ServerHandshake {
186 status: 101,
187 accept: compute_accept(&req.key),
188 protocol: req.protocols.first().cloned(),
189 extensions: req.extensions.clone(),
190 }
191}
192
193#[must_use]
195pub fn render_server_response(resp: &ServerHandshake) -> String {
196 let mut out = alloc::format!(
197 "HTTP/1.1 {} Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: {}\r\n",
198 resp.status,
199 resp.accept
200 );
201 if let Some(p) = &resp.protocol {
202 out.push_str(&alloc::format!("Sec-WebSocket-Protocol: {p}\r\n"));
203 }
204 if !resp.extensions.is_empty() {
205 out.push_str(&alloc::format!(
206 "Sec-WebSocket-Extensions: {}\r\n",
207 resp.extensions.join(", ")
208 ));
209 }
210 out.push_str("\r\n");
211 out
212}
213
214fn sha1(bytes: &[u8]) -> [u8; 20] {
220 let mut h: [u32; 5] = [
221 0x6745_2301,
222 0xEFCD_AB89,
223 0x98BA_DCFE,
224 0x1032_5476,
225 0xC3D2_E1F0,
226 ];
227 let bit_len = (bytes.len() as u64) * 8;
228 let mut msg = Vec::with_capacity(bytes.len() + 64);
229 msg.extend_from_slice(bytes);
230 msg.push(0x80);
231 while msg.len() % 64 != 56 {
232 msg.push(0);
233 }
234 msg.extend_from_slice(&bit_len.to_be_bytes());
235
236 for chunk in msg.chunks_exact(64) {
237 let mut w = [0u32; 80];
238 for (i, word) in chunk.chunks_exact(4).enumerate() {
239 w[i] = u32::from_be_bytes([word[0], word[1], word[2], word[3]]);
240 }
241 for i in 16..80 {
242 w[i] = (w[i - 3] ^ w[i - 8] ^ w[i - 14] ^ w[i - 16]).rotate_left(1);
243 }
244 let (mut a, mut b, mut c, mut d, mut e) = (h[0], h[1], h[2], h[3], h[4]);
245 for (i, &wv) in w.iter().enumerate() {
246 let (f, k) = match i {
247 0..=19 => ((b & c) | ((!b) & d), 0x5A82_7999),
248 20..=39 => (b ^ c ^ d, 0x6ED9_EBA1),
249 40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1B_BCDC),
250 _ => (b ^ c ^ d, 0xCA62_C1D6),
251 };
252 let temp = a
253 .rotate_left(5)
254 .wrapping_add(f)
255 .wrapping_add(e)
256 .wrapping_add(k)
257 .wrapping_add(wv);
258 e = d;
259 d = c;
260 c = b.rotate_left(30);
261 b = a;
262 a = temp;
263 }
264 h[0] = h[0].wrapping_add(a);
265 h[1] = h[1].wrapping_add(b);
266 h[2] = h[2].wrapping_add(c);
267 h[3] = h[3].wrapping_add(d);
268 h[4] = h[4].wrapping_add(e);
269 }
270 let mut out = [0u8; 20];
271 for (i, w) in h.iter().enumerate() {
272 out[i * 4..(i + 1) * 4].copy_from_slice(&w.to_be_bytes());
273 }
274 out
275}
276
277fn base64_encode(bytes: &[u8]) -> String {
279 const ALPHA: &[u8; 64] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
280 let mut out = String::with_capacity(bytes.len().div_ceil(3) * 4);
281 let mut chunks = bytes.chunks_exact(3);
282 for c in &mut chunks {
283 let v = (u32::from(c[0]) << 16) | (u32::from(c[1]) << 8) | u32::from(c[2]);
284 out.push(ALPHA[((v >> 18) & 0x3f) as usize] as char);
285 out.push(ALPHA[((v >> 12) & 0x3f) as usize] as char);
286 out.push(ALPHA[((v >> 6) & 0x3f) as usize] as char);
287 out.push(ALPHA[(v & 0x3f) as usize] as char);
288 }
289 let rem = chunks.remainder();
290 match rem.len() {
291 1 => {
292 let v = u32::from(rem[0]) << 16;
293 out.push(ALPHA[((v >> 18) & 0x3f) as usize] as char);
294 out.push(ALPHA[((v >> 12) & 0x3f) as usize] as char);
295 out.push('=');
296 out.push('=');
297 }
298 2 => {
299 let v = (u32::from(rem[0]) << 16) | (u32::from(rem[1]) << 8);
300 out.push(ALPHA[((v >> 18) & 0x3f) as usize] as char);
301 out.push(ALPHA[((v >> 12) & 0x3f) as usize] as char);
302 out.push(ALPHA[((v >> 6) & 0x3f) as usize] as char);
303 out.push('=');
304 }
305 _ => {}
306 }
307 out
308}
309
310#[cfg(test)]
311#[allow(clippy::expect_used, clippy::unwrap_used, clippy::panic)]
312mod tests {
313 use super::*;
314
315 #[test]
316 fn rfc6455_section_1_3_accept_test_vector() {
317 let key = "dGhlIHNhbXBsZSBub25jZQ==";
319 let expected = "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=";
320 assert_eq!(compute_accept(key), expected);
321 }
322
323 #[test]
324 fn parses_minimal_client_handshake() {
325 let req = "GET /chat HTTP/1.1\r\n\
326 Host: server.example.com\r\n\
327 Upgrade: websocket\r\n\
328 Connection: Upgrade\r\n\
329 Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n\
330 Sec-WebSocket-Version: 13\r\n\
331 \r\n";
332 let h = parse_client_request(req).unwrap();
333 assert_eq!(h.path, "/chat");
334 assert_eq!(h.host, "server.example.com");
335 assert_eq!(h.key, "dGhlIHNhbXBsZSBub25jZQ==");
336 }
337
338 #[test]
339 fn parses_protocols_and_extensions() {
340 let req = "GET / HTTP/1.1\r\n\
341 Host: x\r\n\
342 Upgrade: websocket\r\n\
343 Connection: Upgrade\r\n\
344 Sec-WebSocket-Key: a\r\n\
345 Sec-WebSocket-Version: 13\r\n\
346 Sec-WebSocket-Protocol: chat, superchat\r\n\
347 Sec-WebSocket-Extensions: permessage-deflate\r\n\
348 \r\n";
349 let h = parse_client_request(req).unwrap();
350 assert_eq!(
351 h.protocols,
352 alloc::vec!["chat".to_string(), "superchat".into()]
353 );
354 assert_eq!(h.extensions, alloc::vec!["permessage-deflate".to_string()]);
355 }
356
357 #[test]
358 fn rejects_missing_upgrade() {
359 let req = "GET / HTTP/1.1\r\n\
360 Connection: Upgrade\r\n\
361 Sec-WebSocket-Key: a\r\n\
362 Sec-WebSocket-Version: 13\r\n\
363 \r\n";
364 assert_eq!(
365 parse_client_request(req),
366 Err(HandshakeError::NotWebSocketUpgrade)
367 );
368 }
369
370 #[test]
371 fn rejects_wrong_version() {
372 let req = "GET / HTTP/1.1\r\n\
373 Upgrade: websocket\r\n\
374 Connection: Upgrade\r\n\
375 Sec-WebSocket-Key: a\r\n\
376 Sec-WebSocket-Version: 8\r\n\
377 \r\n";
378 assert!(matches!(
379 parse_client_request(req),
380 Err(HandshakeError::UnsupportedVersion(_))
381 ));
382 }
383
384 #[test]
385 fn rejects_missing_key() {
386 let req = "GET / HTTP/1.1\r\n\
387 Upgrade: websocket\r\n\
388 Connection: Upgrade\r\n\
389 Sec-WebSocket-Version: 13\r\n\
390 \r\n";
391 assert_eq!(parse_client_request(req), Err(HandshakeError::MissingKey));
392 }
393
394 #[test]
395 fn server_response_includes_accept() {
396 let req = ClientHandshake {
397 key: "dGhlIHNhbXBsZSBub25jZQ==".into(),
398 ..Default::default()
399 };
400 let resp = build_server_response(&req);
401 assert_eq!(resp.status, 101);
402 assert_eq!(resp.accept, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=");
403 }
404
405 #[test]
406 fn render_server_response_format() {
407 let resp = ServerHandshake {
408 status: 101,
409 accept: "abc".into(),
410 protocol: Some("chat".into()),
411 extensions: alloc::vec![],
412 };
413 let s = render_server_response(&resp);
414 assert!(s.contains("HTTP/1.1 101"));
415 assert!(s.contains("Upgrade: websocket"));
416 assert!(s.contains("Sec-WebSocket-Accept: abc"));
417 assert!(s.contains("Sec-WebSocket-Protocol: chat"));
418 }
419
420 #[test]
421 fn base64_round_trip_known_vectors() {
422 assert_eq!(base64_encode(b""), "");
424 assert_eq!(base64_encode(b"f"), "Zg==");
425 assert_eq!(base64_encode(b"fo"), "Zm8=");
426 assert_eq!(base64_encode(b"foo"), "Zm9v");
427 assert_eq!(base64_encode(b"foobar"), "Zm9vYmFy");
428 }
429
430 #[test]
431 fn sha1_known_vector_abc() {
432 let h = sha1(b"abc");
434 let expected: [u8; 20] = [
435 0xa9, 0x99, 0x3e, 0x36, 0x47, 0x06, 0x81, 0x6a, 0xba, 0x3e, 0x25, 0x71, 0x78, 0x50,
436 0xc2, 0x6c, 0x9c, 0xd0, 0xd8, 0x9d,
437 ];
438 assert_eq!(h, expected);
439 }
440
441 #[test]
442 fn connection_header_with_keep_alive_still_detects_upgrade() {
443 let req = "GET / HTTP/1.1\r\n\
447 Host: x\r\n\
448 Upgrade: WebSocket\r\n\
449 Connection: keep-alive, Upgrade\r\n\
450 Sec-WebSocket-Key: a\r\n\
451 Sec-WebSocket-Version: 13\r\n\
452 \r\n";
453 assert!(parse_client_request(req).is_ok());
454 }
455}