tokio_tungstenite_keepalive/
lib.rs1#![warn(clippy::pedantic)]
2
3use std::sync::Arc;
4use std::marker::PhantomData;
5use once_cell::sync::OnceCell;
6use std::{pin::Pin, task::Poll};
7
8use futures::{StreamExt, SinkExt, FutureExt};
9use tokio_tungstenite::{WebSocketStream, tungstenite, tungstenite::Message};
10
11use tokio::{
12 io::{AsyncRead, AsyncWrite},
13 sync::mpsc::{UnboundedSender as Sender, UnboundedReceiver as Receiver},
14};
15
16#[pin_project::pin_project]
17pub struct KeptAliveWebSocket<S> {
18 #[pin]
19 next_chan: Receiver<Message>,
20 send_chan: Sender<Message>,
21 err_cell: Arc<OnceCell<tungstenite::Error>>,
22
23 phantom: PhantomData<S>
24}
25
26impl<S> KeptAliveWebSocket<S>
27where
28 S: AsyncRead + AsyncWrite + Unpin + Send + Sync + 'static,
29 <WebSocketStream<S> as futures::Stream>::Item: Is<Type=Result<Message, tungstenite::Error>> + Send,
30 WebSocketStream<S>: futures::Stream,
31{
32 pub fn new(mut websocket: WebSocketStream<S>) -> Self {
35 let (mut next_chan_send, next_chan_recv) = tokio::sync::mpsc::unbounded_channel();
36 let (send_chan_send, mut send_chan_recv) = tokio::sync::mpsc::unbounded_channel();
37
38 let err_cell = Arc::new(OnceCell::new());
39 let err_cell_clone = err_cell.clone();
40
41 tokio::spawn(async move {
42 if let Err(err) = Self::handle_msgs(&mut websocket, &mut next_chan_send, &mut send_chan_recv).await {
43 err_cell_clone.set(err).expect("Error has been set before!");
44 }
45 });
46
47 Self {
48 next_chan: next_chan_recv,
49 send_chan: send_chan_send,
50 err_cell,
51
52 phantom: PhantomData
53 }
54 }
55
56 async fn handle_msgs(ws: &mut WebSocketStream<S>, next_chan: &mut Sender<Message>, send_chan: &mut Receiver<Message>) -> Result<(), tungstenite::Error> {
57 loop {
58 futures::select! {
59 ws_msg = ws.next() => {
60 let ws_msg = if let Some(msg) = ws_msg {
61 narrow(msg)?
62 } else {
63 return Ok(());
64 };
65
66 if next_chan.send(ws_msg).is_err() {
67 return Ok(())
68 }
69 },
70 to_send = send_chan.recv().fuse() => {
71 if let Some(to_send) = to_send {
72 ws.send(to_send).await?;
73 } else{
74 return Ok(())
75 }
76 }
77 }
78 }
79 }
80}
81
82impl<S> KeptAliveWebSocket<S> {
83 pub fn send(&self, msg: Message) -> Result<(), &tungstenite::Error> {
89 if let Some(err) = self.err_cell.get() {
90 return Err(err)
91 }
92
93 self.send_chan.send(msg).expect("Background task has been closed!");
94 Ok(())
95 }
96
97 #[must_use]
102 pub fn poison(&self) -> Option<&tungstenite::Error> {
103 self.err_cell.get()
104 }
105}
106
107impl<S> futures::Stream for KeptAliveWebSocket<S> {
108 type Item = Message;
109
110 fn poll_next(self: Pin<&mut Self>, cx: &mut std::task::Context<'_>) -> Poll<Option<Self::Item>> {
111 if self.err_cell.get().is_some() {
112 Poll::Ready(None)
113 } else {
114 self.project().next_chan.poll_recv(cx)
115 }
116 }
117}
118
119pub trait Is {
121 type Type;
122 fn into(self) -> Self::Type;
123}
124
125impl<T> Is for T {
126 type Type = T;
127 fn into(self) -> Self::Type {
128 self
129 }
130}
131
132fn narrow<T: Is<Type=U>, U>(t: T) -> U {
133 t.into()
134}