1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3mod codec;
4mod config;
5mod error;
6mod handler;
7#[cfg(feature = "mocking")]
8mod mock_server;
9
10#[cfg(feature = "msgpack")]
11pub use codec::MsgPackCodec;
12pub use codec::{Codec, JsonCodec, RawCodec};
13pub use config::ConnectOptions;
14pub use error::Error;
15pub use handler::{ConnectionHandler, HandshakeContext, NoopHandler};
16#[cfg(all(feature = "mocking", feature = "msgpack"))]
17pub use mock_server::msgpack_echo_server;
18#[cfg(feature = "mocking")]
19pub use mock_server::{EchoControlMessage, auth_echo_server, echo_server, get_mock_address};
20
21use bytes::Bytes;
22use futures::{SinkExt, StreamExt, stream::SplitSink, stream::SplitStream};
23use std::time::Duration;
24use tokio::{
25 net::TcpStream,
26 select,
27 sync::{mpsc, oneshot},
28 time::sleep,
29};
30use tokio_tungstenite::{
31 MaybeTlsStream, WebSocketStream, connect_async,
32 tungstenite::{self, Message, Utf8Bytes, protocol::CloseFrame},
33};
34
35#[cfg(feature = "tracing")]
36use tracing::{debug, error, info, instrument, trace};
37use url::Url;
38
39#[derive(Debug)]
40struct TxChannelPayload {
41 message: Message,
42 response_tx: oneshot::Sender<Result<(), Error>>,
43}
44
45pub struct Socketeer<C: Codec, Handler = NoopHandler, const CHANNEL_SIZE: usize = 4>
59where
60 Handler: ConnectionHandler<C>,
61{
62 url: Url,
63 options: ConnectOptions,
64 codec: C,
65 handler: Handler,
66 receiver: mpsc::Receiver<Message>,
67 sender: mpsc::Sender<TxChannelPayload>,
68 socket_handle: tokio::task::JoinHandle<Result<(), Error>>,
69}
70
71impl<C: Codec, Handler, const CHANNEL_SIZE: usize> std::fmt::Debug
72 for Socketeer<C, Handler, CHANNEL_SIZE>
73where
74 Handler: ConnectionHandler<C>,
75{
76 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77 f.debug_struct("Socketeer")
78 .field("url", &self.url)
79 .finish_non_exhaustive()
80 }
81}
82
83impl<C, const CHANNEL_SIZE: usize> Socketeer<C, NoopHandler, CHANNEL_SIZE>
84where
85 C: Codec + Default,
86{
87 #[cfg_attr(feature = "tracing", instrument)]
93 pub async fn connect(url: &str) -> Result<Self, Error> {
94 Self::connect_with(url, ConnectOptions::default()).await
95 }
96
97 #[cfg_attr(feature = "tracing", instrument(skip(options)))]
102 pub async fn connect_with(url: &str, options: ConnectOptions) -> Result<Self, Error> {
103 Socketeer::connect_with_codec(url, options, C::default(), NoopHandler).await
104 }
105}
106
107impl<C, Handler, const CHANNEL_SIZE: usize> Socketeer<C, Handler, CHANNEL_SIZE>
108where
109 C: Codec,
110 Handler: ConnectionHandler<C>,
111{
112 #[cfg_attr(feature = "tracing", instrument(skip(options, codec, handler)))]
122 pub async fn connect_with_codec(
123 url: &str,
124 options: ConnectOptions,
125 codec: C,
126 mut handler: Handler,
127 ) -> Result<Self, Error> {
128 let url = Url::parse(url).map_err(|source| Error::UrlParse {
129 url: url.to_string(),
130 source,
131 })?;
132
133 let request = options.build_request(&url)?;
134 #[allow(unused_variables)]
135 let (socket, response) = connect_async(request).await?;
136 #[cfg(feature = "tracing")]
137 debug!("Connection Successful, connection info: \n{:#?}", response);
138
139 let (mut sink, mut stream) = socket.split();
140 {
141 let mut ctx = HandshakeContext::new(&mut sink, &mut stream, &codec);
142 handler.on_connected(&mut ctx).await?;
143 }
144
145 let keepalive_interval = options.keepalive_interval;
146 let keepalive_message = options.custom_keepalive_message.clone();
147
148 let (tx_tx, tx_rx) = mpsc::channel::<TxChannelPayload>(CHANNEL_SIZE);
149 let (rx_tx, rx_rx) = mpsc::channel::<Message>(CHANNEL_SIZE);
150
151 let socket_handle = tokio::spawn(async move {
152 socket_loop_split(
153 tx_rx,
154 rx_tx,
155 sink,
156 stream,
157 keepalive_interval,
158 keepalive_message,
159 )
160 .await
161 });
162 Ok(Socketeer {
163 url,
164 options,
165 codec,
166 handler,
167 receiver: rx_rx,
168 sender: tx_tx,
169 socket_handle,
170 })
171 }
172
173 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
181 pub async fn next_message(&mut self) -> Result<C::Rx, Error> {
182 let Some(message) = self.receiver.recv().await else {
183 return Err(Error::WebsocketClosed);
184 };
185 #[cfg(feature = "tracing")]
186 trace!("Received message: {:?}", message);
187 self.codec.decode(&message)
188 }
189
190 #[cfg_attr(feature = "tracing", instrument(skip(self, message)))]
198 pub async fn send(&self, message: C::Tx) -> Result<(), Error> {
199 let encoded = self.codec.encode(&message)?;
200 self.send_raw(encoded).await
201 }
202
203 pub async fn next_raw_message(&mut self) -> Result<Message, Error> {
213 self.receiver.recv().await.ok_or(Error::WebsocketClosed)
214 }
215
216 pub async fn send_raw(&self, message: Message) -> Result<(), Error> {
224 let (tx, rx) = oneshot::channel::<Result<(), Error>>();
225 self.sender
226 .send(TxChannelPayload {
227 message,
228 response_tx: tx,
229 })
230 .await
231 .map_err(|_| Error::WebsocketClosed)?;
232 match rx.await {
233 Ok(result) => result,
234 Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
235 }
236 }
237
238 pub async fn reconnect(self) -> Result<Self, Error> {
249 let url = self.url.as_str().to_owned();
250 let options = self.options.clone();
251 let codec = self.codec;
252 let mut handler = self.handler;
253 #[cfg(feature = "tracing")]
254 info!("Reconnecting");
255 handler.on_disconnected().await;
256 match send_close(&self.sender).await {
258 Ok(()) => (),
259 #[allow(unused_variables)]
260 Err(e) => {
261 #[cfg(feature = "tracing")]
262 error!("Socket Loop already stopped: {}", e);
263 }
264 }
265 Self::connect_with_codec(&url, options, codec, handler).await
266 }
267
268 #[cfg_attr(feature = "tracing", instrument(skip(self)))]
274 pub async fn close_connection(self) -> Result<(), Error> {
275 #[cfg(feature = "tracing")]
276 debug!("Closing Connection");
277 send_close(&self.sender).await?;
278 match self.socket_handle.await {
279 Ok(result) => result,
280 Err(_) => unreachable!("Socket loop does not panic, and is not cancelled"),
281 }
282 }
283}
284
285pub(crate) type WebSocketStreamType = WebSocketStream<MaybeTlsStream<TcpStream>>;
286type SocketSink = SplitSink<WebSocketStreamType, Message>;
287type SocketStream = SplitStream<WebSocketStreamType>;
288
289enum LoopState {
290 Running,
291 Error(Error),
292 Closed,
293}
294
295async fn send_close(sender: &mpsc::Sender<TxChannelPayload>) -> Result<(), Error> {
297 let (tx, rx) = oneshot::channel::<Result<(), Error>>();
298 sender
299 .send(TxChannelPayload {
300 message: Message::Close(Some(CloseFrame {
301 code: tungstenite::protocol::frame::coding::CloseCode::Normal,
302 reason: Utf8Bytes::from_static("Closing Connection"),
303 })),
304 response_tx: tx,
305 })
306 .await
307 .map_err(|_| Error::WebsocketClosed)?;
308 match rx.await {
309 Ok(result) => result,
310 Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
311 }
312}
313
314#[cfg_attr(
315 feature = "tracing",
316 instrument(skip(keepalive_interval, keepalive_message))
317)]
318async fn socket_loop_split(
319 mut receiver: mpsc::Receiver<TxChannelPayload>,
320 mut sender: mpsc::Sender<Message>,
321 mut sink: SocketSink,
322 mut stream: SocketStream,
323 keepalive_interval: Option<Duration>,
324 keepalive_message: Option<Message>,
325) -> Result<(), Error> {
326 let mut state = LoopState::Running;
327 while matches!(state, LoopState::Running) {
328 state = if let Some(interval) = keepalive_interval {
329 select! {
330 outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
331 incoming_message = stream.next() => socket_message_received(incoming_message, &mut sender, &mut sink).await,
332 () = sleep(interval) => send_keepalive(&mut sink, keepalive_message.as_ref()).await,
333 }
334 } else {
335 select! {
336 outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
337 incoming_message = stream.next() => socket_message_received(incoming_message, &mut sender, &mut sink).await,
338 }
339 };
340 }
341 match state {
342 LoopState::Error(e) => Err(e),
343 LoopState::Closed => Ok(()),
344 LoopState::Running => unreachable!("We only exit when closed or errored"),
345 }
346}
347
348#[cfg_attr(feature = "tracing", instrument)]
349async fn send_socket_message(
350 message: Option<TxChannelPayload>,
351 sink: &mut SocketSink,
352) -> LoopState {
353 if let Some(message) = message {
354 #[cfg(feature = "tracing")]
355 debug!("Sending message: {:?}", message);
356 let send_result = sink.send(message.message).await.map_err(Error::from);
357 let socket_error = send_result.is_err();
358 match message.response_tx.send(send_result) {
359 Ok(()) => {
360 if socket_error {
361 LoopState::Error(Error::WebsocketClosed)
362 } else {
363 LoopState::Running
364 }
365 }
366 Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
367 }
368 } else {
369 #[cfg(feature = "tracing")]
370 error!("Socketeer dropped without closing connection");
371 LoopState::Error(Error::SocketeerDroppedWithoutClosing)
372 }
373}
374
375#[cfg_attr(feature = "tracing", instrument)]
376async fn socket_message_received(
377 message: Option<Result<Message, tungstenite::Error>>,
378 sender: &mut mpsc::Sender<Message>,
379 sink: &mut SocketSink,
380) -> LoopState {
381 const PONG_BYTES: Bytes = Bytes::from_static(b"pong");
382 match message {
383 Some(Ok(message)) => match message {
384 Message::Ping(_) => {
385 let send_result = sink
386 .send(Message::Pong(PONG_BYTES))
387 .await
388 .map_err(Error::from);
389 match send_result {
390 Ok(()) => LoopState::Running,
391 Err(e) => {
392 #[cfg(feature = "tracing")]
393 error!("Error sending Pong: {:?}", e);
394 LoopState::Error(e)
395 }
396 }
397 }
398 Message::Close(_) => {
399 let close_result = sink.close().await;
400 match close_result {
401 Ok(()) => LoopState::Closed,
402 Err(e) => {
403 #[cfg(feature = "tracing")]
404 error!("Error sending Close: {:?}", e);
405 LoopState::Error(Error::from(e))
406 }
407 }
408 }
409 Message::Text(_) | Message::Binary(_) => match sender.send(message).await {
410 Ok(()) => LoopState::Running,
411 Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
412 },
413 _ => LoopState::Running,
414 },
415 Some(Err(e)) => {
416 #[cfg(feature = "tracing")]
417 error!("Error receiving message: {:?}", e);
418 LoopState::Error(Error::WebsocketError(e))
419 }
420 None => {
421 #[cfg(feature = "tracing")]
422 info!("Websocket Closed, closing rx channel");
423 LoopState::Error(Error::WebsocketClosed)
424 }
425 }
426}
427
428#[cfg_attr(feature = "tracing", instrument)]
429async fn send_keepalive(sink: &mut SocketSink, custom_message: Option<&Message>) -> LoopState {
430 let message = if let Some(custom) = custom_message {
431 #[cfg(feature = "tracing")]
432 info!("Timeout waiting for message, sending custom keepalive");
433 custom.clone()
434 } else {
435 #[cfg(feature = "tracing")]
436 info!("Timeout waiting for message, sending Ping");
437 Message::Ping(Bytes::new())
438 };
439 let result = sink.send(message).await.map_err(Error::from);
440 match result {
441 Ok(()) => LoopState::Running,
442 Err(e) => {
443 #[cfg(feature = "tracing")]
444 error!("Error sending keepalive: {:?}", e);
445 LoopState::Error(e)
446 }
447 }
448}
449
450#[cfg(all(test, feature = "mocking"))]
451mod tests {
452 use super::*;
453 use tokio::time::sleep;
454
455 type EchoJson = JsonCodec<EchoControlMessage, EchoControlMessage>;
456
457 #[tokio::test]
458 async fn test_server_startup() {
459 let _server_address = get_mock_address(echo_server).await;
460 }
461
462 #[tokio::test]
463 async fn test_connection() {
464 let server_address = get_mock_address(echo_server).await;
465 let _socketeer: Socketeer<EchoJson> = Socketeer::connect(&format!("ws://{server_address}"))
466 .await
467 .unwrap();
468 }
469
470 #[tokio::test]
471 async fn test_bad_url() {
472 let error: Result<Socketeer<EchoJson>, Error> = Socketeer::connect("Not a URL").await;
473 assert!(matches!(error.unwrap_err(), Error::UrlParse { .. }));
474 }
475
476 #[tokio::test]
477 async fn test_send_receive() {
478 let server_address = get_mock_address(echo_server).await;
479 let mut socketeer: Socketeer<EchoJson> =
480 Socketeer::connect(&format!("ws://{server_address}"))
481 .await
482 .unwrap();
483 let message = EchoControlMessage::Message("Hello".to_string());
484 socketeer.send(message.clone()).await.unwrap();
485 let received_message = socketeer.next_message().await.unwrap();
486 assert_eq!(message, received_message);
487 }
488
489 #[tokio::test]
490 async fn test_ping_request() {
491 let server_address = get_mock_address(echo_server).await;
492 let mut socketeer: Socketeer<EchoJson> =
493 Socketeer::connect(&format!("ws://{server_address}"))
494 .await
495 .unwrap();
496 let ping_request = EchoControlMessage::SendPing;
497 socketeer.send(ping_request).await.unwrap();
498 let message = EchoControlMessage::Message("Hello".to_string());
500 socketeer.send(message.clone()).await.unwrap();
501 let received_message = socketeer.next_message().await.unwrap();
502 assert_eq!(received_message, message);
503 sleep(Duration::from_millis(2200)).await;
505 socketeer.close_connection().await.unwrap();
507 }
508
509 #[tokio::test]
510 async fn test_reconnection() {
511 let server_address = get_mock_address(echo_server).await;
512 let mut socketeer: Socketeer<EchoJson> =
513 Socketeer::connect(&format!("ws://{server_address}"))
514 .await
515 .unwrap();
516 let message = EchoControlMessage::Message("Hello".to_string());
517 socketeer.send(message.clone()).await.unwrap();
518 let received_message = socketeer.next_message().await.unwrap();
519 assert_eq!(message, received_message);
520 socketeer = socketeer.reconnect().await.unwrap();
521 let message = EchoControlMessage::Message("Hello".to_string());
522 socketeer.send(message.clone()).await.unwrap();
523 let received_message = socketeer.next_message().await.unwrap();
524 assert_eq!(message, received_message);
525 socketeer.close_connection().await.unwrap();
526 }
527
528 #[tokio::test]
529 async fn test_closed_socket() {
530 let server_address = get_mock_address(echo_server).await;
531 let mut socketeer: Socketeer<EchoJson> =
532 Socketeer::connect(&format!("ws://{server_address}"))
533 .await
534 .unwrap();
535 let close_request = EchoControlMessage::Close;
536 socketeer.send(close_request.clone()).await.unwrap();
537 let response = socketeer.next_message().await;
538 assert!(matches!(response.unwrap_err(), Error::WebsocketClosed));
539 let send_result = socketeer.send(close_request).await;
540 assert!(send_result.is_err());
541 let error = send_result.unwrap_err();
542 println!("Actual Error: {error:#?}");
543 assert!(matches!(error, Error::WebsocketClosed));
544 }
545
546 #[tokio::test]
547 async fn test_close_request() {
548 let server_address = get_mock_address(echo_server).await;
549 let socketeer: Socketeer<EchoJson> = Socketeer::connect(&format!("ws://{server_address}"))
550 .await
551 .unwrap();
552 socketeer.close_connection().await.unwrap();
553 }
554
555 #[tokio::test]
556 async fn test_connect_with_default_options() {
557 let server_address = get_mock_address(echo_server).await;
558 let mut socketeer: Socketeer<EchoJson> =
559 Socketeer::connect_with(&format!("ws://{server_address}"), ConnectOptions::default())
560 .await
561 .unwrap();
562 let message = EchoControlMessage::Message("Hello".to_string());
563 socketeer.send(message.clone()).await.unwrap();
564 let received_message = socketeer.next_message().await.unwrap();
565 assert_eq!(message, received_message);
566 }
567
568 #[tokio::test]
569 async fn test_raw_codec_message_roundtrip() {
570 let server_address = get_mock_address(echo_server).await;
573 let mut socketeer: Socketeer<RawCodec> =
574 Socketeer::connect(&format!("ws://{server_address}"))
575 .await
576 .unwrap();
577 let raw_text = r#"{"Message":"raw hello"}"#;
578 socketeer
579 .send(Message::Text(raw_text.into()))
580 .await
581 .unwrap();
582 let received = socketeer.next_message().await.unwrap();
583 assert_eq!(received, Message::Text(raw_text.into()));
584 }
585
586 #[tokio::test]
587 async fn test_disabled_keepalive() {
588 let server_address = get_mock_address(echo_server).await;
589 let options = ConnectOptions {
590 keepalive_interval: None,
591 ..ConnectOptions::default()
592 };
593 let mut socketeer: Socketeer<EchoJson> =
594 Socketeer::connect_with(&format!("ws://{server_address}"), options)
595 .await
596 .unwrap();
597 let message = EchoControlMessage::Message("Hello".to_string());
598 socketeer.send(message.clone()).await.unwrap();
599 let received_message = socketeer.next_message().await.unwrap();
600 assert_eq!(message, received_message);
601 }
602
603 #[tokio::test]
604 async fn test_handler_on_connected() {
605 use serde::{Deserialize, Serialize};
606 use std::sync::Arc;
607 use tokio::sync::Mutex;
608
609 #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
610 struct AuthResponse {
611 status: String,
612 }
613
614 struct TestAuthHandler {
615 connected_count: Arc<Mutex<u32>>,
616 }
617
618 impl<C: Codec> ConnectionHandler<C> for TestAuthHandler {
619 async fn on_connected(
620 &mut self,
621 ctx: &mut HandshakeContext<'_, C>,
622 ) -> Result<(), Error> {
623 ctx.send_text(r#"{"action":"auth","token":"test-token"}"#)
624 .await?;
625 let text = ctx.recv_text().await?;
626 let response: AuthResponse = serde_json::from_str(&text).unwrap();
627 assert_eq!(response.status, "authenticated");
628 let mut count = self.connected_count.lock().await;
629 *count += 1;
630 Ok(())
631 }
632 }
633
634 let connected_count = Arc::new(Mutex::new(0u32));
635 let handler = TestAuthHandler {
636 connected_count: connected_count.clone(),
637 };
638
639 let server_address = get_mock_address(auth_echo_server).await;
640 let mut socketeer: Socketeer<EchoJson, TestAuthHandler> = Socketeer::connect_with_codec(
641 &format!("ws://{server_address}"),
642 ConnectOptions::default(),
643 JsonCodec::new(),
644 handler,
645 )
646 .await
647 .unwrap();
648
649 assert_eq!(*connected_count.lock().await, 1);
650
651 let message = EchoControlMessage::Message("after auth".to_string());
652 socketeer.send(message.clone()).await.unwrap();
653 let received = socketeer.next_message().await.unwrap();
654 assert_eq!(message, received);
655 }
656
657 #[tokio::test]
658 async fn test_handler_reconnect() {
659 use std::sync::Arc;
660 use tokio::sync::Mutex;
661
662 struct ReconnectHandler {
663 connected_count: Arc<Mutex<u32>>,
664 disconnected_count: Arc<Mutex<u32>>,
665 }
666
667 impl<C: Codec> ConnectionHandler<C> for ReconnectHandler {
668 async fn on_connected(
669 &mut self,
670 ctx: &mut HandshakeContext<'_, C>,
671 ) -> Result<(), Error> {
672 ctx.send_text(r#"{"action":"auth","token":"test-token"}"#)
673 .await?;
674 let _response = ctx.recv_text().await?;
675 let mut count = self.connected_count.lock().await;
676 *count += 1;
677 Ok(())
678 }
679
680 async fn on_disconnected(&mut self) {
681 let mut count = self.disconnected_count.lock().await;
682 *count += 1;
683 }
684 }
685
686 let connected_count = Arc::new(Mutex::new(0u32));
687 let disconnected_count = Arc::new(Mutex::new(0u32));
688 let handler = ReconnectHandler {
689 connected_count: connected_count.clone(),
690 disconnected_count: disconnected_count.clone(),
691 };
692
693 let server_address = get_mock_address(auth_echo_server).await;
694 let mut socketeer = Socketeer::<EchoJson, ReconnectHandler>::connect_with_codec(
695 &format!("ws://{server_address}"),
696 ConnectOptions::default(),
697 JsonCodec::new(),
698 handler,
699 )
700 .await
701 .unwrap();
702
703 assert_eq!(*connected_count.lock().await, 1);
704 assert_eq!(*disconnected_count.lock().await, 0);
705
706 let message = EchoControlMessage::Message("before reconnect".to_string());
708 socketeer.send(message.clone()).await.unwrap();
709 let received = socketeer.next_message().await.unwrap();
710 assert_eq!(message, received);
711
712 socketeer = socketeer.reconnect().await.unwrap();
714
715 assert_eq!(*connected_count.lock().await, 2);
716 assert_eq!(*disconnected_count.lock().await, 1);
717
718 let message = EchoControlMessage::Message("after reconnect".to_string());
720 socketeer.send(message.clone()).await.unwrap();
721 let received = socketeer.next_message().await.unwrap();
722 assert_eq!(message, received);
723
724 socketeer.close_connection().await.unwrap();
725 }
726
727 #[cfg(feature = "msgpack")]
728 #[tokio::test]
729 async fn test_msgpack_send_receive() {
730 type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
731
732 let server_address = get_mock_address(msgpack_echo_server).await;
733 let mut socketeer: Socketeer<EchoMsgPack> =
734 Socketeer::connect(&format!("ws://{server_address}"))
735 .await
736 .unwrap();
737 let message = EchoControlMessage::Message("msgpack hello".to_string());
738 socketeer.send(message.clone()).await.unwrap();
739 let received = socketeer.next_message().await.unwrap();
740 assert_eq!(message, received);
741 socketeer.close_connection().await.unwrap();
742 }
743
744 #[tokio::test]
745 async fn test_handler_uses_codec_driven_send_recv() {
746 struct TypedHandshakeHandler;
749
750 impl ConnectionHandler<EchoJson> for TypedHandshakeHandler {
751 async fn on_connected(
752 &mut self,
753 ctx: &mut HandshakeContext<'_, EchoJson>,
754 ) -> Result<(), Error> {
755 ctx.send(&EchoControlMessage::Message("handshake".into()))
756 .await?;
757 let echoed = ctx.recv().await?;
758 assert_eq!(echoed, EchoControlMessage::Message("handshake".into()));
759 Ok(())
760 }
761 }
762
763 let server_address = get_mock_address(echo_server).await;
764 let mut socketeer: Socketeer<EchoJson, TypedHandshakeHandler> =
765 Socketeer::connect_with_codec(
766 &format!("ws://{server_address}"),
767 ConnectOptions::default(),
768 JsonCodec::new(),
769 TypedHandshakeHandler,
770 )
771 .await
772 .unwrap();
773
774 let message = EchoControlMessage::Message("after handshake".into());
776 socketeer.send(message.clone()).await.unwrap();
777 assert_eq!(socketeer.next_message().await.unwrap(), message);
778 socketeer.close_connection().await.unwrap();
779 }
780
781 #[tokio::test]
782 async fn test_handshake_recv_close_with_raw_codec() {
783 struct CloseExpecting;
788
789 impl ConnectionHandler<RawCodec> for CloseExpecting {
790 async fn on_connected(
791 &mut self,
792 ctx: &mut HandshakeContext<'_, RawCodec>,
793 ) -> Result<(), Error> {
794 ctx.send(&Message::Text(r#""Close""#.into())).await?;
796 let err = ctx.recv().await.unwrap_err();
797 assert!(matches!(err, Error::WebsocketClosed));
798 Ok(())
799 }
800 }
801
802 let server_address = get_mock_address(echo_server).await;
803 let _socketeer: Socketeer<RawCodec, CloseExpecting> = Socketeer::connect_with_codec(
804 &format!("ws://{server_address}"),
805 ConnectOptions::default(),
806 RawCodec::new(),
807 CloseExpecting,
808 )
809 .await
810 .unwrap();
811 }
812
813 #[tokio::test]
814 async fn test_extra_headers_used() {
815 let server_address = get_mock_address(echo_server).await;
818 let mut headers = tokio_tungstenite::tungstenite::http::HeaderMap::new();
819 headers.insert("X-Test-Header", "socketeer".parse().unwrap());
820 let options = ConnectOptions {
821 extra_headers: headers,
822 ..ConnectOptions::default()
823 };
824 let mut socketeer: Socketeer<EchoJson> =
825 Socketeer::connect_with(&format!("ws://{server_address}"), options)
826 .await
827 .unwrap();
828 let message = EchoControlMessage::Message("hi".into());
829 socketeer.send(message.clone()).await.unwrap();
830 assert_eq!(socketeer.next_message().await.unwrap(), message);
831 socketeer.close_connection().await.unwrap();
832 }
833
834 #[tokio::test]
835 async fn test_auth_handler_bad_token() {
836 struct BadTokenHandler;
840
841 impl<C: Codec> ConnectionHandler<C> for BadTokenHandler {
842 async fn on_connected(
843 &mut self,
844 ctx: &mut HandshakeContext<'_, C>,
845 ) -> Result<(), Error> {
846 ctx.send_text(r#"{"action":"auth","token":"WRONG"}"#)
847 .await?;
848 let resp = ctx.recv_text().await?;
849 assert!(resp.contains("error"));
850 Ok(())
851 }
852 }
853
854 let server_address = get_mock_address(auth_echo_server).await;
855 let _socketeer: Socketeer<EchoJson, BadTokenHandler> = Socketeer::connect_with_codec(
856 &format!("ws://{server_address}"),
857 ConnectOptions::default(),
858 JsonCodec::new(),
859 BadTokenHandler,
860 )
861 .await
862 .unwrap();
863 }
864
865 #[cfg(feature = "msgpack")]
866 #[tokio::test]
867 async fn test_msgpack_send_ping() {
868 type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
870
871 let server_address = get_mock_address(msgpack_echo_server).await;
872 let mut socketeer: Socketeer<EchoMsgPack> =
873 Socketeer::connect(&format!("ws://{server_address}"))
874 .await
875 .unwrap();
876 socketeer.send(EchoControlMessage::SendPing).await.unwrap();
877 let message = EchoControlMessage::Message("after ping".into());
880 socketeer.send(message.clone()).await.unwrap();
881 assert_eq!(socketeer.next_message().await.unwrap(), message);
882 socketeer.close_connection().await.unwrap();
883 }
884
885 #[cfg(feature = "msgpack")]
886 #[tokio::test]
887 async fn test_msgpack_close_request() {
888 type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
890
891 let server_address = get_mock_address(msgpack_echo_server).await;
892 let mut socketeer: Socketeer<EchoMsgPack> =
893 Socketeer::connect(&format!("ws://{server_address}"))
894 .await
895 .unwrap();
896 socketeer.send(EchoControlMessage::Close).await.unwrap();
897 let result = socketeer.next_message().await;
898 assert!(matches!(result.unwrap_err(), Error::WebsocketClosed));
899 }
900
901 #[tokio::test]
902 async fn test_socketeer_debug_format() {
903 let server_address = get_mock_address(echo_server).await;
904 let socketeer: Socketeer<EchoJson> = Socketeer::connect(&format!("ws://{server_address}"))
905 .await
906 .unwrap();
907 let formatted = format!("{socketeer:?}");
908 assert!(formatted.starts_with("Socketeer"));
909 assert!(formatted.contains("url"));
910 }
911
912 #[tokio::test]
913 async fn test_send_raw_next_raw_message() {
914 let server_address = get_mock_address(echo_server).await;
919 let mut socketeer: Socketeer<EchoJson> =
920 Socketeer::connect(&format!("ws://{server_address}"))
921 .await
922 .unwrap();
923 let raw_text = r#"{"Message":"raw recv"}"#;
924 socketeer
925 .send_raw(Message::Text(raw_text.into()))
926 .await
927 .unwrap();
928 let frame = socketeer.next_raw_message().await.unwrap();
929 assert_eq!(frame, Message::Text(raw_text.into()));
930 socketeer.close_connection().await.unwrap();
931 }
932
933 #[cfg(feature = "msgpack")]
934 #[tokio::test]
935 async fn test_handshake_send_binary_recv_raw() {
936 struct BinaryHandshake;
940
941 type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
942
943 impl ConnectionHandler<EchoMsgPack> for BinaryHandshake {
944 async fn on_connected(
945 &mut self,
946 ctx: &mut HandshakeContext<'_, EchoMsgPack>,
947 ) -> Result<(), Error> {
948 let payload =
949 rmp_serde::to_vec_named(&EchoControlMessage::Message("binary".into())).unwrap();
950 ctx.send_binary(payload).await?;
951 let echo = ctx.recv_raw().await?;
952 assert!(matches!(echo, Message::Binary(_)));
953 Ok(())
954 }
955 }
956
957 let server_address = get_mock_address(msgpack_echo_server).await;
958 let socketeer: Socketeer<EchoMsgPack, BinaryHandshake> = Socketeer::connect_with_codec(
959 &format!("ws://{server_address}"),
960 ConnectOptions::default(),
961 MsgPackCodec::new(),
962 BinaryHandshake,
963 )
964 .await
965 .unwrap();
966 socketeer.close_connection().await.unwrap();
967 }
968
969 #[cfg(feature = "msgpack")]
970 #[tokio::test]
971 async fn test_handshake_recv_text_rejects_binary() {
972 struct ExpectsTextOnBinary;
975
976 type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
977
978 impl ConnectionHandler<EchoMsgPack> for ExpectsTextOnBinary {
979 async fn on_connected(
980 &mut self,
981 ctx: &mut HandshakeContext<'_, EchoMsgPack>,
982 ) -> Result<(), Error> {
983 let payload =
984 rmp_serde::to_vec_named(&EchoControlMessage::Message("hi".into())).unwrap();
985 ctx.send_binary(payload).await?;
986 let err = ctx.recv_text().await.unwrap_err();
988 assert!(matches!(err, Error::UnexpectedMessageType(_)));
989 Ok(())
990 }
991 }
992
993 let server_address = get_mock_address(msgpack_echo_server).await;
994 let socketeer: Socketeer<EchoMsgPack, ExpectsTextOnBinary> = Socketeer::connect_with_codec(
995 &format!("ws://{server_address}"),
996 ConnectOptions::default(),
997 MsgPackCodec::new(),
998 ExpectsTextOnBinary,
999 )
1000 .await
1001 .unwrap();
1002 socketeer.close_connection().await.unwrap();
1003 }
1004
1005 #[tokio::test]
1006 async fn test_binary_custom_keepalive() {
1007 let server_address = get_mock_address(echo_server).await;
1012 let options = ConnectOptions {
1013 keepalive_interval: Some(Duration::from_millis(100)),
1014 custom_keepalive_message: Some(Message::Binary(Bytes::from_static(b"keepalive"))),
1015 ..ConnectOptions::default()
1016 };
1017 let mut socketeer: Socketeer<EchoJson> =
1018 Socketeer::connect_with(&format!("ws://{server_address}"), options)
1019 .await
1020 .unwrap();
1021
1022 sleep(Duration::from_millis(350)).await;
1024
1025 let message = EchoControlMessage::Message("post-keepalive".into());
1026 socketeer.send(message.clone()).await.unwrap();
1027 assert_eq!(socketeer.next_message().await.unwrap(), message);
1028 socketeer.close_connection().await.unwrap();
1029 }
1030}