1use std::{
2 pin::Pin,
3 sync::Arc,
4};
5
6use anyhow::{
7 Result,
8 bail,
9};
10use bson::oid::ObjectId;
11use dashmap::DashMap;
12use http::HeaderMap;
13use serde::de::DeserializeOwned;
14use tokio::{
15 spawn,
16 sync::{
17 Mutex,
18 RwLock,
19 mpsc::{
20 UnboundedReceiver,
21 UnboundedSender,
22 unbounded_channel,
23 },
24 },
25 task::JoinHandle,
26 time::sleep,
27};
28use tokio_tungstenite::tungstenite::Message;
29
30use crate::{
31 WsIoServer,
32 core::packet::{
33 WsIoPacket,
34 WsIoPacketType,
35 },
36 namespace::WsIoServerNamespace,
37};
38
39#[derive(Debug)]
40enum WsIoServerConnectionStatus {
41 Activating,
42 Authenticating,
43 AwaitingAuth,
44 Closed,
45 Closing,
46 Created,
47 Ready,
48}
49
50type EventHandler = Box<
51 dyn for<'a> Fn(Arc<WsIoServerConnection>, Option<&'a [u8]>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'a>>
52 + Send
53 + Sync
54 + 'static,
55>;
56
57type OnCloseHandler = Box<
58 dyn Fn(Arc<WsIoServerConnection>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>
59 + Send
60 + Sync
61 + 'static,
62>;
63
64pub struct WsIoServerConnection {
65 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
66 event_handlers: DashMap<String, EventHandler>,
67 headers: HeaderMap,
68 namespace: Arc<WsIoServerNamespace>,
69 on_close_handler: Mutex<Option<OnCloseHandler>>,
70 sid: String,
71 status: RwLock<WsIoServerConnectionStatus>,
72 tx: UnboundedSender<Message>,
73}
74
75impl WsIoServerConnection {
76 pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
77 let (tx, rx) = unbounded_channel();
78 (
79 Self {
80 auth_timeout_task: Mutex::new(None),
81 event_handlers: DashMap::new(),
82 headers,
83 namespace,
84 on_close_handler: Mutex::new(None),
85 sid: ObjectId::new().to_string(),
86 status: RwLock::new(WsIoServerConnectionStatus::Created),
87 tx,
88 },
89 rx,
90 )
91 }
92
93 async fn abort_auth_timeout_task(&self) {
95 if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
96 auth_timeout_task.abort();
97 }
98 }
99
100 async fn activate(self: &Arc<Self>) -> Result<()> {
101 {
102 let mut status = self.status.write().await;
103 match *status {
104 WsIoServerConnectionStatus::Authenticating | WsIoServerConnectionStatus::Created => {
105 *status = WsIoServerConnectionStatus::Activating
106 }
107 _ => bail!("Cannot activate connection in invalid status: {:#?}", *status),
108 }
109 }
110
111 if let Some(middleware) = &self.namespace.config.middleware {
112 middleware(self.clone()).await?;
113 }
114
115 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
116 on_connect_handler(self.clone()).await?;
117 }
118
119 self.namespace.insert_connection(self.clone());
120 self.send_packet(&WsIoPacket {
121 data: None,
122 key: None,
123 r#type: WsIoPacketType::Ready,
124 })?;
125
126 *self.status.write().await = WsIoServerConnectionStatus::Ready;
127 if let Some(on_ready_handler) = &self.namespace.config.on_ready_handler {
128 on_ready_handler(self.clone()).await?;
129 }
130
131 Ok(())
132 }
133
134 async fn handle_auth_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
135 {
136 let mut status = self.status.write().await;
137 match *status {
138 WsIoServerConnectionStatus::AwaitingAuth => *status = WsIoServerConnectionStatus::Authenticating,
139 _ => bail!("Received auth packet in invalid status: {:#?}", *status),
140 }
141 }
142
143 if let Some(auth_handler) = &self.namespace.config.auth_handler {
144 (auth_handler)(self.clone(), packet_data).await?;
145 self.abort_auth_timeout_task().await;
146 let status = self.status.read().await;
147 if !matches!(*status, WsIoServerConnectionStatus::Authenticating) {
148 bail!("Auth packet processed while connection status was {:#?}", *status);
149 }
150
151 self.activate().await?;
152 Ok(())
153 } else {
154 bail!("Auth packet received but no auth handler is configured");
155 }
156 }
157
158 #[inline]
159 fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
160 Ok(self.tx.send(self.namespace.encode_packet_to_message(packet)?)?)
161 }
162
163 pub(crate) async fn cleanup(self: &Arc<Self>) {
165 *self.status.write().await = WsIoServerConnectionStatus::Closing;
166 self.abort_auth_timeout_task().await;
167 self.namespace.remove_connection(&self.sid);
168 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
169 let _ = on_close_handler(self.clone()).await;
170 }
171
172 *self.status.write().await = WsIoServerConnectionStatus::Closed;
173 }
174
175 pub(crate) async fn close(&self) {
176 {
177 let mut status = self.status.write().await;
178 if matches!(
179 *status,
180 WsIoServerConnectionStatus::Closed | WsIoServerConnectionStatus::Closing
181 ) {
182 return;
183 }
184
185 *status = WsIoServerConnectionStatus::Closing;
186 }
187
188 let _ = self.tx.send(Message::Close(None));
189 }
190
191 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) {
192 let packet = match self.namespace.config.packet_codec.decode(bytes) {
193 Ok(packet) => packet,
194 Err(_) => return,
195 };
196
197 if match packet.r#type {
198 WsIoPacketType::Auth => self.handle_auth_packet(packet.data.as_deref()).await,
199 WsIoPacketType::Event => Ok(()),
200 _ => Ok(()),
201 }
202 .is_err()
203 {
204 self.close().await;
205 }
206 }
207
208 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
209 let require_auth = self.namespace.config.auth_handler.is_some();
210 let packet = WsIoPacket {
211 data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
212 key: None,
213 r#type: WsIoPacketType::Init,
214 };
215
216 if require_auth {
217 *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
218 let connection = self.clone();
219 *self.auth_timeout_task.lock().await = Some(spawn(async move {
220 sleep(connection.namespace.config.auth_timeout).await;
221 if matches!(
222 *connection.status.read().await,
223 WsIoServerConnectionStatus::AwaitingAuth
224 ) {
225 connection.close().await;
226 }
227 }));
228
229 self.send_packet(&packet)?;
230 } else {
231 self.send_packet(&packet)?;
232 self.activate().await?;
233 }
234
235 Ok(())
236 }
237
238 pub async fn disconnect(&self) {
240 let _ = self.send_packet(&WsIoPacket {
241 data: None,
242 key: None,
243 r#type: WsIoPacketType::Disconnect,
244 });
245
246 let _ = self.close().await;
248 }
249
250 #[inline]
251 pub fn headers(&self) -> &HeaderMap {
252 &self.headers
253 }
254
255 #[inline]
256 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
257 self.namespace.clone()
258 }
259
260 #[inline]
261 pub fn off(&self, event: impl AsRef<str>) {
262 self.event_handlers.remove(event.as_ref());
263 }
264
265 #[inline]
266 pub fn on<H, Fut, D>(&self, event: impl AsRef<str>, handler: H) -> Result<()>
267 where
268 H: Fn(Arc<WsIoServerConnection>, &D) -> Fut + Send + Sync + 'static,
269 Fut: Future<Output = Result<()>> + Send + 'static,
270 D: DeserializeOwned + Send + 'static,
271 {
272 let event = event.as_ref();
273 if self.event_handlers.contains_key(event) {
274 bail!("Event {} handler already exists", event);
275 }
276
277 let handler = Arc::new(handler);
278 let packet_codec = self.namespace.config.packet_codec;
279 self.event_handlers.insert(
280 event.into(),
281 Box::new(move |connection, bytes: Option<&[u8]>| {
282 let handler = handler.clone();
283 Box::pin(async move {
284 let bytes = match bytes {
285 Some(bytes) => bytes,
286 None => packet_codec.empty_data_encoded(),
287 };
288
289 let data = packet_codec.decode_data::<D>(bytes)?;
290 handler(connection, &data).await
291 })
292 }),
293 );
294
295 Ok(())
296 }
297
298 pub async fn on_close<H, Fut>(&self, handler: H)
299 where
300 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
301 Fut: Future<Output = Result<()>> + Send + 'static,
302 {
303 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
304 }
305
306 #[inline]
307 pub fn server(&self) -> WsIoServer {
308 self.namespace.server()
309 }
310
311 #[inline]
312 pub fn sid(&self) -> &str {
313 &self.sid
314 }
315}