wsio_server/
connection.rs

1use std::sync::Arc;
2
3use anyhow::{
4    Result,
5    bail,
6};
7use bson::oid::ObjectId;
8use dashmap::DashMap;
9use http::HeaderMap;
10use serde::de::DeserializeOwned;
11use tokio::{
12    spawn,
13    sync::{
14        Mutex,
15        RwLock,
16        mpsc::{
17            UnboundedReceiver,
18            UnboundedSender,
19            unbounded_channel,
20        },
21    },
22    task::JoinHandle,
23    time::sleep,
24};
25use tokio_tungstenite::tungstenite::Message;
26
27use crate::{
28    WsIoServer,
29    core::packet::{
30        WsIoPacket,
31        WsIoPacketType,
32    },
33    namespace::WsIoServerNamespace,
34    types::handler::{
35        WsIoServerConnectionEventHandler,
36        WsIoServerConnectionOnDisconnectHandler,
37    },
38};
39
40enum WsIoServerConnectionStatus {
41    Activating,
42    AwaitingAuth,
43    Closed,
44    Closing,
45    Created,
46    Ready,
47}
48
49pub struct WsIoServerConnection {
50    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
51    event_handlers: DashMap<String, WsIoServerConnectionEventHandler>,
52    headers: HeaderMap,
53    namespace: Arc<WsIoServerNamespace>,
54    on_disconnect_handler: Mutex<Option<WsIoServerConnectionOnDisconnectHandler>>,
55    sid: String,
56    status: RwLock<WsIoServerConnectionStatus>,
57    tx: UnboundedSender<Message>,
58}
59
60impl WsIoServerConnection {
61    pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
62        let (tx, rx) = unbounded_channel();
63        (
64            Self {
65                auth_timeout_task: Mutex::new(None),
66                event_handlers: DashMap::new(),
67                headers,
68                namespace,
69                on_disconnect_handler: Mutex::new(None),
70                sid: ObjectId::new().to_string(),
71                status: RwLock::new(WsIoServerConnectionStatus::Created),
72                tx,
73            },
74            rx,
75        )
76    }
77
78    // Protected methods
79    async fn abort_auth_timeout_task(&self) {
80        if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
81            auth_timeout_task.abort();
82        }
83    }
84
85    async fn activate(self: &Arc<Self>) -> Result<()> {
86        *self.status.write().await = WsIoServerConnectionStatus::Activating;
87        if let Some(middleware) = &self.namespace.config.middleware {
88            middleware(self.clone()).await?;
89        }
90
91        self.namespace.insert_connection(self.clone());
92        *self.status.write().await = WsIoServerConnectionStatus::Ready;
93        let packet = WsIoPacket {
94            data: None,
95            key: None,
96            r#type: WsIoPacketType::Ready,
97        };
98
99        self.send_packet(&packet)?;
100        (self.namespace.config.on_connect_handler)(self.clone()).await
101    }
102
103    pub(crate) async fn cleanup(self: &Arc<Self>) {
104        *self.status.write().await = WsIoServerConnectionStatus::Closing;
105        self.abort_auth_timeout_task().await;
106        self.namespace.cleanup_connection(&self.sid);
107        if let Some(on_disconnect_handler) = self.on_disconnect_handler.lock().await.take() {
108            let _ = on_disconnect_handler(self.clone()).await;
109        }
110
111        *self.status.write().await = WsIoServerConnectionStatus::Closed;
112    }
113
114    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) {
115        let packet = match self.namespace.config.packet_codec.decode(bytes) {
116            Ok(packet) => packet,
117            Err(_) => return,
118        };
119
120        match packet.r#type {
121            WsIoPacketType::Auth => {
122                if let Some(auth_handler) = &self.namespace.config.auth_handler
123                    && (auth_handler)(self.clone(), packet.data.as_deref()).await.is_ok()
124                {
125                    self.abort_auth_timeout_task().await;
126                    if self.activate().await.is_ok() {
127                        return;
128                    }
129                }
130
131                self.close();
132            }
133            WsIoPacketType::Event => {}
134            _ => {}
135        }
136    }
137
138    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
139        self.send(Message::Text(format!("c{}", self.namespace.config.packet_codec).into()))?;
140        let require_auth = self.namespace.config.auth_handler.is_some();
141        let packet = WsIoPacket {
142            data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
143            key: Some(self.sid.clone()),
144            r#type: WsIoPacketType::Init,
145        };
146
147        if require_auth {
148            *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
149            let connection = self.clone();
150            self.auth_timeout_task.lock().await.replace(spawn(async move {
151                sleep(connection.namespace.config.auth_timeout).await;
152                if matches!(
153                    *connection.status.read().await,
154                    WsIoServerConnectionStatus::AwaitingAuth
155                ) {
156                    connection.close();
157                }
158            }));
159
160            self.send_packet(&packet)?;
161        } else {
162            self.send_packet(&packet)?;
163            self.activate().await?;
164        }
165
166        Ok(())
167    }
168
169    #[inline]
170    fn send(&self, message: Message) -> Result<()> {
171        Ok(self.tx.send(message)?)
172    }
173
174    #[inline]
175    fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
176        self.send(self.namespace.encode_packet_to_message(packet)?)
177    }
178
179    // Public methods
180
181    #[inline]
182    pub fn close(&self) {
183        let _ = self.send(Message::Close(None));
184    }
185
186    #[inline]
187    pub fn headers(&self) -> &HeaderMap {
188        &self.headers
189    }
190
191    #[inline]
192    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
193        self.namespace.clone()
194    }
195
196    pub fn on<H, Fut, D>(&self, event: &str, handler: H) -> Result<()>
197    where
198        H: Fn(Arc<WsIoServerConnection>, &D) -> Fut + Send + Sync + 'static,
199        Fut: Future<Output = Result<()>> + Send + 'static,
200        D: DeserializeOwned + Send + 'static,
201    {
202        if self.event_handlers.contains_key(event) {
203            bail!("Event {} handler already exists", event);
204        }
205
206        let handler = Arc::new(handler);
207        let packet_codec = self.namespace.config.packet_codec;
208        self.event_handlers.insert(
209            event.to_string(),
210            Arc::new(move |connection, bytes: Option<&[u8]>| {
211                let handler = handler.clone();
212                Box::pin(async move {
213                    let bytes = match bytes {
214                        Some(bytes) => bytes,
215                        None => packet_codec.empty_data_encoded(),
216                    };
217
218                    let data = packet_codec.decode_data::<D>(bytes)?;
219                    handler(connection, &data).await
220                })
221            }),
222        );
223
224        Ok(())
225    }
226
227    pub async fn on_disconnect<H, Fut>(&self, handler: H)
228    where
229        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
230        Fut: Future<Output = Result<()>> + Send + 'static,
231    {
232        self.on_disconnect_handler
233            .lock()
234            .await
235            .replace(Box::new(move |connection| Box::pin(handler(connection))));
236    }
237
238    #[inline]
239    pub fn server(&self) -> WsIoServer {
240        self.namespace.server()
241    }
242
243    #[inline]
244    pub fn sid(&self) -> &str {
245        &self.sid
246    }
247}