tsukuyomi_tungstenite/
lib.rs

1//! The basic WebSocket support for Tsukuyomi, powered by tungstenite.
2
3#![doc(html_root_url = "https://docs.rs/tsukuyomi-tungstenite/0.2.0")]
4#![deny(
5    missing_docs,
6    missing_debug_implementations,
7    nonstandard_style,
8    rust_2018_idioms,
9    rust_2018_compatibility,
10    unused
11)]
12#![doc(test(attr(deny(deprecated, unused,))))]
13#![forbid(clippy::unimplemented)]
14
15use {
16    futures::IntoFuture,
17    http::Response,
18    tsukuyomi::{error::Error, input::body::UpgradedIo, responder::Responder},
19};
20
21#[doc(no_inline)]
22pub use tungstenite::protocol::{Message, WebSocketConfig};
23
24/// A transport for exchanging data frames with the peer.
25pub type WebSocketStream = tokio_tungstenite::WebSocketStream<UpgradedIo>;
26
27/// A `Responder` that handles an WebSocket connection.
28#[derive(Debug, Clone)]
29pub struct Ws<F> {
30    on_upgrade: F,
31    config: Option<WebSocketConfig>,
32}
33
34impl<F, R> Ws<F>
35where
36    F: Fn(WebSocketStream) -> R + Send + 'static,
37    R: IntoFuture<Item = (), Error = ()>,
38    R::Future: Send + 'static,
39{
40    /// Crates a `Ws` with the specified closure.
41    pub fn new(on_upgrade: F) -> Self {
42        Self {
43            on_upgrade,
44            config: None,
45        }
46    }
47
48    /// Sets the configuration of upgraded WebSocket connection.
49    pub fn config(self, config: WebSocketConfig) -> Self {
50        Self {
51            config: Some(config),
52            ..self
53        }
54    }
55}
56
57impl<F, R> Responder for Ws<F>
58where
59    F: Fn(WebSocketStream) -> R + Send + 'static,
60    R: IntoFuture<Item = (), Error = ()>,
61    R::Future: Send + 'static,
62{
63    type Response = Response<()>;
64    type Error = Error;
65    type Respond = self::imp::WsRespond<F>; // private
66
67    fn respond(self) -> Self::Respond {
68        self::imp::WsRespond(Some(self))
69    }
70}
71
72mod imp {
73    use {
74        super::{WebSocketStream, Ws},
75        futures::{Future, IntoFuture},
76        http::{
77            header::{
78                CONNECTION, //
79                SEC_WEBSOCKET_ACCEPT,
80                SEC_WEBSOCKET_KEY,
81                SEC_WEBSOCKET_VERSION,
82                UPGRADE,
83            },
84            Request, Response, StatusCode,
85        },
86        sha1::{Digest, Sha1},
87        tsukuyomi::{
88            error::HttpError,
89            future::{Poll, TryFuture},
90            input::{
91                body::{RequestBody, UpgradedIo},
92                Input,
93            },
94        },
95        tsukuyomi_server::rt::{DefaultExecutor, Executor},
96        tungstenite::protocol::Role,
97    };
98
99    #[allow(missing_debug_implementations)]
100    pub struct WsRespond<F>(pub(super) Option<Ws<F>>);
101
102    impl<F, R> TryFuture for WsRespond<F>
103    where
104        F: FnOnce(WebSocketStream) -> R + Send + 'static,
105        R: IntoFuture<Item = (), Error = ()>,
106        R::Future: Send + 'static,
107    {
108        type Ok = Response<()>;
109        type Error = tsukuyomi::Error;
110
111        fn poll_ready(&mut self, input: &mut Input<'_>) -> Poll<Self::Ok, Self::Error> {
112            let Ws { on_upgrade, config } =
113                self.0.take().expect("the future has already been polled");
114
115            let accept_hash = handshake(input)?;
116
117            let body = input
118                .locals
119                .remove(&RequestBody::KEY) //
120                .ok_or_else(|| {
121                    tsukuyomi::error::internal_server_error(
122                        "the request body has already been stolen by someone",
123                    )
124                })?;
125
126            let task = body
127                .on_upgrade()
128                .map_err(|e| log::error!("failed to upgrade the request: {}", e))
129                .and_then(move |io: UpgradedIo| {
130                    let transport = WebSocketStream::from_raw_socket(io, Role::Server, config);
131                    on_upgrade(transport).into_future()
132                });
133
134            DefaultExecutor::current()
135                .spawn(Box::new(task))
136                .map_err(tsukuyomi::error::internal_server_error)?;
137
138            Ok(Response::builder()
139                .status(StatusCode::SWITCHING_PROTOCOLS)
140                .header(UPGRADE, "websocket")
141                .header(CONNECTION, "upgrade")
142                .header(SEC_WEBSOCKET_ACCEPT, &*accept_hash)
143                .body(())
144                .expect("should be a valid response")
145                .into())
146        }
147    }
148
149    #[derive(Debug, failure::Fail)]
150    enum HandshakeError {
151        #[fail(display = "The header is missing: `{}'", name)]
152        MissingHeader { name: &'static str },
153
154        #[fail(display = "The header value is invalid: `{}'", name)]
155        InvalidHeader { name: &'static str },
156
157        #[fail(display = "The value of `Sec-WebSocket-Key` is invalid")]
158        InvalidSecWebSocketKey,
159
160        #[fail(display = "The value of `Sec-WebSocket-Version` must be equal to '13'")]
161        InvalidSecWebSocketVersion,
162    }
163
164    impl HttpError for HandshakeError {
165        type Body = String;
166
167        fn into_response(self, _: &Request<()>) -> Response<Self::Body> {
168            Response::builder()
169                .status(StatusCode::BAD_REQUEST)
170                .body(self.to_string())
171                .expect("should be a valid response")
172        }
173    }
174
175    fn handshake(input: &mut Input<'_>) -> Result<String, HandshakeError> {
176        match input.request.headers().get(UPGRADE) {
177            Some(h) if h.as_bytes().eq_ignore_ascii_case(b"websocket") => (),
178            Some(..) => Err(HandshakeError::InvalidHeader { name: "Upgrade" })?,
179            None => Err(HandshakeError::MissingHeader { name: "Upgrade" })?,
180        }
181
182        match input.request.headers().get(CONNECTION) {
183            Some(h) if h.as_bytes().eq_ignore_ascii_case(b"upgrade") => (),
184            Some(..) => Err(HandshakeError::InvalidHeader { name: "Connection" })?,
185            None => Err(HandshakeError::MissingHeader { name: "Connection" })?,
186        }
187
188        match input.request.headers().get(SEC_WEBSOCKET_VERSION) {
189            Some(h) if h == "13" => {}
190            Some(..) => Err(HandshakeError::InvalidSecWebSocketVersion)?,
191            None => Err(HandshakeError::MissingHeader {
192                name: "Sec-WebSocket-Version",
193            })?,
194        }
195
196        let accept_hash = match input.request.headers().get(SEC_WEBSOCKET_KEY) {
197            Some(h) => {
198                if h.len() != 24 || {
199                    h.as_bytes()
200                        .into_iter()
201                        .any(|&b| !b.is_ascii_alphanumeric() && b != b'+' && b != b'/' && b != b'=')
202                } {
203                    Err(HandshakeError::InvalidSecWebSocketKey)?;
204                }
205
206                let mut m = Sha1::new();
207                m.input(h.as_bytes());
208                m.input(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
209                base64::encode(&*m.result())
210            }
211            None => Err(HandshakeError::MissingHeader {
212                name: "Sec-WebSocket-Key",
213            })?,
214        };
215
216        // TODO: Sec-WebSocket-Protocol, Sec-WebSocket-Extension
217
218        Ok(accept_hash)
219    }
220}