wsio_server/
connection.rs

1use std::{
2    pin::Pin,
3    sync::Arc,
4};
5
6use anyhow::{
7    Result,
8    bail,
9};
10use bson::oid::ObjectId;
11use http::HeaderMap;
12use tokio::{
13    spawn,
14    sync::{
15        Mutex,
16        RwLock,
17        mpsc::{
18            Receiver,
19            Sender,
20            channel,
21        },
22    },
23    task::JoinHandle,
24    time::sleep,
25};
26use tokio_tungstenite::tungstenite::Message;
27use tokio_util::sync::CancellationToken;
28
29use crate::{
30    WsIoServer,
31    core::packet::{
32        WsIoPacket,
33        WsIoPacketType,
34    },
35    namespace::WsIoServerNamespace,
36};
37
38#[derive(Debug)]
39enum WsIoServerConnectionStatus {
40    Activating,
41    Authenticating,
42    AwaitingAuth,
43    Closed,
44    Closing,
45    Created,
46    Ready,
47}
48
49type OnCloseHandler = Box<
50    dyn Fn(Arc<WsIoServerConnection>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>
51        + Send
52        + Sync
53        + 'static,
54>;
55
56pub struct WsIoServerConnection {
57    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
58    cancel_token: CancellationToken,
59    headers: HeaderMap,
60    message_tx: Sender<Message>,
61    namespace: Arc<WsIoServerNamespace>,
62    on_close_handler: Mutex<Option<OnCloseHandler>>,
63    sid: String,
64    status: RwLock<WsIoServerConnectionStatus>,
65}
66
67impl WsIoServerConnection {
68    pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, Receiver<Message>) {
69        // TODO: use config set buf size
70        let (message_tx, message_rx) = channel(512);
71        (
72            Self {
73                auth_timeout_task: Mutex::new(None),
74                cancel_token: CancellationToken::new(),
75                headers,
76                message_tx,
77                namespace,
78                on_close_handler: Mutex::new(None),
79                sid: ObjectId::new().to_string(),
80                status: RwLock::new(WsIoServerConnectionStatus::Created),
81            },
82            message_rx,
83        )
84    }
85
86    // Private methods
87    async fn abort_auth_timeout_task(&self) {
88        if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
89            auth_timeout_task.abort();
90        }
91    }
92
93    async fn activate(self: &Arc<Self>) -> Result<()> {
94        {
95            let mut status = self.status.write().await;
96            match *status {
97                WsIoServerConnectionStatus::Authenticating | WsIoServerConnectionStatus::Created => {
98                    *status = WsIoServerConnectionStatus::Activating
99                }
100                _ => bail!("Cannot activate connection in invalid status: {:#?}", *status),
101            }
102        }
103
104        if let Some(middleware) = &self.namespace.config.middleware {
105            middleware(self.clone()).await?;
106        }
107
108        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
109            on_connect_handler(self.clone()).await?;
110        }
111
112        self.namespace.insert_connection(self.clone());
113        self.send_packet(&WsIoPacket {
114            data: None,
115            key: None,
116            r#type: WsIoPacketType::Ready,
117        })
118        .await?;
119
120        *self.status.write().await = WsIoServerConnectionStatus::Ready;
121        if let Some(on_ready_handler) = &self.namespace.config.on_ready_handler {
122            on_ready_handler(self.clone()).await?;
123        }
124
125        Ok(())
126    }
127
128    async fn handle_auth_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
129        {
130            let mut status = self.status.write().await;
131            match *status {
132                WsIoServerConnectionStatus::AwaitingAuth => *status = WsIoServerConnectionStatus::Authenticating,
133                _ => bail!("Received auth packet in invalid status: {:#?}", *status),
134            }
135        }
136
137        if let Some(auth_handler) = &self.namespace.config.auth_handler {
138            (auth_handler)(self.clone(), packet_data).await?;
139            self.abort_auth_timeout_task().await;
140            self.activate().await
141        } else {
142            bail!("Auth packet received but no auth handler is configured");
143        }
144    }
145
146    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
147        Ok(self
148            .message_tx
149            .send(self.namespace.encode_packet_to_message(packet)?)
150            .await?)
151    }
152
153    // Protected methods
154    pub(crate) async fn cleanup(self: &Arc<Self>) {
155        *self.status.write().await = WsIoServerConnectionStatus::Closing;
156        self.abort_auth_timeout_task().await;
157        self.namespace.remove_connection(&self.sid);
158        self.cancel_token.cancel();
159        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
160            let _ = on_close_handler(self.clone()).await;
161        }
162
163        *self.status.write().await = WsIoServerConnectionStatus::Closed;
164    }
165
166    pub(crate) async fn close(&self) {
167        {
168            let mut status = self.status.write().await;
169            if matches!(
170                *status,
171                WsIoServerConnectionStatus::Closed | WsIoServerConnectionStatus::Closing
172            ) {
173                return;
174            }
175
176            *status = WsIoServerConnectionStatus::Closing;
177        }
178
179        let _ = self.message_tx.send(Message::Close(None)).await;
180    }
181
182    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
183        let packet = self.namespace.config.packet_codec.decode(bytes)?;
184        match packet.r#type {
185            WsIoPacketType::Auth => self.handle_auth_packet(packet.data.as_deref()).await,
186            _ => Ok(()),
187        }
188    }
189
190    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
191        let require_auth = self.namespace.config.auth_handler.is_some();
192        let packet = WsIoPacket {
193            data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
194            key: None,
195            r#type: WsIoPacketType::Init,
196        };
197
198        if require_auth {
199            *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
200            let connection = self.clone();
201            *self.auth_timeout_task.lock().await = Some(spawn(async move {
202                sleep(connection.namespace.config.auth_timeout).await;
203                if matches!(
204                    *connection.status.read().await,
205                    WsIoServerConnectionStatus::AwaitingAuth
206                ) {
207                    connection.close().await;
208                }
209            }));
210
211            self.send_packet(&packet).await
212        } else {
213            self.send_packet(&packet).await?;
214            self.activate().await
215        }
216    }
217
218    // Public methods
219
220    #[inline]
221    pub fn cancel_token(&self) -> &CancellationToken {
222        &self.cancel_token
223    }
224
225    pub async fn disconnect(&self) {
226        let _ = self
227            .send_packet(&WsIoPacket {
228                data: None,
229                key: None,
230                r#type: WsIoPacketType::Disconnect,
231            })
232            .await;
233
234        self.close().await
235    }
236
237    #[inline]
238    pub fn headers(&self) -> &HeaderMap {
239        &self.headers
240    }
241
242    #[inline]
243    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
244        self.namespace.clone()
245    }
246
247    pub async fn on_close<H, Fut>(&self, handler: H)
248    where
249        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
250        Fut: Future<Output = Result<()>> + Send + 'static,
251    {
252        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
253    }
254
255    #[inline]
256    pub fn server(&self) -> WsIoServer {
257        self.namespace.server()
258    }
259
260    #[inline]
261    pub fn sid(&self) -> &str {
262        &self.sid
263    }
264}