wsio_server/connection/
mod.rs

1use std::{
2    pin::Pin,
3    sync::Arc,
4};
5
6use anyhow::{
7    Result,
8    bail,
9};
10use http::HeaderMap;
11use num_enum::{
12    IntoPrimitive,
13    TryFromPrimitive,
14};
15use serde::Serialize;
16use tokio::{
17    select,
18    spawn,
19    sync::{
20        Mutex,
21        mpsc::{
22            Receiver,
23            Sender,
24            channel,
25        },
26    },
27    task::JoinHandle,
28    time::{
29        sleep,
30        timeout,
31    },
32};
33use tokio_tungstenite::tungstenite::Message;
34use tokio_util::sync::CancellationToken;
35
36#[cfg(feature = "connection-extensions")]
37mod extensions;
38
39#[cfg(feature = "connection-extensions")]
40use self::extensions::WsIoServerConnectionExtensions;
41use crate::{
42    WsIoServer,
43    core::{
44        atomic::status::AtomicStatus,
45        packet::{
46            WsIoPacket,
47            WsIoPacketType,
48        },
49        utils::task::abort_locked_task,
50    },
51    namespace::WsIoServerNamespace,
52};
53
54#[repr(u8)]
55#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
56enum ConnectionStatus {
57    Activating,
58    Authenticating,
59    AwaitingAuth,
60    Closed,
61    Closing,
62    Created,
63    Ready,
64}
65
66type OnCloseHandler = Box<
67    dyn Fn(Arc<WsIoServerConnection>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>
68        + Send
69        + Sync
70        + 'static,
71>;
72
73pub struct WsIoServerConnection {
74    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
75    cancel_token: CancellationToken,
76    #[cfg(feature = "connection-extensions")]
77    extensions: WsIoServerConnectionExtensions,
78    headers: HeaderMap,
79    message_tx: Sender<Message>,
80    namespace: Arc<WsIoServerNamespace>,
81    on_close_handler: Mutex<Option<OnCloseHandler>>,
82    sid: String,
83    status: AtomicStatus<ConnectionStatus>,
84}
85
86impl WsIoServerConnection {
87    pub(crate) fn new(
88        headers: HeaderMap,
89        namespace: Arc<WsIoServerNamespace>,
90        sid: String,
91    ) -> (Arc<Self>, Receiver<Message>) {
92        let channel_capacity = (namespace.config.websocket_config.max_write_buffer_size
93            / namespace.config.websocket_config.write_buffer_size)
94            .clamp(64, 4096);
95
96        let (message_tx, message_rx) = channel(channel_capacity);
97        (
98            Arc::new(Self {
99                auth_timeout_task: Mutex::new(None),
100                cancel_token: CancellationToken::new(),
101                #[cfg(feature = "connection-extensions")]
102                extensions: WsIoServerConnectionExtensions::new(),
103                headers,
104                message_tx,
105                namespace,
106                on_close_handler: Mutex::new(None),
107                sid,
108                status: AtomicStatus::new(ConnectionStatus::Created),
109            }),
110            message_rx,
111        )
112    }
113
114    // Private methods
115    async fn activate(self: &Arc<Self>) -> Result<()> {
116        // Verify current state; only valid from Authenticating or Created → Activating
117        let status = self.status.get();
118        match status {
119            ConnectionStatus::Authenticating | ConnectionStatus::Created => {
120                self.status.try_transition(status, ConnectionStatus::Activating)?
121            }
122            _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
123        }
124
125        // Invoke middleware with timeout protection if configured
126        if let Some(middleware) = &self.namespace.config.middleware {
127            timeout(
128                self.namespace.config.middleware_execution_timeout,
129                middleware(self.clone()),
130            )
131            .await??;
132        }
133
134        // Invoke on_connect handler with timeout protection if configured
135        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
136            timeout(
137                self.namespace.config.on_connect_handler_timeout,
138                on_connect_handler(self.clone()),
139            )
140            .await??;
141        }
142
143        // Insert connection into namespace
144        self.namespace.insert_connection(self.clone());
145
146        // Transition state to Ready
147        self.status
148            .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
149
150        // Send ready packet
151        self.send_packet(&WsIoPacket {
152            data: None,
153            key: None,
154            r#type: WsIoPacketType::Ready,
155        })
156        .await?;
157
158        // Invoke on_ready handler if configured
159        if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
160            // Run handler asynchronously in a detached task
161            let connection = self.clone();
162            self.spawn_task(async move { on_ready_handler(connection).await });
163        }
164
165        Ok(())
166    }
167
168    async fn handle_auth_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
169        // Verify current state; only valid from AwaitingAuth → Authenticating
170        let status = self.status.get();
171        match status {
172            ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
173            _ => bail!("Received auth packet in invalid status: {:#?}", status),
174        }
175
176        // Abort auth-timeout task if still active
177        abort_locked_task(&self.auth_timeout_task).await;
178
179        // Invoke auth handler with timeout protection if configured, otherwise raise error
180        if let Some(auth_handler) = &self.namespace.config.auth_handler {
181            timeout(
182                self.namespace.config.auth_handler_timeout,
183                (auth_handler)(self.clone(), packet_data),
184            )
185            .await??;
186
187            // Activate connection
188            self.activate().await
189        } else {
190            bail!("Auth packet received but no auth handler is configured");
191        }
192    }
193
194    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
195        Ok(self
196            .message_tx
197            .send(self.namespace.encode_packet_to_message(packet)?)
198            .await?)
199    }
200
201    // Protected methods
202    pub(crate) async fn cleanup(self: &Arc<Self>) {
203        // Set connection state to Closing
204        self.status.store(ConnectionStatus::Closing);
205
206        // Abort auth-timeout task if still active
207        abort_locked_task(&self.auth_timeout_task).await;
208
209        // Remove connection from namespace
210        self.namespace.remove_connection(&self.sid);
211
212        // Cancel all ongoing operations via cancel token
213        self.cancel_token.cancel();
214
215        // Invoke on_close handler with timeout protection if configured
216        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
217            let _ = timeout(
218                self.namespace.config.on_close_handler_timeout,
219                on_close_handler(self.clone()),
220            )
221            .await;
222        }
223
224        // Set connection state to Closed
225        self.status.store(ConnectionStatus::Closed);
226    }
227
228    #[inline]
229    pub(crate) fn close(&self) {
230        // Skip if connection is already Closing or Closed, otherwise set connection state to Closing
231        match self.status.get() {
232            ConnectionStatus::Closed | ConnectionStatus::Closing => return,
233            _ => self.status.store(ConnectionStatus::Closing),
234        }
235
236        // Send websocket close frame to initiate graceful shutdown
237        let _ = self.message_tx.try_send(Message::Close(None));
238    }
239
240    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
241        let packet = self.namespace.config.packet_codec.decode(bytes)?;
242        match packet.r#type {
243            WsIoPacketType::Auth => self.handle_auth_packet(packet.data.as_deref()).await,
244            _ => Ok(()),
245        }
246    }
247
248    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
249        // Verify current state; only valid Created
250        let status = self.status.get();
251        if !matches!(status, ConnectionStatus::Created) {
252            bail!("Cannot init connection in invalid status: {:#?}", status);
253        }
254
255        // Determine if authentication is required
256        let requires_auth = self.namespace.config.auth_handler.is_some();
257
258        // Build Init packet to inform client whether auth is required
259        let packet = WsIoPacket {
260            data: Some(self.namespace.config.packet_codec.encode_data(&requires_auth)?),
261            key: None,
262            r#type: WsIoPacketType::Init,
263        };
264
265        // If authentication is required
266        if requires_auth {
267            // Transition state to AwaitingAuth
268            self.status
269                .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
270
271            // Spawn auth-packet-timeout watchdog to close connection if auth not received in time
272            let connection = self.clone();
273            *self.auth_timeout_task.lock().await = Some(spawn(async move {
274                sleep(connection.namespace.config.auth_packet_timeout).await;
275                if connection.status.is(ConnectionStatus::AwaitingAuth) {
276                    connection.close();
277                }
278            }));
279
280            // Send Init packet to client (expecting auth response)
281            self.send_packet(&packet).await
282        } else {
283            // Send Init packet to client (no auth required)
284            self.send_packet(&packet).await?;
285
286            // Immediately activate connection
287            self.activate().await
288        }
289    }
290
291    // Public methods
292
293    #[inline]
294    pub fn cancel_token(&self) -> &CancellationToken {
295        &self.cancel_token
296    }
297
298    pub async fn disconnect(&self) {
299        let _ = self
300            .send_packet(&WsIoPacket {
301                data: None,
302                key: None,
303                r#type: WsIoPacketType::Disconnect,
304            })
305            .await;
306
307        self.close()
308    }
309
310    pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
311        let status = self.status.get();
312        if status != ConnectionStatus::Ready {
313            bail!("Cannot emit event in invalid status: {:#?}", status);
314        }
315
316        self.send_packet(&WsIoPacket {
317            data: data
318                .map(|data| self.namespace.config.packet_codec.encode_data(data))
319                .transpose()?,
320            key: Some(event.as_ref().to_string()),
321            r#type: WsIoPacketType::Event,
322        })
323        .await
324    }
325
326    #[cfg(feature = "connection-extensions")]
327    #[inline]
328    pub fn extensions(&self) -> &WsIoServerConnectionExtensions {
329        &self.extensions
330    }
331
332    #[inline]
333    pub fn headers(&self) -> &HeaderMap {
334        &self.headers
335    }
336
337    #[inline]
338    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
339        self.namespace.clone()
340    }
341
342    pub async fn on_close<H, Fut>(&self, handler: H)
343    where
344        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
345        Fut: Future<Output = Result<()>> + Send + 'static,
346    {
347        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
348    }
349
350    #[inline]
351    pub fn server(&self) -> WsIoServer {
352        self.namespace.server()
353    }
354
355    #[inline]
356    pub fn sid(&self) -> &str {
357        &self.sid
358    }
359
360    #[inline]
361    pub fn spawn_task<F: Future<Output = Result<()>> + Send + 'static>(&self, future: F) {
362        let cancel_token = self.cancel_token.clone();
363        spawn(async move {
364            select! {
365                _ = cancel_token.cancelled() => {},
366                _ = future => {},
367            }
368        });
369    }
370}