1use std::collections::HashMap;
8use std::net::IpAddr;
9use std::sync::Arc;
10use std::time::{Duration, SystemTime};
11use tokio::sync::{mpsc, RwLock};
12use tracing::{debug, info, warn};
13
14#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
16pub enum FirewallStatus {
17 #[default]
19 Unknown,
20 Accessible,
22 Blocked,
24 Error,
26}
27
28#[derive(Debug, Clone)]
30pub struct FirewallDetectionConfig {
31 pub event_wait_timeout: Duration,
33 pub enable_caching: bool,
35 pub max_cached_devices: usize,
37}
38
39impl Default for FirewallDetectionConfig {
40 fn default() -> Self {
41 Self {
42 event_wait_timeout: Duration::from_secs(15),
43 enable_caching: true,
44 max_cached_devices: 100,
45 }
46 }
47}
48
49#[derive(Debug, Clone)]
51pub struct DeviceFirewallState {
52 pub device_ip: IpAddr,
53 pub status: FirewallStatus,
54 pub first_subscription_time: SystemTime,
55 pub first_event_time: Option<SystemTime>,
56 pub detection_completed: bool,
57 pub timeout_duration: Duration,
58}
59
60#[derive(Debug, Clone)]
62pub struct DetectionResult {
63 pub device_ip: IpAddr,
64 pub status: FirewallStatus,
65 pub reason: DetectionReason,
66}
67
68#[derive(Debug, Clone, Copy, PartialEq, Eq)]
70pub enum DetectionReason {
71 EventReceived,
73 Timeout,
75 SubscriptionFailed,
77}
78
79pub struct FirewallDetectionCoordinator {
85 device_states: Arc<RwLock<HashMap<IpAddr, Arc<RwLock<DeviceFirewallState>>>>>,
87
88 config: FirewallDetectionConfig,
90
91 detection_complete_tx: mpsc::UnboundedSender<DetectionResult>,
93
94 _timeout_task_handle: tokio::task::JoinHandle<()>,
96}
97
98impl FirewallDetectionCoordinator {
99 pub fn new(config: FirewallDetectionConfig) -> Self {
101 let (detection_complete_tx, mut detection_complete_rx) = mpsc::unbounded_channel();
102
103 let device_states = Arc::new(RwLock::new(HashMap::new()));
104
105 let timeout_task_handle = {
107 let device_states = device_states.clone();
108 let detection_complete_tx = detection_complete_tx.clone();
109 tokio::spawn(async move {
110 Self::monitor_timeouts(device_states, detection_complete_tx).await;
111 })
112 };
113
114 tokio::spawn(async move {
116 while let Some(result) = detection_complete_rx.recv().await {
117 match result.reason {
118 DetectionReason::EventReceived => {
119 info!(
120 device_ip = %result.device_ip,
121 reason = ?result.reason,
122 status = ?result.status,
123 "Firewall detection: Events accessible from device"
124 );
125 }
126 DetectionReason::Timeout => {
127 warn!(
128 device_ip = %result.device_ip,
129 reason = ?result.reason,
130 status = ?result.status,
131 "Firewall detection: No events received within timeout"
132 );
133 }
134 DetectionReason::SubscriptionFailed => {
135 warn!(
136 device_ip = %result.device_ip,
137 reason = ?result.reason,
138 status = ?result.status,
139 "Firewall detection: Subscription failed for device"
140 );
141 }
142 }
143 }
144 });
145
146 Self {
147 device_states,
148 config,
149 detection_complete_tx,
150 _timeout_task_handle: timeout_task_handle,
151 }
152 }
153
154 pub async fn on_first_subscription(&self, device_ip: IpAddr) -> FirewallStatus {
159 if !self.config.enable_caching {
160 self.start_detection_for_device(device_ip).await;
162 return FirewallStatus::Unknown;
163 }
164
165 let device_states = self.device_states.read().await;
166
167 if let Some(state_arc) = device_states.get(&device_ip) {
169 let state = state_arc.read().await;
170 if state.detection_completed {
171 debug!(
172 device_ip = %device_ip,
173 status = ?state.status,
174 "Firewall detection: Using cached status for device"
175 );
176 return state.status;
177 }
178 }
179
180 drop(device_states); self.start_detection_for_device(device_ip).await;
184
185 debug!(
186 device_ip = %device_ip,
187 timeout = ?self.config.event_wait_timeout,
188 "Firewall detection: Started monitoring device for events"
189 );
190
191 FirewallStatus::Unknown
192 }
193
194 pub async fn on_event_received(&self, device_ip: IpAddr) {
198 let device_states = self.device_states.read().await;
199
200 if let Some(state_arc) = device_states.get(&device_ip) {
201 let mut state = state_arc.write().await;
202
203 if !state.detection_completed {
204 state.first_event_time = Some(SystemTime::now());
206 state.status = FirewallStatus::Accessible;
207 state.detection_completed = true;
208
209 let elapsed = SystemTime::now()
210 .duration_since(state.first_subscription_time)
211 .unwrap_or(Duration::ZERO);
212
213 let _ = self.detection_complete_tx.send(DetectionResult {
215 device_ip,
216 status: FirewallStatus::Accessible,
217 reason: DetectionReason::EventReceived,
218 });
219
220 info!(
221 device_ip = %device_ip,
222 elapsed = ?elapsed,
223 status = ?FirewallStatus::Accessible,
224 "Firewall detection: Event received from device, marking as accessible"
225 );
226 }
227 }
228 }
229
230 pub async fn get_device_status(&self, device_ip: IpAddr) -> FirewallStatus {
232 let device_states = self.device_states.read().await;
233
234 if let Some(state_arc) = device_states.get(&device_ip) {
235 let state = state_arc.read().await;
236 state.status
237 } else {
238 FirewallStatus::Unknown
239 }
240 }
241
242 pub async fn clear_device_cache(&self, device_ip: IpAddr) {
244 let mut device_states = self.device_states.write().await;
245 device_states.remove(&device_ip);
246 debug!(
247 device_ip = %device_ip,
248 "Firewall detection: Cleared cache for device"
249 );
250 }
251
252 async fn start_detection_for_device(&self, device_ip: IpAddr) {
254 let mut device_states = self.device_states.write().await;
255
256 let new_state = Arc::new(RwLock::new(DeviceFirewallState {
258 device_ip,
259 status: FirewallStatus::Unknown,
260 first_subscription_time: SystemTime::now(),
261 first_event_time: None,
262 detection_completed: false,
263 timeout_duration: self.config.event_wait_timeout,
264 }));
265
266 if device_states.len() >= self.config.max_cached_devices {
268 if let Some(oldest_ip) = device_states.keys().next().copied() {
270 device_states.remove(&oldest_ip);
271 debug!(
272 oldest_ip = %oldest_ip,
273 cache_size = self.config.max_cached_devices,
274 "Firewall detection: Removed oldest cached entry due to cache being full"
275 );
276 }
277 }
278
279 device_states.insert(device_ip, new_state);
280 }
281
282 async fn monitor_timeouts(
284 device_states: Arc<RwLock<HashMap<IpAddr, Arc<RwLock<DeviceFirewallState>>>>>,
285 detection_complete_tx: mpsc::UnboundedSender<DetectionResult>,
286 ) {
287 let mut interval = tokio::time::interval(Duration::from_secs(1));
288
289 loop {
290 interval.tick().await;
291
292 let device_states_read = device_states.read().await;
293
294 for (device_ip, state_arc) in device_states_read.iter() {
295 let mut state = state_arc.write().await;
296
297 if !state.detection_completed {
298 let elapsed = SystemTime::now()
299 .duration_since(state.first_subscription_time)
300 .unwrap_or(Duration::ZERO);
301
302 if elapsed >= state.timeout_duration {
303 state.status = FirewallStatus::Blocked;
305 state.detection_completed = true;
306
307 let _ = detection_complete_tx.send(DetectionResult {
309 device_ip: *device_ip,
310 status: FirewallStatus::Blocked,
311 reason: DetectionReason::Timeout,
312 });
313
314 warn!(
315 device_ip = %device_ip,
316 timeout = ?state.timeout_duration,
317 status = ?FirewallStatus::Blocked,
318 "Firewall detection: No events received within timeout, marking as blocked"
319 );
320 }
321 }
322 }
323 }
324 }
325
326 pub async fn get_stats(&self) -> CoordinatorStats {
328 let device_states = self.device_states.read().await;
329
330 let mut stats = CoordinatorStats {
331 total_devices: device_states.len(),
332 accessible_devices: 0,
333 blocked_devices: 0,
334 unknown_devices: 0,
335 error_devices: 0,
336 };
337
338 for state_arc in device_states.values() {
339 let state = state_arc.read().await;
340 match state.status {
341 FirewallStatus::Accessible => stats.accessible_devices += 1,
342 FirewallStatus::Blocked => stats.blocked_devices += 1,
343 FirewallStatus::Unknown => stats.unknown_devices += 1,
344 FirewallStatus::Error => stats.error_devices += 1,
345 }
346 }
347
348 stats
349 }
350}
351
352#[derive(Debug, Clone)]
354pub struct CoordinatorStats {
355 pub total_devices: usize,
356 pub accessible_devices: usize,
357 pub blocked_devices: usize,
358 pub unknown_devices: usize,
359 pub error_devices: usize,
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use std::net::Ipv4Addr;
366
367 #[tokio::test]
368 async fn test_coordinator_creation() {
369 let config = FirewallDetectionConfig::default();
370 let _coordinator = FirewallDetectionCoordinator::new(config);
371 }
373
374 #[tokio::test]
375 async fn test_first_subscription_starts_monitoring() {
376 let config = FirewallDetectionConfig::default();
377 let coordinator = FirewallDetectionCoordinator::new(config);
378
379 let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
380 let status = coordinator.on_first_subscription(device_ip).await;
381
382 assert_eq!(status, FirewallStatus::Unknown);
384
385 let cached_status = coordinator.get_device_status(device_ip).await;
387 assert_eq!(cached_status, FirewallStatus::Unknown);
388 }
389
390 #[tokio::test]
391 async fn test_event_received_marks_accessible() {
392 let config = FirewallDetectionConfig::default();
393 let coordinator = FirewallDetectionCoordinator::new(config);
394
395 let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
396
397 coordinator.on_first_subscription(device_ip).await;
399
400 coordinator.on_event_received(device_ip).await;
402
403 let status = coordinator.get_device_status(device_ip).await;
405 assert_eq!(status, FirewallStatus::Accessible);
406 }
407
408 #[tokio::test]
409 async fn test_timeout_marks_blocked() {
410 let config = FirewallDetectionConfig {
411 event_wait_timeout: Duration::from_millis(100), ..Default::default()
413 };
414 let coordinator = FirewallDetectionCoordinator::new(config);
415
416 let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
417
418 coordinator.on_first_subscription(device_ip).await;
420
421 tokio::time::sleep(Duration::from_millis(1200)).await;
423
424 let status = coordinator.get_device_status(device_ip).await;
426 assert_eq!(status, FirewallStatus::Blocked);
427 }
428
429 #[tokio::test]
430 async fn test_cached_status_reused() {
431 let config = FirewallDetectionConfig::default();
432 let coordinator = FirewallDetectionCoordinator::new(config);
433
434 let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
435
436 coordinator.on_first_subscription(device_ip).await;
438 coordinator.on_event_received(device_ip).await;
439
440 let status = coordinator.on_first_subscription(device_ip).await;
442 assert_eq!(status, FirewallStatus::Accessible);
443 }
444
445 #[tokio::test]
446 async fn test_clear_device_cache() {
447 let config = FirewallDetectionConfig::default();
448 let coordinator = FirewallDetectionCoordinator::new(config);
449
450 let device_ip = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
451
452 coordinator.on_first_subscription(device_ip).await;
454 coordinator.on_event_received(device_ip).await;
455
456 assert_eq!(
458 coordinator.get_device_status(device_ip).await,
459 FirewallStatus::Accessible
460 );
461
462 coordinator.clear_device_cache(device_ip).await;
464
465 assert_eq!(
467 coordinator.get_device_status(device_ip).await,
468 FirewallStatus::Unknown
469 );
470 }
471
472 #[tokio::test]
473 async fn test_stats() {
474 let config = FirewallDetectionConfig::default();
475 let coordinator = FirewallDetectionCoordinator::new(config);
476
477 let device1 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 100));
478 let device2 = IpAddr::V4(Ipv4Addr::new(192, 168, 1, 101));
479
480 coordinator.on_first_subscription(device1).await;
482 coordinator.on_event_received(device1).await;
483 coordinator.on_first_subscription(device2).await;
484
485 let stats = coordinator.get_stats().await;
486 assert_eq!(stats.total_devices, 2);
487 assert_eq!(stats.accessible_devices, 1);
488 assert_eq!(stats.unknown_devices, 1);
489 assert_eq!(stats.blocked_devices, 0);
490 assert_eq!(stats.error_devices, 0);
491 }
492}