session_rs/ws/
handshake.rs1use base64::Engine;
2use sha1::{Digest, Sha1};
3use std::sync::Arc;
4use tokio::{
5 io::{AsyncBufReadExt, AsyncWriteExt, BufReader},
6 net::TcpStream,
7 sync::Mutex,
8};
9
10use super::WebSocket;
11
12pub async fn handle_websocket_handshake(stream: &mut TcpStream) -> std::io::Result<()> {
13 let (read_half, mut write_half) = stream.split();
14 let mut reader = BufReader::new(read_half);
15
16 let mut request_line = String::new();
17 reader.read_line(&mut request_line).await?;
18 let request_line = request_line.trim_end();
19
20 if request_line.starts_with("HEAD") {
21 write_half
22 .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 0\r\n\r\n")
23 .await?;
24 return Ok(());
25 }
26
27 if !request_line.starts_with("GET") {
28 return Err(std::io::Error::new(
29 std::io::ErrorKind::InvalidData,
30 "Invalid HTTP method",
31 ));
32 }
33
34 use std::collections::HashMap;
35 let mut headers = HashMap::new();
36 let mut line = String::new();
37
38 loop {
39 line.clear();
40 reader.read_line(&mut line).await?;
41 if line == "\r\n" {
42 break;
43 }
44 if let Some((k, v)) = line.split_once(':') {
45 headers.insert(k.trim().to_lowercase(), v.trim().to_string());
46 }
47 }
48
49 if headers
50 .get("upgrade")
51 .map(|v| !v.eq_ignore_ascii_case("websocket"))
52 .unwrap_or(true)
53 {
54 write_half
55 .write_all(b"HTTP/1.1 200 OK\r\nContent-Length: 2\r\n\r\nOK")
56 .await?;
57 return Ok(());
58 }
59
60 let key = headers
61 .get("sec-websocket-key")
62 .ok_or_else(|| std::io::Error::new(std::io::ErrorKind::InvalidData, "Missing key"))?;
63
64 use base64::Engine;
65 use base64::engine::general_purpose::STANDARD as Base64;
66 use sha1::{Digest, Sha1};
67
68 let mut hasher = Sha1::new();
69 hasher.update(key.as_bytes());
70 hasher.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
71 let accept = Base64.encode(hasher.finalize());
72
73 let response = format!(
74 "HTTP/1.1 101 Switching Protocols\r\n\
75 Upgrade: websocket\r\n\
76 Connection: Upgrade\r\n\
77 Sec-WebSocket-Accept: {}\r\n\r\n",
78 accept
79 );
80
81 write_half.write_all(response.as_bytes()).await?;
82 Ok(())
83}
84
85impl WebSocket {
86 pub async fn handshake(mut stream: TcpStream) -> super::Result<Self> {
87 handle_websocket_handshake(&mut stream).await?;
88
89 let (read, write) = stream.into_split();
90
91 Ok(Self {
92 id: rand::random(),
93 reader: Arc::new(Mutex::new(read)),
94 writer: Arc::new(Mutex::new(write)),
95 is_server: false,
96 })
97 }
98
99 pub async fn connect(addr: &str, path: &str) -> super::Result<Self> {
101 let mut stream = TcpStream::connect(addr).await?;
103
104 let key_bytes: [u8; 16] = rand::random();
106 let key = base64::prelude::BASE64_STANDARD.encode(&key_bytes);
107
108 let request = format!(
110 "GET {} HTTP/1.1\r\n\
111 Host: {}\r\n\
112 Upgrade: websocket\r\n\
113 Connection: Upgrade\r\n\
114 Sec-WebSocket-Key: {}\r\n\
115 Sec-WebSocket-Version: 13\r\n\
116 \r\n",
117 path, addr, key
118 );
119 stream.write_all(request.as_bytes()).await?;
120 stream.flush().await?;
121
122 let mut reader = BufReader::new(&mut stream);
124 let mut status_line = String::new();
125 reader.read_line(&mut status_line).await?;
126 if !status_line.starts_with("HTTP/1.1 101") {
127 return Err(super::Error::HandshakeFailed(format!(
128 "Expected 101 Switching Protocols, got: {}",
129 status_line.trim_end()
130 )));
131 }
132
133 let mut sec_accept = None;
135 loop {
136 let mut line = String::new();
137 reader.read_line(&mut line).await?;
138 let line = line.trim_end();
139 if line.is_empty() {
140 break; }
142 if let Some((k, v)) = line.split_once(':') {
143 if k.eq_ignore_ascii_case("sec-websocket-accept") {
144 sec_accept = Some(v.trim().to_string());
145 }
146 }
147 }
148
149 let expected = {
151 let mut sha1 = Sha1::new();
152 sha1.update(key.as_bytes());
153 sha1.update(b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11");
154 base64::prelude::BASE64_STANDARD.encode(sha1.finalize())
155 };
156 if sec_accept.as_deref() != Some(expected.as_str()) {
157 return Err(super::Error::HandshakeFailed(
158 "Sec-WebSocket-Accept mismatch".into(),
159 ));
160 }
161
162 let (read, write) = stream.into_split();
164
165 Ok(Self {
166 id: rand::random(),
167 reader: Arc::new(Mutex::new(read)),
168 writer: Arc::new(Mutex::new(write)),
169 is_server: true,
170 })
171 }
172}