socketeer/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3mod error;
4#[cfg(feature = "mocking")]
5mod mock_server;
6#[cfg(feature = "mocking")]
7pub use mock_server::{EchoControlMessage, echo_server, get_mock_address};
8
9use bytes::Bytes;
10pub use error::Error;
11use futures::{SinkExt, StreamExt, stream::SplitSink};
12use serde::{Deserialize, Serialize};
13use std::{fmt::Debug, time::Duration};
14use tokio::{
15    net::TcpStream,
16    select,
17    sync::{mpsc, oneshot},
18    time::sleep,
19};
20use tokio_tungstenite::{
21    MaybeTlsStream, WebSocketStream, connect_async,
22    tungstenite::{self, Message, Utf8Bytes, protocol::CloseFrame},
23};
24
25#[cfg(feature = "tracing")]
26use tracing::{debug, error, info, instrument, trace};
27use url::Url;
28
29#[derive(Debug)]
30struct TxChannelPayload {
31    message: Message,
32    response_tx: oneshot::Sender<Result<(), Error>>,
33}
34
35/// A WebSocket client that manages the connection to a WebSocket server.
36/// The client can send and receive messages, and will transparently handle protocol messages.
37///
38/// # Type Parameters
39///
40/// - `RxMessage`: The type of message that the client will receive from the server.
41/// - `TxMessage`: The type of message that the client will send to the server.
42/// - `CHANNEL_SIZE`: The size of the internal channels used to communicate between
43///   the task managing the WebSocket connection and the client.
44#[derive(Debug)]
45pub struct Socketeer<RxMessage, TxMessage, const CHANNEL_SIZE: usize = 4> {
46    url: Url,
47    receiever: mpsc::Receiver<Message>,
48    sender: mpsc::Sender<TxChannelPayload>,
49    socket_handle: tokio::task::JoinHandle<Result<(), Error>>,
50    _rx_message: std::marker::PhantomData<RxMessage>,
51    _tx_message: std::marker::PhantomData<TxMessage>,
52}
53
54impl<
55    RxMessage: for<'a> Deserialize<'a> + Debug,
56    TxMessage: Serialize + Debug,
57    const CHANNEL_SIZE: usize,
58> Socketeer<RxMessage, TxMessage, CHANNEL_SIZE>
59{
60    /// Create a `Socketeer` connected to the provided URL.
61    /// Once connected, Socketeer manages the underlying WebSocket connection, transparently handling protocol messages.
62    /// # Errors
63    /// - If the URL cannot be parsed
64    /// - If the WebSocket connection to the requested URL fails
65    #[cfg_attr(feature = "tracing", instrument)]
66    pub async fn connect(
67        url: &str,
68    ) -> Result<Socketeer<RxMessage, TxMessage, CHANNEL_SIZE>, Error> {
69        let url = Url::parse(url).map_err(|source| Error::UrlParse {
70            url: url.to_string(),
71            source,
72        })?;
73        #[allow(unused_variables)]
74        let (socket, response) = connect_async(url.as_str()).await?;
75        #[cfg(feature = "tracing")]
76        debug!("Connection Successful, connection info: \n{:#?}", response);
77
78        let (tx_tx, tx_rx) = mpsc::channel::<TxChannelPayload>(CHANNEL_SIZE);
79        let (rx_tx, rx_rx) = mpsc::channel::<Message>(CHANNEL_SIZE);
80
81        let socket_handle = tokio::spawn(async move { socket_loop(tx_rx, rx_tx, socket).await });
82        Ok(Socketeer {
83            url,
84            receiever: rx_rx,
85            sender: tx_tx,
86            socket_handle,
87            _rx_message: std::marker::PhantomData,
88            _tx_message: std::marker::PhantomData,
89        })
90    }
91
92    /// Wait for the next parsed message from the WebSocket connection.
93    ///
94    /// # Errors
95    ///
96    /// - If the WebSocket connection is closed or otherwise errored
97    /// - If the message cannot be deserialized
98    #[cfg_attr(feature = "tracing", instrument)]
99    pub async fn next_message(&mut self) -> Result<RxMessage, Error> {
100        let Some(message) = self.receiever.recv().await else {
101            return Err(Error::WebsocketClosed);
102        };
103        match message {
104            Message::Text(text) => {
105                #[cfg(feature = "tracing")]
106                trace!("Received text message: {:?}", text);
107                let message = serde_json::from_str(&text)?;
108                Ok(message)
109            }
110            Message::Binary(message) => {
111                #[cfg(feature = "tracing")]
112                trace!("Received binary message: {:?}", message);
113                let message = serde_json::from_slice(&message)?;
114                Ok(message)
115            }
116            _ => Err(Error::UnexpectedMessageType(Box::new(message))),
117        }
118    }
119
120    /// Send a message to the WebSocket connection.
121    /// This function will wait for the message to be sent before returning.
122    ///
123    /// # Errors
124    ///
125    /// - If the message cannot be serialized
126    /// - If the WebSocket connection is closed, or otherwise errored
127    #[cfg_attr(feature = "tracing", instrument)]
128    pub async fn send(&self, message: TxMessage) -> Result<(), Error> {
129        #[cfg(feature = "tracing")]
130        trace!("Sending message: {:?}", message);
131
132        let (tx, rx) = oneshot::channel::<Result<(), Error>>();
133        let message = serde_json::to_string(&message)?;
134
135        self.sender
136            .send(TxChannelPayload {
137                message: Message::Text(message.into()),
138                response_tx: tx,
139            })
140            .await
141            .map_err(|_| Error::WebsocketClosed)?;
142        // We'll ensure that we always respond before dropping the tx channel
143        match rx.await {
144            Ok(result) => result,
145            Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
146        }
147    }
148
149    /// Consume self, closing down any remaining send/recieve, and return a new Socketeer instance if successful
150    /// This function attempts to close the connection gracefully before returning,
151    /// but will not return an error if the connection is already closed,
152    /// as its intended use is to re-establish a failed connection.
153    /// # Errors
154    /// - If a new connection cannot be established
155    #[cfg_attr(feature = "tracing", instrument)]
156    pub async fn reconnect(self) -> Result<Self, Error> {
157        let url = self.url.as_str().to_owned();
158        #[cfg(feature = "tracing")]
159        info!("Reconnecting");
160        match self.close_connection().await {
161            Ok(()) => (),
162            #[allow(unused_variables)]
163            Err(e) => {
164                #[cfg(feature = "tracing")]
165                error!("Socket Loop already stopped: {}", e);
166            }
167        }
168        Self::connect(&url).await
169    }
170
171    /// Close the WebSocket connection gracefully.
172    /// This function will wait for the connection to close before returning.
173    /// # Errors
174    /// - If the WebSocket connection is already closed
175    /// - If the WebSocket connection cannot be closed
176    #[cfg_attr(feature = "tracing", instrument)]
177    pub async fn close_connection(self) -> Result<(), Error> {
178        #[cfg(feature = "tracing")]
179        debug!("Closing Connection");
180        let (tx, rx) = oneshot::channel::<Result<(), Error>>();
181        self.sender
182            .send(TxChannelPayload {
183                message: Message::Close(Some(CloseFrame {
184                    code: tungstenite::protocol::frame::coding::CloseCode::Normal,
185                    reason: Utf8Bytes::from_static("Closing Connection"),
186                })),
187                response_tx: tx,
188            })
189            .await
190            .map_err(|_| Error::WebsocketClosed)?;
191        match rx.await {
192            Ok(result) => result,
193            Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
194        }?;
195        match self.socket_handle.await {
196            Ok(result) => result,
197            Err(_) => unreachable!("Socket loop does not panic, and is not cancelled"),
198        }
199    }
200}
201
202pub(crate) type WebSocketStreamType = WebSocketStream<MaybeTlsStream<TcpStream>>;
203type SocketSink = SplitSink<WebSocketStreamType, Message>;
204
205enum LoopState {
206    Running,
207    Error(Error),
208    Closed,
209}
210
211#[cfg_attr(feature = "tracing", instrument)]
212async fn socket_loop(
213    mut receiver: mpsc::Receiver<TxChannelPayload>,
214    mut sender: mpsc::Sender<Message>,
215    socket: WebSocketStreamType,
216) -> Result<(), Error> {
217    let mut state = LoopState::Running;
218    let (mut sink, mut stream) = socket.split();
219    while matches!(state, LoopState::Running) {
220        state = select! {
221            outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
222            incoming_message = stream.next() => socket_message_received( incoming_message,&mut sender, &mut sink).await,
223            () = sleep(Duration::from_secs(2)) => send_ping(&mut sink).await,
224        };
225    }
226    match state {
227        LoopState::Error(e) => Err(e),
228        LoopState::Closed => Ok(()),
229        LoopState::Running => unreachable!("We only exit when closed or errored"),
230    }
231}
232
233#[cfg_attr(feature = "tracing", instrument)]
234async fn send_socket_message(
235    message: Option<TxChannelPayload>,
236    sink: &mut SocketSink,
237) -> LoopState {
238    if let Some(message) = message {
239        #[cfg(feature = "tracing")]
240        debug!("Sending message: {:?}", message);
241        let send_result = sink.send(message.message).await.map_err(Error::from);
242        let socket_error = send_result.is_err();
243        match message.response_tx.send(send_result) {
244            Ok(()) => {
245                if socket_error {
246                    LoopState::Error(Error::WebsocketClosed)
247                } else {
248                    LoopState::Running
249                }
250            }
251            Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
252        }
253    } else {
254        #[cfg(feature = "tracing")]
255        error!("Socketeer dropped without closing connection");
256        LoopState::Error(Error::SocketeerDroppedWithoutClosing)
257    }
258}
259
260#[cfg_attr(feature = "tracing", instrument)]
261async fn socket_message_received(
262    message: Option<Result<Message, tungstenite::Error>>,
263    sender: &mut mpsc::Sender<Message>,
264    sink: &mut SocketSink,
265) -> LoopState {
266    const PONG_BYTES: Bytes = Bytes::from_static(b"pong");
267    match message {
268        Some(Ok(message)) => match message {
269            Message::Ping(_) => {
270                let send_result = sink
271                    .send(Message::Pong(PONG_BYTES))
272                    .await
273                    .map_err(Error::from);
274                match send_result {
275                    Ok(()) => LoopState::Running,
276                    Err(e) => {
277                        #[cfg(feature = "tracing")]
278                        error!("Error sending Pong: {:?}", e);
279                        LoopState::Error(e)
280                    }
281                }
282            }
283            Message::Close(_) => {
284                let close_result = sink.close().await;
285                match close_result {
286                    Ok(()) => LoopState::Closed,
287                    Err(e) => {
288                        #[cfg(feature = "tracing")]
289                        error!("Error sending Close: {:?}", e);
290                        LoopState::Error(Error::from(e))
291                    }
292                }
293            }
294            Message::Text(_) | Message::Binary(_) => match sender.send(message).await {
295                Ok(()) => LoopState::Running,
296                Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
297            },
298            _ => LoopState::Running,
299        },
300        Some(Err(e)) => {
301            #[cfg(feature = "tracing")]
302            error!("Error receiving message: {:?}", e);
303            LoopState::Error(Error::WebsocketError(e))
304        }
305        None => {
306            #[cfg(feature = "tracing")]
307            info!("Websocket Closed, closing rx channel");
308            LoopState::Error(Error::WebsocketClosed)
309        }
310    }
311}
312
313#[cfg_attr(feature = "tracing", instrument)]
314async fn send_ping(sink: &mut SocketSink) -> LoopState {
315    #[cfg(feature = "tracing")]
316    info!("Timeout waiting for message, sending Ping");
317    let result = sink
318        .send(Message::Ping(Bytes::new()))
319        .await
320        .map_err(Error::from);
321    match result {
322        Ok(()) => LoopState::Running,
323        Err(e) => {
324            #[cfg(feature = "tracing")]
325            error!("Error sending Ping: {:?}", e);
326            LoopState::Error(e)
327        }
328    }
329}
330
331#[cfg(test)]
332mod tests {
333    use super::*;
334    use tokio::time::sleep;
335
336    #[tokio::test]
337    async fn test_server_startup() {
338        let _server_address = get_mock_address(echo_server).await;
339    }
340
341    #[tokio::test]
342    async fn test_connection() {
343        let server_address = get_mock_address(echo_server).await;
344        let _socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
345            Socketeer::connect(&format!("ws://{server_address}",))
346                .await
347                .unwrap();
348    }
349
350    #[tokio::test]
351    async fn test_bad_url() {
352        let error: Result<Socketeer<EchoControlMessage, EchoControlMessage>, Error> =
353            Socketeer::connect("Not a URL").await;
354        assert!(matches!(error.unwrap_err(), Error::UrlParse { .. }));
355    }
356
357    #[tokio::test]
358    async fn test_send_receive() {
359        let server_address = get_mock_address(echo_server).await;
360        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
361            Socketeer::connect(&format!("ws://{server_address}",))
362                .await
363                .unwrap();
364        let message = EchoControlMessage::Message("Hello".to_string());
365        socketeer.send(message.clone()).await.unwrap();
366        let received_message = socketeer.next_message().await.unwrap();
367        assert_eq!(message, received_message);
368    }
369
370    #[tokio::test]
371    async fn test_ping_request() {
372        let server_address = get_mock_address(echo_server).await;
373        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
374            Socketeer::connect(&format!("ws://{server_address}",))
375                .await
376                .unwrap();
377        let ping_request = EchoControlMessage::SendPing;
378        socketeer.send(ping_request).await.unwrap();
379        // The server will respond with a ping request, which Socketeer will transparently respond to
380        let message = EchoControlMessage::Message("Hello".to_string());
381        socketeer.send(message.clone()).await.unwrap();
382        let received_message = socketeer.next_message().await.unwrap();
383        assert_eq!(received_message, message);
384        // We should send a ping in here
385        sleep(Duration::from_millis(2200)).await;
386        // Ensure everything shuts down so we exercize the ping functionality fully
387        socketeer.close_connection().await.unwrap();
388    }
389
390    #[tokio::test]
391    async fn test_reconnection() {
392        let server_address = get_mock_address(echo_server).await;
393        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
394            Socketeer::connect(&format!("ws://{server_address}",))
395                .await
396                .unwrap();
397        let message = EchoControlMessage::Message("Hello".to_string());
398        socketeer.send(message.clone()).await.unwrap();
399        let received_message = socketeer.next_message().await.unwrap();
400        assert_eq!(message, received_message);
401        socketeer = socketeer.reconnect().await.unwrap();
402        let message = EchoControlMessage::Message("Hello".to_string());
403        socketeer.send(message.clone()).await.unwrap();
404        let received_message = socketeer.next_message().await.unwrap();
405        assert_eq!(message, received_message);
406        socketeer.close_connection().await.unwrap();
407    }
408
409    #[tokio::test]
410    async fn test_closed_socket() {
411        let server_address = get_mock_address(echo_server).await;
412        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
413            Socketeer::connect(&format!("ws://{server_address}",))
414                .await
415                .unwrap();
416        let close_request = EchoControlMessage::Close;
417        socketeer.send(close_request.clone()).await.unwrap();
418        let response = socketeer.next_message().await;
419        assert!(matches!(response.unwrap_err(), Error::WebsocketClosed));
420        let send_result = socketeer.send(close_request).await;
421        assert!(send_result.is_err());
422        let error = send_result.unwrap_err();
423        println!("Actual Error: {error:#?}");
424        assert!(matches!(error, Error::WebsocketClosed));
425    }
426
427    #[tokio::test]
428    async fn test_close_request() {
429        let server_address = get_mock_address(echo_server).await;
430        let socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
431            Socketeer::connect(&format!("ws://{server_address}",))
432                .await
433                .unwrap();
434        socketeer.close_connection().await.unwrap();
435    }
436}