Skip to main content

sonos_stream/polling/
scheduler.rs

1//! Polling task scheduler and management
2//!
3//! This module provides intelligent polling task management with support for
4//! adaptive intervals, graceful shutdown, and coordination with the event system.
5
6use std::collections::HashMap;
7use std::sync::atomic::{AtomicBool, Ordering};
8use std::sync::Arc;
9use std::time::{Duration, SystemTime};
10use tokio::sync::{mpsc, RwLock};
11use tokio::task::JoinHandle;
12use tracing::{debug, error, info, warn};
13
14use crate::error::{PollingError, PollingResult};
15use crate::events::types::{EnrichedEvent, EventSource};
16use crate::polling::strategies::DeviceStatePoller;
17use crate::registry::{RegistrationId, SpeakerServicePair};
18
19/// A single polling task with state management
20#[derive(Debug)]
21pub struct PollingTask {
22    /// Registration ID this task is polling for
23    registration_id: RegistrationId,
24
25    /// Speaker/service pair being polled
26    speaker_service_pair: SpeakerServicePair,
27
28    /// Current polling interval
29    current_interval: Duration,
30
31    /// Task handle for the background polling loop
32    task_handle: JoinHandle<()>,
33
34    /// Shutdown signal for graceful termination
35    shutdown_signal: Arc<AtomicBool>,
36
37    /// When this task was started
38    started_at: SystemTime,
39
40    /// Number of consecutive errors
41    error_count: Arc<RwLock<u32>>,
42
43    /// Total number of polls performed
44    poll_count: Arc<RwLock<u64>>,
45}
46
47impl PollingTask {
48    /// Create and start a new polling task
49    pub fn start(
50        registration_id: RegistrationId,
51        speaker_service_pair: SpeakerServicePair,
52        initial_interval: Duration,
53        max_interval: Duration,
54        adaptive_polling: bool,
55        device_poller: Arc<DeviceStatePoller>,
56        event_sender: mpsc::UnboundedSender<EnrichedEvent>,
57    ) -> Self {
58        let shutdown_signal = Arc::new(AtomicBool::new(false));
59        let error_count = Arc::new(RwLock::new(0));
60        let poll_count = Arc::new(RwLock::new(0));
61
62        // Clone for the task
63        let task_registration_id = registration_id;
64        let task_pair = speaker_service_pair.clone();
65        let task_shutdown_signal = Arc::clone(&shutdown_signal);
66        let task_error_count = Arc::clone(&error_count);
67        let task_poll_count = Arc::clone(&poll_count);
68
69        let task_handle = tokio::spawn(async move {
70            Self::polling_loop(
71                task_registration_id,
72                task_pair,
73                initial_interval,
74                max_interval,
75                adaptive_polling,
76                device_poller,
77                event_sender,
78                task_shutdown_signal,
79                task_error_count,
80                task_poll_count,
81            )
82            .await;
83        });
84
85        Self {
86            registration_id,
87            speaker_service_pair,
88            current_interval: initial_interval,
89            task_handle,
90            shutdown_signal,
91            started_at: SystemTime::now(),
92            error_count,
93            poll_count,
94        }
95    }
96
97    /// Main polling loop
98    #[allow(clippy::too_many_arguments)]
99    async fn polling_loop(
100        registration_id: RegistrationId,
101        pair: SpeakerServicePair,
102        mut current_interval: Duration,
103        max_interval: Duration,
104        adaptive_polling: bool,
105        device_poller: Arc<DeviceStatePoller>,
106        event_sender: mpsc::UnboundedSender<EnrichedEvent>,
107        shutdown_signal: Arc<AtomicBool>,
108        error_count: Arc<RwLock<u32>>,
109        poll_count: Arc<RwLock<u64>>,
110    ) {
111        info!(
112            speaker_ip = %pair.speaker_ip,
113            service = ?pair.service,
114            ?current_interval,
115            "Starting polling task"
116        );
117
118        // Track last state locally within the loop
119        let mut last_state: Option<String> = None;
120
121        loop {
122            // Check for shutdown signal
123            if shutdown_signal.load(Ordering::Relaxed) {
124                info!(
125                    speaker_ip = %pair.speaker_ip,
126                    service = ?pair.service,
127                    "Polling task shutting down"
128                );
129                break;
130            }
131
132            // Sleep for the current interval
133            tokio::time::sleep(current_interval).await;
134
135            // Increment poll count
136            {
137                let mut count = poll_count.write().await;
138                *count += 1;
139            }
140
141            // Poll the device state
142            match device_poller.poll_device_state(&pair).await {
143                Ok(current_state) => {
144                    // Reset error count on success
145                    {
146                        let mut errors = error_count.write().await;
147                        *errors = 0;
148                    }
149
150                    // Check for state changes (compare without cloning)
151                    let state_changed = last_state.as_deref() != Some(current_state.as_str());
152
153                    if state_changed {
154                        last_state = Some(current_state.clone());
155                    }
156
157                    if state_changed {
158                        debug!(
159                            speaker_ip = %pair.speaker_ip,
160                            service = ?pair.service,
161                            "State change detected"
162                        );
163
164                        // Convert JSON snapshot to EventData and emit full-state event
165                        match device_poller.state_to_event_data(&pair.service, &current_state) {
166                            Ok(event_data) => {
167                                let enriched_event = EnrichedEvent::new(
168                                    registration_id,
169                                    pair.speaker_ip,
170                                    pair.service,
171                                    EventSource::PollingDetection {
172                                        poll_interval: current_interval,
173                                    },
174                                    event_data,
175                                );
176
177                                if event_sender.send(enriched_event).is_err() {
178                                    error!(
179                                        speaker_ip = %pair.speaker_ip,
180                                        service = ?pair.service,
181                                        "Failed to send polling event — channel closed"
182                                    );
183                                    return;
184                                }
185                            }
186                            Err(e) => {
187                                warn!(
188                                    speaker_ip = %pair.speaker_ip,
189                                    service = ?pair.service,
190                                    error = %e,
191                                    "Failed to convert state to event data"
192                                );
193                            }
194                        }
195
196                        // Adjust interval if adaptive polling is enabled
197                        if adaptive_polling {
198                            current_interval = Self::calculate_adaptive_interval(
199                                current_interval,
200                                max_interval,
201                                SystemTime::now(),
202                            );
203                        }
204                    }
205                }
206                Err(e) => {
207                    // Increment error count
208                    let error_count_value = {
209                        let mut errors = error_count.write().await;
210                        *errors += 1;
211                        *errors
212                    };
213
214                    warn!(
215                        speaker_ip = %pair.speaker_ip,
216                        service = ?pair.service,
217                        attempt = error_count_value,
218                        error = %e,
219                        "Polling error"
220                    );
221
222                    // Use exponential backoff for errors
223                    if error_count_value >= 5 {
224                        error!(
225                            speaker_ip = %pair.speaker_ip,
226                            service = ?pair.service,
227                            "Too many consecutive errors, stopping polling"
228                        );
229                        break;
230                    }
231
232                    // Exponential backoff up to max interval
233                    let backoff_interval = current_interval * (2_u32.pow(error_count_value.min(6)));
234                    let capped_interval = backoff_interval.min(max_interval);
235                    tokio::time::sleep(capped_interval).await;
236                }
237            }
238        }
239
240        info!(
241            speaker_ip = %pair.speaker_ip,
242            service = ?pair.service,
243            "Polling task ended"
244        );
245    }
246
247    /// Calculate adaptive polling interval based on recent activity
248    fn calculate_adaptive_interval(
249        current_interval: Duration,
250        max_interval: Duration,
251        last_change_time: SystemTime,
252    ) -> Duration {
253        let time_since_change = SystemTime::now()
254            .duration_since(last_change_time)
255            .unwrap_or(Duration::ZERO);
256
257        if time_since_change < Duration::from_secs(30) {
258            // Recent activity - poll faster
259            (current_interval / 2).max(Duration::from_secs(2))
260        } else if time_since_change > Duration::from_secs(300) {
261            // No recent activity - poll slower
262            (current_interval * 2).min(max_interval)
263        } else {
264            current_interval
265        }
266    }
267
268    /// Get the registration ID for this task
269    pub fn registration_id(&self) -> RegistrationId {
270        self.registration_id
271    }
272
273    /// Get the speaker/service pair for this task
274    pub fn speaker_service_pair(&self) -> &SpeakerServicePair {
275        &self.speaker_service_pair
276    }
277
278    /// Get the current polling interval
279    pub fn current_interval(&self) -> Duration {
280        self.current_interval
281    }
282
283    /// Check if the task is still running
284    pub fn is_running(&self) -> bool {
285        !self.task_handle.is_finished()
286    }
287
288    /// Get task statistics
289    pub async fn stats(&self) -> PollingTaskStats {
290        let error_count = *self.error_count.read().await;
291        let poll_count = *self.poll_count.read().await;
292
293        PollingTaskStats {
294            registration_id: self.registration_id,
295            speaker_service_pair: self.speaker_service_pair.clone(),
296            current_interval: self.current_interval,
297            started_at: self.started_at,
298            error_count,
299            poll_count,
300            is_running: self.is_running(),
301        }
302    }
303
304    /// Request graceful shutdown of this polling task
305    pub async fn shutdown(self) -> PollingResult<()> {
306        // Signal shutdown
307        self.shutdown_signal.store(true, Ordering::Relaxed);
308
309        // Wait for task to complete
310        match self.task_handle.await {
311            Ok(()) => Ok(()),
312            Err(e) => Err(PollingError::TaskSpawn(format!(
313                "Failed to await task completion: {e}"
314            ))),
315        }
316    }
317}
318
319/// Statistics for a polling task
320#[derive(Debug, Clone)]
321pub struct PollingTaskStats {
322    pub registration_id: RegistrationId,
323    pub speaker_service_pair: SpeakerServicePair,
324    pub current_interval: Duration,
325    pub started_at: SystemTime,
326    pub error_count: u32,
327    pub poll_count: u64,
328    pub is_running: bool,
329}
330
331/// Manages multiple polling tasks
332pub struct PollingScheduler {
333    /// Active polling tasks indexed by registration ID
334    active_tasks: Arc<RwLock<HashMap<RegistrationId, PollingTask>>>,
335
336    /// Device state poller for making actual polling requests
337    device_poller: Arc<DeviceStatePoller>,
338
339    /// Event sender for emitting synthetic events
340    event_sender: mpsc::UnboundedSender<EnrichedEvent>,
341
342    /// Base polling interval
343    base_interval: Duration,
344
345    /// Maximum polling interval for adaptive polling
346    max_interval: Duration,
347
348    /// Whether to use adaptive polling intervals
349    adaptive_polling: bool,
350
351    /// Maximum number of concurrent polling tasks
352    max_concurrent_tasks: usize,
353}
354
355impl PollingScheduler {
356    /// Create a new polling scheduler
357    pub fn new(
358        event_sender: mpsc::UnboundedSender<EnrichedEvent>,
359        base_interval: Duration,
360        max_interval: Duration,
361        adaptive_polling: bool,
362        max_concurrent_tasks: usize,
363    ) -> Self {
364        Self {
365            active_tasks: Arc::new(RwLock::new(HashMap::new())),
366            device_poller: Arc::new(DeviceStatePoller::new()),
367            event_sender,
368            base_interval,
369            max_interval,
370            adaptive_polling,
371            max_concurrent_tasks,
372        }
373    }
374
375    /// Start polling for a speaker/service pair
376    pub async fn start_polling(
377        &self,
378        registration_id: RegistrationId,
379        pair: SpeakerServicePair,
380    ) -> PollingResult<()> {
381        let mut tasks = self.active_tasks.write().await;
382
383        // Check if already polling
384        if tasks.contains_key(&registration_id) {
385            return Ok(()); // Already polling
386        }
387
388        // Check concurrent task limit
389        if tasks.len() >= self.max_concurrent_tasks {
390            return Err(PollingError::TooManyErrors {
391                error_count: tasks.len() as u32,
392            });
393        }
394
395        // Start new polling task
396        let task = PollingTask::start(
397            registration_id,
398            pair.clone(),
399            self.base_interval,
400            self.max_interval,
401            self.adaptive_polling,
402            Arc::clone(&self.device_poller),
403            self.event_sender.clone(),
404        );
405
406        tasks.insert(registration_id, task);
407
408        info!(
409            speaker_ip = %pair.speaker_ip,
410            service = ?pair.service,
411            "Started polling"
412        );
413
414        Ok(())
415    }
416
417    /// Stop polling for a registration ID
418    pub async fn stop_polling(&self, registration_id: RegistrationId) -> PollingResult<()> {
419        let mut tasks = self.active_tasks.write().await;
420
421        if let Some(task) = tasks.remove(&registration_id) {
422            let pair = task.speaker_service_pair().clone();
423            // Shutdown happens when task is dropped, but we can explicitly shut it down
424            task.shutdown().await?;
425
426            info!(
427                speaker_ip = %pair.speaker_ip,
428                service = ?pair.service,
429                "Stopped polling"
430            );
431        }
432
433        Ok(())
434    }
435
436    /// Check if a registration is currently being polled
437    pub async fn is_polling(&self, registration_id: RegistrationId) -> bool {
438        let tasks = self.active_tasks.read().await;
439        tasks.contains_key(&registration_id)
440    }
441
442    /// Get statistics for all active polling tasks
443    pub async fn stats(&self) -> PollingSchedulerStats {
444        let tasks = self.active_tasks.read().await;
445        let total_tasks = tasks.len();
446
447        let mut task_stats = Vec::new();
448        for task in tasks.values() {
449            task_stats.push(task.stats().await);
450        }
451
452        PollingSchedulerStats {
453            total_active_tasks: total_tasks,
454            max_concurrent_tasks: self.max_concurrent_tasks,
455            base_interval: self.base_interval,
456            max_interval: self.max_interval,
457            adaptive_polling: self.adaptive_polling,
458            task_stats,
459        }
460    }
461
462    /// Shutdown all polling tasks
463    pub async fn shutdown_all(&self) -> PollingResult<()> {
464        let mut tasks = self.active_tasks.write().await;
465
466        for (registration_id, task) in tasks.drain() {
467            match task.shutdown().await {
468                Ok(()) => {
469                    debug!(%registration_id, "Shutdown polling task");
470                }
471                Err(e) => {
472                    error!(%registration_id, error = %e, "Failed to shutdown polling task");
473                }
474            }
475        }
476
477        Ok(())
478    }
479}
480
481/// Statistics for the polling scheduler
482#[derive(Debug)]
483pub struct PollingSchedulerStats {
484    pub total_active_tasks: usize,
485    pub max_concurrent_tasks: usize,
486    pub base_interval: Duration,
487    pub max_interval: Duration,
488    pub adaptive_polling: bool,
489    pub task_stats: Vec<PollingTaskStats>,
490}
491
492impl std::fmt::Display for PollingSchedulerStats {
493    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
494        writeln!(f, "Polling Scheduler Stats:")?;
495        writeln!(
496            f,
497            "  Active tasks: {}/{}",
498            self.total_active_tasks, self.max_concurrent_tasks
499        )?;
500        writeln!(f, "  Base interval: {:?}", self.base_interval)?;
501        writeln!(f, "  Max interval: {:?}", self.max_interval)?;
502        writeln!(f, "  Adaptive polling: {}", self.adaptive_polling)?;
503
504        if !self.task_stats.is_empty() {
505            writeln!(f, "  Task details:")?;
506            for stat in &self.task_stats {
507                writeln!(
508                    f,
509                    "    {}: {} {:?} (interval: {:?}, polls: {}, errors: {})",
510                    stat.registration_id,
511                    stat.speaker_service_pair.speaker_ip,
512                    stat.speaker_service_pair.service,
513                    stat.current_interval,
514                    stat.poll_count,
515                    stat.error_count
516                )?;
517            }
518        }
519
520        Ok(())
521    }
522}
523
524#[cfg(test)]
525mod tests {
526    use super::*;
527    use tokio::sync::mpsc;
528
529    #[tokio::test]
530    async fn test_polling_scheduler_creation() {
531        let (event_sender, _event_receiver) = mpsc::unbounded_channel();
532        let scheduler = PollingScheduler::new(
533            event_sender,
534            Duration::from_secs(5),
535            Duration::from_secs(30),
536            true,
537            10,
538        );
539
540        let stats = scheduler.stats().await;
541        assert_eq!(stats.total_active_tasks, 0);
542        assert_eq!(stats.max_concurrent_tasks, 10);
543        assert!(stats.adaptive_polling);
544    }
545
546    #[tokio::test]
547    async fn test_polling_task_lifecycle() {
548        let (event_sender, _event_receiver) = mpsc::unbounded_channel();
549        let scheduler = PollingScheduler::new(
550            event_sender,
551            Duration::from_millis(100), // Fast polling for testing
552            Duration::from_secs(1),
553            false,
554            5,
555        );
556
557        let registration_id = RegistrationId::new(1);
558        let pair = SpeakerServicePair::new(
559            "192.168.1.100".parse().unwrap(),
560            sonos_api::Service::AVTransport,
561        );
562
563        // Start polling
564        scheduler
565            .start_polling(registration_id, pair.clone())
566            .await
567            .unwrap();
568        assert!(scheduler.is_polling(registration_id).await);
569
570        // Stop polling
571        scheduler.stop_polling(registration_id).await.unwrap();
572        assert!(!scheduler.is_polling(registration_id).await);
573    }
574
575    #[test]
576    fn test_adaptive_interval_calculation() {
577        let current = Duration::from_secs(5);
578        let max = Duration::from_secs(30);
579        let recent_change = SystemTime::now() - Duration::from_secs(10);
580
581        let new_interval = PollingTask::calculate_adaptive_interval(current, max, recent_change);
582        // Should decrease interval for recent activity
583        assert!(new_interval <= current);
584
585        let old_change = SystemTime::now() - Duration::from_secs(400);
586        let new_interval = PollingTask::calculate_adaptive_interval(current, max, old_change);
587        // Should increase interval for old activity
588        assert!(new_interval >= current);
589    }
590}