sonos_stream/subscription/
event_detector.rs1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10use tokio::sync::{mpsc, RwLock};
11
12use callback_server::{FirewallDetectionCoordinator, FirewallStatus};
13use tracing::debug;
14
15use crate::broker::PollingReason;
16use crate::registry::{RegistrationId, SpeakerServicePair};
17
18struct MonitoredRegistration {
20 last_event_time: Instant,
21 pair: SpeakerServicePair,
22 polling_activated: bool,
23}
24
25pub struct EventDetector {
27 registrations: Arc<RwLock<HashMap<RegistrationId, MonitoredRegistration>>>,
29
30 event_timeout: Duration,
32
33 polling_activation_delay: Duration,
35
36 firewall_coordinator: Option<Arc<FirewallDetectionCoordinator>>,
38
39 polling_request_sender: Option<mpsc::UnboundedSender<PollingRequest>>,
41}
42
43#[derive(Debug, Clone)]
45pub struct PollingRequest {
46 pub registration_id: RegistrationId,
47 pub speaker_service_pair: SpeakerServicePair,
48 pub action: PollingAction,
49 pub reason: PollingReason,
50}
51
52#[derive(Debug, Clone)]
53pub enum PollingAction {
54 Start,
55 Stop,
56}
57
58impl EventDetector {
59 pub fn new(event_timeout: Duration, polling_activation_delay: Duration) -> Self {
61 Self {
62 registrations: Arc::new(RwLock::new(HashMap::new())),
63 event_timeout,
64 polling_activation_delay,
65 firewall_coordinator: None,
66 polling_request_sender: None,
67 }
68 }
69
70 pub fn set_firewall_coordinator(&mut self, coordinator: Arc<FirewallDetectionCoordinator>) {
72 self.firewall_coordinator = Some(coordinator);
73 }
74
75 pub fn set_polling_request_sender(&mut self, sender: mpsc::UnboundedSender<PollingRequest>) {
77 self.polling_request_sender = Some(sender);
78 }
79
80 pub async fn record_event(&self, registration_id: RegistrationId) {
82 let mut registrations = self.registrations.write().await;
83 if let Some(reg) = registrations.get_mut(®istration_id) {
84 reg.last_event_time = Instant::now();
85 }
86 }
87
88 pub async fn should_start_polling(&self, registration_id: RegistrationId) -> bool {
90 let registrations = self.registrations.read().await;
91 registrations
92 .get(®istration_id)
93 .map(|reg| reg.last_event_time.elapsed() > self.event_timeout)
94 .unwrap_or(false)
95 }
96
97 pub async fn should_stop_polling(&self, registration_id: RegistrationId) -> bool {
99 let registrations = self.registrations.read().await;
100 registrations
101 .get(®istration_id)
102 .map(|reg| reg.last_event_time.elapsed() <= self.polling_activation_delay)
103 .unwrap_or(false)
104 }
105
106 pub async fn evaluate_firewall_status(
108 &self,
109 registration_id: RegistrationId,
110 pair: &SpeakerServicePair,
111 ) -> Option<PollingRequest> {
112 if let Some(firewall_coordinator) = &self.firewall_coordinator {
113 let status = firewall_coordinator
114 .get_device_status(pair.speaker_ip)
115 .await;
116
117 match status {
118 FirewallStatus::Blocked => {
119 Some(PollingRequest {
121 registration_id,
122 speaker_service_pair: pair.clone(),
123 action: PollingAction::Start,
124 reason: PollingReason::FirewallBlocked,
125 })
126 }
127 FirewallStatus::Accessible => {
128 None
130 }
131 FirewallStatus::Unknown => {
132 None
134 }
135 FirewallStatus::Error => {
136 Some(PollingRequest {
138 registration_id,
139 speaker_service_pair: pair.clone(),
140 action: PollingAction::Start,
141 reason: PollingReason::NetworkIssues,
142 })
143 }
144 }
145 } else {
146 None
148 }
149 }
150
151 pub async fn start_monitoring(&self) -> tokio::task::JoinHandle<()> {
154 let registrations = Arc::clone(&self.registrations);
155 let event_timeout = self.event_timeout;
156 let polling_request_sender = self.polling_request_sender.clone();
157
158 let check_interval = (event_timeout / 3).max(Duration::from_secs(1));
159
160 tokio::spawn(async move {
161 let mut interval = tokio::time::interval(check_interval);
162
163 loop {
164 interval.tick().await;
165
166 let now = Instant::now();
167
168 let timed_out: Vec<(RegistrationId, SpeakerServicePair)> = {
170 let regs = registrations.read().await;
171 regs.iter()
172 .filter(|(_, reg)| {
173 !reg.polling_activated
174 && now.duration_since(reg.last_event_time) > event_timeout
175 })
176 .map(|(id, reg)| (*id, reg.pair.clone()))
177 .collect()
178 };
179
180 for (registration_id, pair) in timed_out {
181 if let Some(sender) = &polling_request_sender {
182 let request = PollingRequest {
183 registration_id,
184 speaker_service_pair: pair,
185 action: PollingAction::Start,
186 reason: PollingReason::EventTimeout,
187 };
188
189 if sender.send(request).is_ok() {
190 let mut regs = registrations.write().await;
192 if let Some(reg) = regs.get_mut(®istration_id) {
193 reg.polling_activated = true;
194 }
195
196 debug!(
197 registration_id = %registration_id,
198 "Event timeout detected, sent polling request"
199 );
200 }
201 }
202 }
203 }
204 })
205 }
206
207 pub async fn register_subscription(
209 &self,
210 registration_id: RegistrationId,
211 pair: SpeakerServicePair,
212 ) {
213 let mut registrations = self.registrations.write().await;
214 registrations.insert(
215 registration_id,
216 MonitoredRegistration {
217 last_event_time: Instant::now(),
218 pair,
219 polling_activated: false,
220 },
221 );
222 }
223
224 pub async fn unregister_subscription(&self, registration_id: RegistrationId) {
226 let mut registrations = self.registrations.write().await;
227 registrations.remove(®istration_id);
228 }
229
230 pub async fn stats(&self) -> EventDetectorStats {
232 let registrations = self.registrations.read().await;
233 let total_monitored = registrations.len();
234
235 let now = Instant::now();
236 let mut timeout_count = 0;
237 let mut recent_events_count = 0;
238
239 for reg in registrations.values() {
240 let elapsed = now.duration_since(reg.last_event_time);
241 if elapsed > self.event_timeout {
242 timeout_count += 1;
243 } else if elapsed <= Duration::from_secs(60) {
244 recent_events_count += 1;
245 }
246 }
247
248 let firewall_status = FirewallStatus::Unknown;
250
251 EventDetectorStats {
252 total_monitored,
253 timeout_count,
254 recent_events_count,
255 firewall_status,
256 event_timeout: self.event_timeout,
257 }
258 }
259}
260
261#[derive(Debug)]
263pub struct EventDetectorStats {
264 pub total_monitored: usize,
265 pub timeout_count: usize,
266 pub recent_events_count: usize,
267 pub firewall_status: FirewallStatus,
268 pub event_timeout: Duration,
269}
270
271impl std::fmt::Display for EventDetectorStats {
272 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
273 writeln!(f, "Event Detector Stats:")?;
274 writeln!(f, " Total monitored: {}", self.total_monitored)?;
275 writeln!(f, " Timeout count: {}", self.timeout_count)?;
276 writeln!(f, " Recent events: {}", self.recent_events_count)?;
277 writeln!(f, " Firewall status: {:?}", self.firewall_status)?;
278 writeln!(f, " Event timeout: {:?}", self.event_timeout)?;
279 Ok(())
280 }
281}
282
283#[cfg(test)]
284mod tests {
285 use super::*;
286
287 #[tokio::test]
288 async fn test_event_detector_creation() {
289 let detector = EventDetector::new(Duration::from_secs(30), Duration::from_secs(5));
290
291 assert_eq!(detector.event_timeout, Duration::from_secs(30));
292 assert_eq!(detector.polling_activation_delay, Duration::from_secs(5));
293 }
294
295 #[tokio::test]
296 async fn test_event_recording() {
297 let detector = EventDetector::new(Duration::from_secs(30), Duration::from_secs(5));
298
299 let registration_id = RegistrationId::new(1);
300 let pair = SpeakerServicePair::new(
301 "192.168.1.100".parse().unwrap(),
302 sonos_api::Service::AVTransport,
303 );
304
305 assert!(!detector.should_start_polling(registration_id).await);
307
308 detector.register_subscription(registration_id, pair).await;
310 detector.record_event(registration_id).await;
311
312 assert!(!detector.should_start_polling(registration_id).await);
314 }
315
316 #[tokio::test]
317 async fn test_subscription_registration() {
318 let detector = EventDetector::new(Duration::from_secs(30), Duration::from_secs(5));
319
320 let registration_id = RegistrationId::new(1);
321 let pair = SpeakerServicePair::new(
322 "192.168.1.100".parse().unwrap(),
323 sonos_api::Service::AVTransport,
324 );
325
326 detector.register_subscription(registration_id, pair).await;
328
329 let stats = detector.stats().await;
330 assert_eq!(stats.total_monitored, 1);
331
332 detector.unregister_subscription(registration_id).await;
334
335 let stats = detector.stats().await;
336 assert_eq!(stats.total_monitored, 0);
337 }
338
339 #[tokio::test]
340 async fn test_register_and_unregister() {
341 let detector = EventDetector::new(Duration::from_secs(30), Duration::from_secs(5));
342
343 let registration_id = RegistrationId::new(1);
344 let pair = SpeakerServicePair::new(
345 "192.168.1.100".parse().unwrap(),
346 sonos_api::Service::AVTransport,
347 );
348
349 detector
351 .register_subscription(registration_id, pair.clone())
352 .await;
353
354 let regs = detector.registrations.read().await;
356 assert!(regs.contains_key(®istration_id));
357 assert_eq!(regs[®istration_id].pair.speaker_ip, pair.speaker_ip);
358 drop(regs);
359
360 detector.unregister_subscription(registration_id).await;
362
363 let regs = detector.registrations.read().await;
364 assert!(!regs.contains_key(®istration_id));
365 }
366
367 #[tokio::test]
368 async fn test_event_timeout_sends_polling_request() {
369 use tokio::sync::mpsc;
370
371 let mut detector = EventDetector::new(Duration::from_millis(50), Duration::from_secs(5));
373
374 let (sender, mut receiver) = mpsc::unbounded_channel();
375 detector.set_polling_request_sender(sender);
376 let detector = Arc::new(detector);
377
378 let registration_id = RegistrationId::new(42);
379 let pair = SpeakerServicePair::new(
380 "192.168.1.100".parse().unwrap(),
381 sonos_api::Service::RenderingControl,
382 );
383
384 detector
386 .register_subscription(registration_id, pair.clone())
387 .await;
388
389 {
391 let mut regs = detector.registrations.write().await;
392 if let Some(reg) = regs.get_mut(®istration_id) {
393 reg.last_event_time = Instant::now() - Duration::from_secs(60);
394 }
395 }
396
397 detector.start_monitoring().await;
399
400 let request = tokio::time::timeout(Duration::from_secs(2), receiver.recv()).await;
402
403 assert!(
404 request.is_ok(),
405 "Should receive a polling request within timeout"
406 );
407 let request = request.unwrap().expect("Channel should have a message");
408 assert_eq!(request.registration_id, registration_id);
409 assert_eq!(request.speaker_service_pair.speaker_ip, pair.speaker_ip);
410 assert!(matches!(request.action, PollingAction::Start));
411 assert_eq!(request.reason, PollingReason::EventTimeout);
412 }
413}