1use crate::error::{OverlayError, Result};
8#[cfg(feature = "nat")]
9use crate::nat::ConnectionType;
10use serde::{Deserialize, Serialize};
11use std::collections::HashMap;
12use std::net::{IpAddr, SocketAddr};
13use std::path::{Path, PathBuf};
14use std::time::{Duration, Instant};
15#[cfg(unix)]
16use tokio::io::{AsyncReadExt, AsyncWriteExt};
17use tokio::process::Command;
18use tokio::sync::RwLock;
19use tracing::{debug, info, warn};
20
21pub const DEFAULT_CHECK_INTERVAL: Duration = Duration::from_secs(30);
23
24pub const HANDSHAKE_TIMEOUT_SECS: u64 = 180;
26
27pub const PING_TIMEOUT_SECS: u64 = 5;
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct PeerStatus {
33 pub public_key: String,
35
36 pub overlay_ip: Option<IpAddr>,
38
39 pub healthy: bool,
41
42 pub last_handshake_secs: Option<u64>,
44
45 pub last_ping_ms: Option<u64>,
47
48 pub failure_count: u32,
50
51 pub last_check: u64,
53
54 #[cfg(feature = "nat")]
56 #[serde(default)]
57 pub connection_type: ConnectionType,
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct OverlayHealth {
63 pub interface: String,
65
66 pub total_peers: usize,
68
69 pub healthy_peers: usize,
71
72 pub unhealthy_peers: usize,
74
75 pub peers: Vec<PeerStatus>,
77
78 pub last_check: u64,
80}
81
82#[derive(Debug, Clone)]
84pub struct WgPeerStats {
85 pub public_key: String,
86 pub endpoint: Option<String>,
87 pub allowed_ips: Vec<String>,
88 pub last_handshake_time: Option<u64>,
89 pub transfer_rx: u64,
90 pub transfer_tx: u64,
91}
92
93pub struct OverlayHealthChecker {
98 interface: String,
100
101 check_interval: Duration,
103
104 handshake_timeout: Duration,
106
107 uapi_sock_dir: PathBuf,
112
113 peer_status: RwLock<HashMap<String, PeerStatus>>,
115}
116
117impl OverlayHealthChecker {
118 #[must_use]
124 pub fn new(interface: &str, check_interval: Duration) -> Self {
125 Self {
126 interface: interface.to_string(),
127 check_interval,
128 handshake_timeout: Duration::from_secs(HANDSHAKE_TIMEOUT_SECS),
129 uapi_sock_dir: PathBuf::from("/var/run/wireguard"),
130 peer_status: RwLock::new(HashMap::new()),
131 }
132 }
133
134 #[must_use]
136 pub fn default_for_interface(interface: &str) -> Self {
137 Self::new(interface, DEFAULT_CHECK_INTERVAL)
138 }
139
140 #[must_use]
142 pub fn with_handshake_timeout(mut self, timeout: Duration) -> Self {
143 self.handshake_timeout = timeout;
144 self
145 }
146
147 #[must_use]
153 pub fn with_uapi_sock_dir(mut self, dir: impl Into<PathBuf>) -> Self {
154 self.uapi_sock_dir = dir.into();
155 self
156 }
157
158 #[must_use]
160 pub fn uapi_sock_dir(&self) -> &Path {
161 &self.uapi_sock_dir
162 }
163
164 pub async fn run<F>(&self, mut on_status_change: F)
168 where
169 F: FnMut(&str, bool) + Send + 'static,
170 {
171 info!(
172 interface = %self.interface,
173 interval_secs = self.check_interval.as_secs(),
174 "Starting health check loop"
175 );
176
177 loop {
178 match self.check_all().await {
179 Ok(health) => {
180 for peer in &health.peers {
181 let mut cache = self.peer_status.write().await;
183 let changed = cache
184 .get(&peer.public_key)
185 .is_none_or(|prev| prev.healthy != peer.healthy);
186
187 if changed {
188 on_status_change(&peer.public_key, peer.healthy);
189 }
190
191 cache.insert(peer.public_key.clone(), peer.clone());
192 }
193 }
194 Err(e) => {
195 warn!(error = %e, "Health check failed");
196 }
197 }
198
199 tokio::time::sleep(self.check_interval).await;
200 }
201 }
202
203 #[allow(clippy::similar_names)]
209 pub async fn check_all(&self) -> Result<OverlayHealth> {
210 let now = current_timestamp();
211 let stats = self.get_wg_stats().await?;
212
213 let mut peers = Vec::with_capacity(stats.len());
214 let mut healthy_count = 0;
215
216 for stat in stats {
217 let healthy = self.is_peer_healthy(&stat);
218
219 if healthy {
220 healthy_count += 1;
221 }
222
223 let overlay_ip: Option<IpAddr> = stat.allowed_ips.iter().find_map(|ip_str| {
225 if ip_str.ends_with("/32") {
226 ip_str
227 .trim_end_matches("/32")
228 .parse::<IpAddr>()
229 .ok()
230 .filter(IpAddr::is_ipv4)
231 } else if ip_str.ends_with("/128") {
232 ip_str
233 .trim_end_matches("/128")
234 .parse::<IpAddr>()
235 .ok()
236 .filter(IpAddr::is_ipv6)
237 } else {
238 None
239 }
240 });
241
242 let status = PeerStatus {
243 public_key: stat.public_key,
244 overlay_ip,
245 healthy,
246 last_handshake_secs: stat.last_handshake_time.map(|t| now.saturating_sub(t)),
247 last_ping_ms: None, failure_count: u32::from(!healthy),
249 last_check: now,
250 #[cfg(feature = "nat")]
251 connection_type: ConnectionType::default(),
252 };
253
254 peers.push(status);
255 }
256
257 let total = peers.len();
258 Ok(OverlayHealth {
259 interface: self.interface.clone(),
260 total_peers: total,
261 healthy_peers: healthy_count,
262 unhealthy_peers: total - healthy_count,
263 peers,
264 last_check: now,
265 })
266 }
267
268 fn is_peer_healthy(&self, stats: &WgPeerStats) -> bool {
270 let now = current_timestamp();
271 let timeout_secs = self.handshake_timeout.as_secs();
272
273 stats
274 .last_handshake_time
275 .is_some_and(|t| now.saturating_sub(t) < timeout_secs)
276 }
277
278 pub async fn ping_peer(&self, overlay_ip: IpAddr) -> Result<Duration> {
289 let start = Instant::now();
290
291 #[cfg(target_os = "macos")]
294 let timeout_arg = (PING_TIMEOUT_SECS * 1000).to_string();
295 #[cfg(not(target_os = "macos"))]
296 let timeout_arg = PING_TIMEOUT_SECS.to_string();
297
298 let mut cmd = match overlay_ip {
300 IpAddr::V4(_) => Command::new("ping"),
301 IpAddr::V6(_) => {
302 #[cfg(target_os = "macos")]
303 {
304 Command::new("ping6")
305 }
306 #[cfg(not(target_os = "macos"))]
307 {
308 let mut c = Command::new("ping");
309 c.arg("-6");
310 c
311 }
312 }
313 };
314
315 cmd.args([
316 "-c",
317 "1", "-W",
319 &timeout_arg,
320 &overlay_ip.to_string(),
321 ]);
322
323 let output =
324 tokio::time::timeout(Duration::from_secs(PING_TIMEOUT_SECS), cmd.output()).await;
325
326 match output {
327 Ok(Ok(result)) if result.status.success() => Ok(start.elapsed()),
328 Ok(Ok(_)) => Err(OverlayError::PeerUnreachable {
329 ip: overlay_ip,
330 reason: "ping failed".to_string(),
331 }),
332 Ok(Err(e)) => Err(OverlayError::PeerUnreachable {
333 ip: overlay_ip,
334 reason: e.to_string(),
335 }),
336 Err(_) => Err(OverlayError::PeerUnreachable {
337 ip: overlay_ip,
338 reason: "timeout".to_string(),
339 }),
340 }
341 }
342
343 pub async fn tcp_check(&self, overlay_ip: IpAddr, port: u16) -> Result<Duration> {
351 let start = Instant::now();
352
353 let addr = SocketAddr::new(overlay_ip, port);
354 let result = tokio::time::timeout(
355 Duration::from_secs(PING_TIMEOUT_SECS),
356 tokio::net::TcpStream::connect(addr),
357 )
358 .await;
359
360 match result {
361 Ok(Ok(_stream)) => Ok(start.elapsed()),
362 Ok(Err(e)) => Err(OverlayError::PeerUnreachable {
363 ip: overlay_ip,
364 reason: e.to_string(),
365 }),
366 Err(_) => Err(OverlayError::PeerUnreachable {
367 ip: overlay_ip,
368 reason: "timeout".to_string(),
369 }),
370 }
371 }
372
373 async fn get_wg_stats(&self) -> Result<Vec<WgPeerStats>> {
379 let sock_path = self
380 .uapi_sock_dir
381 .join(format!("{}.sock", self.interface))
382 .to_string_lossy()
383 .into_owned();
384
385 let response = match uapi_get_raw(&sock_path).await {
386 Ok(resp) => resp,
387 Err(e) => {
388 let msg = e.to_string();
389 if msg.contains("No such file")
391 || msg.contains("Connection refused")
392 || msg.contains("not found")
393 {
394 return Ok(Vec::new());
395 }
396 return Err(OverlayError::TransportCommand(msg));
397 }
398 };
399
400 let peers = parse_uapi_get_response(&response);
401
402 debug!(interface = %self.interface, peer_count = peers.len(), "Retrieved overlay peer stats via UAPI");
403 Ok(peers)
404 }
405
406 pub async fn get_cached_status(&self, public_key: &str) -> Option<PeerStatus> {
408 let cache = self.peer_status.read().await;
409 cache.get(public_key).cloned()
410 }
411
412 pub fn check_interval(&self) -> Duration {
414 self.check_interval
415 }
416
417 pub fn interface(&self) -> &str {
419 &self.interface
420 }
421}
422
423fn current_timestamp() -> u64 {
425 std::time::SystemTime::now()
426 .duration_since(std::time::UNIX_EPOCH)
427 .unwrap_or_default()
428 .as_secs()
429}
430
431#[cfg(unix)]
438async fn uapi_get_raw(sock_path: &str) -> std::result::Result<String, Box<dyn std::error::Error>> {
439 let mut stream = tokio::net::UnixStream::connect(sock_path).await?;
440 stream.write_all(b"get=1\n\n").await?;
441 stream.shutdown().await?;
442 let mut response = String::new();
443 stream.read_to_string(&mut response).await?;
444 Ok(response)
445}
446
447#[cfg(not(unix))]
454#[allow(clippy::unused_async)]
455async fn uapi_get_raw(_sock_path: &str) -> std::result::Result<String, Box<dyn std::error::Error>> {
456 Err(Box::new(std::io::Error::new(
457 std::io::ErrorKind::NotFound,
458 "UAPI Unix socket not supported on this platform",
459 )))
460}
461
462fn hex_key_to_base64(hex_key: &str) -> String {
464 use base64::{engine::general_purpose::STANDARD, Engine as _};
465 match hex::decode(hex_key) {
466 Ok(bytes) => STANDARD.encode(bytes),
467 Err(_) => hex_key.to_string(), }
469}
470
471fn parse_uapi_get_response(response: &str) -> Vec<WgPeerStats> {
479 let mut peers = Vec::new();
480 let mut current_peer: Option<WgPeerStats> = None;
481 let mut in_peer = false;
482
483 for line in response.lines() {
484 let line = line.trim();
485 if line.is_empty() || line.starts_with("errno=") {
486 continue;
487 }
488
489 let Some((key, value)) = line.split_once('=') else {
490 continue;
491 };
492
493 match key {
494 "public_key" => {
495 if let Some(peer) = current_peer.take() {
497 peers.push(peer);
498 }
499 in_peer = true;
500 current_peer = Some(WgPeerStats {
501 public_key: hex_key_to_base64(value),
502 endpoint: None,
503 allowed_ips: Vec::new(),
504 last_handshake_time: None,
505 transfer_rx: 0,
506 transfer_tx: 0,
507 });
508 }
509 "endpoint" if in_peer => {
510 if let Some(ref mut peer) = current_peer {
511 if value != "(none)" {
512 peer.endpoint = Some(value.to_string());
513 }
514 }
515 }
516 "allowed_ip" if in_peer => {
517 if let Some(ref mut peer) = current_peer {
518 peer.allowed_ips.push(value.to_string());
519 }
520 }
521 "last_handshake_time_sec" if in_peer => {
522 if let Some(ref mut peer) = current_peer {
523 if let Ok(t) = value.parse::<u64>() {
524 if t > 0 {
525 peer.last_handshake_time = Some(t);
526 }
527 }
528 }
529 }
530 "rx_bytes" if in_peer => {
531 if let Some(ref mut peer) = current_peer {
532 peer.transfer_rx = value.parse().unwrap_or(0);
533 }
534 }
535 "tx_bytes" if in_peer => {
536 if let Some(ref mut peer) = current_peer {
537 peer.transfer_tx = value.parse().unwrap_or(0);
538 }
539 }
540 _ => {}
542 }
543 }
544
545 if let Some(peer) = current_peer {
547 peers.push(peer);
548 }
549
550 peers
551}
552
553#[cfg(test)]
554mod tests {
555 use super::*;
556
557 #[test]
558 fn test_peer_status_serialization_v4() {
559 let status = PeerStatus {
560 public_key: "test_key".to_string(),
561 overlay_ip: Some("10.200.0.5".parse::<IpAddr>().unwrap()),
562 healthy: true,
563 last_handshake_secs: Some(10),
564 last_ping_ms: Some(5),
565 failure_count: 0,
566 last_check: 1_234_567_890,
567 #[cfg(feature = "nat")]
568 connection_type: ConnectionType::default(),
569 };
570
571 let json = serde_json::to_string(&status).unwrap();
572 let deserialized: PeerStatus = serde_json::from_str(&json).unwrap();
573
574 assert_eq!(deserialized.public_key, "test_key");
575 assert!(deserialized.healthy);
576 assert_eq!(
577 deserialized.overlay_ip,
578 Some("10.200.0.5".parse::<IpAddr>().unwrap())
579 );
580 }
581
582 #[test]
583 fn test_peer_status_serialization_v6() {
584 let status = PeerStatus {
585 public_key: "test_key_v6".to_string(),
586 overlay_ip: Some("fd00::5".parse::<IpAddr>().unwrap()),
587 healthy: true,
588 last_handshake_secs: Some(10),
589 last_ping_ms: Some(5),
590 failure_count: 0,
591 last_check: 1_234_567_890,
592 #[cfg(feature = "nat")]
593 connection_type: ConnectionType::default(),
594 };
595
596 let json = serde_json::to_string(&status).unwrap();
597 let deserialized: PeerStatus = serde_json::from_str(&json).unwrap();
598
599 assert_eq!(deserialized.public_key, "test_key_v6");
600 assert!(deserialized.healthy);
601 assert_eq!(
602 deserialized.overlay_ip,
603 Some("fd00::5".parse::<IpAddr>().unwrap())
604 );
605 }
606
607 #[test]
608 fn test_overlay_health_serialization() {
609 let health = OverlayHealth {
610 interface: "zl-overlay0".to_string(),
611 total_peers: 2,
612 healthy_peers: 1,
613 unhealthy_peers: 1,
614 peers: vec![],
615 last_check: 1_234_567_890,
616 };
617
618 let json = serde_json::to_string_pretty(&health).unwrap();
619 assert!(json.contains("zl-overlay0"));
620 }
621
622 #[test]
623 fn test_health_checker_creation() {
624 let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(60));
625 assert_eq!(checker.interface(), "wg0");
626 assert_eq!(checker.check_interval(), Duration::from_secs(60));
627 }
628
629 #[test]
630 fn test_is_peer_healthy_recent_handshake() {
631 let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(30));
632
633 let now = current_timestamp();
634 let stats = WgPeerStats {
635 public_key: "key".to_string(),
636 endpoint: None,
637 allowed_ips: vec![],
638 last_handshake_time: Some(now - 60), transfer_rx: 0,
640 transfer_tx: 0,
641 };
642
643 assert!(checker.is_peer_healthy(&stats));
645 }
646
647 #[test]
648 fn test_is_peer_healthy_stale_handshake() {
649 let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(30));
650
651 let now = current_timestamp();
652 let stats = WgPeerStats {
653 public_key: "key".to_string(),
654 endpoint: None,
655 allowed_ips: vec![],
656 last_handshake_time: Some(now - 300), transfer_rx: 0,
658 transfer_tx: 0,
659 };
660
661 assert!(!checker.is_peer_healthy(&stats));
663 }
664
665 #[test]
666 fn test_is_peer_healthy_no_handshake() {
667 let checker = OverlayHealthChecker::new("wg0", Duration::from_secs(30));
668
669 let stats = WgPeerStats {
670 public_key: "key".to_string(),
671 endpoint: None,
672 allowed_ips: vec![],
673 last_handshake_time: None,
674 transfer_rx: 0,
675 transfer_tx: 0,
676 };
677
678 assert!(!checker.is_peer_healthy(&stats));
680 }
681
682 #[test]
683 fn test_parse_uapi_get_response() {
684 use base64::{engine::general_purpose::STANDARD, Engine as _};
685
686 let key_bytes = [0xABu8; 32];
688 let hex_key = hex::encode(key_bytes);
689 let expected_b64 = STANDARD.encode(key_bytes);
690
691 let response = format!(
692 "private_key=0000000000000000000000000000000000000000000000000000000000000000\n\
693 listen_port=51820\n\
694 public_key={hex_key}\n\
695 endpoint=192.168.1.5:51820\n\
696 allowed_ip=10.200.0.2/32\n\
697 last_handshake_time_sec=1700000000\n\
698 last_handshake_time_nsec=0\n\
699 rx_bytes=12345\n\
700 tx_bytes=67890\n\
701 persistent_keepalive_interval=25\n\
702 errno=0\n"
703 );
704
705 let peers = parse_uapi_get_response(&response);
706 assert_eq!(peers.len(), 1);
707
708 let peer = &peers[0];
709 assert_eq!(peer.public_key, expected_b64);
710 assert_eq!(peer.endpoint, Some("192.168.1.5:51820".to_string()));
711 assert_eq!(peer.allowed_ips, vec!["10.200.0.2/32".to_string()]);
712 assert_eq!(peer.last_handshake_time, Some(1_700_000_000));
713 assert_eq!(peer.transfer_rx, 12345);
714 assert_eq!(peer.transfer_tx, 67890);
715 }
716
717 #[test]
718 fn test_parse_uapi_get_response_multiple_peers() {
719 let key1 = hex::encode([0x01u8; 32]);
720 let key2 = hex::encode([0x02u8; 32]);
721
722 let response = format!(
723 "private_key=0000000000000000000000000000000000000000000000000000000000000000\n\
724 listen_port=51820\n\
725 public_key={key1}\n\
726 endpoint=10.0.0.1:51820\n\
727 allowed_ip=10.200.0.2/32\n\
728 rx_bytes=100\n\
729 tx_bytes=200\n\
730 public_key={key2}\n\
731 endpoint=10.0.0.2:51821\n\
732 allowed_ip=10.200.0.3/32\n\
733 allowed_ip=10.200.1.0/24\n\
734 rx_bytes=300\n\
735 tx_bytes=400\n\
736 errno=0\n"
737 );
738
739 let peers = parse_uapi_get_response(&response);
740 assert_eq!(peers.len(), 2);
741 assert_eq!(peers[0].transfer_rx, 100);
742 assert_eq!(peers[1].transfer_rx, 300);
743 assert_eq!(peers[1].allowed_ips.len(), 2);
744 }
745
746 #[test]
747 fn test_parse_uapi_get_response_empty() {
748 let response = "private_key=0000\nlisten_port=51820\nerrno=0\n";
749 let peers = parse_uapi_get_response(response);
750 assert!(peers.is_empty());
751 }
752
753 #[test]
754 fn test_hex_key_to_base64_roundtrip() {
755 use base64::{engine::general_purpose::STANDARD, Engine as _};
756
757 let key_bytes = [0xCDu8; 32];
758 let hex_key = hex::encode(key_bytes);
759 let b64 = hex_key_to_base64(&hex_key);
760 let expected = STANDARD.encode(key_bytes);
761 assert_eq!(b64, expected);
762 }
763}