1use crate::auth::{AccessToken, AuthManager};
2use crate::error::{WebullError, WebullResult};
3use crate::streaming::events::{
4 ConnectionState, ConnectionStatus, ErrorEvent, Event, EventType, HeartbeatEvent,
5};
6use crate::streaming::subscription::{SubscriptionRequest, UnsubscriptionRequest};
7use crate::utils::serialization::{from_json, to_json};
8use futures_util::{SinkExt, StreamExt};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use serde_json::json;
11use std::sync::{Arc, Mutex};
12use std::time::{Duration, Instant};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc::{self, Receiver, Sender};
15use tokio::time::sleep;
16use tokio_tungstenite::{
17 connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream,
18};
19use url::Url;
20use uuid::Uuid;
21
22pub struct WebSocketClient {
24 base_url: String,
26
27 auth_manager: Arc<AuthManager>,
29
30 connection_state: Arc<Mutex<ConnectionState>>,
32
33 event_sender: Option<Sender<Event>>,
35
36 last_heartbeat: Arc<Mutex<Instant>>,
38
39 heartbeat_interval: u64,
41
42 reconnect_attempts: Arc<Mutex<u32>>,
44
45 max_reconnect_attempts: u32,
47
48 reconnect_delay: u64,
50}
51
52impl WebSocketClient {
53 pub fn new(base_url: String, auth_manager: Arc<AuthManager>) -> Self {
55 Self {
56 base_url,
57 auth_manager,
58 connection_state: Arc::new(Mutex::new(ConnectionState::Disconnected)),
59 event_sender: None,
60 last_heartbeat: Arc::new(Mutex::new(Instant::now())),
61 heartbeat_interval: 30,
62 reconnect_attempts: Arc::new(Mutex::new(0)),
63 max_reconnect_attempts: 5,
64 reconnect_delay: 5,
65 }
66 }
67
68 pub async fn connect(&mut self) -> WebullResult<Receiver<Event>> {
70 let (tx, rx) = mpsc::channel(100);
72 self.event_sender = Some(tx.clone());
73
74 *self.connection_state.lock().unwrap() = ConnectionState::Reconnecting;
76
77 *self.reconnect_attempts.lock().unwrap() = 0;
79
80 let base_url = self.base_url.clone();
82 let auth_manager = self.auth_manager.clone();
83 let connection_state = self.connection_state.clone();
84 let last_heartbeat = self.last_heartbeat.clone();
85 let heartbeat_interval = self.heartbeat_interval;
86 let reconnect_attempts = self.reconnect_attempts.clone();
87 let max_reconnect_attempts = self.max_reconnect_attempts;
88 let reconnect_delay = self.reconnect_delay;
89
90 tokio::spawn(async move {
91 loop {
92 let attempts = *reconnect_attempts.lock().unwrap();
94 if attempts > max_reconnect_attempts {
95 let event = Event {
97 event_type: EventType::Connection,
98 timestamp: chrono::Utc::now(),
99 data: crate::streaming::events::EventData::Connection(ConnectionStatus {
100 status: ConnectionState::Failed,
101 connection_id: None,
102 message: Some("Maximum reconnect attempts exceeded".to_string()),
103 }),
104 };
105
106 let _ = tx.send(event).await;
107
108 *connection_state.lock().unwrap() = ConnectionState::Failed;
110
111 break;
112 }
113
114 *reconnect_attempts.lock().unwrap() = attempts + 1;
116
117 let token = match auth_manager.get_token().await {
119 Ok(token) => token,
120 Err(e) => {
121 let event = Event {
123 event_type: EventType::Error,
124 timestamp: chrono::Utc::now(),
125 data: crate::streaming::events::EventData::Error(ErrorEvent {
126 code: "AUTH_ERROR".to_string(),
127 message: format!("Authentication error: {}", e),
128 }),
129 };
130
131 let _ = tx.send(event).await;
132
133 sleep(Duration::from_secs(reconnect_delay)).await;
135 continue;
136 }
137 };
138
139 match Self::connect_websocket(&base_url, &token).await {
141 Ok(ws_stream) => {
142 *connection_state.lock().unwrap() = ConnectionState::Connected;
144
145 *reconnect_attempts.lock().unwrap() = 0;
147
148 let connection_id = Uuid::new_v4().to_string();
150 let event = Event {
151 event_type: EventType::Connection,
152 timestamp: chrono::Utc::now(),
153 data: crate::streaming::events::EventData::Connection(
154 ConnectionStatus {
155 status: ConnectionState::Connected,
156 connection_id: Some(connection_id.clone()),
157 message: Some("Connection established".to_string()),
158 },
159 ),
160 };
161
162 let _ = tx.send(event).await;
163
164 if let Err(e) = Self::handle_websocket(
166 ws_stream,
167 tx.clone(),
168 last_heartbeat.clone(),
169 heartbeat_interval,
170 )
171 .await
172 {
173 let event = Event {
175 event_type: EventType::Error,
176 timestamp: chrono::Utc::now(),
177 data: crate::streaming::events::EventData::Error(ErrorEvent {
178 code: "WS_ERROR".to_string(),
179 message: format!("WebSocket error: {}", e),
180 }),
181 };
182
183 let _ = tx.send(event).await;
184 }
185
186 *connection_state.lock().unwrap() = ConnectionState::Disconnected;
188
189 let event = Event {
191 event_type: EventType::Connection,
192 timestamp: chrono::Utc::now(),
193 data: crate::streaming::events::EventData::Connection(
194 ConnectionStatus {
195 status: ConnectionState::Disconnected,
196 connection_id: Some(connection_id),
197 message: Some("Connection closed".to_string()),
198 },
199 ),
200 };
201
202 let _ = tx.send(event).await;
203 }
204 Err(e) => {
205 let event = Event {
207 event_type: EventType::Error,
208 timestamp: chrono::Utc::now(),
209 data: crate::streaming::events::EventData::Error(ErrorEvent {
210 code: "WS_CONNECT_ERROR".to_string(),
211 message: format!("WebSocket connection error: {}", e),
212 }),
213 };
214
215 let _ = tx.send(event).await;
216 }
217 }
218
219 sleep(Duration::from_secs(reconnect_delay)).await;
221
222 *connection_state.lock().unwrap() = ConnectionState::Reconnecting;
224
225 let event = Event {
227 event_type: EventType::Connection,
228 timestamp: chrono::Utc::now(),
229 data: crate::streaming::events::EventData::Connection(ConnectionStatus {
230 status: ConnectionState::Reconnecting,
231 connection_id: None,
232 message: Some("Reconnecting...".to_string()),
233 }),
234 };
235
236 let _ = tx.send(event).await;
237 }
238 });
239
240 Ok(rx)
241 }
242
243 pub async fn disconnect(&mut self) -> WebullResult<()> {
245 *self.connection_state.lock().unwrap() = ConnectionState::Disconnected;
247
248 *self.reconnect_attempts.lock().unwrap() = self.max_reconnect_attempts + 1;
250
251 Ok(())
252 }
253
254 pub async fn subscribe(&self, request: SubscriptionRequest) -> WebullResult<()> {
256 if *self.connection_state.lock().unwrap() != ConnectionState::Connected {
258 return Err(WebullError::InvalidRequest(
259 "Not connected to WebSocket server".to_string(),
260 ));
261 }
262
263 let message = json!({
265 "action": "SUBSCRIBE",
266 "request": request,
267 });
268
269 if let Some(tx) = &self.event_sender {
271 let _message_str = to_json(&message)?;
272
273 let event = Event {
275 event_type: EventType::Heartbeat,
276 timestamp: chrono::Utc::now(),
277 data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
278 id: Uuid::new_v4().to_string(),
279 }),
280 };
281
282 tx.send(event).await.map_err(|e| {
283 WebullError::InvalidRequest(format!("Failed to send message: {}", e))
284 })?;
285 }
286
287 Ok(())
288 }
289
290 pub async fn unsubscribe(&self, request: UnsubscriptionRequest) -> WebullResult<()> {
292 if *self.connection_state.lock().unwrap() != ConnectionState::Connected {
294 return Err(WebullError::InvalidRequest(
295 "Not connected to WebSocket server".to_string(),
296 ));
297 }
298
299 let message = json!({
301 "action": "UNSUBSCRIBE",
302 "request": request,
303 });
304
305 if let Some(tx) = &self.event_sender {
307 let _message_str = to_json(&message)?;
308
309 let event = Event {
311 event_type: EventType::Heartbeat,
312 timestamp: chrono::Utc::now(),
313 data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
314 id: Uuid::new_v4().to_string(),
315 }),
316 };
317
318 tx.send(event).await.map_err(|e| {
319 WebullError::InvalidRequest(format!("Failed to send message: {}", e))
320 })?;
321 }
322
323 Ok(())
324 }
325
326 async fn connect_websocket(
328 base_url: &str,
329 token: &AccessToken,
330 ) -> WebullResult<WebSocketStream<MaybeTlsStream<TcpStream>>> {
331 let ws_url = format!("{}/ws", base_url.replace("http", "ws"));
333 let url = Url::parse(&ws_url)
334 .map_err(|e| WebullError::InvalidRequest(format!("Invalid WebSocket URL: {}", e)))?;
335
336 let mut headers = HeaderMap::new();
338 headers.insert(
339 AUTHORIZATION,
340 HeaderValue::from_str(&format!("Bearer {}", token.token)).unwrap(),
341 );
342
343 let (ws_stream, _) = connect_async(url).await.map_err(|e| {
345 WebullError::InvalidRequest(format!("WebSocket connection error: {}", e))
346 })?;
347
348 Ok(ws_stream)
349 }
350
351 async fn handle_websocket(
353 mut ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
354 tx: Sender<Event>,
355 last_heartbeat: Arc<Mutex<Instant>>,
356 heartbeat_interval: u64,
357 ) -> WebullResult<()> {
358 let tx_clone = tx.clone();
360 let last_heartbeat_clone = last_heartbeat.clone();
361
362 tokio::spawn(async move {
363 loop {
364 sleep(Duration::from_secs(heartbeat_interval)).await;
366
367 let now = Instant::now();
369 let last = *last_heartbeat_clone.lock().unwrap();
370
371 if now.duration_since(last).as_secs() >= heartbeat_interval {
372 let heartbeat = json!({
374 "type": "HEARTBEAT",
375 "id": Uuid::new_v4().to_string(),
376 });
377
378 let _message = Message::Text(to_json(&heartbeat).unwrap());
380
381 let event = Event {
383 event_type: EventType::Heartbeat,
384 timestamp: chrono::Utc::now(),
385 data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
386 id: Uuid::new_v4().to_string(),
387 }),
388 };
389
390 if tx_clone.send(event).await.is_err() {
392 break;
394 }
395
396 *last_heartbeat_clone.lock().unwrap() = now;
398 }
399 }
400 });
401
402 while let Some(message) = ws_stream.next().await {
404 match message {
405 Ok(Message::Text(text)) => {
406 match from_json::<Event>(&text) {
408 Ok(event) => {
409 if tx.send(event).await.is_err() {
411 break;
413 }
414 }
415 Err(e) => {
416 let event = Event {
418 event_type: EventType::Error,
419 timestamp: chrono::Utc::now(),
420 data: crate::streaming::events::EventData::Error(ErrorEvent {
421 code: "PARSE_ERROR".to_string(),
422 message: format!("Failed to parse message: {}", e),
423 }),
424 };
425
426 if tx.send(event).await.is_err() {
427 break;
429 }
430 }
431 }
432 }
433 Ok(Message::Binary(_)) => {
434 }
436 Ok(Message::Ping(data)) => {
437 if let Err(e) = ws_stream.send(Message::Pong(data)).await {
439 let event = Event {
441 event_type: EventType::Error,
442 timestamp: chrono::Utc::now(),
443 data: crate::streaming::events::EventData::Error(ErrorEvent {
444 code: "PONG_ERROR".to_string(),
445 message: format!("Failed to send pong: {}", e),
446 }),
447 };
448
449 if tx.send(event).await.is_err() {
450 break;
452 }
453 }
454
455 *last_heartbeat.lock().unwrap() = Instant::now();
457 }
458 Ok(Message::Pong(_)) => {
459 *last_heartbeat.lock().unwrap() = Instant::now();
461 }
462 Ok(Message::Close(_)) => {
463 break;
465 }
466 Ok(Message::Frame(_)) => {
467 }
469 Err(e) => {
470 let event = Event {
472 event_type: EventType::Error,
473 timestamp: chrono::Utc::now(),
474 data: crate::streaming::events::EventData::Error(ErrorEvent {
475 code: "WS_ERROR".to_string(),
476 message: format!("WebSocket error: {}", e),
477 }),
478 };
479
480 if tx.send(event).await.is_err() {
481 break;
483 }
484
485 break;
487 }
488 }
489 }
490
491 Ok(())
492 }
493}