Skip to main content

ws_reconnect_client/
client.rs

1use backon::{BackoffBuilder, ExponentialBuilder};
2use futures_util::{SinkExt, StreamExt};
3use serde::{de::DeserializeOwned, Serialize};
4use std::time::Duration;
5use tokio_tungstenite::tungstenite::Message;
6
7use crate::{
8    connect_with_retry, MessageStream, PingManager, Result, WebSocketError, WsConnectionConfig,
9    WsReader, WsWriter,
10};
11
12/// High-level WebSocket client with automatic reconnection, ping/pong handling, and JSON support
13#[derive(Clone)]
14pub struct WebSocketClient<T>
15where
16    T: DeserializeOwned,
17{
18    config: WsConnectionConfig,
19    _phantom: std::marker::PhantomData<T>,
20}
21
22impl<T> WebSocketClient<T>
23where
24    T: DeserializeOwned,
25{
26    /// Create a new WebSocket client with the given configuration
27    pub fn new(config: WsConnectionConfig) -> Self {
28        Self {
29            config,
30            _phantom: std::marker::PhantomData,
31        }
32    }
33
34    /// Connect to the WebSocket server without sending any subscription
35    ///
36    /// Returns a writer and reader for manual message handling.
37    pub async fn connect(&self) -> Result<(WsWriter, WsReader)> {
38        connect_with_retry(&self.config).await
39    }
40
41    /// Connect to the WebSocket server and return a MessageStream
42    ///
43    /// The MessageStream automatically handles PING/PONG and yields parsed messages.
44    pub async fn connect_stream(&self) -> Result<MessageStream<T>> {
45        use std::sync::Arc;
46        use tokio::sync::Mutex;
47
48        let (writer, reader) = self.connect().await?;
49        let shared_writer = Arc::new(Mutex::new(Some(writer)));
50        Ok(MessageStream::new(
51            reader,
52            shared_writer,
53            self.config.ping_interval_secs,
54        ))
55    }
56
57    /// Connect and optionally send an initial subscription message
58    ///
59    /// Deprecated: Use `connect()` + `send_subscription()` or `connect_stream()` instead.
60    pub async fn connect_and_subscribe<S: Serialize>(
61        &self,
62        subscription: Option<&S>,
63    ) -> Result<(WsWriter, WsReader)> {
64        let (mut writer, reader) = self.connect().await?;
65
66        if let Some(sub) = subscription {
67            send_subscription(&mut writer, sub).await?;
68        }
69
70        Ok((writer, reader))
71    }
72
73    /// Start listening to messages with automatic reconnection and ping/pong handling
74    ///
75    /// Uses exponential backoff for reconnection attempts (same config as initial connection).
76    ///
77    /// # Arguments
78    /// * `subscription` - Optional subscription message to send on each connection
79    /// * `handler` - Callback function that processes each received message
80    ///
81    /// # Example
82    /// ```no_run
83    /// # use ws_reconnect_client::{WebSocketClient, WsConnectionConfig};
84    /// # use serde::{Deserialize, Serialize};
85    /// # #[derive(Debug, Deserialize, Serialize)]
86    /// # struct MyMessage { value: String }
87    /// # #[tokio::main]
88    /// # async fn main() -> Result<(), Box<dyn std::error::Error>> {
89    /// # let config = WsConnectionConfig::new("wss://example.com");
90    /// # let subscription = MyMessage { value: "sub".into() };
91    /// let client = WebSocketClient::<MyMessage>::new(config);
92    /// client.listen(Some(&subscription), |msg| {
93    ///     println!("Received: {:?}", msg);
94    ///     Ok(())
95    /// }).await?;
96    /// # Ok(())
97    /// # }
98    /// ```
99    pub async fn listen<S, F>(&self, subscription: Option<S>, mut handler: F) -> Result<()>
100    where
101        S: Serialize + Clone,
102        F: FnMut(T) -> Result<()>,
103    {
104        if !self.config.auto_reconnect {
105            // No auto-reconnect, just listen once
106            return self.listen_once(subscription.as_ref(), &mut handler).await;
107        }
108
109        // Auto-reconnect enabled: use exponential backoff for reconnections
110        let backoff = ExponentialBuilder::default()
111            .with_min_delay(Duration::from_millis(self.config.initial_backoff_ms))
112            .with_max_delay(Duration::from_millis(self.config.max_backoff_ms))
113            .with_max_times(self.config.max_retries);
114
115        let mut backoff_iter = backoff.build();
116        let mut attempt = 0;
117
118        loop {
119            match self.listen_once(subscription.as_ref(), &mut handler).await {
120                Ok(_) => {
121                    // Connection closed normally, reset backoff and reconnect
122                    backoff_iter = backoff.build();
123                    attempt = 0;
124                }
125                Err(e) => {
126                    // Connection error
127                    attempt += 1;
128
129                    if attempt >= self.config.max_retries {
130                        // Max retries reached
131                        return Err(e);
132                    }
133
134                    // Get next backoff delay from backon
135                    if let Some(delay) = backoff_iter.next() {
136                        tokio::time::sleep(delay).await;
137                    } else {
138                        // Backoff exhausted (shouldn't happen due to max_retries check)
139                        return Err(e);
140                    }
141                }
142            }
143        }
144    }
145
146    /// Listen once (single connection lifecycle)
147    async fn listen_once<S, F>(&self, subscription: Option<&S>, handler: &mut F) -> Result<()>
148    where
149        S: Serialize,
150        F: FnMut(T) -> Result<()>,
151    {
152        let (mut writer, mut reader) = self.connect_and_subscribe(subscription).await?;
153
154        let mut ping_manager = if self.config.ping_interval_secs > 0 {
155            Some(PingManager::new(self.config.ping_interval_secs))
156        } else {
157            None
158        };
159
160        loop {
161            tokio::select! {
162                // Handle incoming messages
163                msg = reader.next() => {
164                    match msg {
165                        Some(Ok(Message::Text(text))) => {
166                            self.handle_text_message(&text, handler)?;
167                        }
168
169                        Some(Ok(Message::Ping(ping))) => {
170                            writer.send(Message::Pong(ping))
171                                .await
172                                .map_err(|_| WebSocketError::SendError)?;
173                        }
174
175                        Some(Ok(Message::Pong(_))) => {
176                            // Pong received, connection is alive
177                        }
178
179                        Some(Ok(Message::Close(_))) => {
180                            return Err(WebSocketError::ConnectionClosed);
181                        }
182
183                        Some(Err(e)) => {
184                            return Err(WebSocketError::Tungstenite(e));
185                        }
186
187                        None => {
188                            return Err(WebSocketError::ConnectionClosed);
189                        }
190
191                        _ => {}
192                    }
193                }
194
195                // Send periodic pings if enabled
196                _ = async {
197                    if let Some(ref mut pm) = ping_manager {
198                        pm.wait_for_next_ping().await;
199                    } else {
200                        // Never completes if pings are disabled
201                        std::future::pending::<()>().await;
202                    }
203                } => {
204                    writer.send(Message::Ping(vec![].into()))
205                        .await
206                        .map_err(|_| WebSocketError::SendError)?;
207                }
208            }
209        }
210    }
211
212    /// Handle text message - try to parse as single message or array of messages
213    fn handle_text_message<F>(&self, text: &str, handler: &mut F) -> Result<()>
214    where
215        F: FnMut(T) -> Result<()>,
216    {
217        // Skip empty messages (keepalive frames from some servers)
218        if text.trim().is_empty() {
219            return Ok(());
220        }
221
222        // Try to parse as array first
223        if text.trim_start().starts_with('[') {
224            match serde_json::from_str::<Vec<T>>(text) {
225                Ok(messages) => {
226                    for msg in messages {
227                        handler(msg)?;
228                    }
229                    return Ok(());
230                }
231                Err(e) => {
232                    // Array parsing failed, log the error and try single message
233                    eprintln!("⚠️  WebSocket: Failed to parse as array: {}", e);
234                }
235            }
236        }
237
238        // Try to parse as single message
239        match serde_json::from_str::<T>(text) {
240            Ok(msg) => {
241                handler(msg)?;
242                Ok(())
243            }
244            Err(e) => {
245                // Could be a subscription confirmation or other non-T message
246                // Log raw message for debugging
247                eprintln!("⚠️  WebSocket: Failed to parse message: {}", e);
248                eprintln!("📨 Raw message: {}", text);
249                // Don't fail, just skip it
250                Err(WebSocketError::SerializationError(e))
251            }
252        }
253    }
254}
255
256/// Send a subscription message over an existing WebSocket connection
257///
258/// This is a helper function for sending JSON-serialized subscription messages.
259///
260/// # Example
261/// ```no_run
262/// # use ws_reconnect_client::{WsWriter, send_subscription};
263/// # #[derive(serde::Serialize)]
264/// # struct Subscription { channel: String }
265/// # async fn example(writer: &mut WsWriter) -> Result<(), Box<dyn std::error::Error>> {
266/// let sub = Subscription { channel: "market".to_string() };
267/// send_subscription(writer, &sub).await?;
268/// # Ok(())
269/// # }
270/// ```
271pub async fn send_subscription<S: Serialize>(writer: &mut WsWriter, subscription: &S) -> Result<()> {
272    let sub_json = serde_json::to_string(subscription)?;
273    writer
274        .send(Message::Text(sub_json.into()))
275        .await
276        .map_err(|_| WebSocketError::SendError)
277}
278
279/// Builder pattern for WebSocketClient
280pub struct WebSocketClientBuilder {
281    config: WsConnectionConfig,
282}
283
284impl WebSocketClientBuilder {
285    pub fn new(url: impl Into<String>) -> Self {
286        Self {
287            config: WsConnectionConfig::new(url),
288        }
289    }
290
291    pub fn with_config(config: WsConnectionConfig) -> Self {
292        Self { config }
293    }
294
295    pub fn ping_interval(mut self, seconds: u64) -> Self {
296        self.config = self.config.with_ping_interval(seconds);
297        self
298    }
299
300    pub fn auto_reconnect(mut self, enabled: bool) -> Self {
301        self.config = self.config.with_auto_reconnect(enabled);
302        self
303    }
304
305    pub fn max_retries(mut self, retries: usize) -> Self {
306        self.config = self.config.with_retries(retries);
307        self
308    }
309
310    pub fn backoff(mut self, initial_ms: u64, max_ms: u64) -> Self {
311        self.config = self.config.with_backoff(initial_ms, max_ms);
312        self
313    }
314
315    pub fn build<T: DeserializeOwned>(self) -> WebSocketClient<T> {
316        WebSocketClient::new(self.config)
317    }
318}