webull_rs/streaming/
client.rs

1use crate::auth::{AccessToken, AuthManager};
2use crate::error::{WebullError, WebullResult};
3use crate::streaming::events::{
4    ConnectionState, ConnectionStatus, ErrorEvent, Event, EventType, HeartbeatEvent,
5};
6use crate::streaming::subscription::{SubscriptionRequest, UnsubscriptionRequest};
7use crate::utils::serialization::{from_json, to_json};
8use futures_util::{SinkExt, StreamExt};
9use reqwest::header::{HeaderMap, HeaderValue, AUTHORIZATION};
10use serde_json::json;
11use std::sync::{Arc, Mutex};
12use std::time::{Duration, Instant};
13use tokio::net::TcpStream;
14use tokio::sync::mpsc::{self, Receiver, Sender};
15use tokio::time::sleep;
16use tokio_tungstenite::{
17    connect_async, tungstenite::protocol::Message, MaybeTlsStream, WebSocketStream,
18};
19use url::Url;
20use uuid::Uuid;
21
22/// WebSocket client for streaming data from Webull.
23pub struct WebSocketClient {
24    /// Base URL for WebSocket connections
25    base_url: String,
26
27    /// Authentication manager
28    auth_manager: Arc<AuthManager>,
29
30    /// Connection state
31    connection_state: Arc<Mutex<ConnectionState>>,
32
33    /// Event sender
34    event_sender: Option<Sender<Event>>,
35
36    /// Last heartbeat time
37    last_heartbeat: Arc<Mutex<Instant>>,
38
39    /// Heartbeat interval in seconds
40    heartbeat_interval: u64,
41
42    /// Reconnect attempts
43    reconnect_attempts: Arc<Mutex<u32>>,
44
45    /// Maximum reconnect attempts
46    max_reconnect_attempts: u32,
47
48    /// Reconnect delay in seconds
49    reconnect_delay: u64,
50}
51
52impl WebSocketClient {
53    /// Create a new WebSocket client.
54    pub fn new(base_url: String, auth_manager: Arc<AuthManager>) -> Self {
55        Self {
56            base_url,
57            auth_manager,
58            connection_state: Arc::new(Mutex::new(ConnectionState::Disconnected)),
59            event_sender: None,
60            last_heartbeat: Arc::new(Mutex::new(Instant::now())),
61            heartbeat_interval: 30,
62            reconnect_attempts: Arc::new(Mutex::new(0)),
63            max_reconnect_attempts: 5,
64            reconnect_delay: 5,
65        }
66    }
67
68    /// Connect to the WebSocket server.
69    pub async fn connect(&mut self) -> WebullResult<Receiver<Event>> {
70        // Create a channel for events
71        let (tx, rx) = mpsc::channel(100);
72        self.event_sender = Some(tx.clone());
73
74        // Set the connection state to reconnecting
75        *self.connection_state.lock().unwrap() = ConnectionState::Reconnecting;
76
77        // Reset reconnect attempts
78        *self.reconnect_attempts.lock().unwrap() = 0;
79
80        // Start the connection task
81        let base_url = self.base_url.clone();
82        let auth_manager = self.auth_manager.clone();
83        let connection_state = self.connection_state.clone();
84        let last_heartbeat = self.last_heartbeat.clone();
85        let heartbeat_interval = self.heartbeat_interval;
86        let reconnect_attempts = self.reconnect_attempts.clone();
87        let max_reconnect_attempts = self.max_reconnect_attempts;
88        let reconnect_delay = self.reconnect_delay;
89
90        tokio::spawn(async move {
91            loop {
92                // Check if we've exceeded the maximum reconnect attempts
93                let attempts = *reconnect_attempts.lock().unwrap();
94                if attempts > max_reconnect_attempts {
95                    // Send a connection failed event
96                    let event = Event {
97                        event_type: EventType::Connection,
98                        timestamp: chrono::Utc::now(),
99                        data: crate::streaming::events::EventData::Connection(ConnectionStatus {
100                            status: ConnectionState::Failed,
101                            connection_id: None,
102                            message: Some("Maximum reconnect attempts exceeded".to_string()),
103                        }),
104                    };
105
106                    let _ = tx.send(event).await;
107
108                    // Set the connection state to failed
109                    *connection_state.lock().unwrap() = ConnectionState::Failed;
110
111                    break;
112                }
113
114                // Increment reconnect attempts
115                *reconnect_attempts.lock().unwrap() = attempts + 1;
116
117                // Get the authentication token
118                let token = match auth_manager.get_token().await {
119                    Ok(token) => token,
120                    Err(e) => {
121                        // Send an error event
122                        let event = Event {
123                            event_type: EventType::Error,
124                            timestamp: chrono::Utc::now(),
125                            data: crate::streaming::events::EventData::Error(ErrorEvent {
126                                code: "AUTH_ERROR".to_string(),
127                                message: format!("Authentication error: {}", e),
128                            }),
129                        };
130
131                        let _ = tx.send(event).await;
132
133                        // Wait before retrying
134                        sleep(Duration::from_secs(reconnect_delay)).await;
135                        continue;
136                    }
137                };
138
139                // Connect to the WebSocket server
140                match Self::connect_websocket(&base_url, &token).await {
141                    Ok(ws_stream) => {
142                        // Set the connection state to connected
143                        *connection_state.lock().unwrap() = ConnectionState::Connected;
144
145                        // Reset reconnect attempts
146                        *reconnect_attempts.lock().unwrap() = 0;
147
148                        // Send a connection established event
149                        let connection_id = Uuid::new_v4().to_string();
150                        let event = Event {
151                            event_type: EventType::Connection,
152                            timestamp: chrono::Utc::now(),
153                            data: crate::streaming::events::EventData::Connection(
154                                ConnectionStatus {
155                                    status: ConnectionState::Connected,
156                                    connection_id: Some(connection_id.clone()),
157                                    message: Some("Connection established".to_string()),
158                                },
159                            ),
160                        };
161
162                        let _ = tx.send(event).await;
163
164                        // Handle the WebSocket connection
165                        if let Err(e) = Self::handle_websocket(
166                            ws_stream,
167                            tx.clone(),
168                            last_heartbeat.clone(),
169                            heartbeat_interval,
170                        )
171                        .await
172                        {
173                            // Send an error event
174                            let event = Event {
175                                event_type: EventType::Error,
176                                timestamp: chrono::Utc::now(),
177                                data: crate::streaming::events::EventData::Error(ErrorEvent {
178                                    code: "WS_ERROR".to_string(),
179                                    message: format!("WebSocket error: {}", e),
180                                }),
181                            };
182
183                            let _ = tx.send(event).await;
184                        }
185
186                        // Set the connection state to disconnected
187                        *connection_state.lock().unwrap() = ConnectionState::Disconnected;
188
189                        // Send a disconnection event
190                        let event = Event {
191                            event_type: EventType::Connection,
192                            timestamp: chrono::Utc::now(),
193                            data: crate::streaming::events::EventData::Connection(
194                                ConnectionStatus {
195                                    status: ConnectionState::Disconnected,
196                                    connection_id: Some(connection_id),
197                                    message: Some("Connection closed".to_string()),
198                                },
199                            ),
200                        };
201
202                        let _ = tx.send(event).await;
203                    }
204                    Err(e) => {
205                        // Send an error event
206                        let event = Event {
207                            event_type: EventType::Error,
208                            timestamp: chrono::Utc::now(),
209                            data: crate::streaming::events::EventData::Error(ErrorEvent {
210                                code: "WS_CONNECT_ERROR".to_string(),
211                                message: format!("WebSocket connection error: {}", e),
212                            }),
213                        };
214
215                        let _ = tx.send(event).await;
216                    }
217                }
218
219                // Wait before reconnecting
220                sleep(Duration::from_secs(reconnect_delay)).await;
221
222                // Set the connection state to reconnecting
223                *connection_state.lock().unwrap() = ConnectionState::Reconnecting;
224
225                // Send a reconnecting event
226                let event = Event {
227                    event_type: EventType::Connection,
228                    timestamp: chrono::Utc::now(),
229                    data: crate::streaming::events::EventData::Connection(ConnectionStatus {
230                        status: ConnectionState::Reconnecting,
231                        connection_id: None,
232                        message: Some("Reconnecting...".to_string()),
233                    }),
234                };
235
236                let _ = tx.send(event).await;
237            }
238        });
239
240        Ok(rx)
241    }
242
243    /// Disconnect from the WebSocket server.
244    pub async fn disconnect(&mut self) -> WebullResult<()> {
245        // Set the connection state to disconnected
246        *self.connection_state.lock().unwrap() = ConnectionState::Disconnected;
247
248        // Reset reconnect attempts
249        *self.reconnect_attempts.lock().unwrap() = self.max_reconnect_attempts + 1;
250
251        Ok(())
252    }
253
254    /// Subscribe to a topic.
255    pub async fn subscribe(&self, request: SubscriptionRequest) -> WebullResult<()> {
256        // Check if we're connected
257        if *self.connection_state.lock().unwrap() != ConnectionState::Connected {
258            return Err(WebullError::InvalidRequest(
259                "Not connected to WebSocket server".to_string(),
260            ));
261        }
262
263        // Send the subscription request
264        let message = json!({
265            "action": "SUBSCRIBE",
266            "request": request,
267        });
268
269        // Send the message
270        if let Some(tx) = &self.event_sender {
271            let _message_str = to_json(&message)?;
272
273            // Create a heartbeat event
274            let event = Event {
275                event_type: EventType::Heartbeat,
276                timestamp: chrono::Utc::now(),
277                data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
278                    id: Uuid::new_v4().to_string(),
279                }),
280            };
281
282            tx.send(event).await.map_err(|e| {
283                WebullError::InvalidRequest(format!("Failed to send message: {}", e))
284            })?;
285        }
286
287        Ok(())
288    }
289
290    /// Unsubscribe from a topic.
291    pub async fn unsubscribe(&self, request: UnsubscriptionRequest) -> WebullResult<()> {
292        // Check if we're connected
293        if *self.connection_state.lock().unwrap() != ConnectionState::Connected {
294            return Err(WebullError::InvalidRequest(
295                "Not connected to WebSocket server".to_string(),
296            ));
297        }
298
299        // Send the unsubscription request
300        let message = json!({
301            "action": "UNSUBSCRIBE",
302            "request": request,
303        });
304
305        // Send the message
306        if let Some(tx) = &self.event_sender {
307            let _message_str = to_json(&message)?;
308
309            // Create a heartbeat event
310            let event = Event {
311                event_type: EventType::Heartbeat,
312                timestamp: chrono::Utc::now(),
313                data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
314                    id: Uuid::new_v4().to_string(),
315                }),
316            };
317
318            tx.send(event).await.map_err(|e| {
319                WebullError::InvalidRequest(format!("Failed to send message: {}", e))
320            })?;
321        }
322
323        Ok(())
324    }
325
326    /// Connect to the WebSocket server.
327    async fn connect_websocket(
328        base_url: &str,
329        token: &AccessToken,
330    ) -> WebullResult<WebSocketStream<MaybeTlsStream<TcpStream>>> {
331        // Create the WebSocket URL
332        let ws_url = format!("{}/ws", base_url.replace("http", "ws"));
333        let url = Url::parse(&ws_url)
334            .map_err(|e| WebullError::InvalidRequest(format!("Invalid WebSocket URL: {}", e)))?;
335
336        // Create the request headers
337        let mut headers = HeaderMap::new();
338        headers.insert(
339            AUTHORIZATION,
340            HeaderValue::from_str(&format!("Bearer {}", token.token)).unwrap(),
341        );
342
343        // Connect to the WebSocket server
344        let (ws_stream, _) = connect_async(url).await.map_err(|e| {
345            WebullError::InvalidRequest(format!("WebSocket connection error: {}", e))
346        })?;
347
348        Ok(ws_stream)
349    }
350
351    /// Handle the WebSocket connection.
352    async fn handle_websocket(
353        mut ws_stream: WebSocketStream<MaybeTlsStream<TcpStream>>,
354        tx: Sender<Event>,
355        last_heartbeat: Arc<Mutex<Instant>>,
356        heartbeat_interval: u64,
357    ) -> WebullResult<()> {
358        // Start the heartbeat task
359        let tx_clone = tx.clone();
360        let last_heartbeat_clone = last_heartbeat.clone();
361
362        tokio::spawn(async move {
363            loop {
364                // Sleep for the heartbeat interval
365                sleep(Duration::from_secs(heartbeat_interval)).await;
366
367                // Check if we need to send a heartbeat
368                let now = Instant::now();
369                let last = *last_heartbeat_clone.lock().unwrap();
370
371                if now.duration_since(last).as_secs() >= heartbeat_interval {
372                    // Create a heartbeat message
373                    let heartbeat = json!({
374                        "type": "HEARTBEAT",
375                        "id": Uuid::new_v4().to_string(),
376                    });
377
378                    // Send the heartbeat message
379                    let _message = Message::Text(to_json(&heartbeat).unwrap());
380
381                    // Create a heartbeat event
382                    let event = Event {
383                        event_type: EventType::Heartbeat,
384                        timestamp: chrono::Utc::now(),
385                        data: crate::streaming::events::EventData::Heartbeat(HeartbeatEvent {
386                            id: Uuid::new_v4().to_string(),
387                        }),
388                    };
389
390                    // Send the heartbeat event
391                    if tx_clone.send(event).await.is_err() {
392                        // Channel closed, exit the task
393                        break;
394                    }
395
396                    // Update the last heartbeat time
397                    *last_heartbeat_clone.lock().unwrap() = now;
398                }
399            }
400        });
401
402        // Handle incoming messages
403        while let Some(message) = ws_stream.next().await {
404            match message {
405                Ok(Message::Text(text)) => {
406                    // Parse the message
407                    match from_json::<Event>(&text) {
408                        Ok(event) => {
409                            // Send the event
410                            if tx.send(event).await.is_err() {
411                                // Channel closed, exit the loop
412                                break;
413                            }
414                        }
415                        Err(e) => {
416                            // Send an error event
417                            let event = Event {
418                                event_type: EventType::Error,
419                                timestamp: chrono::Utc::now(),
420                                data: crate::streaming::events::EventData::Error(ErrorEvent {
421                                    code: "PARSE_ERROR".to_string(),
422                                    message: format!("Failed to parse message: {}", e),
423                                }),
424                            };
425
426                            if tx.send(event).await.is_err() {
427                                // Channel closed, exit the loop
428                                break;
429                            }
430                        }
431                    }
432                }
433                Ok(Message::Binary(_)) => {
434                    // Ignore binary messages
435                }
436                Ok(Message::Ping(data)) => {
437                    // Respond with a pong
438                    if let Err(e) = ws_stream.send(Message::Pong(data)).await {
439                        // Send an error event
440                        let event = Event {
441                            event_type: EventType::Error,
442                            timestamp: chrono::Utc::now(),
443                            data: crate::streaming::events::EventData::Error(ErrorEvent {
444                                code: "PONG_ERROR".to_string(),
445                                message: format!("Failed to send pong: {}", e),
446                            }),
447                        };
448
449                        if tx.send(event).await.is_err() {
450                            // Channel closed, exit the loop
451                            break;
452                        }
453                    }
454
455                    // Update the last heartbeat time
456                    *last_heartbeat.lock().unwrap() = Instant::now();
457                }
458                Ok(Message::Pong(_)) => {
459                    // Update the last heartbeat time
460                    *last_heartbeat.lock().unwrap() = Instant::now();
461                }
462                Ok(Message::Close(_)) => {
463                    // Connection closed
464                    break;
465                }
466                Ok(Message::Frame(_)) => {
467                    // Ignore frame messages
468                }
469                Err(e) => {
470                    // Send an error event
471                    let event = Event {
472                        event_type: EventType::Error,
473                        timestamp: chrono::Utc::now(),
474                        data: crate::streaming::events::EventData::Error(ErrorEvent {
475                            code: "WS_ERROR".to_string(),
476                            message: format!("WebSocket error: {}", e),
477                        }),
478                    };
479
480                    if tx.send(event).await.is_err() {
481                        // Channel closed, exit the loop
482                        break;
483                    }
484
485                    // Exit the loop on error
486                    break;
487                }
488            }
489        }
490
491        Ok(())
492    }
493}