tide_websockets/
handler.rs1use std::future::Future;
2use std::marker::{PhantomData, Send};
3
4use crate::async_tungstenite::WebSocketStream;
5use crate::tungstenite::protocol::Role;
6use crate::WebSocketConnection;
7
8use async_dup::Arc;
9use async_std::task;
10use sha1::{Digest, Sha1};
11
12use tide::http::format_err;
13use tide::http::headers::{HeaderName, CONNECTION, UPGRADE};
14use tide::{Middleware, Request, Response, Result, StatusCode};
15
16const WEBSOCKET_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
17
18#[derive(Debug)]
98pub struct WebSocket<S, H> {
99 handler: Arc<H>,
100 ghostly_apparition: PhantomData<S>,
101 protocols: Vec<String>,
102}
103
104enum UpgradeStatus<S> {
105 Upgraded(Result<Response>),
106 NotUpgraded(Request<S>),
107}
108use UpgradeStatus::{NotUpgraded, Upgraded};
109
110fn header_contains_ignore_case<T>(req: &Request<T>, header_name: HeaderName, value: &str) -> bool {
111 req.header(header_name)
112 .map(|h| {
113 h.as_str()
114 .split(',')
115 .any(|s| s.trim().eq_ignore_ascii_case(value.trim()))
116 })
117 .unwrap_or(false)
118}
119
120impl<S, H, Fut> WebSocket<S, H>
121where
122 S: Send + Sync + Clone + 'static,
123 H: Fn(Request<S>, WebSocketConnection) -> Fut + Sync + Send + 'static,
124 Fut: Future<Output = Result<()>> + Send + 'static,
125{
126 pub fn new(handler: H) -> Self {
128 Self {
129 handler: Arc::new(handler),
130 ghostly_apparition: PhantomData,
131 protocols: Default::default(),
132 }
133 }
134
135 pub fn with_protocols(self, protocols: &[&str]) -> Self {
139 Self {
140 protocols: protocols.iter().map(ToString::to_string).collect(),
141 ..self
142 }
143 }
144
145 async fn handle_upgrade(&self, req: Request<S>) -> UpgradeStatus<S> {
146 let connection_upgrade = header_contains_ignore_case(&req, CONNECTION, "upgrade");
147 let upgrade_to_websocket = header_contains_ignore_case(&req, UPGRADE, "websocket");
148 let upgrade_requested = connection_upgrade && upgrade_to_websocket;
149
150 if !upgrade_requested {
151 return NotUpgraded(req);
152 }
153
154 let header = match req.header("Sec-Websocket-Key") {
155 Some(h) => h.as_str(),
156 None => return Upgraded(Err(format_err!("expected sec-websocket-key"))),
157 };
158
159 let protocol = req.header("Sec-Websocket-Protocol").and_then(|value| {
160 value
161 .as_str()
162 .split(',')
163 .map(str::trim)
164 .find(|req_p| self.protocols.iter().any(|p| p == req_p))
165 });
166
167 let mut response = Response::new(StatusCode::SwitchingProtocols);
168
169 response.insert_header(UPGRADE, "websocket");
170 response.insert_header(CONNECTION, "Upgrade");
171 let hash = Sha1::new().chain(header).chain(WEBSOCKET_GUID).finalize();
172 response.insert_header("Sec-Websocket-Accept", base64::encode(&hash[..]));
173 response.insert_header("Sec-Websocket-Version", "13");
174
175 if let Some(protocol) = protocol {
176 response.insert_header("Sec-Websocket-Protocol", protocol);
177 }
178
179 let http_res: &mut tide::http::Response = response.as_mut();
180 let upgrade_receiver = http_res.recv_upgrade().await;
181 let handler = self.handler.clone();
182
183 task::spawn(async move {
184 if let Some(stream) = upgrade_receiver.await {
185 let stream = WebSocketStream::from_raw_socket(stream, Role::Server, None).await;
186 handler(req, stream.into()).await
187 } else {
188 Err(format_err!("never received an upgrade!"))
189 }
190 });
191
192 Upgraded(Ok(response))
193 }
194}
195
196#[tide::utils::async_trait]
197impl<H, S, Fut> tide::Endpoint<S> for WebSocket<S, H>
198where
199 H: Fn(Request<S>, WebSocketConnection) -> Fut + Sync + Send + 'static,
200 Fut: Future<Output = Result<()>> + Send + 'static,
201 S: Send + Sync + Clone + 'static,
202{
203 async fn call(&self, req: Request<S>) -> Result {
204 match self.handle_upgrade(req).await {
205 Upgraded(result) => result,
206 NotUpgraded(_) => Ok(Response::new(StatusCode::UpgradeRequired)),
207 }
208 }
209}
210
211#[tide::utils::async_trait]
212impl<H, S, Fut> Middleware<S> for WebSocket<S, H>
213where
214 H: Fn(Request<S>, WebSocketConnection) -> Fut + Sync + Send + 'static,
215 Fut: Future<Output = Result<()>> + Send + 'static,
216 S: Send + Sync + Clone + 'static,
217{
218 async fn handle(&self, req: Request<S>, next: tide::Next<'_, S>) -> Result {
219 match self.handle_upgrade(req).await {
220 Upgraded(result) => result,
221 NotUpgraded(req) => Ok(next.run(req).await),
222 }
223 }
224}