wsio_server/
connection.rs1use std::sync::Arc;
2
3use anyhow::Result;
4use bson::oid::ObjectId;
5use http::HeaderMap;
6use tokio::{
7 spawn,
8 sync::{
9 Mutex,
10 RwLock,
11 mpsc::{
12 UnboundedReceiver,
13 UnboundedSender,
14 unbounded_channel,
15 },
16 },
17 task::JoinHandle,
18 time::sleep,
19};
20use tokio_tungstenite::tungstenite::Message;
21
22use crate::{
23 core::packet::{
24 WsIoPacket,
25 WsIoPacketType,
26 },
27 namespace::WsIoServerNamespace,
28 types::handler::WsIoServerConnectionOnDisconnectHandler,
29};
30
31enum WsIoServerConnectionStatus {
32 Activating,
33 AwaitingAuth,
34 Closed,
35 Closing,
36 Created,
37 Ready,
38}
39
40pub struct WsIoServerConnection {
41 auth_timeout_task: Mutex<Option<JoinHandle<()>>>,
42 headers: HeaderMap,
43 namespace: Arc<WsIoServerNamespace>,
44 on_disconnect_handler: Mutex<Option<WsIoServerConnectionOnDisconnectHandler>>,
45 sid: String,
46 status: RwLock<WsIoServerConnectionStatus>,
47 tx: UnboundedSender<Message>,
48}
49
50impl WsIoServerConnection {
51 pub(crate) fn new(headers: HeaderMap, namespace: Arc<WsIoServerNamespace>) -> (Self, UnboundedReceiver<Message>) {
52 let (tx, rx) = unbounded_channel();
53 (
54 Self {
55 auth_timeout_task: Mutex::new(None),
56 headers,
57 namespace,
58 on_disconnect_handler: Mutex::new(None),
59 sid: ObjectId::new().to_string(),
60 status: RwLock::new(WsIoServerConnectionStatus::Created),
61 tx,
62 },
63 rx,
64 )
65 }
66
67 async fn abort_auth_timeout_task(&self) {
69 if let Some(auth_timeout_task) = self.auth_timeout_task.lock().await.take() {
70 auth_timeout_task.abort();
71 }
72 }
73
74 async fn activate(self: &Arc<Self>) -> Result<()> {
75 *self.status.write().await = WsIoServerConnectionStatus::Activating;
76 self.namespace.insert_connection(self.clone());
78 *self.status.write().await = WsIoServerConnectionStatus::Ready;
79 let packet = WsIoPacket {
80 data: None,
81 key: None,
82 r#type: WsIoPacketType::Ready,
83 };
84
85 self.send_packet(&packet)?;
86 (self.namespace.config.on_connect_handler)(self.clone()).await
87 }
88
89 pub(crate) async fn cleanup(self: &Arc<Self>) {
90 *self.status.write().await = WsIoServerConnectionStatus::Closing;
91 self.abort_auth_timeout_task().await;
92 self.namespace.cleanup_connection(&self.sid);
93 if let Some(on_disconnect_handler) = self.on_disconnect_handler.lock().await.take() {
94 let _ = on_disconnect_handler(self.clone()).await;
95 }
96
97 *self.status.write().await = WsIoServerConnectionStatus::Closed;
98 }
99
100 pub(crate) async fn handle_incoming_packet(self: &Arc<Self>, bytes: &[u8]) {
101 let packet = match self.namespace.config.packet_codec.decode(bytes) {
102 Ok(packet) => packet,
103 Err(_) => return,
104 };
105
106 match packet.r#type {
107 WsIoPacketType::Auth => {
108 if let Some(auth_handler) = &self.namespace.config.auth_handler
109 && (auth_handler)(self.clone(), packet.data.as_deref()).await.is_ok() {
110 self.abort_auth_timeout_task().await;
111 if self.activate().await.is_ok() {
112 return;
113 }
114 }
115
116 self.close();
117 }
118 WsIoPacketType::Event => {}
119 _ => {}
120 }
121 }
122
123 pub(crate) async fn init(self: &Arc<Self>) -> Result<()> {
124 self.send(Message::Text(format!("c{}", self.namespace.config.packet_codec).into()))?;
125 let require_auth = self.namespace.config.auth_handler.is_some();
126 let packet = WsIoPacket {
127 data: Some(self.namespace.config.packet_codec.encode_data(&require_auth)?),
128 key: Some(self.sid.clone()),
129 r#type: WsIoPacketType::Init,
130 };
131
132 if require_auth {
133 *self.status.write().await = WsIoServerConnectionStatus::AwaitingAuth;
134 let connection = self.clone();
135 self.auth_timeout_task.lock().await.replace(spawn(async move {
136 sleep(connection.namespace.config.auth_timeout).await;
137 if matches!(
138 *connection.status.read().await,
139 WsIoServerConnectionStatus::AwaitingAuth
140 ) {
141 connection.close();
142 }
143 }));
144
145 self.send_packet(&packet)?;
146 } else {
147 self.send_packet(&packet)?;
148 self.activate().await?;
149 }
150
151 Ok(())
152 }
153
154 #[inline]
155 fn send(&self, message: Message) -> Result<()> {
156 Ok(self.tx.send(message)?)
157 }
158
159 #[inline]
160 fn send_packet(&self, packet: &WsIoPacket) -> Result<()> {
161 self.send(self.namespace.encode_packet_to_message(packet)?)
162 }
163
164 #[inline]
167 pub fn close(&self) {
168 let _ = self.send(Message::Close(None));
169 }
170
171 #[inline]
172 pub fn headers(&self) -> &HeaderMap {
173 &self.headers
174 }
175
176 #[inline]
177 pub fn namespace(&self) -> Arc<WsIoServerNamespace> {
178 self.namespace.clone()
179 }
180
181 pub async fn on_disconnect<H, Fut>(&self, handler: H)
182 where
183 H: Fn(Arc<WsIoServerConnection>) -> Fut + Send + Sync + 'static,
184 Fut: Future<Output = Result<()>> + Send + 'static,
185 {
186 self.on_disconnect_handler
187 .lock()
188 .await
189 .replace(Box::new(move |connection| Box::pin(handler(connection))));
190 }
191
192 #[inline]
193 pub fn sid(&self) -> &str {
194 &self.sid
195 }
196}