1use crate::errors::WebexError;
4use crate::types::MercuryActivity;
5use futures_util::{SinkExt, StreamExt};
6use serde_json::Value;
7use std::sync::Arc;
8use std::time::Duration;
9use tokio::sync::{mpsc, Mutex, Notify};
10use tokio::time;
11use tokio_tungstenite::{connect_async, tungstenite::Message};
12use tracing::{debug, error, info};
13use url::Url;
14use uuid::Uuid;
15
16#[derive(Debug, Clone)]
18pub enum MercuryEvent {
19 Connected,
20 Disconnected(String),
21 Reconnecting(u32),
22 Activity(MercuryActivity),
23 KmsResponse(Value),
24 Error(String),
25}
26
27pub struct MercurySocket {
29 ping_interval: Duration,
30 pong_timeout: Duration,
31 reconnect_backoff_max: Duration,
32 max_reconnect_attempts: u32,
33
34 token: Arc<Mutex<String>>,
35 base_url: Arc<Mutex<String>>,
36 connected: Arc<Mutex<bool>>,
37 should_reconnect: Arc<Mutex<bool>>,
38 reconnect_attempts: Arc<Mutex<u32>>,
39 shutdown: Arc<Notify>,
40
41 event_tx: mpsc::UnboundedSender<MercuryEvent>,
42 event_rx: Arc<Mutex<Option<mpsc::UnboundedReceiver<MercuryEvent>>>>,
43}
44
45impl MercurySocket {
46 pub fn new(
47 ping_interval: Duration,
48 pong_timeout: Duration,
49 reconnect_backoff_max: Duration,
50 max_reconnect_attempts: u32,
51 ) -> Self {
52 let (event_tx, event_rx) = mpsc::unbounded_channel();
53
54 Self {
55 ping_interval,
56 pong_timeout,
57 reconnect_backoff_max,
58 max_reconnect_attempts,
59 token: Arc::new(Mutex::new(String::new())),
60 base_url: Arc::new(Mutex::new(String::new())),
61 connected: Arc::new(Mutex::new(false)),
62 should_reconnect: Arc::new(Mutex::new(true)),
63 reconnect_attempts: Arc::new(Mutex::new(0)),
64 shutdown: Arc::new(Notify::new()),
65 event_tx,
66 event_rx: Arc::new(Mutex::new(Some(event_rx))),
67 }
68 }
69
70 pub async fn take_event_rx(&self) -> Option<mpsc::UnboundedReceiver<MercuryEvent>> {
72 self.event_rx.lock().await.take()
73 }
74
75 pub async fn connect(&self, ws_url: &str, token: &str) -> Result<(), WebexError> {
77 *self.token.lock().await = token.to_string();
78 *self.base_url.lock().await = ws_url.to_string();
79 *self.should_reconnect.lock().await = true;
80 *self.reconnect_attempts.lock().await = 0;
81 self.connect_internal().await
82 }
83
84 async fn connect_internal(&self) -> Result<(), WebexError> {
85 let base_url = self.base_url.lock().await.clone();
86 let prepared_url = Self::prepare_url(&base_url)?;
87 debug!("Connecting to Mercury at {prepared_url}");
88
89 *self.connected.lock().await = false;
90
91 let (ws_stream, _) = connect_async(&prepared_url)
92 .await
93 .map_err(|e| WebexError::mercury_connection(format!("Failed to connect: {e}"), None))?;
94
95 let (mut write, mut read) = ws_stream.split();
96
97 let token = self.token.lock().await.clone();
99 let auth_msg = serde_json::json!({
100 "id": Uuid::new_v4().to_string(),
101 "type": "authorization",
102 "data": { "token": format!("Bearer {}", token) }
103 });
104 write
105 .send(Message::Text(auth_msg.to_string().into()))
106 .await
107 .map_err(|e| WebexError::mercury_connection(format!("Failed to send auth: {e}"), None))?;
108
109 let _ready_timeout = time::timeout(Duration::from_secs(30), async {
111 while let Some(msg) = read.next().await {
112 match msg {
113 Ok(Message::Text(text)) => {
114 let text_str: &str = &text;
115 if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
116 if Self::is_connection_ready(&parsed) {
117 return Ok(parsed);
118 }
119 }
120 }
121 Err(e) => {
122 return Err(WebexError::mercury_connection(
123 format!("WebSocket error during setup: {e}"),
124 None,
125 ));
126 }
127 _ => {}
128 }
129 }
130 Err(WebexError::mercury_connection("WebSocket closed during setup", None))
131 })
132 .await
133 .map_err(|_| WebexError::mercury_connection("Mercury connection timeout", None))??;
134
135 debug!("Mercury connection ready");
136 *self.connected.lock().await = true;
137
138 let event_tx = self.event_tx.clone();
140 let connected = self.connected.clone();
141 let should_reconnect = self.should_reconnect.clone();
142 let reconnect_attempts = self.reconnect_attempts.clone();
143 let max_reconnect = self.max_reconnect_attempts;
144 let backoff_max = self.reconnect_backoff_max;
145 let ping_interval = self.ping_interval;
146 let _pong_timeout = self.pong_timeout;
147 let shutdown = self.shutdown.clone();
148 let base_url_clone = self.base_url.clone();
149 let token_clone = self.token.clone();
150
151 let write = Arc::new(Mutex::new(write));
152 let write_clone = write.clone();
153
154 let ping_write = write.clone();
156 let ping_connected = connected.clone();
157 let ping_shutdown = shutdown.clone();
158 let _ping_event_tx = event_tx.clone();
159 tokio::spawn(async move {
160 let mut interval = time::interval(ping_interval);
161 interval.tick().await; loop {
164 tokio::select! {
165 _ = interval.tick() => {
166 if !*ping_connected.lock().await {
167 break;
168 }
169 let pong_id = Uuid::new_v4().to_string();
170 let ping_msg = serde_json::json!({
171 "id": pong_id,
172 "type": "ping"
173 });
174 let mut w = ping_write.lock().await;
175 if w.send(Message::Text(ping_msg.to_string().into())).await.is_err() {
176 break;
177 }
178 debug!("Sent ping: {pong_id}");
179 drop(w);
180
181 }
183 _ = ping_shutdown.notified() => break,
184 }
185 }
186 });
187
188 tokio::spawn(async move {
190 while let Some(msg) = read.next().await {
191 match msg {
192 Ok(Message::Text(text)) => {
193 let text_str: &str = &text;
194 debug!("WS message received ({} bytes)", text_str.len());
195 if let Ok(parsed) = serde_json::from_str::<Value>(text_str) {
196 Self::handle_message_static(&parsed, &event_tx, &write_clone).await;
197 } else {
198 debug!("Failed to parse WS message as JSON");
199 }
200 }
201 Ok(Message::Close(frame)) => {
202 let code = frame.as_ref().map(|f| f.code.into()).unwrap_or(1000u16);
203 let reason = frame.as_ref().map(|f| f.reason.to_string()).unwrap_or_default();
204 Self::handle_close_static(
205 code,
206 &reason,
207 &connected,
208 &should_reconnect,
209 &reconnect_attempts,
210 max_reconnect,
211 backoff_max,
212 &base_url_clone,
213 &token_clone,
214 &event_tx,
215 )
216 .await;
217 break;
218 }
219 Err(e) => {
220 error!("WebSocket error: {e}");
221 let _ = event_tx.send(MercuryEvent::Error(e.to_string()));
222 *connected.lock().await = false;
223 break;
224 }
225 _ => {}
226 }
227 }
228
229 if *should_reconnect.lock().await && *connected.lock().await == false {
231 }
233 });
234
235 Ok(())
236 }
237
238 fn prepare_url(base_url: &str) -> Result<String, WebexError> {
239 let mut url = Url::parse(base_url)
240 .map_err(|e| WebexError::mercury_connection(format!("Invalid URL: {e}"), None))?;
241 url.query_pairs_mut()
242 .append_pair("outboundWireFormat", "text")
243 .append_pair("bufferStates", "true")
244 .append_pair("aliasHttpStatus", "true")
245 .append_pair(
246 "clientTimestamp",
247 &std::time::SystemTime::now()
248 .duration_since(std::time::UNIX_EPOCH)
249 .unwrap_or_default()
250 .as_millis()
251 .to_string(),
252 );
253 Ok(url.to_string())
254 }
255
256 fn is_connection_ready(message: &Value) -> bool {
257 let event_type = message
258 .get("data")
259 .and_then(|d| d.get("eventType"))
260 .and_then(|e| e.as_str())
261 .unwrap_or("");
262 event_type.contains("mercury.buffer_state") || event_type.contains("mercury.registration_status")
263 }
264
265 async fn handle_message_static(
266 message: &Value,
267 event_tx: &mpsc::UnboundedSender<MercuryEvent>,
268 write: &Arc<Mutex<futures_util::stream::SplitSink<tokio_tungstenite::WebSocketStream<tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>>, Message>>>,
269 ) {
270 let msg_type = message.get("type").and_then(|t| t.as_str()).unwrap_or("");
271
272 match msg_type {
273 "pong" => {
274 let id = message.get("id").and_then(|i| i.as_str()).unwrap_or("");
275 debug!("Received pong: {id}");
276 }
277 "shutdown" => {
278 info!("Received shutdown message from Mercury");
279 }
281 _ => {
282 if let Some(data) = message.get("data") {
283 if let Some(event_type) = data.get("eventType").and_then(|e| e.as_str()) {
284 debug!("Mercury eventType: {event_type}");
285
286 if let Some(msg_id) = message.get("id").and_then(|i| i.as_str()) {
288 let ack = serde_json::json!({"messageId": msg_id, "type": "ack"});
289 let mut w = write.lock().await;
290 let _ = w.send(Message::Text(ack.to_string().into())).await;
291 }
292
293 if event_type.starts_with("encryption.") {
294 debug!("Emitting kms:response for eventType: {event_type}");
295 let _ = event_tx.send(MercuryEvent::KmsResponse(data.clone()));
296 } else if event_type == "conversation.activity" {
297 if let Some(activity_raw) = data.get("activity") {
298 match serde_json::from_value::<MercuryActivity>(activity_raw.clone()) {
299 Ok(activity) => {
300 debug!("Emitting activity: {}", activity.id);
301 let _ = event_tx.send(MercuryEvent::Activity(activity));
302 }
303 Err(e) => {
304 error!("Failed to parse activity: {e}");
305 debug!("Raw activity keys: {:?}", activity_raw.as_object().map(|o| o.keys().collect::<Vec<_>>()));
306 }
307 }
308 }
309 }
310 } else {
311 debug!("Unhandled Mercury message, type={msg_type:?}, keys={:?}", message.as_object().map(|o| o.keys().collect::<Vec<_>>()));
312 }
313 } else {
314 debug!("Unhandled Mercury message, type={msg_type:?}, no data field");
315 }
316 }
317 }
318 }
319
320 #[allow(clippy::too_many_arguments)]
321 async fn handle_close_static(
322 code: u16,
323 reason: &str,
324 connected: &Arc<Mutex<bool>>,
325 should_reconnect: &Arc<Mutex<bool>>,
326 reconnect_attempts: &Arc<Mutex<u32>>,
327 max_reconnect: u32,
328 backoff_max: Duration,
329 _base_url: &Arc<Mutex<String>>,
330 _token: &Arc<Mutex<String>>,
331 event_tx: &mpsc::UnboundedSender<MercuryEvent>,
332 ) {
333 info!("WebSocket closed with code {code}: {reason}");
334 *connected.lock().await = false;
335
336 if code == 4401 {
337 error!("Mercury authorization failed");
338 *should_reconnect.lock().await = false;
339 let _ = event_tx.send(MercuryEvent::Error("Mercury authorization failed".into()));
340 let _ = event_tx.send(MercuryEvent::Disconnected("auth-failed".into()));
341 return;
342 }
343
344 if code == 4400 || code == 4403 {
345 error!("Mercury permanent failure (code {code})");
346 *should_reconnect.lock().await = false;
347 let _ = event_tx.send(MercuryEvent::Error(format!("Mercury permanent failure (code {code})")));
348 let _ = event_tx.send(MercuryEvent::Disconnected("permanent-failure".into()));
349 return;
350 }
351
352 if *should_reconnect.lock().await {
353 let mut attempts = reconnect_attempts.lock().await;
354 if *attempts >= max_reconnect {
355 error!("Max reconnection attempts ({max_reconnect}) exceeded");
356 *should_reconnect.lock().await = false;
357 let _ = event_tx.send(MercuryEvent::Disconnected("max-attempts-exceeded".into()));
358 return;
359 }
360 *attempts += 1;
361 let attempt = *attempts;
362 let delay_secs = (2.0f64.powi(attempt as i32 - 1)).min(backoff_max.as_secs_f64());
363 drop(attempts);
364
365 info!("Reconnecting (attempt {attempt}/{max_reconnect}) in {delay_secs}s");
366 let _ = event_tx.send(MercuryEvent::Reconnecting(attempt));
367
368 time::sleep(Duration::from_secs_f64(delay_secs)).await;
369
370 let _ = event_tx.send(MercuryEvent::Disconnected("reconnect-needed".into()));
372 } else {
373 let _ = event_tx.send(MercuryEvent::Disconnected("manual".into()));
374 }
375 }
376
377 pub async fn disconnect(&self) {
379 info!("Disconnecting from Mercury");
380 *self.should_reconnect.lock().await = false;
381 *self.connected.lock().await = false;
382 self.shutdown.notify_waiters();
383 let _ = self.event_tx.send(MercuryEvent::Disconnected("client".into()));
384 }
385
386 pub async fn connected(&self) -> bool {
388 *self.connected.lock().await
389 }
390
391 pub async fn current_reconnect_attempts(&self) -> u32 {
393 *self.reconnect_attempts.lock().await
394 }
395}