Skip to main content

callback_server/
firewall_detection.rs

1//! Firewall detection coordinator for callback server.
2//!
3//! This module implements a per-device firewall detection system that monitors
4//! real UPnP event delivery to determine whether callback servers can receive
5//! external requests from Sonos devices on the local network.
6
7use std::collections::HashMap;
8use std::net::IpAddr;
9use std::sync::Arc;
10use std::time::{Duration, SystemTime};
11use tokio::sync::{mpsc, RwLock};
12use tracing::{debug, info, warn};
13
14/// Status of firewall detection for a device.
15#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16pub enum FirewallStatus {
17    /// Detection has not been performed yet
18    #[default]
19    Unknown,
20    /// Server can receive external requests from this device
21    Accessible,
22    /// Server appears to be blocked by firewall for this device
23    Blocked,
24    /// Detection failed due to other errors
25    Error,
26}
27
28/// Configuration for firewall detection behavior.
29#[derive(Debug, Clone)]
30pub struct FirewallDetectionConfig {
31    /// Timeout for waiting for first event from a device
32    pub event_wait_timeout: Duration,
33    /// Enable per-device caching of firewall status
34    pub enable_caching: bool,
35    /// Maximum number of cached device states
36    pub max_cached_devices: usize,
37}
38
39impl Default for FirewallDetectionConfig {
40    fn default() -> Self {
41        Self {
42            event_wait_timeout: Duration::from_secs(15),
43            enable_caching: true,
44            max_cached_devices: 100,
45        }
46    }
47}
48
49/// Per-device firewall detection state.
50#[derive(Debug, Clone)]
51pub struct DeviceFirewallState {
52    pub device_ip: IpAddr,
53    pub status: FirewallStatus,
54    pub first_subscription_time: SystemTime,
55    pub first_event_time: Option<SystemTime>,
56    pub detection_completed: bool,
57    pub timeout_duration: Duration,
58}
59
60/// Result of a firewall detection operation.
61#[derive(Debug, Clone)]
62pub struct DetectionResult {
63    pub device_ip: IpAddr,
64    pub status: FirewallStatus,
65    pub reason: DetectionReason,
66}
67
68/// Reason for detection completion.
69#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum DetectionReason {
71    /// First event arrived within timeout
72    EventReceived,
73    /// No events received within timeout
74    Timeout,
75    /// Subscription creation failed
76    SubscriptionFailed,
77}
78
79/// Coordinates per-device firewall detection by monitoring real UPnP event delivery.
80///
81/// The coordinator tracks firewall status on a per-device basis, triggering detection
82/// when the first subscription is created for a device and monitoring for event arrivals
83/// to determine connectivity status.
84pub struct FirewallDetectionCoordinator {
85    /// Per-device detection states
86    device_states: Arc<RwLock<HashMap<IpAddr, Arc<RwLock<DeviceFirewallState>>>>>,
87
88    /// Configuration for detection behavior
89    config: FirewallDetectionConfig,
90
91    /// Channel for notifying when detection completes
92    detection_complete_tx: mpsc::UnboundedSender<DetectionResult>,
93
94    /// Handle for the timeout monitoring task
95    _timeout_task_handle: tokio::task::JoinHandle<()>,
96}
97
98impl FirewallDetectionCoordinator {
99    /// Create a new firewall detection coordinator.
100    pub fn new(config: FirewallDetectionConfig) -> Self {
101        let (detection_complete_tx, mut detection_complete_rx) = mpsc::unbounded_channel();
102
103        let device_states = Arc::new(RwLock::new(HashMap::new()));
104
105        // Spawn background task for timeout monitoring
106        let timeout_task_handle = {
107            let device_states = device_states.clone();
108            let detection_complete_tx = detection_complete_tx.clone();
109            tokio::spawn(async move {
110                Self::monitor_timeouts(device_states, detection_complete_tx).await;
111            })
112        };
113
114        // Spawn task to handle detection results (logging for now)
115        tokio::spawn(async move {
116            while let Some(result) = detection_complete_rx.recv().await {
117                match result.reason {
118                    DetectionReason::EventReceived => {
119                        info!(
120                            device_ip = %result.device_ip,
121                            reason = ?result.reason,
122                            status = ?result.status,
123                            "Firewall detection: Events accessible from device"
124                        );
125                    }
126                    DetectionReason::Timeout => {
127                        warn!(
128                            device_ip = %result.device_ip,
129                            reason = ?result.reason,
130                            status = ?result.status,
131                            "Firewall detection: No events received within timeout"
132                        );
133                    }
134                    DetectionReason::SubscriptionFailed => {
135                        warn!(
136                            device_ip = %result.device_ip,
137                            reason = ?result.reason,
138                            status = ?result.status,
139                            "Firewall detection: Subscription failed for device"
140                        );
141                    }
142                }
143            }
144        });
145
146        Self {
147            device_states,
148            config,
149            detection_complete_tx,
150            _timeout_task_handle: timeout_task_handle,
151        }
152    }
153
154    /// Called when the first subscription is created for a device.
155    ///
156    /// Returns the cached status if already known, otherwise starts monitoring
157    /// and returns Unknown status while detection is in progress.
158    pub async fn on_first_subscription(&self, device_ip: IpAddr) -> FirewallStatus {
159        if !self.config.enable_caching {
160            // Caching disabled - always return Unknown and start fresh detection
161            self.start_detection_for_device(device_ip).await;
162            return FirewallStatus::Unknown;
163        }
164
165        let device_states = self.device_states.read().await;
166
167        // Check if we already have cached status
168        if let Some(state_arc) = device_states.get(&device_ip) {
169            let state = state_arc.read().await;
170            if state.detection_completed {
171                debug!(
172                    device_ip = %device_ip,
173                    status = ?state.status,
174                    "Firewall detection: Using cached status for device"
175                );
176                return state.status;
177            }
178        }
179
180        drop(device_states); // Release read lock before starting detection
181
182        // First subscription for this device - start monitoring
183        self.start_detection_for_device(device_ip).await;
184
185        debug!(
186            device_ip = %device_ip,
187            timeout = ?self.config.event_wait_timeout,
188            "Firewall detection: Started monitoring device for events"
189        );
190
191        FirewallStatus::Unknown
192    }
193
194    /// Called when any event is received from a device.
195    ///
196    /// If detection is in progress for this device, marks it as accessible.
197    pub async fn on_event_received(&self, device_ip: IpAddr) {
198        let device_states = self.device_states.read().await;
199
200        if let Some(state_arc) = device_states.get(&device_ip) {
201            let mut state = state_arc.write().await;
202
203            if !state.detection_completed {
204                // First event received - mark as accessible
205                state.first_event_time = Some(SystemTime::now());
206                state.status = FirewallStatus::Accessible;
207                state.detection_completed = true;
208
209                let elapsed = SystemTime::now()
210                    .duration_since(state.first_subscription_time)
211                    .unwrap_or(Duration::ZERO);
212
213                // Notify completion
214                let _ = self.detection_complete_tx.send(DetectionResult {
215                    device_ip,
216                    status: FirewallStatus::Accessible,
217                    reason: DetectionReason::EventReceived,
218                });
219
220                info!(
221                    device_ip = %device_ip,
222                    elapsed = ?elapsed,
223                    status = ?FirewallStatus::Accessible,
224                    "Firewall detection: Event received from device, marking as accessible"
225                );
226            }
227        }
228    }
229
230    /// Get the current cached status for a device.
231    pub async fn get_device_status(&self, device_ip: IpAddr) -> FirewallStatus {
232        let device_states = self.device_states.read().await;
233
234        if let Some(state_arc) = device_states.get(&device_ip) {
235            let state = state_arc.read().await;
236            state.status
237        } else {
238            FirewallStatus::Unknown
239        }
240    }
241
242    /// Clear cached status for a device (useful for testing).
243    pub async fn clear_device_cache(&self, device_ip: IpAddr) {
244        let mut device_states = self.device_states.write().await;
245        device_states.remove(&device_ip);
246        debug!(
247            device_ip = %device_ip,
248            "Firewall detection: Cleared cache for device"
249        );
250    }
251
252    /// Start detection monitoring for a specific device.
253    async fn start_detection_for_device(&self, device_ip: IpAddr) {
254        let mut device_states = self.device_states.write().await;
255
256        // Create new detection state
257        let new_state = Arc::new(RwLock::new(DeviceFirewallState {
258            device_ip,
259            status: FirewallStatus::Unknown,
260            first_subscription_time: SystemTime::now(),
261            first_event_time: None,
262            detection_completed: false,
263            timeout_duration: self.config.event_wait_timeout,
264        }));
265
266        // Enforce maximum cache size
267        if device_states.len() >= self.config.max_cached_devices {
268            // Remove oldest entry (this is a simple LRU-like behavior)
269            if let Some(oldest_ip) = device_states.keys().next().copied() {
270                device_states.remove(&oldest_ip);
271                debug!(
272                    oldest_ip = %oldest_ip,
273                    cache_size = self.config.max_cached_devices,
274                    "Firewall detection: Removed oldest cached entry due to cache being full"
275                );
276            }
277        }
278
279        device_states.insert(device_ip, new_state);
280    }
281
282    /// Background task that monitors for timeouts.
283    async fn monitor_timeouts(
284        device_states: Arc<RwLock<HashMap<IpAddr, Arc<RwLock<DeviceFirewallState>>>>>,
285        detection_complete_tx: mpsc::UnboundedSender<DetectionResult>,
286    ) {
287        let mut interval = tokio::time::interval(Duration::from_secs(1));
288
289        loop {
290            interval.tick().await;
291
292            let device_states_read = device_states.read().await;
293
294            for (device_ip, state_arc) in device_states_read.iter() {
295                let mut state = state_arc.write().await;
296
297                if !state.detection_completed {
298                    let elapsed = SystemTime::now()
299                        .duration_since(state.first_subscription_time)
300                        .unwrap_or(Duration::ZERO);
301
302                    if elapsed >= state.timeout_duration {
303                        // Timeout reached - mark as blocked
304                        state.status = FirewallStatus::Blocked;
305                        state.detection_completed = true;
306
307                        // Notify completion
308                        let _ = detection_complete_tx.send(DetectionResult {
309                            device_ip: *device_ip,
310                            status: FirewallStatus::Blocked,
311                            reason: DetectionReason::Timeout,
312                        });
313
314                        warn!(
315                            device_ip = %device_ip,
316                            timeout = ?state.timeout_duration,
317                            status = ?FirewallStatus::Blocked,
318                            "Firewall detection: No events received within timeout, marking as blocked"
319                        );
320                    }
321                }
322            }
323        }
324    }
325
326    /// Get statistics about the coordinator state.
327    pub async fn get_stats(&self) -> CoordinatorStats {
328        let device_states = self.device_states.read().await;
329
330        let mut stats = CoordinatorStats {
331            total_devices: device_states.len(),
332            accessible_devices: 0,
333            blocked_devices: 0,
334            unknown_devices: 0,
335            error_devices: 0,
336        };
337
338        for state_arc in device_states.values() {
339            let state = state_arc.read().await;
340            match state.status {
341                FirewallStatus::Accessible => stats.accessible_devices += 1,
342                FirewallStatus::Blocked => stats.blocked_devices += 1,
343                FirewallStatus::Unknown => stats.unknown_devices += 1,
344                FirewallStatus::Error => stats.error_devices += 1,
345            }
346        }
347
348        stats
349    }
350}
351
352/// Statistics about the firewall detection coordinator.
353#[derive(Debug, Clone)]
354pub struct CoordinatorStats {
355    pub total_devices: usize,
356    pub accessible_devices: usize,
357    pub blocked_devices: usize,
358    pub unknown_devices: usize,
359    pub error_devices: usize,
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use std::net::Ipv4Addr;
366
367    #[tokio::test]
368    async fn test_coordinator_creation() {
369        let config = FirewallDetectionConfig::default();
370        let _coordinator = FirewallDetectionCoordinator::new(config);
371        // Just verify it doesn't panic
372    }
373
374    #[tokio::test]
375    async fn test_first_subscription_starts_monitoring() {
376        let config = FirewallDetectionConfig::default();
377        let coordinator = FirewallDetectionCoordinator::new(config);
378
379        let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
380        let status = coordinator.on_first_subscription(device_ip).await;
381
382        // Should return Unknown while monitoring
383        assert_eq!(status, FirewallStatus::Unknown);
384
385        // Should have cached status
386        let cached_status = coordinator.get_device_status(device_ip).await;
387        assert_eq!(cached_status, FirewallStatus::Unknown);
388    }
389
390    #[tokio::test]
391    async fn test_event_received_marks_accessible() {
392        let config = FirewallDetectionConfig::default();
393        let coordinator = FirewallDetectionCoordinator::new(config);
394
395        let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
396
397        // Start monitoring
398        coordinator.on_first_subscription(device_ip).await;
399
400        // Simulate event received
401        coordinator.on_event_received(device_ip).await;
402
403        // Should now be accessible
404        let status = coordinator.get_device_status(device_ip).await;
405        assert_eq!(status, FirewallStatus::Accessible);
406    }
407
408    #[tokio::test]
409    async fn test_timeout_marks_blocked() {
410        let config = FirewallDetectionConfig {
411            event_wait_timeout: Duration::from_millis(100), // Very short timeout for testing
412            ..Default::default()
413        };
414        let coordinator = FirewallDetectionCoordinator::new(config);
415
416        let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
417
418        // Start monitoring
419        coordinator.on_first_subscription(device_ip).await;
420
421        // Wait for timeout + monitoring task to run (monitoring runs every 1 second)
422        tokio::time::sleep(Duration::from_millis(1200)).await;
423
424        // Should now be blocked
425        let status = coordinator.get_device_status(device_ip).await;
426        assert_eq!(status, FirewallStatus::Blocked);
427    }
428
429    #[tokio::test]
430    async fn test_cached_status_reused() {
431        let config = FirewallDetectionConfig::default();
432        let coordinator = FirewallDetectionCoordinator::new(config);
433
434        let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
435
436        // Start monitoring and mark as accessible
437        coordinator.on_first_subscription(device_ip).await;
438        coordinator.on_event_received(device_ip).await;
439
440        // Second subscription should return cached status
441        let status = coordinator.on_first_subscription(device_ip).await;
442        assert_eq!(status, FirewallStatus::Accessible);
443    }
444
445    #[tokio::test]
446    async fn test_clear_device_cache() {
447        let config = FirewallDetectionConfig::default();
448        let coordinator = FirewallDetectionCoordinator::new(config);
449
450        let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
451
452        // Create cached entry
453        coordinator.on_first_subscription(device_ip).await;
454        coordinator.on_event_received(device_ip).await;
455
456        // Verify cached
457        assert_eq!(
458            coordinator.get_device_status(device_ip).await,
459            FirewallStatus::Accessible
460        );
461
462        // Clear cache
463        coordinator.clear_device_cache(device_ip).await;
464
465        // Should be unknown again
466        assert_eq!(
467            coordinator.get_device_status(device_ip).await,
468            FirewallStatus::Unknown
469        );
470    }
471
472    #[tokio::test]
473    async fn test_stats() {
474        let config = FirewallDetectionConfig::default();
475        let coordinator = FirewallDetectionCoordinator::new(config);
476
477        let device1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
478        let device2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 101));
479
480        // One accessible, one unknown
481        coordinator.on_first_subscription(device1).await;
482        coordinator.on_event_received(device1).await;
483        coordinator.on_first_subscription(device2).await;
484
485        let stats = coordinator.get_stats().await;
486        assert_eq!(stats.total_devices, 2);
487        assert_eq!(stats.accessible_devices, 1);
488        assert_eq!(stats.unknown_devices, 1);
489        assert_eq!(stats.blocked_devices, 0);
490        assert_eq!(stats.error_devices, 0);
491    }
492}