1use super::{
2 BodyContent, MultiMap, Request, Response, ResponseChunkMeta, ResponseRangeMeta, WsRouter,
3 WsRouterValue,
4};
5use super::{BodyType, ConnectionData};
6use sha1::{Digest, Sha1};
7use std::cell::RefCell;
8use std::collections::BTreeMap;
9use std::collections::HashMap;
10use std::io::prelude::*;
11use std::io::{BufReader, BufWriter};
12use std::net::TcpStream;
13use std::ops::Deref;
14use std::rc::Rc;
15use std::sync::mpsc;
16use std::sync::{Arc, Mutex};
17use std::thread;
18
19pub(crate) fn exist_pair<'a>(k: &str, header: &'a HashMap<&str, &str>) -> (bool, Option<&'a str>) {
20 let ak = header.keys().find(|&&ak| {
21 if k.to_lowercase() == ak.to_lowercase() {
22 true
23 } else {
24 false
25 }
26 });
27 match ak {
28 Some(&s) => {
29 return (true, Some(header.get(s).unwrap()));
30 }
31 None => {
32 return (false, None);
33 }
34 }
35}
36
37pub(crate) fn is_websocket_upgrade(
38 method: &str,
39 header: &HashMap<&str, &str>,
40) -> (bool, String, String) {
41 if method.to_lowercase() == "get" {
42 let upgrade = exist_pair("Upgrade", header);
43 let connection = exist_pair("Connection", header);
44 let sec_websocket_key = exist_pair("Sec-WebSocket-Key", header);
45 let version = exist_pair("Sec-WebSocket-Version", &header);
46 if upgrade.0 == true
47 && connection.0 == true
48 && sec_websocket_key.0 == true
49 && version.0 == true
50 {
51 if "websocket" == upgrade.1.unwrap().to_lowercase()
52 && "upgrade" == connection.1.unwrap().to_lowercase()
53 {
54 return (
55 true,
56 version.1.unwrap().trim().to_string(),
57 sec_websocket_key.1.unwrap().trim().to_string(),
58 );
59 }
60 }
61 }
62 (false, String::new(), String::new())
63}
64
65#[derive(Debug)]
66pub enum WsMessage {
67 Open,
68 Message(Vec<u8>, u8),
69 Close,
70}
71
72pub struct Websocket {
73 conn: Arc<TcpStream>,
74 write_mutex: Arc<Mutex<()>>,
75 fragment_size: usize,
76}
77impl Websocket {
78 pub fn clone(&self) -> Self {
80 Websocket {
81 conn: Arc::clone(&self.conn),
82 write_mutex: Arc::clone(&self.write_mutex),
83 fragment_size: self.fragment_size,
84 }
85 }
86 pub fn write(&self, data: Vec<u8>, opcode: u8) {
89 let lock_guard = self.write_mutex.lock().unwrap();
90 let len = data.len();
92 let mut writer = BufWriter::new(self.conn.as_ref());
94 if len <= 125 {
95 let first_byte = 0b10000000u8 | opcode;
96 let second_byte = len as u8;
97 let mut buffs = Vec::new();
98 buffs.push(first_byte);
99 buffs.push(second_byte);
100 buffs.extend_from_slice(&data);
101 match writer.write_all(&buffs) {
102 Ok(_) => {
103 println!("write ok!");
104 }
105 Err(_) => {
106 let _ = self.conn.shutdown(std::net::Shutdown::Both);
107 return;
108 }
109 }
110 } else if len >= 126 {
111 let fragment_size = self.fragment_size;
112 let mut fragment_count = {
113 let mut count = len / fragment_size;
114 if (len % fragment_size) != 0 {
115 count += 1;
116 }
117 count
118 };
119 let mut start_pos = 0usize;
121 while fragment_count > 0 {
122 let mut buffs = Vec::new();
123 fragment_count -= 1;
124 let mut end_pos = start_pos + fragment_size;
125 if end_pos >= len {
126 end_pos = len;
127 }
128 let fragment_data = &data[start_pos..end_pos];
129 let first_byte = {
130 if start_pos == 0 {
131 if fragment_count > 0 {
133 0b00000000u8 | opcode
135 } else {
136 0b10000000u8 | opcode }
138 } else {
139 if fragment_count > 0 {
140 0
142 } else {
143 0b10000000u8
145 }
146 }
147 };
148 buffs.push(first_byte);
150 let actual_write_size = end_pos - start_pos;
151 if actual_write_size >= 126 {
152 if actual_write_size <= u16::MAX as usize {
153 let second_byte = 126u8;
154 let payload = (actual_write_size as u16).to_be_bytes();
155 buffs.push(second_byte);
156 buffs.extend_from_slice(&payload);
157 } else if actual_write_size as u64 <= u64::MAX {
158 let second_byte = 127u8;
159 let payload = (actual_write_size as u64).to_be_bytes();
160 buffs.push(second_byte);
161 buffs.extend_from_slice(&payload);
162 }
163 } else {
164 let second_byte = actual_write_size as u8;
165 buffs.push(second_byte);
166 }
167 buffs.extend_from_slice(fragment_data);
169 match writer.write_all(&buffs) {
170 Ok(_) => {
171 start_pos = end_pos;
172 continue;
174 }
175 Err(_) => {
176 let _ = self.conn.shutdown(std::net::Shutdown::Both);
177 return;
178 }
179 }
180 }
181 }
182 match writer.flush() {
183 Ok(_) => {}
184 Err(_) => {
185 let _ = self.conn.shutdown(std::net::Shutdown::Both);
186 return;
187 }
188 }
189 drop(lock_guard);
191 }
192
193 pub fn write_string(&self, s: &str) {
196 let mut vec = Vec::new();
197 vec.extend_from_slice(s.as_bytes());
198 self.write(vec, 1);
199 }
200 pub fn write_binary(&self, data: Vec<u8>) {
203 self.write(data, 2);
204 }
205}
206pub struct WebsocketEvent {
207 ws: Websocket,
208 pub message: WsMessage,
209}
210
211impl WebsocketEvent {
212 pub fn get_conn(&self) -> &Websocket {
213 &self.ws
214 }
215}
216
217impl Deref for WebsocketEvent {
218 type Target = Websocket;
219 fn deref(&self) -> &Self::Target {
220 &self.ws
221 }
222}
223
224pub(crate) fn construct_http_event_for_websocket(
225 stream: &mut TcpStream,
226 method: &str,
227 url: &str,
228 http_version: &str,
229 head_map: &HashMap<&str, &str>,
230 connection_config: Arc<ConnectionData>,
231) -> (bool, Option<Arc<dyn WsRouter + Send + Sync>>) {
232 let conn = Rc::new(RefCell::new(stream));
233 let request = Request {
234 header_pair: head_map.clone(),
235 url,
236 method,
237 version: http_version,
238 body: BodyContent::None,
239 conn_: Rc::clone(&conn),
240 secret_key: Arc::clone(&connection_config.server_config.secret_key),
241 ctx: RefCell::new(BTreeMap::new()),
242 };
243 let mut response = Response {
244 header_pair: MultiMap::new(),
245 version: http_version,
246 method,
247 http_state: 200,
249 body: BodyType::None,
250 chunked: ResponseChunkMeta::new(connection_config.server_config.chunk_size),
251 conn_: Rc::clone(&conn),
252 range: ResponseRangeMeta::None,
253 request_header: head_map.clone(),
254 charset: None,
255 };
256
257 let ws_middleware_result = invoke_ws_middlewares(&connection_config, &request, &mut response);
258
259 if !ws_middleware_result.0 {
260 let mut stream = conn.borrow_mut();
261 if !response.chunked.enable {
262 match super::write_once(*stream, &mut response) {
263 Ok(_) => {}
264 Err(e) => {
265 if connection_config.server_config.open_log {
266 let now = super::get_current_date();
267 println!(
268 "[{}] >>> error in write_once in websocket.rs; type: [{}], line: [{}], msg: [{}]",
269 now,
270 e.kind().to_string(),
271 line!(),
272 ToString::to_string(&e)
273 );
274 }
275 }
276 }
277 } else {
278 match super::write_chunk(*stream, &mut response) {
280 Ok(_) => {}
281 Err(e) => {
282 if connection_config.server_config.open_log {
283 let now = super::get_current_date();
284 println!(
285 "[{}] >>> error in write_chunk in websocket.rs; type: [{}], line: [{}], msg: [{}]",
286 now,
287 e.kind().to_string(),
288 line!(),
289 ToString::to_string(&e)
290 );
291 }
292 }
293 }
294 }
295 (false, None)
296 } else {
297 (true, ws_middleware_result.1)
298 }
299}
300
301fn invoke_ws_middlewares_help(result: &WsRouterValue, req: &Request, res: &mut Response) -> bool {
302 match &result.0 {
303 Some(middlewares) => {
304 for middleware in middlewares {
306 if !middleware.call(req, res) {
307 return false;
308 }
309 }
310 return true;
311 }
312 None => true,
313 }
314}
315
316fn invoke_ws_middlewares(
317 connection_data: &ConnectionData,
318 req: &Request,
319 res: &mut Response,
320) -> (bool, Option<Arc<dyn WsRouter + Send + Sync>>) {
321 let url = req.url.split_once("?");
322 let url = match url {
323 Some((url, _)) => url,
324 None => req.url,
325 };
326 let ws_router_map = &connection_data.ws_router_map;
327 let not_found = connection_data
328 .router_map
329 .get("NEVER_FOUND_FOR_ALL")
330 .unwrap();
331 match ws_router_map.get(url) {
332 Some(result) => {
333 return (
334 invoke_ws_middlewares_help(result, req, res),
335 Some(Arc::clone(&result.1)),
336 );
337 }
338 None => {
339 not_found.1.call(req, res);
340 return (false, None);
341 }
342 }
343}
344
345pub(crate) fn handle_websocket_connection(
346 mut stream: TcpStream,
347 header: HashMap<&str, &str>,
348 ws_version: String,
349 secret_key: String,
350 ws_router: Arc<dyn WsRouter + Send + Sync>,
351 connection_data: Arc<ConnectionData>,
352) {
353 let mut request_meta_header: HashMap<String, String> = HashMap::new();
355 for (k, v) in header {
356 request_meta_header.insert(k.to_string(), v.to_string());
357 }
358 let key = format!("{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11", secret_key);
359 let mut hasher = Sha1::new();
360 hasher.update(key.as_bytes());
361 let hash = hasher.finalize();
362 let result: &[u8] = hash.as_ref();
363 let r = base64::encode(result);
364 let response = format!("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version:{}\r\nSec-WebSocket-Accept:{}\r\n\r\n",ws_version,r);
367 match stream.write(response.as_bytes()) {
368 Ok(_) => {
369 let _ = stream.set_read_timeout(Some(std::time::Duration::from_millis(
370 connection_data.server_config.ws_read_timeout as u64,
371 )));
372 let _ = stream.set_write_timeout(Some(std::time::Duration::from_millis(
373 connection_data.server_config.ws_write_timeout as u64,
374 )));
375 switch_to_websocket(Arc::new(stream), ws_router, connection_data);
376 }
377 Err(_) => {}
378 }
379}
380
381fn read_ws_data(
382 reader: &mut BufReader<&TcpStream>,
383 data_len: usize,
384 data_buffs: &mut Vec<u8>,
385 mask_key: [u8; 4],
386) -> bool {
387 let current_len = data_buffs.len();
388 data_buffs.resize(current_len + data_len, b'\0');
389 match reader.read_exact(&mut data_buffs[current_len..]) {
391 Ok(_) => {
392 decode_ws_data(&mut data_buffs[current_len..], mask_key);
393 true
394 }
395 Err(_) => false,
396 }
397}
398
399fn decode_ws_data(raw_data: &mut [u8], mask_key: [u8; 4]) {
400 let mut i = 0;
401 let len = raw_data.len();
402 while i < len {
403 let j = i % 4;
404 let r = raw_data[i] ^ mask_key[j];
405 raw_data[i] = r;
406 i += 1;
407 }
408}
409
410fn switch_to_websocket(
411 stream: Arc<TcpStream>,
412 ws_router: Arc<dyn WsRouter + Send + Sync>,
413 connection_data: Arc<ConnectionData>,
414) {
415 let fragment_size: usize = connection_data.server_config.ws_frame_size;
417 let server_log_open = connection_data.server_config.open_log;
418 let ws_handler = Websocket {
419 conn: Arc::clone(&stream),
420 write_mutex: Arc::new(Mutex::new(())),
421 fragment_size: fragment_size,
422 };
423 let event = WebsocketEvent {
424 ws: ws_handler.clone(),
425 message: WsMessage::Open,
426 };
427 ws_router.call(event);
428 let read_stream = Arc::clone(&stream);
429 let (tx, rx) = mpsc::channel::<WebsocketEvent>();
430 let _write_thread = thread::spawn(move || loop {
431 match rx.recv() {
432 Ok(event) => {
433 let is_close = if let WsMessage::Close = event.message {
434 true
435 } else {
436 false
437 };
438 ws_router.call(event);
439 if is_close {
440 break;
441 }
442 }
443 Err(e) => {
444 if server_log_open {
445 let now = super::get_current_date();
446 println!(
447 "[{}] >>> error in websocket write thread in websocket.rs; type: [RecvError], line: [{}], msg: [{}]",
448 now,
449 line!(),
450 ToString::to_string(&e)
451 );
452 }
453 break;
454 }
455 }
456 });
457 let _read_thread = thread::spawn(move || {
458 let inner = read_stream.as_ref();
459 let sender = tx;
460 let mut reader = BufReader::new(inner);
461 'Restart: loop {
462 let mut data_buffs = Vec::new();
463 let mut opcode = 0;
464 let mut first_entry = true;
465 'ReadFrame: loop {
466 let mut buff = [b'\0'; 2];
467 match reader.read_exact(&mut buff) {
468 Ok(_) => {
469 let first_byte = buff[0];
470 let fin = (first_byte >> 7) & 1; if first_entry {
476 opcode = first_byte & 0b00001111u8; first_entry = false;
478 }
479 let second_byte = buff[1];
480 let mask = (second_byte >> 7) & 1;
481 if mask != 1 {
482 break 'Restart;
483 }
484 if opcode == 8 {
485 break 'Restart;
488 }
489 if opcode == 9 {
490 ws_handler.write(Vec::new(), 10);
493 continue 'Restart;
494 }
495 if opcode == 10 {
496 continue 'Restart;
499 }
500 let payload = second_byte & 0b01111111u8;
502 let data_len = if payload <= 125 {
503 payload as usize
505 } else if payload == 126 {
506 let mut two_bytes = [b'\0'; 2];
508 match reader.read_exact(&mut two_bytes) {
509 Ok(_) => {
510 let endian = [
511 b'\0',
512 b'\0',
513 b'\0',
514 b'\0',
515 b'\0',
516 b'\0',
517 two_bytes[0],
518 two_bytes[1],
519 ];
520 usize::from_be_bytes(endian)
521 }
522 Err(e) => {
523 if server_log_open {
524 let now = super::get_current_date();
525 println!(
526 "[{}] >>> error in websocket read thread in websocket.rs; type: [{}], line: [{}], msg: [{}]",
527 now,
528 e.kind().to_string(),
529 line!(),
530 ToString::to_string(&e)
531 );
532 }
533 break 'Restart;
534 }
535 }
536 } else if payload == 127 {
537 let mut eight_bytes = [b'\0'; 8];
539 match reader.read_exact(&mut eight_bytes) {
540 Ok(_) => usize::from_be_bytes(eight_bytes),
541 Err(e) => {
542 if server_log_open {
543 let now = super::get_current_date();
544 println!(
545 "[{}] >>> error in websocket read thread in websocket.rs; type: [{}], line: [{}], msg: [{}]",
546 now,
547 e.kind().to_string(),
548 line!(),
549 ToString::to_string(&e)
550 );
551 }
552 break 'Restart;
553 }
554 }
555 } else {
556 break 'Restart;
558 }; let mut mask_key_buffs = [b'\0'; 4];
562 match reader.read_exact(&mut mask_key_buffs) {
563 Ok(_) => {
564 let r = read_ws_data(
566 &mut reader,
567 data_len,
568 &mut data_buffs,
569 mask_key_buffs,
570 );
571 if r == false {
573 break 'Restart;
574 } else {
575 if fin == 1 {
576 let event = WebsocketEvent {
579 ws: ws_handler.clone(),
580 message: WsMessage::Message(data_buffs, opcode),
581 };
582 match sender.send(event) {
584 Ok(_) => {}
585 Err(e) => {
586 if server_log_open {
587 let now = super::get_current_date();
588 println!(
589 "[{}] >>> error in websocket read thread in websocket.rs; type: [SendError], line: [{}], msg: [{}]",
590 now,
591 line!(),
592 ToString::to_string(&e)
593 );
594 }
595 break 'Restart; }
597 }
598 continue 'Restart; } else {
600 continue 'ReadFrame; }
603 }
604 }
605 Err(e) => {
606 if server_log_open {
607 let now = super::get_current_date();
608 println!(
609 "[{}] >>> error in websocket read thread in websocket.rs; type: [{}], line: [{}], msg: [{}]",
610 now,
611 e.kind().to_string(),
612 line!(),
613 ToString::to_string(&e)
614 );
615 }
616 break 'Restart;
617 }
618 }
619 }
620 Err(e) => {
621 if server_log_open {
622 let now = super::get_current_date();
623 println!(
624 "[{}] >>> error in websocket write thread in websocket.rs; type: [{}], line: [{}], msg: [{}]",
625 now,
626 e.kind().to_string(),
627 line!(),
628 ToString::to_string(&e)
629 );
630 }
631 break 'Restart;
632 }
633 };
634 }
635 }
636 let event = WebsocketEvent {
637 ws: ws_handler.clone(),
638 message: WsMessage::Close,
639 };
640 match sender.send(event) {
641 Ok(_) => {}
642 Err(e) => {
643 if server_log_open {
644 let now = super::get_current_date();
645 println!(
646 "[{}] >>> error in websocket read thread in websocket.rs; type: [SendError], line: [{}], msg: [{}]",
647 now,
648 line!(),
649 ToString::to_string(&e)
650 );
651 }
652 }
653 }
654 });
655}