1use bytes::Bytes;
6use std::task::{Context, Poll};
7
8#[derive(Clone, PartialEq, Eq)]
10pub enum Message {
11 Binary(Bytes),
13 Ping,
15 Pong,
17 Close,
19}
20
21impl std::fmt::Debug for Message {
22 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23 match self {
24 Self::Binary(data) => f.debug_struct("Binary").field("len", &data.len()).finish(),
25 Self::Ping => f.debug_struct("Ping").finish(),
26 Self::Pong => f.debug_struct("Pong").finish(),
27 Self::Close => f.debug_struct("Close").finish(),
28 }
29 }
30}
31
32pub trait WebSocket: Send + 'static {
37 fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>>;
42 fn start_send_unpin(&mut self, item: Message) -> Result<(), crate::Error>;
47 fn poll_flush_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>>;
52 fn poll_close_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>>;
58 fn poll_next_unpin(
63 &mut self,
64 cx: &mut Context<'_>,
65 ) -> Poll<Option<Result<Message, crate::Error>>>;
66}
67
68#[cfg(feature = "tungstenite")]
69mod tokio_tungstenite {
70 use std::{
71 pin::Pin,
72 task::{Context, Poll},
73 };
74
75 use bytes::Bytes;
76 use futures_util::{Sink, Stream};
77 use tokio_tungstenite::tungstenite;
78 use tracing::error;
79
80 use super::{Message, WebSocket};
81 impl From<tungstenite::Message> for Message {
82 #[inline]
83 fn from(msg: tungstenite::Message) -> Self {
84 match msg {
85 tungstenite::Message::Binary(data) => Self::Binary(data),
86 tungstenite::Message::Text(text) => {
87 error!("Received text message: {text}");
88 Self::Binary(Bytes::from(text))
89 }
90 tungstenite::Message::Ping(_) => Self::Ping,
91 tungstenite::Message::Pong(_) => Self::Pong,
92 tungstenite::Message::Close(_) => Self::Close,
93 tungstenite::Message::Frame(_) => {
94 unreachable!("`Frame` message should not be received")
95 }
96 }
97 }
98 }
99
100 impl From<Message> for tungstenite::Message {
101 #[inline]
102 fn from(msg: Message) -> Self {
103 match msg {
104 Message::Binary(data) => Self::Binary(data),
105 Message::Ping => Self::Ping(Bytes::new()),
106 Message::Pong => Self::Pong(Bytes::new()),
107 Message::Close => Self::Close(None),
108 }
109 }
110 }
111
112 impl From<tungstenite::Error> for crate::Error {
113 #[inline]
114 fn from(e: tungstenite::Error) -> Self {
115 Self::WebSocket(Box::new(e))
116 }
117 }
118
119 impl<RW> WebSocket for tokio_tungstenite::WebSocketStream<RW>
120 where
121 RW: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin + Send + 'static,
122 {
123 #[inline]
124 fn poll_ready_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
125 Pin::new(self).poll_ready(cx).map_err(Into::into)
126 }
127
128 #[inline]
129 fn start_send_unpin(&mut self, item: Message) -> Result<(), crate::Error> {
130 let item: tungstenite::Message = item.into();
131 Pin::new(self).start_send(item).map_err(Into::into)
132 }
133
134 #[inline]
135 fn poll_flush_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
136 Pin::new(self).poll_flush(cx).map_err(Into::into)
137 }
138
139 #[inline]
140 fn poll_close_unpin(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), crate::Error>> {
141 let this = Pin::new(self);
142 futures_util::Sink::poll_close(this, cx).map_err(Into::into)
143 }
144
145 #[inline]
146 fn poll_next_unpin(
147 &mut self,
148 cx: &mut Context<'_>,
149 ) -> Poll<Option<Result<Message, crate::Error>>> {
150 Pin::new(self)
151 .poll_next(cx)
152 .map(|opt| opt.map(|res| res.map(Into::into).map_err(Into::into)))
153 }
154 }
155
156 #[cfg(test)]
157 mod tests {
158 use super::*;
159 use tokio_tungstenite::tungstenite::protocol::frame::coding::CloseCode;
160
161 #[test]
162 fn test_binary_message() {
163 let msg = tungstenite::Message::Binary(Bytes::from_static(b"Hello"));
164 let converted: Message = msg.clone().into();
165 assert_eq!(converted, Message::Binary(Bytes::from_static(b"Hello")));
166 assert_eq!(tungstenite::Message::from(converted), msg);
167 }
168
169 #[test]
170 fn test_text_message() {
171 let msg = tungstenite::Message::Text("Hello".into());
172 let converted: Message = msg.into();
173 assert_eq!(converted, Message::Binary(Bytes::from_static(b"Hello")));
174 assert_eq!(
175 tungstenite::Message::from(converted),
176 tungstenite::Message::Binary(Bytes::from_static(b"Hello"))
177 );
178 }
179
180 #[test]
181 fn test_ping_message() {
182 let msg = tungstenite::Message::Ping(Bytes::from_static(b"Ping"));
183 let converted: Message = msg.into();
184 assert_eq!(converted, Message::Ping);
185 assert_eq!(
186 tungstenite::Message::from(converted),
187 tungstenite::Message::Ping(Bytes::new())
188 );
189
190 let msg = tungstenite::Message::Pong(Bytes::from_static(b"Pong"));
191 let converted: Message = msg.into();
192 assert_eq!(converted, Message::Pong);
193 assert_eq!(
194 tungstenite::Message::from(converted),
195 tungstenite::Message::Pong(Bytes::new())
196 );
197 }
198
199 #[test]
200 fn test_close_message() {
201 let close_msg =
202 tungstenite::Message::Close(Some(tungstenite::protocol::frame::CloseFrame {
203 code: CloseCode::Reserved(1000),
204 reason: "Normal".into(),
205 }));
206 let converted: Message = close_msg.into();
207 assert_eq!(converted, Message::Close);
208 }
209 }
210}