thruster_socketio/
socketio_upgrade.rs

1use crypto::digest::Digest;
2use futures_util::sink::SinkExt;
3use futures_util::stream::StreamExt;
4use std::boxed::Box;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use thruster::{Context, MiddlewareResult};
9use tokio::time::{self, Duration};
10use tokio_tungstenite::tungstenite::Message;
11
12use crate::sid::generate_sid;
13use crate::socketio::{
14    InternalMessage, SocketIOSocket, SocketIOWrapper as SocketIO, WSSocketMessage,
15    SOCKETIO_EVENT_OPEN, SOCKETIO_PING,
16};
17use crate::socketio_context::SocketIOContext;
18
19const WEBSOCKET_SEC: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
20
21#[derive(Debug, Serialize)]
22#[serde(rename_all = "camelCase")]
23struct HandshakeResponseData {
24    sid: String,
25    upgrades: Vec<String>,
26    ping_interval: usize,
27    ping_timeout: usize,
28}
29
30#[derive(Debug, Serialize)]
31#[serde(rename_all = "camelCase")]
32struct HandshakeResponse {
33    r#type: String,
34    data: HandshakeResponseData,
35}
36
37enum AllowedVersions {
38    V3,
39    V4,
40}
41
42/// Handles any incoming socket.io requests for a particular context by using the passed in handler.
43///
44/// Defaults to a maximum message capacity of 16. If there are more connections, then messages can
45/// (and will!) be dropped.
46pub async fn handle_io<T: Context + SocketIOContext + Default>(
47    context: T,
48    handler: fn(SocketIOSocket) -> Pin<Box<dyn Future<Output = Result<SocketIOSocket, ()>> + Send>>,
49) -> MiddlewareResult<T> {
50    handle_io_with_capacity(context, handler, 16).await
51}
52
53/// Handles any incoming socket.io requests for a particular context by using the passed in handler.
54pub async fn handle_io_with_capacity<T: Context + SocketIOContext + Default>(
55    mut context: T,
56    handler: fn(SocketIOSocket) -> Pin<Box<dyn Future<Output = Result<SocketIOSocket, ()>> + Send>>,
57    message_capacity: usize,
58) -> MiddlewareResult<T> {
59    let param_map = match context.route().split('?').collect::<Vec<&str>>().get(1) {
60        Some(val) => {
61            let mut map = HashMap::new();
62
63            for el in val.split('&') {
64                let mut split = el.split('=');
65
66                map.insert(split.next().unwrap_or(""), split.next().unwrap_or(""));
67            }
68
69            map
70        }
71        None => HashMap::new(),
72    };
73
74    let version = match param_map.get("EIO") {
75        Some(&"4") => AllowedVersions::V4,
76        _ => AllowedVersions::V3,
77    };
78
79    let mut request = context.into_request();
80
81    // Theoretically should check this and the transport query param
82    if request.headers().contains_key(hyper::header::UPGRADE) {
83        let request_accept_key = request
84            .headers()
85            .get("Sec-WebSocket-Key")
86            .unwrap()
87            .to_str()
88            .unwrap();
89        let mut hasher = crypto::sha1::Sha1::new();
90        hasher.input_str(&format!("{}{}", request_accept_key, WEBSOCKET_SEC));
91
92        let mut accept_buffer = vec![0; hasher.output_bits() / 8];
93        hasher.result(&mut accept_buffer);
94        let accept_value = base64::encode(&accept_buffer);
95
96        context = T::default();
97        thruster::Context::status(&mut context, 101);
98        context.set("upgrade", "websocket");
99        context.set("Sec-WebSocket-Accept", &accept_value);
100        context.set("connection", "Upgrade");
101
102        let sid = generate_sid();
103        let body = serde_json::to_string(&HandshakeResponseData {
104            sid: sid.clone(), // must be unique
105            upgrades: vec!["websocket".to_string()],
106            ping_interval: 25000,
107            ping_timeout: 20000,
108        })
109        .unwrap();
110
111        let encoded_opener = format!("0{}", body);
112
113        // Spawn a separate future to handle this connection
114        tokio::spawn(async move {
115            let upgraded_req = hyper::upgrade::on(&mut request)
116                .await
117                .expect("Could not upgrade request to websocket");
118
119            let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
120                upgraded_req,
121                tokio_tungstenite::tungstenite::protocol::Role::Server,
122                None,
123            )
124            .await;
125            let (mut ws_sender, mut ws_receiver) = ws_stream.split();
126
127            // TODO(trezm): Handle errors here
128            let _ = ws_sender.send(Message::Text(encoded_opener)).await;
129            // TODO(trezm): Handle errors here
130
131            match version {
132                AllowedVersions::V3 => {
133                    let _ = ws_sender
134                        .send(Message::Text(SOCKETIO_EVENT_OPEN.to_string()))
135                        .await;
136                }
137                AllowedVersions::V4 => {
138                    let _ = ws_sender
139                        .send(Message::Text(format!(
140                            "{}{{\"sid\":\"{}\"}}",
141                            SOCKETIO_EVENT_OPEN,
142                            sid.clone()
143                        )))
144                        .await;
145                }
146            }
147
148            let mut msg_fut = ws_receiver.next();
149            let socket_wrapper = SocketIO::new(sid.clone(), ws_sender, message_capacity);
150            let sender = socket_wrapper.sender();
151
152            tokio::spawn(async move {
153                socket_wrapper.listen().await;
154            });
155
156            // Keepalive in v4 is the server's responsibility.
157            if let AllowedVersions::V4 = version {
158                let keepalive_sender = sender.clone();
159                tokio::spawn(async move {
160                    let mut interval = time::interval(Duration::from_millis(25000));
161
162                    loop {
163                        interval.tick().await;
164
165                        let res = keepalive_sender.send(InternalMessage::WS(WSSocketMessage::Pong));
166
167                        if res.is_err() {
168                            break;
169                        }
170                    }
171                });
172            };
173
174            let socket = SocketIOSocket::new(sid.clone(), sender.clone());
175            let _ = (handler)(socket)
176                .await
177                .expect("The handler should return a socket");
178
179            loop {
180                match msg_fut.await {
181                    Some(Ok(Message::Text(ws_payload))) => {
182                        // TODO(trezm): Handle errors here
183                        match ws_payload.as_ref() {
184                            SOCKETIO_PING => {
185                                let _ = sender.send(InternalMessage::WS(WSSocketMessage::Ping));
186                            }
187                            val => {
188                                let _ = sender.send(InternalMessage::WS(
189                                    WSSocketMessage::RawMessage(val.to_string()),
190                                ));
191                            }
192                        };
193                    }
194                    Some(Ok(Message::Frame(_ws_payload))) => {
195                        // TODO(trezm): Do this...
196                    }
197                    Some(Ok(Message::Binary(_ws_payload))) => {
198                        // TODO(trezm): Do this...
199                    }
200                    Some(Ok(Message::Ping(_))) => {
201                        let _ = sender.send(InternalMessage::WS(WSSocketMessage::WsPing));
202                        break;
203                    }
204                    Some(Ok(Message::Pong(_))) => {
205                        let _ = sender.send(InternalMessage::WS(WSSocketMessage::WsPong));
206                        break;
207                    }
208                    Some(Err(_e)) => {
209                        break;
210                    }
211                    Some(Ok(Message::Close(_e))) => {
212                        break;
213                    }
214                    None => {
215                        break;
216                    }
217                }
218
219                msg_fut = ws_receiver.next();
220            }
221
222            // Cleanup the socket
223            let _ = sender.send(InternalMessage::WS(WSSocketMessage::Close));
224        });
225
226        Ok(context)
227    } else {
228        let polling_enabled = request
229            .uri()
230            .to_string()
231            .split('?')
232            .nth(1)
233            .map(|query_string| {
234                query_string.split('&').fold(HashMap::new(), |mut acc, x| {
235                    let mut pieces = x.split('=');
236                    acc.insert(
237                        pieces.next().unwrap_or_default(),
238                        pieces.next().unwrap_or_default(),
239                    );
240
241                    acc
242                })
243            })
244            .unwrap_or_default()
245            .get("transport")
246            .map(|v| v.contains("polling"))
247            .unwrap_or(false);
248
249        context = T::default();
250        if !polling_enabled {
251            thruster::Context::status(&mut context, 400);
252            context.set_body(
253                "Polling transport disabled, but no upgrade header for websocket."
254                    .as_bytes()
255                    .to_vec(),
256            );
257
258            Ok(context)
259        } else {
260            context.set_body(
261                "Polling transport is not implemented yet."
262                    .as_bytes()
263                    .to_vec(),
264            );
265            thruster::Context::status(&mut context, 400);
266
267            Ok(context)
268        }
269    }
270}