1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3mod error;
4#[cfg(feature = "mocking")]
5mod mock_server;
6#[cfg(feature = "mocking")]
7pub use mock_server::{EchoControlMessage, echo_server, get_mock_address};
8
9use bytes::Bytes;
10pub use error::Error;
11use futures::{SinkExt, StreamExt, stream::SplitSink};
12use serde::{Deserialize, Serialize};
13use std::{fmt::Debug, time::Duration};
14use tokio::{
15 net::TcpStream,
16 select,
17 sync::{mpsc, oneshot},
18 time::sleep,
19};
20use tokio_tungstenite::{
21 MaybeTlsStream, WebSocketStream, connect_async,
22 tungstenite::{self, Message, Utf8Bytes, protocol::CloseFrame},
23};
24
25#[cfg(feature = "tracing")]
26use tracing::{debug, error, info, instrument, trace};
27use url::Url;
28
29#[derive(Debug)]
30struct TxChannelPayload {
31 message: Message,
32 response_tx: oneshot::Sender<Result<(), Error>>,
33}
34
35#[derive(Debug)]
43pub struct Socketeer<RxMessage, TxMessage, const CHANNEL_SIZE: usize = 4> {
44 url: Url,
45 receiever: mpsc::Receiver<Message>,
46 sender: mpsc::Sender<TxChannelPayload>,
47 socket_handle: tokio::task::JoinHandle<Result<(), Error>>,
48 _rx_message: std::marker::PhantomData<RxMessage>,
49 _tx_message: std::marker::PhantomData<TxMessage>,
50}
51
52impl<
53 RxMessage: for<'a> Deserialize<'a> + Debug,
54 TxMessage: Serialize + Debug,
55 const CHANNEL_SIZE: usize,
56> Socketeer<RxMessage, TxMessage, CHANNEL_SIZE>
57{
58 #[cfg_attr(feature = "tracing", instrument)]
64 pub async fn connect(
65 url: &str,
66 ) -> Result<Socketeer<RxMessage, TxMessage, CHANNEL_SIZE>, Error> {
67 let url = Url::parse(url).map_err(|source| Error::UrlParse {
68 url: url.to_string(),
69 source,
70 })?;
71 #[allow(unused_variables)]
72 let (socket, response) = connect_async(url.as_str()).await?;
73 #[cfg(feature = "tracing")]
74 debug!("Connection Successful, connection info: \n{:#?}", response);
75
76 let (tx_tx, tx_rx) = mpsc::channel::<TxChannelPayload>(CHANNEL_SIZE);
77 let (rx_tx, rx_rx) = mpsc::channel::<Message>(CHANNEL_SIZE);
78
79 let socket_handle = tokio::spawn(async move { socket_loop(tx_rx, rx_tx, socket).await });
80 Ok(Socketeer {
81 url,
82 receiever: rx_rx,
83 sender: tx_tx,
84 socket_handle,
85 _rx_message: std::marker::PhantomData,
86 _tx_message: std::marker::PhantomData,
87 })
88 }
89
90 #[cfg_attr(feature = "tracing", instrument)]
97 pub async fn next_message(&mut self) -> Result<RxMessage, Error> {
98 let Some(message) = self.receiever.recv().await else {
99 return Err(Error::WebsocketClosed);
100 };
101 match message {
102 Message::Text(text) => {
103 #[cfg(feature = "tracing")]
104 trace!("Received text message: {:?}", text);
105 let message = serde_json::from_str(&text)?;
106 Ok(message)
107 }
108 Message::Binary(message) => {
109 #[cfg(feature = "tracing")]
110 trace!("Received binary message: {:?}", message);
111 let message = serde_json::from_slice(&message)?;
112 Ok(message)
113 }
114 _ => Err(Error::UnexpectedMessageType(message)),
115 }
116 }
117
118 #[cfg_attr(feature = "tracing", instrument)]
126 pub async fn send(&self, message: TxMessage) -> Result<(), Error> {
127 #[cfg(feature = "tracing")]
128 trace!("Sending message: {:?}", message);
129
130 let (tx, rx) = oneshot::channel::<Result<(), Error>>();
131 let message = serde_json::to_string(&message)?;
132
133 self.sender
134 .send(TxChannelPayload {
135 message: Message::Text(message.into()),
136 response_tx: tx,
137 })
138 .await
139 .map_err(|_| Error::WebsocketClosed)?;
140 match rx.await {
142 Ok(result) => result,
143 Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
144 }
145 }
146
147 #[cfg_attr(feature = "tracing", instrument)]
154 pub async fn reconnect(self) -> Result<Self, Error> {
155 let url = self.url.as_str().to_owned();
156 #[cfg(feature = "tracing")]
157 info!("Reconnecting");
158 match self.close_connection().await {
159 Ok(()) => (),
160 #[allow(unused_variables)]
161 Err(e) => {
162 #[cfg(feature = "tracing")]
163 error!("Socket Loop already stopped: {}", e);
164 }
165 }
166 Self::connect(&url).await
167 }
168
169 #[cfg_attr(feature = "tracing", instrument)]
175 pub async fn close_connection(self) -> Result<(), Error> {
176 #[cfg(feature = "tracing")]
177 debug!("Closing Connection");
178 let (tx, rx) = oneshot::channel::<Result<(), Error>>();
179 self.sender
180 .send(TxChannelPayload {
181 message: Message::Close(Some(CloseFrame {
182 code: tungstenite::protocol::frame::coding::CloseCode::Normal,
183 reason: Utf8Bytes::from_static("Closing Connection"),
184 })),
185 response_tx: tx,
186 })
187 .await
188 .map_err(|_| Error::WebsocketClosed)?;
189 match rx.await {
190 Ok(result) => result,
191 Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
192 }?;
193 match self.socket_handle.await {
194 Ok(result) => result,
195 Err(_) => unreachable!("Socket loop does not panic, and is not cancelled"),
196 }
197 }
198}
199
200pub(crate) type WebSocketStreamType = WebSocketStream<MaybeTlsStream<TcpStream>>;
201type SocketSink = SplitSink<WebSocketStreamType, Message>;
202
203enum LoopState {
204 Running,
205 Error(Error),
206 Closed,
207}
208
209#[cfg_attr(feature = "tracing", instrument)]
210async fn socket_loop(
211 mut receiver: mpsc::Receiver<TxChannelPayload>,
212 mut sender: mpsc::Sender<Message>,
213 socket: WebSocketStreamType,
214) -> Result<(), Error> {
215 let mut state = LoopState::Running;
216 let (mut sink, mut stream) = socket.split();
217 while matches!(state, LoopState::Running) {
218 state = select! {
219 outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
220 incoming_message = stream.next() => socket_message_received( incoming_message,&mut sender, &mut sink).await,
221 () = sleep(Duration::from_secs(2)) => send_ping(&mut sink).await,
222 };
223 }
224 match state {
225 LoopState::Error(e) => Err(e),
226 LoopState::Closed => Ok(()),
227 LoopState::Running => unreachable!("We only exit when closed or errored"),
228 }
229}
230
231#[cfg_attr(feature = "tracing", instrument)]
232async fn send_socket_message(
233 message: Option<TxChannelPayload>,
234 sink: &mut SocketSink,
235) -> LoopState {
236 if let Some(message) = message {
237 #[cfg(feature = "tracing")]
238 debug!("Sending message: {:?}", message);
239 let send_result = sink.send(message.message).await.map_err(Error::from);
240 let socket_error = send_result.is_err();
241 match message.response_tx.send(send_result) {
242 Ok(()) => {
243 if socket_error {
244 LoopState::Error(Error::WebsocketClosed)
245 } else {
246 LoopState::Running
247 }
248 }
249 Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
250 }
251 } else {
252 #[cfg(feature = "tracing")]
253 error!("Socketeer dropped without closing connection");
254 LoopState::Error(Error::SocketeerDroppedWithoutClosing)
255 }
256}
257
258#[cfg_attr(feature = "tracing", instrument)]
259async fn socket_message_received(
260 message: Option<Result<Message, tungstenite::Error>>,
261 sender: &mut mpsc::Sender<Message>,
262 sink: &mut SocketSink,
263) -> LoopState {
264 const PONG_BYTES: Bytes = Bytes::from_static(b"pong");
265 match message {
266 Some(Ok(message)) => match message {
267 Message::Ping(_) => {
268 let send_result = sink
269 .send(Message::Pong(PONG_BYTES))
270 .await
271 .map_err(Error::from);
272 match send_result {
273 Ok(()) => LoopState::Running,
274 Err(e) => {
275 #[cfg(feature = "tracing")]
276 error!("Error sending Pong: {:?}", e);
277 LoopState::Error(e)
278 }
279 }
280 }
281 Message::Close(_) => {
282 let close_result = sink.close().await;
283 match close_result {
284 Ok(()) => LoopState::Closed,
285 Err(e) => {
286 #[cfg(feature = "tracing")]
287 error!("Error sending Close: {:?}", e);
288 LoopState::Error(Error::from(e))
289 }
290 }
291 }
292 Message::Text(_) | Message::Binary(_) => match sender.send(message).await {
293 Ok(()) => LoopState::Running,
294 Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
295 },
296 _ => LoopState::Running,
297 },
298 Some(Err(e)) => {
299 #[cfg(feature = "tracing")]
300 error!("Error receiving message: {:?}", e);
301 LoopState::Error(Error::WebsocketError(e))
302 }
303 None => {
304 #[cfg(feature = "tracing")]
305 info!("Websocket Closed, closing rx channel");
306 LoopState::Error(Error::WebsocketClosed)
307 }
308 }
309}
310
311#[cfg_attr(feature = "tracing", instrument)]
312async fn send_ping(sink: &mut SocketSink) -> LoopState {
313 #[cfg(feature = "tracing")]
314 info!("Timeout waiting for message, sending Ping");
315 let result = sink
316 .send(Message::Ping(Bytes::new()))
317 .await
318 .map_err(Error::from);
319 match result {
320 Ok(()) => LoopState::Running,
321 Err(e) => {
322 #[cfg(feature = "tracing")]
323 error!("Error sending Ping: {:?}", e);
324 LoopState::Error(e)
325 }
326 }
327}
328
329#[cfg(test)]
330mod tests {
331 use super::*;
332 use tokio::time::sleep;
333
334 #[tokio::test]
335 async fn test_server_startup() {
336 let _server_address = get_mock_address(echo_server).await;
337 }
338
339 #[tokio::test]
340 async fn test_connection() {
341 let server_address = get_mock_address(echo_server).await;
342 let _socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
343 Socketeer::connect(&format!("ws://{server_address}",))
344 .await
345 .unwrap();
346 }
347
348 #[tokio::test]
349 async fn test_bad_url() {
350 let error: Result<Socketeer<EchoControlMessage, EchoControlMessage>, Error> =
351 Socketeer::connect("Not a URL").await;
352 assert!(matches!(error.unwrap_err(), Error::UrlParse { .. }));
353 }
354
355 #[tokio::test]
356 async fn test_send_receive() {
357 let server_address = get_mock_address(echo_server).await;
358 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
359 Socketeer::connect(&format!("ws://{server_address}",))
360 .await
361 .unwrap();
362 let message = EchoControlMessage::Message("Hello".to_string());
363 socketeer.send(message.clone()).await.unwrap();
364 let received_message = socketeer.next_message().await.unwrap();
365 assert_eq!(message, received_message);
366 }
367
368 #[tokio::test]
369 async fn test_ping_request() {
370 let server_address = get_mock_address(echo_server).await;
371 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
372 Socketeer::connect(&format!("ws://{server_address}",))
373 .await
374 .unwrap();
375 let ping_request = EchoControlMessage::SendPing;
376 socketeer.send(ping_request).await.unwrap();
377 let message = EchoControlMessage::Message("Hello".to_string());
379 socketeer.send(message.clone()).await.unwrap();
380 let received_message = socketeer.next_message().await.unwrap();
381 assert_eq!(received_message, message);
382 sleep(Duration::from_millis(2200)).await;
384 socketeer.close_connection().await.unwrap();
386 }
387
388 #[tokio::test]
389 async fn test_reconnection() {
390 let server_address = get_mock_address(echo_server).await;
391 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
392 Socketeer::connect(&format!("ws://{server_address}",))
393 .await
394 .unwrap();
395 let message = EchoControlMessage::Message("Hello".to_string());
396 socketeer.send(message.clone()).await.unwrap();
397 let received_message = socketeer.next_message().await.unwrap();
398 assert_eq!(message, received_message);
399 socketeer = socketeer.reconnect().await.unwrap();
400 let message = EchoControlMessage::Message("Hello".to_string());
401 socketeer.send(message.clone()).await.unwrap();
402 let received_message = socketeer.next_message().await.unwrap();
403 assert_eq!(message, received_message);
404 socketeer.close_connection().await.unwrap();
405 }
406
407 #[tokio::test]
408 async fn test_closed_socket() {
409 let server_address = get_mock_address(echo_server).await;
410 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
411 Socketeer::connect(&format!("ws://{server_address}",))
412 .await
413 .unwrap();
414 let close_request = EchoControlMessage::Close;
415 socketeer.send(close_request.clone()).await.unwrap();
416 let response = socketeer.next_message().await;
417 assert!(matches!(response.unwrap_err(), Error::WebsocketClosed));
418 let send_result = socketeer.send(close_request).await;
419 assert!(send_result.is_err());
420 let error = send_result.unwrap_err();
421 println!("Actual Error: {error:#?}");
422 assert!(matches!(error, Error::WebsocketClosed));
423 }
424
425 #[tokio::test]
426 async fn test_close_request() {
427 let server_address = get_mock_address(echo_server).await;
428 let socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
429 Socketeer::connect(&format!("ws://{server_address}",))
430 .await
431 .unwrap();
432 socketeer.close_connection().await.unwrap();
433 }
434}