rust_xfinal/http_parser/
websocket.rs

1use super::{
2    BodyContent, MultiMap, Request, Response, ResponseChunkMeta, ResponseRangeMeta, WsRouter,
3    WsRouterValue,
4};
5use super::{BodyType, ConnectionData};
6use sha1::{Digest, Sha1};
7use std::cell::RefCell;
8use std::collections::BTreeMap;
9use std::collections::HashMap;
10use std::io::prelude::*;
11use std::io::{BufReader, BufWriter};
12use std::net::TcpStream;
13use std::ops::Deref;
14use std::rc::Rc;
15use std::sync::mpsc;
16use std::sync::{Arc, Mutex};
17use std::thread;
18
19pub(crate) fn exist_pair<'a>(k: &str, header: &'a HashMap<&str, &str>) -> (bool, Option<&'a str>) {
20    let ak = header.keys().find(|&&ak| {
21        if k.to_lowercase() == ak.to_lowercase() {
22            true
23        } else {
24            false
25        }
26    });
27    match ak {
28        Some(&s) => {
29            return (true, Some(header.get(s).unwrap()));
30        }
31        None => {
32            return (false, None);
33        }
34    }
35}
36
37pub(crate) fn is_websocket_upgrade(
38    method: &str,
39    header: &HashMap<&str, &str>,
40) -> (bool, String, String) {
41    if method.to_lowercase() == "get" {
42        let upgrade = exist_pair("Upgrade", header);
43        let connection = exist_pair("Connection", header);
44        let sec_websocket_key = exist_pair("Sec-WebSocket-Key", header);
45        let version = exist_pair("Sec-WebSocket-Version", &header);
46        if upgrade.0 == true
47            && connection.0 == true
48            && sec_websocket_key.0 == true
49            && version.0 == true
50        {
51            if "websocket" == upgrade.1.unwrap().to_lowercase()
52                && "upgrade" == connection.1.unwrap().to_lowercase()
53            {
54                return (
55                    true,
56                    version.1.unwrap().trim().to_string(),
57                    sec_websocket_key.1.unwrap().trim().to_string(),
58                );
59            }
60        }
61    }
62    (false, String::new(), String::new())
63}
64
65#[derive(Debug)]
66pub enum WsMessage {
67    Open,
68    Message(Vec<u8>, u8),
69    Close,
70}
71
72pub struct Websocket {
73    conn: Arc<TcpStream>,
74    write_mutex: Arc<Mutex<()>>,
75    fragment_size: usize,
76}
77impl Websocket {
78    /// > Share the websocket connection
79    pub fn clone(&self) -> Self {
80        Websocket {
81            conn: Arc::clone(&self.conn),
82            write_mutex: Arc::clone(&self.write_mutex),
83            fragment_size: self.fragment_size,
84        }
85    }
86    /// > Write the data whose type is specified by the sceond parameter to peer
87    /// >> The method is thread safety
88    pub fn write(&self, data: Vec<u8>, opcode: u8) {
89        let lock_guard = self.write_mutex.lock().unwrap();
90        //println!("write=======");
91        let len = data.len();
92        //println!("total len:{}", len);
93        let mut writer = BufWriter::new(self.conn.as_ref());
94        if len <= 125 {
95            let first_byte = 0b10000000u8 | opcode;
96            let second_byte = len as u8;
97            let mut buffs = Vec::new();
98            buffs.push(first_byte);
99            buffs.push(second_byte);
100            buffs.extend_from_slice(&data);
101            match writer.write_all(&buffs) {
102                Ok(_) => {
103                    println!("write ok!");
104                }
105                Err(_) => {
106                    let _ = self.conn.shutdown(std::net::Shutdown::Both);
107                    return;
108                }
109            }
110        } else if len >= 126 {
111            let fragment_size = self.fragment_size;
112            let mut fragment_count = {
113                let mut count = len / fragment_size;
114                if (len % fragment_size) != 0 {
115                    count += 1;
116                }
117                count
118            };
119            //println!("fragment_count:{}", fragment_count);
120            let mut start_pos = 0usize;
121            while fragment_count > 0 {
122                let mut buffs = Vec::new();
123                fragment_count -= 1;
124                let mut end_pos = start_pos + fragment_size;
125                if end_pos >= len {
126                    end_pos = len;
127                }
128                let fragment_data = &data[start_pos..end_pos];
129                let first_byte = {
130                    if start_pos == 0 {
131                        // 首个包
132                        if fragment_count > 0 {
133                            // 且不是只有一个包
134                            0b00000000u8 | opcode
135                        } else {
136                            0b10000000u8 | opcode // 首个包且只有一个包
137                        }
138                    } else {
139                        if fragment_count > 0 {
140                            //中间的包,非最后一个
141                            0
142                        } else {
143                            //分片最后一个包
144                            0b10000000u8
145                        }
146                    }
147                };
148                //println!("{:b}",first_byte);
149                buffs.push(first_byte);
150                let actual_write_size = end_pos - start_pos;
151                if actual_write_size >= 126 {
152                    if actual_write_size <= u16::MAX as usize {
153                        let second_byte = 126u8;
154                        let payload = (actual_write_size as u16).to_be_bytes();
155                        buffs.push(second_byte);
156                        buffs.extend_from_slice(&payload);
157                    } else if actual_write_size as u64 <= u64::MAX {
158                        let second_byte = 127u8;
159                        let payload = (actual_write_size as u64).to_be_bytes();
160                        buffs.push(second_byte);
161                        buffs.extend_from_slice(&payload);
162                    }
163                } else {
164                    let second_byte = actual_write_size as u8;
165                    buffs.push(second_byte);
166                }
167                //println!("{:?}", buffs);
168                buffs.extend_from_slice(fragment_data);
169                match writer.write_all(&buffs) {
170                    Ok(_) => {
171                        start_pos = end_pos;
172                        //println!("fragment write ok!");
173                        continue;
174                    }
175                    Err(_) => {
176                        let _ = self.conn.shutdown(std::net::Shutdown::Both);
177                        return;
178                    }
179                }
180            }
181        }
182        match writer.flush() {
183            Ok(_) => {}
184            Err(_) => {
185                let _ = self.conn.shutdown(std::net::Shutdown::Both);
186                return;
187            }
188        }
189        //println!("write over");
190        drop(lock_guard);
191    }
192
193    /// > Conveniently to write string data to peer
194    /// >> The method is thread safety
195    pub fn write_string(&self, s: &str) {
196        let mut vec = Vec::new();
197        vec.extend_from_slice(s.as_bytes());
198        self.write(vec, 1);
199    }
200    /// > Conveniently to write binary data to peer
201    /// >> The method is thread safety
202    pub fn write_binary(&self, data: Vec<u8>) {
203        self.write(data, 2);
204    }
205}
206pub struct WebsocketEvent {
207    ws: Websocket,
208    pub message: WsMessage,
209}
210
211impl WebsocketEvent {
212    pub fn get_conn(&self) -> &Websocket {
213        &self.ws
214    }
215}
216
217impl Deref for WebsocketEvent {
218    type Target = Websocket;
219    fn deref(&self) -> &Self::Target {
220        &self.ws
221    }
222}
223
224pub(crate) fn construct_http_event_for_websocket(
225    stream: &mut TcpStream,
226    method: &str,
227    url: &str,
228    http_version: &str,
229    head_map: &HashMap<&str, &str>,
230    connection_config: Arc<ConnectionData>,
231) -> (bool, Option<Arc<dyn WsRouter + Send + Sync>>) {
232    let conn = Rc::new(RefCell::new(stream));
233    let request = Request {
234        header_pair: head_map.clone(),
235        url,
236        method,
237        version: http_version,
238        body: BodyContent::None,
239        conn_: Rc::clone(&conn),
240        secret_key: Arc::clone(&connection_config.server_config.secret_key),
241        ctx: RefCell::new(BTreeMap::new()),
242    };
243    let mut response = Response {
244        header_pair: MultiMap::new(),
245        version: http_version,
246        method,
247        //url,
248        http_state: 200,
249        body: BodyType::None,
250        chunked: ResponseChunkMeta::new(connection_config.server_config.chunk_size),
251        conn_: Rc::clone(&conn),
252        range: ResponseRangeMeta::None,
253        request_header: head_map.clone(),
254        charset: None,
255    };
256
257    let ws_middleware_result = invoke_ws_middlewares(&connection_config, &request, &mut response);
258
259    if !ws_middleware_result.0 {
260        let mut stream = conn.borrow_mut();
261        if !response.chunked.enable {
262            match super::write_once(*stream, &mut response) {
263                Ok(_) => {}
264                Err(e) => {
265                    if connection_config.server_config.open_log {
266                        let now = super::get_current_date();
267                        println!(
268							"[{}] >>> error in write_once in websocket.rs; type: [{}], line: [{}], msg: [{}]",
269							now,
270							e.kind().to_string(),
271							line!(),
272							ToString::to_string(&e)
273						);
274                    }
275                }
276            }
277        } else {
278            // chunked transfer
279            match super::write_chunk(*stream, &mut response) {
280                Ok(_) => {}
281                Err(e) => {
282                    if connection_config.server_config.open_log {
283                        let now = super::get_current_date();
284                        println!(
285							"[{}] >>> error in write_chunk in websocket.rs; type: [{}], line: [{}], msg: [{}]",
286							now,
287							e.kind().to_string(),
288							line!(),
289							ToString::to_string(&e)
290						);
291                    }
292                }
293            }
294        }
295        (false, None)
296    } else {
297        (true, ws_middleware_result.1)
298    }
299}
300
301fn invoke_ws_middlewares_help(result: &WsRouterValue, req: &Request, res: &mut Response) -> bool {
302    match &result.0 {
303        Some(middlewares) => {
304            // at least one middleware
305            for middleware in middlewares {
306                if !middleware.call(req, res) {
307                    return false;
308                }
309            }
310            return true;
311        }
312        None => true,
313    }
314}
315
316fn invoke_ws_middlewares(
317    connection_data: &ConnectionData,
318    req: &Request,
319    res: &mut Response,
320) -> (bool, Option<Arc<dyn WsRouter + Send + Sync>>) {
321    let url = req.url.split_once("?");
322    let url = match url {
323        Some((url, _)) => url,
324        None => req.url,
325    };
326    let ws_router_map = &connection_data.ws_router_map;
327    let not_found = connection_data
328        .router_map
329        .get("NEVER_FOUND_FOR_ALL")
330        .unwrap();
331    match ws_router_map.get(url) {
332        Some(result) => {
333            return (
334                invoke_ws_middlewares_help(result, req, res),
335                Some(Arc::clone(&result.1)),
336            );
337        }
338        None => {
339            not_found.1.call(req, res);
340            return (false, None);
341        }
342    }
343}
344
345pub(crate) fn handle_websocket_connection(
346    mut stream: TcpStream,
347    header: HashMap<&str, &str>,
348    ws_version: String,
349    secret_key: String,
350    ws_router: Arc<dyn WsRouter + Send + Sync>,
351    connection_data: Arc<ConnectionData>,
352) {
353    //println!("handle_websocket_connection");
354    let mut request_meta_header: HashMap<String, String> = HashMap::new();
355    for (k, v) in header {
356        request_meta_header.insert(k.to_string(), v.to_string());
357    }
358    let key = format!("{}258EAFA5-E914-47DA-95CA-C5AB0DC85B11", secret_key);
359    let mut hasher = Sha1::new();
360    hasher.update(key.as_bytes());
361    let hash = hasher.finalize();
362    let result: &[u8] = hash.as_ref();
363    let r = base64::encode(result);
364    //println!("websocket: {r}");
365    //let shared_tcp_connection = Arc::new(stream);
366    let response = format!("HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version:{}\r\nSec-WebSocket-Accept:{}\r\n\r\n",ws_version,r);
367    match stream.write(response.as_bytes()) {
368        Ok(_) => {
369            let _ = stream.set_read_timeout(Some(std::time::Duration::from_millis(
370                connection_data.server_config.ws_read_timeout as u64,
371            )));
372            let _ = stream.set_write_timeout(Some(std::time::Duration::from_millis(
373                connection_data.server_config.ws_write_timeout as u64,
374            )));
375            switch_to_websocket(Arc::new(stream), ws_router, connection_data);
376        }
377        Err(_) => {}
378    }
379}
380
381fn read_ws_data(
382    reader: &mut BufReader<&TcpStream>,
383    data_len: usize,
384    data_buffs: &mut Vec<u8>,
385    mask_key: [u8; 4],
386) -> bool {
387    let current_len = data_buffs.len();
388    data_buffs.resize(current_len + data_len, b'\0');
389    //println!("current pos:{current_len}, size:{}", current_len + data_len);
390    match reader.read_exact(&mut data_buffs[current_len..]) {
391        Ok(_) => {
392            decode_ws_data(&mut data_buffs[current_len..], mask_key);
393            true
394        }
395        Err(_) => false,
396    }
397}
398
399fn decode_ws_data(raw_data: &mut [u8], mask_key: [u8; 4]) {
400    let mut i = 0;
401    let len = raw_data.len();
402    while i < len {
403        let j = i % 4;
404        let r = raw_data[i] ^ mask_key[j];
405        raw_data[i] = r;
406        i += 1;
407    }
408}
409
410fn switch_to_websocket(
411    stream: Arc<TcpStream>,
412    ws_router: Arc<dyn WsRouter + Send + Sync>,
413    connection_data: Arc<ConnectionData>,
414) {
415    //println!("invoke switch_to_websocket");
416    let fragment_size: usize = connection_data.server_config.ws_frame_size;
417    let server_log_open = connection_data.server_config.open_log;
418    let ws_handler = Websocket {
419        conn: Arc::clone(&stream),
420        write_mutex: Arc::new(Mutex::new(())),
421        fragment_size: fragment_size,
422    };
423    let event = WebsocketEvent {
424        ws: ws_handler.clone(),
425        message: WsMessage::Open,
426    };
427    ws_router.call(event);
428    let read_stream = Arc::clone(&stream);
429    let (tx, rx) = mpsc::channel::<WebsocketEvent>();
430    let _write_thread = thread::spawn(move || loop {
431        match rx.recv() {
432            Ok(event) => {
433                let is_close = if let WsMessage::Close = event.message {
434                    true
435                } else {
436                    false
437                };
438                ws_router.call(event);
439                if is_close {
440                    break;
441                }
442            }
443            Err(e) => {
444                if server_log_open {
445                    let now = super::get_current_date();
446                    println!(
447						"[{}] >>> error in websocket write thread in websocket.rs; type: [RecvError], line: [{}], msg: [{}]",
448						now,
449						line!(),
450						ToString::to_string(&e)
451					);
452                }
453                break;
454            }
455        }
456    });
457    let _read_thread = thread::spawn(move || {
458        let inner = read_stream.as_ref();
459        let sender = tx;
460        let mut reader = BufReader::new(inner);
461        'Restart: loop {
462            let mut data_buffs = Vec::new();
463            let mut opcode = 0;
464            let mut first_entry = true;
465            'ReadFrame: loop {
466                let mut buff = [b'\0'; 2];
467                match reader.read_exact(&mut buff) {
468                    Ok(_) => {
469                        let first_byte = buff[0];
470                        let fin = (first_byte >> 7) & 1; // 1是最后一个包或完整的包, 0是分包
471                                                         // println!(
472                                                         //     "first_entry:{first_entry}, first_byte:{:b}, second_byte:{:b}",
473                                                         //     first_byte, buff[1]
474                                                         // );
475                        if first_entry {
476                            opcode = first_byte & 0b00001111u8; //如果是分片传输,只记录首次的frame中的opcode
477                            first_entry = false;
478                        }
479                        let second_byte = buff[1];
480                        let mask = (second_byte >> 7) & 1;
481                        if mask != 1 {
482                            break 'Restart;
483                        }
484                        if opcode == 8 {
485                            // //关闭连接
486                            //构造close 通知
487                            break 'Restart;
488                        }
489                        if opcode == 9 {
490                            //ping
491                            // 构造pong消息
492                            ws_handler.write(Vec::new(), 10);
493                            continue 'Restart;
494                        }
495                        if opcode == 10 {
496                            //pong
497                            // 客户端回应pong消息
498                            continue 'Restart;
499                        }
500                        // 其他情况的opcode是消息体
501                        let payload = second_byte & 0b01111111u8;
502                        let data_len = if payload <= 125 {
503                            // 就是data的实际大小
504                            payload as usize
505                        } else if payload == 126 {
506                            // 后两个字节表示长度
507                            let mut two_bytes = [b'\0'; 2];
508                            match reader.read_exact(&mut two_bytes) {
509                                Ok(_) => {
510                                    let endian = [
511                                        b'\0',
512                                        b'\0',
513                                        b'\0',
514                                        b'\0',
515                                        b'\0',
516                                        b'\0',
517                                        two_bytes[0],
518                                        two_bytes[1],
519                                    ];
520                                    usize::from_be_bytes(endian)
521                                }
522                                Err(e) => {
523                                    if server_log_open {
524                                        let now = super::get_current_date();
525                                        println!(
526												"[{}] >>> error in websocket read thread in websocket.rs; type: [{}], line: [{}], msg: [{}]",
527												now,
528												e.kind().to_string(),
529												line!(),
530												ToString::to_string(&e)
531											);
532                                    }
533                                    break 'Restart;
534                                }
535                            }
536                        } else if payload == 127 {
537                            // 后 8个字节表示长度
538                            let mut eight_bytes = [b'\0'; 8];
539                            match reader.read_exact(&mut eight_bytes) {
540                                Ok(_) => usize::from_be_bytes(eight_bytes),
541                                Err(e) => {
542                                    if server_log_open {
543                                        let now = super::get_current_date();
544                                        println!(
545												"[{}] >>> error in websocket read thread in websocket.rs; type: [{}], line: [{}], msg: [{}]",
546												now,
547												e.kind().to_string(),
548												line!(),
549												ToString::to_string(&e)
550											);
551                                    }
552                                    break 'Restart;
553                                }
554                            }
555                        } else {
556                            // payload 无效值
557                            break 'Restart;
558                        }; // data_len end
559                           //println!("data_len: {}", data_len);
560                           // 读取 Masking-key, 4个字节
561                        let mut mask_key_buffs = [b'\0'; 4];
562                        match reader.read_exact(&mut mask_key_buffs) {
563                            Ok(_) => {
564                                // read data part;
565                                let r = read_ws_data(
566                                    &mut reader,
567                                    data_len,
568                                    &mut data_buffs,
569                                    mask_key_buffs,
570                                );
571                                //println!("read data part: {r}");
572                                if r == false {
573                                    break 'Restart;
574                                } else {
575                                    if fin == 1 {
576                                        // 构造消息事件
577                                        //println!("total size:{}", data_buffs.len());
578                                        let event = WebsocketEvent {
579                                            ws: ws_handler.clone(),
580                                            message: WsMessage::Message(data_buffs, opcode),
581                                        };
582                                        //println!("{:?}", event.message);
583                                        match sender.send(event) {
584                                            Ok(_) => {}
585                                            Err(e) => {
586                                                if server_log_open {
587                                                    let now = super::get_current_date();
588                                                    println!(
589															"[{}] >>> error in websocket read thread in websocket.rs; type: [SendError], line: [{}], msg: [{}]",
590															now,
591															line!(),
592															ToString::to_string(&e)
593														);
594                                                }
595                                                break 'Restart; // 发送消息错误,关闭当前线程
596                                            }
597                                        }
598                                        continue 'Restart; // 读完完整消息体,重新初始状态等待下一次消息
599                                    } else {
600                                        //println!("continue to read data");
601                                        continue 'ReadFrame; //非完整消息,继续循环ReadFrame块功能
602                                    }
603                                }
604                            }
605                            Err(e) => {
606                                if server_log_open {
607                                    let now = super::get_current_date();
608                                    println!(
609											"[{}] >>> error in websocket read thread in websocket.rs; type: [{}], line: [{}], msg: [{}]",
610											now,
611											e.kind().to_string(),
612											line!(),
613											ToString::to_string(&e)
614										);
615                                }
616                                break 'Restart;
617                            }
618                        }
619                    }
620                    Err(e) => {
621                        if server_log_open {
622                            let now = super::get_current_date();
623                            println!(
624								"[{}] >>> error in websocket write thread in websocket.rs; type: [{}], line: [{}], msg: [{}]",
625								now,
626								e.kind().to_string(),
627								line!(),
628								ToString::to_string(&e)
629							);
630                        }
631                        break 'Restart;
632                    }
633                };
634            }
635        }
636        let event = WebsocketEvent {
637            ws: ws_handler.clone(),
638            message: WsMessage::Close,
639        };
640        match sender.send(event) {
641            Ok(_) => {}
642            Err(e) => {
643                if server_log_open {
644                    let now = super::get_current_date();
645                    println!(
646						"[{}] >>> error in websocket read thread in websocket.rs; type: [SendError], line: [{}], msg: [{}]",
647						now,
648						line!(),
649						ToString::to_string(&e)
650					);
651                }
652            }
653        }
654    });
655}