wsio_server/connection/
mod.rs

1use std::sync::{
2    Arc,
3    LazyLock,
4    atomic::{
5        AtomicU64,
6        Ordering,
7    },
8};
9
10use anyhow::{
11    Result,
12    bail,
13};
14use arc_swap::ArcSwap;
15use http::{
16    HeaderMap,
17    Uri,
18};
19use num_enum::{
20    IntoPrimitive,
21    TryFromPrimitive,
22};
23use serde::{
24    Serialize,
25    de::DeserializeOwned,
26};
27use tokio::{
28    spawn,
29    sync::{
30        Mutex,
31        mpsc::{
32            Receiver,
33            Sender,
34            channel,
35        },
36    },
37    task::JoinHandle,
38    time::{
39        sleep,
40        timeout,
41    },
42};
43use tokio_tungstenite::tungstenite::Message;
44use tokio_util::sync::CancellationToken;
45
46#[cfg(feature = "connection-extensions")]
47mod extensions;
48
49#[cfg(feature = "connection-extensions")]
50use self::extensions::ConnectionExtensions;
51use crate::{
52    WsIoServer,
53    core::{
54        atomic::status::AtomicStatus,
55        channel_capacity_from_websocket_config,
56        event::registry::WsIoEventRegistry,
57        packet::{
58            WsIoPacket,
59            WsIoPacketType,
60        },
61        traits::task::spawner::TaskSpawner,
62        types::{
63            BoxAsyncUnaryResultHandler,
64            hashers::FxDashSet,
65        },
66        utils::task::abort_locked_task,
67    },
68    namespace::{
69        WsIoServerNamespace,
70        operators::broadcast::WsIoServerNamespaceBroadcastOperator,
71    },
72};
73
74// Enums
75#[repr(u8)]
76#[derive(Debug, Eq, IntoPrimitive, PartialEq, TryFromPrimitive)]
77enum ConnectionStatus {
78    Activating,
79    AwaitingInit,
80    Closed,
81    Closing,
82    Created,
83    Initiating,
84    Ready,
85}
86
87// Structs
88pub struct WsIoServerConnection {
89    cancel_token: ArcSwap<CancellationToken>,
90    event_registry: WsIoEventRegistry<WsIoServerConnection, WsIoServerConnection>,
91    #[cfg(feature = "connection-extensions")]
92    extensions: ConnectionExtensions,
93    headers: HeaderMap,
94    id: u64,
95    init_timeout_task: Mutex<Option<JoinHandle<()>>>,
96    joined_rooms: FxDashSet<String>,
97    message_tx: Sender<Arc<Message>>,
98    namespace: Arc<WsIoServerNamespace>,
99    on_close_handler: Mutex<Option<BoxAsyncUnaryResultHandler<Self>>>,
100    request_uri: Uri,
101    status: AtomicStatus<ConnectionStatus>,
102}
103
104impl TaskSpawner for WsIoServerConnection {
105    #[inline]
106    fn cancel_token(&self) -> Arc<CancellationToken> {
107        self.cancel_token.load_full()
108    }
109}
110
111impl WsIoServerConnection {
112    #[inline]
113    pub(crate) fn new(
114        headers: HeaderMap,
115        namespace: Arc<WsIoServerNamespace>,
116        request_uri: Uri,
117    ) -> (Arc<Self>, Receiver<Arc<Message>>) {
118        let channel_capacity = channel_capacity_from_websocket_config(&namespace.config.websocket_config);
119        let (message_tx, message_rx) = channel(channel_capacity);
120        (
121            Arc::new(Self {
122                cancel_token: ArcSwap::new(Arc::new(CancellationToken::new())),
123                event_registry: WsIoEventRegistry::new(),
124                #[cfg(feature = "connection-extensions")]
125                extensions: ConnectionExtensions::new(),
126                headers,
127                id: NEXT_CONNECTION_ID.fetch_add(1, Ordering::Relaxed),
128                init_timeout_task: Mutex::new(None),
129                joined_rooms: FxDashSet::default(),
130                message_tx,
131                namespace,
132                on_close_handler: Mutex::new(None),
133                request_uri,
134                status: AtomicStatus::new(ConnectionStatus::Created),
135            }),
136            message_rx,
137        )
138    }
139
140    // Private methods
141    #[inline]
142    fn handle_event_packet(self: &Arc<Self>, event: &str, packet_data: Option<Vec<u8>>) -> Result<()> {
143        if self.is_ready() {
144            self.event_registry.dispatch_event_packet(
145                self.clone(),
146                event,
147                &self.namespace.config.packet_codec,
148                packet_data,
149                self,
150            );
151        }
152
153        Ok(())
154    }
155
156    async fn handle_init_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
157        // Verify current state; only valid from AwaitingInit → Initiating
158        let status = self.status.get();
159        match status {
160            ConnectionStatus::AwaitingInit => self.status.try_transition(status, ConnectionStatus::Initiating)?,
161            _ => bail!("Received init packet in invalid status: {status:?}"),
162        }
163
164        // Abort init-timeout task
165        abort_locked_task(&self.init_timeout_task).await;
166
167        // Invoke init_response_handler with timeout protection if configured
168        if let Some(init_response_handler) = &self.namespace.config.init_response_handler {
169            timeout(
170                self.namespace.config.init_response_handler_timeout,
171                init_response_handler(self.clone(), packet_data, &self.namespace.config.packet_codec),
172            )
173            .await??
174        }
175
176        // Activate connection
177        self.status
178            .try_transition(ConnectionStatus::Initiating, ConnectionStatus::Activating)?;
179
180        // Invoke middleware with timeout protection if configured
181        if let Some(middleware) = &self.namespace.config.middleware {
182            timeout(
183                self.namespace.config.middleware_execution_timeout,
184                middleware(self.clone()),
185            )
186            .await??;
187
188            // Ensure connection is still in Activating state
189            self.status.ensure(ConnectionStatus::Activating, |status| {
190                format!("Cannot activate connection in invalid status: {status:?}")
191            })?;
192        }
193
194        // Invoke on_connect_handler with timeout protection if configured
195        if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
196            timeout(
197                self.namespace.config.on_connect_handler_timeout,
198                on_connect_handler(self.clone()),
199            )
200            .await??;
201        }
202
203        // Transition state to Ready
204        self.status
205            .try_transition(ConnectionStatus::Activating, ConnectionStatus::Ready)?;
206
207        // Insert connection into namespace
208        self.namespace.insert_connection(self.clone());
209
210        // Send ready packet
211        self.send_packet(&WsIoPacket::new_ready()).await?;
212
213        // Invoke on_ready_handler if configured
214        if let Some(on_ready_handler) = self.namespace.config.on_ready_handler.clone() {
215            // Run handler asynchronously in a detached task
216            self.spawn_task(on_ready_handler(self.clone()));
217        }
218
219        Ok(())
220    }
221
222    async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
223        self.send_message(self.namespace.encode_packet_to_message(packet)?)
224            .await
225    }
226
227    // Protected methods
228    pub(crate) async fn cleanup(self: &Arc<Self>) {
229        // Set connection state to Closing
230        self.status.store(ConnectionStatus::Closing);
231
232        // Remove connection from namespace
233        self.namespace.remove_connection(self.id);
234
235        // Leave all joined rooms
236        let joined_rooms = self.joined_rooms.iter().map(|entry| entry.clone()).collect::<Vec<_>>();
237        for room_name in &joined_rooms {
238            self.namespace.remove_connection_id_from_room(room_name, self.id);
239        }
240
241        self.joined_rooms.clear();
242
243        // Abort init-timeout task
244        abort_locked_task(&self.init_timeout_task).await;
245
246        // Cancel all ongoing operations via cancel token
247        self.cancel_token.load().cancel();
248
249        // Invoke on_close_handler with timeout protection if configured
250        if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
251            let _ = timeout(
252                self.namespace.config.on_close_handler_timeout,
253                on_close_handler(self.clone()),
254            )
255            .await;
256        }
257
258        // Set connection state to Closed
259        self.status.store(ConnectionStatus::Closed);
260    }
261
262    #[inline]
263    pub(crate) fn close(&self) {
264        // Skip if connection is already Closing or Closed, otherwise set connection state to Closing
265        match self.status.get() {
266            ConnectionStatus::Closed | ConnectionStatus::Closing => return,
267            _ => self.status.store(ConnectionStatus::Closing),
268        }
269
270        // Send websocket close frame to initiate graceful shutdown
271        let _ = self.message_tx.try_send(Arc::new(Message::Close(None)));
272    }
273
274    pub(crate) async fn emit_event_message(&self, message: Arc<Message>) -> Result<()> {
275        self.status.ensure(ConnectionStatus::Ready, |status| {
276            format!("Cannot emit in invalid status: {status:?}")
277        })?;
278
279        self.send_message(message).await
280    }
281
282    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, encoded_packet: &[u8]) -> Result<()> {
283        // TODO: lazy load
284        let packet = self.namespace.config.packet_codec.decode(encoded_packet)?;
285        match packet.r#type {
286            WsIoPacketType::Event => {
287                if let Some(event) = packet.key.as_deref() {
288                    self.handle_event_packet(event, packet.data)
289                } else {
290                    bail!("Event packet missing key");
291                }
292            }
293            WsIoPacketType::Init => self.handle_init_packet(packet.data.as_deref()).await,
294            _ => Ok(()),
295        }
296    }
297
298    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
299        // Verify current state; only valid Created
300        self.status.ensure(ConnectionStatus::Created, |status| {
301            format!("Cannot init connection in invalid status: {status:?}")
302        })?;
303
304        // Generate init request data if init request handler is configured
305        let init_request_data = if let Some(init_request_handler) = &self.namespace.config.init_request_handler {
306            timeout(
307                self.namespace.config.init_request_handler_timeout,
308                init_request_handler(self.clone(), &self.namespace.config.packet_codec),
309            )
310            .await??
311        } else {
312            None
313        };
314
315        // Transition state to AwaitingInit
316        self.status
317            .try_transition(ConnectionStatus::Created, ConnectionStatus::AwaitingInit)?;
318
319        // Spawn init-response-timeout watchdog to close connection if init not received in time
320        let connection = self.clone();
321        *self.init_timeout_task.lock().await = Some(spawn(async move {
322            sleep(connection.namespace.config.init_response_timeout).await;
323            if connection.status.is(ConnectionStatus::AwaitingInit) {
324                connection.close();
325            }
326        }));
327
328        // Send init packet
329        self.send_packet(&WsIoPacket::new_init(init_request_data)).await
330    }
331
332    pub(crate) async fn send_message(&self, message: Arc<Message>) -> Result<()> {
333        Ok(self.message_tx.send(message).await?)
334    }
335
336    // Public methods
337    pub async fn disconnect(&self) {
338        let _ = self.send_packet(&WsIoPacket::new_disconnect()).await;
339        self.close()
340    }
341
342    pub async fn emit<D: Serialize>(&self, event: impl AsRef<str>, data: Option<&D>) -> Result<()> {
343        self.emit_event_message(
344            self.namespace.encode_packet_to_message(&WsIoPacket::new_event(
345                event.as_ref(),
346                data.map(|data| self.namespace.config.packet_codec.encode_data(data))
347                    .transpose()?,
348            ))?,
349        )
350        .await
351    }
352
353    #[inline]
354    pub fn except(
355        self: &Arc<Self>,
356        room_names: impl IntoIterator<Item = impl AsRef<str>>,
357    ) -> WsIoServerNamespaceBroadcastOperator {
358        self.namespace.except(room_names).except_connection_ids(vec![self.id])
359    }
360
361    #[cfg(feature = "connection-extensions")]
362    #[inline]
363    pub fn extensions(&self) -> &ConnectionExtensions {
364        &self.extensions
365    }
366
367    #[inline]
368    pub fn headers(&self) -> &HeaderMap {
369        &self.headers
370    }
371
372    #[inline]
373    pub fn id(&self) -> u64 {
374        self.id
375    }
376
377    #[inline]
378    pub fn is_ready(&self) -> bool {
379        self.status.is(ConnectionStatus::Ready)
380    }
381
382    #[inline]
383    pub fn join(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl AsRef<str>>) {
384        for room_name in room_names {
385            let room_name = room_name.as_ref();
386            self.namespace.add_connection_id_to_room(room_name, self.id);
387            self.joined_rooms.insert(room_name.into());
388        }
389    }
390
391    #[inline]
392    pub fn leave(self: &Arc<Self>, room_names: impl IntoIterator<Item = impl AsRef<str>>) {
393        for room_name in room_names {
394            self.namespace
395                .remove_connection_id_from_room(room_name.as_ref(), self.id);
396
397            self.joined_rooms.remove(room_name.as_ref());
398        }
399    }
400
401    #[inline]
402    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
403        self.namespace.clone()
404    }
405
406    #[inline]
407    pub fn off(&self, event: impl AsRef<str>) {
408        self.event_registry.off(event.as_ref());
409    }
410
411    #[inline]
412    pub fn off_by_handler_id(&self, event: impl AsRef<str>, handler_id: u32) {
413        self.event_registry.off_by_handler_id(event.as_ref(), handler_id);
414    }
415
416    #[inline]
417    pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> u32
418    where
419        H: Fn(Arc<WsIoServerConnection>, Arc<D>) -> Fut + Send + Sync + 'static,
420        Fut: Future<Output = Result<()>> + Send + 'static,
421        D: DeserializeOwned + Send + Sync + 'static,
422    {
423        self.event_registry.on(event.as_ref(), handler)
424    }
425
426    pub async fn on_close<H, Fut>(&self, handler: H)
427    where
428        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
429        Fut: Future<Output = Result<()>> + Send + 'static,
430    {
431        *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
432    }
433
434    #[inline]
435    pub fn request_uri(&self) -> &Uri {
436        &self.request_uri
437    }
438
439    #[inline]
440    pub fn server(&self) -> WsIoServer {
441        self.namespace.server()
442    }
443
444    #[inline]
445    pub fn to(
446        self: &Arc<Self>,
447        room_names: impl IntoIterator<Item = impl AsRef<str>>,
448    ) -> WsIoServerNamespaceBroadcastOperator {
449        self.namespace.to(room_names).except_connection_ids(vec![self.id])
450    }
451}
452
453// Constants/Statics
454static NEXT_CONNECTION_ID: LazyLock<AtomicU64> = LazyLock::new(|| AtomicU64::new(0));