workflow_websocket/server/
mod.rs

1//!
2//! async WebSocket server functionality (requires tokio executor)
3//!
4use async_trait::async_trait;
5use cfg_if::cfg_if;
6use downcast_rs::*;
7use futures::{future::FutureExt, select};
8use futures_util::{
9    stream::{SplitSink, SplitStream},
10    SinkExt, StreamExt,
11};
12use std::net::SocketAddr;
13use std::pin::Pin;
14use std::sync::atomic::{AtomicUsize, Ordering};
15use std::sync::Arc;
16use std::time::Duration;
17pub use tokio::net::TcpListener;
18use tokio::net::TcpStream;
19use tokio::sync::mpsc::{
20    UnboundedReceiver as TokioUnboundedReceiver, UnboundedSender as TokioUnboundedSender,
21};
22use tokio_tungstenite::{accept_async_with_config, WebSocketStream};
23use tungstenite::Error as WebSocketError;
24use workflow_core::channel::DuplexChannel;
25use workflow_log::*;
26pub mod error;
27pub mod result;
28
29pub use error::Error;
30pub use result::Result;
31pub use tungstenite::protocol::WebSocketConfig;
32pub use tungstenite::Message;
33/// WebSocket stream sender for dispatching [`tungstenite::Message`].
34/// This stream object must have a mutable reference and can not be cloned.
35pub type WebSocketSender = SplitSink<WebSocketStream<TcpStream>, Message>;
36/// WebSocket stream receiver for receiving [`tungstenite::Message`].
37/// This stream object must have a mutable reference and can not be cloned.
38pub type WebSocketReceiver = SplitStream<WebSocketStream<TcpStream>>;
39/// WebSocketSink [`tokio::sync::mpsc::UnboundedSender`] for dispatching
40/// messages from within the [`WebSocketHandler::message`]. This is an
41/// `MPSC` channel that can be cloned and retained externally for the
42/// lifetime of the WebSocket connection.
43pub type WebSocketSink = TokioUnboundedSender<Message>;
44
45/// Atomic counters that allow tracking connection counts
46/// and cumulative message sizes in bytes (bandwidth consumption
47/// without accounting for the websocket framing overhead).
48/// These counters can be created and supplied externally or
49/// supplied as `None`.
50pub struct WebSocketCounters {
51    pub total_connections: Arc<AtomicUsize>,
52    pub active_connections: Arc<AtomicUsize>,
53    pub handshake_failures: Arc<AtomicUsize>,
54    pub rx_bytes: Arc<AtomicUsize>,
55    pub tx_bytes: Arc<AtomicUsize>,
56}
57
58impl Default for WebSocketCounters {
59    fn default() -> Self {
60        WebSocketCounters {
61            total_connections: Arc::new(AtomicUsize::new(0)),
62            active_connections: Arc::new(AtomicUsize::new(0)),
63            handshake_failures: Arc::new(AtomicUsize::new(0)),
64            rx_bytes: Arc::new(AtomicUsize::new(0)),
65            tx_bytes: Arc::new(AtomicUsize::new(0)),
66        }
67    }
68}
69
70/// WebSocketHandler trait that represents the WebSocket processor
71/// functionality.  This trait is supplied to the WebSocket
72/// which subsequently invokes it's functions during websocket
73/// connection and messages.  The trait can override `with_handshake()` method
74/// to enable invocation of the `handshake()` method upon receipt of the
75/// first valid websocket message from the incoming connection.
76#[async_trait]
77pub trait WebSocketHandler
78where
79    Arc<Self>: Sync,
80{
81    /// Context type used by impl trait to represent websocket connection
82    type Context: Send + Sync;
83
84    /// Called to determine if the connection should be accepted.
85    fn accept(&self, _peer: &SocketAddr) -> bool {
86        true
87    }
88
89    /// Called immediately when connection is established.
90    /// This function should return an error to terminate the connection.
91    /// If the server manages a client ban list, it should process it
92    /// in this function and return an [`Error`] to prevent further processing.
93    async fn connect(self: &Arc<Self>, _peer: &SocketAddr) -> Result<()> {
94        Ok(())
95    }
96
97    /// Called upon websocket disconnection
98    async fn disconnect(self: &Arc<Self>, _ctx: Self::Context, _result: Result<()>) {}
99
100    /// Called after [`Self::connect()`], after creating the [`tokio::sync::mpsc`] sender `sink`
101    /// channel, allowing the server to execute additional handshake communication phase,
102    /// or retain the sink for external message dispatch (such as server-side notifications).
103    async fn handshake(
104        self: &Arc<Self>,
105        peer: &SocketAddr,
106        sender: &mut WebSocketSender,
107        receiver: &mut WebSocketReceiver,
108        sink: &WebSocketSink,
109    ) -> Result<Self::Context>;
110
111    /// Called for every websocket message
112    /// This function can return an error to terminate the connection
113    async fn message(
114        self: &Arc<Self>,
115        ctx: &Self::Context,
116        msg: Message,
117        sink: &WebSocketSink,
118    ) -> Result<()>;
119
120    async fn ctl(self: &Arc<Self>, msg: Message, sender: &mut WebSocketSender) -> Result<()> {
121        if let Message::Ping(data) = msg {
122            sender.send(Message::Pong(data)).await?;
123        }
124        Ok(())
125    }
126}
127
128/// WebSocketServer that provides the main websocket connection
129/// and message processing loop that delivers messages to the
130/// installed WebSocketHandler trait.
131pub struct WebSocketServer<T>
132where
133    T: WebSocketHandler + Send + Sync + 'static + Sized,
134{
135    // pub connections: AtomicU64,
136    pub counters: Arc<WebSocketCounters>,
137    pub handler: Arc<T>,
138    pub stop: DuplexChannel,
139}
140
141impl<T> WebSocketServer<T>
142where
143    T: WebSocketHandler + Send + Sync + 'static,
144{
145    pub fn new(handler: Arc<T>, counters: Option<Arc<WebSocketCounters>>) -> Arc<Self> {
146        Arc::new(WebSocketServer {
147            counters: counters.unwrap_or_default(),
148            handler,
149            stop: DuplexChannel::oneshot(),
150        })
151    }
152
153    async fn handle_connection(
154        self: &Arc<Self>,
155        peer: SocketAddr,
156        stream: TcpStream,
157        config: Option<WebSocketConfig>,
158    ) -> Result<()> {
159        let ws_stream = accept_async_with_config(stream, config).await?;
160        self.handler.connect(&peer).await?;
161        // log_trace!("WebSocket connected: {}", peer);
162
163        let (mut ws_sender, mut ws_receiver) = ws_stream.split();
164        let (sink_sender, sink_receiver) = tokio::sync::mpsc::unbounded_channel::<Message>();
165
166        let ctx = match self
167            .handler
168            .handshake(&peer, &mut ws_sender, &mut ws_receiver, &sink_sender)
169            .await
170        {
171            Ok(ctx) => ctx,
172            Err(err) => {
173                self.counters
174                    .handshake_failures
175                    .fetch_add(1, Ordering::Relaxed);
176                return Err(err);
177            }
178        };
179
180        let result = self
181            .connection_task(&ctx, ws_sender, ws_receiver, sink_sender, sink_receiver)
182            .await;
183        self.handler.disconnect(ctx, result).await;
184        // log_trace!("WebSocket disconnected: {}", peer);
185
186        Ok(())
187    }
188
189    async fn connection_task(
190        self: &Arc<Self>,
191        ctx: &T::Context,
192        mut ws_sender: WebSocketSender,
193        mut ws_receiver: WebSocketReceiver,
194        sink_sender: TokioUnboundedSender<Message>,
195        mut sink_receiver: TokioUnboundedReceiver<Message>,
196    ) -> Result<()> {
197        loop {
198            tokio::select! {
199                msg = sink_receiver.recv() => {
200                    let msg = msg.unwrap();
201                    match msg {
202                        Message::Binary(data)  => {
203                            self.counters.tx_bytes.fetch_add(data.len(), Ordering::Relaxed);
204                            ws_sender.send(Message::Binary(data)).await?;
205                        },
206                        Message::Text(text)  => {
207                            self.counters.tx_bytes.fetch_add(text.len(), Ordering::Relaxed);
208                            ws_sender.send(Message::Text(text)).await?;
209                        },
210                        Message::Close(_) => {
211                            ws_sender.send(msg).await?;
212                            break;
213                        },
214                        Message::Ping(data) => {
215                            self.counters.tx_bytes.fetch_add(data.len(), Ordering::Relaxed);
216                            ws_sender.send(Message::Ping(data)).await?;
217                        },
218                        Message::Pong(data) => {
219                            self.counters.tx_bytes.fetch_add(data.len(), Ordering::Relaxed);
220                            ws_sender.send(Message::Pong(data)).await?;
221                        },
222                        msg => {
223                            ws_sender.send(msg).await?;
224                        }
225                    }
226                },
227                msg = ws_receiver.next() => {
228                    match msg {
229                        Some(msg) => {
230                            let msg = msg?;
231                            match msg {
232                                Message::Binary(data)  => {
233                                    self.counters.rx_bytes.fetch_add(data.len(), Ordering::Relaxed);
234                                    self.handler.message(ctx, Message::Binary(data), &sink_sender).await?;
235                                },
236                                Message::Text(text)  => {
237                                    self.counters.rx_bytes.fetch_add(text.len(), Ordering::Relaxed);
238                                    self.handler.message(ctx, Message::Text(text), &sink_sender).await?;
239                                },
240                                Message::Close(_) => {
241                                    self.handler.message(ctx, msg, &sink_sender).await?;
242                                    break;
243                                },
244                                Message::Ping(data) => {
245                                    self.counters.rx_bytes.fetch_add(data.len(), Ordering::Relaxed);
246                                    cfg_if! {
247                                        if #[cfg(feature = "ping-pong")] {
248                                            self.handler.ctl(Message::Ping(data), &mut ws_sender).await?;
249                                        } else {
250                                            ws_sender.send(Message::Pong(data)).await?;
251                                        }
252                                    }
253                                },
254                                Message::Pong(data) => {
255                                    self.counters.rx_bytes.fetch_add(data.len(), Ordering::Relaxed);
256                                    cfg_if! {
257                                        if #[cfg(feature = "ping-pong")] {
258                                            self.handler.ctl(Message::Pong(data), &mut ws_sender).await?;
259                                        } else {
260                                            // ignore pong
261                                        }
262                                    }
263                                },
264                                _ => {
265                                }
266                            }
267                        }
268                        None => {
269                            return Err(Error::AbnormalClose);
270                        }
271                    }
272                }
273            }
274        }
275
276        Ok(())
277    }
278
279    pub async fn bind(self: &Arc<Self>, addr: &str) -> Result<TcpListener> {
280        let listener = TcpListener::bind(&addr).await.map_err(|err| {
281            Error::Listen(format!(
282                "WebSocket server unable to listen on `{addr}`: {err}",
283            ))
284        })?;
285        // log_trace!("WebSocket server listening on: {}", addr);
286        Ok(listener)
287    }
288
289    async fn accept(self: &Arc<Self>, stream: TcpStream, config: Option<WebSocketConfig>) {
290        let peer = match stream.peer_addr() {
291            Ok(peer_address) => peer_address,
292            Err(_) => {
293                self.counters
294                    .handshake_failures
295                    .fetch_add(1, Ordering::Relaxed);
296                return;
297            }
298        };
299
300        self.counters
301            .total_connections
302            .fetch_add(1, Ordering::Relaxed);
303        self.counters
304            .active_connections
305            .fetch_add(1, Ordering::Relaxed);
306
307        let self_ = self.clone();
308        tokio::spawn(async move {
309            if let Err(e) = self_.handle_connection(peer, stream, config).await {
310                match e {
311                    Error::WebSocketError(WebSocketError::ConnectionClosed)
312                    | Error::WebSocketError(WebSocketError::Protocol(_))
313                    | Error::WebSocketError(WebSocketError::Utf8) => (),
314                    err => log_error!("Error processing connection: {}", err),
315                }
316            }
317            self_
318                .counters
319                .active_connections
320                .fetch_sub(1, Ordering::Relaxed)
321        });
322    }
323
324    pub async fn listen(
325        self: &Arc<Self>,
326        listener: TcpListener,
327        config: Option<WebSocketConfig>,
328    ) -> Result<()> {
329        loop {
330            select! {
331                stream = listener.accept().fuse() => {
332                    if let Ok((stream,socket_addr)) = stream {
333                        if self.handler.accept(&socket_addr) {
334                            self.accept(stream, config).await;
335                        }
336                    }
337                },
338                _ = self.stop.request.receiver.recv().fuse() => break,
339            }
340        }
341
342        self.stop
343            .response
344            .sender
345            .send(())
346            .await
347            .map_err(|err| Error::Done(err.to_string()))
348    }
349
350    pub fn stop(&self) -> Result<()> {
351        self.stop
352            .request
353            .sender
354            .try_send(())
355            .map_err(|err| Error::Stop(err.to_string()))
356    }
357
358    pub async fn join(&self) -> Result<()> {
359        self.stop
360            .response
361            .receiver
362            .recv()
363            .await
364            .map_err(|err| Error::Join(err.to_string()))
365    }
366
367    pub async fn stop_and_join(&self) -> Result<()> {
368        self.stop()?;
369        self.join().await
370    }
371}
372
373/// Base WebSocketServer trait allows the [`WebSocketServer<T>`] struct
374/// to be retained by the trait reference by casting it to the trait
375/// as follows:
376///
377/// ```rust
378/// use std::sync::Arc;
379/// use async_trait::async_trait;
380/// use workflow_websocket::server::{Result,WebSocketServerTrait,WebSocketConfig,TcpListener};
381///
382/// struct Server{}
383///
384/// #[async_trait]
385/// impl WebSocketServerTrait for Server {
386///     async fn bind(self: Arc<Self>, addr: &str) -> Result<TcpListener>{
387///         unimplemented!()
388///     }
389///     async fn listen(self: Arc<Self>, listener : TcpListener, config: Option<WebSocketConfig>) -> Result<()>{
390///         unimplemented!()
391///     }
392///     fn stop(&self) -> Result<()>{
393///         unimplemented!()
394///     }
395///     async fn join(&self) -> Result<()>{
396///         unimplemented!()
397///     }
398///     async fn stop_and_join(&self) -> Result<()>{
399///         unimplemented!()
400///     }
401/// }
402/// let server_trait: Arc<dyn WebSocketServerTrait> = Arc::new(Server{});
403/// let server = server_trait.downcast_arc::<Server>();
404/// ```
405/// This can help simplify web socket handling in case the supplied
406/// `T` generic contains complex generic types that typically
407/// results in generics propagating up into the ownership type chain.
408///
409/// This trait is used in the [`workflow-rpc`](https://docs.rs/workflow-rpc)
410/// crate to isolate `RpcHandler` generics from the RpcServer owning the WebSocket.
411///
412#[async_trait]
413pub trait WebSocketServerTrait: DowncastSync {
414    async fn bind(self: Arc<Self>, addr: &str) -> Result<TcpListener>;
415    async fn listen(
416        self: Arc<Self>,
417        listener: TcpListener,
418        config: Option<WebSocketConfig>,
419    ) -> Result<()>;
420    fn stop(&self) -> Result<()>;
421    async fn join(&self) -> Result<()>;
422    async fn stop_and_join(&self) -> Result<()>;
423}
424impl_downcast!(sync WebSocketServerTrait);
425
426#[async_trait]
427impl<T> WebSocketServerTrait for WebSocketServer<T>
428where
429    T: WebSocketHandler + Send + Sync + 'static + Sized,
430{
431    async fn bind(self: Arc<Self>, addr: &str) -> Result<TcpListener> {
432        WebSocketServer::<T>::bind(&self, addr).await
433    }
434
435    async fn listen(
436        self: Arc<Self>,
437        listener: TcpListener,
438        config: Option<WebSocketConfig>,
439    ) -> Result<()> {
440        WebSocketServer::<T>::listen(&self, listener, config).await
441    }
442
443    fn stop(&self) -> Result<()> {
444        WebSocketServer::<T>::stop(self)
445    }
446
447    async fn join(&self) -> Result<()> {
448        WebSocketServer::<T>::join(self).await
449    }
450
451    async fn stop_and_join(&self) -> Result<()> {
452        WebSocketServer::<T>::stop_and_join(self).await
453    }
454}
455
456pub mod handshake {
457    //!
458    //! Module containing simple convenience handshake functions
459    //! such as `greeting()`
460    //!     
461
462    use super::*;
463
464    /// Handshake closure function type for [`greeting()`] handshake
465    pub type HandshakeFn = Pin<Box<dyn Send + Sync + Fn(&str) -> Result<()>>>;
466
467    /// Simple greeting handshake where supplied closure receives
468    /// the first message from the client and should return
469    /// `Ok(())` to proceed or [`Error`] to abort the connection.
470    pub async fn greeting<'ws>(
471        timeout_duration: Duration,
472        _sender: &'ws mut WebSocketSender,
473        receiver: &'ws mut WebSocketReceiver,
474        handler: HandshakeFn,
475    ) -> Result<()> {
476        let delay = tokio::time::sleep(timeout_duration);
477        tokio::select! {
478            msg = receiver.next() => {
479                if let Some(Ok(msg)) = msg {
480                    if msg.is_text() || msg.is_binary() {
481                        return handler(msg.to_text()?);
482                    }
483                }
484                Err(Error::MalformedHandshake)
485            }
486            _ = delay => {
487                Err(Error::ConnectionTimeout)
488            }
489        }
490    }
491}