wsio_server/
connection.rs

1use std::sync::Arc;
2
3use anyhow::{
4    Result,
5    bail,
6};
7use bson::oid::ObjectId;
8use dashmap::DashMap;
9use http::HeaderMap;
10use serde::de::DeserializeOwned;
11use tokio::{
12    spawn,
13    sync::{
14        Mutex,
15        RwLock,
16        mpsc::{
17            UnboundedReceiver,
18            UnboundedSender,
19            unbounded_channel,
20        },
21    },
22    task::JoinHandle,
23    time::sleep,
24};
25use tokio_tungstenite::tungstenite::Message;
26
27use crate::{
28    WsIoServer,
29    core::packet::{
30        WsIoPacket,
31        WsIoPacketType,
32    },
33    namespace::WsIoServerNamespace,
34    types::handler::{
35        WsIoServerConnectionEventHandler,
36        WsIoServerConnectionOnCloseHandler,
37    },
38};
39
40enum WsIoServerConnectionStatus {
41    Activating,
42    AwaitingAuth,
43    Closed,
44    Closing,
45    Created,
46    Ready,
47}
48
49pub struct WsIoServerConnection {
50    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
51    event_handlers: DashMap<String, WsIoServerConnectionEventHandler>,
52    headers: HeaderMap,
53    namespace: Arc<WsIoServerNamespace>,
54    on_close_handler: Mutex<Option<WsIoServerConnectionOnCloseHandler>>,
55    sid: String,
56    status: RwLock<WsIoServerConnectionStatus>,
57    tx: UnboundedSender<Message>,
58}
59
60impl WsIoServerConnection {
61    pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
62        let (tx, rx) = unbounded_channel();
63        (
64            Self {
65                auth_timeout_task: Mutex::new(None),
66                event_handlers: DashMap::new(),
67                headers,
68                namespace,
69                on_close_handler: Mutex::new(None),
70                sid: ObjectId::new().to_string(),
71                status: RwLock::new(WsIoServerConnectionStatus::Created),
72                tx,
73            },
74            rx,
75        )
76    }
77
78    // Protected methods
79    async fn abort_auth_timeout_task(&self) {
80        if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
81            auth_timeout_task.abort();
82        }
83    }
84
85    async fn activate(self: &Arc<Self>) -> Result<()> {
86        if matches!(
87            *self.status.read().await,
88            WsIoServerConnectionStatus::AwaitingAuth | WsIoServerConnectionStatus::Created
89        ) {
90            return Ok(());
91        }
92
93        *self.status.write().await = WsIoServerConnectionStatus::Activating;
94        if let Some(middleware) = &self.namespace.config.middleware {
95            middleware(self.clone()).await?;
96        }
97
98        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
99            on_connect_handler(self.clone()).await?;
100        }
101
102        self.namespace.insert_connection(self.clone());
103        self.send_packet(&WsIoPacket {
104            data: None,
105            key: None,
106            r#type: WsIoPacketType::Ready,
107        })?;
108
109        *self.status.write().await = WsIoServerConnectionStatus::Ready;
110        if let Some(on_ready_handler) = &self.namespace.config.on_ready_handler {
111            on_ready_handler(self.clone()).await?;
112        }
113
114        Ok(())
115    }
116
117    pub(crate) async fn cleanup(self: &Arc<Self>) {
118        *self.status.write().await = WsIoServerConnectionStatus::Closing;
119        self.abort_auth_timeout_task().await;
120        self.namespace.cleanup_connection(&self.sid);
121        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
122            let _ = on_close_handler(self.clone()).await;
123        }
124
125        *self.status.write().await = WsIoServerConnectionStatus::Closed;
126    }
127
128    pub(crate) async fn close(&self) {
129        {
130            let mut status = self.status.write().await;
131            if matches!(
132                *status,
133                WsIoServerConnectionStatus::Closed | WsIoServerConnectionStatus::Closing
134            ) {
135                return;
136            }
137
138            *status = WsIoServerConnectionStatus::Closing;
139        }
140
141        let _ = self.tx.send(Message::Close(None));
142    }
143
144    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) {
145        let packet = match self.namespace.config.packet_codec.decode(bytes) {
146            Ok(packet) => packet,
147            Err(_) => return,
148        };
149
150        match packet.r#type {
151            WsIoPacketType::Auth => {
152                if let Some(auth_handler) = &self.namespace.config.auth_handler
153                    && (auth_handler)(self.clone(), packet.data.as_deref()).await.is_ok()
154                {
155                    self.abort_auth_timeout_task().await;
156                    if matches!(*self.status.read().await, WsIoServerConnectionStatus::AwaitingAuth)
157                        && self.activate().await.is_ok()
158                    {
159                        return;
160                    }
161                }
162
163                self.close().await;
164            }
165            WsIoPacketType::Event => {}
166            _ => {}
167        }
168    }
169
170    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
171        let require_auth = self.namespace.config.auth_handler.is_some();
172        let packet = WsIoPacket {
173            data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
174            key: Some(self.sid.clone()),
175            r#type: WsIoPacketType::Init,
176        };
177
178        if require_auth {
179            *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
180            let connection = self.clone();
181            *self.auth_timeout_task.lock().await = Some(spawn(async move {
182                sleep(connection.namespace.config.auth_timeout).await;
183                if matches!(
184                    *connection.status.read().await,
185                    WsIoServerConnectionStatus::AwaitingAuth
186                ) {
187                    connection.close().await;
188                }
189            }));
190
191            self.send_packet(&packet)?;
192        } else {
193            self.send_packet(&packet)?;
194            self.activate().await?;
195        }
196
197        Ok(())
198    }
199
200    #[inline]
201    fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
202        Ok(self.tx.send(self.namespace.encode_packet_to_message(packet)?)?)
203    }
204
205    // Public methods
206    pub async fn disconnect(&self) {
207        let _ = self.send_packet(&WsIoPacket {
208            data: None,
209            key: None,
210            r#type: WsIoPacketType::Disconnect,
211        });
212
213        // TODO: Should we wait for the disconnect packet to be sent or use spawn?
214        let _ = self.close().await;
215    }
216
217    #[inline]
218    pub fn headers(&self) -> &HeaderMap {
219        &self.headers
220    }
221
222    #[inline]
223    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
224        self.namespace.clone()
225    }
226
227    #[inline]
228    pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> Result<()>
229    where
230        H: Fn(Arc<WsIoServerConnection>, &D) -> Fut + Send + Sync + 'static,
231        Fut: Future<Output = Result<()>> + Send + 'static,
232        D: DeserializeOwned + Send + 'static,
233    {
234        let event = event.as_ref();
235        if self.event_handlers.contains_key(event) {
236            bail!("Event {} handler already exists", event);
237        }
238
239        let handler = Arc::new(handler);
240        let packet_codec = self.namespace.config.packet_codec;
241        self.event_handlers.insert(
242            event.into(),
243            Box::new(move |connection, bytes: Option<&[u8]>| {
244                let handler = handler.clone();
245                Box::pin(async move {
246                    let bytes = match bytes {
247                        Some(bytes) => bytes,
248                        None => packet_codec.empty_data_encoded(),
249                    };
250
251                    let data = packet_codec.decode_data::<D>(bytes)?;
252                    handler(connection, &data).await
253                })
254            }),
255        );
256
257        Ok(())
258    }
259
260    pub async fn on_close<H, Fut>(&self, handler: H)
261    where
262        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
263        Fut: Future<Output = Result<()>> + Send + 'static,
264    {
265        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
266    }
267
268    #[inline]
269    pub fn server(&self) -> WsIoServer {
270        self.namespace.server()
271    }
272
273    #[inline]
274    pub fn sid(&self) -> &str {
275        &self.sid
276    }
277}