workflow_websocket/server/
mod.rs1use async_trait::async_trait;
5use cfg_if::cfg_if;
6use downcast_rs::*;
7use futures::{future::FutureExt, select};
8use futures_util::{
9 stream::{SplitSink, SplitStream},
10 SinkExt, StreamExt,
11};
12use std::net::SocketAddr;
13use std::pin::Pin;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::sync::Arc;
16use std::time::Duration;
17pub use tokio::net::TcpListener;
18use tokio::net::TcpStream;
19use tokio::sync::mpsc::{
20 UnboundedReceiver as TokioUnboundedReceiver, UnboundedSender as TokioUnboundedSender,
21};
22use tokio_tungstenite::{accept_async_with_config, WebSocketStream};
23use tungstenite::Error as WebSocketError;
24use workflow_core::channel::DuplexChannel;
25use workflow_log::*;
26pub mod error;
27pub mod result;
28
29pub use error::Error;
30pub use result::Result;
31pub use tungstenite::protocol::WebSocketConfig;
32pub use tungstenite::Message;
33pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
36pub type WebSocketReceiver = SplitStream<WebSocketStream<TcpStream>>;
39pub type WebSocketSink = TokioUnboundedSender<Message>;
44
45pub struct WebSocketCounters {
51 pub total_connections: Arc<AtomicUsize>,
52 pub active_connections: Arc<AtomicUsize>,
53 pub handshake_failures: Arc<AtomicUsize>,
54 pub rx_bytes: Arc<AtomicUsize>,
55 pub tx_bytes: Arc<AtomicUsize>,
56}
57
58impl Default for WebSocketCounters {
59 fn default() -> Self {
60 WebSocketCounters {
61 total_connections: Arc::new(AtomicUsize::new(0)),
62 active_connections: Arc::new(AtomicUsize::new(0)),
63 handshake_failures: Arc::new(AtomicUsize::new(0)),
64 rx_bytes: Arc::new(AtomicUsize::new(0)),
65 tx_bytes: Arc::new(AtomicUsize::new(0)),
66 }
67 }
68}
69
70#[async_trait]
77pub trait WebSocketHandler
78where
79 Arc<Self>: Sync,
80{
81 type Context: Send + Sync;
83
84 fn accept(&self, _peer: &SocketAddr) -> bool {
86 true
87 }
88
89 async fn connect(self: &Arc<Self>, _peer: &SocketAddr) -> Result<()> {
94 Ok(())
95 }
96
97 async fn disconnect(self: &Arc<Self>, _ctx: Self::Context, _result: Result<()>) {}
99
100 async fn handshake(
104 self: &Arc<Self>,
105 peer: &SocketAddr,
106 sender: &mut WebSocketSender,
107 receiver: &mut WebSocketReceiver,
108 sink: &WebSocketSink,
109 ) -> Result<Self::Context>;
110
111 async fn message(
114 self: &Arc<Self>,
115 ctx: &Self::Context,
116 msg: Message,
117 sink: &WebSocketSink,
118 ) -> Result<()>;
119
120 async fn ctl(self: &Arc<Self>, msg: Message, sender: &mut WebSocketSender) -> Result<()> {
121 if let Message::Ping(data) = msg {
122 sender.send(Message::Pong(data)).await?;
123 }
124 Ok(())
125 }
126}
127
128pub struct WebSocketServer<T>
132where
133 T: WebSocketHandler + Send + Sync + 'static + Sized,
134{
135 pub counters: Arc<WebSocketCounters>,
137 pub handler: Arc<T>,
138 pub stop: DuplexChannel,
139}
140
141impl<T> WebSocketServer<T>
142where
143 T: WebSocketHandler + Send + Sync + 'static,
144{
145 pub fn new(handler: Arc<T>, counters: Option<Arc<WebSocketCounters>>) -> Arc<Self> {
146 Arc::new(WebSocketServer {
147 counters: counters.unwrap_or_default(),
148 handler,
149 stop: DuplexChannel::oneshot(),
150 })
151 }
152
153 async fn handle_connection(
154 self: &Arc<Self>,
155 peer: SocketAddr,
156 stream: TcpStream,
157 config: Option<WebSocketConfig>,
158 ) -> Result<()> {
159 let ws_stream = accept_async_with_config(stream, config).await?;
160 self.handler.connect(&peer).await?;
161 let (mut ws_sender, mut ws_receiver) = ws_stream.split();
164 let (sink_sender, sink_receiver) = tokio::sync::mpsc::unbounded_channel::<Message>();
165
166 let ctx = match self
167 .handler
168 .handshake(&peer, &mut ws_sender, &mut ws_receiver, &sink_sender)
169 .await
170 {
171 Ok(ctx) => ctx,
172 Err(err) => {
173 self.counters
174 .handshake_failures
175 .fetch_add(1, Ordering::Relaxed);
176 return Err(err);
177 }
178 };
179
180 let result = self
181 .connection_task(&ctx, ws_sender, ws_receiver, sink_sender, sink_receiver)
182 .await;
183 self.handler.disconnect(ctx, result).await;
184 Ok(())
187 }
188
189 async fn connection_task(
190 self: &Arc<Self>,
191 ctx: &T::Context,
192 mut ws_sender: WebSocketSender,
193 mut ws_receiver: WebSocketReceiver,
194 sink_sender: TokioUnboundedSender<Message>,
195 mut sink_receiver: TokioUnboundedReceiver<Message>,
196 ) -> Result<()> {
197 loop {
198 tokio::select! {
199 msg = sink_receiver.recv() => {
200 let msg = msg.unwrap();
201 match msg {
202 Message::Binary(data) => {
203 self.counters.tx_bytes.fetch_add(data.len(), Ordering::Relaxed);
204 ws_sender.send(Message::Binary(data)).await?;
205 },
206 Message::Text(text) => {
207 self.counters.tx_bytes.fetch_add(text.len(), Ordering::Relaxed);
208 ws_sender.send(Message::Text(text)).await?;
209 },
210 Message::Close(_) => {
211 ws_sender.send(msg).await?;
212 break;
213 },
214 Message::Ping(data) => {
215 self.counters.tx_bytes.fetch_add(data.len(), Ordering::Relaxed);
216 ws_sender.send(Message::Ping(data)).await?;
217 },
218 Message::Pong(data) => {
219 self.counters.tx_bytes.fetch_add(data.len(), Ordering::Relaxed);
220 ws_sender.send(Message::Pong(data)).await?;
221 },
222 msg => {
223 ws_sender.send(msg).await?;
224 }
225 }
226 },
227 msg = ws_receiver.next() => {
228 match msg {
229 Some(msg) => {
230 let msg = msg?;
231 match msg {
232 Message::Binary(data) => {
233 self.counters.rx_bytes.fetch_add(data.len(), Ordering::Relaxed);
234 self.handler.message(ctx, Message::Binary(data), &sink_sender).await?;
235 },
236 Message::Text(text) => {
237 self.counters.rx_bytes.fetch_add(text.len(), Ordering::Relaxed);
238 self.handler.message(ctx, Message::Text(text), &sink_sender).await?;
239 },
240 Message::Close(_) => {
241 self.handler.message(ctx, msg, &sink_sender).await?;
242 break;
243 },
244 Message::Ping(data) => {
245 self.counters.rx_bytes.fetch_add(data.len(), Ordering::Relaxed);
246 cfg_if! {
247 if #[cfg(feature = "ping-pong")] {
248 self.handler.ctl(Message::Ping(data), &mut ws_sender).await?;
249 } else {
250 ws_sender.send(Message::Pong(data)).await?;
251 }
252 }
253 },
254 Message::Pong(data) => {
255 self.counters.rx_bytes.fetch_add(data.len(), Ordering::Relaxed);
256 cfg_if! {
257 if #[cfg(feature = "ping-pong")] {
258 self.handler.ctl(Message::Pong(data), &mut ws_sender).await?;
259 } else {
260 }
262 }
263 },
264 _ => {
265 }
266 }
267 }
268 None => {
269 return Err(Error::AbnormalClose);
270 }
271 }
272 }
273 }
274 }
275
276 Ok(())
277 }
278
279 pub async fn bind(self: &Arc<Self>, addr: &str) -> Result<TcpListener> {
280 let listener = TcpListener::bind(&addr).await.map_err(|err| {
281 Error::Listen(format!(
282 "WebSocket server unable to listen on `{addr}`: {err}",
283 ))
284 })?;
285 Ok(listener)
287 }
288
289 async fn accept(self: &Arc<Self>, stream: TcpStream, config: Option<WebSocketConfig>) {
290 let peer = match stream.peer_addr() {
291 Ok(peer_address) => peer_address,
292 Err(_) => {
293 self.counters
294 .handshake_failures
295 .fetch_add(1, Ordering::Relaxed);
296 return;
297 }
298 };
299
300 self.counters
301 .total_connections
302 .fetch_add(1, Ordering::Relaxed);
303 self.counters
304 .active_connections
305 .fetch_add(1, Ordering::Relaxed);
306
307 let self_ = self.clone();
308 tokio::spawn(async move {
309 if let Err(e) = self_.handle_connection(peer, stream, config).await {
310 match e {
311 Error::WebSocketError(WebSocketError::ConnectionClosed)
312 | Error::WebSocketError(WebSocketError::Protocol(_))
313 | Error::WebSocketError(WebSocketError::Utf8) => (),
314 err => log_error!("Error processing connection: {}", err),
315 }
316 }
317 self_
318 .counters
319 .active_connections
320 .fetch_sub(1, Ordering::Relaxed)
321 });
322 }
323
324 pub async fn listen(
325 self: &Arc<Self>,
326 listener: TcpListener,
327 config: Option<WebSocketConfig>,
328 ) -> Result<()> {
329 loop {
330 select! {
331 stream = listener.accept().fuse() => {
332 if let Ok((stream,socket_addr)) = stream {
333 if self.handler.accept(&socket_addr) {
334 self.accept(stream, config).await;
335 }
336 }
337 },
338 _ = self.stop.request.receiver.recv().fuse() => break,
339 }
340 }
341
342 self.stop
343 .response
344 .sender
345 .send(())
346 .await
347 .map_err(|err| Error::Done(err.to_string()))
348 }
349
350 pub fn stop(&self) -> Result<()> {
351 self.stop
352 .request
353 .sender
354 .try_send(())
355 .map_err(|err| Error::Stop(err.to_string()))
356 }
357
358 pub async fn join(&self) -> Result<()> {
359 self.stop
360 .response
361 .receiver
362 .recv()
363 .await
364 .map_err(|err| Error::Join(err.to_string()))
365 }
366
367 pub async fn stop_and_join(&self) -> Result<()> {
368 self.stop()?;
369 self.join().await
370 }
371}
372
373#[async_trait]
413pub trait WebSocketServerTrait: DowncastSync {
414 async fn bind(self: Arc<Self>, addr: &str) -> Result<TcpListener>;
415 async fn listen(
416 self: Arc<Self>,
417 listener: TcpListener,
418 config: Option<WebSocketConfig>,
419 ) -> Result<()>;
420 fn stop(&self) -> Result<()>;
421 async fn join(&self) -> Result<()>;
422 async fn stop_and_join(&self) -> Result<()>;
423}
424impl_downcast!(sync WebSocketServerTrait);
425
426#[async_trait]
427impl<T> WebSocketServerTrait for WebSocketServer<T>
428where
429 T: WebSocketHandler + Send + Sync + 'static + Sized,
430{
431 async fn bind(self: Arc<Self>, addr: &str) -> Result<TcpListener> {
432 WebSocketServer::<T>::bind(&self, addr).await
433 }
434
435 async fn listen(
436 self: Arc<Self>,
437 listener: TcpListener,
438 config: Option<WebSocketConfig>,
439 ) -> Result<()> {
440 WebSocketServer::<T>::listen(&self, listener, config).await
441 }
442
443 fn stop(&self) -> Result<()> {
444 WebSocketServer::<T>::stop(self)
445 }
446
447 async fn join(&self) -> Result<()> {
448 WebSocketServer::<T>::join(self).await
449 }
450
451 async fn stop_and_join(&self) -> Result<()> {
452 WebSocketServer::<T>::stop_and_join(self).await
453 }
454}
455
456pub mod handshake {
457 use super::*;
463
464 pub type HandshakeFn = Pin<Box<dyn Send + Sync + Fn(&str) -> Result<()>>>;
466
467 pub async fn greeting<'ws>(
471 timeout_duration: Duration,
472 _sender: &'ws mut WebSocketSender,
473 receiver: &'ws mut WebSocketReceiver,
474 handler: HandshakeFn,
475 ) -> Result<()> {
476 let delay = tokio::time::sleep(timeout_duration);
477 tokio::select! {
478 msg = receiver.next() => {
479 if let Some(Ok(msg)) = msg {
480 if msg.is_text() || msg.is_binary() {
481 return handler(msg.to_text()?);
482 }
483 }
484 Err(Error::MalformedHandshake)
485 }
486 _ = delay => {
487 Err(Error::ConnectionTimeout)
488 }
489 }
490 }
491}