ws_tool/
protocol.rs

1use http;
2use bytes::BytesMut;
3use sha1::Digest;
4use std::collections::HashMap;
5use std::fmt::Debug;
6
7use crate::errors::WsError;
8
9const GUID: &[u8] = b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
10
11/// helper struct for using close code
12pub struct StatusCode;
13
14impl StatusCode {
15    /// 1000 indicates a normal closure, meaning that the purpose for
16    /// which the connection was established has been fulfilled.
17    pub fn normal() -> u16 {
18        1000
19    }
20
21    /// 1001 indicates that an endpoint is "going away", such as a server
22    /// going down or a browser having navigated away from a page.
23    pub fn going_away() -> u16 {
24        1001
25    }
26
27    /// 1002 indicates that an endpoint is terminating the connection due
28    /// to a protocol error.
29    pub fn protocol_error() -> u16 {
30        1002
31    }
32
33    /// 1003 indicates that an endpoint is terminating the connection
34    /// because it has received a type of data it cannot accept (e.g., an
35    /// endpoint that understands only text data MAY send this if it
36    /// receives a binary message).
37    pub fn terminate() -> u16 {
38        1003
39    }
40    /// Reserved.  The specific meaning might be defined in the future.
41    pub fn reserved() -> u16 {
42        1004
43    }
44
45    /// 1005 is a reserved value and MUST NOT be set as a status code in a
46    /// Close control frame by an endpoint.  It is designated for use in
47    /// applications expecting a status code to indicate that no status
48    /// code was actually present.
49    pub fn app_reserved() -> u16 {
50        1005
51    }
52
53    /// 1006 is a reserved value and MUST NOT be set as a status code in a
54    /// Close control frame by an endpoint.  It is designated for use in
55    /// applications expecting a status code to indicate that the
56    /// connection was closed abnormally, e.g., without sending or
57    /// receiving a Close control frame.
58    pub fn abnormal_reserved() -> u16 {
59        1006
60    }
61
62    /// 1007 indicates that an endpoint is terminating the connection
63    /// because it has received data within a message that was not
64    /// consistent with the type of the message (e.g., non-UTF-8 \[RFC3629\]
65    /// data within a text message).
66    pub fn non_consistent() -> u16 {
67        1007
68    }
69
70    /// 1008 indicates that an endpoint is terminating the connection
71    /// because it has received a message that violates its policy.  This
72    /// is a generic status code that can be returned when there is no
73    /// other more suitable status code (e.g., 1003 or 1009) or if there
74    /// is a need to hide specific details about the policy.
75    pub fn violate_policy() -> u16 {
76        1008
77    }
78
79    /// 1009 indicates that an endpoint is terminating the connection
80    /// because it has received a message that is too big for it to
81    /// process.
82    pub fn too_big() -> u16 {
83        1009
84    }
85
86    /// 1010 indicates that an endpoint (client) is terminating the
87    /// connection because it has expected the server to negotiate one or
88    /// more extension, but the server didn't return them in the response
89    /// message of the WebSocket handshake.  The list of extensions that
90    /// are needed SHOULD appear in the /reason/ part of the Close frame.
91    /// Note that this status code is not used by the server, because it
92    /// can fail the WebSocket handshake instead.
93    pub fn require_ext() -> u16 {
94        1010
95    }
96
97    /// 1011 indicates that a server is terminating the connection because
98    /// it encountered an unexpected condition that prevented it from
99    /// fulfilling the request.
100    pub fn unexpected_condition() -> u16 {
101        1011
102    }
103
104    /// 1015 is a reserved value and MUST NOT be set as a status code in a
105    /// Close control frame by an endpoint.  It is designated for use in
106    /// applications expecting a status code to indicate that the
107    /// connection was closed due to a failure to perform a TLS handshake
108    /// (e.g., the server certificate can't be verified).
109    pub fn platform_fail() -> u16 {
110        1015
111    }
112}
113
114/// websocket connection mode
115#[derive(Debug, PartialEq, Eq)]
116pub enum Mode {
117    /// plain mode `ws://great.nice`
118    WS,
119    /// tls mode `wss://secret.wow`
120    WSS,
121}
122
123impl Mode {
124    /// return corresponding port of websocket mode
125    pub fn default_port(&self) -> u16 {
126        match self {
127            Mode::WS => 80,
128            Mode::WSS => 443,
129        }
130    }
131}
132
133#[cfg(feature = "sync")]
134mod blocking {
135    use http;
136    use std::{
137        collections::HashMap,
138        io::{Read, Write},
139    };
140
141    use bytes::{BufMut, BytesMut};
142
143    use crate::errors::WsError;
144
145    use super::{handle_parse_handshake, perform_parse_req, prepare_handshake};
146
147    /// perform http upgrade
148    ///
149    /// **NOTE**: low level api
150    pub fn req_handshake<S: Read + Write>(
151        stream: &mut S,
152        uri: &http::Uri,
153        protocols: &[String],
154        extensions: &[String],
155        version: u8,
156        extra_headers: HashMap<String, String>,
157    ) -> Result<(String, http::Response<()>), WsError> {
158        let (key, req_str) = prepare_handshake(protocols, extensions, extra_headers, uri, version);
159        stream.write_all(req_str.as_bytes())?;
160        stream.flush()?;
161        let mut read_bytes = BytesMut::with_capacity(1024);
162        let mut buf: [u8; 1] = [0; 1];
163        loop {
164            stream.read_exact(&mut buf)?;
165            read_bytes.put_u8(buf[0]);
166            let header_complete = read_bytes.ends_with(&[b'\r', b'\n', b'\r', b'\n']);
167            if header_complete {
168                break;
169            }
170        }
171        perform_parse_req(read_bytes, key)
172    }
173
174    /// handle protocol handshake
175    pub fn handle_handshake<S: Read + Write>(stream: &mut S) -> Result<http::Request<()>, WsError> {
176        let mut req_bytes = BytesMut::with_capacity(1024);
177        let mut buf = [0u8];
178        loop {
179            stream.read_exact(&mut buf)?;
180            req_bytes.put_u8(buf[0]);
181            if req_bytes.ends_with(&[b'\r', b'\n', b'\r', b'\n']) {
182                break;
183            }
184        }
185        handle_parse_handshake(req_bytes)
186    }
187}
188
189#[cfg(feature = "sync")]
190pub use blocking::*;
191
192#[cfg(feature = "async")]
193mod non_blocking {
194    use http;
195    use std::collections::HashMap;
196
197    use bytes::{BufMut, BytesMut};
198    use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
199
200    use crate::{errors::WsError, protocol::prepare_handshake};
201
202    use super::{handle_parse_handshake, perform_parse_req};
203
204    /// perform http upgrade
205    ///
206    /// **NOTE**: low level api
207    pub async fn async_req_handshake<S: AsyncRead + AsyncWrite + Unpin>(
208        stream: &mut S,
209        uri: &http::Uri,
210        protocols: &[String],
211        extensions: &[String],
212        version: u8,
213        extra_headers: HashMap<String, String>,
214    ) -> Result<(String, http::Response<()>), WsError> {
215        let (key, req_str) = prepare_handshake(protocols, extensions, extra_headers, uri, version);
216        stream.write_all(req_str.as_bytes()).await?;
217        let mut read_bytes = BytesMut::with_capacity(1024);
218        let mut buf = [0u8];
219        loop {
220            stream.read_exact(&mut buf).await?;
221            read_bytes.put_u8(buf[0]);
222            let header_complete = read_bytes.ends_with(&[b'\r', b'\n', b'\r', b'\n']);
223            if header_complete {
224                break;
225            }
226        }
227        perform_parse_req(read_bytes, key)
228    }
229
230    /// async version of handling protocol handshake
231    pub async fn async_handle_handshake<S: AsyncRead + AsyncWrite + Unpin>(
232        stream: &mut S,
233    ) -> Result<http::Request<()>, WsError> {
234        let mut req_bytes = BytesMut::with_capacity(1024);
235        let mut buf = [0u8];
236        loop {
237            stream.read_exact(&mut buf).await?;
238            req_bytes.put_u8(buf[0]);
239            if req_bytes.ends_with(&[b'\r', b'\n', b'\r', b'\n']) {
240                break;
241            }
242        }
243        handle_parse_handshake(req_bytes)
244    }
245}
246
247#[cfg(feature = "async")]
248pub use non_blocking::*;
249
250/// generate random key
251pub fn gen_key() -> String {
252    let r: [u8; 16] = rand::random();
253    base64::encode(r)
254}
255
256/// cal accept key
257pub fn cal_accept_key(source: &[u8]) -> String {
258    let mut sha1 = sha1::Sha1::default();
259    sha1.update(source);
260    sha1.update(GUID);
261    base64::encode(sha1.finalize())
262}
263
264/// perform standard protocol handshake response check
265///
266/// 1. check status code
267/// 2. check `sec-websocket-accept` header & value
268pub fn standard_handshake_resp_check(key: &[u8], resp: &http::Response<()>) -> Result<(), WsError> {
269    tracing::debug!("handshake response {:?}", resp);
270    if resp.status() != http::StatusCode::SWITCHING_PROTOCOLS {
271        return Err(WsError::HandShakeFailed(format!(
272            "expect 101 response, got {}",
273            resp.status()
274        )));
275    }
276    let expect_key = cal_accept_key(key);
277    if let Some(accept_key) = resp.headers().get("sec-websocket-accept") {
278        if accept_key.to_str().unwrap_or_default() != expect_key {
279            return Err(WsError::HandShakeFailed("mismatch key".to_string()));
280        }
281    } else {
282        return Err(WsError::HandShakeFailed(
283            "missing `sec-websocket-accept` header".to_string(),
284        ));
285    }
286    Ok(())
287}
288
289/// perform rfc standard check
290pub fn standard_handshake_req_check(req: &http::Request<()>) -> Result<(), WsError> {
291    if let Some(val) = req.headers().get("upgrade") {
292        if val != "websocket" {
293            return Err(WsError::HandShakeFailed(format!(
294                "expect `websocket`, got {val:?}"
295            )));
296        }
297    } else {
298        return Err(WsError::HandShakeFailed(
299            "missing `upgrade` header".to_string(),
300        ));
301    }
302
303    if let Some(val) = req.headers().get("sec-websocket-key") {
304        if val.is_empty() {
305            return Err(WsError::HandShakeFailed(
306                "empty sec-websocket-key".to_string(),
307            ));
308        }
309    } else {
310        return Err(WsError::HandShakeFailed(
311            "missing `sec-websocket-key` header".to_string(),
312        ));
313    }
314    Ok(())
315}
316
317/// build protocol http reqeust
318///
319/// return (key, request_str)
320pub fn prepare_handshake(
321    protocols: &[String],
322    extensions: &[String],
323    extra_headers: HashMap<String, String>,
324    uri: &http::Uri,
325    version: u8,
326) -> (String, String) {
327    let key = gen_key();
328    let mut headers = vec![
329        format!(
330            "Host: {}{}",
331            uri.host().unwrap_or_default(),
332            uri.port_u16().map(|p| format!(":{p}")).unwrap_or_default()
333        ),
334        "Upgrade: websocket".to_string(),
335        "Connection: Upgrade".to_string(),
336        format!("Sec-Websocket-Key: {key}"),
337        format!("Sec-WebSocket-Version: {version}"),
338    ];
339    for pro in protocols {
340        headers.push(format!("Sec-WebSocket-Protocol: {pro}"))
341    }
342    for ext in extensions {
343        headers.push(format!("Sec-WebSocket-Extensions: {ext}"))
344    }
345    for (k, v) in extra_headers.iter() {
346        headers.push(format!("{k}: {v}"));
347    }
348    let req_str = format!(
349        "{method} {path} {version:?}\r\n{headers}\r\n\r\n",
350        method = http::Method::GET,
351        path = uri
352            .path_and_query()
353            .map(|full_path| full_path.to_string())
354            .unwrap_or_default(),
355        version = http::Version::HTTP_11,
356        headers = headers.join("\r\n")
357    );
358    tracing::debug!("handshake request\n{}", req_str);
359    (key, req_str)
360}
361
362/// parse protocol response
363pub fn perform_parse_req(
364    read_bytes: BytesMut,
365    key: String,
366) -> Result<(String, http::Response<()>), WsError> {
367    let mut headers = [httparse::EMPTY_HEADER; 64];
368    let mut resp = httparse::Response::new(&mut headers);
369    let _parse_status = resp
370        .parse(&read_bytes)
371        .map_err(|_| WsError::HandShakeFailed("invalid response".to_string()))?;
372    let mut resp_builder = http::Response::builder()
373        .status(resp.code.unwrap_or_default())
374        .version(match resp.version.unwrap_or(1) {
375            0 => http::Version::HTTP_10,
376            1 => http::Version::HTTP_11,
377            v => {
378                tracing::warn!("unknown http 1.{} version", v);
379                http::Version::HTTP_11
380            }
381        });
382    for header in resp.headers.iter() {
383        resp_builder = resp_builder.header(header.name, header.value);
384    }
385    tracing::debug!("protocol handshake complete");
386    Ok((key, resp_builder.body(()).unwrap()))
387}
388
389/// parse http request, used by server building
390pub fn handle_parse_handshake(req_bytes: BytesMut) -> Result<http::Request<()>, WsError> {
391    let mut headers = [httparse::EMPTY_HEADER; 64];
392    let mut req = httparse::Request::new(&mut headers);
393    let _parse_status = req
394        .parse(&req_bytes)
395        .map_err(|_| WsError::HandShakeFailed("invalid request".to_string()))?;
396    let mut req_builder = http::Request::builder()
397        .method(req.method.unwrap_or_default())
398        .uri(req.path.unwrap_or_default())
399        .version(match req.version.unwrap_or(1) {
400            0 => http::Version::HTTP_10,
401            1 => http::Version::HTTP_11,
402            v => {
403                tracing::warn!("unknown http 1.{} version", v);
404                http::Version::HTTP_11
405            }
406        });
407    for header in req.headers.iter() {
408        req_builder = req_builder.header(header.name, header.value);
409    }
410    req_builder
411        .body(())
412        .map_err(|e| WsError::HandShakeFailed(e.to_string()))
413}