rust_web_server/websocket/
mod.rs1#[cfg(test)]
42mod tests;
43
44use std::io::{Read, Write};
45
46use crate::header::Header;
47use crate::http::VERSION;
48use crate::request::Request;
49use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
50
51#[derive(Debug, PartialEq, Eq)]
53pub enum Frame {
54 Text(String),
56 Binary(Vec<u8>),
58 Ping(Vec<u8>),
60 Pong(Vec<u8>),
62 Close(Option<u16>, String),
64 Continuation { fin: bool, data: Vec<u8> },
66}
67
68pub struct WebSocket;
70
71impl WebSocket {
72 const MAGIC: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
73
74 pub fn is_upgrade_request(request: &Request) -> bool {
79 let upgrade = request.get_header("Upgrade".to_string());
80 let connection = request.get_header("Connection".to_string());
81 let key = request.get_header("Sec-WebSocket-Key".to_string());
82
83 upgrade.map(|h| h.value.to_lowercase() == "websocket").unwrap_or(false)
84 && connection.map(|h| h.value.to_lowercase().contains("upgrade")).unwrap_or(false)
85 && key.is_some()
86 }
87
88 pub fn handshake_response(request: &Request) -> Result<Response, String> {
95 let key_header = request.get_header("Sec-WebSocket-Key".to_string())
96 .ok_or_else(|| "missing Sec-WebSocket-Key header".to_string())?;
97 let accept = Self::accept_key(&key_header.value);
98
99 let mut response = Response {
100 http_version: VERSION.http_1_1.to_string(),
101 status_code: *STATUS_CODE_REASON_PHRASE.n101_switching_protocols.status_code,
102 reason_phrase: STATUS_CODE_REASON_PHRASE.n101_switching_protocols.reason_phrase.to_string(),
103 headers: vec![
104 Header { name: "Upgrade".to_string(), value: "websocket".to_string() },
105 Header { name: "Connection".to_string(), value: "Upgrade".to_string() },
106 Header { name: "Sec-WebSocket-Accept".to_string(), value: accept },
107 ],
108 content_range_list: vec![],
109 stream_file: None,
110 stream_pipe: None,
111 };
112
113 if let Some(proto) = request.get_header("Sec-WebSocket-Protocol".to_string()) {
114 response.headers.push(Header {
115 name: "Sec-WebSocket-Protocol".to_string(),
116 value: proto.value.split(',').next().unwrap_or("").trim().to_string(),
117 });
118 }
119
120 Ok(response)
121 }
122
123 pub fn accept_key(client_key: &str) -> String {
126 let mut data = client_key.as_bytes().to_vec();
127 data.extend_from_slice(Self::MAGIC.as_bytes());
128 let hash = sha1(&data);
129 base64_encode(&hash)
130 }
131
132 pub fn read_frame(stream: &mut impl Read) -> Result<Frame, String> {
139 let mut header = [0u8; 2];
140 read_exact(stream, &mut header)?;
141
142 let fin = (header[0] & 0x80) != 0;
143 let opcode = header[0] & 0x0F;
144 let masked = (header[1] & 0x80) != 0;
145 let payload_len_byte = (header[1] & 0x7F) as usize;
146
147 let payload_len: usize = match payload_len_byte {
148 126 => {
149 let mut ext = [0u8; 2];
150 read_exact(stream, &mut ext)?;
151 u16::from_be_bytes(ext) as usize
152 }
153 127 => {
154 let mut ext = [0u8; 8];
155 read_exact(stream, &mut ext)?;
156 u64::from_be_bytes(ext) as usize
157 }
158 n => n,
159 };
160
161 let mask_key = if masked {
162 let mut mk = [0u8; 4];
163 read_exact(stream, &mut mk)?;
164 Some(mk)
165 } else {
166 None
167 };
168
169 let mut payload = vec![0u8; payload_len];
170 if payload_len > 0 {
171 read_exact(stream, &mut payload)?;
172 }
173
174 if let Some(key) = mask_key {
175 for (i, byte) in payload.iter_mut().enumerate() {
176 *byte ^= key[i % 4];
177 }
178 }
179
180 let frame = match opcode {
181 0x0 => Frame::Continuation { fin, data: payload },
182 0x1 => {
183 let text = String::from_utf8(payload)
184 .map_err(|_| "text frame contains invalid UTF-8".to_string())?;
185 Frame::Text(text)
186 }
187 0x2 => Frame::Binary(payload),
188 0x8 => {
189 let code = if payload.len() >= 2 {
190 Some(u16::from_be_bytes([payload[0], payload[1]]))
191 } else {
192 None
193 };
194 let reason = if payload.len() > 2 {
195 String::from_utf8_lossy(&payload[2..]).into_owned()
196 } else {
197 String::new()
198 };
199 Frame::Close(code, reason)
200 }
201 0x9 => Frame::Ping(payload),
202 0xA => Frame::Pong(payload),
203 n => return Err(format!("unknown opcode: 0x{:X}", n)),
204 };
205
206 Ok(frame)
207 }
208
209 pub fn write_frame(stream: &mut impl Write, frame: Frame) -> Result<(), String> {
211 let (opcode, payload, fin) = match frame {
212 Frame::Text(s) => (0x1u8, s.into_bytes(), true),
213 Frame::Binary(b) => (0x2, b, true),
214 Frame::Ping(b) => (0x9, b, true),
215 Frame::Pong(b) => (0xA, b, true),
216 Frame::Close(code, reason) => {
217 let mut payload = Vec::new();
218 if let Some(c) = code {
219 payload.extend_from_slice(&c.to_be_bytes());
220 payload.extend_from_slice(reason.as_bytes());
221 }
222 (0x8, payload, true)
223 }
224 Frame::Continuation { fin, data } => (0x0, data, fin),
225 };
226
227 let fin_bit: u8 = if fin { 0x80 } else { 0x00 };
228 let byte0 = fin_bit | opcode;
229
230 let payload_len = payload.len();
231 let mut header = Vec::with_capacity(10);
232 header.push(byte0);
233 match payload_len {
234 0..=125 => header.push(payload_len as u8),
235 126..=65535 => {
236 header.push(126u8);
237 header.extend_from_slice(&(payload_len as u16).to_be_bytes());
238 }
239 _ => {
240 header.push(127u8);
241 header.extend_from_slice(&(payload_len as u64).to_be_bytes());
242 }
243 }
244
245 stream.write_all(&header).map_err(|e| format!("write error: {}", e))?;
246 if !payload.is_empty() {
247 stream.write_all(&payload).map_err(|e| format!("write error: {}", e))?;
248 }
249 stream.flush().map_err(|e| format!("flush error: {}", e))?;
250 Ok(())
251 }
252
253 pub fn send_text(stream: &mut impl Write, text: &str) -> Result<(), String> {
255 Self::write_frame(stream, Frame::Text(text.to_string()))
256 }
257
258 pub fn send_close(stream: &mut impl Write, code: u16, reason: &str) -> Result<(), String> {
260 Self::write_frame(stream, Frame::Close(Some(code), reason.to_string()))
261 }
262
263 pub fn send_pong(stream: &mut impl Write, payload: Vec<u8>) -> Result<(), String> {
265 Self::write_frame(stream, Frame::Pong(payload))
266 }
267}
268
269fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<(), String> {
272 r.read_exact(buf).map_err(|e| format!("read error: {}", e))
273}
274
275pub(crate) fn sha1(data: &[u8]) -> [u8; 20] {
277 let mut h = [0x67452301u32, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0];
278
279 let msg_len = data.len();
280 let mut msg = data.to_vec();
281 msg.push(0x80);
282 while msg.len() % 64 != 56 {
283 msg.push(0x00);
284 }
285 msg.extend_from_slice(&((msg_len as u64) * 8).to_be_bytes());
286
287 for chunk in msg.chunks(64) {
288 let mut w = [0u32; 80];
289 for i in 0..16 {
290 w[i] = u32::from_be_bytes([chunk[i*4], chunk[i*4+1], chunk[i*4+2], chunk[i*4+3]]);
291 }
292 for i in 16..80 {
293 w[i] = (w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]).rotate_left(1);
294 }
295
296 let (mut a, mut b, mut c, mut d, mut e) = (h[0], h[1], h[2], h[3], h[4]);
297
298 for i in 0..80 {
299 let (f, k) = match i {
300 0..=19 => ((b & c) | (!b & d), 0x5A827999u32),
301 20..=39 => (b ^ c ^ d, 0x6ED9EBA1u32),
302 40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDCu32),
303 _ => (b ^ c ^ d, 0xCA62C1D6u32),
304 };
305 let temp = a.rotate_left(5)
306 .wrapping_add(f)
307 .wrapping_add(e)
308 .wrapping_add(k)
309 .wrapping_add(w[i]);
310 e = d; d = c; c = b.rotate_left(30); b = a; a = temp;
311 }
312
313 h[0] = h[0].wrapping_add(a);
314 h[1] = h[1].wrapping_add(b);
315 h[2] = h[2].wrapping_add(c);
316 h[3] = h[3].wrapping_add(d);
317 h[4] = h[4].wrapping_add(e);
318 }
319
320 let mut out = [0u8; 20];
321 for (i, v) in h.iter().enumerate() {
322 out[i*4..(i+1)*4].copy_from_slice(&v.to_be_bytes());
323 }
324 out
325}
326
327pub(crate) fn base64_encode(data: &[u8]) -> String {
329 const T: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
330 let mut out = Vec::with_capacity((data.len() + 2) / 3 * 4);
331 for chunk in data.chunks(3) {
332 let b0 = chunk[0] as u32;
333 let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
334 let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
335 let n = (b0 << 16) | (b1 << 8) | b2;
336 out.push(T[((n >> 18) & 0x3F) as usize]);
337 out.push(T[((n >> 12) & 0x3F) as usize]);
338 out.push(if chunk.len() > 1 { T[((n >> 6) & 0x3F) as usize] } else { b'=' });
339 out.push(if chunk.len() > 2 { T[(n & 0x3F) as usize] } else { b'=' });
340 }
341 String::from_utf8(out).unwrap()
342}