Skip to main content

webex_message_handler/
mercury_socket.rs

1//! Mercury WebSocket connection with auth, heartbeat, and reconnection.
2
3use crate::errors::WebexError;
4use crate::types::MercuryActivity;
5use futures_util::{SinkExt, StreamExt};
6use serde_json::Value;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{mpsc, Mutex, Notify};
10use tokio::time;
11use tokio_tungstenite::{connect_async, tungstenite::Message};
12use tracing::{debug, error, info};
13use url::Url;
14use uuid::Uuid;
15
16/// Events emitted by the Mercury socket.
17#[derive(Debug, Clone)]
18pub enum MercuryEvent {
19    Connected,
20    Disconnected(String),
21    Reconnecting(u32),
22    Activity(MercuryActivity),
23    KmsResponse(Value),
24    Error(String),
25}
26
27/// Mercury WebSocket connection manager.
28pub struct MercurySocket {
29    #[allow(dead_code)]
30    ws_factory: Option<Arc<dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Box<dyn crate::types::InjectedWebSocket>, Box<dyn std::error::Error + Send + Sync>>> + Send>> + Send + Sync>>,
31    ping_interval: Duration,
32    pong_timeout: Duration,
33    reconnect_backoff_max: Duration,
34    max_reconnect_attempts: u32,
35
36    token: Arc<Mutex<String>>,
37    base_url: Arc<Mutex<String>>,
38    connected: Arc<Mutex<bool>>,
39    should_reconnect: Arc<Mutex<bool>>,
40    reconnect_attempts: Arc<Mutex<u32>>,
41    shutdown: Arc<Notify>,
42
43    event_tx: mpsc::UnboundedSender<MercuryEvent>,
44    event_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<MercuryEvent>>>>,
45}
46
47impl MercurySocket {
48    pub fn new(
49        _ws_factory: Option<Arc<dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Box<dyn crate::types::InjectedWebSocket>, Box<dyn std::error::Error + Send + Sync>>> + Send>> + Send + Sync>>,
50        ping_interval: Duration,
51        pong_timeout: Duration,
52        reconnect_backoff_max: Duration,
53        max_reconnect_attempts: u32,
54    ) -> Self {
55        let (event_tx, event_rx) = mpsc::unbounded_channel();
56
57        Self {
58            ws_factory: _ws_factory,
59            ping_interval,
60            pong_timeout,
61            reconnect_backoff_max,
62            max_reconnect_attempts,
63            token: Arc::new(Mutex::new(String::new())),
64            base_url: Arc::new(Mutex::new(String::new())),
65            connected: Arc::new(Mutex::new(false)),
66            should_reconnect: Arc::new(Mutex::new(true)),
67            reconnect_attempts: Arc::new(Mutex::new(0)),
68            shutdown: Arc::new(Notify::new()),
69            event_tx,
70            event_rx: Arc::new(Mutex::new(Some(event_rx))),
71        }
72    }
73
74    /// Take the event receiver. Can only be called once.
75    pub async fn take_event_rx(&self) -> Option<mpsc::UnboundedReceiver<MercuryEvent>> {
76        self.event_rx.lock().await.take()
77    }
78
79    /// Connect to Mercury WebSocket.
80    pub async fn connect(&self, ws_url: &str, token: &str) -> Result<(), WebexError> {
81        *self.token.lock().await = token.to_string();
82        *self.base_url.lock().await = ws_url.to_string();
83        *self.should_reconnect.lock().await = true;
84        *self.reconnect_attempts.lock().await = 0;
85        self.connect_internal().await
86    }
87
88    async fn connect_internal(&self) -> Result<(), WebexError> {
89        let base_url = self.base_url.lock().await.clone();
90        let prepared_url = Self::prepare_url(&base_url)?;
91        debug!("Connecting to Mercury at {prepared_url}");
92
93        *self.connected.lock().await = false;
94
95        let (ws_stream, _) = connect_async(&prepared_url)
96            .await
97            .map_err(|e| WebexError::mercury_connection(format!("Failed to connect: {e}"), None))?;
98
99        let (mut write, mut read) = ws_stream.split();
100
101        // Send authorization
102        let token = self.token.lock().await.clone();
103        let auth_msg = serde_json::json!({
104            "id": Uuid::new_v4().to_string(),
105            "type": "authorization",
106            "data": { "token": format!("Bearer {}", token) }
107        });
108        write
109            .send(Message::Text(auth_msg.to_string().into()))
110            .await
111            .map_err(|e| WebexError::mercury_connection(format!("Failed to send auth: {e}"), None))?;
112
113        // Wait for connection ready
114        let _ready_timeout = time::timeout(Duration::from_secs(30), async {
115            while let Some(msg) = read.next().await {
116                match msg {
117                    Ok(Message::Text(text)) => {
118                        let text_str: &str = &text;
119                        if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
120                            if Self::is_connection_ready(&parsed) {
121                                return Ok(parsed);
122                            }
123                        }
124                    }
125                    Err(e) => {
126                        return Err(WebexError::mercury_connection(
127                            format!("WebSocket error during setup: {e}"),
128                            None,
129                        ));
130                    }
131                    _ => {}
132                }
133            }
134            Err(WebexError::mercury_connection("WebSocket closed during setup", None))
135        })
136        .await
137        .map_err(|_| WebexError::mercury_connection("Mercury connection timeout", None))??;
138
139        debug!("Mercury connection ready");
140        *self.connected.lock().await = true;
141
142        // Spawn read loop and ping loop
143        let event_tx = self.event_tx.clone();
144        let connected = self.connected.clone();
145        let should_reconnect = self.should_reconnect.clone();
146        let reconnect_attempts = self.reconnect_attempts.clone();
147        let max_reconnect = self.max_reconnect_attempts;
148        let backoff_max = self.reconnect_backoff_max;
149        let ping_interval = self.ping_interval;
150        let _pong_timeout = self.pong_timeout;
151        let shutdown = self.shutdown.clone();
152        let base_url_clone = self.base_url.clone();
153        let token_clone = self.token.clone();
154
155        let write = Arc::new(Mutex::new(write));
156        let write_clone = write.clone();
157
158        // Ping loop
159        let ping_write = write.clone();
160        let ping_connected = connected.clone();
161        let ping_shutdown = shutdown.clone();
162        let _ping_event_tx = event_tx.clone();
163        tokio::spawn(async move {
164            let mut interval = time::interval(ping_interval);
165            interval.tick().await; // skip first tick
166
167            loop {
168                tokio::select! {
169                    _ = interval.tick() => {
170                        if !*ping_connected.lock().await {
171                            break;
172                        }
173                        let pong_id = Uuid::new_v4().to_string();
174                        let ping_msg = serde_json::json!({
175                            "id": pong_id,
176                            "type": "ping"
177                        });
178                        let mut w = ping_write.lock().await;
179                        if w.send(Message::Text(ping_msg.to_string().into())).await.is_err() {
180                            break;
181                        }
182                        debug!("Sent ping: {pong_id}");
183                        drop(w);
184
185                        // Pong timeout handled by read loop
186                    }
187                    _ = ping_shutdown.notified() => break,
188                }
189            }
190        });
191
192        // Read loop
193        tokio::spawn(async move {
194            while let Some(msg) = read.next().await {
195                match msg {
196                    Ok(Message::Text(text)) => {
197                        let text_str: &str = &text;
198                        debug!("WS message received ({} bytes)", text_str.len());
199                        if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
200                            Self::handle_message_static(&parsed, &event_tx, &write_clone).await;
201                        } else {
202                            debug!("Failed to parse WS message as JSON");
203                        }
204                    }
205                    Ok(Message::Close(frame)) => {
206                        let code = frame.as_ref().map(|f| f.code.into()).unwrap_or(1000u16);
207                        let reason = frame.as_ref().map(|f| f.reason.to_string()).unwrap_or_default();
208                        Self::handle_close_static(
209                            code,
210                            &reason,
211                            &connected,
212                            &should_reconnect,
213                            &reconnect_attempts,
214                            max_reconnect,
215                            backoff_max,
216                            &base_url_clone,
217                            &token_clone,
218                            &event_tx,
219                        )
220                        .await;
221                        break;
222                    }
223                    Err(e) => {
224                        error!("WebSocket error: {e}");
225                        let _ = event_tx.send(MercuryEvent::Error(e.to_string()));
226                        *connected.lock().await = false;
227                        break;
228                    }
229                    _ => {}
230                }
231            }
232
233            // Connection ended — handle reconnection if needed
234            if *should_reconnect.lock().await && *connected.lock().await == false {
235                // Reconnect logic handled in handle_close_static
236            }
237        });
238
239        Ok(())
240    }
241
242    fn prepare_url(base_url: &str) -> Result<String, WebexError> {
243        let mut url = Url::parse(base_url)
244            .map_err(|e| WebexError::mercury_connection(format!("Invalid URL: {e}"), None))?;
245        url.query_pairs_mut()
246            .append_pair("outboundWireFormat", "text")
247            .append_pair("bufferStates", "true")
248            .append_pair("aliasHttpStatus", "true")
249            .append_pair(
250                "clientTimestamp",
251                &std::time::SystemTime::now()
252                    .duration_since(std::time::UNIX_EPOCH)
253                    .unwrap_or_default()
254                    .as_millis()
255                    .to_string(),
256            );
257        Ok(url.to_string())
258    }
259
260    fn is_connection_ready(message: &Value) -> bool {
261        let event_type = message
262            .get("data")
263            .and_then(|d| d.get("eventType"))
264            .and_then(|e| e.as_str())
265            .unwrap_or("");
266        event_type.contains("mercury.buffer_state") || event_type.contains("mercury.registration_status")
267    }
268
269    async fn handle_message_static(
270        message: &Value,
271        event_tx: &mpsc::UnboundedSender<MercuryEvent>,
272        write: &Arc<Mutex<futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>>>,
273    ) {
274        let msg_type = message.get("type").and_then(|t| t.as_str()).unwrap_or("");
275
276        match msg_type {
277            "pong" => {
278                let id = message.get("id").and_then(|i| i.as_str()).unwrap_or("");
279                debug!("Received pong: {id}");
280            }
281            "shutdown" => {
282                info!("Received shutdown message from Mercury");
283                // Reconnection will be triggered by connection close
284            }
285            _ => {
286                if let Some(data) = message.get("data") {
287                    if let Some(event_type) = data.get("eventType").and_then(|e| e.as_str()) {
288                        debug!("Mercury eventType: {event_type}");
289
290                        // Send ACK
291                        if let Some(msg_id) = message.get("id").and_then(|i| i.as_str()) {
292                            let ack = serde_json::json!({"messageId": msg_id, "type": "ack"});
293                            let mut w = write.lock().await;
294                            let _ = w.send(Message::Text(ack.to_string().into())).await;
295                        }
296
297                        if event_type.starts_with("encryption.") {
298                            debug!("Emitting kms:response for eventType: {event_type}");
299                            let _ = event_tx.send(MercuryEvent::KmsResponse(data.clone()));
300                        } else if event_type == "conversation.activity" {
301                            if let Some(activity_raw) = data.get("activity") {
302                                match serde_json::from_value::<MercuryActivity>(activity_raw.clone()) {
303                                    Ok(activity) => {
304                                        debug!("Emitting activity: {}", activity.id);
305                                        let _ = event_tx.send(MercuryEvent::Activity(activity));
306                                    }
307                                    Err(e) => {
308                                        error!("Failed to parse activity: {e}");
309                                        debug!("Raw activity keys: {:?}", activity_raw.as_object().map(|o| o.keys().collect::<Vec<_>>()));
310                                    }
311                                }
312                            }
313                        }
314                    } else {
315                        debug!("Unhandled Mercury message, type={msg_type:?}, keys={:?}", message.as_object().map(|o| o.keys().collect::<Vec<_>>()));
316                    }
317                } else {
318                    debug!("Unhandled Mercury message, type={msg_type:?}, no data field");
319                }
320            }
321        }
322    }
323
324    #[allow(clippy::too_many_arguments)]
325    async fn handle_close_static(
326        code: u16,
327        reason: &str,
328        connected: &Arc<Mutex<bool>>,
329        should_reconnect: &Arc<Mutex<bool>>,
330        reconnect_attempts: &Arc<Mutex<u32>>,
331        max_reconnect: u32,
332        backoff_max: Duration,
333        _base_url: &Arc<Mutex<String>>,
334        _token: &Arc<Mutex<String>>,
335        event_tx: &mpsc::UnboundedSender<MercuryEvent>,
336    ) {
337        info!("WebSocket closed with code {code}: {reason}");
338        *connected.lock().await = false;
339
340        if code == 4401 {
341            error!("Mercury authorization failed");
342            *should_reconnect.lock().await = false;
343            let _ = event_tx.send(MercuryEvent::Error("Mercury authorization failed".into()));
344            let _ = event_tx.send(MercuryEvent::Disconnected("auth-failed".into()));
345            return;
346        }
347
348        if code == 4400 || code == 4403 {
349            error!("Mercury permanent failure (code {code})");
350            *should_reconnect.lock().await = false;
351            let _ = event_tx.send(MercuryEvent::Error(format!("Mercury permanent failure (code {code})")));
352            let _ = event_tx.send(MercuryEvent::Disconnected("permanent-failure".into()));
353            return;
354        }
355
356        if *should_reconnect.lock().await {
357            let mut attempts = reconnect_attempts.lock().await;
358            if *attempts >= max_reconnect {
359                error!("Max reconnection attempts ({max_reconnect}) exceeded");
360                *should_reconnect.lock().await = false;
361                let _ = event_tx.send(MercuryEvent::Disconnected("max-attempts-exceeded".into()));
362                return;
363            }
364            *attempts += 1;
365            let attempt = *attempts;
366            let delay_secs = (2.0f64.powi(attempt as i32 - 1)).min(backoff_max.as_secs_f64());
367            drop(attempts);
368
369            info!("Reconnecting (attempt {attempt}/{max_reconnect}) in {delay_secs}s");
370            let _ = event_tx.send(MercuryEvent::Reconnecting(attempt));
371
372            time::sleep(Duration::from_secs_f64(delay_secs)).await;
373
374            // Signal that reconnection should happen (handler will re-connect)
375            let _ = event_tx.send(MercuryEvent::Disconnected("reconnect-needed".into()));
376        } else {
377            let _ = event_tx.send(MercuryEvent::Disconnected("manual".into()));
378        }
379    }
380
381    /// Disconnect from Mercury.
382    pub async fn disconnect(&self) {
383        info!("Disconnecting from Mercury");
384        *self.should_reconnect.lock().await = false;
385        *self.connected.lock().await = false;
386        self.shutdown.notify_waiters();
387        let _ = self.event_tx.send(MercuryEvent::Disconnected("client".into()));
388    }
389
390    /// Whether the WebSocket is currently connected.
391    pub async fn connected(&self) -> bool {
392        *self.connected.lock().await
393    }
394
395    /// Current reconnection attempt count.
396    pub async fn current_reconnect_attempts(&self) -> u32 {
397        *self.reconnect_attempts.lock().await
398    }
399}