1use crate::ws::error::*;
4use crate::ws::events::WebsocketEvent;
5use crate::ws::worker::{ControlMessage, WorkerLoop};
6
7pub struct WebSocketClient {
9 config: crate::config::WebSocketConfig,
10 tls_config: Option<crate::config::TLSConfig>,
11 callback: Option<crate::ws::MessageCallback>,
12 control_tx: Option<tokio::sync::mpsc::UnboundedSender<ControlMessage>>,
13 worker_handle: Option<tokio::task::JoinHandle<WebsocketResult<()>>>,
14 is_connected: std::sync::Arc<tokio::sync::RwLock<bool>>,
15}
16impl WebSocketClient {
17 pub fn new(
19 config: crate::config::WebSocketConfig,
20 tls_config: Option<crate::config::TLSConfig>,
21 ) -> Self {
22 Self {
23 config,
24 tls_config,
25 callback: None,
26 control_tx: None,
27 worker_handle: None,
28 is_connected: std::sync::Arc::new(tokio::sync::RwLock::new(false)),
29 }
30 }
31
32 pub fn on_message<F>(&mut self, callback: F)
34 where
35 F: Fn(WebsocketEvent) + Send + Sync + 'static,
36 {
37 self.callback = Some(std::sync::Arc::new(callback));
38 }
39
40 pub async fn start_background(&mut self) -> WebsocketResult<()> {
42 if self.worker_handle.is_some() {
43 return Err(WebsocketError::AlreadyConnected);
44 }
45
46 let (control_tx, control_rx) = tokio::sync::mpsc::unbounded_channel();
47 self.control_tx = Some(control_tx);
48
49 let worker_loop = WorkerLoop::new(
50 self.config.clone(),
51 self.tls_config.clone(),
52 self.callback.clone(),
53 std::sync::Arc::clone(&self.is_connected),
54 );
55
56 let worker_handle = tokio::spawn(async move { worker_loop.run(control_rx).await });
57
58 self.worker_handle = Some(worker_handle);
59 Ok(())
60 }
61
62 pub async fn start_blocking(&mut self) -> WebsocketResult<()> {
64 let (control_tx, control_rx) = tokio::sync::mpsc::unbounded_channel();
65 self.control_tx = Some(control_tx);
66
67 let worker_loop = WorkerLoop::new(
68 self.config.clone(),
69 self.tls_config.clone(),
70 self.callback.clone(),
71 std::sync::Arc::clone(&self.is_connected),
72 );
73
74 worker_loop.run(control_rx).await
76 }
77
78 pub async fn stop_background(&mut self) -> WebsocketResult<()> {
80 if let Some(tx) = &self.control_tx {
81 let _ = tx.send(ControlMessage::Stop);
82 }
83
84 if let Some(handle) = self.worker_handle.take() {
85 let _ = tokio::time::timeout(std::time::Duration::from_secs(5), handle).await;
87 }
88
89 self.control_tx = None;
90 *self.is_connected.write().await = false;
91
92 Ok(())
93 }
94
95 pub async fn is_connected(&self) -> bool {
97 *self.is_connected.read().await
98 }
99
100 pub async fn reconnect(&self) -> WebsocketResult<()> {
102 if let Some(tx) = &self.control_tx {
103 tx.send(ControlMessage::Reconnect)
104 .map_err(|_| WebsocketError::ChannelError)?;
105 Ok(())
106 } else {
107 Err(WebsocketError::NotConnected)
108 }
109 }
110}
111impl Drop for WebSocketClient {
112 fn drop(&mut self) {
113 if let Some(tx) = &self.control_tx {
115 let _ = tx.send(ControlMessage::Stop);
116 }
117 }
118}
119impl std::fmt::Debug for WebSocketClient {
120 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
121 f.debug_struct("WebsocketClient")
122 .field("url", &self.config.url)
123 .field("is_connected", &self.is_connected)
124 .field("has_tls_config", &self.tls_config.is_some())
125 .finish()
126 }
127}