1use std::path::Path;
32use std::sync::Arc;
33use std::time::Duration;
34
35use futures_util::StreamExt;
36use rustls::ClientConfig;
37use rustls_pki_types::CertificateDer;
38use serde::{Deserialize, Serialize};
39use tokio::sync::broadcast;
40use tokio_tungstenite::Connector;
41use tokio_tungstenite::tungstenite::{self, ClientRequestBuilder};
42use tokio_util::sync::CancellationToken;
43use url::Url;
44
45use crate::error::Error;
46use crate::transport::TlsMode;
47
48const EVENT_CHANNEL_CAPACITY: usize = 1024;
51
52#[derive(Debug, Clone, Serialize, Deserialize)]
59pub struct UnifiEvent {
60 pub key: String,
62
63 pub subsystem: String,
65
66 pub site_id: String,
68
69 #[serde(default)]
71 pub message: Option<String>,
72
73 #[serde(default)]
75 pub datetime: Option<String>,
76
77 #[serde(flatten)]
79 pub extra: serde_json::Value,
80}
81
82#[derive(Debug, Clone)]
86pub struct ReconnectConfig {
87 pub initial_delay: Duration,
89
90 pub max_delay: Duration,
92
93 pub max_retries: Option<u32>,
96}
97
98impl Default for ReconnectConfig {
99 fn default() -> Self {
100 Self {
101 initial_delay: Duration::from_secs(1),
102 max_delay: Duration::from_secs(30),
103 max_retries: None,
104 }
105 }
106}
107
108pub struct WebSocketHandle {
115 event_rx: broadcast::Receiver<Arc<UnifiEvent>>,
116 cancel: CancellationToken,
117}
118
119impl WebSocketHandle {
120 pub fn connect(
126 ws_url: Url,
127 reconnect: ReconnectConfig,
128 cancel: CancellationToken,
129 cookie: Option<String>,
130 tls_mode: TlsMode,
131 ) -> Result<Self, Error> {
132 let (event_tx, event_rx) = broadcast::channel(EVENT_CHANNEL_CAPACITY);
133
134 let task_cancel = cancel.clone();
135 tokio::spawn(async move {
136 ws_loop(ws_url, event_tx, reconnect, task_cancel, cookie, tls_mode).await;
137 });
138
139 Ok(Self { event_rx, cancel })
140 }
141
142 pub fn subscribe(&self) -> broadcast::Receiver<Arc<UnifiEvent>> {
147 self.event_rx.resubscribe()
148 }
149
150 pub fn shutdown(&self) {
152 self.cancel.cancel();
153 }
154}
155
156async fn ws_loop(
160 ws_url: Url,
161 event_tx: broadcast::Sender<Arc<UnifiEvent>>,
162 reconnect: ReconnectConfig,
163 cancel: CancellationToken,
164 cookie: Option<String>,
165 tls_mode: TlsMode,
166) {
167 let mut attempt: u32 = 0;
168
169 loop {
170 tokio::select! {
171 biased;
172 () = cancel.cancelled() => break,
173 result = connect_and_read(&ws_url, &event_tx, &cancel, cookie.as_deref(), &tls_mode) => {
174 match result {
175 Ok(()) => {
178 tracing::info!("WebSocket disconnected cleanly, reconnecting");
179 attempt = 0;
180 }
181 Err(e) => {
182 tracing::warn!(error = %e, attempt, "WebSocket error");
183
184 if let Some(max) = reconnect.max_retries {
185 if attempt >= max {
186 tracing::error!(
187 max_retries = max,
188 "WebSocket reconnection limit reached, giving up"
189 );
190 break;
191 }
192 }
193
194 let delay = calculate_backoff(attempt, &reconnect);
195 let delay_ms = u64::try_from(delay.as_millis()).unwrap_or(u64::MAX);
196 tracing::info!(
197 delay_ms,
198 attempt,
199 "Waiting before reconnect"
200 );
201
202 tokio::select! {
203 biased;
204 () = cancel.cancelled() => break,
205 () = tokio::time::sleep(delay) => {}
206 }
207
208 attempt += 1;
209 }
210 }
211 }
212 }
213 }
214
215 #[allow(unreachable_code)]
218 {
219 tracing::debug!("WebSocket loop exiting");
220 }
221}
222
223async fn connect_and_read(
230 url: &Url,
231 event_tx: &broadcast::Sender<Arc<UnifiEvent>>,
232 cancel: &CancellationToken,
233 cookie: Option<&str>,
234 tls_mode: &TlsMode,
235) -> Result<(), Error> {
236 tracing::info!(url = %url, "Connecting to WebSocket");
237
238 let uri: tungstenite::http::Uri = url
239 .as_str()
240 .parse()
241 .map_err(|e: tungstenite::http::uri::InvalidUri| Error::WebSocketConnect(e.to_string()))?;
242
243 let mut request = ClientRequestBuilder::new(uri);
244 if let Some(cookie_val) = cookie {
245 request = request.with_header("Cookie", cookie_val);
246 }
247
248 let connector = build_tls_connector(tls_mode)?;
249
250 let (ws_stream, _response) =
251 tokio_tungstenite::connect_async_tls_with_config(request, None, false, connector)
252 .await
253 .map_err(|e| Error::WebSocketConnect(e.to_string()))?;
254
255 tracing::info!("WebSocket connected");
256
257 let (_write, mut read) = ws_stream.split();
258
259 loop {
260 tokio::select! {
261 biased;
262 () = cancel.cancelled() => return Ok(()),
263 frame = read.next() => {
264 match frame {
265 Some(Ok(tungstenite::Message::Text(text))) => {
266 parse_and_broadcast(&text, event_tx);
267 }
268 Some(Ok(tungstenite::Message::Ping(_))) => {
269 tracing::trace!("WebSocket ping");
271 }
272 Some(Ok(tungstenite::Message::Close(frame))) => {
273 if let Some(ref cf) = frame {
274 tracing::info!(
275 code = %cf.code,
276 reason = %cf.reason,
277 "WebSocket close frame received"
278 );
279 } else {
280 tracing::info!("WebSocket close frame received (no payload)");
281 }
282 return Ok(());
283 }
284 Some(Err(e)) => {
285 return Err(Error::WebSocketConnect(e.to_string()));
286 }
287 None => {
288 tracing::info!("WebSocket stream ended");
290 return Ok(());
291 }
292 _ => {
293 }
295 }
296 }
297 }
298 }
299}
300
301#[derive(Debug, Deserialize)]
307struct WsEnvelope {
308 #[allow(dead_code)]
309 meta: WsMeta,
310 data: Vec<serde_json::Value>,
311}
312
313#[derive(Debug, Deserialize)]
314struct WsMeta {
315 #[allow(dead_code)]
316 rc: String,
317 #[serde(default)]
318 message: Option<String>,
319}
320
321fn parse_and_broadcast(text: &str, event_tx: &broadcast::Sender<Arc<UnifiEvent>>) {
323 let envelope: WsEnvelope = match serde_json::from_str(text) {
324 Ok(e) => e,
325 Err(e) => {
326 tracing::debug!(error = %e, "Failed to parse WebSocket envelope");
327 return;
328 }
329 };
330
331 let msg_type = envelope.meta.message.as_deref().unwrap_or("");
332
333 for data in envelope.data {
337 let event = match msg_type {
338 "events" => match serde_json::from_value::<UnifiEvent>(data.clone()) {
339 Ok(evt) => evt,
340 Err(e) => {
341 tracing::debug!(
342 error = %e,
343 msg_type,
344 "Could not deserialize event, constructing from raw data"
345 );
346 event_from_raw(msg_type, &data)
347 }
348 },
349 _ => event_from_raw(msg_type, &data),
351 };
352
353 let _ = event_tx.send(Arc::new(event));
355 }
356}
357
358fn event_from_raw(msg_type: &str, data: &serde_json::Value) -> UnifiEvent {
361 UnifiEvent {
362 key: data["key"].as_str().unwrap_or(msg_type).to_string(),
363 subsystem: data["subsystem"].as_str().unwrap_or("unknown").to_string(),
364 site_id: data["site_id"].as_str().unwrap_or("").to_string(),
365 message: data["msg"]
366 .as_str()
367 .or_else(|| data["message"].as_str())
368 .map(String::from),
369 datetime: data["datetime"].as_str().map(String::from),
370 extra: data.clone(),
371 }
372}
373
374fn build_tls_connector(tls_mode: &TlsMode) -> Result<Option<Connector>, Error> {
382 match tls_mode {
383 TlsMode::System => Ok(None),
384 TlsMode::CustomCa(path) => {
385 let root_store = load_root_store(path)?;
386 let tls_config = ClientConfig::builder()
387 .with_root_certificates(root_store)
388 .with_no_client_auth();
389 Ok(Some(Connector::Rustls(Arc::new(tls_config))))
390 }
391 TlsMode::DangerAcceptInvalid => {
392 let tls_config = ClientConfig::builder()
393 .dangerous()
394 .with_custom_certificate_verifier(Arc::new(NoVerifier))
395 .with_no_client_auth();
396 Ok(Some(Connector::Rustls(Arc::new(tls_config))))
397 }
398 }
399}
400
401fn load_root_store(path: &Path) -> Result<rustls::RootCertStore, Error> {
403 use rustls_pki_types::pem::PemObject;
404
405 let mut root_store = rustls::RootCertStore::empty();
406 for cert in CertificateDer::pem_file_iter(path)
407 .map_err(|e| Error::Tls(format!("failed to read CA cert: {e}")))?
408 {
409 let cert = cert.map_err(|e| Error::Tls(format!("invalid PEM in CA file: {e}")))?;
410 root_store
411 .add(cert)
412 .map_err(|e| Error::Tls(format!("invalid CA cert: {e}")))?;
413 }
414 Ok(root_store)
415}
416
417#[derive(Debug)]
422struct NoVerifier;
423
424impl rustls::client::danger::ServerCertVerifier for NoVerifier {
425 fn verify_server_cert(
426 &self,
427 _end_entity: &CertificateDer<'_>,
428 _intermediates: &[CertificateDer<'_>],
429 _server_name: &rustls::pki_types::ServerName<'_>,
430 _ocsp_response: &[u8],
431 _now: rustls::pki_types::UnixTime,
432 ) -> Result<rustls::client::danger::ServerCertVerified, rustls::Error> {
433 Ok(rustls::client::danger::ServerCertVerified::assertion())
434 }
435
436 fn verify_tls12_signature(
437 &self,
438 _message: &[u8],
439 _cert: &CertificateDer<'_>,
440 _dss: &rustls::DigitallySignedStruct,
441 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
442 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
443 }
444
445 fn verify_tls13_signature(
446 &self,
447 _message: &[u8],
448 _cert: &CertificateDer<'_>,
449 _dss: &rustls::DigitallySignedStruct,
450 ) -> Result<rustls::client::danger::HandshakeSignatureValid, rustls::Error> {
451 Ok(rustls::client::danger::HandshakeSignatureValid::assertion())
452 }
453
454 fn supported_verify_schemes(&self) -> Vec<rustls::SignatureScheme> {
455 rustls::crypto::ring::default_provider()
456 .signature_verification_algorithms
457 .supported_schemes()
458 }
459}
460
461fn calculate_backoff(attempt: u32, config: &ReconnectConfig) -> Duration {
469 let base = config.initial_delay.as_secs_f64()
470 * 2.0_f64.powi(i32::try_from(attempt).unwrap_or(i32::MAX));
471 let capped = base.min(config.max_delay.as_secs_f64());
472
473 let jitter_factor = 1.0 + 0.25 * ((f64::from(attempt) * 7.3).sin());
476 let with_jitter = (capped * jitter_factor).max(0.0);
477
478 Duration::from_secs_f64(with_jitter)
479}
480
481#[cfg(test)]
484#[allow(clippy::unwrap_used)]
485mod tests {
486 use super::*;
487
488 #[test]
489 fn default_reconnect_config() {
490 let config = ReconnectConfig::default();
491 assert_eq!(config.initial_delay, Duration::from_secs(1));
492 assert_eq!(config.max_delay, Duration::from_secs(30));
493 assert!(config.max_retries.is_none());
494 }
495
496 #[test]
497 fn backoff_increases_exponentially() {
498 let config = ReconnectConfig::default();
499
500 let d0 = calculate_backoff(0, &config);
501 let d1 = calculate_backoff(1, &config);
502 let d2 = calculate_backoff(2, &config);
503
504 assert!(d1 > d0, "d1 ({d1:?}) should be greater than d0 ({d0:?})");
506 assert!(d2 > d1, "d2 ({d2:?}) should be greater than d1 ({d1:?})");
507 }
508
509 #[test]
510 fn backoff_caps_at_max_delay() {
511 let config = ReconnectConfig {
512 initial_delay: Duration::from_secs(1),
513 max_delay: Duration::from_secs(10),
514 max_retries: None,
515 };
516
517 let d10 = calculate_backoff(10, &config);
518 assert!(
520 d10 <= Duration::from_secs(13),
521 "delay at attempt 10 ({d10:?}) should be capped near max_delay"
522 );
523 }
524
525 #[test]
526 fn parse_event_from_raw_json() {
527 let data = serde_json::json!({
528 "key": "EVT_WU_Connected",
529 "subsystem": "wlan",
530 "site_id": "abc123",
531 "msg": "User[aa:bb:cc:dd:ee:ff] connected",
532 "datetime": "2026-02-10T12:00:00Z",
533 "user": "aa:bb:cc:dd:ee:ff",
534 "ssid": "MyNetwork"
535 });
536
537 let event = event_from_raw("events", &data);
538 assert_eq!(event.key, "EVT_WU_Connected");
539 assert_eq!(event.subsystem, "wlan");
540 assert_eq!(event.site_id, "abc123");
541 assert_eq!(
542 event.message.as_deref(),
543 Some("User[aa:bb:cc:dd:ee:ff] connected")
544 );
545 assert_eq!(event.datetime.as_deref(), Some("2026-02-10T12:00:00Z"));
546 }
547
548 #[test]
549 fn parse_sync_event_from_raw_json() {
550 let data = serde_json::json!({
551 "mac": "aa:bb:cc:dd:ee:ff",
552 "state": 1,
553 "site_id": "site1"
554 });
555
556 let event = event_from_raw("device:sync", &data);
557 assert_eq!(event.key, "device:sync");
558 assert_eq!(event.subsystem, "unknown");
559 assert_eq!(event.site_id, "site1");
560 }
561
562 #[test]
563 fn deserialize_unifi_event() {
564 let json = r#"{
565 "key": "EVT_SW_Disconnected",
566 "subsystem": "lan",
567 "site_id": "default",
568 "message": "Switch lost contact",
569 "datetime": "2026-02-10T13:00:00Z",
570 "sw": "aa:bb:cc:dd:ee:ff",
571 "port": 4
572 }"#;
573
574 let event: UnifiEvent = serde_json::from_str(json).unwrap();
575 assert_eq!(event.key, "EVT_SW_Disconnected");
576 assert_eq!(event.subsystem, "lan");
577 assert_eq!(event.site_id, "default");
578 assert_eq!(event.message.as_deref(), Some("Switch lost contact"));
579 assert_eq!(event.extra["sw"], "aa:bb:cc:dd:ee:ff");
581 assert_eq!(event.extra["port"], 4);
582 }
583
584 #[test]
585 fn parse_and_broadcast_events_message() {
586 let (tx, mut rx) = broadcast::channel(16);
587
588 let raw = serde_json::json!({
589 "meta": { "rc": "ok", "message": "events" },
590 "data": [{
591 "key": "EVT_WU_Connected",
592 "subsystem": "wlan",
593 "site_id": "default",
594 "msg": "Client connected",
595 "user": "aa:bb:cc:dd:ee:ff"
596 }]
597 });
598
599 parse_and_broadcast(&raw.to_string(), &tx);
600
601 let event = rx.try_recv().unwrap();
602 assert_eq!(event.key, "EVT_WU_Connected");
603 assert_eq!(event.subsystem, "wlan");
604 }
605
606 #[test]
607 fn parse_and_broadcast_sync_message() {
608 let (tx, mut rx) = broadcast::channel(16);
609
610 let raw = serde_json::json!({
611 "meta": { "rc": "ok", "message": "device:sync" },
612 "data": [{
613 "mac": "aa:bb:cc:dd:ee:ff",
614 "state": 1,
615 "site_id": "site1"
616 }]
617 });
618
619 parse_and_broadcast(&raw.to_string(), &tx);
620
621 let event = rx.try_recv().unwrap();
622 assert_eq!(event.key, "device:sync");
623 assert_eq!(event.site_id, "site1");
624 }
625
626 #[test]
627 fn parse_and_broadcast_malformed_json() {
628 let (tx, mut rx) = broadcast::channel::<Arc<UnifiEvent>>(16);
629
630 parse_and_broadcast("not json at all", &tx);
631
632 assert!(rx.try_recv().is_err());
634 }
635}