Skip to main content

socketeer/
lib.rs

1#![doc = include_str!("../README.md")]
2#![deny(missing_docs)]
3mod config;
4mod error;
5mod handler;
6#[cfg(feature = "mocking")]
7mod mock_server;
8
9pub use config::ConnectOptions;
10pub use error::Error;
11pub use handler::{ConnectionHandler, HandshakeContext, NoopHandler};
12#[cfg(feature = "mocking")]
13pub use mock_server::{EchoControlMessage, auth_echo_server, echo_server, get_mock_address};
14
15use bytes::Bytes;
16use futures::{SinkExt, StreamExt, stream::SplitSink, stream::SplitStream};
17use serde::{Deserialize, Serialize};
18use std::{fmt::Debug, time::Duration};
19use tokio::{
20    net::TcpStream,
21    select,
22    sync::{mpsc, oneshot},
23    time::sleep,
24};
25use tokio_tungstenite::{
26    MaybeTlsStream, WebSocketStream, connect_async,
27    tungstenite::{self, Message, Utf8Bytes, protocol::CloseFrame},
28};
29
30#[cfg(feature = "tracing")]
31use tracing::{debug, error, info, instrument, trace};
32use url::Url;
33
34#[derive(Debug)]
35struct TxChannelPayload {
36    message: Message,
37    response_tx: oneshot::Sender<Result<(), Error>>,
38}
39
40/// A WebSocket client that manages the connection to a WebSocket server.
41/// The client can send and receive messages, and will transparently handle protocol messages.
42///
43/// # Type Parameters
44///
45/// - `RxMessage`: The type of message that the client will receive from the server.
46/// - `TxMessage`: The type of message that the client will send to the server.
47/// - `Handler`: A [`ConnectionHandler`] for lifecycle hooks (auth, subscriptions).
48///   Defaults to [`NoopHandler`] for the simple case.
49/// - `CHANNEL_SIZE`: The size of the internal channels used to communicate between
50///   the task managing the WebSocket connection and the client.
51pub struct Socketeer<RxMessage, TxMessage, Handler = NoopHandler, const CHANNEL_SIZE: usize = 4> {
52    url: Url,
53    options: ConnectOptions,
54    handler: Handler,
55    receiver: mpsc::Receiver<Message>,
56    sender: mpsc::Sender<TxChannelPayload>,
57    socket_handle: tokio::task::JoinHandle<Result<(), Error>>,
58    _rx_message: std::marker::PhantomData<RxMessage>,
59    _tx_message: std::marker::PhantomData<TxMessage>,
60}
61
62impl<RxMessage, TxMessage, Handler, const CHANNEL_SIZE: usize> std::fmt::Debug
63    for Socketeer<RxMessage, TxMessage, Handler, CHANNEL_SIZE>
64{
65    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
66        f.debug_struct("Socketeer")
67            .field("url", &self.url)
68            .finish_non_exhaustive()
69    }
70}
71
72impl<
73    RxMessage: for<'a> Deserialize<'a> + Debug,
74    TxMessage: Serialize + Debug,
75    const CHANNEL_SIZE: usize,
76> Socketeer<RxMessage, TxMessage, NoopHandler, CHANNEL_SIZE>
77{
78    /// Create a `Socketeer` connected to the provided URL with default options.
79    /// Once connected, Socketeer manages the underlying WebSocket connection, transparently handling protocol messages.
80    /// # Errors
81    /// - If the URL cannot be parsed
82    /// - If the WebSocket connection to the requested URL fails
83    #[cfg_attr(feature = "tracing", instrument)]
84    pub async fn connect(url: &str) -> Result<Self, Error> {
85        Self::connect_with(url, ConnectOptions::default()).await
86    }
87
88    /// Create a `Socketeer` connected to the provided URL with custom connection options.
89    /// # Errors
90    /// - If the URL cannot be parsed
91    /// - If the WebSocket connection to the requested URL fails
92    #[cfg_attr(feature = "tracing", instrument(skip(options)))]
93    pub async fn connect_with(url: &str, options: ConnectOptions) -> Result<Self, Error> {
94        Socketeer::connect_with_handler(url, options, NoopHandler).await
95    }
96}
97
98impl<
99    RxMessage: for<'a> Deserialize<'a> + Debug,
100    TxMessage: Serialize + Debug,
101    Handler: ConnectionHandler,
102    const CHANNEL_SIZE: usize,
103> Socketeer<RxMessage, TxMessage, Handler, CHANNEL_SIZE>
104{
105    /// Create a `Socketeer` with a custom [`ConnectionHandler`] for lifecycle hooks.
106    ///
107    /// The handler's [`ConnectionHandler::on_connected`] method is called after the
108    /// WebSocket upgrade completes, before the socket loop starts. This is where
109    /// you should perform authentication handshakes and initial subscriptions.
110    /// # Errors
111    /// - If the URL cannot be parsed
112    /// - If the WebSocket connection to the requested URL fails
113    /// - If the handler's `on_connected` returns an error
114    #[cfg_attr(feature = "tracing", instrument(skip(options, handler)))]
115    pub async fn connect_with_handler(
116        url: &str,
117        options: ConnectOptions,
118        mut handler: Handler,
119    ) -> Result<Self, Error> {
120        let url = Url::parse(url).map_err(|source| Error::UrlParse {
121            url: url.to_string(),
122            source,
123        })?;
124
125        let request = options.build_request(&url)?;
126        #[allow(unused_variables)]
127        let (socket, response) = connect_async(request).await?;
128        #[cfg(feature = "tracing")]
129        debug!("Connection Successful, connection info: \n{:#?}", response);
130
131        let (mut sink, mut stream) = socket.split();
132        {
133            let mut ctx = HandshakeContext::new(&mut sink, &mut stream);
134            handler.on_connected(&mut ctx).await?;
135        }
136
137        let keepalive_interval = options.keepalive_interval;
138        let keepalive_message = options.custom_keepalive_message.clone();
139
140        let (tx_tx, tx_rx) = mpsc::channel::<TxChannelPayload>(CHANNEL_SIZE);
141        let (rx_tx, rx_rx) = mpsc::channel::<Message>(CHANNEL_SIZE);
142
143        let socket_handle = tokio::spawn(async move {
144            socket_loop_split(
145                tx_rx,
146                rx_tx,
147                sink,
148                stream,
149                keepalive_interval,
150                keepalive_message,
151            )
152            .await
153        });
154        Ok(Socketeer {
155            url,
156            options,
157            handler,
158            receiver: rx_rx,
159            sender: tx_tx,
160            socket_handle,
161            _rx_message: std::marker::PhantomData,
162            _tx_message: std::marker::PhantomData,
163        })
164    }
165
166    /// Wait for the next parsed message from the WebSocket connection.
167    ///
168    /// # Errors
169    ///
170    /// - If the WebSocket connection is closed or otherwise errored
171    /// - If the message cannot be deserialized
172    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
173    pub async fn next_message(&mut self) -> Result<RxMessage, Error> {
174        let Some(message) = self.receiver.recv().await else {
175            return Err(Error::WebsocketClosed);
176        };
177        match message {
178            Message::Text(text) => {
179                #[cfg(feature = "tracing")]
180                trace!("Received text message: {:?}", text);
181                let message = serde_json::from_str(&text)?;
182                Ok(message)
183            }
184            Message::Binary(message) => {
185                #[cfg(feature = "tracing")]
186                trace!("Received binary message: {:?}", message);
187                let message = serde_json::from_slice(&message)?;
188                Ok(message)
189            }
190            _ => Err(Error::UnexpectedMessageType(Box::new(message))),
191        }
192    }
193
194    /// Send a message to the WebSocket connection.
195    /// This function will wait for the message to be sent before returning.
196    ///
197    /// # Errors
198    ///
199    /// - If the message cannot be serialized
200    /// - If the WebSocket connection is closed, or otherwise errored
201    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
202    pub async fn send(&self, message: TxMessage) -> Result<(), Error> {
203        #[cfg(feature = "tracing")]
204        trace!("Sending message: {:?}", message);
205
206        let (tx, rx) = oneshot::channel::<Result<(), Error>>();
207        let message = serde_json::to_string(&message)?;
208
209        self.sender
210            .send(TxChannelPayload {
211                message: Message::Text(message.into()),
212                response_tx: tx,
213            })
214            .await
215            .map_err(|_| Error::WebsocketClosed)?;
216        // We'll ensure that we always respond before dropping the tx channel
217        match rx.await {
218            Ok(result) => result,
219            Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
220        }
221    }
222
223    /// Receive the next raw [`Message`] from the WebSocket connection without deserialization.
224    ///
225    /// This is useful for protocols that don't use JSON or need to inspect the raw message.
226    ///
227    /// # Errors
228    ///
229    /// - If the WebSocket connection is closed or otherwise errored
230    pub async fn next_raw_message(&mut self) -> Result<Message, Error> {
231        self.receiver.recv().await.ok_or(Error::WebsocketClosed)
232    }
233
234    /// Send a raw [`Message`] to the WebSocket connection without serialization.
235    ///
236    /// This is useful for sending non-JSON messages (e.g., plain text keepalives)
237    /// or binary data that is already encoded.
238    ///
239    /// # Errors
240    ///
241    /// - If the WebSocket connection is closed, or otherwise errored
242    pub async fn send_raw(&self, message: Message) -> Result<(), Error> {
243        let (tx, rx) = oneshot::channel::<Result<(), Error>>();
244        self.sender
245            .send(TxChannelPayload {
246                message,
247                response_tx: tx,
248            })
249            .await
250            .map_err(|_| Error::WebsocketClosed)?;
251        match rx.await {
252            Ok(result) => result,
253            Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
254        }
255    }
256
257    /// Consume self, closing down any remaining send/receive, and return a new Socketeer instance if successful.
258    /// This function attempts to close the connection gracefully before returning,
259    /// but will not return an error if the connection is already closed,
260    /// as its intended use is to re-establish a failed connection.
261    ///
262    /// The handler's [`ConnectionHandler::on_disconnected`] is called before closing,
263    /// and [`ConnectionHandler::on_connected`] is called after reconnecting.
264    /// # Errors
265    /// - If a new connection cannot be established
266    /// - If the handler's `on_connected` returns an error
267    pub async fn reconnect(self) -> Result<Self, Error> {
268        let url = self.url.as_str().to_owned();
269        let options = self.options.clone();
270        let mut handler = self.handler;
271        #[cfg(feature = "tracing")]
272        info!("Reconnecting");
273        handler.on_disconnected().await;
274        // Attempt graceful close, but don't fail if already closed
275        match send_close(&self.sender).await {
276            Ok(()) => (),
277            #[allow(unused_variables)]
278            Err(e) => {
279                #[cfg(feature = "tracing")]
280                error!("Socket Loop already stopped: {}", e);
281            }
282        }
283        Self::connect_with_handler(&url, options, handler).await
284    }
285
286    /// Close the WebSocket connection gracefully.
287    /// This function will wait for the connection to close before returning.
288    /// # Errors
289    /// - If the WebSocket connection is already closed
290    /// - If the WebSocket connection cannot be closed
291    #[cfg_attr(feature = "tracing", instrument(skip(self)))]
292    pub async fn close_connection(self) -> Result<(), Error> {
293        #[cfg(feature = "tracing")]
294        debug!("Closing Connection");
295        send_close(&self.sender).await?;
296        match self.socket_handle.await {
297            Ok(result) => result,
298            Err(_) => unreachable!("Socket loop does not panic, and is not cancelled"),
299        }
300    }
301}
302
303pub(crate) type WebSocketStreamType = WebSocketStream<MaybeTlsStream<TcpStream>>;
304type SocketSink = SplitSink<WebSocketStreamType, Message>;
305type SocketStream = SplitStream<WebSocketStreamType>;
306
307enum LoopState {
308    Running,
309    Error(Error),
310    Closed,
311}
312
313/// Send a close frame via the tx channel and wait for confirmation.
314async fn send_close(sender: &mpsc::Sender<TxChannelPayload>) -> Result<(), Error> {
315    let (tx, rx) = oneshot::channel::<Result<(), Error>>();
316    sender
317        .send(TxChannelPayload {
318            message: Message::Close(Some(CloseFrame {
319                code: tungstenite::protocol::frame::coding::CloseCode::Normal,
320                reason: Utf8Bytes::from_static("Closing Connection"),
321            })),
322            response_tx: tx,
323        })
324        .await
325        .map_err(|_| Error::WebsocketClosed)?;
326    match rx.await {
327        Ok(result) => result,
328        Err(_) => unreachable!("Socket loop always sends response before dropping one-shot"),
329    }
330}
331
332#[cfg_attr(
333    feature = "tracing",
334    instrument(skip(keepalive_interval, keepalive_message))
335)]
336async fn socket_loop_split(
337    mut receiver: mpsc::Receiver<TxChannelPayload>,
338    mut sender: mpsc::Sender<Message>,
339    mut sink: SocketSink,
340    mut stream: SocketStream,
341    keepalive_interval: Option<Duration>,
342    keepalive_message: Option<String>,
343) -> Result<(), Error> {
344    let mut state = LoopState::Running;
345    while matches!(state, LoopState::Running) {
346        state = if let Some(interval) = keepalive_interval {
347            select! {
348                outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
349                incoming_message = stream.next() => socket_message_received(incoming_message, &mut sender, &mut sink).await,
350                () = sleep(interval) => send_keepalive(&mut sink, keepalive_message.as_deref()).await,
351            }
352        } else {
353            select! {
354                outgoing_message = receiver.recv() => send_socket_message(outgoing_message, &mut sink).await,
355                incoming_message = stream.next() => socket_message_received(incoming_message, &mut sender, &mut sink).await,
356            }
357        };
358    }
359    match state {
360        LoopState::Error(e) => Err(e),
361        LoopState::Closed => Ok(()),
362        LoopState::Running => unreachable!("We only exit when closed or errored"),
363    }
364}
365
366#[cfg_attr(feature = "tracing", instrument)]
367async fn send_socket_message(
368    message: Option<TxChannelPayload>,
369    sink: &mut SocketSink,
370) -> LoopState {
371    if let Some(message) = message {
372        #[cfg(feature = "tracing")]
373        debug!("Sending message: {:?}", message);
374        let send_result = sink.send(message.message).await.map_err(Error::from);
375        let socket_error = send_result.is_err();
376        match message.response_tx.send(send_result) {
377            Ok(()) => {
378                if socket_error {
379                    LoopState::Error(Error::WebsocketClosed)
380                } else {
381                    LoopState::Running
382                }
383            }
384            Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
385        }
386    } else {
387        #[cfg(feature = "tracing")]
388        error!("Socketeer dropped without closing connection");
389        LoopState::Error(Error::SocketeerDroppedWithoutClosing)
390    }
391}
392
393#[cfg_attr(feature = "tracing", instrument)]
394async fn socket_message_received(
395    message: Option<Result<Message, tungstenite::Error>>,
396    sender: &mut mpsc::Sender<Message>,
397    sink: &mut SocketSink,
398) -> LoopState {
399    const PONG_BYTES: Bytes = Bytes::from_static(b"pong");
400    match message {
401        Some(Ok(message)) => match message {
402            Message::Ping(_) => {
403                let send_result = sink
404                    .send(Message::Pong(PONG_BYTES))
405                    .await
406                    .map_err(Error::from);
407                match send_result {
408                    Ok(()) => LoopState::Running,
409                    Err(e) => {
410                        #[cfg(feature = "tracing")]
411                        error!("Error sending Pong: {:?}", e);
412                        LoopState::Error(e)
413                    }
414                }
415            }
416            Message::Close(_) => {
417                let close_result = sink.close().await;
418                match close_result {
419                    Ok(()) => LoopState::Closed,
420                    Err(e) => {
421                        #[cfg(feature = "tracing")]
422                        error!("Error sending Close: {:?}", e);
423                        LoopState::Error(Error::from(e))
424                    }
425                }
426            }
427            Message::Text(_) | Message::Binary(_) => match sender.send(message).await {
428                Ok(()) => LoopState::Running,
429                Err(_) => LoopState::Error(Error::SocketeerDroppedWithoutClosing),
430            },
431            _ => LoopState::Running,
432        },
433        Some(Err(e)) => {
434            #[cfg(feature = "tracing")]
435            error!("Error receiving message: {:?}", e);
436            LoopState::Error(Error::WebsocketError(e))
437        }
438        None => {
439            #[cfg(feature = "tracing")]
440            info!("Websocket Closed, closing rx channel");
441            LoopState::Error(Error::WebsocketClosed)
442        }
443    }
444}
445
446#[cfg_attr(feature = "tracing", instrument)]
447async fn send_keepalive(sink: &mut SocketSink, custom_message: Option<&str>) -> LoopState {
448    let message = if let Some(text) = custom_message {
449        #[cfg(feature = "tracing")]
450        info!("Timeout waiting for message, sending custom keepalive");
451        Message::Text(text.into())
452    } else {
453        #[cfg(feature = "tracing")]
454        info!("Timeout waiting for message, sending Ping");
455        Message::Ping(Bytes::new())
456    };
457    let result = sink.send(message).await.map_err(Error::from);
458    match result {
459        Ok(()) => LoopState::Running,
460        Err(e) => {
461            #[cfg(feature = "tracing")]
462            error!("Error sending keepalive: {:?}", e);
463            LoopState::Error(e)
464        }
465    }
466}
467
468#[cfg(all(test, feature = "mocking"))]
469mod tests {
470    use super::*;
471    use tokio::time::sleep;
472
473    #[tokio::test]
474    async fn test_server_startup() {
475        let _server_address = get_mock_address(echo_server).await;
476    }
477
478    #[tokio::test]
479    async fn test_connection() {
480        let server_address = get_mock_address(echo_server).await;
481        let _socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
482            Socketeer::connect(&format!("ws://{server_address}",))
483                .await
484                .unwrap();
485    }
486
487    #[tokio::test]
488    async fn test_bad_url() {
489        let error: Result<Socketeer<EchoControlMessage, EchoControlMessage>, Error> =
490            Socketeer::connect("Not a URL").await;
491        assert!(matches!(error.unwrap_err(), Error::UrlParse { .. }));
492    }
493
494    #[tokio::test]
495    async fn test_send_receive() {
496        let server_address = get_mock_address(echo_server).await;
497        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
498            Socketeer::connect(&format!("ws://{server_address}",))
499                .await
500                .unwrap();
501        let message = EchoControlMessage::Message("Hello".to_string());
502        socketeer.send(message.clone()).await.unwrap();
503        let received_message = socketeer.next_message().await.unwrap();
504        assert_eq!(message, received_message);
505    }
506
507    #[tokio::test]
508    async fn test_ping_request() {
509        let server_address = get_mock_address(echo_server).await;
510        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
511            Socketeer::connect(&format!("ws://{server_address}",))
512                .await
513                .unwrap();
514        let ping_request = EchoControlMessage::SendPing;
515        socketeer.send(ping_request).await.unwrap();
516        // The server will respond with a ping request, which Socketeer will transparently respond to
517        let message = EchoControlMessage::Message("Hello".to_string());
518        socketeer.send(message.clone()).await.unwrap();
519        let received_message = socketeer.next_message().await.unwrap();
520        assert_eq!(received_message, message);
521        // We should send a ping in here
522        sleep(Duration::from_millis(2200)).await;
523        // Ensure everything shuts down so we exercize the ping functionality fully
524        socketeer.close_connection().await.unwrap();
525    }
526
527    #[tokio::test]
528    async fn test_reconnection() {
529        let server_address = get_mock_address(echo_server).await;
530        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
531            Socketeer::connect(&format!("ws://{server_address}",))
532                .await
533                .unwrap();
534        let message = EchoControlMessage::Message("Hello".to_string());
535        socketeer.send(message.clone()).await.unwrap();
536        let received_message = socketeer.next_message().await.unwrap();
537        assert_eq!(message, received_message);
538        socketeer = socketeer.reconnect().await.unwrap();
539        let message = EchoControlMessage::Message("Hello".to_string());
540        socketeer.send(message.clone()).await.unwrap();
541        let received_message = socketeer.next_message().await.unwrap();
542        assert_eq!(message, received_message);
543        socketeer.close_connection().await.unwrap();
544    }
545
546    #[tokio::test]
547    async fn test_closed_socket() {
548        let server_address = get_mock_address(echo_server).await;
549        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
550            Socketeer::connect(&format!("ws://{server_address}",))
551                .await
552                .unwrap();
553        let close_request = EchoControlMessage::Close;
554        socketeer.send(close_request.clone()).await.unwrap();
555        let response = socketeer.next_message().await;
556        assert!(matches!(response.unwrap_err(), Error::WebsocketClosed));
557        let send_result = socketeer.send(close_request).await;
558        assert!(send_result.is_err());
559        let error = send_result.unwrap_err();
560        println!("Actual Error: {error:#?}");
561        assert!(matches!(error, Error::WebsocketClosed));
562    }
563
564    #[tokio::test]
565    async fn test_close_request() {
566        let server_address = get_mock_address(echo_server).await;
567        let socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
568            Socketeer::connect(&format!("ws://{server_address}",))
569                .await
570                .unwrap();
571        socketeer.close_connection().await.unwrap();
572    }
573
574    #[tokio::test]
575    async fn test_connect_with_default_options() {
576        let server_address = get_mock_address(echo_server).await;
577        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
578            Socketeer::connect_with(&format!("ws://{server_address}"), ConnectOptions::default())
579                .await
580                .unwrap();
581        let message = EchoControlMessage::Message("Hello".to_string());
582        socketeer.send(message.clone()).await.unwrap();
583        let received_message = socketeer.next_message().await.unwrap();
584        assert_eq!(message, received_message);
585    }
586
587    #[tokio::test]
588    async fn test_send_raw_receive_raw() {
589        let server_address = get_mock_address(echo_server).await;
590        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
591            Socketeer::connect(&format!("ws://{server_address}"))
592                .await
593                .unwrap();
594        let raw_text = r#"{"Message":"raw hello"}"#;
595        socketeer
596            .send_raw(Message::Text(raw_text.into()))
597            .await
598            .unwrap();
599        let received = socketeer.next_raw_message().await.unwrap();
600        assert_eq!(received, Message::Text(raw_text.into()));
601    }
602
603    #[tokio::test]
604    async fn test_disabled_keepalive() {
605        let server_address = get_mock_address(echo_server).await;
606        let options = ConnectOptions {
607            keepalive_interval: None,
608            ..ConnectOptions::default()
609        };
610        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage> =
611            Socketeer::connect_with(&format!("ws://{server_address}"), options)
612                .await
613                .unwrap();
614        let message = EchoControlMessage::Message("Hello".to_string());
615        socketeer.send(message.clone()).await.unwrap();
616        let received_message = socketeer.next_message().await.unwrap();
617        assert_eq!(message, received_message);
618    }
619
620    #[tokio::test]
621    async fn test_handler_on_connected() {
622        use serde::{Deserialize, Serialize};
623        use std::sync::Arc;
624        use tokio::sync::Mutex;
625
626        #[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
627        struct AuthResponse {
628            status: String,
629        }
630
631        struct TestAuthHandler {
632            connected_count: Arc<Mutex<u32>>,
633        }
634
635        impl ConnectionHandler for TestAuthHandler {
636            async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_>) -> Result<(), Error> {
637                ctx.send_text(r#"{"action":"auth","token":"test-token"}"#)
638                    .await?;
639                let response: AuthResponse = ctx.recv_json().await?;
640                assert_eq!(response.status, "authenticated");
641                let mut count = self.connected_count.lock().await;
642                *count += 1;
643                Ok(())
644            }
645        }
646
647        let connected_count = Arc::new(Mutex::new(0u32));
648        let handler = TestAuthHandler {
649            connected_count: connected_count.clone(),
650        };
651
652        let server_address = get_mock_address(auth_echo_server).await;
653        let mut socketeer: Socketeer<EchoControlMessage, EchoControlMessage, TestAuthHandler> =
654            Socketeer::connect_with_handler(
655                &format!("ws://{server_address}"),
656                ConnectOptions::default(),
657                handler,
658            )
659            .await
660            .unwrap();
661
662        assert_eq!(*connected_count.lock().await, 1);
663
664        let message = EchoControlMessage::Message("after auth".to_string());
665        socketeer.send(message.clone()).await.unwrap();
666        let received = socketeer.next_message().await.unwrap();
667        assert_eq!(message, received);
668    }
669
670    #[tokio::test]
671    async fn test_handler_reconnect() {
672        use std::sync::Arc;
673        use tokio::sync::Mutex;
674
675        struct ReconnectHandler {
676            connected_count: Arc<Mutex<u32>>,
677            disconnected_count: Arc<Mutex<u32>>,
678        }
679
680        impl ConnectionHandler for ReconnectHandler {
681            async fn on_connected(&mut self, ctx: &mut HandshakeContext<'_>) -> Result<(), Error> {
682                ctx.send_text(r#"{"action":"auth","token":"test-token"}"#)
683                    .await?;
684                let _response = ctx.recv_text().await?;
685                let mut count = self.connected_count.lock().await;
686                *count += 1;
687                Ok(())
688            }
689
690            async fn on_disconnected(&mut self) {
691                let mut count = self.disconnected_count.lock().await;
692                *count += 1;
693            }
694        }
695
696        let connected_count = Arc::new(Mutex::new(0u32));
697        let disconnected_count = Arc::new(Mutex::new(0u32));
698        let handler = ReconnectHandler {
699            connected_count: connected_count.clone(),
700            disconnected_count: disconnected_count.clone(),
701        };
702
703        let server_address = get_mock_address(auth_echo_server).await;
704        let mut socketeer =
705            Socketeer::<EchoControlMessage, EchoControlMessage, ReconnectHandler>::connect_with_handler(
706                &format!("ws://{server_address}"),
707                ConnectOptions::default(),
708                handler,
709            )
710            .await
711            .unwrap();
712
713        assert_eq!(*connected_count.lock().await, 1);
714        assert_eq!(*disconnected_count.lock().await, 0);
715
716        // Send a message to verify connection works
717        let message = EchoControlMessage::Message("before reconnect".to_string());
718        socketeer.send(message.clone()).await.unwrap();
719        let received = socketeer.next_message().await.unwrap();
720        assert_eq!(message, received);
721
722        // Reconnect — handler should fire again
723        socketeer = socketeer.reconnect().await.unwrap();
724
725        assert_eq!(*connected_count.lock().await, 2);
726        assert_eq!(*disconnected_count.lock().await, 1);
727
728        // Verify connection still works after reconnect
729        let message = EchoControlMessage::Message("after reconnect".to_string());
730        socketeer.send(message.clone()).await.unwrap();
731        let received = socketeer.next_message().await.unwrap();
732        assert_eq!(message, received);
733
734        socketeer.close_connection().await.unwrap();
735    }
736}