rust_web_server/websocket/
mod.rs1#[cfg(test)]
42mod tests;
43
44use std::io::{Read, Write};
45
46use crate::header::Header;
47use crate::http::VERSION;
48use crate::request::Request;
49use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
50
51#[derive(Debug, PartialEq, Eq)]
53pub enum Frame {
54 Text(String),
56 Binary(Vec<u8>),
58 Ping(Vec<u8>),
60 Pong(Vec<u8>),
62 Close(Option<u16>, String),
64 Continuation { fin: bool, data: Vec<u8> },
66}
67
68pub struct WebSocket;
70
71impl WebSocket {
72 const MAGIC: &'static str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
73
74 pub fn is_upgrade_request(request: &Request) -> bool {
79 let upgrade = request.get_header("Upgrade".to_string());
80 let connection = request.get_header("Connection".to_string());
81 let key = request.get_header("Sec-WebSocket-Key".to_string());
82
83 upgrade.map(|h| h.value.to_lowercase() == "websocket").unwrap_or(false)
84 && connection.map(|h| h.value.to_lowercase().contains("upgrade")).unwrap_or(false)
85 && key.is_some()
86 }
87
88 pub fn handshake_response(request: &Request) -> Result<Response, String> {
95 let key_header = request.get_header("Sec-WebSocket-Key".to_string())
96 .ok_or_else(|| "missing Sec-WebSocket-Key header".to_string())?;
97 let accept = Self::accept_key(&key_header.value);
98
99 let mut response = Response {
100 http_version: VERSION.http_1_1.to_string(),
101 status_code: *STATUS_CODE_REASON_PHRASE.n101_switching_protocols.status_code,
102 reason_phrase: STATUS_CODE_REASON_PHRASE.n101_switching_protocols.reason_phrase.to_string(),
103 headers: vec![
104 Header { name: "Upgrade".to_string(), value: "websocket".to_string() },
105 Header { name: "Connection".to_string(), value: "Upgrade".to_string() },
106 Header { name: "Sec-WebSocket-Accept".to_string(), value: accept },
107 ],
108 content_range_list: vec![],
109 stream_file: None,
110 };
111
112 if let Some(proto) = request.get_header("Sec-WebSocket-Protocol".to_string()) {
113 response.headers.push(Header {
114 name: "Sec-WebSocket-Protocol".to_string(),
115 value: proto.value.split(',').next().unwrap_or("").trim().to_string(),
116 });
117 }
118
119 Ok(response)
120 }
121
122 pub fn accept_key(client_key: &str) -> String {
125 let mut data = client_key.as_bytes().to_vec();
126 data.extend_from_slice(Self::MAGIC.as_bytes());
127 let hash = sha1(&data);
128 base64_encode(&hash)
129 }
130
131 pub fn read_frame(stream: &mut impl Read) -> Result<Frame, String> {
138 let mut header = [0u8; 2];
139 read_exact(stream, &mut header)?;
140
141 let fin = (header[0] & 0x80) != 0;
142 let opcode = header[0] & 0x0F;
143 let masked = (header[1] & 0x80) != 0;
144 let payload_len_byte = (header[1] & 0x7F) as usize;
145
146 let payload_len: usize = match payload_len_byte {
147 126 => {
148 let mut ext = [0u8; 2];
149 read_exact(stream, &mut ext)?;
150 u16::from_be_bytes(ext) as usize
151 }
152 127 => {
153 let mut ext = [0u8; 8];
154 read_exact(stream, &mut ext)?;
155 u64::from_be_bytes(ext) as usize
156 }
157 n => n,
158 };
159
160 let mask_key = if masked {
161 let mut mk = [0u8; 4];
162 read_exact(stream, &mut mk)?;
163 Some(mk)
164 } else {
165 None
166 };
167
168 let mut payload = vec![0u8; payload_len];
169 if payload_len > 0 {
170 read_exact(stream, &mut payload)?;
171 }
172
173 if let Some(key) = mask_key {
174 for (i, byte) in payload.iter_mut().enumerate() {
175 *byte ^= key[i % 4];
176 }
177 }
178
179 let frame = match opcode {
180 0x0 => Frame::Continuation { fin, data: payload },
181 0x1 => {
182 let text = String::from_utf8(payload)
183 .map_err(|_| "text frame contains invalid UTF-8".to_string())?;
184 Frame::Text(text)
185 }
186 0x2 => Frame::Binary(payload),
187 0x8 => {
188 let code = if payload.len() >= 2 {
189 Some(u16::from_be_bytes([payload[0], payload[1]]))
190 } else {
191 None
192 };
193 let reason = if payload.len() > 2 {
194 String::from_utf8_lossy(&payload[2..]).into_owned()
195 } else {
196 String::new()
197 };
198 Frame::Close(code, reason)
199 }
200 0x9 => Frame::Ping(payload),
201 0xA => Frame::Pong(payload),
202 n => return Err(format!("unknown opcode: 0x{:X}", n)),
203 };
204
205 Ok(frame)
206 }
207
208 pub fn write_frame(stream: &mut impl Write, frame: Frame) -> Result<(), String> {
210 let (opcode, payload, fin) = match frame {
211 Frame::Text(s) => (0x1u8, s.into_bytes(), true),
212 Frame::Binary(b) => (0x2, b, true),
213 Frame::Ping(b) => (0x9, b, true),
214 Frame::Pong(b) => (0xA, b, true),
215 Frame::Close(code, reason) => {
216 let mut payload = Vec::new();
217 if let Some(c) = code {
218 payload.extend_from_slice(&c.to_be_bytes());
219 payload.extend_from_slice(reason.as_bytes());
220 }
221 (0x8, payload, true)
222 }
223 Frame::Continuation { fin, data } => (0x0, data, fin),
224 };
225
226 let fin_bit: u8 = if fin { 0x80 } else { 0x00 };
227 let byte0 = fin_bit | opcode;
228
229 let payload_len = payload.len();
230 let mut header = Vec::with_capacity(10);
231 header.push(byte0);
232 match payload_len {
233 0..=125 => header.push(payload_len as u8),
234 126..=65535 => {
235 header.push(126u8);
236 header.extend_from_slice(&(payload_len as u16).to_be_bytes());
237 }
238 _ => {
239 header.push(127u8);
240 header.extend_from_slice(&(payload_len as u64).to_be_bytes());
241 }
242 }
243
244 stream.write_all(&header).map_err(|e| format!("write error: {}", e))?;
245 if !payload.is_empty() {
246 stream.write_all(&payload).map_err(|e| format!("write error: {}", e))?;
247 }
248 stream.flush().map_err(|e| format!("flush error: {}", e))?;
249 Ok(())
250 }
251
252 pub fn send_text(stream: &mut impl Write, text: &str) -> Result<(), String> {
254 Self::write_frame(stream, Frame::Text(text.to_string()))
255 }
256
257 pub fn send_close(stream: &mut impl Write, code: u16, reason: &str) -> Result<(), String> {
259 Self::write_frame(stream, Frame::Close(Some(code), reason.to_string()))
260 }
261
262 pub fn send_pong(stream: &mut impl Write, payload: Vec<u8>) -> Result<(), String> {
264 Self::write_frame(stream, Frame::Pong(payload))
265 }
266}
267
268fn read_exact<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<(), String> {
271 r.read_exact(buf).map_err(|e| format!("read error: {}", e))
272}
273
274pub(crate) fn sha1(data: &[u8]) -> [u8; 20] {
276 let mut h = [0x67452301u32, 0xEFCDAB89, 0x98BADCFE, 0x10325476, 0xC3D2E1F0];
277
278 let msg_len = data.len();
279 let mut msg = data.to_vec();
280 msg.push(0x80);
281 while msg.len() % 64 != 56 {
282 msg.push(0x00);
283 }
284 msg.extend_from_slice(&((msg_len as u64) * 8).to_be_bytes());
285
286 for chunk in msg.chunks(64) {
287 let mut w = [0u32; 80];
288 for i in 0..16 {
289 w[i] = u32::from_be_bytes([chunk[i*4], chunk[i*4+1], chunk[i*4+2], chunk[i*4+3]]);
290 }
291 for i in 16..80 {
292 w[i] = (w[i-3] ^ w[i-8] ^ w[i-14] ^ w[i-16]).rotate_left(1);
293 }
294
295 let (mut a, mut b, mut c, mut d, mut e) = (h[0], h[1], h[2], h[3], h[4]);
296
297 for i in 0..80 {
298 let (f, k) = match i {
299 0..=19 => ((b & c) | (!b & d), 0x5A827999u32),
300 20..=39 => (b ^ c ^ d, 0x6ED9EBA1u32),
301 40..=59 => ((b & c) | (b & d) | (c & d), 0x8F1BBCDCu32),
302 _ => (b ^ c ^ d, 0xCA62C1D6u32),
303 };
304 let temp = a.rotate_left(5)
305 .wrapping_add(f)
306 .wrapping_add(e)
307 .wrapping_add(k)
308 .wrapping_add(w[i]);
309 e = d; d = c; c = b.rotate_left(30); b = a; a = temp;
310 }
311
312 h[0] = h[0].wrapping_add(a);
313 h[1] = h[1].wrapping_add(b);
314 h[2] = h[2].wrapping_add(c);
315 h[3] = h[3].wrapping_add(d);
316 h[4] = h[4].wrapping_add(e);
317 }
318
319 let mut out = [0u8; 20];
320 for (i, v) in h.iter().enumerate() {
321 out[i*4..(i+1)*4].copy_from_slice(&v.to_be_bytes());
322 }
323 out
324}
325
326pub(crate) fn base64_encode(data: &[u8]) -> String {
328 const T: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
329 let mut out = Vec::with_capacity((data.len() + 2) / 3 * 4);
330 for chunk in data.chunks(3) {
331 let b0 = chunk[0] as u32;
332 let b1 = if chunk.len() > 1 { chunk[1] as u32 } else { 0 };
333 let b2 = if chunk.len() > 2 { chunk[2] as u32 } else { 0 };
334 let n = (b0 << 16) | (b1 << 8) | b2;
335 out.push(T[((n >> 18) & 0x3F) as usize]);
336 out.push(T[((n >> 12) & 0x3F) as usize]);
337 out.push(if chunk.len() > 1 { T[((n >> 6) & 0x3F) as usize] } else { b'=' });
338 out.push(if chunk.len() > 2 { T[(n & 0x3F) as usize] } else { b'=' });
339 }
340 String::from_utf8(out).unwrap()
341}