penguin_mux/
ws.rs

1//! Generic WebSocket
2//
3// SPDX-License-Identifier: Apache-2.0 OR GPL-3.0-or-later
4
5use bytes::Bytes;
6use std::task::{Context, Poll};
7
8/// Types of messages we need
9#[derive(Clone, PartialEq, Eq)]
10pub enum Message {
11    /// Binary message or any payload
12    Binary(Bytes),
13    /// Ping message. Note that the payload is discarded.
14    Ping,
15    /// Pong message. Note that the payload is discarded.
16    Pong,
17    /// Close message. Note that the payload is discarded.
18    Close,
19}
20
21impl std::fmt::Debug for Message {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        match self {
24            Self::Binary(data) => f.debug_struct("Binary").field("len", &data.len()).finish(),
25            Self::Ping => f.debug_struct("Ping").finish(),
26            Self::Pong => f.debug_struct("Pong").finish(),
27            Self::Close => f.debug_struct("Close").finish(),
28        }
29    }
30}
31
32/// A generic WebSocket stream
33///
34/// Specialized for our [`Message`] type similar to [`futures_util::Stream`] and [`futures_util::Sink`].
35/// See [`futures_util::Stream`] and [`futures_util::Sink`] for more details on the required methods.
36pub trait WebSocket: Send + 'static {
37    /// Attempt to prepare the `Sink` to receive a value.
38    ///
39    /// # Errors
40    /// Indicates the underlying sink is permanently be unable to receive items.
41    fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>>;
42    /// Begin the process of sending a value to the sink.
43    ///
44    /// # Errors
45    /// Indicates the underlying sink is permanently be unable to receive items.
46    fn start_send_unpin(&mut self, item: Message) -> Result<(), crate::Error>;
47    /// Flush any remaining output from this sink.
48    ///
49    /// # Errors
50    /// Indicates the underlying sink is permanently be unable to receive items.
51    fn poll_flush_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>>;
52    /// Flush any remaining output and close this sink, if necessary.
53    ///
54    /// # Errors
55    /// Indicates the underlying sink is unable to be closed properly but is nonetheless
56    /// permanently be unable to receive items.
57    fn poll_close_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>>;
58    /// Attempt to pull out the next value of this stream.
59    ///
60    /// # Errors
61    /// Indicates the underlying stream is otherwise unable to produce items.
62    fn poll_next_unpin(
63        &mut self,
64        cx: &mut Context<'_>,
65    ) -> Poll<Option<Result<Message, crate::Error>>>;
66}
67
68#[cfg(feature = "tungstenite")]
69mod tokio_tungstenite {
70    use std::{
71        pin::Pin,
72        task::{Context, Poll},
73    };
74
75    use bytes::Bytes;
76    use futures_util::{Sink, Stream};
77    use tokio_tungstenite::tungstenite;
78    use tracing::error;
79
80    use super::{Message, WebSocket};
81    impl From<tungstenite::Message> for Message {
82        #[inline]
83        fn from(msg: tungstenite::Message) -> Self {
84            match msg {
85                tungstenite::Message::Binary(data) => Self::Binary(data),
86                tungstenite::Message::Text(text) => {
87                    error!("Received text message: {text}");
88                    Self::Binary(Bytes::from(text))
89                }
90                tungstenite::Message::Ping(_) => Self::Ping,
91                tungstenite::Message::Pong(_) => Self::Pong,
92                tungstenite::Message::Close(_) => Self::Close,
93                tungstenite::Message::Frame(_) => {
94                    unreachable!("`Frame` message should not be received")
95                }
96            }
97        }
98    }
99
100    impl From<Message> for tungstenite::Message {
101        #[inline]
102        fn from(msg: Message) -> Self {
103            match msg {
104                Message::Binary(data) => Self::Binary(data),
105                Message::Ping => Self::Ping(Bytes::new()),
106                Message::Pong => Self::Pong(Bytes::new()),
107                Message::Close => Self::Close(None),
108            }
109        }
110    }
111
112    impl From<tungstenite::Error> for crate::Error {
113        #[inline]
114        fn from(e: tungstenite::Error) -> Self {
115            Self::WebSocket(Box::new(e))
116        }
117    }
118
119    impl<RW> WebSocket for tokio_tungstenite::WebSocketStream<RW>
120    where
121        RW: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
122    {
123        #[inline]
124        fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
125            Pin::new(self).poll_ready(cx).map_err(Into::into)
126        }
127
128        #[inline]
129        fn start_send_unpin(&mut self, item: Message) -> Result<(), crate::Error> {
130            let item: tungstenite::Message = item.into();
131            Pin::new(self).start_send(item).map_err(Into::into)
132        }
133
134        #[inline]
135        fn poll_flush_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
136            Pin::new(self).poll_flush(cx).map_err(Into::into)
137        }
138
139        #[inline]
140        fn poll_close_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
141            let this = Pin::new(self);
142            futures_util::Sink::poll_close(this, cx).map_err(Into::into)
143        }
144
145        #[inline]
146        fn poll_next_unpin(
147            &mut self,
148            cx: &mut Context<'_>,
149        ) -> Poll<Option<Result<Message, crate::Error>>> {
150            Pin::new(self)
151                .poll_next(cx)
152                .map(|opt| opt.map(|res| res.map(Into::into).map_err(Into::into)))
153        }
154    }
155
156    #[cfg(test)]
157    mod tests {
158        use super::*;
159        use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
160
161        #[test]
162        fn test_binary_message() {
163            let msg = tungstenite::Message::Binary(Bytes::from_static(b"Hello"));
164            let converted: Message = msg.clone().into();
165            assert_eq!(converted, Message::Binary(Bytes::from_static(b"Hello")));
166            assert_eq!(tungstenite::Message::from(converted), msg);
167        }
168
169        #[test]
170        fn test_text_message() {
171            let msg = tungstenite::Message::Text("Hello".into());
172            let converted: Message = msg.into();
173            assert_eq!(converted, Message::Binary(Bytes::from_static(b"Hello")));
174            assert_eq!(
175                tungstenite::Message::from(converted),
176                tungstenite::Message::Binary(Bytes::from_static(b"Hello"))
177            );
178        }
179
180        #[test]
181        fn test_ping_message() {
182            let msg = tungstenite::Message::Ping(Bytes::from_static(b"Ping"));
183            let converted: Message = msg.into();
184            assert_eq!(converted, Message::Ping);
185            assert_eq!(
186                tungstenite::Message::from(converted),
187                tungstenite::Message::Ping(Bytes::new())
188            );
189
190            let msg = tungstenite::Message::Pong(Bytes::from_static(b"Pong"));
191            let converted: Message = msg.into();
192            assert_eq!(converted, Message::Pong);
193            assert_eq!(
194                tungstenite::Message::from(converted),
195                tungstenite::Message::Pong(Bytes::new())
196            );
197        }
198
199        #[test]
200        fn test_close_message() {
201            let close_msg =
202                tungstenite::Message::Close(Some(tungstenite::protocol::frame::CloseFrame {
203                    code: CloseCode::Reserved(1000),
204                    reason: "Normal".into(),
205                }));
206            let converted: Message = close_msg.into();
207            assert_eq!(converted, Message::Close);
208        }
209    }
210}