thruster_socketio/
socketio_upgrade.rs1use crypto::digest::Digest;
2use futures_util::sink::SinkExt;
3use futures_util::stream::StreamExt;
4use std::boxed::Box;
5use std::collections::HashMap;
6use std::future::Future;
7use std::pin::Pin;
8use thruster::{Context, MiddlewareResult};
9use tokio::time::{self, Duration};
10use tokio_tungstenite::tungstenite::Message;
11
12use crate::sid::generate_sid;
13use crate::socketio::{
14 InternalMessage, SocketIOSocket, SocketIOWrapper as SocketIO, WSSocketMessage,
15 SOCKETIO_EVENT_OPEN, SOCKETIO_PING,
16};
17use crate::socketio_context::SocketIOContext;
18
19const WEBSOCKET_SEC: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11";
20
21#[derive(Debug, Serialize)]
22#[serde(rename_all = "camelCase")]
23struct HandshakeResponseData {
24 sid: String,
25 upgrades: Vec<String>,
26 ping_interval: usize,
27 ping_timeout: usize,
28}
29
30#[derive(Debug, Serialize)]
31#[serde(rename_all = "camelCase")]
32struct HandshakeResponse {
33 r#type: String,
34 data: HandshakeResponseData,
35}
36
37enum AllowedVersions {
38 V3,
39 V4,
40}
41
42pub async fn handle_io<T: Context + SocketIOContext + Default>(
47 context: T,
48 handler: fn(SocketIOSocket) -> Pin<Box<dyn Future<Output = Result<SocketIOSocket, ()>> + Send>>,
49) -> MiddlewareResult<T> {
50 handle_io_with_capacity(context, handler, 16).await
51}
52
53pub async fn handle_io_with_capacity<T: Context + SocketIOContext + Default>(
55 mut context: T,
56 handler: fn(SocketIOSocket) -> Pin<Box<dyn Future<Output = Result<SocketIOSocket, ()>> + Send>>,
57 message_capacity: usize,
58) -> MiddlewareResult<T> {
59 let param_map = match context.route().split('?').collect::<Vec<&str>>().get(1) {
60 Some(val) => {
61 let mut map = HashMap::new();
62
63 for el in val.split('&') {
64 let mut split = el.split('=');
65
66 map.insert(split.next().unwrap_or(""), split.next().unwrap_or(""));
67 }
68
69 map
70 }
71 None => HashMap::new(),
72 };
73
74 let version = match param_map.get("EIO") {
75 Some(&"4") => AllowedVersions::V4,
76 _ => AllowedVersions::V3,
77 };
78
79 let mut request = context.into_request();
80
81 if request.headers().contains_key(hyper::header::UPGRADE) {
83 let request_accept_key = request
84 .headers()
85 .get("Sec-WebSocket-Key")
86 .unwrap()
87 .to_str()
88 .unwrap();
89 let mut hasher = crypto::sha1::Sha1::new();
90 hasher.input_str(&format!("{}{}", request_accept_key, WEBSOCKET_SEC));
91
92 let mut accept_buffer = vec![0; hasher.output_bits() / 8];
93 hasher.result(&mut accept_buffer);
94 let accept_value = base64::encode(&accept_buffer);
95
96 context = T::default();
97 thruster::Context::status(&mut context, 101);
98 context.set("upgrade", "websocket");
99 context.set("Sec-WebSocket-Accept", &accept_value);
100 context.set("connection", "Upgrade");
101
102 let sid = generate_sid();
103 let body = serde_json::to_string(&HandshakeResponseData {
104 sid: sid.clone(), upgrades: vec!["websocket".to_string()],
106 ping_interval: 25000,
107 ping_timeout: 20000,
108 })
109 .unwrap();
110
111 let encoded_opener = format!("0{}", body);
112
113 tokio::spawn(async move {
115 let upgraded_req = hyper::upgrade::on(&mut request)
116 .await
117 .expect("Could not upgrade request to websocket");
118
119 let ws_stream = tokio_tungstenite::WebSocketStream::from_raw_socket(
120 upgraded_req,
121 tokio_tungstenite::tungstenite::protocol::Role::Server,
122 None,
123 )
124 .await;
125 let (mut ws_sender, mut ws_receiver) = ws_stream.split();
126
127 let _ = ws_sender.send(Message::Text(encoded_opener)).await;
129 match version {
132 AllowedVersions::V3 => {
133 let _ = ws_sender
134 .send(Message::Text(SOCKETIO_EVENT_OPEN.to_string()))
135 .await;
136 }
137 AllowedVersions::V4 => {
138 let _ = ws_sender
139 .send(Message::Text(format!(
140 "{}{{\"sid\":\"{}\"}}",
141 SOCKETIO_EVENT_OPEN,
142 sid.clone()
143 )))
144 .await;
145 }
146 }
147
148 let mut msg_fut = ws_receiver.next();
149 let socket_wrapper = SocketIO::new(sid.clone(), ws_sender, message_capacity);
150 let sender = socket_wrapper.sender();
151
152 tokio::spawn(async move {
153 socket_wrapper.listen().await;
154 });
155
156 if let AllowedVersions::V4 = version {
158 let keepalive_sender = sender.clone();
159 tokio::spawn(async move {
160 let mut interval = time::interval(Duration::from_millis(25000));
161
162 loop {
163 interval.tick().await;
164
165 let res = keepalive_sender.send(InternalMessage::WS(WSSocketMessage::Pong));
166
167 if res.is_err() {
168 break;
169 }
170 }
171 });
172 };
173
174 let socket = SocketIOSocket::new(sid.clone(), sender.clone());
175 let _ = (handler)(socket)
176 .await
177 .expect("The handler should return a socket");
178
179 loop {
180 match msg_fut.await {
181 Some(Ok(Message::Text(ws_payload))) => {
182 match ws_payload.as_ref() {
184 SOCKETIO_PING => {
185 let _ = sender.send(InternalMessage::WS(WSSocketMessage::Ping));
186 }
187 val => {
188 let _ = sender.send(InternalMessage::WS(
189 WSSocketMessage::RawMessage(val.to_string()),
190 ));
191 }
192 };
193 }
194 Some(Ok(Message::Frame(_ws_payload))) => {
195 }
197 Some(Ok(Message::Binary(_ws_payload))) => {
198 }
200 Some(Ok(Message::Ping(_))) => {
201 let _ = sender.send(InternalMessage::WS(WSSocketMessage::WsPing));
202 break;
203 }
204 Some(Ok(Message::Pong(_))) => {
205 let _ = sender.send(InternalMessage::WS(WSSocketMessage::WsPong));
206 break;
207 }
208 Some(Err(_e)) => {
209 break;
210 }
211 Some(Ok(Message::Close(_e))) => {
212 break;
213 }
214 None => {
215 break;
216 }
217 }
218
219 msg_fut = ws_receiver.next();
220 }
221
222 let _ = sender.send(InternalMessage::WS(WSSocketMessage::Close));
224 });
225
226 Ok(context)
227 } else {
228 let polling_enabled = request
229 .uri()
230 .to_string()
231 .split('?')
232 .nth(1)
233 .map(|query_string| {
234 query_string.split('&').fold(HashMap::new(), |mut acc, x| {
235 let mut pieces = x.split('=');
236 acc.insert(
237 pieces.next().unwrap_or_default(),
238 pieces.next().unwrap_or_default(),
239 );
240
241 acc
242 })
243 })
244 .unwrap_or_default()
245 .get("transport")
246 .map(|v| v.contains("polling"))
247 .unwrap_or(false);
248
249 context = T::default();
250 if !polling_enabled {
251 thruster::Context::status(&mut context, 400);
252 context.set_body(
253 "Polling transport disabled, but no upgrade header for websocket."
254 .as_bytes()
255 .to_vec(),
256 );
257
258 Ok(context)
259 } else {
260 context.set_body(
261 "Polling transport is not implemented yet."
262 .as_bytes()
263 .to_vec(),
264 );
265 thruster::Context::status(&mut context, 400);
266
267 Ok(context)
268 }
269 }
270}