1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3mod error;
4#[cfg(feature = "mocking")]
5mod mock_server;
6#[cfg(feature = "mocking")]
7pub use mock_server::{EchoControlMessage, echo_server, get_mock_address};
8
9use bytes::Bytes;
10pub use error::Error;
11use futures::{SinkExt, StreamExt, stream::SplitSink};
12use serde::{Deserialize, Serialize};
13use std::{fmt::Debug, time::Duration};
14use tokio::{
15 net::TcpStream,
16 select,
17 sync::{mpsc, oneshot},
18 time::sleep,
19};
20use tokio_tungstenite::{
21 MaybeTlsStream, WebSocketStream, connect_async,
22 tungstenite::{self, Message, Utf8Bytes, protocol::CloseFrame},
23};
24
25#[cfg(feature = "tracing")]
26use tracing::{debug, error, info, instrument, trace};
27use url::Url;
28
29#[derive(Debug)]
30struct TxChannelPayload {
31 message: Message,
32 response_tx: oneshot::Sender<Result<(), Error>>,
33}
34
35#[derive(Debug)]
45pub struct Socketeer<RxMessage, TxMessage, const CHANNEL_SIZE: usize = 4> {
46 url: Url,
47 receiever: mpsc::Receiver<Message>,
48 sender: mpsc::Sender<TxChannelPayload>,
49 socket_handle: tokio::task::JoinHandle<Result<(), Error>>,
50 _rx_message: std::marker::PhantomData<RxMessage>,
51 _tx_message: std::marker::PhantomData<TxMessage>,
52}
53
54impl<
55 RxMessage: for<'a> Deserialize<'a> + Debug,
56 TxMessage: Serialize + Debug,
57 const CHANNEL_SIZE: usize,
58> Socketeer<RxMessage, TxMessage, CHANNEL_SIZE>
59{
60 #[cfg_attr(feature = "tracing", instrument)]
66 pub async fn connect(
67 url: &str,
68 ) -> Result<Socketeer<RxMessage, TxMessage, CHANNEL_SIZE>, Error> {
69 let url = Url::parse(url).map_err(|source| Error::UrlParse {
70 url: url.to_string(),
71 source,
72 })?;
73 #[allow(unused_variables)]
74 let (socket, response) = connect_async(url.as_str()).await?;
75 #[cfg(feature = "tracing")]
76 debug!("Connection Successful, connection info: \n{:#?}", response);
77
78 let (tx_tx, tx_rx) = mpsc::channel::<TxChannelPayload>(CHANNEL_SIZE);
79 let (rx_tx, rx_rx) = mpsc::channel::<Message>(CHANNEL_SIZE);
80
81 let socket_handle = tokio::spawn(async move { socket_loop(tx_rx, rx_tx, socket).await });
82 Ok(Socketeer {
83 url,
84 receiever: rx_rx,
85 sender: tx_tx,
86 socket_handle,
87 _rx_message: std::marker::PhantomData,
88 _tx_message: std::marker::PhantomData,
89 })
90 }
91
92 #[cfg_attr(feature = "tracing", instrument)]
99 pub async fn next_message(&mut self) -> Result<RxMessage, Error> {
100 let Some(message) = self.receiever.recv().await else {
101 return Err(Error::WebsocketClosed);
102 };
103 match message {
104 Message::Text(text) => {
105 #[cfg(feature = "tracing")]
106 trace!("Received text message: {:?}", text);
107 let message = serde_json::from_str(&text)?;
108 Ok(message)
109 }
110 Message::Binary(message) => {
111 #[cfg(feature = "tracing")]
112 trace!("Received binary message: {:?}", message);
113 let message = serde_json::from_slice(&message)?;
114 Ok(message)
115 }
116 _ => Err(Error::UnexpectedMessageType(Box::new(message))),
117 }
118 }
119
120 #[cfg_attr(feature = "tracing", instrument)]
128 pub async fn send(&self, message: TxMessage) -> Result<(), Error> {
129 #[cfg(feature = "tracing")]
130 trace!("Sending message: {:?}", message);
131
132 let (tx, rx) = oneshot::channel::<Result<(), Error>>();
133 let message = serde_json::to_string(&message)?;
134
135 self.sender
136 .send(TxChannelPayload {
137 message: Message::Text(message.into()),
138 response_tx: tx,
139 })
140 .await
141 .map_err(|_| Error::WebsocketClosed)?;
142 match rx.await {
144 Ok(result) => result,
145 Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
146 }
147 }
148
149 #[cfg_attr(feature = "tracing", instrument)]
156 pub async fn reconnect(self) -> Result<Self, Error> {
157 let url = self.url.as_str().to_owned();
158 #[cfg(feature = "tracing")]
159 info!("Reconnecting");
160 match self.close_connection().await {
161 Ok(()) => (),
162 #[allow(unused_variables)]
163 Err(e) => {
164 #[cfg(feature = "tracing")]
165 error!("Socket Loop already stopped: {}", e);
166 }
167 }
168 Self::connect(&url).await
169 }
170
171 #[cfg_attr(feature = "tracing", instrument)]
177 pub async fn close_connection(self) -> Result<(), Error> {
178 #[cfg(feature = "tracing")]
179 debug!("Closing Connection");
180 let (tx, rx) = oneshot::channel::<Result<(), Error>>();
181 self.sender
182 .send(TxChannelPayload {
183 message: Message::Close(Some(CloseFrame {
184 code: tungstenite::protocol::frame::coding::CloseCode::Normal,
185 reason: Utf8Bytes::from_static("Closing Connection"),
186 })),
187 response_tx: tx,
188 })
189 .await
190 .map_err(|_| Error::WebsocketClosed)?;
191 match rx.await {
192 Ok(result) => result,
193 Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
194 }?;
195 match self.socket_handle.await {
196 Ok(result) => result,
197 Err(_) => unreachable!("Socket loop does not panic, and is not cancelled"),
198 }
199 }
200}
201
202pub(crate) type WebSocketStreamType = WebSocketStream<MaybeTlsStream<TcpStream>>;
203type SocketSink = SplitSink<WebSocketStreamType, Message>;
204
205enum LoopState {
206 Running,
207 Error(Error),
208 Closed,
209}
210
211#[cfg_attr(feature = "tracing", instrument)]
212async fn socket_loop(
213 mut receiver: mpsc::Receiver<TxChannelPayload>,
214 mut sender: mpsc::Sender<Message>,
215 socket: WebSocketStreamType,
216) -> Result<(), Error> {
217 let mut state = LoopState::Running;
218 let (mut sink, mut stream) = socket.split();
219 while matches!(state, LoopState::Running) {
220 state = select! {
221 outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
222 incoming_message = stream.next() => socket_message_received( incoming_message,&mut sender, &mut sink).await,
223 () = sleep(Duration::from_secs(2)) => send_ping(&mut sink).await,
224 };
225 }
226 match state {
227 LoopState::Error(e) => Err(e),
228 LoopState::Closed => Ok(()),
229 LoopState::Running => unreachable!("We only exit when closed or errored"),
230 }
231}
232
233#[cfg_attr(feature = "tracing", instrument)]
234async fn send_socket_message(
235 message: Option<TxChannelPayload>,
236 sink: &mut SocketSink,
237) -> LoopState {
238 if let Some(message) = message {
239 #[cfg(feature = "tracing")]
240 debug!("Sending message: {:?}", message);
241 let send_result = sink.send(message.message).await.map_err(Error::from);
242 let socket_error = send_result.is_err();
243 match message.response_tx.send(send_result) {
244 Ok(()) => {
245 if socket_error {
246 LoopState::Error(Error::WebsocketClosed)
247 } else {
248 LoopState::Running
249 }
250 }
251 Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
252 }
253 } else {
254 #[cfg(feature = "tracing")]
255 error!("Socketeer dropped without closing connection");
256 LoopState::Error(Error::SocketeerDroppedWithoutClosing)
257 }
258}
259
260#[cfg_attr(feature = "tracing", instrument)]
261async fn socket_message_received(
262 message: Option<Result<Message, tungstenite::Error>>,
263 sender: &mut mpsc::Sender<Message>,
264 sink: &mut SocketSink,
265) -> LoopState {
266 const PONG_BYTES: Bytes = Bytes::from_static(b"pong");
267 match message {
268 Some(Ok(message)) => match message {
269 Message::Ping(_) => {
270 let send_result = sink
271 .send(Message::Pong(PONG_BYTES))
272 .await
273 .map_err(Error::from);
274 match send_result {
275 Ok(()) => LoopState::Running,
276 Err(e) => {
277 #[cfg(feature = "tracing")]
278 error!("Error sending Pong: {:?}", e);
279 LoopState::Error(e)
280 }
281 }
282 }
283 Message::Close(_) => {
284 let close_result = sink.close().await;
285 match close_result {
286 Ok(()) => LoopState::Closed,
287 Err(e) => {
288 #[cfg(feature = "tracing")]
289 error!("Error sending Close: {:?}", e);
290 LoopState::Error(Error::from(e))
291 }
292 }
293 }
294 Message::Text(_) | Message::Binary(_) => match sender.send(message).await {
295 Ok(()) => LoopState::Running,
296 Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
297 },
298 _ => LoopState::Running,
299 },
300 Some(Err(e)) => {
301 #[cfg(feature = "tracing")]
302 error!("Error receiving message: {:?}", e);
303 LoopState::Error(Error::WebsocketError(e))
304 }
305 None => {
306 #[cfg(feature = "tracing")]
307 info!("Websocket Closed, closing rx channel");
308 LoopState::Error(Error::WebsocketClosed)
309 }
310 }
311}
312
313#[cfg_attr(feature = "tracing", instrument)]
314async fn send_ping(sink: &mut SocketSink) -> LoopState {
315 #[cfg(feature = "tracing")]
316 info!("Timeout waiting for message, sending Ping");
317 let result = sink
318 .send(Message::Ping(Bytes::new()))
319 .await
320 .map_err(Error::from);
321 match result {
322 Ok(()) => LoopState::Running,
323 Err(e) => {
324 #[cfg(feature = "tracing")]
325 error!("Error sending Ping: {:?}", e);
326 LoopState::Error(e)
327 }
328 }
329}
330
331#[cfg(test)]
332mod tests {
333 use super::*;
334 use tokio::time::sleep;
335
336 #[tokio::test]
337 async fn test_server_startup() {
338 let _server_address = get_mock_address(echo_server).await;
339 }
340
341 #[tokio::test]
342 async fn test_connection() {
343 let server_address = get_mock_address(echo_server).await;
344 let _socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
345 Socketeer::connect(&format!("ws://{server_address}",))
346 .await
347 .unwrap();
348 }
349
350 #[tokio::test]
351 async fn test_bad_url() {
352 let error: Result<Socketeer<EchoControlMessage, EchoControlMessage>, Error> =
353 Socketeer::connect("Not a URL").await;
354 assert!(matches!(error.unwrap_err(), Error::UrlParse { .. }));
355 }
356
357 #[tokio::test]
358 async fn test_send_receive() {
359 let server_address = get_mock_address(echo_server).await;
360 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
361 Socketeer::connect(&format!("ws://{server_address}",))
362 .await
363 .unwrap();
364 let message = EchoControlMessage::Message("Hello".to_string());
365 socketeer.send(message.clone()).await.unwrap();
366 let received_message = socketeer.next_message().await.unwrap();
367 assert_eq!(message, received_message);
368 }
369
370 #[tokio::test]
371 async fn test_ping_request() {
372 let server_address = get_mock_address(echo_server).await;
373 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
374 Socketeer::connect(&format!("ws://{server_address}",))
375 .await
376 .unwrap();
377 let ping_request = EchoControlMessage::SendPing;
378 socketeer.send(ping_request).await.unwrap();
379 let message = EchoControlMessage::Message("Hello".to_string());
381 socketeer.send(message.clone()).await.unwrap();
382 let received_message = socketeer.next_message().await.unwrap();
383 assert_eq!(received_message, message);
384 sleep(Duration::from_millis(2200)).await;
386 socketeer.close_connection().await.unwrap();
388 }
389
390 #[tokio::test]
391 async fn test_reconnection() {
392 let server_address = get_mock_address(echo_server).await;
393 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
394 Socketeer::connect(&format!("ws://{server_address}",))
395 .await
396 .unwrap();
397 let message = EchoControlMessage::Message("Hello".to_string());
398 socketeer.send(message.clone()).await.unwrap();
399 let received_message = socketeer.next_message().await.unwrap();
400 assert_eq!(message, received_message);
401 socketeer = socketeer.reconnect().await.unwrap();
402 let message = EchoControlMessage::Message("Hello".to_string());
403 socketeer.send(message.clone()).await.unwrap();
404 let received_message = socketeer.next_message().await.unwrap();
405 assert_eq!(message, received_message);
406 socketeer.close_connection().await.unwrap();
407 }
408
409 #[tokio::test]
410 async fn test_closed_socket() {
411 let server_address = get_mock_address(echo_server).await;
412 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
413 Socketeer::connect(&format!("ws://{server_address}",))
414 .await
415 .unwrap();
416 let close_request = EchoControlMessage::Close;
417 socketeer.send(close_request.clone()).await.unwrap();
418 let response = socketeer.next_message().await;
419 assert!(matches!(response.unwrap_err(), Error::WebsocketClosed));
420 let send_result = socketeer.send(close_request).await;
421 assert!(send_result.is_err());
422 let error = send_result.unwrap_err();
423 println!("Actual Error: {error:#?}");
424 assert!(matches!(error, Error::WebsocketClosed));
425 }
426
427 #[tokio::test]
428 async fn test_close_request() {
429 let server_address = get_mock_address(echo_server).await;
430 let socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
431 Socketeer::connect(&format!("ws://{server_address}",))
432 .await
433 .unwrap();
434 socketeer.close_connection().await.unwrap();
435 }
436}