1#![doc(html_root_url = "https://docs.rs/tsukuyomi-tungstenite/0.2.0")]
4#![deny(
5 missing_docs,
6 missing_debug_implementations,
7 nonstandard_style,
8 rust_2018_idioms,
9 rust_2018_compatibility,
10 unused
11)]
12#![doc(test(attr(deny(deprecated, unused,))))]
13#![forbid(clippy::unimplemented)]
14
15use {
16 futures::IntoFuture,
17 http::Response,
18 tsukuyomi::{error::Error, input::body::UpgradedIo, responder::Responder},
19};
20
21#[doc(no_inline)]
22pub use tungstenite::protocol::{Message, WebSocketConfig};
23
24pub type WebSocketStream = tokio_tungstenite::WebSocketStream<UpgradedIo>;
26
27#[derive(Debug, Clone)]
29pub struct Ws<F> {
30 on_upgrade: F,
31 config: Option<WebSocketConfig>,
32}
33
34impl<F, R> Ws<F>
35where
36 F: Fn(WebSocketStream) -> R + Send + 'static,
37 R: IntoFuture<Item = (), Error = ()>,
38 R::Future: Send + 'static,
39{
40 pub fn new(on_upgrade: F) -> Self {
42 Self {
43 on_upgrade,
44 config: None,
45 }
46 }
47
48 pub fn config(self, config: WebSocketConfig) -> Self {
50 Self {
51 config: Some(config),
52 ..self
53 }
54 }
55}
56
57impl<F, R> Responder for Ws<F>
58where
59 F: Fn(WebSocketStream) -> R + Send + 'static,
60 R: IntoFuture<Item = (), Error = ()>,
61 R::Future: Send + 'static,
62{
63 type Response = Response<()>;
64 type Error = Error;
65 type Respond = self::imp::WsRespond<F>; fn respond(self) -> Self::Respond {
68 self::imp::WsRespond(Some(self))
69 }
70}
71
72mod imp {
73 use {
74 super::{WebSocketStream, Ws},
75 futures::{Future, IntoFuture},
76 http::{
77 header::{
78 CONNECTION, SEC_WEBSOCKET_ACCEPT,
80 SEC_WEBSOCKET_KEY,
81 SEC_WEBSOCKET_VERSION,
82 UPGRADE,
83 },
84 Request, Response, StatusCode,
85 },
86 sha1::{Digest, Sha1},
87 tsukuyomi::{
88 error::HttpError,
89 future::{Poll, TryFuture},
90 input::{
91 body::{RequestBody, UpgradedIo},
92 Input,
93 },
94 },
95 tsukuyomi_server::rt::{DefaultExecutor, Executor},
96 tungstenite::protocol::Role,
97 };
98
99 #[allow(missing_debug_implementations)]
100 pub struct WsRespond<F>(pub(super) Option<Ws<F>>);
101
102 impl<F, R> TryFuture for WsRespond<F>
103 where
104 F: FnOnce(WebSocketStream) -> R + Send + 'static,
105 R: IntoFuture<Item = (), Error = ()>,
106 R::Future: Send + 'static,
107 {
108 type Ok = Response<()>;
109 type Error = tsukuyomi::Error;
110
111 fn poll_ready(&mut self, input: &mut Input<'_>) -> Poll<Self::Ok, Self::Error> {
112 let Ws { on_upgrade, config } =
113 self.0.take().expect("the future has already been polled");
114
115 let accept_hash = handshake(input)?;
116
117 let body = input
118 .locals
119 .remove(&RequestBody::KEY) .ok_or_else(|| {
121 tsukuyomi::error::internal_server_error(
122 "the request body has already been stolen by someone",
123 )
124 })?;
125
126 let task = body
127 .on_upgrade()
128 .map_err(|e| log::error!("failed to upgrade the request: {}", e))
129 .and_then(move |io: UpgradedIo| {
130 let transport = WebSocketStream::from_raw_socket(io, Role::Server, config);
131 on_upgrade(transport).into_future()
132 });
133
134 DefaultExecutor::current()
135 .spawn(Box::new(task))
136 .map_err(tsukuyomi::error::internal_server_error)?;
137
138 Ok(Response::builder()
139 .status(StatusCode::SWITCHING_PROTOCOLS)
140 .header(UPGRADE, "websocket")
141 .header(CONNECTION, "upgrade")
142 .header(SEC_WEBSOCKET_ACCEPT, &*accept_hash)
143 .body(())
144 .expect("should be a valid response")
145 .into())
146 }
147 }
148
149 #[derive(Debug, failure::Fail)]
150 enum HandshakeError {
151 #[fail(display = "The header is missing: `{}'", name)]
152 MissingHeader { name: &'static str },
153
154 #[fail(display = "The header value is invalid: `{}'", name)]
155 InvalidHeader { name: &'static str },
156
157 #[fail(display = "The value of `Sec-WebSocket-Key` is invalid")]
158 InvalidSecWebSocketKey,
159
160 #[fail(display = "The value of `Sec-WebSocket-Version` must be equal to '13'")]
161 InvalidSecWebSocketVersion,
162 }
163
164 impl HttpError for HandshakeError {
165 type Body = String;
166
167 fn into_response(self, _: &Request<()>) -> Response<Self::Body> {
168 Response::builder()
169 .status(StatusCode::BAD_REQUEST)
170 .body(self.to_string())
171 .expect("should be a valid response")
172 }
173 }
174
175 fn handshake(input: &mut Input<'_>) -> Result<String, HandshakeError> {
176 match input.request.headers().get(UPGRADE) {
177 Some(h) if h.as_bytes().eq_ignore_ascii_case(b"websocket") => (),
178 Some(..) => Err(HandshakeError::InvalidHeader { name: "Upgrade" })?,
179 None => Err(HandshakeError::MissingHeader { name: "Upgrade" })?,
180 }
181
182 match input.request.headers().get(CONNECTION) {
183 Some(h) if h.as_bytes().eq_ignore_ascii_case(b"upgrade") => (),
184 Some(..) => Err(HandshakeError::InvalidHeader { name: "Connection" })?,
185 None => Err(HandshakeError::MissingHeader { name: "Connection" })?,
186 }
187
188 match input.request.headers().get(SEC_WEBSOCKET_VERSION) {
189 Some(h) if h == "13" => {}
190 Some(..) => Err(HandshakeError::InvalidSecWebSocketVersion)?,
191 None => Err(HandshakeError::MissingHeader {
192 name: "Sec-WebSocket-Version",
193 })?,
194 }
195
196 let accept_hash = match input.request.headers().get(SEC_WEBSOCKET_KEY) {
197 Some(h) => {
198 if h.len() != 24 || {
199 h.as_bytes()
200 .into_iter()
201 .any(|&b| !b.is_ascii_alphanumeric() && b != b'+' && b != b'/' && b != b'=')
202 } {
203 Err(HandshakeError::InvalidSecWebSocketKey)?;
204 }
205
206 let mut m = Sha1::new();
207 m.input(h.as_bytes());
208 m.input(&b"258EAFA5-E914-47DA-95CA-C5AB0DC85B11"[..]);
209 base64::encode(&*m.result())
210 }
211 None => Err(HandshakeError::MissingHeader {
212 name: "Sec-WebSocket-Key",
213 })?,
214 };
215
216 Ok(accept_hash)
219 }
220}