wsio_server/connection/
mod.rs

1use std::sync::Arc;
2
3use anyhow::{
4    Result,
5    bail,
6};
7use http::HeaderMap;
8use num_enum::{
9    IntoPrimitive,
10    TryFromPrimitive,
11};
12use serde::{
13    Serialize,
14    de::DeserializeOwned,
15};
16use tokio::{
17    select,
18    spawn,
19    sync::{
20        Mutex,
21        mpsc::{
22            Receiver,
23            Sender,
24            channel,
25        },
26    },
27    task::JoinHandle,
28    time::{
29        sleep,
30        timeout,
31    },
32};
33use tokio_tungstenite::tungstenite::Message;
34use tokio_util::sync::CancellationToken;
35
36#[cfg(feature = "connection-extensions")]
37mod extensions;
38
39#[cfg(feature = "connection-extensions")]
40use self::extensions::ConnectionExtensions;
41use crate::{
42    WsIoServer,
43    core::{
44        atomic::status::AtomicStatus,
45        channel_capacity_from_websocket_config,
46        event::registry::WsIoEventRegistry,
47        packet::{
48            WsIoPacket,
49            WsIoPacketType,
50        },
51        types::BoxAsyncUnaryResultHandler,
52        utils::task::abort_locked_task,
53    },
54    namespace::WsIoServerNamespace,
55};
56
57#[repr(u8)]
58#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
59enum ConnectionStatus {
60    Activating,
61    Authenticating,
62    AwaitingAuth,
63    Closed,
64    Closing,
65    Created,
66    Ready,
67}
68
69pub struct WsIoServerConnection {
70    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
71    cancel_token: CancellationToken,
72    event_registry: WsIoEventRegistry<WsIoServerConnection>,
73    #[cfg(feature = "connection-extensions")]
74    extensions: ConnectionExtensions,
75    headers: HeaderMap,
76    message_tx: Sender<Message>,
77    namespace: Arc<WsIoServerNamespace>,
78    on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
79    sid: String,
80    status: AtomicStatus<ConnectionStatus>,
81}
82
83impl WsIoServerConnection {
84    #[inline]
85    pub(crate) fn new(
86        headers: HeaderMap,
87        namespace: Arc<WsIoServerNamespace>,
88        sid: String,
89    ) -> (Arc<Self>, Receiver<Message>) {
90        let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
91        let (message_tx, message_rx) = channel(channel_capacity);
92        (
93            Arc::new(Self {
94                auth_timeout_task: Mutex::new(None),
95                cancel_token: CancellationToken::new(),
96                event_registry: WsIoEventRegistry::new(namespace.config.packet_codec),
97                #[cfg(feature = "connection-extensions")]
98                extensions: ConnectionExtensions::new(),
99                headers,
100                message_tx,
101                namespace,
102                on_close_handler: Mutex::new(None),
103                sid,
104                status: AtomicStatus::new(ConnectionStatus::Created),
105            }),
106            message_rx,
107        )
108    }
109
110    // Private methods
111    async fn activate(self: &Arc<Self>) -> Result<()> {
112        // Verify current state; only valid from Authenticating or Created → Activating
113        let status = self.status.get();
114        match status {
115            ConnectionStatus::Authenticating | ConnectionStatus::Created => {
116                self.status.try_transition(status, ConnectionStatus::Activating)?
117            }
118            _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
119        }
120
121        // Invoke middleware with timeout protection if configured
122        if let Some(middleware) = &self.namespace.config.middleware {
123            timeout(
124                self.namespace.config.middleware_execution_timeout,
125                middleware(self.clone()),
126            )
127            .await??;
128        }
129
130        // Invoke on_connect handler with timeout protection if configured
131        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
132            timeout(
133                self.namespace.config.on_connect_handler_timeout,
134                on_connect_handler(self.clone()),
135            )
136            .await??;
137        }
138
139        // Insert connection into namespace
140        self.namespace.insert_connection(self.clone());
141
142        // Transition state to Ready
143        self.status
144            .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
145
146        // Send ready packet
147        self.send_packet(&WsIoPacket::new_ready()).await?;
148
149        // Invoke on_ready handler if configured
150        if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
151            // Run handler asynchronously in a detached task
152            let connection = self.clone();
153            self.spawn_task(async move { on_ready_handler(connection).await });
154        }
155
156        Ok(())
157    }
158
159    #[inline]
160    fn ensure_status_ready(&self) -> Result<()> {
161        let status = self.status.get();
162        if status != ConnectionStatus::Ready {
163            bail!("Cannot emit event in invalid status: {:#?}", status);
164        }
165
166        Ok(())
167    }
168
169    async fn handle_auth_packet(self: &Arc<Self>, packet_data: &[u8]) -> Result<()> {
170        // Verify current state; only valid from AwaitingAuth → Authenticating
171        let status = self.status.get();
172        match status {
173            ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
174            _ => bail!("Received auth packet in invalid status: {:#?}", status),
175        }
176
177        // Abort auth-timeout task if still active
178        abort_locked_task(&self.auth_timeout_task).await;
179
180        // Invoke auth handler with timeout protection if configured, otherwise raise error
181        if let Some(auth_handler) = &self.namespace.config.auth_handler {
182            timeout(
183                self.namespace.config.auth_handler_timeout,
184                auth_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
185            )
186            .await??;
187
188            // Activate connection
189            self.activate().await
190        } else {
191            bail!("Auth packet received but no auth handler is configured");
192        }
193    }
194
195    #[inline]
196    fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
197        self.event_registry
198            .dispatch_event_packet(self.clone(), event, packet_data);
199
200        Ok(())
201    }
202
203    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
204        Ok(self
205            .message_tx
206            .send(self.namespace.encode_packet_to_message(packet)?)
207            .await?)
208    }
209
210    // Protected methods
211    pub(crate) async fn cleanup(self: &Arc<Self>) {
212        // Set connection state to Closing
213        self.status.store(ConnectionStatus::Closing);
214
215        // Abort auth-timeout task if still active
216        abort_locked_task(&self.auth_timeout_task).await;
217
218        // Remove connection from namespace
219        self.namespace.remove_connection(&self.sid);
220
221        // Cancel all ongoing operations via cancel token
222        self.cancel_token.cancel();
223
224        // Invoke on_close handler with timeout protection if configured
225        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
226            let _ = timeout(
227                self.namespace.config.on_close_handler_timeout,
228                on_close_handler(self.clone()),
229            )
230            .await;
231        }
232
233        // Set connection state to Closed
234        self.status.store(ConnectionStatus::Closed);
235    }
236
237    #[inline]
238    pub(crate) fn close(&self) {
239        // Skip if connection is already Closing or Closed, otherwise set connection state to Closing
240        match self.status.get() {
241            ConnectionStatus::Closed | ConnectionStatus::Closing => return,
242            _ => self.status.store(ConnectionStatus::Closing),
243        }
244
245        // Send websocket close frame to initiate graceful shutdown
246        let _ = self.message_tx.try_send(Message::Close(None));
247    }
248
249    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
250        // TODO: lazy load
251        let packet = self.namespace.config.packet_codec.decode(bytes)?;
252        match packet.r#type {
253            WsIoPacketType::Auth => {
254                if let Some(packet_data) = packet.data.as_deref() {
255                    self.handle_auth_packet(packet_data).await
256                } else {
257                    bail!("Auth packet missing data");
258                }
259            }
260            WsIoPacketType::Event => {
261                if let Some(event) = packet.key.as_deref() {
262                    self.handle_event_packet(event, packet.data)
263                } else {
264                    bail!("Event packet missing key");
265                }
266            }
267            _ => Ok(()),
268        }
269    }
270
271    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
272        // Verify current state; only valid Created
273        let status = self.status.get();
274        if !matches!(status, ConnectionStatus::Created) {
275            bail!("Cannot init connection in invalid status: {:#?}", status);
276        }
277
278        // Determine if authentication is required
279        let requires_auth = self.namespace.config.auth_handler.is_some();
280
281        // Build Init packet to inform client whether auth is required
282        let packet = &WsIoPacket::new_init(self.namespace.config.packet_codec.encode_data(&requires_auth)?);
283
284        // If authentication is required
285        if requires_auth {
286            // Transition state to AwaitingAuth
287            self.status
288                .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
289
290            // Spawn auth-packet-timeout watchdog to close connection if auth not received in time
291            let connection = self.clone();
292            *self.auth_timeout_task.lock().await = Some(spawn(async move {
293                sleep(connection.namespace.config.auth_packet_timeout).await;
294                if connection.status.is(ConnectionStatus::AwaitingAuth) {
295                    connection.close();
296                }
297            }));
298
299            // Send Init packet to client (expecting auth response)
300            self.send_packet(packet).await
301        } else {
302            // Send Init packet to client (no auth required)
303            self.send_packet(packet).await?;
304
305            // Immediately activate connection
306            self.activate().await
307        }
308    }
309
310    // Public methods
311
312    #[inline]
313    pub fn cancel_token(&self) -> &CancellationToken {
314        &self.cancel_token
315    }
316
317    pub async fn disconnect(&self) {
318        let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
319        self.close()
320    }
321
322    pub async fn emit<D: Serialize>(&self, event: impl Into<String>, data: Option<&D>) -> Result<()> {
323        self.ensure_status_ready()?;
324        self.send_packet(&WsIoPacket::new_event(
325            event,
326            data.map(|data| self.namespace.config.packet_codec.encode_data(data))
327                .transpose()?,
328        ))
329        .await
330    }
331
332    pub async fn emit_message(&self, message: Message) -> Result<()> {
333        self.ensure_status_ready()?;
334        Ok(self.message_tx.send(message).await?)
335    }
336
337    #[cfg(feature = "connection-extensions")]
338    #[inline]
339    pub fn extensions(&self) -> &ConnectionExtensions {
340        &self.extensions
341    }
342
343    #[inline]
344    pub fn headers(&self) -> &HeaderMap {
345        &self.headers
346    }
347
348    #[inline]
349    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
350        self.namespace.clone()
351    }
352
353    #[inline]
354    pub fn off(&self, event: impl AsRef<str>) {
355        self.event_registry.off(event);
356    }
357
358    #[inline]
359    pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
360        self.event_registry.off_by_handler_id(event, handler_id);
361    }
362
363    #[inline]
364    pub fn on<H, Fut, D>(&self, event: impl Into<String>, handler: H) -> u32
365    where
366        H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
367        Fut: Future<Output = Result<()>> + Send + 'static,
368        D: DeserializeOwned + Send + Sync + 'static,
369    {
370        self.event_registry.on(event, handler)
371    }
372
373    pub async fn on_close<H, Fut>(&self, handler: H)
374    where
375        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
376        Fut: Future<Output = Result<()>> + Send + 'static,
377    {
378        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
379    }
380
381    #[inline]
382    pub fn server(&self) -> WsIoServer {
383        self.namespace.server()
384    }
385
386    #[inline]
387    pub fn sid(&self) -> &str {
388        &self.sid
389    }
390
391    #[inline]
392    pub fn spawn_task<F: Future<Output = Result<()>> + Send + 'static>(&self, future: F) {
393        let cancel_token = self.cancel_token.clone();
394        spawn(async move {
395            select! {
396                _ = cancel_token.cancelled() => {},
397                _ = future => {},
398            }
399        });
400    }
401}