wsio_server/connection/
mod.rs

1use std::{
2    pin::Pin,
3    sync::Arc,
4};
5
6use anyhow::{
7    Result,
8    bail,
9};
10use http::HeaderMap;
11use num_enum::{
12    IntoPrimitive,
13    TryFromPrimitive,
14};
15use serde::Serialize;
16use tokio::{
17    select,
18    spawn,
19    sync::{
20        Mutex,
21        mpsc::{
22            Receiver,
23            Sender,
24            channel,
25        },
26    },
27    task::JoinHandle,
28    time::sleep,
29};
30use tokio_tungstenite::tungstenite::Message;
31use tokio_util::sync::CancellationToken;
32
33#[cfg(feature = "connection-extensions")]
34mod extensions;
35
36#[cfg(feature = "connection-extensions")]
37use self::extensions::WsIoServerConnectionExtensions;
38use crate::{
39    WsIoServer,
40    core::{
41        atomic::status::AtomicStatus,
42        packet::{
43            WsIoPacket,
44            WsIoPacketType,
45        },
46    },
47    namespace::WsIoServerNamespace,
48};
49
50#[repr(u8)]
51#[derive(Debug, IntoPrimitive, TryFromPrimitive)]
52enum ConnectionStatus {
53    Activating,
54    Authenticating,
55    AwaitingAuth,
56    Closed,
57    Closing,
58    Created,
59    Ready,
60}
61
62type OnCloseHandler = Box<
63    dyn Fn(Arc<WsIoServerConnection>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>
64        + Send
65        + Sync
66        + 'static,
67>;
68
69pub struct WsIoServerConnection {
70    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
71    cancel_token: CancellationToken,
72    #[cfg(feature = "connection-extensions")]
73    extensions: WsIoServerConnectionExtensions,
74    headers: HeaderMap,
75    message_tx: Sender<Message>,
76    namespace: Arc<WsIoServerNamespace>,
77    on_close_handler: Mutex<Option<OnCloseHandler>>,
78    sid: String,
79    status: AtomicStatus<ConnectionStatus>,
80}
81
82impl WsIoServerConnection {
83    pub(crate) fn new(
84        headers: HeaderMap,
85        namespace: Arc<WsIoServerNamespace>,
86        sid: String,
87    ) -> (Arc<Self>, Receiver<Message>) {
88        // TODO: use config set buf size
89        let (message_tx, message_rx) = channel(512);
90        (
91            Arc::new(Self {
92                auth_timeout_task: Mutex::new(None),
93                cancel_token: CancellationToken::new(),
94                #[cfg(feature = "connection-extensions")]
95                extensions: WsIoServerConnectionExtensions::new(),
96                headers,
97                message_tx,
98                namespace,
99                on_close_handler: Mutex::new(None),
100                sid,
101                status: AtomicStatus::new(ConnectionStatus::Created),
102            }),
103            message_rx,
104        )
105    }
106
107    // Private methods
108    async fn abort_auth_timeout_task(&self) {
109        if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
110            auth_timeout_task.abort();
111        }
112    }
113
114    async fn activate(self: &Arc<Self>) -> Result<()> {
115        let status = self.status.get();
116        match status {
117            ConnectionStatus::Authenticating | ConnectionStatus::Created => {
118                self.status.try_transition(status, ConnectionStatus::Activating)?
119            }
120            _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
121        }
122
123        if let Some(middleware) = &self.namespace.config.middleware {
124            middleware(self.clone()).await?;
125        }
126
127        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
128            on_connect_handler(self.clone()).await?;
129        }
130
131        self.namespace.insert_connection(self.clone());
132        self.send_packet(&WsIoPacket {
133            data: None,
134            key: None,
135            r#type: WsIoPacketType::Ready,
136        })
137        .await?;
138
139        self.status
140            .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
141
142        if let Some(on_ready_handler) = &self.namespace.config.on_ready_handler {
143            on_ready_handler(self.clone()).await?;
144        }
145
146        Ok(())
147    }
148
149    async fn handle_auth_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
150        let status = self.status.get();
151        match status {
152            ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
153            _ => bail!("Received auth packet in invalid status: {:#?}", status),
154        }
155
156        if let Some(auth_handler) = &self.namespace.config.auth_handler {
157            (auth_handler)(self.clone(), packet_data).await?;
158            self.abort_auth_timeout_task().await;
159            self.activate().await
160        } else {
161            bail!("Auth packet received but no auth handler is configured");
162        }
163    }
164
165    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
166        Ok(self
167            .message_tx
168            .send(self.namespace.encode_packet_to_message(packet)?)
169            .await?)
170    }
171
172    // Protected methods
173    pub(crate) async fn cleanup(self: &Arc<Self>) {
174        self.status.store(ConnectionStatus::Closing);
175        self.abort_auth_timeout_task().await;
176        self.namespace.remove_connection(&self.sid);
177        self.cancel_token.cancel();
178        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
179            let _ = on_close_handler(self.clone()).await;
180        }
181
182        self.status.store(ConnectionStatus::Closed);
183    }
184
185    pub(crate) async fn close(&self) {
186        match self.status.get() {
187            ConnectionStatus::Closed | ConnectionStatus::Closing => return,
188            _ => self.status.store(ConnectionStatus::Closing),
189        }
190
191        let _ = self.message_tx.send(Message::Close(None)).await;
192    }
193
194    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
195        let packet = self.namespace.config.packet_codec.decode(bytes)?;
196        match packet.r#type {
197            WsIoPacketType::Auth => self.handle_auth_packet(packet.data.as_deref()).await,
198            _ => Ok(()),
199        }
200    }
201
202    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
203        let require_auth = self.namespace.config.auth_handler.is_some();
204        let packet = WsIoPacket {
205            data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
206            key: None,
207            r#type: WsIoPacketType::Init,
208        };
209
210        if require_auth {
211            self.status.store(ConnectionStatus::AwaitingAuth);
212            let connection = self.clone();
213            *self.auth_timeout_task.lock().await = Some(spawn(async move {
214                sleep(connection.namespace.config.auth_timeout).await;
215                if matches!(connection.status.get(), ConnectionStatus::AwaitingAuth) {
216                    connection.close().await;
217                }
218            }));
219
220            self.send_packet(&packet).await
221        } else {
222            self.send_packet(&packet).await?;
223            self.activate().await
224        }
225    }
226
227    // Public methods
228
229    #[inline]
230    pub fn cancel_token(&self) -> &CancellationToken {
231        &self.cancel_token
232    }
233
234    pub async fn disconnect(&self) {
235        let _ = self
236            .send_packet(&WsIoPacket {
237                data: None,
238                key: None,
239                r#type: WsIoPacketType::Disconnect,
240            })
241            .await;
242
243        self.close().await
244    }
245
246    pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
247        self.send_packet(&WsIoPacket {
248            data: data
249                .map(|data| self.namespace.config.packet_codec.encode_data(data))
250                .transpose()?,
251            key: Some(event.as_ref().to_string()),
252            r#type: WsIoPacketType::Event,
253        })
254        .await
255    }
256
257    #[cfg(feature = "connection-extensions")]
258    #[inline]
259    pub fn extensions(&self) -> &WsIoServerConnectionExtensions {
260        &self.extensions
261    }
262
263    #[inline]
264    pub fn headers(&self) -> &HeaderMap {
265        &self.headers
266    }
267
268    #[inline]
269    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
270        self.namespace.clone()
271    }
272
273    pub async fn on_close<H, Fut>(&self, handler: H)
274    where
275        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
276        Fut: Future<Output = Result<()>> + Send + 'static,
277    {
278        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
279    }
280
281    #[inline]
282    pub fn server(&self) -> WsIoServer {
283        self.namespace.server()
284    }
285
286    #[inline]
287    pub fn sid(&self) -> &str {
288        &self.sid
289    }
290
291    #[inline]
292    pub fn spawn_task<F: Future<Output = Result<()>> + Send + 'static>(&self, future: F) {
293        let cancel_token = self.cancel_token().clone();
294        spawn(async move {
295            select! {
296                _ = cancel_token.cancelled() => {},
297                _ = future => {},
298            }
299        });
300    }
301}