wsio_server/
connection.rs

1use std::sync::Arc;
2
3use anyhow::Result;
4use bson::oid::ObjectId;
5use http::HeaderMap;
6use tokio::{
7    spawn,
8    sync::{
9        Mutex,
10        RwLock,
11        mpsc::{
12            UnboundedReceiver,
13            UnboundedSender,
14            unbounded_channel,
15        },
16    },
17    task::JoinHandle,
18    time::sleep,
19};
20use tokio_tungstenite::tungstenite::Message;
21
22use crate::{
23    core::packet::{
24        WsIoPacket,
25        WsIoPacketType,
26    },
27    namespace::WsIoServerNamespace,
28    types::handler::WsIoServerConnectionOnDisconnectHandler,
29};
30
31enum WsIoServerConnectionStatus {
32    Activating,
33    AwaitingAuth,
34    Closed,
35    Closing,
36    Created,
37    Ready,
38}
39
40pub struct WsIoServerConnection {
41    auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
42    headers: HeaderMap,
43    namespace: Arc<WsIoServerNamespace>,
44    on_disconnect_handler: Mutex<Option<WsIoServerConnectionOnDisconnectHandler>>,
45    sid: String,
46    status: RwLock<WsIoServerConnectionStatus>,
47    tx: UnboundedSender<Message>,
48}
49
50impl WsIoServerConnection {
51    pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
52        let (tx, rx) = unbounded_channel();
53        (
54            Self {
55                auth_timeout_task: Mutex::new(None),
56                headers,
57                namespace,
58                on_disconnect_handler: Mutex::new(None),
59                sid: ObjectId::new().to_string(),
60                status: RwLock::new(WsIoServerConnectionStatus::Created),
61                tx,
62            },
63            rx,
64        )
65    }
66
67    // Protected methods
68    async fn abort_auth_timeout_task(&self) {
69        if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
70            auth_timeout_task.abort();
71        }
72    }
73
74    async fn activate(self: &Arc<Self>) -> Result<()> {
75        *self.status.write().await = WsIoServerConnectionStatus::Activating;
76        // TODO: middlewares
77        self.namespace.insert_connection(self.clone());
78        *self.status.write().await = WsIoServerConnectionStatus::Ready;
79        let packet = WsIoPacket {
80            data: None,
81            key: None,
82            r#type: WsIoPacketType::Ready,
83        };
84
85        self.send_packet(&packet)?;
86        (self.namespace.config.on_connect_handler)(self.clone()).await
87    }
88
89    pub(crate) async fn cleanup(self: &Arc<Self>) {
90        *self.status.write().await = WsIoServerConnectionStatus::Closing;
91        self.abort_auth_timeout_task().await;
92        self.namespace.cleanup_connection(&self.sid);
93        if let Some(on_disconnect_handler) = self.on_disconnect_handler.lock().await.take() {
94            let _ = on_disconnect_handler(self.clone()).await;
95        }
96
97        *self.status.write().await = WsIoServerConnectionStatus::Closed;
98    }
99
100    pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) {
101        let packet = match self.namespace.config.packet_codec.decode(bytes) {
102            Ok(packet) => packet,
103            Err(_) => return,
104        };
105
106        match packet.r#type {
107            WsIoPacketType::Auth => {
108                if let Some(auth_handler) = &self.namespace.config.auth_handler
109                    && (auth_handler)(self.clone(), packet.data.as_deref()).await.is_ok() {
110                        self.abort_auth_timeout_task().await;
111                        if self.activate().await.is_ok() {
112                            return;
113                        }
114                    }
115
116                self.close();
117            }
118            WsIoPacketType::Event => {}
119            _ => {}
120        }
121    }
122
123    pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
124        self.send(Message::Text(format!("c{}", self.namespace.config.packet_codec).into()))?;
125        let require_auth = self.namespace.config.auth_handler.is_some();
126        let packet = WsIoPacket {
127            data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
128            key: Some(self.sid.clone()),
129            r#type: WsIoPacketType::Init,
130        };
131
132        if require_auth {
133            *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
134            let connection = self.clone();
135            self.auth_timeout_task.lock().await.replace(spawn(async move {
136                sleep(connection.namespace.config.auth_timeout).await;
137                if matches!(
138                    *connection.status.read().await,
139                    WsIoServerConnectionStatus::AwaitingAuth
140                ) {
141                    connection.close();
142                }
143            }));
144
145            self.send_packet(&packet)?;
146        } else {
147            self.send_packet(&packet)?;
148            self.activate().await?;
149        }
150
151        Ok(())
152    }
153
154    #[inline]
155    fn send(&self, message: Message) -> Result<()> {
156        Ok(self.tx.send(message)?)
157    }
158
159    #[inline]
160    fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
161        self.send(self.namespace.encode_packet_to_message(packet)?)
162    }
163
164    // Public methods
165
166    #[inline]
167    pub fn close(&self) {
168        let _ = self.send(Message::Close(None));
169    }
170
171    #[inline]
172    pub fn headers(&self) -> &HeaderMap {
173        &self.headers
174    }
175
176    #[inline]
177    pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
178        self.namespace.clone()
179    }
180
181    pub async fn on_disconnect<H, Fut>(&self, handler: H)
182    where
183        H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
184        Fut: Future<Output = Result<()>> + Send + 'static,
185    {
186        self.on_disconnect_handler
187            .lock()
188            .await
189            .replace(Box::new(move |connection| Box::pin(handler(connection))));
190    }
191
192    #[inline]
193    pub fn sid(&self) -> &str {
194        &self.sid
195    }
196}