tokio_websocket_client/
lib.rs

1#![deny(clippy::all, clippy::pedantic, clippy::nursery)]
2#![doc=include_str!("../README.md")]
3
4mod client;
5pub(crate) mod command;
6mod connector;
7mod handler;
8mod message;
9mod stream_wrapper;
10
11pub use crate::{
12    client::Client,
13    connector::Connector,
14    handler::{Handler, RetryStrategy},
15    message::{CloseCode, Message},
16};
17
18#[doc(hidden)]
19pub use crate::stream_wrapper::StreamWrapper;
20
21use crate::command::Command;
22use futures::{Sink, SinkExt, StreamExt};
23
24#[derive(Debug, Clone, Eq, PartialEq)]
25enum LoopControl {
26    Continue,
27    Break,
28}
29
30/// Connect to a websocket server using the provided connector.
31///
32/// This function will indefinitly try to connect to the server
33/// unless the [`handler::on_connect_failure`](Handler::on_connect_failure) returns a [`RetryStrategy::Close`].
34pub async fn connect<C, H>(mut connector: C, mut handler: H) -> Option<Client>
35where
36    C: Connector + 'static,
37    H: Handler + 'static,
38    <C::BackendStream as Sink<C::BackendMessage>>::Error: std::error::Error + Send,
39{
40    let (to_send_tx, to_send_rx) = flume::bounded(C::request_queue_size());
41    let (command_tx, command_rx) = flume::bounded(1);
42    let (confirm_close_tx, confirm_close_rx) = flume::bounded(1);
43
44    let Ok(stream) = connect_stream(&mut connector, &mut handler).await else {
45        return None;
46    };
47
48    tokio::spawn(async move {
49        background_task(to_send_rx, command_rx, confirm_close_tx, stream, connector, handler).await;
50    });
51
52    Some(Client {
53        to_send: to_send_tx,
54        command_tx,
55        confirm_close_rx,
56    })
57}
58
59async fn reconnect<C, H>(
60    stream: &mut StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error>,
61    connector: &mut C,
62    handler: &mut H,
63) -> Result<(), LoopControl>
64where
65    C: Connector,
66    H: Handler,
67    <C::BackendStream as Sink<C::BackendMessage>>::Error: std::error::Error + Send,
68{
69    if let Err(reason) = stream.close().await {
70        log::error!("{reason}");
71    }
72
73    *stream = connect_stream(connector, handler).await?;
74
75    Ok(())
76}
77
78async fn connect_stream<C, H>(
79    connector: &mut C,
80    handler: &mut H,
81) -> Result<
82    StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error>,
83    LoopControl,
84>
85where
86    C: Connector,
87    H: Handler,
88    <StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error> as Sink<
89        C::Item,
90    >>::Error: std::error::Error,
91{
92    loop {
93        let stream = match C::connect().await {
94            Ok(stream) => stream,
95            Err(reason) => {
96                log::error!("Failed to connect: {reason}");
97                if handler.on_connect_failure().await == RetryStrategy::Close {
98                    log::error!("Stop retrying to connect.");
99                    break Err(LoopControl::Break);
100                }
101                let delay = connector.retry_delay().await;
102                log::error!("Retrying in {}s", delay.as_secs());
103                tokio::time::sleep(delay).await;
104                continue;
105            }
106        };
107
108        log::info!("Connection Successfully established");
109        handler.on_connect().await;
110
111        break Ok(stream);
112    }
113}
114
115async fn handle_disconnect<C, H>(
116    connector: &mut C,
117    handler: &mut H,
118    stream: &mut StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error>,
119) -> Result<(), LoopControl>
120where
121    C: Connector,
122    H: Handler,
123    <C::BackendStream as Sink<C::BackendMessage>>::Error: std::error::Error + Send,
124{
125    match handler.on_disconnect().await {
126        RetryStrategy::Close => {
127            log::error!("Do not retry to connect.");
128            if let Err(reason) = stream.close().await {
129                log::error!("Failed to close the stream: {reason}");
130            }
131            return Err(LoopControl::Break);
132        }
133
134        RetryStrategy::Retry => {
135            if let Err(control) = reconnect(stream, connector, handler).await {
136                log::error!("Do not retry to connect.");
137                if let Err(reason) = stream.close().await {
138                    log::error!("Failed to close the stream: {reason}");
139                }
140                // only Err on RetryStrategy::Close
141                return Err(control);
142            }
143        }
144    }
145
146    Ok(())
147}
148
149async fn handle_ping_pong<C, H>(
150    ponged: &mut bool,
151    last_ping: u8,
152    connector: &mut C,
153    handler: &mut H,
154    stream: &mut StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error>,
155) -> Result<(), LoopControl>
156where
157    C: Connector,
158    H: Handler,
159    <C::BackendStream as Sink<C::BackendMessage>>::Error: std::error::Error + Send,
160{
161    if *ponged {
162        if let Err(reason) = stream.send(Message::Ping(vec![last_ping]).into()).await {
163            log::error!("Failed to send Ping: {reason}");
164            handle_disconnect(connector, handler, stream).await?;
165            return Err(LoopControl::Continue);
166        }
167        *ponged = false;
168    } else {
169        log::error!("Last ping has not been ponged");
170        handle_disconnect(connector, handler, stream).await?;
171    }
172
173    Ok(())
174}
175
176async fn handle_message_to_send<C, H>(
177    send_result: Result<Message, flume::RecvError>,
178    connector: &mut C,
179    handler: &mut H,
180    stream: &mut StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error>,
181) -> Result<(), LoopControl>
182where
183    C: Connector,
184    H: Handler,
185    <C::BackendStream as Sink<C::BackendMessage>>::Error: std::error::Error + Send,
186{
187    if let Ok(message) = send_result {
188        if let Err(reason) = stream.send(message.into()).await {
189            log::error!("Failed to send message on stream: {reason}");
190            handle_disconnect(connector, handler, stream).await?;
191        }
192    } else {
193        log::info!("Closing the stream, all clients have been dropped");
194        return Err(LoopControl::Break);
195    }
196
197    Ok(())
198}
199
200async fn handle_reconnect<C, H>(
201    connector: &mut C,
202    handler: &mut H,
203    stream: &mut StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error>,
204) -> Result<(), LoopControl>
205where
206    C: Connector,
207    H: Handler,
208    <C::BackendStream as Sink<C::BackendMessage>>::Error: std::error::Error + Send,
209{
210    if reconnect(stream, connector, handler).await.is_err() {
211        if let Err(reason) = stream.close().await {
212            log::error!("Failed to close the stream: {reason}");
213        }
214        // only Err on RetryStrategy::Close
215        return Err(LoopControl::Break);
216    }
217
218    Ok(())
219}
220
221async fn handle_message<C, H>(
222    message: Result<C::Item, C::Error>,
223    want_to_close: bool,
224    ponged: &mut bool,
225    last_ping: &mut u8,
226    connector: &mut C,
227    handler: &mut H,
228    stream: &mut StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error>,
229) -> Result<(), LoopControl>
230where
231    C: Connector,
232    H: Handler,
233    <C::BackendStream as Sink<C::BackendMessage>>::Error: std::error::Error + Send,
234{
235    match message {
236        Ok(message) => match message.into() {
237            Message::Text(ref text) => {
238                handler.on_text(text).await;
239            }
240            Message::Binary(ref buf) => {
241                handler.on_binary(buf).await;
242            }
243            Message::Ping(data) => {
244                if let Err(reason) = stream.send(Message::Pong(data).into()).await {
245                    if !want_to_close {
246                        log::error!("Failed to send Pong to stream: {reason}");
247                        handle_disconnect(connector, handler, stream).await?;
248                    }
249                }
250            }
251            Message::Pong(buf) => {
252                if buf.len() != 1 {
253                    log::error!("Pong data is invalid: {buf:?}");
254                    handle_reconnect(connector, handler, stream).await?;
255                    return Err(LoopControl::Continue);
256                }
257
258                if buf[0] == *last_ping {
259                    *ponged = true;
260                    *last_ping = last_ping.wrapping_add(1);
261                } else if !want_to_close {
262                    log::error!(
263                        "Pong data is invalid, expected {last_ping} got {:?}",
264                        buf[0]
265                    );
266                    handle_reconnect(connector, handler, stream).await?;
267                }
268            }
269            Message::Close(code, reason) => {
270                if want_to_close {
271                    return Err(LoopControl::Break);
272                }
273
274                log::info!("Server closed with code {}: {reason}", u16::from(&code));
275                
276                if let Err(reason) = stream.send(C::Item::from(Message::Close(code.clone(), String::default()))).await {
277                    log::error!("Failed to send back Close to stream: {reason}");
278                }
279                
280                match handler.on_close(code, &reason).await {
281                    RetryStrategy::Close => {
282                        log::error!("Do not retry to connect.");
283                        if let Err(reason) = stream.close().await {
284                            log::error!("Failed to close the stream: {reason}");
285                        }
286                        return Err(LoopControl::Break);
287                    }
288                    RetryStrategy::Retry => {
289                        handle_reconnect(connector, handler, stream).await?;
290                    }
291                }
292            }
293        },
294        Err(reason) => {
295            log::error!("Failed to read stream: {reason}");
296            if !want_to_close {
297                handle_reconnect(connector, handler, stream).await?;
298            }
299        }
300    }
301
302    Ok(())
303}
304
305async fn handle_command<C, H>(
306    command: Result<Command, flume::RecvError>,
307    want_to_close: &mut bool,
308    connector: &mut C,
309    handler: &mut H,
310    stream: &mut StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error>,
311) -> Result<(), LoopControl>
312where
313    C: Connector,
314    H: Handler,
315    <C::BackendStream as Sink<C::BackendMessage>>::Error: std::error::Error + Send,
316{
317    match command {
318        Ok(Command::Reconnect) => {
319            log::info!("Forcing reconnection");
320            handle_reconnect(connector, handler, stream).await?;
321        }
322        Ok(Command::Close) => {
323            log::info!("Client requested to close the connection");
324            if let Err(reason) = stream
325                .send(C::Item::from(Message::Close(
326                    CloseCode::Normal,
327                    "Client is explicitly closing the stream".to_string(),
328                )))
329                .await
330            {
331                log::error!("Failed to send Close on the stream: {reason}");
332                return Err(LoopControl::Break);
333            }
334            *want_to_close = true;
335        }
336        Err(_) => {
337            log::info!("Closing the stream, all clients have been dropped");
338            return Err(LoopControl::Break);
339        }
340    }
341
342    Ok(())
343}
344
345#[allow(clippy::too_many_lines, clippy::redundant_pub_crate)]
346async fn background_task<C, H>(
347    to_send: flume::Receiver<Message>,
348    command_rx: flume::Receiver<Command>,
349    confirm_close_tx: flume::Sender<()>,
350    mut stream: StreamWrapper<'static, C::BackendStream, C::BackendMessage, C::Item, C::Error>,
351    mut connector: C,
352    mut handler: H,
353) where
354    C: Connector,
355    H: Handler,
356    <C::BackendStream as Sink<C::BackendMessage>>::Error: std::error::Error + Send,
357{
358    let mut ping_interval = tokio::time::interval(C::ping_interval());
359    let mut last_ping = 0u8;
360    let mut ponged = true; // initially true to avoid mistaking it for a failed ping/pong
361    let mut want_to_close = false;
362
363    loop {
364        tokio::select! {
365            _ = ping_interval.tick() => {
366                // do not ping if we are closing the connection.
367                if !want_to_close && matches!(handle_ping_pong(&mut ponged, last_ping, &mut connector, &mut handler, &mut stream).await, Err(LoopControl::Break)) {
368                    break;
369                }
370            },
371            res = to_send.recv_async() => {
372                if !want_to_close && matches!(handle_message_to_send(res, &mut connector, &mut handler, &mut stream).await, Err(LoopControl::Break)) {
373                    break;
374                }
375            }
376            res = command_rx.recv_async() => {
377                if !want_to_close && matches!(handle_command(res, &mut want_to_close, &mut connector, &mut handler, &mut stream).await, Err(LoopControl::Break)) {
378                    break;
379                }
380            }
381            Some(message) = stream.next() => {
382                if matches!(handle_message(message, want_to_close, &mut ponged, &mut last_ping, &mut connector, &mut handler, &mut stream).await, Err(LoopControl::Break)) {
383                    break;
384                }
385            }
386        }
387    }
388
389    if let Err(reason) = stream.close().await {
390        log::error!("{reason}");
391    }
392
393    if let Err(reason) = confirm_close_tx.send_async(()).await {
394        log::error!("Failed to send closing confirmation: {reason}");
395    }
396    
397    log::trace!("Background task complete");
398}