1use http;
2use bytes::BytesMut;
3use sha1::Digest;
4use std::collections::HashMap;
5use std::fmt::Debug;
6
7use crate::errors::WsError;
8
9const GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
10
11pub struct StatusCode;
13
14impl StatusCode {
15 pub fn normal() -> u16 {
18 1000
19 }
20
21 pub fn going_away() -> u16 {
24 1001
25 }
26
27 pub fn protocol_error() -> u16 {
30 1002
31 }
32
33 pub fn terminate() -> u16 {
38 1003
39 }
40 pub fn reserved() -> u16 {
42 1004
43 }
44
45 pub fn app_reserved() -> u16 {
50 1005
51 }
52
53 pub fn abnormal_reserved() -> u16 {
59 1006
60 }
61
62 pub fn non_consistent() -> u16 {
67 1007
68 }
69
70 pub fn violate_policy() -> u16 {
76 1008
77 }
78
79 pub fn too_big() -> u16 {
83 1009
84 }
85
86 pub fn require_ext() -> u16 {
94 1010
95 }
96
97 pub fn unexpected_condition() -> u16 {
101 1011
102 }
103
104 pub fn platform_fail() -> u16 {
110 1015
111 }
112}
113
114#[derive(Debug, PartialEq, Eq)]
116pub enum Mode {
117 WS,
119 WSS,
121}
122
123impl Mode {
124 pub fn default_port(&self) -> u16 {
126 match self {
127 Mode::WS => 80,
128 Mode::WSS => 443,
129 }
130 }
131}
132
133#[cfg(feature = "sync")]
134mod blocking {
135 use http;
136 use std::{
137 collections::HashMap,
138 io::{Read, Write},
139 };
140
141 use bytes::{BufMut, BytesMut};
142
143 use crate::errors::WsError;
144
145 use super::{handle_parse_handshake, perform_parse_req, prepare_handshake};
146
147 pub fn req_handshake<S: Read + Write>(
151 stream: &mut S,
152 uri: &http::Uri,
153 protocols: &[String],
154 extensions: &[String],
155 version: u8,
156 extra_headers: HashMap<String, String>,
157 ) -> Result<(String, http::Response<()>), WsError> {
158 let (key, req_str) = prepare_handshake(protocols, extensions, extra_headers, uri, version);
159 stream.write_all(req_str.as_bytes())?;
160 stream.flush()?;
161 let mut read_bytes = BytesMut::with_capacity(1024);
162 let mut buf: [u8; 1] = [0; 1];
163 loop {
164 stream.read_exact(&mut buf)?;
165 read_bytes.put_u8(buf[0]);
166 let header_complete = read_bytes.ends_with(&[b'\r', b'\n', b'\r', b'\n']);
167 if header_complete {
168 break;
169 }
170 }
171 perform_parse_req(read_bytes, key)
172 }
173
174 pub fn handle_handshake<S: Read + Write>(stream: &mut S) -> Result<http::Request<()>, WsError> {
176 let mut req_bytes = BytesMut::with_capacity(1024);
177 let mut buf = [0u8];
178 loop {
179 stream.read_exact(&mut buf)?;
180 req_bytes.put_u8(buf[0]);
181 if req_bytes.ends_with(&[b'\r', b'\n', b'\r', b'\n']) {
182 break;
183 }
184 }
185 handle_parse_handshake(req_bytes)
186 }
187}
188
189#[cfg(feature = "sync")]
190pub use blocking::*;
191
192#[cfg(feature = "async")]
193mod non_blocking {
194 use http;
195 use std::collections::HashMap;
196
197 use bytes::{BufMut, BytesMut};
198 use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
199
200 use crate::{errors::WsError, protocol::prepare_handshake};
201
202 use super::{handle_parse_handshake, perform_parse_req};
203
204 pub async fn async_req_handshake<S: AsyncRead + AsyncWrite + Unpin>(
208 stream: &mut S,
209 uri: &http::Uri,
210 protocols: &[String],
211 extensions: &[String],
212 version: u8,
213 extra_headers: HashMap<String, String>,
214 ) -> Result<(String, http::Response<()>), WsError> {
215 let (key, req_str) = prepare_handshake(protocols, extensions, extra_headers, uri, version);
216 stream.write_all(req_str.as_bytes()).await?;
217 let mut read_bytes = BytesMut::with_capacity(1024);
218 let mut buf = [0u8];
219 loop {
220 stream.read_exact(&mut buf).await?;
221 read_bytes.put_u8(buf[0]);
222 let header_complete = read_bytes.ends_with(&[b'\r', b'\n', b'\r', b'\n']);
223 if header_complete {
224 break;
225 }
226 }
227 perform_parse_req(read_bytes, key)
228 }
229
230 pub async fn async_handle_handshake<S: AsyncRead + AsyncWrite + Unpin>(
232 stream: &mut S,
233 ) -> Result<http::Request<()>, WsError> {
234 let mut req_bytes = BytesMut::with_capacity(1024);
235 let mut buf = [0u8];
236 loop {
237 stream.read_exact(&mut buf).await?;
238 req_bytes.put_u8(buf[0]);
239 if req_bytes.ends_with(&[b'\r', b'\n', b'\r', b'\n']) {
240 break;
241 }
242 }
243 handle_parse_handshake(req_bytes)
244 }
245}
246
247#[cfg(feature = "async")]
248pub use non_blocking::*;
249
250pub fn gen_key() -> String {
252 let r: [u8; 16] = rand::random();
253 base64::encode(r)
254}
255
256pub fn cal_accept_key(source: &[u8]) -> String {
258 let mut sha1 = sha1::Sha1::default();
259 sha1.update(source);
260 sha1.update(GUID);
261 base64::encode(sha1.finalize())
262}
263
264pub fn standard_handshake_resp_check(key: &[u8], resp: &http::Response<()>) -> Result<(), WsError> {
269 tracing::debug!("handshake response {:?}", resp);
270 if resp.status() != http::StatusCode::SWITCHING_PROTOCOLS {
271 return Err(WsError::HandShakeFailed(format!(
272 "expect 101 response, got {}",
273 resp.status()
274 )));
275 }
276 let expect_key = cal_accept_key(key);
277 if let Some(accept_key) = resp.headers().get("sec-websocket-accept") {
278 if accept_key.to_str().unwrap_or_default() != expect_key {
279 return Err(WsError::HandShakeFailed("mismatch key".to_string()));
280 }
281 } else {
282 return Err(WsError::HandShakeFailed(
283 "missing `sec-websocket-accept` header".to_string(),
284 ));
285 }
286 Ok(())
287}
288
289pub fn standard_handshake_req_check(req: &http::Request<()>) -> Result<(), WsError> {
291 if let Some(val) = req.headers().get("upgrade") {
292 if val != "websocket" {
293 return Err(WsError::HandShakeFailed(format!(
294 "expect `websocket`, got {val:?}"
295 )));
296 }
297 } else {
298 return Err(WsError::HandShakeFailed(
299 "missing `upgrade` header".to_string(),
300 ));
301 }
302
303 if let Some(val) = req.headers().get("sec-websocket-key") {
304 if val.is_empty() {
305 return Err(WsError::HandShakeFailed(
306 "empty sec-websocket-key".to_string(),
307 ));
308 }
309 } else {
310 return Err(WsError::HandShakeFailed(
311 "missing `sec-websocket-key` header".to_string(),
312 ));
313 }
314 Ok(())
315}
316
317pub fn prepare_handshake(
321 protocols: &[String],
322 extensions: &[String],
323 extra_headers: HashMap<String, String>,
324 uri: &http::Uri,
325 version: u8,
326) -> (String, String) {
327 let key = gen_key();
328 let mut headers = vec![
329 format!(
330 "Host: {}{}",
331 uri.host().unwrap_or_default(),
332 uri.port_u16().map(|p| format!(":{p}")).unwrap_or_default()
333 ),
334 "Upgrade: websocket".to_string(),
335 "Connection: Upgrade".to_string(),
336 format!("Sec-Websocket-Key: {key}"),
337 format!("Sec-WebSocket-Version: {version}"),
338 ];
339 for pro in protocols {
340 headers.push(format!("Sec-WebSocket-Protocol: {pro}"))
341 }
342 for ext in extensions {
343 headers.push(format!("Sec-WebSocket-Extensions: {ext}"))
344 }
345 for (k, v) in extra_headers.iter() {
346 headers.push(format!("{k}: {v}"));
347 }
348 let req_str = format!(
349 "{method} {path} {version:?}\r\n{headers}\r\n\r\n",
350 method = http::Method::GET,
351 path = uri
352 .path_and_query()
353 .map(|full_path| full_path.to_string())
354 .unwrap_or_default(),
355 version = http::Version::HTTP_11,
356 headers = headers.join("\r\n")
357 );
358 tracing::debug!("handshake request\n{}", req_str);
359 (key, req_str)
360}
361
362pub fn perform_parse_req(
364 read_bytes: BytesMut,
365 key: String,
366) -> Result<(String, http::Response<()>), WsError> {
367 let mut headers = [httparse::EMPTY_HEADER; 64];
368 let mut resp = httparse::Response::new(&mut headers);
369 let _parse_status = resp
370 .parse(&read_bytes)
371 .map_err(|_| WsError::HandShakeFailed("invalid response".to_string()))?;
372 let mut resp_builder = http::Response::builder()
373 .status(resp.code.unwrap_or_default())
374 .version(match resp.version.unwrap_or(1) {
375 0 => http::Version::HTTP_10,
376 1 => http::Version::HTTP_11,
377 v => {
378 tracing::warn!("unknown http 1.{} version", v);
379 http::Version::HTTP_11
380 }
381 });
382 for header in resp.headers.iter() {
383 resp_builder = resp_builder.header(header.name, header.value);
384 }
385 tracing::debug!("protocol handshake complete");
386 Ok((key, resp_builder.body(()).unwrap()))
387}
388
389pub fn handle_parse_handshake(req_bytes: BytesMut) -> Result<http::Request<()>, WsError> {
391 let mut headers = [httparse::EMPTY_HEADER; 64];
392 let mut req = httparse::Request::new(&mut headers);
393 let _parse_status = req
394 .parse(&req_bytes)
395 .map_err(|_| WsError::HandShakeFailed("invalid request".to_string()))?;
396 let mut req_builder = http::Request::builder()
397 .method(req.method.unwrap_or_default())
398 .uri(req.path.unwrap_or_default())
399 .version(match req.version.unwrap_or(1) {
400 0 => http::Version::HTTP_10,
401 1 => http::Version::HTTP_11,
402 v => {
403 tracing::warn!("unknown http 1.{} version", v);
404 http::Version::HTTP_11
405 }
406 });
407 for header in req.headers.iter() {
408 req_builder = req_builder.header(header.name, header.value);
409 }
410 req_builder
411 .body(())
412 .map_err(|e| WsError::HandShakeFailed(e.to_string()))
413}