Skip to main content

webex_message_handler/
mercury_socket.rs

1//! Mercury WebSocket connection with auth, heartbeat, and reconnection.
2
3use crate::errors::WebexError;
4use crate::types::MercuryActivity;
5use futures_util::{SinkExt, StreamExt};
6use serde_json::Value;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{mpsc, Mutex, Notify};
10use tokio::time;
11use tokio_tungstenite::{connect_async, tungstenite::Message};
12use tracing::{debug, error, info, warn};
13use url::Url;
14use uuid::Uuid;
15
16/// Type alias for the injected WebSocket factory.
17#[allow(dead_code)]
18type WsFactoryFn = Arc<
19    dyn Fn(String) -> std::pin::Pin<
20        Box<dyn std::future::Future<
21            Output = Result<Box<dyn crate::types::InjectedWebSocket>, Box<dyn std::error::Error + Send + Sync>>,
22        > + Send>,
23    > + Send + Sync,
24>;
25
26/// Type alias for the WebSocket write half to reduce type complexity.
27type WsSink = futures_util::stream::SplitSink<
28    tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
29    Message,
30>;
31
32/// Events emitted by the Mercury socket.
33#[derive(Debug, Clone)]
34pub enum MercuryEvent {
35    Connected,
36    Disconnected(String),
37    Reconnecting(u32),
38    Activity(Box<MercuryActivity>),
39    KmsResponse(Value),
40    Error(String),
41}
42
43/// Mercury WebSocket connection manager.
44pub struct MercurySocket {
45    #[allow(dead_code)]
46    ws_factory: Option<WsFactoryFn>,
47    ping_interval: Duration,
48    pong_timeout: Duration,
49    reconnect_backoff_max: Duration,
50    max_reconnect_attempts: u32,
51
52    token: Arc<Mutex<String>>,
53    base_url: Arc<Mutex<String>>,
54    connected: Arc<Mutex<bool>>,
55    should_reconnect: Arc<Mutex<bool>>,
56    reconnect_attempts: Arc<Mutex<u32>>,
57    shutdown: Arc<Notify>,
58
59    event_tx: mpsc::UnboundedSender<MercuryEvent>,
60    event_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<MercuryEvent>>>>,
61}
62
63impl MercurySocket {
64    pub fn new(
65        _ws_factory: Option<WsFactoryFn>,
66        ping_interval: Duration,
67        pong_timeout: Duration,
68        reconnect_backoff_max: Duration,
69        max_reconnect_attempts: u32,
70    ) -> Self {
71        let (event_tx, event_rx) = mpsc::unbounded_channel();
72
73        Self {
74            ws_factory: _ws_factory,
75            ping_interval,
76            pong_timeout,
77            reconnect_backoff_max,
78            max_reconnect_attempts,
79            token: Arc::new(Mutex::new(String::new())),
80            base_url: Arc::new(Mutex::new(String::new())),
81            connected: Arc::new(Mutex::new(false)),
82            should_reconnect: Arc::new(Mutex::new(true)),
83            reconnect_attempts: Arc::new(Mutex::new(0)),
84            shutdown: Arc::new(Notify::new()),
85            event_tx,
86            event_rx: Arc::new(Mutex::new(Some(event_rx))),
87        }
88    }
89
90    /// Take the event receiver. Can only be called once.
91    pub async fn take_event_rx(&self) -> Option<mpsc::UnboundedReceiver<MercuryEvent>> {
92        self.event_rx.lock().await.take()
93    }
94
95    /// Connect to Mercury WebSocket.
96    pub async fn connect(&self, ws_url: &str, token: &str) -> Result<(), WebexError> {
97        *self.token.lock().await = token.to_string();
98        *self.base_url.lock().await = ws_url.to_string();
99        *self.should_reconnect.lock().await = true;
100        *self.reconnect_attempts.lock().await = 0;
101        self.connect_internal().await
102    }
103
104    async fn connect_internal(&self) -> Result<(), WebexError> {
105        let base_url = self.base_url.lock().await.clone();
106        let prepared_url = Self::prepare_url(&base_url)?;
107        debug!("Connecting to Mercury at {prepared_url}");
108
109        *self.connected.lock().await = false;
110
111        let (ws_stream, _) = connect_async(&prepared_url)
112            .await
113            .map_err(|e| WebexError::mercury_connection(format!("Failed to connect: {e}"), None))?;
114
115        let (mut write, mut read) = ws_stream.split();
116
117        // Send authorization
118        let token = self.token.lock().await.clone();
119        let auth_msg = serde_json::json!({
120            "id": Uuid::new_v4().to_string(),
121            "type": "authorization",
122            "data": { "token": format!("Bearer {}", token) }
123        });
124        write
125            .send(Message::Text(auth_msg.to_string()))
126            .await
127            .map_err(|e| WebexError::mercury_connection(format!("Failed to send auth: {e}"), None))?;
128
129        // Wait for connection ready
130        let _ready_timeout = time::timeout(Duration::from_secs(30), async {
131            while let Some(msg) = read.next().await {
132                match msg {
133                    Ok(Message::Text(text)) => {
134                        let text_str: &str = &text;
135                        if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
136                            if Self::is_connection_ready(&parsed) {
137                                return Ok(parsed);
138                            }
139                        }
140                    }
141                    Err(e) => {
142                        return Err(WebexError::mercury_connection(
143                            format!("WebSocket error during setup: {e}"),
144                            None,
145                        ));
146                    }
147                    _ => {}
148                }
149            }
150            Err(WebexError::mercury_connection("WebSocket closed during setup", None))
151        })
152        .await
153        .map_err(|_| WebexError::mercury_connection("Mercury connection timeout", None))??;
154
155        debug!("Mercury connection ready");
156        *self.connected.lock().await = true;
157
158        // Spawn read loop and ping loop
159        let event_tx = self.event_tx.clone();
160        let connected = self.connected.clone();
161        let should_reconnect = self.should_reconnect.clone();
162        let reconnect_attempts = self.reconnect_attempts.clone();
163        let max_reconnect = self.max_reconnect_attempts;
164        let backoff_max = self.reconnect_backoff_max;
165        let ping_interval = self.ping_interval;
166        let _pong_timeout = self.pong_timeout;
167        let shutdown = self.shutdown.clone();
168        let base_url_clone = self.base_url.clone();
169        let token_clone = self.token.clone();
170
171        let write = Arc::new(Mutex::new(write));
172        let write_clone = write.clone();
173
174        // Ping loop
175        let ping_write = write.clone();
176        let ping_connected = connected.clone();
177        let ping_shutdown = shutdown.clone();
178        let _ping_event_tx = event_tx.clone();
179        tokio::spawn(async move {
180            let mut interval = time::interval(ping_interval);
181            interval.tick().await; // skip first tick
182
183            loop {
184                tokio::select! {
185                    _ = interval.tick() => {
186                        if !*ping_connected.lock().await {
187                            break;
188                        }
189                        let pong_id = Uuid::new_v4().to_string();
190                        let ping_msg = serde_json::json!({
191                            "id": pong_id,
192                            "type": "ping"
193                        });
194                        let mut w = ping_write.lock().await;
195                        if w.send(Message::Text(ping_msg.to_string())).await.is_err() {
196                            break;
197                        }
198                        debug!("Sent ping: {pong_id}");
199                        drop(w);
200
201                        // Pong timeout handled by read loop
202                    }
203                    _ = ping_shutdown.notified() => break,
204                }
205            }
206        });
207
208        // Read loop
209        tokio::spawn(async move {
210            while let Some(msg) = read.next().await {
211                match msg {
212                    Ok(Message::Text(text)) => {
213                        let text_str: &str = &text;
214                        debug!("WS message received ({} bytes)", text_str.len());
215                        // Enforce WebSocket message size limit (1 MB)
216                        if text_str.len() > 1_048_576 {
217                            warn!("Dropping oversized Mercury message ({} bytes)", text_str.len());
218                            continue;
219                        }
220                        if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
221                            Self::handle_message_static(&parsed, &event_tx, &write_clone).await;
222                        } else {
223                            debug!("Failed to parse WS message as JSON");
224                        }
225                    }
226                    Ok(Message::Close(frame)) => {
227                        let code = frame.as_ref().map(|f| f.code.into()).unwrap_or(1000u16);
228                        let reason = frame.as_ref().map(|f| f.reason.to_string()).unwrap_or_default();
229                        Self::handle_close_static(
230                            code,
231                            &reason,
232                            &connected,
233                            &should_reconnect,
234                            &reconnect_attempts,
235                            max_reconnect,
236                            backoff_max,
237                            &base_url_clone,
238                            &token_clone,
239                            &event_tx,
240                        )
241                        .await;
242                        break;
243                    }
244                    Err(e) => {
245                        error!("WebSocket error: {e}");
246                        let _ = event_tx.send(MercuryEvent::Error(e.to_string()));
247                        *connected.lock().await = false;
248                        break;
249                    }
250                    _ => {}
251                }
252            }
253
254            // Connection ended — handle reconnection if needed
255            if *should_reconnect.lock().await && !*connected.lock().await {
256                // Reconnect logic handled in handle_close_static
257            }
258        });
259
260        Ok(())
261    }
262
263    fn prepare_url(base_url: &str) -> Result<String, WebexError> {
264        let mut url = Url::parse(base_url)
265            .map_err(|e| WebexError::mercury_connection(format!("Invalid URL: {e}"), None))?;
266        url.query_pairs_mut()
267            .append_pair("outboundWireFormat", "text")
268            .append_pair("bufferStates", "true")
269            .append_pair("aliasHttpStatus", "true")
270            .append_pair(
271                "clientTimestamp",
272                &std::time::SystemTime::now()
273                    .duration_since(std::time::UNIX_EPOCH)
274                    .unwrap_or_default()
275                    .as_millis()
276                    .to_string(),
277            );
278        Ok(url.to_string())
279    }
280
281    fn is_connection_ready(message: &Value) -> bool {
282        let event_type = message
283            .get("data")
284            .and_then(|d| d.get("eventType"))
285            .and_then(|e| e.as_str())
286            .unwrap_or("");
287        event_type.contains("mercury.buffer_state") || event_type.contains("mercury.registration_status")
288    }
289
290    async fn handle_message_static(
291        message: &Value,
292        event_tx: &mpsc::UnboundedSender<MercuryEvent>,
293        write: &Arc<Mutex<WsSink>>,
294    ) {
295        let msg_type = message.get("type").and_then(|t| t.as_str()).unwrap_or("");
296
297        match msg_type {
298            "pong" => {
299                let id = message.get("id").and_then(|i| i.as_str()).unwrap_or("");
300                debug!("Received pong: {id}");
301            }
302            "shutdown" => {
303                info!("Received shutdown message from Mercury");
304                // Reconnection will be triggered by connection close
305            }
306            _ => {
307                if let Some(data) = message.get("data") {
308                    if let Some(event_type) = data.get("eventType").and_then(|e| e.as_str()) {
309                        debug!("Mercury eventType: {event_type}");
310
311                        // Send ACK
312                        if let Some(msg_id) = message.get("id").and_then(|i| i.as_str()) {
313                            let ack = serde_json::json!({"messageId": msg_id, "type": "ack"});
314                            let mut w = write.lock().await;
315                            let _ = w.send(Message::Text(ack.to_string())).await;
316                        }
317
318                        if event_type.starts_with("encryption.") {
319                            debug!("Emitting kms:response for eventType: {event_type}");
320                            let _ = event_tx.send(MercuryEvent::KmsResponse(data.clone()));
321                        } else if event_type == "conversation.activity" {
322                            if let Some(activity_raw) = data.get("activity") {
323                                match serde_json::from_value::<MercuryActivity>(activity_raw.clone()) {
324                                    Ok(activity) => {
325                                        debug!("Emitting activity: {}", activity.id);
326                                        let _ = event_tx.send(MercuryEvent::Activity(Box::new(activity)));
327                                    }
328                                    Err(e) => {
329                                        error!("Failed to parse activity: {e}");
330                                        debug!("Raw activity keys: {:?}", activity_raw.as_object().map(|o| o.keys().collect::<Vec<_>>()));
331                                    }
332                                }
333                            }
334                        }
335                    } else {
336                        debug!("Unhandled Mercury message, type={msg_type:?}, keys={:?}", message.as_object().map(|o| o.keys().collect::<Vec<_>>()));
337                    }
338                } else {
339                    debug!("Unhandled Mercury message, type={msg_type:?}, no data field");
340                }
341            }
342        }
343    }
344
345    #[allow(clippy::too_many_arguments)]
346    async fn handle_close_static(
347        code: u16,
348        reason: &str,
349        connected: &Arc<Mutex<bool>>,
350        should_reconnect: &Arc<Mutex<bool>>,
351        reconnect_attempts: &Arc<Mutex<u32>>,
352        max_reconnect: u32,
353        backoff_max: Duration,
354        _base_url: &Arc<Mutex<String>>,
355        _token: &Arc<Mutex<String>>,
356        event_tx: &mpsc::UnboundedSender<MercuryEvent>,
357    ) {
358        info!("WebSocket closed with code {code}: {reason}");
359        *connected.lock().await = false;
360
361        if code == 4401 {
362            error!("Mercury authorization failed");
363            *should_reconnect.lock().await = false;
364            let _ = event_tx.send(MercuryEvent::Error("Mercury authorization failed".into()));
365            let _ = event_tx.send(MercuryEvent::Disconnected("auth-failed".into()));
366            return;
367        }
368
369        if code == 4400 || code == 4403 {
370            error!("Mercury permanent failure (code {code})");
371            *should_reconnect.lock().await = false;
372            let _ = event_tx.send(MercuryEvent::Error(format!("Mercury permanent failure (code {code})")));
373            let _ = event_tx.send(MercuryEvent::Disconnected("permanent-failure".into()));
374            return;
375        }
376
377        if *should_reconnect.lock().await {
378            let mut attempts = reconnect_attempts.lock().await;
379            if *attempts >= max_reconnect {
380                error!("Max reconnection attempts ({max_reconnect}) exceeded");
381                *should_reconnect.lock().await = false;
382                let _ = event_tx.send(MercuryEvent::Disconnected("max-attempts-exceeded".into()));
383                return;
384            }
385            *attempts += 1;
386            let attempt = *attempts;
387            let delay_secs = (2.0f64.powi(attempt as i32 - 1)).min(backoff_max.as_secs_f64());
388            drop(attempts);
389
390            info!("Reconnecting (attempt {attempt}/{max_reconnect}) in {delay_secs}s");
391            let _ = event_tx.send(MercuryEvent::Reconnecting(attempt));
392
393            time::sleep(Duration::from_secs_f64(delay_secs)).await;
394
395            // Signal that reconnection should happen (handler will re-connect)
396            let _ = event_tx.send(MercuryEvent::Disconnected("reconnect-needed".into()));
397        } else {
398            let _ = event_tx.send(MercuryEvent::Disconnected("manual".into()));
399        }
400    }
401
402    /// Disconnect from Mercury.
403    pub async fn disconnect(&self) {
404        info!("Disconnecting from Mercury");
405        *self.should_reconnect.lock().await = false;
406        *self.connected.lock().await = false;
407        self.shutdown.notify_waiters();
408        let _ = self.event_tx.send(MercuryEvent::Disconnected("client".into()));
409    }
410
411    /// Whether the WebSocket is currently connected.
412    pub async fn connected(&self) -> bool {
413        *self.connected.lock().await
414    }
415
416    /// Current reconnection attempt count.
417    pub async fn current_reconnect_attempts(&self) -> u32 {
418        *self.reconnect_attempts.lock().await
419    }
420}