Skip to main content

webex_message_handler/
mercury_socket.rs

1//! Mercury WebSocket connection with auth, heartbeat, and reconnection.
2
3use crate::errors::WebexError;
4use crate::types::MercuryActivity;
5use futures_util::{SinkExt, StreamExt};
6use serde_json::Value;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{mpsc, Mutex, Notify};
10use tokio::time;
11use tokio_tungstenite::{connect_async, tungstenite::Message};
12use tracing::{debug, error, info};
13use url::Url;
14use uuid::Uuid;
15
16/// Events emitted by the Mercury socket.
17#[derive(Debug, Clone)]
18pub enum MercuryEvent {
19    Connected,
20    Disconnected(String),
21    Reconnecting(u32),
22    Activity(MercuryActivity),
23    KmsResponse(Value),
24    Error(String),
25}
26
27/// Mercury WebSocket connection manager.
28pub struct MercurySocket {
29    ping_interval: Duration,
30    pong_timeout: Duration,
31    reconnect_backoff_max: Duration,
32    max_reconnect_attempts: u32,
33
34    token: Arc<Mutex<String>>,
35    base_url: Arc<Mutex<String>>,
36    connected: Arc<Mutex<bool>>,
37    should_reconnect: Arc<Mutex<bool>>,
38    reconnect_attempts: Arc<Mutex<u32>>,
39    shutdown: Arc<Notify>,
40
41    event_tx: mpsc::UnboundedSender<MercuryEvent>,
42    event_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<MercuryEvent>>>>,
43}
44
45impl MercurySocket {
46    pub fn new(
47        ping_interval: Duration,
48        pong_timeout: Duration,
49        reconnect_backoff_max: Duration,
50        max_reconnect_attempts: u32,
51    ) -> Self {
52        let (event_tx, event_rx) = mpsc::unbounded_channel();
53
54        Self {
55            ping_interval,
56            pong_timeout,
57            reconnect_backoff_max,
58            max_reconnect_attempts,
59            token: Arc::new(Mutex::new(String::new())),
60            base_url: Arc::new(Mutex::new(String::new())),
61            connected: Arc::new(Mutex::new(false)),
62            should_reconnect: Arc::new(Mutex::new(true)),
63            reconnect_attempts: Arc::new(Mutex::new(0)),
64            shutdown: Arc::new(Notify::new()),
65            event_tx,
66            event_rx: Arc::new(Mutex::new(Some(event_rx))),
67        }
68    }
69
70    /// Take the event receiver. Can only be called once.
71    pub async fn take_event_rx(&self) -> Option<mpsc::UnboundedReceiver<MercuryEvent>> {
72        self.event_rx.lock().await.take()
73    }
74
75    /// Connect to Mercury WebSocket.
76    pub async fn connect(&self, ws_url: &str, token: &str) -> Result<(), WebexError> {
77        *self.token.lock().await = token.to_string();
78        *self.base_url.lock().await = ws_url.to_string();
79        *self.should_reconnect.lock().await = true;
80        *self.reconnect_attempts.lock().await = 0;
81        self.connect_internal().await
82    }
83
84    async fn connect_internal(&self) -> Result<(), WebexError> {
85        let base_url = self.base_url.lock().await.clone();
86        let prepared_url = Self::prepare_url(&base_url)?;
87        debug!("Connecting to Mercury at {prepared_url}");
88
89        *self.connected.lock().await = false;
90
91        let (ws_stream, _) = connect_async(&prepared_url)
92            .await
93            .map_err(|e| WebexError::mercury_connection(format!("Failed to connect: {e}"), None))?;
94
95        let (mut write, mut read) = ws_stream.split();
96
97        // Send authorization
98        let token = self.token.lock().await.clone();
99        let auth_msg = serde_json::json!({
100            "id": Uuid::new_v4().to_string(),
101            "type": "authorization",
102            "data": { "token": format!("Bearer {}", token) }
103        });
104        write
105            .send(Message::Text(auth_msg.to_string().into()))
106            .await
107            .map_err(|e| WebexError::mercury_connection(format!("Failed to send auth: {e}"), None))?;
108
109        // Wait for connection ready
110        let _ready_timeout = time::timeout(Duration::from_secs(30), async {
111            while let Some(msg) = read.next().await {
112                match msg {
113                    Ok(Message::Text(text)) => {
114                        let text_str: &str = &text;
115                        if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
116                            if Self::is_connection_ready(&parsed) {
117                                return Ok(parsed);
118                            }
119                        }
120                    }
121                    Err(e) => {
122                        return Err(WebexError::mercury_connection(
123                            format!("WebSocket error during setup: {e}"),
124                            None,
125                        ));
126                    }
127                    _ => {}
128                }
129            }
130            Err(WebexError::mercury_connection("WebSocket closed during setup", None))
131        })
132        .await
133        .map_err(|_| WebexError::mercury_connection("Mercury connection timeout", None))??;
134
135        debug!("Mercury connection ready");
136        *self.connected.lock().await = true;
137
138        // Spawn read loop and ping loop
139        let event_tx = self.event_tx.clone();
140        let connected = self.connected.clone();
141        let should_reconnect = self.should_reconnect.clone();
142        let reconnect_attempts = self.reconnect_attempts.clone();
143        let max_reconnect = self.max_reconnect_attempts;
144        let backoff_max = self.reconnect_backoff_max;
145        let ping_interval = self.ping_interval;
146        let _pong_timeout = self.pong_timeout;
147        let shutdown = self.shutdown.clone();
148        let base_url_clone = self.base_url.clone();
149        let token_clone = self.token.clone();
150
151        let write = Arc::new(Mutex::new(write));
152        let write_clone = write.clone();
153
154        // Ping loop
155        let ping_write = write.clone();
156        let ping_connected = connected.clone();
157        let ping_shutdown = shutdown.clone();
158        let _ping_event_tx = event_tx.clone();
159        tokio::spawn(async move {
160            let mut interval = time::interval(ping_interval);
161            interval.tick().await; // skip first tick
162
163            loop {
164                tokio::select! {
165                    _ = interval.tick() => {
166                        if !*ping_connected.lock().await {
167                            break;
168                        }
169                        let pong_id = Uuid::new_v4().to_string();
170                        let ping_msg = serde_json::json!({
171                            "id": pong_id,
172                            "type": "ping"
173                        });
174                        let mut w = ping_write.lock().await;
175                        if w.send(Message::Text(ping_msg.to_string().into())).await.is_err() {
176                            break;
177                        }
178                        debug!("Sent ping: {pong_id}");
179                        drop(w);
180
181                        // Pong timeout handled by read loop
182                    }
183                    _ = ping_shutdown.notified() => break,
184                }
185            }
186        });
187
188        // Read loop
189        tokio::spawn(async move {
190            while let Some(msg) = read.next().await {
191                match msg {
192                    Ok(Message::Text(text)) => {
193                        let text_str: &str = &text;
194                        debug!("WS message received ({} bytes)", text_str.len());
195                        if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
196                            Self::handle_message_static(&parsed, &event_tx, &write_clone).await;
197                        } else {
198                            debug!("Failed to parse WS message as JSON");
199                        }
200                    }
201                    Ok(Message::Close(frame)) => {
202                        let code = frame.as_ref().map(|f| f.code.into()).unwrap_or(1000u16);
203                        let reason = frame.as_ref().map(|f| f.reason.to_string()).unwrap_or_default();
204                        Self::handle_close_static(
205                            code,
206                            &reason,
207                            &connected,
208                            &should_reconnect,
209                            &reconnect_attempts,
210                            max_reconnect,
211                            backoff_max,
212                            &base_url_clone,
213                            &token_clone,
214                            &event_tx,
215                        )
216                        .await;
217                        break;
218                    }
219                    Err(e) => {
220                        error!("WebSocket error: {e}");
221                        let _ = event_tx.send(MercuryEvent::Error(e.to_string()));
222                        *connected.lock().await = false;
223                        break;
224                    }
225                    _ => {}
226                }
227            }
228
229            // Connection ended — handle reconnection if needed
230            if *should_reconnect.lock().await && *connected.lock().await == false {
231                // Reconnect logic handled in handle_close_static
232            }
233        });
234
235        Ok(())
236    }
237
238    fn prepare_url(base_url: &str) -> Result<String, WebexError> {
239        let mut url = Url::parse(base_url)
240            .map_err(|e| WebexError::mercury_connection(format!("Invalid URL: {e}"), None))?;
241        url.query_pairs_mut()
242            .append_pair("outboundWireFormat", "text")
243            .append_pair("bufferStates", "true")
244            .append_pair("aliasHttpStatus", "true")
245            .append_pair(
246                "clientTimestamp",
247                &std::time::SystemTime::now()
248                    .duration_since(std::time::UNIX_EPOCH)
249                    .unwrap_or_default()
250                    .as_millis()
251                    .to_string(),
252            );
253        Ok(url.to_string())
254    }
255
256    fn is_connection_ready(message: &Value) -> bool {
257        let event_type = message
258            .get("data")
259            .and_then(|d| d.get("eventType"))
260            .and_then(|e| e.as_str())
261            .unwrap_or("");
262        event_type.contains("mercury.buffer_state") || event_type.contains("mercury.registration_status")
263    }
264
265    async fn handle_message_static(
266        message: &Value,
267        event_tx: &mpsc::UnboundedSender<MercuryEvent>,
268        write: &Arc<Mutex<futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>>>,
269    ) {
270        let msg_type = message.get("type").and_then(|t| t.as_str()).unwrap_or("");
271
272        match msg_type {
273            "pong" => {
274                let id = message.get("id").and_then(|i| i.as_str()).unwrap_or("");
275                debug!("Received pong: {id}");
276            }
277            "shutdown" => {
278                info!("Received shutdown message from Mercury");
279                // Reconnection will be triggered by connection close
280            }
281            _ => {
282                if let Some(data) = message.get("data") {
283                    if let Some(event_type) = data.get("eventType").and_then(|e| e.as_str()) {
284                        debug!("Mercury eventType: {event_type}");
285
286                        // Send ACK
287                        if let Some(msg_id) = message.get("id").and_then(|i| i.as_str()) {
288                            let ack = serde_json::json!({"messageId": msg_id, "type": "ack"});
289                            let mut w = write.lock().await;
290                            let _ = w.send(Message::Text(ack.to_string().into())).await;
291                        }
292
293                        if event_type.starts_with("encryption.") {
294                            debug!("Emitting kms:response for eventType: {event_type}");
295                            let _ = event_tx.send(MercuryEvent::KmsResponse(data.clone()));
296                        } else if event_type == "conversation.activity" {
297                            if let Some(activity_raw) = data.get("activity") {
298                                match serde_json::from_value::<MercuryActivity>(activity_raw.clone()) {
299                                    Ok(activity) => {
300                                        debug!("Emitting activity: {}", activity.id);
301                                        let _ = event_tx.send(MercuryEvent::Activity(activity));
302                                    }
303                                    Err(e) => {
304                                        error!("Failed to parse activity: {e}");
305                                        debug!("Raw activity keys: {:?}", activity_raw.as_object().map(|o| o.keys().collect::<Vec<_>>()));
306                                    }
307                                }
308                            }
309                        }
310                    } else {
311                        debug!("Unhandled Mercury message, type={msg_type:?}, keys={:?}", message.as_object().map(|o| o.keys().collect::<Vec<_>>()));
312                    }
313                } else {
314                    debug!("Unhandled Mercury message, type={msg_type:?}, no data field");
315                }
316            }
317        }
318    }
319
320    #[allow(clippy::too_many_arguments)]
321    async fn handle_close_static(
322        code: u16,
323        reason: &str,
324        connected: &Arc<Mutex<bool>>,
325        should_reconnect: &Arc<Mutex<bool>>,
326        reconnect_attempts: &Arc<Mutex<u32>>,
327        max_reconnect: u32,
328        backoff_max: Duration,
329        _base_url: &Arc<Mutex<String>>,
330        _token: &Arc<Mutex<String>>,
331        event_tx: &mpsc::UnboundedSender<MercuryEvent>,
332    ) {
333        info!("WebSocket closed with code {code}: {reason}");
334        *connected.lock().await = false;
335
336        if code == 4401 {
337            error!("Mercury authorization failed");
338            *should_reconnect.lock().await = false;
339            let _ = event_tx.send(MercuryEvent::Error("Mercury authorization failed".into()));
340            let _ = event_tx.send(MercuryEvent::Disconnected("auth-failed".into()));
341            return;
342        }
343
344        if code == 4400 || code == 4403 {
345            error!("Mercury permanent failure (code {code})");
346            *should_reconnect.lock().await = false;
347            let _ = event_tx.send(MercuryEvent::Error(format!("Mercury permanent failure (code {code})")));
348            let _ = event_tx.send(MercuryEvent::Disconnected("permanent-failure".into()));
349            return;
350        }
351
352        if *should_reconnect.lock().await {
353            let mut attempts = reconnect_attempts.lock().await;
354            if *attempts >= max_reconnect {
355                error!("Max reconnection attempts ({max_reconnect}) exceeded");
356                *should_reconnect.lock().await = false;
357                let _ = event_tx.send(MercuryEvent::Disconnected("max-attempts-exceeded".into()));
358                return;
359            }
360            *attempts += 1;
361            let attempt = *attempts;
362            let delay_secs = (2.0f64.powi(attempt as i32 - 1)).min(backoff_max.as_secs_f64());
363            drop(attempts);
364
365            info!("Reconnecting (attempt {attempt}/{max_reconnect}) in {delay_secs}s");
366            let _ = event_tx.send(MercuryEvent::Reconnecting(attempt));
367
368            time::sleep(Duration::from_secs_f64(delay_secs)).await;
369
370            // Signal that reconnection should happen (handler will re-connect)
371            let _ = event_tx.send(MercuryEvent::Disconnected("reconnect-needed".into()));
372        } else {
373            let _ = event_tx.send(MercuryEvent::Disconnected("manual".into()));
374        }
375    }
376
377    /// Disconnect from Mercury.
378    pub async fn disconnect(&self) {
379        info!("Disconnecting from Mercury");
380        *self.should_reconnect.lock().await = false;
381        *self.connected.lock().await = false;
382        self.shutdown.notify_waiters();
383        let _ = self.event_tx.send(MercuryEvent::Disconnected("client".into()));
384    }
385
386    /// Whether the WebSocket is currently connected.
387    pub async fn connected(&self) -> bool {
388        *self.connected.lock().await
389    }
390
391    /// Current reconnection attempt count.
392    pub async fn current_reconnect_attempts(&self) -> u32 {
393        *self.reconnect_attempts.lock().await
394    }
395}