tokio_tungstenite_keepalive/
lib.rs

1#![warn(clippy::pedantic)]
2
3use std::sync::Arc;
4use std::marker::PhantomData;
5use once_cell::sync::OnceCell;
6use std::{pin::Pin, task::Poll};
7
8use futures::{StreamExt, SinkExt, FutureExt};
9use tokio_tungstenite::{WebSocketStream, tungstenite, tungstenite::Message};
10
11use tokio::{
12    io::{AsyncRead, AsyncWrite},
13    sync::mpsc::{UnboundedSender as Sender, UnboundedReceiver as Receiver},
14};
15
16#[pin_project::pin_project]
17pub struct KeptAliveWebSocket<S> {
18    #[pin]
19    next_chan: Receiver<Message>,
20    send_chan: Sender<Message>,
21    err_cell: Arc<OnceCell<tungstenite::Error>>,
22
23    phantom: PhantomData<S>
24}
25
26impl<S> KeptAliveWebSocket<S>
27where
28    S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
29    <WebSocketStream<S> as futures::Stream>::Item: Is<Type=Result<Message, tungstenite::Error>> + Send,
30    WebSocketStream<S>: futures::Stream,
31{
32    /// Wraps a websocket with to handle Ping messages as soon as they are recieved, while buffering
33    /// other messages to be recieved when consumed as a [`Stream`].
34    pub fn new(mut websocket: WebSocketStream<S>) -> Self {
35        let (mut next_chan_send, next_chan_recv) = tokio::sync::mpsc::unbounded_channel();
36        let (send_chan_send, mut send_chan_recv) = tokio::sync::mpsc::unbounded_channel();
37
38        let err_cell = Arc::new(OnceCell::new());
39        let err_cell_clone = err_cell.clone();
40
41        tokio::spawn(async move {
42            if let Err(err) = Self::handle_msgs(&mut websocket, &mut next_chan_send, &mut send_chan_recv).await {
43                err_cell_clone.set(err).expect("Error has been set before!");
44            }
45        });
46
47        Self {
48            next_chan: next_chan_recv,
49            send_chan: send_chan_send,
50            err_cell,
51
52            phantom: PhantomData
53        }
54    }
55
56    async fn handle_msgs(ws: &mut WebSocketStream<S>, next_chan: &mut Sender<Message>, send_chan: &mut Receiver<Message>) -> Result<(), tungstenite::Error> {
57        loop {
58            futures::select! {
59                ws_msg = ws.next() => {
60                    let ws_msg = if let Some(msg) = ws_msg {
61                        narrow(msg)?
62                    } else {
63                        return Ok(());
64                    };
65
66                    if next_chan.send(ws_msg).is_err() {
67                        return Ok(())
68                    }
69                },
70                to_send = send_chan.recv().fuse() => {
71                    if let Some(to_send) = to_send {
72                        ws.send(to_send).await?;
73                    } else{
74                        return Ok(())
75                    }
76                }
77            }
78        }
79    }
80}
81
82impl<S> KeptAliveWebSocket<S> {
83    /// Sends a message to the websocket, without waiting until it has been sent.
84    ///
85    /// # Errors
86    /// This errors if the websocket has returned an error previously, and this
87    /// [`KeptAliveWebSocket`] has been poisoned.
88    pub fn send(&self, msg: Message) -> Result<(), &tungstenite::Error> {
89        if let Some(err) = self.err_cell.get() {
90            return Err(err)
91        }
92
93        self.send_chan.send(msg).expect("Background task has been closed!");
94        Ok(())
95    }
96
97    /// Returns the current poison error, if the websocket has failed to send a message.
98    ///
99    /// If this is [`Some`], [`KeptAliveWebSocket::send`] and [`Stream`] methods will error
100    /// or be a no-op.
101    #[must_use]
102    pub fn poison(&self) -> Option<&tungstenite::Error> {
103        self.err_cell.get()
104    }
105}
106
107impl<S> futures::Stream for KeptAliveWebSocket<S> {
108    type Item = Message;
109
110    fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
111        if self.err_cell.get().is_some() {
112            Poll::Ready(None)
113        } else {
114            self.project().next_chan.poll_recv(cx)
115        }
116    }
117}
118
119/// Trait to allow [`WebsocketStream`] to be properly constrained to only return [`Message`]
120pub trait Is {
121    type Type;
122    fn into(self) -> Self::Type;
123}
124
125impl<T> Is for T {
126    type Type = T;
127    fn into(self) -> Self::Type {
128        self
129    }
130}
131
132fn narrow<T: Is<Type=U>, U>(t: T) -> U {
133    t.into()
134}