1use anyhow::Result;
2use futures_util::{SinkExt, StreamExt};
3use std::time::Duration;
4use tokio::sync::{mpsc, watch};
5use tokio_tungstenite::tungstenite;
6use tungstenite::error::{Error as WsError, ProtocolError, UrlError};
7use tungstenite::protocol::frame::coding::CloseCode;
8
9use super::types::BinanceTradeEvent;
10use crate::event::{AppEvent, WsConnectionStatus};
11use crate::model::tick::Tick;
12
13struct ExponentialBackoff {
15 current: Duration,
16 initial: Duration,
17 max: Duration,
18 factor: f64,
19}
20
21impl ExponentialBackoff {
22 fn new(initial: Duration, max: Duration, factor: f64) -> Self {
23 Self {
24 current: initial,
25 initial,
26 max,
27 factor,
28 }
29 }
30
31 fn next_delay(&mut self) -> Duration {
32 let delay = self.current;
33 self.current = Duration::from_secs_f64(
34 (self.current.as_secs_f64() * self.factor).min(self.max.as_secs_f64()),
35 );
36 delay
37 }
38
39 fn reset(&mut self) {
40 self.current = self.initial;
41 }
42}
43
44pub struct BinanceWsClient {
45 spot_url: String,
46 futures_url: String,
47}
48
49impl BinanceWsClient {
50 pub fn new(ws_base_url: &str, futures_ws_base_url: &str) -> Self {
54 Self {
55 spot_url: ws_base_url.to_string(),
56 futures_url: futures_ws_base_url.to_string(),
57 }
58 }
59
60 pub async fn connect_and_run(
63 &self,
64 tick_tx: mpsc::Sender<Tick>,
65 status_tx: mpsc::Sender<AppEvent>,
66 mut symbol_rx: watch::Receiver<String>,
67 mut shutdown: watch::Receiver<bool>,
68 ) -> Result<()> {
69 let mut backoff =
70 ExponentialBackoff::new(Duration::from_secs(1), Duration::from_secs(60), 2.0);
71 let mut attempt: u32 = 0;
72
73 loop {
74 attempt += 1;
75 let instrument = symbol_rx.borrow().clone();
76 let (symbol, is_futures) = parse_instrument_symbol(&instrument);
77 let streams = vec![format!("{}@trade", symbol.to_lowercase())];
78 let ws_url = if is_futures {
79 &self.futures_url
80 } else {
81 &self.spot_url
82 };
83 match self
84 .connect_once(
85 ws_url,
86 &streams,
87 &tick_tx,
88 &status_tx,
89 &mut symbol_rx,
90 &mut shutdown,
91 )
92 .await
93 {
94 Ok(()) => {
95 let _ = status_tx
97 .send(AppEvent::WsStatus(WsConnectionStatus::Disconnected))
98 .await;
99 break;
100 }
101 Err(e) => {
102 let _ = status_tx
103 .send(AppEvent::WsStatus(WsConnectionStatus::Disconnected))
104 .await;
105 tracing::warn!(attempt, error = %e, "WS connection attempt failed");
106 let _ = status_tx
107 .send(AppEvent::LogMessage(format!(
108 "WS error (attempt #{}): {}",
109 attempt, e
110 )))
111 .await;
112
113 let delay = backoff.next_delay();
114 let _ = status_tx
115 .send(AppEvent::WsStatus(WsConnectionStatus::Reconnecting {
116 attempt,
117 delay_ms: delay.as_millis() as u64,
118 }))
119 .await;
120
121 tokio::select! {
122 _ = tokio::time::sleep(delay) => continue,
123 _ = shutdown.changed() => {
124 let _ = status_tx
125 .send(AppEvent::LogMessage("Shutdown during reconnect".to_string()))
126 .await;
127 break;
128 }
129 }
130 }
131 }
132 }
133 Ok(())
134 }
135
136 async fn connect_once(
137 &self,
138 ws_url: &str,
139 streams: &[String],
140 tick_tx: &mpsc::Sender<Tick>,
141 status_tx: &mpsc::Sender<AppEvent>,
142 symbol_rx: &mut watch::Receiver<String>,
143 shutdown: &mut watch::Receiver<bool>,
144 ) -> Result<()> {
145 let _ = status_tx
146 .send(AppEvent::LogMessage(format!("Connecting to {}", ws_url)))
147 .await;
148
149 let (ws_stream, resp) = tokio_tungstenite::connect_async(ws_url)
150 .await
151 .map_err(|e| {
152 let detail = format_ws_error(&e);
153 let _ = status_tx.try_send(AppEvent::LogMessage(detail.clone()));
154 anyhow::anyhow!("WebSocket connect failed: {}", detail)
155 })?;
156
157 tracing::debug!(status = %resp.status(), "WebSocket HTTP upgrade response");
158
159 let (mut write, mut read) = ws_stream.split();
160
161 let subscribe_msg = serde_json::json!({
163 "method": "SUBSCRIBE",
164 "params": streams,
165 "id": 1
166 });
167 write
168 .send(tungstenite::Message::Text(subscribe_msg.to_string()))
169 .await
170 .map_err(|e| {
171 let detail = format_ws_error(&e);
172 anyhow::anyhow!("Failed to send SUBSCRIBE: {}", detail)
173 })?;
174
175 let _ = status_tx
176 .send(AppEvent::LogMessage(format!(
177 "Subscribed to: {}",
178 streams.join(", ")
179 )))
180 .await;
181
182 let _ = status_tx
184 .send(AppEvent::WsStatus(WsConnectionStatus::Connected))
185 .await;
186 let _ = status_tx
187 .send(AppEvent::LogMessage("WebSocket connected".to_string()))
188 .await;
189
190 loop {
191 tokio::select! {
192 msg = read.next() => {
193 match msg {
194 Some(Ok(tungstenite::Message::Text(text))) => {
195 self.handle_text_message(&text, tick_tx, status_tx).await;
196 }
197 Some(Ok(tungstenite::Message::Ping(_))) => {
198 }
200 Some(Ok(tungstenite::Message::Close(frame))) => {
201 let detail = match &frame {
202 Some(cf) => format!(
203 "Server closed: code={} reason=\"{}\"",
204 format_close_code(&cf.code),
205 cf.reason
206 ),
207 None => "Server closed: no close frame".to_string(),
208 };
209 let _ = status_tx
210 .send(AppEvent::LogMessage(detail.clone()))
211 .await;
212 return Err(anyhow::anyhow!("{}", detail));
213 }
214 Some(Ok(other)) => {
215 tracing::trace!(msg_type = ?other, "Unhandled WS message type");
216 }
217 Some(Err(e)) => {
218 let detail = format_ws_error(&e);
219 let _ = status_tx
220 .send(AppEvent::LogMessage(format!("WS read error: {}", detail)))
221 .await;
222 return Err(anyhow::anyhow!("WebSocket read error: {}", detail));
223 }
224 None => {
225 return Err(anyhow::anyhow!(
226 "WebSocket stream ended unexpectedly (connection dropped)"
227 ));
228 }
229 }
230 }
231 _ = shutdown.changed() => {
232 let unsub_msg = serde_json::json!({
234 "method": "UNSUBSCRIBE",
235 "params": streams,
236 "id": 2
237 });
238 let _ = write
239 .send(tungstenite::Message::Text(unsub_msg.to_string()))
240 .await;
241 let _ = write.send(tungstenite::Message::Close(None)).await;
242 return Ok(());
243 }
244 _ = symbol_rx.changed() => {
245 let _ = write.send(tungstenite::Message::Close(None)).await;
246 return Err(anyhow::anyhow!("Symbol changed, reconnecting WebSocket"));
247 }
248 }
249 }
250 }
251
252 async fn handle_text_message(
253 &self,
254 text: &str,
255 tick_tx: &mpsc::Sender<Tick>,
256 status_tx: &mpsc::Sender<AppEvent>,
257 ) {
258 if let Ok(val) = serde_json::from_str::<serde_json::Value>(text) {
260 if val.get("result").is_some() && val.get("id").is_some() {
261 tracing::debug!(id = %val["id"], "Subscription response received");
262 return;
263 }
264 }
265
266 match serde_json::from_str::<BinanceTradeEvent>(text) {
267 Ok(event) => {
268 let tick = Tick {
269 price: event.price,
270 qty: event.qty,
271 timestamp_ms: event.event_time,
272 is_buyer_maker: event.is_buyer_maker,
273 trade_id: event.trade_id,
274 };
275 if tick_tx.try_send(tick).is_err() {
276 tracing::warn!("Tick channel full, dropping tick");
277 }
278 }
279 Err(e) => {
280 tracing::debug!(error = %e, raw = %text, "Failed to parse WS message");
281 let _ = status_tx
282 .send(AppEvent::LogMessage(format!(
283 "WS parse skip: {}",
284 &text[..text.len().min(80)]
285 )))
286 .await;
287 }
288 }
289 }
290}
291
292fn parse_instrument_symbol(instrument: &str) -> (String, bool) {
293 let trimmed = instrument.trim();
294 if let Some(symbol) = trimmed.strip_suffix(" (FUT)") {
295 return (symbol.to_ascii_uppercase(), true);
296 }
297 (trimmed.to_ascii_uppercase(), false)
298}
299
300fn format_ws_error(err: &WsError) -> String {
302 match err {
303 WsError::ConnectionClosed => "Connection closed normally".to_string(),
304 WsError::AlreadyClosed => "Attempted operation on already-closed connection".to_string(),
305 WsError::Io(io_err) => {
306 format!("IO error [kind={}]: {}", io_err.kind(), io_err)
307 }
308 WsError::Tls(tls_err) => format!("TLS error: {}", tls_err),
309 WsError::Capacity(cap_err) => format!("Capacity error: {}", cap_err),
310 WsError::Protocol(proto_err) => {
311 let detail = match proto_err {
312 ProtocolError::ResetWithoutClosingHandshake => {
313 "connection reset without closing handshake (server may have dropped)"
314 }
315 ProtocolError::SendAfterClosing => "tried to send after close frame",
316 ProtocolError::ReceivedAfterClosing => "received data after close frame",
317 ProtocolError::HandshakeIncomplete => "handshake incomplete",
318 _ => "",
319 };
320 if detail.is_empty() {
321 format!("Protocol error: {}", proto_err)
322 } else {
323 format!("Protocol error: {} ({})", proto_err, detail)
324 }
325 }
326 WsError::WriteBufferFull(_) => "Write buffer full (backpressure)".to_string(),
327 WsError::Utf8 => "UTF-8 encoding error in frame data".to_string(),
328 WsError::AttackAttempt => "Attack attempt detected by WebSocket library".to_string(),
329 WsError::Url(url_err) => {
330 let hint = match url_err {
331 UrlError::TlsFeatureNotEnabled => "TLS feature not compiled in",
332 UrlError::NoHostName => "no host name in URL",
333 UrlError::UnableToConnect(addr) => {
334 return format!(
335 "URL error: unable to connect to {} (DNS/network failure?)",
336 addr
337 );
338 }
339 UrlError::UnsupportedUrlScheme => "only ws:// or wss:// are supported",
340 UrlError::EmptyHostName => "empty host name in URL",
341 UrlError::NoPathOrQuery => "no path/query in URL",
342 };
343 format!("URL error: {} — {}", url_err, hint)
344 }
345 WsError::Http(resp) => {
346 let status = resp.status();
347 let body_preview = resp
348 .body()
349 .as_ref()
350 .and_then(|b| std::str::from_utf8(b).ok())
351 .unwrap_or("")
352 .chars()
353 .take(200)
354 .collect::<String>();
355 format!(
356 "HTTP error: status={} ({}), body=\"{}\"",
357 status.as_u16(),
358 status.canonical_reason().unwrap_or("unknown"),
359 body_preview
360 )
361 }
362 WsError::HttpFormat(e) => format!("HTTP format error: {}", e),
363 }
364}
365
366fn format_close_code(code: &CloseCode) -> String {
368 let (num, label) = match code {
369 CloseCode::Normal => (1000, "Normal"),
370 CloseCode::Away => (1001, "Going Away"),
371 CloseCode::Protocol => (1002, "Protocol Error"),
372 CloseCode::Unsupported => (1003, "Unsupported Data"),
373 CloseCode::Status => (1005, "No Status"),
374 CloseCode::Abnormal => (1006, "Abnormal Closure"),
375 CloseCode::Invalid => (1007, "Invalid Payload"),
376 CloseCode::Policy => (1008, "Policy Violation"),
377 CloseCode::Size => (1009, "Message Too Big"),
378 CloseCode::Extension => (1010, "Extension Required"),
379 CloseCode::Error => (1011, "Internal Error"),
380 CloseCode::Restart => (1012, "Service Restart"),
381 CloseCode::Again => (1013, "Try Again Later"),
382 CloseCode::Tls => (1015, "TLS Handshake Failure"),
383 CloseCode::Reserved(n) => (*n, "Reserved"),
384 CloseCode::Iana(n) => (*n, "IANA"),
385 CloseCode::Library(n) => (*n, "Library"),
386 CloseCode::Bad(n) => (*n, "Bad"),
387 };
388 format!("{} ({})", num, label)
389}