1use crate::errors::WebexError;
4use crate::types::MercuryActivity;
5use futures_util::{SinkExt, StreamExt};
6use serde_json::Value;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{mpsc, Mutex, Notify};
10use tokio::time;
11use tokio_tungstenite::{connect_async, tungstenite::Message};
12use tracing::{debug, error, info, warn};
13use url::Url;
14use uuid::Uuid;
15
16#[allow(dead_code)]
18type WsFactoryFn = Arc<
19 dyn Fn(String) -> std::pin::Pin<
20 Box<dyn std::future::Future<
21 Output = Result<Box<dyn crate::types::InjectedWebSocket>, Box<dyn std::error::Error + Send + Sync>>,
22 > + Send>,
23 > + Send + Sync,
24>;
25
26type WsSink = futures_util::stream::SplitSink<
28 tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>,
29 Message,
30>;
31
32#[derive(Debug, Clone)]
34pub enum MercuryEvent {
35 Connected,
36 Disconnected(String),
37 Reconnecting(u32),
38 Activity(Box<MercuryActivity>),
39 KmsResponse(Value),
40 Error(String),
41}
42
43pub struct MercurySocket {
45 #[allow(dead_code)]
46 ws_factory: Option<WsFactoryFn>,
47 ping_interval: Duration,
48 pong_timeout: Duration,
49 reconnect_backoff_max: Duration,
50 max_reconnect_attempts: u32,
51
52 token: Arc<Mutex<String>>,
53 base_url: Arc<Mutex<String>>,
54 connected: Arc<Mutex<bool>>,
55 should_reconnect: Arc<Mutex<bool>>,
56 reconnect_attempts: Arc<Mutex<u32>>,
57 shutdown: Arc<Notify>,
58
59 event_tx: mpsc::UnboundedSender<MercuryEvent>,
60 event_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<MercuryEvent>>>>,
61}
62
63impl MercurySocket {
64 pub fn new(
65 _ws_factory: Option<WsFactoryFn>,
66 ping_interval: Duration,
67 pong_timeout: Duration,
68 reconnect_backoff_max: Duration,
69 max_reconnect_attempts: u32,
70 ) -> Self {
71 let (event_tx, event_rx) = mpsc::unbounded_channel();
72
73 Self {
74 ws_factory: _ws_factory,
75 ping_interval,
76 pong_timeout,
77 reconnect_backoff_max,
78 max_reconnect_attempts,
79 token: Arc::new(Mutex::new(String::new())),
80 base_url: Arc::new(Mutex::new(String::new())),
81 connected: Arc::new(Mutex::new(false)),
82 should_reconnect: Arc::new(Mutex::new(true)),
83 reconnect_attempts: Arc::new(Mutex::new(0)),
84 shutdown: Arc::new(Notify::new()),
85 event_tx,
86 event_rx: Arc::new(Mutex::new(Some(event_rx))),
87 }
88 }
89
90 pub async fn take_event_rx(&self) -> Option<mpsc::UnboundedReceiver<MercuryEvent>> {
92 self.event_rx.lock().await.take()
93 }
94
95 pub async fn connect(&self, ws_url: &str, token: &str) -> Result<(), WebexError> {
97 *self.token.lock().await = token.to_string();
98 *self.base_url.lock().await = ws_url.to_string();
99 *self.should_reconnect.lock().await = true;
100 *self.reconnect_attempts.lock().await = 0;
101 self.connect_internal().await
102 }
103
104 async fn connect_internal(&self) -> Result<(), WebexError> {
105 let base_url = self.base_url.lock().await.clone();
106 let prepared_url = Self::prepare_url(&base_url)?;
107 debug!("Connecting to Mercury at {prepared_url}");
108
109 *self.connected.lock().await = false;
110
111 let (ws_stream, _) = connect_async(&prepared_url)
112 .await
113 .map_err(|e| WebexError::mercury_connection(format!("Failed to connect: {e}"), None))?;
114
115 let (mut write, mut read) = ws_stream.split();
116
117 let token = self.token.lock().await.clone();
119 let auth_msg = serde_json::json!({
120 "id": Uuid::new_v4().to_string(),
121 "type": "authorization",
122 "data": { "token": format!("Bearer {}", token) }
123 });
124 write
125 .send(Message::Text(auth_msg.to_string()))
126 .await
127 .map_err(|e| WebexError::mercury_connection(format!("Failed to send auth: {e}"), None))?;
128
129 let _ready_timeout = time::timeout(Duration::from_secs(30), async {
131 while let Some(msg) = read.next().await {
132 match msg {
133 Ok(Message::Text(text)) => {
134 let text_str: &str = &text;
135 if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
136 if Self::is_connection_ready(&parsed) {
137 return Ok(parsed);
138 }
139 }
140 }
141 Err(e) => {
142 return Err(WebexError::mercury_connection(
143 format!("WebSocket error during setup: {e}"),
144 None,
145 ));
146 }
147 _ => {}
148 }
149 }
150 Err(WebexError::mercury_connection("WebSocket closed during setup", None))
151 })
152 .await
153 .map_err(|_| WebexError::mercury_connection("Mercury connection timeout", None))??;
154
155 debug!("Mercury connection ready");
156 *self.connected.lock().await = true;
157
158 let event_tx = self.event_tx.clone();
160 let connected = self.connected.clone();
161 let should_reconnect = self.should_reconnect.clone();
162 let reconnect_attempts = self.reconnect_attempts.clone();
163 let max_reconnect = self.max_reconnect_attempts;
164 let backoff_max = self.reconnect_backoff_max;
165 let ping_interval = self.ping_interval;
166 let _pong_timeout = self.pong_timeout;
167 let shutdown = self.shutdown.clone();
168 let base_url_clone = self.base_url.clone();
169 let token_clone = self.token.clone();
170
171 let write = Arc::new(Mutex::new(write));
172 let write_clone = write.clone();
173
174 let ping_write = write.clone();
176 let ping_connected = connected.clone();
177 let ping_shutdown = shutdown.clone();
178 let _ping_event_tx = event_tx.clone();
179 tokio::spawn(async move {
180 let mut interval = time::interval(ping_interval);
181 interval.tick().await; loop {
184 tokio::select! {
185 _ = interval.tick() => {
186 if !*ping_connected.lock().await {
187 break;
188 }
189 let pong_id = Uuid::new_v4().to_string();
190 let ping_msg = serde_json::json!({
191 "id": pong_id,
192 "type": "ping"
193 });
194 let mut w = ping_write.lock().await;
195 if w.send(Message::Text(ping_msg.to_string())).await.is_err() {
196 break;
197 }
198 debug!("Sent ping: {pong_id}");
199 drop(w);
200
201 }
203 _ = ping_shutdown.notified() => break,
204 }
205 }
206 });
207
208 tokio::spawn(async move {
210 while let Some(msg) = read.next().await {
211 match msg {
212 Ok(Message::Text(text)) => {
213 let text_str: &str = &text;
214 debug!("WS message received ({} bytes)", text_str.len());
215 if text_str.len() > 1_048_576 {
217 warn!("Dropping oversized Mercury message ({} bytes)", text_str.len());
218 continue;
219 }
220 if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
221 Self::handle_message_static(&parsed, &event_tx, &write_clone).await;
222 } else {
223 debug!("Failed to parse WS message as JSON");
224 }
225 }
226 Ok(Message::Close(frame)) => {
227 let code = frame.as_ref().map(|f| f.code.into()).unwrap_or(1000u16);
228 let reason = frame.as_ref().map(|f| f.reason.to_string()).unwrap_or_default();
229 Self::handle_close_static(
230 code,
231 &reason,
232 &connected,
233 &should_reconnect,
234 &reconnect_attempts,
235 max_reconnect,
236 backoff_max,
237 &base_url_clone,
238 &token_clone,
239 &event_tx,
240 )
241 .await;
242 break;
243 }
244 Err(e) => {
245 error!("WebSocket error: {e}");
246 let _ = event_tx.send(MercuryEvent::Error(e.to_string()));
247 *connected.lock().await = false;
248 break;
249 }
250 _ => {}
251 }
252 }
253
254 if *should_reconnect.lock().await && !*connected.lock().await {
256 }
258 });
259
260 Ok(())
261 }
262
263 fn prepare_url(base_url: &str) -> Result<String, WebexError> {
264 let mut url = Url::parse(base_url)
265 .map_err(|e| WebexError::mercury_connection(format!("Invalid URL: {e}"), None))?;
266 url.query_pairs_mut()
267 .append_pair("outboundWireFormat", "text")
268 .append_pair("bufferStates", "true")
269 .append_pair("aliasHttpStatus", "true")
270 .append_pair(
271 "clientTimestamp",
272 &std::time::SystemTime::now()
273 .duration_since(std::time::UNIX_EPOCH)
274 .unwrap_or_default()
275 .as_millis()
276 .to_string(),
277 );
278 Ok(url.to_string())
279 }
280
281 fn is_connection_ready(message: &Value) -> bool {
282 let event_type = message
283 .get("data")
284 .and_then(|d| d.get("eventType"))
285 .and_then(|e| e.as_str())
286 .unwrap_or("");
287 event_type.contains("mercury.buffer_state") || event_type.contains("mercury.registration_status")
288 }
289
290 async fn handle_message_static(
291 message: &Value,
292 event_tx: &mpsc::UnboundedSender<MercuryEvent>,
293 write: &Arc<Mutex<WsSink>>,
294 ) {
295 let msg_type = message.get("type").and_then(|t| t.as_str()).unwrap_or("");
296
297 match msg_type {
298 "pong" => {
299 let id = message.get("id").and_then(|i| i.as_str()).unwrap_or("");
300 debug!("Received pong: {id}");
301 }
302 "shutdown" => {
303 info!("Received shutdown message from Mercury");
304 }
306 _ => {
307 if let Some(data) = message.get("data") {
308 if let Some(event_type) = data.get("eventType").and_then(|e| e.as_str()) {
309 debug!("Mercury eventType: {event_type}");
310
311 if let Some(msg_id) = message.get("id").and_then(|i| i.as_str()) {
313 let ack = serde_json::json!({"messageId": msg_id, "type": "ack"});
314 let mut w = write.lock().await;
315 let _ = w.send(Message::Text(ack.to_string())).await;
316 }
317
318 if event_type.starts_with("encryption.") {
319 debug!("Emitting kms:response for eventType: {event_type}");
320 let _ = event_tx.send(MercuryEvent::KmsResponse(data.clone()));
321 } else if event_type == "conversation.activity" {
322 if let Some(activity_raw) = data.get("activity") {
323 match serde_json::from_value::<MercuryActivity>(activity_raw.clone()) {
324 Ok(activity) => {
325 debug!("Emitting activity: {}", activity.id);
326 let _ = event_tx.send(MercuryEvent::Activity(Box::new(activity)));
327 }
328 Err(e) => {
329 error!("Failed to parse activity: {e}");
330 debug!("Raw activity keys: {:?}", activity_raw.as_object().map(|o| o.keys().collect::<Vec<_>>()));
331 }
332 }
333 }
334 }
335 } else {
336 debug!("Unhandled Mercury message, type={msg_type:?}, keys={:?}", message.as_object().map(|o| o.keys().collect::<Vec<_>>()));
337 }
338 } else {
339 debug!("Unhandled Mercury message, type={msg_type:?}, no data field");
340 }
341 }
342 }
343 }
344
345 #[allow(clippy::too_many_arguments)]
346 async fn handle_close_static(
347 code: u16,
348 reason: &str,
349 connected: &Arc<Mutex<bool>>,
350 should_reconnect: &Arc<Mutex<bool>>,
351 reconnect_attempts: &Arc<Mutex<u32>>,
352 max_reconnect: u32,
353 backoff_max: Duration,
354 _base_url: &Arc<Mutex<String>>,
355 _token: &Arc<Mutex<String>>,
356 event_tx: &mpsc::UnboundedSender<MercuryEvent>,
357 ) {
358 info!("WebSocket closed with code {code}: {reason}");
359 *connected.lock().await = false;
360
361 if code == 4401 {
362 error!("Mercury authorization failed");
363 *should_reconnect.lock().await = false;
364 let _ = event_tx.send(MercuryEvent::Error("Mercury authorization failed".into()));
365 let _ = event_tx.send(MercuryEvent::Disconnected("auth-failed".into()));
366 return;
367 }
368
369 if code == 4400 || code == 4403 {
370 error!("Mercury permanent failure (code {code})");
371 *should_reconnect.lock().await = false;
372 let _ = event_tx.send(MercuryEvent::Error(format!("Mercury permanent failure (code {code})")));
373 let _ = event_tx.send(MercuryEvent::Disconnected("permanent-failure".into()));
374 return;
375 }
376
377 if *should_reconnect.lock().await {
378 let mut attempts = reconnect_attempts.lock().await;
379 if *attempts >= max_reconnect {
380 error!("Max reconnection attempts ({max_reconnect}) exceeded");
381 *should_reconnect.lock().await = false;
382 let _ = event_tx.send(MercuryEvent::Disconnected("max-attempts-exceeded".into()));
383 return;
384 }
385 *attempts += 1;
386 let attempt = *attempts;
387 let delay_secs = (2.0f64.powi(attempt as i32 - 1)).min(backoff_max.as_secs_f64());
388 drop(attempts);
389
390 info!("Reconnecting (attempt {attempt}/{max_reconnect}) in {delay_secs}s");
391 let _ = event_tx.send(MercuryEvent::Reconnecting(attempt));
392
393 time::sleep(Duration::from_secs_f64(delay_secs)).await;
394
395 let _ = event_tx.send(MercuryEvent::Disconnected("reconnect-needed".into()));
397 } else {
398 let _ = event_tx.send(MercuryEvent::Disconnected("manual".into()));
399 }
400 }
401
402 pub async fn disconnect(&self) {
404 info!("Disconnecting from Mercury");
405 *self.should_reconnect.lock().await = false;
406 *self.connected.lock().await = false;
407 self.shutdown.notify_waiters();
408 let _ = self.event_tx.send(MercuryEvent::Disconnected("client".into()));
409 }
410
411 pub async fn connected(&self) -> bool {
413 *self.connected.lock().await
414 }
415
416 pub async fn current_reconnect_attempts(&self) -> u32 {
418 *self.reconnect_attempts.lock().await
419 }
420}