ws_reconnect_client/
client.rs1use backon::{BackoffBuilder, ExponentialBuilder};
2use futures_util::{SinkExt, StreamExt};
3use serde::{de::DeserializeOwned, Serialize};
4use std::time::Duration;
5use tokio_tungstenite::tungstenite::Message;
6
7use crate::{
8 connect_with_retry, MessageStream, PingManager, Result, WebSocketError, WsConnectionConfig,
9 WsReader, WsWriter,
10};
11
12#[derive(Clone)]
14pub struct WebSocketClient<T>
15where
16 T: DeserializeOwned,
17{
18 config: WsConnectionConfig,
19 _phantom: std::marker::PhantomData<T>,
20}
21
22impl<T> WebSocketClient<T>
23where
24 T: DeserializeOwned,
25{
26 pub fn new(config: WsConnectionConfig) -> Self {
28 Self {
29 config,
30 _phantom: std::marker::PhantomData,
31 }
32 }
33
34 pub async fn connect(&self) -> Result<(WsWriter, WsReader)> {
38 connect_with_retry(&self.config).await
39 }
40
41 pub async fn connect_stream(&self) -> Result<MessageStream<T>> {
45 use std::sync::Arc;
46 use tokio::sync::Mutex;
47
48 let (writer, reader) = self.connect().await?;
49 let shared_writer = Arc::new(Mutex::new(Some(writer)));
50 Ok(MessageStream::new(
51 reader,
52 shared_writer,
53 self.config.ping_interval_secs,
54 ))
55 }
56
57 pub async fn connect_and_subscribe<S: Serialize>(
61 &self,
62 subscription: Option<&S>,
63 ) -> Result<(WsWriter, WsReader)> {
64 let (mut writer, reader) = self.connect().await?;
65
66 if let Some(sub) = subscription {
67 send_subscription(&mut writer, sub).await?;
68 }
69
70 Ok((writer, reader))
71 }
72
73 pub async fn listen<S, F>(&self, subscription: Option<S>, mut handler: F) -> Result<()>
100 where
101 S: Serialize + Clone,
102 F: FnMut(T) -> Result<()>,
103 {
104 if !self.config.auto_reconnect {
105 return self.listen_once(subscription.as_ref(), &mut handler).await;
107 }
108
109 let backoff = ExponentialBuilder::default()
111 .with_min_delay(Duration::from_millis(self.config.initial_backoff_ms))
112 .with_max_delay(Duration::from_millis(self.config.max_backoff_ms))
113 .with_max_times(self.config.max_retries);
114
115 let mut backoff_iter = backoff.build();
116 let mut attempt = 0;
117
118 loop {
119 match self.listen_once(subscription.as_ref(), &mut handler).await {
120 Ok(_) => {
121 backoff_iter = backoff.build();
123 attempt = 0;
124 }
125 Err(e) => {
126 attempt += 1;
128
129 if attempt >= self.config.max_retries {
130 return Err(e);
132 }
133
134 if let Some(delay) = backoff_iter.next() {
136 tokio::time::sleep(delay).await;
137 } else {
138 return Err(e);
140 }
141 }
142 }
143 }
144 }
145
146 async fn listen_once<S, F>(&self, subscription: Option<&S>, handler: &mut F) -> Result<()>
148 where
149 S: Serialize,
150 F: FnMut(T) -> Result<()>,
151 {
152 let (mut writer, mut reader) = self.connect_and_subscribe(subscription).await?;
153
154 let mut ping_manager = if self.config.ping_interval_secs > 0 {
155 Some(PingManager::new(self.config.ping_interval_secs))
156 } else {
157 None
158 };
159
160 loop {
161 tokio::select! {
162 msg = reader.next() => {
164 match msg {
165 Some(Ok(Message::Text(text))) => {
166 self.handle_text_message(&text, handler)?;
167 }
168
169 Some(Ok(Message::Ping(ping))) => {
170 writer.send(Message::Pong(ping))
171 .await
172 .map_err(|_| WebSocketError::SendError)?;
173 }
174
175 Some(Ok(Message::Pong(_))) => {
176 }
178
179 Some(Ok(Message::Close(_))) => {
180 return Err(WebSocketError::ConnectionClosed);
181 }
182
183 Some(Err(e)) => {
184 return Err(WebSocketError::Tungstenite(e));
185 }
186
187 None => {
188 return Err(WebSocketError::ConnectionClosed);
189 }
190
191 _ => {}
192 }
193 }
194
195 _ = async {
197 if let Some(ref mut pm) = ping_manager {
198 pm.wait_for_next_ping().await;
199 } else {
200 std::future::pending::<()>().await;
202 }
203 } => {
204 writer.send(Message::Ping(vec![].into()))
205 .await
206 .map_err(|_| WebSocketError::SendError)?;
207 }
208 }
209 }
210 }
211
212 fn handle_text_message<F>(&self, text: &str, handler: &mut F) -> Result<()>
214 where
215 F: FnMut(T) -> Result<()>,
216 {
217 if text.trim().is_empty() {
219 return Ok(());
220 }
221
222 if text.trim_start().starts_with('[') {
224 match serde_json::from_str::<Vec<T>>(text) {
225 Ok(messages) => {
226 for msg in messages {
227 handler(msg)?;
228 }
229 return Ok(());
230 }
231 Err(e) => {
232 eprintln!("⚠️ WebSocket: Failed to parse as array: {}", e);
234 }
235 }
236 }
237
238 match serde_json::from_str::<T>(text) {
240 Ok(msg) => {
241 handler(msg)?;
242 Ok(())
243 }
244 Err(e) => {
245 eprintln!("⚠️ WebSocket: Failed to parse message: {}", e);
248 eprintln!("📨 Raw message: {}", text);
249 Err(WebSocketError::SerializationError(e))
251 }
252 }
253 }
254}
255
256pub async fn send_subscription<S: Serialize>(writer: &mut WsWriter, subscription: &S) -> Result<()> {
272 let sub_json = serde_json::to_string(subscription)?;
273 writer
274 .send(Message::Text(sub_json.into()))
275 .await
276 .map_err(|_| WebSocketError::SendError)
277}
278
279pub struct WebSocketClientBuilder {
281 config: WsConnectionConfig,
282}
283
284impl WebSocketClientBuilder {
285 pub fn new(url: impl Into<String>) -> Self {
286 Self {
287 config: WsConnectionConfig::new(url),
288 }
289 }
290
291 pub fn with_config(config: WsConnectionConfig) -> Self {
292 Self { config }
293 }
294
295 pub fn ping_interval(mut self, seconds: u64) -> Self {
296 self.config = self.config.with_ping_interval(seconds);
297 self
298 }
299
300 pub fn auto_reconnect(mut self, enabled: bool) -> Self {
301 self.config = self.config.with_auto_reconnect(enabled);
302 self
303 }
304
305 pub fn max_retries(mut self, retries: usize) -> Self {
306 self.config = self.config.with_retries(retries);
307 self
308 }
309
310 pub fn backoff(mut self, initial_ms: u64, max_ms: u64) -> Self {
311 self.config = self.config.with_backoff(initial_ms, max_ms);
312 self
313 }
314
315 pub fn build<T: DeserializeOwned>(self) -> WebSocketClient<T> {
316 WebSocketClient::new(self.config)
317 }
318}