Skip to main content

socketeer/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3mod codec;
4mod config;
5mod error;
6mod handler;
7#[cfg(feature = "mocking")]
8mod mock_server;
9
10#[cfg(feature = "msgpack")]
11pub use codec::MsgPackCodec;
12pub use codec::{Codec, JsonCodec, RawCodec};
13pub use config::ConnectOptions;
14pub use error::Error;
15pub use handler::{ConnectionHandler, HandshakeContext, NoopHandler};
16#[cfg(all(feature = "mocking", feature = "msgpack"))]
17pub use mock_server::msgpack_echo_server;
18#[cfg(feature = "mocking")]
19pub use mock_server::{EchoControlMessage, auth_echo_server, echo_server, get_mock_address};
20
21use bytes::Bytes;
22use futures::{SinkExt, StreamExt, stream::SplitSink, stream::SplitStream};
23use std::time::Duration;
24use tokio::{
25    net::TcpStream,
26    select,
27    sync::{mpsc, oneshot},
28    time::sleep,
29};
30use tokio_tungstenite::{
31    MaybeTlsStream, WebSocketStream, connect_async,
32    tungstenite::{self, Message, Utf8Bytes, protocol::CloseFrame},
33};
34
35#[cfg(feature = "tracing")]
36use tracing::{debug, error, info, instrument, trace};
37use url::Url;
38
39#[derive(Debug)]
40struct TxChannelPayload {
41    message: Message,
42    response_tx: oneshot::Sender<Result<(), Error>>,
43}
44
45/// A WebSocket client that manages the connection to a WebSocket server.
46/// The client can send and receive messages, and will transparently handle protocol messages.
47///
48/// # Type Parameters
49///
50/// - `C`: A [`Codec`] that defines the connection's `Tx` and `Rx` types and how
51///   they map to WebSocket frames. Use [`JsonCodec`] for the common case,
52///   [`MsgPackCodec`] (behind the `msgpack` feature) for `MessagePack`, or
53///   [`RawCodec`] for direct [`Message`] access.
54/// - `Handler`: A [`ConnectionHandler`] for lifecycle hooks (auth, subscriptions).
55///   Defaults to [`NoopHandler`] for the simple case.
56/// - `CHANNEL_SIZE`: The size of the internal channels used to communicate between
57///   the task managing the WebSocket connection and the client.
58pub struct Socketeer<C: Codec, Handler = NoopHandler, const CHANNEL_SIZE: usize = 4>
59where
60    Handler: ConnectionHandler<C>,
61{
62    url: Url,
63    options: ConnectOptions,
64    codec: C,
65    handler: Handler,
66    receiver: mpsc::Receiver<Message>,
67    sender: mpsc::Sender<TxChannelPayload>,
68    socket_handle: tokio::task::JoinHandle<Result<(), Error>>,
69}
70
71impl<C: Codec, Handler, const CHANNEL_SIZE: usize> std::fmt::Debug
72    for Socketeer<C, Handler, CHANNEL_SIZE>
73where
74    Handler: ConnectionHandler<C>,
75{
76    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
77        f.debug_struct("Socketeer")
78            .field("url", &self.url)
79            .finish_non_exhaustive()
80    }
81}
82
83impl<C, const CHANNEL_SIZE: usize> Socketeer<C, NoopHandler, CHANNEL_SIZE>
84where
85    C: Codec + Default,
86{
87    /// Create a `Socketeer` connected to the provided URL with default options.
88    /// Once connected, Socketeer manages the underlying WebSocket connection, transparently handling protocol messages.
89    /// # Errors
90    /// - If the URL cannot be parsed
91    /// - If the WebSocket connection to the requested URL fails
92    #[cfg_attr(feature = "tracing", instrument)]
93    pub async fn connect(url: &str) -> Result<Self, Error> {
94        Self::connect_with(url, ConnectOptions::default()).await
95    }
96
97    /// Create a `Socketeer` connected to the provided URL with custom connection options.
98    /// # Errors
99    /// - If the URL cannot be parsed
100    /// - If the WebSocket connection to the requested URL fails
101    #[cfg_attr(feature = "tracing", instrument(skip(options)))]
102    pub async fn connect_with(url: &str, options: ConnectOptions) -> Result<Self, Error> {
103        Socketeer::connect_with_codec(url, options, C::default(), NoopHandler).await
104    }
105}
106
107impl<C, Handler, const CHANNEL_SIZE: usize> Socketeer<C, Handler, CHANNEL_SIZE>
108where
109    C: Codec,
110    Handler: ConnectionHandler<C>,
111{
112    /// Create a `Socketeer` with an explicit codec and [`ConnectionHandler`].
113    ///
114    /// The handler's [`ConnectionHandler::on_connected`] method is called after the
115    /// WebSocket upgrade completes, before the socket loop starts. This is where
116    /// you should perform authentication handshakes and initial subscriptions.
117    /// # Errors
118    /// - If the URL cannot be parsed
119    /// - If the WebSocket connection to the requested URL fails
120    /// - If the handler's `on_connected` returns an error
121    #[cfg_attr(feature = "tracing", instrument(skip(options, codec, handler)))]
122    pub async fn connect_with_codec(
123        url: &str,
124        options: ConnectOptions,
125        codec: C,
126        mut handler: Handler,
127    ) -> Result<Self, Error> {
128        let url = Url::parse(url).map_err(|source| Error::UrlParse {
129            url: url.to_string(),
130            source,
131        })?;
132
133        let request = options.build_request(&url)?;
134        #[allow(unused_variables)]
135        let (socket, response) = connect_async(request).await?;
136        #[cfg(feature = "tracing")]
137        debug!("Connection Successful, connection info: \n{:#?}", response);
138
139        let (mut sink, mut stream) = socket.split();
140        {
141            let mut ctx = HandshakeContext::new(&mut sink, &mut stream, &codec);
142            handler.on_connected(&mut ctx).await?;
143        }
144
145        let keepalive_interval = options.keepalive_interval;
146        let keepalive_message = options.custom_keepalive_message.clone();
147
148        let (tx_tx, tx_rx) = mpsc::channel::<TxChannelPayload>(CHANNEL_SIZE);
149        let (rx_tx, rx_rx) = mpsc::channel::<Message>(CHANNEL_SIZE);
150
151        let socket_handle = tokio::spawn(async move {
152            socket_loop_split(
153                tx_rx,
154                rx_tx,
155                sink,
156                stream,
157                keepalive_interval,
158                keepalive_message,
159            )
160            .await
161        });
162        Ok(Socketeer {
163            url,
164            options,
165            codec,
166            handler,
167            receiver: rx_rx,
168            sender: tx_tx,
169            socket_handle,
170        })
171    }
172
173    /// Wait for the next message from the WebSocket connection, decoded by the
174    /// connection's [`Codec`].
175    ///
176    /// # Errors
177    ///
178    /// - If the WebSocket connection is closed or otherwise errored
179    /// - If the codec fails to decode the frame
180    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
181    pub async fn next_message(&mut self) -> Result<C::Rx, Error> {
182        let Some(message) = self.receiver.recv().await else {
183            return Err(Error::WebsocketClosed);
184        };
185        #[cfg(feature = "tracing")]
186        trace!("Received message: {:?}", message);
187        self.codec.decode(&message)
188    }
189
190    /// Encode and send a message via the connection's [`Codec`].
191    /// This function will wait for the message to be sent before returning.
192    ///
193    /// # Errors
194    ///
195    /// - If the codec fails to encode the value
196    /// - If the WebSocket connection is closed, or otherwise errored
197    #[cfg_attr(feature = "tracing", instrument(skip(self, message)))]
198    pub async fn send(&self, message: C::Tx) -> Result<(), Error> {
199        let encoded = self.codec.encode(&message)?;
200        self.send_raw(encoded).await
201    }
202
203    /// Receive the next raw [`Message`] from the WebSocket connection without
204    /// running the codec.
205    ///
206    /// Useful when you need to inspect the underlying frame type or handle a
207    /// message that the codec would reject.
208    ///
209    /// # Errors
210    ///
211    /// - If the WebSocket connection is closed or otherwise errored
212    pub async fn next_raw_message(&mut self) -> Result<Message, Error> {
213        self.receiver.recv().await.ok_or(Error::WebsocketClosed)
214    }
215
216    /// Send a raw [`Message`] to the WebSocket connection without running the codec.
217    ///
218    /// Useful for sending control frames or pre-encoded payloads.
219    ///
220    /// # Errors
221    ///
222    /// - If the WebSocket connection is closed, or otherwise errored
223    pub async fn send_raw(&self, message: Message) -> Result<(), Error> {
224        let (tx, rx) = oneshot::channel::<Result<(), Error>>();
225        self.sender
226            .send(TxChannelPayload {
227                message,
228                response_tx: tx,
229            })
230            .await
231            .map_err(|_| Error::WebsocketClosed)?;
232        match rx.await {
233            Ok(result) => result,
234            Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
235        }
236    }
237
238    /// Consume self, closing down any remaining send/receive, and return a new Socketeer instance if successful.
239    /// This function attempts to close the connection gracefully before returning,
240    /// but will not return an error if the connection is already closed,
241    /// as its intended use is to re-establish a failed connection.
242    ///
243    /// The handler's [`ConnectionHandler::on_disconnected`] is called before closing,
244    /// and [`ConnectionHandler::on_connected`] is called after reconnecting.
245    /// # Errors
246    /// - If a new connection cannot be established
247    /// - If the handler's `on_connected` returns an error
248    pub async fn reconnect(self) -> Result<Self, Error> {
249        let url = self.url.as_str().to_owned();
250        let options = self.options.clone();
251        let codec = self.codec;
252        let mut handler = self.handler;
253        #[cfg(feature = "tracing")]
254        info!("Reconnecting");
255        handler.on_disconnected().await;
256        // Attempt graceful close, but don't fail if already closed
257        match send_close(&self.sender).await {
258            Ok(()) => (),
259            #[allow(unused_variables)]
260            Err(e) => {
261                #[cfg(feature = "tracing")]
262                error!("Socket Loop already stopped: {}", e);
263            }
264        }
265        Self::connect_with_codec(&url, options, codec, handler).await
266    }
267
268    /// Close the WebSocket connection gracefully.
269    /// This function will wait for the connection to close before returning.
270    /// # Errors
271    /// - If the WebSocket connection is already closed
272    /// - If the WebSocket connection cannot be closed
273    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
274    pub async fn close_connection(self) -> Result<(), Error> {
275        #[cfg(feature = "tracing")]
276        debug!("Closing Connection");
277        send_close(&self.sender).await?;
278        match self.socket_handle.await {
279            Ok(result) => result,
280            Err(_) => unreachable!("Socket loop does not panic, and is not cancelled"),
281        }
282    }
283}
284
285pub(crate) type WebSocketStreamType = WebSocketStream<MaybeTlsStream<TcpStream>>;
286type SocketSink = SplitSink<WebSocketStreamType, Message>;
287type SocketStream = SplitStream<WebSocketStreamType>;
288
289enum LoopState {
290    Running,
291    Error(Error),
292    Closed,
293}
294
295/// Send a close frame via the tx channel and wait for confirmation.
296async fn send_close(sender: &mpsc::Sender<TxChannelPayload>) -> Result<(), Error> {
297    let (tx, rx) = oneshot::channel::<Result<(), Error>>();
298    sender
299        .send(TxChannelPayload {
300            message: Message::Close(Some(CloseFrame {
301                code: tungstenite::protocol::frame::coding::CloseCode::Normal,
302                reason: Utf8Bytes::from_static("Closing Connection"),
303            })),
304            response_tx: tx,
305        })
306        .await
307        .map_err(|_| Error::WebsocketClosed)?;
308    match rx.await {
309        Ok(result) => result,
310        Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
311    }
312}
313
314#[cfg_attr(
315    feature = "tracing",
316    instrument(skip(keepalive_interval, keepalive_message))
317)]
318async fn socket_loop_split(
319    mut receiver: mpsc::Receiver<TxChannelPayload>,
320    mut sender: mpsc::Sender<Message>,
321    mut sink: SocketSink,
322    mut stream: SocketStream,
323    keepalive_interval: Option<Duration>,
324    keepalive_message: Option<Message>,
325) -> Result<(), Error> {
326    let mut state = LoopState::Running;
327    while matches!(state, LoopState::Running) {
328        state = if let Some(interval) = keepalive_interval {
329            select! {
330                outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
331                incoming_message = stream.next() => socket_message_received(incoming_message, &mut sender, &mut sink).await,
332                () = sleep(interval) => send_keepalive(&mut sink, keepalive_message.as_ref()).await,
333            }
334        } else {
335            select! {
336                outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
337                incoming_message = stream.next() => socket_message_received(incoming_message, &mut sender, &mut sink).await,
338            }
339        };
340    }
341    match state {
342        LoopState::Error(e) => Err(e),
343        LoopState::Closed => Ok(()),
344        LoopState::Running => unreachable!("We only exit when closed or errored"),
345    }
346}
347
348#[cfg_attr(feature = "tracing", instrument)]
349async fn send_socket_message(
350    message: Option<TxChannelPayload>,
351    sink: &mut SocketSink,
352) -> LoopState {
353    if let Some(message) = message {
354        #[cfg(feature = "tracing")]
355        debug!("Sending message: {:?}", message);
356        let send_result = sink.send(message.message).await.map_err(Error::from);
357        let socket_error = send_result.is_err();
358        match message.response_tx.send(send_result) {
359            Ok(()) => {
360                if socket_error {
361                    LoopState::Error(Error::WebsocketClosed)
362                } else {
363                    LoopState::Running
364                }
365            }
366            Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
367        }
368    } else {
369        #[cfg(feature = "tracing")]
370        error!("Socketeer dropped without closing connection");
371        LoopState::Error(Error::SocketeerDroppedWithoutClosing)
372    }
373}
374
375#[cfg_attr(feature = "tracing", instrument)]
376async fn socket_message_received(
377    message: Option<Result<Message, tungstenite::Error>>,
378    sender: &mut mpsc::Sender<Message>,
379    sink: &mut SocketSink,
380) -> LoopState {
381    const PONG_BYTES: Bytes = Bytes::from_static(b"pong");
382    match message {
383        Some(Ok(message)) => match message {
384            Message::Ping(_) => {
385                let send_result = sink
386                    .send(Message::Pong(PONG_BYTES))
387                    .await
388                    .map_err(Error::from);
389                match send_result {
390                    Ok(()) => LoopState::Running,
391                    Err(e) => {
392                        #[cfg(feature = "tracing")]
393                        error!("Error sending Pong: {:?}", e);
394                        LoopState::Error(e)
395                    }
396                }
397            }
398            Message::Close(_) => {
399                let close_result = sink.close().await;
400                match close_result {
401                    Ok(()) => LoopState::Closed,
402                    Err(e) => {
403                        #[cfg(feature = "tracing")]
404                        error!("Error sending Close: {:?}", e);
405                        LoopState::Error(Error::from(e))
406                    }
407                }
408            }
409            Message::Text(_) | Message::Binary(_) => match sender.send(message).await {
410                Ok(()) => LoopState::Running,
411                Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
412            },
413            _ => LoopState::Running,
414        },
415        Some(Err(e)) => {
416            #[cfg(feature = "tracing")]
417            error!("Error receiving message: {:?}", e);
418            LoopState::Error(Error::WebsocketError(e))
419        }
420        None => {
421            #[cfg(feature = "tracing")]
422            info!("Websocket Closed, closing rx channel");
423            LoopState::Error(Error::WebsocketClosed)
424        }
425    }
426}
427
428#[cfg_attr(feature = "tracing", instrument)]
429async fn send_keepalive(sink: &mut SocketSink, custom_message: Option<&Message>) -> LoopState {
430    let message = if let Some(custom) = custom_message {
431        #[cfg(feature = "tracing")]
432        info!("Timeout waiting for message, sending custom keepalive");
433        custom.clone()
434    } else {
435        #[cfg(feature = "tracing")]
436        info!("Timeout waiting for message, sending Ping");
437        Message::Ping(Bytes::new())
438    };
439    let result = sink.send(message).await.map_err(Error::from);
440    match result {
441        Ok(()) => LoopState::Running,
442        Err(e) => {
443            #[cfg(feature = "tracing")]
444            error!("Error sending keepalive: {:?}", e);
445            LoopState::Error(e)
446        }
447    }
448}
449
450#[cfg(all(test, feature = "mocking"))]
451mod tests {
452    use super::*;
453    use tokio::time::sleep;
454
455    type EchoJson = JsonCodec<EchoControlMessage, EchoControlMessage>;
456
457    #[tokio::test]
458    async fn test_server_startup() {
459        let _server_address = get_mock_address(echo_server).await;
460    }
461
462    #[tokio::test]
463    async fn test_connection() {
464        let server_address = get_mock_address(echo_server).await;
465        let _socketeer: Socketeer<EchoJson> = Socketeer::connect(&format!("ws://{server_address}"))
466            .await
467            .unwrap();
468    }
469
470    #[tokio::test]
471    async fn test_bad_url() {
472        let error: Result<Socketeer<EchoJson>, Error> = Socketeer::connect("Not a URL").await;
473        assert!(matches!(error.unwrap_err(), Error::UrlParse { .. }));
474    }
475
476    #[tokio::test]
477    async fn test_send_receive() {
478        let server_address = get_mock_address(echo_server).await;
479        let mut socketeer: Socketeer<EchoJson> =
480            Socketeer::connect(&format!("ws://{server_address}"))
481                .await
482                .unwrap();
483        let message = EchoControlMessage::Message("Hello".to_string());
484        socketeer.send(message.clone()).await.unwrap();
485        let received_message = socketeer.next_message().await.unwrap();
486        assert_eq!(message, received_message);
487    }
488
489    #[tokio::test]
490    async fn test_ping_request() {
491        let server_address = get_mock_address(echo_server).await;
492        let mut socketeer: Socketeer<EchoJson> =
493            Socketeer::connect(&format!("ws://{server_address}"))
494                .await
495                .unwrap();
496        let ping_request = EchoControlMessage::SendPing;
497        socketeer.send(ping_request).await.unwrap();
498        // The server will respond with a ping request, which Socketeer will transparently respond to
499        let message = EchoControlMessage::Message("Hello".to_string());
500        socketeer.send(message.clone()).await.unwrap();
501        let received_message = socketeer.next_message().await.unwrap();
502        assert_eq!(received_message, message);
503        // We should send a ping in here
504        sleep(Duration::from_millis(2200)).await;
505        // Ensure everything shuts down so we exercize the ping functionality fully
506        socketeer.close_connection().await.unwrap();
507    }
508
509    #[tokio::test]
510    async fn test_reconnection() {
511        let server_address = get_mock_address(echo_server).await;
512        let mut socketeer: Socketeer<EchoJson> =
513            Socketeer::connect(&format!("ws://{server_address}"))
514                .await
515                .unwrap();
516        let message = EchoControlMessage::Message("Hello".to_string());
517        socketeer.send(message.clone()).await.unwrap();
518        let received_message = socketeer.next_message().await.unwrap();
519        assert_eq!(message, received_message);
520        socketeer = socketeer.reconnect().await.unwrap();
521        let message = EchoControlMessage::Message("Hello".to_string());
522        socketeer.send(message.clone()).await.unwrap();
523        let received_message = socketeer.next_message().await.unwrap();
524        assert_eq!(message, received_message);
525        socketeer.close_connection().await.unwrap();
526    }
527
528    #[tokio::test]
529    async fn test_closed_socket() {
530        let server_address = get_mock_address(echo_server).await;
531        let mut socketeer: Socketeer<EchoJson> =
532            Socketeer::connect(&format!("ws://{server_address}"))
533                .await
534                .unwrap();
535        let close_request = EchoControlMessage::Close;
536        socketeer.send(close_request.clone()).await.unwrap();
537        let response = socketeer.next_message().await;
538        assert!(matches!(response.unwrap_err(), Error::WebsocketClosed));
539        let send_result = socketeer.send(close_request).await;
540        assert!(send_result.is_err());
541        let error = send_result.unwrap_err();
542        println!("Actual Error: {error:#?}");
543        assert!(matches!(error, Error::WebsocketClosed));
544    }
545
546    #[tokio::test]
547    async fn test_close_request() {
548        let server_address = get_mock_address(echo_server).await;
549        let socketeer: Socketeer<EchoJson> = Socketeer::connect(&format!("ws://{server_address}"))
550            .await
551            .unwrap();
552        socketeer.close_connection().await.unwrap();
553    }
554
555    #[tokio::test]
556    async fn test_connect_with_default_options() {
557        let server_address = get_mock_address(echo_server).await;
558        let mut socketeer: Socketeer<EchoJson> =
559            Socketeer::connect_with(&format!("ws://{server_address}"), ConnectOptions::default())
560                .await
561                .unwrap();
562        let message = EchoControlMessage::Message("Hello".to_string());
563        socketeer.send(message.clone()).await.unwrap();
564        let received_message = socketeer.next_message().await.unwrap();
565        assert_eq!(message, received_message);
566    }
567
568    #[tokio::test]
569    async fn test_raw_codec_message_roundtrip() {
570        // Typed send/next_message round-trip when the codec is RawCodec — the
571        // codec is identity, so frames pass through unchanged.
572        let server_address = get_mock_address(echo_server).await;
573        let mut socketeer: Socketeer<RawCodec> =
574            Socketeer::connect(&format!("ws://{server_address}"))
575                .await
576                .unwrap();
577        let raw_text = r#"{"Message":"raw hello"}"#;
578        socketeer
579            .send(Message::Text(raw_text.into()))
580            .await
581            .unwrap();
582        let received = socketeer.next_message().await.unwrap();
583        assert_eq!(received, Message::Text(raw_text.into()));
584    }
585
586    #[tokio::test]
587    async fn test_disabled_keepalive() {
588        let server_address = get_mock_address(echo_server).await;
589        let options = ConnectOptions {
590            keepalive_interval: None,
591            ..ConnectOptions::default()
592        };
593        let mut socketeer: Socketeer<EchoJson> =
594            Socketeer::connect_with(&format!("ws://{server_address}"), options)
595                .await
596                .unwrap();
597        let message = EchoControlMessage::Message("Hello".to_string());
598        socketeer.send(message.clone()).await.unwrap();
599        let received_message = socketeer.next_message().await.unwrap();
600        assert_eq!(message, received_message);
601    }
602
603    #[tokio::test]
604    async fn test_handler_on_connected() {
605        use serde::{Deserialize, Serialize};
606        use std::sync::Arc;
607        use tokio::sync::Mutex;
608
609        #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
610        struct AuthResponse {
611            status: String,
612        }
613
614        struct TestAuthHandler {
615            connected_count: Arc<Mutex<u32>>,
616        }
617
618        impl<C: Codec> ConnectionHandler<C> for TestAuthHandler {
619            async fn on_connected(
620                &mut self,
621                ctx: &mut HandshakeContext<'_, C>,
622            ) -> Result<(), Error> {
623                ctx.send_text(r#"{"action":"auth","token":"test-token"}"#)
624                    .await?;
625                let text = ctx.recv_text().await?;
626                let response: AuthResponse = serde_json::from_str(&text).unwrap();
627                assert_eq!(response.status, "authenticated");
628                let mut count = self.connected_count.lock().await;
629                *count += 1;
630                Ok(())
631            }
632        }
633
634        let connected_count = Arc::new(Mutex::new(0u32));
635        let handler = TestAuthHandler {
636            connected_count: connected_count.clone(),
637        };
638
639        let server_address = get_mock_address(auth_echo_server).await;
640        let mut socketeer: Socketeer<EchoJson, TestAuthHandler> = Socketeer::connect_with_codec(
641            &format!("ws://{server_address}"),
642            ConnectOptions::default(),
643            JsonCodec::new(),
644            handler,
645        )
646        .await
647        .unwrap();
648
649        assert_eq!(*connected_count.lock().await, 1);
650
651        let message = EchoControlMessage::Message("after auth".to_string());
652        socketeer.send(message.clone()).await.unwrap();
653        let received = socketeer.next_message().await.unwrap();
654        assert_eq!(message, received);
655    }
656
657    #[tokio::test]
658    async fn test_handler_reconnect() {
659        use std::sync::Arc;
660        use tokio::sync::Mutex;
661
662        struct ReconnectHandler {
663            connected_count: Arc<Mutex<u32>>,
664            disconnected_count: Arc<Mutex<u32>>,
665        }
666
667        impl<C: Codec> ConnectionHandler<C> for ReconnectHandler {
668            async fn on_connected(
669                &mut self,
670                ctx: &mut HandshakeContext<'_, C>,
671            ) -> Result<(), Error> {
672                ctx.send_text(r#"{"action":"auth","token":"test-token"}"#)
673                    .await?;
674                let _response = ctx.recv_text().await?;
675                let mut count = self.connected_count.lock().await;
676                *count += 1;
677                Ok(())
678            }
679
680            async fn on_disconnected(&mut self) {
681                let mut count = self.disconnected_count.lock().await;
682                *count += 1;
683            }
684        }
685
686        let connected_count = Arc::new(Mutex::new(0u32));
687        let disconnected_count = Arc::new(Mutex::new(0u32));
688        let handler = ReconnectHandler {
689            connected_count: connected_count.clone(),
690            disconnected_count: disconnected_count.clone(),
691        };
692
693        let server_address = get_mock_address(auth_echo_server).await;
694        let mut socketeer = Socketeer::<EchoJson, ReconnectHandler>::connect_with_codec(
695            &format!("ws://{server_address}"),
696            ConnectOptions::default(),
697            JsonCodec::new(),
698            handler,
699        )
700        .await
701        .unwrap();
702
703        assert_eq!(*connected_count.lock().await, 1);
704        assert_eq!(*disconnected_count.lock().await, 0);
705
706        // Send a message to verify connection works
707        let message = EchoControlMessage::Message("before reconnect".to_string());
708        socketeer.send(message.clone()).await.unwrap();
709        let received = socketeer.next_message().await.unwrap();
710        assert_eq!(message, received);
711
712        // Reconnect — handler should fire again
713        socketeer = socketeer.reconnect().await.unwrap();
714
715        assert_eq!(*connected_count.lock().await, 2);
716        assert_eq!(*disconnected_count.lock().await, 1);
717
718        // Verify connection still works after reconnect
719        let message = EchoControlMessage::Message("after reconnect".to_string());
720        socketeer.send(message.clone()).await.unwrap();
721        let received = socketeer.next_message().await.unwrap();
722        assert_eq!(message, received);
723
724        socketeer.close_connection().await.unwrap();
725    }
726
727    #[cfg(feature = "msgpack")]
728    #[tokio::test]
729    async fn test_msgpack_send_receive() {
730        type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
731
732        let server_address = get_mock_address(msgpack_echo_server).await;
733        let mut socketeer: Socketeer<EchoMsgPack> =
734            Socketeer::connect(&format!("ws://{server_address}"))
735                .await
736                .unwrap();
737        let message = EchoControlMessage::Message("msgpack hello".to_string());
738        socketeer.send(message.clone()).await.unwrap();
739        let received = socketeer.next_message().await.unwrap();
740        assert_eq!(message, received);
741        socketeer.close_connection().await.unwrap();
742    }
743
744    #[tokio::test]
745    async fn test_handler_uses_codec_driven_send_recv() {
746        // Exercises HandshakeContext::send / recv (the codec-driven path).
747        // Other handler tests only cover the raw send_text / recv_text helpers.
748        struct TypedHandshakeHandler;
749
750        impl ConnectionHandler<EchoJson> for TypedHandshakeHandler {
751            async fn on_connected(
752                &mut self,
753                ctx: &mut HandshakeContext<'_, EchoJson>,
754            ) -> Result<(), Error> {
755                ctx.send(&EchoControlMessage::Message("handshake".into()))
756                    .await?;
757                let echoed = ctx.recv().await?;
758                assert_eq!(echoed, EchoControlMessage::Message("handshake".into()));
759                Ok(())
760            }
761        }
762
763        let server_address = get_mock_address(echo_server).await;
764        let mut socketeer: Socketeer<EchoJson, TypedHandshakeHandler> =
765            Socketeer::connect_with_codec(
766                &format!("ws://{server_address}"),
767                ConnectOptions::default(),
768                JsonCodec::new(),
769                TypedHandshakeHandler,
770            )
771            .await
772            .unwrap();
773
774        // Confirm normal traffic still flows after the typed handshake.
775        let message = EchoControlMessage::Message("after handshake".into());
776        socketeer.send(message.clone()).await.unwrap();
777        assert_eq!(socketeer.next_message().await.unwrap(), message);
778        socketeer.close_connection().await.unwrap();
779    }
780
781    #[tokio::test]
782    async fn test_handshake_recv_close_with_raw_codec() {
783        // Regression: with RawCodec, recv_raw returns Ok(Message::Close(_)) and
784        // RawCodec::decode is the identity, so a peer-initiated close used to
785        // surface as Ok(Close) instead of Err(WebsocketClosed). recv must
786        // intercept Close before delegating to the codec.
787        struct CloseExpecting;
788
789        impl ConnectionHandler<RawCodec> for CloseExpecting {
790            async fn on_connected(
791                &mut self,
792                ctx: &mut HandshakeContext<'_, RawCodec>,
793            ) -> Result<(), Error> {
794                // Ask the echo server to close (JSON unit-variant for EchoControlMessage::Close).
795                ctx.send(&Message::Text(r#""Close""#.into())).await?;
796                let err = ctx.recv().await.unwrap_err();
797                assert!(matches!(err, Error::WebsocketClosed));
798                Ok(())
799            }
800        }
801
802        let server_address = get_mock_address(echo_server).await;
803        let _socketeer: Socketeer<RawCodec, CloseExpecting> = Socketeer::connect_with_codec(
804            &format!("ws://{server_address}"),
805            ConnectOptions::default(),
806            RawCodec::new(),
807            CloseExpecting,
808        )
809        .await
810        .unwrap();
811    }
812
813    #[tokio::test]
814    async fn test_extra_headers_used() {
815        // Cover ConnectOptions::build_request's loop body that copies
816        // `extra_headers` onto the upgrade request.
817        let server_address = get_mock_address(echo_server).await;
818        let mut headers = tokio_tungstenite::tungstenite::http::HeaderMap::new();
819        headers.insert("X-Test-Header", "socketeer".parse().unwrap());
820        let options = ConnectOptions {
821            extra_headers: headers,
822            ..ConnectOptions::default()
823        };
824        let mut socketeer: Socketeer<EchoJson> =
825            Socketeer::connect_with(&format!("ws://{server_address}"), options)
826                .await
827                .unwrap();
828        let message = EchoControlMessage::Message("hi".into());
829        socketeer.send(message.clone()).await.unwrap();
830        assert_eq!(socketeer.next_message().await.unwrap(), message);
831        socketeer.close_connection().await.unwrap();
832    }
833
834    #[tokio::test]
835    async fn test_auth_handler_bad_token() {
836        // Covers auth_echo_server's bad-token branch (sends {"status":"error"}
837        // and shuts down). The handler observes the error response, returns
838        // Ok, then a subsequent send fails because the server has closed.
839        struct BadTokenHandler;
840
841        impl<C: Codec> ConnectionHandler<C> for BadTokenHandler {
842            async fn on_connected(
843                &mut self,
844                ctx: &mut HandshakeContext<'_, C>,
845            ) -> Result<(), Error> {
846                ctx.send_text(r#"{"action":"auth","token":"WRONG"}"#)
847                    .await?;
848                let resp = ctx.recv_text().await?;
849                assert!(resp.contains("error"));
850                Ok(())
851            }
852        }
853
854        let server_address = get_mock_address(auth_echo_server).await;
855        let _socketeer: Socketeer<EchoJson, BadTokenHandler> = Socketeer::connect_with_codec(
856            &format!("ws://{server_address}"),
857            ConnectOptions::default(),
858            JsonCodec::new(),
859            BadTokenHandler,
860        )
861        .await
862        .unwrap();
863    }
864
865    #[cfg(feature = "msgpack")]
866    #[tokio::test]
867    async fn test_msgpack_send_ping() {
868        // Covers the SendPing arm of msgpack_echo_server.
869        type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
870
871        let server_address = get_mock_address(msgpack_echo_server).await;
872        let mut socketeer: Socketeer<EchoMsgPack> =
873            Socketeer::connect(&format!("ws://{server_address}"))
874                .await
875                .unwrap();
876        socketeer.send(EchoControlMessage::SendPing).await.unwrap();
877        // Server replies with a Ping; Socketeer auto-Pongs. Round-trip a real
878        // message to confirm the connection is still alive.
879        let message = EchoControlMessage::Message("after ping".into());
880        socketeer.send(message.clone()).await.unwrap();
881        assert_eq!(socketeer.next_message().await.unwrap(), message);
882        socketeer.close_connection().await.unwrap();
883    }
884
885    #[cfg(feature = "msgpack")]
886    #[tokio::test]
887    async fn test_msgpack_close_request() {
888        // Covers the Close arm of msgpack_echo_server.
889        type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
890
891        let server_address = get_mock_address(msgpack_echo_server).await;
892        let mut socketeer: Socketeer<EchoMsgPack> =
893            Socketeer::connect(&format!("ws://{server_address}"))
894                .await
895                .unwrap();
896        socketeer.send(EchoControlMessage::Close).await.unwrap();
897        let result = socketeer.next_message().await;
898        assert!(matches!(result.unwrap_err(), Error::WebsocketClosed));
899    }
900
901    #[tokio::test]
902    async fn test_socketeer_debug_format() {
903        let server_address = get_mock_address(echo_server).await;
904        let socketeer: Socketeer<EchoJson> = Socketeer::connect(&format!("ws://{server_address}"))
905            .await
906            .unwrap();
907        let formatted = format!("{socketeer:?}");
908        assert!(formatted.starts_with("Socketeer"));
909        assert!(formatted.contains("url"));
910    }
911
912    #[tokio::test]
913    async fn test_send_raw_next_raw_message() {
914        // Cover the raw send/receive escape hatches on a typed (non-RawCodec)
915        // connection: send_raw bypasses encoding, next_raw_message bypasses
916        // decoding, so we can speak frames the codec wouldn't otherwise
917        // produce or accept.
918        let server_address = get_mock_address(echo_server).await;
919        let mut socketeer: Socketeer<EchoJson> =
920            Socketeer::connect(&format!("ws://{server_address}"))
921                .await
922                .unwrap();
923        let raw_text = r#"{"Message":"raw recv"}"#;
924        socketeer
925            .send_raw(Message::Text(raw_text.into()))
926            .await
927            .unwrap();
928        let frame = socketeer.next_raw_message().await.unwrap();
929        assert_eq!(frame, Message::Text(raw_text.into()));
930        socketeer.close_connection().await.unwrap();
931    }
932
933    #[cfg(feature = "msgpack")]
934    #[tokio::test]
935    async fn test_handshake_send_binary_recv_raw() {
936        // Cover HandshakeContext::send_binary by sending a pre-encoded
937        // msgpack frame from on_connected and reading the binary echo back
938        // via recv_raw.
939        struct BinaryHandshake;
940
941        type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
942
943        impl ConnectionHandler<EchoMsgPack> for BinaryHandshake {
944            async fn on_connected(
945                &mut self,
946                ctx: &mut HandshakeContext<'_, EchoMsgPack>,
947            ) -> Result<(), Error> {
948                let payload =
949                    rmp_serde::to_vec_named(&EchoControlMessage::Message("binary".into())).unwrap();
950                ctx.send_binary(payload).await?;
951                let echo = ctx.recv_raw().await?;
952                assert!(matches!(echo, Message::Binary(_)));
953                Ok(())
954            }
955        }
956
957        let server_address = get_mock_address(msgpack_echo_server).await;
958        let socketeer: Socketeer<EchoMsgPack, BinaryHandshake> = Socketeer::connect_with_codec(
959            &format!("ws://{server_address}"),
960            ConnectOptions::default(),
961            MsgPackCodec::new(),
962            BinaryHandshake,
963        )
964        .await
965        .unwrap();
966        socketeer.close_connection().await.unwrap();
967    }
968
969    #[cfg(feature = "msgpack")]
970    #[tokio::test]
971    async fn test_handshake_recv_text_rejects_binary() {
972        // Cover the non-Text branch of HandshakeContext::recv_text by pointing
973        // it at a server that only speaks binary frames.
974        struct ExpectsTextOnBinary;
975
976        type EchoMsgPack = MsgPackCodec<EchoControlMessage, EchoControlMessage>;
977
978        impl ConnectionHandler<EchoMsgPack> for ExpectsTextOnBinary {
979            async fn on_connected(
980                &mut self,
981                ctx: &mut HandshakeContext<'_, EchoMsgPack>,
982            ) -> Result<(), Error> {
983                let payload =
984                    rmp_serde::to_vec_named(&EchoControlMessage::Message("hi".into())).unwrap();
985                ctx.send_binary(payload).await?;
986                // recv_text must reject the echoed Binary frame.
987                let err = ctx.recv_text().await.unwrap_err();
988                assert!(matches!(err, Error::UnexpectedMessageType(_)));
989                Ok(())
990            }
991        }
992
993        let server_address = get_mock_address(msgpack_echo_server).await;
994        let socketeer: Socketeer<EchoMsgPack, ExpectsTextOnBinary> = Socketeer::connect_with_codec(
995            &format!("ws://{server_address}"),
996            ConnectOptions::default(),
997            MsgPackCodec::new(),
998            ExpectsTextOnBinary,
999        )
1000        .await
1001        .unwrap();
1002        socketeer.close_connection().await.unwrap();
1003    }
1004
1005    #[tokio::test]
1006    async fn test_binary_custom_keepalive() {
1007        // The widening of custom_keepalive_message from Option<String> to
1008        // Option<Message> is otherwise unexercised. echo_server silently
1009        // ignores Binary frames, so the receive queue stays clean and we can
1010        // verify the connection survives a binary keepalive cycle.
1011        let server_address = get_mock_address(echo_server).await;
1012        let options = ConnectOptions {
1013            keepalive_interval: Some(Duration::from_millis(100)),
1014            custom_keepalive_message: Some(Message::Binary(Bytes::from_static(b"keepalive"))),
1015            ..ConnectOptions::default()
1016        };
1017        let mut socketeer: Socketeer<EchoJson> =
1018            Socketeer::connect_with(&format!("ws://{server_address}"), options)
1019                .await
1020                .unwrap();
1021
1022        // Wait long enough for at least a couple of keepalive ticks to fire.
1023        sleep(Duration::from_millis(350)).await;
1024
1025        let message = EchoControlMessage::Message("post-keepalive".into());
1026        socketeer.send(message.clone()).await.unwrap();
1027        assert_eq!(socketeer.next_message().await.unwrap(), message);
1028        socketeer.close_connection().await.unwrap();
1029    }
1030}