1use std::collections::HashMap;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::sync::{mpsc, RwLock};
11use tokio::task::JoinHandle;
12use tracing::{debug, error, info, warn};
13
14use crate::error::{PollingError, PollingResult};
15use crate::events::types::{EnrichedEvent, EventSource};
16use crate::polling::strategies::DeviceStatePoller;
17use crate::registry::{RegistrationId, SpeakerServicePair};
18
19#[derive(Debug)]
21pub struct PollingTask {
22 registration_id: RegistrationId,
24
25 speaker_service_pair: SpeakerServicePair,
27
28 current_interval: Duration,
30
31 task_handle: JoinHandle<()>,
33
34 shutdown_signal: Arc<AtomicBool>,
36
37 started_at: SystemTime,
39
40 error_count: Arc<RwLock<u32>>,
42
43 poll_count: Arc<RwLock<u64>>,
45}
46
47impl PollingTask {
48 pub fn start(
50 registration_id: RegistrationId,
51 speaker_service_pair: SpeakerServicePair,
52 initial_interval: Duration,
53 max_interval: Duration,
54 adaptive_polling: bool,
55 device_poller: Arc<DeviceStatePoller>,
56 event_sender: mpsc::UnboundedSender<EnrichedEvent>,
57 ) -> Self {
58 let shutdown_signal = Arc::new(AtomicBool::new(false));
59 let error_count = Arc::new(RwLock::new(0));
60 let poll_count = Arc::new(RwLock::new(0));
61
62 let task_registration_id = registration_id;
64 let task_pair = speaker_service_pair.clone();
65 let task_shutdown_signal = Arc::clone(&shutdown_signal);
66 let task_error_count = Arc::clone(&error_count);
67 let task_poll_count = Arc::clone(&poll_count);
68
69 let task_handle = tokio::spawn(async move {
70 Self::polling_loop(
71 task_registration_id,
72 task_pair,
73 initial_interval,
74 max_interval,
75 adaptive_polling,
76 device_poller,
77 event_sender,
78 task_shutdown_signal,
79 task_error_count,
80 task_poll_count,
81 )
82 .await;
83 });
84
85 Self {
86 registration_id,
87 speaker_service_pair,
88 current_interval: initial_interval,
89 task_handle,
90 shutdown_signal,
91 started_at: SystemTime::now(),
92 error_count,
93 poll_count,
94 }
95 }
96
97 #[allow(clippy::too_many_arguments)]
99 async fn polling_loop(
100 registration_id: RegistrationId,
101 pair: SpeakerServicePair,
102 mut current_interval: Duration,
103 max_interval: Duration,
104 adaptive_polling: bool,
105 device_poller: Arc<DeviceStatePoller>,
106 event_sender: mpsc::UnboundedSender<EnrichedEvent>,
107 shutdown_signal: Arc<AtomicBool>,
108 error_count: Arc<RwLock<u32>>,
109 poll_count: Arc<RwLock<u64>>,
110 ) {
111 info!(
112 speaker_ip = %pair.speaker_ip,
113 service = ?pair.service,
114 ?current_interval,
115 "Starting polling task"
116 );
117
118 let mut last_state: Option<String> = None;
120
121 loop {
122 if shutdown_signal.load(Ordering::Relaxed) {
124 info!(
125 speaker_ip = %pair.speaker_ip,
126 service = ?pair.service,
127 "Polling task shutting down"
128 );
129 break;
130 }
131
132 tokio::time::sleep(current_interval).await;
134
135 {
137 let mut count = poll_count.write().await;
138 *count += 1;
139 }
140
141 match device_poller.poll_device_state(&pair).await {
143 Ok(current_state) => {
144 {
146 let mut errors = error_count.write().await;
147 *errors = 0;
148 }
149
150 let state_changed = last_state.as_deref() != Some(current_state.as_str());
152
153 if state_changed {
154 last_state = Some(current_state.clone());
155 }
156
157 if state_changed {
158 debug!(
159 speaker_ip = %pair.speaker_ip,
160 service = ?pair.service,
161 "State change detected"
162 );
163
164 match device_poller.state_to_event_data(&pair.service, ¤t_state) {
166 Ok(event_data) => {
167 let enriched_event = EnrichedEvent::new(
168 registration_id,
169 pair.speaker_ip,
170 pair.service,
171 EventSource::PollingDetection {
172 poll_interval: current_interval,
173 },
174 event_data,
175 );
176
177 if event_sender.send(enriched_event).is_err() {
178 error!(
179 speaker_ip = %pair.speaker_ip,
180 service = ?pair.service,
181 "Failed to send polling event — channel closed"
182 );
183 return;
184 }
185 }
186 Err(e) => {
187 warn!(
188 speaker_ip = %pair.speaker_ip,
189 service = ?pair.service,
190 error = %e,
191 "Failed to convert state to event data"
192 );
193 }
194 }
195
196 if adaptive_polling {
198 current_interval = Self::calculate_adaptive_interval(
199 current_interval,
200 max_interval,
201 SystemTime::now(),
202 );
203 }
204 }
205 }
206 Err(e) => {
207 let error_count_value = {
209 let mut errors = error_count.write().await;
210 *errors += 1;
211 *errors
212 };
213
214 warn!(
215 speaker_ip = %pair.speaker_ip,
216 service = ?pair.service,
217 attempt = error_count_value,
218 error = %e,
219 "Polling error"
220 );
221
222 if error_count_value >= 5 {
224 error!(
225 speaker_ip = %pair.speaker_ip,
226 service = ?pair.service,
227 "Too many consecutive errors, stopping polling"
228 );
229 break;
230 }
231
232 let backoff_interval = current_interval * (2_u32.pow(error_count_value.min(6)));
234 let capped_interval = backoff_interval.min(max_interval);
235 tokio::time::sleep(capped_interval).await;
236 }
237 }
238 }
239
240 info!(
241 speaker_ip = %pair.speaker_ip,
242 service = ?pair.service,
243 "Polling task ended"
244 );
245 }
246
247 fn calculate_adaptive_interval(
249 current_interval: Duration,
250 max_interval: Duration,
251 last_change_time: SystemTime,
252 ) -> Duration {
253 let time_since_change = SystemTime::now()
254 .duration_since(last_change_time)
255 .unwrap_or(Duration::ZERO);
256
257 if time_since_change < Duration::from_secs(30) {
258 (current_interval / 2).max(Duration::from_secs(2))
260 } else if time_since_change > Duration::from_secs(300) {
261 (current_interval * 2).min(max_interval)
263 } else {
264 current_interval
265 }
266 }
267
268 pub fn registration_id(&self) -> RegistrationId {
270 self.registration_id
271 }
272
273 pub fn speaker_service_pair(&self) -> &SpeakerServicePair {
275 &self.speaker_service_pair
276 }
277
278 pub fn current_interval(&self) -> Duration {
280 self.current_interval
281 }
282
283 pub fn is_running(&self) -> bool {
285 !self.task_handle.is_finished()
286 }
287
288 pub async fn stats(&self) -> PollingTaskStats {
290 let error_count = *self.error_count.read().await;
291 let poll_count = *self.poll_count.read().await;
292
293 PollingTaskStats {
294 registration_id: self.registration_id,
295 speaker_service_pair: self.speaker_service_pair.clone(),
296 current_interval: self.current_interval,
297 started_at: self.started_at,
298 error_count,
299 poll_count,
300 is_running: self.is_running(),
301 }
302 }
303
304 pub async fn shutdown(self) -> PollingResult<()> {
306 self.shutdown_signal.store(true, Ordering::Relaxed);
308
309 match self.task_handle.await {
311 Ok(()) => Ok(()),
312 Err(e) => Err(PollingError::TaskSpawn(format!(
313 "Failed to await task completion: {e}"
314 ))),
315 }
316 }
317}
318
319#[derive(Debug, Clone)]
321pub struct PollingTaskStats {
322 pub registration_id: RegistrationId,
323 pub speaker_service_pair: SpeakerServicePair,
324 pub current_interval: Duration,
325 pub started_at: SystemTime,
326 pub error_count: u32,
327 pub poll_count: u64,
328 pub is_running: bool,
329}
330
331pub struct PollingScheduler {
333 active_tasks: Arc<RwLock<HashMap<RegistrationId, PollingTask>>>,
335
336 device_poller: Arc<DeviceStatePoller>,
338
339 event_sender: mpsc::UnboundedSender<EnrichedEvent>,
341
342 base_interval: Duration,
344
345 max_interval: Duration,
347
348 adaptive_polling: bool,
350
351 max_concurrent_tasks: usize,
353}
354
355impl PollingScheduler {
356 pub fn new(
358 event_sender: mpsc::UnboundedSender<EnrichedEvent>,
359 base_interval: Duration,
360 max_interval: Duration,
361 adaptive_polling: bool,
362 max_concurrent_tasks: usize,
363 ) -> Self {
364 Self {
365 active_tasks: Arc::new(RwLock::new(HashMap::new())),
366 device_poller: Arc::new(DeviceStatePoller::new()),
367 event_sender,
368 base_interval,
369 max_interval,
370 adaptive_polling,
371 max_concurrent_tasks,
372 }
373 }
374
375 pub async fn start_polling(
377 &self,
378 registration_id: RegistrationId,
379 pair: SpeakerServicePair,
380 ) -> PollingResult<()> {
381 let mut tasks = self.active_tasks.write().await;
382
383 if tasks.contains_key(®istration_id) {
385 return Ok(()); }
387
388 if tasks.len() >= self.max_concurrent_tasks {
390 return Err(PollingError::TooManyErrors {
391 error_count: tasks.len() as u32,
392 });
393 }
394
395 let task = PollingTask::start(
397 registration_id,
398 pair.clone(),
399 self.base_interval,
400 self.max_interval,
401 self.adaptive_polling,
402 Arc::clone(&self.device_poller),
403 self.event_sender.clone(),
404 );
405
406 tasks.insert(registration_id, task);
407
408 info!(
409 speaker_ip = %pair.speaker_ip,
410 service = ?pair.service,
411 "Started polling"
412 );
413
414 Ok(())
415 }
416
417 pub async fn stop_polling(&self, registration_id: RegistrationId) -> PollingResult<()> {
419 let mut tasks = self.active_tasks.write().await;
420
421 if let Some(task) = tasks.remove(®istration_id) {
422 let pair = task.speaker_service_pair().clone();
423 task.shutdown().await?;
425
426 info!(
427 speaker_ip = %pair.speaker_ip,
428 service = ?pair.service,
429 "Stopped polling"
430 );
431 }
432
433 Ok(())
434 }
435
436 pub async fn is_polling(&self, registration_id: RegistrationId) -> bool {
438 let tasks = self.active_tasks.read().await;
439 tasks.contains_key(®istration_id)
440 }
441
442 pub async fn stats(&self) -> PollingSchedulerStats {
444 let tasks = self.active_tasks.read().await;
445 let total_tasks = tasks.len();
446
447 let mut task_stats = Vec::new();
448 for task in tasks.values() {
449 task_stats.push(task.stats().await);
450 }
451
452 PollingSchedulerStats {
453 total_active_tasks: total_tasks,
454 max_concurrent_tasks: self.max_concurrent_tasks,
455 base_interval: self.base_interval,
456 max_interval: self.max_interval,
457 adaptive_polling: self.adaptive_polling,
458 task_stats,
459 }
460 }
461
462 pub async fn shutdown_all(&self) -> PollingResult<()> {
464 let mut tasks = self.active_tasks.write().await;
465
466 for (registration_id, task) in tasks.drain() {
467 match task.shutdown().await {
468 Ok(()) => {
469 debug!(%registration_id, "Shutdown polling task");
470 }
471 Err(e) => {
472 error!(%registration_id, error = %e, "Failed to shutdown polling task");
473 }
474 }
475 }
476
477 Ok(())
478 }
479}
480
481#[derive(Debug)]
483pub struct PollingSchedulerStats {
484 pub total_active_tasks: usize,
485 pub max_concurrent_tasks: usize,
486 pub base_interval: Duration,
487 pub max_interval: Duration,
488 pub adaptive_polling: bool,
489 pub task_stats: Vec<PollingTaskStats>,
490}
491
492impl std::fmt::Display for PollingSchedulerStats {
493 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494 writeln!(f, "Polling Scheduler Stats:")?;
495 writeln!(
496 f,
497 " Active tasks: {}/{}",
498 self.total_active_tasks, self.max_concurrent_tasks
499 )?;
500 writeln!(f, " Base interval: {:?}", self.base_interval)?;
501 writeln!(f, " Max interval: {:?}", self.max_interval)?;
502 writeln!(f, " Adaptive polling: {}", self.adaptive_polling)?;
503
504 if !self.task_stats.is_empty() {
505 writeln!(f, " Task details:")?;
506 for stat in &self.task_stats {
507 writeln!(
508 f,
509 " {}: {} {:?} (interval: {:?}, polls: {}, errors: {})",
510 stat.registration_id,
511 stat.speaker_service_pair.speaker_ip,
512 stat.speaker_service_pair.service,
513 stat.current_interval,
514 stat.poll_count,
515 stat.error_count
516 )?;
517 }
518 }
519
520 Ok(())
521 }
522}
523
524#[cfg(test)]
525mod tests {
526 use super::*;
527 use tokio::sync::mpsc;
528
529 #[tokio::test]
530 async fn test_polling_scheduler_creation() {
531 let (event_sender, _event_receiver) = mpsc::unbounded_channel();
532 let scheduler = PollingScheduler::new(
533 event_sender,
534 Duration::from_secs(5),
535 Duration::from_secs(30),
536 true,
537 10,
538 );
539
540 let stats = scheduler.stats().await;
541 assert_eq!(stats.total_active_tasks, 0);
542 assert_eq!(stats.max_concurrent_tasks, 10);
543 assert!(stats.adaptive_polling);
544 }
545
546 #[tokio::test]
547 async fn test_polling_task_lifecycle() {
548 let (event_sender, _event_receiver) = mpsc::unbounded_channel();
549 let scheduler = PollingScheduler::new(
550 event_sender,
551 Duration::from_millis(100), Duration::from_secs(1),
553 false,
554 5,
555 );
556
557 let registration_id = RegistrationId::new(1);
558 let pair = SpeakerServicePair::new(
559 "192.168.1.100".parse().unwrap(),
560 sonos_api::Service::AVTransport,
561 );
562
563 scheduler
565 .start_polling(registration_id, pair.clone())
566 .await
567 .unwrap();
568 assert!(scheduler.is_polling(registration_id).await);
569
570 scheduler.stop_polling(registration_id).await.unwrap();
572 assert!(!scheduler.is_polling(registration_id).await);
573 }
574
575 #[test]
576 fn test_adaptive_interval_calculation() {
577 let current = Duration::from_secs(5);
578 let max = Duration::from_secs(30);
579 let recent_change = SystemTime::now() - Duration::from_secs(10);
580
581 let new_interval = PollingTask::calculate_adaptive_interval(current, max, recent_change);
582 assert!(new_interval <= current);
584
585 let old_change = SystemTime::now() - Duration::from_secs(400);
586 let new_interval = PollingTask::calculate_adaptive_interval(current, max, old_change);
587 assert!(new_interval >= current);
589 }
590}