Skip to main content

rust_web_server/websocket/
mod.rs

1//! WebSocket protocol support — RFC 6455.
2//!
3//! Provides the building blocks for adding WebSocket endpoints to a server:
4//! upgrade detection, the opening handshake, and frame-level read/write.
5//!
6//! # Integration pattern
7//!
8//! WebSocket connections require taking over the raw TCP stream after sending
9//! the 101 response, so they cannot be handled inside a normal
10//! [`Controller::process`](crate::controller::Controller) call (which has no
11//! access to the stream).  The recommended pattern is to bypass
12//! [`Server::run`](crate::server::Server) for upgraded connections and drive
13//! them with your own accept loop:
14//!
15//! ```rust,no_run
16//! use std::net::TcpListener;
17//! use rust_web_server::websocket::{WebSocket, Frame};
18//! use rust_web_server::request::Request;
19//! use rust_web_server::response::Response;
20//!
21//! let listener = TcpListener::bind("127.0.0.1:7878").unwrap();
22//! for stream in listener.incoming() {
23//!     let mut stream = stream.unwrap();
24//!     // peek / read request bytes, decide if it's a WS upgrade
25//!     // (simplified — real code should read the full request first)
26//!     let raw = vec![0u8; 4096];
27//!     // ... parse into Request, then:
28//!     //
29//!     // if WebSocket::is_upgrade_request(&request) {
30//!     //     let response = WebSocket::handshake_response(&request).unwrap();
31//!     //     // write the 101 response to stream, then frame loop
32//!     //     loop {
33//!     //         match WebSocket::read_frame(&mut stream) { ... }
34//!     //     }
35//!     // } else {
36//!     //     // normal Server::process
37//!     // }
38//! }
39//! ```
40
41#[cfg(test)]
42mod tests;
43
44use std::io::{Read, Write};
45
46use crate::header::Header;
47use crate::http::VERSION;
48use crate::request::Request;
49use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
50
51/// A decoded WebSocket frame.
52#[derive(Debug, PartialEq, Eq)]
53pub enum Frame {
54    /// A complete UTF-8 text message (opcode 0x1).
55    Text(String),
56    /// A binary message (opcode 0x2).
57    Binary(Vec<u8>),
58    /// A ping control frame (opcode 0x9). Server should respond with Pong.
59    Ping(Vec<u8>),
60    /// A pong control frame (opcode 0xA).
61    Pong(Vec<u8>),
62    /// A close control frame (opcode 0x8) with an optional status code and reason.
63    Close(Option<u16>, String),
64    /// A continuation fragment (opcode 0x0).
65    Continuation { fin: bool, data: Vec<u8> },
66}
67
68/// WebSocket protocol utilities.
69pub struct WebSocket;
70
71impl WebSocket {
72    const MAGIC: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
73
74    // ── Handshake ─────────────────────────────────────────────────────────────
75
76    /// Returns `true` if `request` is a valid WebSocket upgrade request
77    /// (has `Upgrade: websocket`, `Connection: upgrade`, and `Sec-WebSocket-Key`).
78    pub fn is_upgrade_request(request: &Request) -> bool {
79        let upgrade = request.get_header("Upgrade".to_string());
80        let connection = request.get_header("Connection".to_string());
81        let key = request.get_header("Sec-WebSocket-Key".to_string());
82
83        upgrade.map(|h| h.value.to_lowercase() == "websocket").unwrap_or(false)
84            && connection.map(|h| h.value.to_lowercase().contains("upgrade")).unwrap_or(false)
85            && key.is_some()
86    }
87
88    /// Build the HTTP `101 Switching Protocols` response for a WebSocket
89    /// opening handshake. Returns an error if `Sec-WebSocket-Key` is absent.
90    ///
91    /// Write the raw bytes from [`Response::generate_response`] to the stream,
92    /// then transition to frame-level I/O with [`WebSocket::read_frame`] /
93    /// [`WebSocket::write_frame`].
94    pub fn handshake_response(request: &Request) -> Result<Response, String> {
95        let key_header = request.get_header("Sec-WebSocket-Key".to_string())
96            .ok_or_else(|| "missing Sec-WebSocket-Key header".to_string())?;
97        let accept = Self::accept_key(&key_header.value);
98
99        let mut response = Response {
100            http_version: VERSION.http_1_1.to_string(),
101            status_code: *STATUS_CODE_REASON_PHRASE.n101_switching_protocols.status_code,
102            reason_phrase: STATUS_CODE_REASON_PHRASE.n101_switching_protocols.reason_phrase.to_string(),
103            headers: vec![
104                Header { name: "Upgrade".to_string(),              value: "websocket".to_string() },
105                Header { name: "Connection".to_string(),           value: "Upgrade".to_string() },
106                Header { name: "Sec-WebSocket-Accept".to_string(), value: accept },
107            ],
108            content_range_list: vec![],
109            stream_file: None,
110        };
111
112        if let Some(proto) = request.get_header("Sec-WebSocket-Protocol".to_string()) {
113            response.headers.push(Header {
114                name: "Sec-WebSocket-Protocol".to_string(),
115                value: proto.value.split(',').next().unwrap_or("").trim().to_string(),
116            });
117        }
118
119        Ok(response)
120    }
121
122    /// Compute the `Sec-WebSocket-Accept` value from the client's
123    /// `Sec-WebSocket-Key` using SHA-1 and base64 as specified in RFC 6455.
124    pub fn accept_key(client_key: &str) -> String {
125        let mut data = client_key.as_bytes().to_vec();
126        data.extend_from_slice(Self::MAGIC.as_bytes());
127        let hash = sha1(&data);
128        base64_encode(&hash)
129    }
130
131    // ── Frame I/O ─────────────────────────────────────────────────────────────
132
133    /// Read one WebSocket frame from `stream`.
134    ///
135    /// Handles client-to-server masking automatically. Returns an error if the
136    /// stream closes unexpectedly or contains a protocol violation.
137    pub fn read_frame(stream: &mut impl Read) -> Result<Frame, String> {
138        let mut header = [0u8; 2];
139        read_exact(stream, &mut header)?;
140
141        let fin = (header[0] & 0x80) != 0;
142        let opcode = header[0] & 0x0F;
143        let masked = (header[1] & 0x80) != 0;
144        let payload_len_byte = (header[1] & 0x7F) as usize;
145
146        let payload_len: usize = match payload_len_byte {
147            126 => {
148                let mut ext = [0u8; 2];
149                read_exact(stream, &mut ext)?;
150                u16::from_be_bytes(ext) as usize
151            }
152            127 => {
153                let mut ext = [0u8; 8];
154                read_exact(stream, &mut ext)?;
155                u64::from_be_bytes(ext) as usize
156            }
157            n => n,
158        };
159
160        let mask_key = if masked {
161            let mut mk = [0u8; 4];
162            read_exact(stream, &mut mk)?;
163            Some(mk)
164        } else {
165            None
166        };
167
168        let mut payload = vec![0u8; payload_len];
169        if payload_len > 0 {
170            read_exact(stream, &mut payload)?;
171        }
172
173        if let Some(key) = mask_key {
174            for (i, byte) in payload.iter_mut().enumerate() {
175                *byte ^= key[i % 4];
176            }
177        }
178
179        let frame = match opcode {
180            0x0 => Frame::Continuation { fin, data: payload },
181            0x1 => {
182                let text = String::from_utf8(payload)
183                    .map_err(|_| "text frame contains invalid UTF-8".to_string())?;
184                Frame::Text(text)
185            }
186            0x2 => Frame::Binary(payload),
187            0x8 => {
188                let code = if payload.len() >= 2 {
189                    Some(u16::from_be_bytes([payload[0], payload[1]]))
190                } else {
191                    None
192                };
193                let reason = if payload.len() > 2 {
194                    String::from_utf8_lossy(&payload[2..]).into_owned()
195                } else {
196                    String::new()
197                };
198                Frame::Close(code, reason)
199            }
200            0x9 => Frame::Ping(payload),
201            0xA => Frame::Pong(payload),
202            n   => return Err(format!("unknown opcode: 0x{:X}", n)),
203        };
204
205        Ok(frame)
206    }
207
208    /// Write a WebSocket frame to `stream` (server→client, unmasked).
209    pub fn write_frame(stream: &mut impl Write, frame: Frame) -> Result<(), String> {
210        let (opcode, payload, fin) = match frame {
211            Frame::Text(s)          => (0x1u8, s.into_bytes(), true),
212            Frame::Binary(b)        => (0x2, b, true),
213            Frame::Ping(b)          => (0x9, b, true),
214            Frame::Pong(b)          => (0xA, b, true),
215            Frame::Close(code, reason) => {
216                let mut payload = Vec::new();
217                if let Some(c) = code {
218                    payload.extend_from_slice(&c.to_be_bytes());
219                    payload.extend_from_slice(reason.as_bytes());
220                }
221                (0x8, payload, true)
222            }
223            Frame::Continuation { fin, data } => (0x0, data, fin),
224        };
225
226        let fin_bit: u8 = if fin { 0x80 } else { 0x00 };
227        let byte0 = fin_bit | opcode;
228
229        let payload_len = payload.len();
230        let mut header = Vec::with_capacity(10);
231        header.push(byte0);
232        match payload_len {
233            0..=125 => header.push(payload_len as u8),
234            126..=65535 => {
235                header.push(126u8);
236                header.extend_from_slice(&(payload_len as u16).to_be_bytes());
237            }
238            _ => {
239                header.push(127u8);
240                header.extend_from_slice(&(payload_len as u64).to_be_bytes());
241            }
242        }
243
244        stream.write_all(&header).map_err(|e| format!("write error: {}", e))?;
245        if !payload.is_empty() {
246            stream.write_all(&payload).map_err(|e| format!("write error: {}", e))?;
247        }
248        stream.flush().map_err(|e| format!("flush error: {}", e))?;
249        Ok(())
250    }
251
252    /// Convenience: send a text message.
253    pub fn send_text(stream: &mut impl Write, text: &str) -> Result<(), String> {
254        Self::write_frame(stream, Frame::Text(text.to_string()))
255    }
256
257    /// Convenience: send a close frame with a status code and reason.
258    pub fn send_close(stream: &mut impl Write, code: u16, reason: &str) -> Result<(), String> {
259        Self::write_frame(stream, Frame::Close(Some(code), reason.to_string()))
260    }
261
262    /// Convenience: reply to a ping with a pong carrying the same payload.
263    pub fn send_pong(stream: &mut impl Write, payload: Vec<u8>) -> Result<(), String> {
264        Self::write_frame(stream, Frame::Pong(payload))
265    }
266}
267
268// ── Internal utilities ─────────────────────────────────────────────────────────
269
270fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<(), String> {
271    r.read_exact(buf).map_err(|e| format!("read error: {}", e))
272}
273
274/// SHA-1 digest of `data` (FIPS 180-4). Used for the WebSocket handshake only.
275pub(crate) fn sha1(data: &[u8]) -> [u8; 20] {
276    let mut h = [0x67452301u32, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0];
277
278    let msg_len = data.len();
279    let mut msg = data.to_vec();
280    msg.push(0x80);
281    while msg.len() % 64 != 56 {
282        msg.push(0x00);
283    }
284    msg.extend_from_slice(&((msg_len as u64) * 8).to_be_bytes());
285
286    for chunk in msg.chunks(64) {
287        let mut w = [0u32; 80];
288        for i in 0..16 {
289            w[i] = u32::from_be_bytes([chunk[i*4], chunk[i*4+1], chunk[i*4+2], chunk[i*4+3]]);
290        }
291        for i in 16..80 {
292            w[i] = (w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]).rotate_left(1);
293        }
294
295        let (mut a, mut b, mut c, mut d, mut e) = (h[0], h[1], h[2], h[3], h[4]);
296
297        for i in 0..80 {
298            let (f, k) = match i {
299                0..=19  => ((b & c) | (!b & d),          0x5A827999u32),
300                20..=39 => (b ^ c ^ d,                    0x6ED9EBA1u32),
301                40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDCu32),
302                _       => (b ^ c ^ d,                    0xCA62C1D6u32),
303            };
304            let temp = a.rotate_left(5)
305                .wrapping_add(f)
306                .wrapping_add(e)
307                .wrapping_add(k)
308                .wrapping_add(w[i]);
309            e = d; d = c; c = b.rotate_left(30); b = a; a = temp;
310        }
311
312        h[0] = h[0].wrapping_add(a);
313        h[1] = h[1].wrapping_add(b);
314        h[2] = h[2].wrapping_add(c);
315        h[3] = h[3].wrapping_add(d);
316        h[4] = h[4].wrapping_add(e);
317    }
318
319    let mut out = [0u8; 20];
320    for (i, v) in h.iter().enumerate() {
321        out[i*4..(i+1)*4].copy_from_slice(&v.to_be_bytes());
322    }
323    out
324}
325
326/// Standard base64 encoding (RFC 4648 Table 1, with `=` padding).
327pub(crate) fn base64_encode(data: &[u8]) -> String {
328    const T: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
329    let mut out = Vec::with_capacity((data.len() + 2) / 3 * 4);
330    for chunk in data.chunks(3) {
331        let b0 = chunk[0] as u32;
332        let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
333        let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
334        let n = (b0 << 16) | (b1 << 8) | b2;
335        out.push(T[((n >> 18) & 0x3F) as usize]);
336        out.push(T[((n >> 12) & 0x3F) as usize]);
337        out.push(if chunk.len() > 1 { T[((n >> 6) & 0x3F) as usize] } else { b'=' });
338        out.push(if chunk.len() > 2 { T[(n & 0x3F) as usize] } else { b'=' });
339    }
340    String::from_utf8(out).unwrap()
341}