wsio_server/connection/
mod.rs

1use std::sync::{
2    Arc,
3    LazyLock,
4    atomic::{
5        AtomicU64,
6        Ordering,
7    },
8};
9
10use anyhow::{
11    Result,
12    bail,
13};
14use arc_swap::ArcSwap;
15use http::HeaderMap;
16use num_enum::{
17    IntoPrimitive,
18    TryFromPrimitive,
19};
20use serde::{
21    Serialize,
22    de::DeserializeOwned,
23};
24use tokio::{
25    spawn,
26    sync::{
27        Mutex,
28        mpsc::{
29            Receiver,
30            Sender,
31            channel,
32        },
33    },
34    task::JoinHandle,
35    time::{
36        sleep,
37        timeout,
38    },
39};
40use tokio_tungstenite::tungstenite::Message;
41use tokio_util::sync::CancellationToken;
42
43#[cfg(feature = "connection-extensions")]
44mod extensions;
45
46#[cfg(feature = "connection-extensions")]
47use self::extensions::ConnectionExtensions;
48use crate::{
49    WsIoServer,
50    core::{
51        atomic::status::AtomicStatus,
52        channel_capacity_from_websocket_config,
53        event::registry::WsIoEventRegistry,
54        packet::{
55            WsIoPacket,
56            WsIoPacketType,
57        },
58        traits::task::spawner::TaskSpawner,
59        types::{
60            BoxAsyncUnaryResultHandler,
61            hashers::FxDashSet,
62        },
63        utils::task::abort_locked_task,
64    },
65    namespace::WsIoServerNamespace,
66};
67
68// Enums
69#[repr(u8)]
70#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
71enum ConnectionStatus {
72    Activating,
73    Authenticating,
74    AwaitingAuth,
75    Closed,
76    Closing,
77    Created,
78    Ready,
79}
80
81// Structs
82pub struct WsIoServerConnection {
83    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
84    cancel_token: ArcSwap<CancellationToken>,
85    event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
86    #[cfg(feature = "connection-extensions")]
87    extensions: ConnectionExtensions,
88    headers: HeaderMap,
89    id: u64,
90    joined_rooms: FxDashSet<String>,
91    message_tx: Sender<Arc<Message>>,
92    namespace: Arc<WsIoServerNamespace>,
93    on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
94    status: AtomicStatus<ConnectionStatus>,
95}
96
97impl TaskSpawner for WsIoServerConnection {
98    #[inline]
99    fn cancel_token(&self) -> Arc<CancellationToken> {
100        self.cancel_token.load_full()
101    }
102}
103
104impl WsIoServerConnection {
105    #[inline]
106    pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Arc<Self>, Receiver<Arc<Message>>) {
107        let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
108        let (message_tx, message_rx) = channel(channel_capacity);
109        (
110            Arc::new(Self {
111                auth_timeout_task: Mutex::new(None),
112                cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
113                event_registry: WsIoEventRegistry::new(),
114                #[cfg(feature = "connection-extensions")]
115                extensions: ConnectionExtensions::new(),
116                headers,
117                id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
118                joined_rooms: FxDashSet::default(),
119                message_tx,
120                namespace,
121                on_close_handler: Mutex::new(None),
122                status: AtomicStatus::new(ConnectionStatus::Created),
123            }),
124            message_rx,
125        )
126    }
127
128    // Private methods
129    async fn activate(self: &Arc<Self>) -> Result<()> {
130        // Verify current state; only valid from Authenticating or Created → Activating
131        let status = self.status.get();
132        match status {
133            ConnectionStatus::Authenticating | ConnectionStatus::Created => {
134                self.status.try_transition(status, ConnectionStatus::Activating)?
135            }
136            _ => bail!("Cannot activate connection in invalid status: {:#?}", status),
137        }
138
139        // Invoke middleware with timeout protection if configured
140        if let Some(middleware) = &self.namespace.config.middleware {
141            timeout(
142                self.namespace.config.middleware_execution_timeout,
143                middleware(self.clone()),
144            )
145            .await??;
146
147            // Ensure connection is still in Activating state
148            self.status.ensure(ConnectionStatus::Activating, |status| {
149                format!("Cannot activate connection in invalid status: {:#?}", status)
150            })?;
151        }
152
153        // Invoke on_connect handler with timeout protection if configured
154        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
155            timeout(
156                self.namespace.config.on_connect_handler_timeout,
157                on_connect_handler(self.clone()),
158            )
159            .await??;
160        }
161
162        // Transition state to Ready
163        self.status
164            .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
165
166        // Insert connection into namespace
167        self.namespace.insert_connection(self.clone());
168
169        // Send ready packet
170        self.send_packet(&WsIoPacket::new_ready()).await?;
171
172        // Invoke on_ready handler if configured
173        if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
174            // Run handler asynchronously in a detached task
175            self.spawn_task(on_ready_handler(self.clone()));
176        }
177
178        Ok(())
179    }
180
181    async fn handle_auth_packet(self: &Arc<Self>, packet_data: &[u8]) -> Result<()> {
182        // Verify current state; only valid from AwaitingAuth → Authenticating
183        let status = self.status.get();
184        match status {
185            ConnectionStatus::AwaitingAuth => self.status.try_transition(status, ConnectionStatus::Authenticating)?,
186            _ => bail!("Received auth packet in invalid status: {:#?}", status),
187        }
188
189        // Abort auth-timeout task if still active
190        abort_locked_task(&self.auth_timeout_task).await;
191
192        // Invoke auth handler with timeout protection if configured, otherwise raise error
193        if let Some(auth_handler) = &self.namespace.config.auth_handler {
194            timeout(
195                self.namespace.config.auth_handler_timeout,
196                auth_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
197            )
198            .await??;
199
200            // Activate connection
201            self.activate().await
202        } else {
203            bail!("Auth packet received but no auth handler is configured");
204        }
205    }
206
207    #[inline]
208    fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
209        self.event_registry.dispatch_event_packet(
210            self.clone(),
211            event,
212            &self.namespace.config.packet_codec,
213            packet_data,
214            self,
215        );
216
217        Ok(())
218    }
219
220    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
221        self.send_message(self.namespace.encode_packet_to_message(packet)?)
222            .await
223    }
224
225    // Protected methods
226    pub(crate) async fn cleanup(self: &Arc<Self>) {
227        // Set connection state to Closing
228        self.status.store(ConnectionStatus::Closing);
229
230        // Remove connection from namespace
231        self.namespace.remove_connection(self.id);
232
233        // Leave all joined rooms
234        let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
235        for room_name in &joined_rooms {
236            self.namespace.remove_connection_id_from_room(room_name, self.id);
237        }
238
239        self.joined_rooms.clear();
240
241        // Abort auth-timeout task if still active
242        abort_locked_task(&self.auth_timeout_task).await;
243
244        // Cancel all ongoing operations via cancel token
245        self.cancel_token.load().cancel();
246
247        // Invoke on_close handler with timeout protection if configured
248        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
249            let _ = timeout(
250                self.namespace.config.on_close_handler_timeout,
251                on_close_handler(self.clone()),
252            )
253            .await;
254        }
255
256        // Set connection state to Closed
257        self.status.store(ConnectionStatus::Closed);
258    }
259
260    #[inline]
261    pub(crate) fn close(&self) {
262        // Skip if connection is already Closing or Closed, otherwise set connection state to Closing
263        match self.status.get() {
264            ConnectionStatus::Closed | ConnectionStatus::Closing => return,
265            _ => self.status.store(ConnectionStatus::Closing),
266        }
267
268        // Send websocket close frame to initiate graceful shutdown
269        let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
270    }
271
272    pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
273        self.status.ensure(ConnectionStatus::Ready, |status| {
274            format!("Cannot emit in invalid status: {:#?}", status)
275        })?;
276
277        self.send_message(message).await
278    }
279
280    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
281        // TODO: lazy load
282        let packet = self.namespace.config.packet_codec.decode(bytes)?;
283        match packet.r#type {
284            WsIoPacketType::Auth => {
285                if let Some(packet_data) = packet.data.as_deref() {
286                    self.handle_auth_packet(packet_data).await
287                } else {
288                    bail!("Auth packet missing data");
289                }
290            }
291            WsIoPacketType::Event => {
292                if let Some(event) = packet.key.as_deref() {
293                    self.handle_event_packet(event, packet.data)
294                } else {
295                    bail!("Event packet missing key");
296                }
297            }
298            _ => Ok(()),
299        }
300    }
301
302    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
303        // Verify current state; only valid Created
304        self.status.ensure(ConnectionStatus::Created, |status| {
305            format!("Cannot init connection in invalid status: {:#?}", status)
306        })?;
307
308        // Determine if authentication is required
309        let requires_auth = self.namespace.config.auth_handler.is_some();
310
311        // Build Init packet to inform client whether auth is required
312        let packet = &WsIoPacket::new_init(self.namespace.config.packet_codec.encode_data(&requires_auth)?);
313
314        // If authentication is required
315        if requires_auth {
316            // Transition state to AwaitingAuth
317            self.status
318                .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingAuth)?;
319
320            // Spawn auth-packet-timeout watchdog to close connection if auth not received in time
321            let connection = self.clone();
322            *self.auth_timeout_task.lock().await = Some(spawn(async move {
323                sleep(connection.namespace.config.auth_packet_timeout).await;
324                if connection.status.is(ConnectionStatus::AwaitingAuth) {
325                    connection.close();
326                }
327            }));
328
329            // Send Init packet to client (expecting auth response)
330            self.send_packet(packet).await
331        } else {
332            // Send Init packet to client (no auth required)
333            self.send_packet(packet).await?;
334
335            // Immediately activate connection
336            self.activate().await
337        }
338    }
339
340    pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
341        Ok(self.message_tx.send(message).await?)
342    }
343
344    // Public methods
345    pub async fn disconnect(&self) {
346        let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
347        self.close()
348    }
349
350    pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
351        self.emit_event_message(
352            self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
353                event.as_ref(),
354                data.map(|data| self.namespace.config.packet_codec.encode_data(data))
355                    .transpose()?,
356            ))?,
357        )
358        .await
359    }
360
361    #[cfg(feature = "connection-extensions")]
362    #[inline]
363    pub fn extensions(&self) -> &ConnectionExtensions {
364        &self.extensions
365    }
366
367    #[inline]
368    pub fn headers(&self) -> &HeaderMap {
369        &self.headers
370    }
371
372    #[inline]
373    pub fn id(&self) -> u64 {
374        self.id
375    }
376
377    #[inline]
378    pub fn join<I: IntoIterator<Item = S>, S: AsRef<str>>(self: &Arc<Self>, room_names: I) {
379        for room_name in room_names {
380            let room_name = room_name.as_ref();
381            self.namespace.add_connection_id_to_room(room_name, self.id);
382            self.joined_rooms.insert(room_name.to_string());
383        }
384    }
385
386    #[inline]
387    pub fn leave<I: IntoIterator<Item = S>, S: AsRef<str>>(self: &Arc<Self>, room_names: I) {
388        for room_name in room_names {
389            self.namespace
390                .remove_connection_id_from_room(room_name.as_ref(), self.id);
391
392            self.joined_rooms.remove(room_name.as_ref());
393        }
394    }
395
396    #[inline]
397    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
398        self.namespace.clone()
399    }
400
401    #[inline]
402    pub fn off(&self, event: impl AsRef<str>) {
403        self.event_registry.off(event.as_ref());
404    }
405
406    #[inline]
407    pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
408        self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
409    }
410
411    #[inline]
412    pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
413    where
414        H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
415        Fut: Future<Output = Result<()>> + Send + 'static,
416        D: DeserializeOwned + Send + Sync + 'static,
417    {
418        self.event_registry.on(event.as_ref(), handler)
419    }
420
421    pub async fn on_close<H, Fut>(&self, handler: H)
422    where
423        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
424        Fut: Future<Output = Result<()>> + Send + 'static,
425    {
426        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
427    }
428
429    #[inline]
430    pub fn server(&self) -> WsIoServer {
431        self.namespace.server()
432    }
433}
434
435// Constants/Statics
436static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));