socketeer/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3mod error;
4#[cfg(feature = "mocking")]
5mod mock_server;
6#[cfg(feature = "mocking")]
7pub use mock_server::{EchoControlMessage, echo_server, get_mock_address};
8
9use bytes::Bytes;
10pub use error::Error;
11use futures::{SinkExt, StreamExt, stream::SplitSink};
12use serde::{Deserialize, Serialize};
13use std::{fmt::Debug, time::Duration};
14use tokio::{
15    net::TcpStream,
16    select,
17    sync::{mpsc, oneshot},
18    time::sleep,
19};
20use tokio_tungstenite::{
21    MaybeTlsStream, WebSocketStream, connect_async,
22    tungstenite::{self, Message, Utf8Bytes, protocol::CloseFrame},
23};
24
25#[cfg(feature = "tracing")]
26use tracing::{debug, error, info, instrument, trace};
27use url::Url;
28
29#[derive(Debug)]
30struct TxChannelPayload {
31    message: Message,
32    response_tx: oneshot::Sender<Result<(), Error>>,
33}
34
35/// A WebSocket client that manages the connection to a WebSocket server.
36/// The client can send and receive messages, and will transparently handle protocol messages.
37/// # Type Parameters
38/// - `RxMessage`: The type of message that the client will receive from the server.
39/// - `TxMessage`: The type of message that the client will send to the server.
40/// - `CHANNEL_SIZE`: The size of the internal channels used to communicate between
41///     the task managing the WebSocket connection and the client.
42#[derive(Debug)]
43pub struct Socketeer<RxMessage, TxMessage, const CHANNEL_SIZE: usize = 4> {
44    url: Url,
45    receiever: mpsc::Receiver<Message>,
46    sender: mpsc::Sender<TxChannelPayload>,
47    socket_handle: tokio::task::JoinHandle<Result<(), Error>>,
48    _rx_message: std::marker::PhantomData<RxMessage>,
49    _tx_message: std::marker::PhantomData<TxMessage>,
50}
51
52impl<
53    RxMessage: for<'a> Deserialize<'a> + Debug,
54    TxMessage: Serialize + Debug,
55    const CHANNEL_SIZE: usize,
56> Socketeer<RxMessage, TxMessage, CHANNEL_SIZE>
57{
58    /// Create a `Socketeer` connected to the provided URL.
59    /// Once connected, Socketeer manages the underlying WebSocket connection, transparently handling protocol messages.
60    /// # Errors
61    /// - If the URL cannot be parsed
62    /// - If the WebSocket connection to the requested URL fails
63    #[cfg_attr(feature = "tracing", instrument)]
64    pub async fn connect(
65        url: &str,
66    ) -> Result<Socketeer<RxMessage, TxMessage, CHANNEL_SIZE>, Error> {
67        let url = Url::parse(url).map_err(|source| Error::UrlParse {
68            url: url.to_string(),
69            source,
70        })?;
71        #[allow(unused_variables)]
72        let (socket, response) = connect_async(url.as_str()).await?;
73        #[cfg(feature = "tracing")]
74        debug!("Connection Successful, connection info: \n{:#?}", response);
75
76        let (tx_tx, tx_rx) = mpsc::channel::<TxChannelPayload>(CHANNEL_SIZE);
77        let (rx_tx, rx_rx) = mpsc::channel::<Message>(CHANNEL_SIZE);
78
79        let socket_handle = tokio::spawn(async move { socket_loop(tx_rx, rx_tx, socket).await });
80        Ok(Socketeer {
81            url,
82            receiever: rx_rx,
83            sender: tx_tx,
84            socket_handle,
85            _rx_message: std::marker::PhantomData,
86            _tx_message: std::marker::PhantomData,
87        })
88    }
89
90    /// Wait for the next parsed message from the WebSocket connection.
91    ///
92    /// # Errors
93    ///
94    /// - If the WebSocket connection is closed or otherwise errored
95    /// - If the message cannot be deserialized
96    #[cfg_attr(feature = "tracing", instrument)]
97    pub async fn next_message(&mut self) -> Result<RxMessage, Error> {
98        let Some(message) = self.receiever.recv().await else {
99            return Err(Error::WebsocketClosed);
100        };
101        match message {
102            Message::Text(text) => {
103                #[cfg(feature = "tracing")]
104                trace!("Received text message: {:?}", text);
105                let message = serde_json::from_str(&text)?;
106                Ok(message)
107            }
108            Message::Binary(message) => {
109                #[cfg(feature = "tracing")]
110                trace!("Received binary message: {:?}", message);
111                let message = serde_json::from_slice(&message)?;
112                Ok(message)
113            }
114            _ => Err(Error::UnexpectedMessageType(message)),
115        }
116    }
117
118    /// Send a message to the WebSocket connection.
119    /// This function will wait for the message to be sent before returning.
120    ///
121    /// # Errors
122    ///
123    /// - If the message cannot be serialized
124    /// - If the WebSocket connection is closed, or otherwise errored
125    #[cfg_attr(feature = "tracing", instrument)]
126    pub async fn send(&self, message: TxMessage) -> Result<(), Error> {
127        #[cfg(feature = "tracing")]
128        trace!("Sending message: {:?}", message);
129
130        let (tx, rx) = oneshot::channel::<Result<(), Error>>();
131        let message = serde_json::to_string(&message)?;
132
133        self.sender
134            .send(TxChannelPayload {
135                message: Message::Text(message.into()),
136                response_tx: tx,
137            })
138            .await
139            .map_err(|_| Error::WebsocketClosed)?;
140        // We'll ensure that we always respond before dropping the tx channel
141        match rx.await {
142            Ok(result) => result,
143            Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
144        }
145    }
146
147    /// Consume self, closing down any remaining send/recieve, and return a new Socketeer instance if successful
148    /// This function attempts to close the connection gracefully before returning,
149    /// but will not return an error if the connection is already closed,
150    /// as its intended use is to re-establish a failed connection.
151    /// # Errors
152    /// - If a new connection cannot be established
153    #[cfg_attr(feature = "tracing", instrument)]
154    pub async fn reconnect(self) -> Result<Self, Error> {
155        let url = self.url.as_str().to_owned();
156        #[cfg(feature = "tracing")]
157        info!("Reconnecting");
158        match self.close_connection().await {
159            Ok(()) => (),
160            #[allow(unused_variables)]
161            Err(e) => {
162                #[cfg(feature = "tracing")]
163                error!("Socket Loop already stopped: {}", e);
164            }
165        }
166        Self::connect(&url).await
167    }
168
169    /// Close the WebSocket connection gracefully.
170    /// This function will wait for the connection to close before returning.
171    /// # Errors
172    /// - If the WebSocket connection is already closed
173    /// - If the WebSocket connection cannot be closed
174    #[cfg_attr(feature = "tracing", instrument)]
175    pub async fn close_connection(self) -> Result<(), Error> {
176        #[cfg(feature = "tracing")]
177        debug!("Closing Connection");
178        let (tx, rx) = oneshot::channel::<Result<(), Error>>();
179        self.sender
180            .send(TxChannelPayload {
181                message: Message::Close(Some(CloseFrame {
182                    code: tungstenite::protocol::frame::coding::CloseCode::Normal,
183                    reason: Utf8Bytes::from_static("Closing Connection"),
184                })),
185                response_tx: tx,
186            })
187            .await
188            .map_err(|_| Error::WebsocketClosed)?;
189        match rx.await {
190            Ok(result) => result,
191            Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
192        }?;
193        match self.socket_handle.await {
194            Ok(result) => result,
195            Err(_) => unreachable!("Socket loop does not panic, and is not cancelled"),
196        }
197    }
198}
199
200pub(crate) type WebSocketStreamType = WebSocketStream<MaybeTlsStream<TcpStream>>;
201type SocketSink = SplitSink<WebSocketStreamType, Message>;
202
203enum LoopState {
204    Running,
205    Error(Error),
206    Closed,
207}
208
209#[cfg_attr(feature = "tracing", instrument)]
210async fn socket_loop(
211    mut receiver: mpsc::Receiver<TxChannelPayload>,
212    mut sender: mpsc::Sender<Message>,
213    socket: WebSocketStreamType,
214) -> Result<(), Error> {
215    let mut state = LoopState::Running;
216    let (mut sink, mut stream) = socket.split();
217    while matches!(state, LoopState::Running) {
218        state = select! {
219            outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
220            incoming_message = stream.next() => socket_message_received( incoming_message,&mut sender, &mut sink).await,
221            () = sleep(Duration::from_secs(2)) => send_ping(&mut sink).await,
222        };
223    }
224    match state {
225        LoopState::Error(e) => Err(e),
226        LoopState::Closed => Ok(()),
227        LoopState::Running => unreachable!("We only exit when closed or errored"),
228    }
229}
230
231#[cfg_attr(feature = "tracing", instrument)]
232async fn send_socket_message(
233    message: Option<TxChannelPayload>,
234    sink: &mut SocketSink,
235) -> LoopState {
236    if let Some(message) = message {
237        #[cfg(feature = "tracing")]
238        debug!("Sending message: {:?}", message);
239        let send_result = sink.send(message.message).await.map_err(Error::from);
240        let socket_error = send_result.is_err();
241        match message.response_tx.send(send_result) {
242            Ok(()) => {
243                if socket_error {
244                    LoopState::Error(Error::WebsocketClosed)
245                } else {
246                    LoopState::Running
247                }
248            }
249            Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
250        }
251    } else {
252        #[cfg(feature = "tracing")]
253        error!("Socketeer dropped without closing connection");
254        LoopState::Error(Error::SocketeerDroppedWithoutClosing)
255    }
256}
257
258#[cfg_attr(feature = "tracing", instrument)]
259async fn socket_message_received(
260    message: Option<Result<Message, tungstenite::Error>>,
261    sender: &mut mpsc::Sender<Message>,
262    sink: &mut SocketSink,
263) -> LoopState {
264    const PONG_BYTES: Bytes = Bytes::from_static(b"pong");
265    match message {
266        Some(Ok(message)) => match message {
267            Message::Ping(_) => {
268                let send_result = sink
269                    .send(Message::Pong(PONG_BYTES))
270                    .await
271                    .map_err(Error::from);
272                match send_result {
273                    Ok(()) => LoopState::Running,
274                    Err(e) => {
275                        #[cfg(feature = "tracing")]
276                        error!("Error sending Pong: {:?}", e);
277                        LoopState::Error(e)
278                    }
279                }
280            }
281            Message::Close(_) => {
282                let close_result = sink.close().await;
283                match close_result {
284                    Ok(()) => LoopState::Closed,
285                    Err(e) => {
286                        #[cfg(feature = "tracing")]
287                        error!("Error sending Close: {:?}", e);
288                        LoopState::Error(Error::from(e))
289                    }
290                }
291            }
292            Message::Text(_) | Message::Binary(_) => match sender.send(message).await {
293                Ok(()) => LoopState::Running,
294                Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
295            },
296            _ => LoopState::Running,
297        },
298        Some(Err(e)) => {
299            #[cfg(feature = "tracing")]
300            error!("Error receiving message: {:?}", e);
301            LoopState::Error(Error::WebsocketError(e))
302        }
303        None => {
304            #[cfg(feature = "tracing")]
305            info!("Websocket Closed, closing rx channel");
306            LoopState::Error(Error::WebsocketClosed)
307        }
308    }
309}
310
311#[cfg_attr(feature = "tracing", instrument)]
312async fn send_ping(sink: &mut SocketSink) -> LoopState {
313    #[cfg(feature = "tracing")]
314    info!("Timeout waiting for message, sending Ping");
315    let result = sink
316        .send(Message::Ping(Bytes::new()))
317        .await
318        .map_err(Error::from);
319    match result {
320        Ok(()) => LoopState::Running,
321        Err(e) => {
322            #[cfg(feature = "tracing")]
323            error!("Error sending Ping: {:?}", e);
324            LoopState::Error(e)
325        }
326    }
327}
328
329#[cfg(test)]
330mod tests {
331    use super::*;
332    use tokio::time::sleep;
333
334    #[tokio::test]
335    async fn test_server_startup() {
336        let _server_address = get_mock_address(echo_server).await;
337    }
338
339    #[tokio::test]
340    async fn test_connection() {
341        let server_address = get_mock_address(echo_server).await;
342        let _socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
343            Socketeer::connect(&format!("ws://{server_address}",))
344                .await
345                .unwrap();
346    }
347
348    #[tokio::test]
349    async fn test_bad_url() {
350        let error: Result<Socketeer<EchoControlMessage, EchoControlMessage>, Error> =
351            Socketeer::connect("Not a URL").await;
352        assert!(matches!(error.unwrap_err(), Error::UrlParse { .. }));
353    }
354
355    #[tokio::test]
356    async fn test_send_receive() {
357        let server_address = get_mock_address(echo_server).await;
358        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
359            Socketeer::connect(&format!("ws://{server_address}",))
360                .await
361                .unwrap();
362        let message = EchoControlMessage::Message("Hello".to_string());
363        socketeer.send(message.clone()).await.unwrap();
364        let received_message = socketeer.next_message().await.unwrap();
365        assert_eq!(message, received_message);
366    }
367
368    #[tokio::test]
369    async fn test_ping_request() {
370        let server_address = get_mock_address(echo_server).await;
371        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
372            Socketeer::connect(&format!("ws://{server_address}",))
373                .await
374                .unwrap();
375        let ping_request = EchoControlMessage::SendPing;
376        socketeer.send(ping_request).await.unwrap();
377        // The server will respond with a ping request, which Socketeer will transparently respond to
378        let message = EchoControlMessage::Message("Hello".to_string());
379        socketeer.send(message.clone()).await.unwrap();
380        let received_message = socketeer.next_message().await.unwrap();
381        assert_eq!(received_message, message);
382        // We should send a ping in here
383        sleep(Duration::from_millis(2200)).await;
384        // Ensure everything shuts down so we exercize the ping functionality fully
385        socketeer.close_connection().await.unwrap();
386    }
387
388    #[tokio::test]
389    async fn test_reconnection() {
390        let server_address = get_mock_address(echo_server).await;
391        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
392            Socketeer::connect(&format!("ws://{server_address}",))
393                .await
394                .unwrap();
395        let message = EchoControlMessage::Message("Hello".to_string());
396        socketeer.send(message.clone()).await.unwrap();
397        let received_message = socketeer.next_message().await.unwrap();
398        assert_eq!(message, received_message);
399        socketeer = socketeer.reconnect().await.unwrap();
400        let message = EchoControlMessage::Message("Hello".to_string());
401        socketeer.send(message.clone()).await.unwrap();
402        let received_message = socketeer.next_message().await.unwrap();
403        assert_eq!(message, received_message);
404        socketeer.close_connection().await.unwrap();
405    }
406
407    #[tokio::test]
408    async fn test_closed_socket() {
409        let server_address = get_mock_address(echo_server).await;
410        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
411            Socketeer::connect(&format!("ws://{server_address}",))
412                .await
413                .unwrap();
414        let close_request = EchoControlMessage::Close;
415        socketeer.send(close_request.clone()).await.unwrap();
416        let response = socketeer.next_message().await;
417        assert!(matches!(response.unwrap_err(), Error::WebsocketClosed));
418        let send_result = socketeer.send(close_request).await;
419        assert!(send_result.is_err());
420        let error = send_result.unwrap_err();
421        println!("Actual Error: {error:#?}");
422        assert!(matches!(error, Error::WebsocketClosed));
423    }
424
425    #[tokio::test]
426    async fn test_close_request() {
427        let server_address = get_mock_address(echo_server).await;
428        let socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
429            Socketeer::connect(&format!("ws://{server_address}",))
430                .await
431                .unwrap();
432        socketeer.close_connection().await.unwrap();
433    }
434}