wsio_server/
connection.rs1use std::sync::Arc;
2
3use anyhow::{
4 Result,
5 bail,
6};
7use bson::oid::ObjectId;
8use dashmap::DashMap;
9use http::HeaderMap;
10use serde::de::DeserializeOwned;
11use tokio::{
12 spawn,
13 sync::{
14 Mutex,
15 RwLock,
16 mpsc::{
17 UnboundedReceiver,
18 UnboundedSender,
19 unbounded_channel,
20 },
21 },
22 task::JoinHandle,
23 time::sleep,
24};
25use tokio_tungstenite::tungstenite::Message;
26
27use crate::{
28 WsIoServer,
29 core::packet::{
30 WsIoPacket,
31 WsIoPacketType,
32 },
33 namespace::WsIoServerNamespace,
34 types::handler::{
35 WsIoServerConnectionEventHandler,
36 WsIoServerConnectionOnCloseHandler,
37 },
38};
39
40enum WsIoServerConnectionStatus {
41 Activating,
42 AwaitingAuth,
43 Closed,
44 Closing,
45 Created,
46 Ready,
47}
48
49pub struct WsIoServerConnection {
50 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
51 event_handlers: DashMap<String, WsIoServerConnectionEventHandler>,
52 headers: HeaderMap,
53 namespace: Arc<WsIoServerNamespace>,
54 on_close_handler: Mutex<Option<WsIoServerConnectionOnCloseHandler>>,
55 sid: String,
56 status: RwLock<WsIoServerConnectionStatus>,
57 tx: UnboundedSender<Message>,
58}
59
60impl WsIoServerConnection {
61 pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
62 let (tx, rx) = unbounded_channel();
63 (
64 Self {
65 auth_timeout_task: Mutex::new(None),
66 event_handlers: DashMap::new(),
67 headers,
68 namespace,
69 on_close_handler: Mutex::new(None),
70 sid: ObjectId::new().to_string(),
71 status: RwLock::new(WsIoServerConnectionStatus::Created),
72 tx,
73 },
74 rx,
75 )
76 }
77
78 async fn abort_auth_timeout_task(&self) {
80 if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
81 auth_timeout_task.abort();
82 }
83 }
84
85 async fn activate(self: &Arc<Self>) -> Result<()> {
86 if matches!(
87 *self.status.read().await,
88 WsIoServerConnectionStatus::AwaitingAuth | WsIoServerConnectionStatus::Created
89 ) {
90 return Ok(());
91 }
92
93 *self.status.write().await = WsIoServerConnectionStatus::Activating;
94 if let Some(middleware) = &self.namespace.config.middleware {
95 middleware(self.clone()).await?;
96 }
97
98 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
99 on_connect_handler(self.clone()).await?;
100 }
101
102 self.namespace.insert_connection(self.clone());
103 self.send_packet(&WsIoPacket {
104 data: None,
105 key: None,
106 r#type: WsIoPacketType::Ready,
107 })?;
108
109 *self.status.write().await = WsIoServerConnectionStatus::Ready;
110 if let Some(on_ready_handler) = &self.namespace.config.on_ready_handler {
111 on_ready_handler(self.clone()).await?;
112 }
113
114 Ok(())
115 }
116
117 pub(crate) async fn cleanup(self: &Arc<Self>) {
118 *self.status.write().await = WsIoServerConnectionStatus::Closing;
119 self.abort_auth_timeout_task().await;
120 self.namespace.cleanup_connection(&self.sid);
121 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
122 let _ = on_close_handler(self.clone()).await;
123 }
124
125 *self.status.write().await = WsIoServerConnectionStatus::Closed;
126 }
127
128 pub(crate) async fn close(&self) {
129 {
130 let mut status = self.status.write().await;
131 if matches!(
132 *status,
133 WsIoServerConnectionStatus::Closed | WsIoServerConnectionStatus::Closing
134 ) {
135 return;
136 }
137
138 *status = WsIoServerConnectionStatus::Closing;
139 }
140
141 let _ = self.tx.send(Message::Close(None));
142 }
143
144 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) {
145 let packet = match self.namespace.config.packet_codec.decode(bytes) {
146 Ok(packet) => packet,
147 Err(_) => return,
148 };
149
150 match packet.r#type {
151 WsIoPacketType::Auth => {
152 if let Some(auth_handler) = &self.namespace.config.auth_handler
153 && (auth_handler)(self.clone(), packet.data.as_deref()).await.is_ok()
154 {
155 self.abort_auth_timeout_task().await;
156 if matches!(*self.status.read().await, WsIoServerConnectionStatus::AwaitingAuth)
157 && self.activate().await.is_ok()
158 {
159 return;
160 }
161 }
162
163 self.close().await;
164 }
165 WsIoPacketType::Event => {}
166 _ => {}
167 }
168 }
169
170 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
171 let require_auth = self.namespace.config.auth_handler.is_some();
172 let packet = WsIoPacket {
173 data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
174 key: Some(self.sid.clone()),
175 r#type: WsIoPacketType::Init,
176 };
177
178 if require_auth {
179 *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
180 let connection = self.clone();
181 *self.auth_timeout_task.lock().await = Some(spawn(async move {
182 sleep(connection.namespace.config.auth_timeout).await;
183 if matches!(
184 *connection.status.read().await,
185 WsIoServerConnectionStatus::AwaitingAuth
186 ) {
187 connection.close().await;
188 }
189 }));
190
191 self.send_packet(&packet)?;
192 } else {
193 self.send_packet(&packet)?;
194 self.activate().await?;
195 }
196
197 Ok(())
198 }
199
200 #[inline]
201 fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
202 Ok(self.tx.send(self.namespace.encode_packet_to_message(packet)?)?)
203 }
204
205 pub async fn disconnect(&self) {
207 let _ = self.send_packet(&WsIoPacket {
208 data: None,
209 key: None,
210 r#type: WsIoPacketType::Disconnect,
211 });
212
213 let _ = self.close().await;
215 }
216
217 #[inline]
218 pub fn headers(&self) -> &HeaderMap {
219 &self.headers
220 }
221
222 #[inline]
223 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
224 self.namespace.clone()
225 }
226
227 #[inline]
228 pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> Result<()>
229 where
230 H: Fn(Arc<WsIoServerConnection>, &D) -> Fut + Send + Sync + 'static,
231 Fut: Future<Output = Result<()>> + Send + 'static,
232 D: DeserializeOwned + Send + 'static,
233 {
234 let event = event.as_ref();
235 if self.event_handlers.contains_key(event) {
236 bail!("Event {} handler already exists", event);
237 }
238
239 let handler = Arc::new(handler);
240 let packet_codec = self.namespace.config.packet_codec;
241 self.event_handlers.insert(
242 event.into(),
243 Box::new(move |connection, bytes: Option<&[u8]>| {
244 let handler = handler.clone();
245 Box::pin(async move {
246 let bytes = match bytes {
247 Some(bytes) => bytes,
248 None => packet_codec.empty_data_encoded(),
249 };
250
251 let data = packet_codec.decode_data::<D>(bytes)?;
252 handler(connection, &data).await
253 })
254 }),
255 );
256
257 Ok(())
258 }
259
260 pub async fn on_close<H, Fut>(&self, handler: H)
261 where
262 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
263 Fut: Future<Output = Result<()>> + Send + 'static,
264 {
265 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
266 }
267
268 #[inline]
269 pub fn server(&self) -> WsIoServer {
270 self.namespace.server()
271 }
272
273 #[inline]
274 pub fn sid(&self) -> &str {
275 &self.sid
276 }
277}