wsio_server/
connection.rs

1use std::{
2    pin::Pin,
3    sync::Arc,
4};
5
6use anyhow::{
7    Result,
8    bail,
9};
10use bson::oid::ObjectId;
11use dashmap::DashMap;
12use http::HeaderMap;
13use serde::de::DeserializeOwned;
14use tokio::{
15    spawn,
16    sync::{
17        Mutex,
18        RwLock,
19        mpsc::{
20            UnboundedReceiver,
21            UnboundedSender,
22            unbounded_channel,
23        },
24    },
25    task::JoinHandle,
26    time::sleep,
27};
28use tokio_tungstenite::tungstenite::Message;
29
30use crate::{
31    WsIoServer,
32    core::packet::{
33        WsIoPacket,
34        WsIoPacketType,
35    },
36    namespace::WsIoServerNamespace,
37};
38
39#[derive(Debug)]
40enum WsIoServerConnectionStatus {
41    Activating,
42    Authenticating,
43    AwaitingAuth,
44    Closed,
45    Closing,
46    Created,
47    Ready,
48}
49
50type EventHandler = Box<
51    dyn for<'a> Fn(Arc<WsIoServerConnection>, Option<&'a [u8]>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
52        + Send
53        + Sync
54        + 'static,
55>;
56
57type OnCloseHandler = Box<
58    dyn Fn(Arc<WsIoServerConnection>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>
59        + Send
60        + Sync
61        + 'static,
62>;
63
64pub struct WsIoServerConnection {
65    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
66    event_handlers: DashMap<String, EventHandler>,
67    headers: HeaderMap,
68    namespace: Arc<WsIoServerNamespace>,
69    on_close_handler: Mutex<Option<OnCloseHandler>>,
70    sid: String,
71    status: RwLock<WsIoServerConnectionStatus>,
72    tx: UnboundedSender<Message>,
73}
74
75impl WsIoServerConnection {
76    pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
77        let (tx, rx) = unbounded_channel();
78        (
79            Self {
80                auth_timeout_task: Mutex::new(None),
81                event_handlers: DashMap::new(),
82                headers,
83                namespace,
84                on_close_handler: Mutex::new(None),
85                sid: ObjectId::new().to_string(),
86                status: RwLock::new(WsIoServerConnectionStatus::Created),
87                tx,
88            },
89            rx,
90        )
91    }
92
93    // Private methods
94    async fn abort_auth_timeout_task(&self) {
95        if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
96            auth_timeout_task.abort();
97        }
98    }
99
100    async fn activate(self: &Arc<Self>) -> Result<()> {
101        {
102            let mut status = self.status.write().await;
103            match *status {
104                WsIoServerConnectionStatus::Authenticating | WsIoServerConnectionStatus::Created => {
105                    *status = WsIoServerConnectionStatus::Activating
106                }
107                _ => bail!("Cannot activate connection in invalid status: {:#?}", *status),
108            }
109        }
110
111        if let Some(middleware) = &self.namespace.config.middleware {
112            middleware(self.clone()).await?;
113        }
114
115        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
116            on_connect_handler(self.clone()).await?;
117        }
118
119        self.namespace.insert_connection(self.clone());
120        self.send_packet(&WsIoPacket {
121            data: None,
122            key: None,
123            r#type: WsIoPacketType::Ready,
124        })?;
125
126        *self.status.write().await = WsIoServerConnectionStatus::Ready;
127        if let Some(on_ready_handler) = &self.namespace.config.on_ready_handler {
128            on_ready_handler(self.clone()).await?;
129        }
130
131        Ok(())
132    }
133
134    async fn handle_auth_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
135        {
136            let mut status = self.status.write().await;
137            match *status {
138                WsIoServerConnectionStatus::AwaitingAuth => *status = WsIoServerConnectionStatus::Authenticating,
139                _ => bail!("Received auth packet in invalid status: {:#?}", *status),
140            }
141        }
142
143        if let Some(auth_handler) = &self.namespace.config.auth_handler {
144            (auth_handler)(self.clone(), packet_data).await?;
145            self.abort_auth_timeout_task().await;
146            let status = self.status.read().await;
147            if !matches!(*status, WsIoServerConnectionStatus::Authenticating) {
148                bail!("Auth packet processed while connection status was {:#?}", *status);
149            }
150
151            self.activate().await?;
152            Ok(())
153        } else {
154            bail!("Auth packet received but no auth handler is configured");
155        }
156    }
157
158    #[inline]
159    fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
160        Ok(self.tx.send(self.namespace.encode_packet_to_message(packet)?)?)
161    }
162
163    // Protected methods
164    pub(crate) async fn cleanup(self: &Arc<Self>) {
165        *self.status.write().await = WsIoServerConnectionStatus::Closing;
166        self.abort_auth_timeout_task().await;
167        self.namespace.remove_connection(&self.sid);
168        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
169            let _ = on_close_handler(self.clone()).await;
170        }
171
172        *self.status.write().await = WsIoServerConnectionStatus::Closed;
173    }
174
175    pub(crate) async fn close(&self) {
176        {
177            let mut status = self.status.write().await;
178            if matches!(
179                *status,
180                WsIoServerConnectionStatus::Closed | WsIoServerConnectionStatus::Closing
181            ) {
182                return;
183            }
184
185            *status = WsIoServerConnectionStatus::Closing;
186        }
187
188        let _ = self.tx.send(Message::Close(None));
189    }
190
191    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) {
192        let packet = match self.namespace.config.packet_codec.decode(bytes) {
193            Ok(packet) => packet,
194            Err(_) => return,
195        };
196
197        if match packet.r#type {
198            WsIoPacketType::Auth => self.handle_auth_packet(packet.data.as_deref()).await,
199            WsIoPacketType::Event => Ok(()),
200            _ => Ok(()),
201        }
202        .is_err()
203        {
204            self.close().await;
205        }
206    }
207
208    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
209        let require_auth = self.namespace.config.auth_handler.is_some();
210        let packet = WsIoPacket {
211            data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
212            key: None,
213            r#type: WsIoPacketType::Init,
214        };
215
216        if require_auth {
217            *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
218            let connection = self.clone();
219            *self.auth_timeout_task.lock().await = Some(spawn(async move {
220                sleep(connection.namespace.config.auth_timeout).await;
221                if matches!(
222                    *connection.status.read().await,
223                    WsIoServerConnectionStatus::AwaitingAuth
224                ) {
225                    connection.close().await;
226                }
227            }));
228
229            self.send_packet(&packet)?;
230        } else {
231            self.send_packet(&packet)?;
232            self.activate().await?;
233        }
234
235        Ok(())
236    }
237
238    // Public methods
239    pub async fn disconnect(&self) {
240        let _ = self.send_packet(&WsIoPacket {
241            data: None,
242            key: None,
243            r#type: WsIoPacketType::Disconnect,
244        });
245
246        // TODO: Should we wait for the disconnect packet to be sent or use spawn?
247        let _ = self.close().await;
248    }
249
250    #[inline]
251    pub fn headers(&self) -> &HeaderMap {
252        &self.headers
253    }
254
255    #[inline]
256    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
257        self.namespace.clone()
258    }
259
260    #[inline]
261    pub fn off(&self, event: impl AsRef<str>) {
262        self.event_handlers.remove(event.as_ref());
263    }
264
265    #[inline]
266    pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> Result<()>
267    where
268        H: Fn(Arc<WsIoServerConnection>, &D) -> Fut + Send + Sync + 'static,
269        Fut: Future<Output = Result<()>> + Send + 'static,
270        D: DeserializeOwned + Send + 'static,
271    {
272        let event = event.as_ref();
273        if self.event_handlers.contains_key(event) {
274            bail!("Event {} handler already exists", event);
275        }
276
277        let handler = Arc::new(handler);
278        let packet_codec = self.namespace.config.packet_codec;
279        self.event_handlers.insert(
280            event.into(),
281            Box::new(move |connection, bytes: Option<&[u8]>| {
282                let handler = handler.clone();
283                Box::pin(async move {
284                    let bytes = match bytes {
285                        Some(bytes) => bytes,
286                        None => packet_codec.empty_data_encoded(),
287                    };
288
289                    let data = packet_codec.decode_data::<D>(bytes)?;
290                    handler(connection, &data).await
291                })
292            }),
293        );
294
295        Ok(())
296    }
297
298    pub async fn on_close<H, Fut>(&self, handler: H)
299    where
300        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
301        Fut: Future<Output = Result<()>> + Send + 'static,
302    {
303        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
304    }
305
306    #[inline]
307    pub fn server(&self) -> WsIoServer {
308        self.namespace.server()
309    }
310
311    #[inline]
312    pub fn sid(&self) -> &str {
313        &self.sid
314    }
315}