1use std::{
2 pin::Pin,
3 sync::Arc,
4};
5
6use anyhow::{
7 Result,
8 bail,
9};
10use http::HeaderMap;
11use num_enum::{
12 IntoPrimitive,
13 TryFromPrimitive,
14};
15use serde::Serialize;
16use tokio::{
17 select,
18 spawn,
19 sync::{
20 Mutex,
21 mpsc::{
22 Receiver,
23 Sender,
24 channel,
25 },
26 },
27 task::JoinHandle,
28 time::{
29 sleep,
30 timeout,
31 },
32};
33use tokio_tungstenite::tungstenite::Message;
34use tokio_util::sync::CancellationToken;
35
36#[cfg(feature = "connection-extensions")]
37mod extensions;
38
39#[cfg(feature = "connection-extensions")]
40use self::extensions::WsIoServerConnectionExtensions;
41use crate::{
42 WsIoServer,
43 core::{
44 atomic::status::AtomicStatus,
45 packet::{
46 WsIoPacket,
47 WsIoPacketType,
48 },
49 utils::task::abort_locked_task,
50 },
51 namespace::WsIoServerNamespace,
52};
53
54#[repr(u8)]
55#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
56enum ConnectionStatus {
57 Activating,
58 Authenticating,
59 AwaitingAuth,
60 Closed,
61 Closing,
62 Created,
63 Ready,
64}
65
66type OnCloseHandler = Box<
67 dyn Fn(Arc<WsIoServerConnection>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>
68 + Send
69 + Sync
70 + 'static,
71>;
72
73pub struct WsIoServerConnection {
74 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
75 cancel_token: CancellationToken,
76 #[cfg(feature = "connection-extensions")]
77 extensions: WsIoServerConnectionExtensions,
78 headers: HeaderMap,
79 message_tx: Sender<Message>,
80 namespace: Arc<WsIoServerNamespace>,
81 on_close_handler: Mutex<Option<OnCloseHandler>>,
82 sid: String,
83 status: AtomicStatus<ConnectionStatus>,
84}
85
86impl WsIoServerConnection {
87 pub(crate) fn new(
88 headers: HeaderMap,
89 namespace: Arc<WsIoServerNamespace>,
90 sid: String,
91 ) -> (Arc<Self>, Receiver<Message>) {
92 let channel_capacity = (namespace.config.websocket_config.max_write_buffer_size
93 / namespace.config.websocket_config.write_buffer_size)
94 .clamp(64, 4096);
95
96 let (message_tx, message_rx) = channel(channel_capacity);
97 (
98 Arc::new(Self {
99 auth_timeout_task: Mutex::new(None),
100 cancel_token: CancellationToken::new(),
101 #[cfg(feature = "connection-extensions")]
102 extensions: WsIoServerConnectionExtensions::new(),
103 headers,
104 message_tx,
105 namespace,
106 on_close_handler: Mutex::new(None),
107 sid,
108 status: AtomicStatus::new(ConnectionStatus::Created),
109 }),
110 message_rx,
111 )
112 }
113
114 async fn activate(self: &Arc<Self>) -> Result<()> {
116 let status = self.status.get();
118 match status {
119 ConnectionStatus::Authenticating | ConnectionStatus::Created => {
120 self.status.try_transition(status, ConnectionStatus::Activating)?
121 }
122 _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
123 }
124
125 if let Some(middleware) = &self.namespace.config.middleware {
127 timeout(
128 self.namespace.config.middleware_execution_timeout,
129 middleware(self.clone()),
130 )
131 .await??;
132 }
133
134 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
136 timeout(
137 self.namespace.config.on_connect_handler_timeout,
138 on_connect_handler(self.clone()),
139 )
140 .await??;
141 }
142
143 self.namespace.insert_connection(self.clone());
145
146 self.status
148 .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
149
150 self.send_packet(&WsIoPacket {
152 data: None,
153 key: None,
154 r#type: WsIoPacketType::Ready,
155 })
156 .await?;
157
158 if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
160 let connection = self.clone();
162 self.spawn_task(async move { on_ready_handler(connection).await });
163 }
164
165 Ok(())
166 }
167
168 async fn handle_auth_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
169 let status = self.status.get();
171 match status {
172 ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
173 _ => bail!("Received auth packet in invalid status: {:#?}", status),
174 }
175
176 abort_locked_task(&self.auth_timeout_task).await;
178
179 if let Some(auth_handler) = &self.namespace.config.auth_handler {
181 timeout(
182 self.namespace.config.auth_handler_timeout,
183 (auth_handler)(self.clone(), packet_data),
184 )
185 .await??;
186
187 self.activate().await
189 } else {
190 bail!("Auth packet received but no auth handler is configured");
191 }
192 }
193
194 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
195 Ok(self
196 .message_tx
197 .send(self.namespace.encode_packet_to_message(packet)?)
198 .await?)
199 }
200
201 pub(crate) async fn cleanup(self: &Arc<Self>) {
203 self.status.store(ConnectionStatus::Closing);
205
206 abort_locked_task(&self.auth_timeout_task).await;
208
209 self.namespace.remove_connection(&self.sid);
211
212 self.cancel_token.cancel();
214
215 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
217 let _ = timeout(
218 self.namespace.config.on_close_handler_timeout,
219 on_close_handler(self.clone()),
220 )
221 .await;
222 }
223
224 self.status.store(ConnectionStatus::Closed);
226 }
227
228 #[inline]
229 pub(crate) fn close(&self) {
230 match self.status.get() {
232 ConnectionStatus::Closed | ConnectionStatus::Closing => return,
233 _ => self.status.store(ConnectionStatus::Closing),
234 }
235
236 let _ = self.message_tx.try_send(Message::Close(None));
238 }
239
240 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
241 let packet = self.namespace.config.packet_codec.decode(bytes)?;
242 match packet.r#type {
243 WsIoPacketType::Auth => self.handle_auth_packet(packet.data.as_deref()).await,
244 _ => Ok(()),
245 }
246 }
247
248 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
249 let status = self.status.get();
251 if !matches!(status, ConnectionStatus::Created) {
252 bail!("Cannot init connection in invalid status: {:#?}", status);
253 }
254
255 let requires_auth = self.namespace.config.auth_handler.is_some();
257
258 let packet = WsIoPacket {
260 data: Some(self.namespace.config.packet_codec.encode_data(&requires_auth)?),
261 key: None,
262 r#type: WsIoPacketType::Init,
263 };
264
265 if requires_auth {
267 self.status
269 .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
270
271 let connection = self.clone();
273 *self.auth_timeout_task.lock().await = Some(spawn(async move {
274 sleep(connection.namespace.config.auth_packet_timeout).await;
275 if connection.status.is(ConnectionStatus::AwaitingAuth) {
276 connection.close();
277 }
278 }));
279
280 self.send_packet(&packet).await
282 } else {
283 self.send_packet(&packet).await?;
285
286 self.activate().await
288 }
289 }
290
291 #[inline]
294 pub fn cancel_token(&self) -> &CancellationToken {
295 &self.cancel_token
296 }
297
298 pub async fn disconnect(&self) {
299 let _ = self
300 .send_packet(&WsIoPacket {
301 data: None,
302 key: None,
303 r#type: WsIoPacketType::Disconnect,
304 })
305 .await;
306
307 self.close()
308 }
309
310 pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
311 let status = self.status.get();
312 if status != ConnectionStatus::Ready {
313 bail!("Cannot emit event in invalid status: {:#?}", status);
314 }
315
316 self.send_packet(&WsIoPacket {
317 data: data
318 .map(|data| self.namespace.config.packet_codec.encode_data(data))
319 .transpose()?,
320 key: Some(event.as_ref().to_string()),
321 r#type: WsIoPacketType::Event,
322 })
323 .await
324 }
325
326 #[cfg(feature = "connection-extensions")]
327 #[inline]
328 pub fn extensions(&self) -> &WsIoServerConnectionExtensions {
329 &self.extensions
330 }
331
332 #[inline]
333 pub fn headers(&self) -> &HeaderMap {
334 &self.headers
335 }
336
337 #[inline]
338 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
339 self.namespace.clone()
340 }
341
342 pub async fn on_close<H, Fut>(&self, handler: H)
343 where
344 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
345 Fut: Future<Output = Result<()>> + Send + 'static,
346 {
347 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
348 }
349
350 #[inline]
351 pub fn server(&self) -> WsIoServer {
352 self.namespace.server()
353 }
354
355 #[inline]
356 pub fn sid(&self) -> &str {
357 &self.sid
358 }
359
360 #[inline]
361 pub fn spawn_task<F: Future<Output = Result<()>> + Send + 'static>(&self, future: F) {
362 let cancel_token = self.cancel_token.clone();
363 spawn(async move {
364 select! {
365 _ = cancel_token.cancelled() => {},
366 _ = future => {},
367 }
368 });
369 }
370}