Skip to main content

timer_lib/
registry.rs

1use std::collections::HashMap;
2use std::sync::{
3    atomic::{AtomicU64, Ordering},
4    Arc,
5};
6use std::time::Duration;
7
8use tokio::sync::RwLock;
9
10use crate::errors::TimerError;
11use crate::timer::{Timer, TimerCallback, TimerOutcome, TimerState};
12
13/// A registry for tracking timers by identifier.
14#[derive(Clone, Default)]
15pub struct TimerRegistry {
16    timers: Arc<RwLock<HashMap<u64, Timer>>>,
17    next_id: Arc<AtomicU64>,
18}
19
20impl TimerRegistry {
21    /// Creates a new timer registry.
22    pub fn new() -> Self {
23        Self {
24            timers: Arc::new(RwLock::new(HashMap::new())),
25            next_id: Arc::new(AtomicU64::new(0)),
26        }
27    }
28
29    /// Inserts an existing timer and returns its identifier.
30    pub async fn insert(&self, timer: Timer) -> u64 {
31        let id = self.next_id.fetch_add(1, Ordering::Relaxed);
32        self.timers.write().await.insert(id, timer);
33        id
34    }
35
36    /// Starts and registers a one-time timer.
37    pub async fn start_once<F>(
38        &self,
39        delay: Duration,
40        callback: F,
41    ) -> Result<(u64, Timer), TimerError>
42    where
43        F: TimerCallback + 'static,
44    {
45        let timer = Timer::new();
46        let _ = timer.start_once(delay, callback).await?;
47        let id = self.insert(timer.clone()).await;
48        Ok((id, timer))
49    }
50
51    /// Starts and registers a recurring timer.
52    pub async fn start_recurring<F>(
53        &self,
54        interval: Duration,
55        callback: F,
56        expiration_count: Option<usize>,
57    ) -> Result<(u64, Timer), TimerError>
58    where
59        F: TimerCallback + 'static,
60    {
61        let timer = Timer::new();
62        let _ = timer
63            .start_recurring(interval, callback, expiration_count)
64            .await?;
65        let id = self.insert(timer.clone()).await;
66        Ok((id, timer))
67    }
68
69    /// Removes a timer from the registry and returns it.
70    pub async fn remove(&self, id: u64) -> Option<Timer> {
71        self.timers.write().await.remove(&id)
72    }
73
74    /// Returns true when the registry tracks the given timer identifier.
75    pub async fn contains(&self, id: u64) -> bool {
76        self.timers.read().await.contains_key(&id)
77    }
78
79    /// Stops a timer by identifier when it exists.
80    pub async fn stop(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
81        let timer = self.get(id).await;
82        match timer {
83            Some(timer) => timer.stop().await.map(Some),
84            None => Ok(None),
85        }
86    }
87
88    /// Cancels a timer by identifier when it exists.
89    pub async fn cancel(&self, id: u64) -> Result<Option<TimerOutcome>, TimerError> {
90        let timer = self.get(id).await;
91        match timer {
92            Some(timer) => timer.cancel().await.map(Some),
93            None => Ok(None),
94        }
95    }
96
97    /// Stops all timers currently tracked by the registry.
98    pub async fn stop_all(&self) {
99        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
100        for timer in timers {
101            let _ = timer.stop().await;
102        }
103    }
104
105    /// Waits for all tracked timers that have a joinable outcome.
106    pub async fn join_all(&self) -> Vec<(u64, TimerOutcome)> {
107        let timers: Vec<(u64, Timer)> = self
108            .timers
109            .read()
110            .await
111            .iter()
112            .map(|(id, timer)| (*id, timer.clone()))
113            .collect();
114
115        let mut outcomes = Vec::with_capacity(timers.len());
116        for (id, timer) in timers {
117            if let Ok(outcome) = timer.join().await {
118                outcomes.push((id, outcome));
119            }
120        }
121
122        outcomes
123    }
124
125    /// Cancels all timers currently tracked by the registry.
126    pub async fn cancel_all(&self) {
127        let timers: Vec<Timer> = self.timers.read().await.values().cloned().collect();
128        for timer in timers {
129            let _ = timer.cancel().await;
130        }
131    }
132
133    /// Lists all active timers.
134    pub async fn active_ids(&self) -> Vec<u64> {
135        let timers: Vec<(u64, Timer)> = self
136            .timers
137            .read()
138            .await
139            .iter()
140            .map(|(id, timer)| (*id, timer.clone()))
141            .collect();
142
143        let mut active = Vec::new();
144        for (id, timer) in timers {
145            if timer.get_state().await != TimerState::Stopped {
146                active.push(id);
147            }
148        }
149        active
150    }
151
152    /// Retrieves a timer by ID.
153    pub async fn get(&self, id: u64) -> Option<Timer> {
154        self.timers.read().await.get(&id).cloned()
155    }
156
157    /// Returns the number of tracked timers.
158    pub async fn len(&self) -> usize {
159        self.timers.read().await.len()
160    }
161
162    /// Returns true when the registry is empty.
163    pub async fn is_empty(&self) -> bool {
164        self.len().await == 0
165    }
166
167    /// Removes all tracked timers and returns the number removed.
168    pub async fn clear(&self) -> usize {
169        let mut timers = self.timers.write().await;
170        let removed = timers.len();
171        timers.clear();
172        removed
173    }
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179    use crate::timer::TimerFinishReason;
180    use tokio::task::yield_now;
181    use tokio::time::advance;
182
183    async fn settle() {
184        for _ in 0..5 {
185            yield_now().await;
186        }
187    }
188
189    #[tokio::test(flavor = "current_thread", start_paused = true)]
190    async fn registry_start_helpers_are_easy_to_use() {
191        let registry = TimerRegistry::new();
192        let (once_id, once_timer) = registry
193            .start_once(Duration::from_secs(1), || async { Ok(()) })
194            .await
195            .unwrap();
196        let (recurring_id, recurring_timer) = registry
197            .start_recurring(Duration::from_secs(2), || async { Ok(()) }, None)
198            .await
199            .unwrap();
200
201        assert_ne!(once_id, recurring_id);
202        assert_eq!(registry.len().await, 2);
203        assert!(registry.get(once_id).await.is_some());
204
205        advance(Duration::from_secs(1)).await;
206        settle().await;
207        assert_eq!(
208            once_timer.join().await.unwrap().reason,
209            crate::timer::TimerFinishReason::Completed
210        );
211
212        let active = registry.active_ids().await;
213        assert!(active.contains(&recurring_id));
214
215        let _ = recurring_timer.cancel().await.unwrap();
216    }
217
218    #[tokio::test(flavor = "current_thread", start_paused = true)]
219    async fn registry_supports_direct_timer_controls() {
220        let registry = TimerRegistry::new();
221        let (timer_id, _timer) = registry
222            .start_once(Duration::from_secs(5), || async { Ok(()) })
223            .await
224            .unwrap();
225
226        assert!(registry.contains(timer_id).await);
227        let outcome = registry.cancel(timer_id).await.unwrap().unwrap();
228        assert_eq!(outcome.reason, TimerFinishReason::Cancelled);
229        assert_eq!(registry.clear().await, 1);
230        assert!(registry.is_empty().await);
231    }
232}