Skip to main content

session_rs/ws/
handshake.rs

1use base64::Engine;
2use sha1::{Digest, Sha1};
3use std::sync::Arc;
4use tokio::{
5    io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
6    net::TcpStream,
7    sync::Mutex,
8};
9
10use super::WebSocket;
11
12pub async fn handle_websocket_handshake(stream: &mut TcpStream) -> std::io::Result<()> {
13    let (read_half, mut write_half) = stream.split();
14    let mut reader = BufReader::new(read_half);
15
16    let mut request_line = String::new();
17    reader.read_line(&mut request_line).await?;
18    let request_line = request_line.trim_end();
19
20    if request_line.starts_with("HEAD") {
21        write_half
22            .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
23            .await?;
24        return Ok(());
25    }
26
27    if !request_line.starts_with("GET") {
28        return Err(std::io::Error::new(
29            std::io::ErrorKind::InvalidData,
30            "Invalid HTTP method",
31        ));
32    }
33
34    use std::collections::HashMap;
35    let mut headers = HashMap::new();
36    let mut line = String::new();
37
38    loop {
39        line.clear();
40        reader.read_line(&mut line).await?;
41        if line == "\r\n" {
42            break;
43        }
44        if let Some((k, v)) = line.split_once(':') {
45            headers.insert(k.trim().to_lowercase(), v.trim().to_string());
46        }
47    }
48
49    if headers
50        .get("upgrade")
51        .map(|v| !v.eq_ignore_ascii_case("websocket"))
52        .unwrap_or(true)
53    {
54        write_half
55            .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK")
56            .await?;
57        return Ok(());
58    }
59
60    let key = headers
61        .get("sec-websocket-key")
62        .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "Missing key"))?;
63
64    use base64::Engine;
65    use base64::engine::general_purpose::STANDARD as Base64;
66    use sha1::{Digest, Sha1};
67
68    let mut hasher = Sha1::new();
69    hasher.update(key.as_bytes());
70    hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
71    let accept = Base64.encode(hasher.finalize());
72
73    let response = format!(
74        "HTTP/1.1 101 Switching Protocols\r\n\
75         Upgrade: websocket\r\n\
76         Connection: Upgrade\r\n\
77         Sec-WebSocket-Accept: {}\r\n\r\n",
78        accept
79    );
80
81    write_half.write_all(response.as_bytes()).await?;
82    Ok(())
83}
84
85impl WebSocket {
86    pub async fn handshake(mut stream: TcpStream) -> super::Result<Self> {
87        handle_websocket_handshake(&mut stream).await?;
88
89        let (read, write) = stream.into_split();
90
91        Ok(Self {
92            id: rand::random(),
93            reader: Arc::new(Mutex::new(read)),
94            writer: Arc::new(Mutex::new(write)),
95            is_server: false,
96        })
97    }
98
99    /// Connect to a WebSocket server and perform the handshake
100    pub async fn connect(addr: &str, path: &str) -> super::Result<Self> {
101        // 1. TCP connect
102        let mut stream = TcpStream::connect(addr).await?;
103
104        // 2. Generate Sec-WebSocket-Key
105        let key_bytes: [u8; 16] = rand::random();
106        let key = base64::prelude::BASE64_STANDARD.encode(&key_bytes);
107
108        // 3. Send HTTP Upgrade request
109        let request = format!(
110            "GET {} HTTP/1.1\r\n\
111             Host: {}\r\n\
112             Upgrade: websocket\r\n\
113             Connection: Upgrade\r\n\
114             Sec-WebSocket-Key: {}\r\n\
115             Sec-WebSocket-Version: 13\r\n\
116             \r\n",
117            path, addr, key
118        );
119        stream.write_all(request.as_bytes()).await?;
120        stream.flush().await?;
121
122        // 4. Read HTTP response
123        let mut reader = BufReader::new(&mut stream);
124        let mut status_line = String::new();
125        reader.read_line(&mut status_line).await?;
126        if !status_line.starts_with("HTTP/1.1 101") {
127            return Err(super::Error::HandshakeFailed(format!(
128                "Expected 101 Switching Protocols, got: {}",
129                status_line.trim_end()
130            )));
131        }
132
133        // Read headers
134        let mut sec_accept = None;
135        loop {
136            let mut line = String::new();
137            reader.read_line(&mut line).await?;
138            let line = line.trim_end();
139            if line.is_empty() {
140                break; // end of headers
141            }
142            if let Some((k, v)) = line.split_once(':') {
143                if k.eq_ignore_ascii_case("sec-websocket-accept") {
144                    sec_accept = Some(v.trim().to_string());
145                }
146            }
147        }
148
149        // 5. Verify Sec-WebSocket-Accept
150        let expected = {
151            let mut sha1 = Sha1::new();
152            sha1.update(key.as_bytes());
153            sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
154            base64::prelude::BASE64_STANDARD.encode(sha1.finalize())
155        };
156        if sec_accept.as_deref() != Some(expected.as_str()) {
157            return Err(super::Error::HandshakeFailed(
158                "Sec-WebSocket-Accept mismatch".into(),
159            ));
160        }
161
162        // 6. Upgrade succeeded, split stream
163        let (read, write) = stream.into_split();
164
165        Ok(Self {
166            id: rand::random(),
167            reader: Arc::new(Mutex::new(read)),
168            writer: Arc::new(Mutex::new(write)),
169            is_server: true,
170        })
171    }
172}