wsio_server/
connection.rs1use std::{
2 pin::Pin,
3 sync::Arc,
4};
5
6use anyhow::{
7 Result,
8 bail,
9};
10use bson::oid::ObjectId;
11use http::HeaderMap;
12use tokio::{
13 spawn,
14 sync::{
15 Mutex,
16 RwLock,
17 mpsc::{
18 Receiver,
19 Sender,
20 channel,
21 },
22 },
23 task::JoinHandle,
24 time::sleep,
25};
26use tokio_tungstenite::tungstenite::Message;
27use tokio_util::sync::CancellationToken;
28
29use crate::{
30 WsIoServer,
31 core::packet::{
32 WsIoPacket,
33 WsIoPacketType,
34 },
35 namespace::WsIoServerNamespace,
36};
37
38#[derive(Debug)]
39enum WsIoServerConnectionStatus {
40 Activating,
41 Authenticating,
42 AwaitingAuth,
43 Closed,
44 Closing,
45 Created,
46 Ready,
47}
48
49type OnCloseHandler = Box<
50 dyn Fn(Arc<WsIoServerConnection>) -> Pin<Box<dyn Future<Output = Result<()>> + Send + 'static>>
51 + Send
52 + Sync
53 + 'static,
54>;
55
56pub struct WsIoServerConnection {
57 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
58 cancel_token: CancellationToken,
59 headers: HeaderMap,
60 message_tx: Sender<Message>,
61 namespace: Arc<WsIoServerNamespace>,
62 on_close_handler: Mutex<Option<OnCloseHandler>>,
63 sid: String,
64 status: RwLock<WsIoServerConnectionStatus>,
65}
66
67impl WsIoServerConnection {
68 pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, Receiver<Message>) {
69 let (message_tx, message_rx) = channel(512);
71 (
72 Self {
73 auth_timeout_task: Mutex::new(None),
74 cancel_token: CancellationToken::new(),
75 headers,
76 message_tx,
77 namespace,
78 on_close_handler: Mutex::new(None),
79 sid: ObjectId::new().to_string(),
80 status: RwLock::new(WsIoServerConnectionStatus::Created),
81 },
82 message_rx,
83 )
84 }
85
86 async fn abort_auth_timeout_task(&self) {
88 if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
89 auth_timeout_task.abort();
90 }
91 }
92
93 async fn activate(self: &Arc<Self>) -> Result<()> {
94 {
95 let mut status = self.status.write().await;
96 match *status {
97 WsIoServerConnectionStatus::Authenticating | WsIoServerConnectionStatus::Created => {
98 *status = WsIoServerConnectionStatus::Activating
99 }
100 _ => bail!("Cannot activate connection in invalid status: {:#?}", *status),
101 }
102 }
103
104 if let Some(middleware) = &self.namespace.config.middleware {
105 middleware(self.clone()).await?;
106 }
107
108 if let Some(on_connect_handler) = &self.namespace.config.on_connect_handler {
109 on_connect_handler(self.clone()).await?;
110 }
111
112 self.namespace.insert_connection(self.clone());
113 self.send_packet(&WsIoPacket {
114 data: None,
115 key: None,
116 r#type: WsIoPacketType::Ready,
117 })
118 .await?;
119
120 *self.status.write().await = WsIoServerConnectionStatus::Ready;
121 if let Some(on_ready_handler) = &self.namespace.config.on_ready_handler {
122 on_ready_handler(self.clone()).await?;
123 }
124
125 Ok(())
126 }
127
128 async fn handle_auth_packet(self: &Arc<Self>, packet_data: Option<&[u8]>) -> Result<()> {
129 {
130 let mut status = self.status.write().await;
131 match *status {
132 WsIoServerConnectionStatus::AwaitingAuth => *status = WsIoServerConnectionStatus::Authenticating,
133 _ => bail!("Received auth packet in invalid status: {:#?}", *status),
134 }
135 }
136
137 if let Some(auth_handler) = &self.namespace.config.auth_handler {
138 (auth_handler)(self.clone(), packet_data).await?;
139 self.abort_auth_timeout_task().await;
140 self.activate().await
141 } else {
142 bail!("Auth packet received but no auth handler is configured");
143 }
144 }
145
146 async fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
147 Ok(self
148 .message_tx
149 .send(self.namespace.encode_packet_to_message(packet)?)
150 .await?)
151 }
152
153 pub(crate) async fn cleanup(self: &Arc<Self>) {
155 *self.status.write().await = WsIoServerConnectionStatus::Closing;
156 self.abort_auth_timeout_task().await;
157 self.namespace.remove_connection(&self.sid);
158 self.cancel_token.cancel();
159 if let Some(on_close_handler) = self.on_close_handler.lock().await.take() {
160 let _ = on_close_handler(self.clone()).await;
161 }
162
163 *self.status.write().await = WsIoServerConnectionStatus::Closed;
164 }
165
166 pub(crate) async fn close(&self) {
167 {
168 let mut status = self.status.write().await;
169 if matches!(
170 *status,
171 WsIoServerConnectionStatus::Closed | WsIoServerConnectionStatus::Closing
172 ) {
173 return;
174 }
175
176 *status = WsIoServerConnectionStatus::Closing;
177 }
178
179 let _ = self.message_tx.send(Message::Close(None)).await;
180 }
181
182 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) -> Result<()> {
183 let packet = self.namespace.config.packet_codec.decode(bytes)?;
184 match packet.r#type {
185 WsIoPacketType::Auth => self.handle_auth_packet(packet.data.as_deref()).await,
186 _ => Ok(()),
187 }
188 }
189
190 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
191 let require_auth = self.namespace.config.auth_handler.is_some();
192 let packet = WsIoPacket {
193 data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
194 key: None,
195 r#type: WsIoPacketType::Init,
196 };
197
198 if require_auth {
199 *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
200 let connection = self.clone();
201 *self.auth_timeout_task.lock().await = Some(spawn(async move {
202 sleep(connection.namespace.config.auth_timeout).await;
203 if matches!(
204 *connection.status.read().await,
205 WsIoServerConnectionStatus::AwaitingAuth
206 ) {
207 connection.close().await;
208 }
209 }));
210
211 self.send_packet(&packet).await
212 } else {
213 self.send_packet(&packet).await?;
214 self.activate().await
215 }
216 }
217
218 #[inline]
221 pub fn cancel_token(&self) -> &CancellationToken {
222 &self.cancel_token
223 }
224
225 pub async fn disconnect(&self) {
226 let _ = self
227 .send_packet(&WsIoPacket {
228 data: None,
229 key: None,
230 r#type: WsIoPacketType::Disconnect,
231 })
232 .await;
233
234 self.close().await
235 }
236
237 #[inline]
238 pub fn headers(&self) -> &HeaderMap {
239 &self.headers
240 }
241
242 #[inline]
243 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
244 self.namespace.clone()
245 }
246
247 pub async fn on_close<H, Fut>(&self, handler: H)
248 where
249 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
250 Fut: Future<Output = Result<()>> + Send + 'static,
251 {
252 *self.on_close_handler.lock().await = Some(Box::new(move |connection| Box::pin(handler(connection))));
253 }
254
255 #[inline]
256 pub fn server(&self) -> WsIoServer {
257 self.namespace.server()
258 }
259
260 #[inline]
261 pub fn sid(&self) -> &str {
262 &self.sid
263 }
264}