1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3mod config;
4mod error;
5mod handler;
6#[cfg(feature = "mocking")]
7mod mock_server;
8
9pub use config::ConnectOptions;
10pub use error::Error;
11pub use handler::{ConnectionHandler, HandshakeContext, NoopHandler};
12#[cfg(feature = "mocking")]
13pub use mock_server::{EchoControlMessage, auth_echo_server, echo_server, get_mock_address};
14
15use bytes::Bytes;
16use futures::{SinkExt, StreamExt, stream::SplitSink, stream::SplitStream};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, time::Duration};
19use tokio::{
20 net::TcpStream,
21 select,
22 sync::{mpsc, oneshot},
23 time::sleep,
24};
25use tokio_tungstenite::{
26 MaybeTlsStream, WebSocketStream, connect_async,
27 tungstenite::{self, Message, Utf8Bytes, protocol::CloseFrame},
28};
29
30#[cfg(feature = "tracing")]
31use tracing::{debug, error, info, instrument, trace};
32use url::Url;
33
34#[derive(Debug)]
35struct TxChannelPayload {
36 message: Message,
37 response_tx: oneshot::Sender<Result<(), Error>>,
38}
39
40pub struct Socketeer<RxMessage, TxMessage, Handler = NoopHandler, const CHANNEL_SIZE: usize = 4> {
52 url: Url,
53 options: ConnectOptions,
54 handler: Handler,
55 receiver: mpsc::Receiver<Message>,
56 sender: mpsc::Sender<TxChannelPayload>,
57 socket_handle: tokio::task::JoinHandle<Result<(), Error>>,
58 _rx_message: std::marker::PhantomData<RxMessage>,
59 _tx_message: std::marker::PhantomData<TxMessage>,
60}
61
62impl<RxMessage, TxMessage, Handler, const CHANNEL_SIZE: usize> std::fmt::Debug
63 for Socketeer<RxMessage, TxMessage, Handler, CHANNEL_SIZE>
64{
65 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66 f.debug_struct("Socketeer")
67 .field("url", &self.url)
68 .finish_non_exhaustive()
69 }
70}
71
72impl<
73 RxMessage: for<'a> Deserialize<'a> + Debug,
74 TxMessage: Serialize + Debug,
75 const CHANNEL_SIZE: usize,
76> Socketeer<RxMessage, TxMessage, NoopHandler, CHANNEL_SIZE>
77{
78 #[cfg_attr(feature = "tracing", instrument)]
84 pub async fn connect(url: &str) -> Result<Self, Error> {
85 Self::connect_with(url, ConnectOptions::default()).await
86 }
87
88 #[cfg_attr(feature = "tracing", instrument(skip(options)))]
93 pub async fn connect_with(url: &str, options: ConnectOptions) -> Result<Self, Error> {
94 Socketeer::connect_with_handler(url, options, NoopHandler).await
95 }
96}
97
98impl<
99 RxMessage: for<'a> Deserialize<'a> + Debug,
100 TxMessage: Serialize + Debug,
101 Handler: ConnectionHandler,
102 const CHANNEL_SIZE: usize,
103> Socketeer<RxMessage, TxMessage, Handler, CHANNEL_SIZE>
104{
105 #[cfg_attr(feature = "tracing", instrument(skip(options, handler)))]
115 pub async fn connect_with_handler(
116 url: &str,
117 options: ConnectOptions,
118 mut handler: Handler,
119 ) -> Result<Self, Error> {
120 let url = Url::parse(url).map_err(|source| Error::UrlParse {
121 url: url.to_string(),
122 source,
123 })?;
124
125 let request = options.build_request(&url)?;
126 #[allow(unused_variables)]
127 let (socket, response) = connect_async(request).await?;
128 #[cfg(feature = "tracing")]
129 debug!("Connection Successful, connection info: \n{:#?}", response);
130
131 let (mut sink, mut stream) = socket.split();
132 {
133 let mut ctx = HandshakeContext::new(&mut sink, &mut stream);
134 handler.on_connected(&mut ctx).await?;
135 }
136
137 let keepalive_interval = options.keepalive_interval;
138 let keepalive_message = options.custom_keepalive_message.clone();
139
140 let (tx_tx, tx_rx) = mpsc::channel::<TxChannelPayload>(CHANNEL_SIZE);
141 let (rx_tx, rx_rx) = mpsc::channel::<Message>(CHANNEL_SIZE);
142
143 let socket_handle = tokio::spawn(async move {
144 socket_loop_split(
145 tx_rx,
146 rx_tx,
147 sink,
148 stream,
149 keepalive_interval,
150 keepalive_message,
151 )
152 .await
153 });
154 Ok(Socketeer {
155 url,
156 options,
157 handler,
158 receiver: rx_rx,
159 sender: tx_tx,
160 socket_handle,
161 _rx_message: std::marker::PhantomData,
162 _tx_message: std::marker::PhantomData,
163 })
164 }
165
166 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
173 pub async fn next_message(&mut self) -> Result<RxMessage, Error> {
174 let Some(message) = self.receiver.recv().await else {
175 return Err(Error::WebsocketClosed);
176 };
177 match message {
178 Message::Text(text) => {
179 #[cfg(feature = "tracing")]
180 trace!("Received text message: {:?}", text);
181 let message = serde_json::from_str(&text)?;
182 Ok(message)
183 }
184 Message::Binary(message) => {
185 #[cfg(feature = "tracing")]
186 trace!("Received binary message: {:?}", message);
187 let message = serde_json::from_slice(&message)?;
188 Ok(message)
189 }
190 _ => Err(Error::UnexpectedMessageType(Box::new(message))),
191 }
192 }
193
194 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
202 pub async fn send(&self, message: TxMessage) -> Result<(), Error> {
203 #[cfg(feature = "tracing")]
204 trace!("Sending message: {:?}", message);
205
206 let (tx, rx) = oneshot::channel::<Result<(), Error>>();
207 let message = serde_json::to_string(&message)?;
208
209 self.sender
210 .send(TxChannelPayload {
211 message: Message::Text(message.into()),
212 response_tx: tx,
213 })
214 .await
215 .map_err(|_| Error::WebsocketClosed)?;
216 match rx.await {
218 Ok(result) => result,
219 Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
220 }
221 }
222
223 pub async fn next_raw_message(&mut self) -> Result<Message, Error> {
231 self.receiver.recv().await.ok_or(Error::WebsocketClosed)
232 }
233
234 pub async fn send_raw(&self, message: Message) -> Result<(), Error> {
243 let (tx, rx) = oneshot::channel::<Result<(), Error>>();
244 self.sender
245 .send(TxChannelPayload {
246 message,
247 response_tx: tx,
248 })
249 .await
250 .map_err(|_| Error::WebsocketClosed)?;
251 match rx.await {
252 Ok(result) => result,
253 Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
254 }
255 }
256
257 pub async fn reconnect(self) -> Result<Self, Error> {
268 let url = self.url.as_str().to_owned();
269 let options = self.options.clone();
270 let mut handler = self.handler;
271 #[cfg(feature = "tracing")]
272 info!("Reconnecting");
273 handler.on_disconnected().await;
274 match send_close(&self.sender).await {
276 Ok(()) => (),
277 #[allow(unused_variables)]
278 Err(e) => {
279 #[cfg(feature = "tracing")]
280 error!("Socket Loop already stopped: {}", e);
281 }
282 }
283 Self::connect_with_handler(&url, options, handler).await
284 }
285
286 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
292 pub async fn close_connection(self) -> Result<(), Error> {
293 #[cfg(feature = "tracing")]
294 debug!("Closing Connection");
295 send_close(&self.sender).await?;
296 match self.socket_handle.await {
297 Ok(result) => result,
298 Err(_) => unreachable!("Socket loop does not panic, and is not cancelled"),
299 }
300 }
301}
302
303pub(crate) type WebSocketStreamType = WebSocketStream<MaybeTlsStream<TcpStream>>;
304type SocketSink = SplitSink<WebSocketStreamType, Message>;
305type SocketStream = SplitStream<WebSocketStreamType>;
306
307enum LoopState {
308 Running,
309 Error(Error),
310 Closed,
311}
312
313async fn send_close(sender: &mpsc::Sender<TxChannelPayload>) -> Result<(), Error> {
315 let (tx, rx) = oneshot::channel::<Result<(), Error>>();
316 sender
317 .send(TxChannelPayload {
318 message: Message::Close(Some(CloseFrame {
319 code: tungstenite::protocol::frame::coding::CloseCode::Normal,
320 reason: Utf8Bytes::from_static("Closing Connection"),
321 })),
322 response_tx: tx,
323 })
324 .await
325 .map_err(|_| Error::WebsocketClosed)?;
326 match rx.await {
327 Ok(result) => result,
328 Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
329 }
330}
331
332#[cfg_attr(
333 feature = "tracing",
334 instrument(skip(keepalive_interval, keepalive_message))
335)]
336async fn socket_loop_split(
337 mut receiver: mpsc::Receiver<TxChannelPayload>,
338 mut sender: mpsc::Sender<Message>,
339 mut sink: SocketSink,
340 mut stream: SocketStream,
341 keepalive_interval: Option<Duration>,
342 keepalive_message: Option<String>,
343) -> Result<(), Error> {
344 let mut state = LoopState::Running;
345 while matches!(state, LoopState::Running) {
346 state = if let Some(interval) = keepalive_interval {
347 select! {
348 outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
349 incoming_message = stream.next() => socket_message_received(incoming_message, &mut sender, &mut sink).await,
350 () = sleep(interval) => send_keepalive(&mut sink, keepalive_message.as_deref()).await,
351 }
352 } else {
353 select! {
354 outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
355 incoming_message = stream.next() => socket_message_received(incoming_message, &mut sender, &mut sink).await,
356 }
357 };
358 }
359 match state {
360 LoopState::Error(e) => Err(e),
361 LoopState::Closed => Ok(()),
362 LoopState::Running => unreachable!("We only exit when closed or errored"),
363 }
364}
365
366#[cfg_attr(feature = "tracing", instrument)]
367async fn send_socket_message(
368 message: Option<TxChannelPayload>,
369 sink: &mut SocketSink,
370) -> LoopState {
371 if let Some(message) = message {
372 #[cfg(feature = "tracing")]
373 debug!("Sending message: {:?}", message);
374 let send_result = sink.send(message.message).await.map_err(Error::from);
375 let socket_error = send_result.is_err();
376 match message.response_tx.send(send_result) {
377 Ok(()) => {
378 if socket_error {
379 LoopState::Error(Error::WebsocketClosed)
380 } else {
381 LoopState::Running
382 }
383 }
384 Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
385 }
386 } else {
387 #[cfg(feature = "tracing")]
388 error!("Socketeer dropped without closing connection");
389 LoopState::Error(Error::SocketeerDroppedWithoutClosing)
390 }
391}
392
393#[cfg_attr(feature = "tracing", instrument)]
394async fn socket_message_received(
395 message: Option<Result<Message, tungstenite::Error>>,
396 sender: &mut mpsc::Sender<Message>,
397 sink: &mut SocketSink,
398) -> LoopState {
399 const PONG_BYTES: Bytes = Bytes::from_static(b"pong");
400 match message {
401 Some(Ok(message)) => match message {
402 Message::Ping(_) => {
403 let send_result = sink
404 .send(Message::Pong(PONG_BYTES))
405 .await
406 .map_err(Error::from);
407 match send_result {
408 Ok(()) => LoopState::Running,
409 Err(e) => {
410 #[cfg(feature = "tracing")]
411 error!("Error sending Pong: {:?}", e);
412 LoopState::Error(e)
413 }
414 }
415 }
416 Message::Close(_) => {
417 let close_result = sink.close().await;
418 match close_result {
419 Ok(()) => LoopState::Closed,
420 Err(e) => {
421 #[cfg(feature = "tracing")]
422 error!("Error sending Close: {:?}", e);
423 LoopState::Error(Error::from(e))
424 }
425 }
426 }
427 Message::Text(_) | Message::Binary(_) => match sender.send(message).await {
428 Ok(()) => LoopState::Running,
429 Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
430 },
431 _ => LoopState::Running,
432 },
433 Some(Err(e)) => {
434 #[cfg(feature = "tracing")]
435 error!("Error receiving message: {:?}", e);
436 LoopState::Error(Error::WebsocketError(e))
437 }
438 None => {
439 #[cfg(feature = "tracing")]
440 info!("Websocket Closed, closing rx channel");
441 LoopState::Error(Error::WebsocketClosed)
442 }
443 }
444}
445
446#[cfg_attr(feature = "tracing", instrument)]
447async fn send_keepalive(sink: &mut SocketSink, custom_message: Option<&str>) -> LoopState {
448 let message = if let Some(text) = custom_message {
449 #[cfg(feature = "tracing")]
450 info!("Timeout waiting for message, sending custom keepalive");
451 Message::Text(text.into())
452 } else {
453 #[cfg(feature = "tracing")]
454 info!("Timeout waiting for message, sending Ping");
455 Message::Ping(Bytes::new())
456 };
457 let result = sink.send(message).await.map_err(Error::from);
458 match result {
459 Ok(()) => LoopState::Running,
460 Err(e) => {
461 #[cfg(feature = "tracing")]
462 error!("Error sending keepalive: {:?}", e);
463 LoopState::Error(e)
464 }
465 }
466}
467
468#[cfg(all(test, feature = "mocking"))]
469mod tests {
470 use super::*;
471 use tokio::time::sleep;
472
473 #[tokio::test]
474 async fn test_server_startup() {
475 let _server_address = get_mock_address(echo_server).await;
476 }
477
478 #[tokio::test]
479 async fn test_connection() {
480 let server_address = get_mock_address(echo_server).await;
481 let _socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
482 Socketeer::connect(&format!("ws://{server_address}",))
483 .await
484 .unwrap();
485 }
486
487 #[tokio::test]
488 async fn test_bad_url() {
489 let error: Result<Socketeer<EchoControlMessage, EchoControlMessage>, Error> =
490 Socketeer::connect("Not a URL").await;
491 assert!(matches!(error.unwrap_err(), Error::UrlParse { .. }));
492 }
493
494 #[tokio::test]
495 async fn test_send_receive() {
496 let server_address = get_mock_address(echo_server).await;
497 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
498 Socketeer::connect(&format!("ws://{server_address}",))
499 .await
500 .unwrap();
501 let message = EchoControlMessage::Message("Hello".to_string());
502 socketeer.send(message.clone()).await.unwrap();
503 let received_message = socketeer.next_message().await.unwrap();
504 assert_eq!(message, received_message);
505 }
506
507 #[tokio::test]
508 async fn test_ping_request() {
509 let server_address = get_mock_address(echo_server).await;
510 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
511 Socketeer::connect(&format!("ws://{server_address}",))
512 .await
513 .unwrap();
514 let ping_request = EchoControlMessage::SendPing;
515 socketeer.send(ping_request).await.unwrap();
516 let message = EchoControlMessage::Message("Hello".to_string());
518 socketeer.send(message.clone()).await.unwrap();
519 let received_message = socketeer.next_message().await.unwrap();
520 assert_eq!(received_message, message);
521 sleep(Duration::from_millis(2200)).await;
523 socketeer.close_connection().await.unwrap();
525 }
526
527 #[tokio::test]
528 async fn test_reconnection() {
529 let server_address = get_mock_address(echo_server).await;
530 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
531 Socketeer::connect(&format!("ws://{server_address}",))
532 .await
533 .unwrap();
534 let message = EchoControlMessage::Message("Hello".to_string());
535 socketeer.send(message.clone()).await.unwrap();
536 let received_message = socketeer.next_message().await.unwrap();
537 assert_eq!(message, received_message);
538 socketeer = socketeer.reconnect().await.unwrap();
539 let message = EchoControlMessage::Message("Hello".to_string());
540 socketeer.send(message.clone()).await.unwrap();
541 let received_message = socketeer.next_message().await.unwrap();
542 assert_eq!(message, received_message);
543 socketeer.close_connection().await.unwrap();
544 }
545
546 #[tokio::test]
547 async fn test_closed_socket() {
548 let server_address = get_mock_address(echo_server).await;
549 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
550 Socketeer::connect(&format!("ws://{server_address}",))
551 .await
552 .unwrap();
553 let close_request = EchoControlMessage::Close;
554 socketeer.send(close_request.clone()).await.unwrap();
555 let response = socketeer.next_message().await;
556 assert!(matches!(response.unwrap_err(), Error::WebsocketClosed));
557 let send_result = socketeer.send(close_request).await;
558 assert!(send_result.is_err());
559 let error = send_result.unwrap_err();
560 println!("Actual Error: {error:#?}");
561 assert!(matches!(error, Error::WebsocketClosed));
562 }
563
564 #[tokio::test]
565 async fn test_close_request() {
566 let server_address = get_mock_address(echo_server).await;
567 let socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
568 Socketeer::connect(&format!("ws://{server_address}",))
569 .await
570 .unwrap();
571 socketeer.close_connection().await.unwrap();
572 }
573
574 #[tokio::test]
575 async fn test_connect_with_default_options() {
576 let server_address = get_mock_address(echo_server).await;
577 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
578 Socketeer::connect_with(&format!("ws://{server_address}"), ConnectOptions::default())
579 .await
580 .unwrap();
581 let message = EchoControlMessage::Message("Hello".to_string());
582 socketeer.send(message.clone()).await.unwrap();
583 let received_message = socketeer.next_message().await.unwrap();
584 assert_eq!(message, received_message);
585 }
586
587 #[tokio::test]
588 async fn test_send_raw_receive_raw() {
589 let server_address = get_mock_address(echo_server).await;
590 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
591 Socketeer::connect(&format!("ws://{server_address}"))
592 .await
593 .unwrap();
594 let raw_text = r#"{"Message":"raw hello"}"#;
595 socketeer
596 .send_raw(Message::Text(raw_text.into()))
597 .await
598 .unwrap();
599 let received = socketeer.next_raw_message().await.unwrap();
600 assert_eq!(received, Message::Text(raw_text.into()));
601 }
602
603 #[tokio::test]
604 async fn test_disabled_keepalive() {
605 let server_address = get_mock_address(echo_server).await;
606 let options = ConnectOptions {
607 keepalive_interval: None,
608 ..ConnectOptions::default()
609 };
610 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
611 Socketeer::connect_with(&format!("ws://{server_address}"), options)
612 .await
613 .unwrap();
614 let message = EchoControlMessage::Message("Hello".to_string());
615 socketeer.send(message.clone()).await.unwrap();
616 let received_message = socketeer.next_message().await.unwrap();
617 assert_eq!(message, received_message);
618 }
619
620 #[tokio::test]
621 async fn test_handler_on_connected() {
622 use serde::{Deserialize, Serialize};
623 use std::sync::Arc;
624 use tokio::sync::Mutex;
625
626 #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
627 struct AuthResponse {
628 status: String,
629 }
630
631 struct TestAuthHandler {
632 connected_count: Arc<Mutex<u32>>,
633 }
634
635 impl ConnectionHandler for TestAuthHandler {
636 async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_>) -> Result<(), Error> {
637 ctx.send_text(r#"{"action":"auth","token":"test-token"}"#)
638 .await?;
639 let response: AuthResponse = ctx.recv_json().await?;
640 assert_eq!(response.status, "authenticated");
641 let mut count = self.connected_count.lock().await;
642 *count += 1;
643 Ok(())
644 }
645 }
646
647 let connected_count = Arc::new(Mutex::new(0u32));
648 let handler = TestAuthHandler {
649 connected_count: connected_count.clone(),
650 };
651
652 let server_address = get_mock_address(auth_echo_server).await;
653 let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage, TestAuthHandler> =
654 Socketeer::connect_with_handler(
655 &format!("ws://{server_address}"),
656 ConnectOptions::default(),
657 handler,
658 )
659 .await
660 .unwrap();
661
662 assert_eq!(*connected_count.lock().await, 1);
663
664 let message = EchoControlMessage::Message("after auth".to_string());
665 socketeer.send(message.clone()).await.unwrap();
666 let received = socketeer.next_message().await.unwrap();
667 assert_eq!(message, received);
668 }
669
670 #[tokio::test]
671 async fn test_handler_reconnect() {
672 use std::sync::Arc;
673 use tokio::sync::Mutex;
674
675 struct ReconnectHandler {
676 connected_count: Arc<Mutex<u32>>,
677 disconnected_count: Arc<Mutex<u32>>,
678 }
679
680 impl ConnectionHandler for ReconnectHandler {
681 async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_>) -> Result<(), Error> {
682 ctx.send_text(r#"{"action":"auth","token":"test-token"}"#)
683 .await?;
684 let _response = ctx.recv_text().await?;
685 let mut count = self.connected_count.lock().await;
686 *count += 1;
687 Ok(())
688 }
689
690 async fn on_disconnected(&mut self) {
691 let mut count = self.disconnected_count.lock().await;
692 *count += 1;
693 }
694 }
695
696 let connected_count = Arc::new(Mutex::new(0u32));
697 let disconnected_count = Arc::new(Mutex::new(0u32));
698 let handler = ReconnectHandler {
699 connected_count: connected_count.clone(),
700 disconnected_count: disconnected_count.clone(),
701 };
702
703 let server_address = get_mock_address(auth_echo_server).await;
704 let mut socketeer =
705 Socketeer::<EchoControlMessage, EchoControlMessage, ReconnectHandler>::connect_with_handler(
706 &format!("ws://{server_address}"),
707 ConnectOptions::default(),
708 handler,
709 )
710 .await
711 .unwrap();
712
713 assert_eq!(*connected_count.lock().await, 1);
714 assert_eq!(*disconnected_count.lock().await, 0);
715
716 let message = EchoControlMessage::Message("before reconnect".to_string());
718 socketeer.send(message.clone()).await.unwrap();
719 let received = socketeer.next_message().await.unwrap();
720 assert_eq!(message, received);
721
722 socketeer = socketeer.reconnect().await.unwrap();
724
725 assert_eq!(*connected_count.lock().await, 2);
726 assert_eq!(*disconnected_count.lock().await, 1);
727
728 let message = EchoControlMessage::Message("after reconnect".to_string());
730 socketeer.send(message.clone()).await.unwrap();
731 let received = socketeer.next_message().await.unwrap();
732 assert_eq!(message, received);
733
734 socketeer.close_connection().await.unwrap();
735 }
736}