1use crate::errors::WebexError;
4use crate::types::MercuryActivity;
5use futures_util::{SinkExt, StreamExt};
6use serde_json::Value;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{mpsc, Mutex, Notify};
10use tokio::time;
11use tokio_tungstenite::{connect_async, tungstenite::Message};
12use tracing::{debug, error, info};
13use url::Url;
14use uuid::Uuid;
15
16#[derive(Debug, Clone)]
18pub enum MercuryEvent {
19 Connected,
20 Disconnected(String),
21 Reconnecting(u32),
22 Activity(MercuryActivity),
23 KmsResponse(Value),
24 Error(String),
25}
26
27pub struct MercurySocket {
29 #[allow(dead_code)]
30 ws_factory: Option<Arc<dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Box<dyn crate::types::InjectedWebSocket>, Box<dyn std::error::Error + Send + Sync>>> + Send>> + Send + Sync>>,
31 ping_interval: Duration,
32 pong_timeout: Duration,
33 reconnect_backoff_max: Duration,
34 max_reconnect_attempts: u32,
35
36 token: Arc<Mutex<String>>,
37 base_url: Arc<Mutex<String>>,
38 connected: Arc<Mutex<bool>>,
39 should_reconnect: Arc<Mutex<bool>>,
40 reconnect_attempts: Arc<Mutex<u32>>,
41 shutdown: Arc<Notify>,
42
43 event_tx: mpsc::UnboundedSender<MercuryEvent>,
44 event_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<MercuryEvent>>>>,
45}
46
47impl MercurySocket {
48 pub fn new(
49 _ws_factory: Option<Arc<dyn Fn(String) -> std::pin::Pin<Box<dyn std::future::Future<Output = Result<Box<dyn crate::types::InjectedWebSocket>, Box<dyn std::error::Error + Send + Sync>>> + Send>> + Send + Sync>>,
50 ping_interval: Duration,
51 pong_timeout: Duration,
52 reconnect_backoff_max: Duration,
53 max_reconnect_attempts: u32,
54 ) -> Self {
55 let (event_tx, event_rx) = mpsc::unbounded_channel();
56
57 Self {
58 ws_factory: _ws_factory,
59 ping_interval,
60 pong_timeout,
61 reconnect_backoff_max,
62 max_reconnect_attempts,
63 token: Arc::new(Mutex::new(String::new())),
64 base_url: Arc::new(Mutex::new(String::new())),
65 connected: Arc::new(Mutex::new(false)),
66 should_reconnect: Arc::new(Mutex::new(true)),
67 reconnect_attempts: Arc::new(Mutex::new(0)),
68 shutdown: Arc::new(Notify::new()),
69 event_tx,
70 event_rx: Arc::new(Mutex::new(Some(event_rx))),
71 }
72 }
73
74 pub async fn take_event_rx(&self) -> Option<mpsc::UnboundedReceiver<MercuryEvent>> {
76 self.event_rx.lock().await.take()
77 }
78
79 pub async fn connect(&self, ws_url: &str, token: &str) -> Result<(), WebexError> {
81 *self.token.lock().await = token.to_string();
82 *self.base_url.lock().await = ws_url.to_string();
83 *self.should_reconnect.lock().await = true;
84 *self.reconnect_attempts.lock().await = 0;
85 self.connect_internal().await
86 }
87
88 async fn connect_internal(&self) -> Result<(), WebexError> {
89 let base_url = self.base_url.lock().await.clone();
90 let prepared_url = Self::prepare_url(&base_url)?;
91 debug!("Connecting to Mercury at {prepared_url}");
92
93 *self.connected.lock().await = false;
94
95 let (ws_stream, _) = connect_async(&prepared_url)
96 .await
97 .map_err(|e| WebexError::mercury_connection(format!("Failed to connect: {e}"), None))?;
98
99 let (mut write, mut read) = ws_stream.split();
100
101 let token = self.token.lock().await.clone();
103 let auth_msg = serde_json::json!({
104 "id": Uuid::new_v4().to_string(),
105 "type": "authorization",
106 "data": { "token": format!("Bearer {}", token) }
107 });
108 write
109 .send(Message::Text(auth_msg.to_string().into()))
110 .await
111 .map_err(|e| WebexError::mercury_connection(format!("Failed to send auth: {e}"), None))?;
112
113 let _ready_timeout = time::timeout(Duration::from_secs(30), async {
115 while let Some(msg) = read.next().await {
116 match msg {
117 Ok(Message::Text(text)) => {
118 let text_str: &str = &text;
119 if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
120 if Self::is_connection_ready(&parsed) {
121 return Ok(parsed);
122 }
123 }
124 }
125 Err(e) => {
126 return Err(WebexError::mercury_connection(
127 format!("WebSocket error during setup: {e}"),
128 None,
129 ));
130 }
131 _ => {}
132 }
133 }
134 Err(WebexError::mercury_connection("WebSocket closed during setup", None))
135 })
136 .await
137 .map_err(|_| WebexError::mercury_connection("Mercury connection timeout", None))??;
138
139 debug!("Mercury connection ready");
140 *self.connected.lock().await = true;
141
142 let event_tx = self.event_tx.clone();
144 let connected = self.connected.clone();
145 let should_reconnect = self.should_reconnect.clone();
146 let reconnect_attempts = self.reconnect_attempts.clone();
147 let max_reconnect = self.max_reconnect_attempts;
148 let backoff_max = self.reconnect_backoff_max;
149 let ping_interval = self.ping_interval;
150 let _pong_timeout = self.pong_timeout;
151 let shutdown = self.shutdown.clone();
152 let base_url_clone = self.base_url.clone();
153 let token_clone = self.token.clone();
154
155 let write = Arc::new(Mutex::new(write));
156 let write_clone = write.clone();
157
158 let ping_write = write.clone();
160 let ping_connected = connected.clone();
161 let ping_shutdown = shutdown.clone();
162 let _ping_event_tx = event_tx.clone();
163 tokio::spawn(async move {
164 let mut interval = time::interval(ping_interval);
165 interval.tick().await; loop {
168 tokio::select! {
169 _ = interval.tick() => {
170 if !*ping_connected.lock().await {
171 break;
172 }
173 let pong_id = Uuid::new_v4().to_string();
174 let ping_msg = serde_json::json!({
175 "id": pong_id,
176 "type": "ping"
177 });
178 let mut w = ping_write.lock().await;
179 if w.send(Message::Text(ping_msg.to_string().into())).await.is_err() {
180 break;
181 }
182 debug!("Sent ping: {pong_id}");
183 drop(w);
184
185 }
187 _ = ping_shutdown.notified() => break,
188 }
189 }
190 });
191
192 tokio::spawn(async move {
194 while let Some(msg) = read.next().await {
195 match msg {
196 Ok(Message::Text(text)) => {
197 let text_str: &str = &text;
198 debug!("WS message received ({} bytes)", text_str.len());
199 if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
200 Self::handle_message_static(&parsed, &event_tx, &write_clone).await;
201 } else {
202 debug!("Failed to parse WS message as JSON");
203 }
204 }
205 Ok(Message::Close(frame)) => {
206 let code = frame.as_ref().map(|f| f.code.into()).unwrap_or(1000u16);
207 let reason = frame.as_ref().map(|f| f.reason.to_string()).unwrap_or_default();
208 Self::handle_close_static(
209 code,
210 &reason,
211 &connected,
212 &should_reconnect,
213 &reconnect_attempts,
214 max_reconnect,
215 backoff_max,
216 &base_url_clone,
217 &token_clone,
218 &event_tx,
219 )
220 .await;
221 break;
222 }
223 Err(e) => {
224 error!("WebSocket error: {e}");
225 let _ = event_tx.send(MercuryEvent::Error(e.to_string()));
226 *connected.lock().await = false;
227 break;
228 }
229 _ => {}
230 }
231 }
232
233 if *should_reconnect.lock().await && *connected.lock().await == false {
235 }
237 });
238
239 Ok(())
240 }
241
242 fn prepare_url(base_url: &str) -> Result<String, WebexError> {
243 let mut url = Url::parse(base_url)
244 .map_err(|e| WebexError::mercury_connection(format!("Invalid URL: {e}"), None))?;
245 url.query_pairs_mut()
246 .append_pair("outboundWireFormat", "text")
247 .append_pair("bufferStates", "true")
248 .append_pair("aliasHttpStatus", "true")
249 .append_pair(
250 "clientTimestamp",
251 &std::time::SystemTime::now()
252 .duration_since(std::time::UNIX_EPOCH)
253 .unwrap_or_default()
254 .as_millis()
255 .to_string(),
256 );
257 Ok(url.to_string())
258 }
259
260 fn is_connection_ready(message: &Value) -> bool {
261 let event_type = message
262 .get("data")
263 .and_then(|d| d.get("eventType"))
264 .and_then(|e| e.as_str())
265 .unwrap_or("");
266 event_type.contains("mercury.buffer_state") || event_type.contains("mercury.registration_status")
267 }
268
269 async fn handle_message_static(
270 message: &Value,
271 event_tx: &mpsc::UnboundedSender<MercuryEvent>,
272 write: &Arc<Mutex<futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>>>,
273 ) {
274 let msg_type = message.get("type").and_then(|t| t.as_str()).unwrap_or("");
275
276 match msg_type {
277 "pong" => {
278 let id = message.get("id").and_then(|i| i.as_str()).unwrap_or("");
279 debug!("Received pong: {id}");
280 }
281 "shutdown" => {
282 info!("Received shutdown message from Mercury");
283 }
285 _ => {
286 if let Some(data) = message.get("data") {
287 if let Some(event_type) = data.get("eventType").and_then(|e| e.as_str()) {
288 debug!("Mercury eventType: {event_type}");
289
290 if let Some(msg_id) = message.get("id").and_then(|i| i.as_str()) {
292 let ack = serde_json::json!({"messageId": msg_id, "type": "ack"});
293 let mut w = write.lock().await;
294 let _ = w.send(Message::Text(ack.to_string().into())).await;
295 }
296
297 if event_type.starts_with("encryption.") {
298 debug!("Emitting kms:response for eventType: {event_type}");
299 let _ = event_tx.send(MercuryEvent::KmsResponse(data.clone()));
300 } else if event_type == "conversation.activity" {
301 if let Some(activity_raw) = data.get("activity") {
302 match serde_json::from_value::<MercuryActivity>(activity_raw.clone()) {
303 Ok(activity) => {
304 debug!("Emitting activity: {}", activity.id);
305 let _ = event_tx.send(MercuryEvent::Activity(activity));
306 }
307 Err(e) => {
308 error!("Failed to parse activity: {e}");
309 debug!("Raw activity keys: {:?}", activity_raw.as_object().map(|o| o.keys().collect::<Vec<_>>()));
310 }
311 }
312 }
313 }
314 } else {
315 debug!("Unhandled Mercury message, type={msg_type:?}, keys={:?}", message.as_object().map(|o| o.keys().collect::<Vec<_>>()));
316 }
317 } else {
318 debug!("Unhandled Mercury message, type={msg_type:?}, no data field");
319 }
320 }
321 }
322 }
323
324 #[allow(clippy::too_many_arguments)]
325 async fn handle_close_static(
326 code: u16,
327 reason: &str,
328 connected: &Arc<Mutex<bool>>,
329 should_reconnect: &Arc<Mutex<bool>>,
330 reconnect_attempts: &Arc<Mutex<u32>>,
331 max_reconnect: u32,
332 backoff_max: Duration,
333 _base_url: &Arc<Mutex<String>>,
334 _token: &Arc<Mutex<String>>,
335 event_tx: &mpsc::UnboundedSender<MercuryEvent>,
336 ) {
337 info!("WebSocket closed with code {code}: {reason}");
338 *connected.lock().await = false;
339
340 if code == 4401 {
341 error!("Mercury authorization failed");
342 *should_reconnect.lock().await = false;
343 let _ = event_tx.send(MercuryEvent::Error("Mercury authorization failed".into()));
344 let _ = event_tx.send(MercuryEvent::Disconnected("auth-failed".into()));
345 return;
346 }
347
348 if code == 4400 || code == 4403 {
349 error!("Mercury permanent failure (code {code})");
350 *should_reconnect.lock().await = false;
351 let _ = event_tx.send(MercuryEvent::Error(format!("Mercury permanent failure (code {code})")));
352 let _ = event_tx.send(MercuryEvent::Disconnected("permanent-failure".into()));
353 return;
354 }
355
356 if *should_reconnect.lock().await {
357 let mut attempts = reconnect_attempts.lock().await;
358 if *attempts >= max_reconnect {
359 error!("Max reconnection attempts ({max_reconnect}) exceeded");
360 *should_reconnect.lock().await = false;
361 let _ = event_tx.send(MercuryEvent::Disconnected("max-attempts-exceeded".into()));
362 return;
363 }
364 *attempts += 1;
365 let attempt = *attempts;
366 let delay_secs = (2.0f64.powi(attempt as i32 - 1)).min(backoff_max.as_secs_f64());
367 drop(attempts);
368
369 info!("Reconnecting (attempt {attempt}/{max_reconnect}) in {delay_secs}s");
370 let _ = event_tx.send(MercuryEvent::Reconnecting(attempt));
371
372 time::sleep(Duration::from_secs_f64(delay_secs)).await;
373
374 let _ = event_tx.send(MercuryEvent::Disconnected("reconnect-needed".into()));
376 } else {
377 let _ = event_tx.send(MercuryEvent::Disconnected("manual".into()));
378 }
379 }
380
381 pub async fn disconnect(&self) {
383 info!("Disconnecting from Mercury");
384 *self.should_reconnect.lock().await = false;
385 *self.connected.lock().await = false;
386 self.shutdown.notify_waiters();
387 let _ = self.event_tx.send(MercuryEvent::Disconnected("client".into()));
388 }
389
390 pub async fn connected(&self) -> bool {
392 *self.connected.lock().await
393 }
394
395 pub async fn current_reconnect_attempts(&self) -> u32 {
397 *self.reconnect_attempts.lock().await
398 }
399}