tide_websockets/
handler.rs

1use std::future::Future;
2use std::marker::{PhantomData, Send};
3
4use crate::async_tungstenite::WebSocketStream;
5use crate::tungstenite::protocol::Role;
6use crate::WebSocketConnection;
7
8use async_dup::Arc;
9use async_std::task;
10use sha1::{Digest, Sha1};
11
12use tide::http::format_err;
13use tide::http::headers::{HeaderName, CONNECTION, UPGRADE};
14use tide::{Middleware, Request, Response, Result, StatusCode};
15
16const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
17
18/// # endpoint/middleware handler for websockets in tide
19///
20/// This can either be used as a middleware or as an
21/// endpoint. Regardless of which approach is taken, the handler
22/// function provided to [`WebSocket::new`] is only called if the
23/// request correctly negotiates an upgrade to the websocket protocol.
24///
25/// ## As a middleware
26///
27/// If used as a middleware, the endpoint will be executed if the
28/// request is not a websocket upgrade.
29///
30/// ### Example
31///
32/// ```rust
33/// use async_std::prelude::*;
34/// use tide_websockets::{Message, WebSocket};
35///
36/// #[async_std::main]
37/// async fn main() -> Result<(), std::io::Error> {
38///     let mut app = tide::new();
39///
40///     app.at("/ws")
41///         .with(WebSocket::new(|_request, mut stream| async move {
42///             while let Some(Ok(Message::Text(input))) = stream.next().await {
43///                 let output: String = input.chars().rev().collect();
44///
45///                 stream
46///                     .send_string(format!("{} | {}", &input, &output))
47///                     .await?;
48///             }
49///
50///             Ok(())
51///         }))
52///        .get(|_| async move { Ok("this was not a websocket request") });
53///
54/// # if false {
55///     app.listen("127.0.0.1:8080").await?;
56/// # }
57///     Ok(())
58/// }
59/// ```
60///
61/// ## As an endpoint
62///
63/// If used as an endpoint but the request is
64/// not a websocket request, tide will reply with a `426 Upgrade
65/// Required` status code.
66///
67/// ### example
68///
69/// ```rust
70/// use async_std::prelude::*;
71/// use tide_websockets::{Message, WebSocket};
72///
73/// #[async_std::main]
74/// async fn main() -> Result<(), std::io::Error> {
75///     let mut app = tide::new();
76///
77///     app.at("/ws")
78///         .get(WebSocket::new(|_request, mut stream| async move {
79///             while let Some(Ok(Message::Text(input))) = stream.next().await {
80///                 let output: String = input.chars().rev().collect();
81///
82///                 stream
83///                     .send_string(format!("{} | {}", &input, &output))
84///                     .await?;
85///             }
86///
87///             Ok(())
88///         }));
89///
90/// # if false {
91///     app.listen("127.0.0.1:8080").await?;
92/// # }
93///     Ok(())
94/// }
95/// ```
96///
97#[derive(Debug)]
98pub struct WebSocket<S, H> {
99    handler: Arc<H>,
100    ghostly_apparition: PhantomData<S>,
101    protocols: Vec<String>,
102}
103
104enum UpgradeStatus<S> {
105    Upgraded(Result<Response>),
106    NotUpgraded(Request<S>),
107}
108use UpgradeStatus::{NotUpgraded, Upgraded};
109
110fn header_contains_ignore_case<T>(req: &Request<T>, header_name: HeaderName, value: &str) -> bool {
111    req.header(header_name)
112        .map(|h| {
113            h.as_str()
114                .split(',')
115                .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
116        })
117        .unwrap_or(false)
118}
119
120impl<S, H, Fut> WebSocket<S, H>
121where
122    S: Send + Sync + Clone + 'static,
123    H: Fn(Request<S>, WebSocketConnection) -> Fut + Sync + Send + 'static,
124    Fut: Future<Output = Result<()>> + Send + 'static,
125{
126    /// Build a new WebSocket with a handler function that
127    pub fn new(handler: H) -> Self {
128        Self {
129            handler: Arc::new(handler),
130            ghostly_apparition: PhantomData,
131            protocols: Default::default(),
132        }
133    }
134
135    /// `protocols` is a sequence of known protocols. On successful handshake,
136    /// the returned response headers contain the first protocol in this list
137    /// which the server also knows.
138    pub fn with_protocols(self, protocols: &[&str]) -> Self {
139        Self {
140            protocols: protocols.iter().map(ToString::to_string).collect(),
141            ..self
142        }
143    }
144
145    async fn handle_upgrade(&self, req: Request<S>) -> UpgradeStatus<S> {
146        let connection_upgrade = header_contains_ignore_case(&req, CONNECTION, "upgrade");
147        let upgrade_to_websocket = header_contains_ignore_case(&req, UPGRADE, "websocket");
148        let upgrade_requested = connection_upgrade && upgrade_to_websocket;
149
150        if !upgrade_requested {
151            return NotUpgraded(req);
152        }
153
154        let header = match req.header("Sec-Websocket-Key") {
155            Some(h) => h.as_str(),
156            None => return Upgraded(Err(format_err!("expected sec-websocket-key"))),
157        };
158
159        let protocol = req.header("Sec-Websocket-Protocol").and_then(|value| {
160            value
161                .as_str()
162                .split(',')
163                .map(str::trim)
164                .find(|req_p| self.protocols.iter().any(|p| p == req_p))
165        });
166
167        let mut response = Response::new(StatusCode::SwitchingProtocols);
168
169        response.insert_header(UPGRADE, "websocket");
170        response.insert_header(CONNECTION, "Upgrade");
171        let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
172        response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
173        response.insert_header("Sec-Websocket-Version", "13");
174
175        if let Some(protocol) = protocol {
176            response.insert_header("Sec-Websocket-Protocol", protocol);
177        }
178
179        let http_res: &mut tide::http::Response = response.as_mut();
180        let upgrade_receiver = http_res.recv_upgrade().await;
181        let handler = self.handler.clone();
182
183        task::spawn(async move {
184            if let Some(stream) = upgrade_receiver.await {
185                let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
186                handler(req, stream.into()).await
187            } else {
188                Err(format_err!("never received an upgrade!"))
189            }
190        });
191
192        Upgraded(Ok(response))
193    }
194}
195
196#[tide::utils::async_trait]
197impl<H, S, Fut> tide::Endpoint<S> for WebSocket<S, H>
198where
199    H: Fn(Request<S>, WebSocketConnection) -> Fut + Sync + Send + 'static,
200    Fut: Future<Output = Result<()>> + Send + 'static,
201    S: Send + Sync + Clone + 'static,
202{
203    async fn call(&self, req: Request<S>) -> Result {
204        match self.handle_upgrade(req).await {
205            Upgraded(result) => result,
206            NotUpgraded(_) => Ok(Response::new(StatusCode::UpgradeRequired)),
207        }
208    }
209}
210
211#[tide::utils::async_trait]
212impl<H, S, Fut> Middleware<S> for WebSocket<S, H>
213where
214    H: Fn(Request<S>, WebSocketConnection) -> Fut + Sync + Send + 'static,
215    Fut: Future<Output = Result<()>> + Send + 'static,
216    S: Send + Sync + Clone + 'static,
217{
218    async fn handle(&self, req: Request<S>, next: tide::Next<'_, S>) -> Result {
219        match self.handle_upgrade(req).await {
220            Upgraded(result) => result,
221            NotUpgraded(req) => Ok(next.run(req).await),
222        }
223    }
224}