wsio_server/
connection.rs1use std::sync::Arc;
2
3use anyhow::{
4 Result,
5 bail,
6};
7use bson::oid::ObjectId;
8use dashmap::DashMap;
9use http::HeaderMap;
10use serde::de::DeserializeOwned;
11use tokio::{
12 spawn,
13 sync::{
14 Mutex,
15 RwLock,
16 mpsc::{
17 UnboundedReceiver,
18 UnboundedSender,
19 unbounded_channel,
20 },
21 },
22 task::JoinHandle,
23 time::sleep,
24};
25use tokio_tungstenite::tungstenite::Message;
26
27use crate::{
28 WsIoServer,
29 core::packet::{
30 WsIoPacket,
31 WsIoPacketType,
32 },
33 namespace::WsIoServerNamespace,
34 types::handler::{
35 WsIoServerConnectionEventHandler,
36 WsIoServerConnectionOnDisconnectHandler,
37 },
38};
39
40enum WsIoServerConnectionStatus {
41 Activating,
42 AwaitingAuth,
43 Closed,
44 Closing,
45 Created,
46 Ready,
47}
48
49pub struct WsIoServerConnection {
50 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
51 event_handlers: DashMap<String, WsIoServerConnectionEventHandler>,
52 headers: HeaderMap,
53 namespace: Arc<WsIoServerNamespace>,
54 on_disconnect_handler: Mutex<Option<WsIoServerConnectionOnDisconnectHandler>>,
55 sid: String,
56 status: RwLock<WsIoServerConnectionStatus>,
57 tx: UnboundedSender<Message>,
58}
59
60impl WsIoServerConnection {
61 pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
62 let (tx, rx) = unbounded_channel();
63 (
64 Self {
65 auth_timeout_task: Mutex::new(None),
66 event_handlers: DashMap::new(),
67 headers,
68 namespace,
69 on_disconnect_handler: Mutex::new(None),
70 sid: ObjectId::new().to_string(),
71 status: RwLock::new(WsIoServerConnectionStatus::Created),
72 tx,
73 },
74 rx,
75 )
76 }
77
78 async fn abort_auth_timeout_task(&self) {
80 if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
81 auth_timeout_task.abort();
82 }
83 }
84
85 async fn activate(self: &Arc<Self>) -> Result<()> {
86 *self.status.write().await = WsIoServerConnectionStatus::Activating;
87 if let Some(middleware) = &self.namespace.config.middleware {
88 middleware(self.clone()).await?;
89 }
90
91 self.namespace.insert_connection(self.clone());
92 *self.status.write().await = WsIoServerConnectionStatus::Ready;
93 let packet = WsIoPacket {
94 data: None,
95 key: None,
96 r#type: WsIoPacketType::Ready,
97 };
98
99 self.send_packet(&packet)?;
100 (self.namespace.config.on_connect_handler)(self.clone()).await
101 }
102
103 pub(crate) async fn cleanup(self: &Arc<Self>) {
104 *self.status.write().await = WsIoServerConnectionStatus::Closing;
105 self.abort_auth_timeout_task().await;
106 self.namespace.cleanup_connection(&self.sid);
107 if let Some(on_disconnect_handler) = self.on_disconnect_handler.lock().await.take() {
108 let _ = on_disconnect_handler(self.clone()).await;
109 }
110
111 *self.status.write().await = WsIoServerConnectionStatus::Closed;
112 }
113
114 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) {
115 let packet = match self.namespace.config.packet_codec.decode(bytes) {
116 Ok(packet) => packet,
117 Err(_) => return,
118 };
119
120 match packet.r#type {
121 WsIoPacketType::Auth => {
122 if let Some(auth_handler) = &self.namespace.config.auth_handler
123 && (auth_handler)(self.clone(), packet.data.as_deref()).await.is_ok()
124 {
125 self.abort_auth_timeout_task().await;
126 if self.activate().await.is_ok() {
127 return;
128 }
129 }
130
131 self.close();
132 }
133 WsIoPacketType::Event => {}
134 _ => {}
135 }
136 }
137
138 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
139 self.send(Message::Text(format!("c{}", self.namespace.config.packet_codec).into()))?;
140 let require_auth = self.namespace.config.auth_handler.is_some();
141 let packet = WsIoPacket {
142 data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
143 key: Some(self.sid.clone()),
144 r#type: WsIoPacketType::Init,
145 };
146
147 if require_auth {
148 *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
149 let connection = self.clone();
150 self.auth_timeout_task.lock().await.replace(spawn(async move {
151 sleep(connection.namespace.config.auth_timeout).await;
152 if matches!(
153 *connection.status.read().await,
154 WsIoServerConnectionStatus::AwaitingAuth
155 ) {
156 connection.close();
157 }
158 }));
159
160 self.send_packet(&packet)?;
161 } else {
162 self.send_packet(&packet)?;
163 self.activate().await?;
164 }
165
166 Ok(())
167 }
168
169 #[inline]
170 fn send(&self, message: Message) -> Result<()> {
171 Ok(self.tx.send(message)?)
172 }
173
174 #[inline]
175 fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
176 self.send(self.namespace.encode_packet_to_message(packet)?)
177 }
178
179 #[inline]
182 pub fn close(&self) {
183 let _ = self.send(Message::Close(None));
184 }
185
186 #[inline]
187 pub fn headers(&self) -> &HeaderMap {
188 &self.headers
189 }
190
191 #[inline]
192 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
193 self.namespace.clone()
194 }
195
196 pub fn on<H, Fut, D>(&self, event: &str, handler: H) -> Result<()>
197 where
198 H: Fn(Arc<WsIoServerConnection>, &D) -> Fut + Send + Sync + 'static,
199 Fut: Future<Output = Result<()>> + Send + 'static,
200 D: DeserializeOwned + Send + 'static,
201 {
202 if self.event_handlers.contains_key(event) {
203 bail!("Event {} handler already exists", event);
204 }
205
206 let handler = Arc::new(handler);
207 let packet_codec = self.namespace.config.packet_codec;
208 self.event_handlers.insert(
209 event.to_string(),
210 Arc::new(move |connection, bytes: Option<&[u8]>| {
211 let handler = handler.clone();
212 Box::pin(async move {
213 let bytes = match bytes {
214 Some(bytes) => bytes,
215 None => packet_codec.empty_data_encoded(),
216 };
217
218 let data = packet_codec.decode_data::<D>(bytes)?;
219 handler(connection, &data).await
220 })
221 }),
222 );
223
224 Ok(())
225 }
226
227 pub async fn on_disconnect<H, Fut>(&self, handler: H)
228 where
229 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
230 Fut: Future<Output = Result<()>> + Send + 'static,
231 {
232 self.on_disconnect_handler
233 .lock()
234 .await
235 .replace(Box::new(move |connection| Box::pin(handler(connection))));
236 }
237
238 #[inline]
239 pub fn server(&self) -> WsIoServer {
240 self.namespace.server()
241 }
242
243 #[inline]
244 pub fn sid(&self) -> &str {
245 &self.sid
246 }
247}