Skip to main content

rustant_core/nodes/
discovery.rs

1//! Node discovery — finds nodes on the local network or locally.
2//!
3//! Supports local-only discovery (this machine) and mDNS-based LAN
4//! peer discovery using the `_rustant._tcp.local.` service name.
5//! The mDNS layer is trait-abstracted for testability.
6
7use super::types::{Capability, NodeId, NodeInfo, Platform};
8use async_trait::async_trait;
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11
12// ── mDNS constants ──────────────────────────────────────────────────
13
14/// mDNS multicast group (IPv4).
15pub const MDNS_MULTICAST_ADDR: &str = "224.0.0.251";
16/// mDNS port.
17pub const MDNS_PORT: u16 = 5353;
18/// Service name used for Rustant node discovery.
19pub const RUSTANT_SERVICE_NAME: &str = "_rustant._tcp.local.";
20
21// ── mDNS service record ─────────────────────────────────────────────
22
23/// An mDNS service record describing a Rustant node on the LAN.
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct MdnsServiceRecord {
26    /// The service name (always `_rustant._tcp.local.`).
27    pub service_name: String,
28    /// Human-readable instance name (e.g. "DevMac-Rustant").
29    pub instance_name: String,
30    /// IP address of the advertising node.
31    pub address: String,
32    /// Gateway port the node listens on.
33    pub port: u16,
34    /// Platform of the advertising node.
35    pub platform: Platform,
36    /// Node id.
37    pub node_id: String,
38    /// Comma-separated capability list.
39    pub capabilities_csv: String,
40}
41
42impl MdnsServiceRecord {
43    /// Parse capabilities from the CSV field.
44    pub fn parse_capabilities(&self) -> Vec<Capability> {
45        if self.capabilities_csv.is_empty() {
46            return Vec::new();
47        }
48        self.capabilities_csv
49            .split(',')
50            .filter_map(|s| match s.trim() {
51                "shell" => Some(Capability::Shell),
52                "filesystem" => Some(Capability::FileSystem),
53                "applescript" => Some(Capability::AppleScript),
54                "automator" => Some(Capability::Automator),
55                "screenshot" => Some(Capability::Screenshot),
56                "clipboard" => Some(Capability::Clipboard),
57                "notifications" => Some(Capability::Notifications),
58                "browser" => Some(Capability::Browser),
59                "camera" => Some(Capability::Camera),
60                "screen_record" => Some(Capability::ScreenRecord),
61                "location" => Some(Capability::Location),
62                other if other.starts_with("app_control:") => {
63                    Some(Capability::AppControl(other[12..].to_string()))
64                }
65                other if other.starts_with("custom:") => {
66                    Some(Capability::Custom(other[7..].to_string()))
67                }
68                _ => None,
69            })
70            .collect()
71    }
72
73    /// Build capability CSV from a slice of capabilities.
74    pub fn capabilities_to_csv(caps: &[Capability]) -> String {
75        caps.iter()
76            .map(|c| c.to_string())
77            .collect::<Vec<_>>()
78            .join(",")
79    }
80
81    /// Convert to a `DiscoveredNode`.
82    pub fn to_discovered_node(&self) -> DiscoveredNode {
83        DiscoveredNode {
84            node_id: NodeId::new(&self.node_id),
85            address: self.address.clone(),
86            port: self.port,
87            platform: self.platform,
88            capabilities: self.parse_capabilities(),
89            discovered_at: Utc::now(),
90        }
91    }
92}
93
94// ── mDNS service trait ──────────────────────────────────────────────
95
96/// Trait abstracting mDNS network operations for testability.
97#[async_trait]
98pub trait MdnsTransport: Send + Sync {
99    /// Register (advertise) this node on the local network.
100    async fn register(&self, record: &MdnsServiceRecord) -> Result<(), String>;
101
102    /// Unregister (stop advertising) this node.
103    async fn unregister(&self) -> Result<(), String>;
104
105    /// Perform a single discovery scan and return found service records.
106    async fn discover(&self, timeout_ms: u64) -> Result<Vec<MdnsServiceRecord>, String>;
107}
108
109// ── mDNS discovery coordinator ──────────────────────────────────────
110
111/// Configuration for mDNS-based node discovery.
112#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct MdnsConfig {
114    /// Whether mDNS discovery is enabled.
115    pub enabled: bool,
116    /// Scan interval in seconds for background discovery.
117    pub scan_interval_secs: u64,
118    /// Timeout in milliseconds for each discovery scan.
119    pub scan_timeout_ms: u64,
120    /// Maximum age in seconds before a node is considered stale.
121    pub stale_threshold_secs: u64,
122}
123
124impl Default for MdnsConfig {
125    fn default() -> Self {
126        Self {
127            enabled: false,
128            scan_interval_secs: 30,
129            scan_timeout_ms: 3000,
130            stale_threshold_secs: 120,
131        }
132    }
133}
134
135/// mDNS-based discovery coordinator.
136///
137/// Wraps an `MdnsTransport` implementation and manages registration,
138/// scanning, and stale-node pruning.
139pub struct MdnsDiscovery {
140    transport: Box<dyn MdnsTransport>,
141    config: MdnsConfig,
142    /// The record this node advertises (set after `register()`).
143    local_record: Option<MdnsServiceRecord>,
144    /// Nodes found via mDNS scans.
145    found: Vec<DiscoveredNode>,
146}
147
148impl MdnsDiscovery {
149    pub fn new(transport: Box<dyn MdnsTransport>, config: MdnsConfig) -> Self {
150        Self {
151            transport,
152            config,
153            local_record: None,
154            found: Vec::new(),
155        }
156    }
157
158    /// Register this node on the network.
159    pub async fn register(&mut self, record: MdnsServiceRecord) -> Result<(), String> {
160        self.transport.register(&record).await?;
161        self.local_record = Some(record);
162        Ok(())
163    }
164
165    /// Unregister this node.
166    pub async fn unregister(&mut self) -> Result<(), String> {
167        self.transport.unregister().await?;
168        self.local_record = None;
169        Ok(())
170    }
171
172    /// Whether this node is currently registered/advertising.
173    pub fn is_registered(&self) -> bool {
174        self.local_record.is_some()
175    }
176
177    /// Perform a single discovery scan. Returns newly found nodes.
178    pub async fn scan(&mut self) -> Result<Vec<DiscoveredNode>, String> {
179        let records = self.transport.discover(self.config.scan_timeout_ms).await?;
180        let local_id = self.local_record.as_ref().map(|r| r.node_id.as_str());
181
182        let mut new_nodes = Vec::new();
183        for record in records {
184            // Skip our own advertisement.
185            if let Some(lid) = local_id {
186                if record.node_id == lid {
187                    continue;
188                }
189            }
190
191            let already_known = self.found.iter().any(|n| n.node_id.0 == record.node_id);
192            let discovered = record.to_discovered_node();
193
194            if already_known {
195                // Refresh timestamp for existing node.
196                if let Some(existing) = self
197                    .found
198                    .iter_mut()
199                    .find(|n| n.node_id.0 == record.node_id)
200                {
201                    existing.discovered_at = Utc::now();
202                    existing.capabilities = discovered.capabilities;
203                }
204            } else {
205                new_nodes.push(discovered.clone());
206                self.found.push(discovered);
207            }
208        }
209
210        Ok(new_nodes)
211    }
212
213    /// All currently known remote nodes from mDNS.
214    pub fn found_nodes(&self) -> &[DiscoveredNode] {
215        &self.found
216    }
217
218    /// Remove stale nodes that haven't been refreshed within the threshold.
219    pub fn prune_stale(&mut self) -> usize {
220        let now = Utc::now();
221        let threshold = self.config.stale_threshold_secs as i64;
222        let before = self.found.len();
223        self.found.retain(|node| {
224            let age = now.signed_duration_since(node.discovered_at);
225            age.num_seconds() < threshold
226        });
227        before - self.found.len()
228    }
229
230    /// Clear all found nodes.
231    pub fn clear(&mut self) {
232        self.found.clear();
233    }
234
235    /// Access the config.
236    pub fn config(&self) -> &MdnsConfig {
237        &self.config
238    }
239}
240
241// ── Real UDP mDNS transport ─────────────────────────────────────────
242
243/// A real mDNS transport that uses UDP multicast.
244///
245/// Sends and receives mDNS-like JSON packets on `224.0.0.251:5353`.
246/// This is a simplified Rustant-specific protocol — not full RFC 6762 —
247/// but uses the standard mDNS multicast group for LAN discovery.
248pub struct UdpMdnsTransport {
249    bind_addr: String,
250}
251
252impl UdpMdnsTransport {
253    pub fn new() -> Self {
254        Self {
255            bind_addr: format!("0.0.0.0:{}", MDNS_PORT),
256        }
257    }
258
259    pub fn with_bind_addr(addr: impl Into<String>) -> Self {
260        Self {
261            bind_addr: addr.into(),
262        }
263    }
264}
265
266impl Default for UdpMdnsTransport {
267    fn default() -> Self {
268        Self::new()
269    }
270}
271
272#[async_trait]
273impl MdnsTransport for UdpMdnsTransport {
274    async fn register(&self, record: &MdnsServiceRecord) -> Result<(), String> {
275        let socket = tokio::net::UdpSocket::bind("0.0.0.0:0")
276            .await
277            .map_err(|e| format!("Failed to bind UDP socket: {e}"))?;
278
279        let payload =
280            serde_json::to_vec(record).map_err(|e| format!("Failed to serialize record: {e}"))?;
281
282        let dest = format!("{}:{}", MDNS_MULTICAST_ADDR, MDNS_PORT);
283        socket
284            .send_to(&payload, &dest)
285            .await
286            .map_err(|e| format!("Failed to send mDNS announcement: {e}"))?;
287
288        Ok(())
289    }
290
291    async fn unregister(&self) -> Result<(), String> {
292        // In a full implementation, send a "goodbye" packet.
293        // For now, simply stop advertising.
294        Ok(())
295    }
296
297    async fn discover(&self, timeout_ms: u64) -> Result<Vec<MdnsServiceRecord>, String> {
298        use std::net::Ipv4Addr;
299
300        let socket = tokio::net::UdpSocket::bind(&self.bind_addr)
301            .await
302            .map_err(|e| format!("Failed to bind mDNS socket: {e}"))?;
303
304        let multicast_addr: Ipv4Addr = MDNS_MULTICAST_ADDR
305            .parse()
306            .map_err(|e| format!("Invalid multicast addr: {e}"))?;
307
308        socket
309            .join_multicast_v4(multicast_addr, Ipv4Addr::UNSPECIFIED)
310            .map_err(|e| format!("Failed to join multicast group: {e}"))?;
311
312        let mut buf = vec![0u8; 4096];
313        let mut records = Vec::new();
314        let deadline = tokio::time::Instant::now() + tokio::time::Duration::from_millis(timeout_ms);
315
316        loop {
317            let remaining = deadline.saturating_duration_since(tokio::time::Instant::now());
318            if remaining.is_zero() {
319                break;
320            }
321
322            match tokio::time::timeout(remaining, socket.recv_from(&mut buf)).await {
323                Ok(Ok((len, _addr))) => {
324                    if let Ok(record) = serde_json::from_slice::<MdnsServiceRecord>(&buf[..len]) {
325                        if record.service_name == RUSTANT_SERVICE_NAME {
326                            records.push(record);
327                        }
328                    }
329                }
330                Ok(Err(_)) => break,
331                Err(_) => break, // timeout
332            }
333        }
334
335        Ok(records)
336    }
337}
338
339// ── Original local discovery ────────────────────────────────────────
340
341/// A node discovered on the network with connection metadata.
342#[derive(Debug, Clone)]
343pub struct DiscoveredNode {
344    pub node_id: NodeId,
345    pub address: String,
346    pub port: u16,
347    pub platform: Platform,
348    pub capabilities: Vec<Capability>,
349    pub discovered_at: DateTime<Utc>,
350}
351
352/// Discovers nodes available for task execution (local + mDNS).
353#[derive(Debug, Clone, Default)]
354pub struct NodeDiscovery {
355    discovered: Vec<NodeInfo>,
356    network_discovered: Vec<DiscoveredNode>,
357}
358
359impl NodeDiscovery {
360    pub fn new() -> Self {
361        Self::default()
362    }
363
364    /// Discover the local machine as a node.
365    pub fn discover_local(&mut self) -> NodeInfo {
366        let platform = Self::detect_platform();
367        let hostname = Self::get_hostname();
368        let info = NodeInfo {
369            node_id: NodeId::new(format!("local-{}", hostname)),
370            name: format!("Local ({})", hostname),
371            platform,
372            hostname,
373            registered_at: Utc::now(),
374            os_version: None,
375            agent_version: env!("CARGO_PKG_VERSION").to_string(),
376            uptime_secs: 0,
377        };
378        self.discovered.push(info.clone());
379        info
380    }
381
382    /// Number of discovered nodes.
383    pub fn discovered_count(&self) -> usize {
384        self.discovered.len()
385    }
386
387    /// All discovered nodes.
388    pub fn discovered_nodes(&self) -> &[NodeInfo] {
389        &self.discovered
390    }
391
392    /// Clear all discovered nodes.
393    pub fn clear(&mut self) {
394        self.discovered.clear();
395        self.network_discovered.clear();
396    }
397
398    /// Add a network-discovered node.
399    pub fn add_network_node(&mut self, node: DiscoveredNode) {
400        self.network_discovered.push(node);
401    }
402
403    /// List all network-discovered nodes.
404    pub fn network_nodes(&self) -> &[DiscoveredNode] {
405        &self.network_discovered
406    }
407
408    /// Remove stale network-discovered nodes older than `max_age_secs`.
409    /// Returns the number of removed entries.
410    pub fn remove_stale(&mut self, max_age_secs: u64) -> usize {
411        let now = Utc::now();
412        let before = self.network_discovered.len();
413        self.network_discovered.retain(|node| {
414            let age = now.signed_duration_since(node.discovered_at);
415            age.num_seconds() < max_age_secs as i64
416        });
417        before - self.network_discovered.len()
418    }
419
420    /// Detect the current platform.
421    fn detect_platform() -> Platform {
422        if cfg!(target_os = "macos") {
423            Platform::MacOS
424        } else if cfg!(target_os = "linux") {
425            Platform::Linux
426        } else if cfg!(target_os = "windows") {
427            Platform::Windows
428        } else {
429            Platform::Unknown
430        }
431    }
432
433    /// Get the hostname.
434    fn get_hostname() -> String {
435        std::env::var("HOSTNAME")
436            .or_else(|_| std::env::var("HOST"))
437            .unwrap_or_else(|_| "unknown".to_string())
438    }
439}
440
441#[cfg(test)]
442mod tests {
443    use super::*;
444    use std::sync::{Arc, Mutex};
445
446    // ── Mock mDNS transport ─────────────────────────────────────────
447
448    struct MockMdnsTransport {
449        registered: Arc<Mutex<Option<MdnsServiceRecord>>>,
450        scan_results: Arc<Mutex<Vec<MdnsServiceRecord>>>,
451    }
452
453    impl MockMdnsTransport {
454        fn new() -> Self {
455            Self {
456                registered: Arc::new(Mutex::new(None)),
457                scan_results: Arc::new(Mutex::new(Vec::new())),
458            }
459        }
460
461        fn with_scan_results(results: Vec<MdnsServiceRecord>) -> Self {
462            Self {
463                registered: Arc::new(Mutex::new(None)),
464                scan_results: Arc::new(Mutex::new(results)),
465            }
466        }
467
468        #[allow(dead_code)]
469        fn registered_record(&self) -> Option<MdnsServiceRecord> {
470            self.registered.lock().unwrap().clone()
471        }
472    }
473
474    #[async_trait]
475    impl MdnsTransport for MockMdnsTransport {
476        async fn register(&self, record: &MdnsServiceRecord) -> Result<(), String> {
477            *self.registered.lock().unwrap() = Some(record.clone());
478            Ok(())
479        }
480
481        async fn unregister(&self) -> Result<(), String> {
482            *self.registered.lock().unwrap() = None;
483            Ok(())
484        }
485
486        async fn discover(&self, _timeout_ms: u64) -> Result<Vec<MdnsServiceRecord>, String> {
487            Ok(self.scan_results.lock().unwrap().clone())
488        }
489    }
490
491    fn sample_record(node_id: &str, addr: &str, port: u16) -> MdnsServiceRecord {
492        MdnsServiceRecord {
493            service_name: RUSTANT_SERVICE_NAME.to_string(),
494            instance_name: format!("{}-Rustant", node_id),
495            address: addr.to_string(),
496            port,
497            platform: Platform::Linux,
498            node_id: node_id.to_string(),
499            capabilities_csv: "shell,filesystem".to_string(),
500        }
501    }
502
503    // ── Original local discovery tests ──────────────────────────────
504
505    #[test]
506    fn test_discovery_new() {
507        let disc = NodeDiscovery::new();
508        assert_eq!(disc.discovered_count(), 0);
509    }
510
511    #[test]
512    fn test_discover_local() {
513        let mut disc = NodeDiscovery::new();
514        let info = disc.discover_local();
515        assert!(info.node_id.0.starts_with("local-"));
516        assert_eq!(disc.discovered_count(), 1);
517    }
518
519    #[test]
520    fn test_discovery_clear() {
521        let mut disc = NodeDiscovery::new();
522        disc.discover_local();
523        assert_eq!(disc.discovered_count(), 1);
524        disc.clear();
525        assert_eq!(disc.discovered_count(), 0);
526    }
527
528    #[test]
529    fn test_discovered_node_creation() {
530        let node = DiscoveredNode {
531            node_id: NodeId::new("remote-1"),
532            address: "192.168.1.10".into(),
533            port: 8080,
534            platform: Platform::Linux,
535            capabilities: vec![Capability::Shell, Capability::FileSystem],
536            discovered_at: Utc::now(),
537        };
538        assert_eq!(node.address, "192.168.1.10");
539        assert_eq!(node.port, 8080);
540        assert_eq!(node.capabilities.len(), 2);
541    }
542
543    #[test]
544    fn test_discovery_add_and_list() {
545        let mut disc = NodeDiscovery::new();
546        disc.add_network_node(DiscoveredNode {
547            node_id: NodeId::new("remote-1"),
548            address: "10.0.0.1".into(),
549            port: 9000,
550            platform: Platform::MacOS,
551            capabilities: vec![Capability::Shell],
552            discovered_at: Utc::now(),
553        });
554        disc.add_network_node(DiscoveredNode {
555            node_id: NodeId::new("remote-2"),
556            address: "10.0.0.2".into(),
557            port: 9000,
558            platform: Platform::Linux,
559            capabilities: vec![],
560            discovered_at: Utc::now(),
561        });
562
563        assert_eq!(disc.network_nodes().len(), 2);
564    }
565
566    #[test]
567    fn test_discovery_remove_stale() {
568        let mut disc = NodeDiscovery::new();
569        // Add a stale node (discovered 1000 seconds ago)
570        disc.add_network_node(DiscoveredNode {
571            node_id: NodeId::new("old"),
572            address: "10.0.0.1".into(),
573            port: 9000,
574            platform: Platform::MacOS,
575            capabilities: vec![],
576            discovered_at: Utc::now() - chrono::Duration::seconds(1000),
577        });
578        // Add a fresh node
579        disc.add_network_node(DiscoveredNode {
580            node_id: NodeId::new("new"),
581            address: "10.0.0.2".into(),
582            port: 9000,
583            platform: Platform::Linux,
584            capabilities: vec![],
585            discovered_at: Utc::now(),
586        });
587
588        let removed = disc.remove_stale(600); // max age 600s
589        assert_eq!(removed, 1);
590        assert_eq!(disc.network_nodes().len(), 1);
591        assert_eq!(disc.network_nodes()[0].node_id, NodeId::new("new"));
592    }
593
594    #[test]
595    fn test_discovery_no_stale() {
596        let mut disc = NodeDiscovery::new();
597        disc.add_network_node(DiscoveredNode {
598            node_id: NodeId::new("fresh"),
599            address: "10.0.0.1".into(),
600            port: 9000,
601            platform: Platform::MacOS,
602            capabilities: vec![],
603            discovered_at: Utc::now(),
604        });
605
606        let removed = disc.remove_stale(600);
607        assert_eq!(removed, 0);
608        assert_eq!(disc.network_nodes().len(), 1);
609    }
610
611    // ── mDNS service record tests ───────────────────────────────────
612
613    #[test]
614    fn test_mdns_constants() {
615        assert_eq!(MDNS_MULTICAST_ADDR, "224.0.0.251");
616        assert_eq!(MDNS_PORT, 5353);
617        assert_eq!(RUSTANT_SERVICE_NAME, "_rustant._tcp.local.");
618    }
619
620    #[test]
621    fn test_mdns_config_default() {
622        let config = MdnsConfig::default();
623        assert!(!config.enabled);
624        assert_eq!(config.scan_interval_secs, 30);
625        assert_eq!(config.scan_timeout_ms, 3000);
626        assert_eq!(config.stale_threshold_secs, 120);
627    }
628
629    #[test]
630    fn test_mdns_config_serialization() {
631        let config = MdnsConfig {
632            enabled: true,
633            scan_interval_secs: 60,
634            scan_timeout_ms: 5000,
635            stale_threshold_secs: 300,
636        };
637        let json = serde_json::to_string(&config).unwrap();
638        let restored: MdnsConfig = serde_json::from_str(&json).unwrap();
639        assert!(restored.enabled);
640        assert_eq!(restored.scan_interval_secs, 60);
641    }
642
643    #[test]
644    fn test_mdns_service_record_parse_capabilities() {
645        let record = sample_record("node-1", "10.0.0.1", 8080);
646        let caps = record.parse_capabilities();
647        assert_eq!(caps.len(), 2);
648        assert_eq!(caps[0], Capability::Shell);
649        assert_eq!(caps[1], Capability::FileSystem);
650    }
651
652    #[test]
653    fn test_mdns_service_record_parse_empty_capabilities() {
654        let mut record = sample_record("node-1", "10.0.0.1", 8080);
655        record.capabilities_csv = String::new();
656        let caps = record.parse_capabilities();
657        assert!(caps.is_empty());
658    }
659
660    #[test]
661    fn test_mdns_service_record_parse_all_capability_types() {
662        let mut record = sample_record("node-1", "10.0.0.1", 8080);
663        record.capabilities_csv = "shell,filesystem,applescript,automator,screenshot,clipboard,notifications,browser,camera,screen_record,location,app_control:Safari,custom:gpu".to_string();
664        let caps = record.parse_capabilities();
665        assert_eq!(caps.len(), 13);
666        assert_eq!(caps[0], Capability::Shell);
667        assert_eq!(caps[6], Capability::Notifications);
668        assert_eq!(caps[11], Capability::AppControl("Safari".to_string()));
669        assert_eq!(caps[12], Capability::Custom("gpu".to_string()));
670    }
671
672    #[test]
673    fn test_mdns_service_record_capabilities_to_csv() {
674        let caps = vec![
675            Capability::Shell,
676            Capability::FileSystem,
677            Capability::Screenshot,
678        ];
679        let csv = MdnsServiceRecord::capabilities_to_csv(&caps);
680        assert_eq!(csv, "shell,filesystem,screenshot");
681    }
682
683    #[test]
684    fn test_mdns_service_record_to_discovered_node() {
685        let record = sample_record("node-x", "192.168.1.50", 9090);
686        let node = record.to_discovered_node();
687        assert_eq!(node.node_id, NodeId::new("node-x"));
688        assert_eq!(node.address, "192.168.1.50");
689        assert_eq!(node.port, 9090);
690        assert_eq!(node.platform, Platform::Linux);
691        assert_eq!(node.capabilities.len(), 2);
692    }
693
694    #[test]
695    fn test_mdns_service_record_serialization() {
696        let record = sample_record("node-1", "10.0.0.1", 8080);
697        let json = serde_json::to_string(&record).unwrap();
698        assert!(json.contains("_rustant._tcp.local."));
699        let restored: MdnsServiceRecord = serde_json::from_str(&json).unwrap();
700        assert_eq!(restored.node_id, "node-1");
701        assert_eq!(restored.address, "10.0.0.1");
702    }
703
704    // ── MdnsDiscovery coordinator tests ─────────────────────────────
705
706    #[tokio::test]
707    async fn test_mdns_discovery_register() {
708        let transport = MockMdnsTransport::new();
709        let registered = transport.registered.clone();
710        let config = MdnsConfig::default();
711        let mut disc = MdnsDiscovery::new(Box::new(transport), config);
712
713        assert!(!disc.is_registered());
714
715        let record = sample_record("local-1", "127.0.0.1", 8080);
716        disc.register(record).await.unwrap();
717
718        assert!(disc.is_registered());
719        let reg = registered.lock().unwrap();
720        assert_eq!(reg.as_ref().unwrap().node_id, "local-1");
721    }
722
723    #[tokio::test]
724    async fn test_mdns_discovery_unregister() {
725        let transport = MockMdnsTransport::new();
726        let config = MdnsConfig::default();
727        let mut disc = MdnsDiscovery::new(Box::new(transport), config);
728
729        let record = sample_record("local-1", "127.0.0.1", 8080);
730        disc.register(record).await.unwrap();
731        assert!(disc.is_registered());
732
733        disc.unregister().await.unwrap();
734        assert!(!disc.is_registered());
735    }
736
737    #[tokio::test]
738    async fn test_mdns_discovery_scan_finds_remote_nodes() {
739        let remote1 = sample_record("remote-a", "192.168.1.10", 9000);
740        let remote2 = sample_record("remote-b", "192.168.1.11", 9001);
741        let transport = MockMdnsTransport::with_scan_results(vec![remote1, remote2]);
742        let config = MdnsConfig::default();
743        let mut disc = MdnsDiscovery::new(Box::new(transport), config);
744
745        let new_nodes = disc.scan().await.unwrap();
746        assert_eq!(new_nodes.len(), 2);
747        assert_eq!(disc.found_nodes().len(), 2);
748        assert_eq!(disc.found_nodes()[0].node_id, NodeId::new("remote-a"));
749        assert_eq!(disc.found_nodes()[1].node_id, NodeId::new("remote-b"));
750    }
751
752    #[tokio::test]
753    async fn test_mdns_discovery_scan_skips_self() {
754        let local = sample_record("local-1", "127.0.0.1", 8080);
755        let remote = sample_record("remote-a", "192.168.1.10", 9000);
756        let transport = MockMdnsTransport::with_scan_results(vec![local, remote]);
757        let config = MdnsConfig::default();
758        let mut disc = MdnsDiscovery::new(Box::new(transport), config);
759
760        // Register so MdnsDiscovery knows its own node_id.
761        let own = sample_record("local-1", "127.0.0.1", 8080);
762        disc.register(own).await.unwrap();
763
764        let new_nodes = disc.scan().await.unwrap();
765        // Only the remote node should be discovered, not ourselves.
766        assert_eq!(new_nodes.len(), 1);
767        assert_eq!(new_nodes[0].node_id, NodeId::new("remote-a"));
768        assert_eq!(disc.found_nodes().len(), 1);
769    }
770
771    #[tokio::test]
772    async fn test_mdns_discovery_scan_refreshes_known_nodes() {
773        let remote = sample_record("remote-a", "192.168.1.10", 9000);
774        let transport = MockMdnsTransport::with_scan_results(vec![remote]);
775        let config = MdnsConfig::default();
776        let mut disc = MdnsDiscovery::new(Box::new(transport), config);
777
778        // First scan: discovers the node.
779        let new1 = disc.scan().await.unwrap();
780        assert_eq!(new1.len(), 1);
781
782        // Second scan: same node, should refresh, not duplicate.
783        let new2 = disc.scan().await.unwrap();
784        assert_eq!(new2.len(), 0); // not new
785        assert_eq!(disc.found_nodes().len(), 1); // still just one
786    }
787
788    #[tokio::test]
789    async fn test_mdns_discovery_scan_empty() {
790        let transport = MockMdnsTransport::with_scan_results(vec![]);
791        let config = MdnsConfig::default();
792        let mut disc = MdnsDiscovery::new(Box::new(transport), config);
793
794        let new_nodes = disc.scan().await.unwrap();
795        assert!(new_nodes.is_empty());
796        assert!(disc.found_nodes().is_empty());
797    }
798
799    #[tokio::test]
800    async fn test_mdns_discovery_prune_stale() {
801        let remote = sample_record("remote-a", "192.168.1.10", 9000);
802        let transport = MockMdnsTransport::with_scan_results(vec![remote]);
803        let config = MdnsConfig {
804            stale_threshold_secs: 1, // very short threshold
805            ..Default::default()
806        };
807        let mut disc = MdnsDiscovery::new(Box::new(transport), config);
808
809        disc.scan().await.unwrap();
810        assert_eq!(disc.found_nodes().len(), 1);
811
812        // Manually backdate the discovered_at to make it stale.
813        disc.found[0].discovered_at = Utc::now() - chrono::Duration::seconds(10);
814
815        let pruned = disc.prune_stale();
816        assert_eq!(pruned, 1);
817        assert!(disc.found_nodes().is_empty());
818    }
819
820    #[tokio::test]
821    async fn test_mdns_discovery_clear() {
822        let remote = sample_record("remote-a", "192.168.1.10", 9000);
823        let transport = MockMdnsTransport::with_scan_results(vec![remote]);
824        let config = MdnsConfig::default();
825        let mut disc = MdnsDiscovery::new(Box::new(transport), config);
826
827        disc.scan().await.unwrap();
828        assert_eq!(disc.found_nodes().len(), 1);
829
830        disc.clear();
831        assert!(disc.found_nodes().is_empty());
832    }
833
834    #[test]
835    fn test_udp_mdns_transport_default() {
836        let transport = UdpMdnsTransport::default();
837        assert_eq!(transport.bind_addr, "0.0.0.0:5353");
838    }
839
840    #[test]
841    fn test_udp_mdns_transport_custom_bind() {
842        let transport = UdpMdnsTransport::with_bind_addr("0.0.0.0:15353");
843        assert_eq!(transport.bind_addr, "0.0.0.0:15353");
844    }
845}